Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compiler/rustc_abi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1700,6 +1700,9 @@ pub struct AddressSpace(pub u32);
impl AddressSpace {
/// LLVM's `0` address space.
pub const ZERO: Self = AddressSpace(0);
/// The address space for workgroup memory on nvptx and amdgpu.
/// See e.g. the `gpu_launch_sized_workgroup_mem` intrinsic for details.
pub const GPU_WORKGROUP: Self = AddressSpace(3);
}

/// The way we represent values to the backend
Expand Down
25 changes: 25 additions & 0 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,31 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
val
}

pub(crate) fn named_inbounds_gep(
&mut self,
ty: &'ll Type,
ptr: &'ll Value,
indices: &[&'ll Value],
name: &str,
) -> &'ll Value {
let val = unsafe {
llvm::LLVMBuildGEPWithNoWrapFlags(
self.llbuilder,
ty,
ptr,
indices.as_ptr(),
indices.len() as c_uint,
UNNAMED,
GEPNoWrapFlags::InBounds,
)
};
if name != "" {
let name = std::ffi::CString::new(name).unwrap();
llvm::set_value_name(val, &name.as_bytes());
}
val
}

pub(crate) fn inbounds_gep(
&mut self,
ty: &'ll Type,
Expand Down
52 changes: 33 additions & 19 deletions compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,25 +319,26 @@ impl KernelArgsTy {
geps: [&'ll Value; 3],
workgroup_dims: &'ll Value,
thread_dims: &'ll Value,
) -> [(Align, &'ll Value); 13] {
dyn_cache: &'ll Value,
) -> [(Align, &'ll str, &'ll Value); 13] {
let four = Align::from_bytes(4).expect("4 Byte alignment should work");
let eight = Align::EIGHT;

[
(four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)),
(four, cx.get_const_i32(num_args)),
(eight, geps[0]),
(eight, geps[1]),
(eight, geps[2]),
(eight, memtransfer_types),
// The next two are debug infos. FIXME(offload): set them
(eight, cx.const_null(cx.type_ptr())), // dbg
(eight, cx.const_null(cx.type_ptr())), // dbg
(eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT)),
(eight, cx.get_const_i64(KernelArgsTy::FLAGS)),
(four, workgroup_dims),
(four, thread_dims),
(four, cx.get_const_i32(0)),
(four, "Version", cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)),
(four, "NumArgs", cx.get_const_i32(num_args)),
(eight, "ArgBasePtrs", geps[0]),
(eight, "ArgPtrs", geps[1]),
(eight, "ArgSizes", geps[2]),
(eight, "ArgTypes", memtransfer_types),
// The "", next two are debug infos. FIXME(offload): set them
(eight, "ArgNames", cx.const_null(cx.type_ptr())), // dbg
(eight, "ArgMappers", cx.const_null(cx.type_ptr())), // dbg
(eight, "Tripcount", cx.get_const_i64(KernelArgsTy::TRIPCOUNT)),
(eight, "Flags", cx.get_const_i64(KernelArgsTy::FLAGS)),
(four, "NumTeams", workgroup_dims),
(four, "ThreadLimit", thread_dims),
(four, "DynCGroupMem", dyn_cache),
]
}
}
Expand Down Expand Up @@ -576,6 +577,7 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
metadata: &[OffloadMetadata],
offload_globals: &OffloadGlobals<'ll>,
offload_dims: &OffloadKernelDims<'ll>,
dyn_cache: &'ll Value,
) {
let cx = builder.cx;
let OffloadKernelGlobals {
Expand Down Expand Up @@ -740,14 +742,26 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
num_args,
s_ident_t,
);
let values =
KernelArgsTy::new(&cx, num_args, memtransfer_kernel, geps, workgroup_dims, thread_dims);
let values = KernelArgsTy::new(
&cx,
num_args,
memtransfer_kernel,
geps,
workgroup_dims,
thread_dims,
dyn_cache,
);

// Step 3)
// Here we fill the KernelArgsTy, see the documentation above
for (i, value) in values.iter().enumerate() {
let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
builder.store(value.1, ptr, value.0);
let ptr = builder.named_inbounds_gep(
tgt_kernel_decl,
a5,
&[i32_0, cx.get_const_i32(i as u64)],
value.1,
);
builder.store(value.2, ptr, value.0);
}

let args = vec![
Expand Down
23 changes: 23 additions & 0 deletions compiler/rustc_codegen_llvm/src/declare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
use std::borrow::Borrow;

use itertools::Itertools;
use rustc_abi::AddressSpace;
use rustc_codegen_ssa::traits::TypeMembershipCodegenMethods;
use rustc_data_structures::fx::FxIndexSet;
use rustc_middle::ty::{Instance, Ty};
Expand Down Expand Up @@ -97,6 +98,28 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
)
}
}

/// Declare a global value in a specific address space.
///
/// If there’s a value with the same name already declared, the function will
/// return its Value instead.
pub(crate) fn declare_global_in_addrspace(
&self,
name: &str,
ty: &'ll Type,
addr_space: AddressSpace,
) -> &'ll Value {
debug!("declare_global(name={name:?}, addrspace={addr_space:?})");
unsafe {
llvm::LLVMRustGetOrInsertGlobalInAddrspace(
(**self).borrow().llmod,
name.as_c_char_ptr(),
name.len(),
ty,
addr_space.0,
)
}
}
}

impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
Expand Down
64 changes: 60 additions & 4 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use std::ffi::c_uint;
use std::{assert_matches, ptr};

use rustc_abi::{
Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size, WrappingRange,
AddressSpace, Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size,
WrappingRange,
};
use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh};
use rustc_codegen_ssa::common::{IntPredicate, TypeKind};
Expand All @@ -23,7 +24,7 @@ use rustc_session::config::CrateType;
use rustc_span::{Span, Symbol, sym};
use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate};
use rustc_target::callconv::PassMode;
use rustc_target::spec::Os;
use rustc_target::spec::{Arch, Os};
use tracing::debug;

use crate::abi::FnAbiLlvmExt;
Expand Down Expand Up @@ -590,6 +591,44 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
return Ok(());
}

sym::gpu_launch_sized_workgroup_mem => {
// Generate an anonymous global per call, with these properties:
// 1. The global is in the address space for workgroup memory
// 2. It is an `external` global
// 3. It is correctly aligned for the pointee `T`
// All instances of extern addrspace(gpu_workgroup) globals are merged in the LLVM backend.
// The name is irrelevant.
// See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared
// FIXME Workaround an nvptx backend issue that extern globals must have a name
let name = if tcx.sess.target.arch == Arch::Nvptx64 {
"gpu_launch_sized_workgroup_mem"
} else {
""
};
let global = self.declare_global_in_addrspace(
name,
self.type_array(self.type_i8(), 0),
AddressSpace::GPU_WORKGROUP,
);
let ty::RawPtr(inner_ty, _) = result.layout.ty.kind() else { unreachable!() };
// The alignment of the global is used to specify the *minimum* alignment that
// must be obeyed by the GPU runtime.
// When multiple of these global variables are used by a kernel, the maximum alignment is taken.
// See https://github.com/llvm/llvm-project/blob/a271d07488a85ce677674bbe8101b10efff58c95/llvm/lib/Target/AMDGPU/AMDGPULowerModuleLDSPass.cpp#L821
let alignment = self.align_of(*inner_ty).bytes() as u32;
unsafe {
// FIXME Workaround the above issue by taking maximum alignment if the global existed
if tcx.sess.target.arch == Arch::Nvptx64 {
if alignment > llvm::LLVMGetAlignment(global) {
llvm::LLVMSetAlignment(global, alignment);
}
} else {
llvm::LLVMSetAlignment(global, alignment);
}
}
self.cx().const_pointercast(global, self.type_ptr())
}

sym::amdgpu_dispatch_ptr => {
let val = self.call_intrinsic("llvm.amdgcn.dispatch.ptr", &[], &[]);
// Relying on `LLVMBuildPointerCast` to produce an addrspacecast
Expand Down Expand Up @@ -1423,7 +1462,15 @@ fn codegen_offload<'ll, 'tcx>(
};

let offload_dims = OffloadKernelDims::from_operands(bx, &args[1], &args[2]);
let args = get_args_from_tuple(bx, args[3], fn_target);
let dyn_cache = match args[3].val {
OperandValue::Immediate(val) => val,
_ => panic!("unparsable"),
};
//let dyn_cache = args[3]; //bx.const_i32(512);
dbg!(&dyn_cache);
let args = get_args_from_tuple(bx, args[4], fn_target);
//let dyn_cache = args[3];
//llvm::Dump(&dyn_cache);
let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target, LOCAL_CRATE);

let sig = tcx.fn_sig(fn_target.def_id()).skip_binder();
Expand All @@ -1444,7 +1491,16 @@ fn codegen_offload<'ll, 'tcx>(
};
register_offload(cx);
let offload_data = gen_define_handling(&cx, &metadata, target_symbol, offload_globals);
gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals, &offload_dims);
gen_call_handling(
bx,
&offload_data,
&args,
&types,
&metadata,
offload_globals,
&offload_dims,
&dyn_cache,
);
}

