Skip to content
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ readme = "README.md"
requires-python = ">=3.11.9"

dependencies = [
"torch==2.6",
"torch>=2.6",
"torchvision",
"onnx",
"black",
Expand All @@ -15,7 +15,7 @@ dependencies = [
"colorlog",
"pytest",
"pytorch-lightning",
"transformers==4.51",
"transformers==4.57",
"timm",
"pytorch-nlp",
"datasets==3.3.2",
Expand Down Expand Up @@ -45,7 +45,7 @@ dependencies = [
"opencv-python",
"kornia",
"ghp-import",
"optimum==1.24.0",
"optimum[onnxruntime]>=1.25",
"pytest-profiling",
"myst_parser",
"pytest-cov",
Expand Down
24 changes: 24 additions & 0 deletions src/chop/nn/quantized/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
bmm_minifloat_ieee,
bmm_binary,
bmm_ternary,
bmm_mxfp,
bmm_mxint,
matmul_block_fp,
matmul_block_log,
matmul_block_minifloat,
Expand All @@ -51,7 +53,14 @@
matmul_minifloat_ieee,
matmul_binary,
matmul_ternary,
matmul_mxfp,
matmul_mxint,
)
from .softmax import softmax_mxfp, softmax_mxint, softmax_minifloat
from .silu import silu_mxfp, silu_mxint, silu_minifloat
from .rope import rope_mxfp, rope_mxint, rope_minifloat
from .kvcache import kv_cache_mxfp, kv_cache_mxint

from .mult import (
mult_block_fp,
mult_block_log,
Expand Down Expand Up @@ -203,6 +212,10 @@
"matmul_block_log": matmul_block_log,
"matmul_binary": matmul_binary,
"matmul_ternary": matmul_ternary,
"matmul_mxfp": matmul_mxfp,
"matmul_mxint": matmul_mxint,
"bmm_mxfp": bmm_mxfp,
"bmm_mxint": bmm_mxint,
"relu_block_minifloat": relu_block_minifloat,
"relu_integer": relu_integer,
"relu_fixed": relu_integer,
Expand Down Expand Up @@ -277,4 +290,15 @@
"linear_ternary": linearTernary,
"linear_lutnet": linearLUT,
"linear_logicnets": linearLogicNets,
"softmax_mxfp": softmax_mxfp,
"softmax_mxint": softmax_mxint,
"softmax_minifloat": softmax_minifloat,
"silu_mxfp": silu_mxfp,
"silu_mxint": silu_mxint,
"silu_minifloat": silu_minifloat,
"rope_mxfp": rope_mxfp,
"rope_mxint": rope_mxint,
"rope_minifloat": rope_minifloat,
"kv_cache_mxfp": kv_cache_mxfp,
"kv_cache_mxint": kv_cache_mxint,
}
43 changes: 43 additions & 0 deletions src/chop/nn/quantized/functional/kvcache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from functools import partial

from torch import Tensor

from chop.nn.quantizers import mxfp_quantizer, mxint_quantizer


def kv_cache_mxfp(
key_states: Tensor,
value_states: Tensor,
config: dict = None,
) -> tuple[Tensor, Tensor]:
x_block_size = config["data_in_block_size"]
x_exp_bits = config["data_in_exponent_width"]
x_frac_bits = config["data_in_frac_width"]

x_quantizer = partial(
mxfp_quantizer,
block_size=x_block_size,
element_exp_bits=x_exp_bits,
element_frac_bits=x_frac_bits,
block_dim=-1,
)

return x_quantizer(key_states), x_quantizer(value_states)


def kv_cache_mxint(
key_states: Tensor,
value_states: Tensor,
config: dict = None,
) -> tuple[Tensor, Tensor]:
x_block_size = config["data_in_block_size"]
x_element_bits = config["data_in_width"]

x_quantizer = partial(
mxint_quantizer,
block_size=x_block_size,
element_bits=x_element_bits,
block_dim=-1,
)

return x_quantizer(key_states), x_quantizer(value_states)
2 changes: 2 additions & 0 deletions src/chop/nn/quantized/functional/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
binary_quantizer,
ternary_quantizer,
mxint_hardware,
mxfp_quantizer,
mxint_quantizer,
)


Expand Down
80 changes: 80 additions & 0 deletions src/chop/nn/quantized/functional/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
minifloat_ieee_quantizer,
binary_quantizer,
ternary_quantizer,
mxfp_quantizer,
mxint_quantizer,
)

# PyTorch has torch.matmul and torch.bmm for matrix multiplication
Expand Down Expand Up @@ -430,3 +432,81 @@ def bmm_block_minifloat(x, y, config):

def bmm_block_log(x, y, config):
return generic_matmul_block_log(x, y, config, style="bmm")


def generic_matmul_mxfp(x, y, config, style="matmul"):
bypass = config.get("bypass", False)
matmul = matmul_mapping[style]
if bypass:
return matmul(x, y)

x_block_size = config["data_in_block_size"]
x_exp_bits = config["data_in_exponent_width"]
x_frac_bits = config["data_in_frac_width"]
y_block_size = config["weight_block_size"]
y_exp_bits = config["weight_exponent_width"]
y_frac_bits = config["weight_frac_width"]

x_quantizer = partial(
mxfp_quantizer,
block_size=x_block_size,
element_exp_bits=x_exp_bits,
element_frac_bits=x_frac_bits,
block_dim=-1,
)
y_quantizer = partial(
mxfp_quantizer,
block_size=y_block_size,
element_exp_bits=y_exp_bits,
element_frac_bits=y_frac_bits,
block_dim=-1,
)

