Skip to content

Big results difference when using tl.store #2852

@yzhangcs

Description

@yzhangcs

Hi, I find a big results difference when using tl.store (under bfloat16).

# -*- coding: utf-8 -*-

import torch
import triton
import triton.language as tl


@triton.jit
def attention_nostore_fwd_kernel(
    q,
    k,
    v,
    h,
    o,
    s_qh,
    s_qt,
    s_qd,
    s_hh,
    s_ht,
    H,
    T,
    TD,
    scale,
    BT: tl.constexpr,
    BD: tl.constexpr
):
    i_bh = tl.program_id(0)

    # [BD, BD]
    b_h = tl.zeros([BD, BD], dtype=tl.float32)
    for i in range(0, tl.cdiv(T, BT)):
        p_q = tl.make_block_ptr(q + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
        p_k = tl.make_block_ptr(k + i_bh * s_qh, (BD, T), (s_qd, s_qt), (0, i * BT), (BD, BT), (0, 1))
        p_v = tl.make_block_ptr(v + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
        p_h = tl.make_block_ptr(h + i_bh * s_hh, (TD, BD), (s_ht, s_qd), (i * BD, 0), (BD, BD), (1, 0))
        p_o = tl.make_block_ptr(o + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))

        # [BT, BD]
        b_q = tl.load(p_q)
        b_q = (b_q * scale).to(b_q.dtype)
        # [BD, BT]
        b_k = tl.load(p_k)
        # [BT, BD]
        b_v = tl.load(p_v)

        # [BT, BT]
        b_s = tl.dot(b_q, b_k, allow_tf32=False)
        # [BT, BD]
        b_o = tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
        b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)

        # tl.store(p_h, b_h.to(p_h.dtype.element_ty))
        tl.store(p_o, b_o.to(p_o.dtype.element_ty))

        # [BD, BD]
        b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)


@triton.jit
def attention_store_fwd_kernel(
    q,
    k,
    v,
    h,
    o,
    s_qh,
    s_qt,
    s_qd,
    s_hh,
    s_ht,
    H,
    T,
    TD,
    scale,
    BT: tl.constexpr,
    BD: tl.constexpr
):
    i_bh = tl.program_id(0)

    # [BD, BD]
    b_h = tl.zeros([BD, BD], dtype=tl.float32)
    for i in range(0, tl.cdiv(T, BT)):
        p_q = tl.make_block_ptr(q + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
        p_k = tl.make_block_ptr(k + i_bh * s_qh, (BD, T), (s_qd, s_qt), (0, i * BT), (BD, BT), (0, 1))
        p_v = tl.make_block_ptr(v + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
        p_h = tl.make_block_ptr(h + i_bh * s_hh, (TD, BD), (s_ht, s_qd), (i * BD, 0), (BD, BD), (1, 0))
        p_o = tl.make_block_ptr(o + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))

        # [BT, BD]
        b_q = tl.load(p_q)
        b_q = (b_q * scale).to(b_q.dtype)
        # [BD, BT]
        b_k = tl.load(p_k)
        # [BT, BD]
        b_v = tl.load(p_v)

        # [BT, BT]
        b_s = tl.dot(b_q, b_k, allow_tf32=False)
        # [BT, BD]
        b_o = tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
        b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)

        tl.store(p_h, b_h.to(p_h.dtype.element_ty))
        tl.store(p_o, b_o.to(p_o.dtype.element_ty))

        # [BD, BD]
        b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)


class AttentionFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k, v, store=False):
        batch_size, n_heads, seq_len, d_head = q.shape
        scale = d_head ** -0.5
        BD = q.shape[-1]
        BT = 32
        num_stages = 3 if d_head <= 64 else 2
        num_warps = 4

        h = q.new_empty(batch_size, n_heads, triton.cdiv(seq_len, BT) * BD, BD)
        o = torch.empty_like(q)
        grid = (batch_size * n_heads,)
        kernel = attention_store_fwd_kernel if store else attention_nostore_fwd_kernel
        kernel[grid](
            q, k, v, h, o,
            q.stride(1), q.stride(2), q.stride(3), h.stride(1), h.stride(2),
            n_heads, seq_len, h.shape[2], scale,
            BT=BT, BD=BD,
            num_warps=num_warps,
            num_stages=num_stages
        )
        return o


