Skip to content

aten.view fold crashes on element type mismatch instead of verifying #4479

@rkayaith

Description

@rkayaith

aten.view accepts IR where the input and output element types differ (e.g., f32 input → bf16 output). This is semantically invalid — aten.view maps to the shape overload of Tensor.view (aten::view(Tensor, SymInt[])) which always preserves dtype. Dtype reinterpretation is a separate op (aten::view.dtype(Tensor, ScalarType)).

There's no verifier to catch the mismatch, so invalid IR reaches genericViewLikeFold which crashes with an assertion failure when trying to constant-fold it. Adding a verifier to reject element type mismatches would catch this class of bug at the point the invalid op is created, rather than crashing later during folding.

Reproducer

// view_type_mismatch.mlir
func.func @main() -> !torch.vtensor<[1,64,1,1],bf16> {
  %0 = torch.vtensor.literal(dense<1.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
  %1 = torch.constant.int 1
  %2 = torch.constant.int 64
  %3 = torch.prim.ListConstruct %1, %2, %1, %1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %4 = torch.aten.view %0, %3 : !torch.vtensor<[64],f32>, !torch.list<int> -> !torch.vtensor<[1,64,1,1],bf16>
  return %4 : !torch.vtensor<[1,64,1,1],bf16>
}

Expected: verification error on the aten.view op.

Actual:

$ torch-mlir-opt view_type_mismatch.mlir --canonicalize
torch-mlir-opt: .../mlir/lib/IR/BuiltinAttributes.cpp:973:
  Assertion `floatAttr.getType() == eltType && "expected float attribute type to equal element type"' failed.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions