diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index d7518943229ea..4d49e52b21563 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2244,18 +2244,8 @@ def Vector_ShapeCastOp : Results<(outs AnyVectorOfAnyRank:$result)> { let summary = "shape_cast casts between vector shapes"; let description = [{ - The shape_cast operation casts between an n-D source vector shape and - a k-D result vector shape (the element type remains the same). - - If reducing rank (n > k), result dimension sizes must be a product - of contiguous source dimension sizes. - If expanding rank (n < k), source dimensions must factor into a - contiguous sequence of destination dimension sizes. - Each source dim is expanded (or contiguous sequence of source dims combined) - in source dimension list order (i.e. 0 <= i < n), to produce a contiguous - sequence of result dims (or a single result dim), in result dimension list - order (i.e. 0 <= j < k). The product of all source dimension sizes and all - result dimension sizes must match. + Casts to a vector with the same number of elements, element type, and + number of scalable dimensions. It is currently assumed that this operation does not require moving data, and that it will be folded away before lowering vector operations. @@ -2265,15 +2255,13 @@ def Vector_ShapeCastOp : 2-D MLIR vector to a 1-D flattened LLVM vector.shape_cast lowering to LLVM is supported in that particular case, for now. - Example: + Examples: ```mlir - // Example casting to a lower vector rank. - %1 = vector.shape_cast %0 : vector<5x1x4x3xf32> to vector<20x3xf32> - - // Example casting to a higher vector rank. - %3 = vector.shape_cast %2 : vector<10x12x8xf32> to vector<5x2x3x4x8xf32> + %1 = vector.shape_cast %0 : vector<4x3xf32> to vector<3x2x2xf32> + // with 2 scalable dimensions (number of which must be preserved). + %3 = vector.shape_cast %2 : vector<[2]x3x[4]xi8> to vector<3x[1]x[8]xi8> ``` }]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 368259b38b153..237ff17819063 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5505,124 +5505,56 @@ void ShapeCastOp::inferResultRanges(ArrayRef argRanges, setResultRanges(getResult(), argRanges.front()); } -/// Returns true if each element of 'a' is equal to the product of a contiguous -/// sequence of the elements of 'b'. Returns false otherwise. -static bool isValidShapeCast(ArrayRef a, ArrayRef b) { - unsigned rankA = a.size(); - unsigned rankB = b.size(); - assert(rankA < rankB); - - auto isOne = [](int64_t v) { return v == 1; }; - - // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape - // casted to a 0-d vector. - if (rankA == 0 && llvm::all_of(b, isOne)) - return true; +LogicalResult ShapeCastOp::verify() { - unsigned i = 0; - unsigned j = 0; - while (i < rankA && j < rankB) { - int64_t dimA = a[i]; - int64_t dimB = 1; - while (dimB < dimA && j < rankB) - dimB *= b[j++]; - if (dimA != dimB) - break; - ++i; + VectorType sourceType = getSourceVectorType(); + VectorType resultType = getResultVectorType(); - // Handle the case when trailing dimensions are of size 1. - // Include them into the contiguous sequence. - if (i < rankA && llvm::all_of(a.slice(i), isOne)) - i = rankA; - if (j < rankB && llvm::all_of(b.slice(j), isOne)) - j = rankB; - } + // Check that element type is preserved + if (sourceType.getElementType() != resultType.getElementType()) + return emitOpError("has different source and result element types"); - return i == rankA && j == rankB; -} - -static LogicalResult verifyVectorShapeCast(Operation *op, - VectorType sourceVectorType, - VectorType resultVectorType) { - // Check that element type is the same. - if (sourceVectorType.getElementType() != resultVectorType.getElementType()) - return op->emitOpError("source/result vectors must have same element type"); - auto sourceShape = sourceVectorType.getShape(); - auto resultShape = resultVectorType.getShape(); - - // Check that product of source dim sizes matches product of result dim sizes. - int64_t sourceDimProduct = std::accumulate( - sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies{}); - int64_t resultDimProduct = std::accumulate( - resultShape.begin(), resultShape.end(), 1LL, std::multiplies{}); - if (sourceDimProduct != resultDimProduct) - return op->emitOpError("source/result number of elements must match"); - - // Check that expanding/contracting rank cases. - unsigned sourceRank = sourceVectorType.getRank(); - unsigned resultRank = resultVectorType.getRank(); - if (sourceRank < resultRank) { - if (!isValidShapeCast(sourceShape, resultShape)) - return op->emitOpError("invalid shape cast"); - } else if (sourceRank > resultRank) { - if (!isValidShapeCast(resultShape, sourceShape)) - return op->emitOpError("invalid shape cast"); + // Check that number of elements is preserved + int64_t sourceNElms = sourceType.getNumElements(); + int64_t resultNElms = resultType.getNumElements(); + if (sourceNElms != resultNElms) { + return emitOpError() << "has different number of elements at source (" + << sourceNElms << ") and result (" << resultNElms + << ")"; } // Check that (non-)scalability is preserved - int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims(); - int64_t resultNScalableDims = resultVectorType.getNumScalableDims(); + int64_t sourceNScalableDims = sourceType.getNumScalableDims(); + int64_t resultNScalableDims = resultType.getNumScalableDims(); if (sourceNScalableDims != resultNScalableDims) - return op->emitOpError("different number of scalable dims at source (") - << sourceNScalableDims << ") and result (" << resultNScalableDims - << ")"; - sourceVectorType.getNumDynamicDims(); - - return success(); -} - -LogicalResult ShapeCastOp::verify() { - auto sourceVectorType = - llvm::dyn_cast_or_null(getSource().getType()); - auto resultVectorType = - llvm::dyn_cast_or_null(getResult().getType()); - - // Check if source/result are of vector type. - if (sourceVectorType && resultVectorType) - return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType); + return emitOpError() << "has different number of scalable dims at source (" + << sourceNScalableDims << ") and result (" + << resultNScalableDims << ")"; return success(); } OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { + VectorType resultType = getType(); + // No-op shape cast. - if (getSource().getType() == getType()) + if (getSource().getType() == resultType) return getSource(); - VectorType resultType = getType(); - - // Canceling shape casts. + // Y = shape_cast(shape_cast(X))) + // -> X, if X and Y have same type + // -> shape_cast(X) otherwise. if (auto otherOp = getSource().getDefiningOp()) { - - // Only allows valid transitive folding (expand/collapse dimensions). VectorType srcType = otherOp.getSource().getType(); if (resultType == srcType) return otherOp.getSource(); - if (srcType.getRank() < resultType.getRank()) { - if (!isValidShapeCast(srcType.getShape(), resultType.getShape())) - return {}; - } else if (srcType.getRank() > resultType.getRank()) { - if (!isValidShapeCast(resultType.getShape(), srcType.getShape())) - return {}; - } else { - return {}; - } setOperand(otherOp.getSource()); return getResult(); } - // Cancelling broadcast and shape cast ops. + // Y = shape_cast(broadcast(X)) + // -> X, if X and Y have same type if (auto bcastOp = getSource().getDefiningOp()) { if (bcastOp.getSourceType() == resultType) return bcastOp.getSource(); diff --git a/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir index ae2b5393ca449..60ad54bf5c370 100644 --- a/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir @@ -26,8 +26,7 @@ func.func @transpose4x8xf32(%arg0: vector<4x8xf32>) -> vector<8x4xf32> { // CHECK-NEXT: vector.insert {{.*}}[1] // CHECK-NEXT: vector.insert {{.*}}[2] // CHECK-NEXT: vector.insert {{.*}}[3] - // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32> - // CHECK-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<8x4xf32> + // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<8x4xf32> %0 = vector.transpose %arg0, [1, 0] : vector<4x8xf32> to vector<8x4xf32> return %0 : vector<8x4xf32> } @@ -54,8 +53,7 @@ func.func @transpose021_1x4x8xf32(%arg0: vector<1x4x8xf32>) -> vector<1x8x4xf32> // CHECK-NEXT: vector.insert {{.*}}[1] // CHECK-NEXT: vector.insert {{.*}}[2] // CHECK-NEXT: vector.insert {{.*}}[3] - // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32> - // CHECK-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<1x8x4xf32> + // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<1x8x4xf32> %0 = vector.transpose %arg0, [0, 2, 1] : vector<1x4x8xf32> to vector<1x8x4xf32> return %0 : vector<1x8x4xf32> } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 2d365ac2b4287..5bf8b9338c498 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -950,10 +950,9 @@ func.func @insert_no_fold_scalar_to_0d(%v: vector) -> vector { // ----- -// CHECK-LABEL: dont_fold_expand_collapse -// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32> -// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32> -// CHECK: return %[[B]] : vector<8x8xf32> +// CHECK-LABEL: fold_expand_collapse +// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<8x8xf32> +// CHECK: return %[[A]] : vector<8x8xf32> func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf32> { %0 = vector.shape_cast %arg0 : vector<1x1x64xf32> to vector<1x1x8x8xf32> %1 = vector.shape_cast %0 : vector<1x1x8x8xf32> to vector<8x8xf32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 3a8320971bac4..fa4837126accb 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1131,34 +1131,21 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) { // ----- + func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) { - // expected-error@+1 {{op source/result vectors must have same element type}} + // expected-error@+1 {{'vector.shape_cast' op has different source and result element types}} %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32> } // ----- func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) { - // expected-error@+1 {{op source/result number of elements must match}} + // expected-error@+1 {{'vector.shape_cast' op has different number of elements at source (30) and result (20)}} %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32> } // ----- -func.func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) { - // expected-error@+1 {{invalid shape cast}} - %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32> -} - -// ----- - -func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) { - // expected-error@+1 {{invalid shape cast}} - %0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32> -} - -// ----- - func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<15x[2]xf32>) { // expected-error@+1 {{different number of scalable dims at source (1) and result (0)}} %0 = vector.shape_cast %arg0 : vector<15x[2]xf32> to vector<30xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 8ae1e9f9d0c64..f3220aed4360c 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -564,6 +564,17 @@ func.func @shape_cast(%arg0 : vector<5x1x3x2xf32>, return %0, %1, %2, %3 : vector<15x2xf32>, vector<8xf32>, vector<16xf32>, vector<16x1xf32> } +// A vector.shape_cast can cast between any 2 shapes as long as the +// number of elements is preserved. For those familiar with the tensor +// dialect: this behaviour is like the tensor.reshape operation, i.e. +// less restrictive than tensor.collapse_shape and tensor.expand_shape +// CHECK-LABEL: @shape_cast_general_reshape +func.func @shape_cast_general_reshape(%arg0 : vector<2x3xf32>) -> (vector<3x1x2xf32>) { + // CHECK: vector.shape_cast %{{.*}} : vector<2x3xf32> to vector<3x1x2xf32> + %0 = vector.shape_cast %arg0 : vector<2x3xf32> to vector<3x1x2xf32> + return %0 : vector<3x1x2xf32> +} + // CHECK-LABEL: @shape_cast_0d func.func @shape_cast_0d(%arg0 : vector<1x1x1x1xf32>) -> (vector<1x1x1x1xf32>) {