Skip to content

Commit aaea9f8

Browse files
authored
[Tutorial] Fix attention tutorial and enable pytests for DHEAD=128 (#7037)
1 parent bad2950 commit aaea9f8

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

python/tutorials/06-fused-attention.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def supports_host_descriptor():
3636
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
3737

3838

39+
def is_blackwell():
40+
return is_cuda() and torch.cuda.get_device_capability()[0] == 10
41+
42+
3943
@triton.jit
4044
def _attn_fwd_inner(acc, l_i, m_i, q, #
4145
desc_k, desc_v, #
@@ -115,7 +119,7 @@ def _host_descriptor_pre_hook(nargs):
115119
if "PYTEST_VERSION" in os.environ:
116120
# Use a single config in testing for reproducibility
117121
configs = [
118-
triton.Config(dict(BLOCK_M=64, BLOCK_N=64), num_stages=4, num_warps=4, pre_hook=_host_descriptor_pre_hook),
122+
triton.Config(dict(BLOCK_M=64, BLOCK_N=64), num_stages=2, num_warps=4, pre_hook=_host_descriptor_pre_hook),
119123
]
120124

121125

@@ -483,10 +487,10 @@ def forward(ctx, q, k, v, causal, sm_scale, warp_specialize=True):
483487
y_dim = q.shape[0] * q.shape[1] * q.shape[2]
484488

485489
dummy_block = [1, 1]
486-
desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM, 1], block_shape=dummy_block)
487-
desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM, 1], block_shape=dummy_block)
488-
desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM, 1], block_shape=dummy_block)
489-
desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM, 1], block_shape=dummy_block)
490+
desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
491+
desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
492+
desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
493+
desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
490494
else:
491495
desc_q = q
492496
desc_v = v
@@ -509,7 +513,7 @@ def grid(META):
509513
q.shape[0], q.shape[1], #
510514
desc_q, desc_k, desc_v, desc_o, #
511515
N_CTX=q.shape[2], #
512-
HEAD_DIM=HEAD_DIM, #
516+
HEAD_DIM=HEAD_DIM_K, #
513517
FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
514518
STAGE=stage, #
515519
warp_specialize=warp_specialize, #
@@ -567,17 +571,12 @@ def backward(ctx, do):
567571
attention = _attention.apply
568572

569573

570-
@pytest.mark.parametrize('Z, H, N_CTX, HEAD_DIM', [
571-
(1, 2, 1024, 64),
572-
(4, 48, 128, 64),
573-
(4, 48, 256, 64),
574-
(4, 48, 512, 64),
575-
(4, 48, 1024, 64),
576-
(4, 48, 2048, 64),
577-
(4, 48, 4096, 64),
578-
])
579-
@pytest.mark.parametrize("causal", [True])
580-
@pytest.mark.parametrize("warp_specialize", [False, True])
574+
@pytest.mark.parametrize("Z", [1, 4])
575+
@pytest.mark.parametrize("H", [2, 48])
576+
@pytest.mark.parametrize("N_CTX", [128, 1024, (2 if is_hip() else 4) * 1024])
577+
@pytest.mark.parametrize("HEAD_DIM", [64, 128])
578+
@pytest.mark.parametrize("causal", [True]) # FIXME: Non-causal tests do not pass at the moment.
579+
@pytest.mark.parametrize("warp_specialize", [False, True] if is_blackwell() else [False])
581580
def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, dtype=torch.float16):
582581
torch.manual_seed(20)
583582
q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())

0 commit comments

Comments
 (0)