Skip to content

Commit 5ea3d25

Browse files
committed
[NVPTX] Add f32x2 instructions and register class
Also update some test cases to use the autogenerator.
1 parent 20fbbd7 commit 5ea3d25

27 files changed

+3523
-1111
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,17 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
459459
// We only care about 16x2 as it's the only real vector type we
460460
// need to deal with.
461461
MVT VT = Vector.getSimpleValueType();
462-
if (!Isv2x16VT(VT))
462+
if (!isPackedVectorTy(VT) || VT.getVectorNumElements() != 2)
463463
return false;
464+
465+
unsigned Opcode;
466+
if (VT.is32BitVector())
467+
Opcode = NVPTX::I32toV2I16;
468+
else if (VT.is64BitVector())
469+
Opcode = NVPTX::I64toV2I32;
470+
else
471+
llvm_unreachable("Unhandled packed type");
472+
464473
// Find and record all uses of this vector that extract element 0 or 1.
465474
SmallVector<SDNode *, 4> E0, E1;
466475
for (auto *U : Vector.getNode()->users()) {
@@ -484,11 +493,11 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
484493
if (E0.empty() || E1.empty())
485494
return false;
486495

487-
// Merge (f16 extractelt(V, 0), f16 extractelt(V,1))
488-
// into f16,f16 SplitF16x2(V)
496+
// Merge (EltTy extractelt(V, 0), EltTy extractelt(V,1))
497+
// into EltTy,EltTy Split[EltTy]x2(V)
489498
MVT EltVT = VT.getVectorElementType();
490499
SDNode *ScatterOp =
491-
CurDAG->getMachineNode(NVPTX::I32toV2I16, SDLoc(N), EltVT, EltVT, Vector);
500+
CurDAG->getMachineNode(Opcode, SDLoc(N), EltVT, EltVT, Vector);
492501
for (auto *Node : E0)
493502
ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 0));
494503
for (auto *Node : E1)
@@ -1004,6 +1013,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, std::optional<unsigned> Opcode_i8,
10041013
case MVT::i32:
10051014
case MVT::f32:
10061015
return Opcode_i32;
1016+
case MVT::v2f32:
10071017
case MVT::i64:
10081018
case MVT::f64:
10091019
return Opcode_i64;

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 154 additions & 89 deletions
Large diffs are not rendered by default.

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
151151
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
152152
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
153153
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
154+
def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">;
154155

155156
def True : Predicate<"true">;
156157

@@ -220,6 +221,7 @@ def BF16RT : RegTyInfo<bf16, B16, bf16imm, fpimm, supports_imm = 0>;
220221

221222
def F16X2RT : RegTyInfo<v2f16, B32, ?, ?, supports_imm = 0>;
222223
def BF16X2RT : RegTyInfo<v2bf16, B32, ?, ?, supports_imm = 0>;
224+
def F32X2RT : RegTyInfo<v2f32, B64, ?, ?, supports_imm = 0>;
223225

224226

