前段时间因为工作原因重新浏览了 torch.compile 中几个重要的函数调用栈,现将这些内容分享给有兴趣通过源代码理解 torch.compile 行为的朋友。

本文使用 PyTorch 官方提供的 docker 镜像: pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel,PyTorch 版本为 2.4.0。

本文所用示例代码:

import torch
import torch.nn as nn
from collections import OrderedDict

@torch._inductor.config.patch(
    post_grad_fusion_options={"batch_linear_post_grad": {"require_fbgemm": False}}
)
def test():
    n, h = 32, 128
    repeats = 3
    layers = OrderedDict()
    for i in range(repeats):
        layers[f"fc_{i}"] = nn.Linear(h, h)
        layers[f"ln_{i}"] = nn.LayerNorm(h)
        layers[f"silu_{i}"] = nn.SiLU()
    model = nn.Sequential(layers).cuda().half()
    x = torch.randn((n, h), device="cuda", dtype=torch.float16, requires_grad=True)
    dy = torch.randn_like(x)

    compiled = torch.compile(model, mode="reduce-overhead")
    for _ in range(4):
        y = compiled(x)
        y.backward(dy)

if __name__ == "__main__":
    test()

TorchDynamo Link to heading

使用 TorchDynamo 来 trace 函数:

在 AOTAutograd 之前运行 pre_grad_passes,替换捕获到的计算图中的算子:

AOTAutograd & PrimTorch Link to heading

AOTAutograd 通过 trace 得到正向传播和反向传播的 joint graph:

通过 functorch 把待编译函数转为符合函数式编程原则的函数:

通过 PrimTorch 分解算子:

AOTAutograd 划分 joint graph 为 forward graph 和 backward graph:

TorchInductor Link to heading

运行 post_grad_passes,优化计算图:

get_fusion_candidates 以 BFS 的方式搜索依次搜索输入节点,检查是否能够融合算子:

通过 graph lowering 把计算图转为 TorchInductor 的 IR:

分析图中的依赖关系:

TorchInductor 通过 fuse_nodes 做 vertical & horizontal fusion,它先通过 get_possible_fusions 获取可以融合的算子组合,先决条件是算子用到了相同的 buffer,然后通过 can_fuse 检查是否可以融合。

Backend 生成 ATen:

使用 Triton 编译:

为编译过的子图生成调用它们的 wrapper code:

Guard Link to heading

为编译好的子图生成 guard:

生成 check_fn,确保编译过的子图可用,否则重新编译:

CUDA Graph Link to heading

torch.compilereduce-overhead 模式下会自动添加 CUDA Graph 来减小运行时开销,发生在 TorchInductor 中。首先检查子图中是否包含与 CUDA Graph 不兼容的算子:

把子图转为 CUDA Graph:

捕获子图为 CUDA Graph 发生在:

在运行时 replay CUDA Graph: