Skip to content

Fix dtype handling in attention masks#28

Open
ray24777 wants to merge 1 commit intoKiteretsu77:mainfrom
ray24777:patch-1
Open

Fix dtype handling in attention masks#28
ray24777 wants to merge 1 commit intoKiteretsu77:mainfrom
ray24777:patch-1

Conversation

@ray24777
Copy link

@ray24777 ray24777 commented Feb 13, 2026

Fix RuntimeError during DAT model FP16 inference

Description

This PR fixes a bug where using DAT models with half-precision (fp16) inference would crash due to a data type mismatch.

Fixes #29 .

The Issue

When running inference with weight_dtype=torch.float16, the dynamically generated attention mask in Adaptive_Spatial_Attention remained in float32 (default), while the input tensor v was in float16. This caused a RuntimeError during the matrix multiplication operation (attn @ v).

Error Log

Traceback (most recent call last):
  File "/workspace/./test_code/inference.py", line 271, in <module>
    inner_loop(os.path.join(input_dir, filename))
  File "/workspace/./test_code/inference.py", line 254, in inner_loop
    super_resolve_img(generator, process_dir, output_path, weight_dtype, downsample_threshold, crop_for_4x=True)
  ...
  File "/workspace/architecture/dat.py", line 404, in forward
    x1_shift = self.attns[0](qkv_0, _H, _W, mask=mask_tmp[0].to(x.device))
  ...
  File "/workspace/architecture/dat.py", line 244, in forward
    x = (attn @ v)
RuntimeError: expected scalar type Half but found Float

Solution

Modified the forward method in Adaptive_Spatial_Attention , specifically where mask_tmp is handled, to explicitly cast the generated mask to the input tensor's dtype. This ensures compatibility regardless of whether the model is running in full or half precision.

Fix a bug when using DAT models at fp16 precision infer
@ray24777
Copy link
Author

Hey @Kiteretsu77 , would you please review this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RuntimeError: expected scalar type Half but found Float during FP16 inference

1 participant