当前位置: 首页 > article >正文

基于“动手学强化学习”的知识点(二):第 15 章 模仿学习(gym版本 >= 0.26)

第 15 章 模仿学习(gym版本 >= 0.26)

  • 摘要

摘要

本系列知识点讲解基于动手学强化学习中的内容进行详细的疑难点分析!具体内容请阅读动手学强化学习!


对应动手学强化学习——模仿学习


# -*- coding: utf-8 -*-


import gym
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import rl_utils


class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)


class ValueNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)


class PPO:
    ''' PPO算法,采用截断方式 '''
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
                 lmbda, epochs, eps, gamma, device):
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.critic = ValueNet(state_dim, hidden_dim).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
        self.gamma = gamma
        self.lmbda = lmbda
        self.epochs = epochs  # 一条序列的数据用于训练轮数
        self.eps = eps  # PPO中截断范围的参数
        self.device = device

    def take_action(self, state):
        if isinstance(state, tuple):
            state = state[0]
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state)
        '''根据概率分布创建一个离散分类分布对象,用于采样离散动作。
        离散的概率模型。
        '''
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    def update(self, transition_dict):    
        
        processed_state = []
        for s in transition_dict['states']:
            if isinstance(s, tuple):
                # 如果元素是元组,则取元组的第一个元素
                processed_state.append(s[0])
            else:
                processed_state.append(s)
        # states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
        states = torch.tensor(processed_state, dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)
        '''
        计算 TD 目标(即回归目标):
                            td_target=r+γ×V(s′)×(1−done)
        '''
        td_target = rewards + self.gamma * self.critic(next_states) * (1 -dones)
        '''计算 TD 残差(或优势估计的基础):当前状态的 TD 目标减去当前 critic 估计的状态价值。'''
        td_delta = td_target - self.critic(states)
        '''调用辅助函数(在 rl_utils 模块中定义)计算优势函数,通常使用广义优势估计(GAE)。'''
        advantage = rl_utils.compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)
        '''
        先将状态输入 actor 网络得到动作概率分布(例如 shape 为 (batch_size, action_dim))。
        使用 .gather(1, actions) 选出每个样本所执行动作对应的概率(注意 actions 的形状必须匹配)。
        取对数得到旧的对数概率,再 detach() 阻断梯度流,保存旧策略下的概率值。
        '''
        old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()

        for _ in range(self.epochs):
            '''在当前策略下重新计算所有样本的对数概率,与旧对数概率进行比较。'''
            log_probs = torch.log(self.actor(states).gather(1, actions))
            '''计算概率比率,即新旧策略的概率之比,用于 PPO 的 clip 损失计算。'''
            ratio = torch.exp(log_probs - old_log_probs)
            '''计算无截断的策略目标,乘上优势值。'''
            surr1 = ratio * advantage
            '''对 ratio 进行截断,确保其在 [1−ϵ,1+ϵ] 范围内(例如 [0.8, 1.2]),然后乘以优势。'''
            surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage  # 截断
            '''PPO 算法的目标是最大化最小值,因此这里取两者中的较小值再取负号作为损失。对整个 batch 求均值。'''
            actor_loss = torch.mean(-torch.min(surr1, surr2))  # PPO损失函数
            '''计算 critic 的均方误差(MSE)损失:当前 critic 估计与 TD 目标之间的误差,对整个 batch 取平均。'''
            critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
            self.actor_optimizer.zero_grad()
            self.critic_optimizer.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.actor_optimizer.step()
            self.critic_optimizer.step()


actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 250
hidden_dim = 128
gamma = 0.98
lmbda = 0.95
epochs = 10
eps = 0.2
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

env_name = 'CartPole-v0'
env = gym.make(env_name)
if not hasattr(env, 'seed'):
    def seed_fn(self, seed=None):
        env.reset(seed=seed)
        return [seed]
    env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
ppo_agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device)

return_list = rl_utils.train_on_policy_agent(env, ppo_agent, num_episodes)



def sample_expert_data(n_episode):
    states = []
    actions = []
    for episode in range(n_episode):
        state = env.reset()
        done = False
        while not done:
            action = ppo_agent.take_action(state)
            states.append(state)
            actions.append(action)
            result = env.step(action)
            if len(result) == 5:
                next_state, reward, done, truncated, info = result
                done = done or truncated  # 可合并 terminated 和 truncated 标志
            else:
                next_state, reward, done, info = result
            # next_state, reward, done, _ = env.step(action)
            state = next_state
    processed_states = []
    for s in states:
        if isinstance(s, tuple):
            # 如果元素是元组,则取元组的第一个元素
            processed_states.append(s[0])
        else:
            processed_states.append(s)
    return np.array(processed_states), np.array(actions)


