Skip to content

[mlir][xegpu] Relax rank restriction of TensorDescType #145916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

chencha3
Copy link
Contributor

@chencha3 chencha3 commented Jun 26, 2025

This PR removes the rank restriction of TensorDescType, such that XeGPU can accept n-D tensor descriptor. Here is the summary of major changes:

  1. removed rank checks around TensorDescType, and size validations against hardware capabilities, which assumed an XeGPU operation is a 1:1 match to a hardware instruction.
  2. improved verifiers accordingly.
  3. added 3D unit tests
  4. added 3D tests for blocking pass.

@chencha3 chencha3 marked this pull request as ready for review June 26, 2025 20:20
@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Chao Chen (chencha3)

Changes

Patch is 41.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145916.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+12-5)
  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+2)
  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+8-11)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+17-36)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+30-38)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp (+2-4)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (-3)
  • (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+12-67)
  • (modified) mlir/test/Dialect/XeGPU/ops.mlir (+82-2)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-blocking.mlir (+87-4)
  • (modified) mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp (+2-6)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 84c1dc1373ee5..bcd5724835783 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -43,8 +43,8 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td
 
   let parameters = (ins
     OptionalParameter<"MemorySpaceAttr">: $memory_space,
-    OptionalParameter<"IntegerAttr", "1">: $array_length,
-    OptionalParameter<"BoolAttr", "true">: $boundary_check
+    OptionalParameter<"IntegerAttr">: $array_length,
+    OptionalParameter<"BoolAttr">: $boundary_check
   );
 
   let builders = [
@@ -67,8 +67,11 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
         TensorDesc is located, `Global` device memory or `Shared` local memory.
         It is default to `Global`.
 
-    2.  `chunk_size`: indicates number of contiguous elements accessed for each
-        offset, default is 1. It is used with `scattered` attr only.
+    2. `chunk_size`: Specifies the number of contiguous elements accessed per offset.
+      The default value is 1. While XeGPU supports a range of chunk sizes, hardware
+      may only allow specific values (e.g., 1, 2, 3, 4, 8, 16, 32, 64, 128, 256).
+      Therefore, XeGPU will legalize the chunk size as needed prior to lowering to
+      hardware instructions.
   }];
 
   let parameters = (ins
@@ -91,7 +94,11 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
     )>
   ];
 
-  let genVerifyDecl = 1;
+  let extraClassDeclaration = [{
+    int64_t getChunkSizeAsInt() {
+      return getChunkSize().getInt();
+    }
+  }];
  }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index daab65ec893b8..b6f047d132c87 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -757,6 +757,8 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset",
   let assemblyFormat = [{
     $TensorDesc `,` $offsets attr-dict `:` qualified(type($TensorDesc)) `,` type($offsets)
   }];
+
+  let hasVerifier = 1;
 }
 
 def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]> {
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 84314875c2ae5..bd30335ddc344 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -17,12 +17,12 @@ def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64,
 def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
 def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
 def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>;
-def XeGPU_DpasOprType: VectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
-def XeGPU_DpasResType: VectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
-def XeGPU_OffsetType: VectorOfRankAndType<[1], [Index]>;
-def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1], [I1]>, I1]>;
-def XeGPU_ValueType: AnyTypeOf<[VectorOfRankAndType<[1,2,3,4], [XeGPU_ScalarType]>, XeGPU_ScalarType]>;
-def XeGPU_Vector2DType: VectorOfRankAndType<[2], [XeGPU_ScalarType]>;
+def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
+def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
+def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
+def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
+def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
+def XeGPU_Vector2DType: FixedVectorOfRankAndType<[2], [XeGPU_ScalarType]>;
 
 // common base class for types in XeGPU dialect
 class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
