理解 transformers 库 GPT2 的 BBPE 分词算法

2023-11-30 四 20:18 2024-07-21 日 00:12

本文主要是通过 huggingface 的 transformers v4.30.2 里 GPT2Tokenizer 相关源码来理解 gpt2 模型里的 BBPE 分词算法(由于基于 gpt2 的模型有很多,有些模型是其他分词方式,本文只关注 BBPE)。

以下是初始化一个本地已经下载了的模型的分词器的常见代码:

pretrained_model_name_or_path = '/data/huggingface/transformers/gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path)

gpt2 at main 基本可以认为是 gpt2-small 的官方版本,它包括以下几个文件,前两个是网络模型所关注的,不在本文讨论范围内(若想了解 GPT2Model 可以参考 理解 huggingface transformers 中 GPT2 模型 一文),generation_config.json 和分词器也无关,最后三个则是是 tokenizer 初始化时可能要读取的文件。

config.json
pytorch_model.bin

generation_config.json

merges.txt
tokenizer.json
vocab.json

在另一些 gpt2 的实现中,分词器相关的文件可能有所差异,比如 IDEA-CCNL/Wenzhong-GPT2-110M at main 的文件如下:

config.json
pytorch_model.bin

merges.txt
special_tokens_map.json
tokenizer_config.json
tokenizer.json
vocab.json

后五个是 GPT2Tokenizer (可能)需要用到的,也是本文所关注的。

调用分词器后结果:

tokenized_text = tokenizer.tokenize("你好 ma")
print(tokenized_text)
['ä½', 'ł', 'å¥', '½', 'Ġma']

由于原始的 GPT2 用的是 byte-level byte pair encoding(BBPE), 因此分词(或组合)是发生在 byte 表示层面的,英文部分被编码成了 'Ġma', 中文部分则都是“乱码”,本文主要是想回答以下几个问题:

  • 下载的 huggingface 模型里 merges.txt 等文件的作用是什么?
  • 分词后结果里这些乱码是什么?
  • GPT2 默认词表大小是 50257, 它为什么等于 256+1+50000?

本文不包括:

  • 如何训练一个 bbpe 分词器
  • bbpe 与其他分词器如 wordpiece unigram 的原理对比

1. GPT2Tokenizer 初始化和继承关系

1.1. 继承关系和例行代码

注: 本小节都是些流程性代码的记录,可以快速略过。

GPT2Tokenizer 继承关系如下:

GPT2Tokenizer
     ↓
PreTrainedTokenizer
     ↓
PreTrainedTokenizerBase
     ↓
(SpecialTokensMixin, PushToHubMixin)

核心只需要关注 PreTrainedTokenizer 和 PreTrainedTokenizerBase 即可。 from_pretrained 函数就定义在 PreTrainedTokenizerBase 类中。这是一个 @classmethod 修饰的类函数。这种函数能够通过函数第一个参数 cls 访问类属性,GPT2Tokenizer 的部分类属性定义如下:

VOCAB_FILES_NAMES = {
    "vocab_file": "vocab.json",
    "merges_file": "merges.txt",
}

class GPT2Tokenizer
    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP

vocab.json 是从 token 到 id 的映射字典, merges.txt 是 gpt2 分词器使用的 bpe 算法中合并规则需要的数据, 这些在之后会详细介绍。

PreTrainedTokenizerBase.from_pretrained 先做了一些参数提取和参数检查性工作,然后构造一个变量名为 additional_files_names 的字典:

{'added_tokens_file': 'added_tokens.json',
 'special_tokens_map_file': 'special_tokens_map.json',
 'tokenizer_config_file': 'tokenizer_config.json'}

把这个字典与 cls.vocab_files_names (VOCAB_FILES_NAMES)合并扩充成 vocab_files 变量:

{'added_tokens_file': 'added_tokens.json',
 'merges_file': 'merges.txt',
 'special_tokens_map_file': 'special_tokens_map.json',
 'tokenizer_config_file': 'tokenizer_config.json',
 'vocab_file': 'vocab.json'}

以上字典保存了分词器所需要的各类文件名, huggingface 一般支持两种加载模型或分词器的方式,一种是如本文例子里的,给定一个本地路径,另一种是直接给定 huggingface 里模型的 id ,比如 bert-base-uncased 。后一种情况下会自动到在线模型库里去下载并保存在本地缓存路径下,因此接下来的代码就是对以上两种情况做区分,返回完整的文件路径名:

  • 如果给定的是本地路径,那么将 pretrained_model_name_or_path 和文件名进行拼接后更新字典
  • 如果给定的是网络模型名,那么从 cache 中检查是否存在,如果不存在则先下载,然后用完整的 cache 文件路径更新字典
  • 如果文件路径不存在则更新为 None

cached_file 就是对以上功能实现的描述:

resolved_vocab_files[file_id] = cached_file(
    pretrained_model_name_or_path,
    file_path,
    ...
)

以上函数定义在 utils/hub.py 文件里,可以直接 import 出来用于下载模型并查看缓存的文件路径:

from transformers.utils.hub import cached_file
model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin")

将 vocab_files 里路径都拼接完成后,得到 resolved_vocab_files 字典内容如下:

