Skip to content

Commit d538558

Browse files
Add dynamic shared memory allocation
1 parent 17b1fb7 commit d538558

File tree

8 files changed

+35
-11
lines changed

8 files changed

+35
-11
lines changed

platforms/artic/intrinsics_thorin.impala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
#[import(cc = "thorin")] fn cmpxchg_weak[T](_addr: &mut T, _cmp: T, _new: T, _success_order: u32, _failure_order: u32, _scope: &[u8]) -> (T, bool); // only for integer data types
1313
#[import(cc = "thorin")] fn fence(_order: u32, _scope: &[u8]) -> ();
1414
#[import(cc = "thorin")] fn pe_info[T](_src: &[u8], _val: T) -> ();
15-
#[import(cc = "thorin")] fn cuda(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _body: fn() -> ()) -> ();
16-
#[import(cc = "thorin")] fn nvvm(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _body: fn() -> ()) -> ();
17-
#[import(cc = "thorin")] fn opencl(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _body: fn() -> ()) -> ();
18-
#[import(cc = "thorin")] fn amdgpu(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _body: fn() -> ()) -> ();
15+
#[import(cc = "thorin", name = "cuda")] fn cuda_with_lmem(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _lmem: i32, _body: fn() -> ()) -> ();
16+
#[import(cc = "thorin", name = "nvvm")] fn nvvm_with_lmem(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _lmem: i32, _body: fn() -> ()) -> ();
17+
#[import(cc = "thorin", name = "opencl")] fn opencl_with_lmem(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _lmem: i32, _body: fn() -> ()) -> ();
18+
#[import(cc = "thorin", name = "amdgpu")] fn amdgpu_with_lmem(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _lmem: i32, _body: fn() -> ()) -> ();
19+
#[import(cc = "thorin")] fn local_memory() -> &mut addrspace(3)[u8];
1920
#[import(cc = "thorin")] fn reserve_shared[T](_size: i32) -> &mut addrspace(3)[T];
2021
#[import(cc = "thorin")] fn hls(_dev: i32, _body: fn() -> ()) -> ();
2122
#[import(cc = "thorin", name = "pipeline")] fn thorin_pipeline(_initiation_interval: i32, _lower: i32, _upper: i32, _body: fn(i32) -> ()) -> (); // only for HLS/OpenCL backend
@@ -35,6 +36,11 @@
3536
#[import(cc = "thorin", name = "cmpxchg_weak")] fn cmpxchg_weak_p1[T](_addr: &mut addrspace(1)T, _cmp: T, _new: T, _success_order: u32, _failure_order: u32, _scope: &[u8]) -> (T, bool);
3637
#[import(cc = "thorin", name = "cmpxchg_weak")] fn cmpxchg_weak_p3[T](_addr: &mut addrspace(3)T, _cmp: T, _new: T, _success_order: u32, _failure_order: u32, _scope: &[u8]) -> (T, bool);
3738

39+
fn @cuda(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) = cuda_with_lmem(dev, grid, block, 0, body);
40+
fn @nvvm(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) = nvvm_with_lmem(dev, grid, block, 0, body);
41+
fn @opencl(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) = opencl_with_lmem(dev, grid, block, 0, body);
42+
fn @amdgpu(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) = amdgpu_with_lmem(dev, grid, block, 0, body);
43+
3844
fn @pipeline(body: fn(i32) -> ()) = @|initiation_interval: i32, lower: i32, upper: i32| thorin_pipeline(initiation_interval, lower, upper, body);
3945
fn @parallel(body: fn(i32) -> ()) = @|num_threads: i32, lower: i32, upper: i32| thorin_parallel(num_threads, lower, upper, body);
4046
fn @spawn(body: fn() -> ()) = @|| thorin_spawn(body);

