-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Refactor vector linearization patterns #142685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][vector] Refactor vector linearization patterns #142685
Conversation
@llvm/pr-subscribers-mlir-core Author: James Newling (newling) ChangesThis PR separates out the vector linearization API and testing from other vector rewrite patterns. There is no functional change (although the API changes). API change:There is currently a partition into 2 groups of linearization patterns: Test change:The file Patch is 72.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142685.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
new file mode 100644
index 0000000000000..de6a441249695
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
@@ -0,0 +1,251 @@
+//===- VectorLinearize.h - Vector linearization patterns --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H
+#define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace vector {
+
+/// Initialize `typeConverter` with source and target materializations that
+/// use shape_cast for converting to and from 1D (linearized) vectors.
+void initializeForVectorLinearize(TypeConverter &typeConverter);
+
+/// Initialize `conversionTarget` and `patterns` for linearization. Here
+/// linearization means converting a single operation with 1+ vector
+/// operand/result of rank>1, into a new single operation whose vector operands
+/// and results are all rank<=1.
+///
+/// This function initializes `conversionTarget` with a definition of which
+/// operations are illegal and consequently must be converted to a linearized
+/// (legal) form. It also populates `patterns` with patterns that will be run to
+/// convert illegal operations, and sets the priority/benefit patterns have.
+///
+/// Note: the set of legal operations can be extended by a user by adding
+/// additional legality rules to `conversionTarget`.
+///
+/// Further note: the choice to use a dialect conversion design for
+/// linearization is to enable reusing generic structural type conversions for
+/// linearizing scf/cf/func operations.
+void populateForFullVectorLinearize(const TypeConverter &,
+ ConversionTarget &conversionTarget,
+ RewritePatternSet &patterns);
+
+/// The set of patterns available for linearization.
+enum class LinearizePattern {
+
+ /// This pattern converts a constant (or poison) vector of rank>1 into a
+ /// 1D vector, followed by a shape_cast.
+ ///
+ /// BEFORE
+ /// %1 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+ ///
+ /// AFTER
+ /// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
+ /// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
+ LinearizeConstantLike = 0,
+
+ /// BEFORE
+ /// %2 = math.sin %arg0 : vector<2x2xf32>
+ ///
+ /// AFTER
+ /// %0 = vector.shape_cast %arg0 : vector<2x2xf32> to vector<4xf32>
+ /// %1 = math.sin %0 : vector<4xf32>
+ /// %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32>
+ LinearizeVectorizable,
+
+ /// BEFORE
+ /// %2 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
+ ///
+ /// AFTER
+ /// %0 = vector.shape_cast %arg0 : vector<4x4xf32> to vector<16xf32>
+ /// %1 = vector.bitcast %0 : vector<16xf32> to vector<32xf16>
+ /// %2 = vector.shape_cast %1 : vector<32xf16> to vector<4x8xf16>
+ LinearizeVectorBitCast,
+
+ /// BEFORE
+ /// %mask_2d = vector.create_mask %arg0, %arg1 : vector<1x4xi1>
+ ///
+ /// AFTER
+ /// [...]
+ /// %mask_1d= vector.create_mask %mul : vector<4xi1>
+ /// %mask_2d = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
+ ///
+ /// where `%mul` is a function of `%arg0` and `%arg1`.
+ ///
+ /// This pattern currently only supports 2D masks with a unit outer
+ /// dimension.
+ LinearizeVectorCreateMask,
+
+ /// This pattern converts a vector.shuffle that works on nD (n > 1) vectors to
+ /// a one that works on linearized vectors.
+ ///
+ /// BEFORE
+ /// %shuffle_3d = vector.shuffle %v1_3d, %v2_3d [ shuffle_indices ]
+ ///
+ /// AFTER
+ /// %v1_1d = vector.shape_cast %v1_3d : [...]
+ /// %v2_1d = vector.shape_cast %v2_3d : [...]
+ /// %shuffle_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
+ /// %shuffle_3d = vector.shape_cast %shuffle_1d : [...]
+ ///
+ /// Where `shuffle_indices_1d` is computed by expanding `shuffle_indices`.
+ LinearizeVectorShuffle,
+
+ /// BEFORE
+ /// %1 = vector.splat %value : vector<4x4xf32>
+ ///
+ /// AFTER
+ /// %0 = vector.splat %value : vector<16xf32>
+ /// %1 = vector.shape_cast %0 : vector<16xf32> to vector<4x4xf32>
+ LinearizeVectorSplat,
+
+ /// This pattern converts a vector.extract_strided_slice operation into a
+ /// vector.shuffle operation that has rank-1 (linearized) operand and
+ /// result.
+ ///
+ /// BEFORE
+ /// %out_nd = vector.extract_strided_slice %source_nd
+ /// { offsets = [..], strides = [..], sizes = [..] }
+ ///
+ /// AFTER
+ /// %source_1d = vector.shape_cast %source_nd [...]
+ /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
+ /// %out_nd = vector.shape_cast %out_1d [...]
+ ///
+ /// `shuffle_indices_1d` is computed using the offsets and sizes of the
+ /// original vector.extract_strided_slice operation.
+ VectorExtractStridedSliceToRankOneShuffle,
+
+ /// BEFORE
+ /// %extract = vector.extract %src [ position ]
+ ///
+ /// AFTER
+ /// %src_1d = vector.shape_cast %src : [...]
+ /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices ]
+ /// %out_nd = vector.shape_cast %out_1d : [...]
+ ///
+ /// `shuffle_indices` is computed from `position` of original extract.
+ VectorExtractToRankOneShuffle,
+
+ /// This pattern converts a vector.insert_strided_slice operation into a
+ /// vector.shuffle operation that has rank-1 (linearized) operands and result.
+ ///
+ /// BEFORE
+ /// %0 = vector.insert_strided_slice %to_store, %into
+ /// {offsets = [1, 0, 0, 0], strides = [1, 1]}
+ /// : vector<2x2xi8> into vector<2x1x3x2xi8>
+ /// AFTER
+ /// %to_store_1d
+ /// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
+ /// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
+ /// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
+ /// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
+ ///
+ /// where shuffle_indices_1d in this case is
+ /// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
+ /// ^^^^^^^^^^^^^^
+ /// to_store_1d
+ VectorInsertStridedSliceToRankOneShuffle,
+
+ /// BEFORE
+ /// %insert = vector.insert %src %dst [ position ]
+ ///
+ /// AFTER
+ /// %src_1d = vector.shape_cast %src : [...]
+ /// %dst_1d = vector.shape_cast %dst : [...]
+ /// %out_1d = vector.shuffle %dst_1d, %src_1d [ shuffle_indices ]
+ /// %out_nd = vector.shape_cast %out_1d : [...]
+ ///
+ /// `shuffle_indices` is computed from `position`.
+ VectorInsertToRankOneShuffle,
+
+ /// The number of patterns in this enum.
+ N
+};
+
+/// This class contains functions to control the set of linearization patterns
+/// to include for the conversion, and their priority.
+struct VectorLinearizePatterns {
+
+public:
+ /// By default all patterns are enabled and have benefit 1.
+ VectorLinearizePatterns() {
+ enabled.fill(true);
+ benefits.fill(PatternBenefit(1));
+ }
+
+ /// Add the patterns enabled for the conversion to `patterns`.
+ void addToPatternSet(const TypeConverter &,
+ RewritePatternSet &patterns) const;
+
+ VectorLinearizePatterns &enable(LinearizePattern id, bool e = true) {
+ enabled[static_cast<unsigned>(id)] = e;
+ return *this;
+ }
+
+ VectorLinearizePatterns &enableAll(bool e = true) {
+ enabled.fill(e);
+ return *this;
+ }
+
+ bool isEnabled(LinearizePattern id) const {
+ return enabled[static_cast<unsigned>(id)];
+ }
+
+ PatternBenefit getBenefit(LinearizePattern id) const {
+ return benefits[static_cast<unsigned>(id)];
+ }
+
+ VectorLinearizePatterns &setBenefit(LinearizePattern id,
+ PatternBenefit benefit) {
+ getBenefitRef(id) = benefit;
+ return *this;
+ }
+
+ VectorLinearizePatterns &incrementBenefit(LinearizePattern id,
+ unsigned inc = 1) {
+ getBenefitRef(id) = getBenefit(id).getBenefit() + 1;
+ return *this;
+ }
+
+private:
+ std::array<bool, static_cast<unsigned>(LinearizePattern::N)> enabled;
+ std::array<PatternBenefit, static_cast<unsigned>(LinearizePattern::N)>
+ benefits;
+
+ PatternBenefit &getBenefitRef(LinearizePattern id) {
+ unsigned idInt = static_cast<unsigned>(id);
+ assert(idInt < static_cast<unsigned>(LinearizePattern::N) &&
+ "invalid linearization pattern id");
+ return benefits[idInt];
+ }
+};
+
+/// Consider inserting a vector of shape `small` into a vector of shape `large`,
+/// at position `offsets`: this function enumerates all the indices in `large`
+/// that are written to. The enumeration is with row-major ordering.
+///
+/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
+/// positions written to are (1,3) and (1,4), which have linearized indices 8
+/// and 9. So [8,9] is returned.
+///
+/// The length of the returned vector is equal to the number of elements in
+/// the shape `small` (i.e. the product of dimensions of `small`).
+SmallVector<int64_t> getStridedSliceInsertionIndices(ArrayRef<int64_t> small,
+ ArrayRef<int64_t> large,
+ ArrayRef<int64_t> offsets);
+
+} // namespace vector
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 34a94e6ea7051..6954cb7172129 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -406,39 +406,6 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
-/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
-///
-/// Definition: here 'linearization' means converting a single operation with
-/// 1+ vector operand/result of rank>1, into a new single operation whose
-/// vector operands and results are all of rank<=1.
-///
-/// This function registers (1) which operations are legal, and hence should not
-/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
-/// materialze the conversion (with shape_cast)
-///
-/// Note: the set of legal operations can be extended by a user if for example
-/// certain rank>1 vectors are considered valid, by adding additional
-/// dynamically legal ops to `conversionTarget`.
-///
-/// Further note: the choice to use a dialect conversion design for
-/// linearization is to make it easy to reuse generic structural type
-/// conversions for linearizing scf/cf/func operations
-void populateForVectorLinearize(TypeConverter &typeConverter,
- ConversionTarget &conversionTarget);
-
-/// Populates `patterns` for ND vector (N >= 2) linearization. This currently
-/// contains patterns for converting ConstantLike, Vectorizable, and
-/// vector::BitCast ops.
-void populateVectorLinearizeBasePatterns(const TypeConverter &,
- const ConversionTarget &,
- RewritePatternSet &patterns);
-
-/// Populates `patterns` for linearizing ND (N >= 2) vector operations
-/// to 1D vector shuffle operations.
-void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
- const ConversionTarget &,
- RewritePatternSet &patterns);
-
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..0c11c9b5c8740 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -10,9 +10,10 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Vector/Transforms/VectorLinearize.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
@@ -47,12 +48,21 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
namespace {
+/// This pattern converts a constant (or poison) vector of rank>1 into a
+/// 1D vector, followed by a shape_cast.
+///
+/// BEFORE
+/// %1 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+///
+/// AFTER
+/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
+/// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
struct LinearizeConstantLike final
: OpTraitConversionPattern<OpTrait::ConstantLike> {
using OpTraitConversionPattern::OpTraitConversionPattern;
LinearizeConstantLike(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ MLIRContext *context, PatternBenefit benefit)
: OpTraitConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -88,13 +98,20 @@ struct LinearizeConstantLike final
}
};
+/// BEFORE
+/// %2 = math.sin %arg0 : vector<2x2xf32>
+///
+/// AFTER
+/// %0 = vector.shape_cast %arg0 : vector<2x2xf32> to vector<4xf32>
+/// %1 = math.sin %0 : vector<4xf32>
+/// %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32>
struct LinearizeVectorizable final
: OpTraitConversionPattern<OpTrait::Vectorizable> {
using OpTraitConversionPattern::OpTraitConversionPattern;
public:
LinearizeVectorizable(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ MLIRContext *context, PatternBenefit benefit)
: OpTraitConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -109,17 +126,178 @@ struct LinearizeVectorizable final
}
};
-template <typename TOp>
-static bool stridesAllOne(TOp op) {
- static_assert(
- std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
- std::is_same_v<TOp, vector::InsertStridedSliceOp>,
- "expected vector.extract_strided_slice or vector.insert_strided_slice");
- ArrayAttr strides = op.getStrides();
- return llvm::all_of(strides, isOneInteger);
-}
+/// BEFORE
+/// %2 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
+///
+/// AFTER
+/// %0 = vector.shape_cast %arg0 : vector<4x4xf32> to vector<16xf32>
+/// %1 = vector.bitcast %0 : vector<16xf32> to vector<32xf16>
+/// %2 = vector.shape_cast %1 : vector<32xf16> to vector<4x8xf16>
+struct LinearizeVectorBitCast final
+ : public OpConversionPattern<vector::BitCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorBitCast(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+ LogicalResult
+ matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resType = getTypeConverter()->convertType(castOp.getType());
+ assert(resType && "expected 1-D vector type");
+ rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
+ adaptor.getSource());
+ return success();
+ }
+};
+
+/// This pattern converts the vector.create_mask to work on a linearized vector.
+/// It currently supports only 2D masks with a unit outer dimension.
+///
+/// BEFORE
+/// vector.create_mask %arg0, %arg1 : vector<1x4xi1>
+///
+/// AFTER
+/// %zero = arith.constant 0 : index
+/// %cmpi = arith.cmpi sgt, %arg0, %zero : index
+/// %index = arith.index_cast %cmpi : i1 to index
+/// %mul = arith.andi %index, %arg1 : index
+/// %mask = vector.create_mask %mul : vector<4xi1>
+/// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+ : OpConversionPattern<vector::CreateMaskOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorCreateMask(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = createMaskOp.getLoc();
+ VectorType srcTy = createMaskOp.getType();
+ auto srcShape = srcTy.getShape();
+ if (srcShape.size() != 2)
+ return rewriter.notifyMatchFailure(createMaskOp,
+ "only 2D mask is supported.");
+
+ if (srcShape[0] != 1)
+ return rewriter.notifyMatchFailure(
+ createMaskOp, "only unit outer dimension is supported.");
+
+ auto dstTy = getTypeConverter()->convertType(srcTy);
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+ // Compare the first operand with 0. If it is greater than 0, the
+ // corresponding mask element is set to true, otherwise false.
+ // The result of the comparison is then multiplied with
+ // the second operand of create_mask to get the 1D mask.
+ auto firstOperand = adaptor.getOperands().front();
+ auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto isNonZero = rewriter.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sgt, firstOperand, zero);
+ auto isNonZeroIndex = rewriter.createOrFold<arith::IndexCastOp>(
+ loc, rewriter.getIndexType(), isNonZero);
+ auto secondOperand = adaptor.getOperands().back();
+ auto maskSize = rewriter.createOrFold<arith::AndIOp>(
+ loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
+
+ auto newMask = rewriter.create<vector::CreateMaskOp>(loc, dstTy, maskSize);
+ rewriter.replaceOp(createMaskOp, newMask);
+ return success();
+ }
+};
+
+/// This pattern converts a vector.shuffle that works on nD (n > 1) vectors to
+/// a one that works on linearized vectors.
+///
+/// BEFORE
+/// %shuffle_3d = vector.shuffle %v1_3d, %v2_3d [ shuffle_indices ]
+///
+/// AFTER
+/// %v1_1d = vector.shape_cast %v1_3d : [...]
+/// %v2_1d = vector.shape_cast %v2_3d : [...]
+/// %shuffle_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
+/// %shuffle_3d = vector.shape_cast %shuffle_1d : [...]
+///
+/// Where `shuffle_indices_1d` is computed by expanding `shuffle_indices`.
+struct LinearizeVectorShuffle final
+ : public OpConversionPattern<vector::ShuffleOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorShuffle(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType dstType =
+ getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
+ assert(dstType && "vector type destination expected.");
+
+ Value vec1 = adaptor.getV1();
+ Value vec2 = adaptor.getV2();
+ int shuffleSliceLen = 1;
+ int rank = shuffleOp.getV1().getType().getRank();
-/// Convert an array of attributes into a vector of integers, if possible.
+ // If rank > 1, we need to do the shuffle in the granu...
[truncated]
|
@llvm/pr-subscribers-mlir Author: James Newling (newling) ChangesThis PR separates out the vector linearization API and testing from other vector rewrite patterns. There is no functional change (although the API changes). API change:There is currently a partition into 2 groups of linearization patterns: Test change:The file Patch is 72.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142685.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
new file mode 100644
index 0000000000000..de6a441249695
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
@@ -0,0 +1,251 @@
+//===- VectorLinearize.h - Vector linearization patterns --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H
+#define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace vector {
+
+/// Initialize `typeConverter` with source and target materializations that
+/// use shape_cast for converting to and from 1D (linearized) vectors.
+void initializeForVectorLinearize(TypeConverter &typeConverter);
+
+/// Initialize `conversionTarget` and `patterns` for linearization. Here
+/// linearization means converting a single operation with 1+ vector
+/// operand/result of rank>1, into a new single operation whose vector operands
+/// and results are all rank<=1.
+///
+/// This function initializes `conversionTarget` with a definition of which
+/// operations are illegal and consequently must be converted to a linearized
+/// (legal) form. It also populates `patterns` with patterns that will be run to
+/// convert illegal operations, and sets the priority/benefit patterns have.
+///
+/// Note: the set of legal operations can be extended by a user by adding
+/// additional legality rules to `conversionTarget`.
+///
+/// Further note: the choice to use a dialect conversion design for
+/// linearization is to enable reusing generic structural type conversions for
+/// linearizing scf/cf/func operations.
+void populateForFullVectorLinearize(const TypeConverter &,
+ ConversionTarget &conversionTarget,
+ RewritePatternSet &patterns);
+
+/// The set of patterns available for linearization.
+enum class LinearizePattern {
+
+ /// This pattern converts a constant (or poison) vector of rank>1 into a
+ /// 1D vector, followed by a shape_cast.
+ ///
+ /// BEFORE
+ /// %1 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+ ///
+ /// AFTER
+ /// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
+ /// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
+ LinearizeConstantLike = 0,
+
+ /// BEFORE
+ /// %2 = math.sin %arg0 : vector<2x2xf32>
+ ///
+ /// AFTER
+ /// %0 = vector.shape_cast %arg0 : vector<2x2xf32> to vector<4xf32>
+ /// %1 = math.sin %0 : vector<4xf32>
+ /// %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32>
+ LinearizeVectorizable,
+
+ /// BEFORE
+ /// %2 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
+ ///
+ /// AFTER
+ /// %0 = vector.shape_cast %arg0 : vector<4x4xf32> to vector<16xf32>
+ /// %1 = vector.bitcast %0 : vector<16xf32> to vector<32xf16>
+ /// %2 = vector.shape_cast %1 : vector<32xf16> to vector<4x8xf16>
+ LinearizeVectorBitCast,
+
+ /// BEFORE
+ /// %mask_2d = vector.create_mask %arg0, %arg1 : vector<1x4xi1>
+ ///
+ /// AFTER
+ /// [...]
+ /// %mask_1d= vector.create_mask %mul : vector<4xi1>
+ /// %mask_2d = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
+ ///
+ /// where `%mul` is a function of `%arg0` and `%arg1`.
+ ///
+ /// This pattern currently only supports 2D masks with a unit outer
+ /// dimension.
+ LinearizeVectorCreateMask,
+
+ /// This pattern converts a vector.shuffle that works on nD (n > 1) vectors to
+ /// a one that works on linearized vectors.
+ ///
+ /// BEFORE
+ /// %shuffle_3d = vector.shuffle %v1_3d, %v2_3d [ shuffle_indices ]
+ ///
+ /// AFTER
+ /// %v1_1d = vector.shape_cast %v1_3d : [...]
+ /// %v2_1d = vector.shape_cast %v2_3d : [...]
+ /// %shuffle_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
+ /// %shuffle_3d = vector.shape_cast %shuffle_1d : [...]
+ ///
+ /// Where `shuffle_indices_1d` is computed by expanding `shuffle_indices`.
+ LinearizeVectorShuffle,
+
+ /// BEFORE
+ /// %1 = vector.splat %value : vector<4x4xf32>
+ ///
+ /// AFTER
+ /// %0 = vector.splat %value : vector<16xf32>
+ /// %1 = vector.shape_cast %0 : vector<16xf32> to vector<4x4xf32>
+ LinearizeVectorSplat,
+
+ /// This pattern converts a vector.extract_strided_slice operation into a
+ /// vector.shuffle operation that has rank-1 (linearized) operand and
+ /// result.
+ ///
+ /// BEFORE
+ /// %out_nd = vector.extract_strided_slice %source_nd
+ /// { offsets = [..], strides = [..], sizes = [..] }
+ ///
+ /// AFTER
+ /// %source_1d = vector.shape_cast %source_nd [...]
+ /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
+ /// %out_nd = vector.shape_cast %out_1d [...]
+ ///
+ /// `shuffle_indices_1d` is computed using the offsets and sizes of the
+ /// original vector.extract_strided_slice operation.
+ VectorExtractStridedSliceToRankOneShuffle,
+
+ /// BEFORE
+ /// %extract = vector.extract %src [ position ]
+ ///
+ /// AFTER
+ /// %src_1d = vector.shape_cast %src : [...]
+ /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices ]
+ /// %out_nd = vector.shape_cast %out_1d : [...]
+ ///
+ /// `shuffle_indices` is computed from `position` of original extract.
+ VectorExtractToRankOneShuffle,
+
+ /// This pattern converts a vector.insert_strided_slice operation into a
+ /// vector.shuffle operation that has rank-1 (linearized) operands and result.
+ ///
+ /// BEFORE
+ /// %0 = vector.insert_strided_slice %to_store, %into
+ /// {offsets = [1, 0, 0, 0], strides = [1, 1]}
+ /// : vector<2x2xi8> into vector<2x1x3x2xi8>
+ /// AFTER
+ /// %to_store_1d
+ /// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
+ /// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
+ /// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
+ /// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
+ ///
+ /// where shuffle_indices_1d in this case is
+ /// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
+ /// ^^^^^^^^^^^^^^
+ /// to_store_1d
+ VectorInsertStridedSliceToRankOneShuffle,
+
+ /// BEFORE
+ /// %insert = vector.insert %src %dst [ position ]
+ ///
+ /// AFTER
+ /// %src_1d = vector.shape_cast %src : [...]
+ /// %dst_1d = vector.shape_cast %dst : [...]
+ /// %out_1d = vector.shuffle %dst_1d, %src_1d [ shuffle_indices ]
+ /// %out_nd = vector.shape_cast %out_1d : [...]
+ ///
+ /// `shuffle_indices` is computed from `position`.
+ VectorInsertToRankOneShuffle,
+
+ /// The number of patterns in this enum.
+ N
+};
+
+/// This class contains functions to control the set of linearization patterns
+/// to include for the conversion, and their priority.
+struct VectorLinearizePatterns {
+
+public:
+ /// By default all patterns are enabled and have benefit 1.
+ VectorLinearizePatterns() {
+ enabled.fill(true);
+ benefits.fill(PatternBenefit(1));
+ }
+
+ /// Add the patterns enabled for the conversion to `patterns`.
+ void addToPatternSet(const TypeConverter &,
+ RewritePatternSet &patterns) const;
+
+ VectorLinearizePatterns &enable(LinearizePattern id, bool e = true) {
+ enabled[static_cast<unsigned>(id)] = e;
+ return *this;
+ }
+
+ VectorLinearizePatterns &enableAll(bool e = true) {
+ enabled.fill(e);
+ return *this;
+ }
+
+ bool isEnabled(LinearizePattern id) const {
+ return enabled[static_cast<unsigned>(id)];
+ }
+
+ PatternBenefit getBenefit(LinearizePattern id) const {
+ return benefits[static_cast<unsigned>(id)];
+ }
+
+ VectorLinearizePatterns &setBenefit(LinearizePattern id,
+ PatternBenefit benefit) {
+ getBenefitRef(id) = benefit;
+ return *this;
+ }
+
+ VectorLinearizePatterns &incrementBenefit(LinearizePattern id,
+ unsigned inc = 1) {
+ getBenefitRef(id) = getBenefit(id).getBenefit() + 1;
+ return *this;
+ }
+
+private:
+ std::array<bool, static_cast<unsigned>(LinearizePattern::N)> enabled;
+ std::array<PatternBenefit, static_cast<unsigned>(LinearizePattern::N)>
+ benefits;
+
+ PatternBenefit &getBenefitRef(LinearizePattern id) {
+ unsigned idInt = static_cast<unsigned>(id);
+ assert(idInt < static_cast<unsigned>(LinearizePattern::N) &&
+ "invalid linearization pattern id");
+ return benefits[idInt];
+ }
+};
+
+/// Consider inserting a vector of shape `small` into a vector of shape `large`,
+/// at position `offsets`: this function enumerates all the indices in `large`
+/// that are written to. The enumeration is with row-major ordering.
+///
+/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
+/// positions written to are (1,3) and (1,4), which have linearized indices 8
+/// and 9. So [8,9] is returned.
+///
+/// The length of the returned vector is equal to the number of elements in
+/// the shape `small` (i.e. the product of dimensions of `small`).
+SmallVector<int64_t> getStridedSliceInsertionIndices(ArrayRef<int64_t> small,
+ ArrayRef<int64_t> large,
+ ArrayRef<int64_t> offsets);
+
+} // namespace vector
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 34a94e6ea7051..6954cb7172129 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -406,39 +406,6 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
-/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
-///
-/// Definition: here 'linearization' means converting a single operation with
-/// 1+ vector operand/result of rank>1, into a new single operation whose
-/// vector operands and results are all of rank<=1.
-///
-/// This function registers (1) which operations are legal, and hence should not
-/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
-/// materialze the conversion (with shape_cast)
-///
-/// Note: the set of legal operations can be extended by a user if for example
-/// certain rank>1 vectors are considered valid, by adding additional
-/// dynamically legal ops to `conversionTarget`.
-///
-/// Further note: the choice to use a dialect conversion design for
-/// linearization is to make it easy to reuse generic structural type
-/// conversions for linearizing scf/cf/func operations
-void populateForVectorLinearize(TypeConverter &typeConverter,
- ConversionTarget &conversionTarget);
-
-/// Populates `patterns` for ND vector (N >= 2) linearization. This currently
-/// contains patterns for converting ConstantLike, Vectorizable, and
-/// vector::BitCast ops.
-void populateVectorLinearizeBasePatterns(const TypeConverter &,
- const ConversionTarget &,
- RewritePatternSet &patterns);
-
-/// Populates `patterns` for linearizing ND (N >= 2) vector operations
-/// to 1D vector shuffle operations.
-void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
- const ConversionTarget &,
- RewritePatternSet &patterns);
-
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..0c11c9b5c8740 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -10,9 +10,10 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Vector/Transforms/VectorLinearize.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
@@ -47,12 +48,21 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
namespace {
+/// This pattern converts a constant (or poison) vector of rank>1 into a
+/// 1D vector, followed by a shape_cast.
+///
+/// BEFORE
+/// %1 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+///
+/// AFTER
+/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
+/// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
struct LinearizeConstantLike final
: OpTraitConversionPattern<OpTrait::ConstantLike> {
using OpTraitConversionPattern::OpTraitConversionPattern;
LinearizeConstantLike(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ MLIRContext *context, PatternBenefit benefit)
: OpTraitConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -88,13 +98,20 @@ struct LinearizeConstantLike final
}
};
+/// BEFORE
+/// %2 = math.sin %arg0 : vector<2x2xf32>
+///
+/// AFTER
+/// %0 = vector.shape_cast %arg0 : vector<2x2xf32> to vector<4xf32>
+/// %1 = math.sin %0 : vector<4xf32>
+/// %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32>
struct LinearizeVectorizable final
: OpTraitConversionPattern<OpTrait::Vectorizable> {
using OpTraitConversionPattern::OpTraitConversionPattern;
public:
LinearizeVectorizable(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ MLIRContext *context, PatternBenefit benefit)
: OpTraitConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -109,17 +126,178 @@ struct LinearizeVectorizable final
}
};
-template <typename TOp>
-static bool stridesAllOne(TOp op) {
- static_assert(
- std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
- std::is_same_v<TOp, vector::InsertStridedSliceOp>,
- "expected vector.extract_strided_slice or vector.insert_strided_slice");
- ArrayAttr strides = op.getStrides();
- return llvm::all_of(strides, isOneInteger);
-}
+/// BEFORE
+/// %2 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
+///
+/// AFTER
+/// %0 = vector.shape_cast %arg0 : vector<4x4xf32> to vector<16xf32>
+/// %1 = vector.bitcast %0 : vector<16xf32> to vector<32xf16>
+/// %2 = vector.shape_cast %1 : vector<32xf16> to vector<4x8xf16>
+struct LinearizeVectorBitCast final
+ : public OpConversionPattern<vector::BitCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorBitCast(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+ LogicalResult
+ matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resType = getTypeConverter()->convertType(castOp.getType());
+ assert(resType && "expected 1-D vector type");
+ rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
+ adaptor.getSource());
+ return success();
+ }
+};
+
+/// This pattern converts the vector.create_mask to work on a linearized vector.
+/// It currently supports only 2D masks with a unit outer dimension.
+///
+/// BEFORE
+/// vector.create_mask %arg0, %arg1 : vector<1x4xi1>
+///
+/// AFTER
+/// %zero = arith.constant 0 : index
+/// %cmpi = arith.cmpi sgt, %arg0, %zero : index
+/// %index = arith.index_cast %cmpi : i1 to index
+/// %mul = arith.andi %index, %arg1 : index
+/// %mask = vector.create_mask %mul : vector<4xi1>
+/// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+ : OpConversionPattern<vector::CreateMaskOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorCreateMask(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = createMaskOp.getLoc();
+ VectorType srcTy = createMaskOp.getType();
+ auto srcShape = srcTy.getShape();
+ if (srcShape.size() != 2)
+ return rewriter.notifyMatchFailure(createMaskOp,
+ "only 2D mask is supported.");
+
+ if (srcShape[0] != 1)
+ return rewriter.notifyMatchFailure(
+ createMaskOp, "only unit outer dimension is supported.");
+
+ auto dstTy = getTypeConverter()->convertType(srcTy);
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+ // Compare the first operand with 0. If it is greater than 0, the
+ // corresponding mask element is set to true, otherwise false.
+ // The result of the comparison is then multiplied with
+ // the second operand of create_mask to get the 1D mask.
+ auto firstOperand = adaptor.getOperands().front();
+ auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto isNonZero = rewriter.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sgt, firstOperand, zero);
+ auto isNonZeroIndex = rewriter.createOrFold<arith::IndexCastOp>(
+ loc, rewriter.getIndexType(), isNonZero);
+ auto secondOperand = adaptor.getOperands().back();
+ auto maskSize = rewriter.createOrFold<arith::AndIOp>(
+ loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
+
+ auto newMask = rewriter.create<vector::CreateMaskOp>(loc, dstTy, maskSize);
+ rewriter.replaceOp(createMaskOp, newMask);
+ return success();
+ }
+};
+
+/// This pattern converts a vector.shuffle that works on nD (n > 1) vectors to
+/// a one that works on linearized vectors.
+///
+/// BEFORE
+/// %shuffle_3d = vector.shuffle %v1_3d, %v2_3d [ shuffle_indices ]
+///
+/// AFTER
+/// %v1_1d = vector.shape_cast %v1_3d : [...]
+/// %v2_1d = vector.shape_cast %v2_3d : [...]
+/// %shuffle_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
+/// %shuffle_3d = vector.shape_cast %shuffle_1d : [...]
+///
+/// Where `shuffle_indices_1d` is computed by expanding `shuffle_indices`.
+struct LinearizeVectorShuffle final
+ : public OpConversionPattern<vector::ShuffleOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorShuffle(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType dstType =
+ getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
+ assert(dstType && "vector type destination expected.");
+
+ Value vec1 = adaptor.getV1();
+ Value vec2 = adaptor.getV2();
+ int shuffleSliceLen = 1;
+ int rank = shuffleOp.getV1().getType().getRank();
-/// Convert an array of attributes into a vector of integers, if possible.
+ // If rank > 1, we need to do the shuffle in the granu...
[truncated]
|
@llvm/pr-subscribers-mlir-vector Author: James Newling (newling) ChangesThis PR separates out the vector linearization API and testing from other vector rewrite patterns. There is no functional change (although the API changes). API change:There is currently a partition into 2 groups of linearization patterns: Test change:The file Patch is 72.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142685.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
new file mode 100644
index 0000000000000..de6a441249695
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
@@ -0,0 +1,251 @@
+//===- VectorLinearize.h - Vector linearization patterns --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H
+#define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace vector {
+
+/// Initialize `typeConverter` with source and target materializations that
+/// use shape_cast for converting to and from 1D (linearized) vectors.
+void initializeForVectorLinearize(TypeConverter &typeConverter);
+
+/// Initialize `conversionTarget` and `patterns` for linearization. Here
+/// linearization means converting a single operation with 1+ vector
+/// operand/result of rank>1, into a new single operation whose vector operands
+/// and results are all rank<=1.
+///
+/// This function initializes `conversionTarget` with a definition of which
+/// operations are illegal and consequently must be converted to a linearized
+/// (legal) form. It also populates `patterns` with patterns that will be run to
+/// convert illegal operations, and sets the priority/benefit patterns have.
+///
+/// Note: the set of legal operations can be extended by a user by adding
+/// additional legality rules to `conversionTarget`.
+///
+/// Further note: the choice to use a dialect conversion design for
+/// linearization is to enable reusing generic structural type conversions for
+/// linearizing scf/cf/func operations.
+void populateForFullVectorLinearize(const TypeConverter &,
+ ConversionTarget &conversionTarget,
+ RewritePatternSet &patterns);
+
+/// The set of patterns available for linearization.
+enum class LinearizePattern {
+
+ /// This pattern converts a constant (or poison) vector of rank>1 into a
+ /// 1D vector, followed by a shape_cast.
+ ///
+ /// BEFORE
+ /// %1 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+ ///
+ /// AFTER
+ /// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
+ /// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
+ LinearizeConstantLike = 0,
+
+ /// BEFORE
+ /// %2 = math.sin %arg0 : vector<2x2xf32>
+ ///
+ /// AFTER
+ /// %0 = vector.shape_cast %arg0 : vector<2x2xf32> to vector<4xf32>
+ /// %1 = math.sin %0 : vector<4xf32>
+ /// %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32>
+ LinearizeVectorizable,
+
+ /// BEFORE
+ /// %2 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
+ ///
+ /// AFTER
+ /// %0 = vector.shape_cast %arg0 : vector<4x4xf32> to vector<16xf32>
+ /// %1 = vector.bitcast %0 : vector<16xf32> to vector<32xf16>
+ /// %2 = vector.shape_cast %1 : vector<32xf16> to vector<4x8xf16>
+ LinearizeVectorBitCast,
+
+ /// BEFORE
+ /// %mask_2d = vector.create_mask %arg0, %arg1 : vector<1x4xi1>
+ ///
+ /// AFTER
+ /// [...]
+ /// %mask_1d= vector.create_mask %mul : vector<4xi1>
+ /// %mask_2d = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
+ ///
+ /// where `%mul` is a function of `%arg0` and `%arg1`.
+ ///
+ /// This pattern currently only supports 2D masks with a unit outer
+ /// dimension.
+ LinearizeVectorCreateMask,
+
+ /// This pattern converts a vector.shuffle that works on nD (n > 1) vectors to
+ /// a one that works on linearized vectors.
+ ///
+ /// BEFORE
+ /// %shuffle_3d = vector.shuffle %v1_3d, %v2_3d [ shuffle_indices ]
+ ///
+ /// AFTER
+ /// %v1_1d = vector.shape_cast %v1_3d : [...]
+ /// %v2_1d = vector.shape_cast %v2_3d : [...]
+ /// %shuffle_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
+ /// %shuffle_3d = vector.shape_cast %shuffle_1d : [...]
+ ///
+ /// Where `shuffle_indices_1d` is computed by expanding `shuffle_indices`.
+ LinearizeVectorShuffle,
+
+ /// BEFORE
+ /// %1 = vector.splat %value : vector<4x4xf32>
+ ///
+ /// AFTER
+ /// %0 = vector.splat %value : vector<16xf32>
+ /// %1 = vector.shape_cast %0 : vector<16xf32> to vector<4x4xf32>
+ LinearizeVectorSplat,
+
+ /// This pattern converts a vector.extract_strided_slice operation into a
+ /// vector.shuffle operation that has rank-1 (linearized) operand and
+ /// result.
+ ///
+ /// BEFORE
+ /// %out_nd = vector.extract_strided_slice %source_nd
+ /// { offsets = [..], strides = [..], sizes = [..] }
+ ///
+ /// AFTER
+ /// %source_1d = vector.shape_cast %source_nd [...]
+ /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
+ /// %out_nd = vector.shape_cast %out_1d [...]
+ ///
+ /// `shuffle_indices_1d` is computed using the offsets and sizes of the
+ /// original vector.extract_strided_slice operation.
+ VectorExtractStridedSliceToRankOneShuffle,
+
+ /// BEFORE
+ /// %extract = vector.extract %src [ position ]
+ ///
+ /// AFTER
+ /// %src_1d = vector.shape_cast %src : [...]
+ /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices ]
+ /// %out_nd = vector.shape_cast %out_1d : [...]
+ ///
+ /// `shuffle_indices` is computed from `position` of original extract.
+ VectorExtractToRankOneShuffle,
+
+ /// This pattern converts a vector.insert_strided_slice operation into a
+ /// vector.shuffle operation that has rank-1 (linearized) operands and result.
+ ///
+ /// BEFORE
+ /// %0 = vector.insert_strided_slice %to_store, %into
+ /// {offsets = [1, 0, 0, 0], strides = [1, 1]}
+ /// : vector<2x2xi8> into vector<2x1x3x2xi8>
+ /// AFTER
+ /// %to_store_1d
+ /// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
+ /// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
+ /// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
+ /// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
+ ///
+ /// where shuffle_indices_1d in this case is
+ /// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
+ /// ^^^^^^^^^^^^^^
+ /// to_store_1d
+ VectorInsertStridedSliceToRankOneShuffle,
+
+ /// BEFORE
+ /// %insert = vector.insert %src %dst [ position ]
+ ///
+ /// AFTER
+ /// %src_1d = vector.shape_cast %src : [...]
+ /// %dst_1d = vector.shape_cast %dst : [...]
+ /// %out_1d = vector.shuffle %dst_1d, %src_1d [ shuffle_indices ]
+ /// %out_nd = vector.shape_cast %out_1d : [...]
+ ///
+ /// `shuffle_indices` is computed from `position`.
+ VectorInsertToRankOneShuffle,
+
+ /// The number of patterns in this enum.
+ N
+};
+
+/// This class contains functions to control the set of linearization patterns
+/// to include for the conversion, and their priority.
+struct VectorLinearizePatterns {
+
+public:
+ /// By default all patterns are enabled and have benefit 1.
+ VectorLinearizePatterns() {
+ enabled.fill(true);
+ benefits.fill(PatternBenefit(1));
+ }
+
+ /// Add the patterns enabled for the conversion to `patterns`.
+ void addToPatternSet(const TypeConverter &,
+ RewritePatternSet &patterns) const;
+
+ VectorLinearizePatterns &enable(LinearizePattern id, bool e = true) {
+ enabled[static_cast<unsigned>(id)] = e;
+ return *this;
+ }
+
+ VectorLinearizePatterns &enableAll(bool e = true) {
+ enabled.fill(e);
+ return *this;
+ }
+
+ bool isEnabled(LinearizePattern id) const {
+ return enabled[static_cast<unsigned>(id)];
+ }
+
+ PatternBenefit getBenefit(LinearizePattern id) const {
+ return benefits[static_cast<unsigned>(id)];
+ }
+
+ VectorLinearizePatterns &setBenefit(LinearizePattern id,
+ PatternBenefit benefit) {
+ getBenefitRef(id) = benefit;
+ return *this;
+ }
+
+ VectorLinearizePatterns &incrementBenefit(LinearizePattern id,
+ unsigned inc = 1) {
+ getBenefitRef(id) = getBenefit(id).getBenefit() + 1;
+ return *this;
+ }
+
+private:
+ std::array<bool, static_cast<unsigned>(LinearizePattern::N)> enabled;
+ std::array<PatternBenefit, static_cast<unsigned>(LinearizePattern::N)>
+ benefits;
+
+ PatternBenefit &getBenefitRef(LinearizePattern id) {
+ unsigned idInt = static_cast<unsigned>(id);
+ assert(idInt < static_cast<unsigned>(LinearizePattern::N) &&
+ "invalid linearization pattern id");
+ return benefits[idInt];
+ }
+};
+
+/// Consider inserting a vector of shape `small` into a vector of shape `large`,
+/// at position `offsets`: this function enumerates all the indices in `large`
+/// that are written to. The enumeration is with row-major ordering.
+///
+/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
+/// positions written to are (1,3) and (1,4), which have linearized indices 8
+/// and 9. So [8,9] is returned.
+///
+/// The length of the returned vector is equal to the number of elements in
+/// the shape `small` (i.e. the product of dimensions of `small`).
+SmallVector<int64_t> getStridedSliceInsertionIndices(ArrayRef<int64_t> small,
+ ArrayRef<int64_t> large,
+ ArrayRef<int64_t> offsets);
+
+} // namespace vector
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 34a94e6ea7051..6954cb7172129 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -406,39 +406,6 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
-/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
-///
-/// Definition: here 'linearization' means converting a single operation with
-/// 1+ vector operand/result of rank>1, into a new single operation whose
-/// vector operands and results are all of rank<=1.
-///
-/// This function registers (1) which operations are legal, and hence should not
-/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
-/// materialze the conversion (with shape_cast)
-///
-/// Note: the set of legal operations can be extended by a user if for example
-/// certain rank>1 vectors are considered valid, by adding additional
-/// dynamically legal ops to `conversionTarget`.
-///
-/// Further note: the choice to use a dialect conversion design for
-/// linearization is to make it easy to reuse generic structural type
-/// conversions for linearizing scf/cf/func operations
-void populateForVectorLinearize(TypeConverter &typeConverter,
- ConversionTarget &conversionTarget);
-
-/// Populates `patterns` for ND vector (N >= 2) linearization. This currently
-/// contains patterns for converting ConstantLike, Vectorizable, and
-/// vector::BitCast ops.
-void populateVectorLinearizeBasePatterns(const TypeConverter &,
- const ConversionTarget &,
- RewritePatternSet &patterns);
-
-/// Populates `patterns` for linearizing ND (N >= 2) vector operations
-/// to 1D vector shuffle operations.
-void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
- const ConversionTarget &,
- RewritePatternSet &patterns);
-
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..0c11c9b5c8740 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -10,9 +10,10 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Vector/Transforms/VectorLinearize.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
@@ -47,12 +48,21 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
namespace {
+/// This pattern converts a constant (or poison) vector of rank>1 into a
+/// 1D vector, followed by a shape_cast.
+///
+/// BEFORE
+/// %1 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+///
+/// AFTER
+/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
+/// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
struct LinearizeConstantLike final
: OpTraitConversionPattern<OpTrait::ConstantLike> {
using OpTraitConversionPattern::OpTraitConversionPattern;
LinearizeConstantLike(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ MLIRContext *context, PatternBenefit benefit)
: OpTraitConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -88,13 +98,20 @@ struct LinearizeConstantLike final
}
};
+/// BEFORE
+/// %2 = math.sin %arg0 : vector<2x2xf32>
+///
+/// AFTER
+/// %0 = vector.shape_cast %arg0 : vector<2x2xf32> to vector<4xf32>
+/// %1 = math.sin %0 : vector<4xf32>
+/// %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32>
struct LinearizeVectorizable final
: OpTraitConversionPattern<OpTrait::Vectorizable> {
using OpTraitConversionPattern::OpTraitConversionPattern;
public:
LinearizeVectorizable(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ MLIRContext *context, PatternBenefit benefit)
: OpTraitConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -109,17 +126,178 @@ struct LinearizeVectorizable final
}
};
-template <typename TOp>
-static bool stridesAllOne(TOp op) {
- static_assert(
- std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
- std::is_same_v<TOp, vector::InsertStridedSliceOp>,
- "expected vector.extract_strided_slice or vector.insert_strided_slice");
- ArrayAttr strides = op.getStrides();
- return llvm::all_of(strides, isOneInteger);
-}
+/// BEFORE
+/// %2 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
+///
+/// AFTER
+/// %0 = vector.shape_cast %arg0 : vector<4x4xf32> to vector<16xf32>
+/// %1 = vector.bitcast %0 : vector<16xf32> to vector<32xf16>
+/// %2 = vector.shape_cast %1 : vector<32xf16> to vector<4x8xf16>
+struct LinearizeVectorBitCast final
+ : public OpConversionPattern<vector::BitCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorBitCast(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+ LogicalResult
+ matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resType = getTypeConverter()->convertType(castOp.getType());
+ assert(resType && "expected 1-D vector type");
+ rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
+ adaptor.getSource());
+ return success();
+ }
+};
+
+/// This pattern converts the vector.create_mask to work on a linearized vector.
+/// It currently supports only 2D masks with a unit outer dimension.
+///
+/// BEFORE
+/// vector.create_mask %arg0, %arg1 : vector<1x4xi1>
+///
+/// AFTER
+/// %zero = arith.constant 0 : index
+/// %cmpi = arith.cmpi sgt, %arg0, %zero : index
+/// %index = arith.index_cast %cmpi : i1 to index
+/// %mul = arith.andi %index, %arg1 : index
+/// %mask = vector.create_mask %mul : vector<4xi1>
+/// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+ : OpConversionPattern<vector::CreateMaskOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorCreateMask(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = createMaskOp.getLoc();
+ VectorType srcTy = createMaskOp.getType();
+ auto srcShape = srcTy.getShape();
+ if (srcShape.size() != 2)
+ return rewriter.notifyMatchFailure(createMaskOp,
+ "only 2D mask is supported.");
+
+ if (srcShape[0] != 1)
+ return rewriter.notifyMatchFailure(
+ createMaskOp, "only unit outer dimension is supported.");
+
+ auto dstTy = getTypeConverter()->convertType(srcTy);
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+ // Compare the first operand with 0. If it is greater than 0, the
+ // corresponding mask element is set to true, otherwise false.
+ // The result of the comparison is then multiplied with
+ // the second operand of create_mask to get the 1D mask.
+ auto firstOperand = adaptor.getOperands().front();
+ auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto isNonZero = rewriter.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sgt, firstOperand, zero);
+ auto isNonZeroIndex = rewriter.createOrFold<arith::IndexCastOp>(
+ loc, rewriter.getIndexType(), isNonZero);
+ auto secondOperand = adaptor.getOperands().back();
+ auto maskSize = rewriter.createOrFold<arith::AndIOp>(
+ loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
+
+ auto newMask = rewriter.create<vector::CreateMaskOp>(loc, dstTy, maskSize);
+ rewriter.replaceOp(createMaskOp, newMask);
+ return success();
+ }
+};
+
+/// This pattern converts a vector.shuffle that works on nD (n > 1) vectors to
+/// a one that works on linearized vectors.
+///
+/// BEFORE
+/// %shuffle_3d = vector.shuffle %v1_3d, %v2_3d [ shuffle_indices ]
+///
+/// AFTER
+/// %v1_1d = vector.shape_cast %v1_3d : [...]
+/// %v2_1d = vector.shape_cast %v2_3d : [...]
+/// %shuffle_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
+/// %shuffle_3d = vector.shape_cast %shuffle_1d : [...]
+///
+/// Where `shuffle_indices_1d` is computed by expanding `shuffle_indices`.
+struct LinearizeVectorShuffle final
+ : public OpConversionPattern<vector::ShuffleOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorShuffle(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType dstType =
+ getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
+ assert(dstType && "vector type destination expected.");
+
+ Value vec1 = adaptor.getV1();
+ Value vec2 = adaptor.getV2();
+ int shuffleSliceLen = 1;
+ int rank = shuffleOp.getV1().getType().getRank();
-/// Convert an array of attributes into a vector of integers, if possible.
+ // If rank > 1, we need to do the shuffle in the granu...
[truncated]
|
Hey, thanks for bringing this up! Could you elaborate a bit more on the final state you are working towards? Why do we need such a finer grain level of optionality for these patterns? |
The final state I'm working towards is basically this PR plus 4 additional patterns. These are 4 patterns for linearizing towards insert_strided_slice and extract_strided_slice, instead of towards shuffle (see for example VectorExtractToRankOneStrided in #142672). The motivation for linearizing this way is that the strided ops are 'higher' level than shuffle, because they retain the fact that the extract/insert is contiguous. This alternative lowering isn't a 'must have' for linearization in general, but it was my preferred linearization and what I had before trying to integrate with upstream. I could add these 4 patterns without exposing all the patterns in a public API. That would entail adding just 1 new API for this alternative lowering. Perhaps I should post a series of smaller separate PRs with those new patterns. My motivation for the fine grained control was mostly a personal preference, and not seeing an obvious partition of the patterns (as mentioned in this PR summary). I have been in situations using upstream patterns before where I wanted only a subset of the patterns added by the public APIs... But I know that consistency in Vector is worth a lot. so am happy to not go with this fine grained approach :) |
Hey @newling , I don't really have a strong opinion here. In general, I am in favour of grouping patterns/logic this way. The diff is a bit tricky to review - is this is merely moving code around?
IIUC, this would be the main change? How can a user achieve this? Could you point me in the right direction? Thanks! |
Yes, there is no change to the patterns, type converter, or conversion target. The main change is in how the user is expected to initialize them.
Yes, that's the main change. It's shown in a test in the end goal draft PR l #142672. This current PR was pulled out of that larger PR. I now see now that I've 'stacked' the bigger change in a way that makes it harder than necessary to make sense of. I will put this PR into draft mode while I reassess, hope to be back soon with something easier to digest! Thanks for your comments @banach-space and @dcaballe, I'll let you know when I'm ready for feedback again. |
Did you mean |
Not quite, because the user doesn't choose exactly what patterns run there. I was thinking of the following: In the new file TestVectorLinearize.cpp in #142672 there are 3 passes for testing patterns, the one which shows how a user can 'mix-and-match' is VectorLinearizePatterns() // 1
.enableAll(false) // 2
.enable(LinearizePattern::RankReduceInsertStridedSlice) // 3
.enable(LinearizePattern::RankReduceExtractStridedSlice) // 4
.addToPatternSet(typeConverter, patterns); // 5 where As for the tests using preference (shuffle vs strided) like I hope that makes a bit more sense! FWIW I suppose this proposed change to expose more fine grained pattern control is related to the second point of the discourse thread https://discourse.llvm.org/t/finding-rewrite-patterns/85810 |
This PR separates out the vector linearization API and testing from other vector rewrite patterns. There is no functional change (although the API changes).
API change:
There is currently a partition into 2 groups of linearization patterns:
populateVectorLinearizeBasePatterns
andpopulateVectorLinearizeShuffleLikeOpsPatterns
. I would like to add more patterns for linearization (draft PR #142672) but don't want to add a third group of patterns because I don't see any obvious grouping. I think it'd be less opiniated if any sub-group of the patterns can be used. That's introduced in this PR. With this PR there is an API which adds all patterns to theRewritePatternSet
(populateForFullVectorLinearize
), but a user can also bypass this API and mix-and-match whichever patterns they want (as well as control the pattern benefits).Test change:
The file
TestVectorTransforms.cpp
was getting large (~1'000 lines) so I split it up (SPIRV is example dialect that has multiple test util .cpp files like this).