@@ -118,7 +118,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
   ];
 
   let extraClassDeclaration = [{
-    using TensorType::clone;
     using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
     using mlir::ShapedType::Trait<TensorDescType>::getRank;
     using mlir::ShapedType::Trait<TensorDescType>::getNumElements;
@@ -184,10 +183,8 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
     int getChunkSize() {
       auto attr = getEncoding();
       auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
-      assert((!attr || scatter_attr) && "invalid on non ScatterTensorDescAttr.");
-      if (scatter_attr)
-        return scatter_attr.getChunkSize().getInt();
-      return 1;
+      assert(scatter_attr && "invalid on non ScatterTensorDescAttr.");
+      return scatter_attr.getChunkSizeAsInt();
     }
 
     /// Helper to drop all layout information from the TensorDesc type.
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 649e0d453015f..32a4bf883829f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -125,18 +125,6 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
   return Base::get(context, scopeAttr, chunkSizeAttr);
 }
 
-LogicalResult ScatterTensorDescAttr::verify(
-    llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
-    MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
-  int64_t chunkSize = chunk_size.getInt();
-  SmallVector<int64_t> supportedChunkSizes = {1,  2,  3,  4,   8,
-                                              16, 32, 64, 128, 256};
-  if (!llvm::is_contained(supportedChunkSizes, chunkSize))
-    return emitError() << "invalid chunk size";
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // XeGPU_LayoutAttr
 //===----------------------------------------------------------------------===//
@@ -310,15 +298,16 @@ LogicalResult TensorDescType::verify(
     llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
     mlir::Attribute encoding, mlir::Attribute layout) {
   size_t rank = shape.size();
-  if (rank != 1 && rank != 2)
-    return emitError() << "expected 1D or 2D tensor";
+
+  if (rank == 0)
+    return emitError() << "expected non-zero rank tensor";
 
   auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
   if (blockAttr) {
     MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
-    if (rank == 2 && memorySpaceAttr &&
+    if (rank > 1 && memorySpaceAttr &&
         memorySpaceAttr.getValue() == MemorySpace::SLM)
-      return emitError() << "SLM is not supported for 2D block tensor";
+      return emitError() << "SLM is only supported for 1D block tensor";
   }
 
   // for gather and scatter ops, Low-precision types are packed in 32-bit units.
@@ -329,22 +318,18 @@ LogicalResult TensorDescType::verify(
           : 1;
   auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
   if (scatterAttr) {
-    // Expected tensor ranks for scattered data:
-    //   - 1D tensor for fully non-contiguous elements (chunk size == 1)
-    //   - 2D tensor for scattered blocks (chunk size > 1)
-    unsigned chunkSize = scatterAttr.getChunkSize().getInt();
+    int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
     if (rank == 1 && chunkSize != 1)
       return emitError() << "expected non-contiguous elements for 1D tensor";
-    if (rank == 2 && chunkSize < 2)
-      return emitError() << "expected chunk blocks for 2D tensor";
+
     // If chunk size > 1, the second dimension of the tensor shape must be
-    // equal to chunk size and it must be a multiple of the packing factor.
+    // equal to chunk size and it must be a multiple of the
+    // chunkAlignmentFactor.
     if (chunkSize > 1) {
       if (shape.back() != chunkSize)
-        return emitError() << "expected tensor shape[1] to match chunk size";
+        return emitError() << "expected last dim of tensor to match chunk size";
       if (shape.back() % chunkAlignmentFactor != 0)
-        return emitError() << "expected tensor shape[1] to be a multiple of "
-                              "chunk alignment factor "
+        return emitError() << "expected last dim of tensor to be a multiple of "
                            << chunkAlignmentFactor;
     }
   }
@@ -357,17 +342,13 @@ LogicalResult TensorDescType::verify(
     auto laneData = layoutAttr.getLaneData();
     if (scatterAttr && laneData) {
       // Validate subgroup mapping rules for scattered tensors.
-      // A work-item's slice of the tensor with shape [sg_size] or
-      // [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width]
-      // respectively, the mapping should reflect that. This is because each
-      // work item access data in 32 bit granularity.
-
-      if (rank > 1 && laneData[0] != 1)
+      // if chunkSize > 1, the last dimension of the tensor should
+      // be distributed in the units divisible by chunkAlignmentFactor.
+      int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
+      if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
         return emitError()
-               << "cannot map over non-contiguous scattered row elements";
-      if (laneData[rank - 1] != chunkAlignmentFactor)
-        return emitError() << "work item data mapping must match the number of "
-                              "contiguous elements";
+               << "expected last dim of lane_data to be a multiple of: "
+               << chunkAlignmentFactor;
     }
 
     if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 2793c7a35bc97..3f6f596449429 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -87,9 +87,12 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
     return emitError()
            << "Value should have the same element type as TensorDesc.";
 
-  if (tdescShape[0] != maskShape[0])
+  llvm::SmallVector<int64_t> expectedMaskShape(tdescShape);
+  if (chunkSize > 1)
+    expectedMaskShape.pop_back();
+  if (expectedMaskShape != maskShape)
     return emitError()
-           << "dim-0 of the Mask and TensorDesc should be the same.";
+           << "Mask should match TensorDesc except the chunk size dim.";
 
   // a valid shape for SIMT case
   if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
@@ -203,11 +206,9 @@ LogicalResult CreateNdDescOp::verify() {
         "is a memref) should match with each other.");
 
   // check result TensorDesc rank
-  invalidRank = (getType().getRank() > 2 || getType().getRank() > rank);
-
-  if (invalidRank)
+  if (getType().getRank() > rank)
     return emitOpError(
-        "Expecting the TensorDesc rank is up to 2 and not greater than the "
+        "Expecting the TensorDesc rank is not greater than the "
         "ranks of shape, strides, offsets or the memref source.");
 
   if (invalidElemTy)
@@ -247,9 +248,6 @@ LogicalResult LoadNdOp::verify() {
   auto tdescTy = getTensorDescType();
   auto valueTy = getType();
 
-  if (tdescTy.getRank() > 2)
-    return emitOpError("Expecting a 1D/2D TensorDesc.\n");
-
   if (tdescTy.isScattered())
     return emitOpError("Expects a non-scattered TensorDesc.\n");
 
@@ -316,15 +314,13 @@ LogicalResult LoadNdOp::verify() {
   }
 
   auto array_len = tdescTy.getArrayLength();
-  if (array_len > 1) {
+  if (array_len > 1)
     tdescShape.insert(tdescShape.begin(), array_len);
-  }
 
-  if (tdescShape != valueShape) {
+  if (tdescShape != valueShape)
     return emitOpError() << "Result shape " << makeString(valueShape)
                          << " is not consistent with tensor descriptor "
                          << tdescTy;
-  }
 
   return success();
 }
@@ -336,9 +332,6 @@ LogicalResult StoreNdOp::verify() {
   auto dstTy = getTensorDescType(); // Tile
   auto valTy = getValueType();      // Vector
 
-  if (dstTy.getRank() > 2)
-    return emitOpError("Expecting a 1D/2D TensorDesc.\n");
-
   if (dstTy.isScattered())
     return emitOpError("Expects a non-scattered TensorDesc.\n");
 
@@ -370,22 +363,21 @@ LogicalResult StoreNdOp::verify() {
       return emitOpError()
              << "TensorDesc doesn't need LayoutAttr for SIMT code";
 
-    if (tdescElems % valueElems) {
+    if (tdescElems % valueElems)
       return emitOpError()
              << "Value shape " << makeString(getShapeOf(valTy))
              << " is not a valid distribution for tensor descriptor " << dstTy;
-    }
+
     return success();
   }
 
   // SIMD code should have the same shape as the tensor descriptor.
   auto tdescShape = getShapeOf(dstTy);
   auto valueShape = getShapeOf(valTy);
-  if (tdescShape != valueShape) {
+  if (tdescShape != valueShape)
     return emitOpError() << "Value shape " << makeString(valueShape)
                          << " is not consistent with tensor descriptor "
                          << dstTy;
-  }
 
   return success();
 }
@@ -450,24 +442,7 @@ LogicalResult CreateDescOp::verify() {
 
   // check total size
   auto chunkSize = tdescTy.getChunkSize();
-  auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
-  auto bitsPerLane = elemBits * chunkSize;
-  if (chunkSize > 1 && bitsPerLane % 32) {
-    // For 8-bit and 16-bit data, the hardware only supports chunk size of 1.
-    // For 32-bit data, the hardware can support larger larger chunk size. So
-    // we can bitcast 8-bit/16-bit data to 32-bit data for better performance.
-    // But this requires the total size is 32 bit aligned to make the
-    // optimization work.
-    return emitOpError(
-        "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
-  }
-
-  auto lscConstraints = 512 * 8; // each access is upto 512 bytes.
-  if (elemBits * tdescTy.getNumElements() > lscConstraints)
-    return emitOpError("total access size (simd_lanes * chunk_size * "
-                       "sizeof(elemTy)) is upto 512 bytes.");
-
-  SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
+  SmallVector<int64_t> shape(getOffsetsType().getShape());
   if (chunkSize != 1)
     shape.push_back(chunkSize);
 
@@ -563,6 +538,23 @@ void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, tensorDesc, ofrs);
 }
 
+LogicalResult UpdateOffsetOp::verify() {
+  auto tdescTy = getTensorDescType();
+  if (!tdescTy.isScattered())
+    return emitOpError("Expects a scattered TensorDesc.\n");
+
+  auto expectedOffsetShape = getShapeOf(tdescTy);
+  auto offsetShape = getShapeOf(getOffsetsType());
+  if (tdescTy.getChunkSize() > 1)
+    expectedOffsetShape.pop_back();
+
+  if (expectedOffsetShape != offsetShape)
+    return emitOpError(
+        "Offsets should match TensorDesc except the chunk size dim.");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_DpasOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 3950e8f70d1ca..c6c4e3aaa41ed 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -303,9 +303,7 @@ void XeGPUBlockingPass::runOnOperation() {
       // If the encoding is a ScatterTensorDescAttr, we need to
       // potentially adjust the chunk size based on the inst_data.
       if (tdescTy.isScattered()) {
-        auto scatterAttr =
-            llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding);
-        int64_t chunkSize = scatterAttr.getChunkSize().getInt();
+        int64_t chunkSize = tdescTy.getChunkSize();
 
         if (chunkSize > 1) {
           int64_t blockedChunkSize = chunkSize;
@@ -315,7 +313,7 @@ void XeGPUBlockingPass::runOnOperation() {
 
           // To create a new attribute with a different chunk_size:
           auto newEncoding = xegpu::ScatterTensorDescAttr::get(
-              ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize);
+              ctx, tdescTy.getMemorySpace(), blockedChunkSize);
 
           encoding = newEncoding;
         }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 2c48a735bf956..66690f9e9a91a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -625,9 +625,6 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    if (tdescTy.getRank() > 2)
-      return failure();
-
     if (!tdescTy.isScattered())
       return failure();
 
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index a2778cd94d963..a6f7d0992d7e7 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -1,8 +1,8 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
 // -----
-func.func @create_nd_tdesc_vc_1(%src: memref<24xf32>) {
-  // expected-error@+1 {{Expecting the TensorDesc rank is up to 2 and not greater than the ranks of shape, strides, offsets or the memref source}}
+func.func @test_create_nd_tdesc_vc_1(%src: memref<24xf32>) {
+  // expected-error@+1 {{Expecting the TensorDesc rank is not greater than the ranks of shape, strides, offsets or the memref source}}
   %1 = xegpu.create_nd_tdesc %src[0] : memref<24xf32> -> !xegpu.tensor_desc<8x16xf32>
   return
 }
@@ -17,7 +17,7 @@ func.func @create_nd_tdesc_vc_2(%src: memref<24x32xf32>) {
 
 // -----
 func.func @create_nd_tdesc_vc_3(%src: memref<2x24x32xf32, 3>) {
-  // expected-error@+1 {{SLM is not supported for 2D block tensor}}
+  // expected-error@+1 {{SLM is only supported for 1D block tensor}}
   %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = slm>>
   return
 }
@@ -199,15 +199,6 @@ func.func @create_tdesc_vc_1(%src: ui64) {
   return
 }
 
-// -----
-func.func @create_tdesc_vc_2(%src: ui64) {
-  %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
-  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<8xindex>
-  // expected-error@+1 {{expected chunk blocks for 2D tensor}}
-          -> !xegpu.tensor_desc<8x4xf16, #xegpu.scatter_tdesc_attr<>>
-  return
-}
-
 // -----
 func.func @create_tdesc_vc_3(%src: memref<?xf32>) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -221,25 +212,16 @@ func.func @create_tdesc_vc_3(%src: memref<?xf32>) {
 func.func @create_tdesc_vc_4(%src: memref<?xf32>) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>
-  // expected-error@+1 {{invalid chunk size}}
-          -> !xegpu.tensor_desc<4x5xf32, #xegpu.scatter_tdesc_attr<chunk_size = 5>>
-  return
-}
-
-// -----
-func.func @create_tdesc_vc_5(%src: memref<?xf32>) {
-  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>
-  // expected-error@+1 {{expected tensor shape[1] to match chunk size}}
+  // expected-error@+1 {{expected last dim of tensor to match chunk size}}
           -> !xegpu.tensor_desc<4x5xf32, #xegpu.scatter_tdesc_attr<chunk_size = 4>>
   return
 }
 
 // -----
-func.func @create_tdesc_vc_6(%src: memref<?xf16>) {
+func.func @create_tdesc_vc_5(%src: memref<?xf16>) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %1 = xegpu.create_tdesc %src, %0 : memref<?xf16>, vector<4xindex>
-  // expected-error@+1 {{tensor shape[1] to be a multiple of chunk alignment factor 2}}
+  // expected-error@+1 {{last dim of tensor to be a multiple of 2}}
           -> !xegpu.tensor_desc<4x3xf16, #xegpu.scatter_tdesc_attr<chunk_size = 3>>
   return
 }
@@ -267,23 +249,15 @@ func.func @prefetch_vc_2(%src: ui64) {
 func.func @create_tdesc_layout_1(%src: ui64) {
   %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   // expected-error@+1 {{expected layout rank to match tensor rank}}
-  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #...
[truncated]

@@ -184,10 +183,8 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
int getChunkSize() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe no need to expose this method to TensorDesc. You can always get it using TensorDesc->ScatterAttr->getChunk?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to keep it, since chunk size is logically part of the tensor_desc. It now provides a little convenience, also hides the implementation details in case of changes in future.

// if chunkSize > 1, the last dimension of the tensor should
// be distributed in the units divisible by chunkAlignmentFactor.
int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we allow lane layout also be nD?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the layout rank matches the tensor/vector rank.

Comment on lines 546 to 547
auto expectedOffsetShape = getShapeOf(tdescTy);
auto offsetShape = getShapeOf(getOffsetsType());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don’t “almost always” use auto, but do use auto with initializers like cast(...) or other places where the type is already obvious from the context.

nit: make the types explicit here for readability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Comment on lines -453 to -470
auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
auto bitsPerLane = elemBits * chunkSize;
if (chunkSize > 1 && bitsPerLane % 32) {
// For 8-bit and 16-bit data, the hardware only supports chunk size of 1.
// For 32-bit data, the hardware can support larger larger chunk size. So
// we can bitcast 8-bit/16-bit data to 32-bit data for better performance.
// But this requires the total size is 32 bit aligned to make the
// optimization work.
return emitOpError(
"access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
}

auto lscConstraints = 512 * 8; // each access is upto 512 bytes.
if (elemBits * tdescTy.getNumElements() > lscConstraints)
return emitOpError("total access size (simd_lanes * chunk_size * "
"sizeof(elemTy)) is upto 512 bytes.");

SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where are these verified now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are totally gone now. I suppose this will be checked in XeVM. They were appropriate when XeGPU was designed to match hardware abstraction. But now XeGPU is promoted to workgroup level.

auto scatterAttr =
llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding);
int64_t chunkSize = scatterAttr.getChunkSize().getInt();
int64_t chunkSize = tdescTy.getChunkSize();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: safer to not directly expose getChunkSize to tensor_desc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind explaining the reason a little bit?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean now anyone with a tensor_desc can call this method. It not immediately clear from the API that it requires a scatter encoding. so anywhere you call this method you need to guard it by if (tensorDesc.hasScattert()), if not it as an unsafe call (assert will be removed in release build). So I don't see any direct benefit of exposing this to tensorDesc.

@@ -1,8 +1,8 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics

// -----
func.func @create_nd_tdesc_vc_1(%src: memref<24xf32>) {
// expected-error@+1 {{Expecting the TensorDesc rank is up to 2 and not greater than the ranks of shape, strides, offsets or the memref source}}
func.func @test_create_nd_tdesc_vc_1(%src: memref<24xf32>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: drop test prefix in tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0, 0] : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16>
%1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16>
// CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8x16xf16> -> vector<4x8x16xf16>
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8x16xf16> -> vector<4x8x16xf16>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens with transpose attribute for nD case?

Copy link
Contributor Author

@chencha3 chencha3 Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are currently untouched. From semantic perspective, they are expressible. From hardware capability perspective, a pass is needed to legalize it for n-D case.

%block_id_x = gpu.block_id x
%m = arith.muli %block_id_x, %c32 : index

%a_tdesc = xegpu.create_nd_tdesc %A[%m, %m, %c0] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<32x32x32xf16, #l>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add check statement here also.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


gpu.func @test_3d_scattered_tensor_desc(%src: ui64) {

%cst = arith.constant dense<[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to show what happens to these offsets during blocking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added.

@charithaintc
Copy link
Contributor

shouldn't we add some verification for the transpose case (and transpose bit width) for loadNd?

//CHECK: [[block_id_x:%.*]] = gpu.block_id x
//CHECK: [[m:%.*]] = arith.muli [[block_id_x]], [[c32]] : index
%block_id_x = gpu.block_id x
%m = arith.muli %block_id_x, %c32 : index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this is block_id, not subgroup_id?

//CHECK: xegpu.create_nd_tdesc [[arg0]][[[m]], [[m]], [[c0]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16>
//CHECK: xegpu.create_nd_tdesc [[arg0]][[[m]], [[m]], [[c16]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16>
//CHECK: [[off1:%.*]] = arith.addi [[m]], [[c8]] : index
//CHECK: xegpu.create_nd_tdesc [[arg0]][[[off1]], [[m]], [[c0]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How this tensor_desc[8, 32, 16] can be further lowered to 2d block loader?
I think that the inst_data should be [1, 32, 16], then the blocking will unroll it to 8 meaningful 2d tensor_desc.
xegpu.create_nd_tdesc [[arg0]][[[off1]], [[m]], [[c0]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants