初始化 InstructionTranslator Link to heading

清理后的字节码指令经由 transformations(instructions, code_options) 开始执行变换,transformations 是之前的 transform(),其中首先实例化了 InstructionTranslator,定义于 torch/_dynamo/symbolic_convert.py#L1747InstructionTranslator 中有一个 OutputGraph 的实例,OutputGraph 本身是一个 torch.fx.TracerOutputGraph 用于保存 InstructionTranslator字节码翻译 后的输出,以 torch.fx.Graph 表达。OutputGraph 中还包含一个 SideEffects 的实例,它用于跟踪带有副作用 (side effect) 的操作,比如修改列表、setattr()。此处的函数调用栈为:

InstructionTranslator 继承自 InstructionTranslatorBaseInstructionTranslatorBase 的本质是一个 Python 虚拟机的模拟器,它记录了待编译 Python 函数 栈帧 (Frame) 中的 运行时信息,类型是 PyFrameObject,包含了待编译函数在运行时的全局变量 f_globals、局部变量 f_locals、代码对象 f_code、预处理后的字节码指令,TorchDyanmo 在 torch/_dynamo/convert_frame.py#L264-L267 通过 frame 获得这些属性:

return _compile(
    frame.f_code,
    frame.f_globals,
    frame.f_locals,
    frame.f_builtins,
    compiler_fn,
    one_graph,
    export,
    hooks,
    frame,
)

除此之外,InstructionTranslatorBase 还包含了:

  • instruction_pointer: Python 虚拟机的 PC,表明当前正在执行的字节码指令所处位置;
  • stack: Python 虚拟机的 数据栈,因为 Python 虚拟机是 Stack Machine,而不是类似 X86 的 Register Machine,所以 Python 虚拟机中字节码之间通过数据栈交换数据;
  • block_stack: Python 虚拟机的 block stack,用于记录一些特殊的块结构,比如循环、上下文 (with);

torch/_dynamo/symbolic_convert.py#L1781-L1796InstructionTranslator 通过 Code 对象的 co_varnamesco_cellvarsco_freevars 3 个字段获取 待编译函数中用到的变量名,其中 co_varnames 是局部变量名,co_cellvarsco_freevars 是闭包中用到的变量名。对于案例 toy_example(),用到的变量名是 ('a', 'b', 'x')

vars = list(code_options["co_varnames"])
vars.extend(x for x in self.cell_and_freevars() if x not in vars)

self.symbolic_locals = collections.OrderedDict(
    (
        k,
        VariableBuilder(
            self,
            LocalInputSource(k, code_options["co_varnames"].index(k))
            if k in code_options["co_varnames"]
            else LocalSource((k)),
        )(f_locals[k]),
    )
    for k in vars
    if k in f_locals
)

在初始化 InstructionTranslator 的过程中,它为待编译函数中每一个被用到的变量创建一个 VariableTracker,作为 symbolic_localsTorchDynamo 捕获计算图的过程通过字节码翻译,而不单单是传统意义上的 tracing,例如 PyTorch 中的 JIT tracing 或者 FX 的 symbolic tracing。因此,TorchDynamo 需要有一个 类型系统,用于记录每个 Python 变量对应的类型信息,VariableTracker 的功能正是如此。

这个过程通过 VariableBuilder 实现,主要逻辑实现于builder.py#L745-L756:

tensor_proxy = self.tx.output.create_graph_input(
    re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value)
)
tensor_variable = wrap_fx_proxy(
    tx=self.tx,
    proxy=tensor_proxy,
    example_value=value,
    guards=self.make_guards(GuardBuilder.TENSOR_MATCH),
    should_specialize=self.tensor_should_specialize(),
    ignore_subclass=ignore_subclass,
    source=self.get_source(),
)

其中通过 create_graph_input() 在 FX Graph 中创建了类型为 placeholder 的 FX Proxy,Proxy 是 FX symbloic tracing 中的 symbolmake_guards() 创建了类型为 GuardBuilder.TENSOR_MATCHGuardGuard 在 TorchDynamo 中负责检测被编译函数所引用外部数据的信息是否发生变化,如果没有发生变化则可以重用此前编译好的函数,否则当前输入对此前编译好的函数无效,需要重新编译该函数。在默认情况下,TENSOR_MATCH 主要负责检查输入的张量 shape、stride 等信息是否改变。

wrap_fx_proxy() 为刚创建的 FX Proxy 建立对应的 VariableTracker,核心逻辑在 wrap_fx_proxy_cls()。其中先通过 wrap_to_fake_tensor_and_record() 为运行时获得的 torch.Tensor 创建 FakeTensor。默认情况下,TorchDynamo 使用 FakeTensor 捕获计算图而不是真实的 torch.TensorFakeTensor 和真实的 torch.Tensor 有相同的张量信息,但没有张量内存分配。随后通过 specialize() 特化张量的信息,包括 dtypedevicelayout 等。在 static shape 模式下,还会特化 sizestrideis_contiguous 这些信息,而 dynamic shape 模式下则不会特化它们。最后,wrap_fx_proxy_cls() 通过 tensor.py#L69 创建 TensorVariable,它是 VariableTracker 的子类,用于记录 PyTorch 中的 torch.Tensor 类型。

可见,在创建 VariableTracker 的过程中,TorchDynamo 在 FX Graph 中创建了 FX Proxy、添加了 Guard、创建了 FakeTensor,初始化了 VariableTracker由于并不是所有的局部变量都会被 frame 用到,为了减少不必要的 Guard 带来的运行时开销,TorchDynamo 在逐条运行 Python 字节码的过程中有选择性的更新 FX graph 所需要的 Guard。创建 VariableTracker 时的调用栈如下:

不断重复 symbolic_convert.py#L1784-L1796,最终我们得到 symbolic_locals,它包含了待编译函数中所有引用变量的 VariableTracker。至此,InstructionTranslator 初始化完毕,返回到 transform() 中。

字节码翻译 Link to heading

transform() 中,TorchDynamo 的计算图捕获逻辑从此处正式开始:

with tracing(tracer.output.tracing_context):
    tracer.run()

实现在 InstructionTranslatorBase 中的 run() 包含一个 while 循环,每次调用 step() 处理一条字节码指令:

while (
    self.instruction_pointer is not None
    and not self.output.should_exit
    and self.step()
):
    pass

step()instruction_pointer 获取当前 step 要处理的字节码指令,并把 instruction_pointer 加一。在数据栈为空、且应该编译部分图 (partial graph) 的条件下,InstructionTranslatorBase 会备份当前状态为 checkpoint,以便以后用于恢复。应不应该编译部分图,取决于当前 block stack 中的所有条目是否都可以恢复、并且用户没有通过 one_graphnopython 指定全图编译。

此时的函数调用栈为:

逐条翻译字节码实现在 symbolic_convert.py#L560 处的 getattr(self, inst.opname)(inst),直到碰见异常或者 RETURN_VALUE 指令:

try:
    if not hasattr(self, inst.opname):
        unimplemented(f"missing: {inst.opname}")
    getattr(self, inst.opname)(inst)

    return inst.opname != "RETURN_VALUE"
except BackendCompilerFailed:
    raise
except Unsupported as exc:
    exc.real_stack.append(self.frame_summary())
    if self.empty_checkpoint():
        raise
    log.debug("step triggered compile", exc_info=True)
except Exception as exc:
    real_stack = getattr(exc, "real_stack", [])
    real_stack.append(self.frame_summary())
    exc.real_stack = real_stack  # type: ignore[attr-defined]
    raise

对于案例中的 toy_example(),TorchDynamo 初始化阶段处理过后的字节码如下:

14           0 LOAD_FAST                0 (a)
             2 LOAD_GLOBAL              0 (torch)
             4 LOAD_ATTR                1 (abs)
             6 LOAD_FAST                0 (a)
             8 CALL_FUNCTION            1
            10 LOAD_CONST               1 (1)
            12 BINARY_ADD
            14 BINARY_TRUE_DIVIDE
            16 STORE_FAST               2 (x)

15          18 LOAD_FAST                1 (b)
            20 LOAD_ATTR                2 (sum)
            22 CALL_FUNCTION            0
            24 LOAD_CONST               2 (0)
            26 COMPARE_OP               0 (<)
            28 POP_JUMP_IF_FALSE       38

16          30 LOAD_FAST                1 (b)
            32 LOAD_CONST               3 (-1)
            34 BINARY_MULTIPLY
            36 STORE_FAST               1 (b)

17     >>   38 LOAD_FAST                2 (x)
            40 LOAD_FAST                1 (b)
            42 BINARY_MULTIPLY
            44 RETURN_VALUE

它们依次在 InstructionTranslatorBase 中被翻译,并随之建立 FX Graph。对其进行 字节码翻译 的过程如下:

Offset 0, LOAD_FAST: inst.argvala,从 symbolic_locals 取出 变量 aTensorVariable,然后把它压到栈上,栈上的内容为 [TensorVariable(a)]

def LOAD_FAST(self, inst):
    name = inst.argval
    # ...
    self.push(self.symbolic_locals[name])
    # ...

Offset 2, LOAD_GLOBAL: inst.argvaltorch,它从 f_globals 中取出库 torch,调用 VariableBuilder(self, source)(value)builder.py#L391-L395 创建了 TorchVariable,含义是 PyTorch 中的某个 package,并创建了 FUNCTION_MATCH 类型的 Guard。最后把该 TorchVariable 入栈,此时栈上的内容为 [TensorVariable(a), TorchVariable(torch)]

def LOAD_GLOBAL(self, inst):
    # ...
    name = inst.argval
    # ...
    try:
        value = self.f_globals[name]
    except KeyError:
        return self.load_builtin(inst)

    source = self.get_global_source(name)
    self.push(VariableBuilder(self, source)(value))

Offset 4, LOAD_ATTR: 先出栈一个元素,即 TorchVariable(torch),然后为 getattr 创建了 BuiltinVariable、为 inst.argval(即 abs) 创建了 ConstantVariable,转入 call_function()对应着在 TorchDynamo 中执行类似 Python 的 getattr(torch, abs) 的功能

def LOAD_ATTR(self, inst):
    obj = self.pop()
    result = BuiltinVariable(getattr).call_function(
        self, [obj, ConstantVariable(inst.argval)], {}
    )
    self.push(result)

其中进行基本的检查后,它通过 propagate() 收集输入中的所有 Guardbuiltin.py#L548,这里获取 call_getattr 属性成功,然后通过 inspect.signature(handler).bind(tx, *args, **kwargs) 检查是否能够成功将参数绑定到 call_getattr 上,再通过 result = handler(tx, *args, **kwargs) 调用 call_getattr()

call_getattr() 通过 propagate() 收集了输入参数的 Guard,因为此处的 obj(即 torch) 是 TorchVariable,所以在 builtin.py#L1010-L1013torch.abs 创建了新的 TorchVariable,注意此时附加在 TorchVariable(torch) 上的 Guard 已经被传播(propagete)到了 TorchVariable(torch.abs),并且记录了来源 AttrSource(base=GlobalSource(global_name='torch'), member='abs')。随后一路返回到 LOAD_ATTR,并将结果压栈,此时栈上的内容为 [TensorVariable(a), TorchVariable(torch.abs)]

elif isinstance(obj, TorchVariable):
    member = getattr(obj.value, name)
    if is_allowed(member):
        return TorchVariable(member, **options)

Offset 6, LOAD_FAST: 再次加载变量 aTensorVariable,随后栈上的内容为 [TensorVariable(a), TorchVariable(torch.abs), TensorVariable(a)]

Offset 8, CALL_FUNCTION: 因为 CALL_FUNCTION 被装饰器 break_graph_if_unsupported() 装饰,所以执行 CALL_FUNCTION() 会先经过其中的 wrapper()。这里首先创建 checkpoint,保存了所有的状态,以便在后面出现异常时从 checkpoint 中恢复。

state = self.copy_graphstate()
reason = None
try:
    return inner_fn(self, inst)
except Unsupported as excp:
    # ...
self.restore_graphstate(state)

然后执行 inner_fn(self, inst)inner_fn 就是 CALL_FUNCTION(),其中先出栈 N 个元素作为函数参数,N 由 inst.argval 指定,这里是 1,然后再出栈 1 个元素作为函数,通过 InstructionTranslatorBase.call_function() 进行函数调用:

@break_graph_if_unsupported(push=1)
def CALL_FUNCTION(self, inst):
    args = self.popn(inst.argval)
    fn = self.pop()
    self.call_function(fn, args, {})

InstructionTranslatorBase.call_function() 调用 TorchVariable.call_function()TorchDynamo 在此处模拟执行 torch.abs(a)。首先用 propagate() 收集所有参数中的 Guard,然后匹配到 torch.py##L565-L573,此处 proxy_args_kwargs()TensorVariable(a) 获取 torch.fx.Proxy(a),它是在初始化 symbolic_locals 时创建的,然后通过 create_proxy() 创建了新的 Proxy,类型是 call_function,目标是 torch.abs,参数是 a

