@@ -158,6 +158,7 @@ def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
158
158
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
159
159
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
160
160
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
161
+ def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">;
161
162
162
163
def True : Predicate<"true">;
163
164
def False : Predicate<"false">;
@@ -193,6 +194,7 @@ class ValueToRegClass<ValueType T> {
193
194
!eq(name, "bf16"): Int16Regs,
194
195
!eq(name, "v2bf16"): Int32Regs,
195
196
!eq(name, "f32"): Float32Regs,
197
+ !eq(name, "v2f32"): Int64Regs,
196
198
!eq(name, "f64"): Float64Regs,
197
199
!eq(name, "ai32"): Int32ArgRegs,
198
200
!eq(name, "ai64"): Int64ArgRegs,
@@ -239,6 +241,7 @@ def BF16RT : RegTyInfo<bf16, Int16Regs, bf16imm, fpimm, supports_imm = 0>;
239
241
240
242
def F16X2RT : RegTyInfo<v2f16, Int32Regs, ?, ?, supports_imm = 0>;
241
243
def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
244
+ def F32X2RT : RegTyInfo<v2f32, Int64Regs, ?, ?, supports_imm = 0>;
242
245
243
246
244
247
// This class provides a basic wrapper around an NVPTXInst that abstracts the
@@ -461,6 +464,18 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
461
464
[(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
462
465
Requires<[useFP16Math]>;
463
466
467
+ def f32x2rr_ftz :
468
+ BasicNVPTXInst<(outs Int64Regs:$dst),
469
+ (ins Int64Regs:$a, Int64Regs:$b),
470
+ op_str # ".ftz.f32x2",
471
+ [(set v2f32:$dst, (op_pat v2f32:$a, v2f32:$b))]>,
472
+ Requires<[hasF32x2Instructions, doF32FTZ]>;
473
+ def f32x2rr :
474
+ BasicNVPTXInst<(outs Int64Regs:$dst),
475
+ (ins Int64Regs:$a, Int64Regs:$b),
476
+ op_str # ".f32x2",
477
+ [(set v2f32:$dst, (op_pat v2f32:$a, v2f32:$b))]>,
478
+ Requires<[hasF32x2Instructions]>;
464
479
def f16x2rr_ftz :
465
480
BasicNVPTXInst<(outs Int32Regs:$dst),
466
481
(ins Int32Regs:$a, Int32Regs:$b),
@@ -839,6 +854,9 @@ def : Pat<(vt (select i1:$p, vt:$a, vt:$b)),
839
854
(SELP_b32rr $a, $b, $p)>;
840
855
}
841
856
857
+ def : Pat<(v2f32 (select i1:$p, v2f32:$a, v2f32:$b)),
858
+ (SELP_b64rr $a, $b, $p)>;
859
+
842
860
//-----------------------------------
843
861
// Test Instructions
844
862
//-----------------------------------
@@ -1387,6 +1405,8 @@ defm BFMA16 : FMA<"fma.rn.bf16", BF16RT, [hasBF16Math]>;
1387
1405
defm BFMA16x2 : FMA<"fma.rn.bf16x2", BF16X2RT, [hasBF16Math]>;
1388
1406
defm FMA32_ftz : FMA<"fma.rn.ftz.f32", F32RT, [doF32FTZ]>;
1389
1407
defm FMA32 : FMA<"fma.rn.f32", F32RT>;
1408
+ defm FMA32x2_ftz : FMA<"fma.rn.ftz.f32x2", F32X2RT, [hasF32x2Instructions, doF32FTZ]>;
1409
+ defm FMA32x2 : FMA<"fma.rn.f32x2", F32X2RT, [hasF32x2Instructions]>;
1390
1410
defm FMA64 : FMA<"fma.rn.f64", F64RT>;
1391
1411
1392
1412
// sin/cos
@@ -2739,6 +2759,7 @@ def : Pat<(i32 (trunc (sra i64:$s, (i32 32)))), (I64toI32H $s)>;
2739
2759
def: Pat<(i32 (sext (extractelt v2i16:$src, 0))),
2740
2760
(CVT_INREG_s32_s16 $src)>;
2741
2761
2762
+ // Handle extracting one element from the pair (32-bit types)
2742
2763
foreach vt = [v2f16, v2bf16, v2i16] in {
2743
2764
def : Pat<(extractelt vt:$src, 0), (I32toI16L_Sink $src)>, Requires<[hasPTX<71>]>;
2744
2765
def : Pat<(extractelt vt:$src, 1), (I32toI16H_Sink $src)>, Requires<[hasPTX<71>]>;
@@ -2750,10 +2771,21 @@ foreach vt = [v2f16, v2bf16, v2i16] in {
2750
2771
(V2I16toI32 $a, $b)>;
2751
2772
}
2752
2773
2774
+ // Same thing for the 64-bit type v2f32.
2775
+ foreach vt = [v2f32] in {
2776
+ def : Pat<(extractelt vt:$src, 0), (I64toI32L_Sink $src)>, Requires<[hasPTX<71>]>;
2777
+ def : Pat<(extractelt vt:$src, 1), (I64toI32H_Sink $src)>, Requires<[hasPTX<71>]>;
2778
+
2779
+ def : Pat<(extractelt vt:$src, 0), (I64toI32L $src)>;
2780
+ def : Pat<(extractelt vt:$src, 1), (I64toI32H $src)>;
2781
+
2782
+ def : Pat<(vt (build_vector vt.ElementType:$a, vt.ElementType:$b)),
2783
+ (V2I32toI64 $a, $b)>;
2784
+ }
2785
+
2753
2786
def: Pat<(v2i16 (scalar_to_vector i16:$a)),
2754
2787
(CVT_u32_u16 $a, CvtNONE)>;
2755
2788
2756
-
2757
2789
def nvptx_build_vector : SDNode<"NVPTXISD::BUILD_VECTOR", SDTypeProfile<1, 2, []>, []>;
2758
2790
2759
2791
def : Pat<(i64 (nvptx_build_vector i32:$a, i32:$b)),
0 commit comments