Skip to content

Commit 8a0fe9c

Browse files
Address review
1 parent f2b88a1 commit 8a0fe9c

File tree

7 files changed

+36
-24
lines changed

7 files changed

+36
-24
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -257,16 +257,9 @@ def GPUMappingMaskAttr : GPU_Attr<"GPUMappingMask", "mask", [
257257
let parameters = (ins "uint64_t":$mask);
258258
let assemblyFormat = "`<` params `>`";
259259
let description = [{
260-
Attribute describing how to filter the processing units that a
261-
region is mapped to.
262-
263-
In the first implementation the masking is a bitfield that specifies for
264-
each processing unit whether it is active or not.
265-
266-
In the future, we may want to implement this as a symbol to refer to
267-
dynamically defined values.
268-
269-
Extending op semantics with an operand is deemed too intrusive at this time.
260+
Attribute describing how to filter the processing units that a region is
261+
mapped to. The masking is a bitfield that specifies for each processing
262+
unit whether it is active or not.
270263
}];
271264
}
272265

mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
#include "mlir/Dialect/SCF/IR/SCF.h"
1313
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
14-
#include "mlir/IR/OpImplementation.h"
1514
#include "mlir/IR/PatternMatch.h"
1615

1716
namespace mlir {
@@ -57,7 +56,7 @@ mapForallToBlocksImpl(RewriterBase &rewriter, TransformOpInterface transformOp,
5756
DiagnosedSilenceableFailure
5857
mapOneForallToThreadsImpl(RewriterBase &rewriter,
5958
std::optional<TransformOpInterface> transformOp,
60-
scf::ForallOp forallOp, ArrayRef<int64_t> blockDims,
59+
scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
6160
int64_t warpSize, bool syncAfterDistribute);
6261

6362
/// Search `scf.forall` ops nested under `target` and map each such op to an

mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ struct GpuIdBuilder {
7777
/// used for indexing rewrites as well as 3D sizes for predicate generation.
7878
/// If `useLinearMapping` is true, the `idBuilder` method returns nD values
7979
/// used for indexing rewrites as well as 1D sizes for predicate generation.
80+
/// If `mask` is provided, it will be used to filter the active blocks.
8081
struct GpuBlockIdBuilder : public GpuIdBuilder {
8182
GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping = false,
8283
DeviceMaskingAttrInterface mask = nullptr);
@@ -87,6 +88,7 @@ struct GpuBlockIdBuilder : public GpuIdBuilder {
8788
/// used for indexing rewrites as well as 3D sizes for predicate generation.
8889
/// If `useLinearMapping` is true, the `idBuilder` method returns nD values
8990
/// used for indexing rewrites as well as 1D sizes for predicate generation.
91+
/// If `mask` is provided, it will be used to filter the active warpgroups.
9092
struct GpuWarpgroupIdBuilder : public GpuIdBuilder {
9193
GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize,
9294
bool useLinearMapping = false,
@@ -101,6 +103,7 @@ struct GpuWarpgroupIdBuilder : public GpuIdBuilder {
101103
/// used for indexing rewrites as well as 3D sizes for predicate generation.
102104
/// If `useLinearMapping` is true, the `idBuilder` method returns nD values
103105
/// used for indexing rewrites as well as 1D sizes for predicate generation.
106+
/// If `mask` is provided, it will be used to filter the active warps.
104107
struct GpuWarpIdBuilder : public GpuIdBuilder {
105108
GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize,
106109
bool useLinearMapping = false,
@@ -113,6 +116,7 @@ struct GpuWarpIdBuilder : public GpuIdBuilder {
113116
/// used for indexing rewrites as well as 3D sizes for predicate generation.
114117
/// If `useLinearMapping` is true, the `idBuilder` method returns nD values
115118
/// used for indexing rewrites as well as 1D sizes for predicate generation.
119+
/// If `mask` is provided, it will be used to filter the active threads.
116120
struct GpuThreadIdBuilder : public GpuIdBuilder {
117121
GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false,
118122
DeviceMaskingAttrInterface mask = nullptr);
@@ -122,6 +126,7 @@ struct GpuThreadIdBuilder : public GpuIdBuilder {
122126
/// The `idBuilder` method returns nD values used for indexing rewrites as well
123127
/// as 1D sizes for predicate generation.
124128
/// This `useLinearMapping` case is the only supported case.
129+
/// If `mask` is provided, it will be used to filter the active lanes.
125130
struct GpuLaneIdBuilder : public GpuIdBuilder {
126131
GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize, bool unused,
127132
DeviceMaskingAttrInterface mask = nullptr);

mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,35 @@ def DeviceMaskingAttrInterface : AttrInterface<"DeviceMaskingAttrInterface"> {
6666
Attribute interface describing how to filter the processing units that a
6767
region is mapped to.
6868

69-
A popcount can be applied to determine the logical linear index that a
70-
physical processing unit is responsible for.
69+
For instance, consider the following example mask which specifies processing
70+
units 2, 4 and 5 are active:
71+
```
72+
8 4 0
73+
mask : 0 0 0 1 1 0 1 0 0
74+
```
75+
The logical ID for an active processing unit is defined as its position
76+
relative to the other active processing units. In this example, we have:
77+
```
78+
Processing Unit LogicalID
79+
0 N/A
80+
1 N/A
81+
2 0
82+
3 N/A
83+
4 1
84+
5 2
85+
6 N/A
86+
7 N/A
87+
```
7188
}];
7289

7390
let methods = [
7491
InterfaceMethod<
7592
/*desc=*/[{
76-
Return the logical active id for a given physical id.
93+
Create the logical active id for a given physical id.
7794
Expects a physicalLinearMappingId of I64Type.
7895
}],
7996
/*retTy=*/"Value",
80-
/*methodName=*/"getLogicalLinearMappingId",
97+
/*methodName=*/"createLogicalLinearMappingId",
8198
/*args=*/(ins "OpBuilder&":$builder, "Value":$physicalLinearMappingId)
8299
>,
83100
InterfaceMethod<
@@ -87,7 +104,7 @@ def DeviceMaskingAttrInterface : AttrInterface<"DeviceMaskingAttrInterface"> {
87104
Expects a physicalLinearMappingId of I64Type.
88105
}],
89106
/*retTy=*/"Value",
90-
/*methodName=*/"getIsActiveIdPredicate",
107+
/*methodName=*/"createIsActiveIdPredicate",
91108
/*args=*/(ins "OpBuilder&":$builder, "Value":$physicalLinearMappingId)
92109
>,
93110
InterfaceMethod<

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
133133
/// Example filter: 0 0 0 0 1 1 1 1 1
134134
/// Intersection : 0 0 0 0 1 0 1 0 0
135135
/// PopCnt : 2
136-
Value GPUMappingMaskAttr::getLogicalLinearMappingId(
136+
Value GPUMappingMaskAttr::createLogicalLinearMappingId(
137137
OpBuilder &b, Value physicalLinearMappingId) const {
138138
Location loc = physicalLinearMappingId.getLoc();
139139
Value mask = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(getMask()));
@@ -154,7 +154,7 @@ Value GPUMappingMaskAttr::getLogicalLinearMappingId(
154154
/// Example filter: 0 0 0 1 0 0 0 0 0
155155
/// Intersection : 0 0 0 1 0 0 0 0 0
156156
/// Cmp : 1
157-
Value GPUMappingMaskAttr::getIsActiveIdPredicate(
157+
Value GPUMappingMaskAttr::createIsActiveIdPredicate(
158158
OpBuilder &b, Value physicalLinearMappingId) const {
159159
Location loc = physicalLinearMappingId.getLoc();
160160
Value mask = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(getMask()));

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
1313
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1414
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1615
#include "mlir/Dialect/Arith/IR/Arith.h"
17-
#include "mlir/Dialect/Func/IR/FuncOps.h"
1816
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1917
#include "mlir/Dialect/GPU/TransformOps/Utils.h"
2018
#include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -758,7 +756,7 @@ static DiagnosedSilenceableFailure
758756
getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
759757
scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
760758
int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
761-
auto mappingAttr = forallOp.getDeviceMappingAttrs().front();
759+
DeviceMappingAttrInterface mappingAttr = forallOp.getDeviceMappingAttrs().front();
762760
bool useLinearMapping = mappingAttr.isLinearMapping();
763761

764762
// Sanity checks that may result in runtime verification errors.

mlir/lib/Dialect/GPU/TransformOps/Utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ commonLinearIdBuilderFn(int64_t multiplicity = 1,
160160
scaledLinearIdI64 = rewriter.create<arith::IndexCastUIOp>(
161161
loc, rewriter.getI64Type(), scaledLinearId);
162162
Value logicalLinearIdI64 =
163-
mask.getLogicalLinearMappingId(rewriter, scaledLinearIdI64);
163+
mask.createLogicalLinearMappingId(rewriter, scaledLinearIdI64);
164164
scaledLinearId = rewriter.create<arith::IndexCastUIOp>(
165165
loc, rewriter.getIndexType(), logicalLinearIdI64);
166166
LDBG("------adjusting linearId with mask: " << scaledLinearId);
@@ -184,7 +184,7 @@ commonLinearIdBuilderFn(int64_t multiplicity = 1,
184184
// 4. If mask present, it takes precedence to determine predication.
185185
if (mask) {
186186
Value isActiveIdPredicate =
187-
mask.getIsActiveIdPredicate(rewriter, scaledLinearIdI64);
187+
mask.createIsActiveIdPredicate(rewriter, scaledLinearIdI64);
188188
LDBG("------adjusting predicate with mask: " << isActiveIdPredicate);
189189
predicateOps.push_back(isActiveIdPredicate);
190190
} else {

0 commit comments

Comments
 (0)