{'added_tokens_file': None,
 'merges_file': '/data/huggingface/transformers/GPT2-110M/merges.txt',
 'special_tokens_map_file': '/data/huggingface/transformers/GPT2-110M/special_tokens_map.json',
 'tokenizer_config_file': '/data/huggingface/transformers/GPT2-110M/tokenizer_config.json',
 'vocab_file': '/data/huggingface/transformers/GPT2-110M/vocab.json'}

注意这里没有用到 tokenizer.json (一般对应变量名 tokenizer_files), 它一般是用在 Fast Tokenizer 里的,所有 tokenizer 相关的数据都可以放在这一个文件中,并且底层对分词等操作的实现更加高效,这部分是纯工程优化,本文不展开。

接着调用以下函数(代码中没有列出的其他参数都是默认的,基本都是空或者 None):

@classmethod
def _from_pretrained(
    cls,
    resolved_vocab_files,
    pretrained_model_name_or_path,
    ...
)

该函数刚开始会考虑 fast 版本的分词器的读取,在此略过。

接着读取 resolved_vocab_files 字典里的 tokenizer_config_file 文件,内容如下:

{"unk_token": "<|endoftext|>",
 "bos_token": "<|endoftext|>",
 "eos_token": "<|endoftext|>",
 "add_prefix_space": false,
 "model_max_length": 1024,
 "special_tokens_map_file": null,
 "name_or_path": "gpt2",
 "tokenizer_class": "GPT2Tokenizer"}

它会被赋值到 init_kwargs 变量,同时 pop 掉 tokenizer_class, 这里主要是检查 json 文件里写的 tokenizer_class 和当前调用的 tokenizer 类是否一致,以上例子中都是 GPT2Tokenizer, 因此是一致的。

接着对 init_kwargs 调用 convert_added_tokens 函数,但在当前分词器下不会有任何副作用。后读取 added_tokens_file, 当前 gpt2 模型目录里没有该文件,都可以跳过。

把 resolved_vocab_files 字典合并到 init_kwargs 中去,把 name_or_path 覆盖成本地路径,init_kwargs 变成了:

{'add_prefix_space': False,
 'bos_token': '<|endoftext|>',
 'eos_token': '<|endoftext|>',
 'merges_file': '/data/huggingface/transformers/IDEA-CCNL/Wenzhong-GPT2-110M/merges.txt',
 'model_max_length': 1024,
 'name_or_path': '/data/huggingface/transformers/IDEA-CCNL/Wenzhong-GPT2-110M',
 'special_tokens_map_file': None,
 'unk_token': '<|endoftext|>',
 'vocab_file': '/data/huggingface/transformers/IDEA-CCNL/Wenzhong-GPT2-110M/vocab.json'}

接着调用以下函数初始化出了一个 tokenzier 对象

tokenizer = cls(*init_inputs, **init_kwargs)

注意这个 cls 是 GPT2Tokenizer, 因此进入到 GPT2Tokenizer.__init__ ,该函数先用 AddedToken 对各种控制 token 进行包装,这里引出了 transformers 库和 huggingface 的另一个 tokenizers 库的联系:

在 transformers/tokenization_utils_base.py 里:

if is_tokenizers_available():
    from tokenizers import AddedToken
    from tokenizers import Encoding as EncodingFast
else:

    @dataclass(frozen=True, eq=True)
    class AddedToken:
        """
        AddedToken represents a token to be added to a Tokenizer An AddedToken can have special options defining the
        way it should behave.
        """

        content: str = field(default_factory=str)
        single_word: bool = False
        lstrip: bool = False
        rstrip: bool = False
        normalized: bool = True

        def __getstate__(self):
            return self.__dict__

也就是说,在 import AddedToken 的时候,如果当前环境中安装了 tokenizers 库, 则从 tokenziers 里引入,否则自己定义一个。

然后调用 super().__init__, 又进入到 PreTrainedTokenizer 的 init 里,它的第一句是 super().__init__(**kwargs) 于是继续进入到 PreTrainedTokenizerBase.__init__ 也就是和 from_tokenizer 定义同一级别的类的初始函数中。该函数基本都是在构建各种实例属性,比如 self.model_max_length, 因此基本跳过。回到 PreTrainedTokenizer.__init__ ,内容如下:

super().__init__(**kwargs)
self.added_tokens_encoder: Dict[str, int] = {}
self.added_tokens_decoder: Dict[int, str] = {}
self.unique_no_split_tokens: List[str] = []
self.tokens_trie = Trie()

self._decode_use_source_tokenizer = False

super() 初始化后构造了一个 trie 结构的变量,但之后发现该结构对 GPT2 分词基本没有什么影响,因此也跳过。

回顾一下到目前为止初始化的核心:

  • GPT2Tokenizer.from_pretrained 是继承自基类 PreTrainedTokenizerBase 的。该函数会调用 PreTrainedTokenizerBase._from_pretrained
  • _from_pretrained 触发 tokenizer = cls(*init_inputs, **init_kwargs) 真正初始化 GPT2Tokenizer 实例
  • GPT2Tokenizer.__init__ 里通过 super().__init__(),沿着 GPT2Tokenizer -> PreTrainedTokenizer -> PreTrainedTokenizerBase 抽象链向上执行,但都是一些初始变量的工作,基本跳过
  • super().__init__() 结束之后的代码是关键,接下来分多个小节介绍

1.2. 读取 50257 个入口的 vocab.json

with open(vocab_file, encoding="utf-8") as vocab_handle:
    self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}

读取了 vocab.json 文件,这是一个 token 到 id 的映射表,self.encode 是从 token 到 id 编码, 部分内容如下:

{"!":0,"\"":1,"#":2,"$":3,"%":4,"&":5,"'":6,
 "(":7,")":8,"*":9,"+":10,",":11,"-":12,".":13,"/":14,"0":1

self.decoder 是从 id 到 token 编码。

该字典里有 50257 个元素,接下来会解释其含义

1.3. 构造大小为 256 的基础词表

接下来体现 BBPE 中第一个 Byte 意义的核心代码:

self.byte_encoder = bytes_to_unicode()

但在理解 bytes_to_unicode 之前需要介绍关于 utf-8 和 unicode 相关背景

1.3.1. ord 函数、 unicode 和 UTF-8 编码

在 python 中 ord 是直接获得某个字符的 unicode 码点,所谓码点就是对字符的编号, 比如对于经典的 ascii 字符(基本上是标准英文键盘能够直接打出来的字符以及不可见的控制字符)占据了 unicode 的前 128 个编号。

print(ord("1"),ord("c"),ord(" "))
49 99 32

unicode 涵盖范围很大,包括许多表情:

ord("😄")
128516

对于中文字符,unicode 编码一般也都是比较大的编号

ord("你")
20320

以上都是 10 进制表示,更常见的是用 16 进制并且加上 U+ 符号来说明这是一个 unicode 收录的字符:

f'U+00{ord("你"):X}'
U+004F60

搜索 - SYMBL (◕‿◕) 里也可以找到 "你" 对应的 unicode 的十六进制和十进制写法。

注意 unicode 的 16 进制表示和 utf-8 是无关的,unicode 只是负责对几百万字符贴上一个数字标签,其数值是 抽象层面 的,并不是说计算机里 "你" 的表示就是以上的 4F60 。

utf-8 可以看作是 unicode 的一种具体实现,它是另外一套编号,只不过这个编号是 实现层面 (被称为编码), 比如中文 "你" 的 utf-8 编码的 16 进制表示如下:

print("你".encode("utf-8"))
"你".encode("utf-8").hex()
b'\xe4\xbd\xa0'
e4bda0

结果里第二行的 16 进制和上文 unicode 码点的 16 进制表示 4F60 是完全不同的(至少这里有三个字节)。"你" 字所对应的 utf-8 的 10 进制表示如下:

def utf8int(c):
    return int(c.encode("utf-8").hex(), 16)

utf8int("你")
14990752

注:讨论 utf-8 的十进制没什么意义,因为它不是连续编号,而 unicode 一般是连续的,但这里是和 unicode 比较,“你” 的 unicode 16 进制是 4F60 ,十进制是 20320 。

utf-8 编码再解码:

print("你".encode("utf-8").decode("utf-8"))

不过对于 ascii 字符是最初定义的,过于经典,因此它们对应的 unicode 和 utf-8 的十进制表示是一样的:

print(ord("1"),ord("q"),ord(" "))

print(utf8int("1"), utf8int("q"), utf8int(" "))

print("1".encode("utf-8"), "q".encode("utf-8"), " ".encode("utf-8"))
49 113 32
49 113 32
b'1' b'q' b' '

注意最后一行里,对于这些 ascii 可见字符的 utf-8 编码结果,python 直接显示出字符表示而不是像中文 utf-8 里出现的 '\xe4' 样式的十六进制形式。

ascii 中也有不可见字符的,比如 unicode 码点小于 33 的字符,如果打印出来,它们也是用十六进制显示:

for i in range(33):
    char = chr(i)
    encoded_char = char.encode('utf-8')
    print(i, encoded_char) 
0 b'\x00'
1 b'\x01'
2 b'\x02'
3 b'\x03'
4 b'\x04'
5 b'\x05'
6 b'\x06'
7 b'\x07'
8 b'\x08'
9 b'\t'
10 b'\n'
11 b'\x0b'
12 b'\x0c'
13 b'\r'
14 b'\x0e'
15 b'\x0f'
16 b'\x10'
17 b'\x11'
18 b'\x12'
19 b'\x13'
20 b'\x14'
21 b'\x15'
22 b'\x16'
23 b'\x17'
24 b'\x18'
25 b'\x19'
26 b'\x1a'
27 b'\x1b'
28 b'\x1c'
29 b'\x1d'
30 b'\x1e'
31 b'\x1f'
32 b' '

以上除了 \t, \n, \r 经过转义字符而变得"可见" 的字符以及空格,其他都还是保持了 16 进制形式,BBPE 需要对这些 byte 以及 byte 的拼接结果进行保存(比如 "你" 是 '\xe4\xbd\xa0' 三个字节表示),如果以这种非字符形式保存会不太方便,于是才有 bytes_to_unicode 函数,使得在一个 byte 所能表示的 0 到 255 范围内,每个数都能对应一个可见的 unicode 字符,便于字典的保存。

1.3.2. byte 到 unicode 的转换

在 GPT2Tokenizer 初始化函数里,定义以下两个变量

self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}

bytes_to_unicode 函数实现为:

def bytes_to_unicode():
    bs = (
        list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
    )
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8 + n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

该函数没有任何依赖,可以拷贝出来直接执行测试:

byte_encoder = bytes_to_unicode()
print(byte_encoder)
{33: '!', 34: '"', 35: '#', 36: '$', 37: '%', 38: '&', 39: "'", 40: '(', 41: ')', 42: '*', 43: '+', 44: ',', 45: '-', 46: '.', 47: '/', 48: '0', 49: '1', 50: '2', 51: '3', 52: '4', 53: '5', 54: '6', 55: '7', 56: '8', 57: '9', 58: ':', 59: ';', 60: '<', 61: '=', 62: '>', 63: '?', 64: '@', 65: 'A', 66: 'B', 67: 'C', 68: 'D', 69: 'E', 70: 'F', 71: 'G', 72: 'H', 73: 'I', 74: 'J', 75: 'K', 76: 'L', 77: 'M', 78: 'N', 79: 'O', 80: 'P', 81: 'Q', 82: 'R', 83: 'S', 84: 'T', 85: 'U', 86: 'V', 87: 'W', 88: 'X', 89: 'Y', 90: 'Z', 91: '[', 92: '\\', 93: ']', 94: '^', 95: '_', 96: '`', 97: 'a', 98: 'b', 99: 'c', 100: 'd', 101: 'e', 102: 'f', 103: 'g', 104: 'h', 105: 'i', 106: 'j', 107: 'k', 108: 'l', 109: 'm', 110: 'n', 111: 'o', 112: 'p', 113: 'q', 114: 'r', 115: 's', 116: 't', 117: 'u', 118: 'v', 119: 'w', 120: 'x', 121: 'y', 122: 'z', 123: '{', 124: '|', 125: '}', 126: '~', 161: '¡', 162: '¢', 163: '£', 164: '¤', 165: '¥', 166: '¦', 167: '§', 168: '¨', 169: '©', 170: 'ª', 171: '«', 172: '¬', 174: '®', 175: '¯', 176: '°', 177: '±', 178: '²', 179: '³', 180: '´', 181: 'µ', 182: '¶', 183: '·', 184: '¸', 185: '¹', 186: 'º', 187: '»', 188: '¼', 189: '½', 190: '¾', 191: '¿', 192: 'À', 193: 'Á', 194: 'Â', 195: 'Ã', 196: 'Ä', 197: 'Å', 198: 'Æ', 199: 'Ç', 200: 'È', 201: 'É', 202: 'Ê', 203: 'Ë', 204: 'Ì', 205: 'Í', 206: 'Î', 207: 'Ï', 208: 'Ð', 209: 'Ñ', 210: 'Ò', 211: 'Ó', 212: 'Ô', 213: 'Õ', 214: 'Ö', 215: '×', 216: 'Ø', 217: 'Ù', 218: 'Ú', 219: 'Û', 220: 'Ü', 221: 'Ý', 222: 'Þ', 223: 'ß', 224: 'à', 225: 'á', 226: 'â', 227: 'ã', 228: 'ä', 229: 'å', 230: 'æ', 231: 'ç', 232: 'è', 233: 'é', 234: 'ê', 235: 'ë', 236: 'ì', 237: 'í', 238: 'î', 239: 'ï', 240: 'ð', 241: 'ñ', 242: 'ò', 243: 'ó', 244: 'ô', 245: 'õ', 246: 'ö', 247: '÷', 248: 'ø', 249: 'ù', 250: 'ú', 251: 'û', 252: 'ü', 253: 'ý', 254: 'þ', 255: 'ÿ', 0: 'Ā', 1: 'ā', 2: 'Ă', 3: 'ă', 4: 'Ą', 5: 'ą', 6: 'Ć', 7: 'ć', 8: 'Ĉ', 9: 'ĉ', 10: 'Ċ', 11: 'ċ', 12: 'Č', 13: 'č', 14: 'Ď', 15: 'ď', 16: 'Đ', 17: 'đ', 18: 'Ē', 19: 'ē', 20: 'Ĕ', 21: 'ĕ', 22: 'Ė', 23: 'ė', 24: 'Ę', 25: 'ę', 26: 'Ě', 27: 'ě', 28: 'Ĝ', 29: 'ĝ', 30: 'Ğ', 31: 'ğ', 32: 'Ġ', 127: 'ġ', 128: 'Ģ', 129: 'ģ', 130: 'Ĥ', 131: 'ĥ', 132: 'Ħ', 133: 'ħ', 134: 'Ĩ', 135: 'ĩ', 136: 'Ī', 137: 'ī', 138: 'Ĭ', 139: 'ĭ', 140: 'Į', 141: 'į', 142: 'İ', 143: 'ı', 144: 'IJ', 145: 'ij', 146: 'Ĵ', 147: 'ĵ', 148: 'Ķ', 149: 'ķ', 150: 'ĸ', 151: 'Ĺ', 152: 'ĺ', 153: 'Ļ', 154: 'ļ', 155: 'Ľ', 156: 'ľ', 157: 'Ŀ', 158: 'ŀ', 159: 'Ł', 160: 'ł', 173: 'Ń'}

以上函数第一句构造 bs 实际有两个主要部分:

  1. 最常用可见字符集(Basic Latin):

    print(list(range(ord("!"), ord("~") + 1)))
    
    [33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126]
    

    这些字符基本在标准的 PC 键盘上都能打出来:

    print([chr(x) for x in list(range(ord("!"), ord("~") + 1))])
    
    ['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~']
    

    以上 '\\' 表示的是就是 反斜杠 \ 字符,第一个 \ 是转义

    如上节所说,将这些字符对应的编号转成 16 进制再加上 U+ 就表示这些字符的 unicode 码点

    print([f"U+00{i:X}" for i in list(range(ord("!"), ord("~") + 1))])
    
    ['U+0021', 'U+0022', 'U+0023', 'U+0024', 'U+0025', 'U+0026', 'U+0027', 'U+0028', 'U+0029', 'U+002A', 'U+002B', 'U+002C', 'U+002D', 'U+002E', 'U+002F', 'U+0030', 'U+0031', 'U+0032', 'U+0033', 'U+0034', 'U+0035', 'U+0036', 'U+0037', 'U+0038', 'U+0039', 'U+003A', 'U+003B', 'U+003C', 'U+003D', 'U+003E', 'U+003F', 'U+0040', 'U+0041', 'U+0042', 'U+0043', 'U+0044', 'U+0045', 'U+0046', 'U+0047', 'U+0048', 'U+0049', 'U+004A', 'U+004B', 'U+004C', 'U+004D', 'U+004E', 'U+004F', 'U+0050', 'U+0051', 'U+0052', 'U+0053', 'U+0054', 'U+0055', 'U+0056', 'U+0057', 'U+0058', 'U+0059', 'U+005A', 'U+005B', 'U+005C', 'U+005D', 'U+005E', 'U+005F', 'U+0060', 'U+0061', 'U+0062', 'U+0063', 'U+0064', 'U+0065', 'U+0066', 'U+0067', 'U+0068', 'U+0069', 'U+006A', 'U+006B', 'U+006C', 'U+006D', 'U+006E', 'U+006F', 'U+0070', 'U+0071', 'U+0072', 'U+0073', 'U+0074', 'U+0075', 'U+0076', 'U+0077', 'U+0078', 'U+0079', 'U+007A', 'U+007B', 'U+007C', 'U+007D', 'U+007E']
    

    注意标准的 Basic Latin 字符集有 95 个字符,包括空格,而以上只有 94 个,不包括不可见的空格(编号是 32 )

    print(len(list(range(ord("!"), ord("~") + 1))))
    
    94
    
  2. Latin-1 Supplement

    unicode 从 161 到 255 是被称为 Latin-1 Supplement 字符集,有 96 个字符,其中有两个是不可见的,分别是 160 和 173, 以下代码就是避开了这两个符号:

    print(list(range(ord("¡"), ord("¬") + 1)))
    print(list(range(ord("®"), ord("ÿ") + 1)))
    
    [161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172]
    [174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255]
    
    len(range(ord("¡"), ord("¬") + 1)) + len(range(ord("®"), ord("ÿ") + 1))
    
    94
    

因此目前 bs 的值实际就是 unicode 中前 256 个字符里的可见字符集(的编号),共 188 个。但剩下的 256-188 个 byte 值还没有对应的可见字符,于是作者把不在 bs 中的小于 256 的值 idx 都加入到 bs 中,并且将这些值对应的字符设置为 unicode 编号为 idx+256 的字符(放在变量 cs 相应的位置),这是因为 unicode 编号大于 255 的基本都是可见字符了。

最后返回的字典 dict(zip(bs, cs)) 中,原本就可见的 188 个字符的编码到字符的映射是正确的,比如以下几个例子:

for dec in range(33, 40):
    char = byte_encoder[dec]
    print(f"{dec} -> {char}")
33 -> !
34 -> "
35 -> #
36 -> $
37 -> %
38 -> &
39 -> '

但其余的是不正确的,比如以下例子

for dec in [0,1,2,3,4,5,32,33]:
    char = byte_encoder[dec]
    print(f"{dec} -> {char}")
0 -> Ā
1 -> ā
2 -> Ă
3 -> ă
4 -> Ą
5 -> ą
32 -> Ġ
33 -> !

'Ā' 的 unicode 编码应该是 256 ,但这里对应的是 0, 而 32 原本对应的是空格,这里用 'Ġ' 表示。 不过编码对错在这里是不重要的,因为当前的目的只是找到 256 个可见字符来表示一个 byte (8 bit)内的所有可能组合,并以此作为基础词表,任何 utf-8 字符都可以拆分成这些 byte 的拼接。

1.4. 读取 50000 行的 merges.txt

merges.txt 定义了合并规则,文件前几行如下:

Ġ t
Ġ a
h e
i n
r e
o n
Ġt he

上一节最后说到 Ġ 表示的是空格,因此这里空格和 t 以及空格和 a 是最可能组合在一起的,结合第三个 h 和 e 的组合以及 Ġt 和 he 的组合,这侧面体现了是词库里 the 一词是非常多的,此外还有 in, on 。这是合理的,训练数据集里一般英文语料占比最大,而 the ,a 又是因为了最为常见的几个词。

接着看读取 merges.txt 的代码:

with open(merges_file, encoding="utf-8") as merges_handle:
    bpe_merges = merges_handle.read().split("\n")[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))

这里目的就是构建一个 bpe_ranks 字典,key 是 (h e) 和 (i n) 形式的 tuple, value 是合并规则的优先级,值越小优先级越高。

除去注释说明行,merage.txt 长度就是 50000, 这是人为设定的,因为在训练时只选出频率最高的 50000 组 pair 进行合并。

加上基础的 256 个单字节词,再加上额外的:

"<|endoftext|>": 50256