fn get_args_from_tuple<'ll, 'tcx>(
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1986,6 +1986,13 @@ unsafe extern "C" {
NameLen: size_t,
T: &'a Type,
) -> &'a Value;
pub(crate) fn LLVMRustGetOrInsertGlobalInAddrspace<'a>(
M: &'a Module,
Name: *const c_char,
NameLen: size_t,
T: &'a Type,
AddressSpace: c_uint,
) -> &'a Value;
pub(crate) fn LLVMRustGetNamedValue(
M: &Module,
Name: *const c_char,
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_ssa/src/mir/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
sym::abort
| sym::unreachable
| sym::cold_path
| sym::gpu_launch_sized_workgroup_mem
| sym::breakpoint
| sym::amdgpu_dispatch_ptr
| sym::assert_zero_valid
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_hir_analysis/src/check/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi
| sym::forget
| sym::frem_algebraic
| sym::fsub_algebraic
| sym::gpu_launch_sized_workgroup_mem
| sym::is_val_statically_known
| sym::log2f16
| sym::log2f32
Expand Down Expand Up @@ -301,6 +302,7 @@ pub(crate) fn check_intrinsic_type(
sym::field_offset => (1, 0, vec![], tcx.types.usize),
sym::rustc_peek => (1, 0, vec![param(0)], param(0)),
sym::caller_location => (0, 0, vec![], tcx.caller_location_ty()),
sym::gpu_launch_sized_workgroup_mem => (1, 0, vec![], Ty::new_mut_ptr(tcx, param(0))),
sym::assert_inhabited | sym::assert_zero_valid | sym::assert_mem_uninitialized_valid => {
(1, 0, vec![], tcx.types.unit)
}
Expand Down Expand Up @@ -357,6 +359,7 @@ pub(crate) fn check_intrinsic_type(
param(0),
Ty::new_array_with_const_len(tcx, tcx.types.u32, Const::from_target_usize(tcx, 3)),
Ty::new_array_with_const_len(tcx, tcx.types.u32, Const::from_target_usize(tcx, 3)),
tcx.types.u32,
param(1),
],
param(2),
Expand Down
26 changes: 21 additions & 5 deletions compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,12 @@ extern "C" LLVMValueRef LLVMRustGetOrInsertFunction(LLVMModuleRef M,
.getCallee());
}

extern "C" LLVMValueRef LLVMRustGetOrInsertGlobal(LLVMModuleRef M,
const char *Name,
size_t NameLen,
LLVMTypeRef Ty) {
// Get the global variable with the given name if it exists or create a new
// external global.
extern "C" LLVMValueRef
LLVMRustGetOrInsertGlobalInAddrspace(LLVMModuleRef M, const char *Name,
size_t NameLen, LLVMTypeRef Ty,
unsigned int AddressSpace) {
Module *Mod = unwrap(M);
auto NameRef = StringRef(Name, NameLen);

Expand All @@ -308,10 +310,24 @@ extern "C" LLVMValueRef LLVMRustGetOrInsertGlobal(LLVMModuleRef M,
GlobalVariable *GV = Mod->getGlobalVariable(NameRef, true);
if (!GV)
GV = new GlobalVariable(*Mod, unwrap(Ty), false,
GlobalValue::ExternalLinkage, nullptr, NameRef);
GlobalValue::ExternalLinkage, nullptr, NameRef,
nullptr, GlobalValue::NotThreadLocal, AddressSpace);
return wrap(GV);
}

// Get the global variable with the given name if it exists or create a new
// external global.
extern "C" LLVMValueRef LLVMRustGetOrInsertGlobal(LLVMModuleRef M,
const char *Name,
size_t NameLen,
LLVMTypeRef Ty) {
Module *Mod = unwrap(M);
unsigned int AddressSpace =
Mod->getDataLayout().getDefaultGlobalsAddressSpace();
return LLVMRustGetOrInsertGlobalInAddrspace(M, Name, NameLen, Ty,
AddressSpace);
}

// Must match the layout of `rustc_codegen_llvm::llvm::ffi::AttributeKind`.
enum class LLVMRustAttributeKind {
AlwaysInline = 0,
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,7 @@ symbols! {
global_asm,
global_registration,
globs,
gpu_launch_sized_workgroup_mem,
gt,
guard_patterns,
half_open_range_patterns,
Expand Down
Loading