Graph Break Link to heading

Offset 28, POP_JUMP_IF_FALSE:

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L853
POP_JUMP_IF_FALSE = generic_jump(operator.not_, False)
POP_JUMP_IF_TRUE = generic_jump(operator.truth, False)
JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True)
JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True)

POP_JUMP_IF_FALSEsymbolic_convert.py#L853 被设置为 generic_jump(operator.not_, False),真实实现位于 inner()

Python 中的 if/else 编译为字节码后是 POP_JUMP_IF_FALSE 这类条件跳转 (JUMP_IF) 指令,它们在 TorchDynamo 中会引入 Graph Break,触发子图编译 (subgraph compile)。

首先出栈一个元素,即 TensorVariable(lt),然后把附加在其上的 Guard 聚合到 Output Graph 中,最重要的地方是 symbolic_convert.py#L262-L289:

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L262-L289
elif (
    isinstance(value, (TensorVariable)) and self.should_compile_partial_graph()
):
    # compile a partial subgraph prefix then jump into user code
    if self.has_backedge():
        msg = (
            "Skipping frame because there is a graph break in a for/while loop"
        )
        log.info(msg)
        raise exc.SkipFrame(msg)

    self.push(value)
    log.debug("generic_jump triggered compile")
    self.output.compile_subgraph(
        self,
        reason=GraphCompileReason(
            f"generic_jump {typestr(value)}", [self.frame_summary()]
        ),
    )
    self.pop()

    if_next = self.create_call_resume_at(self.next_instruction)
    push and self.push(value)
    if_jump = self.create_call_resume_at(inst.target)

    self.output.add_output_instructions(
        [create_instruction(inst.opname, target=if_jump[0])] + if_next + if_jump
    )

has_backedge() 用来判断待编译的函数中剩余的字节码指令中是否存在后向跳转(循环指令),如果 for/while 中存在 Graph Break,TorchDynamo 会跳过这个函数,并抛出 exc.SkipFrame 异常,它在此前所讲的 _compile() 中被捕获。

因为此处的 POP_JUMP_IF_FALSE 无法被处理,TorchDynamo 又把已出栈的元素 TensorVariable(lt) 压栈,然后调用 OutputGraph.compile_subgraph() 进行子图编译,进行子图编译的原因是 generic_jump 依赖 TensorVariable 的值

compile_subgraph() 中,tx.prune_dead_locals() 用于从 self.symbolic_locals 中剔除此后不再需要的局部变量,实现于 symbolic_convert.py#L458原理是通过活跃变量分析来收集在当前字节码指令之后被用到的局部变量,而不在这些局部变量中的局部变量就可以剔除,详见 livevars_analysis()。从 POP_JUMP_IF_FALSE 往后只有变量 b, x 被用到,它们是活跃变量,self.symbolic_locals 中的非活跃变量被丢弃。output_graph.py#L545stack_values[TensorVariable(lt), TensorVariable(truediv)],对应 b.sum()x

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/output_graph.py#L544
self.add_output_instructions(
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
    + [create_instruction("UNPACK_SEQUENCE", len(stack_values))]
)

compile_and_call_fx_graph() 用于 OutputGraph 中生成代码。其中首先根据传入的两个输出节点更新了 Guard,create_node() 在 FX Graph 中创建了类型为 output 的 FX Proxy,一张完整的 FX Graph 到此构建完毕remove_unused_graphargs() 从 FX Graph 中删除没有 user 的节点。fx.GraphModule(root, self.graph)fx.Graph 创建 fx.GraphModule (GraphModule),通过 recompile() 生成对应的 Python 代码,新编译的函数名为 __compiled_fn_0()。从 FX Graph 生成 Python 代码的过程在 TorchFX 源码解析 中分解过,这里不再赘述。

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/output_graph.py#L609
gm = fx.GraphModule(root, self.graph)
gm.recompile()
gm.compile_subgraph_reason = self.compile_subgraph_reason
name = unique_id("__compiled_fn")

assert_no_fake_params_or_buffers(gm)
compiled_fn = self.call_user_compiler(gm)
compiled_fn = disable(compiled_fn)

counters["stats"]["unique_graphs"] += 1
self.install_global(name, compiled_fn)

到此为止捕获的 FX graph 为:

