diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 1b004ec100..1b02059399 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -107,7 +107,7 @@ if(FMHA_FWD_FAST_EXP2) else() list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) endif() -list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero) +list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero -fbracket-depth=512) # conditionally enable call to the fwd_splitkv API in fmha_fwd example if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 30b524d606..fd1dd0f494 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -149,8 +149,33 @@ auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); constexpr dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( - ck_tile::stream_config{{s.stream_id_}}); + + // shared memory size need by fmha_bwd_dq_dk_dv_kernel_{F_idx} kernel + auto shared_mem_bytes = fmha_bwd_dq_dk_dv_kernel_{F_idx}::GetSmemSize(); + + // device properties + hipDeviceProp_t deviceProps; + HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0)); + auto shared_mem_bytes_limit = deviceProps.sharedMemPerBlock; + + // use dynamic shared memory if it is less than device limit + // otherwise, use workspace memory which is global memory + + if (static_cast(shared_mem_bytes) > shared_mem_bytes_limit) {{ + // use workspace memory + char *workspace = nullptr; + auto workspace_size = shared_mem_bytes * grids.x * grids.y * grids.z; + HIP_CHECK_ERROR(hipMallocAsync(&workspace, workspace_size, s.stream_id_)); + + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs, workspace)( + ck_tile::stream_config{{s.stream_id_}}); + + HIP_CHECK_ERROR(hipFreeAsync(workspace, s.stream_id_)); + }} else {{ + // use dynamic shared memory + ck_tile::make_kernel(k_{{}}, grids, blocks, shared_mem_bytes, kargs)( + ck_tile::stream_config{{s.stream_id_}}); + }} }} template <> @@ -358,6 +383,8 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"], '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + "kr_ktr_vr_iglp", "kr_ktr_vr"], + '512' : [FmhaBwdDQDKDVTileSize( 16, 64, 512, 16, 512, 16, 32, 512, 512, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"] } else: diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index ce3bf8fe8d..b80d1b7bcc 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -1085,11 +1085,24 @@ struct FmhaBwdDQDKDVKernel VGradEpiloguePipeline::GetSmemSize()); } + CK_TILE_DEVICE void operator()(Kargs kargs, char* workspace) const + { + // use workspace as shared memory + auto smem_size_per_block = GetSmemSize(); + auto block_id = gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x; + return operator_core(kargs, workspace + block_id * smem_size_per_block); + } + CK_TILE_DEVICE void operator()(Kargs kargs) const { - // allocate LDS - __shared__ char smem_ptr[GetSmemSize()]; + // use dynamic shared memory + extern __shared__ char smem_ptr[]; + return operator_core(kargs, smem_ptr); + } + + CK_TILE_DEVICE void operator_core(Kargs kargs, char* smem_ptr) const + { // divide problem const auto [i_tile_n, i_nhead, i_batch] = GetTileIndex(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index d353203e0e..ef1e015f7a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -587,7 +587,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kKPerBlock = Problem::kQKHeaddim; - constexpr index_t K1 = 16 / sizeof(AccDataType); + constexpr index_t K1 = 32 / sizeof(AccDataType); constexpr index_t K0 = kKPerBlock / K1; constexpr index_t M2 = get_warp_size() / K0;