Skip to content
This repository was archived by the owner on Oct 15, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 175 additions & 34 deletions bitmat/bitlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,63 +3,187 @@
from .utils.bitmat import bitmat
from .utils.rmsnorm import RMSLayerNorm
from .utils.packing import pack_ternary, unpack_ternary
from .utils.bitmat import terniarize
from .utils.bitmat import terniarize,bitmat_

class BitLinear(torch.nn.Module):
"""
A linear layer that uses packed terniary matrix multiplication.
"""

def __init__(self, in_features, out_features, bias=None, eps=1e-5, keep_rms_in_32b=False, dtype=torch.float16):
import torch
import torch.nn as nn
from torch import Tensor

def weight_quant(weight:Tensor, num_bits=1):
dtype = weight.dtype
weight = weight.float()
s = 1 / weight.abs().mean().clamp(min=1e-5)
result = (weight * s).round().clamp(-1, 1)
return result.type(dtype).to(torch.int8), s


def activation_quant(x, num_bits=8):
dtype = x.dtype
x = x.float()
Qn = -(2 ** (num_bits - 1))
Qp = 2 ** (num_bits - 1) - 1
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
result = (x * s).round().clamp(Qn, Qp)
return result.type(dtype)


class BitLinear(nn.Module):
def __init__(
self,
in_features,
out_features,
bias=None,
eps=1e-5,
keep_rms_in_32b=False,
dtype=torch.float16,
packed=False,
*args,
**kwargs,
):
super(BitLinear, self).__init__()
print("Using Fast Bitmat")
"""
RMSNorm is placed outside BitLinear
"""

self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features)))
if packed:
self.convert_weights_to_packed()
if bias:
self.bias = torch.nn.Parameter(torch.Tensor(out_features))
else:
self.register_parameter("bias", None)
num_bits = 8
self.Qp = 127


def convert_weights_to_parameters(self):
# Converti i pesi in torch.nn.Parameter di tipo float16 per il training.
if self.weight.dtype == torch.int8:
unpacked_weight = unpack_ternary(self.weight)
half_weight = (unpacked_weight / self.scale_w).to(self.dtype)
self.weight = torch.nn.Parameter(half_weight)
self.scale_w = (
None # <- this is done so that the bitmat kernel knows we're training
)

def convert_weights_to_packed(self):
print("Packing")
# Converti i pesi indietro in PackedParameter di tipo int8 dopo il training.
if not isinstance(self.weight, torch.nn.Parameter):
return

terniarized_weight, scale_weight = terniarize(self.weight.data)
packed_weights = pack_ternary(terniarized_weight)

del self.weight # <- this is done so that torch doesn't trow an error when trying to convert the nn.Parameter to PackedParameter
self.register_buffer("weight", packed_weights)
self.register_buffer("scale_w", scale_weight)

def forward(self, input):
self.input = input
quant_input = activation_quant(input)


s = 127 / (
input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
)

if self.weight.dtype!=torch.int8:
self.convert_weights_to_packed()
# quant_weight, scale = terniarize(self.weight)
# wp = pack_ternary(quant_weight.to(torch.int8))
# del self.weight
# self.register_buffer("weight",wp)
# self.scale_w = scale
# print("COLD")

wp = self.weight
scale = self.scale_w
out = bitmat_(quant_input.half() / s, wp.t().contiguous()).to(torch.float)

out = out.to(input.dtype)
out = out / scale
if not self.bias is None:
out += self.bias.view(1, -1).expand_as(out)

return out

def _post_init(self):
self.convert_weights_to_packed()


class BitLinearx(torch.nn.Module):
def __init__(
self,
in_features,
out_features,
bias=None,
eps=1e-5,
keep_rms_in_32b=False,
dtype=torch.float16,
packed=False,
*args,
**kwargs,
):
super(BitLinear, self).__init__()
print("Using Fast Bitmat")
self.eps = eps
self.in_features = in_features
self.out_features = out_features
self.dtype = dtype
self.register_buffer('weight', torch.zeros((out_features, in_features), dtype=torch.int8))
self.scale_w = torch.nn.Parameter(torch.Tensor(1))

if packed:
self.register_buffer("weight" , torch.zeros((out_features, in_features)))
self.register_buffer("scale_w",torch.Tensor(1))
else:
self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features)))

if bias:
self.bias = torch.nn.Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.norm = RMSLayerNorm(in_features, eps)
self.register_parameter("bias", None)
self.keep_rms_in_32b = keep_rms_in_32b
self._post_init()

