Skip to content

[mli][vector] canonicalize vector.from_elements from ascending extracts #139819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 2, 2025
Merged
121 changes: 121 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Support/LLVM.h"
Expand Down Expand Up @@ -2387,9 +2388,129 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
return success();
}

/// Rewrite from_elements on multiple scalar extracts as a shape_cast
/// on a single extract. Example:
/// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
/// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
/// %2 = vector.from_elements %0, %1 : vector<2xi8>
///
/// becomes
/// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
/// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
///
/// The requirements for this to be valid are
///
/// i) The elements are extracted from the same vector (%source).
///
/// ii) The elements form a suffix of %source. Specifically, the number
/// of elements is the same as the product of the last N dimension sizes
/// of %source, for some N.
///
/// iii) The elements are extracted contiguously in ascending order.

class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {

using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(FromElementsOp fromElements,
PatternRewriter &rewriter) const override {

// Handled by `rewriteFromElementsAsSplat`
if (fromElements.getType().getNumElements() == 1)
return failure();

// The common source that all elements are extracted from, if one exists.
TypedValue<VectorType> source;
// The position of the combined extract operation, if one is created.
ArrayRef<int64_t> combinedPosition;
// The expected index of extraction of the current element in the loop, if
// elements are extracted contiguously in ascending order.
SmallVector<int64_t> expectedPosition;

for (auto [insertIndex, element] :
llvm::enumerate(fromElements.getElements())) {

// Check that the element is from a vector.extract operation.
auto extractOp =
dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
if (!extractOp) {
return rewriter.notifyMatchFailure(fromElements,
"element not from vector.extract");
}

// Check condition (i) by checking that all elements have the same source
// as the first element.
if (insertIndex == 0) {
source = extractOp.getVector();
} else if (extractOp.getVector() != source) {
return rewriter.notifyMatchFailure(fromElements,
"element from different vector");
}

ArrayRef<int64_t> position = extractOp.getStaticPosition();
int64_t rank = position.size();
assert(rank == source.getType().getRank() &&
"scalar extract must have full rank position");

// Check condition (ii) by checking that the position that the first
// element is extracted from has sufficient trailing 0s. For example, in
//
// %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
// [...]
// %elms = vector.from_elements %elm0, [...] : vector<12xi8>
//
// The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
// elements, which is the number of elements of %n, so this is valid.
if (insertIndex == 0) {
const int64_t numElms = fromElements.getType().getNumElements();
int64_t numSuffixElms = 1;
int64_t index = rank;
while (index > 0 && position[index - 1] == 0 &&
numSuffixElms < numElms) {
numSuffixElms *= source.getType().getDimSize(index - 1);
--index;
}
if (numSuffixElms != numElms) {
return rewriter.notifyMatchFailure(
fromElements, "elements do not form a suffix of source");
}
expectedPosition = llvm::to_vector(position);
combinedPosition = position.drop_back(rank - index);
}

// Check condition (iii).
else if (expectedPosition != position) {
return rewriter.notifyMatchFailure(
fromElements, "elements not in ascending order (static order)");
}
increment(expectedPosition, source.getType().getShape());
}

auto extracted = rewriter.createOrFold<vector::ExtractOp>(
fromElements.getLoc(), source, combinedPosition);

rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
fromElements, fromElements.getType(), extracted);

return success();
}

/// Increments n-D `indices` by 1 starting from the innermost dimension.
static void increment(MutableArrayRef<int64_t> indices,
ArrayRef<int64_t> shape) {
for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
indices[dim] += 1;
if (indices[dim] < shape[dim])
break;
indices[dim] = 0;
}
}
};

void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(rewriteFromElementsAsSplat);
results.add<FromElementsToShapeCast>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
69 changes: 0 additions & 69 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2943,75 +2943,6 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,

// -----

// CHECK-LABEL: func @extract_scalar_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
// Extract from 0D.
%0 = vector.from_elements %a : vector<f32>
%1 = vector.extract %0[] : f32 from vector<f32>

// Extract from 1D.
%2 = vector.from_elements %a : vector<1xf32>
%3 = vector.extract %2[0] : f32 from vector<1xf32>
%4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
%5 = vector.extract %4[4] : f32 from vector<5xf32>

// Extract from 2D.
%6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
%7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
%8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
%9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
%10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>

// CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
}

// -----

// CHECK-LABEL: func @extract_1d_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
// CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
%1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
%2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
// CHECK: return %[[splat1]], %[[splat2]]
return %1, %2 : vector<3xf32>, vector<3xf32>
}

// -----

// CHECK-LABEL: func @extract_2d_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
// CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
%1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
%2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: return %[[splat1]], %[[splat2]]
return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
}

// -----

// CHECK-LABEL: func @from_elements_to_splat(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
// CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
%0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
// CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
%1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
// CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
%2 = vector.from_elements %a : vector<f32>
// CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
}

// -----

// CHECK-LABEL: func @vector_insert_const_regression(
// CHECK: llvm.mlir.undef
// CHECK: vector.insert
Expand Down
Loading
Loading