diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h index 7b53594a1c8e2..aefa50947f758 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h @@ -128,7 +128,8 @@ struct MMAMatrixStorageType : public TypeStorage { /// : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32> // TODO: consider moving this to ODS. class MMAMatrixType - : public Type::TypeBase { + : public Type::TypeBase { public: using Base::Base; @@ -163,6 +164,9 @@ class MMAMatrixType /// Get elementType of a single element. Type getElementType() const; + /// Implementation for MemRefElementTypeInterface. + unsigned getAnalysisSizeInBytes() const; + /// The general form of operation this type supports is given by the equation /// C += A*B. This function returns which operand in the given equation is /// held by this type. String returned can be one of"AOp", "BOp" and "COp". diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td index 857e68cec8c76..3f8bb0c6ea90a 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td @@ -62,6 +62,10 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [ return $_get(memorySpace.getContext(), memorySpace); }]> ]; + let extraClassDeclaration = [{ + /// Best effort size for analysis purposes. + unsigned getAnalysisSizeInBytes() { return 8; } + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td index 8aa2c55570153..001d0d9f3e756 100644 --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -74,10 +74,20 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> { For example, scalar values such as integers can implement this interface, but indicator types such as `void` or `unit` should not. - The interface currently has no methods and is used by types to opt into - being memref elements. This may change in the future, in particular to - require types to provide their size or alignment given a data layout. + The interface currently has one method and is mainly used by types to opt + into being memref elements. This may change in the future, in particular to + require types to provide actual size or alignment given a data layout. }]; + + let methods = [ + InterfaceMethod<[{ + Returns the size of the element type in bytes for purposes such as + analysis. Such a size is meant to be used in analysis costs models as a + best effort in the absence of data layout, as opposed to for + target-specific lowering which would require a data layout. + }], + "unsigned", "getAnalysisSizeInBytes">, + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp index 86aba7b187535..312eaedaa13c3 100644 --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -1341,6 +1341,9 @@ mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) { vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); else return std::nullopt; + } else if (auto memrefEltType = dyn_cast( + memRefType.getElementType())) { + sizeInBits = memrefEltType.getAnalysisSizeInBytes() * 8; } else { return std::nullopt; } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 976432ea37120..04b8c901b50da 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -149,6 +149,13 @@ bool MMAMatrixType::isValidElementType(Type elementType) { elementType.isInteger(32); } +unsigned MMAMatrixType::getAnalysisSizeInBytes() const { + // The underlying element type is expected to always be int or float and + // typically divisible by 8 bits. + return ShapedType::getNumElements(getShape()) * + llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8); +} + LogicalResult MMAMatrixType::verifyInvariants(function_ref emitError, ArrayRef shape, Type elementType, diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir index 4b9eca45492fb..e948f8ad74bc9 100644 --- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir +++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir @@ -666,3 +666,31 @@ func.func @unrolled(%arg0: memref<2x4xf32>, %arg1: memref<1x2x4xf32>) { // PRODUCER-CONSUMER-MAXIMAL: affine.load %{{.*}}[0, %{{.*}}, %{{.*}}] return } + +// Test for fusion of affine load/store on memrefs of MMA type. + +// PRODUCER-CONSUMER-LABEL: func @gpu_mma_cast +func.func @gpu_mma_cast(%a: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>, %b: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>, %c: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>) { + affine.for %i = 0 to 8 { + affine.for %j = 0 to 4 { + %v = affine.load %a[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3> + affine.store %v, %b[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3> + } + } + + affine.for %i = 0 to 8 { + affine.for %j = 0 to 4 { + %v = affine.load %b[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3> + affine.store %v, %c[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3> + } + } + // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 8 { + // PRODUCER-CONSUMER-NEXT: affine.for %{{.*}} = 0 to 4 { + // PRODUCER-CONSUMER-NEXT: affine.load + // PRODUCER-CONSUMER-NEXT: affine.store + // PRODUCER-CONSUMER-NEXT: affine.load + // PRODUCER-CONSUMER-NEXT: affine.store + + return + // PRODUCER-CONSUMER: return +} diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index f1c31658c13ac..c3aac18917ba7 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -169,6 +169,10 @@ def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [ def TestMemRefElementType : Test_Type<"TestMemRefElementType", [MemRefElementTypeInterface]> { let mnemonic = "memref_element"; + + let extraClassDeclaration = [{ + unsigned getAnalysisSizeInBytes() const { return 1; } + }]; } def TestTypeTrait : NativeTypeTrait<"TestTypeTrait">;