简介 Link to heading

在 PyTorch 2.0 以前,用户通过 PyTorch 可以直接捕获到正向传播的计算图,比如 JIT trace 和 TorchFX 的 symbolic trace。虽然 PyTorch 的每个算子都包含正向传播和反向传播的实现,但用户并不能直接在反向传播的计算图上面做优化,也无法把正向传播和反向传播的计算图合并在一张计算图中。PyTorch 2.0 中引入了 AOTAutograd,它的出现解决了这个问题,从而使得一些针对 training 的优化变得可能。

有了 AOTAutograd,用户可以做以下事情:

  • 获取反向传播计算图、甚至是正向传播和反向传播联合的计算图;
  • 用不同的后端编译器分别编译正向传播和反向传播计算图;
  • 针对训练 (training) 做正向传播、反向传播联合优化,比如通过在反向传播中重算 (recompute) 来减少正向传播为反向传播保留的 tensor,从而削减内存需求;

PyTorch 2.0

用法 Link to heading

截止目前为止,使用 AOTAutograd 的方式有几种,但最为基础的是 aot_function。以下面的代码片段为例:

import torch
from functorch.compile import aot_function, \
    make_boxed_func, ts_compile

def fn(a, b, c, d):
    x = a + b + c + d
    return x.cos().cos()

def run_func(func, *inputs):
    res = func(*inputs)
    loss = res.sum()
    loss.backward()

def compiler_fn(fx_module: torch.fx.GraphModule, _):
    print(fx_module.code)
    return make_boxed_func(fx_module.forward)

a, b, c, d = [torch.randn(2, 4, requires_grad=True,
    device="cuda") for _ in range(4)]
run_func(fn, a, b, c, d)

aot_print_fn = aot_function(fn, fw_compiler=compiler_fn,
    bw_compiler=compiler_fn)
run_func(aot_print_fn, a, b, c, d)

因为 AOTAutograd 现阶段还是 functorch 的一部分,所以要先从 functorch 中导入 aot_function。函数 fn(a, b, c, d) 是我们待优化的函数,compiler_fn() 是自定义的后端编译器,接受一张 fx.Graph,这里只是打印出来 fx.GraphModule 对应的 Python 函数。使用 aot_function 优化 fn,并在 run_func() 中执行正向传播和反向传播。执行上面的代码片段,我们得到以下输出:

$ python test.py
def forward(self, primals_1, primals_2, primals_3, primals_4):
    add = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    add_1 = torch.ops.aten.add.Tensor(add, primals_3);  add = primals_3 = None
    add_2 = torch.ops.aten.add.Tensor(add_1, primals_4);  add_1 = primals_4 = None
    cos = torch.ops.aten.cos.default(add_2)
    cos_1 = torch.ops.aten.cos.default(cos)
    return [cos_1, cos, add_2]

def forward(self, cos, add_2, tangents_1):
    sin = torch.ops.aten.sin.default(cos);  cos = None
    neg = torch.ops.aten.neg.default(sin);  sin = None
    mul = torch.ops.aten.mul.Tensor(tangents_1, neg);  tangents_1 = neg = None
    sin_1 = torch.ops.aten.sin.default(add_2);  add_2 = None
    neg_1 = torch.ops.aten.neg.default(sin_1);  sin_1 = None
    mul_1 = torch.ops.aten.mul.Tensor(mul, neg_1);  mul = neg_1 = None
    return [mul_1, mul_1, mul_1, mul_1]

可以看到,自定义的编译器 compiler_fn() 被调用了两次,分别打印了正向传播和反向传播计算图对应的 Python 代码,cosadd_2 被保留给了反向传播。其中的 primalstangents 是微分几何中的概念,这里可以把 primals 理解为用户函数的输入,它是正向传播的输入,把 tangents 理解为用户函数输出的梯度,它是反向传播的输入。两张计算图是 FX Graph,其中包含的是 ATen 算子,它们是 low-level 算子,而不是 Torch 级别的算子,例如 Linear

我们还可以使用 PyTorch 2.0 内置的编译器,比如 ts_compile:

from functorch.compile import ts_compile

aot_compiled_fn = aot_function(fn, fw_compiler=ts_compile,
   bw_compiler=ts_compile)
run_func(aot_compiled_fn, a, b, c, d)

