Skip to content

DecomposeAtenNativeBatchNormOp uses wrong dtype for running stats reshape #4480

@rkayaith

Description

@rkayaith

Repo: llvm/torch-mlir
Title: DecomposeAtenNativeBatchNormOp uses wrong dtype for running stats reshape


DecomposeAtenNativeBatchNormOp reshapes running_mean and running_var from [C] to [1,C,1,...] to broadcast with the input. The result type of the reshape uses the input dtype instead of the running stats dtype, producing an invalid aten.view when the types differ (e.g., bf16 input with f32 running stats).

The bug is at DecomposeComplexOps.cpp:L8497-L8499:

Type dtype = cast<ValueTensorType>(input.getType()).getOptionalDtype();
Type reshapeType = ValueTensorType::get(
    context, llvm::ArrayRef(runningStatsShapeInt), dtype);

dtype should come from runningMean, not input.

Reproducer

// bn_mixed_precision.mlir
func.func @main(%arg0: !torch.vtensor<[8,64,56,56],bf16>)
    -> !torch.vtensor<[8,64,56,56],bf16> {
  %none = torch.constant.none
  %running_mean = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
  %running_var = torch.vtensor.literal(dense<1.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
  %false = torch.constant.bool false
  %momentum = torch.constant.float 1.000000e-01
  %eps = torch.constant.float 1.000000e-05
  %0:3 = torch.aten.native_batch_norm %arg0, %none, %none,
      %running_mean, %running_var, %false, %momentum, %eps :
      !torch.vtensor<[8,64,56,56],bf16>, !torch.none, !torch.none,
      !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>,
      !torch.bool, !torch.float, !torch.float
      -> !torch.vtensor<[8,64,56,56],bf16>,
         !torch.vtensor<[0],bf16>, !torch.vtensor<[0],bf16>
  return %0#0 : !torch.vtensor<[8,64,56,56],bf16>
}

Running just the decomposition pass shows the invalid view ops:

$ torch-mlir-opt bn_mixed_precision.mlir --torch-decompose-complex-ops | grep aten.view
%1 = torch.aten.view %running_mean, %0 : !torch.vtensor<[64],f32>, !torch.list<int> -> !torch.vtensor<[1,64,1,1],bf16>
                                                              ^^^                                                  ^^^^

Running the full pipeline crashes during constant folding of the invalid view (once #4479 is fixed, this will be a verification error instead):

$ torch-mlir-opt bn_mixed_precision.mlir --torch-function-to-torch-backend-pipeline
torch-mlir-opt: .../mlir/lib/IR/BuiltinAttributes.cpp:973:
  Assertion `floatAttr.getType() == eltType && "expected float attribute type to equal element type"' failed.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions