Graph Break Link to heading
Offset 28, POP_JUMP_IF_FALSE:
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L853
POP_JUMP_IF_FALSE = generic_jump(operator.not_, False)
POP_JUMP_IF_TRUE = generic_jump(operator.truth, False)
JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True)
JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True)
POP_JUMP_IF_FALSE
在 symbolic_convert.py#L853 被设置为 generic_jump(operator.not_, False)
,真实实现位于 inner()。
Python 中的 if/else
编译为字节码后是 POP_JUMP_IF_FALSE
这类条件跳转 (JUMP_IF
) 指令,它们在 TorchDynamo 中会引入 Graph Break,触发子图编译 (subgraph compile)。
首先出栈一个元素,即 TensorVariable(lt)
,然后把附加在其上的 Guard 聚合到 Output Graph 中,最重要的地方是 symbolic_convert.py#L262-L289:
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L262-L289
elif (
isinstance(value, (TensorVariable)) and self.should_compile_partial_graph()
):
# compile a partial subgraph prefix then jump into user code
if self.has_backedge():
msg = (
"Skipping frame because there is a graph break in a for/while loop"
)
log.info(msg)
raise exc.SkipFrame(msg)
self.push(value)
log.debug("generic_jump triggered compile")
self.output.compile_subgraph(
self,
reason=GraphCompileReason(
f"generic_jump {typestr(value)}", [self.frame_summary()]
),
)
self.pop()
if_next = self.create_call_resume_at(self.next_instruction)
push and self.push(value)
if_jump = self.create_call_resume_at(inst.target)
self.output.add_output_instructions(
[create_instruction(inst.opname, target=if_jump[0])] + if_next + if_jump
)
has_backedge() 用来判断待编译的函数中剩余的字节码指令中是否存在后向跳转(循环指令),如果 for/while
中存在 Graph Break,TorchDynamo 会跳过这个函数,并抛出 exc.SkipFrame
异常,它在此前所讲的 _compile() 中被捕获。
因为此处的 POP_JUMP_IF_FALSE
无法被处理,TorchDynamo 又把已出栈的元素 TensorVariable(lt)
压栈,然后调用 OutputGraph.compile_subgraph() 进行子图编译,进行子图编译的原因是 generic_jump
依赖 TensorVariable
的值。
compile_subgraph()
中,tx.prune_dead_locals()
用于从 self.symbolic_locals
中剔除此后不再需要的局部变量,实现于 symbolic_convert.py#L458,原理是通过活跃变量分析来收集在当前字节码指令之后被用到的局部变量,而不在这些局部变量中的局部变量就可以剔除,详见 livevars_analysis()。从 POP_JUMP_IF_FALSE
往后只有变量 b, x
被用到,它们是活跃变量,self.symbolic_locals
中的非活跃变量被丢弃。output_graph.py#L545 的 stack_values
是 [TensorVariable(lt), TensorVariable(truediv)]
,对应 b.sum()
和 x
。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/output_graph.py#L544
self.add_output_instructions(
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
+ [create_instruction("UNPACK_SEQUENCE", len(stack_values))]
)
compile_and_call_fx_graph() 用于 从 OutputGraph
中生成代码。其中首先根据传入的两个输出节点更新了 Guard,create_node() 在 FX Graph 中创建了类型为 output
的 FX Proxy,一张完整的 FX Graph 到此构建完毕。remove_unused_graphargs() 从 FX Graph 中删除没有 user
的节点。fx.GraphModule(root, self.graph)
从 fx.Graph
创建 fx.GraphModule
(GraphModule),通过 recompile() 生成对应的 Python 代码,新编译的函数名为 __compiled_fn_0()
。从 FX Graph 生成 Python 代码的过程在 TorchFX 源码解析 中分解过,这里不再赘述。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/output_graph.py#L609
gm = fx.GraphModule(root, self.graph)
gm.recompile()
gm.compile_subgraph_reason = self.compile_subgraph_reason
name = unique_id("__compiled_fn")
assert_no_fake_params_or_buffers(gm)
compiled_fn = self.call_user_compiler(gm)
compiled_fn = disable(compiled_fn)
counters["stats"]["unique_graphs"] += 1
self.install_global(name, compiled_fn)
到此为止捕获的 FX graph 为:
>>> FX graph:
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f1c83e6c720> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}
__compiled_fn_0()
对应的 Python 函数为:
def forward(self, a : torch.Tensor, b : torch.Tensor):
abs_1 = torch.abs(a)
add = abs_1 + 1; abs_1 = None
truediv = a / add; a = add = None
sum_1 = b.sum(); b = None
lt = sum_1 < 0; sum_1 = None
return (truediv, lt)
call_user_compiler() 调用用户指定的 backend compiler 来编译 fx.GraphModule
,实现在 output_graph.py#L644,编译 fx.GraphModule
时使用 FakeTensor
,它们通过 fake_example_inputs() 由 fx.Graph
的输入节点得到。为了便于调试,backend compiler 经过 debug_wrapper 包装,最终调用到案例代码提供的 my_compiler()
。调用 backend compiler 时发生任何异常,都会被捕获并作为 BackendCompilerFailed
抛出,外层的 InstructionTranslator
在 symbolic_convert.py#L600 捕获到 BackendCompilerFailed
并继续抛出异常。
@dynamo_timed(phase_name="backend_compile")
def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
# ...
try:
# ...
compiler_fn = self.compiler_fn
# ...
if torch._dynamo.debug_utils.MINIFIER_SPAWNED or is_top_level_minifying:
compiled_fn = compiler_fn(gm, self.example_inputs())
elif config.DO_NOT_USE_legacy_non_fake_example_inputs:
compiled_fn = compiler_fn(gm, self.example_inputs())
else:
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
_step_logger()(logging.INFO, f"done compiler function {name}")
assert callable(compiled_fn), "compiler_fn did not return callable"
except Exception as e:
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
e.__traceback__
) from None
return compiled_fn
调用 backend compiler 时的函数调用栈如下:
- [P077] > test.py#L6:my_compiler [New]
- [P073] > torch/_dynamo/debug_utils.py#L1006 [New]
- [P069] > torch/_dynamo/output_graph.py#L644 [New]
- [P065] > torch/_dynamo/utils.py#L158
- [P061] > torch/_dynamo/output_graph.py#L583 [New]
- [P057] > torch/_dynamo/output_graph.py#L467 [New]
- [P053] > torch/_dynamo/symbolic_convert.py#L234 [New]
- [P049] > torch/_dynamo/symbolic_convert.py#L537
编译完成后返回 output_graph.py#L615,disable(compiled_fn)
设置了禁止 TorchDynamo 再次编译已经编译过的函数。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/output_graph.py#L640
cg = PyCodegen(tx)
cg.make_call_generated_code(name)
return cg.get_instructions()
PyCodegen 用于辅助生成 Python 字节码,make_call_generated_code() 用于生成调用函数 fn_name
的字节码指令。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/codegen.py#L345
def make_call_generated_code(self, fn_name: str) -> List[Instruction]:
"""Call the generated code function stored in fn_name"""
self.extend_output(self.load_function_name(fn_name, True))
graphargs = self.tx.output.graphargs
for arg in graphargs:
if arg.is_unspecialized:
self.extend_output(
[
self.create_load_python_module(torch, True),
self.create_load_attr("tensor"),
]
)
self.extend_output(arg.load(self))
self.extend_output(create_call_function(1, False))
else:
self.extend_output(arg.load(self))
self.extend_output(create_call_function(len(graphargs), False))
make_call_generated_code()
首先用 load_function_name 生成从全局空间加载 __compiled_fn_0
的字节码指令,以结构化的 Instruction 表示在 PyCodegen
中:
Instruction(opcode=116, opname='LOAD_GLOBAL', arg=3, argval='__compiled_fn_0', offset=None, starts_line=None, is_jump_target=False, target=None)
这里的函数调用栈为:
- [P081] > torch/_dynamo/bytecode_transformation.py#L52
- [P077] > torch/_dynamo/bytecode_transformation.py#L66 [New]
- [P073] > torch/_dynamo/codegen.py#L226 [New]
- [P069] > torch/_dynamo/codegen.py#L269 [New]
- [P065] > torch/_dynamo/codegen.py#L345 [New]
- [P061] > torch/_dynamo/output_graph.py#L583 [New]
- [P057] > torch/_dynamo/output_graph.py#L467 [New]
对于 OuputGraph
中的输入参数 arg
,arg.load(self)
为编译好的子图生成加载输入参数的字节码指令,这通过 create_load() 实现,生成两条 Instruction
:
Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='a', offset=None, starts_line=None, is_jump_target=False, target=None)
Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=None, starts_line=None, is_jump_target=False, target=None)
这里的函数调用栈为:
- [P081] > torch/_dynamo/bytecode_transformation.py#L52
- [P077] > torch/_dynamo/codegen.py#L194 [New]
- [P073] > torch/_dynamo/source.py#L52 [New]
- [P069] > torch/_dynamo/variables/builder.py#L136
- [P065] > torch/_dynamo/codegen.py#L345 [New]
- [P061] > torch/_dynamo/output_graph.py#L583 [New]
- [P057] > torch/_dynamo/output_graph.py#L467 [New]
create_call_function() 生成一条函数调用的字节码指令,len(graphargs)
表明调用该函数需要多少个函数参数:
Instruction(opcode=131, opname='CALL_FUNCTION', arg=2, argval=2, offset=None, starts_line=None, is_jump_target=False, target=None)
此后返回到 compile_subgraph,output_graph.py#L546 追加了一条 UNPACK_SEQUENCE
指令,用于拆分作为元组返回的函数结果。create_store() 把此时的活跃变量 x
保存起来,通过生成 STORE_FAST
指令实现。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/output_graph.py#L544
if (
# ...
):
# optimization to generate better code in a common case
self.add_output_instructions(
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
+ [create_instruction("UNPACK_SEQUENCE", len(stack_values))]
)
else:
# ...
# restore all the live local vars
self.add_output_instructions(
[PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
)
这里的函数调用栈为:
- [P069] > torch/_dynamo/bytecode_transformation.py#L52
- [P065] > torch/_dynamo/codegen.py#L214 [New]
- [P061] > torch/_dynamo/output_graph.py#L580 [New]
- [P057] > torch/_dynamo/output_graph.py#L467 [New]
通过 compile_subgraph()
实现的子图编译到此完成,回到 symbolic_convert.py#L275,此时处于栈顶的 TensorVariable(lt)
已经无用,因为此前通过 STORE_FAST
指令把变量 x
保存了起来:
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L273
self.push(value)
log.debug("generic_jump triggered compile")
self.output.compile_subgraph(
self,
reason=GraphCompileReason(
f"generic_jump {typestr(value)}", [self.frame_summary()]
),
)
self.pop()
if_next = self.create_call_resume_at(self.next_instruction)
push and self.push(value)
if_jump = self.create_call_resume_at(inst.target)
self.output.add_output_instructions(
[create_instruction(inst.opname, target=if_jump[0])] + if_next + if_jump
)
随后,self.create_call_resume_at(self.next_instruction)
生成 if
为 True
时的字节码指令,self.create_call_resume_at(inst.target)
生成 if
为 False
时的字节码指令。应当注意的是,生成这两部分字节码指令的时机是在运行前,此时我们并不知道 if
分支和 else
分支的子图是什么样子。在 create_call_resume_at 中,首先通过 livevars_analysis 分析在当前字节码指令处还处在活跃状态的变量,随后生成名为 __resume_at_30_1
的函数的字节码。
ContinueExecutionCache.lookup() 生成从 offset 30 开始恢复执行时的字节码,如果缓存中没有已经生成好的字节码片段,则通过 generate() 生成,它调用 transform_code_object() 来完成此功能,此前我们已经详细介绍过 transform_code_object()
的绝大多数功能,这里不再赘述。
transform_code_object()
利用 update() 来更新字节码,create_jump_absolute(target)
创建了直接跳转到 offset 30 处的 JUMP_ABSOLUTE
指令,并作为待编译字节码的首条指令。此处的函数调用栈为:
- [P077] > torch/_dynamo/bytecode_transformation.py#L61 [New]
- [P073] > torch/_dynamo/resume_execution.py#L261 [New]
- [P069] > torch/_dynamo/bytecode_transformation.py#L488
- [P065] > torch/_dynamo/resume_execution.py#L237 [New]
- [P061] > torch/_dynamo/resume_execution.py#L228 [New]
- [P057] > torch/_dynamo/symbolic_convert.py#L1853 [New]
- [P053] > torch/_dynamo/symbolic_convert.py#L234 [New]
- [P049] > torch/_dynamo/symbolic_convert.py#L537
在 transform_code_object()
的最后,clean_and_assemble_instructions(instructions, keys, code_options)[1]
执行 字节码清理和汇编,实现在 clean_and_assemble_instructions()。在执行基本的清理过后,assemble 用 bytes([(inst.opcode, inst.arg)])
从结构化的 Instruction
生成二进制字节码,并最终通过 types.CodeType()
生成编译好的 Python Code。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/bytecode_transformation.py#L547
bytecode, lnotab = assemble(instructions, code_options["co_firstlineno"])
生成字节码时的函数调用栈为:
- [P077] > torch/_dynamo/bytecode_transformation.py#L217 [New]
- [P073] > torch/_dynamo/bytecode_transformation.py#L534 [New]
- [P069] > torch/_dynamo/bytecode_transformation.py#L488
- [P065] > torch/_dynamo/resume_execution.py#L237 [New]
- [P061] > torch/_dynamo/resume_execution.py#L228 [New]
- [P057] > torch/_dynamo/symbolic_convert.py#L1853 [New]
- [P053] > torch/_dynamo/symbolic_convert.py#L234 [New]
- [P049] > torch/_dynamo/symbolic_convert.py#L537
TorchDynamo 针对 if
为 True
时在原函数 toy_example()
的基础上生成函数 __resume_at_30_1()
,它在 toy_example()
字节码的 offset 0 处插入了一条 JUMP_ABSOLUTE
指令,直接跳转到 if
分支所在的指令 offset 32 处:
0 JUMP_ABSOLUTE 32
2 LOAD_FAST 2 (a)
4 LOAD_GLOBAL 0 (torch)
6 LOAD_ATTR 1 (abs)
8 LOAD_FAST 2 (a)
10 CALL_FUNCTION 1
12 LOAD_CONST 1 (1)
14 BINARY_ADD
16 BINARY_TRUE_DIVIDE
18 STORE_FAST 1 (x)
20 LOAD_FAST 0 (b)
22 LOAD_ATTR 2 (sum)
24 CALL_FUNCTION 0
26 LOAD_CONST 2 (0)
28 COMPARE_OP 0 (<)
30 POP_JUMP_IF_FALSE 40
>> 32 LOAD_FAST 0 (b)
34 LOAD_CONST 3 (-1)
36 BINARY_MULTIPLY
38 STORE_FAST 0 (b)
>> 40 LOAD_FAST 1 (x)
42 LOAD_FAST 0 (b)
44 BINARY_MULTIPLY
46 RETURN_VALUE
为了能够调用 __resume_at_30_1()
来执行 if
分支,TorchDynamo 在 symbolic_convert.py#L1903 处通过 load_function_name() 生成 LOAD_GLOBAL
指令来从全局空间加载函数 __resume_at_30_1()
,create_load() 生成两条 LOAD_FAST
指令生成调用函数 __resume_at_30_1()
所需要的两个参数 b
和 x
,create_call_function() 生成 CALL_FUNCTION
指令以调用 __resume_at_30_1()
,create_instruction(“RETURN_VALUE”) 生成 RETURN_VALUE
指令从 if
分支返回主调函数。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L1903
else:
# ...
cg.extend_output(cg.load_function_name(name, True, stack_len))
cg.extend_output([cg.create_load(k) for k in argnames])
cg.extend_output(create_call_function(nargs, False))
cg.append_output(create_instruction("RETURN_VALUE"))
return cg.get_instructions()
这里生成的调用 __resume_at_30_1()
的字节码指令为:
Instruction(opcode=116, opname='LOAD_GLOBAL', arg=4, argval='__resume_at_30_1', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=2, argval='x', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=131, opname='CALL_FUNCTION', arg=2, argval=2, offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=83, opname='RETURN_VALUE', arg=None, argval=None, offset=None, starts_line=None, is_jump_target=False, target=None)
回到 symbolic_convert.py#L283-L285,TorchDynamo 再次利用 create_call_resume_at() 为 else
分支创建字节码,其过程与上面的 if
分支过程一致,这里不再赘述:
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L283
if_next = self.create_call_resume_at(self.next_instruction)
push and self.push(value)
if_jump = self.create_call_resume_at(inst.target)
TorchDynamo 针对 if
为 False
时在原函数 toy_example()
的基础上生成函数 __resume_at_38_2()
,它在 toy_example()
字节码的 offset 0 处插入了一条 JUMP_ABSOLUTE
指令,直接跳转到 if
分支所在的指令 offset 40 处:
0 JUMP_ABSOLUTE 40
2 LOAD_FAST 2 (a)
4 LOAD_GLOBAL 0 (torch)
6 LOAD_ATTR 1 (abs)
8 LOAD_FAST 2 (a)
10 CALL_FUNCTION 1
12 LOAD_CONST 1 (1)
14 BINARY_ADD
16 BINARY_TRUE_DIVIDE
18 STORE_FAST 1 (x)
20 LOAD_FAST 0 (b)
22 LOAD_ATTR 2 (sum)
24 CALL_FUNCTION 0
26 LOAD_CONST 2 (0)
28 COMPARE_OP 0 (<)
30 POP_JUMP_IF_FALSE 40
32 LOAD_FAST 0 (b)
34 LOAD_CONST 3 (-1)
36 BINARY_MULTIPLY
38 STORE_FAST 0 (b)
>> 40 LOAD_FAST 1 (x)
42 LOAD_FAST 0 (b)
44 BINARY_MULTIPLY
46 RETURN_VALUE
生成的调用 __resume_at_38_2()
的字节码指令为:
Instruction(opcode=116, opname='LOAD_GLOBAL', arg=5, argval='__resume_at_38_2', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=2, argval='x', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=131, opname='CALL_FUNCTION', arg=2, argval=2, offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=83, opname='RETURN_VALUE', arg=None, argval=None, offset=None, starts_line=None, is_jump_target=False, target=None)
回到 symbolic_convert.py#L287,示例代码因为 if
产生的 POP_JUMP_IF_FALSE
产生了 graph break,这里拷贝了 POP_JUMP_IF_FALSE
指令,并把刚才生成的 if
分支的字节码指令 if_next
和 else
分支生成的 if_jump
添加到了当前 OutputGraph
的字节码指令中:
self.output.add_output_instructions(
[create_instruction(inst.opname, target=if_jump[0])] + if_next + if_jump
)
此时 OutputGraph
中的的字节码指令如下:
Instruction(opcode=116, opname='LOAD_GLOBAL', arg=3, argval='__compiled_fn_0', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='a', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=131, opname='CALL_FUNCTION', arg=2, argval=2, offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=92, opname='UNPACK_SEQUENCE', arg=2, argval=2, offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=125, opname='STORE_FAST', arg=2, argval='x', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=114, opname='POP_JUMP_IF_FALSE', arg=None, argval=None, offset=None, starts_line=None, is_jump_target=False, target=Instruction(opcode=116, opname='LOAD_GLOBAL', arg=5, argval='__resume_at_38_2', offset=None, starts_line=None, is_jump_target=False, target=None)),
Instruction(opcode=116, opname='LOAD_GLOBAL', arg=4, argval='__resume_at_30_1', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=2, argval='x', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=131, opname='CALL_FUNCTION', arg=2, argval=2, offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=83, opname='RETURN_VALUE', arg=None, argval=None, offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=116, opname='LOAD_GLOBAL', arg=5, argval='__resume_at_38_2', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=2, argval='x', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=131, opname='CALL_FUNCTION', arg=2, argval=2, offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=83, opname='RETURN_VALUE', arg=None, argval=None, offset=None, starts_line=None, is_jump_target=False, target=None)
此时完整的函数调用栈如下,其中省略了执行每个 Python 函数需经过的 3 个 TorchDynamo 的 frame evaluation 函数:
- [P053] > torch/_dynamo/symbolic_convert.py#L234 [New]
- [P049] > torch/_dynamo/symbolic_convert.py#L537
- [P045] > torch/_dynamo/symbolic_convert.py#L590 [New]
- [P041] > torch/_dynamo/symbolic_convert.py#L1838 [New]
- [P037] > torch/_dynamo/convert_frame.py#L298 [New]
- [P033] > torch/_dynamo/bytecode_transformation.py#L488 [New]
- [P029] > torch/_dynamo/convert_frame.py#L279 [New]
- [P025] > torch/_dynamo/utils.py#L158 [New]
- [P021] > torch/_dynamo/convert_frame.py#L200 [New]
- [P017] > torch/_dynamo/convert_frame.py#L96 [New]
- [P013] > torch/_dynamo/convert_frame.py#L403 [New]
- [P009] > torch/_dynamo/eval_frame.py#L362
- [C008] > torch/csrc/dynamo/eval_frame.c#L355
- [C007] > torch/csrc/dynamo/eval_frame.c#L621
- [C006] > torch/csrc/dynamo/eval_frame.c#L346
- [C005] > torch/csrc/dynamo/eval_frame.c#L399
- [C004] > torch/csrc/dynamo/eval_frame.c#L640
- [C003] > torch/csrc/dynamo/eval_frame.c#L621
- [C002] > torch/csrc/dynamo/eval_frame.c#L346
- [P001] > torch/_dynamo/eval_frame.py#L233 [New]
- [P000] > test.py#L17:test [New]
从 inner()
和 step()
返回到 symbolic_convert.py#L594,因为此前在做子图编译的时候设置了 OutputGraph
的 should_exit
为 True
,因此 碰到子图编译会在此停止字节码翻译。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L594
while (
self.instruction_pointer is not None
and not self.output.should_exit
and self.step()
):
pass
沿着上面的调用栈一路返回到 convert_frame.py#L320,remove_dead_code() 执行 死码消除,remove_pointless_jumps() 消除无意义的跳转。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/convert_frame.py#L320
if config.dead_code_elimination:
instructions[:] = remove_pointless_jumps(remove_dead_code(instructions))
随后返回到 bytecode_transformation.py#L531,clean_and_assemble_instructions() 用于把上面结构化的 Instruction
汇编为 Python 可执行的字节码,此前已经分析过它的流程,此处不再赘述。
经过 TorchDynamo 的子图编译,toy_example()
的原字节码:
0 LOAD_FAST 0 (a)
2 LOAD_GLOBAL 0 (torch)
4 LOAD_METHOD 1 (abs)
6 LOAD_FAST 0 (a)
8 CALL_METHOD 1
10 LOAD_CONST 1 (1)
12 BINARY_ADD
14 BINARY_TRUE_DIVIDE
16 STORE_FAST 2 (x)
18 LOAD_FAST 1 (b)
20 LOAD_METHOD 2 (sum)
22 CALL_METHOD 0
24 LOAD_CONST 2 (0)
26 COMPARE_OP 0 (<)
28 POP_JUMP_IF_FALSE 38
30 LOAD_FAST 1 (b)
32 LOAD_CONST 3 (-1)
34 BINARY_MULTIPLY
36 STORE_FAST 1 (b)
>> 38 LOAD_FAST 2 (x)
40 LOAD_FAST 1 (b)
42 BINARY_MULTIPLY
44 RETURN_VALUE
被编译为以下字节码:
0 LOAD_GLOBAL 3 (__compiled_fn_0)
2 LOAD_FAST 0 (a)
4 LOAD_FAST 1 (b)
6 CALL_FUNCTION 2
8 UNPACK_SEQUENCE 2
10 STORE_FAST 2 (x)
12 POP_JUMP_IF_FALSE 24
14 LOAD_GLOBAL 4 (__resume_at_30_1)
16 LOAD_FAST 1 (b)
18 LOAD_FAST 2 (x)
20 CALL_FUNCTION 2
22 RETURN_VALUE
>> 24 LOAD_GLOBAL 5 (__resume_at_38_2)
26 LOAD_FAST 1 (b)
28 LOAD_FAST 2 (x)
30 CALL_FUNCTION 2
32 RETURN_VALUE
编译后的字节码中调用了 3 个函数:
__compiled_fn_0()
: TorchDynamo 编译好的子图,对应if
语句前面的部分;__resume_at_30_1()
: TorchDynamo 未编译的子图,对应if
分支;__resume_at_38_2()
: TorchDynamo 未编译的子图,对应else
分支;