Skip to content

Commit 9baa06d

Browse files
Add Blackwell MLA forward (shape: d=192, dv=128) implementation in example_77 (#2472)
1 parent ebe98c5 commit 9baa06d

13 files changed

+3323
-40
lines changed

examples/77_blackwell_fmha/77_blackwell_fmha.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,7 @@ struct FwdRunner {
853853
flops *= static_cast<double>(size<1>(problem_shape));
854854
flops *= static_cast<double>(size<3,1>(problem_shape));
855855
}
856-
flops *= 4.0 * (std::is_same_v<ActiveMask, CausalMask> ? 0.5 : 1.0);
856+
flops *= 4.0 * (std::is_same_v<ActiveMask, CausalMask<true>> || std::is_same_v<ActiveMask, CausalMask<false>> ? 0.5 : 1.0);
857857
flops *= static_cast<double>(size<2>(problem_shape));
858858
flops *= static_cast<double>(size<3,0>(problem_shape));
859859
double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/);

0 commit comments

Comments
 (0)