Menu

Triton 入门-L2 友好的矩阵乘

post on 22 Nov 2024 about 8050words require 27min
CC BY 4.0 (除特别声明或转载文章外)
如果这篇文章帮助到你,可以请我喝一杯咖啡~

本文基于 Triton 逐步实现 matmul 算子。参考了官方 tutorials,我在官方示例的基础上增加了 GROUP_SIZE_M == 0 时使用非优化的 naive 分配策略。原理在这,大意是重新调整 pidpid_m pid_n 的映射,使得对应到 C 矩阵的下标具有局部性,优化 B 矩阵的访存。然而这个算法基于 GPU cta 的 Round-Robin 调度策略,该策略实际上并没有得到 NV 官方的承认,所以颇有一种内定、钦点的意思。

In practice, this can improve the performance of our matrix multiplication kernel by more than 10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).

然而我无法复现官方宣称的 245 TFLOPS(看官方自己图里也根本没有跑到 220 TFLOPS 的)。从我的结果来看,貌似这个策略只对 (M, N, K) == (10240, 10240, 10240) 附近有效。

实验环境

  • NVIDIA-A100-PCIE-40GB
  • Debian 11
  • spack@0.23.0
  • cuda@12.6.2
  • py-triton@2.1.0
  • py-torch@2.4.1+cuda
  • py-matplotlib@3.7.5
  • py-pandas@1.5.3

源代码 matmul.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# spack load py-triton@2.1.0 py-torch@2.4.1+cuda py-matplotlib@3.7.5 py-pandas@1.5.3
# PATH=/usr/sbin:$PATH python3 matmul.py
import triton
import triton.language as tl
import torch


def get_cuda_autotune_config():
    return [
        triton.Config(
            {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
            num_stages=3,
            num_warps=8,
        ),
        triton.Config(
            {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
            num_stages=5,
            num_warps=2,
        ),
        triton.Config(
            {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
            num_stages=5,
            num_warps=2,
        ),
        # Good config for fp8 inputs.
        triton.Config(
            {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128},
            num_stages=3,
            num_warps=8,
        ),
        triton.Config(
            {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128},
            num_stages=3,
            num_warps=8,
        ),
        triton.Config(
            {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64},
            num_stages=4,
            num_warps=4,
        ),
    ]


@triton.autotune(
    configs=get_cuda_autotune_config(),
    key=["M", "N", "K"],
)
@triton.jit
def kernel_matmul(
    a_ptr,
    b_ptr,
    c_ptr,
    M,
    N,
    K,
    stride_am,
    stride_ak,
    stride_bk,
    stride_bn,
    stride_cm,
    stride_cn,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
    T_C: tl.constexpr,
    T_ACC: tl.constexpr,
):
    pid = tl.program_id(0)
    if GROUP_SIZE_M == 0:  # 不开 L2 友好分块
        pid_m = pid // tl.cdiv(N, BLOCK_SIZE_N)
        pid_n = pid % tl.cdiv(N, BLOCK_SIZE_N)
    else:
        num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
        num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)

        num_pid_in_group = GROUP_SIZE_M * num_pid_n

        group_id = pid // num_pid_in_group
        first_pid_m = group_id * GROUP_SIZE_M
        group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M)

        pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
        pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    accumulator = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], dtype=T_ACC)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        a = tl.load(a_ptrs + k * BLOCK_SIZE_K * stride_ak)
        b = tl.load(b_ptrs + k * BLOCK_SIZE_K * stride_bk)
        accumulator += tl.dot(a, b)
    c = accumulator.to(T_C)

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
    tl.store(c_ptrs, c)


def triton_matmul(a: torch.Tensor, b: torch.Tensor, group_size_m=8):
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    assert b.is_contiguous(), "Matrix B must be contiguous"
    assert a.device == b.device
    assert len(a.shape) == 2
    assert len(b.shape) == 2
    M, N, K = a.shape[0], b.shape[1], a.shape[1]
    c = torch.empty([M, N], device=a.device, dtype=torch.result_type(a, b))
    assert c.dtype == torch.float16
    gridDim = lambda META: [
        triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])
    ]
    kernel_matmul[gridDim](
        a,
        b,
        c,
        M,
        N,
        K,
        a.stride(0),
        a.stride(1),
        b.stride(0),
        b.stride(1),
        c.stride(0),
        c.stride(1),
        GROUP_SIZE_M=group_size_m,
        T_C=tl.float16,
        T_ACC=tl.float32,
    )
    return c


def test():
    DEVICE = "cuda"  # triton.runtime.driver.active.get_active_torch_device()
    M, N, K = 2**12, 2**11, 2**10
    a = torch.rand([M, K], device=DEVICE, dtype=torch.float16)
    b = torch.rand([K, N], device=DEVICE, dtype=torch.float16)
    torch_c = torch.matmul(a, b)
    triton_c = triton_matmul(a, b)
    print("Maxdiff is {}".format(torch.max(torch.abs(torch_c - triton_c))))


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["M", "N", "K"],
        x_vals=[1024 * i for i in range(1, 33)],
        line_arg="provider",
        line_vals=["torch"] + ["triton_gm_" + str(i * 8) for i in range(0, 3)],
        line_names=["Torch"] + ["Triton_gm_" + str(i * 8) for i in range(0, 3)],
        plot_name="matmul-tflops",
        args={},
    )
)
def benchmark(M, N, K, provider):
    DEVICE = "cuda"  # triton.runtime.driver.active.get_active_torch_device()
    a = torch.rand([M, K], device=DEVICE, dtype=torch.float16)
    b = torch.rand([K, N], device=DEVICE, dtype=torch.float16)
    mp = {"torch": lambda: torch.matmul(a, b)}
    for i in range(0, 3):
        mp["triton_gm_" + str(i * 8)] = lambda: triton_matmul(a, b, group_size_m=i * 8)
    ms = triton.testing.do_bench(mp[provider])
    tflops = 2 * M * N * K * 1e-12 / (ms * 1e-3)
    return tflops


if __name__ == "__main__":
    torch.manual_seed(3407)
    test()
    benchmark.run(print_data=True, show_plots=False, save_path=".")

程序输出

Maxdiff is 0.0
matmul-tflops:
          M       Torch  Triton_gm_0  Triton_gm_8  Triton_gm_16
0    1024.0  106.146670    99.944910   100.177621    100.285490
1    2048.0  223.024312   190.019240   192.286181    192.835966
2    3072.0  207.204224   218.600118   211.457078    202.498213
3    4096.0  209.966972   208.352882   205.708176    203.854366
4    5120.0  198.518736   204.148371   204.548559    205.186599
5    6144.0  200.941588   209.902595   206.997657    210.406590
6    7168.0  253.372013   208.320646   209.778622    191.931345
7    8192.0  215.637153   209.329088   213.237731    200.643129
8    9216.0  226.996859   211.730382   211.534334    207.674463
9   10240.0  218.977978   208.736454   210.179943    211.921554
10  11264.0  218.514899   201.457546   212.384554    212.280699
11  12288.0  212.998932   208.715113   209.656970    210.661206
12  13312.0  216.675550   209.229494   210.838007    199.165423
13  14336.0  217.004038   210.731481   207.639317    207.726755
14  15360.0  220.227365   203.148239   209.640658    206.536071
15  16384.0  210.870985   204.128583   201.964038    207.621751
16  17408.0  217.617265   210.061521   205.695903    210.022982
17  18432.0  209.392064   207.010429   207.720621    207.175245
18  19456.0  193.783631   204.407571   200.792386    201.611363
19  20480.0  211.705219   205.941333   204.070112    204.050251
20  21504.0  209.615616   198.929896   199.225789    199.785262
21  22528.0  194.640101   199.477188   195.409970    194.942547
22  23552.0  209.458697   191.897608   193.983802    192.512928
23  24576.0  209.800194   187.619909   189.899639    190.173693
24  25600.0  209.259851   192.309491   192.844827    189.251787
25  26624.0  212.281754   189.735535   188.293298    189.955548
26  27648.0  212.429522   185.660376   184.001049    185.342768
27  28672.0  201.034415   188.316045   184.910028    186.445225
28  29696.0  207.239106   182.942540   184.113413    182.231425
29  30720.0  213.894056   180.319742   179.993019    182.207880
30  31744.0  215.419762   178.461770   178.591866    181.018181
31  32768.0  216.530680   179.372236   178.546413    177.216643

matmul-tflops

以上数据是针对所有 BLOCK_SIZE 调优的结果,未必能体现 L2 优化的效果。我控制变量,固定 triton.autotune 的参数为

1
2
3
4
5
triton.Config(
    {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
    num_stages=4,
    num_warps=4,
)

得到的图如下,结论不变,这个 L2 优化策略只对 (M, N, K) == (10240, 10240, 10240) 附近有效。

matmul-tflops1

难道需要使用其它的魔法参数?

Related posts

Loading comments...