【python运行Janus-Pro-1B文生图功能】
前言
体验了一把本地部署Janus-Pro-1B实现文生图功能。
1、开源项目下载
官方开源项目代码直接从Github上下载。
2、模型下载
模型官方下载需要魔法
Janus-Pro-1B模型文件:Janus-Pro-1B模型文件
百度网盘:
https://pan.baidu.com/s/16t4H4z-QZe2UDAg4EF5g3w?pwd=6666
3、将项目代码和模型添加到同一个文件
1)创建目录存放模型文件
将模型及配置权重文件存放在同一目录下,如Janus-Pro-1B
2)将模型目录放到官方开源项目中。
3)创建虚拟环境
使用如下命令创建虚拟环境:
myenv\Scripts\activate
4)安装所需依赖
pip install -r requirements.txt
5) 项目内容如下
6) 最后运行效果:运行 app_januspro.py
运行代码
# 导入必要的库
# 用于处理文件和目录路径
# 用于图像处理
# PyTorch库,用于深度学习模型的构建和训练
# 用于数值计算
# Hugging Face的Transformers库,用于加载预训练的语言模型
# 自定义的多模态因果语言模型和处理器
import os
import PIL.Image
import time
import torch
import numpy as np
from transformers import AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from datetime import datetime
# 打印PyTorch版本和CUDA可用性
print(torch.__version__)
# 打印当前使用的PyTorch版本
print(torch.cuda.is_available())
# 检查CUDA是否可用(即是否可以使用GPU)# 打印当前GPU的显存占pip install Pillow用情况
print("初始显存占用:", torch.cuda.memory_allocated() / 1024**3, "GB")
# 指定模型路径
# model_path = "deepseek-ai/Janus-Pro-1B" # 禁用此远程路径
# 使用本地路径
model_path = "./Janus-Pro-1B"
# 检查模型路径是否存在
assert os.path.exists(model_path), "模型路径不存在!"
# 打印模型路径
print(f"模型路径:{model_path}")
# 加载VLChatProcessor和tokenizer # 加载VLChatProcessor
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
# 获取tokenizer,用于将文本编码为模型可以理解的输入格式
tokenizer = vl_chat_processor.tokenizer
# 使用AutoModelForCausalLM加载预训练的多模态因果语言模型
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True
)
# 将模型转换为float32精度,并将其移动到GPU上,设置为评估模式
vl_gpt = vl_gpt.to(torch.float32).cuda().eval()
# 打印加载模型后的显存占用情况
print("加载模型后显存占用:", torch.cuda.memory_allocated() / 1024**3, "GB")
# 定义对话内容
conversation = [
{
# 用户角色
"role": "<|User|>",
# 用户输入的文本描述
#"content": "long hair,pink lace dress, off-shoulder, high-waist, narrow waist",
# "content": "这是一张极简主义的照片,一个橙色的橘子,有绿色的茎和叶,象征着繁荣,在中国新年期间坐在红色的丝绸布上。捕捉一个充满活力的向日葵盛开的特写镜头,一只蜜蜂栖息在它的花瓣上,它精致的翅膀捕捉阳光。",
#"content": "这是一张古风照片,一个疑似少年的人站在山顶悬崖边上眺望远处的群山,负手而立,少年周身好似仙人气息,山下烟雾缭绕,隐隐约约有几座宫殿。"
#"content":"瓜子脸 + 高额头 + 浓眉 + 大眼睛 + 高鼻梁 + 樱桃小嘴 + 尖下巴 + 白皙皮肤"
#"content":"圆脸 + 宽额头 + 细眉 + 小眼睛 + 塌鼻子 + 厚嘴唇 + 圆下巴 + 黝黑皮肤"
#"content":"方脸 + 窄额头 + 剑眉 + 丹凤眼 + 鹰钩鼻 + 薄嘴唇 + 方下巴 + 粗糙皮肤"
"content":"这是一张单独人物的照片、性别男、青年、面无表情、短发刘海、脸部干净、穿着西装、打着领带、脸型圆润、动漫化"
},
# 助手角色,内容为空
{"role": "<|Assistant|>", "content": ""},
]
# 应用SFT模板生成多轮对话的prompt
# 使用VLChatProcessor的apply_sft_template_for_multi_turn_prompts方法将对话内容转换为模型可以理解的格式
sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
# 对话内容
conversations=conversation,
# SFT模板格式
sft_format=vl_chat_processor.sft_format,
# 系统提示,此处为空
system_prompt="",
)
# 在生成的prompt后添加图像生成的起始标记
prompt = sft_format + vl_chat_processor.image_start_tag
img_size = 384 # 修改图片尺寸384: (1B支持生成最好的尺寸,超过这个会出稀奇古怪的东西)
patch_size = 16 # 假设 patch_size 是 16
image_token_num_per_image = (img_size // patch_size) ** 2 # 计算 token 数量(** 2平方)
# 定义生成函数,用于生成图片
# 使用torch.inference_mode()装饰器,以减少内存占用并加速推理过程
@torch.inference_mode()
def generate(
mmgpt: MultiModalityCausalLM, # 多模态因果语言模型
vl_chat_processor: VLChatProcessor, # 用于处理输入的处理器
prompt: str, # 用户输入的文本描述
temperature: float = 1, # 控制生成过程的随机性
parallel_size: int = 1, # 一次生成的图片数量
cfg_weight: float = 5, # 分类器自由引导(Classifier-Free Guidance)的权重
image_token_num_per_image=image_token_num_per_image, # 每张图片生成的token数量 576
img_size=img_size, # 生成图片的尺寸384
patch_size= patch_size, # 图片的patch大小
):
# 使用tokenizer将prompt编码为input_ids
input_ids = vl_chat_processor.tokenizer.encode(prompt)
# 将input_ids转换为LongTensor
input_ids = torch.LongTensor(input_ids)
# 初始化tokens,用于生成图片
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).cuda()
for i in range(parallel_size * 2):
# 将input_ids复制到tokens中
tokens[i, :] = input_ids
if i % 2 != 0:
# 对于非条件tokens,填充pad_id
tokens[i, 1:-1] = vl_chat_processor.pad_id
# 获取输入的embeddings # 将tokens转换为embeddings,作为模型的输入
inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)
# 初始化生成的tokens # 初始化用于存储生成图片的tokens
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
# 记录耗时:开始
start_time = time.perf_counter()
# 逐步生成图片
for i in range(image_token_num_per_image):
outputs = mmgpt.language_model.model(
# 输入的embeddings
inputs_embeds=inputs_embeds,
# 使用缓存
use_cache=True,
# 使用过去的key-values
past_key_values=outputs.past_key_values if i != 0 else None,
)
hidden_states = outputs.last_hidden_state # 获取隐藏状态
# 计算logits
logits = mmgpt.gen_head(hidden_states[:, -1, :])
# 条件logits
logit_cond = logits[0::2, :]
# 非条件logits
logit_uncond = logits[1::2, :]
# # 应用分类器自由引导(CFG)权重
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
# 计算概率分布
probs = torch.softmax(logits / temperature, dim=-1)
# 使用torch.multinomial采样下一个token
next_token = torch.multinomial(probs, num_samples=1)
# 将采样的token存储到generated_tokens中
generated_tokens[:, i] = next_token.squeeze(dim=-1)
# 准备下一个输入的embeddings
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
# 生成图片的embeddings
img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
# 更新输入的embeddings
inputs_embeds = img_embeds.unsqueeze(dim=1)
# 打印每一步的显存占用情况
print(f"Step {i}/{image_token_num_per_image} 显存占用:", torch.cuda.memory_allocated() / 1024**3, "GB")
# 记录循环结束时间
end_time = time.perf_counter()
# 计算耗时
elapsed_time = end_time - start_time
print(f"循环耗时: {elapsed_time:.6f} 秒")
# 解码生成的tokens为图片
dec = mmgpt.gen_vision_model.decode_code(
# 生成的tokens
generated_tokens.to(dtype=torch.int),
# 图片的形状
shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size],
)
# 将生成的图片转换为numpy数组
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
# 对图片进行后处理,将其转换为8位无符号整数格式
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
# 初始化用于存储最终图片的数组
visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
# 将生成的图片存储到visual_img中
visual_img[:, :, :] = dec
# 创建目录generated_samples
os.makedirs('generated_samples', exist_ok=True)
for i in range(parallel_size):
# 获取当前本地时间
now = datetime.now()
# 格式化输出
formatted_time = now.strftime("%Y%m%d%H%M%S%f")
# 定义保存路径:生成日期格式
save_path = os.path.join('generated_samples', "img_{}.jpg".format(formatted_time))
# 保存生成的图片
PIL.Image.fromarray(visual_img[i]).save(save_path)
# 调用生成函数生成图片
generate(
# 多模态因果语言模型
vl_gpt,
# 用于处理输入的处理器
vl_chat_processor,
# 用户输入的文本描述
prompt,
)