相比读博客、看视频学习一个开源项目,读源代码是比较痛苦的。然而 阅读源代码是成为优秀程序员的起点,不读源代码往往是知其然而不知其所以然。

PyTorch 的代码对初学者来说并不好读,原因在于:

  • 有很多 C++ 代码是在编译过程中生成的;
  • PyTorch 前端充分利用了 Python 的动态特性,比如 Python 函数可能经过 N 个装饰器修饰;
  • 混合 Python/C++ 调用,比如从 Python 调 C++ 调 Python 再调 C++ 在 PyTorch 中是比较常见的事情;

我从 PyTorch 0.3 开始接触其源代码,期间读过不少 PyTorch 源代码,也修改过其中一些,时至今日依然在坚持读 PyTorch 最新的代码。以下是我读PyTorch 源代码的一些经验,仅供参考:

  1. 明确目标: PyTorch代码很多,建议每次只读一个专题或者从一个特定的问题出发,比如 PyTorch AMP 是怎么实现的;

  2. 把握全局: 一上来直接读代码会有很多障碍,很多术语不明所以,先通读这个专题下官方的教程、文档,以及一些写的好的第三方博客,确保自己对这个专题下的内容有初步的认知。以下是一些初步了解 PyTorch 特定专题内容的比较好的资源:

  3. Debug Build: 一定要 build debug 版的 PyTorch,并保留在编译过程中生成的源代码,否则很难找到 PyTorch 函数调用栈中一些函数的来源;

    以现有的 PyTorch v2.0.0 分支为例,build debug 版 PyTorch 的方法:

    export DEBUG=1
    python setup.py bdist_wheel
    pip install dist/torch*.whl
    

    下面这些文件是 PyTorch 在编译过程中生成的,可以拷贝到 VSCode 中补全 PyTorch 源代码:

    # Compress the generated sources during compilation
    tar -zcf generated.tar.gz \
        torch/csrc/autograd/generated \
        torch/csrc/lazy/generated \
        build/aten/src/ATen
    

    如果不想编译 PyTorch 源代码,但又想获取 PyTorch 在编译过程中生成的源代码,可以执行以下命令:

    cd ${PYTORCH_ROOT}
    python tools/setup_helpers/generate_code.py
        --native-functions-path "aten/src/ATen/native/native_functions.yaml"
        --tags-path "aten/src/ATen/native/tags.yaml"
        --force_schema_registration
        --gen_lazy_ts_backend
    
  4. 静态读代码: 有了完整的 PyTorch 源代码之后就可以开始读了,网上有很多 VSCode 教程,设置好 VSCode 的 Python 和 C++ 插件,方便在函数之间跳转,可以解决一大部分的函数跳转;

  5. 动态读代码: 静态读代码的问题是常常搞不清函数的执行流程,此时在运行过程中动态读执行过的代码就很有帮助,善用 gdbpdb 可以有效辅助读源代码。

    对于 PyTorch 来说,在安装 debug build 版 PyTorch 的基础上,推荐安装 python3-dbg:

    sudo apt install gdb python3 python3-dbg
    

    然后针对我们想要了解的专题,写一个最简单的例子 (minimal example),比如我们想知道一个 CUDA kernel 是怎么在 PyTorch 中调用起来的,可以用下面的代码:

    import torch
    
    def main():
        x = torch.ones(4, device="cuda")
        x * 2
    
    if __name__ == '__main__':
        main()
    

    我们可以给 cudaLaunchKernel 设置断点,然后通过 btpy-bt 查看函数调用栈:

    $ gdb python3
    (gdb) b cudaLaunchKernel
    Breakpoint 2 at 0x7fffb21bcb70
    (gdb) run test.py
    
    Thread 1 "python" hit Breakpoint 2, 0x00007fffb21bcb70 in cudaLaunchKernel ()
       from /usr/local/cuda/lib64/libcudart.so.12
    (gdb) bt
    #0  0x00007fffb21bcb70 in cudaLaunchKernel () from /usr/local/cuda/lib64/libcudart.so.12
    #1  0x00007fff6859be82 in cudaLaunchKernel<char> (
        func=0x7fff685a83a1 <at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<float>, at:
    :detail::Array<char*, 1> >(int, at::native::FillFunctor<float>, at::detail::Array<char*, 1>)> ...,
    gridDim=..., blockDim=..., args=0x7fffffffb7c0, sharedMem=0, stream=0x0)
        at /usr/local/cuda/include/cuda_runtime.h:216
    #2  0x00007fff68595ed7 in __device_stub__ZN2at6native29vectorized_elementwise_kernelILi4ENS0_11FillFuncto
    rIfEENS_6detail5ArrayIPcLi1EEEEEviT0_T1_ (__par0=4, __par1=..., __par2=...)
        at /tmp/tmpxft_00006e3d_00000000-6_FillKernel.compute_90.cudafe1.stub.c:280
    #3  0x00007fff68595f2f in at::native::__wrapper__device_stub_vectorized_elementwise_kernel<4, at::native:
    :FillFunctor<float>, at::detail::Array<char*, 1> > (__cuda_0=@0x7fffffffb82c: 4, __cuda_1=...,
        __cuda_2=...) at /tmp/tmpxft_00006e3d_00000000-6_FillKernel.compute_90.cudafe1.stub.c:283
    #4  0x00007fff685a83d0 in at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<float>, at
    ::detail::Array<char*, 1> > (N=4, f=..., data=...)
        at /opt/pytorch/pytorch/aten/src/ATen/native/cuda/CUDALoops.cuh:59
    #5  0x00007fff685ac57b in at::native::launch_vectorized_kernel<at::native::FillFunctor<float>, at::detail
    ::Array<char*, 1> > (N=4, f=..., data=...)
        at /opt/pytorch/pytorch/aten/src/ATen/native/cuda/CUDALoops.cuh:98
    #6  0x00007fff685a50bf in at::native::gpu_kernel_impl<at::native::FillFunctor<float> > (iter=..., f=...)
        at /opt/pytorch/pytorch/aten/src/ATen/native/cuda/CUDALoops.cuh:214
    ...
    
    (gdb) py-bt
    Traceback (most recent call first):
      <built-in method ones of type object at remote 0x7fffafddf720>
      File "test.py", line 6, in main
        x = torch.ones(4, device="cuda")
      File "test.py", line 10, in <module>
        main()
    
  6. 充分利用源代码中的日志、debug 选项、测试用例: 很多 PyTorch 模块都包含了丰富的日志和 debug 开关,这些日志和用于调试的消息可以帮助我们理解 PyTorch 的执行流程。

    比如对于 TorchDynamo 来说,下面的日志和调试开关就能帮我们更好地熟悉 TorchDynamo 的执行流程:

    import os
    import logging
    import torch._dynamo
    torch._dynamo.config.log_level = logging.DEBUG
    torch._dynamo.config.verbose = True
    torch._dynamo.config.output_code = True
    os.environ["TORCHDYNAMO_PRINT_GUARDS"] = "1"
    

    除此之外,PyTorch 中包含了大量的测试用例,如果单纯看源代码无法理解程序的逻辑,看看其对应的测试用例可以帮助我们理解程序在做什么。

  7. 及时求助: 如果经过上面的流程还无法了解某些代码的逻辑,要及时向社区求助,避免浪费过多时间,包括但不限于:

    • 在知乎、PyTorch 论坛上提问;
    • 在 Github 上发 issue;
  8. 学什么: 明确源代码中哪些东西值得我们学习和借鉴,读源代码时要特别注意这些方面,比如:

    • 特定模块/功能的实现原理;
    • 用到的算法;
    • 一些 coding 技巧;
  9. 知行合一:

    You can’t understand it until you change it.

    读源代码不是最终目的,充分利用从代码中获取的认知才是。有效的 输出 可以加深我们对代码的理解,一些可以参考的输出方式:

    • 写一篇源码剖析的博客;
    • 简化自己对源代码的认识,分享给其他人;
    • 修改源代码,改进或添加一些功能,给 PyTorch 提交 PR;
    • 亲手实现相同功能,写精简版的代码复现核心逻辑;

每一个读源代码的人都是不甘平凡的人,祝大家在这个“痛并快乐着”的过程中成长得更快、更多。