@@ -833,7 +833,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
833
833
setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
834
834
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
835
835
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
836
- ISD::TRUNCATE, ISD::LOAD, ISD::BITCAST});
836
+ ISD::TRUNCATE, ISD::LOAD, ISD::STORE, ISD:: BITCAST});
837
837
838
838
// setcc for f16x2 and bf16x2 needs special handling to prevent
839
839
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -3091,10 +3091,10 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
3091
3091
if (Op.getValueType () == MVT::i1)
3092
3092
return LowerLOADi1 (Op, DAG);
3093
3093
3094
- // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3095
- // unaligned loads and have to handle it here.
3094
+ // v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3095
+ // handle unaligned loads and have to handle it here.
3096
3096
EVT VT = Op.getValueType ();
3097
- if (Isv2x16VT (VT) || VT == MVT::v4i8) {
3097
+ if (Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 ) {
3098
3098
LoadSDNode *Load = cast<LoadSDNode>(Op);
3099
3099
EVT MemVT = Load->getMemoryVT ();
3100
3100
if (!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
@@ -3138,15 +3138,15 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
3138
3138
if (VT == MVT::i1)
3139
3139
return LowerSTOREi1 (Op, DAG);
3140
3140
3141
- // v2f16 is legal, so we can't rely on legalizer to handle unaligned
3142
- // stores and have to handle it here.
3143
- if ((Isv2x16VT (VT) || VT == MVT::v4i8) &&
3141
+ // v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3142
+ // handle unaligned stores and have to handle it here.
3143
+ if ((Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 ) &&
3144
3144
!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
3145
3145
VT, *Store->getMemOperand ()))
3146
3146
return expandUnalignedStore (Store, DAG);
3147
3147
3148
- // v2f16, v2bf16 and v2i16 don't need special handling.
3149
- if (Isv2x16VT (VT) || VT == MVT::v4i8)
3148
+ // v2f16/ v2bf16/ v2i16/v4i8/v2f32 don't need special handling.
3149
+ if (Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 )
3150
3150
return SDValue ();
3151
3151
3152
3152
if (VT.isVector ())
@@ -3155,8 +3155,8 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
3155
3155
return SDValue ();
3156
3156
}
3157
3157
3158
- SDValue
3159
- NVPTXTargetLowering::LowerSTOREVector (SDValue Op, SelectionDAG &DAG) const {
3158
+ static SDValue convertVectorStore (SDValue Op, SelectionDAG &DAG,
3159
+ const SmallVectorImpl<SDValue> &Elements) {
3160
3160
SDNode *N = Op.getNode ();
3161
3161
SDValue Val = N->getOperand (1 );
3162
3162
SDLoc DL (N);
@@ -3223,6 +3223,8 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
3223
3223
SDValue SubVector = DAG.getBuildVector (EltVT, DL, SubVectorElts);
3224
3224
Ops.push_back (SubVector);
3225
3225
}
3226
+ } else if (!Elements.empty ()) {
3227
+ Ops.insert (Ops.end (), Elements.begin (), Elements.end ());
3226
3228
} else {
3227
3229
for (unsigned i = 0 ; i < NumElts; ++i) {
3228
3230
SDValue ExtVal = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
@@ -3240,10 +3242,19 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
3240
3242
DAG.getMemIntrinsicNode (Opcode, DL, DAG.getVTList (MVT::Other), Ops,
3241
3243
MemSD->getMemoryVT (), MemSD->getMemOperand ());
3242
3244
3243
- // return DCI.CombineTo(N, NewSt, true);
3244
3245
return NewSt;
3245
3246
}
3246
3247
3248
+ // Default variant where we don't pass in elements.
3249
+ static SDValue convertVectorStore (SDValue Op, SelectionDAG &DAG) {
3250
+ return convertVectorStore (Op, DAG, SmallVector<SDValue>{});
3251
+ }
3252
+
3253
+ SDValue NVPTXTargetLowering::LowerSTOREVector (SDValue Op,
3254
+ SelectionDAG &DAG) const {
3255
+ return convertVectorStore (Op, DAG);
3256
+ }
3257
+
3247
3258
// st i1 v, addr
3248
3259
// =>
3249
3260
// v1 = zxt v to i16
@@ -5402,6 +5413,9 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
5402
5413
// -->
5403
5414
// StoreRetvalV2 {a, b}
5404
5415
// likewise for V2 -> V4 case
5416
+ //
5417
+ // We also handle target-independent stores, which require us to first
5418
+ // convert to StoreV2.
5405
5419
5406
5420
std::optional<NVPTXISD::NodeType> NewOpcode;
5407
5421
switch (N->getOpcode ()) {
@@ -5427,8 +5441,8 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
5427
5441
SDValue CurrentOp = N->getOperand (I);
5428
5442
if (CurrentOp->getOpcode () == ISD::BUILD_VECTOR) {
5429
5443
assert (CurrentOp.getValueType () == MVT::v2f32);
5430
- NewOps.push_back (CurrentOp.getNode ()-> getOperand (0 ));
5431
- NewOps.push_back (CurrentOp.getNode ()-> getOperand (1 ));
5444
+ NewOps.push_back (CurrentOp.getOperand (0 ));
5445
+ NewOps.push_back (CurrentOp.getOperand (1 ));
5432
5446
} else {
5433
5447
NewOps.clear ();
5434
5448
break ;
@@ -6199,6 +6213,18 @@ static SDValue PerformBITCASTCombine(SDNode *N,
6199
6213
return SDValue ();
6200
6214
}
6201
6215
6216
+ static SDValue PerformStoreCombine (SDNode *N,
6217
+ TargetLowering::DAGCombinerInfo &DCI) {
6218
+ // check if the store'd value can be scalarized
6219
+ SDValue StoredVal = N->getOperand (1 );
6220
+ if (StoredVal.getValueType () == MVT::v2f32 &&
6221
+ StoredVal.getOpcode () == ISD::BUILD_VECTOR) {
6222
+ SmallVector<SDValue> Elements (StoredVal->op_values ());
6223
+ return convertVectorStore (SDValue (N, 0 ), DCI.DAG , Elements);
6224
+ }
6225
+ return SDValue ();
6226
+ }
6227
+
6202
6228
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
6203
6229
DAGCombinerInfo &DCI) const {
6204
6230
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -6228,6 +6254,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
6228
6254
case NVPTXISD::LoadParam:
6229
6255
case NVPTXISD::LoadParamV2:
6230
6256
return PerformLoadCombine (N, DCI);
6257
+ case ISD::STORE:
6258
+ return PerformStoreCombine (N, DCI);
6231
6259
case NVPTXISD::StoreParam:
6232
6260
case NVPTXISD::StoreParamV2:
6233
6261
case NVPTXISD::StoreParamV4:
0 commit comments