Skip to content

Commit cdac087

Browse files
committed
[NVPTX] Add f32x2 instructions and register class
1 parent 432c5f2 commit cdac087

21 files changed

+3179
-852
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -471,8 +471,13 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
471471
// We only care about 16x2 as it's the only real vector type we
472472
// need to deal with.
473473
MVT VT = Vector.getSimpleValueType();
474-
if (!Isv2x16VT(VT))
475-
return false;
474+
auto Opcode = NVPTX::I32toV2I16;
475+
if (!Isv2x16VT(VT)) {
476+
if (VT == MVT::v2f32)
477+
Opcode = NVPTX::I64toV2I32;
478+
else
479+
return false;
480+
}
476481
// Find and record all uses of this vector that extract element 0 or 1.
477482
SmallVector<SDNode *, 4> E0, E1;
478483
for (auto *U : Vector.getNode()->users()) {
@@ -496,11 +501,11 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
496501
if (E0.empty() || E1.empty())
497502
return false;
498503

499-
// Merge (f16 extractelt(V, 0), f16 extractelt(V,1))
500-
// into f16,f16 SplitF16x2(V)
504+
// Merge (EltTy extractelt(V, 0), EltTy extractelt(V,1))
505+
// into EltTy,EltTy Split[EltTy]x2(V)
501506
MVT EltVT = VT.getVectorElementType();
502507
SDNode *ScatterOp =
503-
CurDAG->getMachineNode(NVPTX::I32toV2I16, SDLoc(N), EltVT, EltVT, Vector);
508+
CurDAG->getMachineNode(Opcode, SDLoc(N), EltVT, EltVT, Vector);
504509
for (auto *Node : E0)
505510
ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 0));
506511
for (auto *Node : E1)
@@ -1035,6 +1040,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, std::optional<unsigned> Opcode_i8,
10351040
case MVT::i32:
10361041
case MVT::f32:
10371042
return Opcode_i32;
1043+
case MVT::v2f32:
10381044
case MVT::i64:
10391045
case MVT::f64:
10401046
return Opcode_i64;
@@ -1245,7 +1251,9 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12451251
EltVT = EltVT.getVectorElementType();
12461252
// vectors of 8/16bits type are loaded/stored as multiples of v4i8/v2x16
12471253
// elements.
1248-
if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
1254+
// Packed vector types are loaded/stored in a single register.
1255+
if ((EltVT == MVT::f32 && OrigType == MVT::v2f32) ||
1256+
(EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
12491257
(EltVT == MVT::bf16 && OrigType == MVT::v2bf16) ||
12501258
(EltVT == MVT::i16 && OrigType == MVT::v2i16) ||
12511259
(EltVT == MVT::i8 && OrigType == MVT::v4i8)) {

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 158 additions & 51 deletions
Large diffs are not rendered by default.

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
158158
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
159159
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
160160
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
161+
def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">;
161162

162163
def True : Predicate<"true">;
163164
def False : Predicate<"false">;
@@ -193,6 +194,7 @@ class ValueToRegClass<ValueType T> {
193194
!eq(name, "bf16"): Int16Regs,
194195
!eq(name, "v2bf16"): Int32Regs,
195196
!eq(name, "f32"): Float32Regs,
197+
!eq(name, "v2f32"): Int64Regs,
196198
!eq(name, "f64"): Float64Regs,
197199
!eq(name, "ai32"): Int32ArgRegs,
198200
!eq(name, "ai64"): Int64ArgRegs,
@@ -239,6 +241,7 @@ def BF16RT : RegTyInfo<bf16, Int16Regs, bf16imm, fpimm, supports_imm = 0>;
239241

240242
def F16X2RT : RegTyInfo<v2f16, Int32Regs, ?, ?, supports_imm = 0>;
241243
def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
244+
def F32X2RT : RegTyInfo<v2f32, Int64Regs, ?, ?, supports_imm = 0>;
242245

243246

244247
// This class provides a basic wrapper around an NVPTXInst that abstracts the
@@ -461,6 +464,18 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
461464
[(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
462465
Requires<[useFP16Math]>;
463466

467+
def f32x2rr_ftz :
468+
BasicNVPTXInst<(outs Int64Regs:$dst),
469+
(ins Int64Regs:$a, Int64Regs:$b),
470+
op_str # ".ftz.f32x2",
471+
[(set v2f32:$dst, (op_pat v2f32:$a, v2f32:$b))]>,
472+
Requires<[hasF32x2Instructions, doF32FTZ]>;
473+
def f32x2rr :
474+
BasicNVPTXInst<(outs Int64Regs:$dst),
475+
(ins Int64Regs:$a, Int64Regs:$b),
476+
op_str # ".f32x2",
477+
[(set v2f32:$dst, (op_pat v2f32:$a, v2f32:$b))]>,
478+
Requires<[hasF32x2Instructions]>;
464479
def f16x2rr_ftz :
465480
BasicNVPTXInst<(outs Int32Regs:$dst),
466481
(ins Int32Regs:$a, Int32Regs:$b),
@@ -839,6 +854,9 @@ def : Pat<(vt (select i1:$p, vt:$a, vt:$b)),
839854
(SELP_b32rr $a, $b, $p)>;
840855
}
841856

857+
def : Pat<(v2f32 (select i1:$p, v2f32:$a, v2f32:$b)),
858+
(SELP_b64rr $a, $b, $p)>;
859+
842860
//-----------------------------------
843861
// Test Instructions
844862
//-----------------------------------
@@ -1387,6 +1405,8 @@ defm BFMA16 : FMA<"fma.rn.bf16", BF16RT, [hasBF16Math]>;
13871405
defm BFMA16x2 : FMA<"fma.rn.bf16x2", BF16X2RT, [hasBF16Math]>;
13881406
defm FMA32_ftz : FMA<"fma.rn.ftz.f32", F32RT, [doF32FTZ]>;
13891407
defm FMA32 : FMA<"fma.rn.f32", F32RT>;
1408+
defm FMA32x2_ftz : FMA<"fma.rn.ftz.f32x2", F32X2RT, [hasF32x2Instructions, doF32FTZ]>;
1409+
defm FMA32x2 : FMA<"fma.rn.f32x2", F32X2RT, [hasF32x2Instructions]>;
13901410
defm FMA64 : FMA<"fma.rn.f64", F64RT>;
13911411

13921412
// sin/cos
@@ -2739,6 +2759,7 @@ def : Pat<(i32 (trunc (sra i64:$s, (i32 32)))), (I64toI32H $s)>;
27392759
def: Pat<(i32 (sext (extractelt v2i16:$src, 0))),
27402760
(CVT_INREG_s32_s16 $src)>;
27412761

2762+
// Handle extracting one element from the pair (32-bit types)
27422763
foreach vt = [v2f16, v2bf16, v2i16] in {
27432764
def : Pat<(extractelt vt:$src, 0), (I32toI16L_Sink $src)>, Requires<[hasPTX<71>]>;
27442765
def : Pat<(extractelt vt:$src, 1), (I32toI16H_Sink $src)>, Requires<[hasPTX<71>]>;
@@ -2750,10 +2771,21 @@ foreach vt = [v2f16, v2bf16, v2i16] in {
27502771
(V2I16toI32 $a, $b)>;
27512772
}
27522773

2774+
// Same thing for the 64-bit type v2f32.
2775+
foreach vt = [v2f32] in {
2776+
def : Pat<(extractelt vt:$src, 0), (I64toI32L_Sink $src)>, Requires<[hasPTX<71>]>;
2777+
def : Pat<(extractelt vt:$src, 1), (I64toI32H_Sink $src)>, Requires<[hasPTX<71>]>;
2778+
2779+
def : Pat<(extractelt vt:$src, 0), (I64toI32L $src)>;
2780+
def : Pat<(extractelt vt:$src, 1), (I64toI32H $src)>;
2781+
2782+
def : Pat<(vt (build_vector vt.ElementType:$a, vt.ElementType:$b)),
2783+
(V2I32toI64 $a, $b)>;
2784+
}
2785+
27532786
def: Pat<(v2i16 (scalar_to_vector i16:$a)),
27542787
(CVT_u32_u16 $a, CvtNONE)>;
27552788

2756-
27572789
def nvptx_build_vector : SDNode<"NVPTXISD::BUILD_VECTOR", SDTypeProfile<1, 2, []>, []>;
27582790

27592791
def : Pat<(i64 (nvptx_build_vector i32:$a, i32:$b)),

llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def Int16Regs : NVPTXRegClass<[i16, f16, bf16], 16, (add (sequence "RS%u", 0, 4)
6060
def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8, f32], 32,
6161
(add (sequence "R%u", 0, 4),
6262
VRFrame32, VRFrameLocal32)>;
63-
def Int64Regs : NVPTXRegClass<[i64, f64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
63+
def Int64Regs : NVPTXRegClass<[i64, v2f32, f64], 64,
64+
(add (sequence "RL%u", 0, 4),
65+
VRFrame64, VRFrameLocal64)>;
6466
// 128-bit regs are not defined as general regs in NVPTX. They are used for inlineASM only.
6567
def Int128Regs : NVPTXRegClass<[i128], 128, (add (sequence "RQ%u", 0, 4))>;
6668

llvm/lib/Target/NVPTX/NVPTXSubtarget.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
116116

117117
return HasTcgen05 && PTXVersion >= 86;
118118
}
119+
// f32x2 instructions in Blackwell family
120+
bool hasF32x2Instructions() const {
121+
return SmVersion >= 100 && PTXVersion >= 86;
122+
}
119123

120124
// Prior to CUDA 12.3 ptxas did not recognize that the trap instruction
121125
// terminates a basic block. Instead, it would assume that control flow

llvm/test/CodeGen/NVPTX/aggregate-return.ll

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ define void @test_v2f32(<2 x float> %input, ptr %output) {
1010
; CHECK-LABEL: @test_v2f32
1111
%call = tail call <2 x float> @barv(<2 x float> %input)
1212
; CHECK: .param .align 8 .b8 retval0[8];
13-
; CHECK: ld.param.v2.b32 {[[E0:%r[0-9]+]], [[E1:%r[0-9]+]]}, [retval0];
13+
; CHECK: ld.param.b64 [[E0_1:%rd[0-9]+]], [retval0];
14+
; CHECK: mov.b64 {[[E0:%r[0-9]+]], [[E1:%r[0-9]+]]}, [[E0_1]]
1415
store <2 x float> %call, ptr %output, align 8
1516
; CHECK: st.v2.b32 [{{%rd[0-9]+}}], {[[E0]], [[E1]]}
1617
ret void
@@ -27,9 +28,7 @@ define void @test_v3f32(<3 x float> %input, ptr %output) {
2728
; CHECK-NOT: ld.param.b32 [[E3:%r[0-9]+]], [retval0+12];
2829
store <3 x float> %call, ptr %output, align 8
2930
; CHECK-DAG: st.b32 [{{%rd[0-9]}}+8],
30-
; -- This is suboptimal. We should do st.v2.f32 instead
31-
; of combining 2xf32 info i64.
32-
; CHECK-DAG: st.b64 [{{%rd[0-9]}}],
31+
; CHECK-DAG: st.v2.b32 [{{%rd[0-9]}}],
3332
; CHECK: ret;
3433
ret void
3534
}

0 commit comments

Comments
 (0)