if not hasattr(env, 'seed'):
    def seed_fn(self, seed=None):
        env.reset(seed=seed)
        return [seed]
    env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
random.seed(0)
n_episode = 1
expert_s, expert_a = sample_expert_data(n_episode)

n_samples = 30  # 采样30个数据
random_index = random.sample(range(expert_s.shape[0]), n_samples)
expert_s = expert_s[random_index]
expert_a = expert_a[random_index]


class BehaviorClone:
    def __init__(self, state_dim, hidden_dim, action_dim, lr):
        self.policy = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)

    def learn(self, states, actions):
        """解释:定义一个学习函数,接收一批专家数据中的状态和动作,用于更新策略网络。"""
        states = torch.tensor(states, dtype=torch.float).to(device)
        actions = torch.tensor(actions).view(-1, 1).to(device)
        '''
        - 将 states 输入 policy 网络,得到每个状态下所有动作的概率分布,
          假设输出形状为 (batch_size, action_dim);
        - 使用 .gather(1, actions.long()) 从概率分布中取出对应专家动作的概率(注意动作需要转换为长整型索引);
        - 对这些概率取对数,得到对数概率(log likelihood)。
        '''
        log_probs = torch.log(self.policy(states).gather(1, actions.long()))
        # log_probs = torch.log(self.policy(states).gather(1, actions))
        '''计算行为克隆的损失,即负对数似然损失。对所有样本的负对数概率取均值。'''
        bc_loss = torch.mean(-log_probs)  # 最大似然估计

        self.optimizer.zero_grad()
        bc_loss.backward()
        self.optimizer.step()

    def take_action(self, state):
        if isinstance(state, tuple):
            state = state[0]
        state = torch.tensor([state], dtype=torch.float).to(device)
        probs = self.policy(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()


def test_agent(agent, env, n_episode):
    return_list = []
    for episode in range(n_episode):
        episode_return = 0
        state = env.reset()
        done = False
        while not done:
            action = agent.take_action(state)
            result = env.step(action)
            if len(result) == 5:
                next_state, reward, done, truncated, info = result
                done = done or truncated  # 可合并 terminated 和 truncated 标志
            else:
                next_state, reward, done, info = result
            # next_state, reward, done, _ = env.step(action)
            state = next_state
            episode_return += reward
        return_list.append(episode_return)
    return np.mean(return_list)


if not hasattr(env, 'seed'):
    def seed_fn(self, seed=None):
        env.reset(seed=seed)
        return [seed]
    env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
np.random.seed(0)

lr = 1e-3
bc_agent = BehaviorClone(state_dim, hidden_dim, action_dim, lr)
n_iterations = 1000
batch_size = 64
test_returns = []

with tqdm(total=n_iterations, desc="进度条") as pbar:
    for i in range(n_iterations):
        sample_indices = np.random.randint(low=0, high=expert_s.shape[0], size=batch_size)
        bc_agent.learn(expert_s[sample_indices], expert_a[sample_indices])
        current_return = test_agent(bc_agent, env, 5)
        test_returns.append(current_return)
        if (i + 1) % 10 == 0:
            pbar.set_postfix({'return': '%.3f' % np.mean(test_returns[-10:])})
        pbar.update(1)

iteration_list = list(range(len(test_returns)))
plt.plot(iteration_list, test_returns)
plt.xlabel('Iterations')
plt.ylabel('Returns')
plt.title('BC on {}'.format(env_name))
plt.show()






class Discriminator(nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(Discriminator, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x, a):
        cat = torch.cat([x, a], dim=1)
        x = F.relu(self.fc1(cat))
        return torch.sigmoid(self.fc2(x))
    
class GAIL:
    def __init__(self, agent, state_dim, action_dim, hidden_dim, lr_d):
        print(state_dim, action_dim, hidden_dim)
        self.discriminator = Discriminator(state_dim, hidden_dim, action_dim).to(device)
        self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d)
        self.agent = agent

    def learn(self, expert_s, expert_a, agent_s, agent_a, next_s, dones):

        expert_states = torch.tensor(expert_s, dtype=torch.float).to(device)
        expert_actions = torch.tensor(expert_a).to(device)

        processed_state = []
        for s in agent_s:
            if isinstance(s, tuple):
                # 如果元素是元组,则取元组的第一个元素
                processed_state.append(s[0])
            else:
                processed_state.append(s)
        agent_states = torch.tensor(processed_state, dtype=torch.float).to(device)
        agent_actions = torch.tensor(agent_a).to(device)
        '''作用:将专家动作转换为 one-hot 编码形式,转换为浮点数。'''
        expert_actions = F.one_hot(expert_actions.long(), num_classes=2).float()
        agent_actions = F.one_hot(agent_actions.long(), num_classes=2).float()

        expert_prob = self.discriminator(expert_states, expert_actions)
        agent_prob = self.discriminator(agent_states, agent_actions)
        '''
        作用:计算二元交叉熵损失(BCE):
        - 对 agent 数据,目标标签设为 1(即希望判别器认为 agent 数据为“真”),
          损失为 BCE(agent_prob, 1);
        - 对专家数据,目标标签设为 0(希望判别器认为专家数据为“假”),
          损失为 BCE(expert_prob, 0)。
        - 然后将两部分损失相加。
        '''
        discriminator_loss = nn.BCELoss()(
            agent_prob, torch.ones_like(agent_prob)) + nn.BCELoss()(
                expert_prob, torch.zeros_like(expert_prob))
        self.discriminator_optimizer.zero_grad()
        discriminator_loss.backward()
        self.discriminator_optimizer.step()

        '''
        作用:利用判别器对 agent 数据输出计算奖励:
        - 计算 –log(agent_prob) 作为奖励信号(当 agent_prob 较小时,奖励较高,鼓励 agent 模仿专家);
        - detach() 阻断梯度,转移到 CPU 并转换为 numpy 数组,方便后续传入 agent.update。
        '''
        rewards = -torch.log(agent_prob).detach().cpu().numpy()
        transition_dict = {
            'states': agent_s,
            'actions': agent_a,
            'rewards': rewards,
            'next_states': next_s,
            'dones': dones
        }
        self.agent.update(transition_dict)



if not hasattr(env, 'seed'):
    def seed_fn(self, seed=None):
        env.reset(seed=seed)
        return [seed]
    env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
lr_d = 1e-3
agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device)
gail = GAIL(agent, state_dim, action_dim, hidden_dim, lr_d)
n_episode = 500
return_list = []

