Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 54 additions & 7 deletions shared/qtypes/scaled_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,13 @@ def _linear_fallback(self, input, bias=None):
qweight= self
target_type = _normalize_default_dtype(qweight.dtype)
weights, output_scales = qweight._data, qweight._scale
target_device = input.device
input = input.to(target_type)
output_scales = output_scales.to(target_type)
output_scales = output_scales.to(device=target_device, dtype=target_type)
in_features = input.shape[-1]
out_features = weights.shape[0]
output_shape = input.shape[:-1] + (out_features,)
weights = weights.to(target_type)
weights = weights.to(device=target_device, dtype=target_type)
weights *= output_scales
out = torch.matmul(input.reshape(-1, in_features), weights.t())
out = out.reshape(output_shape)
Expand Down Expand Up @@ -409,6 +410,47 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
bias = args[2] if len(args) > 2 else kwargs.get("bias", None)
if isinstance(weight, ScaledFP8WeightTensor):
return weight.linear(input, bias=bias)
# detach() and to() on a ScaledFP8WeightTensor whose inner _data/_scale are
# inference tensors fail deep in C++ with "Cannot set version_counter for
# inference tensor". The func received here is the C-level TensorBase method
# (not torch.Tensor.detach / torch.Tensor.to), so we match by name.
# Handle both ops before they reach C++ dispatch.
func_name = getattr(func, "__name__", None)
t = args[0] if args else None
if isinstance(t, ScaledFP8WeightTensor):
if func_name == "detach":
# Return a new wrapper sharing the same inner tensors (no-op detach).
# Calling op(t._data) on an inference _data would try to set version
# counters → error. Model weights don't need grad so sharing is safe.
return ScaledFP8WeightTensor.create(
weight=t._data,
scale=t._scale,
size=t.size(),
stride=t.stride(),
dtype=t.dtype,
device=t.device,
requires_grad=False,
)
if func_name == "to":
kw = dict(kwargs)
device = kw.pop("device", t.device)
dtype = kw.pop("dtype", t.dtype)
kw.pop("copy", None) # always creating a new tensor, flag is redundant
if isinstance(device, str):
device = torch.device(device)
if dtype != t.dtype:
return t.dequantize(dtype=dtype, device=device)
with torch.inference_mode(False):
out_data = t._data.to(device=device, **kw)
out_scale = t._scale.to(device=device, **kw)
return ScaledFP8WeightTensor.create(
weight=out_data,
scale=out_scale,
size=t.size(),
stride=t.stride(),
dtype=t.dtype,
device=device,
)
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)

Expand All @@ -424,23 +466,28 @@ def __torch_dispatch__(cls, op, types, args, kwargs=None):
return weight.linear(input, bias=bias)
if op is torch.ops.aten.detach:
t = args[0]
# Do NOT call op(t._data) / op(t._scale): if the inner tensors are
# inference tensors, aten::detach tries to set up version-counter sharing
# and raises "Cannot set version_counter for inference tensor".
# Model weights never require grad, so sharing inner tensors is safe.
return ScaledFP8WeightTensor.create(
weight=op(t._data),
scale=op(t._scale),
weight=t._data,
scale=t._scale,
size=t.size(),
stride=t.stride(),
dtype=t.dtype,
device=t.device,
requires_grad=t.requires_grad,
requires_grad=False,
)
if op in (torch.ops.aten._to_copy, torch.ops.aten.to):
t = args[0]
dtype = kwargs.pop("dtype", t.dtype) if kwargs else t.dtype
device = kwargs.pop("device", t.device) if kwargs else t.device
if dtype != t.dtype:
return t.dequantize(dtype=dtype, device=device)
out_data = op(t._data, device=device, **(kwargs or {}))
out_scale = op(t._scale, device=device, **(kwargs or {}))
with torch.inference_mode(False):
out_data = op(t._data, device=device, **(kwargs or {}))
out_scale = op(t._scale, device=device, **(kwargs or {}))
return ScaledFP8WeightTensor.create(
weight=out_data,
scale=out_scale,
Expand Down