@@ -828,7 +828,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
828
828
setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
829
829
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
830
830
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
831
- ISD::TRUNCATE, ISD::LOAD, ISD::BITCAST});
831
+ ISD::TRUNCATE, ISD::LOAD, ISD::STORE, ISD:: BITCAST});
832
832
833
833
// setcc for f16x2 and bf16x2 needs special handling to prevent
834
834
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -2992,10 +2992,10 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
2992
2992
if (Op.getValueType () == MVT::i1)
2993
2993
return LowerLOADi1 (Op, DAG);
2994
2994
2995
- // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
2996
- // unaligned loads and have to handle it here.
2995
+ // v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
2996
+ // handle unaligned loads and have to handle it here.
2997
2997
EVT VT = Op.getValueType ();
2998
- if (Isv2x16VT (VT) || VT == MVT::v4i8) {
2998
+ if (Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 ) {
2999
2999
LoadSDNode *Load = cast<LoadSDNode>(Op);
3000
3000
EVT MemVT = Load->getMemoryVT ();
3001
3001
if (!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
@@ -3039,15 +3039,15 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
3039
3039
if (VT == MVT::i1)
3040
3040
return LowerSTOREi1 (Op, DAG);
3041
3041
3042
- // v2f16 is legal, so we can't rely on legalizer to handle unaligned
3043
- // stores and have to handle it here.
3044
- if ((Isv2x16VT (VT) || VT == MVT::v4i8) &&
3042
+ // v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3043
+ // handle unaligned stores and have to handle it here.
3044
+ if ((Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 ) &&
3045
3045
!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
3046
3046
VT, *Store->getMemOperand ()))
3047
3047
return expandUnalignedStore (Store, DAG);
3048
3048
3049
- // v2f16, v2bf16 and v2i16 don't need special handling.
3050
- if (Isv2x16VT (VT) || VT == MVT::v4i8)
3049
+ // v2f16/ v2bf16/ v2i16/v4i8/v2f32 don't need special handling.
3050
+ if (Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 )
3051
3051
return SDValue ();
3052
3052
3053
3053
if (VT.isVector ())
@@ -3056,8 +3056,8 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
3056
3056
return SDValue ();
3057
3057
}
3058
3058
3059
- SDValue
3060
- NVPTXTargetLowering::LowerSTOREVector (SDValue Op, SelectionDAG &DAG) const {
3059
+ static SDValue convertVectorStore (SDValue Op, SelectionDAG &DAG,
3060
+ const SmallVectorImpl<SDValue> &Elements) {
3061
3061
SDNode *N = Op.getNode ();
3062
3062
SDValue Val = N->getOperand (1 );
3063
3063
SDLoc DL (N);
@@ -3124,6 +3124,8 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
3124
3124
SDValue SubVector = DAG.getBuildVector (EltVT, DL, SubVectorElts);
3125
3125
Ops.push_back (SubVector);
3126
3126
}
3127
+ } else if (!Elements.empty ()) {
3128
+ Ops.insert (Ops.end (), Elements.begin (), Elements.end ());
3127
3129
} else {
3128
3130
for (unsigned i = 0 ; i < NumElts; ++i) {
3129
3131
SDValue ExtVal = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
@@ -3141,10 +3143,19 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
3141
3143
DAG.getMemIntrinsicNode (Opcode, DL, DAG.getVTList (MVT::Other), Ops,
3142
3144
MemSD->getMemoryVT (), MemSD->getMemOperand ());
3143
3145
3144
- // return DCI.CombineTo(N, NewSt, true);
3145
3146
return NewSt;
3146
3147
}
3147
3148
3149
+ // Default variant where we don't pass in elements.
3150
+ static SDValue convertVectorStore (SDValue Op, SelectionDAG &DAG) {
3151
+ return convertVectorStore (Op, DAG, SmallVector<SDValue>{});
3152
+ }
3153
+
3154
+ SDValue NVPTXTargetLowering::LowerSTOREVector (SDValue Op,
3155
+ SelectionDAG &DAG) const {
3156
+ return convertVectorStore (Op, DAG);
3157
+ }
3158
+
3148
3159
// st i1 v, addr
3149
3160
// =>
3150
3161
// v1 = zxt v to i16
@@ -5289,6 +5300,9 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
5289
5300
// -->
5290
5301
// StoreRetvalV2 {a, b}
5291
5302
// likewise for V2 -> V4 case
5303
+ //
5304
+ // We also handle target-independent stores, which require us to first
5305
+ // convert to StoreV2.
5292
5306
5293
5307
std::optional<NVPTXISD::NodeType> NewOpcode;
5294
5308
switch (N->getOpcode ()) {
@@ -5314,8 +5328,8 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
5314
5328
SDValue CurrentOp = N->getOperand (I);
5315
5329
if (CurrentOp->getOpcode () == ISD::BUILD_VECTOR) {
5316
5330
assert (CurrentOp.getValueType () == MVT::v2f32);
5317
- NewOps.push_back (CurrentOp.getNode ()-> getOperand (0 ));
5318
- NewOps.push_back (CurrentOp.getNode ()-> getOperand (1 ));
5331
+ NewOps.push_back (CurrentOp.getOperand (0 ));
5332
+ NewOps.push_back (CurrentOp.getOperand (1 ));
5319
5333
} else {
5320
5334
NewOps.clear ();
5321
5335
break ;
@@ -6086,6 +6100,18 @@ static SDValue PerformBITCASTCombine(SDNode *N,
6086
6100
return SDValue ();
6087
6101
}
6088
6102
6103
+ static SDValue PerformStoreCombine (SDNode *N,
6104
+ TargetLowering::DAGCombinerInfo &DCI) {
6105
+ // check if the store'd value can be scalarized
6106
+ SDValue StoredVal = N->getOperand (1 );
6107
+ if (StoredVal.getValueType () == MVT::v2f32 &&
6108
+ StoredVal.getOpcode () == ISD::BUILD_VECTOR) {
6109
+ SmallVector<SDValue> Elements (StoredVal->op_values ());
6110
+ return convertVectorStore (SDValue (N, 0 ), DCI.DAG , Elements);
6111
+ }
6112
+ return SDValue ();
6113
+ }
6114
+
6089
6115
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
6090
6116
DAGCombinerInfo &DCI) const {
6091
6117
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -6115,6 +6141,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
6115
6141
case NVPTXISD::LoadParam:
6116
6142
case NVPTXISD::LoadParamV2:
6117
6143
return PerformLoadCombine (N, DCI);
6144
+ case ISD::STORE:
6145
+ return PerformStoreCombine (N, DCI);
6118
6146
case NVPTXISD::StoreParam:
6119
6147
case NVPTXISD::StoreParamV2:
6120
6148
case NVPTXISD::StoreParamV4:
0 commit comments