with tqdm(total=n_episode, desc="进度条") as pbar:
    for i in range(n_episode):
        episode_return = 0
        state = env.reset()
        done = False
        state_list = []
        action_list = []
        next_state_list = []
        done_list = []
        while not done:
            action = agent.take_action(state)
            result = env.step(action)
            if len(result) == 5:
                next_state, reward, done, truncated, info = result
                done = done or truncated  # 可合并 terminated 和 truncated 标志
            else:
                next_state, reward, done, info = result
            # next_state, reward, done, _ = env.step(action)
            state_list.append(state)
            action_list.append(action)
            next_state_list.append(next_state)
            done_list.append(done)
            state = next_state
            episode_return += reward
        return_list.append(episode_return)
        gail.learn(expert_s, expert_a, state_list, action_list, next_state_list, done_list)
        if (i + 1) % 10 == 0:
            pbar.set_postfix({'return': '%.3f' % np.mean(return_list[-10:])})
        pbar.update(1)    
    
    
iteration_list = list(range(len(return_list)))
plt.plot(iteration_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('GAIL on {}'.format(env_name))
plt.show()

http://www.kler.cn/a/590462.html

相关文章:

  • A l密码学(Deepseek)
  • [Windows] 轻量级景好鼠标录制器 v2.1 单文件版,支持轨迹+鼠标键盘录制复刻
  • es6什么是暂时性死区,为何会存在
  • golang开发支持onlyoffice的token功能
  • 2025-3-17算法打卡
  • 02 javase面向对象-狂神说课程笔记
  • 自学Python创建强大AI:从入门到实现DeepSeek级别的AI
  • 多任务学习与持续学习微调:深入探索大型语言模型的性能与适应性
  • 便携版:随时随地,高效处理 PDF 文件
  • matlab 火电厂给水控制系统仿真
  • linux(centos8)下编译ffmpeg
  • AndroidStudio+Android8.0下的Launcher3 导入,编译,烧录,调试
  • K8S学习之基础三十三:K8S之监控Prometheus部署程序版
  • 深度学习项目--基于DenseNet网络的“乳腺癌图像识别”,准确率90%+,pytorch复现
  • 基于YOLOv8与SKU110K数据集实现超市货架物品目标检测与计算
  • 4-001:MySQL 中的索引数量是否越多越好?为什么?
  • dify 源码部署操作记录
  • 微信小程序面试内容整理-事件处理
  • 通向AGI的未来之路!首篇2D/视频/3D/4D统一生成框架全景综述(港科大中山等)
  • Vue中的publicPath释义