AOTAutograd 还提供了 min_cut_rematerialization_partition,它的作用是针对正向传播计算图和反向传播计算图做联合优化,从而 降低内存需求:

from functorch.compile import min_cut_rematerialization_partition

aot_mincut_fn = aot_function(fn, fw_compiler=compiler_fn,
    bw_compiler=compiler_fn,
    partition_fn=min_cut_rematerialization_partition,
    decompositions=None)
run_func(aot_mincut_fn, a, b, c, d)

为了简化上述过程,AOTAutograd 提供了 memory_efficient_fusion,它合并了 aot_functionmin_cut_rematerialization_partition:

from functorch.compile import memory_efficient_fusion

aot_memeff_fn = memory_efficient_fusion(fn)
run_func(aot_memeff_fn, a, b, c, d)

除此之外,用户还可以从 torch.compile 调用 AOTAutograd,默认的 inductor 后端会调用 AOTAutograd,backend 包含 aot_ 字段的编译器也会调用 AOTAutograd,例如 aot_ts_nvfuser:

aot_ts_nvfuser_fn = torch.compile(fn, backend="aot_ts_nvfuser")
run_func(aot_ts_nvfuser_fn, a, b, c, d)

原理 Link to heading

为什么叫 AOTAutograd?因为 PyTorch 反向传播的计算图是在执行正向传播的过程中动态构建的,反向传播的计算图在正向传播结束时才能确定下来。AOTAutograd 以 Ahead-of-Time 的方式同时 trace 正向传播和反向传播,从而在函数真正执行之前拿到正向传播和反向传播的计算图

总的来说,AOTAutograd 的工作流程 如下:

  • 以 AOT 方式通过 __torch_dispatch__ 机制 trace 正向传播和反向传播,生成联合计算图 (joint forward and backward graph),它是包含 Aten/Prim 算子的 FX Graph;
  • partition_fn 把 joint graph 划分为正向传播计算图和反向传播计算图;
  • 可选: 通过 decompositions 把 high-level 算子分解、下沉到粒度更小的算子;
  • 调用 fw_compilerbw_compiler 分别编译正向传播计算图和反向传播计算图,通过 TorchFX 生成编译后的 Python 代码,并整合为一个 torch.autograd.Function;

对于上面的案例 fn,tracing 得到的 joint graph 为:

def forward(self, primals, tangents):
    primals_1: f32[2, 4], primals_2: f32[2, 4], primals_3: f32[2, 4], primals_4: f32[2, 4], tangents_1: f32[2, 4], = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    # No stacktrace found for following nodes
    add: f32[2, 4] = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    add_1: f32[2, 4] = torch.ops.aten.add.Tensor(add, primals_3);  add = primals_3 = None
    add_2: f32[2, 4] = torch.ops.aten.add.Tensor(add_1, primals_4);  add_1 = primals_4 = None
    cos: f32[2, 4] = torch.ops.aten.cos.default(add_2)
    cos_1: f32[2, 4] = torch.ops.aten.cos.default(cos)
    sin: f32[2, 4] = torch.ops.aten.sin.default(cos);  cos = None
    neg: f32[2, 4] = torch.ops.aten.neg.default(sin);  sin = None
    mul: f32[2, 4] = torch.ops.aten.mul.Tensor(tangents_1, neg);  tangents_1 = neg = None

    #
    sin_1: f32[2, 4] = torch.ops.aten.sin.default(add_2);  add_2 = None
    neg_1: f32[2, 4] = torch.ops.aten.neg.default(sin_1);  sin_1 = None
    mul_1: f32[2, 4] = torch.ops.aten.mul.Tensor(mul, neg_1);  mul = neg_1 = None
    return pytree.tree_unflatten([cos_1, mul_1, mul_1, mul_1, mul_1], self._out_spec)

在上述 joint graph 的基础上,通过 partition_fn 将其划分为正向传播计算图和反向传播计算图。

Torch Dispatch Link to heading

AOTAutograd 得以工作的核心是 __torch_dispatch__PyTorch 的核心是一个 dispatcher,它的功能是根据输入 tensor 的属性把算子 dispatch 到具体的 kernel 上,比如根据 tensor 的 device 属性决定是调用 CUDA kernel 还是 CPU 函数执行该算子。

一个算子在 PyTorch 中往往要经过多次 dispatch,__torch_dispatch__ 给了用户提供了一个入口,使得用户能够在算子最终 dispatch 前获取对应的算子和输入。比如 torch.sin 在经过多次 dispatch 之后会调用 __torch_dispatch__:

torch-dispatch

使用 __torch_dispatch__ 的方法有两种,以下面的代码片段为例,PyTensor 继承了 torch.Tensor,并定制了类方法 __torch_dispatch__,其中可以获取要执行的 ATen 算子和对应的参数:

import torch

class PyTensor(torch.Tensor):
    __torch_function__ = torch._C._disabled_torch_function_impl

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        raise NotImplementedError

x = PyTensor(torch.randn(8))
x * 2

执行上面的代码片段,x * 2 会触发 __torch_dispatch__:

$ python test-1.py
Traceback (most recent call last):
  File "test-1.py", line 13, in <module>
    x * 2
  File "test-1.py", line 10, in __torch_dispatch__
    raise NotImplementedError
NotImplementedError

此处的原理是: PyTorch 每个的 Tensor 都包含一组 dispatch key (参考 TensorImpl 中的 DispatchKeySet),初始化 Tensor 的时候,如果存在定制的类属性 __torch_dispatch__ (check_has_torch_dispatch()),则给该 tensor 添加 DispatchKey::PythonDispatchKey::PythonTLSSnapshot。在执行算子时,dispatcher 从 Tensor 中收集 DispatchKeySet,dispatcher 看到 DispatchKey::Python 时会执行 Tensor__torch_dispatch__ 函数。

在上面的例子中,为了使用 __torch_dispatch__ 而不得不继承 torch.Tensor 实现子类是比较麻烦的,好在 PyTorch 提供了另一种更优雅的方式 TorchDispatchMode。在 TorchDispatchMode 的 context 里,执行算子会调用 TorchDispatchMode.__torch_dispatch__,从而避免用户自定义 Tensor 类。例如下面的代码:

import torch
from torch.utils._python_dispatch import TorchDispatchMode

x = torch.randn(8)
with TorchDispatchMode():
    x * 2

执行该代码片段时会触发 TorchDispatchMode.__torch_dispatch__,它需要用户自己实现:

$ python test-2.py
Traceback (most recent call last):
  File "test-2.py", line 8, in <module>
    x * 2
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_python_dispatch.py", line 45, in __torch_dispatch__
    raise NotImplementedError()
NotImplementedError

此处的原理是: PyTorch 的每个线程有一个 DispatchKeySetTorchDispatchModeDispatchKeySet 中添加了 DispatchKey::PythonDispatchKey::PythonTLSSnapshot,在 C++ 中把 torchDispatchModeState 压栈。执行具体的算子时,dispatcher 把当前线程的 DispatchKeySetTensorDispatchKeySet 合并,然后根据合并后的 DispatchKeySet 来 dispatch 该算子,并根据 DispatchKey::Python 优先调用 TorchDispatchMode__torch_dispatch__ 方法。

通过 __torch_dispatch__,用户有机会在 kernel 执行前获取算子和参数,从而可以做很多事情,基于 __torch_dispatch__ 的 tracing 正是其中之一。

去重 Link to heading

TorchFX 中实现了 make_fx,与常规的 symbolic tracing 不同,make_fx 是通过 __torch_dispatch__ 实现的,AOTAutograd 的 tracing 正是用的 make_fx。以下面的代码为例:

import torch
from torch.fx.experimental.proxy_tensor import make_fx

def f(x, y):
    return x + y

x = torch.randn(8)
y = torch.randn(8)
g = make_fx(f)(x, y)
print(g.code)

执行上面的代码片段,可以看到 make_fx 成功 trace 了函数 f 的计算图,并保存在 FX Graph 中:

def forward(self, x_1, y_1):
    add = torch.ops.aten.add.Tensor(x_1, y_1);  x_1 = y_1 = None
    return add

作为对比,使用 symbolic trace 得到的计算图包含的是高层次的算子:

from torch.fx import symbolic_trace
h = symbolic_trace(f)
print(h.code)

可见 symbolic_trace 得到的是 operator.add,而 make_fx 得到的是 torch.ops.aten.add.Tensor:

def forward(self, x, y):
    add = x + y;  x = y = None
    return add

make_fx 并不是完美的,特别是当用于 tracing 的 输入参数中包含重复的 tensor 时,例如:

import torch
from torch.fx.experimental.proxy_tensor import make_fx

def f(x, y):
    return x + y

x = torch.randn(8)
g = make_fx(f)(x, x)
print(g.code)

可以看到,make_fx(f)(x, x) trace 出来的是 y + y 而不是预期中的 x + y:

def forward(self, x_1, y_1):
    add = torch.ops.aten.add.Tensor(y_1, y_1);  y_1 = None
    return add

除此之外,如果我们使用 torch.autograd.grad(f(x, y), (x, y)) 计算函数 f(x, y)(x, y) 的梯度,但如果 xy 是相同的 tensor,trace 出来的梯度就是错的:

>>> x = torch.randn(1, requires_grad=True)
>>> torch.autograd.grad(x + x, (x, x))
(tensor([2.]), tensor([2.]))

把函数参数 去重 (deduplicate) 可以得到正确的梯度:

>>> y = torch.randn(1, requires_grad=True)
>>> torch.autograd.grad(x + y, (x, y))
(tensor([1.]), tensor([1.]))

为什么基于 __torch_dispatch__ 的 tracing 机制会存在这样的问题?因为使用 __torch_dispatch__ 进行 tracing 时使用的是 tensor,而要建立的是 fx.Graph,怎么把 tensor 映射到 fx.Graph 中的节点?答案是通过 tensor 的 ID,相同的 tensor 会被映射到 fx.Graph 中的同一个 Proxy,因而给被 trace 的函数实际参数去重就很有必要。

AOTAutograd 在开始 tracing 前会给输入函数去重,它依次尝试下面两个策略:

  1. 通过 detach 把待 trace 函数的重复参数变为 leaf tensor: 缺点是待 trace 函数不能改变重复参数,例如在重复 tensor 上调用 in-place 算子;
  2. 把重复的参数从函数签名中移除: 捕获的计算图是针对重复参数特化的版本;

Joint Graph Link to heading

有了基于 __torch_dispatch__ 的 tracing 机制,AOTAutograd 就可以 trace 联合正向传播和反向传播计算图。这里的逻辑比较直接,如果用户要想优化的正向传播函数是 fn,AOTAutograd 则构建并 trace 一个 joint_forward_backward 函数,其中调用正向传播函数 fn 之后,再调用 torch.autograd.grad 执行反向传播。核心逻辑如 forward_or_joint() 所示:

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_functorch/aot_autograd.py#L920
outs = fn(*primals_after_cloning)

if grad_primals:
    with fx_traceback.preserve_node_meta():
        backward_out = torch.autograd.grad(
            needed_outs,
            grad_primals,
            grad_outputs=needed_tangents,
            allow_unused=True,
        )

AOTAutograd 通过 make_fx 来 trace 该 joint_forward_backward 函数,对于其中的每个算子,都会触发 __torch_dispatch__,从 tensor 获取 proxy,在 fx.Graph 中创建算子对应的 proxy,类型为 call_function,目标是算子本身,然后以真实 tensor 运行算子,并把结果 tensor 绑定到 proxy 上:

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/fx/experimental/proxy_tensor.py#L259
f_args, f_kwargs = pytree.tree_map_only(torch.Tensor, fetch_tensor_proxy(tracer), (args, kwargs))

proxy_args, proxy_kwargs = pytree.tree_map_only(
    (SymInt, SymFloat, SymBool),
    fetch_sym_proxy(proxy_mode.tracer),
    pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (f_args, f_kwargs))
)

proxy_out = proxy_mode.tracer.create_proxy('call_function', func, proxy_args, proxy_kwargs,
                                           name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__))

out = func(*args, **kwargs)

track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)

如此往复,直到 AOTAutograd trace 完正向传播和反向传播中的所有算子,得到一张完整的 joint graph。

Partition Link to heading

AOTAutograd 用 partition_fn 把 joint graph 划分为正向传播计算图和反向传播计算图,目前内置了两种 partition_fn:

  • default_partition: 模拟了 PyTorch 的默认行为,找出从 forward 的输入到 forward 的输出的所有算子输出,其中被 backward 用到的 tensor 也作为 forward 的输出,是 forward 保留给 backward 的 tensor;
  • min_cut_rematerialization_partition: 通过在 backward 中引入重算,减少 forward 给 backward 保留的 tensor,这种重算的思路与 gradient/activation checkpointing 一致;

此前所述案例对应的 joint graph 为:

def forward(self, primals, tangents):
    primals_1: f32[2, 4], primals_2: f32[2, 4], primals_3: f32[2, 4], primals_4: f32[2, 4], tangents_1: f32[2, 4], = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    # No stacktrace found for following nodes
    add: f32[2, 4] = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    add_1: f32[2, 4] = torch.ops.aten.add.Tensor(add, primals_3);  add = primals_3 = None
    add_2: f32[2, 4] = torch.ops.aten.add.Tensor(add_1, primals_4);  add_1 = primals_4 = None
    cos: f32[2, 4] = torch.ops.aten.cos.default(add_2)
    cos_1: f32[2, 4] = torch.ops.aten.cos.default(cos)
    sin: f32[2, 4] = torch.ops.aten.sin.default(cos);  cos = None
    neg: f32[2, 4] = torch.ops.aten.neg.default(sin);  sin = None
    mul: f32[2, 4] = torch.ops.aten.mul.Tensor(tangents_1, neg);  tangents_1 = neg = None

    #
    sin_1: f32[2, 4] = torch.ops.aten.sin.default(add_2);  add_2 = None
    neg_1: f32[2, 4] = torch.ops.aten.neg.default(sin_1);  sin_1 = None
    mul_1: f32[2, 4] = torch.ops.aten.mul.Tensor(mul, neg_1);  mul = neg_1 = None
    return pytree.tree_unflatten([cos_1, mul_1, mul_1, mul_1, mul_1], self._out_spec)

其中必须由 backward 计算的节点是 mulmul_1,因为它们直接依赖于 tangents,也就是 backward 的输入,因此它们也叫做 tangent’s closure。在此基础上,我们可以有很多种切分 joint graph 的方法,比如:

  • forward 给 backward 保留 {neg, neg_1},从而 backward 计算 {mul, mul_1} 就可以得到最终结果;
  • forward 给 backward 保留 {primals_1, primals_2, primals_3, primals_4},从而 backward 需要计算 joint graph 中的所有节点;

那么如何选择 forward 保留给 backward 的算子,从而使内存需求最小?答案是通过求解 最大流最小割(max-flow/min-cut) 问题,流程如下:

  • 在源节点 (source) 和 primals 之间各添加一条边,在所有的 tangent’s closure 和目标节点 (sink) 之间各添加一条边,它们组成了一张从 source 到 sink 的有向图,边上的权重是 tensor size;
  • 我们需要找到一个合适的切分方法,把这个有向图分成两部分,使得 source 子图到 target 子图之间边上的权重之和最小,这是一个最小割问题;
  • 最小割问题的对等问题是最大流问题,已经有标准的解法,直接在该有向图上运行 max-flow 算法即可得到最佳划分方法;

关于最大流最小割的更多细节,可以参考 最大流最小割(Max-flow min-cut)定理

通过解最大流最小割问题,我们找到了最佳切分方法,即 forward 为 backward 保留 {add_2},需要的内存是最少的,最终生成的正向传播与反向传播为:

def forward(self, primals_1, primals_2, primals_3, primals_4):
    add = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    add_1 = torch.ops.aten.add.Tensor(add, primals_3);  add = primals_3 = None
    add_2 = torch.ops.aten.add.Tensor(add_1, primals_4);  add_1 = primals_4 = None
    cos = torch.ops.aten.cos.default(add_2)
    cos_1 = torch.ops.aten.cos.default(cos);  cos = None
    return [cos_1, add_2]

def forward(self, add_2, tangents_1):
    cos = torch.ops.aten.cos.default(add_2)
    sin = torch.ops.aten.sin.default(cos);  cos = None
    neg = torch.ops.aten.neg.default(sin);  sin = None
    mul = torch.ops.aten.mul.Tensor(tangents_1, neg);  tangents_1 = neg = None
    sin_1 = torch.ops.aten.sin.default(add_2);  add_2 = None
    neg_1 = torch.ops.aten.neg.default(sin_1);  sin_1 = None
    mul_1 = torch.ops.aten.mul.Tensor(mul, neg_1);  mul = neg_1 = None
    return [mul_1, mul_1, mul_1, mul_1]

可以看到,使用 default_partition 时正向传播为反向传播保留了 cosadd_2,使用 min_cut_rematerialization_partition 时正向传播只为反向传播保留了 add_2,反向传播中通过重算获得 cos,从而降低了内存需求。

AOTAutograd 的最后是调用用户指定的编译器分别编译正向传播计算图和反向传播计算图,这里不再赘述。

AOTAutograd 默认 static shape。此外,aot_function 内置 编译缓存 (compilation cache),根据输入张量的特性决定是否需要重新编译 (recompilation),现阶段缓存只有一个条目。

控制语句 Link to heading

不是所有的函数都可以被 AOTAutograd 优化,这主要受限于 make_fx 的 tracing 机制。比如,data-dependent control flow 在 tracing 中是不支持的,下面的代码在 tracing 过程中会抛出异常,因为 if 语句依赖于 tensor 的值:

def fn(x):
    norm = x.norm()
    if norm > 5:
        x = x / norm
    return x

常规的循环在 tracing 的过程中被展开:

def fn(x):
    for i in range(1, 4):
        x = x * i
    return x

使用 AOTAutograd 优化上面这个函数会得到:

def forward(self, primals_1):
    mul = torch.ops.aten.mul.Tensor(primals_1, 0);  primals_1 = None
    mul_1 = torch.ops.aten.mul.Tensor(mul, 1);  mul = None
    mul_2 = torch.ops.aten.mul.Tensor(mul_1, 2);  mul_1 = None
    mul_3 = torch.ops.aten.mul.Tensor(mul_2, 3);  mul_2 = None
    return [mul_3]

def forward(self, tangents_1):
    mul_4 = torch.ops.aten.mul.Tensor(tangents_1, 3);  tangents_1 = None
    mul_5 = torch.ops.aten.mul.Tensor(mul_4, 2);  mul_4 = None
    mul_6 = torch.ops.aten.mul.Tensor(mul_5, 1);  mul_5 = None
    mul_7 = torch.ops.aten.mul.Tensor(mul_6, 0);  mul_6 = None
    return [mul_7]

函数调用在 tracing 的过程中也被展开,例如下面的例子:

def fn(x, n):
    if n > 1:
        return fn(x, n-1) * n
    else:
        return x

x = torch.randn(8, requires_grad=True)
aot_print_fn = aot_function(fn, fw_compiler=compiler_fn,
    bw_compiler=compiler_fn)
y = aot_print_fn(x, 4)
loss = y.sum()
loss.backward()

使用 AOTAutograd 优化上面这个函数会得到:

def forward(self, primals_1, primals_2):
    mul = torch.ops.aten.mul.Tensor(primals_1, 2);  primals_1 = None
    mul_1 = torch.ops.aten.mul.Tensor(mul, 3);  mul = None
    mul_2 = torch.ops.aten.mul.Tensor(mul_1, 4);  mul_1 = None
    return [mul_2]

def forward(self, tangents_1):
    mul_3 = torch.ops.aten.mul.Tensor(tangents_1, 4);  tangents_1 = None
    mul_4 = torch.ops.aten.mul.Tensor(mul_3, 3);  mul_3 = None
    mul_5 = torch.ops.aten.mul.Tensor(mul_4, 2);  mul_4 = None
    return [mul_5, None]

调用栈 Link to heading

下面是 AOTAutograd 通过 __torch_dispatch__ 为算子创建 FX Node 的函数调用栈 (省略了 C/C++ 部分的 dispatch),有兴趣的朋友可以参考其中的关键实现:

总结 Link to heading

  • AOTAutograd 利用了 __torch_dispatch__ 机制通过 tracing 提前得到联合正向传播和反向传播计算图;
  • 经过 __torch_dispatch__ trace 得到的是最内层的 ATen 算子,AOTAutograd 将其保存在 FX Graph 中;
  • 如果用于 tracing 的 tensors 中有重复,那么通过 make_fx 得到的计算图与预期不符,AOTAutograd 会在 tracing 前去重;
  • AOTAutograd 用 partition_fn 把 trace 得到的 joint graph 划分为 foward graph 和 backward graph;
  • min_cut_rematerialization_partition 通过求解最大流/最小割问题最小化正向传播保留给反向传播的 tensor;
  • make_fx 的 tracing 不支持 data-dependent control flow,循环、函数调用在 tracing 后被展开;

参考文献 Link to heading