Skip to content

[mlir][tensor] Relax input type requirement on tensor.splat #145893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

zero9178
Copy link
Member

tensor.splat is currently restricted to only accepting input values that are of integer, index or float type.

This is much more restrictive than the tensor type itself as well as any lowerings of it.

This PR therefore removes this restriction by using AnyType for the input value. Whether the type is actually valid or not for a tensor remains verified through the type equality of the result tensor element type and the input type.

`tensor.splat` is currently restricted to only accepting input values that are of integer, index or float type.

This is much more restrictive than the tensor type itself as well as any lowerings of it.

This PR therefore removes this restriction by using `AnyType` for the input value. Whether the type is actually valid or not for a tensor remains verified through the type equality of the result tensor element type and the input type.
@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-mlir-tensor

Author: Markus Böck (zero9178)

Changes

tensor.splat is currently restricted to only accepting input values that are of integer, index or float type.

This is much more restrictive than the tensor type itself as well as any lowerings of it.

This PR therefore removes this restriction by using AnyType for the input value. Whether the type is actually valid or not for a tensor remains verified through the type equality of the result tensor element type and the input type.


Full diff: https://github.com/llvm/llvm-project/pull/145893.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+2-4)
  • (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+15)
  • (modified) mlir/test/Dialect/Tensor/invalid.mlir (+4-3)
  • (modified) mlir/test/Dialect/Tensor/ops.mlir (+5-1)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 47962f75558ea..7d396e5c64c28 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1771,8 +1771,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
   ]> {
   let summary = "tensor splat or broadcast operation";
   let description = [{
-    Broadcast the operand to all elements of the result tensor. The operand is
-    required to be of integer/index/float type.
+    Broadcast the operand to all elements of the result tensor.
 
     An additional argument of type `index` must be provided for each dynamic
     dimension present in the result type.
@@ -1795,8 +1794,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
     ```
   }];
 
-  let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
-                                 "integer/index/float type">:$input,
+  let arguments = (ins AnyType:$input,
                        Variadic<Index>:$dynamicSizes);
   let results = (outs AnyRankedTensor:$aggregate);
 
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index c0adc8a49bf70..e202a6b3f3e7a 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -615,6 +615,21 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
 
 // -----
 
+// CHECK-LABEL:   func @tensor.splat_other(
+// CHECK-SAME:        %[[F:.*]]: !llvm.ptr)
+// CHECK-DAG:       %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4x!llvm.ptr>
+// CHECK:           %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK:           %[[MAPPED:.*]] = linalg.map
+// CHECK:                 outs(%[[ALLOC_T]] : tensor<10x2x4x!llvm.ptr>)
+// CHECK:             linalg.yield %[[F]]
+// CHECK:           return %[[MAPPED]] : tensor<10x2x4x!llvm.ptr>
+func.func @tensor.splat_other(%f: !llvm.ptr) -> tensor<10x2x4x!llvm.ptr> {
+  %t = tensor.splat %f : tensor<10x2x4x!llvm.ptr>
+  return %t : tensor<10x2x4x!llvm.ptr>
+}
+
+// -----
+
 // CHECK-LABEL:   func @tensor.concat(
 // CHECK-SAME:        %[[F:.*]]: tensor<8xf32>)
 // CHECK:           %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index f35d52e700084..665657a67dc61 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -466,9 +466,10 @@ func.func @invalid_splat(%v : f32) {
 
 // -----
 
-func.func @invalid_splat(%v : vector<8xf32>) {
-  // expected-error@+1 {{must be integer/index/float type}}
-  %w = tensor.splat %v : tensor<8xvector<8xf32>>
+// expected-note@+1 {{prior use here}}
+func.func @invalid_splat(%v : f32) {
+  // expected-error@+1 {{expects different type than prior uses: 'i32' vs 'f32'}}
+  %w = tensor.splat %v : tensor<1xi32>
   return
 }
 
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 930986211cb6d..0fd4b87508a79 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -314,12 +314,16 @@ func.func @pad_to_static_size(%arg0: tensor<?x?xf32>, %ub0: index, %ub1: index,
 
 // CHECK-LABEL: func @test_splat_op
 // CHECK-SAME: [[S:%arg[0-9]+]]: f32
-func.func @test_splat_op(%s : f32) {
+// CHECK-SAME: [[P:%arg[0-9]+]]: !llvm.ptr
+func.func @test_splat_op(%s : f32, %p : !llvm.ptr) {
   // CHECK: tensor.splat [[S]] : tensor<8xf32>
   %v = tensor.splat %s : tensor<8xf32>
 
   // CHECK: tensor.splat [[S]] : tensor<4xf32>
   %u = "tensor.splat"(%s) : (f32) -> tensor<4xf32>
+
+  // CHECK: tensor.splat [[P]] : tensor<8x!llvm.ptr>
+  %w = tensor.splat %p : tensor<8x!llvm.ptr>
   return
 }
 

@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-mlir

Author: Markus Böck (zero9178)

Changes

tensor.splat is currently restricted to only accepting input values that are of integer, index or float type.

This is much more restrictive than the tensor type itself as well as any lowerings of it.

This PR therefore removes this restriction by using AnyType for the input value. Whether the type is actually valid or not for a tensor remains verified through the type equality of the result tensor element type and the input type.


Full diff: https://github.com/llvm/llvm-project/pull/145893.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+2-4)
  • (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+15)
  • (modified) mlir/test/Dialect/Tensor/invalid.mlir (+4-3)
  • (modified) mlir/test/Dialect/Tensor/ops.mlir (+5-1)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 47962f75558ea..7d396e5c64c28 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1771,8 +1771,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
   ]> {
   let summary = "tensor splat or broadcast operation";
   let description = [{
-    Broadcast the operand to all elements of the result tensor. The operand is
-    required to be of integer/index/float type.
+    Broadcast the operand to all elements of the result tensor.
 
     An additional argument of type `index` must be provided for each dynamic
     dimension present in the result type.
@@ -1795,8 +1794,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
     ```
   }];
 
-  let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
-                                 "integer/index/float type">:$input,
+  let arguments = (ins AnyType:$input,
                        Variadic<Index>:$dynamicSizes);
   let results = (outs AnyRankedTensor:$aggregate);
 
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index c0adc8a49bf70..e202a6b3f3e7a 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -615,6 +615,21 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
 
 // -----
 
+// CHECK-LABEL:   func @tensor.splat_other(
+// CHECK-SAME:        %[[F:.*]]: !llvm.ptr)
+// CHECK-DAG:       %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4x!llvm.ptr>
+// CHECK:           %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK:           %[[MAPPED:.*]] = linalg.map
+// CHECK:                 outs(%[[ALLOC_T]] : tensor<10x2x4x!llvm.ptr>)
+// CHECK:             linalg.yield %[[F]]
+// CHECK:           return %[[MAPPED]] : tensor<10x2x4x!llvm.ptr>
+func.func @tensor.splat_other(%f: !llvm.ptr) -> tensor<10x2x4x!llvm.ptr> {
+  %t = tensor.splat %f : tensor<10x2x4x!llvm.ptr>
+  return %t : tensor<10x2x4x!llvm.ptr>
+}
+
+// -----
+
 // CHECK-LABEL:   func @tensor.concat(
 // CHECK-SAME:        %[[F:.*]]: tensor<8xf32>)
 // CHECK:           %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index f35d52e700084..665657a67dc61 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -466,9 +466,10 @@ func.func @invalid_splat(%v : f32) {
 
 // -----
 
-func.func @invalid_splat(%v : vector<8xf32>) {
-  // expected-error@+1 {{must be integer/index/float type}}
-  %w = tensor.splat %v : tensor<8xvector<8xf32>>
+// expected-note@+1 {{prior use here}}
+func.func @invalid_splat(%v : f32) {
+  // expected-error@+1 {{expects different type than prior uses: 'i32' vs 'f32'}}
+  %w = tensor.splat %v : tensor<1xi32>
   return
 }
 
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 930986211cb6d..0fd4b87508a79 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -314,12 +314,16 @@ func.func @pad_to_static_size(%arg0: tensor<?x?xf32>, %ub0: index, %ub1: index,
 
 // CHECK-LABEL: func @test_splat_op
 // CHECK-SAME: [[S:%arg[0-9]+]]: f32
-func.func @test_splat_op(%s : f32) {
+// CHECK-SAME: [[P:%arg[0-9]+]]: !llvm.ptr
+func.func @test_splat_op(%s : f32, %p : !llvm.ptr) {
   // CHECK: tensor.splat [[S]] : tensor<8xf32>
   %v = tensor.splat %s : tensor<8xf32>
 
   // CHECK: tensor.splat [[S]] : tensor<4xf32>
   %u = "tensor.splat"(%s) : (f32) -> tensor<4xf32>
+
+  // CHECK: tensor.splat [[P]] : tensor<8x!llvm.ptr>
+  %w = tensor.splat %p : tensor<8x!llvm.ptr>
   return
 }
 

@kuhar kuhar requested a review from nicolasvasilache June 26, 2025 13:51
Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants