-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Description
Hi, I find a big results difference when using tl.store (under bfloat16).
# -*- coding: utf-8 -*-
import torch
import triton
import triton.language as tl
@triton.jit
def attention_nostore_fwd_kernel(
q,
k,
v,
h,
o,
s_qh,
s_qt,
s_qd,
s_hh,
s_ht,
H,
T,
TD,
scale,
BT: tl.constexpr,
BD: tl.constexpr
):
i_bh = tl.program_id(0)
# [BD, BD]
b_h = tl.zeros([BD, BD], dtype=tl.float32)
for i in range(0, tl.cdiv(T, BT)):
p_q = tl.make_block_ptr(q + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qh, (BD, T), (s_qd, s_qt), (0, i * BT), (BD, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_hh, (TD, BD), (s_ht, s_qd), (i * BD, 0), (BD, BD), (1, 0))
p_o = tl.make_block_ptr(o + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
# [BT, BD]
b_q = tl.load(p_q)
b_q = (b_q * scale).to(b_q.dtype)
# [BD, BT]
b_k = tl.load(p_k)
# [BT, BD]
b_v = tl.load(p_v)
# [BT, BT]
b_s = tl.dot(b_q, b_k, allow_tf32=False)
# [BT, BD]
b_o = tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
# tl.store(p_h, b_h.to(p_h.dtype.element_ty))
tl.store(p_o, b_o.to(p_o.dtype.element_ty))
# [BD, BD]
b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)
@triton.jit
def attention_store_fwd_kernel(
q,
k,
v,
h,
o,
s_qh,
s_qt,
s_qd,
s_hh,
s_ht,
H,
T,
TD,
scale,
BT: tl.constexpr,
BD: tl.constexpr
):
i_bh = tl.program_id(0)
# [BD, BD]
b_h = tl.zeros([BD, BD], dtype=tl.float32)
for i in range(0, tl.cdiv(T, BT)):
p_q = tl.make_block_ptr(q + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qh, (BD, T), (s_qd, s_qt), (0, i * BT), (BD, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_hh, (TD, BD), (s_ht, s_qd), (i * BD, 0), (BD, BD), (1, 0))
p_o = tl.make_block_ptr(o + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
# [BT, BD]
b_q = tl.load(p_q)
b_q = (b_q * scale).to(b_q.dtype)
# [BD, BT]
b_k = tl.load(p_k)
# [BT, BD]
b_v = tl.load(p_v)
# [BT, BT]
b_s = tl.dot(b_q, b_k, allow_tf32=False)
# [BT, BD]
b_o = tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
tl.store(p_h, b_h.to(p_h.dtype.element_ty))
tl.store(p_o, b_o.to(p_o.dtype.element_ty))
# [BD, BD]
b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)
class AttentionFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, store=False):
batch_size, n_heads, seq_len, d_head = q.shape
scale = d_head ** -0.5
BD = q.shape[-1]
BT = 32
num_stages = 3 if d_head <= 64 else 2
num_warps = 4
h = q.new_empty(batch_size, n_heads, triton.cdiv(seq_len, BT) * BD, BD)
o = torch.empty_like(q)
grid = (batch_size * n_heads,)
kernel = attention_store_fwd_kernel if store else attention_nostore_fwd_kernel
kernel[grid](
q, k, v, h, o,
q.stride(1), q.stride(2), q.stride(3), h.stride(1), h.stride(2),
n_heads, seq_len, h.shape[2], scale,
BT=BT, BD=BD,
num_warps=num_warps,
num_stages=num_stages
)
return o
if __name__ == '__main__':
B, H, T, D = 2, 8, 1024, 128
dtype = torch.bfloat16
torch.manual_seed(42)
# [batch_size, n_heads, seq_len, d_head]
q = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
k = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
v = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
print('Testing BFloat16...')
ref = AttentionFunction.apply(q, k, v, True)
tri = AttentionFunction.apply(q, k, v, False)
print(ref[0, 0])
print(tri[0, 0])
print('Diff:', (ref - tri).abs().max(), '\n\n')
print('Testing Float...')
q, k, v = q.float(), k.float(), v.float()
ref = AttentionFunction.apply(q, k, v, True)
tri = AttentionFunction.apply(q, k, v, False)
print(ref[0, 0])
print(tri[0, 0])
print('Diff:', (ref - tri).abs().max(), '\n\n')
I hve pasted the tailored code here for ease of reproduction.
The only differnce between attention_nostore_fwd_kernel
and attention_store_fwd_kernel
is tl.store(p_h, b_h.to(p_h.dtype.element_ty))
, which saves the intermediate results to HBMs, and the output is
Testing BFloat16...
tensor([[ -3.0000, 4.3125, -5.2812, ..., -3.1094, 4.4062, 1.4141],
[ -4.1875, 8.4375, 7.3750, ..., -4.2188, 0.7227, 4.2188],
[ -8.3750, 9.2500, -3.0938, ..., -10.4375, 3.5312, -1.4688],
...,
[ 33.0000, -45.5000, -27.5000, ..., 47.5000, 11.7500, -42.5000],
[-12.8125, -13.3125, -29.5000, ..., -28.2500, -17.8750, -8.5625],
[ -8.0625, 19.0000, -25.1250, ..., 44.0000, 31.8750, 0.7148]],
device='cuda:0', dtype=torch.bfloat16)
tensor([[ -3.0000, 4.3125, -5.2812, ..., -3.1094, 4.4062, 1.4141],
[ -4.1875, 8.4375, 7.3750, ..., -4.2188, 0.7227, 4.2188],
[ -8.3750, 9.2500, -3.0938, ..., -10.4375, 3.5312, -1.4688],
...,
[ 21.0000, -57.2500, 94.0000, ..., -6.8125, -43.5000, -3.3281],
[ 91.0000, 29.6250, 0.9414, ..., 15.3750, -4.5000, 13.4375],
[ 8.0625, -24.6250, 21.8750, ..., 1.3672, -21.3750, 96.0000]],
device='cuda:0', dtype=torch.bfloat16)
Diff: tensor(223., device='cuda:0', dtype=torch.bfloat16)
Testing Float...
tensor([[ -2.9934, 4.3165, -5.2756, ..., -3.1058, 4.4168, 1.4156],
[ -4.1945, 8.4496, 7.3791, ..., -4.2401, 0.7162, 4.2193],
[ -8.4092, 9.2836, -3.0875, ..., -10.4867, 3.5389, -1.4640],
...,
[ 32.8595, -45.5166, -27.4276, ..., 47.6876, 11.7916, -42.2014],
[-12.8712, -13.3366, -29.3908, ..., -28.2100, -17.8268, -8.5858],
[ -8.0902, 19.0255, -25.1759, ..., 44.1477, 31.9002, 0.8230]],
device='cuda:0')
tensor([[ -2.9934, 4.3165, -5.2756, ..., -3.1058, 4.4168, 1.4156],
[ -4.1945, 8.4496, 7.3791, ..., -4.2401, 0.7162, 4.2193],
[ -8.4092, 9.2836, -3.0875, ..., -10.4867, 3.5389, -1.4640],
...,
[ 32.8595, -45.5166, -27.4276, ..., 47.6876, 11.7916, -42.2014],
[-12.8712, -13.3366, -29.3908, ..., -28.2100, -17.8268, -8.5858],
[ -8.0902, 19.0255, -25.1759, ..., 44.1477, 31.9002, 0.8230]],
device='cuda:0')
Diff: tensor(0., device='cuda:0')
The results are consistant under float.
With minor code changes, however, there is a big unacceptable difference in the final outputs under bfloat16 dtype.
Also, the results of bfloat16 can be the same if the inputs are restricted in a very small range, e.g., divided by 1024.
I guess the evil stems from the precision of bfloat16. but I can't figure out why tl.store
brings such a big difference, and how to solve this question.
Could you give me some hints?
The environment is Triton 2.1 & A100-SXM4-40GB.
Thanks.