Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 14 additions & 24 deletions fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,22 +105,17 @@ DLL_PUBLIC Tensor linearize_cache_indices_cuda(
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "linearize_cache_indices_kernel_2", [&] {
#ifdef FBGEMM_GPU_MEMCHECK
const char* func_name = "linearize_cache_indices_kernel";
#endif
linearize_cache_indices_kernel<<<
FBGEMM_LAUNCH_KERNEL(
(linearize_cache_indices_kernel<index_t, offset_t>),
div_round_up(num_indices, kMaxThreads),
kMaxThreads,
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(
func_name, cache_hash_size_cumsum, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, table_offsets, offset_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, linear_cache_indices, int64_t, 1, 32),
at::cuda::getCurrentCUDAStream(),
PTA_B(cache_hash_size_cumsum, int64_t, 1, 32),
PTA_B(indices, index_t, 1, 32),
PTA_B(table_offsets, offset_t, 1, 32),
PTA_B(linear_cache_indices, int64_t, 1, 32),
indices_base_offset);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
return linear_cache_indices;
Expand Down Expand Up @@ -181,21 +176,16 @@ DLL_PUBLIC Tensor linearize_cache_indices_from_row_idx_cuda(
update_row_indices.scalar_type(),
"linearize_cache_indices_from_row_idx_kernel",
[&] {
#ifdef FBGEMM_GPU_MEMCHECK
const char* func_name = "linearize_cache_indices_from_row_idx_kernel";
#endif
linearize_cache_indices_from_row_idx_kernel<<<
FBGEMM_LAUNCH_KERNEL(
(linearize_cache_indices_from_row_idx_kernel<index_t>),
div_round_up(num_indices, kMaxThreads),
kMaxThreads,
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(
func_name, cache_hash_size_cumsum, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, update_table_indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, update_row_indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, linear_cache_indices, index_t, 1, 32));
C10_CUDA_KERNEL_LAUNCH_CHECK();
at::cuda::getCurrentCUDAStream(),
PTA_B(cache_hash_size_cumsum, int64_t, 1, 32),
PTA_B(update_table_indices, index_t, 1, 32),
PTA_B(update_row_indices, index_t, 1, 32),
PTA_B(linear_cache_indices, index_t, 1, 32));
});
return linear_cache_indices;
}
Expand Down
Loading