Skip to content

Commit 7ad75e2

Browse files
Add dynamic shared memory allocation
1 parent 7042fbf commit 7ad75e2

File tree

9 files changed

+40
-12
lines changed

9 files changed

+40
-12
lines changed

platforms/artic/intrinsics_thorin.impala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
#[import(cc = "thorin")] fn cmpxchg_weak[T](_addr: &mut T, _cmp: T, _new: T, _success_order: u32, _failure_order: u32, _scope: &[u8]) -> (T, bool); // only for integer data types
1313
#[import(cc = "thorin")] fn fence(_order: u32, _scope: &[u8]) -> ();
1414
#[import(cc = "thorin")] fn pe_info[T](_src: &[u8], _val: T) -> ();
15-
#[import(cc = "thorin")] fn cuda(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _body: fn() -> ()) -> ();
16-
#[import(cc = "thorin")] fn nvvm(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _body: fn() -> ()) -> ();
17-
#[import(cc = "thorin")] fn opencl(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _body: fn() -> ()) -> ();
18-
#[import(cc = "thorin")] fn amdgpu_hsa(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _body: fn() -> ()) -> ();
19-
#[import(cc = "thorin")] fn amdgpu_pal(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _body: fn() -> ()) -> ();
15+
#[import(cc = "thorin", name = "cuda")] fn cuda_with_lmem(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), i32, _body: fn() -> ()) -> ();
16+
#[import(cc = "thorin", name = "nvvm")] fn nvvm_with_lmem(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), i32, _body: fn() -> ()) -> ();
17+
#[import(cc = "thorin", name = "opencl")] fn opencl_with_lmem(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), i32, _body: fn() -> ()) -> ();
18+
#[import(cc = "thorin", name = "amdgpu_hsa")] fn amdgpu_hsa_with_lmem(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), i32, _body: fn() -> ()) -> ();
19+
#[import(cc = "thorin", name = "amdgpu_pal")] fn amdgpu_pal_with_lmem(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), i32, _body: fn() -> ()) -> ();
20+
#[import(cc = "thorin")] fn local_memory() -> &mut addrspace(3)[u8];
2021
#[import(cc = "thorin")] fn reserve_shared[T](_size: i32) -> &mut addrspace(3)[T];
2122
#[import(cc = "thorin")] fn hls(_dev: i32, _body: fn() -> ()) -> ();
2223
#[import(cc = "thorin", name = "pipeline")] fn thorin_pipeline(_initiation_interval: i32, _lower: i32, _upper: i32, _body: fn(i32) -> ()) -> (); // only for HLS/OpenCL backend
@@ -36,6 +37,12 @@
3637
#[import(cc = "thorin", name = "cmpxchg_weak")] fn cmpxchg_weak_p1[T](_addr: &mut addrspace(1)T, _cmp: T, _new: T, _success_order: u32, _failure_order: u32, _scope: &[u8]) -> (T, bool);
3738
#[import(cc = "thorin", name = "cmpxchg_weak")] fn cmpxchg_weak_p3[T](_addr: &mut addrspace(3)T, _cmp: T, _new: T, _success_order: u32, _failure_order: u32, _scope: &[u8]) -> (T, bool);
3839

40+
fn @cuda(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) = cuda_with_lmem(dev, grid, block, 0, body);
41+
fn @nvvm(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) = nvvm_with_lmem(dev, grid, block, 0, body);
42+
fn @opencl(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) = opencl_with_lmem(dev, grid, block, 0, body);
43+
fn @amdgpu_hsa(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) = amdgpu_hsa_with_lmem(dev, grid, block, 0, body);
44+
fn @amdgpu_pal(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) = amdgpu_pal_with_lmem(dev, grid, block, 0, body);
45+
3946
fn @pipeline(body: fn(i32) -> ()) = @|initiation_interval: i32, lower: i32, upper: i32| thorin_pipeline(initiation_interval, lower, upper, body);
4047
fn @parallel(body: fn(i32) -> ()) = @|num_threads: i32, lower: i32, upper: i32| thorin_parallel(num_threads, lower, upper, body);
4148
fn @spawn(body: fn() -> ()) = @|| thorin_spawn(body);

platforms/impala/intrinsics_thorin.impala

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ extern "thorin" {
1010
fn insert[T, U](T, i32, U) -> T;
1111
//fn shuffle[T](T, T, T) -> T;
1212

13-
fn cuda(i32, (i32, i32, i32), (i32, i32, i32), fn() -> ()) -> ();
14-
fn nvvm(i32, (i32, i32, i32), (i32, i32, i32), fn() -> ()) -> ();
15-
fn opencl(i32, (i32, i32, i32), (i32, i32, i32), fn() -> ()) -> ();
16-
fn amdgpu(i32, (i32, i32, i32), (i32, i32, i32), fn() -> ()) -> ();
13+
fn "cuda" cuda_with_lmem(i32, (i32, i32, i32), (i32, i32, i32), i32, fn() -> ()) -> ();
14+
fn "nvvm" nvvm_with_lmem(i32, (i32, i32, i32), (i32, i32, i32), i32, fn() -> ()) -> ();
15+
fn "opencl" opencl_with_lmem(i32, (i32, i32, i32), (i32, i32, i32), i32, fn() -> ()) -> ();
16+
fn "amdgpu_hsa" amdgpu_hsa_with_lmem(i32, (i32, i32, i32), (i32, i32, i32), i32, fn() -> ()) -> ();
17+
fn "amdgpu_pal" amdgpu_pal_with_lmem(i32, (i32, i32, i32), (i32, i32, i32), i32, fn() -> ()) -> ();
18+
fn local_memory() -> &mut[3][u8];
1719
fn reserve_shared[T](i32) -> &mut[3][T];
1820

1921
fn hls(dev: i32, body: fn() -> ()) -> ();
@@ -42,3 +44,9 @@ extern "thorin" {
4244

4345
fn vectorize(vector_length: i32, body: fn(i32) -> ()) -> ();
4446
}
47+
48+
fn @@cuda(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) { cuda_with_lmem(dev, grid, block, 0, body) }
49+
fn @@nvvm(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) { nvvm_with_lmem(dev, grid, block, 0, body) }
50+
fn @@opencl(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) { opencl_with_lmem(dev, grid, block, 0, body) }
51+
fn @@amdgpu_hsa(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) { amdgpu_hsa_with_lmem(dev, grid, block, 0, body) }
52+
fn @@amdgpu_pal(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) { amdgpu_pal_with_lmem(dev, grid, block, 0, body) }

src/anydsl_runtime.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ void anydsl_copy(
118118
void anydsl_launch_kernel(
119119
int32_t mask, const char* file_name, const char* kernel_name,
120120
const uint32_t* grid, const uint32_t* block,
121+
uint32_t lmem,
121122
void** arg_data,
122123
const uint32_t* arg_sizes,
123124
const uint32_t* arg_aligns,
@@ -129,6 +130,7 @@ void anydsl_launch_kernel(
129130
kernel_name,
130131
grid,
131132
block,
133+
lmem,
132134
{
133135
arg_data,
134136
arg_sizes,

src/anydsl_runtime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ AnyDSL_runtime_API void anydsl_copy(int32_t, const void*, int64_t, int32_t, void
3737
AnyDSL_runtime_API void anydsl_launch_kernel(
3838
int32_t, const char*, const char*,
3939
const uint32_t*, const uint32_t*,
40+
uint32_t,
4041
void**, const uint32_t*, const uint32_t*, const uint32_t*, const uint8_t*,
4142
uint32_t);
4243
AnyDSL_runtime_API void anydsl_synchronize(int32_t);

src/cuda_platform.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,9 @@ void CudaPlatform::launch_kernel(DeviceId dev, const LaunchParams& launch_params
217217
launch_params.grid[1] / launch_params.block[1],
218218
launch_params.grid[2] / launch_params.block[2],
219219
launch_params.block[0], launch_params.block[1], launch_params.block[2],
220-
0, nullptr, launch_params.args.data, nullptr);
220+
launch_params.lmem,
221+
nullptr,
222+
launch_params.args.data, nullptr);
221223
CHECK_CUDA(err, "cuLaunchKernel()");
222224

223225
if (runtime_->profiling_enabled()) {

src/hsa_platform.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ void HSAPlatform::launch_kernel(DeviceId dev, const LaunchParams& launch_params)
409409
aql.kernel_object = kernel_info.kernel;
410410
aql.kernarg_address = kernel_info.kernarg_segment;
411411
aql.private_segment_size = kernel_info.private_segment_size;
412-
aql.group_segment_size = kernel_info.group_segment_size;
412+
aql.group_segment_size = (kernel_info.group_segment_size + 15) / 16 * kernel_info.group_segment_size + launch_params.lmem;
413413

414414
// write to command queue
415415
const uint64_t index = hsa_queue_load_write_index_relaxed(queue);

src/opencl_platform.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,8 @@ void OpenCLPlatform::launch_kernel(DeviceId dev, const LaunchParams& launch_para
375375
cl_mem struct_buf = clCreateBuffer(devices_[dev].ctx, flags, launch_params.args.sizes[i], launch_params.args.data[i], &err);
376376
CHECK_OPENCL(err, "clCreateBuffer()");
377377
kernel_structs[i] = struct_buf;
378-
clSetKernelArg(kernel, i, sizeof(cl_mem), &kernel_structs[i]);
378+
err = clSetKernelArg(kernel, i, sizeof(cl_mem), &kernel_structs[i]);
379+
CHECK_OPENCL(err, "clSetKernelArg()");
379380
} else {
380381
#ifdef CL_VERSION_2_0
381382
if (launch_params.args.types[i] == KernelArgType::Ptr && devices_[dev].version_major == 2) {
@@ -391,6 +392,11 @@ void OpenCLPlatform::launch_kernel(DeviceId dev, const LaunchParams& launch_para
391392
}
392393
}
393394

395+
if (launch_params.lmem != 0) {
396+
cl_int err = clSetKernelArg(kernel, launch_params.num_args, launch_params.lmem, nullptr);
397+
CHECK_OPENCL(err, "clSetKernelArg()");
398+
}
399+
394400
size_t global_work_size[] = {launch_params.grid [0], launch_params.grid [1], launch_params.grid [2]};
395401
size_t local_work_size[] = {launch_params.block[0], launch_params.block[1], launch_params.block[2]};
396402

src/pal_platform.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ void PALPlatform::launch_kernel(DeviceId dev, const LaunchParams& launch_params)
215215
Pal::PipelineBindParams params = {};
216216
params.pipelineBindPoint = Pal::PipelineBindPoint::Compute;
217217
params.pPipeline = pipeline;
218+
params.cs.ldsBytesPerTg = launch_params.lmem; // TODO: add static LDS size
218219

219220
constexpr Pal::HwPipePoint pipe_point = Pal::HwPipePostCs;
220221
Pal::BarrierInfo barrier_info = {};

src/runtime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ struct LaunchParams {
3535
const char* kernel_name;
3636
const uint32_t* grid;
3737
const uint32_t* block;
38+
uint32_t lmem;
3839
ParamsArgs args;
3940
uint32_t num_args;
4041
};

0 commit comments

Comments
 (0)