Skip to content

Commit 8dfd754

Browse files
[fixup] Rename a member function and chanege some allocs to allocas
1 parent 5b7e081 commit 8dfd754

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ class VectorContractRewriter {
201201
}
202202

203203
public:
204-
void rewrite(vector::ContractionOp op, PatternRewriter &rewriter) {
204+
void lower(vector::ContractionOp op, PatternRewriter &rewriter) {
205205
// Create some convenience types.
206206
auto inputElementType = cast<ShapedType>(lhs.getType()).getElementType();
207207
auto accElementType = cast<ShapedType>(acc.getType()).getElementType();
@@ -460,7 +460,7 @@ class LowerContractionToNeonI8MMPattern
460460
VectorContractRewriterI8MM vcr;
461461
if (failed(vcr.matchAndInit(op, rewriter)))
462462
return failure();
463-
vcr.rewrite(op, rewriter);
463+
vcr.lower(op, rewriter);
464464

465465
return success();
466466
}
@@ -476,7 +476,7 @@ class LowerContractionToNeonBFMMLAPattern
476476
VectorContractRewriterBFMMLA vcr;
477477
if (failed(vcr.matchAndInit(op, rewriter)))
478478
return failure();
479-
vcr.rewrite(op, rewriter);
479+
vcr.lower(op, rewriter);
480480

481481
return success();
482482
}

mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func.func @matrix_by_matrix_mul_and_acc() {
5858
[ 0.5, -1.3, -2.2, 0.1],
5959
[-0.7, 1.0, 1.7, -1.0]]> : vector<4x4xf32>
6060

61-
%acc_mem = memref.alloc() : memref<4x4xf32>
61+
%acc_mem = memref.alloca() : memref<4x4xf32>
6262
vector.transfer_write %acc_cst, %acc_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xf32>, memref<4x4xf32>
6363
%acc = vector.transfer_read %acc_mem[%c0, %c0], %c0_f32 {in_bounds = [true, true]} : memref<4x4xf32>, vector<4x4xf32>
6464

@@ -68,7 +68,7 @@ func.func @matrix_by_matrix_mul_and_acc() {
6868
[-0.4, 0.6, 0.8, -0.5],
6969
[-0.6, -1.0, -1.0, -1.0]]> : vector<4x4xbf16>
7070

71-
%lhs_mem = memref.alloc() : memref<4x4xbf16>
71+
%lhs_mem = memref.alloca() : memref<4x4xbf16>
7272
vector.transfer_write %lhs_cst, %lhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
7373
%lhs = vector.transfer_read %lhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
7474

@@ -78,7 +78,7 @@ func.func @matrix_by_matrix_mul_and_acc() {
7878
[-0.2, 0.4, 1.0, 0.4],
7979
[-1.3, -0.2, -2.2, 0.3]]> : vector<4x4xbf16>
8080

81-
%rhs_mem = memref.alloc() : memref<4x4xbf16>
81+
%rhs_mem = memref.alloca() : memref<4x4xbf16>
8282
vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
8383
%rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
8484

@@ -121,14 +121,14 @@ func.func @vector_by_matrix_mul_and_acc() {
121121
// Accumulator test data
122122
%acc_cst = arith.constant dense<[0.7, 1.0, -0.1, 1.8]> : vector<4xf32>
123123

124-
%acc_mem = memref.alloc() : memref<4xf32>
124+
%acc_mem = memref.alloca() : memref<4xf32>
125125
vector.transfer_write %acc_cst, %acc_mem[%c0] {in_bounds = [true] } : vector<4xf32>, memref<4xf32>
126126
%acc = vector.transfer_read %acc_mem[%c0], %c0_f32 {in_bounds = [true]} : memref<4xf32>, vector<4xf32>
127127

128128
// LHS test data
129129
%lhs_cst = arith.constant dense<[0.1, 0.7, -0.9, 1.3]> : vector<4xbf16>
130130

131-
%lhs_mem = memref.alloc() : memref<4xbf16>
131+
%lhs_mem = memref.alloca() : memref<4xbf16>
132132
vector.transfer_write %lhs_cst, %lhs_mem[%c0] {in_bounds = [true] } : vector<4xbf16>, memref<4xbf16>
133133
%lhs = vector.transfer_read %lhs_mem[%c0], %c0_bf16 {in_bounds = [true]} : memref<4xbf16>, vector<4xbf16>
134134

@@ -138,7 +138,7 @@ func.func @vector_by_matrix_mul_and_acc() {
138138
[-0.2, 0.4, 1.0, 0.4],
139139
[-1.3, -0.2, -2.2, 0.3]]> : vector<4x4xbf16>
140140

141-
%rhs_mem = memref.alloc() : memref<4x4xbf16>
141+
%rhs_mem = memref.alloca() : memref<4x4xbf16>
142142
vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
143143
%rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
144144

0 commit comments

Comments
 (0)