Skip to content

Conversation

@huweim
Copy link

@huweim huweim commented Mar 18, 2025

Motivation

Marlin-Sparse is a nice repository. When I run the code about sparse and pack, I encountered some bugs. These bugs results in negative values being pruned, leading to extremely large outputs.

Modifications

1. Incorrect Shape of self.B

In Layer_2_4, the shape of self.B should be (self.k // 16 // 2, self.n * 16 // 8).

# before
self.register_buffer(
    "B", torch.empty((self.k // 16, self.n * 16 // 8), dtype=torch.int)
)

# after
self.register_buffer(
    "B", torch.empty((self.k // 16 // 2, self.n * 16 // 8), dtype=torch.int)
)

2. Issue with Sparsity Logic
The quantized weight in Layer_2_4 has a +8 offset and is clamped within the range (0, 15). However, the function mask_creator does not consider this offset. It prunes the N smallest elements from M elements without considering the +8 offset. As a result, it unintentionally prunes negative values in the weight matrix.

# before
def mask_creator(tensor):
    ...
    num_groups = tensor.numel() // M

    # N:M sparsity for linear layers
    tensor_temp = tensor.detach().abs().reshape(num_groups, M)

# after
def mask_creator(tensor):
    ...
    num_groups = tensor.numel() // M

    # Subtract the offset value for pruning
    maxq = 2**4 - 1
    ZERO_VALUE = (maxq + 1) // 2
    tensor = tensor - ZERO_VALUE

    # N:M sparsity for linear layers
    tensor_temp = tensor.detach().abs().reshape(num_groups, M)

3. Handling of Pruned Elements
After applying sparsity, mask * w.T is used to zero out the pruned elements. The function sparse_semi_structured_from_dense_cutlass then generates indices for these pruned elements. However, the range of quantized weights is (0, 15) after adding offset +8. So the original weights with 0 may be confused with the pruned value, causing the incorrect sparsity.

# before
def pack(self, linear, scales, trans=False):
    ...

    mask = mask_creator(w.T).cuda().bool()
    w = mask * w.T
    w, meta = sparse_semi_structured_from_dense_cutlass(w)
    w = w.t()
# after
def pack(self, linear, scales, trans=False):
    ...

    mask = mask_creator(w.T).cuda().bool()
    # Avoid confusing the pruned elements and the zero elements (-8 + (maxq + 1) // 2)
    w += 1

    w = mask * w.T
    w, meta = sparse_semi_structured_from_dense_cutlass(w)
    
    # Reover the original values
    w -= 1
    w = w.t()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant