当前位置: 首页 > article >正文

精调llama模型

github地址:https://github.com/facebookresearch/llama-recipes
github:https://github.com/facebookresearch/llama

import torch
from transformers import LlamaForCausalLM, LlamaTokenizer

#model_id="./models_hf/7B"
# 可以从huggingface上面下载模型,hf就是huggingface模型,也可以通过transformer库的convert_llama_weights_to_hf方法来转换原始的llama模型
model_id="模型path/Llama-2-7b-chat-hf-local"

tokenizer = LlamaTokenizer.from_pretrained(model_id)

model =LlamaForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map='auto', torch_dtype=torch.float16)
from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
from llama_recipes.configs.datasets import samsum_dataset

train_dataset = get_preprocessed_dataset(tokenizer, samsum_dataset, 'train')
eval_prompt = """
Summarize this dialog:
A: Hi Tom, are you busy tomorrow’s afternoon?
B: I’m pretty sure I am. What’s up?
A: Can you go with me to the animal shelter?.
B: What do you want to do?
A: I want to get a puppy for my son.
B: That will make him so happy.
A: Yeah, we’ve discussed it many times. I think he’s ready now.
B: That’s good. Raising a dog is a tough issue. Like having a baby ;-)
A: I'll get him one of those little dogs.
B: One that won't grow up too big;-)
A: And eat too much;-))
B: Do you know which one he would like?
A: Oh, yes, I took him there last Monday. He showed me one that he really liked.
B: I bet you had to drag him away.
A: He wanted to take it home right away ;-).
B: I wonder what he'll name it.
A: He said he’d name it after his dead hamster – Lemmy  - he's  a great Motorhead fan :-)))
---
Summary:
"""

model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")

model.eval()
with torch.no_grad():
    print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))

model.train()

def create_peft_config(model):
    from peft import (
        get_peft_model,
        LoraConfig,
        TaskType,
        prepare_model_for_int8_training,
    )

    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=8,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules = ["q_proj", "v_proj"]
    )

    # prepare int-8 model for training
    model = prepare_model_for_int8_training(model)
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()
    return model, peft_config

# create peft config
model, lora_config = create_peft_config(model)


from transformers import TrainerCallback
from contextlib import nullcontext
enable_profiler = False
output_dir = "tmp/llama-output"

config = {
    'lora_config': lora_config,
    'learning_rate': 1e-4,
    'num_train_epochs': 1,
    'gradient_accumulation_steps': 2,
    'per_device_train_batch_size': 2,
    'gradient_checkpointing': False,
}

# Set up profiler
if enable_profiler:
    wait, warmup, active, repeat = 1, 1, 2, 1
    total_steps = (wait + warmup + active) * (1 + repeat)
    schedule =  torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat)
    profiler = torch.profiler.profile(
        schedule=schedule,
        on_trace_ready=torch.profiler.tensorboard_trace_handler(f"{output_dir}/logs/tensorboard"),
        record_shapes=True,
        profile_memory=True,
        with_stack=True)
    
    class ProfilerCallback(TrainerCallback):
        def __init__(self, profiler):
            self.profiler = profiler
            
        def on_step_end(self, *args, **kwargs):
            self.profiler.step()

    profiler_callback = ProfilerCallback(profiler)
else:
    profiler = nullcontext()

from transformers import default_data_collator, Trainer, TrainingArguments



# Define training args
training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    bf16=True,  # Use BF16 if available
    # logging strategies
    logging_dir=f"{output_dir}/logs",
    logging_strategy="steps",
    logging_steps=10,
    save_strategy="no",
    optim="adamw_torch_fused",
    max_steps=total_steps if enable_profiler else -1,
    **{k:v for k,v in config.items() if k != 'lora_config'}
)

with profiler:
    # Create Trainer instance
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=default_data_collator,
        callbacks=[profiler_callback] if enable_profiler else [],
    )

    # Start training
    trainer.train()

model.save_pretrained(output_dir)

model.eval()
with torch.no_grad():
    print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))

http://www.kler.cn/news/150296.html

相关文章:

  • Kubernetes之kubeadm集群监控篇—node-exporter部署
  • 优雅退出:避免Spring Boot应用程序在推出JVM时的DestroyJavaVM异常
  • 外汇天眼:外汇市场中的“双向交易”是什么意思?
  • Golang中WebSocket和WSS的支持
  • 芯片及设计制造 - 小记
  • Could NOT find resource [logback-test.xml]
  • 激光切割机切割工件出现锯齿是什么原因?
  • RHCSA---基本命令使用
  • Linux高级IO
  • 利用pytorch实现卷积形式的ResNet
  • win10 下 mvn install 报错:编码GBK不可映射字符
  • vue项目运行时,报错:ValidationError: webpack Dev Server Invalid Options
  • 谨慎Apache-Zookeeper-3.5.5以后在CentOS7.X安装的坑
  • 数据结构中的二分查找(折半查找)
  • vue+el-tooltip 封装提示框组件,只有溢出才提示
  • Findreport中框架图使用的注意事项
  • [原创][2]探究C#多线程开发细节-”线程的无顺序性“
  • c++实现程序单例运行的两种方式
  • Azure Machine Learning - 创建Azure AI搜索索引
  • Spring-AOP与声明式事务
  • Linux socket编程(8):shutdown和close的区别详解及例子
  • 《尚品甄选》:后台系统——分类品牌和规格管理(debug一遍)
  • Docker容器网络模式
  • PHP如何实现邮箱验证
  • Android控件全解手册 - 多语言切换完美解决方案(兼容7.0以上版本)
  • 找不到 sun.misc.BASE64Decoder ,sun.misc.BASE64Encoder 类
  • ESP32-Web-Server 实战编程- 使用 AJAX 自动更新网页内容
  • pytest分布式执行(pytest-xdist)
  • rabbitmq-server-3.11.10.exe
  • 基于opencv+ImageAI+tensorflow的智能动漫人物识别系统——深度学习算法应用(含python、JS、模型源码)+数据集(三)