简介 Link to heading

PyTorch 最大的优点是其灵活性,但缺点也很显著,PyTorch 生态对图编译器 (graph compiler) 的支持非常有限,现有的基于 TorchScript 的方案非常难用。为了解决这些问题,PyTorch 2.0 引入了 torch.compile,它包含几个新的组件: TorchDynamo, AOTAutograd, PrimTorch, TorchInductor。其中,TorchDynamo 用于从用户的 PyTorch 代码中以最小的代价 捕获计算图,更多关于 PyTorch 2.0 的简介请参考 PyTorch 2.0

PyTorch 2.0

TorchDynamo 与其他 PyTorch 代码优化工具的对比:

  • torch.jit.trace(): trace 失败时报错,不允许存在不支持的算子;
  • torch.jit.script(): 需要添加类型注释(type annotation)并移除非 PyTorch 代码;
  • torch.fx.symbolic_trace():无法处理控制流,也没有 fallback 机制,trace 失败时中断执行;
  • torch._dynamo: 默认生成部分图,设置 nopython=True 后在满足要求时可以生成全图,部分图不需要修改用户代码;

TorchDynamo 的局限性:

  • 不支持生成器(generator);
  • Python 字节码兼容性和可移植性问题;

为了熟悉 TorchDynamo 的实现原理,本文用一个简单的例子,跟踪它在 PyTorch 2.0 中的执行过程。本次源码剖析所用软硬件系统如下:

Item Value
Docker Image nvcr.io/nvidia/pytorch:23.04-py3
PyTorch Version 2.1.0a0+fe05266
GPU A100-SXM4-80GB
Date 2023/05/13

源码剖析所用案例如下:

#!/usr/bin/env python

from typing import List
import torch

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print(">>> my_compiler() invoked:")
    print(">>> FX graph:")
    gm.graph.print_tabular()
    print(f">>> Code:\n{gm.code}")
    return gm.forward  # return a python callable

def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

def test(a, b):
    for i in range(4):
        toy_example(a, b * (-1) ** i)

if __name__ == "__main__":
    a, b = torch.randn(10), torch.ones(10)
    toy_example = torch.compile(toy_example, backend=my_compiler)
    test(a, b)

执行上述代码,得到如下输出:

$ python test.py
>>> my_compiler() invoked:
>>> FX graph:
opcode         name     target                                                  args              kwargs
-------------  -------  ------------------------------------------------------  ----------------  --------
placeholder    a        a                                                       ()                {}
placeholder    b        b                                                       ()                {}
call_function  abs_1    <built-in method abs of type object at 0x7f1c83e6c720>  (a,)              {}
call_function  add      <built-in function add>                                 (abs_1, 1)        {}
call_function  truediv  <built-in function truediv>                             (a, add)          {}
call_method    sum_1    sum                                                     (b,)              {}
call_function  lt       <built-in function lt>                                  (sum_1, 0)        {}
output         output   output                                                  ((truediv, lt),)  {}
>>> Code:
def forward(self, a : torch.Tensor, b : torch.Tensor):
    abs_1 = torch.abs(a)
    add = abs_1 + 1;  abs_1 = None
    truediv = a / add;  a = add = None
    sum_1 = b.sum();  b = None
    lt = sum_1 < 0;  sum_1 = None
    return (truediv, lt)

>>> my_compiler() invoked:
>>> FX graph:
opcode         name    target                   args       kwargs
-------------  ------  -----------------------  ---------  --------
placeholder    b       b                        ()         {}
placeholder    x       x                        ()         {}
call_function  mul     <built-in function mul>  (x, b)     {}
output         output  output                   ((mul,),)  {}
>>> Code:
def forward(self, b : torch.Tensor, x : torch.Tensor):
    mul = x * b;  x = b = None
    return (mul,)

>>> my_compiler() invoked:
>>> FX graph:
opcode         name    target                   args         kwargs
-------------  ------  -----------------------  -----------  --------
placeholder    b       b                        ()           {}
placeholder    x       x                        ()           {}
call_function  mul     <built-in function mul>  (b, -1)      {}
call_function  mul_1   <built-in function mul>  (x, mul)     {}
output         output  output                   ((mul_1,),)  {}
>>> Code:
def forward(self, b : torch.Tensor, x : torch.Tensor):
    mul = b * -1;  b = None
    mul_1 = x * mul;  x = mul = None
    return (mul_1,)

在案例中,toy_example() 包含一个 if 语句,自定义的后端编译器 my_compiler() 打印出 FX graph 和生成的 Python 代码。可以看到函数 toy_example()if 语句处被分解为 3 张计算图,分别是 if 前的 x = a / (torch.abs(a) + 1); b.sum() < 0ifFalse 时的 x * b,以及 ifTrue 时的 b = b * -1; x * b

初始化 Link to heading

PyTorch 2.0 的入口是 torch.compile(),它允许用户通过 backend 指定后端编译器,可以是 PyTorch 内置的编译器,例如 inductor,也可以是用户自定义的编译器函数,案例中的 my_compiler() 既是如此,它接受一个 FX Graph,返回一个编译好的函数。torch.compile() 实现在 torch/__init__.py#L1385,它只是对 torch._dynamo.optimize() 的简单包装,定义在 torch/_dynamo/eval_frame.py#L443nopython 参数指定在 TorchDynamo 抓取计算图的过程中 是否允许出现 graph break,比如 TorchDynamo 无法支持的某些 if 语句,为 True 时碰到 graph break 会报错,为 False 时碰到 graph break 则 返回 Python 解释器nopython 的默认参数为 Fasle。参数 dynamic 指定 是否允许 dynamic shape,默认情况为 False

torch._dynamo.optimize() 只是通过 _optimize_catch_errors() 返回了一个 OptimizeContext 对象。因此,经 torch.compile 装饰过的函数成为了一个 OptimizeContext,其中 catch_errors() 成为了 OptimizeContext 中的一个回调函数 (callback)。

由此可见,执行 torch.compile(toy_example, backend=my_compiler) 只是做了一些初始化工作,并没有实际的编译。值得一提的是,在此初始化阶段 TorchDynamo 通过 torch._dynamo.disable() 为很多 PyTorch 内部函数禁用了 TorchDynamo。因此,TorchDynamo 在遇到这些函数时会产生 graph break。详细过程参考 stack:

PEP 523 Link to heading

TorchDynamo

TorchDynamo 的 编译过程发生在将要执行前,它是一个 JIT 编译器,对于案例来说,在执行 toy_example(a, b * (-1) ** i) 时开始编译,因此此时 toy_example 已经被修改为 _fn

TorchDynamo 的入口是 PEP 523。熟悉 CPython 的朋友知道,CPython 在执行 Python 函数前会把 Python 函数编译为字节码。在 Python 虚拟机 (PVM) 中有一个非常重要的函数 _PyEval_EvalFrameDefault,它的功能是在 PVM 中逐条执行编译好的字节码。PEP 523 的功能是给 Python 开发者提供一个接口,让用户在 Python 虚拟机执行字节码前获得字节码,从而可以在 Python 中实现 即时编译器 (JIT Compiler) 的功能。

TorchDynamo 正是通过 PEP 523 把 TorchDynamo 的编译逻辑引入到 Python 虚拟机中,通过 CPython 提供的 API _PyInterpreterState_SetEvalFrameFunc() 把 CPython 用于执行字节码的函数替换为 custom_eval_frame_shim:

inline static void enable_eval_frame_shim(PyThreadState* tstate) {
#if PY_VERSION_HEX >= 0x03090000
  if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) !=
      &custom_eval_frame_shim) {
    _PyInterpreterState_SetEvalFrameFunc(
        tstate->interp, &custom_eval_frame_shim);
  }
#else
  if (tstate->interp->eval_frame != &custom_eval_frame_shim) {
    // First call
    tstate->interp->eval_frame = &custom_eval_frame_shim;
  }
#endif
}

详细的函数调用栈为:

因此,随后在 _fn 中执行用户想要编译的函数时,会触发 custom_eval_frame_shim:

try:
    return fn(*args, **kwargs)
finally:
    set_eval_frame(prior)
    dynamic_ctx.__exit__(None, None, None)
    backend_ctx.__exit__(None, None, None)

_custom_eval_frame() 中,通过 lookup(extra, frame, NULL) 检查缓存 (cache) 中是否已经存在编译好的的函数,存在则直接通过 eval_custom_code() 执行编译好的函数,从而避免再次编译相同的函数,否则则通过 call_callback(callback, frame, cache_size(extra)) 来调用此前设置好的回调函数来编译用户的函数,编译好后通过 set_extra(frame->f_code, extra) 将编译好的函数存放在用户的 Python frame 中。

因为初始化阶段设置的回调函数是 catch_errors(),所以执行 call_callback 会来到这里。需要注意的是,因为 catch_errors() 也是一个 Python 函数,而 PEP 523 设置的 eval frame 函数依然有效,所以此后的每个 Python 函数都要通过 custom_eval_frame_shim() 来执行。例如,此时的函数调用栈为:

_convert_frame_assert() 中包含了很多检查,其中有两个比较重要,一是 TorchDynamo 不支持生成器,二是 如果缓存大小超过配置会有警告信息,缓存大小由 torch._dynamo.config.cache_size_limit 指定,默认值为 64,含义是对于同一个 Python 函数,如果函数的输入张量信息组合变化超过 64 种,TorchDynamo 则不会继续编译用户指定的函数,特别是 TorchDynamo 以 static shape 模式运行、而用户的模型是 dynamic shape 时。

字节码 Link to heading

TorchDynamo 用于编译用户函数的入口在 _compile(),此时的函数调用栈为 (移除了额外的 custom_eval_frame_shim):

_compile() 中有一个循环,其中调用了 transform_code_object() 来转换用户想要优化的函数。在没有发生异常的情况下,该循环只执行一次;遇到 exc.RestartAnalysis 异常会再次尝试 transform_code_object();遇到 exc.SkipFrame 时直接返回,表明跳过该用户函数的优化:

for attempt in itertools.count():
    try:
        out_code = transform_code_object(code, transform)
        orig_code_map[out_code] = code
        break
    except exc.RestartAnalysis:
        log.debug("Restarting analysis ...")
        if attempt > 100:
            unimplemented("100+ RestartAnalysis() calls")
    except exc.SkipFrame as e:
        log.debug(
            f"Skipping frame {e} {code.co_name} \
            {code.co_filename} {code.co_firstlineno}"
        )
        if one_graph:
            log.debug("No graph captured with one_graph=True")
        return None
output_codes.add(out_code)

transform_code_object() 中的 codePyCodeObject 类型,其中包含 待编译的 Python 函数在由 Python 虚拟机编译为字节码后能确定的诸多信息,包括编译好的字节码 co_code、常量 co_consts、用到的符号 co_names 等等,这些都是 编译时 (compile time) 确定的信息,在 torch/_dynamo/bytecode_transformation.py#L524 由 TorchDynamo 捕获:

ipdb> pprint(code_options)
{'co_argcount': 2,
 'co_cellvars': (),
 'co_code': b'|\x00t\x00\xa0\x01|\x00\xa1\x01d\x01\x17\x00\x1b\x00}\x02|\x01'
            b'\xa0\x02\xa1\x00d\x02k\x00r&|\x01d\x03\x14\x00}\x01|\x02'
            b'|\x01\x14\x00S\x00',
 'co_consts': (None, 1, 0, -1),
 'co_filename': 'test.py',
 'co_firstlineno': 13,
 'co_flags': 67,
 'co_freevars': (),
 'co_kwonlyargcount': 0,
 'co_lnotab': b'\x00\x01\x12\x01\x0c\x01\x08\x01',
 'co_name': 'toy_example',
 'co_names': ('torch', 'abs', 'sum'),
 'co_nlocals': 3,
 'co_posonlyargcount': 0,
 'co_stacksize': 4,
 'co_varnames': ('a', 'b', 'x')}

随后 TorchDynamo 通过 cleaned_instructions() 来预处理编译好的字节码指令,获取字节码指令通过 Python 标准库 dis.get_instructions(code) 实现,案例中的 toy_example() 经过反汇编后得到的字节码指令如下:

ipdb> dis.dis(code)
 14           0 LOAD_FAST                0 (a)
              2 LOAD_GLOBAL              0 (torch)
              4 LOAD_METHOD              1 (abs)
              6 LOAD_FAST                0 (a)
              8 CALL_METHOD              1
             10 LOAD_CONST               1 (1)
             12 BINARY_ADD
             14 BINARY_TRUE_DIVIDE
             16 STORE_FAST               2 (x)

 15          18 LOAD_FAST                1 (b)
             20 LOAD_METHOD              2 (sum)
             22 CALL_METHOD              0
             24 LOAD_CONST               2 (0)
             26 COMPARE_OP               0 (<)
             28 POP_JUMP_IF_FALSE       38

 16          30 LOAD_FAST                1 (b)
             32 LOAD_CONST               3 (-1)
             34 BINARY_MULTIPLY
             36 STORE_FAST               1 (b)

 17     >>   38 LOAD_FAST                2 (x)
             40 LOAD_FAST                1 (b)
             42 BINARY_MULTIPLY
             44 RETURN_VALUE

cleaned_instructions() 对捕获到的字节码进行清洗,例如对跳转指令做标准化处理,经过清洗后的指令以结构化的数据表示为 Instruction:

