Skip to content

[CK_TILE][FMHA][Feature] Add support for large hdim #2607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/ck_tile/01_fmha/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ if(FMHA_FWD_FAST_EXP2)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif()
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero -fbracket-depth=512)

# conditionally enable call to the fwd_splitkv API in fmha_fwd example
if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS)
Expand Down
31 changes: 29 additions & 2 deletions example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,33 @@
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});

// shared memory size need by fmha_bwd_dq_dk_dv_kernel_{F_idx} kernel
auto shared_mem_bytes = fmha_bwd_dq_dk_dv_kernel_{F_idx}::GetSmemSize();

// device properties
hipDeviceProp_t deviceProps;
HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0));
auto shared_mem_bytes_limit = deviceProps.sharedMemPerBlock;

// use dynamic shared memory if it is less than device limit
// otherwise, use workspace memory which is global memory

if (static_cast<size_t>(shared_mem_bytes) > shared_mem_bytes_limit) {{
// use workspace memory
char *workspace = nullptr;
auto workspace_size = shared_mem_bytes * grids.x * grids.y * grids.z;
HIP_CHECK_ERROR(hipMallocAsync(&workspace, workspace_size, s.stream_id_));

ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs, workspace)(
ck_tile::stream_config{{s.stream_id_}});

HIP_CHECK_ERROR(hipFreeAsync(workspace, s.stream_id_));
}} else {{
// use dynamic shared memory
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, shared_mem_bytes, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
}}

template <>
Expand Down Expand Up @@ -358,6 +383,8 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'512' : [FmhaBwdDQDKDVTileSize( 16, 64, 512, 16, 512, 16, 32, 512, 512, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"]
}
else:
Expand Down
17 changes: 15 additions & 2 deletions include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1085,11 +1085,24 @@ struct FmhaBwdDQDKDVKernel
VGradEpiloguePipeline::GetSmemSize());
}

CK_TILE_DEVICE void operator()(Kargs kargs, char* workspace) const
{
// use workspace as shared memory
auto smem_size_per_block = GetSmemSize();
auto block_id = gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x;
return operator_core(kargs, workspace + block_id * smem_size_per_block);
}

CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
// use dynamic shared memory
extern __shared__ char smem_ptr[];

return operator_core(kargs, smem_ptr);
}

CK_TILE_DEVICE void operator_core(Kargs kargs, char* smem_ptr) const
{
// divide problem
const auto [i_tile_n, i_nhead, i_batch] = GetTileIndex();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kKPerBlock = Problem::kQKHeaddim;

constexpr index_t K1 = 16 / sizeof(AccDataType);
constexpr index_t K1 = 32 / sizeof(AccDataType);
constexpr index_t K0 = kKPerBlock / K1;

constexpr index_t M2 = get_warp_size() / K0;
Expand Down