CuTe 是一个 Header-only 的库,在提供抽象简化 CUDA 的同时保持了对底层的比较精确的控制,可以一定程度上提高 CUDA 开发的效率。之前一直是用纯 CUDA+ptx 的方式写 kernel,有时会觉得比较繁琐和容易出错,又听说不少工作(比如 FlashAttention)都是用 CuTe 写的,听起来很强,所以决定学习一下 CuTe。

Layout

这篇文章就从 CuTe 最基础的部分 Layout 开始。Layout 的作用是将高维 Tensor 的索引映射到一维的物理地址。通过这种抽象,用户可以从手动计算索引的繁琐过程中解放出来(真的很容易写错……)。

首先,Layout = Shape + Stride,其中 Shape 指张量的形状,Stride 是指各维度上每增加一个单位,物理地址上偏移的单位。

Basic

可以用 make_layout 声明一个 Layout,举个例子,如果要声明一个 4*8 的行主序矩阵,可以写:

#include <cute/tensor.hpp>
using namespace cute;
int main() {
	auto layout = make_layout(Shape<_4, _8>{}, Stride<_8, _1>{});
	print_layout(layout);
}

这里的 Shape<_4, _8> 就是声明一个 4*8 的矩阵。Stride<_8, _1> 说明在第一维增加 1 个单位,物理地址应该增加 8 个单位,第二维同理。

CuTe Layout 2026-03-02 21.21.19CuTe Layout 2026-03-02 21.21.19

这里另外再补充一下,CuTe 提供了相当方便的输出调试功能,比如这里用的 print_layout,可以输出整齐的矩阵和各个单位对应的物理索引,非常便于调试。

如果要改成列主序也很简单,只需要调整 Stride 即可:

// auto layout = make_layout(Shape<_4, _8>{}, Stride<_8, _1>{});
auto layout = make_layout(Shape<_4, _8>{}, Stride<_1, _4>{});

Hierarchy

如果是简单的行主序或列主序,直接用 CUDA 写也不算麻烦,但 CuTe 的好处在于它也可以方便地表示分层的 Tensor,这在各种矩阵分块中非常有用。

举个例子,假如我们想表示一个这样的分块矩阵:

CuTe Layout 2026-03-02 20.39.54CuTe Layout 2026-03-02 20.39.54

可以写

auto layout2 = make_layout(Shape<_4, Shape<_2, _4>>{}, Stride<_2, Stride<_1, _8>>{});

此处形状可以理解为 ((内部行维度, 外部行维度), (内部列维度,外部列维度)),暂时没碰到过更高维的情况,所以我暂时没理 :)

通过 print_layout 可以检查:

  0    1    2    3    4    5    6    7 
   +----+----+----+----+----+----+----+----+
0  |  0 |  1 |  8 |  9 | 16 | 17 | 24 | 25 |
   +----+----+----+----+----+----+----+----+
1  |  2 |  3 | 10 | 11 | 18 | 19 | 26 | 27 |
   +----+----+----+----+----+----+----+----+
2  |  4 |  5 | 12 | 13 | 20 | 21 | 28 | 29 |
   +----+----+----+----+----+----+----+----+
3  |  6 |  7 | 14 | 15 | 22 | 23 | 30 | 31 |
   +----+----+----+----+----+----+----+----+

确实符合预计。

一个更复杂但更常见的例子是矩阵的分块,比如我们要把一个 4*8 的矩阵分为 2*4 的 Tile,Tile 内部和 Tile 间都以行主序方式组织:

CuTe Layout 2026-03-02 21.00.14CuTe Layout 2026-03-02 21.00.14

可以写作:

    auto layout4 = make_layout(Shape<Shape<_2, _2>, Shape<_4, _2>>{}, Stride<Stride<_4, _16>, Stride<_1, _8>>{});

检查一下:

      0    1    2    3    4    5    6    7 
   +----+----+----+----+----+----+----+----+
0  |  0 |  1 |  2 |  3 |  8 |  9 | 10 | 11 |
   +----+----+----+----+----+----+----+----+
1  |  4 |  5 |  6 |  7 | 12 | 13 | 14 | 15 |
   +----+----+----+----+----+----+----+----+
2  | 16 | 17 | 18 | 19 | 24 | 25 | 26 | 27 |
   +----+----+----+----+----+----+----+----+
3  | 20 | 21 | 22 | 23 | 28 | 29 | 30 | 31 |
   +----+----+----+----+----+----+----+----+

此外也可以利用 CuTe 原生提供的一种 Layout Algebra 实现:

auto outer = make_layout(Shape<_2, _2>{}, Stride<_2, _1>{});
print_layout(outer);
auto inner = make_layout(Shape<_2, _4>{}, Stride<_4, _1>{});
print_layout(inner);
auto result = blocked_product(inner, outer);
print_layout(result);

这里的 blocked_product 就可以理解为两个 Layout 嵌套。

其他 Layout 变换暂时还没遇到应用场景,此处不再介绍,可以参考 NVIDIA 的官方文档。后续会继续学习 CuTe 的 MMA 和 Copy 操作的抽象,争取再用 CuTe 实现一版 GEMM 和 FlashAttention。

Reference