@@ -5205,6 +5205,55 @@ getMachineMemOperandForType(const SelectionDAG &DAG,
5205
5205
LLT (VT));
5206
5206
}
5207
5207
5208
+ // These are Combiner rules for expanding v2f32 load results when they are
5209
+ // really being used as their individual f32 components. Now that v2f32 is a
5210
+ // legal type for a register, LowerFormalArguments() and ReplaceLoadVector()
5211
+ // will pack two f32s into a single 64-bit register, leading to ld.b64 instead
5212
+ // of ld.v2.f32 or ld.v2.b64 instead of ld.v4.f32. Sometimes this is ideal if
5213
+ // the results stay packed because they're passed to another instruction that
5214
+ // supports packed f32s (e.g. fmul.f32x2) or (rarely) if v2f32 really is being
5215
+ // reinterpreted as an i64, and then stored.
5216
+ //
5217
+ // Otherwise, SelectionDAG will unpack the results with a sequence of bitcasts,
5218
+ // extensions, and extracts if they go through any other kind of instruction.
5219
+ // This is not ideal, so we undo these patterns and rewrite the load to output
5220
+ // twice as many registers: two f32s for every one i64. This preserves PTX
5221
+ // codegen for programs that don't use packed f32s.
5222
+ //
5223
+ // Also, LowerFormalArguments() and ReplaceLoadVector() happen too early for us
5224
+ // to know whether the def-use chain for a particular load will eventually
5225
+ // include instructions supporting packed f32s. That is why we prefer to resolve
5226
+ // this problem within DAG Combiner.
5227
+ //
5228
+ // This rule proceeds in three general steps:
5229
+ //
5230
+ // 1. Identify the pattern, by traversing the def-use chain.
5231
+ // 2. Rewrite the load, by splitting each 64-bit result into two f32 registers.
5232
+ // 3. Rewrite all uses of the load, including chain and glue uses.
5233
+ //
5234
+ // This has the effect of combining multiple instructions into a single load.
5235
+ // For example:
5236
+ //
5237
+ // (before, ex1)
5238
+ // v: v2f32 = LoadParam [p]
5239
+ // f1: f32 = extractelt v, 0
5240
+ // f2: f32 = extractelt v, 1
5241
+ // r = add.f32 f1, f2
5242
+ //
5243
+ // ...or...
5244
+ //
5245
+ // (before, ex2)
5246
+ // i: i64 = LoadParam [p]
5247
+ // v: v2f32 = bitcast i
5248
+ // f1: f32 = extractelt v, 0
5249
+ // f2: f32 = extractelt v, 1
5250
+ // r = add.f32 f1, f2
5251
+ //
5252
+ // ...will become...
5253
+ //
5254
+ // (after for both)
5255
+ // vf: f32,f32 = LoadParamV2 [p]
5256
+ // r = add.f32 vf:0, vf:1
5208
5257
static SDValue PerformLoadCombine (SDNode *N,
5209
5258
TargetLowering::DAGCombinerInfo &DCI) {
5210
5259
if (DCI.DAG .getOptLevel () == CodeGenOptLevel::None)
@@ -5223,6 +5272,7 @@ static SDValue PerformLoadCombine(SDNode *N,
5223
5272
return VT == MVT::i64 || VT == MVT::f32 || VT.isVector ();
5224
5273
});
5225
5274
5275
+ // (1) All we are doing here is looking for patterns.
5226
5276
SmallDenseMap<SDNode *, unsigned > ExtractElts;
5227
5277
SmallVector<SDNode *> ProxyRegs (OrigNumResults, nullptr );
5228
5278
SmallVector<std::pair<SDNode *, unsigned >> WorkList{{N, {}}};
@@ -5274,24 +5324,18 @@ static SDValue PerformLoadCombine(SDNode *N,
5274
5324
ProcessingInitialLoad = false ;
5275
5325
}
5276
5326
5277
- // (2) If the load's value is only used as f32 elements, replace all
5278
- // extractelts with individual elements of the newly-created load. If there's
5279
- // a ProxyReg, handle that too. After this check, we'll proceed in the
5280
- // following way:
5281
- // 1. Determine which type of load to create, which will split the results
5282
- // of the original load into f32 components.
5283
- // 2. If there's a ProxyReg, split that too.
5284
- // 3. Replace all extractelts with references to the new load / proxy reg.
5285
- // 4. Replace all glue/chain references with references to the new load /
5286
- // proxy reg.
5327
+ // Did we find any patterns? All patterns we're interested in end with an
5328
+ // extractelt.
5287
5329
if (ExtractElts.empty ())
5288
5330
return SDValue ();
5289
5331
5332
+ // (2) Now, we will decide what load to create.
5333
+
5290
5334
// Do we have to tweak the opcode for an NVPTXISD::Load* or do we have to
5291
5335
// rewrite an ISD::LOAD?
5292
5336
std::optional<NVPTXISD::NodeType> NewOpcode;
5293
5337
5294
- // LoadV's are handled slightly different in ISelDAGToDAG.
5338
+ // LoadV's are handled slightly different in ISelDAGToDAG. See below.
5295
5339
bool IsLoadV = false ;
5296
5340
switch (N->getOpcode ()) {
5297
5341
case NVPTXISD::LoadV2:
@@ -5306,7 +5350,15 @@ static SDValue PerformLoadCombine(SDNode *N,
5306
5350
break ;
5307
5351
}
5308
5352
5309
- SDValue OldChain, OldGlue;
5353
+ // We haven't created the new load yet, but we're saving some information
5354
+ // about the old load because we will need to replace all uses of it later.
5355
+ // Because our pattern is generic, we're matching ISD::LOAD and
5356
+ // NVPTXISD::Load*, and we just search for the chain and glue outputs rather
5357
+ // than have a case for each type of load.
5358
+ const bool HaveProxyRegs =
5359
+ llvm::any_of (ProxyRegs, [](const SDNode *PR) { return PR != nullptr ; });
5360
+
5361
+ SDValue OldChain, OldGlue /* optional */ ;
5310
5362
for (unsigned I = 0 , E = N->getNumValues (); I != E; ++I) {
5311
5363
if (N->getValueType (I) == MVT::Other)
5312
5364
OldChain = SDValue (N, I);
@@ -5316,7 +5368,8 @@ static SDValue PerformLoadCombine(SDNode *N,
5316
5368
5317
5369
SDValue NewLoad, NewChain, NewGlue /* (optional) */ ;
5318
5370
unsigned NumElts = 0 ;
5319
- if (NewOpcode) { // tweak NVPTXISD::Load* opcode
5371
+ if (NewOpcode) {
5372
+ // Here, we are tweaking a NVPTXISD::Load* opcode to output N*2 results.
5320
5373
SmallVector<EVT> VTs;
5321
5374
5322
5375
// should always be non-null after this
@@ -5357,6 +5410,15 @@ static SDValue PerformLoadCombine(SDNode *N,
5357
5410
if (NewGlueIdx)
5358
5411
NewGlue = NewLoad.getValue (*NewGlueIdx);
5359
5412
} else if (N->getOpcode () == ISD::LOAD) { // rewrite a load
5413
+ // Here, we are lowering an ISD::LOAD to an NVPTXISD::Load*. For example:
5414
+ //
5415
+ // (before)
5416
+ // v2f32,ch,glue = ISD::LOAD [p]
5417
+ //
5418
+ // ...becomes...
5419
+ //
5420
+ // (after)
5421
+ // f32,f32,ch,glue = NVPTXISD::LoadV2 [p]
5360
5422
std::optional<EVT> CastToType;
5361
5423
EVT ResVT = N->getValueType (0 );
5362
5424
if (ResVT == MVT::i64 ) {
@@ -5374,23 +5436,41 @@ static SDValue PerformLoadCombine(SDNode *N,
5374
5436
}
5375
5437
}
5376
5438
5439
+ // If this was some other type of load we couldn't handle, we bail.
5377
5440
if (!NewLoad)
5378
- return SDValue (); // could not match pattern
5441
+ return SDValue ();
5379
5442
5380
- // (3) begin rewriting uses
5443
+ // (3) We successfully rewrote the load. Now we must rewrite all uses of the
5444
+ // old load.
5381
5445
SmallVector<SDValue> NewOutputsF32;
5382
5446
5383
- if (llvm::any_of (ProxyRegs, [](const SDNode *PR) { return PR != nullptr ; })) {
5384
- // scalarize proxy regs, but first rewrite all uses of chain and glue from
5385
- // the old load to the new load
5447
+ if (!HaveProxyRegs) {
5448
+ // The case without proxy registers in the def-use chain is simple. Each
5449
+ // extractelt is matched to an output of the new load (see calls to
5450
+ // DCI.CombineTo() below).
5451
+ for (unsigned I = 0 , E = NumElts; I != E; ++I)
5452
+ if (NewLoad->getValueType (I) == MVT::f32 )
5453
+ NewOutputsF32.push_back (NewLoad.getValue (I));
5454
+
5455
+ // replace all glue and chain nodes
5456
+ DCI.DAG .ReplaceAllUsesOfValueWith (OldChain, NewChain);
5457
+ if (OldGlue)
5458
+ DCI.DAG .ReplaceAllUsesOfValueWith (OldGlue, NewGlue);
5459
+ } else {
5460
+ // The case with proxy registers is slightly more complicated. We have to
5461
+ // expand those too.
5462
+
5463
+ // First, rewrite all uses of chain and glue from the old load to the new
5464
+ // load. This is one less thing to worry about.
5386
5465
DCI.DAG .ReplaceAllUsesOfValueWith (OldChain, NewChain);
5387
5466
DCI.DAG .ReplaceAllUsesOfValueWith (OldGlue, NewGlue);
5388
5467
5468
+ // Now we will expand all the proxy registers for each output.
5389
5469
for (unsigned ProxyI = 0 , ProxyE = ProxyRegs.size (); ProxyI != ProxyE;
5390
5470
++ProxyI) {
5391
5471
SDNode *ProxyReg = ProxyRegs[ProxyI];
5392
5472
5393
- // no proxy reg might mean this result is unused
5473
+ // No proxy reg might mean this result is unused.
5394
5474
if (!ProxyReg)
5395
5475
continue ;
5396
5476
@@ -5404,12 +5484,12 @@ static SDValue PerformLoadCombine(SDNode *N,
5404
5484
if (SDValue OldInGlue = ProxyReg->getOperand (2 ); OldInGlue.getNode () != N)
5405
5485
NewGlue = OldInGlue;
5406
5486
5407
- // update OldChain, OldGlue to the outputs of ProxyReg, which we will
5408
- // replace later
5487
+ // Update OldChain, OldGlue to the outputs of ProxyReg, which we will
5488
+ // replace later.
5409
5489
OldChain = SDValue (ProxyReg, 1 );
5410
5490
OldGlue = SDValue (ProxyReg, 2 );
5411
5491
5412
- // generate the scalar proxy regs
5492
+ // Generate the scalar proxy regs.
5413
5493
for (unsigned I = 0 , E = 2 ; I != E; ++I) {
5414
5494
SDValue ProxyRegElem = DCI.DAG .getNode (
5415
5495
NVPTXISD::ProxyReg, SDLoc (ProxyReg),
@@ -5424,18 +5504,10 @@ static SDValue PerformLoadCombine(SDNode *N,
5424
5504
DCI.DAG .ReplaceAllUsesOfValueWith (OldChain, NewChain);
5425
5505
DCI.DAG .ReplaceAllUsesOfValueWith (OldGlue, NewGlue);
5426
5506
}
5427
- } else {
5428
- for (unsigned I = 0 , E = NumElts; I != E; ++I)
5429
- if (NewLoad->getValueType (I) == MVT::f32 )
5430
- NewOutputsF32.push_back (NewLoad.getValue (I));
5431
-
5432
- // replace all glue and chain nodes
5433
- DCI.DAG .ReplaceAllUsesOfValueWith (OldChain, NewChain);
5434
- if (OldGlue)
5435
- DCI.DAG .ReplaceAllUsesOfValueWith (OldGlue, NewGlue);
5436
5507
}
5437
5508
5438
- // replace all extractelts with the new outputs
5509
+ // Replace all extractelts with the new outputs. This leaves the old load and
5510
+ // unpacking instructions dead.
5439
5511
for (auto &[Extract, Index] : ExtractElts)
5440
5512
DCI.CombineTo (Extract, NewOutputsF32[Index], false );
5441
5513
0 commit comments