diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h new file mode 100644 index 0000000000000..de6a441249695 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h @@ -0,0 +1,251 @@ +//===- VectorLinearize.h - Vector linearization patterns --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H +#define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace vector { + +/// Initialize `typeConverter` with source and target materializations that +/// use shape_cast for converting to and from 1D (linearized) vectors. +void initializeForVectorLinearize(TypeConverter &typeConverter); + +/// Initialize `conversionTarget` and `patterns` for linearization. Here +/// linearization means converting a single operation with 1+ vector +/// operand/result of rank>1, into a new single operation whose vector operands +/// and results are all rank<=1. +/// +/// This function initializes `conversionTarget` with a definition of which +/// operations are illegal and consequently must be converted to a linearized +/// (legal) form. It also populates `patterns` with patterns that will be run to +/// convert illegal operations, and sets the priority/benefit patterns have. +/// +/// Note: the set of legal operations can be extended by a user by adding +/// additional legality rules to `conversionTarget`. +/// +/// Further note: the choice to use a dialect conversion design for +/// linearization is to enable reusing generic structural type conversions for +/// linearizing scf/cf/func operations. +void populateForFullVectorLinearize(const TypeConverter &, + ConversionTarget &conversionTarget, + RewritePatternSet &patterns); + +/// The set of patterns available for linearization. +enum class LinearizePattern { + + /// This pattern converts a constant (or poison) vector of rank>1 into a + /// 1D vector, followed by a shape_cast. + /// + /// BEFORE + /// %1 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32> + /// + /// AFTER + /// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32> + /// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32> + LinearizeConstantLike = 0, + + /// BEFORE + /// %2 = math.sin %arg0 : vector<2x2xf32> + /// + /// AFTER + /// %0 = vector.shape_cast %arg0 : vector<2x2xf32> to vector<4xf32> + /// %1 = math.sin %0 : vector<4xf32> + /// %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32> + LinearizeVectorizable, + + /// BEFORE + /// %2 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16> + /// + /// AFTER + /// %0 = vector.shape_cast %arg0 : vector<4x4xf32> to vector<16xf32> + /// %1 = vector.bitcast %0 : vector<16xf32> to vector<32xf16> + /// %2 = vector.shape_cast %1 : vector<32xf16> to vector<4x8xf16> + LinearizeVectorBitCast, + + /// BEFORE + /// %mask_2d = vector.create_mask %arg0, %arg1 : vector<1x4xi1> + /// + /// AFTER + /// [...] + /// %mask_1d= vector.create_mask %mul : vector<4xi1> + /// %mask_2d = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1> + /// + /// where `%mul` is a function of `%arg0` and `%arg1`. + /// + /// This pattern currently only supports 2D masks with a unit outer + /// dimension. + LinearizeVectorCreateMask, + + /// This pattern converts a vector.shuffle that works on nD (n > 1) vectors to + /// a one that works on linearized vectors. + /// + /// BEFORE + /// %shuffle_3d = vector.shuffle %v1_3d, %v2_3d [ shuffle_indices ] + /// + /// AFTER + /// %v1_1d = vector.shape_cast %v1_3d : [...] + /// %v2_1d = vector.shape_cast %v2_3d : [...] + /// %shuffle_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ] + /// %shuffle_3d = vector.shape_cast %shuffle_1d : [...] + /// + /// Where `shuffle_indices_1d` is computed by expanding `shuffle_indices`. + LinearizeVectorShuffle, + + /// BEFORE + /// %1 = vector.splat %value : vector<4x4xf32> + /// + /// AFTER + /// %0 = vector.splat %value : vector<16xf32> + /// %1 = vector.shape_cast %0 : vector<16xf32> to vector<4x4xf32> + LinearizeVectorSplat, + + /// This pattern converts a vector.extract_strided_slice operation into a + /// vector.shuffle operation that has rank-1 (linearized) operand and + /// result. + /// + /// BEFORE + /// %out_nd = vector.extract_strided_slice %source_nd + /// { offsets = [..], strides = [..], sizes = [..] } + /// + /// AFTER + /// %source_1d = vector.shape_cast %source_nd [...] + /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] + /// %out_nd = vector.shape_cast %out_1d [...] + /// + /// `shuffle_indices_1d` is computed using the offsets and sizes of the + /// original vector.extract_strided_slice operation. + VectorExtractStridedSliceToRankOneShuffle, + + /// BEFORE + /// %extract = vector.extract %src [ position ] + /// + /// AFTER + /// %src_1d = vector.shape_cast %src : [...] + /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices ] + /// %out_nd = vector.shape_cast %out_1d : [...] + /// + /// `shuffle_indices` is computed from `position` of original extract. + VectorExtractToRankOneShuffle, + + /// This pattern converts a vector.insert_strided_slice operation into a + /// vector.shuffle operation that has rank-1 (linearized) operands and result. + /// + /// BEFORE + /// %0 = vector.insert_strided_slice %to_store, %into + /// {offsets = [1, 0, 0, 0], strides = [1, 1]} + /// : vector<2x2xi8> into vector<2x1x3x2xi8> + /// AFTER + /// %to_store_1d + /// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8> + /// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8> + /// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ] + /// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8> + /// + /// where shuffle_indices_1d in this case is + /// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11]. + /// ^^^^^^^^^^^^^^ + /// to_store_1d + VectorInsertStridedSliceToRankOneShuffle, + + /// BEFORE + /// %insert = vector.insert %src %dst [ position ] + /// + /// AFTER + /// %src_1d = vector.shape_cast %src : [...] + /// %dst_1d = vector.shape_cast %dst : [...] + /// %out_1d = vector.shuffle %dst_1d, %src_1d [ shuffle_indices ] + /// %out_nd = vector.shape_cast %out_1d : [...] + /// + /// `shuffle_indices` is computed from `position`. + VectorInsertToRankOneShuffle, + + /// The number of patterns in this enum. + N +}; + +/// This class contains functions to control the set of linearization patterns +/// to include for the conversion, and their priority. +struct VectorLinearizePatterns { + +public: + /// By default all patterns are enabled and have benefit 1. + VectorLinearizePatterns() { + enabled.fill(true); + benefits.fill(PatternBenefit(1)); + } + + /// Add the patterns enabled for the conversion to `patterns`. + void addToPatternSet(const TypeConverter &, + RewritePatternSet &patterns) const; + + VectorLinearizePatterns &enable(LinearizePattern id, bool e = true) { + enabled[static_cast(id)] = e; + return *this; + } + + VectorLinearizePatterns &enableAll(bool e = true) { + enabled.fill(e); + return *this; + } + + bool isEnabled(LinearizePattern id) const { + return enabled[static_cast(id)]; + } + + PatternBenefit getBenefit(LinearizePattern id) const { + return benefits[static_cast(id)]; + } + + VectorLinearizePatterns &setBenefit(LinearizePattern id, + PatternBenefit benefit) { + getBenefitRef(id) = benefit; + return *this; + } + + VectorLinearizePatterns &incrementBenefit(LinearizePattern id, + unsigned inc = 1) { + getBenefitRef(id) = getBenefit(id).getBenefit() + 1; + return *this; + } + +private: + std::array(LinearizePattern::N)> enabled; + std::array(LinearizePattern::N)> + benefits; + + PatternBenefit &getBenefitRef(LinearizePattern id) { + unsigned idInt = static_cast(id); + assert(idInt < static_cast(LinearizePattern::N) && + "invalid linearization pattern id"); + return benefits[idInt]; + } +}; + +/// Consider inserting a vector of shape `small` into a vector of shape `large`, +/// at position `offsets`: this function enumerates all the indices in `large` +/// that are written to. The enumeration is with row-major ordering. +/// +/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2 +/// positions written to are (1,3) and (1,4), which have linearized indices 8 +/// and 9. So [8,9] is returned. +/// +/// The length of the returned vector is equal to the number of elements in +/// the shape `small` (i.e. the product of dimensions of `small`). +SmallVector getStridedSliceInsertionIndices(ArrayRef small, + ArrayRef large, + ArrayRef offsets); + +} // namespace vector +} // namespace mlir + +#endif diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index 34a94e6ea7051..6954cb7172129 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -406,39 +406,6 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, void populateVectorTransposeNarrowTypeRewritePatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); -/// Initialize `typeConverter` and `conversionTarget` for vector linearization. -/// -/// Definition: here 'linearization' means converting a single operation with -/// 1+ vector operand/result of rank>1, into a new single operation whose -/// vector operands and results are all of rank<=1. -/// -/// This function registers (1) which operations are legal, and hence should not -/// be linearized, (2) what the converted types are (rank-1 vectors) and how to -/// materialze the conversion (with shape_cast) -/// -/// Note: the set of legal operations can be extended by a user if for example -/// certain rank>1 vectors are considered valid, by adding additional -/// dynamically legal ops to `conversionTarget`. -/// -/// Further note: the choice to use a dialect conversion design for -/// linearization is to make it easy to reuse generic structural type -/// conversions for linearizing scf/cf/func operations -void populateForVectorLinearize(TypeConverter &typeConverter, - ConversionTarget &conversionTarget); - -/// Populates `patterns` for ND vector (N >= 2) linearization. This currently -/// contains patterns for converting ConstantLike, Vectorizable, and -/// vector::BitCast ops. -void populateVectorLinearizeBasePatterns(const TypeConverter &, - const ConversionTarget &, - RewritePatternSet &patterns); - -/// Populates `patterns` for linearizing ND (N >= 2) vector operations -/// to 1D vector shuffle operations. -void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &, - const ConversionTarget &, - RewritePatternSet &patterns); - } // namespace vector } // namespace mlir diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 678a88627ca82..0c11c9b5c8740 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -10,9 +10,10 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Vector/Transforms/VectorLinearize.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Operation.h" @@ -47,12 +48,21 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter, namespace { +/// This pattern converts a constant (or poison) vector of rank>1 into a +/// 1D vector, followed by a shape_cast. +/// +/// BEFORE +/// %1 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32> +/// +/// AFTER +/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32> +/// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32> struct LinearizeConstantLike final : OpTraitConversionPattern { using OpTraitConversionPattern::OpTraitConversionPattern; LinearizeConstantLike(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) + MLIRContext *context, PatternBenefit benefit) : OpTraitConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -88,13 +98,20 @@ struct LinearizeConstantLike final } }; +/// BEFORE +/// %2 = math.sin %arg0 : vector<2x2xf32> +/// +/// AFTER +/// %0 = vector.shape_cast %arg0 : vector<2x2xf32> to vector<4xf32> +/// %1 = math.sin %0 : vector<4xf32> +/// %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32> struct LinearizeVectorizable final : OpTraitConversionPattern { using OpTraitConversionPattern::OpTraitConversionPattern; public: LinearizeVectorizable(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) + MLIRContext *context, PatternBenefit benefit) : OpTraitConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -109,17 +126,178 @@ struct LinearizeVectorizable final } }; -template -static bool stridesAllOne(TOp op) { - static_assert( - std::is_same_v || - std::is_same_v, - "expected vector.extract_strided_slice or vector.insert_strided_slice"); - ArrayAttr strides = op.getStrides(); - return llvm::all_of(strides, isOneInteger); -} +/// BEFORE +/// %2 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16> +/// +/// AFTER +/// %0 = vector.shape_cast %arg0 : vector<4x4xf32> to vector<16xf32> +/// %1 = vector.bitcast %0 : vector<16xf32> to vector<32xf16> +/// %2 = vector.shape_cast %1 : vector<32xf16> to vector<4x8xf16> +struct LinearizeVectorBitCast final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorBitCast(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit) + : OpConversionPattern(typeConverter, context, benefit) {} + LogicalResult + matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto resType = getTypeConverter()->convertType(castOp.getType()); + assert(resType && "expected 1-D vector type"); + rewriter.replaceOpWithNewOp(castOp, resType, + adaptor.getSource()); + return success(); + } +}; + +/// This pattern converts the vector.create_mask to work on a linearized vector. +/// It currently supports only 2D masks with a unit outer dimension. +/// +/// BEFORE +/// vector.create_mask %arg0, %arg1 : vector<1x4xi1> +/// +/// AFTER +/// %zero = arith.constant 0 : index +/// %cmpi = arith.cmpi sgt, %arg0, %zero : index +/// %index = arith.index_cast %cmpi : i1 to index +/// %mul = arith.andi %index, %arg1 : index +/// %mask = vector.create_mask %mul : vector<4xi1> +/// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1> +struct LinearizeVectorCreateMask final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorCreateMask(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = createMaskOp.getLoc(); + VectorType srcTy = createMaskOp.getType(); + auto srcShape = srcTy.getShape(); + if (srcShape.size() != 2) + return rewriter.notifyMatchFailure(createMaskOp, + "only 2D mask is supported."); + + if (srcShape[0] != 1) + return rewriter.notifyMatchFailure( + createMaskOp, "only unit outer dimension is supported."); + + auto dstTy = getTypeConverter()->convertType(srcTy); + if (!dstTy) + return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type."); + + // Compare the first operand with 0. If it is greater than 0, the + // corresponding mask element is set to true, otherwise false. + // The result of the comparison is then multiplied with + // the second operand of create_mask to get the 1D mask. + auto firstOperand = adaptor.getOperands().front(); + auto zero = rewriter.create(loc, 0); + auto isNonZero = rewriter.createOrFold( + loc, arith::CmpIPredicate::sgt, firstOperand, zero); + auto isNonZeroIndex = rewriter.createOrFold( + loc, rewriter.getIndexType(), isNonZero); + auto secondOperand = adaptor.getOperands().back(); + auto maskSize = rewriter.createOrFold( + loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand); + + auto newMask = rewriter.create(loc, dstTy, maskSize); + rewriter.replaceOp(createMaskOp, newMask); + return success(); + } +}; + +/// This pattern converts a vector.shuffle that works on nD (n > 1) vectors to +/// a one that works on linearized vectors. +/// +/// BEFORE +/// %shuffle_3d = vector.shuffle %v1_3d, %v2_3d [ shuffle_indices ] +/// +/// AFTER +/// %v1_1d = vector.shape_cast %v1_3d : [...] +/// %v2_1d = vector.shape_cast %v2_3d : [...] +/// %shuffle_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ] +/// %shuffle_3d = vector.shape_cast %shuffle_1d : [...] +/// +/// Where `shuffle_indices_1d` is computed by expanding `shuffle_indices`. +struct LinearizeVectorShuffle final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorShuffle(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType dstType = + getTypeConverter()->convertType(shuffleOp.getType()); + assert(dstType && "vector type destination expected."); + + Value vec1 = adaptor.getV1(); + Value vec2 = adaptor.getV2(); + int shuffleSliceLen = 1; + int rank = shuffleOp.getV1().getType().getRank(); -/// Convert an array of attributes into a vector of integers, if possible. + // If rank > 1, we need to do the shuffle in the granularity of slices + // instead of scalars. Size of the slice is equal to the rank-1 innermost + // dims. Mask of the shuffle op specifies which slice to take from the + // outermost dim. + if (rank > 1) { + llvm::ArrayRef shape = shuffleOp.getV1().getType().getShape(); + for (unsigned i = 1; i < shape.size(); ++i) { + shuffleSliceLen *= shape[i]; + } + } + + // For each value in the mask, we generate the indices of the source vectors + // that need to be shuffled to the destination vector. If shuffleSliceLen > + // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of + // elements) instead of scalars. + ArrayRef mask = shuffleOp.getMask(); + int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen; + llvm::SmallVector indices(totalSizeOfShuffledElmnts); + for (auto [i, value] : llvm::enumerate(mask)) { + std::iota(indices.begin() + shuffleSliceLen * i, + indices.begin() + shuffleSliceLen * (i + 1), + shuffleSliceLen * value); + } + + rewriter.replaceOpWithNewOp(shuffleOp, dstType, vec1, + vec2, indices); + return success(); + } +}; + +/// BEFORE +/// %1 = vector.splat %value : vector<4x4xf32> +/// +/// AFTER +/// %0 = vector.splat %value : vector<16xf32> +/// %1 = vector.shape_cast %0 : vector<16xf32> to vector<4x4xf32> +struct LinearizeVectorSplat final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstTy = getTypeConverter()->convertType(splatOp.getType()); + if (!dstTy) + return rewriter.notifyMatchFailure(splatOp, "cannot convert type."); + rewriter.replaceOpWithNewOp(splatOp, adaptor.getInput(), + dstTy); + return success(); + } +}; + +/// Convert an array of attributes into a vector of integers. static FailureOr> intsFromArrayAttr(ArrayAttr attrs) { if (!attrs) return failure(); @@ -135,89 +313,27 @@ static FailureOr> intsFromArrayAttr(ArrayAttr attrs) { return ints; } -/// Consider inserting a vector of shape `small` into a vector of shape `large`, -/// at position `offsets`: this function enumeratates all the indices in `large` -/// that are written to. The enumeration is with row-major ordering. -/// -/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2 -/// positions written to are (1,3) and (1,4), which have linearized indices 8 -/// and 9. So [8,9] is returned. -/// -/// The length of the returned vector is equal to the number of elements in -/// the shape `small` (i.e. the product of dimensions of `small`). -SmallVector static getStridedSliceInsertionIndices( - ArrayRef small, ArrayRef large, - ArrayRef offsets) { - - // Example of alignment between, `large`, `small` and `offsets`: - // large = 4, 5, 6, 7, 8 - // small = 1, 6, 7, 8 - // offsets = 2, 3, 0 - // - // `offsets` has implicit trailing 0s, `small` has implicit leading 1s. - assert((large.size() >= small.size()) && - "rank of 'large' cannot be lower than rank of 'small'"); - assert((large.size() >= offsets.size()) && - "rank of 'large' cannot be lower than the number of offsets"); - unsigned delta = large.size() - small.size(); - unsigned nOffsets = offsets.size(); - auto getSmall = [&](int64_t i) -> int64_t { - return i >= delta ? small[i - delta] : 1; - }; - auto getOffset = [&](int64_t i) -> int64_t { - return i < nOffsets ? offsets[i] : 0; - }; - - // Using 2 vectors of indices, at each iteration populate the updated set of - // indices based on the old set of indices, and the size of the small vector - // in the current iteration. - SmallVector indices{0}; - int64_t stride = 1; - for (int i = large.size() - 1; i >= 0; --i) { - int64_t currentSize = indices.size(); - int64_t smallSize = getSmall(i); - int64_t nextSize = currentSize * smallSize; - SmallVector nextIndices(nextSize); - int64_t *base = nextIndices.begin(); - int64_t offset = getOffset(i) * stride; - for (int j = 0; j < smallSize; ++j) { - for (int k = 0; k < currentSize; ++k) { - base[k] = indices[k] + offset; - } - offset += stride; - base += currentSize; - } - stride *= large[i]; - indices = std::move(nextIndices); - } - return indices; -} - /// This pattern converts a vector.extract_strided_slice operation into a -/// vector.shuffle operation that has a rank-1 (linearized) operand and result. -/// -/// For example, the following: +/// vector.shuffle operation that has rank-1 (linearized) operand and +/// result. /// -/// ``` -/// vector.extract_strided_slice %source +/// BEFORE +/// %out_nd = vector.extract_strided_slice %source_nd /// { offsets = [..], strides = [..], sizes = [..] } -/// ``` /// -/// is converted to : -/// ``` -/// %source_1d = vector.shape_cast %source -/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] -/// %out_nd = vector.shape_cast %out_1d -/// ``` +/// AFTER +/// %source_1d = vector.shape_cast %source_nd [...] +/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] +/// %out_nd = vector.shape_cast %out_1d [...] /// -/// `shuffle_indices_1d` is computed using the offsets and sizes of the original -/// vector.extract_strided_slice operation. -struct LinearizeVectorExtractStridedSlice final - : public mlir::OpConversionPattern { +/// `shuffle_indices_1d` is computed using the offsets and sizes of the +/// original vector.extract_strided_slice operation. +struct VectorExtractStridedSliceToRankOneShuffle final + : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter, - MLIRContext *context, - PatternBenefit benefit = 1) + VectorExtractStridedSliceToRankOneShuffle(const TypeConverter &typeConverter, + MLIRContext *context, + PatternBenefit benefit) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult @@ -231,7 +347,7 @@ struct LinearizeVectorExtractStridedSlice final // Expect a legalization failure if the strides are not all 1 (if ever the // verifier for extract_strided_slice allows non-1 strides). - if (!stridesAllOne(extractStridedSliceOp)) { + if (extractStridedSliceOp.hasNonUnitStrides()) { return rewriter.notifyMatchFailure( extractStridedSliceOp, "extract_strided_slice with strides != 1 not supported"); @@ -249,7 +365,7 @@ struct LinearizeVectorExtractStridedSlice final ArrayRef outputShape = extractStridedSliceOp.getType().getShape(); - SmallVector indices = getStridedSliceInsertionIndices( + SmallVector indices = vector::getStridedSliceInsertionIndices( outputShape, inputShape, offsets.value()); Value srcVector = adaptor.getVector(); @@ -259,36 +375,81 @@ struct LinearizeVectorExtractStridedSlice final } }; +/// BEFORE +/// %extract = vector.extract %src [ position ] +/// +/// AFTER +/// %src_1d = vector.shape_cast %src : [...] +/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices ] +/// %out_nd = vector.shape_cast %out_1d : [...] +/// +/// `shuffle_indices` is computed from `position` of original extract. +struct VectorExtractToRankOneShuffle final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + VectorExtractToRankOneShuffle(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit) + : OpConversionPattern(typeConverter, context, benefit) {} + LogicalResult + matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Skip if result is not a vector type + if (!isa(extractOp.getType())) + return rewriter.notifyMatchFailure(extractOp, + "scalar extract not supported"); + Type dstTy = getTypeConverter()->convertType(extractOp.getType()); + assert(dstTy && "expected 1-D vector type"); + + // Dynamic position is not supported. + if (extractOp.hasDynamicPosition()) + return rewriter.notifyMatchFailure(extractOp, + "dynamic position is not supported."); + + llvm::ArrayRef shape = extractOp.getVector().getType().getShape(); + int64_t size = extractOp.getVector().getType().getNumElements(); + + // Compute linearized offset. + int64_t linearizedOffset = 0; + llvm::ArrayRef offsets = extractOp.getStaticPosition(); + for (auto [i, off] : llvm::enumerate(offsets)) { + size /= shape[i]; + linearizedOffset += offsets[i] * size; + } + + Value v0 = adaptor.getVector(); + llvm::SmallVector indices(size); + std::iota(indices.begin(), indices.end(), linearizedOffset); + rewriter.replaceOpWithNewOp(extractOp, dstTy, v0, v0, + indices); + + return success(); + } +}; + /// This pattern converts a vector.insert_strided_slice operation into a /// vector.shuffle operation that has rank-1 (linearized) operands and result. /// -/// For example, the following: -/// ``` -/// %0 = vector.insert_strided_slice %to_store, %into +/// BEFORE +/// %0 = vector.insert_strided_slice %to_store, %into /// {offsets = [1, 0, 0, 0], strides = [1, 1]} /// : vector<2x2xi8> into vector<2x1x3x2xi8> -/// ``` -/// -/// is converted to -/// ``` -/// %to_store_1d -/// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8> -/// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8> -/// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ] -/// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8> -/// ``` +/// AFTER +/// %to_store_1d +/// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8> +/// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8> +/// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ] +/// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8> /// /// where shuffle_indices_1d in this case is /// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11]. /// ^^^^^^^^^^^^^^ /// to_store_1d -/// -struct LinearizeVectorInsertStridedSlice final - : public mlir::OpConversionPattern { +struct VectorInsertStridedSliceToRankOneShuffle final + : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter, - MLIRContext *context, - PatternBenefit benefit = 1) + VectorInsertStridedSliceToRankOneShuffle(const TypeConverter &typeConverter, + MLIRContext *context, + PatternBenefit benefit) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult @@ -298,7 +459,7 @@ struct LinearizeVectorInsertStridedSlice final // Expect a legalization failure if the strides are not all 1 (if ever the // verifier for insert_strided_slice allows non-1 strides). - if (!stridesAllOne(insertStridedSliceOp)) { + if (insertStridedSliceOp.hasNonUnitStrides()) { return rewriter.notifyMatchFailure( insertStridedSliceOp, "insert_strided_slice with strides != 1 not supported"); @@ -317,7 +478,7 @@ struct LinearizeVectorInsertStridedSlice final return rewriter.notifyMatchFailure(insertStridedSliceOp, "failed to get integer offsets"); } - SmallVector sliceIndices = getStridedSliceInsertionIndices( + SmallVector sliceIndices = vector::getStridedSliceInsertionIndices( inputShape, outputShape, offsets.value()); SmallVector indices(nOutputElements); @@ -335,131 +496,22 @@ struct LinearizeVectorInsertStridedSlice final } }; -/// This pattern converts the ShuffleOp that works on nD (n > 1) -/// vectors to a ShuffleOp that works on linearized vectors. -/// Following, -/// vector.shuffle %v1, %v2 [ shuffle_indices ] -/// is converted to : -/// %v1_1d = vector.shape_cast %v1 -/// %v2_1d = vector.shape_cast %v2 -/// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ] -/// %out_nd = vector.shape_cast %out_1d -// `shuffle_indices_1d` is computed using the sizes and `shuffle_indices` -/// of the original shuffle operation. -struct LinearizeVectorShuffle final - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LinearizeVectorShuffle(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} - - LogicalResult - matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - VectorType dstType = - getTypeConverter()->convertType(shuffleOp.getType()); - assert(dstType && "vector type destination expected."); - - Value vec1 = adaptor.getV1(); - Value vec2 = adaptor.getV2(); - int shuffleSliceLen = 1; - int rank = shuffleOp.getV1().getType().getRank(); - - // If rank > 1, we need to do the shuffle in the granularity of slices - // instead of scalars. Size of the slice is equal to the rank-1 innermost - // dims. Mask of the shuffle op specifies which slice to take from the - // outermost dim. - if (rank > 1) { - llvm::ArrayRef shape = shuffleOp.getV1().getType().getShape(); - for (unsigned i = 1; i < shape.size(); ++i) { - shuffleSliceLen *= shape[i]; - } - } - - // For each value in the mask, we generate the indices of the source vectors - // that need to be shuffled to the destination vector. If shuffleSliceLen > - // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of - // elements) instead of scalars. - ArrayRef mask = shuffleOp.getMask(); - int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen; - llvm::SmallVector indices(totalSizeOfShuffledElmnts); - for (auto [i, value] : llvm::enumerate(mask)) { - std::iota(indices.begin() + shuffleSliceLen * i, - indices.begin() + shuffleSliceLen * (i + 1), - shuffleSliceLen * value); - } - - rewriter.replaceOpWithNewOp(shuffleOp, dstType, vec1, - vec2, indices); - return success(); - } -}; - -/// This pattern converts the ExtractOp to a ShuffleOp that works on a -/// linearized vector. -/// Following, -/// vector.extract %source [ position ] -/// is converted to : -/// %source_1d = vector.shape_cast %source -/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] -/// %out_nd = vector.shape_cast %out_1d -/// `shuffle_indices_1d` is computed using the position of the original extract. -struct LinearizeVectorExtract final - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LinearizeVectorExtract(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} - LogicalResult - matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Skip if result is not a vector type - if (!isa(extractOp.getType())) - return rewriter.notifyMatchFailure(extractOp, - "scalar extract not supported"); - Type dstTy = getTypeConverter()->convertType(extractOp.getType()); - assert(dstTy && "expected 1-D vector type"); - - // Dynamic position is not supported. - if (extractOp.hasDynamicPosition()) - return rewriter.notifyMatchFailure(extractOp, - "dynamic position is not supported."); - - llvm::ArrayRef shape = extractOp.getVector().getType().getShape(); - int64_t size = extractOp.getVector().getType().getNumElements(); - - // Compute linearized offset. - int64_t linearizedOffset = 0; - llvm::ArrayRef offsets = extractOp.getStaticPosition(); - for (auto [i, off] : llvm::enumerate(offsets)) { - size /= shape[i]; - linearizedOffset += offsets[i] * size; - } - - llvm::SmallVector indices(size); - std::iota(indices.begin(), indices.end(), linearizedOffset); - rewriter.replaceOpWithNewOp( - extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices); - - return success(); - } -}; +/// BEFORE +/// %insert = vector.insert %src %dst [ position ] +/// +/// AFTER +/// %src_1d = vector.shape_cast %src : [...] +/// %dst_1d = vector.shape_cast %dst : [...] +/// %out_1d = vector.shuffle %dst_1d, %src_1d [ shuffle_indices ] +/// %out_nd = vector.shape_cast %out_1d : [...] +/// +/// `shuffle_indices` is computed from `position`. -/// This pattern converts the InsertOp to a ShuffleOp that works on a -/// linearized vector. -/// Following, -/// vector.insert %source %destination [ position ] -/// is converted to : -/// %source_1d = vector.shape_cast %source -/// %destination_1d = vector.shape_cast %destination -/// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d -/// ] %out_nd = vector.shape_cast %out_1d -/// `shuffle_indices_1d` is computed using the position of the original insert. -struct LinearizeVectorInsert final +struct VectorInsertToRankOneShuffle final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LinearizeVectorInsert(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) + VectorInsertToRankOneShuffle(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, @@ -514,215 +566,251 @@ struct LinearizeVectorInsert final } }; -/// This pattern converts the BitCastOp that works on nD (n > 1) -/// vectors to a BitCastOp that works on linearized vectors. -/// Following, -/// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16> -/// is converted to : -/// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32> -/// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16> -/// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16> -struct LinearizeVectorBitCast final - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LinearizeVectorBitCast(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} - LogicalResult - matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto resType = getTypeConverter()->convertType(castOp.getType()); - assert(resType && "expected 1-D vector type"); - rewriter.replaceOpWithNewOp(castOp, resType, - adaptor.getSource()); - return mlir::success(); - } -}; - -/// This pattern converts the SplatOp to work on a linearized vector. -/// Following, -/// vector.splat %value : vector<4x4xf32> -/// is converted to: -/// %out_1d = vector.splat %value : vector<16xf32> -/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> -struct LinearizeVectorSplat final - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context, - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} - - LogicalResult - matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto dstTy = getTypeConverter()->convertType(splatOp.getType()); - if (!dstTy) - return rewriter.notifyMatchFailure(splatOp, "cannot convert type."); - rewriter.replaceOpWithNewOp(splatOp, adaptor.getInput(), - dstTy); - return success(); - } -}; - -/// This pattern converts the CreateMaskOp to work on a linearized vector. -/// It currently supports only 2D masks with a unit outer dimension. -/// Following, -/// vector.create_mask %arg0, %arg1 : vector<1x4xi1> -/// is converted to: -/// %zero = arith.constant 0 : index -/// %cmpi = arith.cmpi sgt, %arg0, %zero : index -/// %index = arith.index_cast %cmpi : i1 to index -/// %mul = arith.andi %index, %arg1 : index -/// %mask = vector.create_mask %mul : vector<4xi1> -/// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1> -struct LinearizeVectorCreateMask final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LinearizeVectorCreateMask(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} - - LogicalResult - matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = createMaskOp.getLoc(); - VectorType srcTy = createMaskOp.getType(); - auto srcShape = srcTy.getShape(); - if (srcShape.size() != 2) - return rewriter.notifyMatchFailure(createMaskOp, - "only 2D mask is supported."); - - if (srcShape[0] != 1) - return rewriter.notifyMatchFailure( - createMaskOp, "only unit outer dimension is supported."); - - auto dstTy = getTypeConverter()->convertType(srcTy); - if (!dstTy) - return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type."); - - // Compare the first operand with 0. If it is greater than 0, the - // corresponding mask element is set to true, otherwise false. - // The result of the comparison is then multiplied with - // the second operand of create_mask to get the 1D mask. - auto firstOperand = adaptor.getOperands().front(); - auto zero = rewriter.create(loc, 0); - auto isNonZero = rewriter.createOrFold( - loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero); - auto isNonZeroIndex = rewriter.createOrFold( - loc, rewriter.getIndexType(), isNonZero); - auto secondOperand = adaptor.getOperands().back(); - auto maskSize = rewriter.createOrFold( - loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand); - - auto newMask = - rewriter.create(loc, dstTy, maskSize); - rewriter.replaceOp(createMaskOp, newMask); - return success(); - } -}; - } // namespace -/// This method defines the set of operations that are linearizable, and hence -/// that are considered illegal for the conversion target. -static bool isLinearizable(Operation *op) { - - // Only ops that are in the vector dialect, are ConstantLike, or - // are Vectorizable might be linearized currently. - StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace(); - StringRef opDialect = op->getDialect()->getNamespace(); - bool supported = (opDialect == vectorDialect) || - op->hasTrait() || - op->hasTrait(); - if (!supported) - return false; - +/// Return true if `op` is an insert, extract, insert_strided_slice, or +/// extract_strided_slice operation that operates on scalable vectors. +/// Otherwise return false. +static bool isScalableExtractOrInsertOrStrided(Operation *op) { return TypeSwitch(op) - // As type legalization is done with vector.shape_cast, shape_cast - // itself cannot be linearized (will create new shape_casts to linearize - // ad infinitum). - .Case([&](auto) { return false; }) - // The operations - // - vector.extract_strided_slice - // - vector.extract - // - vector.insert_strided_slice - // - vector.insert - // are linearized to a rank-1 vector.shuffle by the current patterns. - // vector.shuffle only supports fixed size vectors, so it is impossible to - // use this approach to linearize these ops if they operate on scalable - // vectors. .Case( [&](vector::ExtractStridedSliceOp extractOp) { - return !extractOp.getType().isScalable(); + return extractOp.getType().isScalable(); }) .Case( [&](vector::InsertStridedSliceOp insertOp) { - return !insertOp.getType().isScalable(); + return insertOp.getType().isScalable(); }) .Case([&](vector::InsertOp insertOp) { - return !insertOp.getType().isScalable(); + return insertOp.getType().isScalable(); }) .Case([&](vector::ExtractOp extractOp) { - return !extractOp.getSourceVectorType().isScalable(); + return extractOp.getSourceVectorType().isScalable(); }) - .Default([&](auto) { return true; }); + .Default([&](auto) { return false; }); } -void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter, - ConversionTarget &target) { +SmallVector +vector::getStridedSliceInsertionIndices(ArrayRef small, + ArrayRef large, + ArrayRef offsets) { + + // Example of alignment between, `large`, `small` and `offsets`: + // large = 4, 5, 6, 7, 8 + // small = 1, 6, 7, 8 + // offsets = 2, 3, 0 + // + // `offsets` has implicit trailing 0s, `small` has implicit leading 1s. + assert((large.size() >= small.size()) && + "rank of 'large' cannot be lower than rank of 'small'"); + assert((large.size() >= offsets.size()) && + "rank of 'large' cannot be lower than the number of offsets"); + unsigned delta = large.size() - small.size(); + unsigned nOffsets = offsets.size(); + auto getSmall = [&](int64_t i) -> int64_t { + return i >= delta ? small[i - delta] : 1; + }; + auto getOffset = [&](int64_t i) -> int64_t { + return i < nOffsets ? offsets[i] : 0; + }; + + // Using 2 vectors of indices, at each iteration populate the updated set of + // indices based on the old set of indices, and the size of the small vector + // in the current iteration. + SmallVector indices{0}; + int64_t stride = 1; + for (int i = large.size() - 1; i >= 0; --i) { + int64_t currentSize = indices.size(); + int64_t smallSize = getSmall(i); + int64_t nextSize = currentSize * smallSize; + SmallVector nextIndices(nextSize); + int64_t *base = nextIndices.begin(); + int64_t offset = getOffset(i) * stride; + for (int j = 0; j < smallSize; ++j) { + for (int k = 0; k < currentSize; ++k) { + base[k] = indices[k] + offset; + } + offset += stride; + base += currentSize; + } + stride *= large[i]; + indices = std::move(nextIndices); + } + return indices; +} + +void vector::initializeForVectorLinearize(TypeConverter &typeConverter) { auto convertType = [](Type type) -> std::optional { VectorType vectorType = dyn_cast(type); - if (!vectorType || !isLinearizableVector(vectorType)) + + if (!vectorType || !vector::isLinearizableVector(vectorType)) return type; VectorType linearizedType = VectorType::get(vectorType.getNumElements(), vectorType.getElementType(), vectorType.isScalable()); + return linearizedType; }; typeConverter.addConversion(convertType); auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { - if (inputs.size() != 1) + if (inputs.size() != 1) { return nullptr; - + } Value value = inputs.front(); - if (!isa(type) || !isa(value.getType())) + if (!isa(type) || !isa(value.getType())) { return nullptr; - + } return builder.create(loc, type, value); }; typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); +} + +void vector::populateForFullVectorLinearize(const TypeConverter &typeConverter, + ConversionTarget &target, + RewritePatternSet &patterns) { target.markUnknownOpDynamicallyLegal( [=](Operation *op) -> std::optional { - if (!isLinearizable(op)) + // Only ops that are in the vector dialect, are ConstantLike, or + // are Vectorizable might be linearized currently. + StringLiteral vectorDialect = + vector::VectorDialect::getDialectNamespace(); + StringRef opDialect = op->getDialect()->getNamespace(); + bool supported = (opDialect == vectorDialect) || + op->hasTrait() || + op->hasTrait(); + if (!supported) + return true; + + // As type legalization is done with vector.shape_cast, shape_cast + // itself cannot be linearized (doing so would create new shape_casts to + // linearize ad infinitum). + if (isa(op)) + return true; + + // The operations extract_strided_slice, extract, insert_strided_slice, + // and insert are linearized to a rank-1 operations that do not fully + // support scalable vectors, so it is not generally possible to + // linearize these ops if they operate on scalable vectors. + if (isScalableExtractOrInsertOrStrided(op)) return true; + // This will return true if, for all operand and result types `t`, // convertType(t) = t. This is true if there are no rank>=2 vectors. return typeConverter.isLegal(op); }); -} -void mlir::vector::populateVectorLinearizeBasePatterns( - const TypeConverter &typeConverter, const ConversionTarget &target, - RewritePatternSet &patterns) { - patterns - .add( - typeConverter, patterns.getContext()); + VectorLinearizePatterns linearizePatterns; + + // Mark extract_strided_slice, insert_strided_slice, extract with source + // rank > 1, and insert with result rank > 1 as illegal, as they must be + // converted to shuffle or rank-1 extract/insert. + // + // Note that the order of the calls to `markUnknownOpDynamicallyLegal` + // is important: the legality rule added here takes precedence over the + // generic one preceding it which marked these ops as legal. + target.markUnknownOpDynamicallyLegal( + [](Operation *op) -> std::optional { + bool isStrided = + isa( + op); + + bool isHighRankExtractOrInsert = [&]() { + if (auto extractOp = dyn_cast(op)) { + return extractOp.getSourceVectorType().getRank() > 1; + } + if (auto insertOp = dyn_cast(op)) { + return insertOp.getType().getRank() > 1; + } + return false; + }(); + + bool isScalable = isScalableExtractOrInsertOrStrided(op); + + if ((isStrided || isHighRankExtractOrInsert) && !isScalable) { + return false; + } + return std::nullopt; + }); + + // Ensure that the benefit of patterns targetting shuffle is higher than + // the benefit of patterns targeting rank-1 strided slice operations. This + // will ensure that patterns for converting to rank-1 shuffle are run first. + linearizePatterns + .incrementBenefit( + LinearizePattern::VectorExtractStridedSliceToRankOneShuffle) + .incrementBenefit( + LinearizePattern::VectorInsertStridedSliceToRankOneShuffle) + .incrementBenefit(LinearizePattern::VectorExtractToRankOneShuffle) + .incrementBenefit(LinearizePattern::VectorInsertToRankOneShuffle); + + linearizePatterns.addToPatternSet(typeConverter, patterns); } -void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( - const TypeConverter &typeConverter, const ConversionTarget &target, - RewritePatternSet &patterns) { - patterns.add(typeConverter, - patterns.getContext()); +void vector::VectorLinearizePatterns::addToPatternSet( + const TypeConverter &typeConverter, RewritePatternSet &patterns) const { + + MLIRContext *context = patterns.getContext(); + + if (isEnabled(LinearizePattern::LinearizeConstantLike)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::LinearizeConstantLike)); + + if (isEnabled(LinearizePattern::LinearizeVectorizable)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::LinearizeVectorizable)); + + if (isEnabled(LinearizePattern::LinearizeVectorBitCast)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::LinearizeVectorBitCast)); + + if (isEnabled(LinearizePattern::LinearizeVectorCreateMask)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::LinearizeVectorCreateMask)); + + if (isEnabled(LinearizePattern::LinearizeVectorShuffle)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::LinearizeVectorShuffle)); + + if (isEnabled(LinearizePattern::LinearizeVectorSplat)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::LinearizeVectorSplat)); + + // ------------------------ // + // Extract related patterns // + // ------------------------ // + if (isEnabled(LinearizePattern::VectorExtractStridedSliceToRankOneShuffle)) + patterns.add( + typeConverter, context, + getBenefit( + LinearizePattern::VectorExtractStridedSliceToRankOneShuffle)); + + if (isEnabled(LinearizePattern::VectorExtractToRankOneShuffle)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::VectorExtractToRankOneShuffle)); + + // ------------------------ // + // Insert related patterns // + // ------------------------ // + if (isEnabled(LinearizePattern::VectorInsertStridedSliceToRankOneShuffle)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::VectorInsertStridedSliceToRankOneShuffle)); + + if (isEnabled(LinearizePattern::VectorInsertToRankOneShuffle)) + patterns.add( + typeConverter, context, + getBenefit(LinearizePattern::VectorInsertToRankOneShuffle)); } diff --git a/mlir/test/Dialect/Vector/linearize-subject-to-bitwidth.mlir b/mlir/test/Dialect/Vector/linearize/linearize-subject-to-bitwidth.mlir similarity index 100% rename from mlir/test/Dialect/Vector/linearize-subject-to-bitwidth.mlir rename to mlir/test/Dialect/Vector/linearize/linearize-subject-to-bitwidth.mlir diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize/linearize.mlir similarity index 100% rename from mlir/test/Dialect/Vector/linearize.mlir rename to mlir/test/Dialect/Vector/linearize/linearize.mlir diff --git a/mlir/test/lib/Dialect/Vector/CMakeLists.txt b/mlir/test/lib/Dialect/Vector/CMakeLists.txt index e16937029ac0e..1ce069599af43 100644 --- a/mlir/test/lib/Dialect/Vector/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Vector/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRVectorTestPasses TestVectorTransforms.cpp + TestVectorLinearize.cpp EXCLUDE_FROM_LIBMLIR ) diff --git a/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp b/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp new file mode 100644 index 0000000000000..67179c9f98e9b --- /dev/null +++ b/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp @@ -0,0 +1,185 @@ +//===- TestVectorLinearize.cpp - Test Vector linearization ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Math//IR/Math.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" +#include "mlir/Dialect/Vector/Transforms/VectorLinearize.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" + +using namespace mlir; +using namespace mlir::vector; + +namespace { + +struct TestVectorLinearize final + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) + + TestVectorLinearize() = default; + + StringRef getArgument() const override { return "test-vector-linearize"; } + StringRef getDescription() const override { + return "Linearizes ND vectors for N >= 2 into 1D vectors"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext &context = getContext(); + TypeConverter converter; + RewritePatternSet patterns(&context); + ConversionTarget target(context); + initializeForVectorLinearize(converter); + populateForFullVectorLinearize(converter, target, patterns); + + mlir::scf::populateSCFStructuralTypeConversionsAndLegality( + converter, patterns, target); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; + +struct TestVectorBitWidthLinearize final + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize) + + TestVectorBitWidthLinearize() = default; + TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass) + : PassWrapper(pass) {} + + StringRef getArgument() const override { + return "test-bit-width-constrained-vector-linearize"; + } + StringRef getDescription() const override { + return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints " + "on inner-most dimension's bit width. If the inner-most dimension " + "exceded a threshold, the op is not linearized."; + } + Option targetVectorBitwidth{ + *this, "target-vector-bitwidth", + llvm::cl::desc( + "Minimum vector bitwidth to enable the flattening transformation"), + llvm::cl::init(std::numeric_limits::max())}; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext &context = getContext(); + TypeConverter typeConverter; + RewritePatternSet patterns(&context); + ConversionTarget target(context); + populateWithBitWidthConstraints(typeConverter, target, patterns, + targetVectorBitwidth); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } + +private: + /// If `type` is VectorType with trailing dimension of (bit) size greater than + /// or equal to `targetBitWidth`, its defining op is considered legal. + static bool + isNotLinearizableBecauseLargeInnerDimension(Type type, + unsigned targetBitWidth) { + + VectorType vecType = dyn_cast(type); + + // Not linearizable for reasons other than what this function checks. + if (!vecType || vecType.getRank() == 0) + return false; + + // The width of the type 'index' is unbounded (and therefore potentially + // above the target width). + if (vecType.getElementType().isIndex()) + return true; + + unsigned finalDimSize = vecType.getShape().back(); + unsigned nbBitsPerElm = vecType.getElementTypeBitWidth(); + unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm; + return trailingVecDimBitWidth >= targetBitWidth; + } + + static bool + isNotLinearizableBecauseLargeInnerDimension(Operation *op, + unsigned targetBitWidth) { + // Check on bitwidths. + SmallVector> toCheck = + getTypeBitWidthBoundPairs(op, targetBitWidth); + return std::any_of(toCheck.begin(), toCheck.end(), + [&](std::pair typeWidth) { + return isNotLinearizableBecauseLargeInnerDimension( + typeWidth.first, typeWidth.second); + }); + } + + static void populateWithBitWidthConstraints(TypeConverter &typeConverter, + ConversionTarget &target, + RewritePatternSet &patterns, + unsigned targetBitWidth) { + + initializeForVectorLinearize(typeConverter); + populateForFullVectorLinearize(typeConverter, target, patterns); + + // Extend the set of legal ops to include those with large inner-most + // dimensions on selected operands/results. + target.markUnknownOpDynamicallyLegal( + [=](Operation *op) -> std::optional { + if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) { + return true; + } + return {}; + }); + } + + /// Get the set of operand/result types to check for sufficiently + /// small inner-most dimension size. + static SmallVector> + getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) { + + if (auto insertOp = dyn_cast(op)) { + unsigned w = targetBitWidth < std::numeric_limits::max() + ? targetBitWidth + 1 + : targetBitWidth; + return {{insertOp.getValueToStoreType(), w}}; + } + + auto resultTypes = op->getResultTypes(); + SmallVector> resultsWithBitWidth; + resultsWithBitWidth.reserve(resultTypes.size()); + for (Type type : resultTypes) { + resultsWithBitWidth.push_back({type, targetBitWidth}); + } + return resultsWithBitWidth; + } +}; + +} // namespace + +namespace mlir { +namespace test { +extern void registerTestVectorLinearize() { + PassRegistration(); + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index f4f32e9339870..5c75d32c22236 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -17,7 +17,6 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" @@ -837,160 +836,6 @@ struct TestVectorEmulateMaskedLoadStore final } }; -/// Get the set of operand/result types to check for sufficiently -/// small inner-most dimension size. -static SmallVector> -getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) { - - if (auto insertOp = dyn_cast(op)) { - unsigned w = targetBitWidth < std::numeric_limits::max() - ? targetBitWidth + 1 - : targetBitWidth; - return {{insertOp.getValueToStoreType(), w}}; - } - - auto resultTypes = op->getResultTypes(); - SmallVector> resultsWithBitWidth; - resultsWithBitWidth.reserve(resultTypes.size()); - for (Type type : resultTypes) { - resultsWithBitWidth.push_back({type, targetBitWidth}); - } - return resultsWithBitWidth; -} - -/// If `type` is VectorType with trailing dimension of (bit) size greater than -/// or equal to `targetBitWidth`, its defining op is considered legal. -static bool -isNotLinearizableBecauseLargeInnerDimension(Type type, - unsigned targetBitWidth) { - - VectorType vecType = dyn_cast(type); - - // Not linearizable for reasons other than what this function checks. - if (!vecType || vecType.getRank() == 0) - return false; - - // The width of the type 'index' is unbounded (and therefore potentially above - // the target width). - if (vecType.getElementType().isIndex()) - return true; - - unsigned finalDimSize = vecType.getShape().back(); - unsigned nbBitsPerElm = vecType.getElementTypeBitWidth(); - unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm; - return trailingVecDimBitWidth >= targetBitWidth; -} - -static bool -isNotLinearizableBecauseLargeInnerDimension(Operation *op, - unsigned targetBitWidth) { - // Check on bitwidths. - SmallVector> toCheck = - getTypeBitWidthBoundPairs(op, targetBitWidth); - return llvm::any_of(toCheck, [&](std::pair typeWidth) { - return isNotLinearizableBecauseLargeInnerDimension(typeWidth.first, - typeWidth.second); - }); -} - -void populateWithBitWidthConstraints(TypeConverter &typeConverter, - ConversionTarget &target, - unsigned targetBitWidth) { - - // The general purpose definition of what ops are legal must come first. - populateForVectorLinearize(typeConverter, target); - - // Extend the set of legal ops to include those with large inner-most - // dimensions on selected operands/results. - target.markUnknownOpDynamicallyLegal( - [=](Operation *op) -> std::optional { - if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) { - return true; - } - return {}; - }); -} - -struct TestVectorBitWidthLinearize final - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize) - - TestVectorBitWidthLinearize() = default; - TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass) - : PassWrapper(pass) {} - - StringRef getArgument() const override { - return "test-bit-width-constrained-vector-linearize"; - } - StringRef getDescription() const override { - return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints " - "in inner-most dimension's bit width."; - } - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - Option targetVectorBitwidth{ - *this, "target-vector-bitwidth", - llvm::cl::desc( - "Minimum vector bitwidth to enable the flattening transformation"), - llvm::cl::init(std::numeric_limits::max())}; - void runOnOperation() override { - auto *context = &getContext(); - - TypeConverter typeConverter; - RewritePatternSet patterns(context); - ConversionTarget target(*context); - - populateWithBitWidthConstraints(typeConverter, target, - targetVectorBitwidth); - - vector::populateVectorLinearizeBasePatterns(typeConverter, target, - patterns); - - vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter, target, - patterns); - - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - return signalPassFailure(); - } -}; - -struct TestVectorLinearize final - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) - - TestVectorLinearize() = default; - - StringRef getArgument() const override { return "test-vector-linearize"; } - StringRef getDescription() const override { - return "Linearizes ND vectors for N >= 2 into 1D vectors"; - } - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - MLIRContext &context = getContext(); - TypeConverter converter; - RewritePatternSet patterns(&context); - ConversionTarget target(context); - - vector::populateForVectorLinearize(converter, target); - - vector::populateVectorLinearizeBasePatterns(converter, target, patterns); - vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target, - patterns); - mlir::scf::populateSCFStructuralTypeConversionsAndLegality( - converter, patterns, target); - - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - return signalPassFailure(); - } -}; - struct TestEliminateVectorMasks : public PassWrapper> { @@ -1062,10 +907,6 @@ void registerTestVectorLowerings() { PassRegistration(); - PassRegistration(); - - PassRegistration(); - PassRegistration(); } } // namespace test diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 2e08ae6f37980..f52f36107e301 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -155,6 +155,7 @@ void registerTestTopologicalSortAnalysisPass(); void registerTestTransformDialectEraseSchedulePass(); void registerTestPassStateExtensionCommunication(); void registerTestVectorLowerings(); +void registerTestVectorLinearize(); void registerTestVectorReductionToSPIRVDotProd(); void registerTestVulkanRunnerPipeline(); void registerTestWrittenToPass(); @@ -300,6 +301,7 @@ void registerTestPasses() { mlir::test::registerTestTransformDialectEraseSchedulePass(); mlir::test::registerTestPassStateExtensionCommunication(); mlir::test::registerTestVectorLowerings(); + mlir::test::registerTestVectorLinearize(); mlir::test::registerTestVectorReductionToSPIRVDotProd(); mlir::test::registerTestVulkanRunnerPipeline(); mlir::test::registerTestWrittenToPass();