当前位置: 首页 > article >正文

huggingface之tokenization基础结构Trie-代码解读

目的

基于transformers查看字典树是怎么生成的呢,输入字符串text是怎么在字典树中进行分割的,一起来看一下

参考链接

wikipedia

代码

def traversal_states_second(states, offsets, start, current, text, skip, reset, to_remove):
    # trie_pointer存在最后终止符,需要重置和存储结果在offsets中
    # lookahead要匹配最长的
    # 比如extra_id_1 vs extra_id_100这种,先找到extra_id_1,还是会继续匹配的
    # "[CLS]", "L" 匹配CLS
    for lookstart, looktrie_pointer in states.items():
    	# if lookstart in to_remove:# 额外处理一下的,处理那个bug
        #      continue
        if lookstart > start:
            # This partial match is later, we can stop looking
            break# 这个匹配是后面的结果,可以停止看了
            """
            trie = Trie()
            trie.add("喜欢你")
            trie.add("欢你")
            trie.split("我喜欢你一起玩")
            # ['我', '喜欢你', '一起玩']
            # 当匹配到char“一”的时候,states=OrderedDict([(1, {'': 1}), (2, {'': 1})])
            # start=1时,trie_pointer = {'': 1} 则'' in trie_pointer为true
            # 开始重新遍历states
            # lookstart=2时,lookstart>start,这里对应的token "欢你" 已经是后面的匹配了,停止
            """
        elif lookstart < start:
            # This partial match is earlier, the trie pointer
            # was already updated, so index is + 1
            # 这是前面的匹配,trie_pointer已经更新了,所以index=当前+1
            lookahead_index = current + 1
            end = current + 1
            """
            trie = Trie()
            trie.add("喜欢你")
            trie.add("欢")
            trie.split("我喜欢你一起玩")
            # ['我', '喜欢你', '一起玩']
            # 当匹配到char“你”的时候,states=OrderedDict([(1, {'你': {'': 1}}), (2, {'': 1})])
            # start=2时,trie_pointer = {'': 1} 则'' in trie_pointer为true
            # 开始重新遍历states
            # lookstart=1时,lookstart<start
            # 虽然是在"欢"时找到了对应的终止符,但是前面的匹配"喜欢"也不能不管了,所以要看一下下面的char是否匹配
            """
        else:
            # Here lookstart == start and
            #      looktrie_pointer == trie_pointer
            # It wasn't updated yet so indices are current ones
            lookahead_index = current# 当前token对应的位置
            end = current# 
            """
            trie = Trie()
            trie.add("喜欢你")
            trie.split("我喜欢你一起玩")
            # ['我', '喜欢你', '一起玩']
            # 当匹配到char“一”的时候,states=OrderedDict([(1, {'': 1})])
            # start=1时,trie_pointer = {'': 1} 则'' in trie_pointer为true
            # 开始重新遍历states
            # lookstart=1时,lookstart=start
            # 当前token并未验证是否在trie_pointer中,所以需要记录当前char对应的位置
            """
        next_char = text[lookahead_index] if lookahead_index < len(text) else None# 下一个字符,避免extra_id_1 vs extra_id_100这种情况
        if "" in looktrie_pointer:# 找到终止符了,对应了lookstart == start这种情况
            start = lookstart
            end = lookahead_index
            skip = lookahead_index
        while next_char in looktrie_pointer:# 对应lookstart<start这种情况,要把后面的token相同的都找到
        # 这里start、end值都是有了,可能不更新,也可能更新多轮
        # 这里我没仔细找例子,先把整个流程跑通,重新找demo
            looktrie_pointer = looktrie_pointer[next_char]
            lookahead_index += 1# 后面的char位置
            if "" in looktrie_pointer:# 找到终止符了
                start = lookstart
                end = lookahead_index
                skip = lookahead_index
            if lookahead_index == len(text):# 我咋没把这种情况测试出来呢
                # End of string
                print ('End of string')
                break
            next_char = text[lookahead_index]
        # End lookahead
    # Storing and resetting
    offsets.append(start)
    offsets.append(end)
    reset = True
    return offsets, skip, reset, to_remove


def traversal_states(states, current_char, offsets, current, text, skip, reset, to_remove):
    for start, trie_pointer in states.items():# 遍历states
        if "" in trie_pointer:# 这个指针到结束位置了
            offsets, skip, reset, to_remove = traversal_states_second(states, offsets, start, current, text, skip, reset, to_remove)
            break
        elif current_char in trie_pointer:# 当前字符在trie_pointer中,则更新指针,将states更新
            trie_pointer = trie_pointer[current_char]# 喜 欢 和 你
            states[start] = trie_pointer# 慢慢的找到这个单词的剩余字符
        else:
            # 当前字符不能匹配到trie_pointer,需要停止追踪这个匹配
            # 因为python迭代器的工作方式,不能直接在这个循环中执行这个操作
            to_remove.add(start)# 因为这里的原因,所以会造成['ab c', 'd']的情况
    return offsets, skip, reset, to_remove


