post on 22 Nov 2024 about 8050words require 27min
CC BY 4.0 (除特别声明或转载文章外)
如果这篇文章帮助到你,可以请我喝一杯咖啡~
本文基于 Triton 逐步实现 matmul 算子。参考了官方 tutorials,我在官方示例的基础上增加了 GROUP_SIZE_M == 0
时使用非优化的 naive 分配策略。原理在这,大意是重新调整 pid
到 pid_m
pid_n
的映射,使得对应到 C 矩阵的下标具有局部性,优化 B 矩阵的访存。然而这个算法基于 GPU cta 的 Round-Robin 调度策略,该策略实际上并没有得到 NV 官方的承认,所以颇有一种内定、钦点的意思。
In practice, this can improve the performance of our matrix multiplication kernel by more than 10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
然而我无法复现官方宣称的 245 TFLOPS(看官方自己图里也根本没有跑到 220 TFLOPS 的)。从我的结果来看,貌似这个策略只对 (M, N, K) == (10240, 10240, 10240)
附近有效。
matmul.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# spack load py-triton@2.1.0 py-torch@2.4.1+cuda py-matplotlib@3.7.5 py-pandas@1.5.3
# PATH=/usr/sbin:$PATH python3 matmul.py
import triton
import triton.language as tl
import torch
def get_cuda_autotune_config():
return [
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
# Good config for fp8 inputs.
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=4,
),
]
@triton.autotune(
configs=get_cuda_autotune_config(),
key=["M", "N", "K"],
)
@triton.jit
def kernel_matmul(
a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
T_C: tl.constexpr,
T_ACC: tl.constexpr,
):
pid = tl.program_id(0)
if GROUP_SIZE_M == 0: # 不开 L2 友好分块
pid_m = pid // tl.cdiv(N, BLOCK_SIZE_N)
pid_n = pid % tl.cdiv(N, BLOCK_SIZE_N)
else:
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], dtype=T_ACC)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs + k * BLOCK_SIZE_K * stride_ak)
b = tl.load(b_ptrs + k * BLOCK_SIZE_K * stride_bk)
accumulator += tl.dot(a, b)
c = accumulator.to(T_C)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
tl.store(c_ptrs, c)
def triton_matmul(a: torch.Tensor, b: torch.Tensor, group_size_m=8):
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
assert b.is_contiguous(), "Matrix B must be contiguous"
assert a.device == b.device
assert len(a.shape) == 2
assert len(b.shape) == 2
M, N, K = a.shape[0], b.shape[1], a.shape[1]
c = torch.empty([M, N], device=a.device, dtype=torch.result_type(a, b))
assert c.dtype == torch.float16
gridDim = lambda META: [
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])
]
kernel_matmul[gridDim](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
GROUP_SIZE_M=group_size_m,
T_C=tl.float16,
T_ACC=tl.float32,
)
return c
def test():
DEVICE = "cuda" # triton.runtime.driver.active.get_active_torch_device()
M, N, K = 2**12, 2**11, 2**10
a = torch.rand([M, K], device=DEVICE, dtype=torch.float16)
b = torch.rand([K, N], device=DEVICE, dtype=torch.float16)
torch_c = torch.matmul(a, b)
triton_c = triton_matmul(a, b)
print("Maxdiff is {}".format(torch.max(torch.abs(torch_c - triton_c))))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["M", "N", "K"],
x_vals=[1024 * i for i in range(1, 33)],
line_arg="provider",
line_vals=["torch"] + ["triton_gm_" + str(i * 8) for i in range(0, 3)],
line_names=["Torch"] + ["Triton_gm_" + str(i * 8) for i in range(0, 3)],
plot_name="matmul-tflops",
args={},
)
)
def benchmark(M, N, K, provider):
DEVICE = "cuda" # triton.runtime.driver.active.get_active_torch_device()
a = torch.rand([M, K], device=DEVICE, dtype=torch.float16)
b = torch.rand([K, N], device=DEVICE, dtype=torch.float16)
mp = {"torch": lambda: torch.matmul(a, b)}
for i in range(0, 3):
mp["triton_gm_" + str(i * 8)] = lambda: triton_matmul(a, b, group_size_m=i * 8)
ms = triton.testing.do_bench(mp[provider])
tflops = 2 * M * N * K * 1e-12 / (ms * 1e-3)
return tflops
if __name__ == "__main__":
torch.manual_seed(3407)
test()
benchmark.run(print_data=True, show_plots=False, save_path=".")
Maxdiff is 0.0
matmul-tflops:
M Torch Triton_gm_0 Triton_gm_8 Triton_gm_16
0 1024.0 106.146670 99.944910 100.177621 100.285490
1 2048.0 223.024312 190.019240 192.286181 192.835966
2 3072.0 207.204224 218.600118 211.457078 202.498213
3 4096.0 209.966972 208.352882 205.708176 203.854366
4 5120.0 198.518736 204.148371 204.548559 205.186599
5 6144.0 200.941588 209.902595 206.997657 210.406590
6 7168.0 253.372013 208.320646 209.778622 191.931345
7 8192.0 215.637153 209.329088 213.237731 200.643129
8 9216.0 226.996859 211.730382 211.534334 207.674463
9 10240.0 218.977978 208.736454 210.179943 211.921554
10 11264.0 218.514899 201.457546 212.384554 212.280699
11 12288.0 212.998932 208.715113 209.656970 210.661206
12 13312.0 216.675550 209.229494 210.838007 199.165423
13 14336.0 217.004038 210.731481 207.639317 207.726755
14 15360.0 220.227365 203.148239 209.640658 206.536071
15 16384.0 210.870985 204.128583 201.964038 207.621751
16 17408.0 217.617265 210.061521 205.695903 210.022982
17 18432.0 209.392064 207.010429 207.720621 207.175245
18 19456.0 193.783631 204.407571 200.792386 201.611363
19 20480.0 211.705219 205.941333 204.070112 204.050251
20 21504.0 209.615616 198.929896 199.225789 199.785262
21 22528.0 194.640101 199.477188 195.409970 194.942547
22 23552.0 209.458697 191.897608 193.983802 192.512928
23 24576.0 209.800194 187.619909 189.899639 190.173693
24 25600.0 209.259851 192.309491 192.844827 189.251787
25 26624.0 212.281754 189.735535 188.293298 189.955548
26 27648.0 212.429522 185.660376 184.001049 185.342768
27 28672.0 201.034415 188.316045 184.910028 186.445225
28 29696.0 207.239106 182.942540 184.113413 182.231425
29 30720.0 213.894056 180.319742 179.993019 182.207880
30 31744.0 215.419762 178.461770 178.591866 181.018181
31 32768.0 216.530680 179.372236 178.546413 177.216643
以上数据是针对所有 BLOCK_SIZE 调优的结果,未必能体现 L2 优化的效果。我控制变量,固定 triton.autotune
的参数为
1
2
3
4
5
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
)
得到的图如下,结论不变,这个 L2 优化策略只对 (M, N, K) == (10240, 10240, 10240)
附近有效。
难道需要使用其它的魔法参数?
Related posts