Skip to content

Commit 1e183f9

Browse files
committed
[NVPTX] add more comments to PerformLoadCombine
1 parent a4bd7e9 commit 1e183f9

File tree

1 file changed

+104
-32
lines changed

1 file changed

+104
-32
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 104 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5205,6 +5205,55 @@ getMachineMemOperandForType(const SelectionDAG &DAG,
52055205
LLT(VT));
52065206
}
52075207

5208+
// These are Combiner rules for expanding v2f32 load results when they are
5209+
// really being used as their individual f32 components. Now that v2f32 is a
5210+
// legal type for a register, LowerFormalArguments() and ReplaceLoadVector()
5211+
// will pack two f32s into a single 64-bit register, leading to ld.b64 instead
5212+
// of ld.v2.f32 or ld.v2.b64 instead of ld.v4.f32. Sometimes this is ideal if
5213+
// the results stay packed because they're passed to another instruction that
5214+
// supports packed f32s (e.g. fmul.f32x2) or (rarely) if v2f32 really is being
5215+
// reinterpreted as an i64, and then stored.
5216+
//
5217+
// Otherwise, SelectionDAG will unpack the results with a sequence of bitcasts,
5218+
// extensions, and extracts if they go through any other kind of instruction.
5219+
// This is not ideal, so we undo these patterns and rewrite the load to output
5220+
// twice as many registers: two f32s for every one i64. This preserves PTX
5221+
// codegen for programs that don't use packed f32s.
5222+
//
5223+
// Also, LowerFormalArguments() and ReplaceLoadVector() happen too early for us
5224+
// to know whether the def-use chain for a particular load will eventually
5225+
// include instructions supporting packed f32s. That is why we prefer to resolve
5226+
// this problem within DAG Combiner.
5227+
//
5228+
// This rule proceeds in three general steps:
5229+
//
5230+
// 1. Identify the pattern, by traversing the def-use chain.
5231+
// 2. Rewrite the load, by splitting each 64-bit result into two f32 registers.
5232+
// 3. Rewrite all uses of the load, including chain and glue uses.
5233+
//
5234+
// This has the effect of combining multiple instructions into a single load.
5235+
// For example:
5236+
//
5237+
// (before, ex1)
5238+
// v: v2f32 = LoadParam [p]
5239+
// f1: f32 = extractelt v, 0
5240+
// f2: f32 = extractelt v, 1
5241+
// r = add.f32 f1, f2
5242+
//
5243+
// ...or...
5244+
//
5245+
// (before, ex2)
5246+
// i: i64 = LoadParam [p]
5247+
// v: v2f32 = bitcast i
5248+
// f1: f32 = extractelt v, 0
5249+
// f2: f32 = extractelt v, 1
5250+
// r = add.f32 f1, f2
5251+
//
5252+
// ...will become...
5253+
//
5254+
// (after for both)
5255+
// vf: f32,f32 = LoadParamV2 [p]
5256+
// r = add.f32 vf:0, vf:1
52085257
static SDValue PerformLoadCombine(SDNode *N,
52095258
TargetLowering::DAGCombinerInfo &DCI) {
52105259
if (DCI.DAG.getOptLevel() == CodeGenOptLevel::None)
@@ -5223,6 +5272,7 @@ static SDValue PerformLoadCombine(SDNode *N,
52235272
return VT == MVT::i64 || VT == MVT::f32 || VT.isVector();
52245273
});
52255274

5275+
// (1) All we are doing here is looking for patterns.
52265276
SmallDenseMap<SDNode *, unsigned> ExtractElts;
52275277
SmallVector<SDNode *> ProxyRegs(OrigNumResults, nullptr);
52285278
SmallVector<std::pair<SDNode *, unsigned>> WorkList{{N, {}}};
@@ -5274,24 +5324,18 @@ static SDValue PerformLoadCombine(SDNode *N,
52745324
ProcessingInitialLoad = false;
52755325
}
52765326

5277-
// (2) If the load's value is only used as f32 elements, replace all
5278-
// extractelts with individual elements of the newly-created load. If there's
5279-
// a ProxyReg, handle that too. After this check, we'll proceed in the
5280-
// following way:
5281-
// 1. Determine which type of load to create, which will split the results
5282-
// of the original load into f32 components.
5283-
// 2. If there's a ProxyReg, split that too.
5284-
// 3. Replace all extractelts with references to the new load / proxy reg.
5285-
// 4. Replace all glue/chain references with references to the new load /
5286-
// proxy reg.
5327+
// Did we find any patterns? All patterns we're interested in end with an
5328+
// extractelt.
52875329
if (ExtractElts.empty())
52885330
return SDValue();
52895331

5332+
// (2) Now, we will decide what load to create.
5333+
52905334
// Do we have to tweak the opcode for an NVPTXISD::Load* or do we have to
52915335
// rewrite an ISD::LOAD?
52925336
std::optional<NVPTXISD::NodeType> NewOpcode;
52935337

5294-
// LoadV's are handled slightly different in ISelDAGToDAG.
5338+
// LoadV's are handled slightly different in ISelDAGToDAG. See below.
52955339
bool IsLoadV = false;
52965340
switch (N->getOpcode()) {
52975341
case NVPTXISD::LoadV2:
@@ -5306,7 +5350,15 @@ static SDValue PerformLoadCombine(SDNode *N,
53065350
break;
53075351
}
53085352

5309-
SDValue OldChain, OldGlue;
5353+
// We haven't created the new load yet, but we're saving some information
5354+
// about the old load because we will need to replace all uses of it later.
5355+
// Because our pattern is generic, we're matching ISD::LOAD and
5356+
// NVPTXISD::Load*, and we just search for the chain and glue outputs rather
5357+
// than have a case for each type of load.
5358+
const bool HaveProxyRegs =
5359+
llvm::any_of(ProxyRegs, [](const SDNode *PR) { return PR != nullptr; });
5360+
5361+
SDValue OldChain, OldGlue /* optional */;
53105362
for (unsigned I = 0, E = N->getNumValues(); I != E; ++I) {
53115363
if (N->getValueType(I) == MVT::Other)
53125364
OldChain = SDValue(N, I);
@@ -5316,7 +5368,8 @@ static SDValue PerformLoadCombine(SDNode *N,
53165368

53175369
SDValue NewLoad, NewChain, NewGlue /* (optional) */;
53185370
unsigned NumElts = 0;
5319-
if (NewOpcode) { // tweak NVPTXISD::Load* opcode
5371+
if (NewOpcode) {
5372+
// Here, we are tweaking a NVPTXISD::Load* opcode to output N*2 results.
53205373
SmallVector<EVT> VTs;
53215374

53225375
// should always be non-null after this
@@ -5357,6 +5410,15 @@ static SDValue PerformLoadCombine(SDNode *N,
53575410
if (NewGlueIdx)
53585411
NewGlue = NewLoad.getValue(*NewGlueIdx);
53595412
} else if (N->getOpcode() == ISD::LOAD) { // rewrite a load
5413+
// Here, we are lowering an ISD::LOAD to an NVPTXISD::Load*. For example:
5414+
//
5415+
// (before)
5416+
// v2f32,ch,glue = ISD::LOAD [p]
5417+
//
5418+
// ...becomes...
5419+
//
5420+
// (after)
5421+
// f32,f32,ch,glue = NVPTXISD::LoadV2 [p]
53605422
std::optional<EVT> CastToType;
53615423
EVT ResVT = N->getValueType(0);
53625424
if (ResVT == MVT::i64) {
@@ -5374,23 +5436,41 @@ static SDValue PerformLoadCombine(SDNode *N,
53745436
}
53755437
}
53765438

5439+
// If this was some other type of load we couldn't handle, we bail.
53775440
if (!NewLoad)
5378-
return SDValue(); // could not match pattern
5441+
return SDValue();
53795442

5380-
// (3) begin rewriting uses
5443+
// (3) We successfully rewrote the load. Now we must rewrite all uses of the
5444+
// old load.
53815445
SmallVector<SDValue> NewOutputsF32;
53825446

5383-
if (llvm::any_of(ProxyRegs, [](const SDNode *PR) { return PR != nullptr; })) {
5384-
// scalarize proxy regs, but first rewrite all uses of chain and glue from
5385-
// the old load to the new load
5447+
if (!HaveProxyRegs) {
5448+
// The case without proxy registers in the def-use chain is simple. Each
5449+
// extractelt is matched to an output of the new load (see calls to
5450+
// DCI.CombineTo() below).
5451+
for (unsigned I = 0, E = NumElts; I != E; ++I)
5452+
if (NewLoad->getValueType(I) == MVT::f32)
5453+
NewOutputsF32.push_back(NewLoad.getValue(I));
5454+
5455+
// replace all glue and chain nodes
5456+
DCI.DAG.ReplaceAllUsesOfValueWith(OldChain, NewChain);
5457+
if (OldGlue)
5458+
DCI.DAG.ReplaceAllUsesOfValueWith(OldGlue, NewGlue);
5459+
} else {
5460+
// The case with proxy registers is slightly more complicated. We have to
5461+
// expand those too.
5462+
5463+
// First, rewrite all uses of chain and glue from the old load to the new
5464+
// load. This is one less thing to worry about.
53865465
DCI.DAG.ReplaceAllUsesOfValueWith(OldChain, NewChain);
53875466
DCI.DAG.ReplaceAllUsesOfValueWith(OldGlue, NewGlue);
53885467

5468+
// Now we will expand all the proxy registers for each output.
53895469
for (unsigned ProxyI = 0, ProxyE = ProxyRegs.size(); ProxyI != ProxyE;
53905470
++ProxyI) {
53915471
SDNode *ProxyReg = ProxyRegs[ProxyI];
53925472

5393-
// no proxy reg might mean this result is unused
5473+
// No proxy reg might mean this result is unused.
53945474
if (!ProxyReg)
53955475
continue;
53965476

@@ -5404,12 +5484,12 @@ static SDValue PerformLoadCombine(SDNode *N,
54045484
if (SDValue OldInGlue = ProxyReg->getOperand(2); OldInGlue.getNode() != N)
54055485
NewGlue = OldInGlue;
54065486

5407-
// update OldChain, OldGlue to the outputs of ProxyReg, which we will
5408-
// replace later
5487+
// Update OldChain, OldGlue to the outputs of ProxyReg, which we will
5488+
// replace later.
54095489
OldChain = SDValue(ProxyReg, 1);
54105490
OldGlue = SDValue(ProxyReg, 2);
54115491

5412-
// generate the scalar proxy regs
5492+
// Generate the scalar proxy regs.
54135493
for (unsigned I = 0, E = 2; I != E; ++I) {
54145494
SDValue ProxyRegElem = DCI.DAG.getNode(
54155495
NVPTXISD::ProxyReg, SDLoc(ProxyReg),
@@ -5424,18 +5504,10 @@ static SDValue PerformLoadCombine(SDNode *N,
54245504
DCI.DAG.ReplaceAllUsesOfValueWith(OldChain, NewChain);
54255505
DCI.DAG.ReplaceAllUsesOfValueWith(OldGlue, NewGlue);
54265506
}
5427-
} else {
5428-
for (unsigned I = 0, E = NumElts; I != E; ++I)
5429-
if (NewLoad->getValueType(I) == MVT::f32)
5430-
NewOutputsF32.push_back(NewLoad.getValue(I));
5431-
5432-
// replace all glue and chain nodes
5433-
DCI.DAG.ReplaceAllUsesOfValueWith(OldChain, NewChain);
5434-
if (OldGlue)
5435-
DCI.DAG.ReplaceAllUsesOfValueWith(OldGlue, NewGlue);
54365507
}
54375508

5438-
// replace all extractelts with the new outputs
5509+
// Replace all extractelts with the new outputs. This leaves the old load and
5510+
// unpacking instructions dead.
54395511
for (auto &[Extract, Index] : ExtractElts)
54405512
DCI.CombineTo(Extract, NewOutputsF32[Index], false);
54415513

0 commit comments

Comments
 (0)