import bisect
import itertools
import re
import unicodedata
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union, overload

class Trie:
    """
    Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass
    Loose reference https://en.wikipedia.org/wiki/Trie
    """
    def __init__(self, *args):
        self.data = {}
        self._tokens = set()
        self._termination_char = ""
        self.update(*args)
    def update(self, *args):
        """
        Updates the Trie with new tokens provided as arguments.
        新单词作为参数更新Trie
        demo:
            trie = Trie(("hello","you"))
            trie.data
            # {'h': {'e': {'l': {'l': {'o': {'': 1}}}}}, 'y': {'o': {'u': {'': 1}}}}
            trie._tokens
            # {'hello', 'you'}
        Args:
            *args: Variable number of words to be added to the Trie.
            添加到Trie的单词
        """
        for token in tuple(*args):
            self.add(token)
    def add(self, word: str):
        """
        Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
        The special key `""` in `self._termination_char` is used to represent termination.
        This function is idempotent, adding twice the same word will leave the trie unchanged
        Example:
        ```python
        >>> trie = Trie()
        >>> trie.add("Hello 友達")
        >>> trie.data
        {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}
        >>> trie.add("Hello")
        >>> trie.data
        {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}
        ```
        """
        if not word:# ''或者为None
            # Prevent empty string
            return
        self._tokens.add(word)# 要把word添加到self._tokens中
        ref = self.data
        for char in word:
            ref[char] = ref.setdefault(char, {})# 这一步会更新self.data
            ref = ref[char]# 这一步不会更新self.data
        ref[self._termination_char] = 1
    def split(self, text: str) -> List[str]:
        """
        Will look for the words added to the trie within `text`. Output is the original string splitted along the
        boundaries of the words found.
        This trie will match the longest possible word first !
        Example:
        ```python
        >>> trie = Trie()
        >>> trie.split("[CLS] This is a extra_id_100")
        ["[CLS] This is a extra_id_100"]
        >>> trie.add("[CLS]")
        >>> trie.add("extra_id_1")
        >>> trie.add("extra_id_100")
        >>> trie.split("[CLS] This is a extra_id_100")
        ["[CLS]", " This is a ", "extra_id_100"]
        ```
        """
        # indexes are counted left of the chars index.
        # "hello", index 0, is left of h, index 1 is between h and e.
        # index 5 is right of the "o".
        # States are going to capture every possible start (indexes as above)
        # as keys, and have as values, a pointer to the position in the trie
        # where we're at. This is a partial match for now.
        # This enables to keep track of multiple matches while we're iterating
        # the string
        # If the trie contains, "blowing", and "lower" and we encounter the
        # string "blower", we need to split into ["b", "lower"].
        # This is where we need to keep track of multiple possible starts.
        # indexes是字符index的左计数
        states = OrderedDict()
        # offsets包含了每一个需要分割的分块,强制在位置0和位置len(text)进行分割
        offsets = [0]
        # This is used by the lookahead which needs to skip over
        # some text where the full match exceeded the place in the initial
        # for loop
        # 这是由lookahead使用的,它需要跳过一些完全匹配超出初始 for 循环中位置的文本
        skip = 0
        # Main loop, Giving this algorithm O(n) complexity
        # 主循环,O(n) 复杂的
        # current是当前位置,current_char是当前字符
        for current, current_char in enumerate(text):
            if skip and current < skip:
                continue
            to_remove = set()# 将停止匹配的状态,都放到to_remove,停止追踪
            reset = False# 当找到一个匹配,就丢掉一切,这是一个贪心算法,会匹配到第一个发现的token
            offsets, skip,reset, to_remove = traversal_states(states, current_char, offsets, current, text, skip, reset, to_remove)# 遍历states,这一步会遍历每个states
            if reset:# 找到匹配时,就丢掉一切,即重置states
                states = {}
            else:# 没有找到匹配,将停止匹配的状态从states中删除,停止追踪
                for start in to_remove:
                    del states[start]
            # If this character is a starting character within the trie
            # start keeping track of this partial match.
            # 当前字符在self.data中,要开始追踪了
            # skip代表了最后一个找到的char对应的位置,因为在off_1中也会遍历char位置,所以设置了这样一个位置指针
            # 贪心算法,前面token已经把这个char使用上了,后面即使有一样的,也不会管了
            """
            trie = Trie()
            trie.add("喜欢你一")
            trie.add("一起玩")
            trie.split("我喜欢你一起玩")
            # ['我', '喜欢你一', '起玩']
            """
            # self.data = {'喜': {'欢': {'和': {'你': {'': 1}}}, '爱': {'': 1}}, '欢': {'喜': {'': 1}}}
            # '喜' in self.data == True
            if current >= skip and current_char in self.data:
                states[current] = self.data[current_char]# 根据current_char在self.data中找到对应的value,即剩余的字符串        
                # We have a cut at the end with states.
        for start, trie_pointer in states.items():
            if "" in trie_pointer:
                # This is a final match, we need to reset and
                # store the results in `offsets`.
                # 最后的匹配了,将结果保存在offsets中
                end = len(text)
                offsets.append(start)
                offsets.append(end)
                # Longest cut is always the one with lower start so the first
                # item so we need to break.
                break# 这里没有看懂
        return self.cut_text(text, offsets)
    def cut_text(self, text, offsets):
        # We have all the offsets now, we just need to do the actual splitting.
        # We need to eventually add the first part of the string and the eventual
        # last part.
        offsets.append(len(text))# 把最开始和最后的部分都要加上
        tokens = []
        start = 0
        for end in offsets:
            if start > end:
                print (
                    "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it"
                    " anyway."
                )
                continue
            elif start == end:
                # This might happen if there's a match at index 0
                # we're also preventing zero-width cuts in case of two
                # consecutive matches
                continue
            tokens.append(text[start:end])
            start = end
        return tokens


测试用例

新建

trie = Trie(("sea","see"))
trie._tokens# {'sea', 'see'}
trie.data# {'s': {'e': {'a': {'': 1}, 'e': {'': 1}}}}
# ->s->e->a
# ->s->e->e

添加

trie = Trie()
trie.add("Hello 友達")
trie.data

分割

trie = Trie()
trie.split("[CLS] This is a extra_id_100")
# ["[CLS] This is a extra_id_100"]


trie = Trie()
trie.add("喜欢你")
trie.split("我喜欢你一起玩")

trie = Trie()
trie.add("喜欢你一")
trie.add("一起玩")
trie.add("你一起玩")
trie.split("我喜欢你一起玩")
# ['我', '喜欢你一', '起玩']


trie = Trie()
trie.add("abc")
trie.add("b")
trie.split("ab cd")
# ['ab c', 'd']
# 这里有个bug,对应了traversal_states_second中的这里
    	# if lookstart in to_remove:# 额外处理一下的,处理那个bug
        #      continue

额外说几句

查看这个代码的过程中,发现了一个bug,在git上提交了,还没有得到回复,可能是我自己想错了。
这里的代码逻辑比较散,需要对应多种特殊情况,希望能画出流程图来。


http://www.kler.cn/a/372632.html

相关文章:

  • ffmpeg视频滤镜:添加边框-drawbox
  • 基于SpringBoot的汽车票网上预订系统
  • 使用 Mermaid 语言描述 AGI 系统架构图
  • PHP数据类型
  • Spring Boot 经典九设计模式全览
  • 宠物空气净化器哪个牌子好?有没有噪音低的宠物空气净化器推荐?
  • 【缓存与加速技术实践】Redis 主从复制
  • 银河麒麟v10安装Anaconda(python大蟒蛇)+pycharm安装
  • AJAX和JSON
  • K8S 容器可视化管理工具-kuboard 监控管理工具搭建
  • 操作数据表
  • 【蓝桥杯选拔赛真题81】python矩形数量 第十五届青少年组蓝桥杯python选拔赛真题 算法思维真题解析
  • C++ 中回调函数的实现方式-函数指针
  • ICT网络赛道安全考点知识总结1
  • 笔记整理—linux驱动开发部分(2)模块信息与编译
  • 记录一次查询优化
  • 关于Mac打包ipa的配置小结
  • Hyperledger Fabric有那些核心技术,和其他区块链对比Hyperledger Fabric有那些优势
  • Spring Boot 实现文件分片上传和下载
  • 运维端口号详解(Detailed Explanation of Operation and Maintenance Port Numbers)
  • 高效MySQL缓存策略
  • C++(运算符重载)
  • iQOO手机怎样将屏幕投射到MacBook?可以同步音频吗?
  • 【Searxng】Searxng docker 安装
  • 《IMM交互式多模型滤波MATLAB实践》专栏目录,持续更新……
  • 基于Django+python的车牌识别系统设计与实现(带文档)