Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 267 additions & 26 deletions src/usflows/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,197 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return ret


class LayerNormChannelsND(nn.Module):
"""Channel-wise LayerNorm for N-D spatial tensors (batch, channel, *spatial)
Creates gamma and beta parameters shaped (1, C, 1, ..., 1) with `num_spatial_dims` trailing ones.
"""

def __init__(self, c_in, num_spatial_dims: int = 2, eps=1e-5):
super().__init__()
shape = (1, c_in) + (1,) * num_spatial_dims
self.gamma = nn.Parameter(torch.ones(*shape))
self.beta = nn.Parameter(torch.zeros(*shape))
self.eps = eps

def forward(self, x):
# mean/var over channels
mean = x.mean(dim=1, keepdim=True)
var = x.var(dim=1, unbiased=False, keepdim=True)
y = (x - mean) / torch.sqrt(var + self.eps)
y = y * self.gamma + self.beta
return y


class GatedConvND(nn.Module):
"""Gated residual block that adapts to 1/2/3-D convolutions.
The residual is automatically projected if channel counts differ.
"""

def __init__(
self,
c_in,
c_out,
kernel_size=3,
padding=1,
stride=1,
dilation=1,
nonlinearity: callable = nn.ReLU(),
input_rank: int = 2,
):
super().__init__()
assert stride == 1, "Stride > 1 cannot be used to skip connection."

conv_map = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d}
if input_rank not in conv_map:
raise ValueError(f"Unsupported input rank {input_rank}")
Conv = conv_map[input_rank]

# first conv reduces/increases to intermediate channels (we choose c_out)
self.net = nn.Sequential(
nonlinearity,
Conv(
c_in,
c_out,
kernel_size=kernel_size,
padding=padding,
stride=stride,
dilation=dilation,
),
nonlinearity,
# final 1x1 conv produces 2 * c_out channels to split into val/gate
Conv(
c_out,
2 * c_out,
kernel_size=1,
padding=0,
stride=1,
),
)

# projection for residual if channel counts differ
if c_in != c_out:
self.proj = Conv(c_in, c_out, kernel_size=1, padding=0)
else:
self.proj = None

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.net(x)
val, gate = out.chunk(2, dim=1)
res = val * torch.sigmoid(gate)
if self.proj is not None:
x = self.proj(x)
return x + res


class ConvNet(nn.Module):
"""Generic ConvNet that adapts to input topology (1D/2D/3D) using `in_dims`.

c_hidden: list[int] specifying channel size of each hidden layer (length replaces num_layers)
"""

def __init__(
self,
in_dims: Iterable[int],
c_hidden: List[int],
c_out: int = -1,
nonlinearity: any = nn.ReLU(),
kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
padding: Optional[int] = None,
normalize_layers: bool = True,
gating: bool = True,
):
super().__init__()

if padding is None:
padding = kernel_size // 2

# infer channels and input rank from in_dims
if not isinstance(in_dims, Iterable):
raise ValueError("in_dims must be an iterable like [C, H, W]")
in_dims = list(in_dims)
c_in = in_dims[0]
input_rank = max(1, len(in_dims) - 1)
c_out = c_out if c_out > 0 else c_in

conv_map = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d}
if input_rank not in conv_map:
raise ValueError(f"Unsupported input rank {input_rank}")
Conv = conv_map[input_rank]

layers = []
# initial conv from c_in -> first hidden
assert len(c_hidden) > 0 and all([h > 0 for h in c_hidden]), "c_hidden must be non-empty list of positive ints"
first_hidden = c_hidden[0]
layers.append(
Conv(
c_in,
first_hidden,
kernel_size=kernel_size,
padding=padding,
stride=stride,
dilation=dilation,
)
)

# hidden blocks
for i in range(len(c_hidden)):
in_ch = c_hidden[i - 1] if i > 0 else first_hidden
out_ch = c_hidden[i]
if gating:
layers.append(
GatedConvND(
in_ch,
out_ch,
kernel_size=kernel_size,
padding=padding,
stride=stride,
dilation=dilation,
nonlinearity=nonlinearity,
input_rank=input_rank,
)
)
layers.append(nonlinearity)
else:
layers.append(
Conv(
in_ch,
out_ch,
kernel_size=kernel_size,
padding=padding,
stride=stride,
dilation=dilation,
)
)
layers.append(nonlinearity)

