简介 Link to heading
PyTorch 最大的优点是其灵活性,但缺点也很显著,PyTorch 生态对图编译器 (graph compiler) 的支持非常有限,现有的基于 TorchScript 的方案非常难用。为了解决这些问题,PyTorch 2.0 引入了 torch.compile
,它包含几个新的组件: TorchDynamo, AOTAutograd, PrimTorch, TorchInductor。其中,TorchDynamo 用于从用户的 PyTorch 代码中以最小的代价 捕获计算图,更多关于 PyTorch 2.0 的简介请参考 PyTorch 2.0。
TorchDynamo 与其他 PyTorch 代码优化工具的对比:
torch.jit.trace()
: trace 失败时报错,不允许存在不支持的算子;torch.jit.script()
: 需要添加类型注释(type annotation)并移除非 PyTorch 代码;torch.fx.symbolic_trace()
:无法处理控制流,也没有 fallback 机制,trace 失败时中断执行;torch._dynamo
: 默认生成部分图,设置nopython=True
后在满足要求时可以生成全图,部分图不需要修改用户代码;
TorchDynamo 的局限性:
- 不支持生成器(generator);
- Python 字节码兼容性和可移植性问题;
为了熟悉 TorchDynamo 的实现原理,本文用一个简单的例子,跟踪它在 PyTorch 2.0 中的执行过程。本次源码剖析所用软硬件系统如下:
Item | Value |
---|---|
Docker Image | nvcr.io/nvidia/pytorch:23.04-py3 |
PyTorch Version | 2.1.0a0+fe05266 |
GPU | A100-SXM4-80GB |
Date | 2023/05/13 |
源码剖析所用案例如下:
#!/usr/bin/env python
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print(">>> my_compiler() invoked:")
print(">>> FX graph:")
gm.graph.print_tabular()
print(f">>> Code:\n{gm.code}")
return gm.forward # return a python callable
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
def test(a, b):
for i in range(4):
toy_example(a, b * (-1) ** i)
if __name__ == "__main__":
a, b = torch.randn(10), torch.ones(10)
toy_example = torch.compile(toy_example, backend=my_compiler)
test(a, b)
执行上述代码,得到如下输出:
$ python test.py
>>> my_compiler() invoked:
>>> FX graph:
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f1c83e6c720> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}
>>> Code:
def forward(self, a : torch.Tensor, b : torch.Tensor):
abs_1 = torch.abs(a)
add = abs_1 + 1; abs_1 = None
truediv = a / add; a = add = None
sum_1 = b.sum(); b = None
lt = sum_1 < 0; sum_1 = None
return (truediv, lt)
>>> my_compiler() invoked:
>>> FX graph:
opcode name target args kwargs
------------- ------ ----------------------- --------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (x, b) {}
output output output ((mul,),) {}
>>> Code:
def forward(self, b : torch.Tensor, x : torch.Tensor):
mul = x * b; x = b = None
return (mul,)
>>> my_compiler() invoked:
>>> FX graph:
opcode name target args kwargs
------------- ------ ----------------------- ----------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (b, -1) {}
call_function mul_1 <built-in function mul> (x, mul) {}
output output output ((mul_1,),) {}
>>> Code:
def forward(self, b : torch.Tensor, x : torch.Tensor):
mul = b * -1; b = None
mul_1 = x * mul; x = mul = None
return (mul_1,)
在案例中,toy_example()
包含一个 if
语句,自定义的后端编译器 my_compiler()
打印出 FX graph 和生成的 Python 代码。可以看到函数 toy_example()
在 if
语句处被分解为 3 张计算图,分别是 if
前的 x = a / (torch.abs(a) + 1); b.sum() < 0
,if
为 False
时的 x * b
,以及 if
为 True
时的 b = b * -1; x * b
。
初始化 Link to heading
PyTorch 2.0 的入口是 torch.compile()
,它允许用户通过 backend
指定后端编译器,可以是 PyTorch 内置的编译器,例如 inductor
,也可以是用户自定义的编译器函数,案例中的 my_compiler()
既是如此,它接受一个 FX Graph,返回一个编译好的函数。torch.compile()
实现在 torch/__init__.py#L1385,它只是对 torch._dynamo.optimize()
的简单包装,定义在 torch/_dynamo/eval_frame.py#L443。nopython
参数指定在 TorchDynamo 抓取计算图的过程中 是否允许出现 graph break,比如 TorchDynamo 无法支持的某些 if
语句,为 True
时碰到 graph break 会报错,为 False
时碰到 graph break 则 返回 Python 解释器,nopython
的默认参数为 Fasle
。参数 dynamic
指定 是否允许 dynamic shape,默认情况为 False
。
torch._dynamo.optimize()
只是通过 _optimize_catch_errors() 返回了一个 OptimizeContext 对象。因此,经 torch.compile
装饰过的函数成为了一个 OptimizeContext
,其中 catch_errors() 成为了 OptimizeContext
中的一个回调函数 (callback)。
由此可见,执行 torch.compile(toy_example, backend=my_compiler)
只是做了一些初始化工作,并没有实际的编译。值得一提的是,在此初始化阶段 TorchDynamo 通过 torch._dynamo.disable()
为很多 PyTorch 内部函数禁用了 TorchDynamo。因此,TorchDynamo 在遇到这些函数时会产生 graph break。详细过程参考 stack:
- [P005] > torch/_dynamo/eval_frame.py#L886 [New]
- [P004] > torch/_dynamo/eval_frame.py#L172 [New]
- [P003] > torch/_dynamo/eval_frame.py#L317 [New]
- [P002] > torch/_dynamo/eval_frame.py#L397 [New]
- [P001] > torch/_dynamo/eval_frame.py#L443 [New]
- [P000] > torch/init.py#L1385 [New]
PEP 523 Link to heading
TorchDynamo 的 编译过程发生在将要执行前,它是一个 JIT 编译器,对于案例来说,在执行 toy_example(a, b * (-1) ** i)
时开始编译,因此此时 toy_example
已经被修改为 _fn。
TorchDynamo 的入口是 PEP 523。熟悉 CPython 的朋友知道,CPython 在执行 Python 函数前会把 Python 函数编译为字节码。在 Python 虚拟机 (PVM) 中有一个非常重要的函数 _PyEval_EvalFrameDefault,它的功能是在 PVM 中逐条执行编译好的字节码。PEP 523 的功能是给 Python 开发者提供一个接口,让用户在 Python 虚拟机执行字节码前获得字节码,从而可以在 Python 中实现 即时编译器 (JIT Compiler) 的功能。
TorchDynamo 正是通过 PEP 523 把 TorchDynamo 的编译逻辑引入到 Python 虚拟机中,通过 CPython 提供的 API _PyInterpreterState_SetEvalFrameFunc()
把 CPython 用于执行字节码的函数替换为 custom_eval_frame_shim
:
inline static void enable_eval_frame_shim(PyThreadState* tstate) {
#if PY_VERSION_HEX >= 0x03090000
if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) !=
&custom_eval_frame_shim) {
_PyInterpreterState_SetEvalFrameFunc(
tstate->interp, &custom_eval_frame_shim);
}
#else
if (tstate->interp->eval_frame != &custom_eval_frame_shim) {
// First call
tstate->interp->eval_frame = &custom_eval_frame_shim;
}
#endif
}
详细的函数调用栈为:
- [C005] > torch/csrc/dynamo/eval_frame.c#L366 [New]
- [C004] > torch/csrc/dynamo/eval_frame.c#L740 [New]
- [C003] > torch/csrc/dynamo/eval_frame.c#L758 [New]
- [C002] > torch/csrc/dynamo/eval_frame.c#L784 [New]
- [P001] > torch/_dynamo/eval_frame.py#L233 [New]
- [P000] > test.py#L17:test [New]
因此,随后在 _fn
中执行用户想要编译的函数时,会触发 custom_eval_frame_shim
:
try:
return fn(*args, **kwargs)
finally:
set_eval_frame(prior)
dynamic_ctx.__exit__(None, None, None)
backend_ctx.__exit__(None, None, None)
在 _custom_eval_frame() 中,通过 lookup(extra, frame, NULL)
检查缓存 (cache) 中是否已经存在编译好的的函数,存在则直接通过 eval_custom_code()
执行编译好的函数,从而避免再次编译相同的函数,否则则通过 call_callback(callback, frame, cache_size(extra))
来调用此前设置好的回调函数来编译用户的函数,编译好后通过 set_extra(frame->f_code, extra)
将编译好的函数存放在用户的 Python frame 中。
因为初始化阶段设置的回调函数是 catch_errors(),所以执行 call_callback
会来到这里。需要注意的是,因为 catch_errors()
也是一个 Python 函数,而 PEP 523 设置的 eval frame 函数依然有效,所以此后的每个 Python 函数都要通过 custom_eval_frame_shim() 来执行。例如,此时的函数调用栈为:
- [P009] > torch/_dynamo/eval_frame.py#L362 [New]
- [C008] > torch/csrc/dynamo/eval_frame.c#L355 [New]
- [C007] > torch/csrc/dynamo/eval_frame.c#L621
- [C006] > torch/csrc/dynamo/eval_frame.c#L346
- [C005] > torch/csrc/dynamo/eval_frame.c#L399 [New]
- [C004] > torch/csrc/dynamo/eval_frame.c#L640 [New]
- [C003] > torch/csrc/dynamo/eval_frame.c#L621 [New]
- [C002] > torch/csrc/dynamo/eval_frame.c#L346 [New]
- [P001] > torch/_dynamo/eval_frame.py#L233 [New]
- [P000] > test.py#L17:test [New]
_convert_frame_assert() 中包含了很多检查,其中有两个比较重要,一是 TorchDynamo 不支持生成器,二是 如果缓存大小超过配置会有警告信息,缓存大小由 torch._dynamo.config.cache_size_limit
指定,默认值为 64,含义是对于同一个 Python 函数,如果函数的输入张量信息组合变化超过 64 种,TorchDynamo 则不会继续编译用户指定的函数,特别是 TorchDynamo 以 static shape 模式运行、而用户的模型是 dynamic shape 时。
字节码 Link to heading
TorchDynamo 用于编译用户函数的入口在 _compile(),此时的函数调用栈为 (移除了额外的 custom_eval_frame_shim
):
- [P029] > torch/_dynamo/convert_frame.py#L279 [New]
- [P025] > torch/_dynamo/utils.py#L158 [New]
- [P021] > torch/_dynamo/convert_frame.py#L200 [New]
- [P017] > torch/_dynamo/convert_frame.py#L96 [New]
- [P013] > torch/_dynamo/convert_frame.py#L403 [New]
- [P009] > torch/_dynamo/eval_frame.py#L362
- [C005] > torch/csrc/dynamo/eval_frame.c#L399
- [C004] > torch/csrc/dynamo/eval_frame.c#L640
- [C003] > torch/csrc/dynamo/eval_frame.c#L621
- [C002] > torch/csrc/dynamo/eval_frame.c#L346
- [P001] > torch/_dynamo/eval_frame.py#L233 [New]
- [P000] > test.py#L17:test [New]
_compile()
中有一个循环,其中调用了 transform_code_object() 来转换用户想要优化的函数。在没有发生异常的情况下,该循环只执行一次;遇到 exc.RestartAnalysis
异常会再次尝试 transform_code_object()
;遇到 exc.SkipFrame
时直接返回,表明跳过该用户函数的优化:
for attempt in itertools.count():
try:
out_code = transform_code_object(code, transform)
orig_code_map[out_code] = code
break
except exc.RestartAnalysis:
log.debug("Restarting analysis ...")
if attempt > 100:
unimplemented("100+ RestartAnalysis() calls")
except exc.SkipFrame as e:
log.debug(
f"Skipping frame {e} {code.co_name} \
{code.co_filename} {code.co_firstlineno}"
)
if one_graph:
log.debug("No graph captured with one_graph=True")
return None
output_codes.add(out_code)
transform_code_object()
中的 code
是 PyCodeObject 类型,其中包含 待编译的 Python 函数在由 Python 虚拟机编译为字节码后能确定的诸多信息,包括编译好的字节码 co_code
、常量 co_consts
、用到的符号 co_names
等等,这些都是 编译时 (compile time) 确定的信息,在 torch/_dynamo/bytecode_transformation.py#L524 由 TorchDynamo 捕获:
ipdb> pprint(code_options)
{'co_argcount': 2,
'co_cellvars': (),
'co_code': b'|\x00t\x00\xa0\x01|\x00\xa1\x01d\x01\x17\x00\x1b\x00}\x02|\x01'
b'\xa0\x02\xa1\x00d\x02k\x00r&|\x01d\x03\x14\x00}\x01|\x02'
b'|\x01\x14\x00S\x00',
'co_consts': (None, 1, 0, -1),
'co_filename': 'test.py',
'co_firstlineno': 13,
'co_flags': 67,
'co_freevars': (),
'co_kwonlyargcount': 0,
'co_lnotab': b'\x00\x01\x12\x01\x0c\x01\x08\x01',
'co_name': 'toy_example',
'co_names': ('torch', 'abs', 'sum'),
'co_nlocals': 3,
'co_posonlyargcount': 0,
'co_stacksize': 4,
'co_varnames': ('a', 'b', 'x')}
随后 TorchDynamo 通过 cleaned_instructions() 来预处理编译好的字节码指令,获取字节码指令通过 Python 标准库 dis.get_instructions(code)
实现,案例中的 toy_example()
经过反汇编后得到的字节码指令如下:
ipdb> dis.dis(code)
14 0 LOAD_FAST 0 (a)
2 LOAD_GLOBAL 0 (torch)
4 LOAD_METHOD 1 (abs)
6 LOAD_FAST 0 (a)
8 CALL_METHOD 1
10 LOAD_CONST 1 (1)
12 BINARY_ADD
14 BINARY_TRUE_DIVIDE
16 STORE_FAST 2 (x)
15 18 LOAD_FAST 1 (b)
20 LOAD_METHOD 2 (sum)
22 CALL_METHOD 0
24 LOAD_CONST 2 (0)
26 COMPARE_OP 0 (<)
28 POP_JUMP_IF_FALSE 38
16 30 LOAD_FAST 1 (b)
32 LOAD_CONST 3 (-1)
34 BINARY_MULTIPLY
36 STORE_FAST 1 (b)
17 >> 38 LOAD_FAST 2 (x)
40 LOAD_FAST 1 (b)
42 BINARY_MULTIPLY
44 RETURN_VALUE
cleaned_instructions()
对捕获到的字节码进行清洗,例如对跳转指令做标准化处理,经过清洗后的指令以结构化的数据表示为 Instruction:
ipdb> pprint(instructions)
[Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='a', offset=0, starts_line=14, is_jump_target=False, target=None),
Instruction(opcode=116, opname='LOAD_GLOBAL', arg=0, argval='torch', offset=2, starts_line=14, is_jump_target=False, target=None),
Instruction(opcode=106, opname='LOAD_ATTR', arg=1, argval='abs', offset=4, starts_line=14, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='a', offset=6, starts_line=14, is_jump_target=False, target=None),
Instruction(opcode=131, opname='CALL_FUNCTION', arg=1, argval=1, offset=8, starts_line=14, is_jump_target=False, target=None),
Instruction(opcode=100, opname='LOAD_CONST', arg=1, argval=1, offset=10, starts_line=14, is_jump_target=False, target=None),
Instruction(opcode=23, opname='BINARY_ADD', arg=None, argval=None, offset=12, starts_line=14, is_jump_target=False, target=None),
Instruction(opcode=27, opname='BINARY_TRUE_DIVIDE', arg=None, argval=None, offset=14, starts_line=14, is_jump_target=False, target=None),
Instruction(opcode=125, opname='STORE_FAST', arg=2, argval='x', offset=16, starts_line=14, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=18, starts_line=15, is_jump_target=False, target=None),
Instruction(opcode=106, opname='LOAD_ATTR', arg=2, argval='sum', offset=20, starts_line=15, is_jump_target=False, target=None),
Instruction(opcode=131, opname='CALL_FUNCTION', arg=0, argval=0, offset=22, starts_line=15, is_jump_target=False, target=None),
Instruction(opcode=100, opname='LOAD_CONST', arg=2, argval=0, offset=24, starts_line=15, is_jump_target=False, target=None),
Instruction(opcode=107, opname='COMPARE_OP', arg=0, argval='<', offset=26, starts_line=15, is_jump_target=False, target=None),
Instruction(opcode=114, opname='POP_JUMP_IF_FALSE', arg=38, argval=38, offset=28, starts_line=15, is_jump_target=False, target=Instruction(opcode=124, opname='LOAD_FAST', arg=2, argval='x', offset=38, starts_line=17, is_jump_target=True, target=None)),
Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=30, starts_line=16, is_jump_target=False, target=None),
Instruction(opcode=100, opname='LOAD_CONST', arg=3, argval=-1, offset=32, starts_line=16, is_jump_target=False, target=None),
Instruction(opcode=20, opname='BINARY_MULTIPLY', arg=None, argval=None, offset=34, starts_line=16, is_jump_target=False, target=None),
Instruction(opcode=125, opname='STORE_FAST', arg=1, argval='b', offset=36, starts_line=16, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=2, argval='x', offset=38, starts_line=17, is_jump_target=True, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=40, starts_line=17, is_jump_target=False, target=None),
Instruction(opcode=20, opname='BINARY_MULTIPLY', arg=None, argval=None, offset=42, starts_line=17, is_jump_target=False, target=None),
Instruction(opcode=83, opname='RETURN_VALUE', arg=None, argval=None, offset=44, starts_line=17, is_jump_target=False, target=None)]
Hello World 的字节码 Link to heading
为了便于理解 Python 虚拟机、字节码和 TorchDynamo 的行为,下面用 hello()
函数简要介绍下 Python 字节码的行为:
import dis
def hello():
print("Hello, world!")
for k in ["co_names", "co_varnames", "co_consts"]:
print(k, getattr(hello.__code__, k))
print(dis.dis(hello))
"""
co_names ('print',)
co_varnames ()
co_consts (None, 'Hello, world!')
0 LOAD_GLOBAL 0 (print)
2 LOAD_CONST 1 ('Hello, world!')
4 CALL_FUNCTION 1
6 POP_TOP
8 LOAD_CONST 0 (None)
10 RETURN_VALUE
"""
其中包含了 6 条 Python 字节码,它们的功能如下:
LOAD_GLOBAL 0
: 从f_builtins
和f_globals
中加载由下标 0 所引用的全局对象,把它压到数据栈上;LOAD_CONST 1
: 从co_consts
中加载由下标 1 所引用的常量,把它压到数据栈上;CALL_FUNCTION 1
: 从栈顶出栈 1 个元素作为函数参数,再出栈一个元素作为被调函数,调用该函数并把返回值压到数据栈上;POP_TOP
: 从栈顶移除一个元素;LOAD_CONST 0
: 从co_consts
中加载由下标 0 所引用的常量,把它压到数据栈上;RETURN_VALUE
: 从栈顶出栈 1 个元素,把它作为返回值返回给主调函数;
Python 虚拟机是 Stack Machine,它维护了 3 个 stack:
- Call Stack: 其中的条目是 Python frame,类似 C 的函数调用栈;
- Evaluation Stack (or Data Stack): 每个 Python frame 都有一个 evaluation stack,执行 Python 字节码时的数据由该 stack 管理,这与常见的 Register Machine 有所区别;
- Block Stack: 每个 Python frame 都有一个 block stack,目的是跟踪 Python 中的控制结构,例如循环、
try / except
、with
语句等,进入/退出这类控制结构时会有对应的条目被 push/pop。Block stack 帮助 Python 在任意时刻都知道当前活跃的 block,continue
和break
会影响当前活跃的 block;
更多 Python 字节码和虚拟机的细节可以参考 _PyEval_EvalFrameDefault。