词库大小是 50257 = 256 + 50000 + 1,

另外如果没有上一节的 bytes_to_unicode 函数转换,该文件前几行就变成了:

  t
  a
h e
i n
r e
o n
 t he

这容易给人产生困惑,另外读取文件后用 split 函数来切分也会把空格删除导致错误,更别说还有 \n \t 甚至删除控制符带来的问题了。

1.5. pre-tokenizer 准备

读取 merges.txt 后, GPT2tokenizer.__init__ 中还设置了以下正则表达式变量,它是用来做预分词处理的(pre-tokenizer),后文会详细介绍

self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

1.6. cls(*init_inputs, **init_kwargs) 之后的处理

以上每个小节实际都是发生在 PreTrainedTokenizerBase._from_pretrained 调用 cls(*init_inputs,..) 内部,这会返回一个 tokenizer 对象,在这之后,还读取了 special_tokens_map.json 文件(对应 special_token_file 变量),该文件内容如下:

{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}

可以看到,这和 tokenizer_config_file 的内容是有重叠的,因此不会改变 tokenizer 属性(但如果定义了额外的字符,可以加入到此处)。

此外,如果存在 added_tokens_file 文件,会对其进行读取, gpt2 默认没有该文件,因此不考虑。

至此,整个初始化过程结束了,总结来看:

  • 读取了模型文件里的 tokenizer_config.json 文件,这是对分词器的基本设置,包括对 bos_token, eos_token 等控制类字符的定义, "model_max_length": 1024 参数等
  • 读取大小为 50257 的 vocab.json 字典
  • 构造了一个 256 大小的 byte 到 unicode 映射字典以及逆向字典。
  • 读取包括 50000 条合并规则的 merges.txt 文件
  • 读取 special_tokens_map.json, 定义其他的特殊字符(默认就是 1 个 <|endoftext|>,因此 vocab.json 里词的数量就是 256 + 50000 + 1)。
  • 没有读取 tokenizer.json, 这是给 FastTokenizer 用的(可能更新的模型会使用该文件)

2. 分词函数: GPT2Tokenizer.tokenize

先构造一个例子:

demo_sentence = "朋友,it's a good day."

分词结果:

tokenized_text = tokenizer.tokenize(demo_sentence)
['æľ', 'ĭ', 'åı', 'ĭ', 'ï', '¼', 'Į', 'it', "'s", 'Ġa', 'Ġgood', 'Ġday', '.']

以下又是流程性记录,可以快速略过:

tokenize 是定义在 PreTrainedTokenizerBase 中的接口函数,PreTrainedTokenizer 负责具体实现,在其中:

  • 先调用了 GPT2Tokenizer.prepare_for_tokenization:

    def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
        add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
        if is_split_into_words or add_prefix_space:
            text = " " + text
        return (text, kwargs)
    

    但由于 add_prefix_space 为 False 且传入的 kwargs 是空,因此返回结果就是 (text, {})

  • 接着执行 tokens = self.tokens_trie.split(text), 用 trie 树结果做了一次切分,但据我观察,基本只是包装成了 list, 变成 [text],因此基本不考虑它(猜测:trie 树应该是处理特殊字符的,而这里特殊字符只有 '<|endoftext|>')
  • tokenize 因此真正核心还是以下代码:
for token in tokens: # 本例中, tokens 是 ["朋友,it's a good day."]
    # Need to skip eventual empty (fully stripped) tokens
    if not token:
        continue
    if token in no_split_token:
        tokenized_text.append(token)
    else:
        tokenized_text.extend(self._tokenize(token))

而 _tokenize 的实现是在 GPT2Tokenizer 中

2.1. GPT2Tokenizer._tokenize

该函数非常简短

def _tokenize(self, text):
    """Tokenize a string."""
    bpe_tokens = []
    for token in re.findall(self.pat, text):
        token = "".join(
            self.byte_encoder[b] for b in token.encode("utf-8")
        )  # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
        bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
    return bpe_tokens

不过循环部分每一行都值得特别说明:

2.1.1. pre tokenize

前文提到过, self.pat 定义如下

import regex as re
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

注意这个是 re 是 regex 库而不是 re 库,regex 支持 ?\p 这类更复杂的正则表达式。

AI 对该表达式的解释是:

  • `'s|'t|'re|'ve|'m|'ll|'d`:这部分匹配常见的英文缩写和所有格形式,如 `'s`, `'t`, `'re`, `'ve`, `'m`, `'ll`, `'d`(例如,matches `it's`, `don't`, `they're`, `I've`, `I'm`, `we'll`, `I'd`)。
  • `?\p{L}+`:这部分使用 Unicode 属性匹配一个或多个任意语言的字母。`\p{L}` 匹配任何语言的字母字符。可选的前导空格由 `?` 表示。
  • `?\p{N}+`:类似地,这部分匹配一个或多个数字。`\p{N}` 匹配任何数字字符。数字前面的空格是可选的。
  • `?[^\s\p{L}\p{N}]+`:这部分匹配任何不是空格、字母或数字的字符序列。这可能包括标点符号、特殊字符等。字符序列前的空格是可选的。
  • `\s+(?!§)`:这个部分稍微复杂一些。`\s+` 匹配一个或多个空白字符,`(?!§)` 是一个负向前瞻断言,确保后面不跟着非空白字符。这样的组合意味着它匹配字符串末尾的空白字符。
  • `\s+`:匹配一个或多个空白字符。

