From 1904014c25780cced8c6abaea4dcb655a4f95832 Mon Sep 17 00:00:00 2001 From: Faried Abu Zaid Date: Tue, 9 Sep 2025 14:17:57 +0200 Subject: [PATCH] Implement ConvNet (topology adaptive) --- src/usflows/networks.py | 293 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 267 insertions(+), 26 deletions(-) diff --git a/src/usflows/networks.py b/src/usflows/networks.py index 6e83904..dba12d3 100644 --- a/src/usflows/networks.py +++ b/src/usflows/networks.py @@ -121,12 +121,197 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return ret +class LayerNormChannelsND(nn.Module): + """Channel-wise LayerNorm for N-D spatial tensors (batch, channel, *spatial) + Creates gamma and beta parameters shaped (1, C, 1, ..., 1) with `num_spatial_dims` trailing ones. + """ + + def __init__(self, c_in, num_spatial_dims: int = 2, eps=1e-5): + super().__init__() + shape = (1, c_in) + (1,) * num_spatial_dims + self.gamma = nn.Parameter(torch.ones(*shape)) + self.beta = nn.Parameter(torch.zeros(*shape)) + self.eps = eps + + def forward(self, x): + # mean/var over channels + mean = x.mean(dim=1, keepdim=True) + var = x.var(dim=1, unbiased=False, keepdim=True) + y = (x - mean) / torch.sqrt(var + self.eps) + y = y * self.gamma + self.beta + return y + + +class GatedConvND(nn.Module): + """Gated residual block that adapts to 1/2/3-D convolutions. + The residual is automatically projected if channel counts differ. + """ + + def __init__( + self, + c_in, + c_out, + kernel_size=3, + padding=1, + stride=1, + dilation=1, + nonlinearity: callable = nn.ReLU(), + input_rank: int = 2, + ): + super().__init__() + assert stride == 1, "Stride > 1 cannot be used to skip connection." + + conv_map = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d} + if input_rank not in conv_map: + raise ValueError(f"Unsupported input rank {input_rank}") + Conv = conv_map[input_rank] + + # first conv reduces/increases to intermediate channels (we choose c_out) + self.net = nn.Sequential( + nonlinearity, + Conv( + c_in, + c_out, + kernel_size=kernel_size, + padding=padding, + stride=stride, + dilation=dilation, + ), + nonlinearity, + # final 1x1 conv produces 2 * c_out channels to split into val/gate + Conv( + c_out, + 2 * c_out, + kernel_size=1, + padding=0, + stride=1, + ), + ) + + # projection for residual if channel counts differ + if c_in != c_out: + self.proj = Conv(c_in, c_out, kernel_size=1, padding=0) + else: + self.proj = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.net(x) + val, gate = out.chunk(2, dim=1) + res = val * torch.sigmoid(gate) + if self.proj is not None: + x = self.proj(x) + return x + res + + +class ConvNet(nn.Module): + """Generic ConvNet that adapts to input topology (1D/2D/3D) using `in_dims`. + + c_hidden: list[int] specifying channel size of each hidden layer (length replaces num_layers) + """ + + def __init__( + self, + in_dims: Iterable[int], + c_hidden: List[int], + c_out: int = -1, + nonlinearity: any = nn.ReLU(), + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + padding: Optional[int] = None, + normalize_layers: bool = True, + gating: bool = True, + ): + super().__init__() + + if padding is None: + padding = kernel_size // 2 + + # infer channels and input rank from in_dims + if not isinstance(in_dims, Iterable): + raise ValueError("in_dims must be an iterable like [C, H, W]") + in_dims = list(in_dims) + c_in = in_dims[0] + input_rank = max(1, len(in_dims) - 1) + c_out = c_out if c_out > 0 else c_in + + conv_map = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d} + if input_rank not in conv_map: + raise ValueError(f"Unsupported input rank {input_rank}") + Conv = conv_map[input_rank] + + layers = [] + # initial conv from c_in -> first hidden + assert len(c_hidden) > 0 and all([h > 0 for h in c_hidden]), "c_hidden must be non-empty list of positive ints" + first_hidden = c_hidden[0] + layers.append( + Conv( + c_in, + first_hidden, + kernel_size=kernel_size, + padding=padding, + stride=stride, + dilation=dilation, + ) + ) + + # hidden blocks + for i in range(len(c_hidden)): + in_ch = c_hidden[i - 1] if i > 0 else first_hidden + out_ch = c_hidden[i] + if gating: + layers.append( + GatedConvND( + in_ch, + out_ch, + kernel_size=kernel_size, + padding=padding, + stride=stride, + dilation=dilation, + nonlinearity=nonlinearity, + input_rank=input_rank, + ) + ) + layers.append(nonlinearity) + else: + layers.append( + Conv( + in_ch, + out_ch, + kernel_size=kernel_size, + padding=padding, + stride=stride, + dilation=dilation, + ) + ) + layers.append(nonlinearity) + + if normalize_layers: + layers.append(LayerNormChannelsND(out_ch, num_spatial_dims=input_rank)) + + # final conv from last hidden -> c_out + layers.append( + Conv( + c_hidden[-1], + c_out, + kernel_size=kernel_size, + padding=padding, + stride=stride, + dilation=dilation, + ) + ) + + self.nn = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: + return self.nn(x) + + class ConvNet2D(nn.Module): def __init__( self, c_in: int, c_hidden: int = 3, - rescale_hidden: int = 2, c_out: int = -1, num_layers: int = 3, nonlinearity: any = nn.ReLU(), @@ -173,8 +358,6 @@ def __init__( dilation=dilation, ), ] - if rescale_hidden != 1: - layers += [nn.MaxPool2d(rescale_hidden)] for layer_index in range(num_layers): if gating: @@ -208,27 +391,6 @@ def __init__( ] # compute padding and output padding for rescaling via transposed convolutions - if rescale_hidden != 1: - diff = rescale_hidden - kernel_size - if diff < 0: - outpad = diff % 2 - pad = ceil(abs(diff) / 2.0) - else: - outpad = diff - pad = 0 - - layers += [ - nn.ConvTranspose2d( - c_hidden, - c_hidden, - kernel_size=kernel_size, - stride=stride, - output_padding=outpad, - padding=pad, - ), - nonlinearity, - ] - layers += [ nn.Conv2d( c_hidden, @@ -253,12 +415,92 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor return self.nn(x) +class CondConvNet(ConvNet): + """Conditional ConvNet that appends a context channel to the input. + + Mirrors `ConvNet` but increases the input channel count by one and expands + the supplied `context` tensor to match the spatial topology of `x` before + concatenation. + """ + + def __init__( + self, + in_dims: Iterable[int], + c_hidden: List[int], + c_out: int = -1, + nonlinearity: any = nn.ReLU(), + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + padding: Optional[int] = None, + normalize_layers: bool = True, + gating: bool = True, + **kwargs, + ): + # if c_out < 0 we'll let parent compute it, but ConvNet expects c_out>0 or uses c_in + # increase input channels by 1 + if not isinstance(in_dims, Iterable): + raise ValueError("in_dims must be an iterable like [C, H, W]") + in_dims = list(in_dims) + in_dims_with_ctx = [in_dims[0] + 1] + in_dims[1:] + + super().__init__( + in_dims=in_dims_with_ctx, + c_hidden=c_hidden, + c_out=c_out, + nonlinearity=nonlinearity, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + normalize_layers=normalize_layers, + gating=gating, + **kwargs, + ) + + # Keep original in_dims for forward-time spatial handling + self._orig_in_dims = in_dims + + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: + """Append expanded context as an extra channel and forward through network. + + context can be scalar, vector, or tensor with batch dim. It will be + reshaped/expanded to (batch,1,*spatial) to match `x` before concatenation. + """ + size_in = x.shape + # default context + if context is None: + context = torch.tensor([0.0], device=x.device) + else: + if not isinstance(context, torch.Tensor): + context = torch.tensor(context, device=x.device) + + # reshape context to have trailing singleton spatial dims to match x + n_context_dims = len(context.shape) + n_input_dims = len(x.shape) + n_dims = n_input_dims - n_context_dims + if n_dims > 0: + context = context.reshape(*context.shape, *([1] * n_dims)) + + # spatial dims to expand to + spatial = x.shape[2:] + + # now expand to (batch, 1, *spatial). If context already has batch dim, expand will keep it. + try: + context = context.expand(x.shape[0], 1, *spatial) + except Exception: + # final fallback: create zeros + context = torch.zeros((x.shape[0], 1, *spatial), device=x.device) + + x = torch.cat([x, context], dim=1) + return self.nn(x) + + class CondConvNet2D(ConvNet2D): def __init__( self, c_in: int, c_hidden: int = 3, - rescale_hidden: int = 2, c_out: int = -1, num_layers: int = 3, nonlinearity: any = nn.ReLU(), @@ -294,7 +536,6 @@ def __init__( super().__init__( c_in=c_in + 1, c_hidden=c_hidden, - rescale_hidden=rescale_hidden, c_out=c_out, num_layers=num_layers, nonlinearity=nonlinearity,