初始化 InstructionTranslator
Link to heading
清理后的字节码指令经由 transformations(instructions, code_options)
开始执行变换,transformations
是之前的 transform(),其中首先实例化了 InstructionTranslator
,定义于 torch/_dynamo/symbolic_convert.py#L1747。InstructionTranslator
中有一个 OutputGraph 的实例,OutputGraph
本身是一个 torch.fx.Tracer, OutputGraph
用于保存 InstructionTranslator
做 字节码翻译 后的输出,以 torch.fx.Graph 表达。OutputGraph
中还包含一个 SideEffects 的实例,它用于跟踪带有副作用 (side effect) 的操作,比如修改列表、setattr()
。此处的函数调用栈为:
- [P049] > torch/_dynamo/side_effects.py#L76 [New]
- [P045] > torch/_dynamo/output_graph.py#L171 [New]
- [P041] > torch/_dynamo/symbolic_convert.py#L1747 [New]
- [P037] > torch/_dynamo/convert_frame.py#L298 [New]
- [P033] > torch/_dynamo/bytecode_transformation.py#L488 [New]
- [P029] > torch/_dynamo/convert_frame.py#L279 [New]
InstructionTranslator
继承自 InstructionTranslatorBase,InstructionTranslatorBase
的本质是一个 Python 虚拟机的模拟器,它记录了待编译 Python 函数 栈帧 (Frame) 中的 运行时信息,类型是 PyFrameObject,包含了待编译函数在运行时的全局变量 f_globals
、局部变量 f_locals
、代码对象 f_code
、预处理后的字节码指令,TorchDyanmo 在 torch/_dynamo/convert_frame.py#L264-L267 通过 frame
获得这些属性:
return _compile(
frame.f_code,
frame.f_globals,
frame.f_locals,
frame.f_builtins,
compiler_fn,
one_graph,
export,
hooks,
frame,
)
除此之外,InstructionTranslatorBase
还包含了:
instruction_pointer
: Python 虚拟机的 PC,表明当前正在执行的字节码指令所处位置;stack
: Python 虚拟机的 数据栈,因为 Python 虚拟机是 Stack Machine,而不是类似 X86 的 Register Machine,所以 Python 虚拟机中字节码之间通过数据栈交换数据;block_stack
: Python 虚拟机的 block stack,用于记录一些特殊的块结构,比如循环、上下文 (with
);
在 torch/_dynamo/symbolic_convert.py#L1781-L1796,InstructionTranslator
通过 Code 对象的 co_varnames
、co_cellvars
、co_freevars
3 个字段获取 待编译函数中用到的变量名,其中 co_varnames
是局部变量名,co_cellvars
和 co_freevars
是闭包中用到的变量名。对于案例 toy_example()
,用到的变量名是 ('a', 'b', 'x')
。
vars = list(code_options["co_varnames"])
vars.extend(x for x in self.cell_and_freevars() if x not in vars)
self.symbolic_locals = collections.OrderedDict(
(
k,
VariableBuilder(
self,
LocalInputSource(k, code_options["co_varnames"].index(k))
if k in code_options["co_varnames"]
else LocalSource((k)),
)(f_locals[k]),
)
for k in vars
if k in f_locals
)
在初始化 InstructionTranslator
的过程中,它为待编译函数中每一个被用到的变量创建一个 VariableTracker
,作为 symbolic_locals
。TorchDynamo 捕获计算图的过程通过字节码翻译,而不单单是传统意义上的 tracing,例如 PyTorch 中的 JIT tracing 或者 FX 的 symbolic tracing。因此,TorchDynamo 需要有一个 类型系统,用于记录每个 Python 变量对应的类型信息,VariableTracker
的功能正是如此。
这个过程通过 VariableBuilder 实现,主要逻辑实现于builder.py#L745-L756:
tensor_proxy = self.tx.output.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value)
)
tensor_variable = wrap_fx_proxy(
tx=self.tx,
proxy=tensor_proxy,
example_value=value,
guards=self.make_guards(GuardBuilder.TENSOR_MATCH),
should_specialize=self.tensor_should_specialize(),
ignore_subclass=ignore_subclass,
source=self.get_source(),
)
其中通过 create_graph_input() 在 FX Graph 中创建了类型为 placeholder
的 FX Proxy,Proxy 是 FX symbloic tracing 中的 symbol。make_guards() 创建了类型为 GuardBuilder.TENSOR_MATCH
的 Guard。Guard
在 TorchDynamo 中负责检测被编译函数所引用外部数据的信息是否发生变化,如果没有发生变化则可以重用此前编译好的函数,否则当前输入对此前编译好的函数无效,需要重新编译该函数。在默认情况下,TENSOR_MATCH
主要负责检查输入的张量 shape、stride 等信息是否改变。
wrap_fx_proxy() 为刚创建的 FX Proxy 建立对应的 VariableTracker
,核心逻辑在 wrap_fx_proxy_cls()。其中先通过 wrap_to_fake_tensor_and_record() 为运行时获得的 torch.Tensor
创建 FakeTensor。默认情况下,TorchDynamo 使用 FakeTensor
捕获计算图而不是真实的 torch.Tensor
,FakeTensor
和真实的 torch.Tensor
有相同的张量信息,但没有张量内存分配。随后通过 specialize() 特化张量的信息,包括 dtype
、device
、layout
等。在 static shape 模式下,还会特化 size
、stride
、is_contiguous
这些信息,而 dynamic shape 模式下则不会特化它们。最后,wrap_fx_proxy_cls()
通过 tensor.py#L69 创建 TensorVariable
,它是 VariableTracker
的子类,用于记录 PyTorch 中的 torch.Tensor
类型。
可见,在创建 VariableTracker
的过程中,TorchDynamo 在 FX Graph 中创建了 FX Proxy、添加了 Guard、创建了 FakeTensor
,初始化了 VariableTracker
。由于并不是所有的局部变量都会被 frame 用到,为了减少不必要的 Guard 带来的运行时开销,TorchDynamo 在逐条运行 Python 字节码的过程中有选择性的更新 FX graph 所需要的 Guard
。创建 VariableTracker
时的调用栈如下:
- [P089] > torch/_dynamo/variables/tensor.py#L69
- [P085] > torch/_dynamo/variables/base.py#L26
- [P081] > torch/_dynamo/variables/base.py#L61
- [P077] > torch/_dynamo/variables/base.py#L72
- [P073] > torch/_dynamo/variables/base.py#L270
- [P069] > torch/_dynamo/variables/base.py#L26
- [P065] > torch/_dynamo/variables/builder.py#L876 [New]
- [P061] > torch/_dynamo/variables/builder.py#L864 [New]
- [P057] > torch/_dynamo/variables/builder.py#L704 [New]
- [P053] > torch/_dynamo/variables/builder.py#L294 [New]
- [P049] > torch/_dynamo/variables/builder.py#L170 [New]
- [P045] > torch/_dynamo/symbolic_convert.py#L1784 [New]
- [P041] > torch/_dynamo/symbolic_convert.py#L1747 [New]
- [P037] > torch/_dynamo/convert_frame.py#L298 [New]
- [P033] > torch/_dynamo/bytecode_transformation.py#L488 [New]
- [P029] > torch/_dynamo/convert_frame.py#L279 [New]
不断重复 symbolic_convert.py#L1784-L1796,最终我们得到 symbolic_locals
,它包含了待编译函数中所有引用变量的 VariableTracker
。至此,InstructionTranslator
初始化完毕,返回到 transform() 中。
字节码翻译 Link to heading
在 transform() 中,TorchDynamo 的计算图捕获逻辑从此处正式开始:
with tracing(tracer.output.tracing_context):
tracer.run()
实现在 InstructionTranslatorBase
中的 run() 包含一个 while
循环,每次调用 step() 处理一条字节码指令:
while (
self.instruction_pointer is not None
and not self.output.should_exit
and self.step()
):
pass
step() 用 instruction_pointer
获取当前 step 要处理的字节码指令,并把 instruction_pointer
加一。在数据栈为空、且应该编译部分图 (partial graph) 的条件下,InstructionTranslatorBase
会备份当前状态为 checkpoint,以便以后用于恢复。应不应该编译部分图,取决于当前 block stack 中的所有条目是否都可以恢复、并且用户没有通过 one_graph
或 nopython
指定全图编译。
此时的函数调用栈为:
- [P049] > torch/_dynamo/symbolic_convert.py#L537 [New]
- [P045] > torch/_dynamo/symbolic_convert.py#L590 [New]
- [P041] > torch/_dynamo/symbolic_convert.py#L1838 [New]
- [P037] > torch/_dynamo/convert_frame.py#L298 [New]
- [P033] > torch/_dynamo/bytecode_transformation.py#L488 [New]
- [P029] > torch/_dynamo/convert_frame.py#L279 [New]
逐条翻译字节码实现在 symbolic_convert.py#L560 处的 getattr(self, inst.opname)(inst)
,直到碰见异常或者 RETURN_VALUE
指令:
try:
if not hasattr(self, inst.opname):
unimplemented(f"missing: {inst.opname}")
getattr(self, inst.opname)(inst)
return inst.opname != "RETURN_VALUE"
except BackendCompilerFailed:
raise
except Unsupported as exc:
exc.real_stack.append(self.frame_summary())
if self.empty_checkpoint():
raise
log.debug("step triggered compile", exc_info=True)
except Exception as exc:
real_stack = getattr(exc, "real_stack", [])
real_stack.append(self.frame_summary())
exc.real_stack = real_stack # type: ignore[attr-defined]
raise
对于案例中的 toy_example()
,TorchDynamo 初始化阶段处理过后的字节码如下:
14 0 LOAD_FAST 0 (a)
2 LOAD_GLOBAL 0 (torch)
4 LOAD_ATTR 1 (abs)
6 LOAD_FAST 0 (a)
8 CALL_FUNCTION 1
10 LOAD_CONST 1 (1)
12 BINARY_ADD
14 BINARY_TRUE_DIVIDE
16 STORE_FAST 2 (x)
15 18 LOAD_FAST 1 (b)
20 LOAD_ATTR 2 (sum)
22 CALL_FUNCTION 0
24 LOAD_CONST 2 (0)
26 COMPARE_OP 0 (<)
28 POP_JUMP_IF_FALSE 38
16 30 LOAD_FAST 1 (b)
32 LOAD_CONST 3 (-1)
34 BINARY_MULTIPLY
36 STORE_FAST 1 (b)
17 >> 38 LOAD_FAST 2 (x)
40 LOAD_FAST 1 (b)
42 BINARY_MULTIPLY
44 RETURN_VALUE
它们依次在 InstructionTranslatorBase
中被翻译,并随之建立 FX Graph。对其进行 字节码翻译 的过程如下:
Offset 0, LOAD_FAST: inst.argval
为 a
,从 symbolic_locals
取出 变量 a
的 TensorVariable
,然后把它压到栈上,栈上的内容为 [TensorVariable(a)]
。
def LOAD_FAST(self, inst):
name = inst.argval
# ...
self.push(self.symbolic_locals[name])
# ...
Offset 2, LOAD_GLOBAL: inst.argval
为 torch
,它从 f_globals
中取出库 torch
,调用 VariableBuilder(self, source)(value)
在 builder.py#L391-L395 创建了 TorchVariable,含义是 PyTorch 中的某个 package,并创建了 FUNCTION_MATCH
类型的 Guard
。最后把该 TorchVariable
入栈,此时栈上的内容为 [TensorVariable(a), TorchVariable(torch)]
。
def LOAD_GLOBAL(self, inst):
# ...
name = inst.argval
# ...
try:
value = self.f_globals[name]
except KeyError:
return self.load_builtin(inst)
source = self.get_global_source(name)
self.push(VariableBuilder(self, source)(value))
Offset 4, LOAD_ATTR: 先出栈一个元素,即 TorchVariable(torch)
,然后为 getattr
创建了 BuiltinVariable、为 inst.argval
(即 abs
) 创建了 ConstantVariable,转入 call_function(),对应着在 TorchDynamo 中执行类似 Python 的 getattr(torch, abs)
的功能。
def LOAD_ATTR(self, inst):
obj = self.pop()
result = BuiltinVariable(getattr).call_function(
self, [obj, ConstantVariable(inst.argval)], {}
)
self.push(result)
其中进行基本的检查后,它通过 propagate() 收集输入中的所有 Guard
,builtin.py#L548,这里获取 call_getattr
属性成功,然后通过 inspect.signature(handler).bind(tx, *args, **kwargs)
检查是否能够成功将参数绑定到 call_getattr
上,再通过 result = handler(tx, *args, **kwargs)
调用 call_getattr()。
call_getattr()
通过 propagate()
收集了输入参数的 Guard
,因为此处的 obj
(即 torch
) 是 TorchVariable
,所以在 builtin.py#L1010-L1013 为 torch.abs
创建了新的 TorchVariable
,注意此时附加在 TorchVariable(torch)
上的 Guard
已经被传播(propagete)到了 TorchVariable(torch.abs)
,并且记录了来源 AttrSource(base=GlobalSource(global_name='torch'), member='abs')
。随后一路返回到 LOAD_ATTR
,并将结果压栈,此时栈上的内容为 [TensorVariable(a), TorchVariable(torch.abs)]
。
elif isinstance(obj, TorchVariable):
member = getattr(obj.value, name)
if is_allowed(member):
return TorchVariable(member, **options)
Offset 6, LOAD_FAST: 再次加载变量 a
的 TensorVariable
,随后栈上的内容为 [TensorVariable(a), TorchVariable(torch.abs), TensorVariable(a)]
。
Offset 8, CALL_FUNCTION: 因为 CALL_FUNCTION
被装饰器 break_graph_if_unsupported() 装饰,所以执行 CALL_FUNCTION()
会先经过其中的 wrapper()
。这里首先创建 checkpoint,保存了所有的状态,以便在后面出现异常时从 checkpoint 中恢复。
state = self.copy_graphstate()
reason = None
try:
return inner_fn(self, inst)
except Unsupported as excp:
# ...
self.restore_graphstate(state)
然后执行 inner_fn(self, inst)
,inner_fn
就是 CALL_FUNCTION()
,其中先出栈 N 个元素作为函数参数,N 由 inst.argval
指定,这里是 1,然后再出栈 1 个元素作为函数,通过 InstructionTranslatorBase.call_function() 进行函数调用:
@break_graph_if_unsupported(push=1)
def CALL_FUNCTION(self, inst):
args = self.popn(inst.argval)
fn = self.pop()
self.call_function(fn, args, {})
InstructionTranslatorBase.call_function()
调用 TorchVariable.call_function(),TorchDynamo 在此处模拟执行 torch.abs(a)
。首先用 propagate()
收集所有参数中的 Guard
,然后匹配到 torch.py##L565-L573,此处 proxy_args_kwargs() 从 TensorVariable(a)
获取 torch.fx.Proxy(a)
,它是在初始化 symbolic_locals
时创建的,然后通过 create_proxy() 创建了新的 Proxy
,类型是 call_function
,目标是 torch.abs
,参数是 a
。
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
fn_,
*proxy_args_kwargs(args, kwargs),
),
**options,
)
最后通过 wrap_fx_proxy 创建了新的 TensorVariable
来保存结果,收集到的 Guard
信息也附加了上去,一路返回后并在 call_function() 处将结果压栈,栈上的内容为 [TensorVariable(a), TensorVariable(torch.abs(a))]
。可见,TorchDynamo 在字节码分析过程中没有真正地执行指令,而是以符号分析的方式从字节码中逐步提取用到的符号和函数,为它们创建 Proxy
并添加到 FX graph 中。
对应的函数调用栈为:
- [P069] > torch/_dynamo/utils.py#L428 [New]
- [P065] > torch/_dynamo/variables/torch.py#L181 [New]
- [P061] > torch/_dynamo/symbolic_convert.py#L469 [New]
- [P057] > torch/_dynamo/symbolic_convert.py#L988 [New]
- [P053] > torch/_dynamo/symbolic_convert.py#L341 [New]
- [P049] > torch/_dynamo/symbolic_convert.py#L537
Offset 10, LOAD_CONST: 加载常量 1
,为其创建 ConstantVariable 并压栈,栈上的内容为 [TensorVariable(a), TensorVariable(torch.abs(a)), ConstantVariable(1)]
。
def LOAD_CONST(self, inst):
self.push(ConstantVariable(value=inst.argval))
Offset 12, BINARY_ADD: InstructionTranslatorBase
对常见的 Python 内建函数用 stack_op 做了封装和转发,其中 BINARY_ADD = stack_op(operator.add)
:
BINARY_REMAINDER = stack_op(operator.mod)
BINARY_ADD = stack_op(operator.add)
BINARY_SUBTRACT = stack_op(operator.sub)
所以 symbolic_convert.py#L560 在执行 getattr(self, 'BINARY_ADD')(inst)
时,会转到 impl():
def stack_op(fn: typing.Callable[..., object]):
nargs = len(inspect.signature(fn).parameters)
fn_var = BuiltinVariable(fn)
@functools.wraps(fn)
def impl(self: "InstructionTranslatorBase", inst: Instruction):
self.push(fn_var.call_function(self, self.popn(nargs), {}))
return impl
其中,fn_var
是创建闭包时创建的 BuiltinVariable,需要出栈的参数个数由 inspect.signature(fn)
确定,对于 operator.add
来说需要 2 个参数,因此出栈 TensorVariable(torch.abs(a))
和 ConstantVariable(1)
,随后转到 BuiltinVariable.call_function(),这与 Offset 4 处的 LOAD_ATTR
调用的是同一个函数。其中调用 propagate()
从输入的 VariableTracker
中收集 Guard
,分别是针对变量 a
的 TENSOR_MATCH
和针对 torch
的 FUNCTION_MATCH
。
proxy = tx.output.create_proxy(
"call_function",
fn,
*proxy_args_kwargs(args, kwargs),
)
builtin.py#L476-L480 处先调用 proxy_args_kwargs() 获取输入参数的 Proxy
,然后在 FX Graph 中创建输出节点对应的 Proxy
,通过 create_proxy() 创建了类型是 call_function
的 Proxy
,目标是 operator.add
,参数是 Proxy(a)
和 1
。
最后用 wrap_fx_proxy() 创建了新的 TensorVariable
来保存结果,收集到的 Guard
信息也附加了上去,一路返回后在 impl() 处将结果压栈,栈上的内容为 [TensorVariable(a), TensorVariable(torch.abs(a)+1)]
。
- [P061] > torch/_dynamo/variables/builder.py#L864
- [P057] > torch/_dynamo/variables/builtin.py#L428
- [P053] > torch/_dynamo/symbolic_convert.py#L148 [New]
- [P049] > torch/_dynamo/symbolic_convert.py#L537
Offset 14, BINARY_TRUE_DIVIDE: 与 Offset 12 的 BINARY_ADD
类似,BINARY_TRUE_DIVIDE
在 InstructionTranslatorBase
中被设为 stack_op(operator.truediv)
,随后到代码路径与 BINARY_ADD
一致。出栈 [TensorVariable(a)
和 TensorVariable(torch.abs(a)+1)
,为输出创建新的 Proxy
,类型是 call_function
,目标是 operator.truediv
,参数是 Proxy(a)
和 Proxy(add)
。再为其创建了新的 TensorVariable
来跟踪收集到的 Guard
,最后把结果压栈,栈上的内容为 [TensorVariable(a/(torch.abs(a)+1))]
。
BINARY_TRUE_DIVIDE = stack_op(operator.truediv)
Offset 16, STORE_FAST: 出栈最后一个元素 TensorVariable(a/(torch.abs(a)+1))
,并把它存在 self.symbolic_locals[inst.argval]
中,这里 self.symbolic_locals
跟踪当前 frame 的局部变量,变量名由 inst.argval
指定,即 x
。此后栈上内容为空。
def STORE_FAST(self, inst):
self.symbolic_locals[inst.argval] = self.pop()
Offset 18, LOAD_FAST: 再次回到 step() 时,因为 self.block_stack
和 self.stack
都为空,会创建当前状态的 checkpoint。随后进入 LOAD_FAST,需要加载的变量名由 inst.argval
指定,即 b
。从 self.symbolic_locals
中获取 TensorVariable(b)
并压栈,栈上的内容为 [TensorVariable(b)]
。
Offset 20, LOAD_ATTR: 这与 Offset 4 的 LOAD_ATTR
类似,先出栈一个元素,即 TensorVariable(b)
,然后为 getattr
创建了 BuiltinVariable、为 inst.argval
(即 sum
) 创建了 ConstantVariable,转入 BuiltinVariable.call_function(),对应着在 TorchDynamo 中执行类似 Python 的 getattr(b, "sum")
的功能。其中进行基本的检查后,它通过 VariableTracker.propagate() 收集了输入参数中的所有 Guard
。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/variables/builtin.py#L548
handler = getattr(self, f"call_{self.fn.__name__}", None)
if handler:
try:
inspect.signature(handler).bind(tx, *args, **kwargs)
except TypeError as exc:
# ...
if handler:
try:
result = handler(tx, *args, **kwargs)
if result is not None:
return result.add_options(options)
except Unsupported as exc:
if not has_constant_handler:
raise
# Actually, we will handle this just fine
exc.remove_from_stats()
builtin.py#L548 获取了 call_getattr
属性,然后通过 inspect.signature(handler).bind(tx, *args, **kwargs)
检查是否能够成功将参数绑定到 call_getattr
上。builtin.py#L561 通过 result = handler(tx, *args, **kwargs)
调用 BuiltinVariable.call_getattr(),其中再次通过 VariableTracker.propagate()
收集了输入参数的 Guard
,此处的 obj
(即 b
) 是 TensorVariable
,在 builtin.py#L1006 进入 TensorVariable.var_getattr(),但在 tensor.py#L215 抛出 NotImplementedError
异常,它被 builtin.py#L1009 捕获,进而为 b.sum
创建 GetAttrVariable,builtin.py#L563 把输入中的 Guard
附加在输出上。最后返回到 LOAD_ATTR,并将结果压栈,此时栈上的内容为 [GetAttrVariable(TensorVariable(b),sum)]
。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/variables/builtin.py#L1004
try:
return (
obj.var_getattr(tx, name).clone(source=source).add_options(options)
)
except NotImplementedError:
return GetAttrVariable(obj, name, **options)
对应的函数调用栈为:
- [P069] > torch/_dynamo/variables/misc.py#L549 [New]
- [P065] > torch/_dynamo/variables/base.py#L26
- [P061] > torch/_dynamo/variables/builtin.py#L938
- [P057] > torch/_dynamo/variables/builtin.py#L428
- [P053] > torch/_dynamo/symbolic_convert.py#L1063
- [P049] > torch/_dynamo/symbolic_convert.py#L537
Call Function Link to heading
Offset 22, CALL_FUNCTION: 与 Offset 8 处的 CALL_FUNCTION
相似,经 wrapper() 来到 InstructionTranslatorBase.CALL_FUNCTION(),出栈函数参数和函数后转到 InstructionTranslatorBase.call_function(),因为 fn
是 GetAttrVariable
,因此执行 fn.call_function(self, args, kwargs)
会转到 GetAttrVariable.call_function()。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L988
@break_graph_if_unsupported(push=1)
def CALL_FUNCTION(self, inst):
args = self.popn(inst.argval)
fn = self.pop()
self.call_function(fn, args, {})
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L469
def call_function(
self,
fn: VariableTracker,
args: List[VariableTracker],
kwargs: Dict[str, VariableTracker],
):
# ...
self.push(fn.call_function(self, args, kwargs))
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/variables/misc.py#L581
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
# ...
return self.obj.call_method(tx, self.name, args, kwargs).add_options(self)
在 misc.py#L666,self.obj.call_method(tx, self.name, ...)
中的 self.obj
的类型是 TensorVariable
、self.name
是 sum
,因此来到 TensorVariable.call_method()。TensorVariable.call_method()
针对 PyTorch Tensor 下的许多方法进行了特殊处理,这里函数名 sum
未能成功匹配该函数中的大多数逻辑,最终来到 tensor.py#L422,此处为 sum
函数创建 torch.fx.Proxy
,类型为 call_method
,输入节点为 Proxy(b)
,目标为 sum()
,并以此 Proxy 通过 wrap_fx_proxy() 创建 VariableTracker
,并附着从输入 VariableTracker
中收集到的 Guard
。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/variables/tensor.py#L238
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
# ...
else:
# ...
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self] + list(args), kwargs),
),
**options,
)
此前已经分析过 wrap_fx_proxy()
中的大部分内容,builder.py#L900 处通过 example_value = get_fake_value(proxy.node, tx)
从 FX Node 中创建 FakeTensor
、并以 FakeTensor
运行该节点所代表的算子,实现在 get_fake_value()。在算子运行前,TorchDynamo 打开了 Python Dispatcher,它可以在 Python 层面捕获到 PyTorch 最底层的 ATen 算子,详见 TorchDispatchMode。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/utils.py#L1152
try:
with tx.fake_mode, enable_python_dispatcher():
return wrap_fake_exception(
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
)
except Unsupported:
raise
except RuntimeError as e:
# ...
以 FakeTensor
运行算子实现在 run_node(),此时 op
为 call_method
,getattr(args[0], node.target)(*args[1:], **kwargs)
是 getattr(b, 'sum')()
,也就是执行 b.sum()
,因为 Tensor.sum()
是实现在 C++ 中的函数,这会进入到函数 THPVariable_sum()
中,它是 PyTorch 在编译过程中生成的,生成后的 THPVariable_sum()
位于 torch/csrc/autograd/generated/python_variable_methods.cpp#L15065
。
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/utils.py#L1178
def run_node(output_graph, node, args, kwargs, nnmodule):
op = node.op
try:
if op == "call_function":
return node.target(*args, **kwargs)
elif op == "call_method":
return getattr(args[0], node.target)(*args[1:], **kwargs)
elif op == "call_module":
assert nnmodule is not None
return nnmodule(*args, **kwargs)
elif op == "get_attr":
return output_graph.get_submodule(node.target)
elif op == "placeholder":
assert "example_value" in node.meta
return node.meta["example_value"]
except Exception as e:
raise RuntimeError(
f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n{e}\n(scroll up for backtrace)"
) from e
raise AssertionError(op)
此处的函数调用栈为:
- [C098] > build/aten/src/ATen/Operators_2.cpp#L5915:call [New]
- [C097] > build/aten/src/ATen/Operators_2.cpp#L5912:call [New]
- [C096] > build/aten/src/ATen/core/TensorBody.h#L3574:sum [New]
- [C095] > torch/csrc/autograd/generated/python_variable_methods.cpp#L15083:operator() [New]
- [C094] > torch/csrc/autograd/generated/python_variable_methods.cpp#L15065:THPVariable_sum [New]
- [P093] > torch/_dynamo/utils.py#L1178
- [P089] > torch/_dynamo/utils.py#L1155
- [P085] > torch/_dynamo/utils.py#L806
- [P081] > torch/_dynamo/utils.py#L1120
- [P077] > torch/_dynamo/variables/builder.py#L876
- [P073] > torch/_dynamo/variables/builder.py#L864
- [P069] > torch/_dynamo/variables/tensor.py#L238 [New]
- [P065] > torch/_dynamo/variables/misc.py#L581 [New]
- [P061] > torch/_dynamo/symbolic_convert.py#L469
- [P057] > torch/_dynamo/symbolic_convert.py#L988
- [P053] > torch/_dynamo/symbolic_convert.py#L341
- [P049] > torch/_dynamo/symbolic_convert.py#L537
回到 builder.py#L900,通过 specialize() 特化张量的信息,执行 return target_cls(proxy, **options)
创建 TensorVariable
,它是 VariableTracker
的子类,用于记录 PyTorch 中的 torch.Tensor
类型,实现在 tensor.py#L69。
因此,在为计算节点的输出创建 VariableTracker
的过程中,TorchDynamo 在 FX Graph 中创建了 FX Proxy、添加了 Guard、创建了 FakeTensor
,以 FakeTensor
为输入执行了算子,初始化了 VariableTracker
。
最后回到 symbolic_convert.py#L494,把新创建的 TensorVariable
压栈,此时栈上的内容为 [TensorVariable(b.sum())]
。
Offset 24, LOAD_CONST: 加载常量 0,创建 ConstantVariable
并压栈,此时栈上的内容为 [TensorVariable(b.sum()), ConstantVariable(0)]
。
Offset 26, COMPARE_OP:
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/symbolic_convert.py#L925
def COMPARE_OP(self, inst):
left, right = self.popn(2)
left = left.as_specialized(self)
right = right.as_specialized(self)
options = VariableTracker.propagate([left, right])
op = inst.argval
# ...
else:
self.push(
BuiltinVariable(supported_any[op], **options).call_function(
self, [left, right], {}
)
)
出栈要比较的两个元素 left
和 right
,调用 VariableTracker.propagate()
收集输入中的 Guard
,inst.argval
是小于号 <
。symbolic_convert.py#L980 处的 supported_tensors[op]
是 operator.lt
,此处首先创建 BuiltinVariable
,在 BuiltinVariable.call_function() 经 result = handler(tx, *args, **kwargs)
到 _comparison():
# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/variables/builtin.py#L1178-L1186
if isinstance(left, TensorVariable):
from .builder import wrap_fx_proxy
if op not in supported_tensor_comparison_ops.values():
_unimplemented()
return wrap_fx_proxy(
tx,
op(left.as_proxy(), right.as_proxy()),
)
这里的 op
为 operator.lt
,在 FX 初始化阶段被修改为 impl(),此处创建新的 FX Proxy,op
为 call_function
,target
为 operator.lt
。wrap_fx_proxy() 的功能与前面一致,从 Proxy
中创建了 FakeTensor
,以 FakeTensor
为输入执行算子 THPVariable_lt()
,创建新的 TensorVariable
。最后 symbolic_convert.py#L979 把新的 TensorVariable
压栈,此时栈上的内容为 [TensorVariable(lt)]
。
这个过程的函数调用栈为:
- [P065] > torch/_dynamo/variables/builder.py#L864
- [P061] > torch/_dynamo/variables/builtin.py#L1135 [New]
- [P057] > torch/_dynamo/variables/builtin.py#L428
- [P053] > torch/_dynamo/symbolic_convert.py#L925 [New]
- [P049] > torch/_dynamo/symbolic_convert.py#L537