From 8c7d760e59984ed295d7ec689f55982b6aa300ae Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Tue, 24 Jun 2025 06:33:22 +0000 Subject: [PATCH 1/4] [mlir][Vector] Add `vector.shuffle` tree transformation This PR adds a new transformation that turns sequences of `vector.to_elements` and `vector.from_elements` into a binary tree of `vector.shuffle` operations. (Related RFC: https://discourse.llvm.org/t/rfc-adding-vector-to-elements-op-to-the-vector-dialect/86779). Example: ``` %0:4 = vector.to_elements %a : vector<4xf32> %1:4 = vector.to_elements %b : vector<4xf32> %2:4 = vector.to_elements %c : vector<4xf32> %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, %1#2, %1#3, %2#0, %2#1, %2#2, %2#3 : vector<12xf32> ==> %0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> %1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1] : vector<4xf32>, vector<4xf32> %2 = vector.shuffle %0, %1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> ``` The algorithm leverages the structured extraction/insertion information of `vector.to_elements` and `vector.from_elements` operations and builds a set of intervals to determine the vector length that should be used at each level of the tree. There are a few improvements that can be implemented in the future, such as shuffle mask compression to avoid unnecessarily large vector lengths with poison values, but I decided to keep things "simpler" and spend more time documenting the different steps of the algorithm so that people can follow along. --- .../Vector/Transforms/LoweringPatterns.h | 7 + .../mlir/Dialect/Vector/Transforms/Passes.h | 1 + .../mlir/Dialect/Vector/Transforms/Passes.td | 5 + .../Dialect/Vector/Transforms/CMakeLists.txt | 1 + ...LowerVectorToFromElementsToShuffleTree.cpp | 692 ++++++++++++++++++ ...m-elements-to-shuffle-tree-transforms.mlir | 329 +++++++++ 6 files changed, 1035 insertions(+) create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp create mode 100644 mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 14cff4ff893b5..6761cd65c5009 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -297,6 +297,13 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, /// n > 1. void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns); +/// Populate patterns to rewrite sequences of `vector.to_elements` + +/// `vector.from_elements` operations into a tree of `vector.shuffle` +/// operations. +void populateVectorToFromElementsToShuffleTreePatterns( + RewritePatternSet &patterns, PatternBenefit benefit = 1); + } // namespace vector } // namespace mlir + #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h index 5667f4fa95ace..959c2fbf31f1a 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_ #define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_ +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td index 7436998749791..9431a4d8e240f 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td @@ -34,4 +34,9 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func ]; } +def LowerVectorToFromElementsToShuffleTree + : Pass<"lower-vector-to-from-elements-to-shuffle-tree", "func::FuncOp"> { + let summary = "Lower `vector.to_elements` and `vector.from_elements` to a tree of `vector.shuffle` operations"; +} + #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index 8ca5cb6c6dfab..9e287fc109990 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorScan.cpp LowerVectorShapeCast.cpp LowerVectorStep.cpp + LowerVectorToFromElementsToShuffleTree.cpp LowerVectorTransfer.cpp LowerVectorTranspose.cpp SubsetOpInterfaceImpl.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp new file mode 100644 index 0000000000000..53728d6dbe2a3 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp @@ -0,0 +1,692 @@ +//===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements pattern rewrites to lower sequences of +// `vector.to_elements` and `vector.from_elements` operations into a tree of +// `vector.shuffle` operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace vector { + +#define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE +#include "mlir/Dialect/Vector/Transforms/Passes.h.inc" + +} // namespace vector +} // namespace mlir + +#define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree" + +using namespace mlir; +using namespace mlir::vector; + +namespace { + +// Indentation unit for debug output formatting. +constexpr unsigned kIndScale = 2; + +/// Represents a closed interval of elements (e.g., [0, 7] = 8 elements). +using Interval = std::pair; +// Sentinel value for uninitialized intervals. +constexpr unsigned kMaxUnsigned = std::numeric_limits::max(); + +/// The VectorShuffleTreeBuilder builds a balanced binary tree of +/// `vector.shuffle` operations from one or more `vector.to_elements` +/// operations feeding a single `vector.from_elements` operation. +/// +/// The implementation generates hardware-agnostic `vector.shuffle` operations +/// that minimize both the number of shuffle operations and the length of +/// intermediate vectors (to the extent possible). The tree has the +/// following properties: +/// +/// 1. Vectors are shuffled in pairs by order of appearance in +/// the `vector.from_elements` operand list. +/// 2. Each input vector to each level is used only once. +/// 3. The number of levels in the tree is: +/// ceil(log2(# `vector.to_elements` ops)). +/// 4. Vectors at each level of the tree have the same vector length. +/// 5. Vector positions that do not need to be shuffled are represented with +/// poison in the shuffle mask. +/// +/// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>: +/// +/// %0:4 = vector.to_elements %a : vector<4xf32> +/// %1:4 = vector.to_elements %b : vector<4xf32> +/// %2:4 = vector.to_elements %c : vector<4xf32> +/// %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, +/// %1#2, %1#3, %2#0, %2#1, %2#2, %2#3 +/// : vector<12xf32> +/// => +/// +/// %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7] +/// : vector<4xf32>, vector<4xf32> +/// %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1] +/// : vector<4xf32>, vector<4xf32> +/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5, +/// 6, 7, 8, 9, 10, 11] +/// : vector<8xf32>, vector<8xf32> +/// +/// Comments: +/// * The shuffle tree has two levels: +/// - Level 1 = (%shuffle0, %shuffle1) +/// - Level 2 = (%result) +/// * `%a` and `%b` are shuffled first because they appear first in the +/// `vector.from_elements` operand list (`%0#0` and `%1#0`). +/// * `%c` is shuffled with itself because the number of +/// `vector.from_elements` operands is odd. +/// * The vector length for the first and second levels are 8 and 16, +/// respectively. +/// * `%shuffle1` uses poison values to match the vector length of its +/// tree level (8). +/// +/// +/// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: +/// +/// %0:5 = vector.to_elements %a : vector<5xf32> +/// %1:5 = vector.to_elements %b : vector<5xf32> +/// %2:5 = vector.to_elements %c : vector<5xf32> +/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, +/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> +/// => +/// +/// %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6] +/// : vector<5xf32>, vector<5xf32> +/// %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1] +/// : vector<5xf32>, vector<5xf32> +/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14] +/// : vector<8xf32>, vector<8xf32> +/// +/// Comments: +/// * `%c` and `%b` are shuffled first because they appear first in the +/// `vector.from_elements` operand list (`%2#2` and `%1#1`). +/// * `%a` is shuffled with itself because the number of +/// `vector.from_elements` operands is odd. +/// * The vector length for the first and second levels are 8 and 9, +/// respectively. +/// * `%shuffle0` uses poison values to mark unused vector positions and +/// match the vector length of its tree level (8). +/// +/// TODO: Implement mask compression to reduce the number of intermediate poison +/// values. +/// +class VectorShuffleTreeBuilder { +public: + VectorShuffleTreeBuilder() = delete; + VectorShuffleTreeBuilder(FromElementsOp fromElemOp, + ArrayRef toElemDefs); + + /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence + /// and compute the shuffle tree configuration. This method does not generate + /// any IR. + LogicalResult computeShuffleTree(); + + /// Materialize the shuffle tree configuration computed by + /// `computeShuffleTree` in the IR. + Value generateShuffleTree(PatternRewriter &rewriter); + +private: + // IR input information. + FromElementsOp fromElementsOp; + SmallVector toElementsDefs; + + // Shuffle tree configuration. + unsigned numLevels; + SmallVector vectorSizePerLevel; + /// Holds the range of positions in the final output that each vector input + /// in the tree is contributing to. + SmallVector> inputIntervalsPerLevel; + + // Utility methods to compute the shuffle tree configuration. + void computeInputVectorIntervals(); + void computeOutputVectorSizePerLevel(); + + /// Dump the shuffle tree configuration. + void dump(); +}; + +VectorShuffleTreeBuilder::VectorShuffleTreeBuilder( + FromElementsOp fromElemOp, ArrayRef toElemDefs) + : fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) { + + assert(fromElementsOp && "from_elements op is required"); + assert(!toElementsDefs.empty() && "At least one to_elements op is required"); + + // Duplicate the last vector if the number of `vector.to_elements` is odd to + // simplify the shuffle tree algorithm. + if (toElementsDefs.size() % 2 != 0) { + toElementsDefs.push_back(toElementsDefs.back()); + } +} + +// ===--------------------------------------------------------------------===// +// Shuffle Tree Analysis Utilities. +// ===--------------------------------------------------------------------===// + +/// Compute the intervals for all the input vectors in the shuffle tree. The +/// interval of an input vector is the range of positions in the final output +/// that the input vector contributes to. +/// +/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: +/// +/// %0:5 = vector.to_elements %a : vector<5xf32> +/// %1:5 = vector.to_elements %b : vector<5xf32> +/// %2:5 = vector.to_elements %c : vector<5xf32> +/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, +/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> +/// +/// Level 0 has 4 inputs (%2, %1, %0, %0, the last one is duplicated to make the +/// number of inputs even) so we compute the interval for each input vector: +/// +/// * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6] +/// * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7] +/// * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8] +/// * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8] +/// +/// Level 1 has 2 inputs, resulting from the shuffling of %2 + %1 and %0 + %0 so +/// we compute the intervals for each input vector to level 1 as: +/// * inputIntervalsPerLevel[1][0] = interval(%2) U interval(%1) = [0,7] +/// * inputIntervalsPerLevel[1][1] = interval(%0) U interval(%0) = [2,8] +/// +void VectorShuffleTreeBuilder::computeInputVectorIntervals() { + // Map `vector.to_elements` ops to their ordinal position in the + // `vector.from_elements` operand list. Make sure duplicated + // `vector.to_elements` ops are mapped to the its first occurrence. + DenseMap toElementsToInputOrdinal; + for (const auto &[idx, toElementsOp] : llvm::enumerate(toElementsDefs)) + toElementsToInputOrdinal.insert({toElementsOp, idx}); + + // Compute intervals for each input vector in the shuffle tree. The first + // level computation is special-cased to keep the implementation simpler. + + SmallVector firstLevelIntervals(toElementsDefs.size(), + {kMaxUnsigned, kMaxUnsigned}); + + for (const auto &[idx, element] : + llvm::enumerate(fromElementsOp.getElements())) { + auto toElementsOp = cast(element.getDefiningOp()); + unsigned inputIdx = toElementsToInputOrdinal[toElementsOp]; + Interval ¤tInterval = firstLevelIntervals[inputIdx]; + + // Set lower bound to the first occurrence of the `vector.to_elements`. + if (currentInterval.first == kMaxUnsigned) + currentInterval.first = idx; + + // Set upper bound to the last occurrence of the `vector.to_elements`. + currentInterval.second = idx; + } + + // If the number of `vector.to_elements` is odd and the last op was + // duplicated, the interval for the duplicated op was not computed in the + // previous step as all the input occurrences were mapped to the original op. + // We copy the interval of the original op to the interval of the duplicated + // op manually. + if (firstLevelIntervals.back().second == kMaxUnsigned) + firstLevelIntervals.back() = *std::prev(firstLevelIntervals.end(), 2); + + inputIntervalsPerLevel.push_back(std::move(firstLevelIntervals)); + + // Compute intervals for the remaining levels. + unsigned outputNumElements = + cast(fromElementsOp.getResult().getType()).getNumElements(); + for (unsigned level = 1; level < numLevels; ++level) { + const auto &prevLevelIntervals = inputIntervalsPerLevel[level - 1]; + SmallVector currentLevelIntervals( + llvm::divideCeil(prevLevelIntervals.size(), 2), + {kMaxUnsigned, kMaxUnsigned}); + + for (size_t inputIdx = 0; inputIdx < currentLevelIntervals.size(); + ++inputIdx) { + auto &interval = currentLevelIntervals[inputIdx]; + const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2]; + const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1]; + + // The interval of a vector at the current level is the union of the + // intervals of the two input vectors from the previous level being + // shuffled at this level. + interval.first = std::min(prevLhsInterval.first, prevRhsInterval.first); + interval.second = + std::min(std::max(prevLhsInterval.second, prevRhsInterval.second), + outputNumElements - 1); + } + + inputIntervalsPerLevel.push_back(std::move(currentLevelIntervals)); + } +} + +/// Compute the uniform output vector size for each level of the shuffle tree, +/// given the intervals of the input vectors at that level. The output vector +/// size of a level is the size of the widest interval resulting from shuffling +/// each pair of input vectors. +/// +/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: +/// +/// Intervals: +/// * Level 0: [0,6], [1,7], [2,8], [2,8] +/// * Level 1: [0,7], [2,8] +/// +/// Vector sizes: +/// * Level 0: max(size_of([0,6] U [1,7] = [0,7]) = 8, +/// size_of([2,8] U [2,8] = [2,8]) = 7) = 8 +/// +/// * Level 1: max(size_of([0,7] U [2,8] = [0,8]) = 9) = 9 +/// +void VectorShuffleTreeBuilder::computeOutputVectorSizePerLevel() { + // Compute vector size for each level. + for (unsigned level = 0; level < numLevels; ++level) { + const auto ¤tLevelIntervals = inputIntervalsPerLevel[level]; + unsigned currentVectorSize = 1; + for (size_t i = 0; i < currentLevelIntervals.size(); i += 2) { + const auto &lhsInterval = currentLevelIntervals[i]; + const auto &rhsInterval = currentLevelIntervals[i + 1]; + unsigned combinedIntervalSize = + std::max(lhsInterval.second, rhsInterval.second) - lhsInterval.first + + 1; + currentVectorSize = std::max(currentVectorSize, combinedIntervalSize); + } + vectorSizePerLevel[level] = currentVectorSize; + } +} + +void VectorShuffleTreeBuilder::dump() { + LLVM_DEBUG({ + unsigned indLv = 0; + + llvm::dbgs() << "VectorShuffleTreeBuilder Configuration:\n"; + ++indLv; + llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n"; + ++indLv; + for (const auto &toElementsOp : toElementsDefs) + llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElementsOp << "\n"; + llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElementsOp << "\n\n"; + --indLv; + + llvm::dbgs() << llvm::indent(indLv, kIndScale) + << "* Total levels: " << numLevels << "\n"; + llvm::dbgs() << llvm::indent(indLv, kIndScale) + << "* Vector sizes per level: ["; + llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs()); + llvm::dbgs() << "]\n"; + llvm::dbgs() << llvm::indent(indLv, kIndScale) + << "* Input intervals per level:\n"; + ++indLv; + for (const auto &[level, intervals] : + llvm::enumerate(inputIntervalsPerLevel)) { + llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Level " << level + << ": "; + llvm::interleaveComma(intervals, llvm::dbgs(), + [](const Interval &interval) { + llvm::dbgs() << "[" << interval.first << "," + << interval.second << "]"; + }); + llvm::dbgs() << "\n"; + } + }); +} + +/// Compute the shuffle tree configuration for the given `vector.to_elements` + +/// `vector.from_elements` input sequence. This method builds a balanced binary +/// shuffle tree that combines pairs of input vectors at each level. +/// +/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: +/// +/// %0:5 = vector.to_elements %a : vector<5xf32> +/// %1:5 = vector.to_elements %b : vector<5xf32> +/// %2:5 = vector.to_elements %c : vector<5xf32> +/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, +/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> +/// +/// build a tree that looks like: +/// +/// %2 %1 %0 %0 +/// \ / \ / +/// %2_1 = vector.shuffle %0_0 = vector.shuffle +/// \ / +/// %2_1_0_0 =vector.shuffle +/// +/// The configuration comprises of computing the intervals of the input vectors +/// at each level of the shuffle tree (i.e., %2, %1, %0, %0, %2_1, %0_0 and +/// %2_1_0_0) and the output vector size for each level. For further details on +/// intervals and output vector size computation, please, take a look at the +/// corresponding utility functions. +LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() { + // Initialize shuffle tree information based on its size. + assert(toElementsDefs.size() > 1 && + "At least two 'vector.to_elements' ops are required"); + numLevels = llvm::Log2_64(toElementsDefs.size()); + vectorSizePerLevel.resize(numLevels, 0); + inputIntervalsPerLevel.reserve(numLevels); + + computeInputVectorIntervals(); + computeOutputVectorSizePerLevel(); + dump(); + + return success(); +} + +// ===--------------------------------------------------------------------===// +// Shuffle Tree Code Generation Utilities. +// ===--------------------------------------------------------------------===// + +/// Compute the permutation mask for shuffling two input `vector.to_elements` +/// ops. The permutation mask is the mapping of the input vector elements to +/// their final position in the output vector, relative to the intermediate +/// output vector of the `vector.shuffle` operation combining the two inputs. +/// +/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: +/// +/// %0:5 = vector.to_elements %a : vector<5xf32> +/// %1:5 = vector.to_elements %b : vector<5xf32> +/// %2:5 = vector.to_elements %c : vector<5xf32> +/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, +/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> +/// +/// => +/// +/// // Level 0, vector length = 8 +/// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6] +/// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1] +/// +/// TODO: Implement mask compression. +static SmallVector computePermutationShuffleMask( + ToElementsOp toElementOp0, const Interval &interval0, + ToElementsOp toElementOp1, const Interval &interval1, + FromElementsOp fromElementsOp, unsigned outputVectorSize) { + SmallVector mask(outputVectorSize, ShuffleOp::kPoisonIndex); + unsigned inputVectorSize = + toElementOp0.getSource().getType().getNumElements(); + + for (const auto &[inputIdx, element] : + llvm::enumerate(fromElementsOp.getElements())) { + auto currentToElemOp = cast(element.getDefiningOp()); + // Match `vector.from_elements` operands to the two input ops. + if (currentToElemOp != toElementOp0 && currentToElemOp != toElementOp1) + continue; + + // The permutation value for a particular operand is the ordinal position of + // the operand in the `vector.to_elements` list of results. + unsigned permVal = cast(element).getResultNumber(); + unsigned maskIdx = inputIdx; + + // The mask index is the ordinal position of the operand in + // `vector.from_elements` operand list. We make this position relative to + // the interval of the output vector resulting from combining the two + // input vectors. + if (currentToElemOp == toElementOp0) { + maskIdx -= interval0.first; + } else { + // currentToElemOp == toElementOp1 + unsigned intervalOffset = interval1.first - interval0.first; + maskIdx += intervalOffset - interval1.first; + permVal += inputVectorSize; + } + + mask[maskIdx] = permVal; + } + + LLVM_DEBUG({ + unsigned indLv = 1; + llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Permutation mask: ["; + llvm::interleaveComma(mask, llvm::dbgs()); + llvm::dbgs() << "]\n"; + ++indLv; + llvm::dbgs() << llvm::indent(indLv, kIndScale) + << "* Combining: " << toElementOp0 << " and " << toElementOp1 + << "\n"; + }); + + return mask; +} + +/// Compute the propagation shuffle mask for combining two intermediate shuffle +/// operations of the tree. The propagation shuffle mask is the mapping of the +/// intermediate vector elements, which have already been shuffled to their +/// relative output position using the mask generated by +/// `computePermutationShuffleMask`, to their next position in the tree. +/// +/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: +/// +/// %0:5 = vector.to_elements %a : vector<5xf32> +/// %1:5 = vector.to_elements %b : vector<5xf32> +/// %2:5 = vector.to_elements %c : vector<5xf32> +/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, +/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> +/// +/// // Level 0, vector length = 8 +/// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6] +/// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1] +/// +/// => +/// +/// // Level 1, vector length = 9 +/// PropagationShuffleMask(%2_1, %0_0) = [0, 1, 8, 9, 4, 5, 6, 7, 14] +/// +/// TODO: Implement mask compression. +/// +static SmallVector computePropagationShuffleMask( + ShuffleOp lhsShuffleOp, const Interval &lhsInterval, ShuffleOp rhsShuffleOp, + const Interval &rhsInterval, unsigned outputVectorSize) { + ArrayRef lhsShuffleMask = lhsShuffleOp.getMask(); + ArrayRef rhsShuffleMask = rhsShuffleOp.getMask(); + unsigned inputVectorSize = lhsShuffleMask.size(); + assert(inputVectorSize == rhsShuffleMask.size() && + "Expected both shuffle masks to have the same size"); + + unsigned lhsRhsOffset = rhsInterval.first - lhsInterval.first; + SmallVector mask(outputVectorSize, ShuffleOp::kPoisonIndex); + + // Propagate any element from the input mask that is not poison. For the RHS + // input vector, the mask index is offset by the offset between the two + // intervals of the input vectors. + for (unsigned i = 0; i < inputVectorSize; ++i) { + if (lhsShuffleMask[i] != ShuffleOp::kPoisonIndex) + mask[i] = i; + + unsigned rhsIdx = i + lhsRhsOffset; + if (rhsShuffleMask[i] != ShuffleOp::kPoisonIndex) { + assert(rhsIdx < outputVectorSize && "RHS index out of bounds"); + assert(mask[rhsIdx] == ShuffleOp::kPoisonIndex && "mask already set"); + mask[rhsIdx] = i + inputVectorSize; + } + } + + LLVM_DEBUG({ + unsigned indLv = 1; + llvm::dbgs() << llvm::indent(indLv, kIndScale) + << "* Propagation shuffle mask computation:\n"; + ++indLv; + llvm::dbgs() << llvm::indent(indLv, kIndScale) + << "* LHS shuffle op: " << lhsShuffleOp << "\n"; + llvm::dbgs() << llvm::indent(indLv, kIndScale) + << "* RHS shuffle op: " << rhsShuffleOp << "\n"; + llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Result mask: ["; + llvm::interleaveComma(mask, llvm::dbgs()); + llvm::dbgs() << "]\n"; + }); + + return mask; +} + +/// Materialize the pre-computed shuffle tree configuration in the IR by +/// generating the corresponding `vector.shuffle` ops. +/// +/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: +/// +/// %0:5 = vector.to_elements %a : vector<5xf32> +/// %1:5 = vector.to_elements %b : vector<5xf32> +/// %2:5 = vector.to_elements %c : vector<5xf32> +/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, +/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> +/// +/// with the pre-computed shuffle tree configuration: +/// +/// * Vector sizes per level: [8, 9] +/// * Input intervals per level: +/// * Level 0: [0,6], [1,7], [2,8], [2,8] +/// * Level 1: [0,7], [2,8] +/// +/// => +/// +/// %0 = vector.shuffle %arg2, %arg1 [2, 6, -1, -1, 7, 2, 0, 6] +/// : vector<5xf32>, vector<5xf32> +/// %1 = vector.shuffle %arg0, %arg0 [1, 1, -1, -1, -1, -1, 4, -1] +/// : vector<5xf32>, vector<5xf32> +/// %2 = vector.shuffle %0, %1 [0, 1, 8, 9, 4, 5, 6, 7, 14] +/// : vector<8xf32>, vector<8xf32> +/// +/// The code generation comprises of combining pairs of input vectors for each +/// level of the tree, using the pre-computed per tree level intervals and +/// vector sizes. The algorithm generates two kinds of shuffle masks: +/// permutation masks and propagation masks. Permutation masks are computed for +/// the first level of the tree and permute the input vector elements to their +/// relative position in the final output. Propagation masks are computed for +/// subsequent levels and propagate the elements to the next level without +/// permutation. For further details on the shuffle mask computation, please, +/// take a look at the corresponding `computePermutationShuffleMask` and +/// `computePropagationShuffleMask` functions. +/// +Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) { + LLVM_DEBUG(llvm::dbgs() << "VectorShuffleTreeBuilder Code Generation:\n"); + + // Initialize work list with the `vector.to_elements` sources. + SmallVector levelInputs; + llvm::transform( + toElementsDefs, std::back_inserter(levelInputs), + [](ToElementsOp toElementsOp) { return toElementsOp.getSource(); }); + + // Build shuffle tree by combining pairs of vectors. + Location loc = fromElementsOp.getLoc(); + unsigned currentLevel = 0; + for (const auto &[levelVectorSize, inputIntervals] : + llvm::zip_equal(vectorSizePerLevel, inputIntervalsPerLevel)) { + LLVM_DEBUG(llvm::dbgs() + << llvm::indent(1, kIndScale) << "* Processing level " + << currentLevel << " (vector size: " << levelVectorSize + << ", # inputs: " << levelInputs.size() << ")\n"); + + // Process level input vectors in pairs. + SmallVector levelOutputs; + for (size_t i = 0; i < levelInputs.size(); i += 2) { + Value lhsVector = levelInputs[i]; + Value rhsVector = levelInputs[i + 1]; + const Interval &lhsInterval = inputIntervals[i]; + const Interval &rhsInterval = inputIntervals[i + 1]; + + // For the first level of the tree, permute the vector elements to their + // relative position in the final output. For subsequent levels, we + // propagate the elements to the next level without permutation. + SmallVector shuffleMask; + if (currentLevel == 0) { + shuffleMask = computePermutationShuffleMask( + toElementsDefs[i], lhsInterval, toElementsDefs[i + 1], rhsInterval, + fromElementsOp, levelVectorSize); + } else { + auto lhsShuffleOp = cast(lhsVector.getDefiningOp()); + auto rhsShuffleOp = cast(rhsVector.getDefiningOp()); + shuffleMask = computePropagationShuffleMask(lhsShuffleOp, lhsInterval, + rhsShuffleOp, rhsInterval, + levelVectorSize); + } + + Value shuffleVal = rewriter.create( + loc, lhsVector, rhsVector, shuffleMask); + levelOutputs.push_back(shuffleVal); + } + + levelInputs = std::move(levelOutputs); + ++currentLevel; + } + + assert(levelInputs.size() == 1 && "Should have exactly one result"); + return levelInputs.front(); +} + +/// Gather and unique all the `vector.to_elements` operations that feed the +/// `vector.from_elements` operation. The `vector.to_elements` operations are +/// returned in order of appearance in the `vector.from_elements`'s operand +/// list. +static LogicalResult +getToElementsDefiningOps(FromElementsOp fromElementsOp, + SmallVectorImpl &toElementsDefs) { + SetVector toElementsDefsSet; + for (Value element : fromElementsOp.getElements()) { + auto toElementsOp = element.getDefiningOp(); + if (!toElementsOp) + return failure(); + toElementsDefsSet.insert(toElementsOp); + } + + toElementsDefs.assign(toElementsDefsSet.begin(), toElementsDefsSet.end()); + return success(); +} + +/// Pass to rewrite `vector.to_elements` + `vector.from_elements` sequences into +/// a tree of `vector.shuffle` operations. +struct ToFromElementsToShuffleTreeRewrite final + : OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::FromElementsOp fromElementsOp, + PatternRewriter &rewriter) const override { + VectorType resultType = fromElementsOp.getType(); + if (resultType.getRank() != 1 || resultType.isScalable()) + return failure(); + + SmallVector toElementsDefs; + if (failed(getToElementsDefiningOps(fromElementsOp, toElementsDefs))) + return failure(); + + // Avoid generating a shuffle tree for trivial `vector.to_elements` -> + // `vector.from_elements` forwarding cases that do not require shuffling. + if (toElementsDefs.size() == 1) { + ToElementsOp toElementsOp0 = toElementsDefs.front(); + if (llvm::equal(fromElementsOp.getElements(), toElementsOp0.getResults())) + return failure(); + } + + VectorShuffleTreeBuilder shuffleTreeBuilder(fromElementsOp, toElementsDefs); + if (failed(shuffleTreeBuilder.computeShuffleTree())) + return failure(); + + Value finalShuffle = shuffleTreeBuilder.generateShuffleTree(rewriter); + rewriter.replaceOp(fromElementsOp, finalShuffle); + return success(); + } +}; + +struct LowerVectorToFromElementsToShuffleTreePass + : public vector::impl::LowerVectorToFromElementsToShuffleTreeBase< + LowerVectorToFromElementsToShuffleTreePass> { + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorToFromElementsToShuffleTreePatterns(patterns); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +void mlir::vector::populateVectorToFromElementsToShuffleTreePatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), + benefit); +} diff --git a/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir b/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir new file mode 100644 index 0000000000000..3dc579be12f0f --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir @@ -0,0 +1,329 @@ +// RUN: mlir-opt -lower-vector-to-from-elements-to-shuffle-tree -split-input-file %s | FileCheck %s + +// Captured variable names for `vector.shuffle` operations follow the L#SH# convention, +// where L# refers to the level of the tree the shuffle belongs to, and SH# refers to +// the shuffle index within that level. + +func.func @to_from_elements_single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> { + %0:8 = vector.to_elements %a : vector<8xf32> + %1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<8xf32> + return %1 : vector<8xf32> +} + +// CHECK-LABEL: func @to_from_elements_single_input_shuffle( +// CHECK-SAME: %[[A:.*]]: vector<8xf32> + // CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[A]] [7, 0, 6, 1, 5, 2, 4, 3] : vector<8xf32>, vector<8xf32> + // CHECK: return %[[L0SH0]] + +// ----- + +func.func @from_elements_to_elements_single_shuffle(%a: vector<8xf32>, + %b: vector<8xf32>) -> vector<8xf32> { + %0:8 = vector.to_elements %a : vector<8xf32> + %1:8 = vector.to_elements %b : vector<8xf32> + %2 = vector.from_elements %0#7, %1#0, %0#6, %1#1, %0#5, %1#2, %0#4, %1#3 : vector<8xf32> + return %2 : vector<8xf32> +} + +// CHECK-LABEL: func @from_elements_to_elements_single_shuffle( +// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32> +// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [7, 8, 6, 9, 5, 10, 4, 11] : vector<8xf32> +// CHECK: return %[[L0SH0]] + +// ----- + +func.func @to_from_elements_shuffle_tree_concat_4x8_to_32(%a: vector<8xf32>, + %b: vector<8xf32>, + %c: vector<8xf32>, + %d: vector<8xf32>) -> vector<32xf32> { + %0:8 = vector.to_elements %a : vector<8xf32> + %1:8 = vector.to_elements %b : vector<8xf32> + %2:8 = vector.to_elements %c : vector<8xf32> + %3:8 = vector.to_elements %d : vector<8xf32> + %4 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7, + %1#0, %1#1, %1#2, %1#3, %1#4, %1#5, %1#6, %1#7, + %2#0, %2#1, %2#2, %2#3, %2#4, %2#5, %2#6, %2#7, + %3#0, %3#1, %3#2, %3#3, %3#4, %3#5, %3#6, %3#7 : vector<32xf32> + return %4 : vector<32xf32> +} + +// CHECK-LABEL: func @to_from_elements_shuffle_tree_concat_4x8_to_32( +// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>, %[[C:.*]]: vector<8xf32>, %[[D:.*]]: vector<8xf32> +// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[D]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: return %[[L1SH0]] : vector<32xf32> + +// ----- + +func.func @to_from_elements_shuffle_tree_concat_3x4_to_12(%a: vector<4xf32>, + %b: vector<4xf32>, + %c: vector<4xf32>) -> vector<12xf32> { + %0:4 = vector.to_elements %a : vector<4xf32> + %1:4 = vector.to_elements %b : vector<4xf32> + %2:4 = vector.to_elements %c : vector<4xf32> + %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, %1#2, %1#3, %2#0, %2#1, %2#2, %2#3 : vector<12xf32> + return %3 : vector<12xf32> +} + +// CHECK-LABEL: func @to_from_elements_shuffle_tree_concat_3x4_to_12( +// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32> +// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[C]] [0, 1, 2, 3, -1, -1, -1, -1] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> +// CHECK: return %[[L1SH0]] : vector<12xf32> + +// ----- + +func.func @to_from_elements_shuffle_tree_concat_64x4_256( + %a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>, %d: vector<4xf32>, + %e: vector<4xf32>, %f: vector<4xf32>, %g: vector<4xf32>, %h: vector<4xf32>, + %i: vector<4xf32>, %j: vector<4xf32>, %k: vector<4xf32>, %l: vector<4xf32>, + %m: vector<4xf32>, %n: vector<4xf32>, %o: vector<4xf32>, %p: vector<4xf32>, + %q: vector<4xf32>, %r: vector<4xf32>, %s: vector<4xf32>, %t: vector<4xf32>, + %u: vector<4xf32>, %v: vector<4xf32>, %w: vector<4xf32>, %x: vector<4xf32>, + %y: vector<4xf32>, %z: vector<4xf32>, %aa: vector<4xf32>, %ab: vector<4xf32>, + %ac: vector<4xf32>, %ad: vector<4xf32>, %ae: vector<4xf32>, %af: vector<4xf32>, + %ag: vector<4xf32>, %ah: vector<4xf32>, %ai: vector<4xf32>, %aj: vector<4xf32>, + %ak: vector<4xf32>, %al: vector<4xf32>, %am: vector<4xf32>, %an: vector<4xf32>, + %ao: vector<4xf32>, %ap: vector<4xf32>, %aq: vector<4xf32>, %ar: vector<4xf32>, + %as: vector<4xf32>, %at: vector<4xf32>, %au: vector<4xf32>, %av: vector<4xf32>, + %aw: vector<4xf32>, %ax: vector<4xf32>, %ay: vector<4xf32>, %az: vector<4xf32>, + %ba: vector<4xf32>, %bb: vector<4xf32>, %bc: vector<4xf32>, %bd: vector<4xf32>, + %be: vector<4xf32>, %bf: vector<4xf32>, %bg: vector<4xf32>, %bh: vector<4xf32>, + %bi: vector<4xf32>, %bj: vector<4xf32>, %bk: vector<4xf32>, %bl: vector<4xf32>) -> vector<256xf32> { + %0:4 = vector.to_elements %a : vector<4xf32> + %1:4 = vector.to_elements %b : vector<4xf32> + %2:4 = vector.to_elements %c : vector<4xf32> + %3:4 = vector.to_elements %d : vector<4xf32> + %4:4 = vector.to_elements %e : vector<4xf32> + %5:4 = vector.to_elements %f : vector<4xf32> + %6:4 = vector.to_elements %g : vector<4xf32> + %7:4 = vector.to_elements %h : vector<4xf32> + %8:4 = vector.to_elements %i : vector<4xf32> + %9:4 = vector.to_elements %j : vector<4xf32> + %10:4 = vector.to_elements %k : vector<4xf32> + %11:4 = vector.to_elements %l : vector<4xf32> + %12:4 = vector.to_elements %m : vector<4xf32> + %13:4 = vector.to_elements %n : vector<4xf32> + %14:4 = vector.to_elements %o : vector<4xf32> + %15:4 = vector.to_elements %p : vector<4xf32> + %16:4 = vector.to_elements %q : vector<4xf32> + %17:4 = vector.to_elements %r : vector<4xf32> + %18:4 = vector.to_elements %s : vector<4xf32> + %19:4 = vector.to_elements %t : vector<4xf32> + %20:4 = vector.to_elements %u : vector<4xf32> + %21:4 = vector.to_elements %v : vector<4xf32> + %22:4 = vector.to_elements %w : vector<4xf32> + %23:4 = vector.to_elements %x : vector<4xf32> + %24:4 = vector.to_elements %y : vector<4xf32> + %25:4 = vector.to_elements %z : vector<4xf32> + %26:4 = vector.to_elements %aa : vector<4xf32> + %27:4 = vector.to_elements %ab : vector<4xf32> + %28:4 = vector.to_elements %ac : vector<4xf32> + %29:4 = vector.to_elements %ad : vector<4xf32> + %30:4 = vector.to_elements %ae : vector<4xf32> + %31:4 = vector.to_elements %af : vector<4xf32> + %32:4 = vector.to_elements %ag : vector<4xf32> + %33:4 = vector.to_elements %ah : vector<4xf32> + %34:4 = vector.to_elements %ai : vector<4xf32> + %35:4 = vector.to_elements %aj : vector<4xf32> + %36:4 = vector.to_elements %ak : vector<4xf32> + %37:4 = vector.to_elements %al : vector<4xf32> + %38:4 = vector.to_elements %am : vector<4xf32> + %39:4 = vector.to_elements %an : vector<4xf32> + %40:4 = vector.to_elements %ao : vector<4xf32> + %41:4 = vector.to_elements %ap : vector<4xf32> + %42:4 = vector.to_elements %aq : vector<4xf32> + %43:4 = vector.to_elements %ar : vector<4xf32> + %44:4 = vector.to_elements %as : vector<4xf32> + %45:4 = vector.to_elements %at : vector<4xf32> + %46:4 = vector.to_elements %au : vector<4xf32> + %47:4 = vector.to_elements %av : vector<4xf32> + %48:4 = vector.to_elements %aw : vector<4xf32> + %49:4 = vector.to_elements %ax : vector<4xf32> + %50:4 = vector.to_elements %ay : vector<4xf32> + %51:4 = vector.to_elements %az : vector<4xf32> + %52:4 = vector.to_elements %ba : vector<4xf32> + %53:4 = vector.to_elements %bb : vector<4xf32> + %54:4 = vector.to_elements %bc : vector<4xf32> + %55:4 = vector.to_elements %bd : vector<4xf32> + %56:4 = vector.to_elements %be : vector<4xf32> + %57:4 = vector.to_elements %bf : vector<4xf32> + %58:4 = vector.to_elements %bg : vector<4xf32> + %59:4 = vector.to_elements %bh : vector<4xf32> + %60:4 = vector.to_elements %bi : vector<4xf32> + %61:4 = vector.to_elements %bj : vector<4xf32> + %62:4 = vector.to_elements %bk : vector<4xf32> + %63:4 = vector.to_elements %bl : vector<4xf32> + %64 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, %1#2, %1#3, %2#0, %2#1, %2#2, %2#3, %3#0, %3#1, %3#2, %3#3, %4#0, %4#1, %4#2, %4#3, + %5#0, %5#1, %5#2, %5#3, %6#0, %6#1, %6#2, %6#3, %7#0, %7#1, %7#2, %7#3, %8#0, %8#1, %8#2, %8#3, %9#0, %9#1, %9#2, %9#3, + %10#0, %10#1, %10#2, %10#3, %11#0, %11#1, %11#2, %11#3, %12#0, %12#1, %12#2, %12#3, %13#0, %13#1, %13#2, %13#3, %14#0, %14#1, %14#2, %14#3, + %15#0, %15#1, %15#2, %15#3, %16#0, %16#1, %16#2, %16#3, %17#0, %17#1, %17#2, %17#3, %18#0, %18#1, %18#2, %18#3, %19#0, %19#1, %19#2, %19#3, + %20#0, %20#1, %20#2, %20#3, %21#0, %21#1, %21#2, %21#3, %22#0, %22#1, %22#2, %22#3, %23#0, %23#1, %23#2, %23#3, %24#0, %24#1, %24#2, %24#3, + %25#0, %25#1, %25#2, %25#3, %26#0, %26#1, %26#2, %26#3, %27#0, %27#1, %27#2, %27#3, %28#0, %28#1, %28#2, %28#3, %29#0, %29#1, %29#2, %29#3, + %30#0, %30#1, %30#2, %30#3, %31#0, %31#1, %31#2, %31#3, %32#0, %32#1, %32#2, %32#3, %33#0, %33#1, %33#2, %33#3, %34#0, %34#1, %34#2, %34#3, + %35#0, %35#1, %35#2, %35#3, %36#0, %36#1, %36#2, %36#3, %37#0, %37#1, %37#2, %37#3, %38#0, %38#1, %38#2, %38#3, %39#0, %39#1, %39#2, %39#3, + %40#0, %40#1, %40#2, %40#3, %41#0, %41#1, %41#2, %41#3, %42#0, %42#1, %42#2, %42#3, %43#0, %43#1, %43#2, %43#3, %44#0, %44#1, %44#2, %44#3, + %45#0, %45#1, %45#2, %45#3, %46#0, %46#1, %46#2, %46#3, %47#0, %47#1, %47#2, %47#3, %48#0, %48#1, %48#2, %48#3, %49#0, %49#1, %49#2, %49#3, + %50#0, %50#1, %50#2, %50#3, %51#0, %51#1, %51#2, %51#3, %52#0, %52#1, %52#2, %52#3, %53#0, %53#1, %53#2, %53#3, %54#0, %54#1, %54#2, %54#3, + %55#0, %55#1, %55#2, %55#3, %56#0, %56#1, %56#2, %56#3, %57#0, %57#1, %57#2, %57#3, %58#0, %58#1, %58#2, %58#3, %59#0, %59#1, %59#2, %59#3, + %60#0, %60#1, %60#2, %60#3, %61#0, %61#1, %61#2, %61#3, %62#0, %62#1, %62#2, %62#3, %63#0, %63#1, %63#2, %63#3 : vector<256xf32> + return %64 : vector<256xf32> +} + +// CHECK-LABEL: func.func @to_from_elements_shuffle_tree_concat_64x4_256( +// CHECK-SAME: %[[A:.+]]: vector<4xf32>, %[[B:.+]]: vector<4xf32>, %[[C:.+]]: vector<4xf32>, %[[D:.+]]: vector<4xf32>, %[[E:.+]]: vector<4xf32>, %[[F:.+]]: vector<4xf32>, %[[G:.+]]: vector<4xf32>, %[[H:.+]]: vector<4xf32>, %[[I:.+]]: vector<4xf32>, %[[J:.+]]: vector<4xf32>, %[[K:.+]]: vector<4xf32>, %[[L:.+]]: vector<4xf32>, %[[M:.+]]: vector<4xf32>, %[[N:.+]]: vector<4xf32>, %[[O:.+]]: vector<4xf32>, %[[P:.+]]: vector<4xf32>, %[[Q:.+]]: vector<4xf32>, %[[R:.+]]: vector<4xf32>, %[[S:.+]]: vector<4xf32>, %[[T:.+]]: vector<4xf32>, %[[U:.+]]: vector<4xf32>, %[[V:.+]]: vector<4xf32>, %[[W:.+]]: vector<4xf32>, %[[X:.+]]: vector<4xf32>, %[[Y:.+]]: vector<4xf32>, %[[Z:.+]]: vector<4xf32>, %[[AA:.+]]: vector<4xf32>, %[[AB:.+]]: vector<4xf32>, %[[AC:.+]]: vector<4xf32>, %[[AD:.+]]: vector<4xf32>, %[[AE:.+]]: vector<4xf32>, %[[AF:.+]]: vector<4xf32>, %[[AG:.+]]: vector<4xf32>, %[[AH:.+]]: vector<4xf32>, %[[AI:.+]]: vector<4xf32>, %[[AJ:.+]]: vector<4xf32>, %[[AK:.+]]: vector<4xf32>, %[[AL:.+]]: vector<4xf32>, %[[AM:.+]]: vector<4xf32>, %[[AN:.+]]: vector<4xf32>, %[[AO:.+]]: vector<4xf32>, %[[AP:.+]]: vector<4xf32>, %[[AQ:.+]]: vector<4xf32>, %[[AR:.+]]: vector<4xf32>, %[[AS:.+]]: vector<4xf32>, %[[AT:.+]]: vector<4xf32>, %[[AU:.+]]: vector<4xf32>, %[[AV:.+]]: vector<4xf32>, %[[AW:.+]]: vector<4xf32>, %[[AX:.+]]: vector<4xf32>, %[[AY:.+]]: vector<4xf32>, %[[AZ:.+]]: vector<4xf32>, %[[BA:.+]]: vector<4xf32>, %[[BB:.+]]: vector<4xf32>, %[[BC:.+]]: vector<4xf32>, %[[BD:.+]]: vector<4xf32>, %[[BE:.+]]: vector<4xf32>, %[[BF:.+]]: vector<4xf32>, %[[BG:.+]]: vector<4xf32>, %[[BH:.+]]: vector<4xf32>, %[[BI:.+]]: vector<4xf32>, %[[BJ:.+]]: vector<4xf32>, %[[BK:.+]]: vector<4xf32>, %[[BL:.+]]: vector<4xf32>) +// CHECK: %[[L0SH0:.+]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH1:.+]] = vector.shuffle %[[C]], %[[D]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH2:.+]] = vector.shuffle %[[E]], %[[F]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH3:.+]] = vector.shuffle %[[G]], %[[H]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH4:.+]] = vector.shuffle %[[I]], %[[J]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH5:.+]] = vector.shuffle %[[K]], %[[L]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH6:.+]] = vector.shuffle %[[M]], %[[N]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH7:.+]] = vector.shuffle %[[O]], %[[P]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH8:.+]] = vector.shuffle %[[Q]], %[[R]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH9:.+]] = vector.shuffle %[[S]], %[[T]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH10:.+]] = vector.shuffle %[[U]], %[[V]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH11:.+]] = vector.shuffle %[[W]], %[[X]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH12:.+]] = vector.shuffle %[[Y]], %[[Z]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH13:.+]] = vector.shuffle %[[AA]], %[[AB]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH14:.+]] = vector.shuffle %[[AC]], %[[AD]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH15:.+]] = vector.shuffle %[[AE]], %[[AF]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH16:.+]] = vector.shuffle %[[AG]], %[[AH]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH17:.+]] = vector.shuffle %[[AI]], %[[AJ]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH18:.+]] = vector.shuffle %[[AK]], %[[AL]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH19:.+]] = vector.shuffle %[[AM]], %[[AN]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH20:.+]] = vector.shuffle %[[AO]], %[[AP]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH21:.+]] = vector.shuffle %[[AQ]], %[[AR]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH22:.+]] = vector.shuffle %[[AS]], %[[AT]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH23:.+]] = vector.shuffle %[[AU]], %[[AV]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH24:.+]] = vector.shuffle %[[AW]], %[[AX]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH25:.+]] = vector.shuffle %[[AY]], %[[AZ]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH26:.+]] = vector.shuffle %[[BA]], %[[BB]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH27:.+]] = vector.shuffle %[[BC]], %[[BD]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH28:.+]] = vector.shuffle %[[BE]], %[[BF]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH29:.+]] = vector.shuffle %[[BG]], %[[BH]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH30:.+]] = vector.shuffle %[[BI]], %[[BJ]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH31:.+]] = vector.shuffle %[[BK]], %[[BL]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L1SH0:.+]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH1:.+]] = vector.shuffle %[[L0SH2]], %[[L0SH3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH2:.+]] = vector.shuffle %[[L0SH4]], %[[L0SH5]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH3:.+]] = vector.shuffle %[[L0SH6]], %[[L0SH7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH4:.+]] = vector.shuffle %[[L0SH8]], %[[L0SH9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH5:.+]] = vector.shuffle %[[L0SH10]], %[[L0SH11]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH6:.+]] = vector.shuffle %[[L0SH12]], %[[L0SH13]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH7:.+]] = vector.shuffle %[[L0SH14]], %[[L0SH15]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH8:.+]] = vector.shuffle %[[L0SH16]], %[[L0SH17]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH9:.+]] = vector.shuffle %[[L0SH18]], %[[L0SH19]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH10:.+]] = vector.shuffle %[[L0SH20]], %[[L0SH21]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH11:.+]] = vector.shuffle %[[L0SH22]], %[[L0SH23]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH12:.+]] = vector.shuffle %[[L0SH24]], %[[L0SH25]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH13:.+]] = vector.shuffle %[[L0SH26]], %[[L0SH27]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH14:.+]] = vector.shuffle %[[L0SH28]], %[[L0SH29]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH15:.+]] = vector.shuffle %[[L0SH30]], %[[L0SH31]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L2SH0:.+]] = vector.shuffle %[[L1SH0]], %[[L1SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: %[[L2SH1:.+]] = vector.shuffle %[[L1SH2]], %[[L1SH3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: %[[L2SH2:.+]] = vector.shuffle %[[L1SH4]], %[[L1SH5]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: %[[L2SH3:.+]] = vector.shuffle %[[L1SH6]], %[[L1SH7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: %[[L2SH4:.+]] = vector.shuffle %[[L1SH8]], %[[L1SH9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: %[[L2SH5:.+]] = vector.shuffle %[[L1SH10]], %[[L1SH11]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: %[[L2SH6:.+]] = vector.shuffle %[[L1SH12]], %[[L1SH13]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: %[[L2SH7:.+]] = vector.shuffle %[[L1SH14]], %[[L1SH15]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: %[[L3SH0:.+]] = vector.shuffle %[[L2SH0]], %[[L2SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf32>, vector<32xf32> +// CHECK: %[[L3SH1:.+]] = vector.shuffle %[[L2SH2]], %[[L2SH3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf32>, vector<32xf32> +// CHECK: %[[L3SH2:.+]] = vector.shuffle %[[L2SH4]], %[[L2SH5]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf32>, vector<32xf32> +// CHECK: %[[L3SH3:.+]] = vector.shuffle %[[L2SH6]], %[[L2SH7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf32>, vector<32xf32> +// CHECK: %[[L4SH0:.+]] = vector.shuffle %[[L3SH0]], %[[L3SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127] : vector<64xf32>, vector<64xf32> +// CHECK: %[[L4SH1:.+]] = vector.shuffle %[[L3SH2]], %[[L3SH3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127] : vector<64xf32>, vector<64xf32> +// CHECK: %[[L5SH0:.+]] = vector.shuffle %[[L4SH0]], %[[L4SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] : vector<128xf32>, vector<128xf32> +// CHECK: return %[[L5SH0]] : vector<256xf32> + +// ----- + +func.func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16(%a: vector<4xf32>, + %b: vector<4xf32>, + %c: vector<4xf32>, + %d: vector<4xf32>) -> vector<16xf32> { + %0:4 = vector.to_elements %a : vector<4xf32> + %1:4 = vector.to_elements %b : vector<4xf32> + %2:4 = vector.to_elements %c : vector<4xf32> + %3:4 = vector.to_elements %d : vector<4xf32> + %4 = vector.from_elements %3#3, %0#0, %2#2, %1#1, %3#0, %2#1, %0#3, %1#2, %0#1, %3#2, %1#0, %2#3, %1#3, %0#2, %3#1, %2#0 : vector<16xf32> + return %4 : vector<16xf32> +} + +// TODO: Implement mask compression to reduce the number of intermediate poison values. + +// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16( +// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>, %[[D:.*]]: vector<4xf32> +// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[D]], %[[A]] [3, 4, -1, -1, 0, -1, 7, -1, 5, 2, -1, -1, -1, 6, 1] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[B]] [2, 5, -1, 1, -1, 6, -1, -1, 4, 3, 7, -1, -1, 0, -1] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 15, 16, 4, 18, 6, 20, 8, 9, 23, 24, 25, 13, 14, 28] : vector<15xf32>, vector<15xf32> +// CHECK: return %[[L1SH0]] : vector<16xf32> + +// ----- + +func.func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12(%a: vector<4xf32>, + %b: vector<4xf32>, + %c: vector<4xf32>) -> vector<12xf32> { + %0:4 = vector.to_elements %a : vector<4xf32> + %1:4 = vector.to_elements %b : vector<4xf32> + %2:4 = vector.to_elements %c : vector<4xf32> + %3 = vector.from_elements %0#2, %1#1, %2#0, %0#1, %1#0, %2#2, %0#0, %1#3, %2#3, %0#3, %1#2, %2#1 : vector<12xf32> + return %3 : vector<12xf32> +} + +// TODO: Implement mask compression to reduce the number of intermediate poison values. + +// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12( +// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32> +// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [2, 5, -1, 1, 4, -1, 0, 7, -1, 3, 6] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[C]] [0, -1, -1, 2, -1, -1, 3, -1, -1, 1, -1] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 11, 3, 4, 14, 6, 7, 17, 9, 10, 20] : vector<11xf32>, vector<11xf32> +// CHECK: return %[[L1SH0]] : vector<12xf32> + +// ----- + +func.func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9(%a: vector<5xf32>, + %b: vector<5xf32>, + %c: vector<5xf32>) -> vector<9xf32> { + %0:5 = vector.to_elements %a : vector<5xf32> + %1:5 = vector.to_elements %b : vector<5xf32> + %2:5 = vector.to_elements %c : vector<5xf32> + %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, %2#2, %2#0, %1#1, %0#4 : vector<9xf32> + return %3 : vector<9xf32> +} + +// TODO: Implement mask compression to reduce the number of intermediate poison values. + +// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9( +// CHECK-SAME: %[[A:.*]]: vector<5xf32>, %[[B:.*]]: vector<5xf32>, %[[C:.*]]: vector<5xf32> +// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6] : vector<5xf32>, vector<5xf32> +// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1] : vector<5xf32>, vector<5xf32> +// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 8, 9, 4, 5, 6, 7, 14] : vector<8xf32>, vector<8xf32> +// CHECK: return %[[L1SH0]] : vector<9xf32> + +// ----- + +func.func @to_from_elements_shuffle_tree_broadcast_4x2_to_32(%a: vector<2xf32>, + %b: vector<2xf32>, + %c: vector<2xf32>, + %d: vector<2xf32>) -> vector<32xf32> { + %0:2 = vector.to_elements %a : vector<2xf32> + %1:2 = vector.to_elements %b : vector<2xf32> + %2:2 = vector.to_elements %c : vector<2xf32> + %3:2 = vector.to_elements %d : vector<2xf32> + %4 = vector.from_elements %0#0, %0#0, %0#0, %0#0, %0#1, %0#1, %0#1, %0#1, + %1#0, %1#0, %1#0, %1#0, %1#1, %1#1, %1#1, %1#1, + %2#0, %2#0, %2#0, %2#0, %2#1, %2#1, %2#1, %2#1, + %3#0, %3#0, %3#0, %3#0, %3#1, %3#1, %3#1, %3#1 : vector<32xf32> + return %4 : vector<32xf32> +} + +// CHECK-LABEL: func @to_from_elements_shuffle_tree_broadcast_4x2_to_32( +// CHECK-SAME: %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<2xf32>, %[[C:.*]]: vector<2xf32>, %[[D:.*]]: vector<2xf32> + // CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] : vector<2xf32>, vector<2xf32> + // CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[D]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] : vector<2xf32>, vector<2xf32> + // CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: return %[[L1SH0]] : vector<32xf32> From 6ed003b85024fd0ea0b0f0911a2fba9d2561d683 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Tue, 1 Jul 2025 05:05:30 +0000 Subject: [PATCH 2/4] Feedback --- ...LowerVectorToFromElementsToShuffleTree.cpp | 161 ++++++++++-------- ...m-elements-to-shuffle-tree-transforms.mlir | 103 +++++++++-- 2 files changed, 181 insertions(+), 83 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp index 53728d6dbe2a3..504103529cdcb 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp @@ -124,7 +124,6 @@ constexpr unsigned kMaxUnsigned = std::numeric_limits::max(); /// /// TODO: Implement mask compression to reduce the number of intermediate poison /// values. -/// class VectorShuffleTreeBuilder { public: VectorShuffleTreeBuilder() = delete; @@ -142,8 +141,8 @@ class VectorShuffleTreeBuilder { private: // IR input information. - FromElementsOp fromElementsOp; - SmallVector toElementsDefs; + FromElementsOp fromElemsOp; + SmallVector toElemsDefs; // Shuffle tree configuration. unsigned numLevels; @@ -162,16 +161,19 @@ class VectorShuffleTreeBuilder { VectorShuffleTreeBuilder::VectorShuffleTreeBuilder( FromElementsOp fromElemOp, ArrayRef toElemDefs) - : fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) { - - assert(fromElementsOp && "from_elements op is required"); - assert(!toElementsDefs.empty() && "At least one to_elements op is required"); + : fromElemsOp(fromElemOp), toElemsDefs(toElemDefs) { + assert(fromElemsOp && "from_elements op is required"); + assert(!toElemsDefs.empty() && "At least one to_elements op is required"); +} - // Duplicate the last vector if the number of `vector.to_elements` is odd to - // simplify the shuffle tree algorithm. - if (toElementsDefs.size() % 2 != 0) { - toElementsDefs.push_back(toElementsDefs.back()); - } +/// Duplicate the last operation, value or interval if the total number of them +/// is odd. This is useful to simplify the shuffle tree algorithm given that +/// vectors are shuffled in pairs and duplication would lead to the last shuffle +/// to have a single (duplicated) input vector. +template +static void duplicateLastIfOdd(SmallVectorImpl &values) { + if (values.size() % 2 != 0) + values.push_back(values.back()); } // ===--------------------------------------------------------------------===// @@ -207,20 +209,20 @@ void VectorShuffleTreeBuilder::computeInputVectorIntervals() { // Map `vector.to_elements` ops to their ordinal position in the // `vector.from_elements` operand list. Make sure duplicated // `vector.to_elements` ops are mapped to the its first occurrence. - DenseMap toElementsToInputOrdinal; - for (const auto &[idx, toElementsOp] : llvm::enumerate(toElementsDefs)) - toElementsToInputOrdinal.insert({toElementsOp, idx}); + DenseMap toElemsToInputOrdinal; + for (const auto &[idx, toElemsOp] : llvm::enumerate(toElemsDefs)) + toElemsToInputOrdinal.insert({toElemsOp, idx}); // Compute intervals for each input vector in the shuffle tree. The first // level computation is special-cased to keep the implementation simpler. - SmallVector firstLevelIntervals(toElementsDefs.size(), + SmallVector firstLevelIntervals(toElemsDefs.size(), {kMaxUnsigned, kMaxUnsigned}); for (const auto &[idx, element] : - llvm::enumerate(fromElementsOp.getElements())) { - auto toElementsOp = cast(element.getDefiningOp()); - unsigned inputIdx = toElementsToInputOrdinal[toElementsOp]; + llvm::enumerate(fromElemsOp.getElements())) { + auto toElemsOp = cast(element.getDefiningOp()); + unsigned inputIdx = toElemsToInputOrdinal[toElemsOp]; Interval ¤tInterval = firstLevelIntervals[inputIdx]; // Set lower bound to the first occurrence of the `vector.to_elements`. @@ -231,19 +233,13 @@ void VectorShuffleTreeBuilder::computeInputVectorIntervals() { currentInterval.second = idx; } - // If the number of `vector.to_elements` is odd and the last op was - // duplicated, the interval for the duplicated op was not computed in the - // previous step as all the input occurrences were mapped to the original op. - // We copy the interval of the original op to the interval of the duplicated - // op manually. - if (firstLevelIntervals.back().second == kMaxUnsigned) - firstLevelIntervals.back() = *std::prev(firstLevelIntervals.end(), 2); - + duplicateLastIfOdd(toElemsDefs); + duplicateLastIfOdd(firstLevelIntervals); inputIntervalsPerLevel.push_back(std::move(firstLevelIntervals)); // Compute intervals for the remaining levels. unsigned outputNumElements = - cast(fromElementsOp.getResult().getType()).getNumElements(); + cast(fromElemsOp.getResult().getType()).getNumElements(); for (unsigned level = 1; level < numLevels; ++level) { const auto &prevLevelIntervals = inputIntervalsPerLevel[level - 1]; SmallVector currentLevelIntervals( @@ -265,6 +261,7 @@ void VectorShuffleTreeBuilder::computeInputVectorIntervals() { outputNumElements - 1); } + duplicateLastIfOdd(currentLevelIntervals); inputIntervalsPerLevel.push_back(std::move(currentLevelIntervals)); } } @@ -311,9 +308,9 @@ void VectorShuffleTreeBuilder::dump() { ++indLv; llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n"; ++indLv; - for (const auto &toElementsOp : toElementsDefs) - llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElementsOp << "\n"; - llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElementsOp << "\n\n"; + for (const auto &toElemsOp : toElemsDefs) + llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElemsOp << "\n"; + llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElemsOp << "\n\n"; --indLv; llvm::dbgs() << llvm::indent(indLv, kIndScale) @@ -366,9 +363,7 @@ void VectorShuffleTreeBuilder::dump() { /// corresponding utility functions. LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() { // Initialize shuffle tree information based on its size. - assert(toElementsDefs.size() > 1 && - "At least two 'vector.to_elements' ops are required"); - numLevels = llvm::Log2_64(toElementsDefs.size()); + numLevels = std::max(1u, llvm::Log2_64_Ceil(toElemsDefs.size())); vectorSizePerLevel.resize(numLevels, 0); inputIntervalsPerLevel.reserve(numLevels); @@ -402,17 +397,18 @@ LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() { /// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6] /// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1] /// -/// TODO: Implement mask compression. +/// TODO: Implement mask compression to reduce the number of intermediate poison +/// values. static SmallVector computePermutationShuffleMask( ToElementsOp toElementOp0, const Interval &interval0, ToElementsOp toElementOp1, const Interval &interval1, - FromElementsOp fromElementsOp, unsigned outputVectorSize) { + FromElementsOp fromElemsOp, unsigned outputVectorSize) { SmallVector mask(outputVectorSize, ShuffleOp::kPoisonIndex); unsigned inputVectorSize = toElementOp0.getSource().getType().getNumElements(); for (const auto &[inputIdx, element] : - llvm::enumerate(fromElementsOp.getElements())) { + llvm::enumerate(fromElemsOp.getElements())) { auto currentToElemOp = cast(element.getDefiningOp()); // Match `vector.from_elements` operands to the two input ops. if (currentToElemOp != toElementOp0 && currentToElemOp != toElementOp1) @@ -476,8 +472,8 @@ static SmallVector computePermutationShuffleMask( /// // Level 1, vector length = 9 /// PropagationShuffleMask(%2_1, %0_0) = [0, 1, 8, 9, 4, 5, 6, 7, 14] /// -/// TODO: Implement mask compression. -/// +/// TODO: Implement mask compression to reduce the number of intermediate poison +/// values. static SmallVector computePropagationShuffleMask( ShuffleOp lhsShuffleOp, const Interval &lhsInterval, ShuffleOp rhsShuffleOp, const Interval &rhsInterval, unsigned outputVectorSize) { @@ -487,6 +483,7 @@ static SmallVector computePropagationShuffleMask( assert(inputVectorSize == rhsShuffleMask.size() && "Expected both shuffle masks to have the same size"); + bool hasSameInput = lhsShuffleOp == rhsShuffleOp; unsigned lhsRhsOffset = rhsInterval.first - lhsInterval.first; SmallVector mask(outputVectorSize, ShuffleOp::kPoisonIndex); @@ -497,6 +494,9 @@ static SmallVector computePropagationShuffleMask( if (lhsShuffleMask[i] != ShuffleOp::kPoisonIndex) mask[i] = i; + if (hasSameInput) + continue; + unsigned rhsIdx = i + lhsRhsOffset; if (rhsShuffleMask[i] != ShuffleOp::kPoisonIndex) { assert(rhsIdx < outputVectorSize && "RHS index out of bounds"); @@ -565,15 +565,19 @@ Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) { // Initialize work list with the `vector.to_elements` sources. SmallVector levelInputs; - llvm::transform( - toElementsDefs, std::back_inserter(levelInputs), - [](ToElementsOp toElementsOp) { return toElementsOp.getSource(); }); + llvm::transform(toElemsDefs, std::back_inserter(levelInputs), + [](ToElementsOp toElemsOp) { return toElemsOp.getSource(); }); + // TODO: Check that every pair of input has the same vector size. Otherwise, + // promote the narrower one to the wider one. // Build shuffle tree by combining pairs of vectors. - Location loc = fromElementsOp.getLoc(); + Location loc = fromElemsOp.getLoc(); unsigned currentLevel = 0; for (const auto &[levelVectorSize, inputIntervals] : llvm::zip_equal(vectorSizePerLevel, inputIntervalsPerLevel)) { + + duplicateLastIfOdd(levelInputs); + LLVM_DEBUG(llvm::dbgs() << llvm::indent(1, kIndScale) << "* Processing level " << currentLevel << " (vector size: " << levelVectorSize @@ -593,8 +597,8 @@ Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) { SmallVector shuffleMask; if (currentLevel == 0) { shuffleMask = computePermutationShuffleMask( - toElementsDefs[i], lhsInterval, toElementsDefs[i + 1], rhsInterval, - fromElementsOp, levelVectorSize); + toElemsDefs[i], lhsInterval, toElemsDefs[i + 1], rhsInterval, + fromElemsOp, levelVectorSize); } else { auto lhsShuffleOp = cast(lhsVector.getDefiningOp()); auto rhsShuffleOp = cast(rhsVector.getDefiningOp()); @@ -621,17 +625,17 @@ Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) { /// returned in order of appearance in the `vector.from_elements`'s operand /// list. static LogicalResult -getToElementsDefiningOps(FromElementsOp fromElementsOp, - SmallVectorImpl &toElementsDefs) { - SetVector toElementsDefsSet; - for (Value element : fromElementsOp.getElements()) { - auto toElementsOp = element.getDefiningOp(); - if (!toElementsOp) +getToElementsDefiningOps(FromElementsOp fromElemsOp, + SmallVectorImpl &toElemsDefs) { + SetVector toElemsDefsSet; + for (Value element : fromElemsOp.getElements()) { + auto toElemsOp = element.getDefiningOp(); + if (!toElemsOp) return failure(); - toElementsDefsSet.insert(toElementsOp); + toElemsDefsSet.insert(toElemsOp); } - toElementsDefs.assign(toElementsDefsSet.begin(), toElementsDefsSet.end()); + toElemsDefs.assign(toElemsDefsSet.begin(), toElemsDefsSet.end()); return success(); } @@ -642,30 +646,53 @@ struct ToFromElementsToShuffleTreeRewrite final using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(vector::FromElementsOp fromElementsOp, + LogicalResult matchAndRewrite(vector::FromElementsOp fromElemsOp, PatternRewriter &rewriter) const override { - VectorType resultType = fromElementsOp.getType(); - if (resultType.getRank() != 1 || resultType.isScalable()) - return failure(); + VectorType resultType = fromElemsOp.getType(); + if (resultType.getRank() != 1) + return rewriter.notifyMatchFailure( + fromElemsOp, "Multi-dimensional vectors are not supported yet"); + if (resultType.isScalable()) + return rewriter.notifyMatchFailure( + fromElemsOp, + "'vector.from_elements' does not support scalable vectors"); + + SmallVector toElemsDefs; + if (failed(getToElementsDefiningOps(fromElemsOp, toElemsDefs))) + return rewriter.notifyMatchFailure(fromElemsOp, "unsupported sources"); + + int64_t numElements = + toElemsDefs.front().getSource().getType().getNumElements(); + for (ToElementsOp toElemsOp : toElemsDefs) { + if (toElemsOp.getSource().getType().getNumElements() != numElements) + return rewriter.notifyMatchFailure( + fromElemsOp, "unsupported sources with different vector sizes"); + } - SmallVector toElementsDefs; - if (failed(getToElementsDefiningOps(fromElementsOp, toElementsDefs))) - return failure(); + if (llvm::any_of(toElemsDefs, [](ToElementsOp toElemsOp) { + return !toElemsOp.getSource().getType().hasRank(); + })) { + return rewriter.notifyMatchFailure(fromElemsOp, + "0-D vectors are not supported"); + } // Avoid generating a shuffle tree for trivial `vector.to_elements` -> // `vector.from_elements` forwarding cases that do not require shuffling. - if (toElementsDefs.size() == 1) { - ToElementsOp toElementsOp0 = toElementsDefs.front(); - if (llvm::equal(fromElementsOp.getElements(), toElementsOp0.getResults())) - return failure(); + if (toElemsDefs.size() == 1) { + ToElementsOp toElemsOp0 = toElemsDefs.front(); + if (llvm::equal(fromElemsOp.getElements(), toElemsOp0.getResults())) { + return rewriter.notifyMatchFailure( + fromElemsOp, "trivial forwarding case does not require shuffling"); + } } - VectorShuffleTreeBuilder shuffleTreeBuilder(fromElementsOp, toElementsDefs); + VectorShuffleTreeBuilder shuffleTreeBuilder(fromElemsOp, toElemsDefs); if (failed(shuffleTreeBuilder.computeShuffleTree())) - return failure(); + return rewriter.notifyMatchFailure(fromElemsOp, + "failed to compute shuffle tree"); Value finalShuffle = shuffleTreeBuilder.generateShuffleTree(rewriter); - rewriter.replaceOp(fromElementsOp, finalShuffle); + rewriter.replaceOp(fromElemsOp, finalShuffle); return success(); } }; diff --git a/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir b/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir index 3dc579be12f0f..a8d3d5278e893 100644 --- a/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir @@ -4,13 +4,27 @@ // where L# refers to the level of the tree the shuffle belongs to, and SH# refers to // the shuffle index within that level. -func.func @to_from_elements_single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> { +func.func @trivial_forwarding(%a: vector<8xf32>) -> vector<8xf32> { + %0:8 = vector.to_elements %a : vector<8xf32> + %1 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 : vector<8xf32> + return %1 : vector<8xf32> +} + +// No shuffle tree needed for trivial forwarding case. + +// CHECK-LABEL: func @trivial_forwarding( +// CHECK-SAME: %[[A:.*]]: vector<8xf32> +// CHECK: return %[[A]] : vector<8xf32> + +// ----- + +func.func @single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> { %0:8 = vector.to_elements %a : vector<8xf32> %1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<8xf32> return %1 : vector<8xf32> } -// CHECK-LABEL: func @to_from_elements_single_input_shuffle( +// CHECK-LABEL: func @single_input_shuffle( // CHECK-SAME: %[[A:.*]]: vector<8xf32> // CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[A]] [7, 0, 6, 1, 5, 2, 4, 3] : vector<8xf32>, vector<8xf32> // CHECK: return %[[L0SH0]] @@ -32,7 +46,7 @@ func.func @from_elements_to_elements_single_shuffle(%a: vector<8xf32>, // ----- -func.func @to_from_elements_shuffle_tree_concat_4x8_to_32(%a: vector<8xf32>, +func.func @shuffle_tree_concat_4x8_to_32(%a: vector<8xf32>, %b: vector<8xf32>, %c: vector<8xf32>, %d: vector<8xf32>) -> vector<32xf32> { @@ -47,7 +61,7 @@ func.func @to_from_elements_shuffle_tree_concat_4x8_to_32(%a: vector<8xf32>, return %4 : vector<32xf32> } -// CHECK-LABEL: func @to_from_elements_shuffle_tree_concat_4x8_to_32( +// CHECK-LABEL: func @shuffle_tree_concat_4x8_to_32( // CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>, %[[C:.*]]: vector<8xf32>, %[[D:.*]]: vector<8xf32> // CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> // CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[D]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> @@ -56,7 +70,7 @@ func.func @to_from_elements_shuffle_tree_concat_4x8_to_32(%a: vector<8xf32>, // ----- -func.func @to_from_elements_shuffle_tree_concat_3x4_to_12(%a: vector<4xf32>, +func.func @shuffle_tree_concat_3x4_to_12(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<12xf32> { %0:4 = vector.to_elements %a : vector<4xf32> @@ -66,7 +80,7 @@ func.func @to_from_elements_shuffle_tree_concat_3x4_to_12(%a: vector<4xf32>, return %3 : vector<12xf32> } -// CHECK-LABEL: func @to_from_elements_shuffle_tree_concat_3x4_to_12( +// CHECK-LABEL: func @shuffle_tree_concat_3x4_to_12( // CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32> // CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> // CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[C]] [0, 1, 2, 3, -1, -1, -1, -1] : vector<4xf32>, vector<4xf32> @@ -75,7 +89,7 @@ func.func @to_from_elements_shuffle_tree_concat_3x4_to_12(%a: vector<4xf32>, // ----- -func.func @to_from_elements_shuffle_tree_concat_64x4_256( +func.func @shuffle_tree_concat_64x4_256( %a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>, %d: vector<4xf32>, %e: vector<4xf32>, %f: vector<4xf32>, %g: vector<4xf32>, %h: vector<4xf32>, %i: vector<4xf32>, %j: vector<4xf32>, %k: vector<4xf32>, %l: vector<4xf32>, @@ -172,7 +186,7 @@ func.func @to_from_elements_shuffle_tree_concat_64x4_256( return %64 : vector<256xf32> } -// CHECK-LABEL: func.func @to_from_elements_shuffle_tree_concat_64x4_256( +// CHECK-LABEL: func.func @shuffle_tree_concat_64x4_256( // CHECK-SAME: %[[A:.+]]: vector<4xf32>, %[[B:.+]]: vector<4xf32>, %[[C:.+]]: vector<4xf32>, %[[D:.+]]: vector<4xf32>, %[[E:.+]]: vector<4xf32>, %[[F:.+]]: vector<4xf32>, %[[G:.+]]: vector<4xf32>, %[[H:.+]]: vector<4xf32>, %[[I:.+]]: vector<4xf32>, %[[J:.+]]: vector<4xf32>, %[[K:.+]]: vector<4xf32>, %[[L:.+]]: vector<4xf32>, %[[M:.+]]: vector<4xf32>, %[[N:.+]]: vector<4xf32>, %[[O:.+]]: vector<4xf32>, %[[P:.+]]: vector<4xf32>, %[[Q:.+]]: vector<4xf32>, %[[R:.+]]: vector<4xf32>, %[[S:.+]]: vector<4xf32>, %[[T:.+]]: vector<4xf32>, %[[U:.+]]: vector<4xf32>, %[[V:.+]]: vector<4xf32>, %[[W:.+]]: vector<4xf32>, %[[X:.+]]: vector<4xf32>, %[[Y:.+]]: vector<4xf32>, %[[Z:.+]]: vector<4xf32>, %[[AA:.+]]: vector<4xf32>, %[[AB:.+]]: vector<4xf32>, %[[AC:.+]]: vector<4xf32>, %[[AD:.+]]: vector<4xf32>, %[[AE:.+]]: vector<4xf32>, %[[AF:.+]]: vector<4xf32>, %[[AG:.+]]: vector<4xf32>, %[[AH:.+]]: vector<4xf32>, %[[AI:.+]]: vector<4xf32>, %[[AJ:.+]]: vector<4xf32>, %[[AK:.+]]: vector<4xf32>, %[[AL:.+]]: vector<4xf32>, %[[AM:.+]]: vector<4xf32>, %[[AN:.+]]: vector<4xf32>, %[[AO:.+]]: vector<4xf32>, %[[AP:.+]]: vector<4xf32>, %[[AQ:.+]]: vector<4xf32>, %[[AR:.+]]: vector<4xf32>, %[[AS:.+]]: vector<4xf32>, %[[AT:.+]]: vector<4xf32>, %[[AU:.+]]: vector<4xf32>, %[[AV:.+]]: vector<4xf32>, %[[AW:.+]]: vector<4xf32>, %[[AX:.+]]: vector<4xf32>, %[[AY:.+]]: vector<4xf32>, %[[AZ:.+]]: vector<4xf32>, %[[BA:.+]]: vector<4xf32>, %[[BB:.+]]: vector<4xf32>, %[[BC:.+]]: vector<4xf32>, %[[BD:.+]]: vector<4xf32>, %[[BE:.+]]: vector<4xf32>, %[[BF:.+]]: vector<4xf32>, %[[BG:.+]]: vector<4xf32>, %[[BH:.+]]: vector<4xf32>, %[[BI:.+]]: vector<4xf32>, %[[BJ:.+]]: vector<4xf32>, %[[BK:.+]]: vector<4xf32>, %[[BL:.+]]: vector<4xf32>) // CHECK: %[[L0SH0:.+]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> // CHECK: %[[L0SH1:.+]] = vector.shuffle %[[C]], %[[D]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> @@ -241,7 +255,7 @@ func.func @to_from_elements_shuffle_tree_concat_64x4_256( // ----- -func.func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16(%a: vector<4xf32>, +func.func @shuffle_tree_arbitrary_4x4_to_16(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>, %d: vector<4xf32>) -> vector<16xf32> { @@ -255,7 +269,7 @@ func.func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16(%a: vector<4xf32>, // TODO: Implement mask compression to reduce the number of intermediate poison values. -// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16( +// CHECK-LABEL: func @shuffle_tree_arbitrary_4x4_to_16( // CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>, %[[D:.*]]: vector<4xf32> // CHECK: %[[L0SH0:.*]] = vector.shuffle %[[D]], %[[A]] [3, 4, -1, -1, 0, -1, 7, -1, 5, 2, -1, -1, -1, 6, 1] : vector<4xf32>, vector<4xf32> // CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[B]] [2, 5, -1, 1, -1, 6, -1, -1, 4, 3, 7, -1, -1, 0, -1] : vector<4xf32>, vector<4xf32> @@ -264,7 +278,7 @@ func.func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16(%a: vector<4xf32>, // ----- -func.func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12(%a: vector<4xf32>, +func.func @shuffle_tree_arbitrary_3x4_to_12(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<12xf32> { %0:4 = vector.to_elements %a : vector<4xf32> @@ -276,7 +290,7 @@ func.func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12(%a: vector<4xf32>, // TODO: Implement mask compression to reduce the number of intermediate poison values. -// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12( +// CHECK-LABEL: func @shuffle_tree_arbitrary_3x4_to_12( // CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32> // CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [2, 5, -1, 1, 4, -1, 0, 7, -1, 3, 6] : vector<4xf32>, vector<4xf32> // CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[C]] [0, -1, -1, 2, -1, -1, 3, -1, -1, 1, -1] : vector<4xf32>, vector<4xf32> @@ -285,7 +299,7 @@ func.func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12(%a: vector<4xf32>, // ----- -func.func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9(%a: vector<5xf32>, +func.func @shuffle_tree_arbitrary_3x5_to_9(%a: vector<5xf32>, %b: vector<5xf32>, %c: vector<5xf32>) -> vector<9xf32> { %0:5 = vector.to_elements %a : vector<5xf32> @@ -297,7 +311,7 @@ func.func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9(%a: vector<5xf32>, // TODO: Implement mask compression to reduce the number of intermediate poison values. -// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9( +// CHECK-LABEL: func @shuffle_tree_arbitrary_3x5_to_9( // CHECK-SAME: %[[A:.*]]: vector<5xf32>, %[[B:.*]]: vector<5xf32>, %[[C:.*]]: vector<5xf32> // CHECK: %[[L0SH0:.*]] = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6] : vector<5xf32>, vector<5xf32> // CHECK: %[[L0SH1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1] : vector<5xf32>, vector<5xf32> @@ -306,7 +320,7 @@ func.func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9(%a: vector<5xf32>, // ----- -func.func @to_from_elements_shuffle_tree_broadcast_4x2_to_32(%a: vector<2xf32>, +func.func @shuffle_tree_broadcast_4x2_to_32(%a: vector<2xf32>, %b: vector<2xf32>, %c: vector<2xf32>, %d: vector<2xf32>) -> vector<32xf32> { @@ -321,9 +335,66 @@ func.func @to_from_elements_shuffle_tree_broadcast_4x2_to_32(%a: vector<2xf32>, return %4 : vector<32xf32> } -// CHECK-LABEL: func @to_from_elements_shuffle_tree_broadcast_4x2_to_32( +// CHECK-LABEL: func @shuffle_tree_broadcast_4x2_to_32( // CHECK-SAME: %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<2xf32>, %[[C:.*]]: vector<2xf32>, %[[D:.*]]: vector<2xf32> // CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] : vector<2xf32>, vector<2xf32> // CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[D]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] : vector<2xf32>, vector<2xf32> // CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> // CHECK: return %[[L1SH0]] : vector<32xf32> + +// ----- + + +func.func @shuffle_tree_arbitrary_mixed_sizes( + %a : vector<2xf32>, + %b : vector<1xf32>, + %c : vector<3xf32>, + %d : vector<1xf32>, + %e : vector<5xf32>) -> vector<6xf32> { + %0:2 = vector.to_elements %a : vector<2xf32> + %1 = vector.to_elements %b : vector<1xf32> + %2:3 = vector.to_elements %c : vector<3xf32> + %3 = vector.to_elements %d : vector<1xf32> + %4:5 = vector.to_elements %e : vector<5xf32> + %5 = vector.from_elements %0#0, %2#0, %3, %4#0, %1, %4#3 : vector<6xf32> + return %5 : vector<6xf32> +} + +// TODO: Support mixed vector sizes. + +// CHECK-LABEL: func @shuffle_tree_arbitrary_mixed_sizes( +// CHECK-COUNT-5: vector.to_elements +// CHECK: vector.from_elements + +// ----- + +func.func @shuffle_tree_odd_intermediate_vectors( + %a : vector<2xf32>, + %b : vector<2xf32>, + %c : vector<2xf32>, + %d : vector<2xf32>, + %e : vector<2xf32>, + %f : vector<2xf32>) -> vector<6xf32> { + %0:2 = vector.to_elements %a : vector<2xf32> + %1:2 = vector.to_elements %b : vector<2xf32> + %2:2 = vector.to_elements %c : vector<2xf32> + %3:2 = vector.to_elements %d : vector<2xf32> + %4:2 = vector.to_elements %e : vector<2xf32> + %5:2 = vector.to_elements %f : vector<2xf32> + %6 = vector.from_elements %0#0, %1#1, %2#0, %3#1, %4#0, %5#1 : vector<6xf32> + return %6 : vector<6xf32> +} + +// CHECK-LABEL: func @shuffle_tree_odd_intermediate_vectors( +// CHECK-SAME: %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<2xf32>, %[[C:.*]]: vector<2xf32>, %[[D:.*]]: vector<2xf32>, %[[E:.*]]: vector<2xf32>, %[[F:.*]]: vector<2xf32> +// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 3] : vector<2xf32>, vector<2xf32> +// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[D]] [0, 3] : vector<2xf32>, vector<2xf32> +// CHECK: %[[L0SH2:.*]] = vector.shuffle %[[E]], %[[F]] [0, 3] : vector<2xf32>, vector<2xf32> +// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3] : vector<2xf32>, vector<2xf32> +// CHECK: %[[L2SH0:.*]] = vector.shuffle %[[L0SH2]], %[[L0SH2]] [0, 1, -1, -1] : vector<2xf32>, vector<2xf32> +// CHECK: %[[L3SH0:.*]] = vector.shuffle %[[L1SH0]], %[[L2SH0]] [0, 1, 2, 3, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK: return %[[L3SH0]] : vector<6xf32> + + + + From 5f54d910998cd002ef49c5a0a473433b1293d3bd Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Wed, 2 Jul 2025 19:28:55 +0000 Subject: [PATCH 3/4] Doc++, enable mixed vectors, misc. improvements --- ...LowerVectorToFromElementsToShuffleTree.cpp | 165 +++++++++--------- ...m-elements-to-shuffle-tree-transforms.mlir | 33 +++- 2 files changed, 113 insertions(+), 85 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp index 504103529cdcb..766f89254a191 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp @@ -57,7 +57,7 @@ constexpr unsigned kMaxUnsigned = std::numeric_limits::max(); /// /// 1. Vectors are shuffled in pairs by order of appearance in /// the `vector.from_elements` operand list. -/// 2. Each input vector to each level is used only once. +/// 2. Each vector at each level is used only once. /// 3. The number of levels in the tree is: /// ceil(log2(# `vector.to_elements` ops)). /// 4. Vectors at each level of the tree have the same vector length. @@ -147,13 +147,13 @@ class VectorShuffleTreeBuilder { // Shuffle tree configuration. unsigned numLevels; SmallVector vectorSizePerLevel; - /// Holds the range of positions in the final output that each vector input - /// in the tree is contributing to. - SmallVector> inputIntervalsPerLevel; + /// Holds the range of positions each vector in the tree contributes to the + /// final output vector. + SmallVector> intervalsPerLevel; // Utility methods to compute the shuffle tree configuration. - void computeInputVectorIntervals(); - void computeOutputVectorSizePerLevel(); + void computeShuffleTreeIntervals(); + void computeShuffleTreeVectorSizes(); /// Dump the shuffle tree configuration. void dump(); @@ -176,13 +176,13 @@ static void duplicateLastIfOdd(SmallVectorImpl &values) { values.push_back(values.back()); } -// ===--------------------------------------------------------------------===// +// ===---------------------------------------------------------------------===// // Shuffle Tree Analysis Utilities. -// ===--------------------------------------------------------------------===// +// ===---------------------------------------------------------------------===// -/// Compute the intervals for all the input vectors in the shuffle tree. The -/// interval of an input vector is the range of positions in the final output -/// that the input vector contributes to. +/// Compute the intervals for all the vectors in the shuffle tree. The interval +/// of a vector is the range of positions that vector contributes to the final +/// output vector. /// /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: /// @@ -192,20 +192,20 @@ static void duplicateLastIfOdd(SmallVectorImpl &values) { /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> /// -/// Level 0 has 4 inputs (%2, %1, %0, %0, the last one is duplicated to make the -/// number of inputs even) so we compute the interval for each input vector: +/// Level 0 has 4 vectors (%2, %1, %0, %0, the last one is duplicated to make +/// the number of inputs even) so we compute the interval for each vector: /// -/// * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6] -/// * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7] -/// * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8] -/// * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8] +/// * intervalsPerLevel[0][0] = interval(%2) = [0,6] +/// * intervalsPerLevel[0][1] = interval(%1) = [1,7] +/// * intervalsPerLevel[0][2] = interval(%0) = [2,8] +/// * intervalsPerLevel[0][3] = interval(%0) = [2,8] /// -/// Level 1 has 2 inputs, resulting from the shuffling of %2 + %1 and %0 + %0 so -/// we compute the intervals for each input vector to level 1 as: -/// * inputIntervalsPerLevel[1][0] = interval(%2) U interval(%1) = [0,7] -/// * inputIntervalsPerLevel[1][1] = interval(%0) U interval(%0) = [2,8] +/// Level 1 has 2 vectors, resulting from the shuffling of %2 + %1 and %0 + %0 +/// so we compute the intervals for each vector at level 1 as: +/// * intervalsPerLevel[1][0] = interval(%2) U interval(%1) = [0,7] +/// * intervalsPerLevel[1][1] = interval(%0) U interval(%0) = [2,8] /// -void VectorShuffleTreeBuilder::computeInputVectorIntervals() { +void VectorShuffleTreeBuilder::computeShuffleTreeIntervals() { // Map `vector.to_elements` ops to their ordinal position in the // `vector.from_elements` operand list. Make sure duplicated // `vector.to_elements` ops are mapped to the its first occurrence. @@ -213,7 +213,7 @@ void VectorShuffleTreeBuilder::computeInputVectorIntervals() { for (const auto &[idx, toElemsOp] : llvm::enumerate(toElemsDefs)) toElemsToInputOrdinal.insert({toElemsOp, idx}); - // Compute intervals for each input vector in the shuffle tree. The first + // Compute intervals for each vector in the shuffle tree. The first // level computation is special-cased to keep the implementation simpler. SmallVector firstLevelIntervals(toElemsDefs.size(), @@ -235,13 +235,13 @@ void VectorShuffleTreeBuilder::computeInputVectorIntervals() { duplicateLastIfOdd(toElemsDefs); duplicateLastIfOdd(firstLevelIntervals); - inputIntervalsPerLevel.push_back(std::move(firstLevelIntervals)); + intervalsPerLevel.push_back(std::move(firstLevelIntervals)); // Compute intervals for the remaining levels. unsigned outputNumElements = cast(fromElemsOp.getResult().getType()).getNumElements(); for (unsigned level = 1; level < numLevels; ++level) { - const auto &prevLevelIntervals = inputIntervalsPerLevel[level - 1]; + const auto &prevLevelIntervals = intervalsPerLevel[level - 1]; SmallVector currentLevelIntervals( llvm::divideCeil(prevLevelIntervals.size(), 2), {kMaxUnsigned, kMaxUnsigned}); @@ -253,8 +253,8 @@ void VectorShuffleTreeBuilder::computeInputVectorIntervals() { const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1]; // The interval of a vector at the current level is the union of the - // intervals of the two input vectors from the previous level being - // shuffled at this level. + // intervals of the two vectors from the previous level being shuffled at + // this level. interval.first = std::min(prevLhsInterval.first, prevRhsInterval.first); interval.second = std::min(std::max(prevLhsInterval.second, prevRhsInterval.second), @@ -262,14 +262,14 @@ void VectorShuffleTreeBuilder::computeInputVectorIntervals() { } duplicateLastIfOdd(currentLevelIntervals); - inputIntervalsPerLevel.push_back(std::move(currentLevelIntervals)); + intervalsPerLevel.push_back(std::move(currentLevelIntervals)); } } -/// Compute the uniform output vector size for each level of the shuffle tree, -/// given the intervals of the input vectors at that level. The output vector -/// size of a level is the size of the widest interval resulting from shuffling -/// each pair of input vectors. +/// Compute the uniform vector size for each level of the shuffle tree, given +/// the intervals of the vectors at that level. The vector size of a level is +/// the size of the widest interval resulting from shuffling each pair of +/// vectors. /// /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: /// @@ -278,15 +278,16 @@ void VectorShuffleTreeBuilder::computeInputVectorIntervals() { /// * Level 1: [0,7], [2,8] /// /// Vector sizes: -/// * Level 0: max(size_of([0,6] U [1,7] = [0,7]) = 8, +/// * Level 0: Arbitrary sizes from input vectors. +/// * Level 1: max(size_of([0,6] U [1,7] = [0,7]) = 8, /// size_of([2,8] U [2,8] = [2,8]) = 7) = 8 /// -/// * Level 1: max(size_of([0,7] U [2,8] = [0,8]) = 9) = 9 +/// * Level 2: max(size_of([0,7] U [2,8] = [0,8]) = 9) = 9 /// -void VectorShuffleTreeBuilder::computeOutputVectorSizePerLevel() { +void VectorShuffleTreeBuilder::computeShuffleTreeVectorSizes() { // Compute vector size for each level. - for (unsigned level = 0; level < numLevels; ++level) { - const auto ¤tLevelIntervals = inputIntervalsPerLevel[level]; + for (unsigned level = 1; level < numLevels; ++level) { + const auto ¤tLevelIntervals = intervalsPerLevel[level]; unsigned currentVectorSize = 1; for (size_t i = 0; i < currentLevelIntervals.size(); i += 2) { const auto &lhsInterval = currentLevelIntervals[i]; @@ -338,7 +339,7 @@ void VectorShuffleTreeBuilder::dump() { /// Compute the shuffle tree configuration for the given `vector.to_elements` + /// `vector.from_elements` input sequence. This method builds a balanced binary -/// shuffle tree that combines pairs of input vectors at each level. +/// shuffle tree that combines pairs of vectors at each level. /// /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: /// @@ -356,32 +357,32 @@ void VectorShuffleTreeBuilder::dump() { /// \ / /// %2_1_0_0 =vector.shuffle /// -/// The configuration comprises of computing the intervals of the input vectors -/// at each level of the shuffle tree (i.e., %2, %1, %0, %0, %2_1, %0_0 and -/// %2_1_0_0) and the output vector size for each level. For further details on -/// intervals and output vector size computation, please, take a look at the -/// corresponding utility functions. +/// The actual representation of the shuffle tree configuration is based on +/// intervals of each vector at each level of the shuffle tree (i.e., %2, %1, +/// %0, %0, %2_1, %0_0 and %2_1_0_0) and the output vector size for each level. +/// For further details on intervals and output vector size computation, please, +/// take a look at the corresponding utility functions. LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() { // Initialize shuffle tree information based on its size. - numLevels = std::max(1u, llvm::Log2_64_Ceil(toElemsDefs.size())); + numLevels = 1 + llvm::Log2_64_Ceil(toElemsDefs.size()); vectorSizePerLevel.resize(numLevels, 0); - inputIntervalsPerLevel.reserve(numLevels); + intervalsPerLevel.reserve(numLevels); - computeInputVectorIntervals(); - computeOutputVectorSizePerLevel(); + computeShuffleTreeIntervals(); + computeShuffleTreeVectorSizes(); dump(); return success(); } -// ===--------------------------------------------------------------------===// +// ===---------------------------------------------------------------------===// // Shuffle Tree Code Generation Utilities. -// ===--------------------------------------------------------------------===// +// ===---------------------------------------------------------------------===// /// Compute the permutation mask for shuffling two input `vector.to_elements` -/// ops. The permutation mask is the mapping of the input vector elements to -/// their final position in the output vector, relative to the intermediate -/// output vector of the `vector.shuffle` operation combining the two inputs. +/// ops. The permutation mask is the mapping of the vector elements to their +/// final position in the output vector, relative to the intermediate output +/// vector of the `vector.shuffle` operation combining the two inputs. /// /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: /// @@ -421,8 +422,7 @@ static SmallVector computePermutationShuffleMask( // The mask index is the ordinal position of the operand in // `vector.from_elements` operand list. We make this position relative to - // the interval of the output vector resulting from combining the two - // input vectors. + // the output interval resulting from combining the two input intervals. if (currentToElemOp == toElementOp0) { maskIdx -= interval0.first; } else { @@ -488,8 +488,7 @@ static SmallVector computePropagationShuffleMask( SmallVector mask(outputVectorSize, ShuffleOp::kPoisonIndex); // Propagate any element from the input mask that is not poison. For the RHS - // input vector, the mask index is offset by the offset between the two - // intervals of the input vectors. + // vector, offset mask index by the distance between the intervals. for (unsigned i = 0; i < inputVectorSize; ++i) { if (lhsShuffleMask[i] != ShuffleOp::kPoisonIndex) mask[i] = i; @@ -549,9 +548,9 @@ static SmallVector computePropagationShuffleMask( /// %2 = vector.shuffle %0, %1 [0, 1, 8, 9, 4, 5, 6, 7, 14] /// : vector<8xf32>, vector<8xf32> /// -/// The code generation comprises of combining pairs of input vectors for each -/// level of the tree, using the pre-computed per tree level intervals and -/// vector sizes. The algorithm generates two kinds of shuffle masks: +/// The code generation consists of combining pairs of vectors at each level of +/// the tree, using the pre-computed tree intervals and vector sizes. The +/// algorithm generates two kinds of shuffle masks: permutation masks and /// permutation masks and propagation masks. Permutation masks are computed for /// the first level of the tree and permute the input vector elements to their /// relative position in the final output. Propagation masks are computed for @@ -567,29 +566,32 @@ Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) { SmallVector levelInputs; llvm::transform(toElemsDefs, std::back_inserter(levelInputs), [](ToElementsOp toElemsOp) { return toElemsOp.getSource(); }); - // TODO: Check that every pair of input has the same vector size. Otherwise, - // promote the narrower one to the wider one. - // Build shuffle tree by combining pairs of vectors. + // Build shuffle tree by combining pairs of vectors (represented by their + // corresponding intervals) in one level and producing a new vector with the + // next level's vector length. Skip the interval from the last tree level + // (actual shuffle tree output) as it doesn't have to be combined with + // anything else. Location loc = fromElemsOp.getLoc(); unsigned currentLevel = 0; - for (const auto &[levelVectorSize, inputIntervals] : - llvm::zip_equal(vectorSizePerLevel, inputIntervalsPerLevel)) { + for (const auto &[nextLevelVectorSize, intervals] : + llvm::zip_equal(ArrayRef(vectorSizePerLevel).drop_front(), + ArrayRef(intervalsPerLevel).drop_back())) { duplicateLastIfOdd(levelInputs); - LLVM_DEBUG(llvm::dbgs() - << llvm::indent(1, kIndScale) << "* Processing level " - << currentLevel << " (vector size: " << levelVectorSize - << ", # inputs: " << levelInputs.size() << ")\n"); + LLVM_DEBUG(llvm::dbgs() << llvm::indent(1, kIndScale) + << "* Processing level " << currentLevel + << " (output vector size: " << nextLevelVectorSize + << ", # inputs: " << levelInputs.size() << ")\n"); // Process level input vectors in pairs. SmallVector levelOutputs; for (size_t i = 0; i < levelInputs.size(); i += 2) { Value lhsVector = levelInputs[i]; Value rhsVector = levelInputs[i + 1]; - const Interval &lhsInterval = inputIntervals[i]; - const Interval &rhsInterval = inputIntervals[i + 1]; + const Interval &lhsInterval = intervals[i]; + const Interval &rhsInterval = intervals[i + 1]; // For the first level of the tree, permute the vector elements to their // relative position in the final output. For subsequent levels, we @@ -598,13 +600,13 @@ Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) { if (currentLevel == 0) { shuffleMask = computePermutationShuffleMask( toElemsDefs[i], lhsInterval, toElemsDefs[i + 1], rhsInterval, - fromElemsOp, levelVectorSize); + fromElemsOp, nextLevelVectorSize); } else { auto lhsShuffleOp = cast(lhsVector.getDefiningOp()); auto rhsShuffleOp = cast(rhsVector.getDefiningOp()); shuffleMask = computePropagationShuffleMask(lhsShuffleOp, lhsInterval, rhsShuffleOp, rhsInterval, - levelVectorSize); + nextLevelVectorSize); } Value shuffleVal = rewriter.create( @@ -640,7 +642,8 @@ getToElementsDefiningOps(FromElementsOp fromElemsOp, } /// Pass to rewrite `vector.to_elements` + `vector.from_elements` sequences into -/// a tree of `vector.shuffle` operations. +/// a tree of `vector.shuffle` operations. Only 1-D input vectors are supported +/// for now. struct ToFromElementsToShuffleTreeRewrite final : OpRewritePattern { @@ -651,22 +654,24 @@ struct ToFromElementsToShuffleTreeRewrite final VectorType resultType = fromElemsOp.getType(); if (resultType.getRank() != 1) return rewriter.notifyMatchFailure( - fromElemsOp, "Multi-dimensional vectors are not supported yet"); + fromElemsOp, + "multi-dimensional output vectors are not supported yet"); if (resultType.isScalable()) return rewriter.notifyMatchFailure( fromElemsOp, "'vector.from_elements' does not support scalable vectors"); + // Gather all the `vector.to_elements` operations that feed the + // `vector.from_elements` operation. Other op definitions are not supported. SmallVector toElemsDefs; if (failed(getToElementsDefiningOps(fromElemsOp, toElemsDefs))) return rewriter.notifyMatchFailure(fromElemsOp, "unsupported sources"); - int64_t numElements = - toElemsDefs.front().getSource().getType().getNumElements(); - for (ToElementsOp toElemsOp : toElemsDefs) { - if (toElemsOp.getSource().getType().getNumElements() != numElements) - return rewriter.notifyMatchFailure( - fromElemsOp, "unsupported sources with different vector sizes"); + if (llvm::any_of(toElemsDefs, [](ToElementsOp toElemsOp) { + return toElemsOp.getSource().getType().getRank() != 1; + })) { + return rewriter.notifyMatchFailure( + fromElemsOp, "multi-dimensional input vectors are not supported yet"); } if (llvm::any_of(toElemsDefs, [](ToElementsOp toElemsOp) { diff --git a/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir b/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir index a8d3d5278e893..593d28fcc4178 100644 --- a/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir @@ -18,6 +18,26 @@ func.func @trivial_forwarding(%a: vector<8xf32>) -> vector<8xf32> { // ----- +func.func @unsupported_multi_dim_vector_inputs(%a: vector<2x4xf32>, %b: vector<2x4xf32>) -> vector<4xf32> { + %0:8 = vector.to_elements %a : vector<2x4xf32> + %1:8 = vector.to_elements %b : vector<2x4xf32> + %2 = vector.from_elements %0#0, %0#7, + %1#0, %1#7 : vector<4xf32> + return %2 : vector<4xf32> +} + +// ----- + +func.func @unsupported_multi_dim_vector_output(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<2x2xf32> { + %0:8 = vector.to_elements %a : vector<8xf32> + %1:8 = vector.to_elements %b : vector<8xf32> + %2 = vector.from_elements %0#0, %0#7, + %1#0, %1#7 : vector<2x2xf32> + return %2 : vector<2x2xf32> +} + +// ----- + func.func @single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> { %0:8 = vector.to_elements %a : vector<8xf32> %1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<8xf32> @@ -344,7 +364,6 @@ func.func @shuffle_tree_broadcast_4x2_to_32(%a: vector<2xf32>, // ----- - func.func @shuffle_tree_arbitrary_mixed_sizes( %a : vector<2xf32>, %b : vector<1xf32>, @@ -360,11 +379,15 @@ func.func @shuffle_tree_arbitrary_mixed_sizes( return %5 : vector<6xf32> } -// TODO: Support mixed vector sizes. - // CHECK-LABEL: func @shuffle_tree_arbitrary_mixed_sizes( -// CHECK-COUNT-5: vector.to_elements -// CHECK: vector.from_elements +// CHECK-SAME: %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<1xf32>, %[[C:.*]]: vector<3xf32>, %[[D:.*]]: vector<1xf32>, %[[E:.*]]: vector<5xf32> +// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[C]] [0, 2, -1, -1] : vector<2xf32>, vector<3xf32> +// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[D]], %[[E]] [0, 1, -1, 4] : vector<1xf32>, vector<5xf32> +// CHECK: %[[L0SH2:.*]] = vector.shuffle %[[B]], %[[B]] [0, -1, -1, -1] : vector<1xf32>, vector<1xf32> +// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 4, 5, -1, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L2SH0:.*]] = vector.shuffle %[[L0SH2]], %[[L0SH2]] [0, -1, -1, -1, -1, -1] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L3SH0:.*]] = vector.shuffle %[[L1SH0]], %[[L2SH0]] [0, 1, 2, 3, 6, 5] : vector<6xf32>, vector<6xf32> +// CHECK: return %[[L3SH0]] : vector<6xf32> // ----- From bf1c747db73c176a7e2e2b4a9a581c21591393e0 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Tue, 8 Jul 2025 00:11:25 +0000 Subject: [PATCH 4/4] Simplify algorithm, improve doc, feedback --- ...LowerVectorToFromElementsToShuffleTree.cpp | 131 ++++++++++-------- 1 file changed, 77 insertions(+), 54 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp index 766f89254a191..050bce2a927a7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp @@ -59,7 +59,7 @@ constexpr unsigned kMaxUnsigned = std::numeric_limits::max(); /// the `vector.from_elements` operand list. /// 2. Each vector at each level is used only once. /// 3. The number of levels in the tree is: -/// ceil(log2(# `vector.to_elements` ops)). +/// 1 (input vectors) + ceil(max(1,log2(# `vector.to_elements` ops))). /// 4. Vectors at each level of the tree have the same vector length. /// 5. Vector positions that do not need to be shuffled are represented with /// poison in the shuffle mask. @@ -83,15 +83,15 @@ constexpr unsigned kMaxUnsigned = std::numeric_limits::max(); /// : vector<8xf32>, vector<8xf32> /// /// Comments: -/// * The shuffle tree has two levels: +/// * The shuffle tree has three levels: +/// - Level 0 = (%a, %b, %c, %c) /// - Level 1 = (%shuffle0, %shuffle1) /// - Level 2 = (%result) /// * `%a` and `%b` are shuffled first because they appear first in the /// `vector.from_elements` operand list (`%0#0` and `%1#0`). /// * `%c` is shuffled with itself because the number of /// `vector.from_elements` operands is odd. -/// * The vector length for the first and second levels are 8 and 16, -/// respectively. +/// * The vector length for level 1 and level 2 are 8 and 16, respectively. /// * `%shuffle1` uses poison values to match the vector length of its /// tree level (8). /// @@ -117,8 +117,7 @@ constexpr unsigned kMaxUnsigned = std::numeric_limits::max(); /// `vector.from_elements` operand list (`%2#2` and `%1#1`). /// * `%a` is shuffled with itself because the number of /// `vector.from_elements` operands is odd. -/// * The vector length for the first and second levels are 8 and 9, -/// respectively. +/// * The vector length for level 1 and level 2 are 8 and 9, respectively. /// * `%shuffle0` uses poison values to mark unused vector positions and /// match the vector length of its tree level (8). /// @@ -181,8 +180,8 @@ static void duplicateLastIfOdd(SmallVectorImpl &values) { // ===---------------------------------------------------------------------===// /// Compute the intervals for all the vectors in the shuffle tree. The interval -/// of a vector is the range of positions that vector contributes to the final -/// output vector. +/// interval of a vector is the range of positions that the vector contributes +/// to the final output vector. /// /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: /// @@ -192,8 +191,9 @@ static void duplicateLastIfOdd(SmallVectorImpl &values) { /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> /// -/// Level 0 has 4 vectors (%2, %1, %0, %0, the last one is duplicated to make -/// the number of inputs even) so we compute the interval for each vector: +/// The shuffle tree has 3 levels. Level 0 has 4 vectors (%2, %1, %0, %0, the +/// last one is duplicated to make the number of inputs even) so we compute the +/// interval for each vector: /// /// * intervalsPerLevel[0][0] = interval(%2) = [0,6] /// * intervalsPerLevel[0][1] = interval(%1) = [1,7] @@ -202,8 +202,15 @@ static void duplicateLastIfOdd(SmallVectorImpl &values) { /// /// Level 1 has 2 vectors, resulting from the shuffling of %2 + %1 and %0 + %0 /// so we compute the intervals for each vector at level 1 as: -/// * intervalsPerLevel[1][0] = interval(%2) U interval(%1) = [0,7] -/// * intervalsPerLevel[1][1] = interval(%0) U interval(%0) = [2,8] +/// * intervalsPerLevel[1][0] = intervalsPerLevel[0][0] U +/// intervalsPerLevel[0][1] = [0,7] +/// * intervalsPerLevel[1][1] = intervalsPerLevel[0][2] U +/// intervalsPerLevel[0][3] = [2,8] +/// +/// Level 2 is the last level and only contains the output vector so the +/// interval should be the whole output vector: +/// * intervalsPerLevel[2][0] = intervalsPerLevel[1][0] U +/// intervalsPerLevel[1][1] = [0,8] /// void VectorShuffleTreeBuilder::computeShuffleTreeIntervals() { // Map `vector.to_elements` ops to their ordinal position in the @@ -241,13 +248,14 @@ void VectorShuffleTreeBuilder::computeShuffleTreeIntervals() { unsigned outputNumElements = cast(fromElemsOp.getResult().getType()).getNumElements(); for (unsigned level = 1; level < numLevels; ++level) { + bool isLastLevel = level == numLevels - 1; const auto &prevLevelIntervals = intervalsPerLevel[level - 1]; SmallVector currentLevelIntervals( llvm::divideCeil(prevLevelIntervals.size(), 2), {kMaxUnsigned, kMaxUnsigned}); - for (size_t inputIdx = 0; inputIdx < currentLevelIntervals.size(); - ++inputIdx) { + size_t currentNumLevels = currentLevelIntervals.size(); + for (size_t inputIdx = 0; inputIdx < currentNumLevels; ++inputIdx) { auto &interval = currentLevelIntervals[inputIdx]; const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2]; const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1]; @@ -255,48 +263,57 @@ void VectorShuffleTreeBuilder::computeShuffleTreeIntervals() { // The interval of a vector at the current level is the union of the // intervals of the two vectors from the previous level being shuffled at // this level. - interval.first = std::min(prevLhsInterval.first, prevRhsInterval.first); + interval.first = prevLhsInterval.first; interval.second = - std::min(std::max(prevLhsInterval.second, prevRhsInterval.second), - outputNumElements - 1); + std::max(prevLhsInterval.second, prevRhsInterval.second); } - duplicateLastIfOdd(currentLevelIntervals); + // Duplicate the last interval if the number of intervals is odd, except for + // the last level as it only contains the output vector, which doesn't have + // to be shuffled. + if (!isLastLevel) + duplicateLastIfOdd(currentLevelIntervals); + intervalsPerLevel.push_back(std::move(currentLevelIntervals)); } } /// Compute the uniform vector size for each level of the shuffle tree, given -/// the intervals of the vectors at that level. The vector size of a level is -/// the size of the widest interval resulting from shuffling each pair of -/// vectors. +/// the intervals of the vectors at each level. The vector size of a level is +/// the size of the widest interval at that level. /// /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: /// /// Intervals: /// * Level 0: [0,6], [1,7], [2,8], [2,8] /// * Level 1: [0,7], [2,8] +/// * Level 2: [0,8] /// /// Vector sizes: /// * Level 0: Arbitrary sizes from input vectors. -/// * Level 1: max(size_of([0,6] U [1,7] = [0,7]) = 8, -/// size_of([2,8] U [2,8] = [2,8]) = 7) = 8 -/// -/// * Level 2: max(size_of([0,7] U [2,8] = [0,8]) = 9) = 9 +/// * Level 1: max(size_of([0,7]) = 8, size_of([2,8]) = 7) = 8 +/// * Level 2: max(size_of([0,8]) = 9) = 9 /// void VectorShuffleTreeBuilder::computeShuffleTreeVectorSizes() { - // Compute vector size for each level. - for (unsigned level = 1; level < numLevels; ++level) { + // Compute vector size for each level. There are two direct cases: + // * First level: the vector size depends on the actual size of the input + // vectors and it's allowed to be non-uniform. We set it to 0. + // * Last level: the vector size is the output vector size so it doesn't + // have to be computed using intervals. + vectorSizePerLevel.front() = 0; + vectorSizePerLevel.back() = + cast(fromElemsOp.getResult().getType()).getNumElements(); + + for (unsigned level = 1; level < numLevels - 1; ++level) { const auto ¤tLevelIntervals = intervalsPerLevel[level]; unsigned currentVectorSize = 1; - for (size_t i = 0; i < currentLevelIntervals.size(); i += 2) { - const auto &lhsInterval = currentLevelIntervals[i]; - const auto &rhsInterval = currentLevelIntervals[i + 1]; - unsigned combinedIntervalSize = - std::max(lhsInterval.second, rhsInterval.second) - lhsInterval.first + - 1; - currentVectorSize = std::max(currentVectorSize, combinedIntervalSize); + size_t numIntervals = currentLevelIntervals.size(); + for (size_t i = 0; i < numIntervals; ++i) { + const auto &interval = currentLevelIntervals[i]; + unsigned intervalSize = interval.second - interval.first + 1; + currentVectorSize = std::max(currentVectorSize, intervalSize); } + assert(currentVectorSize > 0 && "vector size must be positive"); vectorSizePerLevel[level] = currentVectorSize; } } @@ -317,14 +334,13 @@ void VectorShuffleTreeBuilder::dump() { llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Total levels: " << numLevels << "\n"; llvm::dbgs() << llvm::indent(indLv, kIndScale) - << "* Vector sizes per level: ["; + << "* Vector sizes per level: "; llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs()); - llvm::dbgs() << "]\n"; + llvm::dbgs() << "\n"; llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Input intervals per level:\n"; ++indLv; - for (const auto &[level, intervals] : - llvm::enumerate(inputIntervalsPerLevel)) { + for (const auto &[level, intervals] : llvm::enumerate(intervalsPerLevel)) { llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Level " << level << ": "; llvm::interleaveComma(intervals, llvm::dbgs(), @@ -359,12 +375,14 @@ void VectorShuffleTreeBuilder::dump() { /// /// The actual representation of the shuffle tree configuration is based on /// intervals of each vector at each level of the shuffle tree (i.e., %2, %1, -/// %0, %0, %2_1, %0_0 and %2_1_0_0) and the output vector size for each level. -/// For further details on intervals and output vector size computation, please, -/// take a look at the corresponding utility functions. +/// %0, %0, %2_1, %0_0 and %2_1_0_0) and the vector size for each level. For +/// further details on intervals and vector size computation, please, take a +/// look at the corresponding utility functions. LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() { - // Initialize shuffle tree information based on its size. - numLevels = 1 + llvm::Log2_64_Ceil(toElemsDefs.size()); + // Initialize shuffle tree information based on its size. For the number of + // levels, we add one to account for the input `vector.to_elements` as one + // tree level. We need the std::max(1) to account for a single element input. + numLevels = 1u + std::max(1u, llvm::Log2_64_Ceil(toElemsDefs.size())); vectorSizePerLevel.resize(numLevels, 0); intervalsPerLevel.reserve(numLevels); @@ -394,7 +412,7 @@ LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() { /// /// => /// -/// // Level 0, vector length = 8 +/// // Level 1, vector length = 8 /// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6] /// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1] /// @@ -463,13 +481,13 @@ static SmallVector computePermutationShuffleMask( /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> /// -/// // Level 0, vector length = 8 +/// // Level 1, vector length = 8 /// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6] /// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1] /// /// => /// -/// // Level 1, vector length = 9 +/// // Level 2, vector length = 9 /// PropagationShuffleMask(%2_1, %0_0) = [0, 1, 8, 9, 4, 5, 6, 7, 14] /// /// TODO: Implement mask compression to reduce the number of intermediate poison @@ -534,10 +552,11 @@ static SmallVector computePropagationShuffleMask( /// /// with the pre-computed shuffle tree configuration: /// -/// * Vector sizes per level: [8, 9] +/// * Vector sizes per level: 0, 8, 9 /// * Input intervals per level: /// * Level 0: [0,6], [1,7], [2,8], [2,8] /// * Level 1: [0,7], [2,8] +/// * Level 2: [0,8] /// /// => /// @@ -551,12 +570,15 @@ static SmallVector computePropagationShuffleMask( /// The code generation consists of combining pairs of vectors at each level of /// the tree, using the pre-computed tree intervals and vector sizes. The /// algorithm generates two kinds of shuffle masks: permutation masks and -/// permutation masks and propagation masks. Permutation masks are computed for -/// the first level of the tree and permute the input vector elements to their -/// relative position in the final output. Propagation masks are computed for -/// subsequent levels and propagate the elements to the next level without -/// permutation. For further details on the shuffle mask computation, please, -/// take a look at the corresponding `computePermutationShuffleMask` and +/// permutation masks and propagation masks: +/// * Permutation masks are computed for the first level of the tree and +/// permute the input vector elements to their relative position in the +/// final output. +/// * Propagation masks are computed for subsequent levels and propagate the +/// elements to the next level without permutation. +/// +/// For further details on the shuffle mask computation, please, take a look at +/// the corresponding `computePermutationShuffleMask` and /// `computePropagationShuffleMask` functions. /// Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) { @@ -587,7 +609,8 @@ Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) { // Process level input vectors in pairs. SmallVector levelOutputs; - for (size_t i = 0; i < levelInputs.size(); i += 2) { + for (size_t i = 0, numLevelInputs = levelInputs.size(); i < numLevelInputs; + i += 2) { Value lhsVector = levelInputs[i]; Value rhsVector = levelInputs[i + 1]; const Interval &lhsInterval = intervals[i];