Skip to content

Commit 30e5848

Browse files
committed
[NVPTX] support generic LDG/LDU for packed data types
Support ld.global.nc.b64/ldu.global.b64 for v2f32 and ld.global.nc.b32/ldu.global.b32 for v2f16/v2bf16/v2i16/v4i8
1 parent 00f5b69 commit 30e5848

File tree

6 files changed

+76
-50
lines changed

6 files changed

+76
-50
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 66 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,9 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12761276
EVT OrigType = N->getValueType(0);
12771277
EVT EltVT = Mem->getMemoryVT();
12781278
unsigned NumElts = 1;
1279+
1280+
std::optional<unsigned> Opcode;
1281+
12791282
if (EltVT.isVector()) {
12801283
NumElts = EltVT.getVectorNumElements();
12811284
EltVT = EltVT.getVectorElementType();
@@ -1288,6 +1291,24 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12881291
(EltVT == MVT::i8 && OrigType == MVT::v4i8)) {
12891292
assert(NumElts % OrigType.getVectorNumElements() == 0 &&
12901293
"NumElts must be divisible by the number of elts in subvectors");
1294+
if (N->getOpcode() == ISD::LOAD ||
1295+
N->getOpcode() == ISD::INTRINSIC_W_CHAIN) {
1296+
switch (OrigType.getSimpleVT().SimpleTy) {
1297+
case MVT::v2f32:
1298+
Opcode = N->getOpcode() == ISD::LOAD ? NVPTX::INT_PTX_LDG_GLOBAL_b64
1299+
: NVPTX::INT_PTX_LDU_GLOBAL_b64;
1300+
break;
1301+
case MVT::v2f16:
1302+
case MVT::v2bf16:
1303+
case MVT::v2i16:
1304+
case MVT::v4i8:
1305+
Opcode = N->getOpcode() == ISD::LOAD ? NVPTX::INT_PTX_LDG_GLOBAL_b32
1306+
: NVPTX::INT_PTX_LDU_GLOBAL_b32;
1307+
break;
1308+
default:
1309+
llvm_unreachable("Unhandled packed vector type");
1310+
}
1311+
}
12911312
EltVT = OrigType;
12921313
NumElts /= OrigType.getVectorNumElements();
12931314
}
@@ -1309,50 +1330,51 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
13091330
SelectADDR(Op1, Base, Offset);
13101331
SDValue Ops[] = {Base, Offset, Chain};
13111332

1312-
std::optional<unsigned> Opcode;
1313-
switch (N->getOpcode()) {
1314-
default:
1315-
return false;
1316-
case ISD::LOAD:
1317-
Opcode = pickOpcodeForVT(
1318-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8,
1319-
NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
1320-
NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
1321-
NVPTX::INT_PTX_LDG_GLOBAL_f64);
1322-
break;
1323-
case ISD::INTRINSIC_W_CHAIN:
1324-
Opcode = pickOpcodeForVT(
1325-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8,
1326-
NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
1327-
NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
1328-
NVPTX::INT_PTX_LDU_GLOBAL_f64);
1329-
break;
1330-
case NVPTXISD::LoadV2:
1331-
Opcode = pickOpcodeForVT(
1332-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v2i8_ELE,
1333-
NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
1334-
NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
1335-
NVPTX::INT_PTX_LDG_G_v2f64_ELE);
1336-
break;
1337-
case NVPTXISD::LDUV2:
1338-
Opcode = pickOpcodeForVT(
1339-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v2i8_ELE,
1340-
NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
1341-
NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
1342-
NVPTX::INT_PTX_LDU_G_v2f64_ELE);
1343-
break;
1344-
case NVPTXISD::LoadV4:
1345-
Opcode = pickOpcodeForVT(
1346-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
1347-
NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
1348-
std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
1349-
break;
1350-
case NVPTXISD::LDUV4:
1351-
Opcode = pickOpcodeForVT(
1352-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE,
1353-
NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
1354-
std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
1355-
break;
1333+
if (!Opcode) {
1334+
switch (N->getOpcode()) {
1335+
default:
1336+
return false;
1337+
case ISD::LOAD:
1338+
Opcode = pickOpcodeForVT(
1339+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8,
1340+
NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
1341+
NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
1342+
NVPTX::INT_PTX_LDG_GLOBAL_f64);
1343+
break;
1344+
case ISD::INTRINSIC_W_CHAIN:
1345+
Opcode = pickOpcodeForVT(
1346+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8,
1347+
NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
1348+
NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
1349+
NVPTX::INT_PTX_LDU_GLOBAL_f64);
1350+
break;
1351+
case NVPTXISD::LoadV2:
1352+
Opcode = pickOpcodeForVT(
1353+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v2i8_ELE,
1354+
NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
1355+
NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
1356+
NVPTX::INT_PTX_LDG_G_v2f64_ELE);
1357+
break;
1358+
case NVPTXISD::LDUV2:
1359+
Opcode = pickOpcodeForVT(
1360+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v2i8_ELE,
1361+
NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
1362+
NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
1363+
NVPTX::INT_PTX_LDU_G_v2f64_ELE);
1364+
break;
1365+
case NVPTXISD::LoadV4:
1366+
Opcode = pickOpcodeForVT(
1367+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
1368+
NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
1369+
std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
1370+
break;
1371+
case NVPTXISD::LDUV4:
1372+
Opcode = pickOpcodeForVT(
1373+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE,
1374+
NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
1375+
std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
1376+
break;
1377+
}
13561378
}
13571379
if (!Opcode)
13581380
return false;

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2702,7 +2702,9 @@ class LDU_G<string TyStr, NVPTXRegClass regclass>
27022702
def INT_PTX_LDU_GLOBAL_i8 : LDU_G<"u8", Int16Regs>;
27032703
def INT_PTX_LDU_GLOBAL_i16 : LDU_G<"u16", Int16Regs>;
27042704
def INT_PTX_LDU_GLOBAL_i32 : LDU_G<"u32", Int32Regs>;
2705+
def INT_PTX_LDU_GLOBAL_b32 : LDU_G<"b32", Int32Regs>;
27052706
def INT_PTX_LDU_GLOBAL_i64 : LDU_G<"u64", Int64Regs>;
2707+
def INT_PTX_LDU_GLOBAL_b64 : LDU_G<"b64", Int64Regs>;
27062708
def INT_PTX_LDU_GLOBAL_f32 : LDU_G<"f32", Float32Regs>;
27072709
def INT_PTX_LDU_GLOBAL_f64 : LDU_G<"f64", Float64Regs>;
27082710

@@ -2752,7 +2754,9 @@ class LDG_G<string TyStr, NVPTXRegClass regclass>
27522754
def INT_PTX_LDG_GLOBAL_i8 : LDG_G<"u8", Int16Regs>;
27532755
def INT_PTX_LDG_GLOBAL_i16 : LDG_G<"u16", Int16Regs>;
27542756
def INT_PTX_LDG_GLOBAL_i32 : LDG_G<"u32", Int32Regs>;
2757+
def INT_PTX_LDG_GLOBAL_b32 : LDG_G<"b32", Int32Regs>;
27552758
def INT_PTX_LDG_GLOBAL_i64 : LDG_G<"u64", Int64Regs>;
2759+
def INT_PTX_LDG_GLOBAL_b64 : LDG_G<"b64", Int64Regs>;
27562760
def INT_PTX_LDG_GLOBAL_f32 : LDG_G<"f32", Float32Regs>;
27572761
def INT_PTX_LDG_GLOBAL_f64 : LDG_G<"f64", Float64Regs>;
27582762

llvm/test/CodeGen/NVPTX/ldg-invariant.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ define half @ld_global_v2f16(ptr addrspace(1) %ptr) {
3232
; CHECK-EMPTY:
3333
; CHECK-NEXT: // %bb.0:
3434
; CHECK-NEXT: ld.param.u64 %rd1, [ld_global_v2f16_param_0];
35-
; CHECK-NEXT: ld.global.nc.u32 %r1, [%rd1];
35+
; CHECK-NEXT: ld.global.nc.b32 %r1, [%rd1];
3636
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r1;
3737
; CHECK-NEXT: cvt.f32.f16 %f1, %rs2;
3838
; CHECK-NEXT: cvt.f32.f16 %f2, %rs1;

llvm/test/CodeGen/NVPTX/ldu-ldg.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ define <2 x half> @test_ldu_v2f16(ptr addrspace(1) %ptr) {
154154
; CHECK-EMPTY:
155155
; CHECK-NEXT: // %bb.0:
156156
; CHECK-NEXT: ld.param.u64 %rd1, [test_ldu_v2f16_param_0];
157-
; CHECK-NEXT: ldu.global.u32 %r1, [%rd1];
157+
; CHECK-NEXT: ldu.global.b32 %r1, [%rd1];
158158
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
159159
; CHECK-NEXT: ret;
160160
%val = tail call <2 x half> @llvm.nvvm.ldu.global.f.v2f16.p1(ptr addrspace(1) %ptr, i32 4)
@@ -291,7 +291,7 @@ define <2 x half> @test_ldg_v2f16(ptr addrspace(1) %ptr) {
291291
; CHECK-EMPTY:
292292
; CHECK-NEXT: // %bb.0:
293293
; CHECK-NEXT: ld.param.u64 %rd1, [test_ldg_v2f16_param_0];
294-
; CHECK-NEXT: ld.global.nc.u32 %r1, [%rd1];
294+
; CHECK-NEXT: ld.global.nc.b32 %r1, [%rd1];
295295
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
296296
; CHECK-NEXT: ret;
297297
%val = tail call <2 x half> @llvm.nvvm.ldg.global.f.v2f16.p1(ptr addrspace(1) %ptr, i32 4)

llvm/test/CodeGen/NVPTX/load-with-non-coherent-cache.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ define ptx_kernel void @foo7(ptr noalias readonly %from, ptr %to) {
8282
; SM20-LABEL: .visible .entry foo8(
8383
; SM20: ld.global.u32
8484
; SM35-LABEL: .visible .entry foo8(
85-
; SM35: ld.global.nc.u32
85+
; SM35: ld.global.nc.b32
8686
define ptx_kernel void @foo8(ptr noalias readonly %from, ptr %to) {
8787
%1 = load <2 x i16>, ptr %from
8888
store <2 x i16> %1, ptr %to
@@ -132,7 +132,7 @@ define ptx_kernel void @foo12(ptr noalias readonly %from, ptr %to) {
132132
; SM20-LABEL: .visible .entry foo13(
133133
; SM20: ld.global.u32
134134
; SM35-LABEL: .visible .entry foo13(
135-
; SM35: ld.global.nc.u32
135+
; SM35: ld.global.nc.b32
136136
define ptx_kernel void @foo13(ptr noalias readonly %from, ptr %to) {
137137
%1 = load <4 x i8>, ptr %from
138138
store <4 x i8> %1, ptr %to

llvm/test/CodeGen/NVPTX/read-global-variable-constant.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ define float @test_gv_float() {
1717

1818
; CHECK-LABEL: test_gv_float2()
1919
define <2 x float> @test_gv_float2() {
20-
; CHECK: ld.global.nc.v2.f32
20+
; CHECK: ld.global.nc.b64
2121
%v = load <2 x float>, ptr @gv_float2
2222
ret <2 x float> %v
2323
}

0 commit comments

Comments
 (0)