当前位置: 首页 > article >正文

samout llm解码 幻觉更低更稳定

这段代码定义了一个简单的对话生成系统,包括模型加载、词汇表加载、以及基于给定提示生成文本的功能。下面是对代码的解析:

  1. load_model_and_voc(device="cpu"):

    • 该函数用于加载预训练的模型和词汇表(vocabulary)。它首先从文件 total_voc.pkl 中加载词汇表,并创建一个名为 SamOut 的神经网络实例。
    • 模型参数的数量被打印出来以供参考。
    • 然后尝试加载指定路径下的预训练权重到模型中,并将模型移动到指定的设备(CPU 或 GPU)上。
    • 最后设置模型为评估模式(.eval()),并返回模型和词汇表。
  2. gen_token(voc, model, prompt, max_len, rp=1.2, temp=0.13, top_k=16, device="cpu"):

    • 这个函数负责根据提供的提示(prompt)生成新的文本序列。
    • 它接受多个参数,包括词汇表、模型、初始提示、最大生成长度等。
    • 函数内部实现了重复抑制、温度调整和top-k采样等技术来控制生成文本的质量。
    • 使用softmax函数对模型输出进行处理,并通过多类别抽样选择下一个token。
    • 如果生成了特殊的开始标记 <|sos|>,则停止生成过程。
    • 生成的每个token会立即打印在屏幕上,形成即时响应的效果。
  3. t_infre():

    • 此函数是交互式推理循环,允许用户输入文本,然后调用 gen_token 函数来生成回应。
    • 它是一个无限循环,持续等待用户的输入直到程序被手动终止。
  4. if __name__ == '__main__':

    • 这部分代码确保当脚本作为主程序运行时,会执行某些特定的操作或测试。
    • 注释掉的代码可能是之前用于数据预处理、训练或其他实验的部分。
    • 最终调用了 t_infre() 函数来启动交互式推理。

需要注意的是,这里使用的 SamOut 类并没有在给出的代码片段中定义,因此你可能需要确保这个类已经被正确实现并在其他地方导入。此外,为了使代码能够正常工作,你需要确保所有依赖库(如 PyTorch 和 pandas)已经安装,并且所有提及的数据文件和模型权重文件都存在于正确的路径下。

def load_model_and_voc(device="cpu"):
    voc = pd.read_pickle("total_voc.pkl")

    net = SamOut(len(voc["voc"]), 1024 + 512, 64, 16)
    # net = SamOut(len(voc["voc"]), 512, 32, 8)
    print(sum([i.shape[0] * i.shape[1] for i in net.parameters() if len(i.shape) > 1]) + sum(
        [i.shape[0] for i in net.parameters() if len(i.shape) == 1]))

    # net.load_state_dict(torch.load("pretrain_768.pth", map_location=device))
    # net.load_state_dict(torch.load("pretrain_sft_single.pth", map_location=device))
    net.load_state_dict(torch.load("pretrain_sft_single_1024.pth", map_location=device))
    # net.load_state_dict(torch.load("pretrain.pth", map_location=device))
    net.to(device)
    net.eval()
    return net, voc


def gen_token(voc, model, prompt, max_len, rp=1.2, temp=0.13, top_k=16, device="cpu"):
    print("agent:", end="", flush=True)

    for _ in range(max_len):

        prompt_list = []
        for i in prompt:
            if i not in voc["voc"]:
                prompt_list += [voc["voc"].index(ii) for ii in voc["voc0"].get(i)]
            else:

                prompt_list.append(voc["voc"].index(i))
        out, _ = model(torch.Tensor([prompt_list]).to(device).long())
        out = out[:, -1:]
        # 重复抑制
        for token_id in enumerate(prompt_list):
            out[:, :, token_id] /= rp
        score = torch.softmax(out, -1)[0, 0]
        score, score_index = torch.sort(score,descending=True)
        score=score.detach().numpy()
        score_sum = np.cumsum(score)
        score_index = score_index.detach().numpy()
        score1=score[score_sum<0.8]
        if score1.size==0:
            score=score[:1]
        else:
            score=score1
        score_index=score_index[:score.size]



        out = score / temp

        v= out[:min(top_k, score.size)]



        idx_next = torch.multinomial(torch.Tensor(v), num_samples=1, generator=None)
        if voc["voc"][score_index[idx_next.item()]] == "<|sos|>":
            break
        prompt += [voc["voc"][score_index[idx_next.item()]]]
        print(prompt[-1], end="", flush=True)


def t_infre():
    model, voc = load_model_and_voc()
    while True:
        text = input("user:")
        gen_token(voc, model, ["<|user|>"] + list("{}".format(text)) + ["<|agent|>"], 64)
        print()


if __name__ == '__main__':
    # print(pd.read_pickle("loss916"))
    # gen_one_voc()
    # gen_voc()
    # for i in range(17,18):
    #     gen_pre_data_align(i, 16)

    # train()
    # gen_sft_single_data_align()
    # train_single()
    # sft 推理  一本正经的胡说八道已练成

    t_infre()

http://www.kler.cn/a/444134.html

相关文章:

  • 前端yarn工具打包时网络连接问题排查与解决
  • [LeetCode-Python版] 定长滑动窗口1(1456 / 643 / 1343 / 2090 / 2379)
  • C语言编程1.27汉诺塔
  • Python-存储数据-Thu-Fri
  • 开放词汇目标检测(Open-Vocabulary Object Detection, OVOD)综述
  • 【C语言】特殊指针汇总
  • CentOS 快捷安装 jenkins 并设置开机自启
  • vue相关的---Vuex
  • 游戏AI实现-寻路算法(DFS)
  • ESP-AT 固件:物联网智能 “引擎”
  • C语言学习day22:URLDownloadToFile函数/开发文件下载工具
  • [python]使用flask-caching缓存数据
  • QT图形/视图架构详解(二)
  • Oracle 技术精选学习
  • VScode使用教程(菜鸟版)
  • Day26下 - BERT项目实战
  • 2024 年的科技趋势
  • vue-cli 5接入模块联邦 module federation
  • 【GO环境安装】mac系统+GoLand使用
  • nginx 记录完整的 request 及 response
  • 使用JustAuth实现gittee登录
  • 中型项目下的 MySQL 挑战与应对
  • 利用Python爬虫实现数据收集与挖掘
  • 音视频入门基础:MPEG2-TS专题(18)——PES流简介
  • HTML基本标签详解
  • MySQL Explain 分析SQL语句性能