Skip to content

fix(rtx): reintroduce BF16 support, fall back depthwise conv to PyTorch#4178

Open
tp5uiuc wants to merge 1 commit intopytorch:mainfrom
tp5uiuc:bf16-depthwise-fallback
Open

fix(rtx): reintroduce BF16 support, fall back depthwise conv to PyTorch#4178
tp5uiuc wants to merge 1 commit intopytorch:mainfrom
tp5uiuc:bf16-depthwise-fallback

Conversation

@tp5uiuc
Copy link
Copy Markdown
Contributor

@tp5uiuc tp5uiuc commented Apr 9, 2026

Description

TensorRT-RTX does not support depthwise convolutions (grouped conv where groups == out_channels) in BF16. Previously, BF16 was globally disabled on TensorRT-RTX as a workaround. This was overly broad — all non-depthwise ops support BF16 correctly.

This PR removes the global BF16 disable and instead adds a targeted capability_validator (depthwise_bf16_validator) on the aten.convolution.default converter. When a depthwise convolution with a BF16 input tensor is detected on TensorRT-RTX, the validator returns False, causing the partitioner to fall back to PyTorch for that specific node. All other convolutions remain on TRT.

Changes:

  • _TRTInterpreter.py: Remove the global RuntimeError for BF16 on RTX
  • aten_ops_converters.py: Add depthwise_bf16_validator and register it on the convolution converter
  • Test files: Remove pytest.skip / @unittest.skipIf guards that blocked BF16 tests on RTX

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

Remove the global BF16 disable on TensorRT-RTX and instead add a
targeted capability_validator that rejects only depthwise convolutions
(groups == out_channels) when the input tensor is BF16. This causes
the partitioner to fall back to PyTorch for those specific nodes while
all other convolutions remain on TRT.

Root cause: TensorRT-RTX does not support depthwise conv/deconv in
BF16. The previous global disable was overly broad — all non-depthwise
ops support BF16 correctly.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@meta-cla meta-cla bot added the cla signed label Apr 9, 2026
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Apr 9, 2026
@github-actions github-actions bot requested a review from zewenli98 April 9, 2026 21:16
@narendasan narendasan requested a review from lanluo-nvidia April 9, 2026 22:07
Copy link
Copy Markdown
Collaborator

@lanluo-nvidia lanluo-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm
waiting for the ci to pass.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants