AI 实验中的 python 工程实践

深度学习实验管理

2023-09-11 一 16:30 2024-02-19 一 14:14

修改历史

  • [2023-12-15 五 16:03] 调整章节顺序,把数据集处理部分放代码阅读重构和文档章节之后

本文主要梳理深度学习(或者机器学习相关的)实验中除模型、算法设计之外的工程方面的经验,包括项目的文件组织、测试和调试、参数管理、代码阅读和重构、文档和注释等,不涉及多卡训练相关主题。💡本文记录的都是个人经验,请谨慎参考。

1. 项目源文件的结构

在组织 python 项目的文件结构时,有两个主要方面需要考虑:

  • python 本身对多文件交互和组合所设计的机制,包括文件路径的确定和 import 原理,这是本章第一节要理清的。
  • 实验对象的特点:本章第二节围绕深度学习实验项目的主要特点来介绍一种比较典型的项目结构。

1.1. 理解 python 运行时的 cwd 和 sys.path

在没有特别说明时,本章后文提到的文件和目录都是基于以下项目结构的:

/tmp/proj/
   ├── a.py # from subdir import b
   └── subdir/
       ├── b.py # from deepdir import c
       └── deepdir
           └── c.py #"print('c.py here')"

可以用如下 bash 脚本来创建以上目录结构

cd /tmp; rm -rf proj
mkdir -p proj/subdir/deepdir
echo "from subdir import b" > proj/a.py
echo "from deepdir import c" > proj/subdir/b.py
echo "print('c.py here')" > proj/subdir/deepdir/c.py

假设目前已经从终端进入到了 /tmp/proj/ 目录中,当运行 python 程序后,涉及两个比较重要的保存"路径"的变量,分别是通过 os.getcwd() 得到的当前工作目录和通过 sys.path 得到的系统路径列表。

1.1.1. 当前工作目录 cwd

import os; print(os.getcwd())

CWD 全称是 Current Working Directory, 一般叫做工作目录,它记录的是 python 命令触发时所在的目录,因此执行 python subdir/b.py 后的工作目录是 /tmp/proj/

当用 open 等命令打开其他文件时,文件的"相对路径"就是相对 CWD 目录而言的。比如假设 subdir/b.py 里新增了读取文件的代码:

with open("data.json", "r") as f:
     #read

那么在 /tmp/proj 里执行 python subdir/b.py 读取的是 /tmp/proj/data.json 文件;而切换到 /tmp/proj/subdir 下再执行 python b.py 读取的就是 /tmp/proj/subdir/data.json 文件,这种不一致会带来问题。

因此一般来说尽可能只在项目根目录 /tmp/proj 里执行 python 命令,而不要随意切换到其他目录去执行。这样的话,项目里所有文件存取相关的路径都可以是相对根目录的,减轻理解上的负担和不必要的麻烦。

1.1.2. sys.path

import sys; print(sys.path)

该变量和 import 有关,首先区分两种不同的 import 形式:

  1. 绝对导入方式,形如:

    import xyz
    from subdir import b
    
  1. 相对导入形式,形如:

    from . import b
    import .xyz
    from .deepdir import c
    from ..subdir import b
    

    注意在模块或包前有一个或两个句号。(模块指的是 python 文件,包则指的是 python 文件夹)

sys.path 影响的是 绝对引用 形式,它是一个保存了多个路径的列表,这些路径用于在执行绝对导入时寻找对应的目标模块或包, python 进程还会额外读取 PYTHONPATH 环境变量,把其中设置的值添加到 sys.path 的开头。这和 bash 里的 PATH 环境变量的功能是类似的。

该变量会有一些默认值,其中很重要的是第三方库所在目录,比如 '~/miniconda3/envs/basic/lib/python3.8/site-packages', python 在解释 import torchimport numpy 等语句时,会根据该目录找到了对应的包。

/tmp/proj 里执行 python subdir/b.py 时,目标文件 b.py 所在的目录 /tmp/proj/subdir 会被添加到 sys.path 列表的第一项。如果在代码中没有手动修改 sys.path, 那么该变量的内容是不会动态变化的,这个特点非常重要,因为它意味着无论执行的 python 文件引发的 import 逻辑链中涉及到多少不同的文件(模块)和子目录(包),其中任何一个绝对导入形式 "import xyz" 语句都会从相同的 sys.path 中查找对应的 xyz 。用一个具体例子来说明这句话,回到以上项目结构:

/tmp/proj/
       ├── a.py # from subdir import b
       └── subdir/
           ├── b.py # from deepdir import c
           └── deepdir
               └── c.py #"print('c.py here')"

/tmp/proj 中执行 python a.py 会得到以下错误:

python a.py
Traceback (most recent call last):
  File "a.py", line 1, in <module>
    from subdir import b
  File "/tmp/proj/subdir/b.py", line 1, in <module>
    from deepdir import c
ModuleNotFoundError: No module named 'deepdir'

这是因为 sys.path 里加入的是 a.py 所在目录 /tmp/proj ,该目录下有 subdir/ 子目录,因此 a.py 里的 from subdir import b 是能执行成功的。 但 b.py 中的 from deepdir import c 会促使解释器去 proj/ 目录下找 deepdir/ 子目录,而该目录在 proj/subdir 下, 从而导致 import 失败。

如果不清晰地理解 sys.path 不变的性质, 这里可能带来困惑,因为在 b.py 里导入与文件自身在同一目录下的 deepdir/ 目录(或者叫做 package)看上去是很自然的,直接执行 python subdir/b.py 也是可以成功的,原因是此时 sys.path 第一项变成了 subdir/ 。这是编写 python 项目和编写简单脚本集合的区别, 编写脚本集合时更多是以当前文件为参考,所有文件都在同一个目录下,因此相对引用和绝对引用是一样的,但编写项目的时候,目录可能是嵌套的,绝对引用要以执行入口文件为参考。

如果要使得执行 python a.py 不会报错, b.py 中可以采用相对引用: from .deepdir import c, 其中 .deepdir 是相对当前文件所在包而言的(而不是相对 sys.path 里任何路径, 也不是相对当前文件所在路径)。但它的副作用是导致执行 python subdir/b.py 失败,因为直接执行时解释器不会去预设当前 b.py 在包的路径(python 可以自动检测 b.py 的目录,但是出于某些权衡,并不会把 b.py 所在目录看作是 b.py 所在的 package,而是弹出错误):

from .deepdir import c
ImportError: attempted relative import with no known parent package

解决办法是改用 python -m subdir.b 来启动 b.py,它明确告诉解释器 b.py 是在 subdir 这个 package 之中的,并且把 subdir 所在的目录 /tmp/proj 加入到 sys.path ,然后再运行 subdir/b.py ,这个时候,解释器根据 b.py 所在包的路径(/temp/proj/subdir)找到相对路径 .deepdir ,从而成功执行 from .deepdir import c

另外,这种方式会触发 if __name__ == "__main__" 语句, 因此除了影响 package 的识别, 它和一般执行脚本是没什么区别的。

还有一种做法是在 b.py 中写相对根目录的绝对引用: from subdir.deepdir import c, 但这同样会导致 python subdir/b.py 执行失败(因为 /tmp/proj/subdir 加入到 sys.path, 而其中没有另一个 sudir 目录)以及 python -m subdir.b 执行成功(因为 /tmp/proj 加入到了 sys.path),因此它和相对引用的最终效果是类似的,不过该写法更体现项目的整体性,明确告诉代码阅读者,当前的 b.py 是属于 subdir 所在的项目中的一员,因此最好不要直接执行 b.py, 而是通过项目入口来调用 b.py. 这也是本文推荐的做法。

以下两种也是解决方案,但不推荐:

  • 在 b.py 中明确把该文件所在目录加入到 sys.path 中,但这种做法可能会降低可读性, 读者因为必须确定这些语句会被执行才知道哪些引用是对的.
  • 更"暴力"做法是用以下语句,同时考虑被单独执行和被 import 的场景:
try:
    from deepdir import c
except:
    from .deepdir import c

1.1.3. 小结

结合 cwd 和 sys.path 的特点,可以提出以下几点关于组织 python 项目的原则:

  • 项目的入口最好统一到单独的主文件中 (比如 main.py) ,通过绝对的 import 方式把其他函数导入到主文件并根据命令行参数来分配不同的执行过程。例如只执行 python main.py --option ... 就可以执行到所有核心任务。

    这样做的好处是,所有的 import 语句可以统一遵从以根目录为基准的绝对引用,并且代码中涉及的文件路径也都可以是相对根目录而言的。

    如果 main.py 是在项目的子目录中,比如以下结构,那么执行的时候最好用 python -m src.main.py ,这样 main.py 还是可以用 from src.utls import 的形式导入其他模块。

    proj
    ...
    ├── src/
    │   ├── __init__.py
    │   ├── main.py
    │   ├── utils.py
    ...
    

    执行和 main.py 在同一个目录下的其他文件时,cwd 和 sys.path 与执行 main.py 的情况下是相同的,因此也不会带来额外的问题,比如可以直接执行

    python -m src.utils
    
  • 由于相对引用是相对包所在目录的,因此最好在每个子目录加入 __init__.py , 明确告诉解释器该目录是一个包(虽然没有该文件目前解释器也能自动识别)。另外,不需要在根目录加入 __init__.py

1.2. 深度学习实验项目结构

本节针对的是学习某些深度学习算法或科研实验的场景,实验性代码面对的问题和一般软件开发很不一样,其独特点体现在:

  • 能快速灵活地添加或删除某些参数,比如添加一个模型选项以对比不同激活函数对学习效果的影响,这对代码设计提出的要求包括:
    • 添加参数的整条通路要比较清晰,从 main.py 里加入一个选项后,容易看清楚这个选项是如何流通到最终的 model 构建的。
    • 新增加了参数路径后不会影响现有代码,因此最好能够有充分的测试,每次修改都能跑一次完整测试,以快速发现问题。
    • 测试要用小数据集或者人造的少量数据,否则跑一遍很耗时间。
  • 整个数据处理过程基本是线性的,因此有明确的阶段性,比如数据预处理、构建数据集、训练、验证、预测评估, 每个步骤有不同的参数,可以组合出非常多的选项。
    • 预处理和构建数据集阶段包括不同训练数据集选择、预处理手段选择
    • 训练阶段则有不同的训练超参数选择
    • 预测评估阶段有不同的 benchmark 和测试指标选择
  • 以上不同阶段都要有明确的函数对应,同时又能组合起不同的阶段执行流,可以写多个脚本,或者用 Makefile 来组织。
  • 可复现性:记录好各个核心包的版本和安装方法,一般可以用 requirements.txt 或者把环境搭建步骤写在 README.md 里,代码内要固定随机种子。
  • 能够留出部分接口通过 jupyter 来 import 部分函数进行交互式分析和可视化。

以下是一种项目文件的划分方式:

dl_imports.svg
  • 点状箭头表示可能的 import 依赖, 线状箭头则是参数的保存和读取路径
  • main.py 和 config.py 都是用来做参数管理的,这是为了满足参数修改的灵活性
  • 留出专门的测试目录 tests/
  • 用 Makefile 来管理不同的阶段的代码执行入口(也可以用一个或多个 bash 脚本)
  • 根据需求还可以拆分更多小的数据处理文件,比如 preprocess.py, tokenize.py 等。
  • 如果是比较小的项目, train.py 和 eval.py 可以直接合并成一个 train.py 或 train_eval.py.

加上数据和运行结果保存后,一个相对完整的目录结构如下:

proj
├── Makefile
├── .gitignore
├── requirements.txt
├── sample/
│   ├── __init__.py
│   ├── train.py
│   ├── eval.py
│   ├── utils.py
│   ├── models.py
│   ├── config.py
│   ├── dataprocess.py
│   └── main.py
├── notebooks/
│   └── exploratory_data_analysis.ipynb
├── data/ 
│   ├── raw/
│   └── processed/
├── checkpoints/ 
├── logs/ # 日志文件
├── docs/ # 结果记录
└── tests/
    ├── __init__.py
    ├── test_basic.py
    ├── test_advanced.py
    ├── test_args.py
    └── test_datprocess.py

补充说明:

实际研究中很多时候是基于其他人公开的代码库,在其基础上进行修改,不需要自己重头搭建一个项目,但以上内容也可以作为一个参考框架,有助于理解项目的整体结构、关键组件,也有助于对现有代码库的定制和扩展。

2. 测试友好和调试

测试和调试不是实验的核心但却是保证编程活动能稳定前进的关键。

2.1. 测试友好

本文以 pytest 作为测试框架,因为它使得写测试和写一般的 python 函数一样,先通过 pip install pytest 安装,测试子目录的结构如下:

proj
...
├── sample/
│   ├── __init__.py
│   ├── train.py
...
└── tests/
    ├── __init__.py
    ├── test_basic.py
    ├── test_advanced.py
    ├── test_args.py
    └── test_datprocess.py

可以将测试视为一种“可执行的草稿文档”,对流畅可持续的编程非常有帮助:

  • 编程时如果对某个 numpy 或 pytorch 的函数不清楚,一般会去查找文档或他人的解读,也可以直接问 GPT 等大语言模型服务,总之你会得到一些解释和函数使用的样例,我们可以参考这些样例自己构造一个与当前实验有关的输入和输出的例子,将其写在测试文档里,比如 tests/test_basic.py 中,保证运行通过,这样下次不熟悉时可以马上在实验相关数据的场景下阅读并重新回忆起来。
  • 测试中包括函数的输入输出,因此编写正式调用该函数的代码时,只要确保输入输出和测试用例里的形式保持一致即可。
  • 测试中手动构造一些样例的实践会让编程者更懂得如何处理困难问题,比如可以先用手动构造的简单输出数据代替上一轮的处理结果(手动 mock), 保证整个流程通畅后,再来解决困难的部分。
  • 可以在测试样例中写许多非正式的给自己看的代码说明。
  • 在一次编程之前,运行所有测试,以建立本次编程的信心,确保是在稳固的代码上继续改进,而不是在断壁颓垣上盖房子,这样出了错误只会更加狼藉。
  • 在一次编程结束后,运行所有测试,以及时发现问题。
  • 在编程被打断时,编写一个期望的测试函数,回到编程状态时执行该测试会失败,从而提示你该从哪里继续编程。
  • 从测试很容易跳转到被测试对象,因为测试函数中就引用了类和函数,因此可以在函数的输入输出样例和函数实现上很方便地跳转

具体操作上,当在 proj/ 目录下执行 pytest 命令时,pytest 会自动寻找 tests/ 下 test_ 前缀的的文件执行,假设找到 tests/test_basic.py 后,先识别出该文件所在的目录 tests/ 中有 __init__.py 文件,于是该目录被认为是一个包,pytest 会用类似 python -m tests.test_basic 的方式来执行测试文件,由上一章 import 的原理可以知道, tests 包的父目录 proj/ 会被加入到 sys.path 中,因此在测试文件中可以用 from sample.train import training 的方式来导入项目里其他任何 python 文件。

如果没有 init 文件的后果

如果 tests/ 下没有 __init__.py 文件,那么 pytest 发现 test_basic.py 是一个独立的模块,于是它会把该文件所在的目录 tests/ 加入到 sys.path 中,这容易导致测试文件里以项目绝对路径方式进行 import 失败。 可以用 python -m pytest 来解决,它会把执行 python 的当前目录加入到 sys.path, 也就是 proj/, 因此在 test_basic.py 中类似 from sample.train import training 的绝对引用形式又可以生效,选择任何一种都可以,但个人还是加上 init 文件,这样执行命令更简单点。

注意: 由于 pytest 会执行 test_ 开头的函数,所以普通函数最好不用 test_ 开头,比如 test_on_dataset 函数可以改为 predict_on_dataset 。

测试常用的命令:

# 自动发现并运行所有测试
pytest 

# 自动发现并运行所有测试, 显示出 print 语句结果(默认不打印)
pytest -s 

# 执行单个测试文件
pytest tests/test_dataset.py

# 执行单个测试函数
pytest tests/test_dataset.py::test_get_raw_dataset

此外, 在以上所有命令后加入 --pdb 选项, 会在执行到异常时进入调试模式,当某个测试的 assertion 失败后再执行一次 --pdb 版本的 pytest 命令就会使得代码停在失败的位置,方便进行调试。但不要在一开始测试时就加上 --pdb ,因为很可能异常是出现在某个第三方库的复杂的边界检查函数里,这时候即便停在那也难以进行调试。

如果是主动希望检查某些函数运行的内部状态,则直接添加 breakpoint() 语句, 然后正常运行代码, 执行到该语句就会进入 pdb 交互界面中,接着打印或修改相关的参数进行调试即可。

其他补充:

2.2. 命令行 pdb 调试

上节最后一段引出了 pdb 调试方式,这是 python 自带的功能,不需要额外安装任何工具。如果已经熟悉 IDE 的交互调试功能,那么用自己熟悉的方式即可。但个人觉得 pdb 比较直观并且灵活,例如由于是基于命令行的形式,因此可以在各种环境下立刻开始调试,比如远程服务器上,尽管有时候需要及时检查哪些地方多加了 breakpoint 没有及时删除(用全局搜索)。

  • 注意:在 python 3.7 版本后 pdb 本身就是 ipdb ,因此不再区分这两个工具,以下方式启动:

    在需要停止的地方添加:

    import pdb; pdb.set_trace()
    # 或, >=3.7 版本用以下
    breakpoint()
    

    以下方式则会停在脚本第一句执行之前:

    python -m pdb myscript.py
    
  • 常用 pdb 中的命令如下:

    next #运行到下一行
    until 49 #运行到第 49 行
    return # 执行到函数的返回行,也就是最后一行,一般再按一个 next 就回到执行该函数语句的下一句
    continue # 运行直到下一个 breakpoint
    jump
    pp # 打印出形式美观的结果,比如 pp value
    p # 普通打印,比如 p value 等价于 print(value)
    where # 打印出当要执行的代码行以及文件路径和行号,当迷失在源码里的时候很有用,因为可以点击文件路径跳转
    list # 显示出当执行到的代码的上下文
    break # break 8 表示在第 8 行打一个新的断点
    tbreak # tbreak 8 表示在第 8 行打一个新的断点, 但执行过一次第 8 行就会自动删除该断电
    step # 进入函数体
    whatis # 打印对象, whatis value 类似 type(value)
    

    这些命令可以直接在 pdb 里用 help 命令查看说明:

    (Pdb) h return
    r(eturn)
          Continue execution until the current function returns.
    
  • 在循环里某个 step 停住: 对于训练或预测阶段,调试时,可以加入以下语句快速跑几个 batch 的样例看结果的形式是否符合预期。

    if step == 5: # debug
        breakpoint()
    

    在 pdb 交互界面里按某个变量的条件添加一个临时断点:

    tbreak 13 , i==3
    

    这时候如果输入 c 或者 continue,当 i 是 3 的时候,会停在第 13 行(一般是某个循环里面),并且删除这个断点。

  • 对异常代码的调试: 对于出现异常的代码,希望查看异常前的执行环境,可以用 try except 包裹住,当异常时进入到 pdb:

    try:
        some code
    except:
        breakpoint()
    
  • 多行函数调用时的 next 和 step

    有些函数会以多行的形式调用,例如:

    11 train(model = model,
    12       dataset = dataset,
    13       optimizer = get_optimizer(),
    14       ..
    15       )
    

    这种情况下,如果断点停在 11 行,想要跳转进入到 train 的执行过程中,在 dubug 交互窗口(命令行)里输入 step 或者 s, pdb 不会立刻跳转到 train 函数的定义位置,因为对于 pdb 来说,每一行都是一个语句,程序需要执行完函数调用里所有参数的赋值语句才会进入真正的函数体。 这是合理的,比如以上第 13 行是动态调用 get_optimizer() 来获得优化器,因此必须先执行完 get_optimizer 才能执行 train, 因此在 11 行执行 s 会运行到 12 行,继续按 s 则进入到 13 行,再继续 step 就进入到 get_optimizer() 函数了,这时候得从 get_optimizer 里返回,然后再持续按 step 到第 15 行,再按 step 才真正进入到 train 函数的定义。

    想更快跳入 train 函数的定义, 可以先输入 unt 14 执行到 14 行(整个函数调用的倒数第二行),再输入 step 会跳转回第 11 行,即真正准备进入 train 函数体,此时再 step 就可以进入到 train 。

  • 进入某个函数体理解了执行过程后,想快速跳出来

    比如以下例子,执行程序后,停在第 6 行,如果按 step 就会进入 trap 的实现,即第 2 行 for 循环处,继续跳转进入 increase 函数,并不断 step 进入到很深的某个核心函数,比如 step 了 5 次(相当于往 调用栈 push 了 5 个函数),这个时候如果想回到以下代码的第 7 行,就要输入很多的 return, 逐步从调用栈里跳出来(相当于不断 pop)。

    1 def trap():
    2     for i in range(10):
    3         increase(i)
    4
    5 breakpoint()
    6 trap()
    7 ...
    

    以上方式比较繁琐,一种改进的做法是,在第一次 step 进入 trap 前,先对未来要回来的位置做一个 mark 操作,比如在 pdb 交互里执行 tbreak 7 , 然后再不断 step, 当想要跳出来的时候,输入 continue 或者 c 就直接到了第 7 行,这个过程就像是在某个地洞里探险前在用绳子绑住自己,并让同伴在洞外拉住另外一头,当需要从洞里快速返回的时候就可以被迅速拉起。

    确保在 pdb 里探索代码的深层逻辑时不会迷失在复杂的函数调用中,在需要的时候迅速回到调用栈的上层。

3. 参数管理和函数接口选择

3.1. argparse 和 main

第一章提到过,应该尽量把参数都集中在一个入口文件里,比如 main.py 。 其他函数被 import 到改该文件中调用。

argparser 参数在传递过程中不断被"消解" 。假设 argparse 里有 10 个参数,包括模型类型、学习率、batch 大小等,经过执行准备阶段后,"吸收"掉了许多参数,比如模型类型, batch 大小,返回出来的是具体的 model 和 dataloader 或 dataset, 接着训练函数 do_train 只需要继续吸收 args 中的一小部分参数如学习率等。

3.1.1. argparse 中的 bool 选项说明

argparse 里有一点让人迷惑的,那就是如何设置 bool 值的情况,比如如果有以下选项

parser.add_argument("--cache", default=True, type=bool, help="use dataset cache or not")

那么执行 python main.py –cache=False 时 False 会解释成一个字符串 "False", 由于这不是空串,因此它的值就解释为 True, 只有 –cache="" 的时候才能改成 False, 这导致整体看上去很不自然。

一种做法是手动转换它:

args = parser.parse_args()
args.cache = bool(str(args.cache).lower() == "true")

这里要转变一种习惯,bool 类型选项是一个只有两个值的开关,因此可以用该选项是否存在来表示 True 或者 False

比如以下 –small 选项,这里 action 是 store_true, 它意味着,如果执行命令的时候给定 –small 选项,例如 python main.py --small, 那么 args.small 就会是 True, 如果没有该选项,如执行 python main.py, 那么 args.small 就是 False

parser.add_argument("--small", default=False, action="store_true")

以上写法等价于:

parser.add_argument("--small", action="store_true")

这更加精简,不过如果在 store_ture 的情况下 default=True, 那么 args.small 则恒为 True, 因此以下写法没什么意义。

parser.add_argument("--small", default=True, action="store_true")

与 store_true 对应的 action 是 store_false, 如下表示 –small 选项存在那么 args.small 为 False.

parser.add_argument("--small", default=True, action="store_false")

不过这种写法是不自然的,因为当给定某个否定选项时,最好在选项名称里就表达出来,因此如果某个选项默认是 True, 给定选项之后是 False, 那么更推荐如下写法:

parser.add_argument("--no_small", default=False, action="store_true")

args = parser.parse_args()
args.small = not args.no_small

这样执行 python main.py --no-small 后, args.no-small 就是 True, 而 args.small 则为 False.

在 python3.9 后可以用以下代码,使得否定和肯定都可以明确表达出来(也保留了默认值)

parser.add_argument('--cache', default=True, action=argparse.BooleanOptionalAction)

这样可以用 –no-cache 来将此变量置为 False.

3.1.2. click 和 argparse 的比较

click 是一个第三方库,需要用 pip 安装,click 比较简洁,不需要太多样板代码,可以很快执行一些命令,比如以下例子

@click.command()
@click.option("--opt", default="len", help="data process cmdline")
@click.option("--small", default=True, help="using small dataset?")
def main(opt, small):
    #skip

这样就可以执行 python main.py --opt="xyz" --samll 命令了, 但由于装饰器不太好封装,当参数很多的时候会写很多行 @click.opition 语句,IDE 不支持折叠,而且没法单独测试参数默认值。因此对于少量参数的可执行脚本的可以用它来修饰(比如预处理文件里,与 main.py 的命令行选项区分开),对于 main.py 里解析大量参数的场景,用 argparse 更好。

因此总的来说 click 并不是必要的。

3.2. Makefile

makefile 是对 main.py 或者其他文件里的命令选项的最高层汇总,也可用 bash 脚本来包装。

以下展示了一些例子,其中包括生成数据,训练,预测,在不同 GPU 上预测或训练, 对某个参数进行 grid 搜索,显示所后日志文件(实验结果表格),运行所有测试,在测试失败后进入 debug 等例子:

.PHONY: all preprocess build clean test help

gen-train-dataset:  ## Generate train datasets caceh
	python -m src.dataprocess --opt=train

train-all: train-ours-full train-ours-initial

train-ours-full: ## train ours-model on full dataset
	python -m src.main --stage=train

train-ours-initial: ## train ours-model on initial dataset
	CUDA_VISIBLE_DEVICES=2 python -m src.main --stage=train --initial

predict-ours-full: ## predict on sample0 based on our-model
	python -m src.main --stage=predict --base_model=ours --epoches=2

predict-ours-full-weights: ## predict on sample0 based on our-model with different weight
	for pin_weight in 0.1 0.3 0.5 0.7 0.9 1.1 1.3 1.5 1.7 1.9; do \
		python -m src.main --stage=predict --base_model=ours --epoches=2 --pin_weight=$$pin_weight; \
	done

show-logs:
	@for file in logs/*.txt; do \
		echo "==== $$file ===="; \
		cat "$$file"; \
		echo "\n"; \
	done

test: ## run all test cases
	@echo "Starting full unit test..."			
	pytest -s

test-break: ## run all test cases, break on first failure
	@echo "Starting full unit test..."			
	pytest -s --pdb

help: ## Display this help message
	@awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n  make \033[36m<target>\033[0m\n"} /^[a-zA-Z_-]+:.*?##/ { printf "  \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST)

3.3. 训练、验证和测试接口

这部分是比较灵活的,根据需求可以把接口区分的很细,也可以直接用一个高层 api (比如 huggingface 的 Trainer)

比 Trainer 低一层的是把所有代码写在一个函数:

def do_train(cfg, model, train_loader, val_loader, optimizer, scheduler, loss_fn):
    #for epoch loop:
       #for batch loop:
# inference
def inference(cfg, model, val_loader):

来源: L1aoXingyu/Deep-Learning-Project-Template: A best practice for deep learning project template architecture.

李沐的 dive2dl 里对单步训练也有封装,比如以下 epoch 循环里调用 train_batch_ch13:

def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
               devices=try_all_gpus()[:1]):

    #...
    for epoch in range(num_epochs):
        #...
        train_batch_ch13(net, X, y, loss, trainer, devices):
        #...

这种写法是好的,因为可以细致地测量。

而 pytorch-lightning 等框架则拆分地非常细:

  • forward_step: 输入 tensor 返回 logits
  • loss_step: logits 和 label 结合计算损失
  • decode_step: 把 logits 转成整数分类, 为计算 metric 做准备
  • concretize_step: 从 id 到最终有意义的 token.
  • metric_step: 从 concretize_step 得到符号化输出后,输入到 metric 进行 update, 真正的计算则是在 logging 步骤中进行
  • backward_step: 只做反向传播和梯度更新,需要处理 optimizer 和 scheduler
  • train_step: 整合 forward_step 和 loss_step 以及 backward_step
  • train_epoch: for 循环整合 train_step, 记录损失; callback 的概念是也从这里延伸出(在不同函数调用间插入回调)。
    • 打印类:logger, tqdm
    • 收集类: metirc, loss, 时间
  • train: 整个 train 的过程
  • fit: 最高层的组织
  • evaluate_step: 和 train_step 接口一样,但没有反向传播,整合 forward_step, loss_step, decode_step
  • evaluate/evaluate_epoch eval 直接就是 eval_epoch, 因为只要计算一遍就可以。eval 关键是 model.eval() 以及 with no_grad 在 train 的间隔中被调用,需要将 id 转换回 symbol 格式,做 metric 测试。
  • evaluate :监控 loss 函数的输出值,以及更新模型和保存参数,如是否要 warmup 等。 优化策略:对那些层进行冻结,不同层的学习率和 decay 的设置 可以单独用一个文件来管理所有常用的策略
  • predict_step: 和 eval_step 类似,但不需要计算 loss, 直接输出 logits
  • test/test_epoch

对于大部分项目没必要写这么细(框架采样考虑),比如在用 huggingface 的接口 forward_step 和 loss_step 是合并在 model.forward 里的,大部分时候抽象出 batch_forward 和 train/evaluate/inference 就够了。

4. 代码阅读、编写和重构

这部分比较零散,但核心是要识别出哪些任务可以用机器来做而不是自己去做,并且尽量写出能够让机器辅助自己理解的代码。

4.1. 关注差异和大纲

  • 阅读和修改代码的时候,注意力主要放在差异部分,因此要善于用代码版本控制和比较工具:
    • 当前修改与上一次提交的差异: vscode 的 git 工具
    • 文件之间的差异: vscode 的 diff 或者 vimdiff
    • git 提交历史之间的: gitkraken 或 vscode 的 gitlens 插件
    • 局部代码的差异: 代码对比/归并, Computed Diff - Diff Checker

      局部代码比较的一种场景是,当要对某个函数进行修改,但这个修改并不简单,如果在原函数上改,很可能会彻底改变原函数逻辑和接口,导致整个流程或测试被破坏。这时候可以先复制一个新的函数,在新函数上修改,原函数和整个 workflow 都不会被影响,等新函数功能测试完之后,再把两个函数复制到以上代码对比工具里,看是否可以通过修改参数来合并这两个函数,合并之后再修改测试接口和 workflow 里的接口。

  • 多使用 IDE 的大纲模式,在 vscode 等编辑器里都会有文件的大纲模式,显示类,函数,全局变量名,很多时候我们是在大纲层级思考整体逻辑的,而不用总是在细节中

4.2. 代码搜索来辅助修改代码

  • 修改变量名或者文件名用 IDE 中 "重命名符号" 相关功能,保证对该变量的引用都同步修改。
  • 如果修改某个函数的签名,比如从三个参数变成两个,那么需要把所有对该函数的调用都修改,可以用 IDE 搜索功能,比如修改 xyz 函数,则搜 "xyz(", 把左括号加上可以排除 import 和函数定义行,然后根据搜索列表顺序逐个修改。
  • 根据某个参数选择不同的函数或者类的时候,不用担心 if 和 elif 太多,因为这样可以很清楚地进行函数引用之间的跳转,这是利用了 IDE 的定义查找搜索功能,比如:

    if stage == "train":
        train(...)
    elif stage == "eval": 
        evaluate(...)
    elif stage == "predict": 
        predict(...)
    

    用字典查询的形式会使得代码更简洁统一,但需要查找字典里 key 和 value 的变量名再肉眼搜索到函数或类的定义中, 这实际会打乱阅读流畅性。

    stage_map = {"train": train, "eval": evalate, "predict": predict}
    stage_map[stage](...)
    

4.3. type hint 和类型检查

以下函数的参数里 decoder_tokenizer 有了 type hint 注明它是 BertTokenzier, 那么在 vscode 中按住 ctrl 后点击 decoder_tokenizer.encoder 就可以跳转到函数实现上,如果没有类型则只有运行时才知道,静态检查就弱了。

def convert_to_aligned_tokens(
    target: List[str],
    encoder_tokenizer: MyTokenizer,
    decoder_tokenizer: BertTokenzier,
):
    decoder_tokenizer.encoder(target)

其他好处:

使用类型提示(Type Hinting)在Python代码中有多个好处:

  1. 可读性和文档 类型提示可以作为代码文档,帮助开发者理解函数或方法应该接收什么类型的参数,以及它们会返回什么类型的结果。 因此对协作是友好的,尤其是给以后的自己看是友好的(善待自己善待别人
  2. 开发效率 许多现代IDE和编辑器(如PyCharm、VSCode等)能够利用类型提示来提供更精确的代码补全、提示和重构支持。 能否分别举例子说明哪些场景用了 typehint 能够更好支持代码补全,提示和重构
  3. 错误检测 使用类型检查工具(如`mypy`)可以在运行代码之前识别类型错误,这有助于在早期发现和修复问题。

python3.9 之后,原生的 list,tuple 类型和 typing 里的 List 和 Tuple 都一样了,也可以写成 list[int] 来表示 int 类型的数组。

python3.7 后,通过 __future__ 模块里的 annotations 可以将没有定义的类型也作为 type hint

from __future__ import annotations
  
class MyClass:
    def method(self) -> OtherClass:
        pass

class OtherClass:
    def method(self) -> MyClass:
        pass

如果没有第一句,会报错

NameError: name 'OtherClass' is not defined
  • 编写自己的项目时打开 IDE 中的 type checking 功能, 更早发现错误。阅读别人公开的代码时,谨慎开启 type checking, 因为有可能他人的代码报太多 warning 影响阅读。

    比如 VScode 安装 pylance 语言服务后 Ctrl-, 进入设置输入 type checking mode 设置为 basic 即可(strict 会对第三方库也报 warning 有点过了)

    这里主要是说明类型检查是有可能误报的,主要是一个提醒,但不需要太执着消除所有 warning。

    静态检查误报的一个例子:

    如果在 except 块内部调用 sys.exc_info(),那么 exc_traceback 通常应该是非 None 的。在这种情境下,Pylance 的警告可能是不必要的。但静态类型检查器(如 Pylance)通常不会理解这种上下文信息,因此它可能仍然会发出警告。如果你确定 exc_traceback 在这里永远不会是 None,你可以用 # type: ignore 选择忽略这个特定的警告:

    try:
        return search_str(last)
    except Exception as e:
        exc_type, exc_value, exc_traceback = sys.exc_info()
        line_number = exc_traceback.tb_lineno  # type: ignore
        return [f"An error occurred on line {line_number}: {str(e)}"]
    

4.4. 函数功能的修改

  • 如果发现数据处理有问题,要重写部分数据处理过程的话,可能会导致下游训练、验证、预测的接口都要随之变化,但由于这里面 train 一般是最复杂的(因为训练过程可能会调用验证函数),因此要沿着调用链逆流而上进行修改,比如先修改 generate/infer/test, 再修改 evaluate ,最后修改 train 。每修改一个函数前最好更新测试,修改完后就测试。

    当然编写前认真设计好函数接口以及预留可能的参数是更重要的,比如核心的这几个流程函数都至少预留动态的 kwargs 参数,尽量减少接口的修改。

  • 猴子补丁(monkey patch)的写法

    普通类方法补丁,函数第一个参数需要是 self:

    class A:
        pass
    
    def method(self):
        print(self.a)
    
    A.method = method
    

    静态方法补丁:

    def static_method():
        pass
    
    A.static_method = staticmethod(static_method)
    

    或者

    @staticmethod
    def static_method():
        pass
    
    A.static_method = static_method
    

    可以看到,赋值语句和普通方法是一样的,区别就在于不需要有第一个 self 形参,并用 staticmethod 装饰,该装饰器会告诉 python 不要把函数第一个参数当作类的引用。

  • partial 函数

    一般用来固定函数某些参数值并返回一个新的函数,这样可以构造出不同的函数,一般用于把函数作为参数传给其他函数或者用在猴子补丁中

    可以固定任意一个参数, 但在固定参数之后的所有参数赋值都需要明确写明参数名称,比如以下对 c,f 固定了,那么 c 以后的赋值(d,e,g)就要明确写出:

    def mul7(a, b, c, d, e, f, g):
        return a * b * c * d * e * f *g
    
    from functools import partial
    
    mul5 = partial(mul7, c=3, f=6)
    
    mul5(1,2,d=4,e=5,g=7)
    
    5040
    

    以下就会出错

    mul(1,2,4,e=5,g=7)
    

    可以嵌套 partial:

    mul2 = partial(mul5, b=2, e=5, g=7)
    mul2(1,d=4)
    
    5040
    

5. 文档、注释、日志

5.1. 打印和日志

5.1.1. pprint 和 log

使用 pprint 可以使得某些对象的打印更具有可读性,例如字典的打印,用 pprint 的话,每个 key 会对应单独的一行。使用 logging 的好处在于,可以在其中写更多的信息,比如 {%(filename)s:%(lineno)d} 会打印出文件名和行号,这样在 IDE 终端里可以方便跳转回到打印代码所在行,此外可以日志分等级、同时把日志保存到文件里等等。

以下是一种结合 pprint 和 logging 的打印(有 bug)

import logging
import pprint


_LOG_FMT = '%(asctime)s - %(levelname)s  {%(filename)s:%(lineno)d} - %(name)s -   %(message)s'
_DATE_FMT = '%m/%d/%Y %H:%M:%S'
logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO)
LOGGER = logging.getLogger('my_proj')

def printlog(msg, level=logging.INFO, pretty=True):
    if pretty:
        msg = pprint.pformat(msg)
    if level == logging.INFO:
        LOGGER.info(msg)
    elif level == logging.DEBUG:
        LOGGER.debug(msg)

假设以上代码写在 utils.py 中,在其他文件里,先 from utils import printlog,将 print 替换成 printlog 即可。

但以上代码是有问题的,因为 LOGGER.info(msg) 打印出的 {%(filename)s:%(lineno)d} 是该语句所在行,而由于这个语句就是在 printlog 函数里,因此无论在什么地方调用 printlog, 打印出来的文件名和行号都是 utils.py:19 (假设 LOGGER.info(msg) 在第 19 行)

因此如果要正确打印出结果,就只能从别的文件中执行 from utils import LOGGER, 然后执行 LOGGER.info 但这样就失去了 pprint 的格式化效果。

因此个人选择的做法是,在消息之前手动加入行号信息,这需要用到 python 的 inspect 库(LOGGER 本身也使用到了该库)

import logging
import pprint
import inspect
import os

_LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
_DATE_FMT = '%Y-%m-%d: %H:%M:%S'
logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO)
LOGGER = logging.getLogger('my_proj')

def printlog(msg, level=logging.INFO, pretty=True):
    if pretty:
        msg = pprint.pformat(msg)

    frame = inspect.currentframe().f_back
    filepath = frame.f_code.co_filename
    lineno = frame.f_lineno
    relative_filepath = os.path.relpath(filepath)
    msg = f"{relative_filepath}:{lineno}: {msg}"
    if level == logging.INFO:
        LOGGER.info(msg)
    elif level == logging.DEBUG:
        LOGGER.debug(msg)

以上删除了 _LOG_FMT 里关于文件名和行号的信息,然后手动加入到 msg 之前。另外,这里用 os.path.relpath(filepath) 来表示 filepath 路径相对于当前工作目录(参考第一章)的相对目录,这是可选的,也可以直接打印出绝对的 filepath.

5.1.2. tqdm

应该用 enumerate(tqdm(dataset)) 这样 tqdm 才能读取到 dataset 的长度 如果用 tqdm(enumerate(dataset)) tqdm 读到的是 enumerate 这个生成器的信息,它没有长度,也就不会有进度条

5.1.3. torch.utils.tensorboard, tensorboardx 和 tensorboard

tensorboardx 和 torch.utils.tensorboard 都是用于在 PyTorch 中生成 TensorBoard 所需的日志数据的,它们只是遵循了 tensorboard 官方的数据格式协议,可以把数据保存成这种格式,而要在终端里用 tensorboard --logdir=runs 启动 TensorBoard 服务并在网页中查看可视化效果,还需要安装 tensorboard。

由于 tensorboardx 是在官方的 torch.utils.tensorboard 推出之前开发的第三方软件,并且还支持 mxnet, numpy 等其他库,因此有比较大的用户基础,一些框架性代码里还会用以下方式同时兼容 tensorboardX 和 torch.utils.tensorboard

try:
    from torch.utils import tensorboard
except ModuleNotFoundError:
    import tensorboardX as tensorboard

大部分情况都是用 SummaryWriter 类来记录

from torch.utils.tensorboard.writer import SummaryWriter
from tensorboardX import SummaryWriter

总结来看,要使用 tensorboard 可视化:

  • conda install tensorboad (或 pip install)是必要的
  • 安装完 pytorch 之后默认就可以使用 torch.utils.tensorboard, 不想安装 tensorboardx 可以不装
  • 很多时候会把图片或者散点图等也写进去,因此需要安装 matplotlib

5.1.4. git 日志

一般来说,只要有一个明确的修改或功能实现(每次一小步),就应该提交一次 git ,并且写清楚当前提交的核心修改部分的说明(目前可以用 copilot 自动生成 commit 信息,因此更应该让每次修改差异更加聚焦)

6. 数据集处理

6.1. 数据处理的流程

数据处理的核心是尽早解耦合,以便可以灵活地处理不同的数据部分:

  • 首先,训练用的和最终测试的数据一般是切分好的两个相同格式的不同文件,因此可以用不同函数独立处理(或者用同一特征函数调用两次分别作用在不同文件上,这并没有什么额外成本)。
  • 对于训练集,最好先全部读取成到一个变量,比如读取成 eorch.utils.data.Dataset 或 datasets.Dataset 对象, 然后 split 成 train 和 eval 两个 Dataset 对象,这样可以对它们分别做特征转换(例如用 datasets.map 进行 tokenizer 得到 token_ids 等 ) 这比先做特征处理后切分是更灵活的,因为很可能每个切分部分会用到不同处理手段。
  • 接着连同 collator 函数包装成 dataloader 对象,每个 dataset 可以对应不同的 collator。

6.2. load_dataset 的问题

from src.ppo_data import PPO_TRAIN_SET_JSON
ds = load_dataset("json", data_files=PPO_TRAIN_SET_JSON)

一旦报错很难排查,比如以下错误:

ArrowIndexError: array slice would exceed array length

Invalid Arrow data from JSONL · Issue #5531 · huggingface/datasets 参考以上链接改成以下读取方式又成功了,区别仅仅在于以下 ds 不区分 train split, 可以直接通过 ds[0] 获取样本

from datasets import Dataset
import pandas as pd
ds = Dataset.from_pandas(pd.read_json(PPO_TRAIN_SET_JSON, lines=True))

因此很多时候构造 Dataset 可以考虑先自己手动读取成字典,再转换,比如以下方式:

from datasets import Dataset
data = {
    'text': ['hello', 'world'],
    'label': [0, 1]
}

dataset = Dataset.from_dict(data)
print(dataset)
Dataset({
    features: ['text', 'label'],
    num_rows: 2
})

6.3. 两种 Dataset 类的区别

from datasets import Dataset 和 from torch.utils.data import Dataset 是两个不同库中的不同实现,它们有一些关键区别:

6.3.1. Hugging Face 的 `datasets.Dataset`

  • 这是 Hugging Face 的 `datasets` 库中的一个类,主要用于NLP任务。
  • 它提供了丰富的数据预处理和转换功能。
  • 支持从多种数据源(如CSV、JSON、Parquet等)加载数据。
  • 内置了缓存机制,可以更有效地处理大型数据集。
  • 提供了与Pandas DataFrame相似的API。

6.3.2. PyTorch 的 `torch.utils.data.Dataset`

  • 这是 PyTorch 库中的一个抽象类,用于自定义数据加载。
  • 你需要实现 `__len__` 和 `__getitem__` 方法来创建一个自定义的数据集。
  • 主要用于计算机视觉、NLP、时间序列分析等多种任务。
  • 与 PyTorch 的 DataLoader 配合使用,支持自动批处理、数据洗牌和多线程数据加载。

6.3.3. 相互转换

虽然这两种 Dataset 类是不同的,但它们可以相互转换。

  • 从 Hugging Face 的 `Dataset` 转换到 PyTorch 的 `Dataset`:你可以使用 `datasets.Dataset` 对象的 `with_format("torch")` 方法将其转换为 PyTorch 张量格式,然后在自定义的 PyTorch `Dataset` 类中使用这些张量。

    from datasets import load_dataset
    from torch.utils.data import Dataset as TorchDataset
    
    class CustomDataset(TorchDataset):
        def __init__(self, hf_dataset):
            self.hf_dataset = hf_dataset.with_format("torch")
    
        def __len__(self):
            return len(self.hf_dataset)
    
        def __getitem__(self, idx):
            return self.hf_dataset[idx]
    
    hf_dataset = load_dataset("squad")["train"]
    custom_dataset = CustomDataset(hf_dataset)
    
  • 从 PyTorch 的 Dataset 转换到 Hugging Face 的 Dataset:你可以先将 PyTorch Dataset 对象转换为 Python 字典或 Pandas DataFrame,然后使用 datasets.Dataset.from_dict() 或 datasets.Dataset.from_pandas() 方法。

这样,你就可以根据任务需求灵活地使用这两种 Dataset 类。

6.4. dataset.map 和 dataloader collator 的权衡

dataset.map 里的 process 函数的形参 examples 可以看作是一个字典类型的列对象集合,如以下代码所示,直接从 examples 中取出 "context" 和 "target" 两列,其长度是 datasets 默认(如果没有设置),一般是 1000, 然后对其中每一个都进行处理(或者批量处理,比如 tokenize 可以批量分词),返回一个新的字典类型的列对象

def preprocess_train_function(self, examples):
    contexts = examples["context"]
    targets = examples["target"]
    model_inputs = defaultdict(list)

    for i in range(len(contexts)):
        # tokenize
        input_ids = self.tokenizer.encode_plus(context[i])["input_ids"]
        label_ids = self.tokenizer.encode_plus(labels[i])["input_ids"]
        model_inputs["input_ids"].append(input_ids)
        model_inputs["labels"].append(label_ids)
   return model_inputs

可以用多次操作来处理,这里 batched=Fasle 使得 examples 就是一个单元素字典,比如以下例子中 x["review"] 取出来的直接是字符串而不是列表:

dataset = dataset.map(
    lambda x: {"input_ids": gpt2_tokenizer.encode(" " + x["review"], return_tensors="pt")[0, :txt_in_len]},
    batched=False,
)
dataset = dataset.map(lambda x: {"query": gpt2_tokenizer.decode(x["input_ids"])}, batched=False)
dataset = dataset[:20480]

而对于 collator 函数,它的形参 features 是一个列表类型的行对象集合,处理方式是遍历每一行,进行转换或者进行统一 操作(对齐),返回的结果则是一个小 batch 的字典列对象:

def data_collator(features: List[InputDataClass]):
    batch = {
        "input_ids": [],
        "labels": [],
    }

    for i, feature in enumerate(features):
        for k, v in feature.items():
            batch[k].append(torch.tensor(v).long())
    return batch

如下图所示:

dl_dataset_collator.svg

在 Hugging Face 的 `datasets` 库中,是在 `dataset.map` 中进行分词(tokenization)还是在 collator 函数中进行分词?

在 Hugging Face 的 `datasets` 库中,您可以选择在 `dataset.map` 方法中进行分词(tokenization)或者在数据集的 collator 函数中进行分词。两者各有优缺点:

### 在 `dataset.map` 中进行分词:

#### 优点:

  1. **预处理速度**:分词只需进行一次,然后可以缓存结果,这样每个 epoch 都可以更快地读取。
  2. **简单性**:所有预处理步骤(包括分词)都在一个地方完成,使得代码更容易理解和维护。

#### 缺点:

  1. **内存占用**:如果您的数据集非常大,预先进行分词可能会占用大量内存。
  2. **灵活性**:预先分词可能限制了您在训练过程中动态更改分词策略的能力。

### 在 collator 函数中进行分词:

#### 优点:

  1. **灵活性**:您可以更容易地在训练过程中动态更改分词策略。
  2. **内存效率**:由于分词是按需进行的,因此不会一次性占用大量内存。

#### 缺点:

  1. **运行速度**:由于分词是在每个 mini-batch 上动态进行的,这可能会降低训练速度。
  2. **复杂性**:预处理逻辑被分散在多个地方,可能会使代码更难维护。

### 总结:

  • 如果您的数据集相对较小,或者您希望能够缓存分词结果以加速多次实验,那么在 `dataset.map` 中进行分词可能是一个好选择。
  • 如果您需要更高的灵活性,或者您的数据集非常大,那么在 collator 函数中进行分词可能更合适。

根据您的具体需求和约束,您可以选择最适合您的方法。希望这能帮助您做出决定!有其他问题吗?

因此,对于小团队项目(除训练大语言模型之外的大部分深度学习项目),在 dataset.map 里分词是更好的。

是在 dataset.map 里对齐好还是在 collator 里对齐好?

对齐序列(也称为填充或 padding)的长度可以在 `dataset.map` 或 collator 函数中进行,具体取决于您的需求和场景。

### 在 `dataset.map` 中进行对齐:

#### 优点:

  1. **预处理速度**:与分词一样,对齐只需进行一次,然后可以缓存结果。
  2. **简单性**:所有预处理步骤(包括对齐)都在一个地方完成,使得代码更容易理解和维护。

#### 缺点:

  1. **内存占用**:预先对齐可能会导致数据集占用更多内存,因为所有序列都会被填充到相同的最大长度。
  2. **灵活性**:预先对齐可能限制了您在训练过程中动态更改对齐策略的能力。

### 在 collator 函数中进行对齐:

#### 优点:

  1. **灵活性**:您可以更容易地在训练过程中动态更改对齐策略。
  2. **内存效率**:由于对齐是按需进行的,因此不会一次性占用大量内存。
  3. **效率**:您可以根据每个 mini-batch 的实际需要进行对齐,而不是填充到整个数据集中的最大长度。

#### 缺点:

  1. **运行速度**:由于对齐是在每个 mini-batch 上动态进行的,这可能会稍微降低训练速度。
  2. **复杂性**:预处理逻辑被分散在多个地方,可能会使代码更难维护。

### 总结:

  • 如果您的数据集相对较小,或者您希望能够缓存预处理结果以加速多次实验,那么在 `dataset.map` 中进行对齐可能是一个好选择。
  • 如果您需要更高的灵活性,或者您的数据集非常大,那么在 collator 函数中进行对齐可能更合适。

根据您的具体需求和约束,您可以选择最适合您的方法。希望这能帮助您做出决定!有其他问题吗?

对于实验项目,在 collator 里对齐更好,因为更灵活,比如 batch 为 1 的时候,从 dataloader 里拿到的就是原始的特征。

另外,一般可以分别编写 get_datasets 和 get_dataloaders 两个函数,其中 get_datasets 返回 datasets.map 和 dataset.train_test_split 的结果,然后 get_dataloaders 根据给定的 dataset 结合 collator 返回 dataloader 类, get_dataloaders 函数的选项可以通过 partial 函数传给 collator 。

6.5. json.load 和 json.loads

json.load 是直接加载一个文件对象(load file), json.loads 则是加载字符串(load string 的缩写),可以用来加载每行是一个 json 对象的文件

读取整个文件

import json
with open('data.json', 'r') as f:
    data = json.load(f)

读取文件中每一行:

for line in tqdm(lines):
    data.append(json.loads(line))

6.6. pad_sequence 和 pack_padded_sequence

from torch.nn.utils.rnn import  pad_sequence, pack_padded_sequence

pad_sequence 是常用的对齐 batch 的函数:

import torch

# 创建一个包含三个不同长度的序列的列表
sequences = [torch.tensor([1, 2, 3]), torch.tensor([4, 5]), torch.tensor([6])]

# 使用 pad_sequence 进行填充
padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0)

print(padded_sequences)
tensor([[1, 2, 3],
        [4, 5, 0],
        [6, 0, 0]])

对于 transfomer, 它的分词器就可以自动对齐 padding 和截断:

input_ids = tokenizer("Hello, how are you?", padding='max_length', max_length=10, return_tensors='pt')['input_ids']

但对于一些需要精细处理的部分,这时候可以用以上 pad_sequence 或者手动截断和对齐:

trunc_label = label[: self.max_length]
# use -100 to mask loss
decoder_label = trunc_label + [-100] * (self.max_length - len(trunc_label))

pack_padded_sequence 是专用于 RNN 的,对于 transformer 结构的场景,基本不再使用该函数。

6.7. 其他说明

  • tokenizer 获取:

    如果项目大,则用单独一个文件,比如 tokenizer_utils.py, 如果比较小则放在模型文件里,因为模型和 tokenizer 是比较紧密的

7. 整个实验活动流程总结

假设已经确定了要做的事情,那么编码的步骤如下:

  1. 对于核心函数,先在 jupyter notebook 上打样,基本用 print 进行调试,每编写完一个函数,则把该函数写回 py 文件,并补充一个测试,测试的输入可以直接在 jupyter 中打印出来并复制到测试函数里,assert 的结果则是从 jupyter 打印的结果复制过来。

    (不是很必要在 jupyter notebook 里写的函数包括: datset.map 里的函数和 collator 函数,因为输入不太好构造,可以直接在 py 文件里写,然后用 breakpoint() 交互)

    另外这类数据生成函数的测试可以写在函数所在文件的 click 命令选择中

  2. 迁移到 py 文件并编写测试、通过测试、git commit, 尽管大部分是复制工作,但还是要一定时间(至少 20min),因为其中涉及:
    • 函数接口修正:jupyter 里很多变量是"全局"的,只有复制到 py 文件后才知道,这个函数需要传入哪些参数
    • import 引入
    • 测试函数编写
    • git 差异检查、编写 commit message 并提交
  3. 编写完后,最好删除掉 jupyter 里的函数实现,改成 import, 这样保证函数各处是一致的。因此记得在 jupyter 开始的 block 加上以下命令保证函数及时更新:

    %load_ext autoreload
    %autoreload 2
    
  4. 把该函数应用到更全局的工作流中,一般会编写一个 wrapper 函数,该函数可以直接写在 py 文件中,因为当期那已经没有技术难点,更多是环节组合,编写完后需要有对该流程的测试(如果是数据集生成,那么直接生成后查看格式即可,否则可以用很少的几个数据去测试),这个过程一般也要预留 20-30 分钟来处理。

综合性问题时间预估:

  • 一个数据集流程打通至少 4 到 5 小时(除非非常标准化的数据),如果中间遇到 bug 或不太熟需要学习库的使用,一天就过去了
  • 修改模型的某些结构类似,不是很标准的话也是一天就过去

8. 其他经验参考

  • Patterns for Research in Machine Learning | Ali Eslami
    • code 和 data 目录要独立
    • data 目录下继续区分 input, working, output 三个子目录 其中 input 是不会修改的,是原始数据集,对应我的 database/datasets 目录 working 对应的是我在各个数据集下设置的 process 目录,然而这里作者把 working 独立出来, 这样拷贝整个 dataset 的时候很方便,也可以随时删除 working, 更合理方便 output 对应的是我之前的 run 目录,作者把这个目录直接移除到代码目录外,这也是很合理的, 整个代码就只包括代码和文档,这样可以很方便版本控制以及备份代码,都不需要太关注 .ignore 文件,同时方便分享代码。 例如 output 中可能包括代码和结果,可以直接把整个目录公开出去,相当于 output 就是一个单独的项目和结果目录
    • 对每个数据集单独新建 readme 以及转换成不同输入特征的函数
    • 不使用全局变量,这方便单元测试,debug 以及并行
    • 可以方便地执行各个单独阶段的代码,这是 pipeline 的解耦和,设置可以用以下方式执行:

      >> run_experiment('dataset_1_options', '|preprocess_data|initialise_model|train_model|');
      

      类似积木了,这样可以单独测试,也可以分阶段处理,例如只处理数据,然后在处理好的数据上进行训练, 另外的好处是,可以写很多个不同的 pipeline 组合,需要跑哪个就执行哪个,而不是每次跑代码去注释掉某些选项。

    • 保存训练的中间状态,可以随时地精准地恢复
    • 编写测试和 demo, 使得功能清晰
    • 每次运行实验前有一个对实验时间的大概预估
    • 每次运行实验前,记录下为什么要运行这个实验,设计理念是什么
    • 大部分编码工作应该在小数据集上,即整个运行起来不会超过 10s
    • 能够方便地替换损失和模型
    • 阅读 The Pragmatic Programmer
  • Principles of Research Code 这篇文章类似对以上文章的评论和补充,更多是对为什么这样做的更高层解释 解释为什么 Programming for research is very different than programming for industry. 五个原则:

    • 研究者输出的不是代码而是知识,大部分代码在发完论文后就忘记了
    • 除非你解决了大问题,使得很多人 follow 你的研究,这样你才需要考虑"好代码"的问题
    • 需要有更多的测试,保证结果的正确性
    • 需要特定的脚本和工具帮助自己科研
    • 需要保证可复现,多年后看到一张论文里的图,还能够找到对应版本代码并且执行这个代码得到一样的结果

    基于这些原则,对实践经验的补充:

    • 在每个实验前都提交一次代码,这样实验出问题可以回退,更大好处是,多年后如果你看到一个日志记录了某个时间的实验, 你可以回退版本到那个时间点,然后复现这个实验结果。
    • 维持固定的随机种子,只用这些种子训练,不要随便换动
    • output 目录要精心维护,论文里每个图表都应该有对应的 output 目录里的原始数据,这个目录是要随时分享出去的。
    • 用 makefile 分阶段单独处理数据,并且中间数据用 text 格式,而不是 binary 格式,这样才可以打开来检查。

如对本文有任何疑问,欢迎通过 github issue 邮件 metaescape at foxmail dot com 进行反馈