Skip to content

Commit 9d84bd3

Browse files
committed
Revert "[NVPTX] add combiner rule for final packed op in reduction"
This reverts commit 8cbda00.
1 parent 8cbda00 commit 9d84bd3

File tree

2 files changed

+244
-210
lines changed

2 files changed

+244
-210
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 6 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -852,13 +852,6 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
852852
if (STI.allowFP16Math() || STI.hasBF16Math())
853853
setTargetDAGCombine(ISD::SETCC);
854854

855-
// Combine reduction operations on packed types (e.g. fadd.f16x2) with vector
856-
// shuffles when one of their lanes is a no-op.
857-
if (STI.allowFP16Math() || STI.hasBF16Math())
858-
// already added above: FADD, ADD, AND
859-
setTargetDAGCombine({ISD::FMUL, ISD::FMINIMUM, ISD::FMAXIMUM, ISD::UMIN,
860-
ISD::UMAX, ISD::SMIN, ISD::SMAX, ISD::OR, ISD::XOR});
861-
862855
// Promote fp16 arithmetic if fp16 hardware isn't available or the
863856
// user passed --nvptx-no-fp16-math. The flag is useful because,
864857
// although sm_53+ GPUs have some sort of FP16 support in
@@ -5076,102 +5069,20 @@ static SDValue PerformStoreRetvalCombine(SDNode *N) {
50765069
return PerformStoreCombineHelper(N, 2, 0);
50775070
}
50785071

5079-
/// For vector reductions, the final result needs to be a scalar. The default
5080-
/// expansion will use packed ops (ex. fadd.f16x2) even for the final operation.
5081-
/// This requires a packed operation where one of the lanes is undef.
5082-
///
5083-
/// ex: lowering of vecreduce_fadd(V) where V = v4f16<a b c d>
5084-
///
5085-
/// v1: v2f16 = fadd reassoc v2f16<a b>, v2f16<c d> (== <a+c b+d>)
5086-
/// v2: v2f16 = vector_shuffle<1,u> v1, undef:v2f16 (== <b+d undef>)
5087-
/// v3: v2f16 = fadd reassoc v2, v1 (== <b+d+a+c undef>)
5088-
/// vR: f16 = extractelt v3, 1
5089-
///
5090-
/// We wish to replace vR, v3, and v2 with:
5091-
/// vR: f16 = fadd reassoc (extractelt v1, 1) (extractelt v1, 0)
5092-
///
5093-
/// ...so that we get:
5094-
/// v1: v2f16 = fadd reassoc v2f16<a b>, v2f16<c d> (== <a+c b+d>)
5095-
/// s1: f16 = extractelt v1, 1
5096-
/// s2: f16 = extractelt v1, 0
5097-
/// vR: f16 = fadd reassoc s1, s2 (== a+c+b+d)
5098-
///
5099-
/// So for this example, this rule will replace v3 and v2, returning a vector
5100-
/// with the result in lane 0 and an undef in lane 1, which we expect will be
5101-
/// folded into the extractelt in vR.
5102-
static SDValue PerformPackedOpCombine(SDNode *N,
5103-
TargetLowering::DAGCombinerInfo &DCI) {
5104-
// Convert:
5105-
// (fop.x2 (vector_shuffle<i,u> A), B) -> ((fop A:i, B:0), undef)
5106-
// ...or...
5107-
// (fop.x2 (vector_shuffle<u,i> A), B) -> (undef, (fop A:i, B:1))
5108-
// ...where i is a valid index and u is poison.
5109-
const EVT VectorVT = N->getValueType(0);
5110-
if (!Isv2x16VT(VectorVT))
5111-
return SDValue();
5112-
5113-
SDLoc DL(N);
5114-
5115-
SDValue ShufOp = N->getOperand(0);
5116-
SDValue VectOp = N->getOperand(1);
5117-
bool Swapped = false;
5118-
5119-
// canonicalize shuffle to op0
5120-
if (VectOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
5121-
std::swap(ShufOp, VectOp);
5122-
Swapped = true;
5123-
}
5124-
5125-
if (ShufOp.getOpcode() != ISD::VECTOR_SHUFFLE)
5126-
return SDValue();
5127-
5128-
auto *ShuffleOp = cast<ShuffleVectorSDNode>(ShufOp);
5129-
int LiveLane; // exclusively live lane
5130-
for (LiveLane = 0; LiveLane < 2; ++LiveLane) {
5131-
// check if the current lane is live and the other lane is dead
5132-
if (ShuffleOp->getMaskElt(LiveLane) != PoisonMaskElem &&
5133-
ShuffleOp->getMaskElt(!LiveLane) == PoisonMaskElem)
5134-
break;
5135-
}
5136-
if (LiveLane == 2)
5137-
return SDValue();
5138-
5139-
int ElementIdx = ShuffleOp->getMaskElt(LiveLane);
5140-
const EVT ScalarVT = VectorVT.getScalarType();
5141-
SDValue Lanes[2] = {};
5142-
for (auto [LaneID, LaneVal] : enumerate(Lanes)) {
5143-
if (LaneID == (unsigned)LiveLane) {
5144-
SDValue Operands[2] = {
5145-
DCI.DAG.getExtractVectorElt(DL, ScalarVT, ShufOp.getOperand(0),
5146-
ElementIdx),
5147-
DCI.DAG.getExtractVectorElt(DL, ScalarVT, VectOp, LiveLane)};
5148-
// preserve the order of operands
5149-
if (Swapped)
5150-
std::swap(Operands[0], Operands[1]);
5151-
LaneVal = DCI.DAG.getNode(N->getOpcode(), DL, ScalarVT, Operands);
5152-
} else {
5153-
LaneVal = DCI.DAG.getUNDEF(ScalarVT);
5154-
}
5155-
}
5156-
return DCI.DAG.getBuildVector(VectorVT, DL, Lanes);
5157-
}
5158-
51595072
/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
51605073
///
51615074
static SDValue PerformADDCombine(SDNode *N,
51625075
TargetLowering::DAGCombinerInfo &DCI,
51635076
CodeGenOptLevel OptLevel) {
5077+
if (OptLevel == CodeGenOptLevel::None)
5078+
return SDValue();
5079+
51645080
SDValue N0 = N->getOperand(0);
51655081
SDValue N1 = N->getOperand(1);
51665082

51675083
// Skip non-integer, non-scalar case
51685084
EVT VT = N0.getValueType();
5169-
if (VT.isVector())
5170-
return PerformPackedOpCombine(N, DCI);
5171-
if (VT != MVT::i32)
5172-
return SDValue();
5173-
5174-
if (OptLevel == CodeGenOptLevel::None)
5085+
if (VT.isVector() || VT != MVT::i32)
51755086
return SDValue();
51765087

51775088
// First try with the default operand order.
@@ -5191,10 +5102,7 @@ static SDValue PerformFADDCombine(SDNode *N,
51915102
SDValue N1 = N->getOperand(1);
51925103

51935104
EVT VT = N0.getValueType();
5194-
if (VT.isVector())
5195-
return PerformPackedOpCombine(N, DCI);
5196-
5197-
if (!(VT == MVT::f32 || VT == MVT::f64))
5105+
if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
51985106
return SDValue();
51995107

52005108
// First try with the default operand order.
@@ -5297,7 +5205,7 @@ static SDValue PerformANDCombine(SDNode *N,
52975205
DCI.CombineTo(N, Val, AddTo);
52985206
}
52995207

5300-
return PerformPackedOpCombine(N, DCI);
5208+
return SDValue();
53015209
}
53025210

53035211
static SDValue PerformREMCombine(SDNode *N,
@@ -5778,16 +5686,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
57785686
return PerformADDCombine(N, DCI, OptLevel);
57795687
case ISD::FADD:
57805688
return PerformFADDCombine(N, DCI, OptLevel);
5781-
case ISD::FMUL:
5782-
case ISD::FMINNUM:
5783-
case ISD::FMAXIMUM:
5784-
case ISD::UMIN:
5785-
case ISD::UMAX:
5786-
case ISD::SMIN:
5787-
case ISD::SMAX:
5788-
case ISD::OR:
5789-
case ISD::XOR:
5790-
return PerformPackedOpCombine(N, DCI);
57915689
case ISD::MUL:
57925690
return PerformMULCombine(N, DCI, OptLevel);
57935691
case ISD::SHL:

0 commit comments

Comments
 (0)