From c15cda94c3f56d70bb3ca194a86a99758d7692bc Mon Sep 17 00:00:00 2001 From: joey00072 <00shxf@gmail.com> Date: Fri, 10 May 2024 21:26:23 +0530 Subject: [PATCH 1/3] activation quantization range fix --- bitmat/utils/bitmat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitmat/utils/bitmat.py b/bitmat/utils/bitmat.py index 04d0b67..d30cdf4 100644 --- a/bitmat/utils/bitmat.py +++ b/bitmat/utils/bitmat.py @@ -20,7 +20,7 @@ def quantize_activations(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ dtype = x.dtype scale = (127 / torch.max(x.abs().max(dim=-1).values, torch.tensor(1e-5))).unsqueeze(-1) - return torch.clamp((x * scale), -127, 128).to(torch.int8), scale.to(dtype) + return torch.clamp((x * scale), -128, 127).to(torch.int8), scale.to(dtype) class BitMat(torch.autograd.Function): From 8ab609a027831528f009f920c738caf68818b1cd Mon Sep 17 00:00:00 2001 From: joey00072 <00shxf@gmail.com> Date: Fri, 10 May 2024 21:31:08 +0530 Subject: [PATCH 2/3] added optinal quantization --- bitmat/utils/bitmat.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/bitmat/utils/bitmat.py b/bitmat/utils/bitmat.py index d30cdf4..08c45aa 100644 --- a/bitmat/utils/bitmat.py +++ b/bitmat/utils/bitmat.py @@ -1,10 +1,11 @@ from typing import Tuple - +import os import torch from ..triton_kernels.bitmat_kernel import bitmat_ from .packing import pack_ternary +BITMAT_QUANT_ACTIVATIONS = not os.getenv("BITMAT_QUANT_ACTIVATIONS","True").lower() in ('false', '0', 'f') def terniarize(weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -19,14 +20,14 @@ def quantize_activations(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: Quantizes the activations and returns the scale for each row. """ dtype = x.dtype - scale = (127 / torch.max(x.abs().max(dim=-1).values, torch.tensor(1e-5))).unsqueeze(-1) + scale = (128 / torch.max(x.abs().max(dim=-1).values, torch.tensor(1e-5))).unsqueeze(-1) return torch.clamp((x * scale), -128, 127).to(torch.int8), scale.to(dtype) class BitMat(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd - def forward(ctx, W, X, scale_w=None): + def forward(ctx, W, X, scale_w=None,quant_activations=True): """ During the forward pass, we ternarize the weights, pack them and then quantize the activations. We then perform the bit matrix multiplication and return the scaled results. @@ -42,19 +43,27 @@ def forward(ctx, W, X, scale_w=None): Y = X @ w_packed.t() | dot product Y = Y / scale_w / scale_x) | STE """ + if quant_activations: + X, scale_x = quantize_activations(X) + if scale_w is None: dtype = W.dtype W, scale_w = terniarize(W) #packed_w = pack_ternary(W, 4) -> this is actually not efficent atm ctx.save_for_backward(X) - X, scale_x = quantize_activations(X) + y = X.to(dtype) @ W.to(dtype).t() #y = batched_bitmat(X, packed_w) -> this is actually not efficent atm - return y / scale_w / scale_x + out = y / scale_w else: - X, scale_x = quantize_activations(X) + y = bitmat_(X, W.t().contiguous()) - return y / scale_w / scale_x + out = y / scale_w + + if quant_activations: + out = out / scale_x + + return out @staticmethod @@ -65,5 +74,6 @@ def backward(ctx, grad_output): grad_W = (grad_output.transpose(1,2) @ X).mean(dim=0) return grad_W, None, None -def bitmat(W: torch.Tensor, X: torch.Tensor, scale_w) -> torch.Tensor: - return BitMat.apply(W, X, scale_w) \ No newline at end of file +def bitmat(W: torch.Tensor, X: torch.Tensor, scale_w,quant_activations=None) -> torch.Tensor: + quant_activations = quant_activations if quant_activations is not None else BITMAT_QUANT_ACTIVATIONS + return BitMat.apply(W, X, scale_w,quant_activations=quant_activations) \ No newline at end of file From 0f35005def4ff5a03ff5569d23098e3645fe1b4f Mon Sep 17 00:00:00 2001 From: joey00072 <00shxf@gmail.com> Date: Thu, 16 May 2024 19:53:15 +0530 Subject: [PATCH 3/3] fix n-3 --- bitmat/bitlinear.py | 209 ++++++++++++++++++++++++----- bitmat/utils/bitmat.py | 31 ++--- bitmat/utils/convert_hf_model.py | 9 +- bitmat/utils/modeling/automodel.py | 1 + bitmp/main.py | 0 dev.py | 7 + 6 files changed, 206 insertions(+), 51 deletions(-) create mode 100644 bitmp/main.py create mode 100644 dev.py diff --git a/bitmat/bitlinear.py b/bitmat/bitlinear.py index fe1efc9..bb9458f 100644 --- a/bitmat/bitlinear.py +++ b/bitmat/bitlinear.py @@ -3,46 +3,168 @@ from .utils.bitmat import bitmat from .utils.rmsnorm import RMSLayerNorm from .utils.packing import pack_ternary, unpack_ternary -from .utils.bitmat import terniarize +from .utils.bitmat import terniarize,bitmat_ -class BitLinear(torch.nn.Module): - """ - A linear layer that uses packed terniary matrix multiplication. - """ - def __init__(self, in_features, out_features, bias=None, eps=1e-5, keep_rms_in_32b=False, dtype=torch.float16): +import torch +import torch.nn as nn +from torch import Tensor + +def weight_quant(weight:Tensor, num_bits=1): + dtype = weight.dtype + weight = weight.float() + s = 1 / weight.abs().mean().clamp(min=1e-5) + result = (weight * s).round().clamp(-1, 1) + return result.type(dtype).to(torch.int8), s + + +def activation_quant(x, num_bits=8): + dtype = x.dtype + x = x.float() + Qn = -(2 ** (num_bits - 1)) + Qp = 2 ** (num_bits - 1) - 1 + s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + result = (x * s).round().clamp(Qn, Qp) + return result.type(dtype) + + +class BitLinear(nn.Module): + def __init__( + self, + in_features, + out_features, + bias=None, + eps=1e-5, + keep_rms_in_32b=False, + dtype=torch.float16, + packed=False, + *args, + **kwargs, + ): + super(BitLinear, self).__init__() + print("Using Fast Bitmat") + """ + RMSNorm is placed outside BitLinear + """ + + self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features))) + if packed: + self.convert_weights_to_packed() + if bias: + self.bias = torch.nn.Parameter(torch.Tensor(out_features)) + else: + self.register_parameter("bias", None) + num_bits = 8 + self.Qp = 127 + + + def convert_weights_to_parameters(self): + # Converti i pesi in torch.nn.Parameter di tipo float16 per il training. + if self.weight.dtype == torch.int8: + unpacked_weight = unpack_ternary(self.weight) + half_weight = (unpacked_weight / self.scale_w).to(self.dtype) + self.weight = torch.nn.Parameter(half_weight) + self.scale_w = ( + None # <- this is done so that the bitmat kernel knows we're training + ) + + def convert_weights_to_packed(self): + print("Packing") + # Converti i pesi indietro in PackedParameter di tipo int8 dopo il training. + if not isinstance(self.weight, torch.nn.Parameter): + return + + terniarized_weight, scale_weight = terniarize(self.weight.data) + packed_weights = pack_ternary(terniarized_weight) + + del self.weight # <- this is done so that torch doesn't trow an error when trying to convert the nn.Parameter to PackedParameter + self.register_buffer("weight", packed_weights) + self.register_buffer("scale_w", scale_weight) + + def forward(self, input): + self.input = input + quant_input = activation_quant(input) + + + s = 127 / ( + input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + ) + + if self.weight.dtype!=torch.int8: + self.convert_weights_to_packed() + # quant_weight, scale = terniarize(self.weight) + # wp = pack_ternary(quant_weight.to(torch.int8)) + # del self.weight + # self.register_buffer("weight",wp) + # self.scale_w = scale + # print("COLD") + + wp = self.weight + scale = self.scale_w + out = bitmat_(quant_input.half() / s, wp.t().contiguous()).to(torch.float) + + out = out.to(input.dtype) + out = out / scale + if not self.bias is None: + out += self.bias.view(1, -1).expand_as(out) + + return out + + def _post_init(self): + self.convert_weights_to_packed() + + +class BitLinearx(torch.nn.Module): + def __init__( + self, + in_features, + out_features, + bias=None, + eps=1e-5, + keep_rms_in_32b=False, + dtype=torch.float16, + packed=False, + *args, + **kwargs, + ): super(BitLinear, self).__init__() + print("Using Fast Bitmat") self.eps = eps self.in_features = in_features self.out_features = out_features self.dtype = dtype - self.register_buffer('weight', torch.zeros((out_features, in_features), dtype=torch.int8)) - self.scale_w = torch.nn.Parameter(torch.Tensor(1)) + + if packed: + self.register_buffer("weight" , torch.zeros((out_features, in_features))) + self.register_buffer("scale_w",torch.Tensor(1)) + else: + self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features))) + if bias: self.bias = torch.nn.Parameter(torch.Tensor(out_features)) else: - self.register_parameter('bias', None) - self.norm = RMSLayerNorm(in_features, eps) + self.register_parameter("bias", None) self.keep_rms_in_32b = keep_rms_in_32b - self._post_init() - + # self._post_init() """ def a .to() to keep eps precision """ - def _post_init(self): - #crea un var dei parametri del modello così da poter inizializzare i pesi e i bias + print("POST init") + # crea un var dei parametri del modello così da poter inizializzare i pesi e i bias # Inizializza i pesi utilizzando l'inizializzazione di Kaiming - params = torch.nn.Parameter(torch.zeros((self.out_features, self.in_features), dtype=self.dtype)) - torch.nn.init.kaiming_normal_(params, mode='fan_out', nonlinearity='relu') - terniarized_val, self.scale_w.data = terniarize(params) - del params - self.register_buffer('weight',pack_ternary(terniarized_val)) + # params = torch.nn.Parameter( + # torch.zeros((self.out_features, self.in_features), dtype=self.dtype) + # ) + # torch.nn.init.kaiming_normal_(params, mode="fan_out", nonlinearity="relu") + # terniarized_val, self.scale_w.data = terniarize(params) + # del params + # self.register_buffer("weight", pack_ternary(terniarized_val)) - if self.bias is not None: - torch.nn.init.constant_(self.bias, 0) + # if self.bias is not None: + # torch.nn.init.constant_(self.bias, 0) def convert_weights_to_parameters(self): # Converti i pesi in torch.nn.Parameter di tipo float16 per il training. @@ -50,7 +172,9 @@ def convert_weights_to_parameters(self): unpacked_weight = unpack_ternary(self.weight) half_weight = (unpacked_weight / self.scale_w).to(self.dtype) self.weight = torch.nn.Parameter(half_weight) - self.scale_w = None# <- this is done so that the bitmat kernel knows we're training + self.scale_w = ( + None # <- this is done so that the bitmat kernel knows we're training + ) def convert_weights_to_packed(self): # Converti i pesi indietro in PackedParameter di tipo int8 dopo il training. @@ -58,8 +182,8 @@ def convert_weights_to_packed(self): terniarized_weight, scale_weight = terniarize(self.weight.data) packed_weights = pack_ternary(terniarized_weight) self.scale_w = torch.nn.Parameter(scale_weight) - del self.weight # <- this is done so that torch doesn't trow an error when trying to convert the nn.Parameter to PackedParameter - self.register_buffer('weight', packed_weights) + del self.weight # <- this is done so that torch doesn't trow an error when trying to convert the nn.Parameter to PackedParameter + self.register_buffer("weight", packed_weights) def train(self, mode=True): super().train(mode) @@ -70,13 +194,30 @@ def train(self, mode=True): self.convert_weights_to_packed() return self.to(device) - def forward(self, x): - if self.training and (self.weight.dtype == torch.int8): - # Just to make sure the weights are in the right format even if the user forgot to call train() - self.convert_weights_to_parameters() - x_dtype = x.dtype - x = self.norm(x.to(self.norm.weight.dtype)).to(x_dtype) - output = bitmat(self.weight.data, x, scale_w=self.scale_w) - if self.bias is not None: - output += self.bias.unsqueeze(0).expand_as(output) - return output + def forward(self, input): + self.input = input + quant_input = activation_quant(input, self.input_bits) + + + s = self.Qp / ( + input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + 2e-6 + ) + + if self.weight.dtype!=torch.int8: + quant_weight, scale = weight_quant(self.weight, self.weight_bits) + wp = pack_ternary(quant_weight.to(torch.int8)) + del self.weight + self.register_buffer("weight",wp) + self.scale = scale + print("cold start") + else: + wp = self.weight + scale = self.scale + out = bitmat_(quant_input.half() / s, wp.t().contiguous()).to(torch.float) + + out = out.to(input.dtype) + out = out * scale + if not self.bias is None: + out += self.bias.view(1, -1).expand_as(out) + + return out diff --git a/bitmat/utils/bitmat.py b/bitmat/utils/bitmat.py index 08c45aa..2865ce9 100644 --- a/bitmat/utils/bitmat.py +++ b/bitmat/utils/bitmat.py @@ -5,15 +5,15 @@ from ..triton_kernels.bitmat_kernel import bitmat_ from .packing import pack_ternary -BITMAT_QUANT_ACTIVATIONS = not os.getenv("BITMAT_QUANT_ACTIVATIONS","True").lower() in ('false', '0', 'f') +BITMAT_QUANT_8BIT_ACTIVATIONS = not os.getenv("BITMAT_QUANT_8BIT_ACTIVATIONS","True").lower() in ('false', '0', 'f') + +def terniarize(weight:torch.Tensor): + dtype = weight.dtype + weight = weight.float() + scale = 1 / weight.abs().mean().clamp(min=1e-5) + result = (weight * scale).round().clamp(-1, 1) + return result.type(dtype).to(torch.int8), scale.to(dtype) -def terniarize(weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Terniarizes the weights and returns the scale. - """ - dtype = weights.dtype - scale = 1 / torch.max(weights.abs().mean(), torch.tensor(1e-5)) - return torch.clamp((weights * scale).to(torch.int8), -1, 1), scale.to(dtype) def quantize_activations(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -27,7 +27,7 @@ def quantize_activations(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: class BitMat(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd - def forward(ctx, W, X, scale_w=None,quant_activations=True): + def forward(ctx, W, X, scale_w=None,quant_8bit_activations=True): """ During the forward pass, we ternarize the weights, pack them and then quantize the activations. We then perform the bit matrix multiplication and return the scaled results. @@ -43,8 +43,9 @@ def forward(ctx, W, X, scale_w=None,quant_activations=True): Y = X @ w_packed.t() | dot product Y = Y / scale_w / scale_x) | STE """ - if quant_activations: - X, scale_x = quantize_activations(X) + X, scale_x = quantize_activations(X) + if not quant_8bit_activations: + X = X/scale_x if scale_w is None: dtype = W.dtype @@ -60,7 +61,7 @@ def forward(ctx, W, X, scale_w=None,quant_activations=True): y = bitmat_(X, W.t().contiguous()) out = y / scale_w - if quant_activations: + if quant_8bit_activations: out = out / scale_x return out @@ -74,6 +75,6 @@ def backward(ctx, grad_output): grad_W = (grad_output.transpose(1,2) @ X).mean(dim=0) return grad_W, None, None -def bitmat(W: torch.Tensor, X: torch.Tensor, scale_w,quant_activations=None) -> torch.Tensor: - quant_activations = quant_activations if quant_activations is not None else BITMAT_QUANT_ACTIVATIONS - return BitMat.apply(W, X, scale_w,quant_activations=quant_activations) \ No newline at end of file +def bitmat(W: torch.Tensor, X: torch.Tensor, scale_w,quant_8bit_activations=None) -> torch.Tensor: + quant_8bit_activations = quant_8bit_activations if quant_8bit_activations is not None else BITMAT_QUANT_8BIT_ACTIVATIONS + return BitMat.apply(W, X, scale_w,quant_8bit_activations) \ No newline at end of file diff --git a/bitmat/utils/convert_hf_model.py b/bitmat/utils/convert_hf_model.py index 0924a21..c0d34f6 100644 --- a/bitmat/utils/convert_hf_model.py +++ b/bitmat/utils/convert_hf_model.py @@ -3,7 +3,7 @@ from tqdm import tqdm from ..bitlinear import BitLinear -from transformers import AutoModel, GemmaConfig, MistralConfig, LlamaConfig +from transformers import AutoModel, GemmaConfig, MistralConfig, LlamaConfig, AutoModelForCausalLM # Importing custom hijack classes for specific models from .modeling.model_hijacks.gemma_1_58b import Gemma158ForCausalLM @@ -28,7 +28,12 @@ def convert_hf_model(model: AutoModel) -> AutoModel: elif isinstance(model_config, LlamaConfig): hijacked_model = Llama158ForCausalLM(model_config) else: - raise RuntimeError("Unsupported model type. Please open an issue on GitHub citing the model you are using") + try: + hijacked_model = AutoModelForCausalLM(model_config) + hijacked_model = apply_bitlinear_to_hf_model(hijacked_model) + except Exception as e: + print(str(e)) + raise RuntimeError("Unsupported model type. Please open an issue on GitHub citing the model you are using") pbar.update(1) return hijacked_model diff --git a/bitmat/utils/modeling/automodel.py b/bitmat/utils/modeling/automodel.py index 85b4bee..a39a898 100644 --- a/bitmat/utils/modeling/automodel.py +++ b/bitmat/utils/modeling/automodel.py @@ -22,6 +22,7 @@ def from_pretrained(cls, *args, **kwargs): raise ValueError( f"The model {args[0]} was not found, this mean it has not been mapped yet, please open an issue on the github repository") + print(f"model CLASS {model_class_name}") # Ottieni la classe del modello utilizzando il suo nome model_class = globals()[model_class_name] diff --git a/bitmp/main.py b/bitmp/main.py new file mode 100644 index 0000000..e69de29 diff --git a/dev.py b/dev.py new file mode 100644 index 0000000..54007ef --- /dev/null +++ b/dev.py @@ -0,0 +1,7 @@ +d = {} +print(id(d)) +nd = {} +print(id(nd)) +d= d[0] =nd +print(id(d)) +print(d) \ No newline at end of file