Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions crates/rustc_codegen_nvvm/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
}
}

/// Declare a function. All functions use the default ABI, NVVM ignores any calling convention markers.
/// All functions calls are generated according to the PTX calling convention.
/// Declare a function with appropriate PTX calling conventions.
/// <https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#calling-conventions>
pub fn declare_fn(
&self,
Expand All @@ -332,8 +331,12 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {

trace!("Declaring function `{}` with ty `{:?}`", name, ty);

// TODO(RDambrosio016): we should probably still generate accurate calling conv for functions
// just to make it easier to debug IR and/or make it more compatible with compiling using llvm
// Set PTX device calling convention for all functions declared here.
// Kernel functions will have their calling convention overridden in mono_item.rs
unsafe {
llvm::LLVMSetFunctionCallConv(llfn, llvm::PtxCallConv::Device as u32);
}

llvm::SetUnnamedAddress(llfn, llvm::UnnamedAddr::Global);
if let Some(abi) = fn_abi {
abi.apply_attrs_llfn(self, llfn);
Expand Down
14 changes: 14 additions & 0 deletions crates/rustc_codegen_nvvm/src/llvm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,20 @@ pub(crate) enum Visibility {
Protected = 2,
}

/// PTX/NVPTX calling conventions from LLVM
/// See: <https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/CallingConv.h>
///
/// While NVVM doesn't strictly require these calling conventions to be set
/// (it generates PTX according to its own rules), we set them anyway to
/// make the generated LLVM IR more accurate and easier to debug.
#[repr(u32)]
pub(crate) enum PtxCallConv {
/// PTX kernel calling convention
Kernel = 71,
/// PTX device calling convention
Device = 72,
}

/// LLVMUnnamedAddr
#[repr(C)]
pub(crate) enum UnnamedAddr {
Expand Down
3 changes: 3 additions & 0 deletions crates/rustc_codegen_nvvm/src/mono_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ impl<'tcx> PreDefineCodegenMethods<'tcx> for CodegenCx<'_, 'tcx> {
// to nvvm.annotations per the nvvm ir docs.
if nvvm_attrs.kernel {
trace!("Marking function `{:?}` as a kernel", symbol_name);
llvm::LLVMSetFunctionCallConv(lldecl, llvm::PtxCallConv::Kernel as u32);

// Add kernel metadata for NVVM
let kernel = llvm::LLVMMDStringInContext(self.llcx, "kernel".as_ptr().cast(), 6);
let mdvals = &[lldecl, kernel, self.const_i32(1)];
let node =
Expand Down
Loading