相比读博客、看视频学习一个开源项目,读源代码是比较痛苦的。然而 阅读源代码是成为优秀程序员的起点,不读源代码往往是知其然而不知其所以然。
PyTorch 的代码对初学者来说并不好读,原因在于:
- 有很多 C++ 代码是在编译过程中生成的;
- PyTorch 前端充分利用了 Python 的动态特性,比如 Python 函数可能经过 N 个装饰器修饰;
- 混合 Python/C++ 调用,比如从 Python 调 C++ 调 Python 再调 C++ 在 PyTorch 中是比较常见的事情;
我从 PyTorch 0.3 开始接触其源代码,期间读过不少 PyTorch 源代码,也修改过其中一些,时至今日依然在坚持读 PyTorch 最新的代码。以下是我读PyTorch 源代码的一些经验,仅供参考:
-
明确目标: PyTorch代码很多,建议每次只读一个专题或者从一个特定的问题出发,比如 PyTorch AMP 是怎么实现的;
-
把握全局: 一上来直接读代码会有很多障碍,很多术语不明所以,先通读这个专题下官方的教程、文档,以及一些写的好的第三方博客,确保自己对这个专题下的内容有初步的认知。以下是一些初步了解 PyTorch 特定专题内容的比较好的资源:
-
Debug Build: 一定要 build debug 版的 PyTorch,并保留在编译过程中生成的源代码,否则很难找到 PyTorch 函数调用栈中一些函数的来源;
以现有的 PyTorch
v2.0.0
分支为例,build debug 版 PyTorch 的方法:export DEBUG=1 python setup.py bdist_wheel pip install dist/torch*.whl
下面这些文件是 PyTorch 在编译过程中生成的,可以拷贝到 VSCode 中补全 PyTorch 源代码:
# Compress the generated sources during compilation tar -zcf generated.tar.gz \ torch/csrc/autograd/generated \ torch/csrc/lazy/generated \ build/aten/src/ATen
如果不想编译 PyTorch 源代码,但又想获取 PyTorch 在编译过程中生成的源代码,可以执行以下命令:
cd ${PYTORCH_ROOT} python tools/setup_helpers/generate_code.py --native-functions-path "aten/src/ATen/native/native_functions.yaml" --tags-path "aten/src/ATen/native/tags.yaml" --force_schema_registration --gen_lazy_ts_backend
-
静态读代码: 有了完整的 PyTorch 源代码之后就可以开始读了,网上有很多 VSCode 教程,设置好 VSCode 的 Python 和 C++ 插件,方便在函数之间跳转,可以解决一大部分的函数跳转;
-
动态读代码: 静态读代码的问题是常常搞不清函数的执行流程,此时在运行过程中动态读执行过的代码就很有帮助,善用
gdb
和pdb
可以有效辅助读源代码。对于 PyTorch 来说,在安装 debug build 版 PyTorch 的基础上,推荐安装
python3-dbg
:sudo apt install gdb python3 python3-dbg
然后针对我们想要了解的专题,写一个最简单的例子 (minimal example),比如我们想知道一个 CUDA kernel 是怎么在 PyTorch 中调用起来的,可以用下面的代码:
import torch def main(): x = torch.ones(4, device="cuda") x * 2 if __name__ == '__main__': main()
我们可以给
cudaLaunchKernel
设置断点,然后通过bt
和py-bt
查看函数调用栈:$ gdb python3 (gdb) b cudaLaunchKernel Breakpoint 2 at 0x7fffb21bcb70 (gdb) run test.py Thread 1 "python" hit Breakpoint 2, 0x00007fffb21bcb70 in cudaLaunchKernel () from /usr/local/cuda/lib64/libcudart.so.12 (gdb) bt #0 0x00007fffb21bcb70 in cudaLaunchKernel () from /usr/local/cuda/lib64/libcudart.so.12 #1 0x00007fff6859be82 in cudaLaunchKernel<char> ( func=0x7fff685a83a1 <at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<float>, at: :detail::Array<char*, 1> >(int, at::native::FillFunctor<float>, at::detail::Array<char*, 1>)> ..., gridDim=..., blockDim=..., args=0x7fffffffb7c0, sharedMem=0, stream=0x0) at /usr/local/cuda/include/cuda_runtime.h:216 #2 0x00007fff68595ed7 in __device_stub__ZN2at6native29vectorized_elementwise_kernelILi4ENS0_11FillFuncto rIfEENS_6detail5ArrayIPcLi1EEEEEviT0_T1_ (__par0=4, __par1=..., __par2=...) at /tmp/tmpxft_00006e3d_00000000-6_FillKernel.compute_90.cudafe1.stub.c:280 #3 0x00007fff68595f2f in at::native::__wrapper__device_stub_vectorized_elementwise_kernel<4, at::native: :FillFunctor<float>, at::detail::Array<char*, 1> > (__cuda_0=@0x7fffffffb82c: 4, __cuda_1=..., __cuda_2=...) at /tmp/tmpxft_00006e3d_00000000-6_FillKernel.compute_90.cudafe1.stub.c:283 #4 0x00007fff685a83d0 in at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<float>, at ::detail::Array<char*, 1> > (N=4, f=..., data=...) at /opt/pytorch/pytorch/aten/src/ATen/native/cuda/CUDALoops.cuh:59 #5 0x00007fff685ac57b in at::native::launch_vectorized_kernel<at::native::FillFunctor<float>, at::detail ::Array<char*, 1> > (N=4, f=..., data=...) at /opt/pytorch/pytorch/aten/src/ATen/native/cuda/CUDALoops.cuh:98 #6 0x00007fff685a50bf in at::native::gpu_kernel_impl<at::native::FillFunctor<float> > (iter=..., f=...) at /opt/pytorch/pytorch/aten/src/ATen/native/cuda/CUDALoops.cuh:214 ... (gdb) py-bt Traceback (most recent call first): <built-in method ones of type object at remote 0x7fffafddf720> File "test.py", line 6, in main x = torch.ones(4, device="cuda") File "test.py", line 10, in <module> main()
-
充分利用源代码中的日志、debug 选项、测试用例: 很多 PyTorch 模块都包含了丰富的日志和 debug 开关,这些日志和用于调试的消息可以帮助我们理解 PyTorch 的执行流程。
比如对于 TorchDynamo 来说,下面的日志和调试开关就能帮我们更好地熟悉 TorchDynamo 的执行流程:
import os import logging import torch._dynamo torch._dynamo.config.log_level = logging.DEBUG torch._dynamo.config.verbose = True torch._dynamo.config.output_code = True os.environ["TORCHDYNAMO_PRINT_GUARDS"] = "1"
除此之外,PyTorch 中包含了大量的测试用例,如果单纯看源代码无法理解程序的逻辑,看看其对应的测试用例可以帮助我们理解程序在做什么。
-
及时求助: 如果经过上面的流程还无法了解某些代码的逻辑,要及时向社区求助,避免浪费过多时间,包括但不限于:
- 在知乎、PyTorch 论坛上提问;
- 在 Github 上发 issue;
-
学什么: 明确源代码中哪些东西值得我们学习和借鉴,读源代码时要特别注意这些方面,比如:
- 特定模块/功能的实现原理;
- 用到的算法;
- 一些 coding 技巧;
-
知行合一:
You can’t understand it until you change it.
读源代码不是最终目的,充分利用从代码中获取的认知才是。有效的 输出 可以加深我们对代码的理解,一些可以参考的输出方式:
- 写一篇源码剖析的博客;
- 简化自己对源代码的认识,分享给其他人;
- 修改源代码,改进或添加一些功能,给 PyTorch 提交 PR;
- 亲手实现相同功能,写精简版的代码复现核心逻辑;
每一个读源代码的人都是不甘平凡的人,祝大家在这个“痛并快乐着”的过程中成长得更快、更多。