-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Relax input type requirement on tensor.splat
#145893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
`tensor.splat` is currently restricted to only accepting input values that are of integer, index or float type. This is much more restrictive than the tensor type itself as well as any lowerings of it. This PR therefore removes this restriction by using `AnyType` for the input value. Whether the type is actually valid or not for a tensor remains verified through the type equality of the result tensor element type and the input type.
@llvm/pr-subscribers-mlir-tensor Author: Markus Böck (zero9178) Changes
This is much more restrictive than the tensor type itself as well as any lowerings of it. This PR therefore removes this restriction by using Full diff: https://github.com/llvm/llvm-project/pull/145893.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 47962f75558ea..7d396e5c64c28 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1771,8 +1771,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
]> {
let summary = "tensor splat or broadcast operation";
let description = [{
- Broadcast the operand to all elements of the result tensor. The operand is
- required to be of integer/index/float type.
+ Broadcast the operand to all elements of the result tensor.
An additional argument of type `index` must be provided for each dynamic
dimension present in the result type.
@@ -1795,8 +1794,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
```
}];
- let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
- "integer/index/float type">:$input,
+ let arguments = (ins AnyType:$input,
Variadic<Index>:$dynamicSizes);
let results = (outs AnyRankedTensor:$aggregate);
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index c0adc8a49bf70..e202a6b3f3e7a 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -615,6 +615,21 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
// -----
+// CHECK-LABEL: func @tensor.splat_other(
+// CHECK-SAME: %[[F:.*]]: !llvm.ptr)
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4x!llvm.ptr>
+// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: %[[MAPPED:.*]] = linalg.map
+// CHECK: outs(%[[ALLOC_T]] : tensor<10x2x4x!llvm.ptr>)
+// CHECK: linalg.yield %[[F]]
+// CHECK: return %[[MAPPED]] : tensor<10x2x4x!llvm.ptr>
+func.func @tensor.splat_other(%f: !llvm.ptr) -> tensor<10x2x4x!llvm.ptr> {
+ %t = tensor.splat %f : tensor<10x2x4x!llvm.ptr>
+ return %t : tensor<10x2x4x!llvm.ptr>
+}
+
+// -----
+
// CHECK-LABEL: func @tensor.concat(
// CHECK-SAME: %[[F:.*]]: tensor<8xf32>)
// CHECK: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index f35d52e700084..665657a67dc61 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -466,9 +466,10 @@ func.func @invalid_splat(%v : f32) {
// -----
-func.func @invalid_splat(%v : vector<8xf32>) {
- // expected-error@+1 {{must be integer/index/float type}}
- %w = tensor.splat %v : tensor<8xvector<8xf32>>
+// expected-note@+1 {{prior use here}}
+func.func @invalid_splat(%v : f32) {
+ // expected-error@+1 {{expects different type than prior uses: 'i32' vs 'f32'}}
+ %w = tensor.splat %v : tensor<1xi32>
return
}
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 930986211cb6d..0fd4b87508a79 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -314,12 +314,16 @@ func.func @pad_to_static_size(%arg0: tensor<?x?xf32>, %ub0: index, %ub1: index,
// CHECK-LABEL: func @test_splat_op
// CHECK-SAME: [[S:%arg[0-9]+]]: f32
-func.func @test_splat_op(%s : f32) {
+// CHECK-SAME: [[P:%arg[0-9]+]]: !llvm.ptr
+func.func @test_splat_op(%s : f32, %p : !llvm.ptr) {
// CHECK: tensor.splat [[S]] : tensor<8xf32>
%v = tensor.splat %s : tensor<8xf32>
// CHECK: tensor.splat [[S]] : tensor<4xf32>
%u = "tensor.splat"(%s) : (f32) -> tensor<4xf32>
+
+ // CHECK: tensor.splat [[P]] : tensor<8x!llvm.ptr>
+ %w = tensor.splat %p : tensor<8x!llvm.ptr>
return
}
|
@llvm/pr-subscribers-mlir Author: Markus Böck (zero9178) Changes
This is much more restrictive than the tensor type itself as well as any lowerings of it. This PR therefore removes this restriction by using Full diff: https://github.com/llvm/llvm-project/pull/145893.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 47962f75558ea..7d396e5c64c28 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1771,8 +1771,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
]> {
let summary = "tensor splat or broadcast operation";
let description = [{
- Broadcast the operand to all elements of the result tensor. The operand is
- required to be of integer/index/float type.
+ Broadcast the operand to all elements of the result tensor.
An additional argument of type `index` must be provided for each dynamic
dimension present in the result type.
@@ -1795,8 +1794,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
```
}];
- let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
- "integer/index/float type">:$input,
+ let arguments = (ins AnyType:$input,
Variadic<Index>:$dynamicSizes);
let results = (outs AnyRankedTensor:$aggregate);
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index c0adc8a49bf70..e202a6b3f3e7a 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -615,6 +615,21 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
// -----
+// CHECK-LABEL: func @tensor.splat_other(
+// CHECK-SAME: %[[F:.*]]: !llvm.ptr)
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4x!llvm.ptr>
+// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: %[[MAPPED:.*]] = linalg.map
+// CHECK: outs(%[[ALLOC_T]] : tensor<10x2x4x!llvm.ptr>)
+// CHECK: linalg.yield %[[F]]
+// CHECK: return %[[MAPPED]] : tensor<10x2x4x!llvm.ptr>
+func.func @tensor.splat_other(%f: !llvm.ptr) -> tensor<10x2x4x!llvm.ptr> {
+ %t = tensor.splat %f : tensor<10x2x4x!llvm.ptr>
+ return %t : tensor<10x2x4x!llvm.ptr>
+}
+
+// -----
+
// CHECK-LABEL: func @tensor.concat(
// CHECK-SAME: %[[F:.*]]: tensor<8xf32>)
// CHECK: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index f35d52e700084..665657a67dc61 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -466,9 +466,10 @@ func.func @invalid_splat(%v : f32) {
// -----
-func.func @invalid_splat(%v : vector<8xf32>) {
- // expected-error@+1 {{must be integer/index/float type}}
- %w = tensor.splat %v : tensor<8xvector<8xf32>>
+// expected-note@+1 {{prior use here}}
+func.func @invalid_splat(%v : f32) {
+ // expected-error@+1 {{expects different type than prior uses: 'i32' vs 'f32'}}
+ %w = tensor.splat %v : tensor<1xi32>
return
}
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 930986211cb6d..0fd4b87508a79 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -314,12 +314,16 @@ func.func @pad_to_static_size(%arg0: tensor<?x?xf32>, %ub0: index, %ub1: index,
// CHECK-LABEL: func @test_splat_op
// CHECK-SAME: [[S:%arg[0-9]+]]: f32
-func.func @test_splat_op(%s : f32) {
+// CHECK-SAME: [[P:%arg[0-9]+]]: !llvm.ptr
+func.func @test_splat_op(%s : f32, %p : !llvm.ptr) {
// CHECK: tensor.splat [[S]] : tensor<8xf32>
%v = tensor.splat %s : tensor<8xf32>
// CHECK: tensor.splat [[S]] : tensor<4xf32>
%u = "tensor.splat"(%s) : (f32) -> tensor<4xf32>
+
+ // CHECK: tensor.splat [[P]] : tensor<8x!llvm.ptr>
+ %w = tensor.splat %p : tensor<8x!llvm.ptr>
return
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, thanks!
tensor.splat
is currently restricted to only accepting input values that are of integer, index or float type.This is much more restrictive than the tensor type itself as well as any lowerings of it.
This PR therefore removes this restriction by using
AnyType
for the input value. Whether the type is actually valid or not for a tensor remains verified through the type equality of the result tensor element type and the input type.