总的来说,这个正则表达式设计用于匹配包括缩写、单词、数字、特殊符号和某些空白字符在内的多种模式。它似乎用于某种形式的文本处理或分词任务,可能是在自然语言处理的上下文中。使用 `re.findall(self.pat, text)` 将返回给定文本中所有匹配这些模式的子串的列表。

用几个例子来说明

pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
print(re.findall(pat, "this"))
['this']

空格会在单词之前:

print(re.findall(pat, "it is a dog"))
['it', ' is', ' a', ' dog', '1']

额外的空格被单独抽取出来,数字和英文分开:

print(re.findall(pat, "this   is gpt3"))
['this', '  ', ' is', ' gpt', '3']

所有格被切分:

print(re.findall(pat, "he's"))
['he', "'s"]

中文标点会被切分,但不含标点的连续中文不会被切分:

print(re.findall(pat, "你好,你在用 AI 吗"))
['你好', ',', '你在用', ' AI', '  ', ' 吗']

还有我们最初的例子:

print(re.findall(pat, "朋友,it's a good day."))
['朋友', ',', 'it', "'s", ' a', ' good', ' day', '.']

这么看 re.findall(self.pat, text) 就是一个基础的分词算法,它比 text.split() 考虑了更多的英语语法特点以及通用的标点符号。此外,相比与 text.split() 或者其他会丢弃掉空格的预分词器,这种方式不会丢弃掉空格(或许也不会丢掉任何其他字符),空格本身是一个需要被训练的词。

2.1.2. byte 转换

对于 re.findall 返回的所有分词的结果,先转换成 utf-8 编码,然后从中进行查表

token = "".join(
    self.byte_encoder[b] for b in token.encode("utf-8")
)  

直接打印出 demo 例子的中间过程:

def show_demo_byte(token):
    utf8_rep = token.encode("utf-8")
    print(f"UTF8 Hex of {token} is {utf8_rep}")
    token_list = [byte_encoder[b] for b in utf8_rep]
    _token = "".join(token_list)
    print(f"{token}->{token_list}->{_token}")

demo_tokens =  re.findall(pat, demo_sentence)
for i, x in enumerate(demo_tokens):
    print((i, x))
(0, '朋友')
(1, ',')
(2, 'it')
(3, "'s")
(4, ' a')
(5, ' good')
(6, ' day')
(7, '.')

先看所有的英文部分(键盘可以直接打出的词汇),注意 a, good, day 之前都有空格:

show_demo_byte(demo_tokens[2])
show_demo_byte(demo_tokens[3])
show_demo_byte(demo_tokens[5])
show_demo_byte(demo_tokens[6])
show_demo_byte(demo_tokens[7])
UTF8 Hex of it is b'it'
it->['i', 't']->it
UTF8 Hex of 's is b"'s"
's->["'", 's']->'s
UTF8 Hex of  good is b' good'
 good->['Ġ', 'g', 'o', 'o', 'd']->Ġgood
UTF8 Hex of  day is b' day'
 day->['Ġ', 'd', 'a', 'y']->Ġday
UTF8 Hex of . is b'.'
.->['.']->.

首先看到,这些 token 转换成 utf-8 后的结果是不变的,因为这些字符的 utf-8 和 ascii 码是兼容的。 但经过 byte_encoder 转换,由于空格不可见,作者用 Ġ 替代了,于是得到了以上的结果。

接着查看中文部分:

show_demo_byte(demo_tokens[0])
show_demo_byte(demo_tokens[1])
UTF8 Hex of 朋友 is b'\xe6\x9c\x8b\xe5\x8f\x8b'
朋友->['æ', 'ľ', 'ĭ', 'å', 'ı', 'ĭ']->æľĭåıĭ
UTF8 Hex of , is b'\xef\xbc\x8c'
,->['ï', '¼', 'Į']->ï¼Į

中文字符的 utf-8 基本是由 3 个 byte 表示。

2.1.3. BPE 合并

接着进入到核心的 bpe 函数:

bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))

这里 token 是经过 byte 转换的结果,比如 "朋友" 对应的是 æľĭåıĭ, 后文将以这个例子来说明,

demo_token = "æľĭåıĭ"

完整 bpe 函数如下:

def bpe(self, token):
    if token in self.cache:
        return self.cache[token]
    word = tuple(token)
    pairs = get_pairs(word)

    if not pairs:
        return token

    while True:
        bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
        if bigram not in self.bpe_ranks:
            break
        first, second = bigram
        new_word = []
        i = 0
        while i < len(word):
            try:
                j = word.index(first, i)
            except ValueError:
                new_word.extend(word[i:])
                break
            else:
                new_word.extend(word[i:j])
                i = j

            if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
                new_word.append(first + second)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_word = tuple(new_word)
        word = new_word
        if len(word) == 1:
            break
        else:
            pairs = get_pairs(word)
    word = " ".join(word)
    self.cache[token] = word
    return word

前两行检查是否已经有缓存,有的话直接返回。接着将 token 转为 tuple 类型:

word = tuple(demo_token)
print(word)
('æ', 'ľ', 'ĭ', 'å', 'ı', 'ĭ')

调用以下的 get_pairs 函数,目的是构造 bigram 集合,也就是相邻字符对的集合,这是为了确定哪些相邻 token 之间需要合并