platforms/impala/intrinsics_thorin.impala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@ extern "thorin" {
1010
fn insert[T, U](T, i32, U) -> T;
1111
//fn shuffle[T](T, T, T) -> T;
1212

13-
fn cuda(i32, (i32, i32, i32), (i32, i32, i32), fn() -> ()) -> ();
14-
fn nvvm(i32, (i32, i32, i32), (i32, i32, i32), fn() -> ()) -> ();
15-
fn opencl(i32, (i32, i32, i32), (i32, i32, i32), fn() -> ()) -> ();
16-
fn amdgpu(i32, (i32, i32, i32), (i32, i32, i32), fn() -> ()) -> ();
13+
fn "cuda" cuda_with_lmem(i32, (i32, i32, i32), (i32, i32, i32), i32, fn() -> ()) -> ();
14+
fn "nvvm" nvvm_with_lmem(i32, (i32, i32, i32), (i32, i32, i32), i32, fn() -> ()) -> ();
15+
fn "opencl" opencl_with_lmem(i32, (i32, i32, i32), (i32, i32, i32), i32, fn() -> ()) -> ();
16+
fn "amdgpu" amdgpu_with_lmem(i32, (i32, i32, i32), (i32, i32, i32), i32, fn() -> ()) -> ();
17+
fn local_memory() -> &mut[3][u8];
1718
fn reserve_shared[T](i32) -> &mut[3][T];
1819

1920
fn hls(dev: i32, body: fn() -> ()) -> ();
@@ -42,3 +43,8 @@ extern "thorin" {
4243

4344
fn vectorize(vector_length: i32, body: fn(i32) -> ()) -> ();
4445
}
46+
47+
fn @@cuda(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) { cuda_with_lmem(dev, grid, block, 0, body) }
48+
fn @@nvvm(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) { nvvm_with_lmem(dev, grid, block, 0, body) }
49+
fn @@opencl(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) { opencl_with_lmem(dev, grid, block, 0, body) }
50+
fn @@amdgpu(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) { amdgpu_with_lmem(dev, grid, block, 0, body) }

src/anydsl_runtime.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ void anydsl_copy(
117117
void anydsl_launch_kernel(
118118
int32_t mask, const char* file_name, const char* kernel_name,
119119
const uint32_t* grid, const uint32_t* block,
120+
uint32_t smem,
120121
void** arg_data,
121122
const uint32_t* arg_sizes,
122123
const uint32_t* arg_aligns,
@@ -128,6 +129,7 @@ void anydsl_launch_kernel(
128129
kernel_name,
129130
grid,
130131
block,
132+
smem,
131133
{
132134
arg_data,
133135
arg_sizes,

src/anydsl_runtime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ AnyDSL_runtime_API void anydsl_copy(int32_t, const void*, int64_t, int32_t, void
3636
AnyDSL_runtime_API void anydsl_launch_kernel(
3737
int32_t, const char*, const char*,
3838
const uint32_t*, const uint32_t*,
39+
uint32_t,
3940
void**, const uint32_t*, const uint32_t*, const uint32_t*, const uint8_t*,
4041
uint32_t);
4142
AnyDSL_runtime_API void anydsl_synchronize(int32_t);

src/cuda_platform.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,9 @@ void CudaPlatform::launch_kernel(DeviceId dev, const LaunchParams& launch_params
225225
launch_params.grid[1] / launch_params.block[1],
226226
launch_params.grid[2] / launch_params.block[2],
227227
launch_params.block[0], launch_params.block[1], launch_params.block[2],
228-
0, nullptr, launch_params.args.data, nullptr);
228+
launch_params.lmem,
229+
nullptr,
230+
launch_params.args.data, nullptr);
229231
CHECK_CUDA(err, "cuLaunchKernel()");
230232

231233
if (runtime_->profiling_enabled()) {

src/hsa_platform.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ void HSAPlatform::launch_kernel(DeviceId dev, const LaunchParams& launch_params)
409409
aql.kernel_object = kernel_info.kernel;
410410
aql.kernarg_address = kernel_info.kernarg_segment;
411411
aql.private_segment_size = kernel_info.private_segment_size;
412-
aql.group_segment_size = kernel_info.group_segment_size;
412+
aql.group_segment_size = (kernel_info.group_segment_size + 15) / 16 * kernel_info.group_segment_size + launch_params.lmem;
413413

414414
// write to command queue
415415
const uint64_t index = hsa_queue_load_write_index_relaxed(queue);

src/opencl_platform.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,8 @@ void OpenCLPlatform::launch_kernel(DeviceId dev, const LaunchParams& launch_para
375375
cl_mem struct_buf = clCreateBuffer(devices_[dev].ctx, flags, launch_params.args.sizes[i], launch_params.args.data[i], &err);
376376
CHECK_OPENCL(err, "clCreateBuffer()");
377377
kernel_structs[i] = struct_buf;
378-
clSetKernelArg(kernel, i, sizeof(cl_mem), &kernel_structs[i]);
378+
err = clSetKernelArg(kernel, i, sizeof(cl_mem), &kernel_structs[i]);
379+
CHECK_OPENCL(err, "clSetKernelArg()");
379380
} else {
380381
#ifdef CL_VERSION_2_0
381382
if (launch_params.args.types[i] == KernelArgType::Ptr && devices_[dev].version_major == 2) {
@@ -391,6 +392,11 @@ void OpenCLPlatform::launch_kernel(DeviceId dev, const LaunchParams& launch_para
391392
}
392393
}
393394

395+
if (launch_params.lmem != 0) {
396+
cl_int err = clSetKernelArg(kernel, launch_params.num_args, launch_params.lmem, nullptr);
397+
CHECK_OPENCL(err, "clSetKernelArg()");
398+
}
399+
394400
size_t global_work_size[] = {launch_params.grid [0], launch_params.grid [1], launch_params.grid [2]};
395401
size_t local_work_size[] = {launch_params.block[0], launch_params.block[1], launch_params.block[2]};
396402

src/runtime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ struct LaunchParams {
2727
const char* kernel_name;
2828
const uint32_t* grid;
2929
const uint32_t* block;
30+
uint32_t lmem;
3031
struct {
3132
void** data;
3233
const uint32_t* sizes;

0 commit comments

Comments
 (0)