简介 Link to heading

torch.fx 用于对 torch.nn.Module 做图变换。它包含三部分:

  • 符号跟踪器(symbolic tracer):用于捕获 module 的语义,它以符号的方式执行 Python 代码(symbolic execution),通过给 module 提供虚假值(Proxies)并记录涉及到的运算;
  • 中间表示(intermediate representation):IR 是在 tracing 期间所记录算子的图,包含一系列节点(Node),节点代表输入(placeholder)、函数(get_attr, call_function, call_module, call_method)、输出(output),IR 是用 torch.fx 进行图变换(transformation)的基石;
  • Python 代码生成(code generation):直接生成 Python 代码使得 torch.fx 可以进行 Python-Python 或 Module-Module 的一一变换,此功能包含在 torch.fx.GraphModule 中,它是 torch.nn.Module 的实例,并保留有 torch.fx.Graph

总而言之,torch.fx 的流程是:符号跟踪,中间表示,代码变换,Python 代码生成。

使用方法 Link to heading

创建 GraphModule Link to heading

方法一:先获取 torch.fx.Graph,再创建 torch.fx.GraphModule:

graph : torch.fx.Graph = torch.fx.Tracer().trace(module)

# Modify graph
# <...>

graph_module = torch.fx.GraphModule(module, graph)

方法二:先获取 torch.fx.GraphModule,再修改 torch.fx.Graph,其中 graph_module.recompile() 用于同步在 .graph 上的更改,生成新的 forward() 函数:

from torch.fx import symbolic_trace

graph_module : torch.fx.GraphModule = symbolic_trace(module)

# Modify graph_module.graph
# <...>

# Recompile the forward() method of `graph_module` from its Graph
graph_module.recompile()

图属性 Link to heading

print(graph_module.graph)  # IR
print(graph_module.code)   # Generated Python code

辅助函数:

graph.print_tabular()  # Nodes in the graph
graph.lint()           # sanity check

图变换 Link to heading

替换算子:

for node in graph.nodes:
    if node.op == "call_function":
        if node.target == torch.relu:
            node.target = torch.sigmoid

torch.fx.replace_pattern() 可以用来执行简单的算子替换:

def pattern(a1, a2):
    val1 = torch.neg(a1)
    return torch.cat([val1, a2]).sum()

def replacement(w1, w2):
    return torch.stack([w1, w2])

replace_pattern(traced, pattern, replacement)

插入算子:

# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with graph_module.graph.inserting_after(node):
    # Insert a new `call_function` node calling `torch.relu`
    new_node = graph_module.graph.call_function(
        torch.relu, args=(node,))

    # We want all places that used the value of `node` to
    # now use that value after the `relu` call we've added.
    # We use the `replace_all_uses_with` API to do this.
    node.replace_all_uses_with(new_node)

操纵图的另一种方式是 Proxy 机制进行 retracing,它可以自动化 graph rewriting,避免显示图修改,可以用 Python 函数的形式描述图重写规则。这里的关键是,把 Node 包装进 Proxy,传给图变换函数,返回新的 Proxy,用新的 Proxy 中的 Node 继续组建新图

def relu_decomposition(x):
    return (x > 0) * x

decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition

for node in graph.nodes:
    if node.op == 'call_function' and node.target in decomposition_rules:
        proxy_args = [fx.Proxy(env[x.name], tracer) if \
            isinstance(x, fx.Node) else x for x in node.args]
        output_proxy = decomposition_rules[node.target](*proxy_args)

高级API Link to heading

torch.fx.Interpretertorch.fx 的基础 API 做了封装,它可以逐节点执行 FX 图,可以用于辅助分析、子图变换、retracing等等。

torch.fx.Transformer 是特殊的 Interpreter,Interpreter 需要具体的输入数据才能运行,而 Transformer 完全以符号方式运行,它可以用来做子图变换。

调试 Link to heading

避免图变换出错和一些调试建议:

  • 不要使用 Python 中的 set() 管理 Node,集合类型是无序的;
  • 使用 torch.allclose() 对比变换前后 Module 的结果;
  • 使用 import pdb; pdb.set_trace() 在变换后的 Module 执行前暂停,然后单步调试;
  • 继承原 Module,把生成的 forward() 函数复制粘贴到继承的 Module 中,用继承的 Module 调试;
  • 使用 GraphModule.to_folder() 将 FX 代码导出到本地,然后导入模块进行调试;
  • 检查 .graph.code 属性,以及 graph.print_tabular()

局限性 Link to heading

torch.fx 基于符号跟踪(symbolic tracing),局限性有:

  • 不支持动态控制流(dynamic control flow),支持静态控制流,即 if/loop 的条件不随输入 tensor 变化,生成的静态控制流代码,图是特化和展开的,不包含控制语句。语义上是静态控制的动态控制流,可以通过 symbolic_trace()concrete_args 参数传入具体参数。真实的动态控制流可以通过 wrap() 避免 trace 它们;
  • torch.fx 目前支持 torch, operator, math 3个 Python 包下面的函数,除此之外的函数需要用 torch.fx.wrap() 将其声明为直接调用,例如 Python 内建函数 len()
  • 创建 Tensor 的 API 不可被 trace,例如 torch.ones(), torch.randn(),后者是非确定性的,需要 wrap()
  • 符号跟踪过程中捕获的 flag 变量,不可在执行阶段改变。例如,捕获 training 时的 dropout 函数 torch.nn.functional.dropout(x, training=self.training),以 eval 模式执行会出错,应该使用 torch.nn.Dropout,它是叶子模块,内部实现不会被符号跟踪;