>>> FX graph:
opcode         name     target                                                  args              kwargs
-------------  -------  ------------------------------------------------------  ----------------  --------
placeholder    a        a                                                       ()                {}
placeholder    b        b                                                       ()                {}
call_function  abs_1    <built-in method abs of type object at 0x7f1c83e6c720>  (a,)              {}
call_function  add      <built-in function add>                                 (abs_1, 1)        {}
call_function  truediv  <built-in function truediv>                             (a, add)          {}
call_method    sum_1    sum                                                     (b,)              {}
call_function  lt       <built-in function lt>                                  (sum_1, 0)        {}
output         output   output                                                  ((truediv, lt),)  {}

__compiled_fn_0() 对应的 Python 函数为:

def forward(self, a : torch.Tensor, b : torch.Tensor):
    abs_1 = torch.abs(a)
    add = abs_1 + 1;  abs_1 = None
    truediv = a / add;  a = add = None
    sum_1 = b.sum();  b = None
    lt = sum_1 < 0;  sum_1 = None
    return (truediv, lt)

call_user_compiler() 调用用户指定的 backend compiler 来编译 fx.GraphModule,实现在 output_graph.py#L644,编译 fx.GraphModule 时使用 FakeTensor,它们通过 fake_example_inputs()fx.Graph 的输入节点得到。为了便于调试,backend compiler 经过 debug_wrapper 包装,最终调用到案例代码提供的 my_compiler()调用 backend compiler 时发生任何异常,都会被捕获并作为 BackendCompilerFailed 抛出,外层的 InstructionTranslatorsymbolic_convert.py#L600 捕获到 BackendCompilerFailed 并继续抛出异常。

@dynamo_timed(phase_name="backend_compile")
def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
    # ...
    try:
        # ...
        compiler_fn = self.compiler_fn
        # ...
        if torch._dynamo.debug_utils.MINIFIER_SPAWNED or is_top_level_minifying:
            compiled_fn = compiler_fn(gm, self.example_inputs())
        elif config.DO_NOT_USE_legacy_non_fake_example_inputs:
            compiled_fn = compiler_fn(gm, self.example_inputs())
        else:
            compiled_fn = compiler_fn(gm, self.fake_example_inputs())
        _step_logger()(logging.INFO, f"done compiler function {name}")
        assert callable(compiled_fn), "compiler_fn did not return callable"
    except Exception as e:
        raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
            e.__traceback__
        ) from None
    return compiled_fn

调用 backend compiler 时的函数调用栈如下:

编译完成后返回 output_graph.py#L615disable(compiled_fn) 设置了禁止 TorchDynamo 再次编译已经编译过的函数。

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/output_graph.py#L640
cg = PyCodegen(tx)
cg.make_call_generated_code(name)
return cg.get_instructions()

PyCodegen 用于辅助生成 Python 字节码,make_call_generated_code() 用于生成调用函数 fn_name 的字节码指令。

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/codegen.py#L345
def make_call_generated_code(self, fn_name: str) -> List[Instruction]:
    """Call the generated code function stored in fn_name"""
    self.extend_output(self.load_function_name(fn_name, True))

    graphargs = self.tx.output.graphargs
    for arg in graphargs:
        if arg.is_unspecialized:
            self.extend_output(
                [
                    self.create_load_python_module(torch, True),
                    self.create_load_attr("tensor"),
                ]
            )
            self.extend_output(arg.load(self))
            self.extend_output(create_call_function(1, False))
        else:
            self.extend_output(arg.load(self))

    self.extend_output(create_call_function(len(graphargs), False))

make_call_generated_code() 首先用 load_function_name 生成从全局空间加载 __compiled_fn_0 的字节码指令,以结构化的 Instruction 表示在 PyCodegen 中:

Instruction(opcode=116, opname='LOAD_GLOBAL', arg=3, argval='__compiled_fn_0', offset=None, starts_line=None, is_jump_target=False, target=None)

这里的函数调用栈为:

对于 OuputGraph 中的输入参数 argarg.load(self) 为编译好的子图生成加载输入参数的字节码指令,这通过 create_load() 实现,生成两条 Instruction:

Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='a', offset=None, starts_line=None, is_jump_target=False, target=None)
Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=None, starts_line=None, is_jump_target=False, target=None)

这里的函数调用栈为:

create_call_function() 生成一条函数调用的字节码指令,len(graphargs) 表明调用该函数需要多少个函数参数:

Instruction(opcode=131, opname='CALL_FUNCTION', arg=2, argval=2, offset=None, starts_line=None, is_jump_target=False, target=None)

此后返回到 compile_subgraphoutput_graph.py#L546 追加了一条 UNPACK_SEQUENCE 指令,用于拆分作为元组返回的函数结果。create_store() 把此时的活跃变量 x 保存起来,通过生成 STORE_FAST 指令实现。

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/output_graph.py#L544
if (
    # ...
):
    # optimization to generate better code in a common case
    self.add_output_instructions(
        self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
        + [create_instruction("UNPACK_SEQUENCE", len(stack_values))]
    )
else:
    # ...

# restore all the live local vars
self.add_output_instructions(
    [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
)

这里的函数调用栈为:

通过 compile_subgraph() 实现的子图编译到此完成,回到 symbolic_convert.py#L275,此时处于栈顶的 TensorVariable(lt) 已经无用,因为此前通过 STORE_FAST 指令把变量 x 保存了起来:

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L273
self.push(value)
log.debug("generic_jump triggered compile")
self.output.compile_subgraph(
    self,
    reason=GraphCompileReason(
        f"generic_jump {typestr(value)}", [self.frame_summary()]
    ),
)
self.pop()

if_next = self.create_call_resume_at(self.next_instruction)
push and self.push(value)
if_jump = self.create_call_resume_at(inst.target)

self.output.add_output_instructions(
    [create_instruction(inst.opname, target=if_jump[0])] + if_next + if_jump
)

随后,self.create_call_resume_at(self.next_instruction) 生成 ifTrue 时的字节码指令,self.create_call_resume_at(inst.target) 生成 ifFalse 时的字节码指令。应当注意的是,生成这两部分字节码指令的时机是在运行前,此时我们并不知道 if 分支和 else 分支的子图是什么样子。在 create_call_resume_at 中,首先通过 livevars_analysis 分析在当前字节码指令处还处在活跃状态的变量,随后生成名为 __resume_at_30_1 的函数的字节码。

ContinueExecutionCache.lookup() 生成从 offset 30 开始恢复执行时的字节码,如果缓存中没有已经生成好的字节码片段,则通过 generate() 生成,它调用 transform_code_object() 来完成此功能,此前我们已经详细介绍过 transform_code_object() 的绝大多数功能,这里不再赘述。

transform_code_object() 利用 update() 来更新字节码,create_jump_absolute(target) 创建了直接跳转到 offset 30 处的 JUMP_ABSOLUTE 指令,并作为待编译字节码的首条指令。此处的函数调用栈为:

transform_code_object() 的最后,clean_and_assemble_instructions(instructions, keys, code_options)[1] 执行 字节码清理和汇编,实现在 clean_and_assemble_instructions()。在执行基本的清理过后,assemble bytes([(inst.opcode, inst.arg)]) 从结构化的 Instruction 生成二进制字节码,并最终通过 types.CodeType() 生成编译好的 Python Code。

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/bytecode_transformation.py#L547
bytecode, lnotab = assemble(instructions, code_options["co_firstlineno"])

生成字节码时的函数调用栈为:

TorchDynamo 针对 ifTrue 时在原函数 toy_example() 的基础上生成函数 __resume_at_30_1(),它在 toy_example() 字节码的 offset 0 处插入了一条 JUMP_ABSOLUTE 指令,直接跳转到 if 分支所在的指令 offset 32 处:

      0 JUMP_ABSOLUTE           32
      2 LOAD_FAST                2 (a)
      4 LOAD_GLOBAL              0 (torch)
      6 LOAD_ATTR                1 (abs)
      8 LOAD_FAST                2 (a)
     10 CALL_FUNCTION            1
     12 LOAD_CONST               1 (1)
     14 BINARY_ADD
     16 BINARY_TRUE_DIVIDE
     18 STORE_FAST               1 (x)
     20 LOAD_FAST                0 (b)
     22 LOAD_ATTR                2 (sum)
     24 CALL_FUNCTION            0
     26 LOAD_CONST               2 (0)
     28 COMPARE_OP               0 (<)
     30 POP_JUMP_IF_FALSE       40

>>   32 LOAD_FAST                0 (b)
     34 LOAD_CONST               3 (-1)
     36 BINARY_MULTIPLY
     38 STORE_FAST               0 (b)

>>   40 LOAD_FAST                1 (x)
     42 LOAD_FAST                0 (b)
     44 BINARY_MULTIPLY
     46 RETURN_VALUE

为了能够调用 __resume_at_30_1() 来执行 if 分支,TorchDynamo 在 symbolic_convert.py#L1903 处通过 load_function_name() 生成 LOAD_GLOBAL 指令来从全局空间加载函数 __resume_at_30_1()create_load() 生成两条 LOAD_FAST 指令生成调用函数 __resume_at_30_1() 所需要的两个参数 bxcreate_call_function() 生成 CALL_FUNCTION 指令以调用 __resume_at_30_1()create_instruction(“RETURN_VALUE”) 生成 RETURN_VALUE 指令从 if 分支返回主调函数。

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L1903
else:
    # ...
    cg.extend_output(cg.load_function_name(name, True, stack_len))

cg.extend_output([cg.create_load(k) for k in argnames])
cg.extend_output(create_call_function(nargs, False))
cg.append_output(create_instruction("RETURN_VALUE"))
return cg.get_instructions()

这里生成的调用 __resume_at_30_1() 的字节码指令为:

Instruction(opcode=116, opname='LOAD_GLOBAL', arg=4, argval='__resume_at_30_1', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=2, argval='x', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=131, opname='CALL_FUNCTION', arg=2, argval=2, offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=83, opname='RETURN_VALUE', arg=None, argval=None, offset=None, starts_line=None, is_jump_target=False, target=None)

回到 symbolic_convert.py#L283-L285,TorchDynamo 再次利用 create_call_resume_at()else 分支创建字节码,其过程与上面的 if 分支过程一致,这里不再赘述:

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L283
if_next = self.create_call_resume_at(self.next_instruction)
push and self.push(value)
if_jump = self.create_call_resume_at(inst.target)

TorchDynamo 针对 ifFalse 时在原函数 toy_example() 的基础上生成函数 __resume_at_38_2(),它在 toy_example() 字节码的 offset 0 处插入了一条 JUMP_ABSOLUTE 指令,直接跳转到 if 分支所在的指令 offset 40 处:

      0 JUMP_ABSOLUTE           40
      2 LOAD_FAST                2 (a)
      4 LOAD_GLOBAL              0 (torch)
      6 LOAD_ATTR                1 (abs)
      8 LOAD_FAST                2 (a)
     10 CALL_FUNCTION            1
     12 LOAD_CONST               1 (1)
     14 BINARY_ADD
     16 BINARY_TRUE_DIVIDE
     18 STORE_FAST               1 (x)
     20 LOAD_FAST                0 (b)
     22 LOAD_ATTR                2 (sum)
     24 CALL_FUNCTION            0
     26 LOAD_CONST               2 (0)
     28 COMPARE_OP               0 (<)
     30 POP_JUMP_IF_FALSE       40
     32 LOAD_FAST                0 (b)
     34 LOAD_CONST               3 (-1)
     36 BINARY_MULTIPLY
     38 STORE_FAST               0 (b)

>>   40 LOAD_FAST                1 (x)
     42 LOAD_FAST                0 (b)
     44 BINARY_MULTIPLY
     46 RETURN_VALUE

生成的调用 __resume_at_38_2() 的字节码指令为:

Instruction(opcode=116, opname='LOAD_GLOBAL', arg=5, argval='__resume_at_38_2', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=2, argval='x', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=131, opname='CALL_FUNCTION', arg=2, argval=2, offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=83, opname='RETURN_VALUE', arg=None, argval=None, offset=None, starts_line=None, is_jump_target=False, target=None)

回到 symbolic_convert.py#L287,示例代码因为 if 产生的 POP_JUMP_IF_FALSE 产生了 graph break,这里拷贝了 POP_JUMP_IF_FALSE 指令,并把刚才生成的 if 分支的字节码指令 if_nextelse 分支生成的 if_jump 添加到了当前 OutputGraph 的字节码指令中:

self.output.add_output_instructions(
    [create_instruction(inst.opname, target=if_jump[0])] + if_next + if_jump
)

此时 OutputGraph 中的的字节码指令如下:

Instruction(opcode=116, opname='LOAD_GLOBAL', arg=3, argval='__compiled_fn_0', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='a', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=131, opname='CALL_FUNCTION', arg=2, argval=2, offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=92, opname='UNPACK_SEQUENCE', arg=2, argval=2, offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=125, opname='STORE_FAST', arg=2, argval='x', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=114, opname='POP_JUMP_IF_FALSE', arg=None, argval=None, offset=None, starts_line=None, is_jump_target=False, target=Instruction(opcode=116, opname='LOAD_GLOBAL', arg=5, argval='__resume_at_38_2', offset=None, starts_line=None, is_jump_target=False, target=None)),
Instruction(opcode=116, opname='LOAD_GLOBAL', arg=4, argval='__resume_at_30_1', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=2, argval='x', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=131, opname='CALL_FUNCTION', arg=2, argval=2, offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=83, opname='RETURN_VALUE', arg=None, argval=None, offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=116, opname='LOAD_GLOBAL', arg=5, argval='__resume_at_38_2', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=124, opname='LOAD_FAST', arg=2, argval='x', offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=131, opname='CALL_FUNCTION', arg=2, argval=2, offset=None, starts_line=None, is_jump_target=False, target=None),
Instruction(opcode=83, opname='RETURN_VALUE', arg=None, argval=None, offset=None, starts_line=None, is_jump_target=False, target=None)

此时完整的函数调用栈如下,其中省略了执行每个 Python 函数需经过的 3 个 TorchDynamo 的 frame evaluation 函数:

inner()step() 返回到 symbolic_convert.py#L594,因为此前在做子图编译的时候设置了 OutputGraphshould_exitTrue,因此 碰到子图编译会在此停止字节码翻译

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L594
while (
    self.instruction_pointer is not None
    and not self.output.should_exit
    and self.step()
):
    pass

沿着上面的调用栈一路返回到 convert_frame.py#L320remove_dead_code() 执行 死码消除remove_pointless_jumps() 消除无意义的跳转。

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/convert_frame.py#L320
if config.dead_code_elimination:
    instructions[:] = remove_pointless_jumps(remove_dead_code(instructions))

随后返回到 bytecode_transformation.py#L531clean_and_assemble_instructions() 用于把上面结构化的 Instruction 汇编为 Python 可执行的字节码,此前已经分析过它的流程,此处不再赘述。

经过 TorchDynamo 的子图编译,toy_example() 的原字节码:

      0 LOAD_FAST                0 (a)
      2 LOAD_GLOBAL              0 (torch)
      4 LOAD_METHOD              1 (abs)
      6 LOAD_FAST                0 (a)
      8 CALL_METHOD              1
     10 LOAD_CONST               1 (1)
     12 BINARY_ADD
     14 BINARY_TRUE_DIVIDE
     16 STORE_FAST               2 (x)

     18 LOAD_FAST                1 (b)
     20 LOAD_METHOD              2 (sum)
     22 CALL_METHOD              0
     24 LOAD_CONST               2 (0)
     26 COMPARE_OP               0 (<)
     28 POP_JUMP_IF_FALSE       38

     30 LOAD_FAST                1 (b)
     32 LOAD_CONST               3 (-1)
     34 BINARY_MULTIPLY
     36 STORE_FAST               1 (b)

>>   38 LOAD_FAST                2 (x)
     40 LOAD_FAST                1 (b)
     42 BINARY_MULTIPLY
     44 RETURN_VALUE

被编译为以下字节码:

      0 LOAD_GLOBAL              3 (__compiled_fn_0)
      2 LOAD_FAST                0 (a)
      4 LOAD_FAST                1 (b)
      6 CALL_FUNCTION            2
      8 UNPACK_SEQUENCE          2
     10 STORE_FAST               2 (x)
     12 POP_JUMP_IF_FALSE       24
     14 LOAD_GLOBAL              4 (__resume_at_30_1)
     16 LOAD_FAST                1 (b)
     18 LOAD_FAST                2 (x)
     20 CALL_FUNCTION            2
     22 RETURN_VALUE
>>   24 LOAD_GLOBAL              5 (__resume_at_38_2)
     26 LOAD_FAST                1 (b)
     28 LOAD_FAST                2 (x)
     30 CALL_FUNCTION            2
     32 RETURN_VALUE

编译后的字节码中调用了 3 个函数:

  • __compiled_fn_0(): TorchDynamo 编译好的子图,对应 if 语句前面的部分;
  • __resume_at_30_1(): TorchDynamo 未编译的子图,对应 if 分支;
  • __resume_at_38_2(): TorchDynamo 未编译的子图,对应 else 分支;