Guard Link to heading

随着函数调用栈返回到 convert_frame.py#L327TorchDynamo 编译的最后需要为 Guard 生成 Python 代码,从而在后续执行编译好的函数时检查函数的输入信息是否发生了变化,从而决定是否需要重新编译:

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/convert_frame.py#L366
check_fn = CheckFunctionManager(
    output,
    locals,
    globals,
    hooks.guard_fail_fn if hooks else None,
)

guarded_code = GuardedCode(out_code, check_fn.check_fn)

CheckFunctionManagerGuard 生成 Python 代码,OutputGraph 中的 Guard 是 TorchDynamo 在做字节码翻译的过程中逐步从输入收集并传播到输出节点的。在编译好的子图 __compiled_fn_0() 中收集到了 3 个 Guard,分别是张量 ab、库 torch:

Guard(name='a', source=<GuardSource.LOCAL: 0>, create_fn=<function GuardBuilder.TENSOR_MATCH at 0x7f0dd813c820>, is_volatile=False, guard_types=None, code_list=None, obj_weakref=None, guarded_class_weakref=None),
Guard(name='b', source=<GuardSource.LOCAL: 0>, create_fn=<function GuardBuilder.TENSOR_MATCH at 0x7f0dd813c820>, is_volatile=False, guard_types=None, code_list=None, obj_weakref=None, guarded_class_weakref=None),
Guard(name='torch', source=<GuardSource.GLOBAL: 1>, create_fn=<function GuardBuilder.FUNCTION_MATCH at 0x7f0dd813c160>, is_volatile=False, guard_types=None, code_list=None, obj_weakref=None, guarded_class_weakref=None)

compile_check_fn() 负责从 Guard 中生成可执行的 Python 代码,为了降低运行时检测输入是否发生变化的函数的开销,TorchDynamo 把 Guard 的检测功能实现在 C++ 中。以 TensorGuards 为例,它初始化在 guards.py#L678,后端的 TensorGuards_init() 实现在 guards.cpp#L179:

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/guards.py#L678
tensor_guards = TensorGuards(
    *tensor_check_examples, dynamic_shapes=config.dynamic_shapes
)
check_tensors_fn = tensor_guards.check
check_tensors_verbose_fn = tensor_guards.check_verbose
code_parts.append(f"___check_tensors({', '.join(tensor_check_names)})")

此处的函数调用栈为:

这里的 check_tensors_fntensor_guards.check,它是在 TorchDynamo 的 C++ 扩展中实现的,C++ 函数是 TensorGuards_check:

// https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/csrc/dynamo/guards.cpp#L308
static PyMethodDef TensorGuards_methods[] = {
    {"check", (PyCFunction)TensorGuards_check, METH_VARARGS, ""},
    {"check_verbose",
     (PyCFunction)(void*)TensorGuards_check_verbose,
     METH_VARARGS | METH_KEYWORDS,
     "verbose fail reasons for failed checks"},
    {NULL} /* Sentinel */
};

Guard 的 Python 前端函数由 guards.py#L729 生成,对于 __compiled_fn_0()lambda 函数体为 ___guarded_code.valid and ___check_tensors(a, b)set_guard_fail_hook() 设置了 Guard 检测失败时执行 Python 函数 guard_fail_hook(),其中检测并记录了失败原因。

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/guards.py#L729
        py_code = f"""\
def ___make_guard_fn({','.join(closure_vars.keys())}):
    return lambda {args}: {code}
"""
        if os.environ.get("TORCHDYNAMO_PRINT_GUARDS", None) == "1":
            print("GUARDS", code)
        set_guard_fail_hook(guard_fail_hook)
        out: Dict[str, Any] = dict()
        # print("RUNNING PY CODE", py_code)
        exec(py_code, global_builder.scope, out)
        guard_fn = out["___make_guard_fn"](*closure_vars.values())

