From d0ec304e319150b6d8f576c1baf64c2dc2945aa8 Mon Sep 17 00:00:00 2001 From: Connor Lane Date: Wed, 1 May 2024 12:52:44 -0400 Subject: [PATCH 1/9] Initial blocksparse linear outline --- columnformers/models/layers.py | 81 ++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/columnformers/models/layers.py b/columnformers/models/layers.py index 8f33fe8..e56502f 100644 --- a/columnformers/models/layers.py +++ b/columnformers/models/layers.py @@ -5,6 +5,13 @@ from timm.layers import trunc_normal_ from torch import nn +try: + from triton.ops.blocksparse import matmul as blocksparse_matmul # noqa + + triton_available = True +except ImportError: + triton_available = False + Layer = Callable[..., nn.Module] @@ -280,3 +287,77 @@ def init_weights(module: nn.Module): nn.init.zeros_(module.bias) elif hasattr(module, "init_weights"): module.init_weights() + + +class BlockSparseLinear(nn.Module): + """ + A linear layer with block sparse connectivity. + + Args: + connectivity: a binary tensor of shape (out_features, in_features) representing + the connectivity between input and output units. + bias: use bias + blocksize: sparse block size, e.g. 16, 32. Must divide each dimension of + connectivity + + TODO: + [ ] initialize weight and bias. weight should be masked by connectivity at init. + think about what the appropriate init std should be. + [ ] create a dsd sparse matmul kernel following xformers.BlockSparseAttention: + https://github.com/facebookresearch/xformers/blob/fad50d49834ab18dd137acc727bd4d567ff17842/xformers/components/attention/blocksparse.py#L96 + [ ] implement forward that should mask weight by connectivity and then call the + blocksparse matmul kernel + """ + + def __init__( + self, connectivity: torch.Tensor, bias: bool = True, blocksize: int = 16 + ): + assert triton_available, "blocksparse linear requires triton" + super().__init__() + self.in_features = connectivity.shape[1] + self.out_features = connectivity.shape[0] + self.blocksize = blocksize + + # convert to torch blocksparse representation if not already + connectivity = connectivity.to_sparse_bsr(blocksize) + + # block sparse layout as expected by triton + # shape (1, out_features // block, in_features // block) + # must be dtype int64 + layout = torch.sparse_csr_tensor( + connectivity.crow_indices(), + connectivity.col_indices(), + torch.ones_like(connectivity.col_indices()), + ) + layout = layout.to_dense().unsqueeze(0) + + # only keep raw values, don't need indices since we have layout + # shape (nnz_blocks, block, block) + connectivity = (connectivity.values() > 0).float() + + self.register_buffer("connectivity", connectivity) + self.register_buffer("layout", layout) + + # TODO: initialize weight and bias + + def reset_parameters(self): + raise NotImplementedError + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def extra_repr(self) -> str: + return ( + f"{self.in_features}, {self.out_features}, " + f"bias={self.bias is not None}, blocksize={self.blocksize}" + ) + + +class BlockSparseLocallyConnected(nn.Module): + """ + A locally connected layer implemented using block sparse linear. + + TODO: main step is just computing the connectivity based on conv params. shape + should be something like: (out_height * out_width * out_channels, in_height * + in_width * in_channels). Then we just use BlockSparseLinear. + """ From bc7e7484c9f90a88d0bc509dfc07ffcbf2e218c5 Mon Sep 17 00:00:00 2001 From: alismil Date: Thu, 6 Jun 2024 19:06:35 +0100 Subject: [PATCH 2/9] BlockSparseLinear complete --- columnformers/models/layers.py | 131 ++++++++++++++------------------- 1 file changed, 54 insertions(+), 77 deletions(-) diff --git a/columnformers/models/layers.py b/columnformers/models/layers.py index e56502f..fcd9dfc 100644 --- a/columnformers/models/layers.py +++ b/columnformers/models/layers.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F +import torch.nn.utils.prune as prune from timm.layers import trunc_normal_ from torch import nn @@ -67,17 +68,17 @@ def __init__( self, in_features: int, out_features: int, - coef: "MixtureCoefficients", + rank: int = 16, bias: bool = True, ): super().__init__() self.in_features = in_features self.out_features = out_features - self.coef = coef + self.rank = rank - self.weight = nn.Parameter(torch.empty((out_features, in_features, coef.rank))) + self.weight = nn.Parameter(torch.empty((out_features, in_features, rank))) if bias: - self.bias = nn.Parameter(torch.empty(out_features, coef.rank)) + self.bias = nn.Parameter(torch.empty(out_features, rank)) else: self.register_parameter("bias", None) self.reset_parameters() @@ -87,10 +88,9 @@ def reset_parameters(self) -> None: if self.bias is not None: nn.init.zeros_(self.bias) - def forward(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, input: torch.Tensor, coef: torch.Tensor) -> torch.Tensor: # input: (B, N, C) # coef: (N, R) - coef = self.coef() # Nb, this implementation for some reason uses significantly fewer flops # compared to equivalent alternatives (e.g. einsum) for some reason. weight = (coef @ self.weight.transpose(1, 2)).transpose(0, 1) @@ -102,7 +102,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return output def extra_repr(self) -> str: - return f"{self.in_features}, {self.out_features}, bias={self.bias is not None}" + return ( + f"{self.in_features}, {self.out_features}, {self.rank}, " + f"bias={self.bias is not None}" + ) class MixtureCoefficients(nn.Module): @@ -198,52 +201,6 @@ def extra_repr(self) -> str: ) -class MixtureLayerNorm(nn.Module): - def __init__( - self, - dim: int, - coef: "MixtureCoefficients", - eps: float = 1e-5, - elementwise_affine: bool = True, - ): - super().__init__() - self.dim = dim - self.eps = eps - self.elementwise_affine = elementwise_affine - self.coef = coef - - if self.elementwise_affine: - self.weight = nn.Parameter(torch.empty(dim, coef.rank)) - self.bias = nn.Parameter(torch.empty(dim, coef.rank)) - else: - self.register_parameter("weight", None) - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self) -> None: - if self.elementwise_affine: - nn.init.ones_(self.weight) - nn.init.zeros_(self.bias) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - input = F.layer_norm(input, (self.dim,), eps=self.eps) - if self.elementwise_affine: - coef = self.coef() - weight = coef @ self.weight.t() - bias = coef @ self.bias.t() - input = input * weight + bias - return input - - def no_weight_decay(self) -> List[str]: - # Nb, not excluded by default since 2d - return ["weight", "bias"] - - def extra_repr(self) -> str: - return ( - f"{self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}" - ) - - class SpatialPool(nn.Module): """ Pool a sequence of features with a learned attention weight per class. @@ -299,14 +256,6 @@ class BlockSparseLinear(nn.Module): bias: use bias blocksize: sparse block size, e.g. 16, 32. Must divide each dimension of connectivity - - TODO: - [ ] initialize weight and bias. weight should be masked by connectivity at init. - think about what the appropriate init std should be. - [ ] create a dsd sparse matmul kernel following xformers.BlockSparseAttention: - https://github.com/facebookresearch/xformers/blob/fad50d49834ab18dd137acc727bd4d567ff17842/xformers/components/attention/blocksparse.py#L96 - [ ] implement forward that should mask weight by connectivity and then call the - blocksparse matmul kernel """ def __init__( @@ -314,37 +263,65 @@ def __init__( ): assert triton_available, "blocksparse linear requires triton" super().__init__() - self.in_features = connectivity.shape[1] - self.out_features = connectivity.shape[0] + self.in_features = connectivity.shape[0] + self.out_features = connectivity.shape[1] self.blocksize = blocksize + self.connectivity = connectivity + + self.linear = nn.Linear(self.in_features, self.out_features, bias=False).to( + self.connectivity.device + ) + + if bias: + self.bias = nn.Parameter(torch.empty(self.out_features)) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + prune.custom_from_mask(self.linear, name="weight", mask=connectivity) # convert to torch blocksparse representation if not already - connectivity = connectivity.to_sparse_bsr(blocksize) + sparse_connectivity = connectivity.to_sparse_bsr(blocksize) # block sparse layout as expected by triton # shape (1, out_features // block, in_features // block) # must be dtype int64 layout = torch.sparse_csr_tensor( - connectivity.crow_indices(), - connectivity.col_indices(), - torch.ones_like(connectivity.col_indices()), + sparse_connectivity.crow_indices(), + sparse_connectivity.col_indices(), + torch.ones_like(sparse_connectivity.col_indices()), ) layout = layout.to_dense().unsqueeze(0) - # only keep raw values, don't need indices since we have layout - # shape (nnz_blocks, block, block) - connectivity = (connectivity.values() > 0).float() - - self.register_buffer("connectivity", connectivity) - self.register_buffer("layout", layout) - - # TODO: initialize weight and bias + self.sparse_dot_dds = blocksparse_matmul( + layout, + blocksize, + "dds", + trans_a=False, + trans_b=False, + device=self.connectivity.device, + ) def reset_parameters(self): - raise NotImplementedError + # TODO: decide how to best init the weights + nn.init.xavier_normal_(self.linear.weight) + if self.bias is not None: + nn.init.zeros_(self.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: - raise NotImplementedError + batch = x.shape[0] + + weight = self.linear.weight + sparse_weight = weight.to_sparse_bsr(self.blocksize).values() + sparse_weight = sparse_weight.unsqueeze(0).repeat(batch, 1, 1, 1) + + x = self.sparse_dot_dds(x.to(torch.float16), sparse_weight.to(torch.float16)) + + if self.bias is not None: + x += self.bias + + return x def extra_repr(self) -> str: return ( From 4489743c9f55b78b2e71ae75b89bfe635e1ff8f9 Mon Sep 17 00:00:00 2001 From: alismil Date: Thu, 6 Jun 2024 19:10:15 +0100 Subject: [PATCH 3/9] added MixtureLayerNorm --- columnformers/models/layers.py | 46 ++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/columnformers/models/layers.py b/columnformers/models/layers.py index fcd9dfc..62fad8e 100644 --- a/columnformers/models/layers.py +++ b/columnformers/models/layers.py @@ -201,6 +201,52 @@ def extra_repr(self) -> str: ) +class MixtureLayerNorm(nn.Module): + def __init__( + self, + dim: int, + coef: "MixtureCoefficients", + eps: float = 1e-5, + elementwise_affine: bool = True, + ): + super().__init__() + self.dim = dim + self.eps = eps + self.elementwise_affine = elementwise_affine + self.coef = coef + + if self.elementwise_affine: + self.weight = nn.Parameter(torch.empty(dim, coef.rank)) + self.bias = nn.Parameter(torch.empty(dim, coef.rank)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = F.layer_norm(input, (self.dim,), eps=self.eps) + if self.elementwise_affine: + coef = self.coef() + weight = coef @ self.weight.t() + bias = coef @ self.bias.t() + input = input * weight + bias + return input + + def no_weight_decay(self) -> List[str]: + # Nb, not excluded by default since 2d + return ["weight", "bias"] + + def extra_repr(self) -> str: + return ( + f"{self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}" + ) + + class SpatialPool(nn.Module): """ Pool a sequence of features with a learned attention weight per class. From d320747bceb4ed6b8b1330d342fcb9e663725ae3 Mon Sep 17 00:00:00 2001 From: alismil Date: Thu, 6 Jun 2024 19:12:03 +0100 Subject: [PATCH 4/9] matched MixtureLinear to main --- columnformers/models/layers.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/columnformers/models/layers.py b/columnformers/models/layers.py index 62fad8e..6c4d147 100644 --- a/columnformers/models/layers.py +++ b/columnformers/models/layers.py @@ -68,17 +68,17 @@ def __init__( self, in_features: int, out_features: int, - rank: int = 16, + coef: "MixtureCoefficients", bias: bool = True, ): super().__init__() self.in_features = in_features self.out_features = out_features - self.rank = rank + self.coef = coef - self.weight = nn.Parameter(torch.empty((out_features, in_features, rank))) + self.weight = nn.Parameter(torch.empty((out_features, in_features, coef.rank))) if bias: - self.bias = nn.Parameter(torch.empty(out_features, rank)) + self.bias = nn.Parameter(torch.empty(out_features, coef.rank)) else: self.register_parameter("bias", None) self.reset_parameters() @@ -88,9 +88,10 @@ def reset_parameters(self) -> None: if self.bias is not None: nn.init.zeros_(self.bias) - def forward(self, input: torch.Tensor, coef: torch.Tensor) -> torch.Tensor: + def forward(self, input: torch.Tensor) -> torch.Tensor: # input: (B, N, C) # coef: (N, R) + coef = self.coef() # Nb, this implementation for some reason uses significantly fewer flops # compared to equivalent alternatives (e.g. einsum) for some reason. weight = (coef @ self.weight.transpose(1, 2)).transpose(0, 1) @@ -102,10 +103,7 @@ def forward(self, input: torch.Tensor, coef: torch.Tensor) -> torch.Tensor: return output def extra_repr(self) -> str: - return ( - f"{self.in_features}, {self.out_features}, {self.rank}, " - f"bias={self.bias is not None}" - ) + return f"{self.in_features}, {self.out_features}, bias={self.bias is not None}" class MixtureCoefficients(nn.Module): From f38640dea6a921f107bd89479fa58c66dcee4de5 Mon Sep 17 00:00:00 2001 From: alismil Date: Mon, 10 Jun 2024 19:54:58 +0100 Subject: [PATCH 5/9] completed BlockSparseLocallyConnected --- columnformers/models/layers.py | 112 +++++++++++++++++++++++++++++++-- 1 file changed, 106 insertions(+), 6 deletions(-) diff --git a/columnformers/models/layers.py b/columnformers/models/layers.py index 6c4d147..2bda7ad 100644 --- a/columnformers/models/layers.py +++ b/columnformers/models/layers.py @@ -1,4 +1,4 @@ -from typing import Callable, List +from typing import Callable, List, Tuple import torch import torch.nn.functional as F @@ -312,7 +312,7 @@ def __init__( self.blocksize = blocksize self.connectivity = connectivity - self.linear = nn.Linear(self.in_features, self.out_features, bias=False).to( + self.linear = nn.Linear(self.out_features, self.in_features, bias=False).to( self.connectivity.device ) @@ -377,8 +377,108 @@ def extra_repr(self) -> str: class BlockSparseLocallyConnected(nn.Module): """ A locally connected layer implemented using block sparse linear. - - TODO: main step is just computing the connectivity based on conv params. shape - should be something like: (out_height * out_width * out_channels, in_height * - in_width * in_channels). Then we just use BlockSparseLinear. """ + + def __init__( + self, + kernel_dims: Tuple[int, int], + in_dims: Tuple[int, int], + padding: Tuple[int, int], + stride: Tuple[int, int], + bias: bool, + blocksize: int, + ): + super().__init__() + assert ( + kernel_dims[0] <= in_dims[0] + 2 * padding[0] + ), "Kernel height exceeds input height + padding" + assert ( + kernel_dims[1] <= in_dims[1] + 2 * padding[1] + ), "Kernel width exceeds input width + padding" + assert torch.cuda.is_available(), "Triton BlockSparse operations require a GPU" + + self.in_height, self.in_width = in_dims + self.stride_h, self.stride_w = stride + self.padding_h, self.padding_w = padding + self.kernel_height, self.kernel_width = kernel_dims + self.num_kernels_h = ( + 1 + + (self.in_height + 2 * self.padding_h - self.kernel_height) + // self.stride_h + ) + self.num_kernels_w = ( + 1 + + (self.in_width + 2 * self.padding_w - self.kernel_width) // self.stride_w + ) + self.num_kernels = self.num_kernels_w * self.num_kernels_h + + connectivity = self._create_connectivity_matrix().cuda() + + self.bsl = BlockSparseLinear( + connectivity=connectivity, bias=bias, blocksize=blocksize + ) + + def _create_connectivity_matrix(self) -> torch.Tensor: + """Create a 2D binary connectivity matrix which will mask the linear layer""" + + connectivity_height = (self.in_height + 2 * self.padding_h) * ( + self.in_width + 2 * self.padding_w + ) + connectivity_width = self.num_kernels + + connectivity = torch.zeros( + connectivity_height, + connectivity_width, + ) + + full_in_width = self.in_width + 2 * self.padding_w + + idx = [] + idx_start = 0 + for _ in range(self.kernel_height): + idx.extend(range(idx_start, idx_start + self.kernel_width)) + idx_start += full_in_width + connectivity[idx, 0] = 1 + + start = 0 + current_w = 1 + num_rows = 1 + + for i in range(1, self.num_kernels): + if current_w < self.num_kernels_w: + start += self.stride_w + current_w += 1 + else: + start += ( + (full_in_width * num_rows) + - start + + (self.stride_h - 1) * full_in_width + ) + current_w = 1 + num_rows += self.stride_h + connectivity[[i + start for i in idx], i] = 1 + + return connectivity + + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert x.shape[2] == self.in_height + assert x.shape[3] == self.in_width + + padding = (self.padding_w, self.padding_w, self.padding_h, self.padding_h) + x = F.pad(x, padding, "constant", 0) + + x = ( + x.view( + x.shape[0], + (self.in_height + 2 * self.padding_h) + * (self.in_width + 2 * self.padding_w), + ) + .unsqueeze(1) + .unsqueeze(1) + ) + + out = self.bsl(x) + + out = out.reshape(x.shape[0], self.num_kernels_h, self.num_kernels_w) + + return out From 6356947c191081ed9e5ba51fc80d621cb7638c9c Mon Sep 17 00:00:00 2001 From: alismil Date: Mon, 10 Jun 2024 20:09:49 +0100 Subject: [PATCH 6/9] updated comments in _create_connectivity_matrix --- columnformers/models/layers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/columnformers/models/layers.py b/columnformers/models/layers.py index 2bda7ad..4441c93 100644 --- a/columnformers/models/layers.py +++ b/columnformers/models/layers.py @@ -435,6 +435,7 @@ def _create_connectivity_matrix(self) -> torch.Tensor: idx = [] idx_start = 0 + # find the nonzero indices of the flattened input tensor for the first kernel for _ in range(self.kernel_height): idx.extend(range(idx_start, idx_start + self.kernel_width)) idx_start += full_in_width @@ -445,6 +446,9 @@ def _create_connectivity_matrix(self) -> torch.Tensor: num_rows = 1 for i in range(1, self.num_kernels): + # to find the nonzero indices for each subsequent kernel, first find the index + # of the top left corner of the kernel then add this to all the elements of the + # first set of indices, effectively shifting the window over the input if current_w < self.num_kernels_w: start += self.stride_w current_w += 1 From f6581f4d5dfc1b8b9210d95cc5f1a68ce0ed2d3d Mon Sep 17 00:00:00 2001 From: Connor Lane Date: Tue, 2 Jul 2024 10:33:42 -0700 Subject: [PATCH 7/9] Update blocksparse linear implementation - Use pytorch native blocksparse tensors following [this blog](https://pytorch.org/blog/speeding-up-vits/) rather than triton blocksparse, which seems not very stable. See e.g. [this recent PR](https://github.com/triton-lang/triton/pull/4156) where `triton.ops` were deprecated. - Represent weights in `BlockSparseLinear` as block-sparse tensor rather than dense tensor. This will save a lot of gpu memory. - Change `BlockSparseLocallyConnected` interface to more closely match `nn.Conv2d`. Except remove support for padding and stride. For now we should be able to restrict to `padding="same"`, `stride=1` (?) - Rewrite function to construct local connectivity matrix to directly construct a sparse rather than dense matrix. Use vectorized operations rather than for loops for construction. This should save a lot of memory and run faster. - Add support for multiple input and output channels and depthwise convolution. The channels axis can either be first or last. For depthwise convolution, first should be more efficient (more block sparsity). TODO: - Finish testing on cuda. native blocksparse matmul is not implemented on CPU. --- columnformers/models/layers.py | 310 +++++++++++++++++-------------- tests/test_models/test_layers.py | 35 ++++ 2 files changed, 203 insertions(+), 142 deletions(-) create mode 100644 tests/test_models/test_layers.py diff --git a/columnformers/models/layers.py b/columnformers/models/layers.py index 4441c93..74553b6 100644 --- a/columnformers/models/layers.py +++ b/columnformers/models/layers.py @@ -1,17 +1,12 @@ -from typing import Callable, List, Tuple +import warnings +from typing import Callable, List, Literal, Optional, Tuple import torch import torch.nn.functional as F -import torch.nn.utils.prune as prune +from einops import rearrange from timm.layers import trunc_normal_ from torch import nn - -try: - from triton.ops.blocksparse import matmul as blocksparse_matmul # noqa - - triton_available = True -except ImportError: - triton_available = False +from torch.types import _device, _dtype Layer = Callable[..., nn.Module] @@ -302,18 +297,43 @@ class BlockSparseLinear(nn.Module): connectivity """ + connectivity: torch.Tensor + def __init__( - self, connectivity: torch.Tensor, bias: bool = True, blocksize: int = 16 + self, connectivity: torch.Tensor, bias: bool = True, blocksize: int = 32 ): - assert triton_available, "blocksparse linear requires triton" super().__init__() - self.in_features = connectivity.shape[0] - self.out_features = connectivity.shape[1] + device_capability = _cuda_get_device_capability() + if device_capability is None or device_capability < (8, 0): + warnings.warn( + "BlockSparseLinear only supported for CUDA A100 or higher", + RuntimeWarning, + ) + + self.in_features = connectivity.shape[1] + self.out_features = connectivity.shape[0] self.blocksize = blocksize - self.connectivity = connectivity - self.linear = nn.Linear(self.out_features, self.in_features, bias=False).to( - self.connectivity.device + # convert to torch blocksparse representation if not already + connectivity = connectivity.to_sparse_bsr(blocksize).float() + self.register_buffer("connectivity", connectivity) + + n_blocks = (self.out_features // blocksize) * (self.in_features // blocksize) + nnz_blocks = connectivity.values().size(0) + self.sparsity = 1 - (nnz_blocks / n_blocks) + + # Nb, we are using pytorch native block-sparse tensors following this blog: + # https://pytorch.org/blog/speeding-up-vits/ + # We were previously using triton blocksparse matmul, but it seems not very + # stable. See here: + # https://github.com/triton-lang/triton/pull/4156 + self.weight = nn.Parameter( + torch.sparse_bsr_tensor( + crow_indices=connectivity.crow_indices(), + col_indices=connectivity.col_indices(), + values=torch.empty_like(connectivity.values()), + size=connectivity.size(), + ) ) if bias: @@ -323,54 +343,22 @@ def __init__( self.reset_parameters() - prune.custom_from_mask(self.linear, name="weight", mask=connectivity) - - # convert to torch blocksparse representation if not already - sparse_connectivity = connectivity.to_sparse_bsr(blocksize) - - # block sparse layout as expected by triton - # shape (1, out_features // block, in_features // block) - # must be dtype int64 - layout = torch.sparse_csr_tensor( - sparse_connectivity.crow_indices(), - sparse_connectivity.col_indices(), - torch.ones_like(sparse_connectivity.col_indices()), - ) - layout = layout.to_dense().unsqueeze(0) - - self.sparse_dot_dds = blocksparse_matmul( - layout, - blocksize, - "dds", - trans_a=False, - trans_b=False, - device=self.connectivity.device, - ) - def reset_parameters(self): - # TODO: decide how to best init the weights - nn.init.xavier_normal_(self.linear.weight) + trunc_normal_(self.weight.values(), std=0.02) if self.bias is not None: nn.init.zeros_(self.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: - batch = x.shape[0] - - weight = self.linear.weight - sparse_weight = weight.to_sparse_bsr(self.blocksize).values() - sparse_weight = sparse_weight.unsqueeze(0).repeat(batch, 1, 1, 1) - - x = self.sparse_dot_dds(x.to(torch.float16), sparse_weight.to(torch.float16)) - - if self.bias is not None: - x += self.bias - + # apply sparse connectivity mask + self.weight.values().data.mul_(self.connectivity.values()) + x = F.linear(x, self.weight, self.bias) return x def extra_repr(self) -> str: return ( f"{self.in_features}, {self.out_features}, " - f"bias={self.bias is not None}, blocksize={self.blocksize}" + f"bias={self.bias is not None}, blocksize={self.blocksize}, " + f"sparsity={self.sparsity:.2f}" ) @@ -381,108 +369,146 @@ class BlockSparseLocallyConnected(nn.Module): def __init__( self, - kernel_dims: Tuple[int, int], - in_dims: Tuple[int, int], - padding: Tuple[int, int], - stride: Tuple[int, int], - bias: bool, - blocksize: int, + in_channels: int, + out_channels: int, + kernel_size: int, + height: int, + depthwise: bool = False, + bias: bool = True, + blocksize: int = 32, + in_shape: Literal["nlc", "nchw"] = "nchw", ): super().__init__() - assert ( - kernel_dims[0] <= in_dims[0] + 2 * padding[0] - ), "Kernel height exceeds input height + padding" - assert ( - kernel_dims[1] <= in_dims[1] + 2 * padding[1] - ), "Kernel width exceeds input width + padding" - assert torch.cuda.is_available(), "Triton BlockSparse operations require a GPU" - - self.in_height, self.in_width = in_dims - self.stride_h, self.stride_w = stride - self.padding_h, self.padding_w = padding - self.kernel_height, self.kernel_width = kernel_dims - self.num_kernels_h = ( - 1 - + (self.in_height + 2 * self.padding_h - self.kernel_height) - // self.stride_h - ) - self.num_kernels_w = ( - 1 - + (self.in_width + 2 * self.padding_w - self.kernel_width) // self.stride_w + assert isinstance(kernel_size, int), "only square kernels supported" + self.in_channels = in_channels + self.out_channels = out_channels + self.height = height + self.kernel_size = kernel_size + self.depthwise = depthwise + self.blocksize = blocksize + self.in_shape = in_shape + self.channels_last = not depthwise + + connectivity = _sparse_local_connectivity( + in_channels, + out_channels, + kernel_size, + height, + depthwise=depthwise, + channels_last=self.channels_last, ) - self.num_kernels = self.num_kernels_w * self.num_kernels_h - - connectivity = self._create_connectivity_matrix().cuda() - self.bsl = BlockSparseLinear( connectivity=connectivity, bias=bias, blocksize=blocksize ) - def _create_connectivity_matrix(self) -> torch.Tensor: - """Create a 2D binary connectivity matrix which will mask the linear layer""" - - connectivity_height = (self.in_height + 2 * self.padding_h) * ( - self.in_width + 2 * self.padding_w + def forward(self, input: torch.Tensor) -> torch.Tensor: + in_pattern = "n (h w) c" if self.in_shape == "nlc" else "n c h w" + out_pattern = "n (h w c)" if self.channels_last else "n (c h w)" + output = rearrange( + input, f"{in_pattern} -> {out_pattern}", h=self.height, w=self.height ) - connectivity_width = self.num_kernels - - connectivity = torch.zeros( - connectivity_height, - connectivity_width, + output = self.bsl(output) + output = rearrange( + output, + f"{out_pattern} -> {in_pattern}", + c=self.out_channels, + h=self.height, + w=self.height, ) + return output - full_in_width = self.in_width + 2 * self.padding_w - - idx = [] - idx_start = 0 - # find the nonzero indices of the flattened input tensor for the first kernel - for _ in range(self.kernel_height): - idx.extend(range(idx_start, idx_start + self.kernel_width)) - idx_start += full_in_width - connectivity[idx, 0] = 1 - - start = 0 - current_w = 1 - num_rows = 1 - - for i in range(1, self.num_kernels): - # to find the nonzero indices for each subsequent kernel, first find the index - # of the top left corner of the kernel then add this to all the elements of the - # first set of indices, effectively shifting the window over the input - if current_w < self.num_kernels_w: - start += self.stride_w - current_w += 1 - else: - start += ( - (full_in_width * num_rows) - - start - + (self.stride_h - 1) * full_in_width - ) - current_w = 1 - num_rows += self.stride_h - connectivity[[i + start for i in idx], i] = 1 - - return connectivity - def forward(self, x: torch.Tensor) -> torch.Tensor: - assert x.shape[2] == self.in_height - assert x.shape[3] == self.in_width +def _sparse_local_connectivity( + in_channels: int, + out_channels: int, + kernel_size: int, + height: int, + depthwise: bool = False, + channels_last: bool = False, + dtype: _dtype = None, + device: _device = None, +) -> torch.Tensor: + """ + Construct sparse local connectivity matrix, shape + (out_channels * height * height, in_channels * height * height). The returned + connectivity will have sparse COO layout. - padding = (self.padding_w, self.padding_w, self.padding_h, self.padding_h) - x = F.pad(x, padding, "constant", 0) + The connectivity pattern is equivalent to + `nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding="same")` - x = ( - x.view( - x.shape[0], - (self.in_height + 2 * self.padding_h) - * (self.in_width + 2 * self.padding_w), - ) - .unsqueeze(1) - .unsqueeze(1) + If channels_last is True, the shape of connectivity in einops notation is + "(h w cout) (h w cin)". Otherwise, it is "(cout h w) (cin h w)". The latter should + be more efficient when depthwise is True (connectivity will be more block sparse). + """ + assert kernel_size % 2 == 1, "kernel_size must be odd" + assert ( + not depthwise or out_channels == in_channels + ), "in channels must match out channels for depthwise" + N = height * height + + # ij indices of input grid + # (h^2, 2) + col_indices = torch.cartesian_prod(torch.arange(height), torch.arange(height)) + + # conv kernel index offsets. note that the kernel width is required to be odd. + # (k^2, 2) + kernel_half_width = (kernel_size - 1) // 2 + kernel_indices = torch.cartesian_prod( + torch.arange(-kernel_half_width, kernel_half_width + 1), + torch.arange(-kernel_half_width, kernel_half_width + 1), + ) + + # input edge indices for each output unit. these will be the column indices for the + # sparse COO connectivity. + # (h^2, k^2, 2) + col_indices = col_indices.unsqueeze(1) + kernel_indices.unsqueeze(0) + + # input edge row indices + # (h^2, k^2) + row_indices = torch.arange(N).unsqueeze(1).repeat(1, kernel_size**2) + + # exclude edges falling outside grid + mask = ((col_indices >= 0) & (col_indices < height)).all(axis=-1) + col_indices = col_indices[mask] + row_indices = row_indices[mask] + + # rasterize column indices + col_indices = height * col_indices[..., 0] + col_indices[..., 1] + + # add channel blocks with full or depthwise (diagonal) connectivity + if depthwise: + channel_indices = torch.arange(out_channels).unsqueeze(1).repeat(1, 2) + else: + channel_indices = torch.cartesian_prod( + torch.arange(out_channels), torch.arange(in_channels) ) - out = self.bsl(x) + # we can insert the channels axis either at the front or the back + # front is better for depthwise=True, back is better for depthwise=False + if channels_last: + row_indices = out_channels * row_indices.unsqueeze(1) + channel_indices[:, 0] + col_indices = in_channels * col_indices.unsqueeze(1) + channel_indices[:, 1] + else: + row_indices = N * channel_indices[:, 0].unsqueeze(1) + row_indices + col_indices = N * channel_indices[:, 1].unsqueeze(1) + col_indices + + row_indices = row_indices.flatten() + col_indices = col_indices.flatten() + + # construct sparse connectivity tensor + connectivity = ( + torch.sparse_coo_tensor( + torch.stack([row_indices, col_indices]), + torch.ones(len(row_indices), dtype=dtype), + size=(out_channels * N, in_channels * N), + ) + .coalesce() + .to(device) + ) + return connectivity - out = out.reshape(x.shape[0], self.num_kernels_h, self.num_kernels_w) - return out +def _cuda_get_device_capability() -> Optional[Tuple[int, int]]: + if not torch.cuda.is_available(): + return None + return torch.cuda.get_device_capability() diff --git a/tests/test_models/test_layers.py b/tests/test_models/test_layers.py new file mode 100644 index 0000000..62a78cb --- /dev/null +++ b/tests/test_models/test_layers.py @@ -0,0 +1,35 @@ +import logging +import torch +from torch import nn + +import columnformers.models.layers as L + + +def test_block_sparse_locally_connected(): + loc = L.BlockSparseLocallyConnected( + in_channels=8, + out_channels=16, + kernel_size=3, + height=16, + depthwise=False, + ) + logging.info("%s", loc) + + conv = nn.Conv2d( + in_channels=8, + out_channels=16, + kernel_size=3, + stride=1, + padding="same", + ) + + nn.init.ones_(loc.bsl.weight.values()) + nn.init.ones_(conv.weight) + nn.init.zeros_(loc.bsl.bias) + nn.init.zeros_(conv.bias) + + # TODO: finish testing on cuda + input = torch.randn(2, 8, 16, 16) + output_loc = loc(input) + output_conv = conv(input) + assert torch.allclose(output_loc, output_conv) From b32e1fab2189398cdf0f54bdbc919422a30d608b Mon Sep 17 00:00:00 2001 From: Connor Lane Date: Tue, 2 Jul 2024 20:19:08 -0400 Subject: [PATCH 8/9] Debug block sparse linear Could not use `sparse_bsr` layout weight parameter. It fails to map to cuda correctly when calling `model.cuda()`. This is probably a bug that should be reported. But as a workaround I just unpack the `crow_indices`, `col_indices` as buffers and store the sparse bsr weight values as a standard strided parameter. Then I construct the sparse bsr weight tensor on the fly during forward. TODO: backward does not work. Raises ``` RuntimeError: addmm: computation on CUDA is not implemented for Strided + Strided @ SparseBsr ``` --- columnformers/models/layers.py | 42 +++++++++++++++++--------------- tests/test_models/test_layers.py | 33 +++++++++++++++++++------ 2 files changed, 48 insertions(+), 27 deletions(-) diff --git a/columnformers/models/layers.py b/columnformers/models/layers.py index 74553b6..8253df2 100644 --- a/columnformers/models/layers.py +++ b/columnformers/models/layers.py @@ -297,7 +297,9 @@ class BlockSparseLinear(nn.Module): connectivity """ - connectivity: torch.Tensor + crow_indices: torch.Tensor + col_indices: torch.Tensor + mask: torch.Tensor def __init__( self, connectivity: torch.Tensor, bias: bool = True, blocksize: int = 32 @@ -307,7 +309,7 @@ def __init__( if device_capability is None or device_capability < (8, 0): warnings.warn( "BlockSparseLinear only supported for CUDA A100 or higher", - RuntimeWarning, + UserWarning, ) self.in_features = connectivity.shape[1] @@ -316,26 +318,19 @@ def __init__( # convert to torch blocksparse representation if not already connectivity = connectivity.to_sparse_bsr(blocksize).float() - self.register_buffer("connectivity", connectivity) + self.register_buffer("crow_indices", connectivity.crow_indices()) + self.register_buffer("col_indices", connectivity.col_indices()) + self.register_buffer("mask", connectivity.values()) n_blocks = (self.out_features // blocksize) * (self.in_features // blocksize) nnz_blocks = connectivity.values().size(0) self.sparsity = 1 - (nnz_blocks / n_blocks) - # Nb, we are using pytorch native block-sparse tensors following this blog: - # https://pytorch.org/blog/speeding-up-vits/ - # We were previously using triton blocksparse matmul, but it seems not very - # stable. See here: - # https://github.com/triton-lang/triton/pull/4156 - self.weight = nn.Parameter( - torch.sparse_bsr_tensor( - crow_indices=connectivity.crow_indices(), - col_indices=connectivity.col_indices(), - values=torch.empty_like(connectivity.values()), - size=connectivity.size(), - ) - ) - + # The weight parameter is just the sparse bsr values. We construct the sparse + # bsr tensor on the fly. Trying to use a sparse bsr layout parameter doesn't + # work. Mapping the model to cuda fails. The tensor container is mapped but not + # the underlying values. + self.weight = nn.Parameter(torch.empty_like(connectivity.values())) if bias: self.bias = nn.Parameter(torch.empty(self.out_features)) else: @@ -344,14 +339,21 @@ def __init__( self.reset_parameters() def reset_parameters(self): - trunc_normal_(self.weight.values(), std=0.02) + trunc_normal_(self.weight, std=0.02) + self.weight.data.mul_(self.mask) if self.bias is not None: nn.init.zeros_(self.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: # apply sparse connectivity mask - self.weight.values().data.mul_(self.connectivity.values()) - x = F.linear(x, self.weight, self.bias) + weight = self.mask * self.weight + weight = torch.sparse_bsr_tensor( + crow_indices=self.crow_indices, + col_indices=self.col_indices, + values=weight, + size=(self.out_features, self.in_features), + ) + x = F.linear(x, weight, self.bias) return x def extra_repr(self) -> str: diff --git a/tests/test_models/test_layers.py b/tests/test_models/test_layers.py index 62a78cb..f388183 100644 --- a/tests/test_models/test_layers.py +++ b/tests/test_models/test_layers.py @@ -1,3 +1,5 @@ +import pytest + import logging import torch from torch import nn @@ -5,7 +7,11 @@ import columnformers.models.layers as L +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda not available") def test_block_sparse_locally_connected(): + torch.manual_seed(42) + device = torch.device("cuda") + loc = L.BlockSparseLocallyConnected( in_channels=8, out_channels=16, @@ -23,13 +29,26 @@ def test_block_sparse_locally_connected(): padding="same", ) - nn.init.ones_(loc.bsl.weight.values()) - nn.init.ones_(conv.weight) + nn.init.ones_(loc.bsl.weight) nn.init.zeros_(loc.bsl.bias) + nn.init.ones_(conv.weight) nn.init.zeros_(conv.bias) - # TODO: finish testing on cuda - input = torch.randn(2, 8, 16, 16) - output_loc = loc(input) - output_conv = conv(input) - assert torch.allclose(output_loc, output_conv) + loc = loc.to(device) + conv = conv.to(device) + + input = torch.randn((2, 8, 16, 16), device=device) + input_loc = input.clone().requires_grad_(True) + input_conv = input.clone().requires_grad_(True) + + output_loc = loc(input_loc) + output_conv = conv(input_conv) + assert torch.allclose(output_loc, output_conv, rtol=1e-4) + + loss_loc = (output_loc**2).mean() + loss_conv = (output_conv**2).mean() + loss_loc.backward() + loss_conv.backward() + grad_loc = input_loc.grad.data + grad_conv = input_conv.grad.data + assert torch.allclose(grad_loc, grad_conv, rtol=1e-4) From b5b9a02d2a25d51d90a01d61168d9869adb12f0c Mon Sep 17 00:00:00 2001 From: Connor Lane Date: Wed, 3 Jul 2024 18:00:01 -0400 Subject: [PATCH 9/9] Add `in_shape="nd"` option for no rearranging --- columnformers/models/layers.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/columnformers/models/layers.py b/columnformers/models/layers.py index 8253df2..035084c 100644 --- a/columnformers/models/layers.py +++ b/columnformers/models/layers.py @@ -378,7 +378,7 @@ def __init__( depthwise: bool = False, bias: bool = True, blocksize: int = 32, - in_shape: Literal["nlc", "nchw"] = "nchw", + in_shape: Literal["nlc", "nchw", "nd"] = "nchw", ): super().__init__() assert isinstance(kernel_size, int), "only square kernels supported" @@ -404,19 +404,25 @@ def __init__( ) def forward(self, input: torch.Tensor) -> torch.Tensor: - in_pattern = "n (h w) c" if self.in_shape == "nlc" else "n c h w" - out_pattern = "n (h w c)" if self.channels_last else "n (c h w)" - output = rearrange( - input, f"{in_pattern} -> {out_pattern}", h=self.height, w=self.height - ) - output = self.bsl(output) - output = rearrange( - output, - f"{out_pattern} -> {in_pattern}", - c=self.out_channels, - h=self.height, - w=self.height, - ) + needs_reshape = self.in_shape != "nd" + + if needs_reshape: + in_pattern = "n (h w) c" if self.in_shape == "nlc" else "n c h w" + out_pattern = "n (h w c)" if self.channels_last else "n (c h w)" + input = rearrange( + input, f"{in_pattern} -> {out_pattern}", h=self.height, w=self.height + ) + + output = self.bsl(input) + + if needs_reshape: + output = rearrange( + output, + f"{out_pattern} -> {in_pattern}", + c=self.out_channels, + h=self.height, + w=self.height, + ) return output