初始化 InstructionTranslator
Link to heading
清理后的字节码指令经由 transformations(instructions, code_options) 开始执行变换,transformations 是之前的 transform(),其中首先实例化了 InstructionTranslator,定义于 torch/_dynamo/symbolic_convert.py#L1747。InstructionTranslator 中有一个 OutputGraph 的实例,OutputGraph 本身是一个 torch.fx.Tracer, OutputGraph 用于保存 InstructionTranslator 做 字节码翻译 后的输出,以 torch.fx.Graph 表达。OutputGraph 中还包含一个 SideEffects 的实例,它用于跟踪带有副作用 (side effect) 的操作,比如修改列表、setattr()。此处的函数调用栈为:
- [P049] > torch/_dynamo/side_effects.py#L76 [New]
- [P045] > torch/_dynamo/output_graph.py#L171 [New]
- [P041] > torch/_dynamo/symbolic_convert.py#L1747 [New]
- [P037] > torch/_dynamo/convert_frame.py#L298 [New]
- [P033] > torch/_dynamo/bytecode_transformation.py#L488 [New]
- [P029] > torch/_dynamo/convert_frame.py#L279 [New]
InstructionTranslator 继承自 InstructionTranslatorBase,InstructionTranslatorBase 的本质是一个 Python 虚拟机的模拟器,它记录了待编译 Python 函数 栈帧 (Frame) 中的 运行时信息,类型是 PyFrameObject,包含了待编译函数在运行时的全局变量 f_globals、局部变量 f_locals、代码对象 f_code、预处理后的字节码指令,TorchDyanmo 在 torch/_dynamo/convert_frame.py#L264-L267 通过 frame 获得这些属性:
return _compile(
frame.f_code,
frame.f_globals,
frame.f_locals,
frame.f_builtins,
compiler_fn,
one_graph,
export,
hooks,
frame,
)
除此之外,InstructionTranslatorBase 还包含了:
instruction_pointer: Python 虚拟机的 PC,表明当前正在执行的字节码指令所处位置;stack: Python 虚拟机的 数据栈,因为 Python 虚拟机是 Stack Machine,而不是类似 X86 的 Register Machine,所以 Python 虚拟机中字节码之间通过数据栈交换数据;block_stack: Python 虚拟机的 block stack,用于记录一些特殊的块结构,比如循环、上下文 (with);
在 torch/_dynamo/symbolic_convert.py#L1781-L1796,InstructionTranslator 通过 Code 对象的 co_varnames、co_cellvars、co_freevars 3 个字段获取 待编译函数中用到的变量名,其中 co_varnames 是局部变量名,co_cellvars 和 co_freevars 是闭包中用到的变量名。对于案例 toy_example(),用到的变量名是 ('a', 'b', 'x')。
vars = list(code_options["co_varnames"])
vars.extend(x for x in self.cell_and_freevars() if x not in vars)
self.symbolic_locals = collections.OrderedDict(
(
k,
VariableBuilder(
self,
LocalInputSource(k, code_options["co_varnames"].index(k))
if k in code_options["co_varnames"]
else LocalSource((k)),
)(f_locals[k]),
)
for k in vars
if k in f_locals
)
在初始化 InstructionTranslator 的过程中,它为待编译函数中每一个被用到的变量创建一个 VariableTracker,作为 symbolic_locals。TorchDynamo 捕获计算图的过程通过字节码翻译,而不单单是传统意义上的 tracing,例如 PyTorch 中的 JIT tracing 或者 FX 的 symbolic tracing。因此,TorchDynamo 需要有一个 类型系统,用于记录每个 Python 变量对应的类型信息,VariableTracker 的功能正是如此。
这个过程通过 VariableBuilder 实现,主要逻辑实现于builder.py#L745-L756:
tensor_proxy = self.tx.output.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value)
)
tensor_variable = wrap_fx_proxy(
tx=self.tx,
proxy=tensor_proxy,
example_value=value,
guards=self.make_guards(GuardBuilder.TENSOR_MATCH),
should_specialize=self.tensor_should_specialize(),
ignore_subclass=ignore_subclass,
source=self.get_source(),
)
其中通过 create_graph_input() 在 FX Graph 中创建了类型为 placeholder 的 FX Proxy,Proxy 是 FX symbloic tracing 中的 symbol。make_guards() 创建了类型为 GuardBuilder.TENSOR_MATCH 的 Guard。Guard 在 TorchDynamo 中负责检测被编译函数所引用外部数据的信息是否发生变化,如果没有发生变化则可以重用此前编译好的函数,否则当前输入对此前编译好的函数无效,需要重新编译该函数。在默认情况下,TENSOR_MATCH 主要负责检查输入的张量 shape、stride 等信息是否改变。
wrap_fx_proxy() 为刚创建的 FX Proxy 建立对应的 VariableTracker,核心逻辑在 wrap_fx_proxy_cls()。其中先通过 wrap_to_fake_tensor_and_record() 为运行时获得的 torch.Tensor 创建 FakeTensor。默认情况下,TorchDynamo 使用 FakeTensor 捕获计算图而不是真实的 torch.Tensor,FakeTensor 和真实的 torch.Tensor 有相同的张量信息,但没有张量内存分配。随后通过 specialize() 特化张量的信息,包括 dtype、device、layout 等。在 static shape 模式下,还会特化 size、stride、is_contiguous 这些信息,而 dynamic shape 模式下则不会特化它们。最后,wrap_fx_proxy_cls() 通过 tensor.py#L69 创建 TensorVariable,它是 VariableTracker 的子类,用于记录 PyTorch 中的 torch.Tensor 类型。
可见,在创建 VariableTracker 的过程中,TorchDynamo 在 FX Graph 中创建了 FX Proxy、添加了 Guard、创建了 FakeTensor,初始化了 VariableTracker。由于并不是所有的局部变量都会被 frame 用到,为了减少不必要的 Guard 带来的运行时开销,TorchDynamo 在逐条运行 Python 字节码的过程中有选择性的更新 FX graph 所需要的 Guard。创建 VariableTracker 时的调用栈如下:
- [P089] > torch/_dynamo/variables/tensor.py#L69
- [P085] > torch/_dynamo/variables/base.py#L26
- [P081] > torch/_dynamo/variables/base.py#L61
- [P077] > torch/_dynamo/variables/base.py#L72
- [P073] > torch/_dynamo/variables/base.py#L270
- [P069] > torch/_dynamo/variables/base.py#L26
- [P065] > torch/_dynamo/variables/builder.py#L876 [New]
- [P061] > torch/_dynamo/variables/builder.py#L864 [New]
- [P057] > torch/_dynamo/variables/builder.py#L704 [New]
- [P053] > torch/_dynamo/variables/builder.py#L294 [New]
- [P049] > torch/_dynamo/variables/builder.py#L170 [New]
- [P045] > torch/_dynamo/symbolic_convert.py#L1784 [New]
- [P041] > torch/_dynamo/symbolic_convert.py#L1747 [New]
- [P037] > torch/_dynamo/convert_frame.py#L298 [New]
- [P033] > torch/_dynamo/bytecode_transformation.py#L488 [New]
- [P029] > torch/_dynamo/convert_frame.py#L279 [New]
不断重复 symbolic_convert.py#L1784-L1796,最终我们得到 symbolic_locals,它包含了待编译函数中所有引用变量的 VariableTracker。至此,InstructionTranslator 初始化完毕,返回到 transform() 中。
字节码翻译 Link to heading
在 transform() 中,TorchDynamo 的计算图捕获逻辑从此处正式开始:
with tracing(tracer.output.tracing_context):
tracer.run()
实现在 InstructionTranslatorBase 中的 run() 包含一个 while 循环,每次调用 step() 处理一条字节码指令:
while (
self.instruction_pointer is not None
and not self.output.should_exit
and self.step()
):
pass
step() 用 instruction_pointer 获取当前 step 要处理的字节码指令,并把 instruction_pointer 加一。在数据栈为空、且应该编译部分图 (partial graph) 的条件下,InstructionTranslatorBase 会备份当前状态为 checkpoint,以便以后用于恢复。应不应该编译部分图,取决于当前 block stack 中的所有条目是否都可以恢复、并且用户没有通过 one_graph 或 nopython 指定全图编译。
此时的函数调用栈为:
- [P049] > torch/_dynamo/symbolic_convert.py#L537 [New]
- [P045] > torch/_dynamo/symbolic_convert.py#L590 [New]
- [P041] > torch/_dynamo/symbolic_convert.py#L1838 [New]
- [P037] > torch/_dynamo/convert_frame.py#L298 [New]
- [P033] > torch/_dynamo/bytecode_transformation.py#L488 [New]
- [P029] > torch/_dynamo/convert_frame.py#L279 [New]
逐条翻译字节码实现在 symbolic_convert.py#L560 处的 getattr(self, inst.opname)(inst),直到碰见异常或者 RETURN_VALUE 指令:
try:
if not hasattr(self, inst.opname):
unimplemented(f"missing: {inst.opname}")
getattr(self, inst.opname)(inst)
return inst.opname != "RETURN_VALUE"
except BackendCompilerFailed:
raise
except Unsupported as exc:
exc.real_stack.append(self.frame_summary())
if self.empty_checkpoint():
raise
log.debug("step triggered compile", exc_info=True)
except Exception as exc:
real_stack = getattr(exc, "real_stack", [])
real_stack.append(self.frame_summary())
exc.real_stack = real_stack # type: ignore[attr-defined]
raise
对于案例中的 toy_example(),TorchDynamo 初始化阶段处理过后的字节码如下:
14 0 LOAD_FAST 0 (a)
2 LOAD_GLOBAL 0 (torch)
4 LOAD_ATTR 1 (abs)
6 LOAD_FAST 0 (a)
8 CALL_FUNCTION 1
10 LOAD_CONST 1 (1)
12 BINARY_ADD
14 BINARY_TRUE_DIVIDE
16 STORE_FAST 2 (x)
15 18 LOAD_FAST 1 (b)
20 LOAD_ATTR 2 (sum)
22 CALL_FUNCTION 0
24 LOAD_CONST 2 (0)
26 COMPARE_OP 0 (<)
28 POP_JUMP_IF_FALSE 38
16 30 LOAD_FAST 1 (b)
32 LOAD_CONST 3 (-1)
34 BINARY_MULTIPLY
36 STORE_FAST 1 (b)
17 >> 38 LOAD_FAST 2 (x)
40 LOAD_FAST 1 (b)
42 BINARY_MULTIPLY
44 RETURN_VALUE
它们依次在 InstructionTranslatorBase 中被翻译,并随之建立 FX Graph。对其进行 字节码翻译 的过程如下:
Offset 0, LOAD_FAST: inst.argval 为 a,从 symbolic_locals 取出 变量 a 的 TensorVariable,然后把它压到栈上,栈上的内容为 [TensorVariable(a)]。
def LOAD_FAST(self, inst):
name = inst.argval
# ...
self.push(self.symbolic_locals[name])
# ...
Offset 2, LOAD_GLOBAL: inst.argval 为 torch,它从 f_globals 中取出库 torch,调用 VariableBuilder(self, source)(value) 在 builder.py#L391-L395 创建了 TorchVariable,含义是 PyTorch 中的某个 package,并创建了 FUNCTION_MATCH 类型的 Guard。最后把该 TorchVariable 入栈,此时栈上的内容为 [TensorVariable(a), TorchVariable(torch)]。
def LOAD_GLOBAL(self, inst):
# ...
name = inst.argval
# ...
try:
value = self.f_globals[name]
except KeyError:
return self.load_builtin(inst)
source = self.get_global_source(name)
self.push(VariableBuilder(self, source)(value))
Offset 4, LOAD_ATTR: 先出栈一个元素,即 TorchVariable(torch),然后为 getattr 创建了 BuiltinVariable、为 inst.argval(即 abs) 创建了 ConstantVariable,转入 call_function(),对应着在 TorchDynamo 中执行类似 Python 的 getattr(torch, abs) 的功能。
def LOAD_ATTR(self, inst):
obj = self.pop()
result = BuiltinVariable(getattr).call_function(
self, [obj, ConstantVariable(inst.argval)], {}
)
self.push(result)
其中进行基本的检查后,它通过 propagate() 收集输入中的所有 Guard,builtin.py#L548,这里获取 call_getattr 属性成功,然后通过 inspect.signature(handler).bind(tx, *args, **kwargs) 检查是否能够成功将参数绑定到 call_getattr 上,再通过 result = handler(tx, *args, **kwargs) 调用 call_getattr()。
call_getattr() 通过 propagate() 收集了输入参数的 Guard,因为此处的 obj(即 torch) 是 TorchVariable,所以在 builtin.py#L1010-L1013 为 torch.abs 创建了新的 TorchVariable,注意此时附加在 TorchVariable(torch) 上的 Guard 已经被传播(propagete)到了 TorchVariable(torch.abs),并且记录了来源 AttrSource(base=GlobalSource(global_name='torch'), member='abs')。随后一路返回到 LOAD_ATTR,并将结果压栈,此时栈上的内容为 [TensorVariable(a), TorchVariable(torch.abs)]。
elif isinstance(obj, TorchVariable):
member = getattr(obj.value, name)
if is_allowed(member):
return TorchVariable(member, **options)
Offset 6, LOAD_FAST: 再次加载变量 a 的 TensorVariable,随后栈上的内容为 [TensorVariable(a), TorchVariable(torch.abs), TensorVariable(a)]。
Offset 8, CALL_FUNCTION: 因为 CALL_FUNCTION 被装饰器 break_graph_if_unsupported() 装饰,所以执行 CALL_FUNCTION() 会先经过其中的 wrapper()。这里首先创建 checkpoint,保存了所有的状态,以便在后面出现异常时从 checkpoint 中恢复。
state = self.copy_graphstate()
reason = None
try:
return inner_fn(self, inst)
except Unsupported as excp:
# ...
self.restore_graphstate(state)
然后执行 inner_fn(self, inst),inner_fn 就是 CALL_FUNCTION(),其中先出栈 N 个元素作为函数参数,N 由 inst.argval 指定,这里是 1,然后再出栈 1 个元素作为函数,通过 InstructionTranslatorBase.call_function() 进行函数调用:
@break_graph_if_unsupported(push=1)
def CALL_FUNCTION(self, inst):
args = self.popn(inst.argval)
fn = self.pop()
self.call_function(fn, args, {})
InstructionTranslatorBase.call_function() 调用 TorchVariable.call_function(),TorchDynamo 在此处模拟执行 torch.abs(a)。首先用 propagate() 收集所有参数中的 Guard,然后匹配到 torch.py##L565-L573,此处 proxy_args_kwargs() 从 TensorVariable(a) 获取 torch.fx.Proxy(a),它是在初始化 symbolic_locals 时创建的,然后通过 create_proxy() 创建了新的 Proxy,类型是 call_function,目标是 torch.abs,参数是 a。
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
fn_,
*proxy_args_kwargs(args, kwargs),
),
**options,
)
最后通过 wrap_fx_proxy 创建了新的 TensorVariable 来保存结果,收集到的 Guard 信息也附加了上去,一路返回后并在 call_function() 处将结果压栈,栈上的内容为 [TensorVariable(a), TensorVariable(torch.abs(a))]。可见,TorchDynamo 在字节码分析过程中没有真正地执行指令,而是以符号分析的方式从字节码中逐步提取用到的符号和函数,为它们创建 Proxy 并添加到 FX graph 中。
对应的函数调用栈为:
- [P069] > torch/_dynamo/utils.py#L428 [New]
- [P065] > torch/_dynamo/variables/torch.py#L181 [New]
- [P061] > torch/_dynamo/symbolic_convert.py#L469 [New]
- [P057] > torch/_dynamo/symbolic_convert.py#L988 [New]
- [P053] > torch/_dynamo/symbolic_convert.py#L341 [New]
- [P049] > torch/_dynamo/symbolic_convert.py#L537
Offset 10, LOAD_CONST: 加载常量 1,为其创建 ConstantVariable 并压栈,栈上的内容为 [TensorVariable(a), TensorVariable(torch.abs(a)), ConstantVariable(1)]。
def LOAD_CONST(self, inst):
self.push(ConstantVariable(value=inst.argval))
Offset 12, BINARY_ADD: InstructionTranslatorBase 对常见的 Python 内建函数用 stack_op 做了封装和转发,其中 BINARY_ADD = stack_op(operator.add):
BINARY_REMAINDER = stack_op(operator.mod)
BINARY_ADD = stack_op(operator.add)
BINARY_SUBTRACT = stack_op(operator.sub)
所以 symbolic_convert.py#L560 在执行 getattr(self, 'BINARY_ADD')(inst) 时,会转到 impl():
def stack_op(fn: typing.Callable[..., object]):
nargs = len(inspect.signature(fn).parameters)
fn_var = BuiltinVariable(fn)
@functools.wraps(fn)
def impl(self: "InstructionTranslatorBase", inst: Instruction):
self.push(fn_var.call_function(self, self.popn(nargs), {}))
return impl
其中,fn_var 是创建闭包时创建的 BuiltinVariable,需要出栈的参数个数由 inspect.signature(fn) 确定,对于 operator.add 来说需要 2 个参数,因此出栈 TensorVariable(torch.abs(a)) 和 ConstantVariable(1),随后转到 BuiltinVariable.call_function(),这与 Offset 4 处的 LOAD_ATTR 调用的是同一个函数。其中调用 propagate() 从输入的 VariableTracker 中收集 Guard,分别是针对变量 a 的 TENSOR_MATCH 和针对 torch 的 FUNCTION_MATCH。
proxy = tx.output.create_proxy(
"call_function",
fn,
*proxy_args_kwargs(args, kwargs),
)
builtin.py#L476-L480 处先调用 proxy_args_kwargs() 获取输入参数的 Proxy,然后在 FX Graph 中创建输出节点对应的 Proxy,通过 create_proxy() 创建了类型是 call_function 的 Proxy,目标是 operator.add,参数是 Proxy(a) 和 1。
最后用 wrap_fx_proxy() 创建了新的 TensorVariable 来保存结果,收集到的 Guard 信息也附加了上去,一路返回后在 impl() 处将结果压栈,栈上的内容为 [TensorVariable(a), TensorVariable(torch.abs(a)+1)]。
- [P061] > torch/_dynamo/variables/builder.py#L864
- [P057] > torch/_dynamo/variables/builtin.py#L428
- [P053] > torch/_dynamo/symbolic_convert.py#L148 [New]
- [P049] > torch/_dynamo/symbolic_convert.py#L537
Offset 14, BINARY_TRUE_DIVIDE: 与 Offset 12 的 BINARY_ADD 类似,BINARY_TRUE_DIVIDE 在 InstructionTranslatorBase 中被设为 stack_op(operator.truediv),随后到代码路径与 BINARY_ADD 一致。出栈 [TensorVariable(a) 和 TensorVariable(torch.abs(a)+1),为输出创建新的 Proxy,类型是 call_function,目标是 operator.truediv,参数是 Proxy(a) 和 Proxy(add)。再为其创建了新的 TensorVariable 来跟踪收集到的 Guard,最后把结果压栈,栈上的内容为 [TensorVariable(a/(torch.abs(a)+1))]。
BINARY_TRUE_DIVIDE = stack_op(operator.truediv)
Offset 16, STORE_FAST: 出栈最后一个元素 TensorVariable(a/(torch.abs(a)+1)),并把它存在 self.symbolic_locals[inst.argval] 中,这里 self.symbolic_locals 跟踪当前 frame 的局部变量,变量名由 inst.argval 指定,即 x。此后栈上内容为空。
def STORE_FAST(self, inst):
self.symbolic_locals[inst.argval] = self.pop()
Offset 18, LOAD_FAST: 再次回到 step() 时,因为 self.block_stack 和 self.stack 都为空,会创建当前状态的 checkpoint。随后进入 LOAD_FAST,需要加载的变量名由 inst.argval 指定,即 b。从 self.symbolic_locals 中获取 TensorVariable(b) 并压栈,栈上的内容为 [TensorVariable(b)]。
Offset 20, LOAD_ATTR: 这与 Offset 4 的 LOAD_ATTR 类似,先出栈一个元素,即 TensorVariable(b),然后为 getattr 创建了 BuiltinVariable、为 inst.argval(即 sum) 创建了 ConstantVariable,转入 BuiltinVariable.call_function(),对应着在 TorchDynamo 中执行类似 Python 的 getattr(b, "sum") 的功能。其中进行基本的检查后,它通过 VariableTracker.propagate() 收集了输入参数中的所有 Guard。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/variables/builtin.py#L548
handler = getattr(self, f"call_{self.fn.__name__}", None)
if handler:
try:
inspect.signature(handler).bind(tx, *args, **kwargs)
except TypeError as exc:
# ...
if handler:
try:
result = handler(tx, *args, **kwargs)
if result is not None:
return result.add_options(options)
except Unsupported as exc:
if not has_constant_handler:
raise
# Actually, we will handle this just fine
exc.remove_from_stats()
builtin.py#L548 获取了 call_getattr 属性,然后通过 inspect.signature(handler).bind(tx, *args, **kwargs) 检查是否能够成功将参数绑定到 call_getattr 上。builtin.py#L561 通过 result = handler(tx, *args, **kwargs) 调用 BuiltinVariable.call_getattr(),其中再次通过 VariableTracker.propagate() 收集了输入参数的 Guard,此处的 obj(即 b) 是 TensorVariable,在 builtin.py#L1006 进入 TensorVariable.var_getattr(),但在 tensor.py#L215 抛出 NotImplementedError 异常,它被 builtin.py#L1009 捕获,进而为 b.sum 创建 GetAttrVariable,builtin.py#L563 把输入中的 Guard 附加在输出上。最后返回到 LOAD_ATTR,并将结果压栈,此时栈上的内容为 [GetAttrVariable(TensorVariable(b),sum)]。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/variables/builtin.py#L1004
try:
return (
obj.var_getattr(tx, name).clone(source=source).add_options(options)
)
except NotImplementedError:
return GetAttrVariable(obj, name, **options)
对应的函数调用栈为:
- [P069] > torch/_dynamo/variables/misc.py#L549 [New]
- [P065] > torch/_dynamo/variables/base.py#L26
- [P061] > torch/_dynamo/variables/builtin.py#L938
- [P057] > torch/_dynamo/variables/builtin.py#L428
- [P053] > torch/_dynamo/symbolic_convert.py#L1063
- [P049] > torch/_dynamo/symbolic_convert.py#L537
Call Function Link to heading
Offset 22, CALL_FUNCTION: 与 Offset 8 处的 CALL_FUNCTION 相似,经 wrapper() 来到 InstructionTranslatorBase.CALL_FUNCTION(),出栈函数参数和函数后转到 InstructionTranslatorBase.call_function(),因为 fn 是 GetAttrVariable,因此执行 fn.call_function(self, args, kwargs) 会转到 GetAttrVariable.call_function()。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L988
@break_graph_if_unsupported(push=1)
def CALL_FUNCTION(self, inst):
args = self.popn(inst.argval)
fn = self.pop()
self.call_function(fn, args, {})
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L469
def call_function(
self,
fn: VariableTracker,
args: List[VariableTracker],
kwargs: Dict[str, VariableTracker],
):
# ...
self.push(fn.call_function(self, args, kwargs))
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/variables/misc.py#L581
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
# ...
return self.obj.call_method(tx, self.name, args, kwargs).add_options(self)
在 misc.py#L666,self.obj.call_method(tx, self.name, ...) 中的 self.obj 的类型是 TensorVariable、self.name 是 sum,因此来到 TensorVariable.call_method()。TensorVariable.call_method() 针对 PyTorch Tensor 下的许多方法进行了特殊处理,这里函数名 sum 未能成功匹配该函数中的大多数逻辑,最终来到 tensor.py#L422,此处为 sum 函数创建 torch.fx.Proxy,类型为 call_method,输入节点为 Proxy(b),目标为 sum(),并以此 Proxy 通过 wrap_fx_proxy() 创建 VariableTracker,并附着从输入 VariableTracker 中收集到的 Guard。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/variables/tensor.py#L238
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
# ...
else:
# ...
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self] + list(args), kwargs),
),
**options,
)
此前已经分析过 wrap_fx_proxy() 中的大部分内容,builder.py#L900 处通过 example_value = get_fake_value(proxy.node, tx) 从 FX Node 中创建 FakeTensor、并以 FakeTensor 运行该节点所代表的算子,实现在 get_fake_value()。在算子运行前,TorchDynamo 打开了 Python Dispatcher,它可以在 Python 层面捕获到 PyTorch 最底层的 ATen 算子,详见 TorchDispatchMode。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/utils.py#L1152
try:
with tx.fake_mode, enable_python_dispatcher():
return wrap_fake_exception(
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
)
except Unsupported:
raise
except RuntimeError as e:
# ...
以 FakeTensor 运行算子实现在 run_node(),此时 op 为 call_method,getattr(args[0], node.target)(*args[1:], **kwargs) 是 getattr(b, 'sum')(),也就是执行 b.sum(),因为 Tensor.sum() 是实现在 C++ 中的函数,这会进入到函数 THPVariable_sum() 中,它是 PyTorch 在编译过程中生成的,生成后的 THPVariable_sum() 位于 torch/csrc/autograd/generated/python_variable_methods.cpp#L15065。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/utils.py#L1178
def run_node(output_graph, node, args, kwargs, nnmodule):
op = node.op
try:
if op == "call_function":
return node.target(*args, **kwargs)
elif op == "call_method":
return getattr(args[0], node.target)(*args[1:], **kwargs)
elif op == "call_module":
assert nnmodule is not None
return nnmodule(*args, **kwargs)
elif op == "get_attr":
return output_graph.get_submodule(node.target)
elif op == "placeholder":
assert "example_value" in node.meta
return node.meta["example_value"]
except Exception as e:
raise RuntimeError(
f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n{e}\n(scroll up for backtrace)"
) from e
raise AssertionError(op)
此处的函数调用栈为:
- [C098] > build/aten/src/ATen/Operators_2.cpp#L5915:call [New]
- [C097] > build/aten/src/ATen/Operators_2.cpp#L5912:call [New]
- [C096] > build/aten/src/ATen/core/TensorBody.h#L3574:sum [New]
- [C095] > torch/csrc/autograd/generated/python_variable_methods.cpp#L15083:operator() [New]
- [C094] > torch/csrc/autograd/generated/python_variable_methods.cpp#L15065:THPVariable_sum [New]
- [P093] > torch/_dynamo/utils.py#L1178
- [P089] > torch/_dynamo/utils.py#L1155
- [P085] > torch/_dynamo/utils.py#L806
- [P081] > torch/_dynamo/utils.py#L1120
- [P077] > torch/_dynamo/variables/builder.py#L876
- [P073] > torch/_dynamo/variables/builder.py#L864
- [P069] > torch/_dynamo/variables/tensor.py#L238 [New]
- [P065] > torch/_dynamo/variables/misc.py#L581 [New]
- [P061] > torch/_dynamo/symbolic_convert.py#L469
- [P057] > torch/_dynamo/symbolic_convert.py#L988
- [P053] > torch/_dynamo/symbolic_convert.py#L341
- [P049] > torch/_dynamo/symbolic_convert.py#L537
回到 builder.py#L900,通过 specialize() 特化张量的信息,执行 return target_cls(proxy, **options) 创建 TensorVariable,它是 VariableTracker 的子类,用于记录 PyTorch 中的 torch.Tensor 类型,实现在 tensor.py#L69。
因此,在为计算节点的输出创建 VariableTracker 的过程中,TorchDynamo 在 FX Graph 中创建了 FX Proxy、添加了 Guard、创建了 FakeTensor,以 FakeTensor 为输入执行了算子,初始化了 VariableTracker。
最后回到 symbolic_convert.py#L494,把新创建的 TensorVariable 压栈,此时栈上的内容为 [TensorVariable(b.sum())]。
Offset 24, LOAD_CONST: 加载常量 0,创建 ConstantVariable 并压栈,此时栈上的内容为 [TensorVariable(b.sum()), ConstantVariable(0)]。
Offset 26, COMPARE_OP:
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L925
def COMPARE_OP(self, inst):
left, right = self.popn(2)
left = left.as_specialized(self)
right = right.as_specialized(self)
options = VariableTracker.propagate([left, right])
op = inst.argval
# ...
else:
self.push(
BuiltinVariable(supported_any[op], **options).call_function(
self, [left, right], {}
)
)
出栈要比较的两个元素 left 和 right,调用 VariableTracker.propagate() 收集输入中的 Guard,inst.argval 是小于号 <。symbolic_convert.py#L980 处的 supported_tensors[op] 是 operator.lt,此处首先创建 BuiltinVariable,在 BuiltinVariable.call_function() 经 result = handler(tx, *args, **kwargs) 到 _comparison():
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/variables/builtin.py#L1178-L1186
if isinstance(left, TensorVariable):
from .builder import wrap_fx_proxy
if op not in supported_tensor_comparison_ops.values():
_unimplemented()
return wrap_fx_proxy(
tx,
op(left.as_proxy(), right.as_proxy()),
)
这里的 op 为 operator.lt,在 FX 初始化阶段被修改为 impl(),此处创建新的 FX Proxy,op 为 call_function,target 为 operator.lt。wrap_fx_proxy() 的功能与前面一致,从 Proxy 中创建了 FakeTensor,以 FakeTensor 为输入执行算子 THPVariable_lt(),创建新的 TensorVariable。最后 symbolic_convert.py#L979 把新的 TensorVariable 压栈,此时栈上的内容为 [TensorVariable(lt)]。
这个过程的函数调用栈为:
- [P065] > torch/_dynamo/variables/builder.py#L864
- [P061] > torch/_dynamo/variables/builtin.py#L1135 [New]
- [P057] > torch/_dynamo/variables/builtin.py#L428
- [P053] > torch/_dynamo/symbolic_convert.py#L925 [New]
- [P049] > torch/_dynamo/symbolic_convert.py#L537