# self._post_init()

"""
def a .to() to keep eps precision
"""


def _post_init(self):
#crea un var dei parametri del modello così da poter inizializzare i pesi e i bias
print("POST init")
# crea un var dei parametri del modello così da poter inizializzare i pesi e i bias
# Inizializza i pesi utilizzando l'inizializzazione di Kaiming
params = torch.nn.Parameter(torch.zeros((self.out_features, self.in_features), dtype=self.dtype))
torch.nn.init.kaiming_normal_(params, mode='fan_out', nonlinearity='relu')
terniarized_val, self.scale_w.data = terniarize(params)
del params
self.register_buffer('weight',pack_ternary(terniarized_val))
# params = torch.nn.Parameter(
# torch.zeros((self.out_features, self.in_features), dtype=self.dtype)
# )
# torch.nn.init.kaiming_normal_(params, mode="fan_out", nonlinearity="relu")
# terniarized_val, self.scale_w.data = terniarize(params)
# del params
# self.register_buffer("weight", pack_ternary(terniarized_val))

if self.bias is not None:
torch.nn.init.constant_(self.bias, 0)
# if self.bias is not None:
# torch.nn.init.constant_(self.bias, 0)

def convert_weights_to_parameters(self):
# Converti i pesi in torch.nn.Parameter di tipo float16 per il training.
if self.weight.dtype == torch.int8:
unpacked_weight = unpack_ternary(self.weight)
half_weight = (unpacked_weight / self.scale_w).to(self.dtype)
self.weight = torch.nn.Parameter(half_weight)
self.scale_w = None# <- this is done so that the bitmat kernel knows we're training
self.scale_w = (
None # <- this is done so that the bitmat kernel knows we're training
)

def convert_weights_to_packed(self):
# Converti i pesi indietro in PackedParameter di tipo int8 dopo il training.
if isinstance(self.weight, torch.nn.Parameter):
terniarized_weight, scale_weight = terniarize(self.weight.data)
packed_weights = pack_ternary(terniarized_weight)
self.scale_w = torch.nn.Parameter(scale_weight)
del self.weight # <- this is done so that torch doesn't trow an error when trying to convert the nn.Parameter to PackedParameter
self.register_buffer('weight', packed_weights)
del self.weight # <- this is done so that torch doesn't trow an error when trying to convert the nn.Parameter to PackedParameter
self.register_buffer("weight", packed_weights)

def train(self, mode=True):
super().train(mode)
Expand All @@ -70,13 +194,30 @@ def train(self, mode=True):
self.convert_weights_to_packed()
return self.to(device)

def forward(self, x):
if self.training and (self.weight.dtype == torch.int8):
# Just to make sure the weights are in the right format even if the user forgot to call train()
self.convert_weights_to_parameters()
x_dtype = x.dtype
x = self.norm(x.to(self.norm.weight.dtype)).to(x_dtype)
output = bitmat(self.weight.data, x, scale_w=self.scale_w)
if self.bias is not None:
output += self.bias.unsqueeze(0).expand_as(output)
return output
def forward(self, input):
self.input = input
quant_input = activation_quant(input, self.input_bits)


s = self.Qp / (
input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + 2e-6
)

if self.weight.dtype!=torch.int8:
quant_weight, scale = weight_quant(self.weight, self.weight_bits)
wp = pack_ternary(quant_weight.to(torch.int8))
del self.weight
self.register_buffer("weight",wp)
self.scale = scale
print("cold start")
else:
wp = self.weight
scale = self.scale
out = bitmat_(quant_input.half() / s, wp.t().contiguous()).to(torch.float)

out = out.to(input.dtype)
out = out * scale
if not self.bias is None:
out += self.bias.view(1, -1).expand_as(out)

return out
45 changes: 28 additions & 17 deletions bitmat/utils/bitmat.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
from typing import Tuple

import os
import torch

from ..triton_kernels.bitmat_kernel import bitmat_
from .packing import pack_ternary

BITMAT_QUANT_8BIT_ACTIVATIONS = not os.getenv("BITMAT_QUANT_8BIT_ACTIVATIONS","True").lower() in ('false', '0', 'f')

def terniarize(weight:torch.Tensor):
dtype = weight.dtype
weight = weight.float()
scale = 1 / weight.abs().mean().clamp(min=1e-5)
result = (weight * scale).round().clamp(-1, 1)
return result.type(dtype).to(torch.int8), scale.to(dtype)