exec(py_code, global_builder.scope, out) 创建了 Python 函数 ___make_guard_fn(),和 out["___make_guard_fn"](*closure_vars.values()) 调用 ___make_guard_fn() 并生成可执行的 guard_fn,也就是 ___guarded_code.valid and ___check_tensors(a, b)

回到 convert_frame.py#L366check_fn 即是刚创建的 guard_fnGuardedCode 保存了编译好的子图 __compiled_fn_0()check_fn

# https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/_dynamo/convert_frame.py#L366
check_fn = CheckFunctionManager(
    output,
    locals,
    globals,
    hooks.guard_fail_fn if hooks else None,
)

guarded_code = GuardedCode(out_code, check_fn.check_fn)

一次完整的子图编译到此结束,依次从 _compile() 一直返回到 _custom_eval_frame():

Cache 与运行子图 Link to heading

回到 _custom_eval_frame(),我们得到了编译好的 GuardedCodecreate_cache_entry() 往当前用户函数的 frame->f_code 中写入一条 CacheEntry,记录了 check_fn 和编译好的 codecache 以单向链表组织,新创建的 CacheEntry 放在链表头部

// https://github.com/pytorch/pytorch/blob/fe05266fda4f908130dea7cbac37e9264c0429a2/torch/csrc/dynamo/eval_frame.c#L712
PyObject* result =
    call_callback(callback, frame, cache_size(extra));
if (result == NULL) {
  // internal exception, returning here will leak the exception into user code
  // this is useful for debugging -- but we dont want it to happen outside of
  // testing
  return NULL;
} else if (result != Py_None) {
  DEBUG_TRACE("create cache %s", name(frame));
  extra = create_cache_entry(extra, result);
  Py_DECREF(result);
  set_extra(frame->f_code, extra);
  // Re-enable custom behavior
  eval_frame_callback_set(callback);
  return eval_custom_code(tstate, frame, extra->code, throw_flag);
} else {
  DEBUG_TRACE("create skip %s", name(frame));
  Py_DECREF(result);
  destroy_cache_entry(extra);
  set_extra(frame->f_code, SKIP_CODE);
  // Re-enable custom behavior
  eval_frame_callback_set(callback);
  return eval_frame_default(tstate, frame, throw_flag);
}

TorchDynamo 在开始捕获计算图时,清除了 Frame Evaluation 的回调函数,确保只在最外层用户函数调用回调函数 catch_errors()。此时子图已经抓取并编译,eval_frame_callback_set(callback) 将 Frame Evaluation 的回调函数重置为 catch_errors()eval_custom_code(tstate, frame, extra->code, throw_flag) 创建了一个新的 Python Frame,并运行编译好的函数,实现于 eval_custom_code()。注意,frame->f_code 包含的是用户函数的字节码,extra->code 包含的是 TorchDynamo 编译过后的字节码。对于编译过后的 toy_example(),其内容是:

      0 LOAD_GLOBAL              3 (__compiled_fn_0)
      2 LOAD_FAST                0 (a)
      4 LOAD_FAST                1 (b)
      6 CALL_FUNCTION            2
      8 UNPACK_SEQUENCE          2
     10 STORE_FAST               2 (x)
     12 POP_JUMP_IF_FALSE       24
     14 LOAD_GLOBAL              4 (__resume_at_30_1)
     16 LOAD_FAST                1 (b)
     18 LOAD_FAST                2 (x)
     20 CALL_FUNCTION            2
     22 RETURN_VALUE
>>   24 LOAD_GLOBAL              5 (__resume_at_38_2)
     26 LOAD_FAST                1 (b)
     28 LOAD_FAST                2 (x)
     30 CALL_FUNCTION            2
     32 RETURN_VALUE

eval_custom_code() 中直接调用了 eval_frame_default() 来执行上面的字节码,所以此处不会再次触发 TorchDynamo 定制的 Frame Evaluation 函数 custom_eval_frame_shim。执行上面经过修改的 toy_example() 字节码,其中首先调用 TorchDynamo 编译好的函数 __compiled_fn_0()。此前在编译 __compiled_fn_0() 时,output_graph.py#L616 通过 compiled_fn = disable(compiled_fn) 禁止在已编译的子图上再次启用 TorchDynamodisable() 通过 DisableContext 返回 _fn()。因此,执行 __compiled_fn_0() 会再次来到 _fn(),只不过此时的 callbackNone 而不是 catch_errors(),所以 fn(*args, **kwargs) 会在没有 TorchDynamo 的上下文下中执行 __compiled_fn_0(),这个过程与 PyTorch 以 eager 模式执行一个函数相同。对于 __compiled_fn_0(),它是经过编译的 GraphModule.forward(),编译好的 Python 代码如下,在 eager 模式中依次执行其中的代码:

def forward(self, a : torch.Tensor, b : torch.Tensor):
    abs_1 = torch.abs(a)
    add = abs_1 + 1;  abs_1 = None
    truediv = a / add;  a = add = None
    sum_1 = b.sum();  b = None
    lt = sum_1 < 0;  sum_1 = None
    return (truediv, lt)

此时的函数调用栈为:

执行完编译过的子图 __compiled_fn_0(),程序返回到 Python 解释器,下一条字节码是 POP_JUMP_IF_FALSE,对应原来 toy_example() 中的 if。在 Python 解释器中执行 POP_JUMP_IF_FALSE 24 会从 Python 虚拟机栈顶取出一个元素,将其转化为 bool 类型,根据结果判断是否跳转。此时的栈顶为 Tensor b.sum() < 0,Python 虚拟机调用 THPVariable_bool_scalar() 将其转为 bool 值,此次求值结果为 False,因此跳转到 offset 24 处 开始继续执行。

>>   24 LOAD_GLOBAL              5 (__resume_at_38_2)
     26 LOAD_FAST                1 (b)
     28 LOAD_FAST                2 (x)
     30 CALL_FUNCTION            2
     32 RETURN_VALUE

Offset 24 处是此前 TorchDynamo 由于 graph break 而创建的函数 __resume_at_38_2()执行该 Python 函数会触发 TorchDynamo 设置的 Frame Evaluation 函数 custom_eval_frame_shim(),在 _custom_eval_frame() 检查 __resume_at_38_2() 的缓存中是否存在已经编译好的结果,此处发生 cache miss,所以通过 call_callback() 调用设置好的回调函数 catch_errors针对 __resume_at_38_2() 开启一次全新的子图编译过程

逐条翻译 __resume_at_38_2() 的字节码的流程与此前的 toy_example() 基本一致,此时的函数调用栈为:

此前已经讲过,__resume_at_38_2() 的字节码为:

      0 JUMP_ABSOLUTE           40
      2 LOAD_FAST                2 (a)
      4 LOAD_GLOBAL              0 (torch)
      6 LOAD_ATTR                1 (abs)
      8 LOAD_FAST                2 (a)
     10 CALL_FUNCTION            1
     12 LOAD_CONST               1 (1)
     14 BINARY_ADD
     16 BINARY_TRUE_DIVIDE
     18 STORE_FAST               1 (x)
     20 LOAD_FAST                0 (b)
     22 LOAD_ATTR                2 (sum)
     24 CALL_FUNCTION            0
     26 LOAD_CONST               2 (0)
     28 COMPARE_OP               0 (<)
     30 POP_JUMP_IF_FALSE       40
     32 LOAD_FAST                0 (b)
     34 LOAD_CONST               3 (-1)
     36 BINARY_MULTIPLY
     38 STORE_FAST               0 (b)

>>   40 LOAD_FAST                1 (x)
     42 LOAD_FAST                0 (b)
     44 BINARY_MULTIPLY
     46 RETURN_VALUE

step() 中翻译 __resume_at_38_2() 字节码的流程如下:

