fix: ScaledFP8WeightTensor crashes on detach/to when inner tensors are inference tensors#1629
Open
Gunther-Schulz wants to merge 5 commits intodeepbeepmeep:mainfrom
Open
Conversation
In __torch_dispatch__ for _to_copy/to, only dtype and device were popped from kwargs — copy=True was left in and forwarded to the inner op() calls on t._data and t._scale. Those are plain inference tensors (loaded under torch.inference_mode()), and _to_copy with copy=True triggers PyTorch's version_counter guard on inference tensors: "Cannot set version_counter for inference tensor" This surfaced when _duplicate_timestep_embedder (transformer_args.py) duplicates the adaln linear weights (which are ScaledFP8WeightTensors) to a specific CUDA device before the denoising prewarm step. Since ScaledFP8WeightTensor.create() always constructs a new object, the copy flag is redundant at the inner level and safe to drop. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…Tensor.to()
The previous fix was wrong: for inference tensors, __torch_dispatch__ is
bypassed entirely, so patching it had no effect. The error originates in
C++ code reached via __torch_function__'s fallback at line 413.
Root cause: torch.Tensor.to(copy=True) on an inference ScaledFP8WeightTensor
hits a C++ code path that attempts to set version counters on inference
tensors, which PyTorch forbids ("Cannot set version_counter for inference
tensor"). This surfaces when _duplicate_timestep_embedder() copies the
adaln linear weights (FP8-quantized, loaded as inference tensors) to a
specific CUDA device before the denoising prewarm step.
Fix: intercept torch.Tensor.to() in __torch_function__ before it reaches
C++. Manually transfer _data and _scale under torch.inference_mode(False)
(which allows creating new normal tensors from inference tensor sources),
then reconstruct the ScaledFP8WeightTensor. The copy= flag is dropped
since we always construct a fresh tensor.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…inner tensors
Diagnostics revealed the real root cause:
- The ScaledFP8WeightTensor wrapper itself is NOT an inference tensor
(is_inference()=False), but its inner _data/_scale ARE inference tensors
- The failing op was detach(), not to() — func was the C-level
<method 'detach' of 'torch._C.TensorBase' objects>, so identity
checks against torch.Tensor.detach/to never matched
- aten::detach on an inference inner tensor tries to set up version-counter
sharing between the new view and the source, which fails on inference tensors
Fix:
1. __torch_function__: match detach and to by func.__name__ (not identity),
check arg0 is ScaledFP8WeightTensor, then handle both ops manually:
- detach: return new wrapper sharing same _data/_scale (no-op since
weights never require grad; avoids touching inner tensor version counters)
- to: move _data/_scale under inference_mode(False), reconstruct wrapper
2. __torch_dispatch__ detach handler: stop calling op(t._data)/op(t._scale);
use inner tensors directly as defence-in-depth for other dispatch paths
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
When mmgp's RAM offload keeps a ScaledFP8WeightTensor weight on CPU and the input arrives on CUDA, QLinearScaledFP8.forward falls through to _linear_fallback because the device check fails: input.device == qweight._data.device → cuda:0 != cpu _linear_fallback previously called weights.to(target_type) which converts dtype but not device, then crashed in torch.matmul with: "Expected all tensors to be on the same device" Fix: capture input.device before any dtype conversion and pass it to both weights.to() and output_scales.to(), matching the device-aware behaviour already present in _linear_scaled. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The __torch_dispatch__ _to_copy handler was calling op(t._data) directly, which fails when inner tensors are inference tensors (same bug fixed in __torch_function__). Now both paths are consistent. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
I created this PR to be able to run a fine-tuned LTX2.3 Dev model that's only available as FP8.
🤖 Generated with Claude Code
Problem
Running FP8-quantized checkpoints (e.g. LTX-2.3 22B) crashes with:
The crash occurs in
ScaledFP8WeightTensor.__torch_function__when LTX-2's_duplicate_timestep_embeddercalls.detach().to(device)on a quantized weight during the prewarm step. A secondary crash follows in_linear_fallback:Root cause
ScaledFP8WeightTensoris a_make_wrapper_subclasstensor whose inner_data(fp8) and_scale(float32) tensors are created as inference tensors during model loading (undertorch.inference_mode()). The wrapper itself is not an inference tensor, but its inner tensors are.When
.detach()is called on the wrapper, PyTorch attempts to set up version-counter sharing on the inner tensors — which is illegal for inference tensors. Two dispatch layers were affected:__torch_function__: receives the C-levelTensorBase.detachmethod (nottorch.Tensor.detach), so identity checks fail. Must match byfunc.__name__.__torch_dispatch__: the existingaten.detachhandler calledop(t._data)directly, which triggers the same version-counter error.__torch_dispatch___to_copy: same issue — calledop(t._data)without guarding against inference tensors.The
_linear_fallbackdevice mismatch was a separate latent bug: the fallback path converted weights to the right dtype but forgot to move them to the right device, only surfacing when the fallback was actually exercised.Fix
__torch_function__: interceptdetachandtoby name before they reach C++ dispatch. Fordetach, return a new wrapper sharing the same inner tensors (safe — model weights never require grad). Forto, usetorch.inference_mode(False)to move inner tensors without version-counter errors.__torch_dispatch__detach: stop callingop(t._data)/op(t._scale); share inner tensors directly instead.__torch_dispatch___to_copy: wrapop(t._data)/op(t._scale)intorch.inference_mode(False)for consistency with the__torch_function__path._linear_fallback: capturetarget_device = input.deviceand pass it to both.to()calls so weights land on the correct device.Why this only surfaces with FP8 checkpoints
The bug requires
ScaledFP8WeightTensorto exist, which only happens for FP8-quantized weights. bf16 checkpoints never create these wrappers so the crash path is never reached.Testing
Verified end-to-end generation with LTX-2.3 22B FP8 checkpoint. No regressions expected for bf16 models as the changed code paths are only reached for
ScaledFP8WeightTensorinstances.🤖 Generated with Claude Code