x = x_quantizer(x)
y = y_quantizer(y)
return matmul(x, y)


def generic_matmul_mxint(x, y, config, style="matmul"):
bypass = config.get("bypass", False)
matmul = matmul_mapping[style]
if bypass:
return matmul(x, y)

x_block_size = config["data_in_block_size"]
x_element_bits = config["data_in_width"]
y_block_size = config["weight_block_size"]
y_element_bits = config["weight_width"]

x_quantizer = partial(
mxint_quantizer,
block_size=x_block_size,
element_bits=x_element_bits,
block_dim=-1,
)
y_quantizer = partial(
mxint_quantizer,
block_size=y_block_size,
element_bits=y_element_bits,
block_dim=-1,
)

x = x_quantizer(x)
y = y_quantizer(y)
return matmul(x, y)


def matmul_mxfp(x, y, config):
return generic_matmul_mxfp(x, y, config, "matmul")


def matmul_mxint(x, y, config):
return generic_matmul_mxint(x, y, config, "matmul")


def bmm_mxfp(x, y, config):
return generic_matmul_mxfp(x, y, config, "bmm")


def bmm_mxint(x, y, config):
return generic_matmul_mxint(x, y, config, "bmm")
101 changes: 101 additions & 0 deletions src/chop/nn/quantized/functional/rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from functools import partial

import torch
from torch import Tensor

from chop.nn.quantizers import mxfp_quantizer, mxint_quantizer
from chop.nn.quantizers._minifloat_mx import MinifloatMeta, minifloat_quantizer_sim


def rotate_half(x: Tensor) -> Tensor:
"""Rotate half the last dimension (for RoPE)."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def _apply_rope(q, k, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)

seq_len = q.size(-2)
cos = cos[..., :seq_len, :]
sin = sin[..., :seq_len, :]

q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def rope_mxfp(
q: Tensor,
k: Tensor,
cos: Tensor,
sin: Tensor,
config: dict = None,
unsqueeze_dim: int = 1,
) -> tuple[Tensor, Tensor]:
x_block_size = config["data_in_block_size"]
x_exp_bits = config["data_in_exponent_width"]
x_frac_bits = config["data_in_frac_width"]

x_quantizer = partial(
mxfp_quantizer,
block_size=x_block_size,
element_exp_bits=x_exp_bits,
element_frac_bits=x_frac_bits,
block_dim=-1,
)

cos = x_quantizer(cos)
sin = x_quantizer(sin)
return _apply_rope(q, k, cos, sin, unsqueeze_dim)


def rope_mxint(
q: Tensor,
k: Tensor,
cos: Tensor,
sin: Tensor,
config: dict = None,
unsqueeze_dim: int = 1,
) -> tuple[Tensor, Tensor]:
x_block_size = config["data_in_block_size"]
x_element_bits = config["data_in_width"]

x_quantizer = partial(
mxint_quantizer,
block_size=x_block_size,
element_bits=x_element_bits,
block_dim=-1,
)

cos = x_quantizer(cos)
sin = x_quantizer(sin)
return _apply_rope(q, k, cos, sin, unsqueeze_dim)


def rope_minifloat(
q: Tensor,
k: Tensor,
cos: Tensor,
sin: Tensor,
config: dict = None,
unsqueeze_dim: int = 1,
) -> tuple[Tensor, Tensor]:
x_exp_bits = config["data_in_exponent_width"]
x_frac_bits = config["data_in_frac_width"]

x_quantizer = partial(
minifloat_quantizer_sim,
minifloat_meta=MinifloatMeta(
exp_bits=x_exp_bits,
frac_bits=x_frac_bits,
is_finite=config.get("data_in_is_finite", True),
round_mode=config.get("data_in_round_mode", "rn"),
),
)

cos = x_quantizer(cos)
sin = x_quantizer(sin)
return _apply_rope(q, k, cos, sin, unsqueeze_dim)
57 changes: 57 additions & 0 deletions src/chop/nn/quantized/functional/silu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from functools import partial

import torch
from torch import Tensor

from chop.nn.quantizers import mxfp_quantizer, mxint_quantizer
from chop.nn.quantizers._minifloat_mx import MinifloatMeta, minifloat_quantizer_sim


def silu_mxfp(x: Tensor, config: dict = None) -> Tensor:
x_block_size = config["data_in_block_size"]
x_exp_bits = config["data_in_exponent_width"]
x_frac_bits = config["data_in_frac_width"]

x_quantizer = partial(
mxfp_quantizer,
block_size=x_block_size,
element_exp_bits=x_exp_bits,
element_frac_bits=x_frac_bits,
block_dim=-1,
)

x = x_quantizer(x)
return torch.nn.functional.silu(x)


def silu_mxint(x: Tensor, config: dict = None) -> Tensor:
x_block_size = config["data_in_block_size"]
x_element_bits = config["data_in_width"]

x_quantizer = partial(
mxint_quantizer,
block_size=x_block_size,
element_bits=x_element_bits,
block_dim=-1,
)

x = x_quantizer(x)
return torch.nn.functional.silu(x)


def silu_minifloat(x: Tensor, config: dict = None) -> Tensor:
x_exp_bits = config["data_in_exponent_width"]
x_frac_bits = config["data_in_frac_width"]

x_quantizer = partial(
minifloat_quantizer_sim,
minifloat_meta=MinifloatMeta(
exp_bits=x_exp_bits,
frac_bits=x_frac_bits,
is_finite=config.get("data_in_is_finite", True),
round_mode=config.get("data_in_round_mode", "rn"),
),
)

x = x_quantizer(x)
return torch.nn.functional.silu(x)
Loading
Loading