Skip to content

fix: ScaledFP8WeightTensor crashes on detach/to when inner tensors are inference tensors#1629

Open
Gunther-Schulz wants to merge 5 commits intodeepbeepmeep:mainfrom
Gunther-Schulz:fix/scaled-fp8-inference-tensor-copy
Open

fix: ScaledFP8WeightTensor crashes on detach/to when inner tensors are inference tensors#1629
Gunther-Schulz wants to merge 5 commits intodeepbeepmeep:mainfrom
Gunther-Schulz:fix/scaled-fp8-inference-tensor-copy

Conversation

@Gunther-Schulz
Copy link
Copy Markdown

@Gunther-Schulz Gunther-Schulz commented Mar 21, 2026

I created this PR to be able to run a fine-tuned LTX2.3 Dev model that's only available as FP8.

🤖 Generated with Claude Code

Problem

Running FP8-quantized checkpoints (e.g. LTX-2.3 22B) crashes with:

RuntimeError: Cannot set version_counter for inference tensor

The crash occurs in ScaledFP8WeightTensor.__torch_function__ when LTX-2's _duplicate_timestep_embedder calls .detach().to(device) on a quantized weight during the prewarm step. A secondary crash follows in _linear_fallback:

RuntimeError: Expected all tensors to be on the same device

Root cause

ScaledFP8WeightTensor is a _make_wrapper_subclass tensor whose inner _data (fp8) and _scale (float32) tensors are created as inference tensors during model loading (under torch.inference_mode()). The wrapper itself is not an inference tensor, but its inner tensors are.

When .detach() is called on the wrapper, PyTorch attempts to set up version-counter sharing on the inner tensors — which is illegal for inference tensors. Two dispatch layers were affected:

  1. __torch_function__: receives the C-level TensorBase.detach method (not torch.Tensor.detach), so identity checks fail. Must match by func.__name__.
  2. __torch_dispatch__: the existing aten.detach handler called op(t._data) directly, which triggers the same version-counter error.
  3. __torch_dispatch__ _to_copy: same issue — called op(t._data) without guarding against inference tensors.

The _linear_fallback device mismatch was a separate latent bug: the fallback path converted weights to the right dtype but forgot to move them to the right device, only surfacing when the fallback was actually exercised.

Fix

  • __torch_function__: intercept detach and to by name before they reach C++ dispatch. For detach, return a new wrapper sharing the same inner tensors (safe — model weights never require grad). For to, use torch.inference_mode(False) to move inner tensors without version-counter errors.
  • __torch_dispatch__ detach: stop calling op(t._data) / op(t._scale); share inner tensors directly instead.
  • __torch_dispatch__ _to_copy: wrap op(t._data) / op(t._scale) in torch.inference_mode(False) for consistency with the __torch_function__ path.
  • _linear_fallback: capture target_device = input.device and pass it to both .to() calls so weights land on the correct device.

Why this only surfaces with FP8 checkpoints

The bug requires ScaledFP8WeightTensor to exist, which only happens for FP8-quantized weights. bf16 checkpoints never create these wrappers so the crash path is never reached.

Testing

Verified end-to-end generation with LTX-2.3 22B FP8 checkpoint. No regressions expected for bf16 models as the changed code paths are only reached for ScaledFP8WeightTensor instances.

🤖 Generated with Claude Code

Gunther-Schulz and others added 5 commits March 21, 2026 20:58
In __torch_dispatch__ for _to_copy/to, only dtype and device were
popped from kwargs — copy=True was left in and forwarded to the inner
op() calls on t._data and t._scale. Those are plain inference tensors
(loaded under torch.inference_mode()), and _to_copy with copy=True
triggers PyTorch's version_counter guard on inference tensors:
  "Cannot set version_counter for inference tensor"

This surfaced when _duplicate_timestep_embedder (transformer_args.py)
duplicates the adaln linear weights (which are ScaledFP8WeightTensors)
to a specific CUDA device before the denoising prewarm step.

Since ScaledFP8WeightTensor.create() always constructs a new object,
the copy flag is redundant at the inner level and safe to drop.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…Tensor.to()

The previous fix was wrong: for inference tensors, __torch_dispatch__ is
bypassed entirely, so patching it had no effect. The error originates in
C++ code reached via __torch_function__'s fallback at line 413.

Root cause: torch.Tensor.to(copy=True) on an inference ScaledFP8WeightTensor
hits a C++ code path that attempts to set version counters on inference
tensors, which PyTorch forbids ("Cannot set version_counter for inference
tensor"). This surfaces when _duplicate_timestep_embedder() copies the
adaln linear weights (FP8-quantized, loaded as inference tensors) to a
specific CUDA device before the denoising prewarm step.

Fix: intercept torch.Tensor.to() in __torch_function__ before it reaches
C++. Manually transfer _data and _scale under torch.inference_mode(False)
(which allows creating new normal tensors from inference tensor sources),
then reconstruct the ScaledFP8WeightTensor. The copy= flag is dropped
since we always construct a fresh tensor.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…inner tensors

Diagnostics revealed the real root cause:
- The ScaledFP8WeightTensor wrapper itself is NOT an inference tensor
  (is_inference()=False), but its inner _data/_scale ARE inference tensors
- The failing op was detach(), not to() — func was the C-level
  <method 'detach' of 'torch._C.TensorBase' objects>, so identity
  checks against torch.Tensor.detach/to never matched
- aten::detach on an inference inner tensor tries to set up version-counter
  sharing between the new view and the source, which fails on inference tensors

Fix:
1. __torch_function__: match detach and to by func.__name__ (not identity),
   check arg0 is ScaledFP8WeightTensor, then handle both ops manually:
   - detach: return new wrapper sharing same _data/_scale (no-op since
     weights never require grad; avoids touching inner tensor version counters)
   - to: move _data/_scale under inference_mode(False), reconstruct wrapper
2. __torch_dispatch__ detach handler: stop calling op(t._data)/op(t._scale);
   use inner tensors directly as defence-in-depth for other dispatch paths

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
When mmgp's RAM offload keeps a ScaledFP8WeightTensor weight on CPU and
the input arrives on CUDA, QLinearScaledFP8.forward falls through to
_linear_fallback because the device check fails:
  input.device == qweight._data.device  →  cuda:0 != cpu

_linear_fallback previously called weights.to(target_type) which converts
dtype but not device, then crashed in torch.matmul with:
  "Expected all tensors to be on the same device"

Fix: capture input.device before any dtype conversion and pass it to
both weights.to() and output_scales.to(), matching the device-aware
behaviour already present in _linear_scaled.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The __torch_dispatch__ _to_copy handler was calling op(t._data) directly,
which fails when inner tensors are inference tensors (same bug fixed in
__torch_function__). Now both paths are consistent.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant