@@ -580,6 +580,7 @@ namespace {
580
580
SDValue reassociateReduction(unsigned RedOpc, unsigned Opc, const SDLoc &DL,
581
581
EVT VT, SDValue N0, SDValue N1,
582
582
SDNodeFlags Flags = SDNodeFlags());
583
+ SDValue foldReductionWithUndefLane(SDNode *N);
583
584
584
585
SDValue visitShiftByConstant(SDNode *N);
585
586
@@ -1347,6 +1348,75 @@ SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
1347
1348
return SDValue();
1348
1349
}
1349
1350
1351
+ // Convert:
1352
+ // (op.x2 (vector_shuffle<i,u> A), B) -> <(op A:i, B:0) undef>
1353
+ // ...or...
1354
+ // (op.x2 (vector_shuffle<u,i> A), B) -> <undef (op A:i, B:1)>
1355
+ // ...where i is a valid index and u is poison.
1356
+ SDValue DAGCombiner::foldReductionWithUndefLane(SDNode *N) {
1357
+ const EVT VectorVT = N->getValueType(0);
1358
+
1359
+ // Only support 2-packed vectors for now.
1360
+ if (!VectorVT.isVector() || VectorVT.isScalableVector()
1361
+ || VectorVT.getVectorNumElements() != 2)
1362
+ return SDValue();
1363
+
1364
+ // If the operation is already unsupported, we don't need to do this
1365
+ // operation.
1366
+ if (!TLI.isOperationLegal(N->getOpcode(), VectorVT))
1367
+ return SDValue();
1368
+
1369
+ // If vector shuffle is supported on the target, this optimization may
1370
+ // increase register pressure.
1371
+ if (TLI.isOperationLegalOrCustomOrPromote(ISD::VECTOR_SHUFFLE, VectorVT))
1372
+ return SDValue();
1373
+
1374
+ SDLoc DL(N);
1375
+
1376
+ SDValue ShufOp = N->getOperand(0);
1377
+ SDValue VectOp = N->getOperand(1);
1378
+ bool Swapped = false;
1379
+
1380
+ // canonicalize shuffle op
1381
+ if (VectOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
1382
+ std::swap(ShufOp, VectOp);
1383
+ Swapped = true;
1384
+ }
1385
+
1386
+ if (ShufOp.getOpcode() != ISD::VECTOR_SHUFFLE)
1387
+ return SDValue();
1388
+
1389
+ auto *ShuffleOp = cast<ShuffleVectorSDNode>(ShufOp);
1390
+ int LiveLane; // exclusively live lane
1391
+ for (LiveLane = 0; LiveLane < 2; ++LiveLane) {
1392
+ // check if the current lane is live and the other lane is dead
1393
+ if (ShuffleOp->getMaskElt(LiveLane) != PoisonMaskElem &&
1394
+ ShuffleOp->getMaskElt(!LiveLane) == PoisonMaskElem)
1395
+ break;
1396
+ }
1397
+ if (LiveLane == 2)
1398
+ return SDValue();
1399
+
1400
+ const int ElementIdx = ShuffleOp->getMaskElt(LiveLane);
1401
+ const EVT ScalarVT = VectorVT.getScalarType();
1402
+ SDValue Lanes[2] = {};
1403
+ for (auto [LaneID, LaneVal] : enumerate(Lanes)) {
1404
+ if (LaneID == (unsigned)LiveLane) {
1405
+ SDValue Operands[2] = {
1406
+ DAG.getExtractVectorElt(DL, ScalarVT, ShufOp.getOperand(0),
1407
+ ElementIdx),
1408
+ DAG.getExtractVectorElt(DL, ScalarVT, VectOp, LiveLane)};
1409
+ // preserve the order of operands
1410
+ if (Swapped)
1411
+ std::swap(Operands[0], Operands[1]);
1412
+ LaneVal = DAG.getNode(N->getOpcode(), DL, ScalarVT, Operands);
1413
+ } else {
1414
+ LaneVal = DAG.getUNDEF(ScalarVT);
1415
+ }
1416
+ }
1417
+ return DAG.getBuildVector(VectorVT, DL, Lanes);
1418
+ }
1419
+
1350
1420
SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1351
1421
bool AddTo) {
1352
1422
assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
@@ -3056,6 +3126,9 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
3056
3126
return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), SV);
3057
3127
}
3058
3128
3129
+ if (SDValue R = foldReductionWithUndefLane(N))
3130
+ return R;
3131
+
3059
3132
return SDValue();
3060
3133
}
3061
3134
@@ -5999,6 +6072,9 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
5999
6072
SDLoc(N), VT, N0, N1))
6000
6073
return SD;
6001
6074
6075
+ if (SDValue SD = foldReductionWithUndefLane(N))
6076
+ return SD;
6077
+
6002
6078
// Simplify the operands using demanded-bits information.
6003
6079
if (SimplifyDemandedBits(SDValue(N, 0)))
6004
6080
return SDValue(N, 0);
@@ -7267,6 +7343,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
7267
7343
}
7268
7344
}
7269
7345
}
7346
+
7347
+ if (SDValue R = foldReductionWithUndefLane(N))
7348
+ return R;
7270
7349
}
7271
7350
7272
7351
// fold (and x, -1) -> x
@@ -8242,6 +8321,9 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
8242
8321
}
8243
8322
}
8244
8323
}
8324
+
8325
+ if (SDValue R = foldReductionWithUndefLane(N))
8326
+ return R;
8245
8327
}
8246
8328
8247
8329
// fold (or x, 0) -> x
@@ -9923,6 +10005,9 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
9923
10005
if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
9924
10006
return Combined;
9925
10007
10008
+ if (SDValue R = foldReductionWithUndefLane(N))
10009
+ return R;
10010
+
9926
10011
return SDValue();
9927
10012
}
9928
10013
@@ -17529,6 +17614,10 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
17529
17614
AddToWorklist(Fused.getNode());
17530
17615
return Fused;
17531
17616
}
17617
+
17618
+ if (SDValue R = foldReductionWithUndefLane(N))
17619
+ return R;
17620
+
17532
17621
return SDValue();
17533
17622
}
17534
17623
@@ -17897,6 +17986,9 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
17897
17986
if (SDValue R = combineFMulOrFDivWithIntPow2(N))
17898
17987
return R;
17899
17988
17989
+ if (SDValue R = foldReductionWithUndefLane(N))
17990
+ return R;
17991
+
17900
17992
return SDValue();
17901
17993
}
17902
17994
@@ -19002,6 +19094,9 @@ SDValue DAGCombiner::visitFMinMax(SDNode *N) {
19002
19094
Opc, SDLoc(N), VT, N0, N1, Flags))
19003
19095
return SD;
19004
19096
19097
+ if (SDValue SD = foldReductionWithUndefLane(N))
19098
+ return SD;
19099
+
19005
19100
return SDValue();
19006
19101
}
19007
19102
0 commit comments