Skip to content

[NVPTX] support packed f32 instructions for sm_100+ #126337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,17 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
// We only care about 16x2 as it's the only real vector type we
// need to deal with.
MVT VT = Vector.getSimpleValueType();
if (!Isv2x16VT(VT))
if (!isPackedVectorTy(VT) || VT.getVectorNumElements() != 2)
return false;

unsigned Opcode;
if (VT.is32BitVector())
Opcode = NVPTX::I32toV2I16;
else if (VT.is64BitVector())
Opcode = NVPTX::I64toV2I32;
else
llvm_unreachable("Unhandled packed type");

// Find and record all uses of this vector that extract element 0 or 1.
SmallVector<SDNode *, 4> E0, E1;
for (auto *U : Vector.getNode()->users()) {
Expand All @@ -496,11 +505,11 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
if (E0.empty() || E1.empty())
return false;

// Merge (f16 extractelt(V, 0), f16 extractelt(V,1))
// into f16,f16 SplitF16x2(V)
// Merge (EltTy extractelt(V, 0), EltTy extractelt(V,1))
// into EltTy,EltTy Split[EltTy]x2(V)
MVT EltVT = VT.getVectorElementType();
SDNode *ScatterOp =
CurDAG->getMachineNode(NVPTX::I32toV2I16, SDLoc(N), EltVT, EltVT, Vector);
CurDAG->getMachineNode(Opcode, SDLoc(N), EltVT, EltVT, Vector);
for (auto *Node : E0)
ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 0));
for (auto *Node : E1)
Expand Down Expand Up @@ -1035,6 +1044,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, std::optional<unsigned> Opcode_i8,
case MVT::i32:
case MVT::f32:
return Opcode_i32;
case MVT::v2f32:
case MVT::i64:
case MVT::f64:
return Opcode_i64;
Expand Down
158 changes: 95 additions & 63 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Large diffs are not rendered by default.

34 changes: 33 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">;

def True : Predicate<"true">;

Expand Down Expand Up @@ -185,6 +186,7 @@ class ValueToRegClass<ValueType T> {
!eq(name, "bf16"): Int16Regs,
!eq(name, "v2bf16"): Int32Regs,
!eq(name, "f32"): Float32Regs,
!eq(name, "v2f32"): Int64Regs,
!eq(name, "f64"): Float64Regs,
!eq(name, "ai32"): Int32ArgRegs,
!eq(name, "ai64"): Int64ArgRegs,
Expand Down Expand Up @@ -231,6 +233,7 @@ def BF16RT : RegTyInfo<bf16, Int16Regs, bf16imm, fpimm, supports_imm = 0>;

def F16X2RT : RegTyInfo<v2f16, Int32Regs, ?, ?, supports_imm = 0>;
def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
def F32X2RT : RegTyInfo<v2f32, Int64Regs, ?, ?, supports_imm = 0>;


// This class provides a basic wrapper around an NVPTXInst that abstracts the
Expand Down Expand Up @@ -462,6 +465,18 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
[(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
Requires<[useFP16Math]>;

def f32x2rr_ftz :
BasicNVPTXInst<(outs Int64Regs:$dst),
(ins Int64Regs:$a, Int64Regs:$b),
op_str # ".ftz.f32x2",
[(set v2f32:$dst, (op_pat v2f32:$a, v2f32:$b))]>,
Requires<[hasF32x2Instructions, doF32FTZ]>;
def f32x2rr :
BasicNVPTXInst<(outs Int64Regs:$dst),
(ins Int64Regs:$a, Int64Regs:$b),
op_str # ".f32x2",
[(set v2f32:$dst, (op_pat v2f32:$a, v2f32:$b))]>,
Requires<[hasF32x2Instructions]>;
def f16x2rr_ftz :
BasicNVPTXInst<(outs Int32Regs:$dst),
(ins Int32Regs:$a, Int32Regs:$b),
Expand Down Expand Up @@ -840,6 +855,9 @@ def : Pat<(vt (select i1:$p, vt:$a, vt:$b)),
(SELP_b32rr $a, $b, $p)>;
}

def : Pat<(v2f32 (select i1:$p, v2f32:$a, v2f32:$b)),
(SELP_b64rr $a, $b, $p)>;

//-----------------------------------
// Test Instructions
//-----------------------------------
Expand Down Expand Up @@ -1368,6 +1386,8 @@ defm BFMA16 : FMA<"fma.rn.bf16", BF16RT, [hasBF16Math]>;
defm BFMA16x2 : FMA<"fma.rn.bf16x2", BF16X2RT, [hasBF16Math]>;
defm FMA32_ftz : FMA<"fma.rn.ftz.f32", F32RT, [doF32FTZ]>;
defm FMA32 : FMA<"fma.rn.f32", F32RT>;
defm FMA32x2_ftz : FMA<"fma.rn.ftz.f32x2", F32X2RT, [hasF32x2Instructions, doF32FTZ]>;
defm FMA32x2 : FMA<"fma.rn.f32x2", F32X2RT, [hasF32x2Instructions]>;
defm FMA64 : FMA<"fma.rn.f64", F64RT>;

// sin/cos
Expand Down Expand Up @@ -2714,6 +2734,7 @@ def : Pat<(i32 (trunc (sra i64:$s, (i32 32)))), (I64toI32H $s)>;
def: Pat<(i32 (sext (extractelt v2i16:$src, 0))),
(CVT_INREG_s32_s16 $src)>;

// Handle extracting one element from the pair (32-bit types)
foreach vt = [v2f16, v2bf16, v2i16] in {
def : Pat<(extractelt vt:$src, 0), (I32toI16L_Sink $src)>, Requires<[hasPTX<71>]>;
def : Pat<(extractelt vt:$src, 1), (I32toI16H_Sink $src)>, Requires<[hasPTX<71>]>;
Expand All @@ -2725,10 +2746,21 @@ foreach vt = [v2f16, v2bf16, v2i16] in {
(V2I16toI32 $a, $b)>;
}

// Same thing for the 64-bit type v2f32.
foreach vt = [v2f32] in {
def : Pat<(extractelt vt:$src, 0), (I64toI32L_Sink $src)>, Requires<[hasPTX<71>]>;
def : Pat<(extractelt vt:$src, 1), (I64toI32H_Sink $src)>, Requires<[hasPTX<71>]>;

def : Pat<(extractelt vt:$src, 0), (I64toI32L $src)>;
def : Pat<(extractelt vt:$src, 1), (I64toI32H $src)>;

def : Pat<(vt (build_vector vt.ElementType:$a, vt.ElementType:$b)),
(V2I32toI64 $a, $b)>;
}

def: Pat<(v2i16 (scalar_to_vector i16:$a)),
(CVT_u32_u16 $a, CvtNONE)>;


def nvptx_build_vector : SDNode<"NVPTXISD::BUILD_VECTOR", SDTypeProfile<1, 2, []>, []>;

def : Pat<(i64 (nvptx_build_vector i32:$a, i32:$b)),
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def Int16Regs : NVPTXRegClass<[i16, f16, bf16], 16, (add (sequence "RS%u", 0, 4)
def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8, f32], 32,
(add (sequence "R%u", 0, 4),
VRFrame32, VRFrameLocal32)>;
def Int64Regs : NVPTXRegClass<[i64, f64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
def Int64Regs : NVPTXRegClass<[i64, v2f32, f64], 64,
(add (sequence "RL%u", 0, 4),
VRFrame64, VRFrameLocal64)>;
// 128-bit regs are not defined as general regs in NVPTX. They are used for inlineASM only.
def Int128Regs : NVPTXRegClass<[i128], 128, (add (sequence "RQ%u", 0, 4))>;

Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXSubtarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {

return HasTcgen05 && PTXVersion >= 86;
}
// f32x2 instructions in Blackwell family
bool hasF32x2Instructions() const {
return SmVersion >= 100 && PTXVersion >= 86;
}

// TMA G2S copy with cta_group::1/2 support
bool hasCpAsyncBulkTensorCTAGroupSupport() const {
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
Insert = false;
}
}
if (Insert && Isv2x16VT(VT)) {
// Can be built in a single mov
if (Insert && isPackedVectorTy(VT) && VT.is32BitVector()) {
// Can be built in a single 32-bit mov (64-bit regs are emulated in SASS
// with 2x 32-bit regs)
Cost += 1;
Insert = false;
}
Expand Down
10 changes: 8 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,14 @@ inline unsigned promoteScalarArgumentSize(unsigned size) {

bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM);

inline bool Isv2x16VT(EVT VT) {
return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
inline bool isPackedVectorTy(EVT VT) {
return (VT == MVT::v4i8 || VT == MVT::v2f16 || VT == MVT::v2bf16 ||
VT == MVT::v2i16 || VT == MVT::v2f32);
}

inline bool isPackedElementTy(EVT VT) {
return (VT == MVT::i8 || VT == MVT::f16 || VT == MVT::bf16 ||
VT == MVT::i16 || VT == MVT::f32);
}

inline bool shouldPassAsArray(Type *Ty) {
Expand Down
7 changes: 3 additions & 4 deletions llvm/test/CodeGen/NVPTX/aggregate-return.ll
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ define void @test_v2f32(<2 x float> %input, ptr %output) {
; CHECK-LABEL: @test_v2f32
%call = tail call <2 x float> @barv(<2 x float> %input)
; CHECK: .param .align 8 .b8 retval0[8];
; CHECK: ld.param.v2.b32 {[[E0:%r[0-9]+]], [[E1:%r[0-9]+]]}, [retval0];
; CHECK: ld.param.b64 [[E0_1:%rd[0-9]+]], [retval0];
; CHECK: mov.b64 {[[E0:%r[0-9]+]], [[E1:%r[0-9]+]]}, [[E0_1]]
store <2 x float> %call, ptr %output, align 8
; CHECK: st.v2.b32 [{{%rd[0-9]+}}], {[[E0]], [[E1]]}
ret void
Expand All @@ -27,9 +28,7 @@ define void @test_v3f32(<3 x float> %input, ptr %output) {
; CHECK-NOT: ld.param.b32 [[E3:%r[0-9]+]], [retval0+12];
store <3 x float> %call, ptr %output, align 8
; CHECK-DAG: st.b32 [{{%rd[0-9]}}+8],
; -- This is suboptimal. We should do st.v2.f32 instead
; of combining 2xf32 info i64.
; CHECK-DAG: st.b64 [{{%rd[0-9]}}],
; CHECK-DAG: st.v2.b32 [{{%rd[0-9]}}],
; CHECK: ret;
ret void
}
Expand Down
136 changes: 76 additions & 60 deletions llvm/test/CodeGen/NVPTX/bf16-instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -707,108 +707,124 @@ define <8 x float> @test_extload_bf16x8(ptr addrspace(3) noundef %arg) #0 {
; SM70: {
; SM70-NEXT: .reg .b16 %rs<9>;
; SM70-NEXT: .reg .b32 %r<21>;
; SM70-NEXT: .reg .b64 %rd<2>;
; SM70-NEXT: .reg .b64 %rd<6>;
; SM70-EMPTY:
; SM70-NEXT: // %bb.0:
; SM70-NEXT: ld.param.b64 %rd1, [test_extload_bf16x8_param_0];
; SM70-NEXT: ld.shared.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1];
; SM70-NEXT: mov.b32 {%rs1, %rs2}, %r1;
; SM70-NEXT: mov.b32 {%rs3, %rs4}, %r2;
; SM70-NEXT: mov.b32 {%rs5, %rs6}, %r3;
; SM70-NEXT: mov.b32 {%rs7, %rs8}, %r4;
; SM70-NEXT: cvt.u32.u16 %r5, %rs8;
; SM70-NEXT: mov.b32 {%rs1, %rs2}, %r4;
; SM70-NEXT: cvt.u32.u16 %r5, %rs2;
; SM70-NEXT: shl.b32 %r6, %r5, 16;
; SM70-NEXT: cvt.u32.u16 %r7, %rs7;
; SM70-NEXT: cvt.u32.u16 %r7, %rs1;
; SM70-NEXT: shl.b32 %r8, %r7, 16;
; SM70-NEXT: cvt.u32.u16 %r9, %rs6;
; SM70-NEXT: mov.b64 %rd2, {%r8, %r6};
; SM70-NEXT: mov.b32 {%rs3, %rs4}, %r3;
; SM70-NEXT: cvt.u32.u16 %r9, %rs4;
; SM70-NEXT: shl.b32 %r10, %r9, 16;
; SM70-NEXT: cvt.u32.u16 %r11, %rs5;
; SM70-NEXT: cvt.u32.u16 %r11, %rs3;
; SM70-NEXT: shl.b32 %r12, %r11, 16;
; SM70-NEXT: cvt.u32.u16 %r13, %rs4;
; SM70-NEXT: mov.b64 %rd3, {%r12, %r10};
; SM70-NEXT: mov.b32 {%rs5, %rs6}, %r2;
; SM70-NEXT: cvt.u32.u16 %r13, %rs6;
; SM70-NEXT: shl.b32 %r14, %r13, 16;
; SM70-NEXT: cvt.u32.u16 %r15, %rs3;
; SM70-NEXT: cvt.u32.u16 %r15, %rs5;
; SM70-NEXT: shl.b32 %r16, %r15, 16;
; SM70-NEXT: cvt.u32.u16 %r17, %rs2;
; SM70-NEXT: mov.b64 %rd4, {%r16, %r14};
; SM70-NEXT: mov.b32 {%rs7, %rs8}, %r1;
; SM70-NEXT: cvt.u32.u16 %r17, %rs8;
; SM70-NEXT: shl.b32 %r18, %r17, 16;
; SM70-NEXT: cvt.u32.u16 %r19, %rs1;
; SM70-NEXT: cvt.u32.u16 %r19, %rs7;
; SM70-NEXT: shl.b32 %r20, %r19, 16;
; SM70-NEXT: st.param.v4.b32 [func_retval0], {%r20, %r18, %r16, %r14};
; SM70-NEXT: st.param.v4.b32 [func_retval0+16], {%r12, %r10, %r8, %r6};
; SM70-NEXT: mov.b64 %rd5, {%r20, %r18};
; SM70-NEXT: st.param.v2.b64 [func_retval0], {%rd5, %rd4};
; SM70-NEXT: st.param.v2.b64 [func_retval0+16], {%rd3, %rd2};
; SM70-NEXT: ret;
;
; SM80-LABEL: test_extload_bf16x8(
; SM80: {
; SM80-NEXT: .reg .b16 %rs<9>;
; SM80-NEXT: .reg .b32 %r<13>;
; SM80-NEXT: .reg .b64 %rd<2>;
; SM80-NEXT: .reg .b64 %rd<6>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b64 %rd1, [test_extload_bf16x8_param_0];
; SM80-NEXT: ld.shared.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1];
; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r1;
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r2;
; SM80-NEXT: mov.b32 {%rs5, %rs6}, %r3;
; SM80-NEXT: mov.b32 {%rs7, %rs8}, %r4;
; SM80-NEXT: cvt.f32.bf16 %r5, %rs8;
; SM80-NEXT: cvt.f32.bf16 %r6, %rs7;
; SM80-NEXT: cvt.f32.bf16 %r7, %rs6;
; SM80-NEXT: cvt.f32.bf16 %r8, %rs5;
; SM80-NEXT: cvt.f32.bf16 %r9, %rs4;
; SM80-NEXT: cvt.f32.bf16 %r10, %rs3;
; SM80-NEXT: cvt.f32.bf16 %r11, %rs2;
; SM80-NEXT: cvt.f32.bf16 %r12, %rs1;
; SM80-NEXT: st.param.v4.b32 [func_retval0], {%r12, %r11, %r10, %r9};
; SM80-NEXT: st.param.v4.b32 [func_retval0+16], {%r8, %r7, %r6, %r5};
; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r4;
; SM80-NEXT: cvt.f32.bf16 %r5, %rs2;
; SM80-NEXT: cvt.f32.bf16 %r6, %rs1;
; SM80-NEXT: mov.b64 %rd2, {%r6, %r5};
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r3;
; SM80-NEXT: cvt.f32.bf16 %r7, %rs4;
; SM80-NEXT: cvt.f32.bf16 %r8, %rs3;
; SM80-NEXT: mov.b64 %rd3, {%r8, %r7};
; SM80-NEXT: mov.b32 {%rs5, %rs6}, %r2;
; SM80-NEXT: cvt.f32.bf16 %r9, %rs6;
; SM80-NEXT: cvt.f32.bf16 %r10, %rs5;
; SM80-NEXT: mov.b64 %rd4, {%r10, %r9};
; SM80-NEXT: mov.b32 {%rs7, %rs8}, %r1;
; SM80-NEXT: cvt.f32.bf16 %r11, %rs8;
; SM80-NEXT: cvt.f32.bf16 %r12, %rs7;
; SM80-NEXT: mov.b64 %rd5, {%r12, %r11};
; SM80-NEXT: st.param.v2.b64 [func_retval0], {%rd5, %rd4};
; SM80-NEXT: st.param.v2.b64 [func_retval0+16], {%rd3, %rd2};
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_extload_bf16x8(
; SM80-FTZ: {
; SM80-FTZ-NEXT: .reg .b16 %rs<9>;
; SM80-FTZ-NEXT: .reg .b32 %r<13>;
; SM80-FTZ-NEXT: .reg .b64 %rd<2>;
; SM80-FTZ-NEXT: .reg .b64 %rd<6>;
; SM80-FTZ-EMPTY:
; SM80-FTZ-NEXT: // %bb.0:
; SM80-FTZ-NEXT: ld.param.b64 %rd1, [test_extload_bf16x8_param_0];
; SM80-FTZ-NEXT: ld.shared.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1];
; SM80-FTZ-NEXT: mov.b32 {%rs1, %rs2}, %r1;
; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r2;
; SM80-FTZ-NEXT: mov.b32 {%rs5, %rs6}, %r3;
; SM80-FTZ-NEXT: mov.b32 {%rs7, %rs8}, %r4;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r5, %rs8;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r6, %rs7;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r7, %rs6;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r8, %rs5;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r9, %rs4;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r10, %rs3;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r11, %rs2;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r12, %rs1;
; SM80-FTZ-NEXT: st.param.v4.b32 [func_retval0], {%r12, %r11, %r10, %r9};
; SM80-FTZ-NEXT: st.param.v4.b32 [func_retval0+16], {%r8, %r7, %r6, %r5};
; SM80-FTZ-NEXT: mov.b32 {%rs1, %rs2}, %r4;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r5, %rs2;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r6, %rs1;
; SM80-FTZ-NEXT: mov.b64 %rd2, {%r6, %r5};
; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r3;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r7, %rs4;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r8, %rs3;
; SM80-FTZ-NEXT: mov.b64 %rd3, {%r8, %r7};
; SM80-FTZ-NEXT: mov.b32 {%rs5, %rs6}, %r2;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r9, %rs6;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r10, %rs5;
; SM80-FTZ-NEXT: mov.b64 %rd4, {%r10, %r9};
; SM80-FTZ-NEXT: mov.b32 {%rs7, %rs8}, %r1;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r11, %rs8;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r12, %rs7;
; SM80-FTZ-NEXT: mov.b64 %rd5, {%r12, %r11};
; SM80-FTZ-NEXT: st.param.v2.b64 [func_retval0], {%rd5, %rd4};
; SM80-FTZ-NEXT: st.param.v2.b64 [func_retval0+16], {%rd3, %rd2};
; SM80-FTZ-NEXT: ret;
;
; SM90-LABEL: test_extload_bf16x8(
; SM90: {
; SM90-NEXT: .reg .b16 %rs<9>;
; SM90-NEXT: .reg .b32 %r<13>;
; SM90-NEXT: .reg .b64 %rd<2>;
; SM90-NEXT: .reg .b64 %rd<6>;
; SM90-EMPTY:
; SM90-NEXT: // %bb.0:
; SM90-NEXT: ld.param.b64 %rd1, [test_extload_bf16x8_param_0];
; SM90-NEXT: ld.shared.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1];
; SM90-NEXT: mov.b32 {%rs1, %rs2}, %r1;
; SM90-NEXT: mov.b32 {%rs3, %rs4}, %r2;
; SM90-NEXT: mov.b32 {%rs5, %rs6}, %r3;
; SM90-NEXT: mov.b32 {%rs7, %rs8}, %r4;
; SM90-NEXT: cvt.f32.bf16 %r5, %rs8;
; SM90-NEXT: cvt.f32.bf16 %r6, %rs7;
; SM90-NEXT: cvt.f32.bf16 %r7, %rs6;
; SM90-NEXT: cvt.f32.bf16 %r8, %rs5;
; SM90-NEXT: cvt.f32.bf16 %r9, %rs4;
; SM90-NEXT: cvt.f32.bf16 %r10, %rs3;
; SM90-NEXT: cvt.f32.bf16 %r11, %rs2;
; SM90-NEXT: cvt.f32.bf16 %r12, %rs1;
; SM90-NEXT: st.param.v4.b32 [func_retval0], {%r12, %r11, %r10, %r9};
; SM90-NEXT: st.param.v4.b32 [func_retval0+16], {%r8, %r7, %r6, %r5};
; SM90-NEXT: mov.b32 {%rs1, %rs2}, %r4;
; SM90-NEXT: cvt.f32.bf16 %r5, %rs2;
; SM90-NEXT: cvt.f32.bf16 %r6, %rs1;
; SM90-NEXT: mov.b64 %rd2, {%r6, %r5};
; SM90-NEXT: mov.b32 {%rs3, %rs4}, %r3;
; SM90-NEXT: cvt.f32.bf16 %r7, %rs4;
; SM90-NEXT: cvt.f32.bf16 %r8, %rs3;
; SM90-NEXT: mov.b64 %rd3, {%r8, %r7};
; SM90-NEXT: mov.b32 {%rs5, %rs6}, %r2;
; SM90-NEXT: cvt.f32.bf16 %r9, %rs6;
; SM90-NEXT: cvt.f32.bf16 %r10, %rs5;
; SM90-NEXT: mov.b64 %rd4, {%r10, %r9};
; SM90-NEXT: mov.b32 {%rs7, %rs8}, %r1;
; SM90-NEXT: cvt.f32.bf16 %r11, %rs8;
; SM90-NEXT: cvt.f32.bf16 %r12, %rs7;
; SM90-NEXT: mov.b64 %rd5, {%r12, %r11};
; SM90-NEXT: st.param.v2.b64 [func_retval0], {%rd5, %rd4};
; SM90-NEXT: st.param.v2.b64 [func_retval0+16], {%rd3, %rd2};
; SM90-NEXT: ret;
%load = load <8 x bfloat>, ptr addrspace(3) %arg, align 16
%res = fpext <8 x bfloat> %load to <8 x float>
Expand Down
Loading
Loading