典型的报错案例:

# Proxy cannot be used as inputs to control flow
func = lambda x, y: y if x > 0 else -y
# Proxy cannot be iterated
func = lambda x, y: sum([i*j for i,j in zip(x, y)])
# 'len' is not supported in symbolic tracing by default
func = lambda x: len(x)
torch.fx.wrap('len')

除此之外,实践中发现其局限性还有:

  • 符号跟踪的基本单位是 PyTorch 的算子(op),正向传播与反向传播非对称 fusion 无法实现;

源码剖析 Link to heading

本次源码剖析使用 2022/07/22 最新版本:

  • NVIDIA PyTorch container: nvcr.io/nvidia/pytorch:22.06-py3
  • PyTorch: v1.12.0

符号跟踪源码 Link to heading

import torch
from torch.fx import symbolic_trace

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

module = MyModule()

symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
print(symbolic_traced.graph)
print(symbolic_traced.code)

symbolic_trace 实现于 torch/fx/_symbolic_trace.py:827:symbolic_trace, 流程:

  • 创建 torch.fx.Tracer,进行简单的初始化;
  • 调用 trace() 来捕获 torch.fx.Graph
  • 创建 torch.fx.GraphModule

符号跟踪(symblic tracing)torch/fx/_symbolic_trace.py:499:trace 实现,符号跟踪整体流程如下:

  • torch/fx/_symbolic_trace.py#L539:创建 torch.fx.Graph,Graph 是 FX IR 的主要数据结构,包含一系列 torch.fx.Node,Node 代表了输入/输出/算子,一系列 Node 构成了 Python 函数。Graph 初始化在 torch/fx/graph.py:600:__init__,其中创建了 root 节点(torch.fx.Node)和 CodeGen 对象。Node 初始化于 torch/fx/node.py:123:__init__,Node 包含 .op 属性,表明该节点的功能。此处主要函数执行流程:

    [P002] > torch/fx/graph.py:600:__init__ [New]
    [P003] > torch/fx/node.py:123:__init__ [New]
    [P003] < torch/fx/node.py:123:__init__
    [P003] > torch/fx/graph.py:250:__init__ [New]
    [P003] < torch/fx/graph.py:250:__init__
    [P002] < torch/fx/graph.py:600:__init__
    
  • torch/fx/_symbolic_trace.py#L559:创建被 trace 的 root 函数(对 module 来说是 forward)的输入节点(placeholder),具体实现在 torch/fx/_symbolic_trace.py:376:create_args_for_root,流程如下:

    • 对 root 函数进行自省以获取函数签名,它通过标准库的 inspect.signature() 实现,从而获得参数名称等信息;
    • 对 root 函数的每个参数创建 torch.fx.Proxy 对象(torch/fx/proxy.py:232:__init__),其类型是 placeholder,含义是函数的输入。Proxy 是 Node 的 wrapper,每个 Proxy 都包含一个 Node,Proxy 在 symbolic tracing 过程中负责记录涉及到的算子。Proxy 是符号跟踪过程中的符号
    • 创建 Node 通过 graph.create_node() 实现(torch/fx/graph.py:691:create_node),Node 以双向链表的方式管理,每个 Node 都记录其前后节点,节点插入在 torch/fx/node.py:224:prepend 实现。这里 graphTracerBase类属性Tracer 继承自 TracerBase

    此处主要函数执行流程:

    [P002] > torch/fx/_symbolic_trace.py:376:create_args_for_root [New]
    [P003] > python3.8/inspect.py:493:unwrap [New]
    [P003] < python3.8/inspect.py:493:unwrap
    [P003] > python3.8/inspect.py:3103:signature [New]
    [P003] < python3.8/inspect.py:3103:signature
    [P003] > torch/fx/_symbolic_trace.py:443:<genexpr> [New]
    [P004] > torch/fx/_symbolic_trace.py:403:proxy_placeholder [New]
    [P005] > torch/fx/proxy.py:49:create_proxy [New]
    [P006] > torch/fx/_symbolic_trace.py:201:create_arg [New]
    [P007] > torch/fx/proxy.py:107:create_arg [New]
    [P007] < torch/fx/proxy.py:107:create_arg
    [P006] < torch/fx/_symbolic_trace.py:201:create_arg
    [P006] > torch/fx/proxy.py:29:create_node [New]
    [P007] > torch/fx/graph.py:691:create_node [New]
    [P008] > torch/fx/node.py:123:__init__
    [P008] < torch/fx/node.py:123:__init__
    [P008] > torch/fx/node.py:224:prepend [New]
    [P009] > torch/fx/node.py:257:_remove_from_list [New]
    [P009] < torch/fx/node.py:257:_remove_from_list
    [P008] < torch/fx/node.py:224:prepend
    [P007] < torch/fx/graph.py:691:create_node
    [P006] < torch/fx/proxy.py:29:create_node
    [P006] > torch/fx/proxy.py:45:proxy [New]
    [P007] > torch/fx/proxy.py:232:__init__ [New]
    [P007] < torch/fx/proxy.py:232:__init__
    [P006] < torch/fx/proxy.py:45:proxy
    [P005] < torch/fx/proxy.py:49:create_proxy
    [P004] < torch/fx/_symbolic_trace.py:403:proxy_placeholder
    [P003] < torch/fx/_symbolic_trace.py:443:<genexpr>
    [P003] > torch/utils/_pytree.py:126:tree_flatten [New]
    [P003] < torch/utils/_pytree.py:126:tree_flatten
    [P002] < torch/fx/_symbolic_trace.py:376:create_args_for_root
    
  • torch/fx/_symbolic_trace.py#L581-L582:以 monkey patch 的方式给 torch.nn.Module.__getattr__torch.nn.Module.__call__ 打补丁,通过 setattr() 将它们替换为 module_getattr_wrappermodule_call_wrapper

  • torch/fx/_symbolic_trace.py#L583-L586:给叶子函数(leaf function)打补丁,确保在符号跟踪时直接调用原函数。叶子函数包括通过 torch.fx.wrap(fn_or_name) 标记的函数,以及在 Tracer 实例化时由参数 autowrap_modules 显示指定的 Python 模块(默认为 math 模块)和由 autowrap_function 显示指定的 Python 函数。这些叶子函数在符号跟踪的过程中不会被跟踪,会直接调用原函数;

  • torch/fx/_symbolic_trace.py#L587-L588:以 Proxy 作为参数调用 root 函数,返回后为其创建 output 节点;

