Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/chop/nn/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .quantizers_for_hw import (
integer_quantizer_for_hw,
integer_floor_quantizer_for_hw,
mxint_quantizer_for_hw,
)
from .mxint import mxint_quantizer

Expand Down
61 changes: 0 additions & 61 deletions src/chop/nn/quantizers/quantizers_for_hw.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,64 +35,3 @@ def integer_floor_quantizer_for_hw(x: Tensor, width: int, frac_width: int):
fixed_point_value = fixed_point_value.to(torch.int)
fixed_point_value = fixed_point_value % (2**width)
return fixed_point_value


def mxint_quantizer_for_hw(
x: Tensor,
width: int,
exponent_width: int,
block_size: list[int],
floor: bool = False,
):
"""
- Convert IEEE FP32/64 to Microscaling Interger (MXINT), where an exponent is shared over all elements in a block.
- https://arxiv.org/pdf/2310.10537.pdf
- https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf


---
- forward: convert IEEE FP32/64 to MXINT
- backward: STE

---
- `width`: The number of mantissa bits + 1 (the sign bit)
- `exponent_width`: the number of exponent bits
- `block_size`: a list of integers where each integer is the block size on that dimension. See function `block`.
"""

blocked_x, per_block_max, padded_x_shape, block_shape = block(
x,
block_shape=block_size,
)

if torch.all(per_block_max == 0):
per_block_max = torch.ones_like(per_block_max)
else:
per_block_max[per_block_max == 0] = per_block_max[per_block_max != 0].min()

if torch.all(per_block_max == 0):
per_block_max = torch.ones_like(per_block_max)
else:
per_block_max[per_block_max == 0] = per_block_max[per_block_max != 0].min()

exponent_bias = 2 ** (exponent_width - 1) - 1

per_block_expontent = my_floor(torch.log2(per_block_max)) + exponent_bias
per_block_expontent = my_clamp(per_block_expontent, 0, 2**exponent_width - 1)

element_max = 2 ** (width - 1) - 1
shift = 2 ** (width - 2)

scaled_value = shift * blocked_x / 2 ** (per_block_expontent - exponent_bias)

if floor:
quantized_value = my_floor(scaled_value)
else:
quantized_value = my_round(scaled_value)

quantized_value = my_clamp(quantized_value, -element_max, element_max)

return quantized_value, per_block_expontent


# sw_quantizer_to_hw_quantizer = {integer_quantizer: integer_quantizer_for_hw}
3 changes: 0 additions & 3 deletions src/chop/passes/graph/transforms/verilog/emit_bram.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from chop.nn.quantizers import (
integer_quantizer_for_hw,
integer_floor_quantizer_for_hw,
mxint_quantizer_for_hw,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -307,8 +306,6 @@ def emit_parameters_in_dat_internal(node, param_name, file_name):
else:
base_quantizer = integer_quantizer_for_hw

scale = 2**frac_width
thresh = 2**width
data_buff = ""
for i in range(0, out_depth):
line_buff = ""
Expand Down