From 55888c52c0a3cd91a728978a6f594acda84e53f5 Mon Sep 17 00:00:00 2001 From: Faried Abu Zaid Date: Wed, 10 Sep 2025 12:15:48 +0200 Subject: [PATCH] Bigfix: handle 1D case correctly --- src/usflows/networks.py | 231 ++++++++++++++++++++++++++++------------ 1 file changed, 163 insertions(+), 68 deletions(-) diff --git a/src/usflows/networks.py b/src/usflows/networks.py index dba12d3..833ed80 100644 --- a/src/usflows/networks.py +++ b/src/usflows/networks.py @@ -203,10 +203,60 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x + res +class LayerNormVector(nn.Module): + """LayerNorm for vector inputs shaped (batch, features).""" + def __init__(self, features: int, eps: float = 1e-5): + super().__init__() + self.layernorm = nn.LayerNorm(features, eps=eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Accept (batch, features) or (batch, features, 1) or (1, batch, features) + if x.dim() == 3 and x.shape[-1] == 1: + x = x.view(x.shape[0], x.shape[1]) + if x.dim() == 3 and x.shape[0] == 1 and x.shape[2] != 1: + # shape (1, B, C) -> (B, C) + x = x.permute(1, 2, 0).contiguous().view(x.shape[1], x.shape[2]) + return self.layernorm(x) + + +class GatedMLP(nn.Module): + """Gated residual MLP block analogous to GatedConvND for vector inputs.""" + def __init__(self, in_features: int, out_features: int, nonlinearity: callable = nn.ReLU()): + super().__init__() + self.net1 = nn.Sequential( + nonlinearity, + nn.Linear(in_features, out_features), + nonlinearity, + nn.Linear(out_features, 2 * out_features), + ) + # projection for residual if dims differ + if in_features != out_features: + self.proj = nn.Linear(in_features, out_features) + else: + self.proj = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (batch, features) + out = self.net1(x) + val, gate = out.chunk(2, dim=1) + res = val * torch.sigmoid(gate) + if self.proj is not None: + x = self.proj(x) + return x + res + + class ConvNet(nn.Module): - """Generic ConvNet that adapts to input topology (1D/2D/3D) using `in_dims`. + """Generic ConvNet that adapts to input topology (vector/1D/2D/3D) using `in_dims`. + + - Vector case (len(in_dims) == 1): builds an MLP path using GatedMLP blocks, + LayerNormVector for normalization and Linear layers for in/out projections. + Conv-specific kwargs are ignored for this case. + - Spatial case (len(in_dims) > 1): preserves existing ConvND behavior using + GatedConvND, LayerNormChannelsND etc. - c_hidden: list[int] specifying channel size of each hidden layer (length replaces num_layers) + The forward normalizes a few common input layouts for backward compatibility: + - vector path accepts (batch, features), (batch, features, 1), (1, batch, features) + - conv path accepts (batch, channels, width/height/...) and also (L, batch, channels) or (1, batch, channels) """ def __init__( @@ -227,84 +277,129 @@ def __init__( if padding is None: padding = kernel_size // 2 - # infer channels and input rank from in_dims if not isinstance(in_dims, Iterable): - raise ValueError("in_dims must be an iterable like [C, H, W]") + raise ValueError("in_dims must be an iterable like [C, H, W] or [C] for vector") in_dims = list(in_dims) - c_in = in_dims[0] - input_rank = max(1, len(in_dims) - 1) + c_in = int(in_dims[0]) + is_vector = len(in_dims) == 1 c_out = c_out if c_out > 0 else c_in - conv_map = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d} - if input_rank not in conv_map: - raise ValueError(f"Unsupported input rank {input_rank}") - Conv = conv_map[input_rank] - - layers = [] - # initial conv from c_in -> first hidden - assert len(c_hidden) > 0 and all([h > 0 for h in c_hidden]), "c_hidden must be non-empty list of positive ints" - first_hidden = c_hidden[0] - layers.append( - Conv( - c_in, - first_hidden, - kernel_size=kernel_size, - padding=padding, - stride=stride, - dilation=dilation, + if is_vector: + # Build MLP / vector path + assert len(c_hidden) > 0 and all([h > 0 for h in c_hidden]), "c_hidden must be non-empty list of positive ints" + layers = [] + # initial linear projection + first_hidden = int(c_hidden[0]) + layers.append(nn.Linear(c_in, first_hidden)) + # hidden blocks + for i in range(len(c_hidden)): + in_ch = int(c_hidden[i - 1]) if i > 0 else first_hidden + out_ch = int(c_hidden[i]) + if gating: + layers.append(GatedMLP(in_ch, out_ch, nonlinearity=nonlinearity)) + else: + layers.append(nn.Sequential(nonlinearity, nn.Linear(in_ch, out_ch))) + if normalize_layers: + layers.append(LayerNormVector(out_ch)) + # final linear to c_out + layers.append(nn.Linear(int(c_hidden[-1]), c_out)) + self.nn = nn.Sequential(*layers) + self.is_vector = True + self._vector_in_features = c_in + else: + # Spatial / conv path: preserve existing ConvND behavior + input_rank = max(1, len(in_dims) - 1) + conv_map = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d} + if input_rank not in conv_map: + raise ValueError(f"Unsupported input rank {input_rank}") + Conv = conv_map[input_rank] + + assert len(c_hidden) > 0 and all([h > 0 for h in c_hidden]), "c_hidden must be non-empty list of positive ints" + first_hidden = c_hidden[0] + layers = [] + layers.append( + Conv( + c_in, + first_hidden, + kernel_size=kernel_size, + padding=padding, + stride=stride, + dilation=dilation, + ) ) - ) - # hidden blocks - for i in range(len(c_hidden)): - in_ch = c_hidden[i - 1] if i > 0 else first_hidden - out_ch = c_hidden[i] - if gating: - layers.append( - GatedConvND( - in_ch, - out_ch, - kernel_size=kernel_size, - padding=padding, - stride=stride, - dilation=dilation, - nonlinearity=nonlinearity, - input_rank=input_rank, + for i in range(len(c_hidden)): + in_ch = c_hidden[i - 1] if i > 0 else first_hidden + out_ch = c_hidden[i] + if gating: + layers.append( + GatedConvND( + in_ch, + out_ch, + kernel_size=kernel_size, + padding=padding, + stride=stride, + dilation=dilation, + nonlinearity=nonlinearity, + input_rank=input_rank, + ) ) - ) - layers.append(nonlinearity) - else: - layers.append( - Conv( - in_ch, - out_ch, - kernel_size=kernel_size, - padding=padding, - stride=stride, - dilation=dilation, + layers.append(nonlinearity) + else: + layers.append( + Conv( + in_ch, + out_ch, + kernel_size=kernel_size, + padding=padding, + stride=stride, + dilation=dilation, + ) ) + layers.append(nonlinearity) + + if normalize_layers: + layers.append(LayerNormChannelsND(out_ch, num_spatial_dims=input_rank)) + + # final conv from last hidden -> c_out + layers.append( + Conv( + c_hidden[-1], + c_out, + kernel_size=kernel_size, + padding=padding, + stride=stride, + dilation=dilation, ) - layers.append(nonlinearity) - - if normalize_layers: - layers.append(LayerNormChannelsND(out_ch, num_spatial_dims=input_rank)) - - # final conv from last hidden -> c_out - layers.append( - Conv( - c_hidden[-1], - c_out, - kernel_size=kernel_size, - padding=padding, - stride=stride, - dilation=dilation, ) - ) - - self.nn = nn.Sequential(*layers) + self.nn = nn.Sequential(*layers) + self.is_vector = False + self._spatial_rank = input_rank def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: - return self.nn(x) + if self.is_vector: + # Accept (batch, features) or (batch, features, 1) or (1, batch, features) + if x.dim() == 3 and x.shape[-1] == 1: + x = x.view(x.shape[0], x.shape[1]) + if x.dim() == 3 and x.shape[0] == 1 and x.shape[2] != 1: + x = x.permute(1, 2, 0).contiguous().view(x.shape[1], x.shape[2]) + # final ensure shape (batch, features) + if x.dim() != 2: + x = x.view(x.shape[0], -1) + return self.nn(x) + else: + # conv path: normalize shapes to (batch, channels, width/...) + if x.dim() == 3: + # detect (L, B, C) style where last dim matches channel count + if x.shape[2] == int(getattr(self, "_spatial_rank", 1) and self.nn[0].in_channels) and x.shape[0] != x.shape[1]: + try: + x = x.permute(1, 2, 0).contiguous() + except Exception: + pass + # also handle (1, B, C) + elif x.shape[0] == 1 and x.shape[2] == self.nn[0].in_channels: + x = x.permute(1, 2, 0).contiguous() + return self.nn(x) class ConvNet2D(nn.Module):