@@ -36,6 +36,10 @@ def supports_host_descriptor():
36
36
return is_cuda () and torch .cuda .get_device_capability ()[0 ] >= 9
37
37
38
38
39
+ def is_blackwell ():
40
+ return is_cuda () and torch .cuda .get_device_capability ()[0 ] == 10
41
+
42
+
39
43
@triton .jit
40
44
def _attn_fwd_inner (acc , l_i , m_i , q , #
41
45
desc_k , desc_v , #
@@ -115,7 +119,7 @@ def _host_descriptor_pre_hook(nargs):
115
119
if "PYTEST_VERSION" in os .environ :
116
120
# Use a single config in testing for reproducibility
117
121
configs = [
118
- triton .Config (dict (BLOCK_M = 64 , BLOCK_N = 64 ), num_stages = 4 , num_warps = 4 , pre_hook = _host_descriptor_pre_hook ),
122
+ triton .Config (dict (BLOCK_M = 64 , BLOCK_N = 64 ), num_stages = 2 , num_warps = 4 , pre_hook = _host_descriptor_pre_hook ),
119
123
]
120
124
121
125
@@ -483,10 +487,10 @@ def forward(ctx, q, k, v, causal, sm_scale, warp_specialize=True):
483
487
y_dim = q .shape [0 ] * q .shape [1 ] * q .shape [2 ]
484
488
485
489
dummy_block = [1 , 1 ]
486
- desc_q = TensorDescriptor (q , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM , 1 ], block_shape = dummy_block )
487
- desc_v = TensorDescriptor (v , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM , 1 ], block_shape = dummy_block )
488
- desc_k = TensorDescriptor (k , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM , 1 ], block_shape = dummy_block )
489
- desc_o = TensorDescriptor (o , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM , 1 ], block_shape = dummy_block )
490
+ desc_q = TensorDescriptor (q , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = dummy_block )
491
+ desc_v = TensorDescriptor (v , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = dummy_block )
492
+ desc_k = TensorDescriptor (k , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = dummy_block )
493
+ desc_o = TensorDescriptor (o , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = dummy_block )
490
494
else :
491
495
desc_q = q
492
496
desc_v = v
@@ -509,7 +513,7 @@ def grid(META):
509
513
q .shape [0 ], q .shape [1 ], #
510
514
desc_q , desc_k , desc_v , desc_o , #
511
515
N_CTX = q .shape [2 ], #
512
- HEAD_DIM = HEAD_DIM , #
516
+ HEAD_DIM = HEAD_DIM_K , #
513
517
FP8_OUTPUT = q .dtype == torch .float8_e5m2 , #
514
518
STAGE = stage , #
515
519
warp_specialize = warp_specialize , #
@@ -567,17 +571,12 @@ def backward(ctx, do):
567
571
attention = _attention .apply
568
572
569
573
570
- @pytest .mark .parametrize ('Z, H, N_CTX, HEAD_DIM' , [
571
- (1 , 2 , 1024 , 64 ),
572
- (4 , 48 , 128 , 64 ),
573
- (4 , 48 , 256 , 64 ),
574
- (4 , 48 , 512 , 64 ),
575
- (4 , 48 , 1024 , 64 ),
576
- (4 , 48 , 2048 , 64 ),
577
- (4 , 48 , 4096 , 64 ),
578
- ])
579
- @pytest .mark .parametrize ("causal" , [True ])
580
- @pytest .mark .parametrize ("warp_specialize" , [False , True ])
574
+ @pytest .mark .parametrize ("Z" , [1 , 4 ])
575
+ @pytest .mark .parametrize ("H" , [2 , 48 ])
576
+ @pytest .mark .parametrize ("N_CTX" , [128 , 1024 , (2 if is_hip () else 4 ) * 1024 ])
577
+ @pytest .mark .parametrize ("HEAD_DIM" , [64 , 128 ])
578
+ @pytest .mark .parametrize ("causal" , [True ]) # FIXME: Non-causal tests do not pass at the moment.
579
+ @pytest .mark .parametrize ("warp_specialize" , [False , True ] if is_blackwell () else [False ])
581
580
def test_op (Z , H , N_CTX , HEAD_DIM , causal , warp_specialize , dtype = torch .float16 ):
582
581
torch .manual_seed (20 )
583
582
q = (torch .empty ((Z , H , N_CTX , HEAD_DIM ), dtype = dtype , device = DEVICE ).normal_ (mean = 0.0 , std = 0.5 ).requires_grad_ ())
0 commit comments