diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 84c1dc1373ee5..bcd5724835783 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -43,8 +43,8 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td let parameters = (ins OptionalParameter<"MemorySpaceAttr">: $memory_space, - OptionalParameter<"IntegerAttr", "1">: $array_length, - OptionalParameter<"BoolAttr", "true">: $boundary_check + OptionalParameter<"IntegerAttr">: $array_length, + OptionalParameter<"BoolAttr">: $boundary_check ); let builders = [ @@ -67,8 +67,11 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat TensorDesc is located, `Global` device memory or `Shared` local memory. It is default to `Global`. - 2. `chunk_size`: indicates number of contiguous elements accessed for each - offset, default is 1. It is used with `scattered` attr only. + 2. `chunk_size`: Specifies the number of contiguous elements accessed per offset. + The default value is 1. While XeGPU supports a range of chunk sizes, hardware + may only allow specific values (e.g., 1, 2, 3, 4, 8, 16, 32, 64, 128, 256). + Therefore, XeGPU will legalize the chunk size as needed prior to lowering to + hardware instructions. }]; let parameters = (ins @@ -91,7 +94,11 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat )> ]; - let genVerifyDecl = 1; + let extraClassDeclaration = [{ + int64_t getChunkSizeAsInt() { + return getChunkSize().getInt(); + } + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index daab65ec893b8..b6f047d132c87 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -757,6 +757,8 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset", let assemblyFormat = [{ $TensorDesc `,` $offsets attr-dict `:` qualified(type($TensorDesc)) `,` type($offsets) }]; + + let hasVerifier = 1; } def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]> { diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 84314875c2ae5..bd30335ddc344 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -17,12 +17,12 @@ def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>; def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>; def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>; -def XeGPU_DpasOprType: VectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>; -def XeGPU_DpasResType: VectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>; -def XeGPU_OffsetType: VectorOfRankAndType<[1], [Index]>; -def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1], [I1]>, I1]>; -def XeGPU_ValueType: AnyTypeOf<[VectorOfRankAndType<[1,2,3,4], [XeGPU_ScalarType]>, XeGPU_ScalarType]>; -def XeGPU_Vector2DType: VectorOfRankAndType<[2], [XeGPU_ScalarType]>; +def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>; +def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>; +def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>; +def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>; +def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>; +def XeGPU_Vector2DType: FixedVectorOfRankAndType<[2], [XeGPU_ScalarType]>; // common base class for types in XeGPU dialect class XeGPUTypeDef traits = [], @@ -118,7 +118,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", ]; let extraClassDeclaration = [{ - using TensorType::clone; using mlir::ShapedType::Trait::getElementTypeBitWidth; using mlir::ShapedType::Trait::getRank; using mlir::ShapedType::Trait::getNumElements; @@ -184,10 +183,8 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", int getChunkSize() { auto attr = getEncoding(); auto scatter_attr = mlir::dyn_cast_if_present(attr); - assert((!attr || scatter_attr) && "invalid on non ScatterTensorDescAttr."); - if (scatter_attr) - return scatter_attr.getChunkSize().getInt(); - return 1; + assert(scatter_attr && "invalid on non ScatterTensorDescAttr."); + return scatter_attr.getChunkSizeAsInt(); } /// Helper to drop all layout information from the TensorDesc type. diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 649e0d453015f..32a4bf883829f 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -125,18 +125,6 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context, return Base::get(context, scopeAttr, chunkSizeAttr); } -LogicalResult ScatterTensorDescAttr::verify( - llvm::function_ref emitError, - MemorySpaceAttr memory_space, IntegerAttr chunk_size) { - int64_t chunkSize = chunk_size.getInt(); - SmallVector supportedChunkSizes = {1, 2, 3, 4, 8, - 16, 32, 64, 128, 256}; - if (!llvm::is_contained(supportedChunkSizes, chunkSize)) - return emitError() << "invalid chunk size"; - - return success(); -} - //===----------------------------------------------------------------------===// // XeGPU_LayoutAttr //===----------------------------------------------------------------------===// @@ -310,15 +298,16 @@ LogicalResult TensorDescType::verify( llvm::ArrayRef shape, mlir::Type elementType, mlir::Attribute encoding, mlir::Attribute layout) { size_t rank = shape.size(); - if (rank != 1 && rank != 2) - return emitError() << "expected 1D or 2D tensor"; + + if (rank == 0) + return emitError() << "expected non-zero rank tensor"; auto blockAttr = mlir::dyn_cast_if_present(encoding); if (blockAttr) { MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace(); - if (rank == 2 && memorySpaceAttr && + if (rank > 1 && memorySpaceAttr && memorySpaceAttr.getValue() == MemorySpace::SLM) - return emitError() << "SLM is not supported for 2D block tensor"; + return emitError() << "SLM is only supported for 1D block tensor"; } // for gather and scatter ops, Low-precision types are packed in 32-bit units. @@ -329,22 +318,18 @@ LogicalResult TensorDescType::verify( : 1; auto scatterAttr = mlir::dyn_cast_if_present(encoding); if (scatterAttr) { - // Expected tensor ranks for scattered data: - // - 1D tensor for fully non-contiguous elements (chunk size == 1) - // - 2D tensor for scattered blocks (chunk size > 1) - unsigned chunkSize = scatterAttr.getChunkSize().getInt(); + int64_t chunkSize = scatterAttr.getChunkSizeAsInt(); if (rank == 1 && chunkSize != 1) return emitError() << "expected non-contiguous elements for 1D tensor"; - if (rank == 2 && chunkSize < 2) - return emitError() << "expected chunk blocks for 2D tensor"; + // If chunk size > 1, the second dimension of the tensor shape must be - // equal to chunk size and it must be a multiple of the packing factor. + // equal to chunk size and it must be a multiple of the + // chunkAlignmentFactor. if (chunkSize > 1) { if (shape.back() != chunkSize) - return emitError() << "expected tensor shape[1] to match chunk size"; + return emitError() << "expected last dim of tensor to match chunk size"; if (shape.back() % chunkAlignmentFactor != 0) - return emitError() << "expected tensor shape[1] to be a multiple of " - "chunk alignment factor " + return emitError() << "expected last dim of tensor to be a multiple of " << chunkAlignmentFactor; } } @@ -357,17 +342,13 @@ LogicalResult TensorDescType::verify( auto laneData = layoutAttr.getLaneData(); if (scatterAttr && laneData) { // Validate subgroup mapping rules for scattered tensors. - // A work-item's slice of the tensor with shape [sg_size] or - // [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width] - // respectively, the mapping should reflect that. This is because each - // work item access data in 32 bit granularity. - - if (rank > 1 && laneData[0] != 1) + // if chunkSize > 1, the last dimension of the tensor should + // be distributed in the units divisible by chunkAlignmentFactor. + int64_t chunkSize = scatterAttr.getChunkSizeAsInt(); + if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor) return emitError() - << "cannot map over non-contiguous scattered row elements"; - if (laneData[rank - 1] != chunkAlignmentFactor) - return emitError() << "work item data mapping must match the number of " - "contiguous elements"; + << "expected last dim of lane_data to be a multiple of: " + << chunkAlignmentFactor; } if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 2793c7a35bc97..caef13b59f5c8 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -87,9 +87,12 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, return emitError() << "Value should have the same element type as TensorDesc."; - if (tdescShape[0] != maskShape[0]) + llvm::SmallVector expectedMaskShape(tdescShape); + if (chunkSize > 1) + expectedMaskShape.pop_back(); + if (expectedMaskShape != maskShape) return emitError() - << "dim-0 of the Mask and TensorDesc should be the same."; + << "Mask should match TensorDesc except the chunk size dim."; // a valid shape for SIMT case if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) { @@ -203,11 +206,9 @@ LogicalResult CreateNdDescOp::verify() { "is a memref) should match with each other."); // check result TensorDesc rank - invalidRank = (getType().getRank() > 2 || getType().getRank() > rank); - - if (invalidRank) + if (getType().getRank() > rank) return emitOpError( - "Expecting the TensorDesc rank is up to 2 and not greater than the " + "Expecting the TensorDesc rank is not greater than the " "ranks of shape, strides, offsets or the memref source."); if (invalidElemTy) @@ -247,9 +248,6 @@ LogicalResult LoadNdOp::verify() { auto tdescTy = getTensorDescType(); auto valueTy = getType(); - if (tdescTy.getRank() > 2) - return emitOpError("Expecting a 1D/2D TensorDesc.\n"); - if (tdescTy.isScattered()) return emitOpError("Expects a non-scattered TensorDesc.\n"); @@ -316,15 +314,13 @@ LogicalResult LoadNdOp::verify() { } auto array_len = tdescTy.getArrayLength(); - if (array_len > 1) { + if (array_len > 1) tdescShape.insert(tdescShape.begin(), array_len); - } - if (tdescShape != valueShape) { + if (tdescShape != valueShape) return emitOpError() << "Result shape " << makeString(valueShape) << " is not consistent with tensor descriptor " << tdescTy; - } return success(); } @@ -336,9 +332,6 @@ LogicalResult StoreNdOp::verify() { auto dstTy = getTensorDescType(); // Tile auto valTy = getValueType(); // Vector - if (dstTy.getRank() > 2) - return emitOpError("Expecting a 1D/2D TensorDesc.\n"); - if (dstTy.isScattered()) return emitOpError("Expects a non-scattered TensorDesc.\n"); @@ -370,22 +363,21 @@ LogicalResult StoreNdOp::verify() { return emitOpError() << "TensorDesc doesn't need LayoutAttr for SIMT code"; - if (tdescElems % valueElems) { + if (tdescElems % valueElems) return emitOpError() << "Value shape " << makeString(getShapeOf(valTy)) << " is not a valid distribution for tensor descriptor " << dstTy; - } + return success(); } // SIMD code should have the same shape as the tensor descriptor. auto tdescShape = getShapeOf(dstTy); auto valueShape = getShapeOf(valTy); - if (tdescShape != valueShape) { + if (tdescShape != valueShape) return emitOpError() << "Value shape " << makeString(valueShape) << " is not consistent with tensor descriptor " << dstTy; - } return success(); } @@ -450,24 +442,7 @@ LogicalResult CreateDescOp::verify() { // check total size auto chunkSize = tdescTy.getChunkSize(); - auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth(); - auto bitsPerLane = elemBits * chunkSize; - if (chunkSize > 1 && bitsPerLane % 32) { - // For 8-bit and 16-bit data, the hardware only supports chunk size of 1. - // For 32-bit data, the hardware can support larger larger chunk size. So - // we can bitcast 8-bit/16-bit data to 32-bit data for better performance. - // But this requires the total size is 32 bit aligned to make the - // optimization work. - return emitOpError( - "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned."); - } - - auto lscConstraints = 512 * 8; // each access is upto 512 bytes. - if (elemBits * tdescTy.getNumElements() > lscConstraints) - return emitOpError("total access size (simd_lanes * chunk_size * " - "sizeof(elemTy)) is upto 512 bytes."); - - SmallVector shape({(int64_t)getNumOffsets()}); + SmallVector shape(getOffsetsType().getShape()); if (chunkSize != 1) shape.push_back(chunkSize); @@ -563,6 +538,23 @@ void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, build(builder, state, tensorDesc, ofrs); } +LogicalResult UpdateOffsetOp::verify() { + auto tdescTy = getTensorDescType(); + if (!tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc.\n"); + + SmallVector expectedOffsetShape = getShapeOf(tdescTy); + SmallVector offsetShape = getShapeOf(getOffsetsType()); + if (tdescTy.getChunkSize() > 1) + expectedOffsetShape.pop_back(); + + if (expectedOffsetShape != offsetShape) + return emitOpError( + "Offsets should match TensorDesc except the chunk size dim."); + + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_DpasOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index 3950e8f70d1ca..c6c4e3aaa41ed 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -303,9 +303,7 @@ void XeGPUBlockingPass::runOnOperation() { // If the encoding is a ScatterTensorDescAttr, we need to // potentially adjust the chunk size based on the inst_data. if (tdescTy.isScattered()) { - auto scatterAttr = - llvm::dyn_cast_if_present(encoding); - int64_t chunkSize = scatterAttr.getChunkSize().getInt(); + int64_t chunkSize = tdescTy.getChunkSize(); if (chunkSize > 1) { int64_t blockedChunkSize = chunkSize; @@ -315,7 +313,7 @@ void XeGPUBlockingPass::runOnOperation() { // To create a new attribute with a different chunk_size: auto newEncoding = xegpu::ScatterTensorDescAttr::get( - ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize); + ctx, tdescTy.getMemorySpace(), blockedChunkSize); encoding = newEncoding; } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index 2c48a735bf956..66690f9e9a91a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -625,9 +625,6 @@ struct UnrollUpdateOffsetOp : public UnrollPattern { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (tdescTy.getRank() > 2) - return failure(); - if (!tdescTy.isScattered()) return failure(); diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index a2778cd94d963..77918c66b82af 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -2,7 +2,7 @@ // ----- func.func @create_nd_tdesc_vc_1(%src: memref<24xf32>) { - // expected-error@+1 {{Expecting the TensorDesc rank is up to 2 and not greater than the ranks of shape, strides, offsets or the memref source}} + // expected-error@+1 {{Expecting the TensorDesc rank is not greater than the ranks of shape, strides, offsets or the memref source}} %1 = xegpu.create_nd_tdesc %src[0] : memref<24xf32> -> !xegpu.tensor_desc<8x16xf32> return } @@ -17,7 +17,7 @@ func.func @create_nd_tdesc_vc_2(%src: memref<24x32xf32>) { // ----- func.func @create_nd_tdesc_vc_3(%src: memref<2x24x32xf32, 3>) { - // expected-error@+1 {{SLM is not supported for 2D block tensor}} + // expected-error@+1 {{SLM is only supported for 1D block tensor}} %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> return } @@ -199,15 +199,6 @@ func.func @create_tdesc_vc_1(%src: ui64) { return } -// ----- -func.func @create_tdesc_vc_2(%src: ui64) { - %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex> - %1 = xegpu.create_tdesc %src, %0 : ui64, vector<8xindex> - // expected-error@+1 {{expected chunk blocks for 2D tensor}} - -> !xegpu.tensor_desc<8x4xf16, #xegpu.scatter_tdesc_attr<>> - return -} - // ----- func.func @create_tdesc_vc_3(%src: memref) { %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> @@ -221,25 +212,16 @@ func.func @create_tdesc_vc_3(%src: memref) { func.func @create_tdesc_vc_4(%src: memref) { %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> %1 = xegpu.create_tdesc %src, %0 : memref, vector<4xindex> - // expected-error@+1 {{invalid chunk size}} - -> !xegpu.tensor_desc<4x5xf32, #xegpu.scatter_tdesc_attr> - return -} - -// ----- -func.func @create_tdesc_vc_5(%src: memref) { - %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - %1 = xegpu.create_tdesc %src, %0 : memref, vector<4xindex> - // expected-error@+1 {{expected tensor shape[1] to match chunk size}} + // expected-error@+1 {{expected last dim of tensor to match chunk size}} -> !xegpu.tensor_desc<4x5xf32, #xegpu.scatter_tdesc_attr> return } // ----- -func.func @create_tdesc_vc_6(%src: memref) { +func.func @create_tdesc_vc_5(%src: memref) { %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> %1 = xegpu.create_tdesc %src, %0 : memref, vector<4xindex> - // expected-error@+1 {{tensor shape[1] to be a multiple of chunk alignment factor 2}} + // expected-error@+1 {{last dim of tensor to be a multiple of 2}} -> !xegpu.tensor_desc<4x3xf16, #xegpu.scatter_tdesc_attr> return } @@ -267,23 +249,15 @@ func.func @prefetch_vc_2(%src: ui64) { func.func @create_tdesc_layout_1(%src: ui64) { %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> // expected-error@+1 {{expected layout rank to match tensor rank}} - %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> return } // ----- func.func @create_tdesc_layout_2(%src: ui64) { %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - // expected-error@+1 {{cannot map over non-contiguous scattered row elements}} - %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> - return -} - -// ----- -func.func @create_tdesc_layout_3(%src: ui64) { - %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - // expected-error@+1 {{work item data mapping must match the number of contiguous elements}} - %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x3xf32, #xegpu.scatter_tdesc_attr, #xegpu.layout> + // expected-error@+1 {{expected last dim of lane_data to be a multiple of: 2}} + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x4xf16, #xegpu.scatter_tdesc_attr, #xegpu.layout> return } @@ -406,18 +380,10 @@ func.func @atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector<16xi return } -// ----- -func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) { - %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> - // expected-error@+1 {{expected 1D or 2D tensor}} - !xegpu.tensor_desc<16x2x2xf32> - return -} - // ----- func.func @tensor_desc_invalid_rank_1(%src: memref<24x32xf32>) { %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> - // expected-error@+1 {{expected 1D or 2D tensor}} + // expected-error@+1 {{expected non-zero rank tensor}} !xegpu.tensor_desc return } @@ -470,27 +436,6 @@ func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) { return } -// ----- -func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) { - %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> - // expected-error@+1 {{cannot map over non-contiguous scattered row elements}} - !xegpu.tensor_desc<4x2xf32, - #xegpu.scatter_tdesc_attr, - #xegpu.layout> - return -} - -// ----- -func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) { - %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> -> - // expected-error@+1 {{work item data mapping must match the number of contiguous elements}} - !xegpu.tensor_desc<16xf32, - #xegpu.scatter_tdesc_attr, - #xegpu.layout> - return -} - // ----- func.func @tensor_desc_scatter_invalid_chunk_size_1D(%src: ui64, %offsets: vector<16xindex>) { %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> -> @@ -504,9 +449,9 @@ func.func @tensor_desc_scatter_invalid_chunk_size_1D(%src: ui64, %offsets: vecto // ----- func.func @tensor_desc_scatter_invalid_chunk_size_2D(%src: ui64, %offsets: vector<16xindex>) { %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> -> - // expected-error@+1 {{expected chunk blocks for 2D tensor}} + // expected-error@+1 {{expected last dim of tensor to match chunk size}} !xegpu.tensor_desc<16x2xf32, - #xegpu.scatter_tdesc_attr, + #xegpu.scatter_tdesc_attr, #xegpu.layout> return } diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index aff8f63adc05b..252c6eeaaf6ec 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -54,6 +54,13 @@ gpu.func @create_nd_tdesc_6(%src: memref<24x32xf32>) { gpu.return } +// CHECK: gpu.func @create_nd_tdesc_7(%[[arg0:.*]]: memref<8x24x32x48x64xf32>) { +gpu.func @create_nd_tdesc_7(%src: memref<8x24x32x48x64xf32>) { + // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf32> -> !xegpu.tensor_desc<8x8x8x24x32xf32> + %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf32> -> !xegpu.tensor_desc<8x8x8x24x32xf32> + gpu.return +} + // CHECK: gpu.func @prefetch_nd(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @prefetch_nd(%src: memref<24x32xf16>) { @@ -64,6 +71,14 @@ gpu.func @prefetch_nd(%src: memref<24x32xf16>) { gpu.return } +// CHECK: gpu.func @prefetch_nd_2(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) { +gpu.func @prefetch_nd_2(%src: memref<8x24x32x48x64xf16>) { + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16> + %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16> + // CHECK: xegpu.prefetch_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<1x2x4x8x16xf16> + xegpu.prefetch_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: !xegpu.tensor_desc<1x2x4x8x16xf16> + gpu.return +} // CHECK: func @subgroup_load_nd(%[[arg0:.*]]: memref<8x16xf16>) { gpu.func @subgroup_load_nd(%src: memref<8x16xf16>) { @@ -213,6 +228,15 @@ gpu.func @simt_load_nd_8(%src: memref<24x32xf32>) { gpu.return } +// CHECK: func @subgroup_load_nd_9(%[[arg0:.*]]: memref<4x8x16xf16>) { +gpu.func @subgroup_load_nd_9(%src: memref<4x8x16xf16>) { + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0, 0] : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16> + %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16> + // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4x8x16xf16> -> vector<4x8x16xf16> + %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4x8x16xf16> -> vector<4x8x16xf16> + gpu.return +} + // CHECK: func @subgroup_store_nd(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @subgroup_store_nd(%dst: memref<24x32xf16>) { // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16> @@ -257,6 +281,17 @@ gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) { gpu.return } +// CHECK: func @subgroup_store_nd_3(%[[arg0:.*]]: memref<8x24x32xf16>) { +gpu.func @subgroup_store_nd_3(%dst: memref<8x24x32xf16>) { + // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<8x24x32xf16> + %1 = arith.constant dense<1.0>: vector<8x24x32xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0] : memref<8x24x32xf16> -> !xegpu.tensor_desc<8x24x32xf16> + %2 = xegpu.create_nd_tdesc %dst[0, 0, 0] : memref<8x24x32xf16> -> !xegpu.tensor_desc<8x24x32xf16> + // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<8x24x32xf16>, !xegpu.tensor_desc<8x24x32xf16> + xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<8x24x32xf16>, !xegpu.tensor_desc<8x24x32xf16> + gpu.return +} + // CHECK: gpu.func @update_nd_tdesc(%[[arg0:.*]]: memref<24x32xf32>) { gpu.func @update_nd_tdesc(%src: memref<24x32xf32>) { // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> @@ -266,6 +301,14 @@ gpu.func @update_nd_tdesc(%src: memref<24x32xf32>) { gpu.return } +// CHECK: gpu.func @update_nd_tdesc_2(%[[arg0:.*]]: memref<8x24x32xf32>) { +gpu.func @update_nd_tdesc_2(%src: memref<8x24x32xf32>) { + // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0, 0] : memref<8x24x32xf32> -> !xegpu.tensor_desc<2x8x16xf32> + %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<8x24x32xf32> -> !xegpu.tensor_desc<2x8x16xf32> + // CHECK: %[[R1:.*]] = xegpu.update_nd_offset %[[REG]], [0, 0, 16] : !xegpu.tensor_desc<2x8x16xf32> + %2 = xegpu.update_nd_offset %1, [0, 0, 16]: !xegpu.tensor_desc<2x8x16xf32> + gpu.return +} // CHECK: gpu.func @create_tdesc(%[[arg0:.*]]: ui64) { gpu.func @create_tdesc(%src: ui64) { @@ -291,8 +334,8 @@ gpu.func @create_tdesc_1(%src: memref) { gpu.func @create_tdesc_2(%src: memref) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<> - %1 = xegpu.create_tdesc %src, %0 : memref, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>> + %1 = xegpu.create_tdesc %src, %0 : memref, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>> gpu.return } @@ -306,6 +349,15 @@ gpu.func @create_tdesc_3(%src: ui64) { gpu.return } +// CHECK: gpu.func @create_tdesc_4(%[[arg0:.*]]: ui64) { +gpu.func @create_tdesc_4(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<{{.*}}> : vector<2x4xindex> + %0 = arith.constant dense<[[0, 8, 16, 24], [32, 40, 48, 56]]> : vector<2x4xindex> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<2x4xindex> -> !xegpu.tensor_desc<2x4x2xf16, #xegpu.scatter_tdesc_attr> + %1 = xegpu.create_tdesc %src, %0 : ui64, vector<2x4xindex> -> !xegpu.tensor_desc<2x4x2xf16, #xegpu.scatter_tdesc_attr> + gpu.return +} + // CHECK: gpu.func @subgroup_load(%[[arg0:.*]]: ui64) { gpu.func @subgroup_load(%src: ui64) { @@ -385,6 +437,19 @@ gpu.func @simt_load_3(%src: ui64) { gpu.return } +// CHECK: gpu.func @subgroup_load_4(%[[arg0:.*]]: ui64) { +gpu.func @subgroup_load_4(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<{{.*}}> : vector<2x4xindex> + %0 = arith.constant dense<[[0, 8, 16, 24], [32, 40, 48, 56]]> : vector<2x4xindex> + //CHECK: %[[cst1:.*]] = arith.constant dense : vector<2x4xi1> + %1 = arith.constant dense<1>: vector<2x4xi1> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<2x4xindex> -> !xegpu.tensor_desc<2x4x8xf16, #xegpu.scatter_tdesc_attr> + %2 = xegpu.create_tdesc %src, %0 : ui64, vector<2x4xindex> -> !xegpu.tensor_desc<2x4x8xf16, #xegpu.scatter_tdesc_attr> + //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<2x4x8xf16, #xegpu.scatter_tdesc_attr>, vector<2x4xi1> -> vector<2x4x8xf16> + %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<2x4x8xf16, #xegpu.scatter_tdesc_attr>, vector<2x4xi1> -> vector<2x4x8xf16> + gpu.return +} + // CHECK: gpu.func @subgroup_store(%[[arg0:.*]]: ui64) { gpu.func @subgroup_store(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> @@ -475,6 +540,21 @@ gpu.func @simt_store_3(%src: ui64) { gpu.return } +// CHECK: gpu.func @subgroup_store_4(%[[arg0:.*]]: ui64) { +gpu.func @subgroup_store_4(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<{{.*}}> : vector<2x4xindex> + %0 = arith.constant dense<[[0, 8, 16, 24], [32, 40, 48, 56]]> : vector<2x4xindex> + //CHECK: %[[cst1:.*]] = arith.constant dense : vector<2x4xi1> + %1 = arith.constant dense<1>: vector<2x4xi1> + //CHECK: %[[cst2:.*]] = arith.constant {{.*}} : vector<2x4xf32> + %2 = arith.constant dense<2.9>: vector<2x4xf32> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<2x4xindex> -> !xegpu.tensor_desc<2x4xf32, #xegpu.scatter_tdesc_attr<>> + %3 = xegpu.create_tdesc %src, %0 : ui64, vector<2x4xindex> -> !xegpu.tensor_desc<2x4xf32, #xegpu.scatter_tdesc_attr<>> + //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<2x4xf32>, !xegpu.tensor_desc<2x4xf32, #xegpu.scatter_tdesc_attr<>>, vector<2x4xi1> + xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<2x4xf32>, !xegpu.tensor_desc<2x4xf32, #xegpu.scatter_tdesc_attr<>>, vector<2x4xi1> + gpu.return +} + // CHECK: gpu.func @prefetch(%[[arg0:.*]]: ui64) { gpu.func @prefetch(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir index ac5fe89a67f9a..7da336272555e 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir @@ -358,8 +358,8 @@ gpu.module @test_kernel { // CHECK-SAME: [[arg0:%.+]]: ui64 // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> - // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex> - // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> + // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex> + // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> // CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> gpu.func @test_prefetch_load_store_update(%src: ui64) { @@ -406,8 +406,8 @@ gpu.module @test_kernel { // CHECK-SAME: [[arg0:%.+]]: ui64 // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> // CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> - // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xindex> - // CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x2xf32> + // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xindex> + // CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x2xf32> // CHECK-COUNT-4: xegpu.store {{.*}} : vector<16x2xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> gpu.func @test_prefetch_load_store_update_chunk(%src: ui64) { @@ -446,4 +446,115 @@ gpu.module @test_kernel { } } +// ----- +#l = #xegpu.layout +gpu.module @test_kernel { + // CHECK-LABEL: test_3d_block_tensor_desc + // CHECK-SAME: [[arg0:%.+]]: memref<1024x1024x1024xf16>, [[arg1:%.+]]: memref<1024x1024x1024xf16>, [[arg2:%.+]]: memref<1024x1024x1024xf16> + gpu.func @test_3d_block_tensor_desc(%A: memref<1024x1024x1024xf16>, %B: memref<1024x1024x1024xf16>, %C: memref<1024x1024x1024xf16>) { + //CHECK: [[c24:%.*]] = arith.constant 24 : index + //CHECK: [[c8:%.*]] = arith.constant 8 : index + //CHECK: [[c16:%.*]] = arith.constant 16 : index + //CHECK: [[c0:%.*]] = arith.constant 0 : index + //CHECK: [[c32:%.*]] = arith.constant 32 : index + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + //CHECK: [[block_id_x:%.*]] = gpu.block_id x + //CHECK: [[m:%.*]] = arith.muli [[block_id_x]], [[c32]] : index + %block_id_x = gpu.block_id x + %m = arith.muli %block_id_x, %c32 : index + + //CHECK: xegpu.create_nd_tdesc [[arg0]][[[m]], [[m]], [[c0]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16> + //CHECK: xegpu.create_nd_tdesc [[arg0]][[[m]], [[m]], [[c16]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16> + //CHECK: [[off1:%.*]] = arith.addi [[m]], [[c8]] : index + //CHECK: xegpu.create_nd_tdesc [[arg0]][[[off1]], [[m]], [[c0]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16> + //CHECK: [[off2:%.*]] = arith.addi [[m]], [[c8]] : index + //CHECK: xegpu.create_nd_tdesc [[arg0]][[[off2]], [[m]], [[c16]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16> + //CHECK: [[off3:%.*]] = arith.addi [[m]], [[c16]] : index + //CHECK: xegpu.create_nd_tdesc [[arg0]][[[off3]], [[m]], [[c0]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16> + //CHECK: [[off4:%.*]] = arith.addi [[m]], [[c16]] : index + //CHECK: xegpu.create_nd_tdesc [[arg0]][[[off4]], [[m]], [[c16]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16> + //CHECK: [[off5:%.*]] = arith.addi [[m]], [[c24]] : index + //CHECK: xegpu.create_nd_tdesc [[arg0]][[[off5]], [[m]], [[c0]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16> + //CHECK: [[off6:%.*]] = arith.addi [[m]], [[c24]] : index + //CHECK: xegpu.create_nd_tdesc [[arg0]][[[off6]], [[m]], [[c16]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16> + + %a_tdesc = xegpu.create_nd_tdesc %A[%m, %m, %c0] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<32x32x32xf16, #l> + %b_tdesc = xegpu.create_nd_tdesc %B[%m, %m, %c0] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<32x32x32xf16, #l> + %c_tdesc = xegpu.create_nd_tdesc %C[%m, %m, %c0] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<32x32x32xf16, #l> + %out:3 = scf.for %k = %c0 to %c1024 step %c32 + iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc) + -> (!xegpu.tensor_desc<32x32x32xf16, #l>, !xegpu.tensor_desc<32x32x32xf16, #l>, !xegpu.tensor_desc<32x32x32xf16, #l>) { + //CHECK-COUNT-16: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x32x16xf16> -> vector<8x32x16xf16> + %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<32x32x32xf16, #l> -> vector<32x32x32xf16> + %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<32x32x32xf16, #l> -> vector<32x32x32xf16> + + //CHECK-COUNT-8: arith.addf {{.*}} : vector<8x32x16xf16> + %c = arith.addf %a, %b {layout_result_0 = #l} : vector<32x32x32xf16> + + //CHECK-COUNT-8: xegpu.store_nd {{.*}} : vector<8x32x16xf16>, !xegpu.tensor_desc<8x32x16xf16> + xegpu.store_nd %c, %arg2: vector<32x32x32xf16>, !xegpu.tensor_desc<32x32x32xf16, #l> + + //CHECK-COUNT-24: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x32x16xf16> + %a_next_tdesc = xegpu.update_nd_offset %arg0, [0, 0, %c32] : !xegpu.tensor_desc<32x32x32xf16, #l> + %b_next_tdesc = xegpu.update_nd_offset %arg1, [0, 0, %c32] : !xegpu.tensor_desc<32x32x32xf16, #l> + %c_next_tdesc = xegpu.update_nd_offset %arg2, [0, 0, %c32] : !xegpu.tensor_desc<32x32x32xf16, #l> + scf.yield %a_next_tdesc, %b_next_tdesc, %c_next_tdesc + : !xegpu.tensor_desc<32x32x32xf16, #l>, !xegpu.tensor_desc<32x32x32xf16, #l>, !xegpu.tensor_desc<32x32x32xf16, #l> + } + gpu.return + } +} + +// ----- +#l = #xegpu.layout +gpu.module @test_kernel { + // CHECK-LABEL: test_3d_scattered_tensor_desc + // CHECK-SAME: [[arg0:%.+]]: ui64 + // CHECK: [[cst_1:%.+]] = arith.constant dense<{{.*}}[130, 138, 146, 154, 162, 170, 178, 186], [194, 202, 210, 218, 226, 234, 242, 250]]> : vector<2x8xindex> + // CHECK: [[cst_2:%.+]] = arith.constant dense<{{.*}}[2, 10, 18, 26, 34, 42, 50, 58], [66, 74, 82, 90, 98, 106, 114, 122]]> : vector<2x8xindex> + // CHECK: [[cst_3:%.+]] = arith.constant dense<{{.*}}[0, 8, 16, 24, 32, 40, 48, 56], [64, 72, 80, 88, 96, 104, 112, 120]]> : vector<2x8xindex> + // CHECK: [[cst_4:%.+]] = arith.constant dense<{{.*}}[128, 136, 144, 152, 160, 168, 176, 184], [192, 200, 208, 216, 224, 232, 240, 248]]> : vector<2x8xindex> + // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<2x8xindex> -> !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr>, vector<2x8xindex> + // CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr>, vector<2x8xi1> -> vector<2x8x2xf32> + // CHECK-COUNT-4: xegpu.store {{.*}} : vector<2x8x2xf32>, !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr>, vector<2x8xi1> + + + gpu.func @test_3d_scattered_tensor_desc(%src: ui64) { + %cst = arith.constant dense<[ + [0, 8, 16, 24, 32, 40, 48, 56], + [64, 72, 80, 88, 96, 104, 112, 120], + [128, 136, 144, 152, 160, 168, 176, 184], + [192, 200, 208, 216, 224, 232, 240, 248] + ]> : vector<4x8xindex> + + %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<4x8xindex> -> !xegpu.tensor_desc<4x8x4xf32, #xegpu.scatter_tdesc_attr, #l> + xegpu.prefetch %tdesc: !xegpu.tensor_desc<4x8x4xf32, #xegpu.scatter_tdesc_attr, #l> + + %delta = arith.constant dense<[ + [32, 32, 32, 32, 32, 32, 32, 32], + [32, 32, 32, 32, 32, 32, 32, 64], + [128, 128, 128, 128, 128, 128, 128, 128], + [128, 128, 128, 128, 128, 128, 128, 256] + ]> : vector<4x8xindex> + %new_tdesc = xegpu.update_offset %tdesc, %delta + : !xegpu.tensor_desc<4x8x4xf32, #xegpu.scatter_tdesc_attr, #l>, vector<4x8xindex> + + %c4 = arith.constant 4: index + %mask = vector.create_mask %c4, %c4: vector<4x8xi1> + + %ld_vec = xegpu.load %new_tdesc, %mask <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: !xegpu.tensor_desc<4x8x4xf32, #xegpu.scatter_tdesc_attr, #l>, vector<4x8xi1> -> vector<4x8x4xf32> + + %st_vec = arith.addf %ld_vec, %ld_vec {layout_result_0 = #l} : vector<4x8x4xf32> + xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: + vector<4x8x4xf32>, + !xegpu.tensor_desc<4x8x4xf32, #xegpu.scatter_tdesc_attr, #l>, + vector<4x8xi1> + gpu.return + } +} diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index c84eb74198544..335f89f1826aa 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -107,10 +107,7 @@ struct TestXeGPUUnrollingPatterns // If the encoding is a ScatterTensorDescAttr, we need to // potentially adjust the chunk size based on the inst_data. if (tdescTy.isScattered()) { - auto scatterAttr = - llvm::dyn_cast_if_present( - encoding); - int64_t chunkSize = scatterAttr.getChunkSize().getInt(); + int64_t chunkSize = tdescTy.getChunkSize(); if (chunkSize > 1) { int64_t blockedChunkSize = chunkSize; @@ -120,8 +117,7 @@ struct TestXeGPUUnrollingPatterns // To create a new attribute with a different chunk_size: auto newEncoding = xegpu::ScatterTensorDescAttr::get( - ctx, scatterAttr.getMemorySpace().getValue(), - blockedChunkSize); + ctx, tdescTy.getMemorySpace(), blockedChunkSize); encoding = newEncoding; }