225227
// This class provides a basic wrapper around an NVPTXInst that abstracts the
@@ -451,6 +453,18 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
451453
[(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
452454
Requires<[useFP16Math]>;
453455

456+
def f32x2rr_ftz :
457+
BasicNVPTXInst<(outs B64:$dst),
458+
(ins B64:$a, B64:$b),
459+
op_str # ".ftz.f32x2",
460+
[(set v2f32:$dst, (op_pat v2f32:$a, v2f32:$b))]>,
461+
Requires<[hasF32x2Instructions, doF32FTZ]>;
462+
def f32x2rr :
463+
BasicNVPTXInst<(outs B64:$dst),
464+
(ins B64:$a, B64:$b),
465+
op_str # ".f32x2",
466+
[(set v2f32:$dst, (op_pat v2f32:$a, v2f32:$b))]>,
467+
Requires<[hasF32x2Instructions]>;
454468
def f16x2rr_ftz :
455469
BasicNVPTXInst<(outs B32:$dst),
456470
(ins B32:$a, B32:$b),
@@ -829,6 +843,9 @@ def : Pat<(vt (select i1:$p, vt:$a, vt:$b)),
829843
(SELP_b32rr $a, $b, $p)>;
830844
}
831845

846+
def : Pat<(v2f32 (select i1:$p, v2f32:$a, v2f32:$b)),
847+
(SELP_b64rr $a, $b, $p)>;
848+
832849
//-----------------------------------
833850
// Test Instructions
834851
//-----------------------------------
@@ -1345,6 +1362,8 @@ defm BFMA16 : FMA<"fma.rn.bf16", BF16RT, [hasBF16Math]>;
13451362
defm BFMA16x2 : FMA<"fma.rn.bf16x2", BF16X2RT, [hasBF16Math]>;
13461363
defm FMA32_ftz : FMA<"fma.rn.ftz.f32", F32RT, [doF32FTZ]>;
13471364
defm FMA32 : FMA<"fma.rn.f32", F32RT>;
1365+
defm FMA32x2_ftz : FMA<"fma.rn.ftz.f32x2", F32X2RT, [hasF32x2Instructions, doF32FTZ]>;
1366+
defm FMA32x2 : FMA<"fma.rn.f32x2", F32X2RT, [hasF32x2Instructions]>;
13481367
defm FMA64 : FMA<"fma.rn.f64", F64RT>;
13491368

13501369
// sin/cos
@@ -2585,6 +2604,7 @@ def : Pat<(i32 (trunc (sra i64:$s, (i32 32)))), (I64toI32H $s)>;
25852604
def: Pat<(i32 (sext (extractelt v2i16:$src, 0))),
25862605
(CVT_INREG_s32_s16 $src)>;
25872606

2607+
// Handle extracting one element from the pair (32-bit types)
25882608
foreach vt = [v2f16, v2bf16, v2i16] in {
25892609
def : Pat<(extractelt vt:$src, 0), (I32toI16L_Sink $src)>, Requires<[hasPTX<71>]>;
25902610
def : Pat<(extractelt vt:$src, 1), (I32toI16H_Sink $src)>, Requires<[hasPTX<71>]>;
@@ -2596,10 +2616,21 @@ foreach vt = [v2f16, v2bf16, v2i16] in {
25962616
(V2I16toI32 $a, $b)>;
25972617
}
25982618

2619+
// Same thing for the 64-bit type v2f32.
2620+
foreach vt = [v2f32] in {
2621+
def : Pat<(extractelt vt:$src, 0), (I64toI32L_Sink $src)>, Requires<[hasPTX<71>]>;
2622+
def : Pat<(extractelt vt:$src, 1), (I64toI32H_Sink $src)>, Requires<[hasPTX<71>]>;
2623+
2624+
def : Pat<(extractelt vt:$src, 0), (I64toI32L $src)>;
2625+
def : Pat<(extractelt vt:$src, 1), (I64toI32H $src)>;
2626+
2627+
def : Pat<(vt (build_vector vt.ElementType:$a, vt.ElementType:$b)),
2628+
(V2I32toI64 $a, $b)>;
2629+
}
2630+
25992631
def: Pat<(v2i16 (scalar_to_vector i16:$a)),
26002632
(CVT_u32_u16 $a, CvtNONE)>;
26012633

2602-
26032634
def nvptx_build_vector : SDNode<"NVPTXISD::BUILD_VECTOR", SDTypeProfile<1, 2, []>, []>;
26042635

26052636
def : Pat<(i64 (nvptx_build_vector i32:$a, i32:$b)),

llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def B16 : NVPTXRegClass<[i16, f16, bf16], 16, (add (sequence "RS%u", 0, 4))>;
6060
def B32 : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8, f32], 32,
6161
(add (sequence "R%u", 0, 4),
6262
VRFrame32, VRFrameLocal32)>;
63-
def B64 : NVPTXRegClass<[i64, f64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
63+
def B64 : NVPTXRegClass<[i64, v2f32, f64], 64, (add (sequence "RL%u", 0, 4),
64+
VRFrame64, VRFrameLocal64)>;
6465
// 128-bit regs are not defined as general regs in NVPTX. They are used for inlineASM only.
6566
def B128 : NVPTXRegClass<[i128], 128, (add (sequence "RQ%u", 0, 4))>;
6667

llvm/lib/Target/NVPTX/NVPTXSubtarget.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
116116

117117
return HasTcgen05 && PTXVersion >= 86;
118118
}
119+
// f32x2 instructions in Blackwell family
120+
bool hasF32x2Instructions() const {
121+
return SmVersion >= 100 && PTXVersion >= 86;
122+
}
119123

120124
// TMA G2S copy with cta_group::1/2 support
121125
bool hasCpAsyncBulkTensorCTAGroupSupport() const {

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,9 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
129129
Insert = false;
130130
}
131131
}
132-
if (Insert && Isv2x16VT(VT)) {
133-
// Can be built in a single mov
132+
if (Insert && isPackedVectorTy(VT) && VT.is32BitVector()) {
133+
// Can be built in a single 32-bit mov (64-bit regs are emulated in SASS
134+
// with 2x 32-bit regs)
134135
Cost += 1;
135136
Insert = false;
136137
}

llvm/lib/Target/NVPTX/NVPTXUtilities.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,14 @@ inline unsigned promoteScalarArgumentSize(unsigned size) {
8585

8686
bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM);
8787

88-
inline bool Isv2x16VT(EVT VT) {
89-
return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
88+
inline bool isPackedVectorTy(EVT VT) {
89+
return (VT == MVT::v4i8 || VT == MVT::v2f16 || VT == MVT::v2bf16 ||
90+
VT == MVT::v2i16 || VT == MVT::v2f32);
91+
}
92+
93+
inline bool isPackedElementTy(EVT VT) {
94+
return (VT == MVT::i8 || VT == MVT::f16 || VT == MVT::bf16 ||
95+
VT == MVT::i16 || VT == MVT::f32);
9096
}
9197

9298
inline bool shouldPassAsArray(Type *Ty) {
Lines changed: 80 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
12
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_35 | FileCheck %s
23
; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_35 | %ptxas-verify %}
34

@@ -7,57 +8,105 @@ declare [2 x float] @bara([2 x float] %input)
78
declare {float, float} @bars({float, float} %input)
89

910
define void @test_v2f32(<2 x float> %input, ptr %output) {
10-
; CHECK-LABEL: @test_v2f32
11+
; CHECK-LABEL: test_v2f32(
12+
; CHECK: {
13+
; CHECK-NEXT: .reg .b64 %rd<5>;
14+
; CHECK-EMPTY:
15+
; CHECK-NEXT: // %bb.0:
16+
; CHECK-NEXT: ld.param.b64 %rd1, [test_v2f32_param_0];
17+
; CHECK-NEXT: { // callseq 0, 0
18+
; CHECK-NEXT: .param .align 8 .b8 param0[8];
19+
; CHECK-NEXT: st.param.b64 [param0], %rd1;
20+
; CHECK-NEXT: .param .align 8 .b8 retval0[8];
21+
; CHECK-NEXT: call.uni (retval0), barv, (param0);
22+
; CHECK-NEXT: ld.param.b64 %rd2, [retval0];
23+
; CHECK-NEXT: } // callseq 0
24+
; CHECK-NEXT: ld.param.b64 %rd4, [test_v2f32_param_1];
25+
; CHECK-NEXT: st.b64 [%rd4], %rd2;
26+
; CHECK-NEXT: ret;
1127
%call = tail call <2 x float> @barv(<2 x float> %input)
12-
; CHECK: .param .align 8 .b8 retval0[8];
13-
; CHECK: ld.param.v2.b32 {[[E0:%r[0-9]+]], [[E1:%r[0-9]+]]}, [retval0];
1428
store <2 x float> %call, ptr %output, align 8
15-
; CHECK: st.v2.b32 [{{%rd[0-9]+}}], {[[E0]], [[E1]]}
1629
ret void
1730
}
1831

1932
define void @test_v3f32(<3 x float> %input, ptr %output) {
20-
; CHECK-LABEL: @test_v3f32
21-
;
33+
; CHECK-LABEL: test_v3f32(
34+
; CHECK: {
35+
; CHECK-NEXT: .reg .b32 %r<10>;
36+
; CHECK-NEXT: .reg .b64 %rd<2>;
37+
; CHECK-EMPTY:
38+
; CHECK-NEXT: // %bb.0:
39+
; CHECK-NEXT: ld.param.v2.b32 {%r1, %r2}, [test_v3f32_param_0];
40+
; CHECK-NEXT: ld.param.b32 %r3, [test_v3f32_param_0+8];
41+
; CHECK-NEXT: { // callseq 1, 0
42+
; CHECK-NEXT: .param .align 16 .b8 param0[16];
43+
; CHECK-NEXT: st.param.v2.b32 [param0], {%r1, %r2};
44+
; CHECK-NEXT: st.param.b32 [param0+8], %r3;
45+
; CHECK-NEXT: .param .align 16 .b8 retval0[16];
46+
; CHECK-NEXT: call.uni (retval0), barv3, (param0);
47+
; CHECK-NEXT: ld.param.v2.b32 {%r4, %r5}, [retval0];
48+
; CHECK-NEXT: ld.param.b32 %r6, [retval0+8];
49+
; CHECK-NEXT: } // callseq 1
50+
; CHECK-NEXT: ld.param.b64 %rd1, [test_v3f32_param_1];
51+
; CHECK-NEXT: st.b32 [%rd1+8], %r6;
52+
; CHECK-NEXT: st.v2.b32 [%rd1], {%r4, %r5};
53+
; CHECK-NEXT: ret;
2254
%call = tail call <3 x float> @barv3(<3 x float> %input)
23-
; CHECK: .param .align 16 .b8 retval0[16];
24-
; CHECK-DAG: ld.param.v2.b32 {[[E0:%r[0-9]+]], [[E1:%r[0-9]+]]}, [retval0];
25-
; CHECK-DAG: ld.param.b32 [[E2:%r[0-9]+]], [retval0+8];
2655
; Make sure we don't load more values than than we need to.
27-
; CHECK-NOT: ld.param.b32 [[E3:%r[0-9]+]], [retval0+12];
2856
store <3 x float> %call, ptr %output, align 8
29-
; CHECK-DAG: st.b32 [{{%rd[0-9]}}+8],
30-
; -- This is suboptimal. We should do st.v2.f32 instead
31-
; of combining 2xf32 info i64.
32-
; CHECK-DAG: st.b64 [{{%rd[0-9]}}],
33-
; CHECK: ret;
3457
ret void
3558
}
3659

3760
define void @test_a2f32([2 x float] %input, ptr %output) {
38-
; CHECK-LABEL: @test_a2f32
61+
; CHECK-LABEL: test_a2f32(
62+
; CHECK: {
63+
; CHECK-NEXT: .reg .b32 %r<7>;
64+
; CHECK-NEXT: .reg .b64 %rd<2>;
65+
; CHECK-EMPTY:
66+
; CHECK-NEXT: // %bb.0:
67+
; CHECK-NEXT: ld.param.b32 %r1, [test_a2f32_param_0];
68+
; CHECK-NEXT: ld.param.b32 %r2, [test_a2f32_param_0+4];
69+
; CHECK-NEXT: { // callseq 2, 0
70+
; CHECK-NEXT: .param .align 4 .b8 param0[8];
71+
; CHECK-NEXT: st.param.b32 [param0], %r1;
72+
; CHECK-NEXT: st.param.b32 [param0+4], %r2;
73+
; CHECK-NEXT: .param .align 4 .b8 retval0[8];
74+
; CHECK-NEXT: call.uni (retval0), bara, (param0);
75+
; CHECK-NEXT: ld.param.b32 %r3, [retval0];
76+
; CHECK-NEXT: ld.param.b32 %r4, [retval0+4];
77+
; CHECK-NEXT: } // callseq 2
78+
; CHECK-NEXT: ld.param.b64 %rd1, [test_a2f32_param_1];
79+
; CHECK-NEXT: st.b32 [%rd1+4], %r4;
80+
; CHECK-NEXT: st.b32 [%rd1], %r3;
81+
; CHECK-NEXT: ret;
3982
%call = tail call [2 x float] @bara([2 x float] %input)
40-
; CHECK: .param .align 4 .b8 retval0[8];
41-
; CHECK-DAG: ld.param.b32 [[ELEMA1:%r[0-9]+]], [retval0];
42-
; CHECK-DAG: ld.param.b32 [[ELEMA2:%r[0-9]+]], [retval0+4];
4383
store [2 x float] %call, ptr %output, align 4
44-
; CHECK: }
45-
; CHECK-DAG: st.b32 [{{%rd[0-9]+}}], [[ELEMA1]]
46-
; CHECK-DAG: st.b32 [{{%rd[0-9]+}}+4], [[ELEMA2]]
4784
ret void
48-
; CHECK: ret
4985
}
5086

5187
define void @test_s2f32({float, float} %input, ptr %output) {
52-
; CHECK-LABEL: @test_s2f32
88+
; CHECK-LABEL: test_s2f32(
89+
; CHECK: {
90+
; CHECK-NEXT: .reg .b32 %r<7>;
91+
; CHECK-NEXT: .reg .b64 %rd<2>;
92+
; CHECK-EMPTY:
93+
; CHECK-NEXT: // %bb.0:
94+
; CHECK-NEXT: ld.param.b32 %r1, [test_s2f32_param_0];
95+
; CHECK-NEXT: ld.param.b32 %r2, [test_s2f32_param_0+4];
96+
; CHECK-NEXT: { // callseq 3, 0
97+
; CHECK-NEXT: .param .align 4 .b8 param0[8];
98+
; CHECK-NEXT: st.param.b32 [param0], %r1;
99+
; CHECK-NEXT: st.param.b32 [param0+4], %r2;
100+
; CHECK-NEXT: .param .align 4 .b8 retval0[8];
101+
; CHECK-NEXT: call.uni (retval0), bars, (param0);
102+
; CHECK-NEXT: ld.param.b32 %r3, [retval0];
103+
; CHECK-NEXT: ld.param.b32 %r4, [retval0+4];
104+
; CHECK-NEXT: } // callseq 3
105+
; CHECK-NEXT: ld.param.b64 %rd1, [test_s2f32_param_1];
106+
; CHECK-NEXT: st.b32 [%rd1+4], %r4;
107+
; CHECK-NEXT: st.b32 [%rd1], %r3;
108+
; CHECK-NEXT: ret;
53109
%call = tail call {float, float} @bars({float, float} %input)
54-
; CHECK: .param .align 4 .b8 retval0[8];
55-
; CHECK-DAG: ld.param.b32 [[ELEMS1:%r[0-9]+]], [retval0];
56-
; CHECK-DAG: ld.param.b32 [[ELEMS2:%r[0-9]+]], [retval0+4];
57110
store {float, float} %call, ptr %output, align 4
58-
; CHECK: }
59-
; CHECK-DAG: st.b32 [{{%rd[0-9]+}}], [[ELEMS1]]
60-
; CHECK-DAG: st.b32 [{{%rd[0-9]+}}+4], [[ELEMS2]]
61111
ret void
62-
; CHECK: ret
63112
}

0 commit comments

Comments
 (0)