Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 163 additions & 68 deletions src/usflows/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,60 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + res


class LayerNormVector(nn.Module):
"""LayerNorm for vector inputs shaped (batch, features)."""
def __init__(self, features: int, eps: float = 1e-5):
super().__init__()
self.layernorm = nn.LayerNorm(features, eps=eps)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Accept (batch, features) or (batch, features, 1) or (1, batch, features)
if x.dim() == 3 and x.shape[-1] == 1:
x = x.view(x.shape[0], x.shape[1])
if x.dim() == 3 and x.shape[0] == 1 and x.shape[2] != 1:
# shape (1, B, C) -> (B, C)
x = x.permute(1, 2, 0).contiguous().view(x.shape[1], x.shape[2])
return self.layernorm(x)


class GatedMLP(nn.Module):
"""Gated residual MLP block analogous to GatedConvND for vector inputs."""
def __init__(self, in_features: int, out_features: int, nonlinearity: callable = nn.ReLU()):
super().__init__()
self.net1 = nn.Sequential(
nonlinearity,
nn.Linear(in_features, out_features),
nonlinearity,
nn.Linear(out_features, 2 * out_features),
)
# projection for residual if dims differ
if in_features != out_features:
self.proj = nn.Linear(in_features, out_features)
else:
self.proj = None

def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (batch, features)
out = self.net1(x)
val, gate = out.chunk(2, dim=1)
res = val * torch.sigmoid(gate)
if self.proj is not None:
x = self.proj(x)
return x + res


class ConvNet(nn.Module):
"""Generic ConvNet that adapts to input topology (1D/2D/3D) using `in_dims`.
"""Generic ConvNet that adapts to input topology (vector/1D/2D/3D) using `in_dims`.

- Vector case (len(in_dims) == 1): builds an MLP path using GatedMLP blocks,
LayerNormVector for normalization and Linear layers for in/out projections.
Conv-specific kwargs are ignored for this case.
- Spatial case (len(in_dims) > 1): preserves existing ConvND behavior using
GatedConvND, LayerNormChannelsND etc.

c_hidden: list[int] specifying channel size of each hidden layer (length replaces num_layers)
The forward normalizes a few common input layouts for backward compatibility:
- vector path accepts (batch, features), (batch, features, 1), (1, batch, features)
- conv path accepts (batch, channels, width/height/...) and also (L, batch, channels) or (1, batch, channels)
"""

def __init__(
Expand All @@ -227,84 +277,129 @@ def __init__(
if padding is None:
padding = kernel_size // 2

# infer channels and input rank from in_dims
if not isinstance(in_dims, Iterable):
raise ValueError("in_dims must be an iterable like [C, H, W]")
raise ValueError("in_dims must be an iterable like [C, H, W] or [C] for vector")
in_dims = list(in_dims)
c_in = in_dims[0]
input_rank = max(1, len(in_dims) - 1)
c_in = int(in_dims[0])
is_vector = len(in_dims) == 1
c_out = c_out if c_out > 0 else c_in

conv_map = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d}
if input_rank not in conv_map:
raise ValueError(f"Unsupported input rank {input_rank}")
Conv = conv_map[input_rank]

layers = []
# initial conv from c_in -> first hidden
assert len(c_hidden) > 0 and all([h > 0 for h in c_hidden]), "c_hidden must be non-empty list of positive ints"
first_hidden = c_hidden[0]
layers.append(
Conv(
c_in,
first_hidden,
kernel_size=kernel_size,
padding=padding,
stride=stride,
dilation=dilation,
if is_vector:
# Build MLP / vector path
assert len(c_hidden) > 0 and all([h > 0 for h in c_hidden]), "c_hidden must be non-empty list of positive ints"
layers = []
# initial linear projection
first_hidden = int(c_hidden[0])
layers.append(nn.Linear(c_in, first_hidden))
# hidden blocks
for i in range(len(c_hidden)):
in_ch = int(c_hidden[i - 1]) if i > 0 else first_hidden
out_ch = int(c_hidden[i])
if gating:
layers.append(GatedMLP(in_ch, out_ch, nonlinearity=nonlinearity))
else:
layers.append(nn.Sequential(nonlinearity, nn.Linear(in_ch, out_ch)))
if normalize_layers:
layers.append(LayerNormVector(out_ch))
# final linear to c_out
layers.append(nn.Linear(int(c_hidden[-1]), c_out))
self.nn = nn.Sequential(*layers)
self.is_vector = True
self._vector_in_features = c_in
else:
# Spatial / conv path: preserve existing ConvND behavior
input_rank = max(1, len(in_dims) - 1)
conv_map = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d}
if input_rank not in conv_map:
raise ValueError(f"Unsupported input rank {input_rank}")
Conv = conv_map[input_rank]

assert len(c_hidden) > 0 and all([h > 0 for h in c_hidden]), "c_hidden must be non-empty list of positive ints"
first_hidden = c_hidden[0]
layers = []
layers.append(
Conv(
c_in,
first_hidden,
kernel_size=kernel_size,
padding=padding,
stride=stride,
dilation=dilation,
)
)
)

# hidden blocks
for i in range(len(c_hidden)):
in_ch = c_hidden[i - 1] if i > 0 else first_hidden
out_ch = c_hidden[i]
if gating:
layers.append(
GatedConvND(
in_ch,
out_ch,
kernel_size=kernel_size,
padding=padding,
stride=stride,
dilation=dilation,
nonlinearity=nonlinearity,
input_rank=input_rank,
for i in range(len(c_hidden)):
in_ch = c_hidden[i - 1] if i > 0 else first_hidden
out_ch = c_hidden[i]
if gating:
layers.append(
GatedConvND(
in_ch,
out_ch,
kernel_size=kernel_size,
padding=padding,
stride=stride,
dilation=dilation,
nonlinearity=nonlinearity,
input_rank=input_rank,
)
)
)
layers.append(nonlinearity)
else:
layers.append(
Conv(
in_ch,
out_ch,
kernel_size=kernel_size,
padding=padding,
stride=stride,
dilation=dilation,
layers.append(nonlinearity)
else:
layers.append(
Conv(
in_ch,
out_ch,
kernel_size=kernel_size,
padding=padding,
stride=stride,
dilation=dilation,
)
)
layers.append(nonlinearity)

if normalize_layers:
layers.append(LayerNormChannelsND(out_ch, num_spatial_dims=input_rank))

# final conv from last hidden -> c_out
layers.append(
Conv(
c_hidden[-1],
c_out,
kernel_size=kernel_size,
padding=padding,
stride=stride,
dilation=dilation,
)
layers.append(nonlinearity)

if normalize_layers:
layers.append(LayerNormChannelsND(out_ch, num_spatial_dims=input_rank))

# final conv from last hidden -> c_out
layers.append(
Conv(
c_hidden[-1],
c_out,
kernel_size=kernel_size,
padding=padding,
stride=stride,
dilation=dilation,
)
)

self.nn = nn.Sequential(*layers)
self.nn = nn.Sequential(*layers)
self.is_vector = False
self._spatial_rank = input_rank

def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor:
return self.nn(x)
if self.is_vector:
# Accept (batch, features) or (batch, features, 1) or (1, batch, features)
if x.dim() == 3 and x.shape[-1] == 1:
x = x.view(x.shape[0], x.shape[1])
if x.dim() == 3 and x.shape[0] == 1 and x.shape[2] != 1:
x = x.permute(1, 2, 0).contiguous().view(x.shape[1], x.shape[2])
# final ensure shape (batch, features)
if x.dim() != 2:
x = x.view(x.shape[0], -1)
return self.nn(x)
else:
# conv path: normalize shapes to (batch, channels, width/...)
if x.dim() == 3:
# detect (L, B, C) style where last dim matches channel count
if x.shape[2] == int(getattr(self, "_spatial_rank", 1) and self.nn[0].in_channels) and x.shape[0] != x.shape[1]:
try:
x = x.permute(1, 2, 0).contiguous()
except Exception:
pass
# also handle (1, B, C)
elif x.shape[0] == 1 and x.shape[2] == self.nn[0].in_channels:
x = x.permute(1, 2, 0).contiguous()
return self.nn(x)


class ConvNet2D(nn.Module):
Expand Down
Loading