Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ __global__ static void gather_or_scatter_along_first_dim_kernel(
SmemT& smem = *reinterpret_cast<SmemT*>(smem_raw);

int indexing_dim = IsGather ? N : M;
const int n_or_m_offset = blockIdx.x * kBlkNOrM;
const auto n_or_m_offset = blockIdx.x * kBlkNOrM;
if (n_or_m_offset >= indexing_dim) {
return;
}
Expand Down Expand Up @@ -344,17 +344,17 @@ void scatter_add_along_first_dim(
at::Tensor dst,
at::Tensor src,
at::Tensor index) {
const int M = src.size(0);
const int K = src.size(1);
const int N = index.size(0);
if (N == 0 || M == 0) {
assert(M == 0);
return;
}
if (dst.is_contiguous() && dst.dim() == 2 && src.is_contiguous() &&
src.dim() == 2 && index.is_contiguous() && index.dim() == 1) {
using T = cutlass::bfloat16_t;

const int M = src.size(0);
const int K = src.size(1);
const int N = index.size(0);
if (N == 0) {
assert(M == 0);
return;
}
assert(dst.size(1) == K);
// TODO(shikaili): Make it supports more configurations.
if (dst.dtype() == at::kBFloat16 && src.dtype() == at::kBFloat16 &&
Expand All @@ -377,7 +377,6 @@ void scatter_add_along_first_dim(
}
}

const int K = src.size(1);
dst.scatter_add_(0, index.to(at::kLong).unsqueeze(1).expand({-1, K}), src);
}

Expand Down
Loading