@@ -583,20 +583,29 @@ def backward(ctx, do):
583
583
584
584
attention = _attention .apply
585
585
586
+ TORCH_HAS_FP8 = hasattr (torch , 'float8_e5m2' )
587
+
586
588
587
589
@pytest .mark .parametrize ("Z" , [1 , 4 ])
588
590
@pytest .mark .parametrize ("H" , [2 , 48 ])
589
591
@pytest .mark .parametrize ("N_CTX" , [128 , 1024 , (2 if is_hip () else 4 ) * 1024 ])
590
592
@pytest .mark .parametrize ("HEAD_DIM" , [64 , 128 ])
591
593
@pytest .mark .parametrize ("causal" , [True ]) # FIXME: Non-causal tests do not pass at the moment.
592
594
@pytest .mark .parametrize ("warp_specialize" , [False , True ] if is_blackwell () else [False ])
593
- def test_op (Z , H , N_CTX , HEAD_DIM , causal , warp_specialize , dtype = torch .float16 ):
595
+ @pytest .mark .parametrize ("mode" , ["fwd" , "bwd" ])
596
+ @pytest .mark .parametrize ("provider" , ["triton-fp16" ] + (["triton-fp8" ] if TORCH_HAS_FP8 else []))
597
+ def test_op (Z , H , N_CTX , HEAD_DIM , causal , warp_specialize , mode , provider , dtype = torch .float16 ):
594
598
torch .manual_seed (20 )
595
599
q = (torch .empty ((Z , H , N_CTX , HEAD_DIM ), dtype = dtype , device = DEVICE ).normal_ (mean = 0.0 , std = 0.5 ).requires_grad_ ())
596
600
k = (torch .empty ((Z , H , N_CTX , HEAD_DIM ), dtype = dtype , device = DEVICE ).normal_ (mean = 0.0 , std = 0.5 ).requires_grad_ ())
597
601
v = (torch .empty ((Z , H , N_CTX , HEAD_DIM ), dtype = dtype , device = DEVICE ).normal_ (mean = 0.0 , std = 0.5 ).requires_grad_ ())
602
+ if mode == "fwd" and "fp8" in provider :
603
+ q = q .to (torch .float8_e5m2 )
604
+ k = k .to (torch .float8_e5m2 )
605
+ v = v .permute (0 , 1 , 3 , 2 ).contiguous ()
606
+ v = v .permute (0 , 1 , 3 , 2 )
607
+ v = v .to (torch .float8_e5m2 )
598
608
sm_scale = 0.5
599
- dout = torch .randn_like (q )
600
609
# reference implementation
601
610
M = torch .tril (torch .ones ((N_CTX , N_CTX ), device = DEVICE ))
602
611
p = torch .matmul (q , k .transpose (2 , 3 )) * sm_scale
@@ -605,18 +614,23 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, dtype=torch.float16)
605
614
p = torch .softmax (p .float (), dim = - 1 ).half ()
606
615
# p = torch.exp(p)
607
616
ref_out = torch .matmul (p , v )
608
- ref_out .backward (dout )
609
- ref_dv , v .grad = v .grad .clone (), None
610
- ref_dk , k .grad = k .grad .clone (), None
611
- ref_dq , q .grad = q .grad .clone (), None
617
+ if mode == "bwd" :
618
+ dout = torch .randn_like (q )
619
+ ref_out .backward (dout )
620
+ ref_dv , v .grad = v .grad .clone (), None
621
+ ref_dk , k .grad = k .grad .clone (), None
622
+ ref_dq , q .grad = q .grad .clone (), None
612
623
# triton implementation
613
624
tri_out = attention (q , k , v , causal , sm_scale , warp_specialize ).half ()
614
- tri_out .backward (dout )
615
- tri_dv , v .grad = v .grad .clone (), None
616
- tri_dk , k .grad = k .grad .clone (), None
617
- tri_dq , q .grad = q .grad .clone (), None
625
+ if mode == "bwd" :
626
+ tri_out .backward (dout )
627
+ tri_dv , v .grad = v .grad .clone (), None
628
+ tri_dk , k .grad = k .grad .clone (), None
629
+ tri_dq , q .grad = q .grad .clone (), None
618
630
# compare
619
631
torch .testing .assert_close (ref_out , tri_out , atol = 1e-2 , rtol = 0 )
632
+ if mode == "fwd" :
633
+ return
620
634
rtol = 0.0
621
635
# Relative tolerance workaround for known hardware limitation of CDNA2 GPU.
622
636
# For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
@@ -634,7 +648,6 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, dtype=torch.float16)
634
648
except BaseException :
635
649
HAS_FLASH = False
636
650
637
- TORCH_HAS_FP8 = hasattr (torch , 'float8_e5m2' )
638
651
BATCH , N_HEADS , HEAD_DIM = 4 , 32 , 64
639
652
# vary seq length for fixed head and batch=4
640
653
configs = []
0 commit comments