Offset 0, JUMP_ABSOLUTE: 直接通过 jump 修改当前字节码指针 instruction_pointer 为 40,从而跳转到 offset 40。

Offset 40, LOAD_FAST: 与此前所讲 LOAD_FAST 过程一致,从 symbolic_locals 取出 变量 xTensorVariable,然后把它压到栈上,栈上的内容为 [TensorVariable(x)]

Offset 42, LOAD_FAST: 从 symbolic_locals 取出 变量 bTensorVariable,然后把它压到栈上,栈上的内容为 [TensorVariable(x), TensorVariable(b)]

Offset 44, BINARY_MULTIPLY: 真正实现位于 impl,与此前所讲 BINARY_ADD 类似,BINARY_MULTIPLYInstructionTranslatorBase 中被设为 stack_op(operator.mul),随后到代码路径与 BINARY_ADD 一致。出栈 TensorVariable(x)TensorVariable(b),为输出创建新的 Proxy,类型是 call_function,目标是 operator.mul,参数是 Proxy(x)Proxy(b)。再为其创建了新的 TensorVariable 来跟踪收集到的 Guard,以 FakeTensor 运行节点,最后把结果压栈,栈上的内容为 [TensorVariable(mul)]

BINARY_MULTIPLY = stack_op(operator.mul)

Offset 46, RETURN_VALUE: 调用 compile_subgraph() 来编译计算图,原因是 return_value,代表一个 Python 函数的计算图被完整的捕获下来,在函数返回处的 RETURN_VALUE 字节码处开始子图编译。此时捕获的计算图为:

opcode         name    target                   args       kwargs
-------------  ------  -----------------------  ---------  --------
placeholder    b       b                        ()         {}
placeholder    x       x                        ()         {}
call_function  mul     <built-in function mul>  (x, b)     {}
output         output  output                   ((mul,),)  {}

这里因为函数返回而触发的子图编译流程与此前的因为 graph break 而触发的子图编译流程相同,此处不再赘述。原函数 __resume_at_38_2() 对应的字节码:

      0 JUMP_ABSOLUTE           40
      2 LOAD_FAST                2 (a)
      4 LOAD_GLOBAL              0 (torch)
      6 LOAD_ATTR                1 (abs)
      8 LOAD_FAST                2 (a)
     10 CALL_FUNCTION            1
     12 LOAD_CONST               1 (1)
     14 BINARY_ADD
     16 BINARY_TRUE_DIVIDE
     18 STORE_FAST               1 (x)
     20 LOAD_FAST                0 (b)
     22 LOAD_ATTR                2 (sum)
     24 CALL_FUNCTION            0
     26 LOAD_CONST               2 (0)
     28 COMPARE_OP               0 (<)
     30 POP_JUMP_IF_FALSE       40
     32 LOAD_FAST                0 (b)
     34 LOAD_CONST               3 (-1)
     36 BINARY_MULTIPLY
     38 STORE_FAST               0 (b)

>>   40 LOAD_FAST                1 (x)
     42 LOAD_FAST                0 (b)
     44 BINARY_MULTIPLY
     46 RETURN_VALUE

被 TorchDynamo 编译过后 __resume_at_38_2() 的字节码为:

 0 LOAD_GLOBAL              3 (__compiled_fn_3)
 2 LOAD_FAST                0 (b)
 4 LOAD_FAST                1 (x)
 6 CALL_FUNCTION            2
 8 UNPACK_SEQUENCE          1
10 RETURN_VALUE

其中捕获的子图 __compiled_fn_3(),对应的 GraphModule 的 Python 代码为:

def forward(self, b : torch.Tensor, x : torch.Tensor):
    mul = x * b;  x = b = None
    return (mul,)

随后返回到 convert_frame.py#L366,为编译过后的代码添加 Guard: ___guarded_code.valid and ___check_tensors(b, x),确保张量 bx 的信息在下次执行时没有发生变化。到此 __resume_at_38_2() 的子图编译完成,回到 _custom_eval_frame()将编译好的代码保存在缓存中,调用 eval_custom_code() 执行编译过的函数 __resume_at_38_2(),其中直接调用编译好的子图 __compiled_fn_3()。执行 __compiled_fn_3() 前再次以 DisableContext 模式进入 _fn,以 eager 模式执行编译好的子图 __compiled_fn_3(),此处的函数调用栈为:

__compiled_fn_3()__resume_at_38_2()toy_example() 一路返回到 test()for 循环的第 1 次 toy_example() 调用完毕。

def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

def test(a, b):
    for i in range(4):
        toy_example(a, b * (-1) ** i)

第 2 次调用 toy_example() 时,eval_frame.c#L694 通过 lookup(extra, frame, NULL) 检查缓存中是否存在已经编译好的函数,lookup() 依次遍历缓存中的条目,通过 PyObject_Call(e->check_fn, noargs, f_locals) 调用此前为 Guard 生成的 check_fn 执行检查。TorchDynamo 以 LRU 方式管理缓存,在缓存命中 (cache hit) 时把命中的缓存条目挪到链表头部,在缓存缺失 (cache miss) 时返回 NULL。对于案例中的 toy_example(),调用 TensorGuards_check 检查输入的张量信息与子图编译时的张量信息是否匹配。因为此时张量信息没有改变,所以缓存命中,执行编译好的 toy_example()。此时的函数调用栈为:

在编译好的 toy_example() 中首先以 eager 模式执行此前编译好的子图 __compiled_fn_0(),涉及的函数调用栈为:

执行完成后返回到 Python 解释器,把返回的张量 b.sum() < 0 转为 Python 中的 bool 值并判定结果,这次结果为 True,调用函数 __resume_at_30_1()由此再次进入 TorchDynamo 定制的 Frame Evaluation 函数 custom_eval_frame_shim(),因为 if 分支此前没有编译过,所以缓存中不存在对应的条目,开始抓取新的子图并编译__resume_at_30_1() 的字节码为:

      0 JUMP_ABSOLUTE           32
      2 LOAD_FAST                2 (a)
      4 LOAD_GLOBAL              0 (torch)
      6 LOAD_ATTR                1 (abs)
      8 LOAD_FAST                2 (a)
     10 CALL_FUNCTION            1
     12 LOAD_CONST               1 (1)
     14 BINARY_ADD
     16 BINARY_TRUE_DIVIDE
     18 STORE_FAST               1 (x)
     20 LOAD_FAST                0 (b)
     22 LOAD_ATTR                2 (sum)
     24 CALL_FUNCTION            0
     26 LOAD_CONST               2 (0)
     28 COMPARE_OP               0 (<)
     30 POP_JUMP_IF_FALSE       40

>>   32 LOAD_FAST                0 (b)
     34 LOAD_CONST               3 (-1)
     36 BINARY_MULTIPLY
     38 STORE_FAST               0 (b)

>>   40 LOAD_FAST                1 (x)
     42 LOAD_FAST                0 (b)
     44 BINARY_MULTIPLY
     46 RETURN_VALUE

Offset 0, JUMP_ABSOLUTE: 直接通过 jump 修改当前字节码指针 instruction_pointer 为 32,从而跳转到 offset 32。

Offset 32, LOAD_FAST: 与此前所讲 LOAD_FAST 过程一致,从 symbolic_locals 取出变量 bTensorVariable,然后把它压到栈上,栈上的内容为 [TensorVariable(b)]

Offset 34, LOAD_CONST: 加载常量 -1,创建 ConstantVariable 并压栈,此时栈上的内容为 [TensorVariable(b), ConstantVariable(-1)]

Offset 36, BINARY_MULTIPLY: 真正实现位于 implBINARY_MULTIPLYInstructionTranslatorBase 中被设为 stack_op(operator.mul),出栈 TensorVariable(b)TensorVariable(-1),为输出创建新的 Proxy,类型是 call_function,目标是 operator.mul,参数是 Proxy(b)Proxy(-1)。再为其创建了新的 TensorVariable 来跟踪收集到的 Guard,以 FakeTensor 运行节点,最后把结果压栈,栈上的内容为 [TensorVariable(mul)]

Offset 38, STORE_FAST: 出栈最后一个元素 TensorVariable(mul),并把它存在 self.symbolic_locals[inst.argval] 中,这里 self.symbolic_locals 跟踪当前 frame 的局部变量,变量名由 inst.argval 指定,即 b,此后栈上内容为空。

Offset 40, LOAD_FAST: 从 symbolic_locals 取出变量 xTensorVariable,然后把它压到栈上,栈上的内容为 [TensorVariable(x)]

Offset 42, LOAD_FAST: 从 symbolic_locals 取出 变量 bTensorVariable,然后把它压到栈上,栈上的内容为 [TensorVariable(x), TensorVariable(b)]

Offset 44, BINARY_MULTIPLY: 出栈 TensorVariable(x)TensorVariable(b),为输出创建新的 Proxy,类型是 call_function,目标是 operator.mul,参数是 Proxy(x)Proxy(b)。再为其创建了新的 TensorVariable 来跟踪收集到的 Guard,以 FakeTensor 运行节点,最后把结果压栈,栈上的内容为 [TensorVariable(mul)]

Offset 46, RETURN_VALUE: 调用 compile_subgraph() 来编译计算图,原因是 return_value,代表一个 Python 函数的计算图被完整的捕获下来,在函数返回处的 RETURN_VALUE 字节码处开始子图编译。此时捕获的计算图为:

opcode         name    target                   args         kwargs
-------------  ------  -----------------------  -----------  --------
placeholder    b       b                        ()           {}
placeholder    x       x                        ()           {}
call_function  mul     <built-in function mul>  (b, -1)      {}
call_function  mul_1   <built-in function mul>  (x, mul)     {}
output         output  output                   ((mul_1,),)  {}

这里因为函数返回而触发的子图编译流程与此前的因为 graph break 而触发的子图编译流程相同,此处不再赘述。__resume_at_30_1() 经 TorchDynamo 编译过后的字节码为:

 0 LOAD_GLOBAL              3 (__compiled_fn_4)
 2 LOAD_FAST                0 (b)
 4 LOAD_FAST                1 (x)
 6 CALL_FUNCTION            2
 8 UNPACK_SEQUENCE          1
10 RETURN_VALUE

其中捕获的子图 __compiled_fn_4(),对应的 GraphModule 的 Python 代码为:

def forward(self, b : torch.Tensor, x : torch.Tensor):
    mul = b * -1;  b = None
    mul_1 = x * mul;  x = mul = None
    return (mul_1,)

同样,返回到 convert_frame.py#L366,为编译过后的代码添加 Guard: ___guarded_code.valid and ___check_tensors(b, x),确保张量 bx 的信息在下次执行时没有发生变化。到此 __resume_at_30_1() 的子图编译完成,回到 _custom_eval_frame()将编译好的代码保存在缓存中,调用 eval_custom_code() 执行编译过的函数 __resume_at_30_1(),其中直接调用编译好的子图 __compiled_fn_4()。执行 __compiled_fn_4() 前再次以 DisableContext 模式进入 _fn,以 eager 模式执行编译好的子图 __compiled_fn_4()

__compiled_fn_4()__resume_at_30_1()toy_example() 一路返回到 test()for 循环的第 2 次 toy_example() 调用完毕。此后的两次循环,因为 toy_example() 中涉及的 3 个函数 __compiled_fn_0()__resume_at_30_1() (编译为 __compiled_fn_3())、__resume_at_38_2() (编译为 __compiled_fn_4()) 都存在被编译的版本,所以以后在运行时都直接调用编译后的版本,整个 toy_example() 的编译、运行过程到此结束。