def get_pairs(word):
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

print(get_pairs(word))
{('å', 'ı'), ('ı', 'ĭ'), ('æ', 'ľ'), ('ĭ', 'å'), ('ľ', 'ĭ')}

接着进入 while 循环,先根据 bpe_ranks 找出各个 bigram 的排序,并找到最小的排序对

bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
    break

第一次循环找到的是:

('æ', 'ľ')

接着又一个 while 循环,目的是用合并后的 'æľ' 取代 word 里原本的这两个相邻 token, 注意,有可能 word 里有多个相同的 bigram, 因此用 while 循环依次替换。

first, second = bigram
new_word = []
i = 0
while i < len(word):
    try:
        j = word.index(first, i)
    except ValueError:
        new_word.extend(word[i:])
        break
    else:
        new_word.extend(word[i:j])
        i = j

    if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
        new_word.append(first + second)
        i += 2
    else:
        new_word.append(word[i])
        i += 1
new_word = tuple(new_word)
word = new_word

因此,第一次循环后, word 应该是

('æľ', 'ĭ', 'å', 'ı', 'ĭ')

最后用空格 concate 后返回(可能是为了 cache 查找更高效?以 tuple 形式返回也可以)

word = " ".join(word)
self.cache[token] = word
return word

如此这般,直到合并的 word 长度为 1 或者找出来的 bigram 不在 bpe_ranks 里,也就无法合并为止。比如最后 "朋友" 这个词的 bbpe 表示为: ['æľ', 'ĭ', 'åı', 'ĭ']

对每个预分词后的 token 都这样处理,然后统一加入到 bpe_tokens 中

bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))

bpe_tokens 的变化:

['æľ', 'ĭ', 'åı', 'ĭ'] # "朋友"
['æľ', 'ĭ', 'åı', 'ĭ', 'ï', '¼', 'Į'] "朋友,"
['æľ', 'ĭ', 'åı', 'ĭ', 'ï', '¼', 'Į', 'it'] # "朋友,it"
['æľ', 'ĭ', 'åı', 'ĭ', 'ï', '¼', 'Į', 'it', "'s"] # "朋友,it's"
['æľ', 'ĭ', 'åı', 'ĭ', 'ï', '¼', 'Į', 'it', "'s", 'Ġa'] # "朋友,it's a" 
...
['æľ', 'ĭ', 'åı', 'ĭ', 'ï', '¼', 'Į', 'it', "'s", 'Ġa', 'Ġgood', 'Ġday', '.'] # "朋友,it's a good day." 

这里可以看到,原本中文的 "朋" 字是有 3 个字节的,分别是 æ ľ ĭ,但经过 bbpe 分词后,前两个字节聚在了一起,和最后一个 byte 分开了,很可能是这两个前缀也常出现在别的中文或其他非 ascii 字节的表示中。

另外, æ 字符的 unicode 由于是大于 255 的,因此它实际由两个 byte 组成:

'æ'.encode("utf-8")
b'\xc3\xa6'

不能直接去用以上方式解码 bbpe 的结果序列,应该先 self.byte_decoder 来解码,将大于 255 的字符表示转换回到小于等于 255 的 byte 范围内才行。

2.2. 逆向过程

tokeninze 的逆过程是 convert_tokens_to_string, 它在 PreTrainedTokenizerBase 里定义了接口,GPT2Tokenizer 实现

tokenizer.convert_tokens_to_string(tokenizer.tokenize("朋友,it's a good day."))
"朋友,it's a good day."

逆转的过程是,先把 tokens 拼接起来再转成一个 bytearray 后再解码成 utf-8 格式。

def convert_tokens_to_string(self, tokens):
    """Converts a sequence of tokens (string) in a single string."""
    text = "".join(tokens)
    text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
    return text

3. encode 等其他常用接口

前文从 tokenize 函数作为入口理清了 bbpe 的推理过程,实际中更常用的是直接把句子直接转换成 inptut_is, 这只是在 tokneize 的结果上进行 index 查询,考虑到完整性,这里也进行补充。 包括 encode_plus, encode 和直接把 tokenizer 作为函数调用。

encode 只是返回 encode_plus 结果里 ["input_ids"] id 部分

def encode(
    self,
    text: Union[TextInput, PreTokenizedInput, EncodedInput],
    #...
) -> List[int]:
    encoded_inputs = self.encode_plus(text,...)
    return encoded_inputs["input_ids"]

encode_plus 的核心是以下代码(在 PretrainedTokenizer._encode_plus 里具体实现)

tokens = self.tokenize(text, **kwargs)
return self.convert_tokens_to_ids(tokens)

convert_tokens_to_ids 的核心实现是 GPT2Tokenizer._convert_token_to_id:

def _convert_token_to_id(self, token):
    """Converts a token (str) in an id using the vocab."""
    return self.encoder.get(token, self.encoder.get(self.unk_token))

这就是一个查询 vocab.json 字典的过程

最后,如果直接调用 tokenizer 实例,效果和调用 encode_plus 类似,会返回 attention_mask ,但这些本文不再解析

tokenizer("朋友,it's a good day.")
{'input_ids': [17312, 233, 20998, 233, 171, 120, 234, 270, 338, 257, 922, 1110, 13], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
radioLinkPopups

如对本文有任何疑问,欢迎通过 github issue 邮件 metaescape at foxmail dot com 进行反馈