if normalize_layers:
layers.append(LayerNormChannelsND(out_ch, num_spatial_dims=input_rank))

# final conv from last hidden -> c_out
layers.append(
Conv(
c_hidden[-1],
c_out,
kernel_size=kernel_size,
padding=padding,
stride=stride,
dilation=dilation,
)
)

self.nn = nn.Sequential(*layers)

def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor:
return self.nn(x)


class ConvNet2D(nn.Module):
def __init__(
self,
c_in: int,
c_hidden: int = 3,
rescale_hidden: int = 2,
c_out: int = -1,
num_layers: int = 3,
nonlinearity: any = nn.ReLU(),
Expand Down Expand Up @@ -173,8 +358,6 @@ def __init__(
dilation=dilation,
),
]
if rescale_hidden != 1:
layers += [nn.MaxPool2d(rescale_hidden)]

for layer_index in range(num_layers):
if gating:
Expand Down Expand Up @@ -208,27 +391,6 @@ def __init__(
]

# compute padding and output padding for rescaling via transposed convolutions
if rescale_hidden != 1:
diff = rescale_hidden - kernel_size
if diff < 0:
outpad = diff % 2
pad = ceil(abs(diff) / 2.0)
else:
outpad = diff
pad = 0

layers += [
nn.ConvTranspose2d(
c_hidden,
c_hidden,
kernel_size=kernel_size,
stride=stride,
output_padding=outpad,
padding=pad,
),
nonlinearity,
]

layers += [
nn.Conv2d(
c_hidden,
Expand All @@ -253,12 +415,92 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor
return self.nn(x)


class CondConvNet(ConvNet):
"""Conditional ConvNet that appends a context channel to the input.

Mirrors `ConvNet` but increases the input channel count by one and expands
the supplied `context` tensor to match the spatial topology of `x` before
concatenation.
"""

def __init__(
self,
in_dims: Iterable[int],
c_hidden: List[int],
c_out: int = -1,
nonlinearity: any = nn.ReLU(),
kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
padding: Optional[int] = None,
normalize_layers: bool = True,
gating: bool = True,
**kwargs,
):
# if c_out < 0 we'll let parent compute it, but ConvNet expects c_out>0 or uses c_in
# increase input channels by 1
if not isinstance(in_dims, Iterable):
raise ValueError("in_dims must be an iterable like [C, H, W]")
in_dims = list(in_dims)
in_dims_with_ctx = [in_dims[0] + 1] + in_dims[1:]

super().__init__(
in_dims=in_dims_with_ctx,
c_hidden=c_hidden,
c_out=c_out,
nonlinearity=nonlinearity,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
normalize_layers=normalize_layers,
gating=gating,
**kwargs,
)

# Keep original in_dims for forward-time spatial handling
self._orig_in_dims = in_dims

def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Append expanded context as an extra channel and forward through network.

context can be scalar, vector, or tensor with batch dim. It will be
reshaped/expanded to (batch,1,*spatial) to match `x` before concatenation.
"""
size_in = x.shape
# default context
if context is None:
context = torch.tensor([0.0], device=x.device)
else:
if not isinstance(context, torch.Tensor):
context = torch.tensor(context, device=x.device)

# reshape context to have trailing singleton spatial dims to match x
n_context_dims = len(context.shape)
n_input_dims = len(x.shape)
n_dims = n_input_dims - n_context_dims
if n_dims > 0:
context = context.reshape(*context.shape, *([1] * n_dims))

# spatial dims to expand to
spatial = x.shape[2:]

# now expand to (batch, 1, *spatial). If context already has batch dim, expand will keep it.
try:
context = context.expand(x.shape[0], 1, *spatial)
except Exception:
# final fallback: create zeros
context = torch.zeros((x.shape[0], 1, *spatial), device=x.device)

x = torch.cat([x, context], dim=1)
return self.nn(x)


class CondConvNet2D(ConvNet2D):
def __init__(
self,
c_in: int,
c_hidden: int = 3,
rescale_hidden: int = 2,
c_out: int = -1,
num_layers: int = 3,
nonlinearity: any = nn.ReLU(),
Expand Down Expand Up @@ -294,7 +536,6 @@ def __init__(
super().__init__(
c_in=c_in + 1,
c_hidden=c_hidden,
rescale_hidden=rescale_hidden,
c_out=c_out,
num_layers=num_layers,
nonlinearity=nonlinearity,
Expand Down
Loading