if __name__ == '__main__':
    B, H, T, D = 2, 8, 1024, 128
    dtype = torch.bfloat16
    torch.manual_seed(42)
    # [batch_size, n_heads, seq_len, d_head]
    q = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
    k = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
    v = torch.randn((B, H, T, D), dtype=dtype, device='cuda')

    print('Testing BFloat16...')
    ref = AttentionFunction.apply(q, k, v, True)
    tri = AttentionFunction.apply(q, k, v, False)
    print(ref[0, 0])
    print(tri[0, 0])
    print('Diff:', (ref - tri).abs().max(), '\n\n')

    print('Testing Float...')
    q, k, v = q.float(), k.float(), v.float()
    ref = AttentionFunction.apply(q, k, v, True)
    tri = AttentionFunction.apply(q, k, v, False)
    print(ref[0, 0])
    print(tri[0, 0])
    print('Diff:', (ref - tri).abs().max(), '\n\n')

I hve pasted the tailored code here for ease of reproduction.
The only differnce between attention_nostore_fwd_kernel and attention_store_fwd_kernel is tl.store(p_h, b_h.to(p_h.dtype.element_ty)), which saves the intermediate results to HBMs, and the output is

Testing BFloat16...
tensor([[ -3.0000,   4.3125,  -5.2812,  ...,  -3.1094,   4.4062,   1.4141],
        [ -4.1875,   8.4375,   7.3750,  ...,  -4.2188,   0.7227,   4.2188],
        [ -8.3750,   9.2500,  -3.0938,  ..., -10.4375,   3.5312,  -1.4688],
        ...,
        [ 33.0000, -45.5000, -27.5000,  ...,  47.5000,  11.7500, -42.5000],
        [-12.8125, -13.3125, -29.5000,  ..., -28.2500, -17.8750,  -8.5625],
        [ -8.0625,  19.0000, -25.1250,  ...,  44.0000,  31.8750,   0.7148]],
       device='cuda:0', dtype=torch.bfloat16)
tensor([[ -3.0000,   4.3125,  -5.2812,  ...,  -3.1094,   4.4062,   1.4141],
        [ -4.1875,   8.4375,   7.3750,  ...,  -4.2188,   0.7227,   4.2188],
        [ -8.3750,   9.2500,  -3.0938,  ..., -10.4375,   3.5312,  -1.4688],
        ...,
        [ 21.0000, -57.2500,  94.0000,  ...,  -6.8125, -43.5000,  -3.3281],
        [ 91.0000,  29.6250,   0.9414,  ...,  15.3750,  -4.5000,  13.4375],
        [  8.0625, -24.6250,  21.8750,  ...,   1.3672, -21.3750,  96.0000]],
       device='cuda:0', dtype=torch.bfloat16)
Diff: tensor(223., device='cuda:0', dtype=torch.bfloat16) 


Testing Float...
tensor([[ -2.9934,   4.3165,  -5.2756,  ...,  -3.1058,   4.4168,   1.4156],
        [ -4.1945,   8.4496,   7.3791,  ...,  -4.2401,   0.7162,   4.2193],
        [ -8.4092,   9.2836,  -3.0875,  ..., -10.4867,   3.5389,  -1.4640],
        ...,
        [ 32.8595, -45.5166, -27.4276,  ...,  47.6876,  11.7916, -42.2014],
        [-12.8712, -13.3366, -29.3908,  ..., -28.2100, -17.8268,  -8.5858],
        [ -8.0902,  19.0255, -25.1759,  ...,  44.1477,  31.9002,   0.8230]],
       device='cuda:0')
tensor([[ -2.9934,   4.3165,  -5.2756,  ...,  -3.1058,   4.4168,   1.4156],
        [ -4.1945,   8.4496,   7.3791,  ...,  -4.2401,   0.7162,   4.2193],
        [ -8.4092,   9.2836,  -3.0875,  ..., -10.4867,   3.5389,  -1.4640],
        ...,
        [ 32.8595, -45.5166, -27.4276,  ...,  47.6876,  11.7916, -42.2014],
        [-12.8712, -13.3366, -29.3908,  ..., -28.2100, -17.8268,  -8.5858],
        [ -8.0902,  19.0255, -25.1759,  ...,  44.1477,  31.9002,   0.8230]],
       device='cuda:0')
Diff: tensor(0., device='cuda:0') 

The results are consistant under float.
With minor code changes, however, there is a big unacceptable difference in the final outputs under bfloat16 dtype.
Also, the results of bfloat16 can be the same if the inputs are restricted in a very small range, e.g., divided by 1024.

I guess the evil stems from the precision of bfloat16. but I can't figure out why tl.store brings such a big difference, and how to solve this question.
Could you give me some hints?

The environment is Triton 2.1 & A100-SXM4-40GB.

Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions