CuTe 是一个 Header-only 的库,在提供抽象简化 CUDA 的同时保持了对底层的比较精确的控制,可以一定程度上提高 CUDA 开发的效率。之前一直是用纯 CUDA+ptx 的方式写 kernel,有时会觉得比较繁琐和容易出错,又听说不少工作(比如 FlashAttention)都是用 CuTe 写的,听起来很强,所以决定学习一下 CuTe。
Layout
这篇文章就从 CuTe 最基础的部分 Layout 开始。Layout 的作用是将高维 Tensor 的索引映射到一维的物理地址。通过这种抽象,用户可以从手动计算索引的繁琐过程中解放出来(真的很容易写错……)。
首先,Layout = Shape + Stride,其中 Shape 指张量的形状,Stride 是指各维度上每增加一个单位,物理地址上偏移的单位。
Basic
可以用 make_layout 声明一个 Layout,举个例子,如果要声明一个 4*8 的行主序矩阵,可以写:
#include <cute/tensor.hpp>
using namespace cute;
int main() {
auto layout = make_layout(Shape<_4, _8>{}, Stride<_8, _1>{});
print_layout(layout);
}这里的 Shape<_4, _8> 就是声明一个 4*8 的矩阵。Stride<_8, _1> 说明在第一维增加 1 个单位,物理地址应该增加 8 个单位,第二维同理。
这里另外再补充一下,CuTe 提供了相当方便的输出调试功能,比如这里用的 print_layout,可以输出整齐的矩阵和各个单位对应的物理索引,非常便于调试。
如果要改成列主序也很简单,只需要调整 Stride 即可:
// auto layout = make_layout(Shape<_4, _8>{}, Stride<_8, _1>{});
auto layout = make_layout(Shape<_4, _8>{}, Stride<_1, _4>{});Hierarchy
如果是简单的行主序或列主序,直接用 CUDA 写也不算麻烦,但 CuTe 的好处在于它也可以方便地表示分层的 Tensor,这在各种矩阵分块中非常有用。
举个例子,假如我们想表示一个这样的分块矩阵:
可以写
auto layout2 = make_layout(Shape<_4, Shape<_2, _4>>{}, Stride<_2, Stride<_1, _8>>{});此处形状可以理解为 ((内部行维度, 外部行维度), (内部列维度,外部列维度)),暂时没碰到过更高维的情况,所以我暂时没理 :)
通过 print_layout 可以检查:
0 1 2 3 4 5 6 7
+----+----+----+----+----+----+----+----+
0 | 0 | 1 | 8 | 9 | 16 | 17 | 24 | 25 |
+----+----+----+----+----+----+----+----+
1 | 2 | 3 | 10 | 11 | 18 | 19 | 26 | 27 |
+----+----+----+----+----+----+----+----+
2 | 4 | 5 | 12 | 13 | 20 | 21 | 28 | 29 |
+----+----+----+----+----+----+----+----+
3 | 6 | 7 | 14 | 15 | 22 | 23 | 30 | 31 |
+----+----+----+----+----+----+----+----+
确实符合预计。
一个更复杂但更常见的例子是矩阵的分块,比如我们要把一个 4*8 的矩阵分为 2*4 的 Tile,Tile 内部和 Tile 间都以行主序方式组织:
可以写作:
auto layout4 = make_layout(Shape<Shape<_2, _2>, Shape<_4, _2>>{}, Stride<Stride<_4, _16>, Stride<_1, _8>>{});检查一下:
0 1 2 3 4 5 6 7
+----+----+----+----+----+----+----+----+
0 | 0 | 1 | 2 | 3 | 8 | 9 | 10 | 11 |
+----+----+----+----+----+----+----+----+
1 | 4 | 5 | 6 | 7 | 12 | 13 | 14 | 15 |
+----+----+----+----+----+----+----+----+
2 | 16 | 17 | 18 | 19 | 24 | 25 | 26 | 27 |
+----+----+----+----+----+----+----+----+
3 | 20 | 21 | 22 | 23 | 28 | 29 | 30 | 31 |
+----+----+----+----+----+----+----+----+
此外也可以利用 CuTe 原生提供的一种 Layout Algebra 实现:
auto outer = make_layout(Shape<_2, _2>{}, Stride<_2, _1>{});
print_layout(outer);
auto inner = make_layout(Shape<_2, _4>{}, Stride<_4, _1>{});
print_layout(inner);
auto result = blocked_product(inner, outer);
print_layout(result);这里的 blocked_product 就可以理解为两个 Layout 嵌套。
其他 Layout 变换暂时还没遇到应用场景,此处不再介绍,可以参考 NVIDIA 的官方文档。后续会继续学习 CuTe 的 MMA 和 Copy 操作的抽象,争取再用 CuTe 实现一版 GEMM 和 FlashAttention。