huggingface之tokenization基础结构Trie-代码解读
目的
基于transformers查看字典树是怎么生成的呢,输入字符串text是怎么在字典树中进行分割的,一起来看一下
参考链接
wikipedia
代码
def traversal_states_second(states, offsets, start, current, text, skip, reset, to_remove):
# trie_pointer存在最后终止符,需要重置和存储结果在offsets中
# lookahead要匹配最长的
# 比如extra_id_1 vs extra_id_100这种,先找到extra_id_1,还是会继续匹配的
# "[CLS]", "L" 匹配CLS
for lookstart, looktrie_pointer in states.items():
# if lookstart in to_remove:# 额外处理一下的,处理那个bug
# continue
if lookstart > start:
# This partial match is later, we can stop looking
break# 这个匹配是后面的结果,可以停止看了
"""
trie = Trie()
trie.add("喜欢你")
trie.add("欢你")
trie.split("我喜欢你一起玩")
# ['我', '喜欢你', '一起玩']
# 当匹配到char“一”的时候,states=OrderedDict([(1, {'': 1}), (2, {'': 1})])
# start=1时,trie_pointer = {'': 1} 则'' in trie_pointer为true
# 开始重新遍历states
# lookstart=2时,lookstart>start,这里对应的token "欢你" 已经是后面的匹配了,停止
"""
elif lookstart < start:
# This partial match is earlier, the trie pointer
# was already updated, so index is + 1
# 这是前面的匹配,trie_pointer已经更新了,所以index=当前+1
lookahead_index = current + 1
end = current + 1
"""
trie = Trie()
trie.add("喜欢你")
trie.add("欢")
trie.split("我喜欢你一起玩")
# ['我', '喜欢你', '一起玩']
# 当匹配到char“你”的时候,states=OrderedDict([(1, {'你': {'': 1}}), (2, {'': 1})])
# start=2时,trie_pointer = {'': 1} 则'' in trie_pointer为true
# 开始重新遍历states
# lookstart=1时,lookstart<start
# 虽然是在"欢"时找到了对应的终止符,但是前面的匹配"喜欢"也不能不管了,所以要看一下下面的char是否匹配
"""
else:
# Here lookstart == start and
# looktrie_pointer == trie_pointer
# It wasn't updated yet so indices are current ones
lookahead_index = current# 当前token对应的位置
end = current#
"""
trie = Trie()
trie.add("喜欢你")
trie.split("我喜欢你一起玩")
# ['我', '喜欢你', '一起玩']
# 当匹配到char“一”的时候,states=OrderedDict([(1, {'': 1})])
# start=1时,trie_pointer = {'': 1} 则'' in trie_pointer为true
# 开始重新遍历states
# lookstart=1时,lookstart=start
# 当前token并未验证是否在trie_pointer中,所以需要记录当前char对应的位置
"""
next_char = text[lookahead_index] if lookahead_index < len(text) else None# 下一个字符,避免extra_id_1 vs extra_id_100这种情况
if "" in looktrie_pointer:# 找到终止符了,对应了lookstart == start这种情况
start = lookstart
end = lookahead_index
skip = lookahead_index
while next_char in looktrie_pointer:# 对应lookstart<start这种情况,要把后面的token相同的都找到
# 这里start、end值都是有了,可能不更新,也可能更新多轮
# 这里我没仔细找例子,先把整个流程跑通,重新找demo
looktrie_pointer = looktrie_pointer[next_char]
lookahead_index += 1# 后面的char位置
if "" in looktrie_pointer:# 找到终止符了
start = lookstart
end = lookahead_index
skip = lookahead_index
if lookahead_index == len(text):# 我咋没把这种情况测试出来呢
# End of string
print ('End of string')
break
next_char = text[lookahead_index]
# End lookahead
# Storing and resetting
offsets.append(start)
offsets.append(end)
reset = True
return offsets, skip, reset, to_remove
def traversal_states(states, current_char, offsets, current, text, skip, reset, to_remove):
for start, trie_pointer in states.items():# 遍历states
if "" in trie_pointer:# 这个指针到结束位置了
offsets, skip, reset, to_remove = traversal_states_second(states, offsets, start, current, text, skip, reset, to_remove)
break
elif current_char in trie_pointer:# 当前字符在trie_pointer中,则更新指针,将states更新
trie_pointer = trie_pointer[current_char]# 喜 欢 和 你
states[start] = trie_pointer# 慢慢的找到这个单词的剩余字符
else:
# 当前字符不能匹配到trie_pointer,需要停止追踪这个匹配
# 因为python迭代器的工作方式,不能直接在这个循环中执行这个操作
to_remove.add(start)# 因为这里的原因,所以会造成['ab c', 'd']的情况
return offsets, skip, reset, to_remove
import bisect
import itertools
import re
import unicodedata
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union, overload
class Trie:
"""
Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass
Loose reference https://en.wikipedia.org/wiki/Trie
"""
def __init__(self, *args):
self.data = {}
self._tokens = set()
self._termination_char = ""
self.update(*args)
def update(self, *args):
"""
Updates the Trie with new tokens provided as arguments.
新单词作为参数更新Trie
demo:
trie = Trie(("hello","you"))
trie.data
# {'h': {'e': {'l': {'l': {'o': {'': 1}}}}}, 'y': {'o': {'u': {'': 1}}}}
trie._tokens
# {'hello', 'you'}
Args:
*args: Variable number of words to be added to the Trie.
添加到Trie的单词
"""
for token in tuple(*args):
self.add(token)
def add(self, word: str):
"""
Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
The special key `""` in `self._termination_char` is used to represent termination.
This function is idempotent, adding twice the same word will leave the trie unchanged
Example:
```python
>>> trie = Trie()
>>> trie.add("Hello 友達")
>>> trie.data
{"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}
>>> trie.add("Hello")
>>> trie.data
{"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}
```
"""
if not word:# ''或者为None
# Prevent empty string
return
self._tokens.add(word)# 要把word添加到self._tokens中
ref = self.data
for char in word:
ref[char] = ref.setdefault(char, {})# 这一步会更新self.data
ref = ref[char]# 这一步不会更新self.data
ref[self._termination_char] = 1
def split(self, text: str) -> List[str]:
"""
Will look for the words added to the trie within `text`. Output is the original string splitted along the
boundaries of the words found.
This trie will match the longest possible word first !
Example:
```python
>>> trie = Trie()
>>> trie.split("[CLS] This is a extra_id_100")
["[CLS] This is a extra_id_100"]
>>> trie.add("[CLS]")
>>> trie.add("extra_id_1")
>>> trie.add("extra_id_100")
>>> trie.split("[CLS] This is a extra_id_100")
["[CLS]", " This is a ", "extra_id_100"]
```
"""
# indexes are counted left of the chars index.
# "hello", index 0, is left of h, index 1 is between h and e.
# index 5 is right of the "o".
# States are going to capture every possible start (indexes as above)
# as keys, and have as values, a pointer to the position in the trie
# where we're at. This is a partial match for now.
# This enables to keep track of multiple matches while we're iterating
# the string
# If the trie contains, "blowing", and "lower" and we encounter the
# string "blower", we need to split into ["b", "lower"].
# This is where we need to keep track of multiple possible starts.
# indexes是字符index的左计数
states = OrderedDict()
# offsets包含了每一个需要分割的分块,强制在位置0和位置len(text)进行分割
offsets = [0]
# This is used by the lookahead which needs to skip over
# some text where the full match exceeded the place in the initial
# for loop
# 这是由lookahead使用的,它需要跳过一些完全匹配超出初始 for 循环中位置的文本
skip = 0
# Main loop, Giving this algorithm O(n) complexity
# 主循环,O(n) 复杂的
# current是当前位置,current_char是当前字符
for current, current_char in enumerate(text):
if skip and current < skip:
continue
to_remove = set()# 将停止匹配的状态,都放到to_remove,停止追踪
reset = False# 当找到一个匹配,就丢掉一切,这是一个贪心算法,会匹配到第一个发现的token
offsets, skip,reset, to_remove = traversal_states(states, current_char, offsets, current, text, skip, reset, to_remove)# 遍历states,这一步会遍历每个states
if reset:# 找到匹配时,就丢掉一切,即重置states
states = {}
else:# 没有找到匹配,将停止匹配的状态从states中删除,停止追踪
for start in to_remove:
del states[start]
# If this character is a starting character within the trie
# start keeping track of this partial match.
# 当前字符在self.data中,要开始追踪了
# skip代表了最后一个找到的char对应的位置,因为在off_1中也会遍历char位置,所以设置了这样一个位置指针
# 贪心算法,前面token已经把这个char使用上了,后面即使有一样的,也不会管了
"""
trie = Trie()
trie.add("喜欢你一")
trie.add("一起玩")
trie.split("我喜欢你一起玩")
# ['我', '喜欢你一', '起玩']
"""
# self.data = {'喜': {'欢': {'和': {'你': {'': 1}}}, '爱': {'': 1}}, '欢': {'喜': {'': 1}}}
# '喜' in self.data == True
if current >= skip and current_char in self.data:
states[current] = self.data[current_char]# 根据current_char在self.data中找到对应的value,即剩余的字符串
# We have a cut at the end with states.
for start, trie_pointer in states.items():
if "" in trie_pointer:
# This is a final match, we need to reset and
# store the results in `offsets`.
# 最后的匹配了,将结果保存在offsets中
end = len(text)
offsets.append(start)
offsets.append(end)
# Longest cut is always the one with lower start so the first
# item so we need to break.
break# 这里没有看懂
return self.cut_text(text, offsets)
def cut_text(self, text, offsets):
# We have all the offsets now, we just need to do the actual splitting.
# We need to eventually add the first part of the string and the eventual
# last part.
offsets.append(len(text))# 把最开始和最后的部分都要加上
tokens = []
start = 0
for end in offsets:
if start > end:
print (
"There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it"
" anyway."
)
continue
elif start == end:
# This might happen if there's a match at index 0
# we're also preventing zero-width cuts in case of two
# consecutive matches
continue
tokens.append(text[start:end])
start = end
return tokens
测试用例
新建
trie = Trie(("sea","see"))
trie._tokens# {'sea', 'see'}
trie.data# {'s': {'e': {'a': {'': 1}, 'e': {'': 1}}}}
# ->s->e->a
# ->s->e->e
添加
trie = Trie()
trie.add("Hello 友達")
trie.data
分割
trie = Trie()
trie.split("[CLS] This is a extra_id_100")
# ["[CLS] This is a extra_id_100"]
trie = Trie()
trie.add("喜欢你")
trie.split("我喜欢你一起玩")
trie = Trie()
trie.add("喜欢你一")
trie.add("一起玩")
trie.add("你一起玩")
trie.split("我喜欢你一起玩")
# ['我', '喜欢你一', '起玩']
trie = Trie()
trie.add("abc")
trie.add("b")
trie.split("ab cd")
# ['ab c', 'd']
# 这里有个bug,对应了traversal_states_second中的这里
# if lookstart in to_remove:# 额外处理一下的,处理那个bug
# continue
额外说几句
查看这个代码的过程中,发现了一个bug,在git上提交了,还没有得到回复,可能是我自己想错了。
这里的代码逻辑比较散,需要对应多种特殊情况,希望能画出流程图来。