ipdb> pprint(instructions)
[Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='a', offset=0, starts_line=14, is_jump_target=False, target=None),
 Instruction(opcode=116, opname='LOAD_GLOBAL', arg=0, argval='torch', offset=2, starts_line=14, is_jump_target=False, target=None),
 Instruction(opcode=106, opname='LOAD_ATTR', arg=1, argval='abs', offset=4, starts_line=14, is_jump_target=False, target=None),
 Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='a', offset=6, starts_line=14, is_jump_target=False, target=None),
 Instruction(opcode=131, opname='CALL_FUNCTION', arg=1, argval=1, offset=8, starts_line=14, is_jump_target=False, target=None),
 Instruction(opcode=100, opname='LOAD_CONST', arg=1, argval=1, offset=10, starts_line=14, is_jump_target=False, target=None),
 Instruction(opcode=23, opname='BINARY_ADD', arg=None, argval=None, offset=12, starts_line=14, is_jump_target=False, target=None),
 Instruction(opcode=27, opname='BINARY_TRUE_DIVIDE', arg=None, argval=None, offset=14, starts_line=14, is_jump_target=False, target=None),
 Instruction(opcode=125, opname='STORE_FAST', arg=2, argval='x', offset=16, starts_line=14, is_jump_target=False, target=None),
 Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=18, starts_line=15, is_jump_target=False, target=None),
 Instruction(opcode=106, opname='LOAD_ATTR', arg=2, argval='sum', offset=20, starts_line=15, is_jump_target=False, target=None),
 Instruction(opcode=131, opname='CALL_FUNCTION', arg=0, argval=0, offset=22, starts_line=15, is_jump_target=False, target=None),
 Instruction(opcode=100, opname='LOAD_CONST', arg=2, argval=0, offset=24, starts_line=15, is_jump_target=False, target=None),
 Instruction(opcode=107, opname='COMPARE_OP', arg=0, argval='<', offset=26, starts_line=15, is_jump_target=False, target=None),
 Instruction(opcode=114, opname='POP_JUMP_IF_FALSE', arg=38, argval=38, offset=28, starts_line=15, is_jump_target=False, target=Instruction(opcode=124, opname='LOAD_FAST', arg=2, argval='x', offset=38, starts_line=17, is_jump_target=True, target=None)),
 Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=30, starts_line=16, is_jump_target=False, target=None),
 Instruction(opcode=100, opname='LOAD_CONST', arg=3, argval=-1, offset=32, starts_line=16, is_jump_target=False, target=None),
 Instruction(opcode=20, opname='BINARY_MULTIPLY', arg=None, argval=None, offset=34, starts_line=16, is_jump_target=False, target=None),
 Instruction(opcode=125, opname='STORE_FAST', arg=1, argval='b', offset=36, starts_line=16, is_jump_target=False, target=None),
 Instruction(opcode=124, opname='LOAD_FAST', arg=2, argval='x', offset=38, starts_line=17, is_jump_target=True, target=None),
 Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=40, starts_line=17, is_jump_target=False, target=None),
 Instruction(opcode=20, opname='BINARY_MULTIPLY', arg=None, argval=None, offset=42, starts_line=17, is_jump_target=False, target=None),
 Instruction(opcode=83, opname='RETURN_VALUE', arg=None, argval=None, offset=44, starts_line=17, is_jump_target=False, target=None)]

Hello World 的字节码 Link to heading

为了便于理解 Python 虚拟机、字节码和 TorchDynamo 的行为,下面用 hello() 函数简要介绍下 Python 字节码的行为:

import dis

def hello():
    print("Hello, world!")

for k in ["co_names", "co_varnames", "co_consts"]:
    print(k, getattr(hello.__code__, k))
print(dis.dis(hello))

"""
co_names ('print',)
co_varnames ()
co_consts (None, 'Hello, world!')

0 LOAD_GLOBAL              0 (print)
2 LOAD_CONST               1 ('Hello, world!')
4 CALL_FUNCTION            1
6 POP_TOP
8 LOAD_CONST               0 (None)
10 RETURN_VALUE
"""

其中包含了 6 条 Python 字节码,它们的功能如下:

  • LOAD_GLOBAL 0: 从 f_builtinsf_globals 中加载由下标 0 所引用的全局对象,把它压到数据栈上;
  • LOAD_CONST 1: 从 co_consts 中加载由下标 1 所引用的常量,把它压到数据栈上;
  • CALL_FUNCTION 1: 从栈顶出栈 1 个元素作为函数参数,再出栈一个元素作为被调函数,调用该函数并把返回值压到数据栈上;
  • POP_TOP: 从栈顶移除一个元素;
  • LOAD_CONST 0: 从 co_consts 中加载由下标 0 所引用的常量,把它压到数据栈上;
  • RETURN_VALUE: 从栈顶出栈 1 个元素,把它作为返回值返回给主调函数;

Python 虚拟机是 Stack Machine,它维护了 3 个 stack

  • Call Stack: 其中的条目是 Python frame,类似 C 的函数调用栈;
  • Evaluation Stack (or Data Stack): 每个 Python frame 都有一个 evaluation stack,执行 Python 字节码时的数据由该 stack 管理,这与常见的 Register Machine 有所区别;
  • Block Stack: 每个 Python frame 都有一个 block stack,目的是跟踪 Python 中的控制结构,例如循环、try / exceptwith 语句等,进入/退出这类控制结构时会有对应的条目被 push/pop。Block stack 帮助 Python 在任意时刻都知道当前活跃的 block,continuebreak 会影响当前活跃的 block;

更多 Python 字节码和虚拟机的细节可以参考 _PyEval_EvalFrameDefault