深入理解BPE、WordPiece、Unigram分词算法

1. 前言

NLP任务中最重要的一个环节就是分词。分词器(Tokenizer)在整个任务流程中扮演的角色如下

即给定一段文本,分词器会将其分割成一个个token,这些token会根据vocab转化成对应的ID以作为模型的输入。

完整的分词流程包含以下四个步骤

  • 标准化阶段(Normalization):先将原始文本(raw text)做一个预处理,例如去掉Unicode字符的重音(é 变成 e),将所有字母全部转化成小写字母(E 变成 e)等。这一步可以理解为数据清洗;
  • 预分词阶段(Pre-tokenization):进行一遍粗略的分词,例如基于空格&标点的分词,即word-level。注意预分词结果的粒度必须大于最终分词结果的粒度;
  • 模型阶段(Model):在预分词的语料上进行训练(注意分词器的训练和模型的训练是两个概念!模型的训练是通过梯度下降来降低loss,具有随机性,而分词器的训练是一个统计过程,最终的结果是确定性的);
  • 后处理阶段(Postprocessor):添加一些特殊的token,例如BERT中的 [CLS][SEP]

分词的粒度主要有三种:word-level、char-level和subword-level,对于前两种粒度,则不需要经历第三阶段(模型阶段)。

本文将主要聚焦于第三阶段,但在此之前,先让我们回顾一下word-level和char-level分词。

1.1 word-level分词

word-level分词的一个直观出发点是基于空格分词,即对字符串调用 split() 方法

text = "Don't you love 🤗 Transformers? We sure do."
print(text.split())
# ["Don't", 'you', 'love', '🤗', 'Transformers?', 'We', 'sure', 'do.']

但这个结果不是最优的,因为 Transformers?do. 中都含有标点符号。如果允许标点符号紧跟在单词后面,那么当语料库足够大的时候,Transformers?Transformers!Transformers. 等都会纳入到词表当中,这显然不是我们希望看到的。

我们可以调用 tokenizers 库中的 Whitespace() 来实现基于空格&标点的分词

from tokenizers.pre_tokenizers import Whitespace

tokenizer = Whitespace()
print(tokenizer.pre_tokenize_str(text))
# [('Don', (0, 3)), ("'", (3, 4)), ('t', (4, 5)), ('you', (6, 9)), ('love', (10, 14)), ('🤗', (15, 16)), ('Transformers', (17, 29)), ('?', (29, 30)), ('We', (31, 33)), ('sure', (34, 38)), ('do', (39, 41)), ('.', (41, 42))]

📝 Whitespace() 内部使用的是正则表达式 \w+|[^\w\s]+

但这个结果仍不是最优的,因为 Don't 的意思是 Do not,所以应当具有 [Do, n't] 的分词结果,而不是 [Don', t]

要想达到合理的效果,就要使用基于规则的分词器了,例如 spaCy 或者 Moses

1.2 char-level分词

char-level分词非常简单,即把字符串中的每一个字符看作一个token

text = "Don't you love 🤗 Transformers? We sure do."
print(list(text))
# ['D', 'o', 'n', "'", 't', ' ', 'y', 'o', 'u', ' ', 'l', 'o', 'v', 'e', ' ', '🤗', ' ', 'T', 'r', 'a', 'n', 's', 'f', 'o', 'r', 'm', 'e', 'r', 's', '?', ' ', 'W', 'e', ' ', 's', 'u', 'r', 'e', ' ', 'd', 'o', '.']

1.3 为什么要有subword-level分词?

  • word-level分词很容易造成词表过大,且容易出现OOV问题(Transformer-XL使用word-level分词,词表大小为267735,而基于subword-level的词表大小通常不会超过50000);
  • 在word-level意义下,由look衍生出的looks、looking、looked等都会被添加进词表当中,但这些词的区别仅仅在于时态,所以我们更应当保存前缀 look 和时态 singed(假设 n n n 个词每个都有三种时态,那么后者相比前者能够节约 3 n − ( n + 3 ) = 2 n − 3 3n-(n+3)=2n-3 3n(n+3)=2n3 个token的空间);
  • 基于word-level的模型学到的old、older、oldest之间的关系无法泛化到smart、smarter、smartest之间;
  • char-level能够很好地解决OOV问题,但一个char能够表示的语义远没有一个word能够表示的语义丰富;
  • 按照char-level分词很容易导致sequence length过长。

subword-level分词粒度介于两者之间,能够较好地平衡OOV问题。

2. 子词分词

子词分词(Subword Tokenization)主要有以下三种:

  • Byte-Pair Encoding:GPT、GPT-2、RoBERTa、DeBERTa、BART等;
  • WordPiece:BERT、DistilBERT、MobileBERT等;
  • Unigram:T5、XLNet、ALBERT、mBART等。

2.1 BPE

2.1.1 学习流程

BPE既然属于subword-level分词,那自然就要经历预分词阶段,即需要先对原始文本进行一遍word-level分词。

为叙述简便起见,假设在经历了标准化、预分词阶段后我们的语料库只含以下 5 5 5 种单词(每个单词右侧的数字代表这个单词出现的频数)

("hug", 10), ("pug", 5), ("pun", 12), ("bun", 4), ("hugs", 5)

BPE首先建立一个基础词表(base vocabulary),它由构成上述所有单词的字符组成,即

["b", "g", "h", "n", "p", "s", "u"]

接下来,BPE将不断学习merge rule以从基础词表中选取两个元素并合并成一个新的元素然后添加到词表中(原先的两个元素不会删除),直到词表达到预先指定的大小为止

📝 每一条merge rule都对应了一个要添加到词表中的新元素,merge rules通常又简称为merges,所以最终词表大小=基础词表大小+merges的数量。基础词表的大小通常是确定的,所以唯一的超参数就是merges的数量。例如:

  • GPT-1:base vocab size:478,merges:40k,total vocab size:40478
  • GPT-2:base vocab size:256,merges:50k,special tokens:1,total vocab size:50257

基础词表是char-level的,说明BPE也是从char-level开始学习merges,即一开始合并的都是char,随着学习的进行,合并的就是subword了。所以,我们需要将之前的每个单词都划分成字符

("h" "u" "g", 10), ("p" "u" "g", 5), ("p" "u" "n", 12), ("b" "u" "n", 4), ("h" "u" "g" "s", 5)

现统计每个单词内的所有pair,并找到出现频数最高的那一个pair(这里的pair指的是相邻元素组成的对,例如 [a, bc, d, e] 含有 3 3 3 个pair:[(a, bc), (bc, d), (d, e)]),然后合并这个pair。

在上面的例子中,很明显 (u, g) 这个pair出现的频数最高,它一共出现了 20 20 20 次,所以我们学到的第一条merge rule就是 (u, g) -> ug,我们将 ug 添加到词表当中,然后将语料库中的所有 (u, g) 合并成 ug

Vocabulary: ["b", "g", "h", "n", "p", "s", "u", "ug"]
Merges: ("u" "g")
Corpus: ("h" "ug", 10), ("p" "ug", 5), ("p" "u" "n", 12), ("b" "u" "n", 4), ("h" "ug" "s", 5)

继续重复上述步骤,我们可以发现 (u, n) 这个pair出现的频率最高,高达 16 16 16 次,所以BPE学到的第二条merge rule是 (u, n) -> un,此时有

Vocabulary: ["b", "g", "h", "n", "p", "s", "u", "ug", "un"]
Merges: ("u" "g"), ("u" "n")
Corpus: ("h" "ug", 10), ("p" "ug", 5), ("p" "un", 12), ("b" "un", 4), ("h" "ug" "s", 5)

继续重复上述步骤,我们可以发现 (h, ug) 这个pair出现的频率最高,共出现 15 15 15 次,所以BPE学到的第三条merge rule是 (h, ug) -> hug,此时有

Vocabulary: ["b", "g", "h", "n", "p", "s", "u", "ug", "un", "hug"]
Merges: ("u" "g"), ("u" "n"), ("h" "ug")
Corpus: ("hug", 10), ("p" "ug", 5), ("p" "un", 12), ("b" "un", 4), ("hug" "s", 5)

假设我们预先希望学得 3 3 3 个merges,那么BPE的流程就到此结束了。

2.1.2 分词流程

我们已经知道BPE是怎样学习的了,那么给定一段文本,我们如何使用BPE对它进行分词呢?

假设文本是 str 型,自然地,我们需要先对它应用前两个流程:标准化和预分词,此时文本变成了 List[str] 型,其中的每一个元素都是一个word。之后,对于每一个word,我们先将它全部拆为字符,然后从前往后扫描之前学得的merges,能够应用一条就应用一条。当所有word都过完后,我们就得到了最终的分词结果。

这里仅以单个word为例。假设有一个word是 bug,我们共有三个merges

("u", "g") -> "ug"
("u", "n") -> "un"
("h", "ug") -> "hug"

首先将 bug 拆为字符:[b, u, g],然后从前往后扫描merges,发现第一条可以应用,应用之后得到 [b, ug],而第二、三条都不能应用,于是对于该word的分词结束。

对于 mug 而言,同样是先拆成字符 [m, u, g],注意到由于 m 不在词表当中,所以实际拆完后是 [[UNK], u, g],合并完后是 [[UNK], ug]

同理可得,对于 thug 而言,其分词结果是 [[UNK], hug]。对于 unhug,其分词结果是 [un, hug]

📝 可以看出,只要单词中有一个字符没有在基础词表中出现,那么关于该单词的分词结果就会出现 [UNK]。所以,如果要想保证对于一段文本的分词不会出现 [UNK],那么就需要保证该文本的预分词结果中的每一个字符都被包含在基础词表当中。

2.1.3 从零实现BPE

由于标准化和预分词这两个阶段并不在我们的考虑范围内,所以假设现在已经得到了预分词的结果 pre_tokenized_res,其类型是 List[str],其中的每一个元素都可以视为一个word。

根据2.1.1中的内容,我们首先需要统计出每个word出现的频数,然后对每个word进行char-level的划分以方便后续进行merge

def get_word_freqs(pre_tokenized_res: List[str]) -> Dict[str, int]:
    word_freqs = defaultdict(int)
    for word in pre_tokenized_res:
        word_freqs[word] += 1
    return word_freqs


def get_word_splits(word_freqs: Dict[str, int]) -> Dict[str, List[str]]:
    word_splits = {}
    for word in word_freqs.keys():
        word_splits[word] = list(word)
    return word_splits

我们还需要基础词表以及计算每个pair出现的频数

def get_base_vocab(word_splits: Dict[str, List[str]]) -> List[str]:
    vocab = []
    for split in word_splits.values():
        vocab.extend(split)
    vocab = sorted(list(set(vocab)))
    return vocab


def get_pair_freqs(
    word_freqs: Dict[str, int],
    word_splits: Dict[str, List[str]],
) -> Dict[Tuple[str, str], int]:
    pair_freqs = defaultdict(int)
    for freq, split in zip(word_freqs.values(), word_splits.values()):
        if len(split) == 1:
            continue
        for i in range(len(split) - 1):
            pair_freqs[(split[i], split[i + 1])] += freq
    return pair_freqs

假设我们已经从 pair_freqs 找到了出现频数最多的那个 pair,除了将这个 pair 添加到 merges 之外,还需要对 word_splits 中的每一个这样的 pair 进行合并,同时再计算一次 pair_freqs

def merge_pair(pair, word_splits):
    for word in word_splits.keys():
        split = word_splits[word]
        if len(split) == 1:
            continue
        i = 0
        while i < len(split) - 1:
            if split[i] == pair[0] and split[i + 1] == pair[1]:
                split = split[:i] + [pair[0] + pair[1]] + split[i + 2:]
            i += 1
        word_splits[word] = split
    return word_splits

⚠️ 至于为什么用 while 而不用 for 请参考 Python中关于可变循环的一些坑

万事俱备,现在可以定义BPE的训练函数了

def bpe_train(num_merges, word_freqs, word_splits):
    merges = []
    vocab = get_base_vocab(word_splits)
    for _ in range(num_merges):
        pair_freqs = get_pair_freqs(word_freqs, word_splits)
        if not pair_freqs:
            break
        max_freq_pair = max(pair_freqs, key=pair_freqs.get)
        word_splits = merge_pair(max_freq_pair, word_splits)
        merges.append(max_freq_pair)
        vocab.append(''.join(max_freq_pair))
    return vocab, merges

我们可以拿WikiText-2的验证集来看一下BPE算法到底学到了什么,设置merges的数量为50,输出结果如下

from datasets import load_dataset
from tokenizers.pre_tokenizers import Whitespace

pre_tokenizer = Whitespace()

corpus = []
data = load_dataset("wikitext", "wikitext-2-v1")["validation"]
for line in data:
    tokenized_line = pre_tokenizer.pre_tokenize_str(line['text'])
    corpus.extend(list(map(lambda x: x[0], tokenized_line)))

word_freqs = get_word_freqs(corpus)
word_splits = get_word_splits(word_freqs)
vocab, merges = bpe_train(50, word_freqs, word_splits)
print(vocab)
# ['!', '"', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '~', '£', '°', '²', '½', 'É', 'Î', 'Ú', 'á', 'ç', 'é', 'ë', 'í', 'ü', 'ā', 'ō', 'š', 'α', 'β', 'γ', 'μ', '‑', '–', '—', '‘', '’', '“', '”', '′', '″', '⁄', '₤', '−', '♭', '♯', 'th', 'in', 'the', 'un', 'an', 'er', 'unk', 'on', 'ed', 'at', 're', 'en', 'or', 'st', 'and', 'of', 'al', 'ar', 'as', 'to', 'ing', 'es', 'it', 'is', 'ro', 'ic', 'he', 'ion', 'ou', 'il', 'le', 'ent', 'ac', 'ad', 'se', 'was', 'ur', 'for', 'The', 'be', 'ly', 'om', 'am', 'id', 'ig', 've', 'ch', 'lo', '@-', '@-@']
print(merges)
# [('t', 'h'), ('i', 'n'), ('th', 'e'), ('u', 'n'), ('a', 'n'), ('e', 'r'), ('un', 'k'), ('o', 'n'), ('e', 'd'), ('a', 't'), ('r', 'e'), ('e', 'n'), ('o', 'r'), ('s', 't'), ('an', 'd'), ('o', 'f'), ('a', 'l'), ('a', 'r'), ('a', 's'), ('t', 'o'), ('in', 'g'), ('e', 's'), ('i', 't'), ('i', 's'), ('r', 'o'), ('i', 'c'), ('h', 'e'), ('i', 'on'), ('o', 'u'), ('i', 'l'), ('l', 'e'), ('en', 't'), ('a', 'c'), ('a', 'd'), ('s', 'e'), ('w', 'as'), ('u', 'r'), ('f', 'or'), ('T', 'he'), ('b', 'e'), ('l', 'y'), ('o', 'm'), ('a', 'm'), ('i', 'd'), ('i', 'g'), ('v', 'e'), ('c', 'h'), ('l', 'o'), ('@', '-'), ('@-', '@')]

假设BPE已经训练完毕,给定一段文本,我们如何使用merges对它进行分词呢?

依然只考虑预分词后的文本,即 pre_tokenized_text,其中的每个元素都是一个 word,我们需要先对 word 划分成char-level,然后从前往后遍历merges并应用到 word 上。

先只考虑仅对一个 word 分词

def tokenize_word(word: str, vocab: List[str], merges: List[Tuple[str, str]]):
    get_pairs = lambda word: [(word[i], word[i + 1]) for i in range(len(word) - 1)]
    bpe_ranks = dict(zip(merges, range(len(merges))))  # 给每一条merge rule编号,方便之后从前往后查找

    if len(word) == 1:
        return word

    word = ['[UNK]' if char not in vocab else char for char in word]
    pairs = get_pairs(word)

    while True:
        bigram = min(pairs, key=lambda pair: bpe_ranks.get(pair, float('inf')))  # 因为要从前往后去应用,所以要找到下标最小的那个pair
        if bigram not in bpe_ranks:
            break

        i = 0
        while i < len(word) - 1:
            if word[i] == bigram[0] and word[i + 1] == bigram[1]:
                word = word[:i] + [bigram[0] + bigram[1]] + word[i + 2:]
            i += 1

        if len(word) == 1:
            break
        else:
            pairs = get_pairs(word)

    return word

效果

from functools import partial

vocab = ["b", "g", "h", "n", "p", "s", "u", "ug", "un", "hug"]
merges = [('u', 'g'), ('u', 'n'), ('h', 'ug')]
tokenize_word = partial(tokenize_word, vocab=vocab, merges=merges)

words = ["hug", "bug", "mug", "unhug"]
for word in words:
    print(f"{word}: {tokenize_word(word)}")
# hug: ['hug']
# bug: ['b', 'ug']
# mug: ['[UNK]', 'ug']
# unhug: ['un', 'hug']

接下来我们就可以对一段文本进行分词了

def tokenize(pre_tokenized_text: List[str], vocab, merges):
    res = []
    for word in pre_tokenized_text:
        res.extend(tokenize_word(word, vocab, merges))
    return res

⚠️ 注意,截止目前,我们的实现主要是为了清晰地展示流程,并不是最高效的实现方式,下文相同。

2.1.4 解码流程

BPE的分词流程实际上就是编码流程(将单词拆分为子词),那么解码流程就是将子词组合成单词。什么场景下才会有这一需要呢?

我们知道,模型接收的输入就是子词序列(通常由子词的ID来表示),输出自然也是子词序列,但子词序列的可读性不强,我们需要将其转化成单词序列才方便人类阅读。

给定子词序列

[a, b, c, d, e, f]

这里每个字母都代表一个子词,显然 a 肯定是第一个单词的开头,那么怎么确定第一个单词的结尾呢?如果无法确定哪个子词是第一个单词的结尾,自然就无法确定哪个子词是第二个单词的开头。

一个直观的做法是,在BPE训练之前在每个单词的末尾添加一个标记,例如 #。这样一来,当BPE训练结束后,如果一个子词的末尾含有 #,说明这个子词必定是某一个单词的结尾。

按照这种方式,最终会得到如下形式的子词序列

[a, b#, c, d, e#, f#]

很容易将其恢复成单词序列

[ab, cde, f]

在BPE的原论文当中,这个标记就是 </w>,注意到我们在2.1.3中的实现并没有考虑标记,感兴趣的读者可自行实现。

2.2 WordPiece

2.2.1 学习流程

WordPiece和BPE的流程几乎一致,主要的区别在于,BPE每次按照出现频数最高这一原则来选取pair,而WordPiece则是按照能够最大限度提升语言模型概率这一原则来选取pair。

具体来讲,假设句子 S S S n n n 个子词组成 t 1 , t 2 , ⋯   , t n t_1,t_2,\cdots,t_n t1,t2,,tn,且各个子词之间是相互独立的,则

log ⁡ P ( S ) = ∑ i = 1 n log ⁡ P ( t i ) \log P(S)=\sum_{i=1}^n \log P(t_i) logP(S)=i=1nlogP(ti)

如果选取 t i , t j t_i,t_j ti,tj 进行合并,并记合并后的子词为 t k t_k tk,那么 log ⁡ P ( S ) \log P(S) logP(S) 的变化值为

log ⁡ P ( t k ) − ( log ⁡ P ( t i ) + log ⁡ P ( t j ) ) = log ⁡ P ( t k ) P ( t i ) P ( t j ) \log P(t_k)-(\log P(t_i)+\log P(t_j))=\log\frac{P(t_k)}{P(t_i)P(t_j)} logP(tk)(logP(ti)+logP(tj))=logP(ti)P(tj)P(tk)

可以看出,变化值就是两个子词之间的互信息,我们通常用 n ( t i , t j ) / ( n ( t i ) ⋅ n ( t j ) ) n(t_i,t_j)/(n(t_i)\cdot n(t_j)) n(ti,tj)/(n(ti)n(tj)) 来估计 P ( t k ) / ( P ( t i ) P ( t j ) ) P(t_k)/(P(t_i)P(t_j)) P(tk)/(P(ti)P(tj))

除此之外,WordPiece还有一点与BPE不同,那就是一开始在对单词进行char-level划分的时候,会在每个非单词开头的字符前加上 ## 前缀,这个前缀是用于解码过程中的子词合并。在合并两个子词时,后一个子词的前缀需要去掉,例如,(##h, ##ug) -> ##hug(b, ##ug) -> bug

2.2.2 分词流程

BPE训练结束后会保存 vocabmerges,而WordPiece只保存 vocab,那WordPiece是如何对一段文本进行分词的呢?

回顾BPE,它首先将一个单词拆成字符,然后从前往后遍历 merges 直到不能再应用为止。WordPiece并不会将单词拆分成字符,而是首先寻找出现在 vocab 中的最长前缀,然后将单词一分为二,用数学公式描述就是

j = arg max ⁡ i ( word [ : i ]      in     vocab ) j=\argmax_i (\text{word}[:i] \;\;\text{in\; vocab} ) j=iargmax(word[:i]in vocab)

然后将 word 拆分成 word[:j]##word[j:],前者添加到最终结果中,后者继续应用上述操作。

BPE仅仅会把那些没有出现在词表中的字符标成 [UNK],而在WordPiece的分词过程中,只要有一次在词表中找不到最长前缀,那么整个单词就会被标记为 [UNK]

2.2.3 从零实现WordPiece

由于WordPiece和BPE大体相似,所以我们只需要改写一部分函数即可。

首先是 get_word_splits 函数,我们需要给每个非单词开头的字符加上 ## 前缀

def get_word_splits(word_freqs: Dict[str, int]) -> Dict[str, List[str]]:
    word_splits = {}
    for word in word_freqs.keys():
        word_splits[word] = [c if i == 0 else f"##{c}" for i, c in enumerate(word)]
    return word_splits

BPE只需要计算 n ( t i , t j ) n(t_i,t_j) n(ti,tj),而WordPiece除了 n ( t i , t j ) n(t_i,t_j) n(ti,tj) 还需要计算 n ( t i ) n(t_i) n(ti)。我们将 get_pair_freqs 改写成 get_pair_scores

def get_pair_scores(
    word_freqs: Dict[str, int],
    word_splits: Dict[str, List[str]],
) -> Dict[Tuple[str, str], int]:
    pair_freqs = defaultdict(int)
    subword_freqs = defaultdict(int)

    for freq, split in zip(word_freqs.values(), word_splits.values()):
        if len(split) == 1:
            subword_freqs[split[0]] += freq
            continue
        for i in range(len(split) - 1):
            pair_freqs[(split[i], split[i + 1])] += freq
            subword_freqs[split[i]] += freq
        subword_freqs[split[-1]] += freq

    pair_scores = {
        pair: freq / (subword_freqs[pair[0]] * subword_freqs[pair[1]]) \
        for pair, freq in pair_freqs.items()
    }
    return pair_scores

然后是 merge_pair 函数,唯一需要注意的细节就是,当第二个子词含有 ## 前缀时,需要将其去掉

def merge_pair(pair, word_splits):
    a, b = pair
    for word in word_splits.keys():
        split = word_splits[word]
        if len(split) == 1:
            continue
        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                new_word = a + b[2:] if b.startswith('##') else a + b
                split = split[:i] + [new_word] + split[i + 2:]
            i += 1
        word_splits[word] = split
    return word_splits

最后就是WordPiece的训练函数了。注意由于不需要保存 merges,因此训练停止的条件就是词表大小达到预先设定的值

def wordpiece_train(vocab_size, word_freqs, word_splits):
    vocab = get_base_vocab(word_splits)
    while len(vocab) < vocab_size:
        pair_scores = get_pair_scores(word_freqs, word_splits)
        if not pair_scores:
            break
        max_score_pair = max(pair_scores, key=pair_scores.get)
        word_splits = merge_pair(max_score_pair, word_splits)
        a, b = max_score_pair
        new_word = a + b[2:] if b.startswith('##') else a + b
        vocab.append(new_word)
    return vocab

依旧拿WikiText-2的验证集来看一下WordPiece算法都学到了什么,设置最终词表大小为300,输出结果如下

from datasets import load_dataset
from tokenizers.pre_tokenizers import Whitespace

pre_tokenizer = Whitespace()

corpus = []
data = load_dataset("wikitext", "wikitext-2-v1")["validation"]
for line in data:
    tokenized_line = pre_tokenizer.pre_tokenize_str(line['text'])
    corpus.extend(list(map(lambda x: x[0], tokenized_line)))

word_freqs = get_word_freqs(corpus)
word_splits = get_word_splits(word_freqs)
vocab = wordpiece_train(300, word_freqs, word_splits)
print(vocab[:20])
# ['!', '"', '##,', '##-', '##.', '##0', '##1', '##2', '##3', '##4', '##5', '##6', '##7', '##8', '##9', '##@', '##A', '##B', '##C', '##D']
print(vocab[-20:])
# ['XIV', 'UP', '##PI', 'UK', 'NFL', 'NHC', 'NCAA', 'NHS', 'NME', 'WWE', 'FBI', 'FIBA', 'FIFA', 'DVD', 'kW', 'HIV', 'HMCS', 'HBO', 'NBA', 'GBA']

相比BPE,WordPiece的分词实现较为简单,我们只需要改写 tokenize_word 函数即可

def tokenize_word(word, vocab):
    new_word = []
    while len(word) > 0:
        i = len(word)
        while i > 0 and word[:i] not in vocab:
            i -= 1
        if i == 0:
            return ['[UNK]']
        new_word.append(word[:i])
        word = word[i:]
        if len(word) > 0:
            word = f"##{word}"
    return new_word

继续使用WikiText-2的验证集并设置最终词表大小为3000,尝试对不同单词进行分词,效果如下

words = ['apple', 'occupied', 'upload', 'company']
for word in words:
    print(tokenize_word(word, vocab))
# ['app', '##l', '##e']
# ['occupi', '##e', '##d']
# ['up', '##l', '##o', '##a', '##d']
# ['comp', '##a', '##n', '##y']

2.2.4 解码流程

我们可以根据 ## 来解码,以下提供了一种可能的实现方案

def decode(output):
    res, cur = [], ""

    for i in range(len(output)):
        if not output[i].startswith('##'):
            res.append(cur)
            cur = output[i]
        else:
            cur += output[i][2:]
    res.append(cur)

    return res[1:]

效果

output = ['Th', '##i', '##s', 'is', 'th', '##e', 'Hugg', '##i', '##n', '##g', 'Fac', '##e', 'c', '##o', '##u', '##r', '##s', '##e']
print(decode(output))
# ['This', 'is', 'the', 'Hugging', 'Face', 'course']

2.3 Unigram

2.3.1 学习流程

BPE和WordPiece都是从一个小的基础词表开始不断去扩充这个词表,而Unigram则与之相反,Unigram会先初始化一个大的词表,然后不断从中删去子词直至词表达到指定的大小。

Unigram初始化词表的方式有很多种,例如我们可以在预分词的结果上应用BPE算法并设置较大的merges以获得初始词表,或者计算预分词结果的所有严格子串并从中选取一些出现频率最高的子串作为初始词表。

这里以后者为例(注意初始词表还需要包含所有的基础字符以防止OOV):

def init_vocab(word_freqs: Dict[str, int], init_vocab_size: int) -> Dict[str, int]:
    char_freqs = defaultdict(int)
    subword_freqs = defaultdict(int)
    for word, freq in word_freqs.items():
        for i in range(len(word)):
            char_freqs[word[i]] += freq
            # 子词的长度至少是2,并且word[i:j]是左闭右开,因此j最多要取到len(word)
            for j in range(i + 2, len(word) + 1):
                # 必须是严格子词,即不能包含自身
                if j - i < len(word):
                    subword_freqs[word[i:j]] += freq
    char_freqs = sorted(char_freqs.items())
    subword_freqs = sorted(subword_freqs.items(), key=lambda x: x[1], reverse=True)
    assert init_vocab_size > len(char_freqs)
    vocab_with_freqs = char_freqs + subword_freqs[:init_vocab_size - len(char_freqs)]
    return dict(vocab_with_freqs)

在得到了初始词表后,我们如何对它进行剪枝呢?这里有必要先了解一下Unigram的分词流程,读者可先跳转至2.3.2节。

现在假定你已经阅读完了2.3.2节。Unigram首先会计算整个语料库(预分词结果)上的loss,具体而言,设语料库中的所有单词为 w 1 , ⋯   , w N w_1,\cdots,w_N w1,,wN(可能会有重复),那么整个语料库的loss可以表示成

l o s s ( c o r p u s ) = ∑ i = 1 N l o s s ( w i ) = ∑ i = 1 K f i ⋅ l o s s ( w i ′ ) , word_freqs [ w i ′ ] = f i , { w i ′ } i = 1 K ⊂ { w i } i = 1 N , ∑ i = 1 K f i = N loss(corpus)=\sum_{i=1}^Nloss(w_i)=\sum_{i=1}^Kf_i\cdot loss(w'_i),\quad \text{word\_freqs}[w'_i]=f_i,\quad \{w'_i\}_{i=1}^K \subset \{w_i\}_{i=1}^N,\quad \sum_{i=1}^Kf_i=N loss(corpus)=i=1Nloss(wi)=i=1Kfiloss(wi),word_freqs[wi]=fi,{wi}i=1K{wi}i=1N,i=1Kfi=N

l o s s ( w i ′ ) loss(w'_i) loss(wi) 就是 w i ′ w'_i wi 经过Unigram分词后对应的loss,即

l o s s ( c o r p u s ) = ∑ i = 1 K word_freqs [ w i ′ ] ⋅ tokenize_word ( w i ′ ,    vocab_with_loss ) [ 1 ] loss(corpus)=\sum_{i=1}^K\text{word\_freqs}[w'_i]\cdot \text{tokenize\_word}(w'_i,\;\text{vocab\_with\_loss})[1] loss(corpus)=i=1Kword_freqs[wi]tokenize_word(wi,vocab_with_loss)[1]

接下来,Unigram会为vocab中的每个子词计算一个score,这个score等于从vocab中移除该子词后 l o s s ( c o r p u s ) loss(corpus) loss(corpus) 的变化值。可以证明,这个变化值一定是非负的,因此score越小说明这个子词在vocab中越不重要,故可以移除。

⚠️ 不少教程会把这个score称为loss。

将vocab中的子词按照score从小到大进行排序,并移除前 p % p\% p% 个子词(通常取 10 , 20 10,20 10,20),然后重新计算 vocab_with_loss 。反复执行上述操作直至vocab大小符合要求。

📝 关于变化值是非负的证明:

注意到 l o s s ( c o r p u s ) loss(corpus) loss(corpus) 仅与 l o s s ( w i ′ ) loss(w'_i) loss(wi) 有关,因此只需要研究后者的变化值。当移除了一些子词后,包含了这些子词的分词方案也就不再存在,相应的分词方案空间 S ( w i ′ ) \mathcal{S}(w'_i) S(wi) 就会缩小,而 l o s s ( w i ′ ) loss(w'_i) loss(wi) 是基于 S ( w i ′ ) \mathcal{S}(w'_i) S(wi) 计算出的最小值,因此这个最小值会增大。

类似于: min ⁡ S ′ ≥ min ⁡ S \min S'\geq \min S minSminS,其中 S ′ ⊆ S S'\subseteq S SS

2.3.2 分词流程

Unigram分词,顾名思义,它假设各个子词之间相互独立,因此对于一个单词 w w w 而言,将其拆分成子词序列 w = ( w 1 , ⋯   , w m ) w=(w_1,\cdots,w_m) w=(w1,,wm) 后有

P ( w ) = ∏ i = 1 m P ( w i ) (1) P(w)=\prod_{i=1}^m P(w_i)\tag{1} P(w)=i=1mP(wi)(1)

其中 P ( w i ) P(w_i) P(wi) 可以通过 w i w_i wi 出现的频率来估计(注意不是频数)。

S ( w ) \mathcal{S}(w) S(w) 为单词 w w w 的所有分词方案,从而Unigram的分词结果为

w ∗ = arg max ⁡ w ′ ∈ S ( w ) P ( w ′ ) (2) w^*=\argmax_{w'\in\mathcal{S}(w)} P(w')\tag{2} w=wS(w)argmaxP(w)(2)

w w w 的长度为 n n n,若限定 m > 1 m>1 m>1(即必须对 w w w 进行切分),则可知 ∣ S ∣ = 2 n − 1 − 1 |\mathcal{S}|=2^{n-1}-1 S=2n11(使用隔板法,共有 n − 1 n-1 n1 个空,可放 1 , 2 , ⋯   , n − 1 1,2,\cdots,n-1 1,2,,n1 个隔板),因此暴力求解 w ∗ w^* w 的时间复杂度是 O ( 2 n ) O(2^n) O(2n),这显然是不可行的。

从数值稳定性的角度来讲,计算 P ( w i ) P(w_i) P(wi) 的乘积没有计算 log ⁡ P ( w i ) \log P(w_i) logP(wi) 的和稳定,因此我们可以对 ( 1 ) (1) (1) 式取负对数,并记 l o s s ( w i ) = − log ⁡ P ( w i ) loss(w_i)=-\log P(w_i) loss(wi)=logP(wi),于是就有

l o s s ( w ) = ∑ i = 1 m l o s s ( w i ) (3) loss(w)=\sum_{i=1}^m loss(w_i)\tag{3} loss(w)=i=1mloss(wi)(3)

从而 ( 2 ) (2) (2) 式变成 w ∗ = arg min ⁡ w ′ ∈ S ( w ) l o s s ( w ′ ) w^*=\argmin_{w'\in\mathcal{S}(w)} loss(w') w=argminwS(w)loss(w)

回到求解 w ∗ w^* w 的问题上,我们可以使用动态规划的方法将时间复杂度降低至 O ( n 3 ) O(n^3) O(n3)。具体来讲,记 d p [ j ] dp[j] dp[j] 为对 w [ : j ] w[:j] w[:j] 分词得到的最小loss,即 d p [ j ] = l o s s ( w [ : j ] ∗ ) dp[j]=loss(w[:j]^*) dp[j]=loss(w[:j]),从而 d p [ len ( w ) ] = l o s s ( w [ : len ( w ) ] ∗ ) = l o s s ( w ∗ ) dp[\text{len}(w)]=loss(w[:\text{len}(w)]^*)=loss(w^*) dp[len(w)]=loss(w[:len(w)])=loss(w) 就是最终答案。

⚠️ 这里的 w[:j] 遵循Python的切片,即左闭右开,取不到 j

接下来看转移方程。对 w [ : j ] w[:j] w[:j] 进行分词可以得到若干个子词,这里我们只关注最后一个子词,不妨记为 t o k e n token token,如下:

t o k e n token token 的左端点 i i i 的取值范围是 [ 0 , j ) [0,j) [0,j),由 ( 3 ) (3) (3) 式,我们有

d p [ j ] = min ⁡ 0 ≤ i < j ( d p [ i ] + l o s s ( t o k e n ) ) = min ⁡ 0 ≤ i < j ( d p [ i ] + l o s s ( w [ i : j ] ) ) (4) dp[j]=\min_{0\leq i<j}( dp[i]+loss(token))=\min_{0\leq i<j}(dp[i]+loss(w[i:j]))\tag{4} dp[j]=0i<jmin(dp[i]+loss(token))=0i<jmin(dp[i]+loss(w[i:j]))(4)

显然 d p [ 1 ] dp[1] dp[1] 就是 l o s s ( w [ 0 ] ) loss(w[0]) loss(w[0]),而由 ( 4 ) (4) (4) 式, d p [ 1 ] = d p [ 0 ] + l o s s ( w [ 0 ] ) dp[1]=dp[0]+loss(w[0]) dp[1]=dp[0]+loss(w[0]),故得边界条件 d p [ 0 ] = 0 dp[0]=0 dp[0]=0

注意,在更新 d p [ j ] dp[j] dp[j] 的时候,我们还需要保存相应的 i i i,这样一来,当 d p dp dp 数组计算完后,我们就可以从 j = len ( w ) j=\text{len}(w) j=len(w) 开始不断向前回溯以获得 w ∗ w^* w

为了实现该算法,我们需要先将 vocab_with_freqs 处理成 vocab_with_loss

import math

def process_vocab(vocab_with_freqs: Dict[str, int]) -> Dict[str, float]:
    total_sum = sum(vocab_with_freqs.values())
    vocab_with_loss = {token: -math.log(freq / total_sum) for token, freq in vocab_with_freqs.items()}
    return vocab_with_loss

接下来实现Unigram分词

def tokenize_word(word: str, vocab_with_loss: Dict[str, float]):
    dp = [{'start': None, 'loss': None} for _ in range(len(word) + 1)]
    dp[0] = {'start': -1, 'loss': 0}

    for i in range(len(word)):
        for j in range(i + 1, len(word) + 1):
            token = word[i:j]
            if token in vocab_with_loss and dp[i]['loss'] is not None:
                new_loss = dp[i]['loss'] + vocab_with_loss[token]
                if dp[j]['loss'] is None or new_loss < dp[j]['loss']:
                    dp[j] = {'start': i, 'loss': new_loss}

    word_loss = dp[-1]['loss']
    if word_loss is None:
        return ['[UNK]'], word_loss

    start, end = dp[-1]['start'], len(word)
    res = []
    while ~start:
        res.append(word[start:end])
        end = start
        start = dp[start]['start']

    res.reverse()

    return res, word_loss

该动态规划算法又称Viterbi算法,时间复杂度为 O ( n 3 ) O(n^3) O(n3)。具体推导如下:由于切片的时间复杂度为 O ( j − i ) O(j-i) O(ji),所以总时间复杂度为

∑ i = 0 n − 1 ∑ j = i + 1 n O ( j − i ) = ∑ i = 0 n − 1 O ( ( 1 + n − i ) ( n − i ) 2 ) = ∑ i = 1 n O ( i ( i + 1 ) 2 ) = O ( ∑ i = 1 n i 2 ) = O ( n ( n + 1 ) ( 2 n + 1 ) 6 ) = O ( n 3 ) \begin{aligned} \sum_{i=0}^{n-1}\sum_{j=i+1}^nO(j-i)&=\sum_{i=0}^{n-1}O\left(\frac{(1+n-i)(n-i)}{2}\right)=\sum_{i=1}^nO\left(\frac{i(i+1)}{2}\right) \\ &=O\left(\sum_{i=1}^n i^2\right)=O\left(\frac{n(n+1)(2n+1)}{6}\right) \\ &=O(n^3) \end{aligned} i=0n1j=i+1nO(ji)=i=0n1O(2(1+ni)(ni))=i=1nO(2i(i+1))=O(i=1ni2)=O(6n(n+1)(2n+1))=O(n3)

2.3.3 从零实现Unigram

我们在前两小节中已经实现了一些函数,接下来只需完成剩下的部分。

计算整个语料库上的损失:

def corpus_loss(word_freqs, vocab_with_loss):
    loss = 0
    for word, freq in word_freqs.items():
        loss += freq * tokenize_word(word, vocab_with_loss)[1]
    return loss

为vocab中的每个子词计算相应的score:

def compute_scores(word_freqs, vocab_with_loss) -> List[str]:
    scores = {}
    pre_loss = corpus_loss(word_freqs, vocab_with_loss)
    for token in vocab_with_loss:
        if len(token) == 1: continue
        vocab_with_loss_ = copy.deepcopy(vocab_with_loss)
        vocab_with_loss_.pop(token)
        cur_loss = corpus_loss(word_freqs, vocab_with_loss_)
        scores[token] = cur_loss - pre_loss
    return sorted(scores, key=lambda x: x[1])

定义Unigram的训练函数:

def unigram_train(word_freqs, init_vocab_size, vocab_size, p):
    vocab_with_freqs = init_vocab(word_freqs, init_vocab_size)
    vocab_with_loss = process_vocab(vocab_with_freqs)
    while len(vocab_with_loss) > vocab_size:
        scores = compute_scores(word_freqs, vocab_with_loss)
        num_to_remove = min(int(len(vocab_with_loss) * p), len(scores))
        if num_to_remove == 0:
            break
        for i in range(num_to_remove):
            vocab_with_freqs.pop(scores[i])
        vocab_with_loss = process_vocab(vocab_with_freqs)
    return vocab_with_loss

同样拿WikiText-2的验证集来看一下Unigram算法都学到了什么,设置初始词表大小为500,最终词表大小为200,丢弃率为0.1,输出结果如下

pre_tokenizer = Whitespace()

corpus = []
data = load_dataset("wikitext", "wikitext-2-v1")["validation"]
for line in data:
    tokenized_line = pre_tokenizer.pre_tokenize_str(line['text'])
    corpus.extend(list(map(lambda x: x[0], tokenized_line)))

word_freqs = get_word_freqs(corpus)
vocab_with_loss = unigram_train(word_freqs, 500, 200, 0.1)
vocab_with_loss = [(k, v) for k, v in vocab_with_loss.items()]
print(vocab_with_loss[:20])
# [('!', 9.832349524702789), ('"', 5.970450906457211), ('$', 9.517268478062894), ('%', 9.646946301371425), ('&', 10.77681113354364), ("'", 6.286570914229525), ('(', 6.786064971985966), (')', 6.787827086979365), ('*', 13.128186390707118), ('+', 12.435039210147172), (',', 4.56506426740248), ('-', 6.273304601332049), ('.', 4.7355365343654805), ('/', 9.001052005662027), ('0', 5.683060932736877), ('1', 5.607139156414499), ('2', 6.043540944928233), ('3', 6.869561406868152), ('4', 6.992621499625379), ('5', 6.809218276960683)]
print(vocab_with_loss[-20:])
# [('pt', 7.320043900726674), ('ated', 7.327579731415376), ('ound', 7.362995287922273), ('ite', 7.38678705247961), ('St', 7.388393478527883), ('ese', 7.396464547361675), ('Au', 7.414453585197748), ('oug', 7.427742817316431), ('stra', 7.432772165721433), ('qu', 7.434454251904418), ('rth', 7.444606623368436), ('ctio', 7.44801378169005), ('ction', 7.449721724035206), ('ass', 7.451432588438836), ('ave', 7.451432588438836), ('ough', 7.458305467726598), ('au', 7.475697210438467), ('nter', 7.502365457520629), ('ish', 7.511415293040546), ('oth', 7.515058284319047)]

分词效果

words = ['apple', 'ground', 'finish']
for word in words:
    print(tokenize_word(word, vocab_with_loss))
# (['a', 'p', 'p', 'l', 'e'], 17.16241864337219)
# (['g', 'r', 'ound'], 14.555457675482492)
# (['f', 'i', 'n', 'ish'], 17.145052257815873)

因为词表较小,所以分词效果并不理想。

2.3.4 解码流程

Unigram通常与SentencePiece结合在一起使用。SentencePiece会将文本中的空格替换成 (U+2581),解码时只需要执行如下操作即可

detokenized = ''.join(pieces).replace('▁', ' ')

3. 三种分词方法的比较

分词方法BPEWordPieceUnigram
训练从小的基础词表开始不断学习规则以扩充词表和BPE相同从大的词表开始不断学习规则以缩减词表
规则合并出现频率最高的pair合并互信息最大的pair移除使得语料库的loss增加最小的子词
结果词表+合并规则仅有词表带有loss的词表
编码先将单词拆成字符,然后不断应用合并规则在词表中找到单词的最长前缀,对单词的剩余部分递归执行上述操作选取损失最小的分词方案

Ref

[1] https://huggingface.co/docs/transformers/tokenizer_summary
[2] https://paddlepedia.readthedocs.io/en/latest/tutorials/pretrain_model/subword.html
[3] https://huggingface.co/learn/nlp-course/chapter6/5?fw=pt
[4] https://zhuanlan.zhihu.com/p/86965595
[5] https://wmathor.com/index.php/archives/1517/
[6] https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/tokenization_gpt2.py#L104