基于cross-attention算法关联文本和图像、图像和动作
问题:
基于cross-attention算法关联文本和图像,可以举一个可以运行的例子吗?
基本实现框架:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
class TextImageCrossAttention(nn.Module):
def __init__(self, text_dim=512, image_dim=768, att_dim=64):
super().__init__()
self.Wq = nn.Linear(text_dim, att_dim) # 文本作为Query
self.Wk = nn.Linear(image_dim, att_dim) # 图像作为Key
self.Wv = nn.Linear(image_dim, att_dim) # 图像作为Value
def forward(self, text_feats, image_feats):
# text_feats: [batch, seq_len, text_dim]
# image_feats: [batch, h*w, image_dim]
Q = self.Wq(text_feats) # [batch, seq, att_dim]
K = self.Wk(image_feats) # [batch, h*w, att_dim]
V = self.Wv(image_feats)
attn_scores = torch.bmm(Q, K.transpose(1,2)) / (att_dim**0.5)
attn_weights = torch.softmax(attn_scores, dim=-1)
context = torch.bmm(attn_weights, V) # [batch, seq, att_dim]
return context, attn_weights
图像描述生成的例子(特征对齐、上下文建模、多模态融合):
准备输入数据:
# 模拟CLIP特征提取结果
text_feats = torch.randn(1, 5, 512) # 文本序列:"A cat sits on grass"
image_feats = torch.randn(1, 196, 768) # 14x14图像特征图
# 实例化注意力模块
cross_attn = TextImageCrossAttention()
# 计算交叉注意力
context, attn_weights = cross_attn(text_feats, image_feats)
注意力可视化:
def visualize_attention(attn, text_tokens, img_size=14):
# 取第一个样本的注意力权重
attn_map = attn[0].mean(dim=0) # [seq_len, h*w]
# 创建可视化画布
fig, axes = plt.subplots(1, len(text_tokens), figsize=(15,3))
for i, token in enumerate(text_tokens):
# 将注意力权重转换为特征图尺寸
heatmap = attn_map[i].view(img_size, img_size).detach().numpy()
axes[i].imshow(heatmap, cmap='viridis')
axes[i].set_title(f'"{token}"')
axes[i].axis('off')
plt.colorbar(axes[-1].images[0], ax=axes, shrink=0.5)
plt.show()
# 示例文本标记
text_tokens = ["[CLS]", "A", "cat", "sits", "grass"]
# 执行可视化
visualize_attention(attn_weights, text_tokens)
个性化图像生成场景:
# 示例代码片段(基于网页2思想)
def nested_attention(orig_attn, theme_attn):
# 融合原始注意力与主题注意力
return alpha*orig_attn + (1-alpha)*theme_attn
实时图像编辑场景:
# 示例编辑操作:增强"grass"的注意力权重
attn_weights[:,3] *= 1.5 # 对应"grass"的位置索引
attn_weights = torch.softmax(attn_weights, dim=-1)
问题:
基于cross-attention算法关联动作策略和图像,可以举一个可以运行的例子吗?
模拟机器人视觉导航场景的核心代码实现:
import torch
import torch.nn as nn
import torchvision.models as models
import matplotlib.pyplot as plt
class ActionPolicyCrossAttention(nn.Module):
def __init__(self, policy_dim=256, image_dim=512, att_dim=64):
super().__init__()
# 图像特征提取(预训练ResNet18)
self.cnn = models.resnet18(pretrained=True)
self.cnn = nn.Sequential(*list(self.cnn.children())[:-2]) # 输出14x14特征图
# Cross-Attention模块
self.Wq = nn.Linear(policy_dim, att_dim) # 策略作为Query
self.Wk = nn.Linear(image_dim, att_dim) # 图像作为Key
self.Wv = nn.Linear(image_dim, att_dim) # 图像作为Value
# 策略生成层
self.policy_head = nn.Linear(att_dim, 4) # 输出动作:前/后/左/右
def forward(self, image, policy_state):
"""
image: [batch, 3, 224, 224]
policy_state: [batch, policy_dim] (当前策略状态)
"""
# 提取图像特征
img_feats = self.cnn(image) # [batch, 512, 14, 14]
img_feats = img_feats.view(img_feats.size(0), img_feats.size(1), -1).permute(0,2,1) # [batch, 196, 512]
# 计算交叉注意力
Q = self.Wq(policy_state.unsqueeze(1)) # [batch, 1, att_dim]
K = self.Wk(img_feats) # [batch, 196, att_dim]
V = self.Wv(img_feats)
attn_scores = torch.bmm(Q, K.transpose(1,2)) / (att_dim**0.5)
attn_weights = torch.softmax(attn_scores, dim=-1)
context = torch.bmm(attn_weights, V) # [batch, 1, att_dim]
# 生成动作策略
action_logits = self.policy_head(context.squeeze(1))
return action_logits, attn_weights
模拟避障场景(特征交互机制、动态权重分配、多模态融合架构):
初始化与模型输入:
# 模拟输入
image_input = torch.randn(1, 3, 224, 224) # 摄像头图像
policy_state = torch.randn(1, 256) # 当前策略状态向量
# 初始化模型
model = ActionPolicyCrossAttention()
# 前向计算
action_probs, attn = model(image_input, policy_state)
print(f"动作概率分布: {torch.softmax(action_probs, dim=-1).detach().numpy()}")
注意力可视化:
def visualize_action_attention(attn_weights, img_size=14):
heatmap = attn_weights[0].view(img_size, img_size).detach().numpy()
plt.imshow(heatmap, cmap='viridis')
plt.title("策略关注的关键图像区域")
plt.colorbar()
plt.show()
visualize_action_attention(attn)
游戏AI控制:
# 实现动作策略与游戏画面的关联
class GameAgent(nn.Module):
def __init__(self):
super().__init__()
self.visual_encoder = models.vgg16(pretrained=True).features
self.policy_net = nn.LSTM(input_size=128, hidden_size=256)
self.cross_attn = CrossAttention(policy_dim=256, image_dim=512)
机器人实时决策(通过时间序列扩展实现连续决策):
class RecurrentCrossAttn(nn.Module):
def __init__(self):
self.lstm = nn.LSTMCell(input_size=64, hidden_size=256)
self.cross_attn = CrossAttention(policy_dim=256, image_dim=512)