Trie 树
本文是对过去做过的 leetcode 题里与 trie 树相关的问题的整理,加上了一些图片和文字说明。
Trie 的读音和 try 类似,取自 retrieval 单词,因此这个结构的主要作用是提供便于检索的接口,而检索的内容和字符串或者序列的前缀有密切关联。
集合性质的 Trie: 序列的存在性检查
Trie 以及基本接口
208. 实现 Trie (前缀树) - 力扣(LeetCode) 题目如下:
Trie(发音类似 "try")或者说 前缀树 是一种树形数据结构,用于高效地存储和检索字符串数据集中的键。这一数据结构有相当多的应用情景,例如自动补完和拼写检查。 请你实现 Trie 类: - Trie() 初始化前缀树对象。 - void insert(String word) 向前缀树中插入字符串 word 。 - boolean search(String word) 如果字符串 word 在前缀树中,返回 true(即,在检索之前已经插入);否则,返回 false 。 - boolean startsWith(String prefix) 如果之前已经插入的字符串 word 的前缀之一为 prefix ,返回 true ;否则,返回 false 。
首先根据题目描述,绘制出插入几个单词到 trie 树后的示意图,如下图 (a),可以看到,这是一种多叉树,每个结点的分叉数量还可能不一样,因此可以用一个字典来表示树的分叉。 其次,插入元素和搜索过程都是"垂直的", 每次插入之后都要继续向树的下一层移动,然后递归地把剩余元素插入,图(b) 突出了插入的每个单词最后一个字符,在 trie 树中,任何一个结点都可以对应一条从根结点到该结点的路径,而任何一条这样的路径都是一个前缀。
python 的初步实现
相比于递归实现,我们先考虑更容易一点的循环实现:
class Trie:
def __init__(self):
self.branch = {}
def insert(self, word: str) -> None:
cur_branch = self.branch
for char in word:
if char not in cur_branch:
cur_branch[char] = {}
cur_branch = cur_branch[char]
def search(self, word: str) -> bool:
cur_branch = self.branch
for char in word:
if char not in cur_branch:
return False
cur_branch = cur_branch[char]
return cur_branch == {}
def startsWith(self, prefix: str) -> bool:
cur_branch = self.branch
for char in prefix:
if char not in cur_branch:
return False
cur_branch = cur_branch[char]
return True
以上算法设计是有问题的,因为无法考虑某个前缀是否又是一个完整的词,比如先插入了 apple, 再插入 app 后,接着调用 search(app) 返回是 False, 因为以上算法把 app 看作是 prefix 而不是路径, 但实际应该返回 True 。 这里体现的是,搜索某个完整的词不等于搜索 Trie 中的某条从根结点到叶子节点的路径,中间结点也有可能是一个完整的词的终点。以上算法构造出的实际是上图 (a) 中的树,它无法区分 word 和 prefix ,也许有些场景是只有 prefix 搜索要求的,但本题中要的是上图 (b) 所示的树,任何插入的单词的结束位置都应该带有标识,一般称为哨兵( sentinel )。
因此我们在单词结束的 branch 字典里添加一个特殊的结束标识,扩展成以下代码:
class Trie:
def __init__(self):
self.branch = {}
def insert(self, word: str) -> None:
cur_branch = self.branch
for char in word:
if char not in cur_branch:
cur_branch[char] = {}
cur_branch = cur_branch[char]
cur_branch["end"] = True
def search(self, word: str) -> bool:
cur_branch = self.branch
for char in word:
if char not in cur_branch:
return False
cur_branch = cur_branch[char]
return cur_branch.get("end", False)
def startsWith(self, prefix: str) -> bool:
cur_branch = self.branch
for char in prefix:
if char not in cur_branch:
return False
cur_branch = cur_branch[char]
return True
以上实现的 trie 和图 (b) 也是有差异的,它更接近以下的图 (c):
经典应用场景:拼写检查
以上 Trie 树是集合性质的,这意味着可以向 trie 中添加元素,也可以判断该元素是否在 trie 中(额外可以区分是以 prefix 还是完整单词的形式存在),它的一个应用是拼写检查,比如假设 valid_words 表示所有拼写正确的英文单词,以下用几个词作为示例:
valid_words = ["this", "apple", "is", "a", "app", "apply", "big", "bath", "batman"]
trie = Trie()
for word in valid_words:
trie.insert(word)
那么当用户在编辑器里输入了一段话,就可以把这段话拆分成单词进行逐一检查,并且用简单的可视化显示出拼写错误的词:
editor_content = "this is a bid apple"
def spell_check(content):
result = []
for word in content.split():
if trie.search(word):
result.append(word)
else:
result.append(f"<{word}>")
return " ".join(result)
spell_check(editor_content)
this is a <bid> apple
这里 bid 不在合法词库里,因此被括号标注出来。
注意,该场景下完全可以用集合来替代 trie, 把所有合法单词放在一个 set 中,然后将
if trie.search(word):
改为
if word in valid_words_set:
以上 demo 的使用场景下,trie 并没有特别优势,甚至还有劣势,因为每个结点要用一个字典表示,而这仅仅是保存一个英文字母,比如一个极端的例子: 对于字符 "abc" 用 trie 来表示,是如下结构
{"a", {"b": {"c": {}}}}
用 set 则是
{"abc"}
前者比后者占据更多空间,当字典数量更大的时候,空间差距会减小,但大多数情况 trie 还是比 set 占用更多空间(参考 Trie Data Structure | Interview Cake)。
但 trie 的好处在于,如果要返回某个前缀下的所有单词,或者对错误的单词依据前缀匹配进行后选词推荐,效率要比 set 或者 hash 表高很多。对于后者,要找出拥有公共前缀单词必须对字典里的所有单词进行遍历比较,非常低效。 在下一章 "字典性质的 Trie" 中会介绍到 trie 在序列补全或推荐上的应用。
[优化]:代码结构上的抽象
以上实现的 Trie 对象中,三个类方法非常相似,尤其是 search 和 startsWith, 除了最后一行,其他代码都是一样的,因此可以对其进行一次抽象,构造一个中间函数,该函数是返回某个 prefix 在 trie 树上进行搜索后返回的子树(branch)
class Trie:
def __init__(self):
self.branch = {}
def insert(self, word: str) -> None:
cur_branch = self.branch
for char in word:
if char not in cur_branch:
cur_branch[char] = {}
cur_branch = cur_branch[char]
cur_branch["end"] = True
def find_prefix(self, prefix: str) -> dict:
cur_branch = self.branch
for char in prefix:
if char not in cur_branch:
return {}
cur_branch = cur_branch[char]
return cur_branch
def search(self, word: str) -> bool:
return self.find_prefix(word).get("end", False)
def startsWith(self, prefix: str) -> bool:
return self.find_prefix(prefix) != {}
以上 startsWith 的实现之所以是合理的,是因为任何在 trie 中的 prefix 返回的剩余 子树 branch,要么其中一定是有 "end" 入口的,要么是还有其他字符,因此通过是否为空就可以判断是否为 prefix.
接着继续尝试抽象出 insert 和 find_prefix 的共同代码结构,两个函数前 5 行代码中,除了当 char 不在 branch 条件下语句不同,其他都是一样的。除此之外,insert 函数只是对 find_prefix 返回的结果字典中添加 end 标签。因此抽象的核心关注点还是 if 语句。一种直接的写法就是再加入一个条件判断:
class Trie:
def __init__(self):
self.branch = {}
def find_prefix(self, prefix: str, insertion = False) -> dict:
cur_branch = self.branch
for char in prefix:
if char not in cur_branch:
if insertion:
cur_branch[char] = {}
else:
return {}
cur_branch = cur_branch[char]
return cur_branch
def insert(self, word: str) -> None:
self.find_prefix(word, True)["end"] = True
def search(self, word: str) -> bool:
return self.find_prefix(word).get("end", False)
def startsWith(self, prefix: str) -> bool:
return self.find_prefix(prefix) != {}
这样通过一个 find_prefix 就统一了三个函数,但这种写法会使得可读性变差。
[优化]:用字母数组表示树节点
我们继续针对 leetcode 题目场景进行优化,由于插入的元素都是英文单词,其中不允许那些包含标点的单词,比如 don't,因此可以用一个长度为 26 的数组来表示 a 到 z, 数组中记录的是下一个节点的地址,而由于数组无法像字典那样灵活地加入 "end" 哨兵,这时就要引入一个数据结构来封装 trie 的结点,该结点中包括一个数组和一个结束标识符,大致如以下形式:
structure Node children: Node[Alphabet-Size] end: Boolean
其图示如下,这里黄色数组元素表示该节点还连接了一个哨兵 node, 其中的 end 为 true:
这种实现的扩展性很好,因为可以往 Node 中加入更多 value 值,用于存储额外信息,这在下一章会有更详细说明
python 实现为:
class Node:
def __init__(self):
self.end = False
self.children = [None for i in range(26)]
def __contains__(self, char):
if not char.isalpha() or len(char) != 1:
return False
return self.children[index(char)] is not None
def __getitem__(self, char):
if char not in self:
raise KeyError(f"{char} is not in Trie")
return self.children[index(char)]
def __setitem__(self, char, value):
if not char.isalpha() or len(char) != 1 or not isinstance(value, Node):
raise KeyError("Key must be a single alphabetic character and value must be a Node instance")
self.children[index(char)] = value
@property
def empty(self):
return not self.end and all(child is None for child in self.children)
index = lambda char: ord(char.lower()) - ord('a')
以上实现许多下划线函数来支持 in, node[x] 取值和赋值的语法糖,这样可以使得以下算法实现和之前字典表示 branch 的实现的差异变得很小。另外我们用 emtpy 来区分这是一个空的节点,这种节点没有对应的前缀序列, 相当于之前实现中的空字典。
class Trie:
def __init__(self):
self.branch = Node()
def find_prefix(self, prefix: str, insertion = False) -> dict:
cur_branch = self.branch
for char in prefix:
if char not in cur_branch:
if insertion:
cur_branch[char] = Node()
else:
return Node()
cur_branch = cur_branch[char]
return cur_branch
def insert(self, word: str) -> None:
self.find_prefix(word, True).end = True
def search(self, word: str) -> bool:
return self.find_prefix(word).end
def startsWith(self, prefix: str) -> bool:
return not self.find_prefix(prefix).empty
t = Trie()
t.startsWith("a")
False
这种优化在实际应用中也有使用,比如如果输入字符中包括更多英文标点甚至控制符(比如按键序列), 可以用一个 256 大小的数组来表示 branch.
[优化]:二维数组实现
用长度为 26 的数组表示了 branch/children 之后,我们还可以更进一步,把整个 Trie 用数组来表示,什么意思呢?以下是前文中 trie 节点的数据结构描述,其中 children 数组中每个都是一个 Node 元素。这里的思路就是,如果每个 node 的核心已经是一个 children 数组了,那么我们把这些数组继续放在一个数组中,那么每个 children 就有一个数字下标了,这样 children 里的元素也就是可以是 int 了。
structure Node children: Node[Alphabet-Size] end: Boolean
注意区分在 children 中保存 Node 和保存下标的区别:
- 保存 Node 对象实际保存的是地址,但因为 python 的设计,我们不需要手动维护这个地址,只需要用 children["a"]= Node() 这样的语句即可,但 python 解释器在内存某个地址创建了 Node() 对象并且把地址值赋值给了 children["a"]
- 保存数组下标的话,相当于自己去维护一个"虚拟内存"空间,因此当我们写 children["a"] = 10 的时候,我们要清楚 10 这个位置对应的是另外的哪个 children 。由于本题场景中不考虑删除,也就是说"虚拟内存" 是不断随着 insert 调用而单调非减的,于是可以按顺序一个个扩充这个二维数组,而如果不是 python 语言,比如 C 语言,没有内置的动态 list, 就需要先预计最多会消耗多少个字符,比如 20000,然后初始一个 20000x26 的二维数组。
另外,我们为了记录当前节点是否是 end, 可以用额外的数组或者把 branch 从 26 改成 27 个元素来额外记录。 不过意下直接用一个额外的 set 来保存 end 节点下标。
index = lambda char: ord(char.lower()) - ord('a')
SIZE = 26
class Trie:
def __init__(self):
self.all_branch = [[0] * SIZE]
self.ends = set()
def find_prefix(self, prefix: str, insertion = False) -> dict:
cur_branch = self.all_branch[0]
for char in prefix:
if cur_branch[index(char)] == 0:
if insertion:
cur_branch[index(char)] = len(self.all_branch)
self.all_branch.append([0] * SIZE)
else:
return 0
idx = cur_branch[index(char)]
cur_branch = self.all_branch[idx]
return idx
def insert(self, word: str) -> None:
self.ends.add(self.find_prefix(word, True))
def search(self, word: str) -> bool:
return self.find_prefix(word) in self.ends
def startsWith(self, prefix: str) -> bool:
return self.find_prefix(prefix) != 0
这种实现的图示如下:
注意,这个结果和插入元素的顺序有关,它对应的是以下词汇插入之后的情况:
trie = Trie()
for word in ["apple", "apply", "app", "and"]:
trie.insert(word)
from pprint import pprint
pprint(trie.all_branch)
pprint(trie.ends)
[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] {8, 3, 5, 6}
字典性质的 Trie: 路径和路径属性作为查询对象
leetcode 问题引入
设计一个 map ,满足以下几点: - 字符串表示键,整数表示值 - 返回具有前缀等于给定字符串的键的值的总和 实现一个 MapSum 类: - MapSum() 初始化 MapSum 对象 - void insert(String key, int val) 插入 key-val 键值对,字符串表示键 key ,整数表示值 val 。如果键 key 已经存在,那么原来的键值对 key-value 将被替代成新的键值对。 - int sum(string prefix) 返回所有以该前缀 prefix 开头的键 key 的值的总和。
这个题目需要用 dfs 对子树进行遍历,去各个树梢上"摘取果子" 并汇总
class MapSum:
def __init__(self):
self.branch = {}
def insert(self, key: str, val: int) -> None:
cur_branch = self.branch
for char in key:
if char not in cur_branch:
cur_branch[char] = {}
cur_branch = cur_branch[char]
cur_branch["val"] = val
def sum(self, prefix: str) -> int:
cur_branch = self.branch
for char in prefix:
if char not in cur_branch:
return 0
cur_branch = cur_branch[char]
return dfs(cur_branch)
def dfs(branch_or_val):
if branch_or_val == {}:
return 0
if type(branch_or_val) != dict:
return branch_or_val
return sum([dfs(b) for b in branch_or_val.values()])
[扩展场景]:词根替换
648. 单词替换 - 力扣(LeetCode) 题目如下:
在英语中,我们有一个叫做 词根(root) 的概念,可以词根后面添加其他一些词组成另一个较长的单词——我们称这个词为 继承词(successor)。例如,词根an,跟随着单词 other(其他),可以形成新的单词 another(另一个)。 现在,给定一个由许多词根组成的词典 dictionary 和一个用空格分隔单词形成的句子 sentence。你需要将句子中的所有继承词用词根替换掉。如果继承词有许多可以形成它的词根,则用最短的词根替换它。 你需要输出替换之后的句子。 示例 1: 输入:dictionary = ["cat","bat","rat"], sentence = "the cattle was rattled by the battery" 输出:"the cat was rat by the bat"
可以用 trie 树来实现,这里是要根据字符串来返回最短的前缀
class Trie:
def __init__(self):
self.branch = {}
def insert(self, prefix):
cur_branch = self.branch
for char in prefix:
if char not in cur_branch:
cur_branch[char] = {}
cur_branch = cur_branch[char]
cur_branch["end"] = True
def find_shortest_prefix(self, word):
prefix = []
cur_branch = self.branch
for char in word:
if char in cur_branch:
prefix.append(char)
cur_branch = cur_branch[char]
if cur_branch.get("end", False):
return "".join(prefix)
else:
break
return word
class Solution:
def replaceWords(self, dictionary: List[str], sentence: str) -> str:
prefix_tree = Trie()
for root in dictionary:
prefix_tree.insert(root)
res = []
for word in sentence.split():
res.append(prefix_tree.find_shortest_prefix(word))
return " ".join(res)
不过由于最多 100 个词根, 数据量不是很大,因此对 dictionary 按长度进行排序, 然后逐一去调用 statswith 判断也是可行的。
这个例子和词干提取(Stemming)的差别:
词干提取(Stemming) 词干提取是自然语言处理中的一个过程,旨在从单词中移除词缀(如后缀),以得到单词的基本形式或“词干”。例如:
"fishing", "fished", "fisher" → "fish" "argue", "argued", "argues", "arguing" → "argu" 词干提取通常使用一组规则来剥离词缀。这可能导致提取的“词干”并不是实际的单词或词根。例如,使用 Porter 词干算法,“happiness”会被剥离为“happi”。
- 词干提取(Stemming):通常使用固定的规则来剥离词缀,可能得到的不是实际的单词或词根。
- 您的算法(词根替换):使用一个词根列表来替换句子中的单词,如果单词以列表中的某个词根开始。
经典场景:自动补全和排序
1268. 搜索推荐系统 - 力扣(LeetCode) 题目是一个输入补全功能的核心实现, 要求对每个输入的前缀都返回推荐的前三个完整单词,"暴力" 的做法是每匹配一个字符后马上做一次 dfs 搜索,搜集到子树中所有的路径,然后取路径里 top3 返回:
class Trie:
def __init__(self, word_list):
self.branch = {}
for word in word_list:
self.insert(word)
def insert(self, word):
cur_branch = self.branch
for char in word:
if char not in cur_branch:
cur_branch[char] = {}
cur_branch = cur_branch[char]
cur_branch["end"] = True
def find_top3(self, prefix):
cur_branch = self.branch
res = []
for i, char in enumerate(prefix):
if char not in cur_branch:
for j in range(i, len(prefix)):
res.append([])
break
else:
all_suffix = dfs(cur_branch[char])
all_words = [prefix[:i+1]+suffix for suffix in all_suffix]
res.append(sorted(all_words)[:3])
cur_branch = cur_branch[char]
return res
def dfs(branch_or_end):
if branch_or_end == True:
return [""]
res = []
for key in branch_or_end:
suffixes = dfs(branch_or_end[key])
p = key if key != "end" else ""
res.extend([p + suffix for suffix in suffixes])
return res
class Solution:
def suggestedProducts(self, products: List[str], searchWord: str) -> List[List[str]]:
t = Trie(products)
return t.find_top3(searchWord)
这种做法有大量的重复工作,比如输入 "ab" 后调用 dfs 搜索的路径实际再输入前缀 "a" 时已经搜索过了。 一种优化是,对 branch_or_end 做一个 memo cache, 但由于它是字典类型,无法作为 cache 的 key, 用 Node 对象来表示 Trie 节点的话看上去是可以的。
对于本题来说,可以在插入的时候就进行类似缓存的操作, 在 Trie 节点里直接构造一个 top3 变量,保存在该前缀下的前三个字典序最小的单词,
class TrieNode:
def __init__(self):
self.branch = {}
self.top3 = []
class Trie:
def __init__(self, words):
self.root = TrieNode()
for word in words:
self.insert(word)
def insert(self, word):
cur_branch = self.root
for char in word:
if char not in cur_branch.branch:
cur_branch.branch[char]= TrieNode()
cur_branch.top3.append(word)
cur_branch.top3.sort()
if len(cur_branch.top3) > 3:
cur_branch.top3.pop()
cur_branch = cur_branch.branch[char]
class Solution:
def suggestedProducts(self, products: List[str], searchWord: str) -> List[List[str]]:
trie = Trie(products)
res = []
fake = TrieNode()
cur_branch = trie.root
for i, char in enumerate(searchWord):
top3 = []
if char not in cur_branch.branch:
cur_branch = fake # fake Trie
else:
top3 = cur_branch.top3
cur_branch = cur_branch.branch[char]
res.append(top3)
return t.find_top3(searchWord)
注意以上是把核心搜索补全功能的接口写在 trie 树之外,也可以把它放到 Trie 数据结构之中:
class TrieNode:
def __init__(self):
self.branch = {}
self.top3 = []
class Trie:
def __init__(self, words):
self.root = TrieNode()
for word in words:
self.insert(word)
def insert(self, word):
cur_branch = self.root
for char in word:
if char not in cur_branch.branch:
cur_branch.branch[char]= TrieNode()
cur_branch = cur_branch.branch[char]
cur_branch.top3.append(word)
cur_branch.top3.sort()
if len(cur_branch.top3) > 3:
cur_branch.top3.pop()
def get_top3(self, prefix):
res = []
fake = TrieNode()
cur_branch = self.root
for i, char in enumerate(prefix):
top3 = []
if char not in cur_branch.branch:
cur_branch = fake # fake Trie
else:
cur_branch = cur_branch.branch[char]
top3 = cur_branch.top3
res.append(top3)
return res
class Solution:
def suggestedProducts(self, products: List[str], searchWord: str) -> List[List[str]]:
trie = Trie(products)
return trie.get_top3(searchWord)
以上对比是为了凸现出 力扣官方题解 中的写法,以下 Trie 对象实际是 TrieNode, 而 Trie 的功能接口是在 suggestedProducts 中实现的,这里并没有像进行上一段代码一样进行封装。
class Trie:
def __init__(self):
self.child = dict()
self.words = list()
class Solution:
def suggestedProducts(self, products: List[str], searchWord: str) -> List[List[str]]:
def addWord(root, word):
cur = root
for ch in word:
if ch not in cur.child:
cur.child[ch] = Trie()
cur = cur.child[ch]
cur.words.append(word)
cur.words.sort()
if len(cur.words) > 3:
cur.words.pop()
root = Trie()
for word in products:
addWord(root, word)
ans = list()
cur = root
flag = False
for ch in searchWord:
if flag or ch not in cur.child:
ans.append(list())
flag = True
else:
cur = cur.child[ch]
ans.append(cur.words)
return ans
其他参考
在对效率要求非常高的场景下 trie 的设计: HAT-trie, a cache-conscious trie
提到了 set 和 trie 复杂度的文章: Trie Data Structure | Interview Cake
宫水三叶 trie 相关文章 中有 10 多篇 leetcode 中 trie 相关题目的解析。