Repo: llvm/torch-mlir
Title: DecomposeAtenNativeBatchNormOp uses wrong dtype for running stats reshape
DecomposeAtenNativeBatchNormOp reshapes running_mean and running_var from [C] to [1,C,1,...] to broadcast with the input. The result type of the reshape uses the input dtype instead of the running stats dtype, producing an invalid aten.view when the types differ (e.g., bf16 input with f32 running stats).
The bug is at DecomposeComplexOps.cpp:L8497-L8499:
Type dtype = cast<ValueTensorType>(input.getType()).getOptionalDtype();
Type reshapeType = ValueTensorType::get(
context, llvm::ArrayRef(runningStatsShapeInt), dtype);
dtype should come from runningMean, not input.
Reproducer
// bn_mixed_precision.mlir
func.func @main(%arg0: !torch.vtensor<[8,64,56,56],bf16>)
-> !torch.vtensor<[8,64,56,56],bf16> {
%none = torch.constant.none
%running_mean = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
%running_var = torch.vtensor.literal(dense<1.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
%false = torch.constant.bool false
%momentum = torch.constant.float 1.000000e-01
%eps = torch.constant.float 1.000000e-05
%0:3 = torch.aten.native_batch_norm %arg0, %none, %none,
%running_mean, %running_var, %false, %momentum, %eps :
!torch.vtensor<[8,64,56,56],bf16>, !torch.none, !torch.none,
!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>,
!torch.bool, !torch.float, !torch.float
-> !torch.vtensor<[8,64,56,56],bf16>,
!torch.vtensor<[0],bf16>, !torch.vtensor<[0],bf16>
return %0#0 : !torch.vtensor<[8,64,56,56],bf16>
}
Running just the decomposition pass shows the invalid view ops:
$ torch-mlir-opt bn_mixed_precision.mlir --torch-decompose-complex-ops | grep aten.view
%1 = torch.aten.view %running_mean, %0 : !torch.vtensor<[64],f32>, !torch.list<int> -> !torch.vtensor<[1,64,1,1],bf16>
^^^ ^^^^
Running the full pipeline crashes during constant folding of the invalid view (once #4479 is fixed, this will be a verification error instead):
$ torch-mlir-opt bn_mixed_precision.mlir --torch-function-to-torch-backend-pipeline
torch-mlir-opt: .../mlir/lib/IR/BuiltinAttributes.cpp:973:
Assertion `floatAttr.getType() == eltType && "expected float attribute type to equal element type"' failed.
Repo: llvm/torch-mlir
Title:
DecomposeAtenNativeBatchNormOpuses wrong dtype for running stats reshapeDecomposeAtenNativeBatchNormOpreshapesrunning_meanandrunning_varfrom[C]to[1,C,1,...]to broadcast with the input. The result type of the reshape uses the input dtype instead of the running stats dtype, producing an invalidaten.viewwhen the types differ (e.g., bf16 input with f32 running stats).The bug is at DecomposeComplexOps.cpp:L8497-L8499:
Type dtype = cast<ValueTensorType>(input.getType()).getOptionalDtype(); Type reshapeType = ValueTensorType::get( context, llvm::ArrayRef(runningStatsShapeInt), dtype);dtypeshould come fromrunningMean, notinput.Reproducer
Running just the decomposition pass shows the invalid view ops:
Running the full pipeline crashes during constant folding of the invalid view (once #4479 is fixed, this will be a verification error instead):