diff --git a/columnformers/models/layers.py b/columnformers/models/layers.py index 8f33fe8..035084c 100644 --- a/columnformers/models/layers.py +++ b/columnformers/models/layers.py @@ -1,9 +1,12 @@ -from typing import Callable, List +import warnings +from typing import Callable, List, Literal, Optional, Tuple import torch import torch.nn.functional as F +from einops import rearrange from timm.layers import trunc_normal_ from torch import nn +from torch.types import _device, _dtype Layer = Callable[..., nn.Module] @@ -280,3 +283,240 @@ def init_weights(module: nn.Module): nn.init.zeros_(module.bias) elif hasattr(module, "init_weights"): module.init_weights() + + +class BlockSparseLinear(nn.Module): + """ + A linear layer with block sparse connectivity. + + Args: + connectivity: a binary tensor of shape (out_features, in_features) representing + the connectivity between input and output units. + bias: use bias + blocksize: sparse block size, e.g. 16, 32. Must divide each dimension of + connectivity + """ + + crow_indices: torch.Tensor + col_indices: torch.Tensor + mask: torch.Tensor + + def __init__( + self, connectivity: torch.Tensor, bias: bool = True, blocksize: int = 32 + ): + super().__init__() + device_capability = _cuda_get_device_capability() + if device_capability is None or device_capability < (8, 0): + warnings.warn( + "BlockSparseLinear only supported for CUDA A100 or higher", + UserWarning, + ) + + self.in_features = connectivity.shape[1] + self.out_features = connectivity.shape[0] + self.blocksize = blocksize + + # convert to torch blocksparse representation if not already + connectivity = connectivity.to_sparse_bsr(blocksize).float() + self.register_buffer("crow_indices", connectivity.crow_indices()) + self.register_buffer("col_indices", connectivity.col_indices()) + self.register_buffer("mask", connectivity.values()) + + n_blocks = (self.out_features // blocksize) * (self.in_features // blocksize) + nnz_blocks = connectivity.values().size(0) + self.sparsity = 1 - (nnz_blocks / n_blocks) + + # The weight parameter is just the sparse bsr values. We construct the sparse + # bsr tensor on the fly. Trying to use a sparse bsr layout parameter doesn't + # work. Mapping the model to cuda fails. The tensor container is mapped but not + # the underlying values. + self.weight = nn.Parameter(torch.empty_like(connectivity.values())) + if bias: + self.bias = nn.Parameter(torch.empty(self.out_features)) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self): + trunc_normal_(self.weight, std=0.02) + self.weight.data.mul_(self.mask) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # apply sparse connectivity mask + weight = self.mask * self.weight + weight = torch.sparse_bsr_tensor( + crow_indices=self.crow_indices, + col_indices=self.col_indices, + values=weight, + size=(self.out_features, self.in_features), + ) + x = F.linear(x, weight, self.bias) + return x + + def extra_repr(self) -> str: + return ( + f"{self.in_features}, {self.out_features}, " + f"bias={self.bias is not None}, blocksize={self.blocksize}, " + f"sparsity={self.sparsity:.2f}" + ) + + +class BlockSparseLocallyConnected(nn.Module): + """ + A locally connected layer implemented using block sparse linear. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + height: int, + depthwise: bool = False, + bias: bool = True, + blocksize: int = 32, + in_shape: Literal["nlc", "nchw", "nd"] = "nchw", + ): + super().__init__() + assert isinstance(kernel_size, int), "only square kernels supported" + self.in_channels = in_channels + self.out_channels = out_channels + self.height = height + self.kernel_size = kernel_size + self.depthwise = depthwise + self.blocksize = blocksize + self.in_shape = in_shape + self.channels_last = not depthwise + + connectivity = _sparse_local_connectivity( + in_channels, + out_channels, + kernel_size, + height, + depthwise=depthwise, + channels_last=self.channels_last, + ) + self.bsl = BlockSparseLinear( + connectivity=connectivity, bias=bias, blocksize=blocksize + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + needs_reshape = self.in_shape != "nd" + + if needs_reshape: + in_pattern = "n (h w) c" if self.in_shape == "nlc" else "n c h w" + out_pattern = "n (h w c)" if self.channels_last else "n (c h w)" + input = rearrange( + input, f"{in_pattern} -> {out_pattern}", h=self.height, w=self.height + ) + + output = self.bsl(input) + + if needs_reshape: + output = rearrange( + output, + f"{out_pattern} -> {in_pattern}", + c=self.out_channels, + h=self.height, + w=self.height, + ) + return output + + +def _sparse_local_connectivity( + in_channels: int, + out_channels: int, + kernel_size: int, + height: int, + depthwise: bool = False, + channels_last: bool = False, + dtype: _dtype = None, + device: _device = None, +) -> torch.Tensor: + """ + Construct sparse local connectivity matrix, shape + (out_channels * height * height, in_channels * height * height). The returned + connectivity will have sparse COO layout. + + The connectivity pattern is equivalent to + `nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding="same")` + + If channels_last is True, the shape of connectivity in einops notation is + "(h w cout) (h w cin)". Otherwise, it is "(cout h w) (cin h w)". The latter should + be more efficient when depthwise is True (connectivity will be more block sparse). + """ + assert kernel_size % 2 == 1, "kernel_size must be odd" + assert ( + not depthwise or out_channels == in_channels + ), "in channels must match out channels for depthwise" + N = height * height + + # ij indices of input grid + # (h^2, 2) + col_indices = torch.cartesian_prod(torch.arange(height), torch.arange(height)) + + # conv kernel index offsets. note that the kernel width is required to be odd. + # (k^2, 2) + kernel_half_width = (kernel_size - 1) // 2 + kernel_indices = torch.cartesian_prod( + torch.arange(-kernel_half_width, kernel_half_width + 1), + torch.arange(-kernel_half_width, kernel_half_width + 1), + ) + + # input edge indices for each output unit. these will be the column indices for the + # sparse COO connectivity. + # (h^2, k^2, 2) + col_indices = col_indices.unsqueeze(1) + kernel_indices.unsqueeze(0) + + # input edge row indices + # (h^2, k^2) + row_indices = torch.arange(N).unsqueeze(1).repeat(1, kernel_size**2) + + # exclude edges falling outside grid + mask = ((col_indices >= 0) & (col_indices < height)).all(axis=-1) + col_indices = col_indices[mask] + row_indices = row_indices[mask] + + # rasterize column indices + col_indices = height * col_indices[..., 0] + col_indices[..., 1] + + # add channel blocks with full or depthwise (diagonal) connectivity + if depthwise: + channel_indices = torch.arange(out_channels).unsqueeze(1).repeat(1, 2) + else: + channel_indices = torch.cartesian_prod( + torch.arange(out_channels), torch.arange(in_channels) + ) + + # we can insert the channels axis either at the front or the back + # front is better for depthwise=True, back is better for depthwise=False + if channels_last: + row_indices = out_channels * row_indices.unsqueeze(1) + channel_indices[:, 0] + col_indices = in_channels * col_indices.unsqueeze(1) + channel_indices[:, 1] + else: + row_indices = N * channel_indices[:, 0].unsqueeze(1) + row_indices + col_indices = N * channel_indices[:, 1].unsqueeze(1) + col_indices + + row_indices = row_indices.flatten() + col_indices = col_indices.flatten() + + # construct sparse connectivity tensor + connectivity = ( + torch.sparse_coo_tensor( + torch.stack([row_indices, col_indices]), + torch.ones(len(row_indices), dtype=dtype), + size=(out_channels * N, in_channels * N), + ) + .coalesce() + .to(device) + ) + return connectivity + + +def _cuda_get_device_capability() -> Optional[Tuple[int, int]]: + if not torch.cuda.is_available(): + return None + return torch.cuda.get_device_capability() diff --git a/tests/test_models/test_layers.py b/tests/test_models/test_layers.py new file mode 100644 index 0000000..f388183 --- /dev/null +++ b/tests/test_models/test_layers.py @@ -0,0 +1,54 @@ +import pytest + +import logging +import torch +from torch import nn + +import columnformers.models.layers as L + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda not available") +def test_block_sparse_locally_connected(): + torch.manual_seed(42) + device = torch.device("cuda") + + loc = L.BlockSparseLocallyConnected( + in_channels=8, + out_channels=16, + kernel_size=3, + height=16, + depthwise=False, + ) + logging.info("%s", loc) + + conv = nn.Conv2d( + in_channels=8, + out_channels=16, + kernel_size=3, + stride=1, + padding="same", + ) + + nn.init.ones_(loc.bsl.weight) + nn.init.zeros_(loc.bsl.bias) + nn.init.ones_(conv.weight) + nn.init.zeros_(conv.bias) + + loc = loc.to(device) + conv = conv.to(device) + + input = torch.randn((2, 8, 16, 16), device=device) + input_loc = input.clone().requires_grad_(True) + input_conv = input.clone().requires_grad_(True) + + output_loc = loc(input_loc) + output_conv = conv(input_conv) + assert torch.allclose(output_loc, output_conv, rtol=1e-4) + + loss_loc = (output_loc**2).mean() + loss_conv = (output_conv**2).mean() + loss_loc.backward() + loss_conv.backward() + grad_loc = input_loc.grad.data + grad_conv = input_conv.grad.data + assert torch.allclose(grad_loc, grad_conv, rtol=1e-4)