From 5e9aa6b6b1fe332f5e0c92b14835484558c00732 Mon Sep 17 00:00:00 2001 From: MengmengSun Date: Wed, 9 Jul 2025 02:53:36 -0700 Subject: [PATCH 1/2] [mlir][Vector]Add constraints to vector.shape_cast(constant) -> constant --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 9 ++++++--- mlir/test/Dialect/Vector/canonicalize.mlir | 12 ++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 214d2ba7e1b8e..5bbe6704aac48 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5922,10 +5922,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { return bcastOp.getSource(); } - // shape_cast(constant) -> constant + // shape_cast(constant) -> constant, + // if element type of the source and result are the same if (auto splatAttr = - llvm::dyn_cast_if_present(adaptor.getSource())) - return splatAttr.reshape(getType()); + llvm::dyn_cast_if_present(adaptor.getSource())) { + if (splatAttr.getElementType() == resultType.getElementType()) + return splatAttr.reshape(getType()); + } // shape_cast(poison) -> poison if (llvm::dyn_cast_if_present(adaptor.getSource())) { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 8a9e27378df61..69da8a31d2c9b 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1002,6 +1002,18 @@ func.func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> { // ----- +// CHECK-LABEL: func @canonicalize_extract_shapecast_different_element_type +func.func @canonicalize_extract_shapecast_different_element_type()->vector<12xi8> { + %0 = llvm.mlir.constant(dense<0.000000e+00> : vector<12xf8E4M3FN>) : vector<12xi8> + // CHECK-NOT: vector.shape_cast + %1 = vector.shape_cast %0 : vector<12xi8> to vector<1x12xi8> + // CHECK-NOT: vector.extract + %2 = vector.extract %1[0] : vector<12xi8> from vector<1x12xi8> + return %2 : vector<12xi8> +} + +// ----- + // CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar // CHECK: vector.broadcast // CHECK-NOT: vector.shape_cast From 5604b262bae4dabbcba9803abf52f57eb95edaf1 Mon Sep 17 00:00:00 2001 From: MengmengSun Date: Tue, 15 Jul 2025 03:11:26 -0700 Subject: [PATCH 2/2] Update based on comments --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 26 ++++++++++++++++++---- mlir/test/Dialect/Vector/canonicalize.mlir | 4 ++-- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 5bbe6704aac48..4cc4eed08f2da 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5922,12 +5922,30 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { return bcastOp.getSource(); } - // shape_cast(constant) -> constant, - // if element type of the source and result are the same + // shape_cast(constant) -> constant if (auto splatAttr = llvm::dyn_cast_if_present(adaptor.getSource())) { - if (splatAttr.getElementType() == resultType.getElementType()) - return splatAttr.reshape(getType()); + + // The shape and 'scalable dims' of the new attribute must match the result + // of the shape_cast: + auto newShape = resultType.getShape(); + auto newScalableDims = resultType.getScalableDims(); + + // The element type must be retained. Note that this is to handle currently + // valid IR like + // + // ``` + // %0 = llvm.mlir.constant(dense<0.> : vector<1xf8E4M3FN>) : vector<1xi8> + // %1 = vector.shape_cast %0 : vector<1xi8> to vector<1x1xi8> + // ``` + // + // where the element types of the attribute and result do not match. + auto newElementType = splatAttr.getElementType(); + + auto newAttr = VectorType::get(newShape, newElementType, newScalableDims); + + return DenseElementsAttr::get(newAttr, + splatAttr.getSplatValue()); } // shape_cast(poison) -> poison diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 69da8a31d2c9b..b0114905db742 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1003,11 +1003,11 @@ func.func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> { // ----- // CHECK-LABEL: func @canonicalize_extract_shapecast_different_element_type +// CHECK: %[[CONST:.*]] = llvm.mlir.constant +// CHECK-NEXT: return %[[CONST]] func.func @canonicalize_extract_shapecast_different_element_type()->vector<12xi8> { %0 = llvm.mlir.constant(dense<0.000000e+00> : vector<12xf8E4M3FN>) : vector<12xi8> - // CHECK-NOT: vector.shape_cast %1 = vector.shape_cast %0 : vector<12xi8> to vector<1x12xi8> - // CHECK-NOT: vector.extract %2 = vector.extract %1[0] : vector<12xi8> from vector<1x12xi8> return %2 : vector<12xi8> }