def terniarize(weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Terniarizes the weights and returns the scale.
"""
dtype = weights.dtype
scale = 1 / torch.max(weights.abs().mean(), torch.tensor(1e-5))
return torch.clamp((weights * scale).to(torch.int8), -1, 1), scale.to(dtype)

def quantize_activations(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the activations and returns the scale for each row.
"""
dtype = x.dtype
scale = (127 / torch.max(x.abs().max(dim=-1).values, torch.tensor(1e-5))).unsqueeze(-1)
return torch.clamp((x * scale), -127, 128).to(torch.int8), scale.to(dtype)
scale = (128 / torch.max(x.abs().max(dim=-1).values, torch.tensor(1e-5))).unsqueeze(-1)
return torch.clamp((x * scale), -128, 127).to(torch.int8), scale.to(dtype)


class BitMat(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, W, X, scale_w=None):
def forward(ctx, W, X, scale_w=None,quant_8bit_activations=True):
"""
During the forward pass, we ternarize the weights, pack them and then quantize the activations.
We then perform the bit matrix multiplication and return the scaled results.
Expand All @@ -42,19 +43,28 @@ def forward(ctx, W, X, scale_w=None):
Y = X @ w_packed.t() | dot product
Y = Y / scale_w / scale_x) | STE
"""
X, scale_x = quantize_activations(X)
if not quant_8bit_activations:
X = X/scale_x

if scale_w is None:
dtype = W.dtype
W, scale_w = terniarize(W)
#packed_w = pack_ternary(W, 4) -> this is actually not efficent atm
ctx.save_for_backward(X)
X, scale_x = quantize_activations(X)

y = X.to(dtype) @ W.to(dtype).t()
#y = batched_bitmat(X, packed_w) -> this is actually not efficent atm
return y / scale_w / scale_x
out = y / scale_w
else:
X, scale_x = quantize_activations(X)

y = bitmat_(X, W.t().contiguous())
return y / scale_w / scale_x
out = y / scale_w

if quant_8bit_activations:
out = out / scale_x

return out


@staticmethod
Expand All @@ -65,5 +75,6 @@ def backward(ctx, grad_output):
grad_W = (grad_output.transpose(1,2) @ X).mean(dim=0)
return grad_W, None, None

def bitmat(W: torch.Tensor, X: torch.Tensor, scale_w) -> torch.Tensor:
return BitMat.apply(W, X, scale_w)
def bitmat(W: torch.Tensor, X: torch.Tensor, scale_w,quant_8bit_activations=None) -> torch.Tensor:
quant_8bit_activations = quant_8bit_activations if quant_8bit_activations is not None else BITMAT_QUANT_8BIT_ACTIVATIONS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like that is is been passed as a env var

return BitMat.apply(W, X, scale_w,quant_8bit_activations)
9 changes: 7 additions & 2 deletions bitmat/utils/convert_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tqdm import tqdm
from ..bitlinear import BitLinear

from transformers import AutoModel, GemmaConfig, MistralConfig, LlamaConfig
from transformers import AutoModel, GemmaConfig, MistralConfig, LlamaConfig, AutoModelForCausalLM

# Importing custom hijack classes for specific models
from .modeling.model_hijacks.gemma_1_58b import Gemma158ForCausalLM
Expand All @@ -28,7 +28,12 @@ def convert_hf_model(model: AutoModel) -> AutoModel:
elif isinstance(model_config, LlamaConfig):
hijacked_model = Llama158ForCausalLM(model_config)
else:
raise RuntimeError("Unsupported model type. Please open an issue on GitHub citing the model you are using")
try:
hijacked_model = AutoModelForCausalLM(model_config)
hijacked_model = apply_bitlinear_to_hf_model(hijacked_model)
except Exception as e:
print(str(e))
raise RuntimeError("Unsupported model type. Please open an issue on GitHub citing the model you are using")

pbar.update(1)
return hijacked_model
Expand Down
1 change: 1 addition & 0 deletions bitmat/utils/modeling/automodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def from_pretrained(cls, *args, **kwargs):
raise ValueError(
f"The model {args[0]} was not found, this mean it has not been mapped yet, please open an issue on the github repository")

print(f"model CLASS {model_class_name}")
# Ottieni la classe del modello utilizzando il suo nome
model_class = globals()[model_class_name]

Expand Down
Empty file added bitmp/main.py
Empty file.
7 changes: 7 additions & 0 deletions dev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
d = {}
print(id(d))
nd = {}
print(id(nd))
d= d[0] =nd
print(id(d))
print(d)