tensor_variable = wrap_fx_proxy(
    tx=tx,
    proxy=tx.output.create_proxy(
        "call_function",
        fn_,
        *proxy_args_kwargs(args, kwargs),
    ),
    **options,
)

最后通过 wrap_fx_proxy 创建了新的 TensorVariable 来保存结果,收集到的 Guard 信息也附加了上去,一路返回后并在 call_function() 处将结果压栈,栈上的内容为 [TensorVariable(a), TensorVariable(torch.abs(a))]。可见,TorchDynamo 在字节码分析过程中没有真正地执行指令,而是以符号分析的方式从字节码中逐步提取用到的符号和函数,为它们创建 Proxy 并添加到 FX graph 中

对应的函数调用栈为:

Offset 10, LOAD_CONST: 加载常量 1,为其创建 ConstantVariable 并压栈,栈上的内容为 [TensorVariable(a), TensorVariable(torch.abs(a)), ConstantVariable(1)]

def LOAD_CONST(self, inst):
    self.push(ConstantVariable(value=inst.argval))

Offset 12, BINARY_ADD: InstructionTranslatorBase 对常见的 Python 内建函数用 stack_op 做了封装和转发,其中 BINARY_ADD = stack_op(operator.add):

BINARY_REMAINDER = stack_op(operator.mod)
BINARY_ADD = stack_op(operator.add)
BINARY_SUBTRACT = stack_op(operator.sub)

所以 symbolic_convert.py#L560 在执行 getattr(self, 'BINARY_ADD')(inst) 时,会转到 impl():

def stack_op(fn: typing.Callable[..., object]):
    nargs = len(inspect.signature(fn).parameters)
    fn_var = BuiltinVariable(fn)

    @functools.wraps(fn)
    def impl(self: "InstructionTranslatorBase", inst: Instruction):
        self.push(fn_var.call_function(self, self.popn(nargs), {}))

    return impl

其中,fn_var 是创建闭包时创建的 BuiltinVariable,需要出栈的参数个数由 inspect.signature(fn) 确定,对于 operator.add 来说需要 2 个参数,因此出栈 TensorVariable(torch.abs(a))ConstantVariable(1),随后转到 BuiltinVariable.call_function(),这与 Offset 4 处的 LOAD_ATTR 调用的是同一个函数。其中调用 propagate() 从输入的 VariableTracker 中收集 Guard,分别是针对变量 aTENSOR_MATCH 和针对 torchFUNCTION_MATCH

proxy = tx.output.create_proxy(
    "call_function",
    fn,
    *proxy_args_kwargs(args, kwargs),
)

builtin.py#L476-L480 处先调用 proxy_args_kwargs() 获取输入参数的 Proxy,然后在 FX Graph 中创建输出节点对应的 Proxy,通过 create_proxy() 创建了类型是 call_functionProxy,目标是 operator.add,参数是 Proxy(a)1

最后用 wrap_fx_proxy() 创建了新的 TensorVariable 来保存结果,收集到的 Guard 信息也附加了上去,一路返回后在 impl() 处将结果压栈,栈上的内容为 [TensorVariable(a), TensorVariable(torch.abs(a)+1)]

Offset 14, BINARY_TRUE_DIVIDE: 与 Offset 12BINARY_ADD 类似,BINARY_TRUE_DIVIDEInstructionTranslatorBase 中被设为 stack_op(operator.truediv),随后到代码路径与 BINARY_ADD 一致。出栈 [TensorVariable(a)TensorVariable(torch.abs(a)+1),为输出创建新的 Proxy,类型是 call_function,目标是 operator.truediv,参数是 Proxy(a)Proxy(add)。再为其创建了新的 TensorVariable 来跟踪收集到的 Guard,最后把结果压栈,栈上的内容为 [TensorVariable(a/(torch.abs(a)+1))]

BINARY_TRUE_DIVIDE = stack_op(operator.truediv)

Offset 16, STORE_FAST: 出栈最后一个元素 TensorVariable(a/(torch.abs(a)+1)),并把它存在 self.symbolic_locals[inst.argval] 中,这里 self.symbolic_locals 跟踪当前 frame 的局部变量,变量名由 inst.argval 指定,即 x。此后栈上内容为空。

def STORE_FAST(self, inst):
    self.symbolic_locals[inst.argval] = self.pop()

Offset 18, LOAD_FAST: 再次回到 step() 时,因为 self.block_stackself.stack 都为空,会创建当前状态的 checkpoint。随后进入 LOAD_FAST,需要加载的变量名由 inst.argval 指定,即 b。从 self.symbolic_locals 中获取 TensorVariable(b) 并压栈,栈上的内容为 [TensorVariable(b)]

Offset 20, LOAD_ATTR: 这与 Offset 4LOAD_ATTR 类似,先出栈一个元素,即 TensorVariable(b),然后为 getattr 创建了 BuiltinVariable、为 inst.argval(即 sum) 创建了 ConstantVariable,转入 BuiltinVariable.call_function()对应着在 TorchDynamo 中执行类似 Python 的 getattr(b, "sum") 的功能。其中进行基本的检查后,它通过 VariableTracker.propagate() 收集了输入参数中的所有 Guard

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/variables/builtin.py#L548
handler = getattr(self, f"call_{self.fn.__name__}", None)
if handler:
    try:
        inspect.signature(handler).bind(tx, *args, **kwargs)
    except TypeError as exc:
        # ...

if handler:
    try:
        result = handler(tx, *args, **kwargs)
        if result is not None:
            return result.add_options(options)
    except Unsupported as exc:
        if not has_constant_handler:
            raise
        # Actually, we will handle this just fine
        exc.remove_from_stats()

builtin.py#L548 获取了 call_getattr 属性,然后通过 inspect.signature(handler).bind(tx, *args, **kwargs) 检查是否能够成功将参数绑定到 call_getattr 上。builtin.py#L561 通过 result = handler(tx, *args, **kwargs) 调用 BuiltinVariable.call_getattr(),其中再次通过 VariableTracker.propagate() 收集了输入参数的 Guard,此处的 obj(即 b) 是 TensorVariable,在 builtin.py#L1006 进入 TensorVariable.var_getattr(),但在 tensor.py#L215 抛出 NotImplementedError 异常,它被 builtin.py#L1009 捕获,进而为 b.sum 创建 GetAttrVariablebuiltin.py#L563 把输入中的 Guard 附加在输出上。最后返回到 LOAD_ATTR,并将结果压栈,此时栈上的内容为 [GetAttrVariable(TensorVariable(b),sum)]

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/variables/builtin.py#L1004
try:
    return (
        obj.var_getattr(tx, name).clone(source=source).add_options(options)
    )
except NotImplementedError:
    return GetAttrVariable(obj, name, **options)

对应的函数调用栈为:

Call Function Link to heading

Offset 22, CALL_FUNCTION: 与 Offset 8 处的 CALL_FUNCTION 相似,经 wrapper() 来到 InstructionTranslatorBase.CALL_FUNCTION(),出栈函数参数和函数后转到 InstructionTranslatorBase.call_function(),因为 fnGetAttrVariable,因此执行 fn.call_function(self, args, kwargs) 会转到 GetAttrVariable.call_function()

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L988
@break_graph_if_unsupported(push=1)
def CALL_FUNCTION(self, inst):
    args = self.popn(inst.argval)
    fn = self.pop()
    self.call_function(fn, args, {})

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L469
def call_function(
    self,
    fn: VariableTracker,
    args: List[VariableTracker],
    kwargs: Dict[str, VariableTracker],
):
    # ...
    self.push(fn.call_function(self, args, kwargs))

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/variables/misc.py#L581
def call_function(
    self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
# ...
return self.obj.call_method(tx, self.name, args, kwargs).add_options(self)

misc.py#L666self.obj.call_method(tx, self.name, ...) 中的 self.obj 的类型是 TensorVariableself.namesum,因此来到 TensorVariable.call_method()TensorVariable.call_method() 针对 PyTorch Tensor 下的许多方法进行了特殊处理,这里函数名 sum 未能成功匹配该函数中的大多数逻辑,最终来到 tensor.py#L422,此处为 sum 函数创建 torch.fx.Proxy,类型为 call_method,输入节点为 Proxy(b),目标为 sum(),并以此 Proxy 通过 wrap_fx_proxy() 创建 VariableTracker,并附着从输入 VariableTracker 中收集到的 Guard

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/variables/tensor.py#L238
def call_method(
    self,
    tx,
    name,
    args: "List[VariableTracker]",
    kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
    # ...
    else:
        # ...
        return wrap_fx_proxy(
            tx,
            tx.output.create_proxy(
                "call_method",
                name,
                *proxy_args_kwargs([self] + list(args), kwargs),
            ),
            **options,
        )

此前已经分析过 wrap_fx_proxy() 中的大部分内容,builder.py#L900 处通过 example_value = get_fake_value(proxy.node, tx) 从 FX Node 中创建 FakeTensor、并以 FakeTensor 运行该节点所代表的算子,实现在 get_fake_value()在算子运行前,TorchDynamo 打开了 Python Dispatcher,它可以在 Python 层面捕获到 PyTorch 最底层的 ATen 算子,详见 TorchDispatchMode

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/utils.py#L1152
try:
    with tx.fake_mode, enable_python_dispatcher():
        return wrap_fake_exception(
            lambda: run_node(tx.output, node, args, kwargs, nnmodule)
        )
except Unsupported:
    raise
except RuntimeError as e:
    # ...

FakeTensor 运行算子实现在 run_node(),此时 opcall_methodgetattr(args[0], node.target)(*args[1:], **kwargs)getattr(b, 'sum')(),也就是执行 b.sum(),因为 Tensor.sum() 是实现在 C++ 中的函数,这会进入到函数 THPVariable_sum() 中,它是 PyTorch 在编译过程中生成的,生成后的 THPVariable_sum() 位于 torch/csrc/autograd/generated/python_variable_methods.cpp#L15065

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/utils.py#L1178
def run_node(output_graph, node, args, kwargs, nnmodule):
    op = node.op
    try:
        if op == "call_function":
            return node.target(*args, **kwargs)
        elif op == "call_method":
            return getattr(args[0], node.target)(*args[1:], **kwargs)
        elif op == "call_module":
            assert nnmodule is not None
            return nnmodule(*args, **kwargs)
        elif op == "get_attr":
            return output_graph.get_submodule(node.target)
        elif op == "placeholder":
            assert "example_value" in node.meta
            return node.meta["example_value"]
    except Exception as e:
        raise RuntimeError(
            f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n{e}\n(scroll up for backtrace)"
        ) from e
    raise AssertionError(op)

此处的函数调用栈为:

回到 builder.py#L900,通过 specialize() 特化张量的信息,执行 return target_cls(proxy, **options) 创建 TensorVariable,它是 VariableTracker 的子类,用于记录 PyTorch 中的 torch.Tensor 类型,实现在 tensor.py#L69

因此,在为计算节点的输出创建 VariableTracker 的过程中,TorchDynamo 在 FX Graph 中创建了 FX Proxy、添加了 Guard、创建了 FakeTensor,以 FakeTensor 为输入执行了算子,初始化了 VariableTracker

最后回到 symbolic_convert.py#L494,把新创建的 TensorVariable 压栈,此时栈上的内容为 [TensorVariable(b.sum())]

Offset 24, LOAD_CONST: 加载常量 0,创建 ConstantVariable 并压栈,此时栈上的内容为 [TensorVariable(b.sum()), ConstantVariable(0)]

Offset 26, COMPARE_OP:

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L925
def COMPARE_OP(self, inst):
    left, right = self.popn(2)
    left = left.as_specialized(self)
    right = right.as_specialized(self)
    options = VariableTracker.propagate([left, right])
    op = inst.argval
    # ...
    else:
        self.push(
            BuiltinVariable(supported_any[op], **options).call_function(
                self, [left, right], {}
            )
        )

出栈要比较的两个元素 leftright,调用 VariableTracker.propagate() 收集输入中的 Guardinst.argval 是小于号 <symbolic_convert.py#L980 处的 supported_tensors[op]operator.lt,此处首先创建 BuiltinVariable,在 BuiltinVariable.call_function()result = handler(tx, *args, **kwargs)_comparison():

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/variables/builtin.py#L1178-L1186
if isinstance(left, TensorVariable):
    from .builder import wrap_fx_proxy

    if op not in supported_tensor_comparison_ops.values():
        _unimplemented()
    return wrap_fx_proxy(
        tx,
        op(left.as_proxy(), right.as_proxy()),
    )

这里的 opoperator.lt,在 FX 初始化阶段被修改为 impl(),此处创建新的 FX Proxy,opcall_functiontargetoperator.ltwrap_fx_proxy() 的功能与前面一致,从 Proxy 中创建了 FakeTensor,以 FakeTensor 为输入执行算子 THPVariable_lt(),创建新的 TensorVariable。最后 symbolic_convert.py#L979 把新的 TensorVariable 压栈,此时栈上的内容为 [TensorVariable(lt)]

这个过程的函数调用栈为: