当前位置: 首页 > article >正文

VLM(视觉语言模型)与DeepSeek R1(奖励机制)如何结合

VLM(视觉语言模型)与DeepSeek R1(奖励机制)如何结合

flyfish

VLM的传统训练依赖于监督学习(直接拟合问答对),而规则奖励函数通常用于强化学习(通过试错和奖励反馈优化策略)。这两种方式如何结合?

源码来自
VLM-R1/src/open-r1-multimodal/src/open_r1/grpo_rec.py

# 导入 debugpy 库,用于调试,当前代码中被注释掉,若需要调试可取消注释
# import debugpy
# try:
#     # 5678 是 VS Code 调试配置中的默认附加端口。除非指定主机和端口,否则主机默认为 127.0.0.1
#     debugpy.listen(("localhost", 9501))
#     print("Waiting for debugger attach")
#     debugpy.wait_for_client()
# except Exception as e:
#     pass

# 导入操作系统相关功能的库
import os
# 导入正则表达式库,用于字符串匹配和处理
import re
# 导入日期时间处理库
from datetime import datetime
# 导入数据类装饰器和字段定义类,用于定义数据类
from dataclasses import dataclass, field
# 导入可选类型注解,用于表示某个参数可以为 None
from typing import Optional

# 导入 Pillow 库中的 Image 类,用于处理图像
from PIL import Image
# 导入 PyTorch 中的数据集基类
from torch.utils.data import Dataset
# 导入 Qwen2VL 条件生成模型
from transformers import Qwen2VLForConditionalGeneration

# 导入自定义的数学验证模块中的解析和验证函数
from math_verify import parse, verify
# 导入自定义的 Qwen2VLGRPOTrainer 类
from open_r1.trainer import Qwen2VLGRPOTrainer
# 导入 TRL 库中的 GRPO 配置、训练器、模型配置、脚本参数、解析器和 PEFT 配置获取函数
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
# 导入 Transformers 库中的训练参数类
from transformers import TrainingArguments
# 导入 YAML 文件处理库
import yaml
# 导入 JSON 文件处理库
import json
# 导入随机数生成库
import random
# 导入数学计算库
import math

# ----------------------- 修复当前版本 transformers 中的 flash attention 错误 -----------------------
# 导入 Qwen2_5_VL 模型中的相关类和函数
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
# 导入 PyTorch 库
import torch
# 导入元组类型注解
from typing import Tuple

# 自定义 Qwen2_5_VLVisionFlashAttention2 类的前向传播函数
def custom_forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: Optional[torch.Tensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
    # 获取隐藏状态的序列长度
    seq_length = hidden_states.shape[0]
    # 通过 qkv 层得到查询、键、值张量,并进行形状调整和维度置换
    q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
    # 如果没有提供位置嵌入,则根据旋转位置嵌入计算余弦和正弦值
    if position_embeddings is None:
        # 打印一次警告信息,提示 RoPE 嵌入计算方式的变化
        logger.warning_once(
            "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
            "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
            "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
            "removed and `position_embeddings` will be mandatory."
        )
        # 拼接旋转位置嵌入
        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
        # 计算余弦值
        cos = emb.cos().float()
        # 计算正弦值
        sin = emb.sin().float()
    else:
        # 从位置嵌入中获取余弦和正弦值
        cos, sin = position_embeddings
        # 将余弦值转换为浮点类型
        cos = cos.to(torch.float)
        # 将正弦值转换为浮点类型
        sin = sin.to(torch.float)
    # 应用旋转位置嵌入到查询和键张量
    q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
    # 去除查询张量的额外维度
    q = q.squeeze(0)
    # 去除键张量的额外维度
    k = k.squeeze(0)

    # 计算最大序列长度
    max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
    # 调用 flash 注意力函数计算注意力输出
    attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
        seq_length, -1
    )
    # 通过投影层得到最终的注意力输出
    attn_output = self.proj(attn_output)
    return attn_output

# 将自定义的前向传播函数赋值给 Qwen2_5_VLVisionFlashAttention2 类的 forward 方法
Qwen2_5_VLVisionFlashAttention2.forward = custom_forward


# ----------------------- 主脚本 -----------------------
# 定义 GRPOScriptArguments 数据类,继承自 ScriptArguments
@dataclass
class GRPOScriptArguments(ScriptArguments):
    """
    用于 GRPO 训练脚本的脚本参数。

    参数:
        reward_funcs (`list[str]`):
            奖励函数列表。可能的值: 'accuracy', 'format'。
    """

    # 奖励函数列表,默认包含 'accuracy' 和 'format'
    reward_funcs: list[str] = field(
        default_factory=lambda: ["accuracy", "format"],
        metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"}
    )
    # 图像的最大像素数,默认为 12845056
    max_pixels: Optional[int] = field(
        default=12845056,
        metadata={"help": "Maximum number of pixels for the image"}
    )
    # 图像的最小像素数,默认为 3136
    min_pixels: Optional[int] = field(
        default=3136,
        metadata={"help": "Minimum number of pixels for the image"}
    )
    # 图像的根目录,默认为 None
    image_root: Optional[str] = field(
        default=None,
        metadata={"help": "Root directory of the image"}
    )

# 定义系统提示信息,用于指导模型的对话生成
SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>"
)

# 定义 LazySupervisedDataset 类,继承自 Dataset
class LazySupervisedDataset(Dataset):
    def __init__(self, data_path: str, script_args: GRPOScriptArguments):
        # 调用父类的构造函数
        super(LazySupervisedDataset, self).__init__()
        # 保存脚本参数
        self.script_args = script_args
        # 初始化数据字典列表
        self.list_data_dict = []

        # 如果数据文件是 YAML 格式
        if data_path.endswith(".yaml"):
            # 打开 YAML 文件
            with open(data_path, "r") as file:
                # 加载 YAML 数据
                yaml_data = yaml.safe_load(file)
                # 获取数据集列表
                datasets = yaml_data.get("datasets")
                # 文件格式应为:
                # datasets:
                #   - json_path: xxxx1.json
                #     sampling_strategy: first:1000
                #   - json_path: xxxx2.json
                #     sampling_strategy: end:3000
                #   - json_path: xxxx3.json
                #     sampling_strategy: random:999

                # 遍历每个数据集
                for data in datasets:
                    # 获取 JSON 文件路径
                    json_path = data.get("json_path")
                    # 获取采样策略,默认为 'all'
                    sampling_strategy = data.get("sampling_strategy", "all")
                    # 初始化采样数量为 None
                    sampling_number = None

                    # 如果 JSON 文件是 JSONL 格式
                    if json_path.endswith(".jsonl"):
                        # 初始化当前数据字典列表
                        cur_data_dict = []
                        # 打开 JSONL 文件
                        with open(json_path, "r") as json_file:
                            # 逐行读取文件
                            for line in json_file:
                                # 解析每行 JSON 数据并添加到当前数据字典列表
                                cur_data_dict.append(json.loads(line.strip()))
                    # 如果 JSON 文件是 JSON 格式
                    elif json_path.endswith(".json"):
                        # 打开 JSON 文件
                        with open(json_path, "r") as json_file:
                            # 加载 JSON 数据到当前数据字典列表
                            cur_data_dict = json.load(json_file)
                    else:
                        # 如果文件类型不支持,抛出异常
                        raise ValueError(f"Unsupported file type: {json_path}")

                    # 如果采样策略包含冒号
                    if ":" in sampling_strategy:
                        # 分割采样策略和采样数量
                        sampling_strategy, sampling_number = sampling_strategy.split(":")
                        # 如果采样数量包含百分比符号
                        if "%" in sampling_number:
                            # 计算采样数量
                            sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
                        else:
                            # 将采样数量转换为整数
                            sampling_number = int(sampling_number)

                    # 应用采样策略
                    if sampling_strategy == "first" and sampling_number is not None:
                        # 取前 sampling_number 个样本
                        cur_data_dict = cur_data_dict[:sampling_number]
                    elif sampling_strategy == "end" and sampling_number is not None:
                        # 取后 sampling_number 个样本
                        cur_data_dict = cur_data_dict[-sampling_number:]
                    elif sampling_strategy == "random" and sampling_number is not None:
                        # 随机打乱样本
                        random.shuffle(cur_data_dict)
                        # 取前 sampling_number 个样本
                        cur_data_dict = cur_data_dict[:sampling_number]
                    # 打印从当前 JSON 文件加载的样本数量
                    print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
                    # 将当前数据字典列表添加到总数据字典列表
                    self.list_data_dict.extend(cur_data_dict)
        else:
            # 如果文件类型不支持,抛出异常
            raise ValueError(f"Unsupported file type: {data_path}")

    def __len__(self):
        # 返回数据字典列表的长度
        return len(self.list_data_dict)

    def __getitem__(self, i):
        # 定义将示例转换为对话格式的函数
        def make_conversation(example):
            return {
                "prompt": [
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": example["problem"]}
                ]
            }

        # 问题模板,用于包含图像的对话
        QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."

        # 定义将包含图像的示例转换为对话格式的函数
        def make_conversation_image(example):
            return {
                "prompt": [
                    # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
                    {
                        "role": "user",
                        "content": [
                            {"type": "image"},
                            {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])}
                        ]
                    }
                ]
            }

        # 获取指定索引的示例
        example = self.list_data_dict[i]
        # 获取图像根目录
        image_root = self.script_args.image_root
        # 如果示例中包含图像信息
        if 'image' in example:
            # 构建图像路径
            image_path = os.path.join(image_root, example['image'])
            # 如果图像文件不存在
            while not os.path.exists(image_path):
                # 打印警告信息
                print(f"Warning: Image {image_path} not found, randomly selecting another image")
                # 随机选择一个新的索引
                new_index = random.randint(0, len(self.list_data_dict)-1)
                # 获取新的示例
                example = self.list_data_dict[new_index]
                # 构建新的图像路径
                image_path = os.path.join(image_root, example['image'])
            # 打开图像并转换为 RGB 格式
            image = Image.open(image_path).convert("RGB")
        else:
            # 如果示例中不包含图像信息,图像为 None
            image = None

        return {
            'image': image,
            'problem': example['problem'],
            'solution': example['solution'],
            'prompt': make_conversation_image(example)['prompt'] if 'image' in example else make_conversation(example)['prompt']
        }

'''
    如果模型预测的边界框与真实边界框的交并比(IoU)大于 0.5,则奖励为 1.0,否则为 0.0。
    这是一种硬奖励,未来可能使用软奖励会更好。
'''
def iou_reward(completions, solution, **kwargs):
    # 定义计算交并比的函数
    def iou(box1, box2):
        # 计算交集的左上角坐标
        inter_x1 = max(box1[0], box2[0])
        inter_y1 = max(box1[1], box2[1])
        # 计算交集的右下角坐标
        inter_x2 = min(box1[2]-1, box2[2]-1)
        inter_y2 = min(box1[3]-1, box2[3]-1)
        # 如果交集存在
        if inter_x1 < inter_x2 and inter_y1 < inter_y2:
            # 计算交集面积
            inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
        else:
            # 交集面积为 0
            inter = 0
        # 计算并集面积
        union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
        # 返回交并比
        return float(inter)/union

    # 获取完成内容列表
    contents = [completion[0]["content"] for completion in completions]
    # 初始化奖励列表
    rewards = []
    # 获取当前时间并格式化
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    # 定义答案标签的正则表达式模式
    answer_tag_pattern = r'<answer>(.*?)</answer>'
    # 定义边界框的正则表达式模式
    bbox_pattern = r'\[(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*)\]'
    # 遍历完成内容和真实解决方案
    for content, sol in zip(contents, solution):
        # 初始化奖励为 0.0
        reward = 0.0
        # 尝试进行符号验证
        try:
            # 在完成内容中查找答案标签
            content_answer_match = re.search(answer_tag_pattern, content)
            if content_answer_match:
                # 获取答案内容
                content_answer = content_answer_match.group(1).strip()
                # 在答案内容中查找边界框
                bbox_match = re.search(bbox_pattern, content_answer)
                if bbox_match:
                    # 获取边界框坐标
                    bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
                    # 如果交并比大于 0.5
                    if iou(bbox, sol) > 0.5:
                        # 奖励为 1.0
                        reward = 1.0
        except Exception:
            # 如果验证失败,继续下一个验证方法
            pass

        # 将奖励添加到奖励列表
        rewards.append(reward)
        # 如果处于调试模式
        if os.getenv("DEBUG_MODE") == "true":
            # 获取日志路径
            log_path = os.getenv("LOG_PATH")
            # 打开日志文件并追加记录
            with open(log_path, "a") as f:
                # 记录当前时间和奖励信息
                f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                # 记录完成内容
                f.write(f"Content: {content}\n")
                # 记录真实解决方案
                f.write(f"Solution: {sol}\n")
    return rewards


def format_reward(completions, **kwargs):
    """奖励函数,用于检查完成内容是否符合特定格式。"""
    # 定义格式的正则表达式模式
    # pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"
    # 获取完成内容列表
    completion_contents = [completion[0]["content"] for completion in completions]
    # 检查每个完成内容是否符合格式
    matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
    # 根据匹配结果生成奖励列表
    return [1.0 if match else 0.0 for match in matches]


# 奖励函数注册表,将奖励函数名称映射到对应的函数
reward_funcs_registry = {
    "accuracy": iou_reward,
    "format": format_reward,
}


def main(script_args, training_args, model_args):
    # 根据脚本参数中的奖励函数名称,从注册表中获取对应的奖励函数
    reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
    # 打印奖励函数列表
    print("reward_funcs:", reward_funcs)

    # 加载数据集
    dataset = LazySupervisedDataset(script_args.dataset_name, script_args)

    # 选择训练器类,这里使用自定义的 Qwen2VLGRPOTrainer
    trainer_cls = Qwen2VLGRPOTrainer
    # 初始化 GRPO 训练器
    trainer = trainer_cls(
        model=model_args.model_name_or_path,  # 模型名称或路径
        reward_funcs=reward_funcs,  # 奖励函数列表
        args=training_args,  # 训练参数
        train_dataset=dataset,  # 训练数据集
        eval_dataset=None,  # 评估数据集,这里设为 None
        peft_config=get_peft_config(model_args),  # PEFT 配置
        attn_implementation=model_args.attn_implementation,  # 注意力实现方式
        max_pixels=script_args.max_pixels,  # 图像最大像素数
        min_pixels=script_args.min_pixels,  # 图像最小像素数
        torch_dtype=model_args.torch_dtype,  # PyTorch 数据类型
    )

    # 开始训练模型
    trainer.train()

    # 保存模型到指定的输出目录
    trainer.save_model(training_args.output_dir)
    # 如果设置了将模型推送到 Hub
    if training_args.push_to_hub:
        # 将模型推送到 Hub,并指定数据集名称
        trainer.push_to_hub(dataset_name=script_args.dataset_name)


if __name__ == "__main__":
    # 创建 TrlParser 对象,用于解析脚本参数、训练配置和模型配置
    parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
    # 解析命令行参数和配置
    script_args, training_args, model_args = parser.parse_args_and_config()
    # 调用主函数开始训练
    main(script_args, training_args, model_args)

代码中的两个关键奖励函数 format_rewardiou_reward

1. 格式奖励函数 format_reward

函数定义和功能
def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]

此函数的主要功能是检查模型生成的完成内容是否符合特定的格式要求。具体来说,它期望模型的输出满足以下格式:

  • 包含 <think></think> 标签,用于包裹思考过程。
  • 包含 <answer></answer> 标签,用于包裹答案。
  • 答案部分需要是一个 JSON 格式,并且其中包含一个由四个整数组成的列表,通常可以理解为表示边界框的坐标。
实现步骤
  1. 定义正则表达式模式pattern 是一个正则表达式,用于描述期望的输出格式。
  2. 提取完成内容completion_contentscompletions 中提取出每个完成内容的文本部分。
  3. 检查格式匹配matches 使用 re.fullmatch 函数检查每个完成内容是否完全匹配正则表达式模式。
  4. 生成奖励列表:根据匹配结果,为每个完成内容生成一个奖励值,如果匹配则为 1.0,否则为 0.0。
作用

通过这个奖励函数,模型在训练过程中会被激励去生成符合特定格式的输出,有助于规范模型的回答结构,使得输出更易于解析和使用。

2. 交并比(IoU)奖励函数 iou_reward

函数定义和功能
def iou_reward(completions, solution, **kwargs):
    def iou(box1, box2):
        inter_x1 = max(box1[0], box2[0])
        inter_y1 = max(box1[1], box2[1])
        inter_x2 = min(box1[2]-1, box2[2]-1)
        inter_y2 = min(box1[3]-1, box2[3]-1)
        if inter_x1 < inter_x2 and inter_y1 < inter_y2:
            inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
        else:
            inter = 0
        union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
        return float(inter)/union
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    answer_tag_pattern = r'<answer>(.*?)</answer>'
    bbox_pattern = r'\[(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*)\]'
    for content, sol in zip(contents, solution):
        reward = 0.0
        try:
            content_answer_match = re.search(answer_tag_pattern, content)
            if content_answer_match:
                content_answer = content_answer_match.group(1).strip()
                bbox_match = re.search(bbox_pattern, content_answer)
                if bbox_match:
                    bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
                    if iou(bbox, sol) > 0.5:
                        reward = 1.0
        except Exception:
            pass
        rewards.append(reward)
        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            with open(log_path, "a") as f:
                f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                f.write(f"Content: {content}\n")
                f.write(f"Solution: {sol}\n")
    return rewards

此函数的主要功能是评估模型预测的边界框与真实边界框之间的重叠程度,并根据交并比(IoU)值给予奖励。

实现步骤
  1. 定义 IoU 计算函数iou 函数用于计算两个边界框的交并比。它首先计算两个边界框的交集面积和并集面积,然后将交集面积除以并集面积得到 IoU 值。
  2. 提取完成内容contentscompletions 中提取出每个完成内容的文本部分。
  3. 查找答案和边界框:使用正则表达式 answer_tag_pattern 查找完成内容中的答案部分,再使用 bbox_pattern 查找答案中的边界框坐标。
  4. 计算 IoU 并给予奖励:对于每个完成内容,提取预测的边界框坐标,与真实边界框计算 IoU 值。如果 IoU 值大于 0.5,则给予 1.0 的奖励,否则给予 0.0 的奖励。
  5. 日志记录(可选):如果设置了调试模式(DEBUG_MODEtrue),则将每个完成内容的奖励信息记录到日志文件中。
作用

通过这个奖励函数,模型在训练过程中会被激励去预测更准确的边界框,提高目标检测的精度。同时,结合格式奖励函数,可以让模型不仅准确预测边界框,还能以规定的格式输出结果。

监督学习与规则奖励函数强化学习的结合方式

1. 数据层面的结合
  • 利用监督数据初始化模型:在开始强化学习训练之前,使用监督学习的方式对视觉语言模型(VLM)进行预训练。通过直接拟合问答对数据,让模型学习到基本的语言和视觉特征表示以及问题回答的模式。例如,在代码中使用 LazySupervisedDataset 类加载数据集,这些数据可以作为监督学习阶段的训练数据,让模型初步学习到如何根据问题和图像生成答案。
  • 监督数据作为强化学习的参考:在强化学习的过程中,监督学习的数据可以作为参考来评估模型的输出。例如,在 iou_reward 函数中,通过比较模型预测的边界框与真实边界框的交并比(IoU)来给予奖励,这里的真实边界框就是监督学习中的标签信息。
2. 训练过程的结合
  • 分阶段训练:先进行监督学习训练,让模型收敛到一个较好的初始状态。然后再切换到强化学习阶段,使用规则奖励函数来进一步优化模型的策略。在代码中,虽然没有明确体现分阶段训练的逻辑,但可以在实际应用中先使用监督学习的方法对 Qwen2VLForConditionalGeneration 模型进行训练,然后再使用 Qwen2VLGRPOTrainer 进行强化学习训练。
  • 混合训练:在每个训练步骤中,既使用监督学习的损失函数,又使用强化学习的奖励函数。例如,可以将监督学习的交叉熵损失和强化学习的奖励损失加权求和,作为总的损失函数来更新模型参数。这样可以让模型在学习过程中既考虑到直接拟合标签的准确性,又考虑到长期的奖励优化。
3. 奖励函数设计结合监督信息
  • 准确性奖励:如 iou_reward 函数,将模型输出与监督学习中的标签进行比较,根据比较结果给予奖励。这种奖励函数可以促使模型在强化学习过程中输出更接近真实标签的结果,从而结合了监督学习的信息。
  • 格式奖励format_reward 函数可以确保模型输出的格式符合特定要求,这可以看作是一种规则约束。同时,这种格式要求也可以是在监督学习阶段就定义好的,从而将监督学习中的格式规范融入到强化学习的奖励机制中。

http://www.kler.cn/a/555819.html

相关文章:

  • 1.13作业
  • 详解Nginx 配置
  • 关于ES中text类型时间字段范围查询的结构化解决方案
  • SprinBoot整合HTTP API:从零开始的实战指南
  • 以太网的PHY(物理层)详解
  • 适配器模式 Adapter Pattern
  • 如何设计提示词让AI以思维链方式回答问题
  • Linux:文件(二)
  • NSFNET是什么?NSFNET网络具有什么特点?
  • halcon三维点云数据处理(二十五)moments_object_model_3d
  • 【目标检测】【YOLOv4】YOLOv4:目标检测的最佳速度与精度
  • 嵌入式八股,struct结构体和union联合体的联系与区别
  • PWM(脉宽调制)技术详解:从基础到应用实践示例
  • Hive JOIN过滤条件位置玄学:ON vs WHERE的量子纠缠
  • 最新版保姆级JDK安装教程
  • 芯谷D2761:为扬声器保驾护航的音频限幅器
  • 在 JMeter 中实现多用户并发登录及操作
  • coco格式
  • CVE-2021-34527: PrintNightmare 域内提权
  • 解锁健康密码,拥抱养生生活