符号跟踪的核心正是 torch/fx/_symbolic_trace.py#L587-L588fn(*args),其中 args 是 Proxy 构成的列表,fn 是上述示例代码中的 forward()

forward() 中用到 self.param,而 selftorch.nn.Module 的实例,torch.nn.Module.__getattr__ 在此之前已被修改为 torch/fx/_symbolic_trace.py:565:module_getattr_wrapper,这里获取原属性后,转到 torch/fx/_symbolic_trace.py:473:_module_getattr。如果该属性的类型是 torch.nn.Parameter,则为其创建类型为 get_attr 的 Proxy,含义是从 module 中获取 Parameter。

然后执行 x + self.param,此时两个操作数都是 torch.fx.Proxy 类型,加法在这里需要调用 Python 中的魔术方法(magic method) __add__Proxy.__add__ 在导入 torch.fx 模块时被设置为 torch/fx/proxy.py:383:impl,被修改的魔术方法列表由 torch/fx/graph.py#L1417 指定,包含了常见的 Python 数学运算符。在 impl() 内,找到真正的算子 operator.add,由 tracer 为 operator.add 创建新的 Proxy,类型为 call_function。创建 Proxy 实现在 torch/fx/proxy.py:49:create_proxy,两个操作数都是 Proxy 类型,获取参数(torch/fx/proxy.py#L63)会直接通过 Porxy.node 获取其对应的 Node(torch/fx/proxy.py#L147),然后通过 graph 创建新的 Node,其 .targetoperator.add.opcall_function._input_nodes 是两个操作数的 Node,而新创建的 Node 也成了操作数 Node 的 .usersgraph 以环形双向链表的方式管理 Node,新的 Node 被插入到 graph 的 root 节点前(root._prev)、最后一个节点后,每个 Node 自己负责记录 input_nodes (producer) 和 users (consumer)。 最后为新 Node 创建 Proxy。

此处主要的函数执行流程:

[P003] > torch/fx/proxy.py:383:impl [New]
[P004] > torch/fx/proxy.py:49:create_proxy
[P005] > torch/fx/_symbolic_trace.py:201:create_arg
[P006] > torch/fx/proxy.py:107:create_arg
[P007] > torch/fx/proxy.py:125:<genexpr>
[P008] > torch/fx/_symbolic_trace.py:201:create_arg
[P008] < torch/fx/_symbolic_trace.py:201:create_arg
[P007] < torch/fx/proxy.py:125:<genexpr>
[P006] < torch/fx/proxy.py:107:create_arg
[P005] < torch/fx/_symbolic_trace.py:201:create_arg
[P005] > torch/fx/proxy.py:29:create_node
[P006] > torch/fx/graph.py:691:create_node
[P007] > torch/fx/node.py:123:__init__
[P008] > torch/fx/node.py:365:__update_args_kwargs
[P008] < torch/fx/node.py:365:__update_args_kwargs
[P007] < torch/fx/node.py:123:__init__
[P006] < torch/fx/graph.py:691:create_node
[P005] < torch/fx/proxy.py:29:create_node
[P004] < torch/fx/proxy.py:49:create_proxy
[P003] < torch/fx/proxy.py:383:impl

torch/fx/proxy.py:383:impl 正是符号跟踪的巧妙之处,调用加法运算符被转到了 Proxy.__add__,该函数并没有真正的执行加法运算,而是根据操作数的 Proxy 创建了新的 Proxy,其中 graph 记录了所有的 Node,每个 Node 记录了涉及到的运算符和输入 Node。

