diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py index f36404ca6..1cb9c9cae 100644 --- a/lightllm/models/vit/triton_kernel/flashattention_nopad.py +++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py @@ -48,7 +48,7 @@ def _fwd_kernel( q = tl.load(Q + off_q, mask=offs_m[:, None] < seq_len, other=0.0) # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) for start_n in range(0, seq_len, BLOCK_N): @@ -68,7 +68,7 @@ def _fwd_kernel( qk += tl.where((start_n + offs_n[None, :]) < seq_len, 0, float("-inf")) # -- compute m_ij, p, l_ij - m_ij = tl.maximum(tl.max(qk, 1), l_i) + m_ij = tl.maximum(tl.max(qk, 1), m_i) p = tl.exp(qk - m_ij[:, None]) l_ij = tl.sum(p, 1)