Skip to content

Commit c324519

Browse files
Expand test_op with triton-fp8 provider
Signed-off-by: Whitney Tsang <[email protected]>
1 parent c10df5d commit c324519

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

python/tutorials/06-fused-attention.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -583,20 +583,29 @@ def backward(ctx, do):
583583

584584
attention = _attention.apply
585585

586+
TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2')
587+
586588

587589
@pytest.mark.parametrize("Z", [1, 4])
588590
@pytest.mark.parametrize("H", [2, 48])
589591
@pytest.mark.parametrize("N_CTX", [128, 1024, (2 if is_hip() else 4) * 1024])
590592
@pytest.mark.parametrize("HEAD_DIM", [64, 128])
591593
@pytest.mark.parametrize("causal", [True]) # FIXME: Non-causal tests do not pass at the moment.
592594
@pytest.mark.parametrize("warp_specialize", [False, True] if is_blackwell() else [False])
593-
def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, dtype=torch.float16):
595+
@pytest.mark.parametrize("mode", ["fwd", "bwd"])
596+
@pytest.mark.parametrize("provider", ["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []))
597+
def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, mode, provider, dtype=torch.float16):
594598
torch.manual_seed(20)
595599
q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
596600
k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
597601
v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
602+
if mode == "fwd" and "fp8" in provider:
603+
q = q.to(torch.float8_e5m2)
604+
k = k.to(torch.float8_e5m2)
605+
v = v.permute(0, 1, 3, 2).contiguous()
606+
v = v.permute(0, 1, 3, 2)
607+
v = v.to(torch.float8_e5m2)
598608
sm_scale = 0.5
599-
dout = torch.randn_like(q)
600609
# reference implementation
601610
M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
602611
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
@@ -605,18 +614,23 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, dtype=torch.float16)
605614
p = torch.softmax(p.float(), dim=-1).half()
606615
# p = torch.exp(p)
607616
ref_out = torch.matmul(p, v)
608-
ref_out.backward(dout)
609-
ref_dv, v.grad = v.grad.clone(), None
610-
ref_dk, k.grad = k.grad.clone(), None
611-
ref_dq, q.grad = q.grad.clone(), None
617+
if mode == "bwd":
618+
dout = torch.randn_like(q)
619+
ref_out.backward(dout)
620+
ref_dv, v.grad = v.grad.clone(), None
621+
ref_dk, k.grad = k.grad.clone(), None
622+
ref_dq, q.grad = q.grad.clone(), None
612623
# triton implementation
613624
tri_out = attention(q, k, v, causal, sm_scale, warp_specialize).half()
614-
tri_out.backward(dout)
615-
tri_dv, v.grad = v.grad.clone(), None
616-
tri_dk, k.grad = k.grad.clone(), None
617-
tri_dq, q.grad = q.grad.clone(), None
625+
if mode == "bwd":
626+
tri_out.backward(dout)
627+
tri_dv, v.grad = v.grad.clone(), None
628+
tri_dk, k.grad = k.grad.clone(), None
629+
tri_dq, q.grad = q.grad.clone(), None
618630
# compare
619631
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0)
632+
if mode == "fwd":
633+
return
620634
rtol = 0.0
621635
# Relative tolerance workaround for known hardware limitation of CDNA2 GPU.
622636
# For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
@@ -634,7 +648,6 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, dtype=torch.float16)
634648
except BaseException:
635649
HAS_FLASH = False
636650

637-
TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2')
638651
BATCH, N_HEADS, HEAD_DIM = 4, 32, 64
639652
# vary seq length for fixed head and batch=4
640653
configs = []

0 commit comments

Comments
 (0)