下一步执行 self.linear(x + self.param),调用 self.linear() 会执行 torch.nn.Module.__call__,而它在此前已被修改为 torch/fx/_symbolic_trace.py:570:module_call_wrapper,紧接着转到 torch/fx/_symbolic_trace.py:342:call_module。如果 module 不是 leaf module,即包含其他 torch.nn.Module 或容器的 module,则继续调用 forward() 函数,直到 module 是 leaf module,然后创建并返回新的 Proxy,类型为 call_module,创建 Proxy 的过程和前面相同。此处主要的函数执行流程:

[P003] > torch/fx/_symbolic_trace.py:570:module_call_wrapper [New]
[P004] > torch/fx/_symbolic_trace.py:759:_autowrap_check
[P004] < torch/fx/_symbolic_trace.py:759:_autowrap_check
[P004] > torch/fx/_symbolic_trace.py:342:call_module [New]
[P005] > torch/fx/_symbolic_trace.py:293:is_leaf_module [New]
[P005] < torch/fx/_symbolic_trace.py:293:is_leaf_module
[P005] > torch/fx/proxy.py:49:create_proxy
[P005] < torch/fx/proxy.py:49:create_proxy
[P004] < torch/fx/_symbolic_trace.py:342:call_module
[P003] < torch/fx/_symbolic_trace.py:570:module_call_wrapper

随后执行 .clamp(min=0.0, max=1.0),因为 self.linear() 返回的是 Proxy,.clamp() 会被转到 Proxy.__getattr__("clamp")(),它的实现在 torch/fx/proxy.py:243:__getattr__,其中直接返回一个 Attribute(torch/fx/proxy.py:325:__init__),Attribute 继承自 Proxy。在调用 Attribute() 时,tracer 创建并返回了新的 Proxy,类型是 call_method,target 是 clamp。此处主要的函数执行流程:

[P003] > torch/fx/proxy.py:243:__getattr__ [New]
[P004] > torch/fx/proxy.py:325:__init__ [New]
[P004] < torch/fx/proxy.py:325:__init__
[P003] < torch/fx/proxy.py:243:__getattr__
[P003] > torch/fx/proxy.py:340:__call__ [New]
[P004] > torch/fx/proxy.py:49:create_proxy
[P004] < torch/fx/proxy.py:49:create_proxy
[P003] < torch/fx/proxy.py:340:__call__

到此,示例代码中的 MyModule.forward() 执行完毕。Tracer.trace() 的最后一步是建立 output 节点,类型为 output,一个完整的 symbolic tracing 到此结束。

torch.fx.symbolic_trace 的最后一步是创建 torch.fx.GraphModule,实现于 graph_module.py:293:__new__GraphModule 继承自 torch.nn.Module,包含从 graph 中生成的 .graph, .code, .forward 属性。GraphModule 在初始化的过程中,会通过 setattr() 引用原 module 的属性。在 tracing 过程中捕获的 graph,也会设置为 GraphModule.graph,此时会触发 torch/fx/graph_module.py:624:recompile,其功能是根据 torch.fx.Graph 重新编译 GraphModule,每次对 .graph 的修改,都需要重新编译 GraphModule。

gm.recompile() 的第一步是 Python 代码生成torch/fx/graph_module.py#L634),代码生成的核心由 torch.fx.CodeGen 实现,具体过程详见 torch/fx/graph.py:297:_gen_python_code。它首先把一些内置名称添加到全局命名空间,例如 inf, None, torch。然后以逆序方式遍历图中的节点(torch/fx/graph.py#L389-L391),找到每个节点最后被使用的地方,从而在代码生成的过程中及时释放不用的节点。

生成 Python 代码在 torch/fx/graph.py#L475-L479代码生成逐节点进行,依次为每个 Node 生成对应的 Python 代码。上述示例代码的代码生成过程如下:

Index Node Op Target Args Kwargs Body Unused
0 x placeholder x () {}
1 param get_attr param () {} param = self.param
2 add call_function operator.add (x, param) {} add = x + param; x = param = None
3 linear call_module linear (add,) {} linear = self.linear(add); add = None
4 clamp call_method clamp (linear,) {‘min’: 0.0, ‘max’: 1.0} clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
5 output output output (clamp,) {} return clamp

以生成 add 节点对应的代码为例,进入到 emit_node() 后,node.opcall_functionnode.targetoperator.add,是 Python 中的魔术方法 __add__torch/fx/graph.py#L1400 定义 add 的模板是 '{} + {}',通过 format() 向上述模板中填入两个参数 xparam,最终得到字符串 add = x + params。在 delete_unused_values() 中,变量 xparam 是他们在 add 节点最后被用到的地方,因此生成代码 ; x = param = Noneadd 节点的代码生成完毕,按照此流程即可生成所有节点对应的代码。

body.append(f'{repr(node)}{maybe_type_annotation} = '
            f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')

如果存在通过 torch.fx.wrap() 定义的叶子函数,则在 torch/fx/graph.py#L490-L491 为其生成 wrap 语句。torch/fx/graph.py#L501 生成函数定义 def forward(self, x):。最后,一份完整的 Python 代码字符串生成完毕。

gm.recompile() 的第二步是 代码编译,它通过 torch/fx/graph_module.py#L69-L80 实现。核心代码是 exec(compile(src, key, 'exec'), globals),它调用 Python 内建函数 compile() 把生成的 Python 代码字符串编译为 Python 字节码,并调用 Python 内建函数 exec() 执行字节码。因为生成的代码是 forward() 函数的定义,所以在 exec() 之后,globals 中存在一个名为 forward 的可执行函数,该函数被直接设置为 graph module 的 forward 函数(torch/fx/graph_module.py#L638)。torch/fx/graph_module.py#L654cls.__call__ = call_wrappedGraphModule() 转发到 call_wrapped,最终会转到 forward() 函数。

到此,GraphModule 初始化完成,初始化 GraphModule 过程中的重要函数执行流程:

[P001] > torch/fx/graph_module.py:293:__new__ [New]
[P002] > torch/fx/graph_module.py:307:GraphModuleImpl [New]
[P002] < torch/fx/graph_module.py:307:GraphModuleImpl
[P001] < torch/fx/graph_module.py:293:__new__
[P001] > torch/fx/graph_module.py:311:__init__ [New]
[P002] > torch/fx/graph.py:632:nodes [New]
[P003] > torch/fx/graph.py:221:__init__ [New]
[P003] < torch/fx/graph.py:221:__init__
[P002] < torch/fx/graph.py:632:nodes
[P002] > torch/fx/graph.py:229:__iter__ [New]
[P002] < torch/fx/graph.py:229:__iter__
[P002] > torch/fx/graph_module.py:182:_copy_attr [New]
[P002] < torch/fx/graph_module.py:182:_copy_attr
[P002] > torch/nn/modules/module.py:1210:__setattr__
[P003] > torch/fx/graph_module.py:394:graph [New]
[P004] > torch/fx/graph_module.py:624:recompile [New]
[P005] > torch/fx/graph.py:1091:python_code [New]
[P006] > torch/fx/graph.py:115:__init__
[P006] < torch/fx/graph.py:115:__init__
[P006] > python3.8/contextlib.py:211:contextmanager [New]
[P006] < python3.8/contextlib.py:211:contextmanager
[P006] > python3.8/contextlib.py:108:__enter__ [New]
[P006] < python3.8/contextlib.py:108:__enter__
[P006] > torch/fx/graph.py:1153:_python_code [New]
[P007] > torch/fx/graph.py:297:_gen_python_code [New]
[P008] > torch/fx/graph.py:306:add_global [New]
[P008] < torch/fx/graph.py:306:add_global
[P008] > torch/fx/node.py:592:map_arg
[P008] < torch/fx/node.py:592:map_arg
[P008] > torch/fx/graph.py:412:emit_node
[P008] < torch/fx/graph.py:412:emit_node
[P008] > torch/fx/graph.py:393:delete_unused_values
[P008] < torch/fx/graph.py:393:delete_unused_values
[P008] > torch/fx/graph.py:290:additional_globals [New]
[P008] < torch/fx/graph.py:290:additional_globals
[P008] > torch/fx/graph.py:253:gen_fn_def [New]
[P008] < torch/fx/graph.py:253:gen_fn_def
[P007] < torch/fx/graph.py:297:_gen_python_code
[P006] < torch/fx/graph.py:1153:_python_code
[P006] > python3.8/contextlib.py:117:__exit__ [New]
[P006] < python3.8/contextlib.py:117:__exit__
[P005] < torch/fx/graph.py:1091:python_code
[P005] > torch/nn/modules/module.py:1210:__setattr__
[P005] < torch/nn/modules/module.py:1210:__setattr__
[P005] > torch/fx/graph_module.py:74:_forward_from_src [New]
[P006] > torch/fx/graph_module.py:69:_exec_with_source [New]
[P007] > torch/fx/graph_module.py:28:cache [New]
[P007] < torch/fx/graph_module.py:28:cache
[P007] > <eval_with_key>.0:4:<module> [New]
[P007] < <eval_with_key>.0:4:<module>
[P006] < torch/fx/graph_module.py:69:_exec_with_source
[P005] < torch/fx/graph_module.py:74:_forward_from_src
[P005] > torch/fx/graph_module.py:227:__init__ [New]
[P005] < torch/fx/graph_module.py:227:__init__
[P004] < torch/fx/graph_module.py:624:recompile
[P003] < torch/fx/graph_module.py:394:graph
[P002] < torch/nn/modules/module.py:1210:__setattr__
[P002] > torch/fx/graph_module.py:387:graph [New]
[P002] < torch/fx/graph_module.py:387:graph
[P001] < torch/fx/graph_module.py:311:__init__

到此,symbolic_trace() 的最后一步,整个符号跟踪完成。

符号跟踪总结:

  • 在初始化阶段,通过 monkey patch 把针对 Tensor 的函数修改为针对 Proxy 的函数;
  • 在跟踪阶段,用 Proxy 替换 Tensor 作为函数输入,依次在 Proxy 的方法中建立新的 Proxy,其中的 Node 构成了 Graph;
  • 在代码生成阶段,依次遍历 Node,逐个生成对应的 Python 代码字符串;
  • 在编译阶段,通过 Python 内建函数 exec(compile()) 将字符串编译为可执行函数;

子图重写源码 Link to heading

在符号跟踪示例代码的基础上,我们使用 subgraph_rewriter.replace_pattern() 来进行简单的图变换,复杂的图变换可以遵循和 replace_pattern() 相似的手段。在示例代码中,我们把 operator.add 替换为 operator.mul

from torch.fx import subgraph_rewriter

pattern = lambda x, y: x + y
replacement = lambda x, y: x * y

subgraph_rewriter.replace_pattern(symbolic_traced, pattern, replacement)

replace_pattern() 的源代码位于 torch/fx/subgraph_rewriter.py#L134子图重写的第一步是子图匹配,其中首先获取3张 torch.fx.Graph:GraphModule 中保存的 graph,pattern 和 replacement 对应的 graph。后两者通过 symbolic_trace() 获取 graph,symbolic_trace() 的实现细节参考前文。

在开始匹配前,首先创建了 matcher,它是 _SubgraphMatcher 类的实例,负责进行模式匹配,试图匹配由 pattern_graph 所确定的模式。

torch/fx/subgraph_rewriter.py#L260 依次遍历原图中的所有节点 anchor,调用 matcher.matches_subgraph_from_anchor(anchor) 来检查是否能够在 anchor 处成功匹配 pattern,核心实现在 torch/fx/subgraph_rewriter.py:45:_match_nodespattern_anchor_SubgraphMatcher 初始化时被设置为 output 节点,模式匹配从 pattern graph 的 output 节点开始匹配原图的 anchor 节点,并依次向上检查从 anchor 节点开始能否完整匹配 pattern graph 中的所有节点。匹配的逻辑如下:

  • pn (pattern node) 与 gn (graph node) 的属性相同,两者的 .op.target 必须相同。特例:pattern graph 中的 placeholder 节点可以匹配所有节点,pattern graph 中的 output 节点可以匹配原图中非 placeholder 的节点;
  • pn 如果是 output 节点,它的输入节点(producer)个数必须与 gn 的输入节点个数相同,因为 _SubgraphMatcher 目前只支持单个返回值的 patter graph,所以 output 节点的唯一的输入节点必须能成功匹配 gn 的其中一个输入节点,否则匹配失败,输入节点的匹配以递归调用 _match_nodes() 进行;
  • 对于其他类型的节点,匹配成功的标准是两者的输入节点数目相同,并且输入节点能够一一匹配,节点的匹配仍以递归方式进行;

子图匹配成功后,torch/fx/subgraph_rewriter.py#L296 还进一步限制 原图中成功匹配的节点,其 user 节点也必须在成功匹配的节点当中,外部节点引用子图中间节点会视为不匹配

对于每一个成功的子图匹配,torch/fx/subgraph_rewriter.py#L322 进一步限定了 子图匹配不能重叠(overlap),即当前成功匹配的子图中不能包含前面匹配成功的子图中所包含的节点。

子图重写的第二步是子图替换torch/fx/subgraph_rewriter.py#L328-L335 检查 pattern graph 和 replacement graph 中 placeholder 数目是否一致,torch/fx/subgraph_rewriter.py#L348-L349 把成功匹配的子图中的所有节点标记为已被替换,以防止其中的节点被后续成功匹配的子图重用。

子图替换的第一步是子图拷贝torch/fx/subgraph_rewriter.py#L362-L365 将替换图拷贝到原图中,在拷贝前整图如下:

graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp

拷贝后的原图中既包含 pattern 子图,又包含 replacement 子图,拷贝后的整图如下,其中 xparam 节点各有两个 user,分别是 addmul,而新拷贝的 mul 节点此时没有 user:

graph():
    %x : [#users=2] = placeholder[target=x]
    %param : [#users=2] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %mul : [#users=0] = call_function[target=operator.mul](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp

将替换图拷贝到原图中,需要修改原图中的插入点,默认插入点是 root 前,torch/fx/graph.py:777:inserting_before 把插入点修改为原图中配对替换图中 output 的节点,torch/fx/subgraph_rewriter.py#L364-L365 执行子图拷贝操作,将替换图拷贝到原图中,具体实现在 torch/fx/graph.py:647:graph_copy,其中会调用 torch/fx/graph.py:1028:node_copy 来逐个拷贝节点。

子图替换的第二步是在原图中引用替换图,实现在 torch/fx/subgraph_rewriter.py#L393-L400。这里 pn 是 pattern graph 中的 output 节点,rn 是替换图(replacement graph)中的 output 节点,对于上述示例代码,pn.all_input_nodes[add]rn.all_input_nodes[mul]gn_input 是原图(整图)中的 add 节点,rn_input_in_original_graph 是原图中新拷贝的 mul 节点。gn_input.replace_all_uses_with(rn_input_in_original_graph) 把原图中 add 节点的所有 user 的输入节点修正为 mul 节点,从而实现在原图中引用新拷贝的替换图。

修正后的原图如下,此时 add 节点没有 user,mul 节点有一个 user:

graph():
    %x : [#users=2] = placeholder[target=x]
    %param : [#users=2] = get_attr[target=param]
    %add : [#users=0] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %mul : [#users=1] = call_function[target=operator.mul](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%mul,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp

修正原图引用的核心是 torch/fx/node.py:465:replace_all_uses_with,其中依次遍历当前节点的各个 user,把 user 中对当前节点的引用修正为替换节点。

子图替换的第三步是删除原图中的被替换节点,实现在 torch/fx/subgraph_rewriter.py#L414-L417。这里以逆序的方式依次遍历原图中的所有节点,将 output 节点以外的无 user 节点从原图中删除。删除节点由 torch/fx/graph.py:750:erase_node 完成,被删除节点的输入节点会更新其 user 信息。

子图替换到此结束,最终的原图如下:

graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %mul : [#users=1] = call_function[target=operator.mul](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%mul,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp

子图重写的第三步是整图编译,它由 gm.recompile() 实现,此前已详述其实现,这里不在赘述。

subgraph_rewriter.replace_pattern() 实现的子图重写到此结束,复杂的子图重写流程可以参照其实现完成。整个过程中涉及到的主要函数执行流程:

[P000] > torch/fx/subgraph_rewriter.py:134:replace_pattern [New]
[P001] > torch/fx/graph_module.py:387:graph [New]
[P001] < torch/fx/graph_module.py:387:graph
[P001] > torch/fx/_symbolic_trace.py:827:symbolic_trace [New]
[P001] < torch/fx/_symbolic_trace.py:827:symbolic_trace
[P001] > torch/fx/_symbolic_trace.py:827:symbolic_trace
[P001] < torch/fx/_symbolic_trace.py:827:symbolic_trace
[P001] > torch/fx/subgraph_rewriter.py:19:__init__ [New]
[P001] < torch/fx/subgraph_rewriter.py:19:__init__
[P001] > torch/fx/subgraph_rewriter.py:33:matches_subgraph_from_anchor
[P002] > torch/fx/subgraph_rewriter.py:45:_match_nodes
[P002] < torch/fx/subgraph_rewriter.py:45:_match_nodes
[P001] < torch/fx/subgraph_rewriter.py:33:matches_subgraph_from_anchor
[P001] > torch/fx/subgraph_rewriter.py:264:pattern_is_contained [New]
[P001] < torch/fx/subgraph_rewriter.py:264:pattern_is_contained
[P001] > torch/fx/subgraph_rewriter.py:312:overlaps_with_prev_match [New]
[P001] < torch/fx/subgraph_rewriter.py:312:overlaps_with_prev_match
[P001] > torch/fx/subgraph_rewriter.py:341:mark_node_as_replaced [New]
[P001] < torch/fx/subgraph_rewriter.py:341:mark_node_as_replaced
[P001] > torch/fx/graph.py:777:inserting_before [New]
[P002] > torch/fx/graph.py:210:__init__ [New]
[P002] < torch/fx/graph.py:210:__init__
[P001] < torch/fx/graph.py:777:inserting_before
[P001] > torch/fx/graph.py:214:__enter__ [New]
[P001] < torch/fx/graph.py:214:__enter__
[P001] > torch/fx/graph.py:647:graph_copy [New]
[P002] > torch/fx/graph.py:1028:node_copy [New]
[P002] < torch/fx/graph.py:1028:node_copy
[P001] < torch/fx/graph.py:647:graph_copy
[P001] > torch/fx/graph.py:217:__exit__ [New]
[P001] < torch/fx/graph.py:217:__exit__
[P001] > torch/fx/node.py:465:replace_all_uses_with [New]
[P001] < torch/fx/node.py:465:replace_all_uses_with
[P001] > torch/fx/graph.py:237:__reversed__
[P001] < torch/fx/graph.py:237:__reversed__
[P001] > torch/fx/graph.py:750:erase_node [New]
[P001] < torch/fx/graph.py:750:erase_node
[P001] > torch/fx/graph_module.py:624:recompile
[P001] < torch/fx/graph_module.py:624:recompile
[P000] < torch/fx/subgraph_rewriter.py:134:replace_pattern

图执行源码 Link to heading

在上述示例代码的基础上,我们增加以下代码来执行修改过后的图:

x = torch.randn((3, 4))
dy =torch.randn((3, 5))

y = symbolic_traced(x)
y.backward(dy)

检查 symbolic_traced.code,生成的 Python 代码如下:

def forward(self, x):
    param = self.param
    mul = x * param;  x = param = None
    linear = self.linear(mul);  mul = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp

symbolic_traced 的类型是 GraphModule,在执行 y = symbolic_traced(x),会调用 GraphModule.__call__,它在 gm.recompile() 时已经被设置为 call_wrappedtorch/fx/graph_module.py#L654)。其中 self._wrapped_call_WrappedCall 的实例,因而被转到 torch/fx/graph_module.py#L262

torch.fx.GraphModule 转到 torch.nn.Module 的核心在 torch/fx/graph_module.py#L267super(self.cls, obj).__call__(*args, **kwargs)self.clsGraphModule,它继承自 torch.nn.Module,因而转为调用 torch.nn.Module.__call__(*args, **kwargs),此后的执行流程与普通的 torch.nn.Module 执行过程一致。

值得注意的是,在执行 linear = self.linear(mul) 后,Python 执行了 mul = None这是在符号跟踪过程中添加的代码,在变量最后一次使用后显示释放变量,因而调用了 THPVariable_subclass_dealloc 来释放内存,它是在 torch/csrc/autograd/python_variable.cpp#L1800 所指定的 tp_dealloc 属性,在 Tensor 的引用计数变为零时自动调用。

图执行过程中涉及到的主要函数执行过程:

[P000] > torch/fx/graph_module.py:651:call_wrapped [New]
[P001] > torch/fx/graph_module.py:262:__call__ [New]
[P002] > torch/nn/modules/module.py:1124:_call_impl [New]
[P003] > <eval_with_key>.3:4:forward [New]
[C004] > torch/csrc/autograd/generated/python_variable_methods.cpp:16684:TypeError_to_NotImplemented_<torch::autograd::THPVariable_mul> [New]
[C004] < torch/csrc/autograd/generated/python_variable_methods.cpp:16684:TypeError_to_NotImplemented_<torch::autograd::THPVariable_mul>
[P004] > torch/nn/modules/module.py:1124:_call_impl
[P005] > torch/nn/modules/linear.py:113:forward [New]
[C006] > torch/csrc/autograd/generated/python_nn_functions.cpp:1798:THPVariable_linear [New]
[C006] < torch/csrc/autograd/generated/python_nn_functions.cpp:1798:THPVariable_linear
[P005] < torch/nn/modules/linear.py:113:forward
[P004] < torch/nn/modules/module.py:1124:_call_impl
[C004] > torch/csrc/autograd/python_variable.cpp:1498:THPVariable_subclass_dealloc [New]
[C005] > torch/csrc/autograd/python_variable.cpp:379:THPVariable_tryResurrect [New]
[C005] < torch/csrc/autograd/python_variable.cpp:379:THPVariable_tryResurrect
[C004] < torch/csrc/autograd/python_variable.cpp:1498:THPVariable_subclass_dealloc
[C004] > torch/csrc/autograd/generated/python_variable_methods.cpp:4920:THPVariable_clamp [New]
[C004] < torch/csrc/autograd/generated/python_variable_methods.cpp:4920:THPVariable_clamp
[P003] < <eval_with_key>.3:4:forward
[P002] < torch/nn/modules/module.py:1124:_call_impl
[P001] < torch/fx/graph_module.py:262:__call__
[P000] < torch/fx/graph_module.py:651:call_wrapped
[P000] > torch/_tensor.py:340:backward [New]
[P001] > torch/autograd/__init__.py:85:backward [New]
[C002] > torch/csrc/autograd/python_engine.cpp:153:THPEngine_run_backward [New]
[C002] < torch/csrc/autograd/python_engine.cpp:153:THPEngine_run_backward
[P001] < torch/autograd/__init__.py:85:backward
[P000] < torch/_tensor.py:340:backward

控制流 Link to heading

torch.fx 目前只支持 静态控制流,即基于 placeholder 计算出来的值不能作为控制语句的条件。以下面的代码为例:

def func(x, y):
    return y if x > 0 else -y

symbolic_traced = symbolic_trace(func)

执行该代码会抛出异常,tracer 检测到在符号跟踪过程中被跟踪的变量被用于控制流的输入:

torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow

在符号跟踪的过程中,x 是初始化时创建的 Proxyx > 0 触发 Proxy.__gt__(),它创建了一个新的 Proxy,target 属性是 operator.gt。Python 中 if 语句的条件是一个 bool 类型的值,因此 if x > 0 会调用 Proxy.__bool__()gt 这个 Proxy 转为 bool 类型,实现在 torch/fx/proxy.py#L262。该函数立刻跳到 torch/fx/proxy.py#L153,并抛出 TraceError

这里涉及的主要函数有:

[P000] > torch/fx/_symbolic_trace.py:827:symbolic_trace [New]
[P001] > torch/fx/_symbolic_trace.py:499:trace [New]
[P002] > sym_trace.py:44:func [New]
[P003] > torch/fx/proxy.py:383:impl [New]
[P004] > torch/fx/proxy.py:49:create_proxy
[P004] < torch/fx/proxy.py:49:create_proxy
[P003] < torch/fx/proxy.py:383:impl
[P003] > torch/fx/proxy.py:262:__bool__ [New]
[P004] > torch/fx/proxy.py:153:to_bool [New]
[P004] < torch/fx/proxy.py:153:to_bool
[P003] < torch/fx/proxy.py:262:__bool__
[P002] < sym_trace.py:44:func
[P001] < torch/fx/_symbolic_trace.py:499:trace
[P000] < torch/fx/_symbolic_trace.py:827:symbolic_trace

torch/fx/proxy.py#L153-L184 可以看出,符号跟踪目前并不支持将符号作为布尔类型(如动态控制流中的条件语句)、将符号转为迭代器(如动态控制流中的循环)、以及将符号作为字典。

虽然符号跟踪不支持动态控制流,但它 支持静态控制流。例如下面的代码,flag 是某种模型架构探索时设置的参数,在符号跟踪期间虽然后 if/else 语句,但它是静态的,不会出发异常。

flag = True
def func(x, y):
    return y if flag else x

symbolic_traced = symbolic_trace(func)

符号跟踪后对应的图如下,可以看到 符号跟踪过程中并不记录静态控制流

graph():
    %x : [#users=0] = placeholder[target=x]
    %y : [#users=1] = placeholder[target=y]
    return y

最后生成的 Python 代码如下:

def forward(self, x, y):
    return y

参考文献 Link to heading