Skip to content

[MLIR][Vector]Add constraints to vector.shape_cast(constant) -> constant #147691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

MengmSun
Copy link

@MengmSun MengmSun commented Jul 9, 2025

We have the case that after ConvertToLLVMPass it looks like:

...
%4 = llvm.mlir.constant(dense<0.000000e+00> : vector<192xf8E4M3FN>) : vector<192xi8>
%8 = vector.shape_cast %4 : vector<192xi8> to vector<1x192xi8>
%10 = vector.extract %8[0] : vector<192xi8> from vector<1x192xi8>
...

Our next pass is Canonicalizer. Several months ago everything went smoothly. However recently we met problem that

mlir::DenseElementsAttr mlir::DenseElementsAttr::reshape(mlir::ShapedType): Assertion `newType.getElementType() == curType.getElementType() && "expected the same element type"' failed.

and we found that's because a reshape operation is added for vector.shape_cast(constant) -> constant when calling ShapeCastOp::fold() in the Canonicalizer pass. This operation will fail if the element type of the source attribute and return type are different.

So we want to add the constraints that only when the element type of the source attribute and return type are the same it will return reshape operation to make our case work as before and will not influence other cases.

Copy link

github-actions bot commented Jul 9, 2025

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jul 9, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Mengmeng Sun (MengmSun)

Changes

We have the case that after ConvertToLLVMPass it looks like:

...
%4 = llvm.mlir.constant(dense&lt;0.000000e+00&gt; : vector&lt;192xf8E4M3FN&gt;) : vector&lt;192xi8&gt;
%8 = vector.shape_cast %4 : vector&lt;192xi8&gt; to vector&lt;1x192xi8&gt;
%10 = vector.extract %8[0] : vector&lt;192xi8&gt; from vector&lt;1x192xi8&gt;
...

Our next pass is Canonicalizer. Several months ago everything went smoothly. However recently we met problem that

mlir::DenseElementsAttr mlir::DenseElementsAttr::reshape(mlir::ShapedType): Assertion `newType.getElementType() == curType.getElementType() &amp;&amp; "expected the same element type"' failed.

and we found that's because a reshape operation is added for vector.shape_cast(constant) -&gt; constant. This operation will fail if the element type of the source attribute and return type are different.

So we want to add the constraints that only when the element type of the source attribute and return type are the same it will return reshape operation to make our case work as before and will not influence other cases.


Full diff: https://github.com/llvm/llvm-project/pull/147691.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6-3)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+12)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 214d2ba7e1b8e..5bbe6704aac48 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5922,10 +5922,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
       return bcastOp.getSource();
   }
 
-  // shape_cast(constant) -> constant
+  // shape_cast(constant) -> constant,
+  // if element type of the source and result are the same
   if (auto splatAttr =
-          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
-    return splatAttr.reshape(getType());
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
+    if (splatAttr.getElementType() == resultType.getElementType())
+      return splatAttr.reshape(getType());
+  }
 
   // shape_cast(poison) -> poison
   if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8a9e27378df61..69da8a31d2c9b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1002,6 +1002,18 @@ func.func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> {
 
 // -----
 
+// CHECK-LABEL: func @canonicalize_extract_shapecast_different_element_type
+func.func @canonicalize_extract_shapecast_different_element_type()->vector<12xi8> {
+  %0 = llvm.mlir.constant(dense<0.000000e+00> : vector<12xf8E4M3FN>) : vector<12xi8>
+  // CHECK-NOT: vector.shape_cast
+  %1 = vector.shape_cast %0 : vector<12xi8> to vector<1x12xi8>
+  // CHECK-NOT: vector.extract
+  %2 = vector.extract %1[0] : vector<12xi8> from vector<1x12xi8>
+  return %2 : vector<12xi8>
+}
+
+// -----
+
 // CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar
 //       CHECK:   vector.broadcast
 //   CHECK-NOT:   vector.shape_cast

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR #133988 moved the canonicalizer to a folder, that is probably what triggered your error.

I am not sure if

%0 = llvm.mlir.constant(dense<0.000000e+00> : vector<12xf8E4M3FN>) : vector<12xi8>

is valid. I can see the bitwidth of the element, and the number of elements is the same between the 2 vectors. But I'd have thought they must be identical. Maybe there is a verification missing in LogicalResult LLVM::ConstantOp::verify() ?

When I run

mlir-opt --verify-diagnostics playtime.mlir

on

func.func @canonicalize_extract_shapecast_different_element_type()->vector<12xi128> {
  %0 = llvm.mlir.constant(dense<0.000000e+00> : vector<12xf8E4M3FN>) : vector<12xi128>
  %1 = vector.shape_cast %0 : vector<12xi128> to vector<1x12xi128>
  %2 = vector.extract %1[0] : vector<12xi128> from vector<1x12xi128>
  return %2 : vector<12xi128>
}

I don't get an error either, but this example looks especially wrong because the number of bits is different between the element types.

@dcaballe
Copy link
Contributor

%0 = llvm.mlir.constant(dense<0.000000e+00> : vector<12xf8E4M3FN>) : vector<12xi8>

The following question is unrelated but I've seen this popping up multiple times lately: why llvm.mlir.constant allows the attribute and the return type to mismatch?

@newling
Copy link
Contributor

newling commented Jul 11, 2025

%0 = llvm.mlir.constant(dense<0.000000e+00> : vector<12xf8E4M3FN>) : vector<12xi8>

The following question is unrelated but I've seen this popping up multiple times lately: why llvm.mlir.constant allows the attribute and the return type to mismatch?

I think this is actually central to the PR!

@MengmSun
Copy link
Author

%0 = llvm.mlir.constant(dense<0.000000e+00> : vector<12xf8E4M3FN>) : vector<12xi8>

The following question is unrelated but I've seen this popping up multiple times lately: why llvm.mlir.constant allows the attribute and the return type to mismatch?

I think this is actually central to the PR!

After investigation this is probably caused by the lowering process of arith constrained vector in the ConvertToLLVMPass. It lastly called LLVM::detail::vectorOneToOneRewrite() to lower %cst = arith.constant dense<0.000000e+00> : vector<192xf8E4M3FN> and the value of the attribute parameter is returned by attrConvert.getAttrs() which is the attributes of the source op. And in the lowering process, the typeConverter convert f8E4M3FN to i8. So after this pass,

%cst = arith.constant dense<0.000000e+00> : vector<192xf8E4M3FN>

will be

%4 = llvm.mlir.constant(dense<0.000000e+00> : vector<192xf8E4M3FN>) : vector<192xi8>

From our perspective, if the logic of lowering constrained vector needs to be re-considered it must be a big change and will not last in a short term. We hope to just fix our current problem in this PR.

Comment on lines 5925 to 5931
// shape_cast(constant) -> constant
// shape_cast(constant) -> constant,
// if element type of the source and result are the same
if (auto splatAttr =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
return splatAttr.reshape(getType());
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
if (splatAttr.getElementType() == resultType.getElementType())
return splatAttr.reshape(getType());
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code below will fold, but use the correct element type:

  // shape_cast(constant) -> constant
  if (auto splatAttr =
          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {

    // The shape and 'scalable dims' of the new attribute must match the result
    // of the shape_cast:
    auto newShape = resultType.getShape();
    auto newScalableDims = resultType.getScalableDims();

    // The element type must be retained. Note that this is to handle currently
    // valid IR like
    //
    // ```
    // %0 = llvm.mlir.constant(dense<0.> : vector<1xf8E4M3FN>) : vector<1xi8>
    // %1 = vector.shape_cast %0 : vector<1xi8> to vector<1x1xi8>
    // ```
    //
    // where the element types of the attribute and result do not match.
    auto newElementType = splatAttr.getElementType();

    auto newAttr = VectorType::get(newShape, newElementType, newScalableDims);

    return DenseElementsAttr::get(newAttr,
                                  splatAttr.getSplatValue<Attribute>());
  }

@newling
Copy link
Contributor

newling commented Jul 14, 2025

%0 = llvm.mlir.constant(dense<0.000000e+00> : vector<12xf8E4M3FN>) : vector<12xi8>

The following question is unrelated but I've seen this popping up multiple times lately: why llvm.mlir.constant allows the attribute and the return type to mismatch?

I think this is actually central to the PR!

After investigation this is probably caused by the lowering process of arith constrained vector in the ConvertToLLVMPass. It lastly called LLVM::detail::vectorOneToOneRewrite() to lower %cst = arith.constant dense<0.000000e+00> : vector<192xf8E4M3FN> and the value of the attribute parameter is returned by attrConvert.getAttrs() which is the attributes of the source op. And in the lowering process, the typeConverter convert f8E4M3FN to i8. So after this pass,

%cst = arith.constant dense<0.000000e+00> : vector<192xf8E4M3FN>

will be

%4 = llvm.mlir.constant(dense<0.000000e+00> : vector<192xf8E4M3FN>) : vector<192xi8>

From our perspective, if the logic of lowering constrained vector needs to be re-considered it must be a big change and will not last in a short term. We hope to just fix our current problem in this PR.

I agree, so am happy to proceed with this PR. Please find my proposed change to this PR in my previous review

@@ -1002,6 +1002,18 @@ func.func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> {

// -----

// CHECK-LABEL: func @canonicalize_extract_shapecast_different_element_type
Copy link
Contributor

@newling newling Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something along the lines of:

// CHECK: [...] = llvm.mlir.constant
// CHECK-NEXT: return %[[CONST]] 

will be more explicit.

Wondering if it is ok to mix llvm dialect here, I guess so?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's OK. Updated.

@banach-space
Copy link
Contributor

if the logic of lowering constrained vector needs to be re-considered it must be a big change

Why would it be a big change? Can we discuss what would be involved and why isn't it feasible?

Workarounds like the one proposed here are not scalable.And, to me, create code to maintain ... other incorrect code.

@MengmSun
Copy link
Author

MengmSun commented Jul 15, 2025

if the logic of lowering constrained vector needs to be re-considered it must be a big change

Why would it be a big change? Can we discuss what would be involved and why isn't it feasible?

Workarounds like the one proposed here are not scalable.And, to me, create code to maintain ... other incorrect code.

This logic is defined 3 years ago and I'm afraid changing this will influence many files especially tests. The duration to fix this problem thoroughly is unpredictable to me. By the way to fix this problem thoroughly I think only updating the logic of lowering constrained vector to llvm is not enough. As @dcaballe pointed out,

why llvm.mlir.constant allows the attribute and the return type to mismatch?

There might be lack of verification for llvm.mlir.constant , and other hidden problems existed.

@dcaballe
Copy link
Contributor

I'm probably missing something but https://github.com/llvm/llvm-project/blob/main/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp#L251

  return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
                                       adaptor.getOperands(), op->getAttrs(),
                                       *getTypeConverter(), rewriter);

Could we just convert the attributes here? The name of the oneToOneRewrite parameter is targetAttrs so I understand it's expected that we pass a target-converted attribute? It's probably good to look at what other invocations to this method are doing...

Another options:

template <typename SourceOp, typename TargetOp,
          template <typename, typename> typename AttrConvert =
              AttrConvertPassThrough>
class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
public:

It looks like AttrConvert can be instantiated to something that is actually doing the type conversion, instead of just passing the type through?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants