当前位置: 首页 > article >正文

本地训练controlnet网络详解——以官方fill50k数据集为例

      • 写在前
      • controlnet是什么?
      • 环境说明
      • 相关资源
      • stable-diffusion-v1-5下载
      • controlnet代码下载
      • Pycharm运行和调试远程服务器代码
        • pycharm加载源码
        • 服务器配置
        • 解析器配置
      • fill50k数据集加载
        • fill50k数据集简单介绍
        • 数据集下载
        • 数据集加载
        • 本地加载数据集的注意事项
      • 参数配置和参数说明
        • parser简单说明
        • 参数说明
      • 代码运行
        • 运行配置
        • 官方参数设置
        • 启动效果
      • 测试过程

写在前

本文基于diffusers官方开源的代码训练自己的controlnet模型,数据集例子也是官方的fill50k。如果想要训练自己的例子,只需要根据fill50k的数据集内容组织自己的数据即可

controlnet是什么?

controlnet相当于一种插件,是用来微调现有的stable diffusion。这是因为stable diffusion本身训练比较耗费资源,为了某些目的而重新训练stable diffusion是不理想的。controlnet通过增加额外的训练参数来控制生成的效果。如下是多种开源controlnet的一种。也就是用户给定一张草图以及一个提示(bird)作为controlnet的输入,通过controlnet就可以得到你想要的图像。在这里插入图片描述

环境说明

本文主要是在本地Ubuntu服务器进行训练的,代码在显存8G16G20G以及38G都可以进行训练自己的

  • 系统版本:Ubuntu 22.04.2 LTS
  • python版本:python3.9
  • torch相关包版本如下在这里插入图片描述

相关资源

controlnet相关论文
stable-diffusion-v1-5基础模型权重
diffusers官方整理的controlnet源码
fill50k数据集

stable-diffusion-v1-5下载

controlnet是相当于一种用来微调预训练stable-diffusion的一种插件,所以首先需要下载stable-diffusion的相关权重。本文主要基于v1.5,相对于而言v1.5目前较为稳定且泛化性较好

1、打开网址stable-diffusion-v1-5基础模型权重(网址有可能失效,自行搜索例如stable-diffusion-v1-5关键词也可以找到模型的)(需要科学上网

2、切换到Files选项卡
在这里插入图片描述
3、有很多文件,其中有一些权重文件很大,但是并不需要全部下载省事也可以全部下载,但是因为科学上网可能比较慢,看个人需求

在这里插入图片描述

以下是我的下载的的内容,所有文件夹都要有,除了文件夹内容,以外的就只需要下载model_index.json即可

在这里插入图片描述

全部下载的话就不用看这里了

每个文件夹的内容都点进去下载内容,对于每个文件夹中有些文件夹里边有safetensors后缀的文件,还很大,**这些都不需要下载其他(.bin.txt.json)都需要下载

在这里插入图片描述

整体文件组织不要改变,最终是下边这样的
在这里插入图片描述

controlnet代码下载

本文主要是基于diffuser的开源代码进行测试

1、打开链接diffusers官方整理的controlnet源码(需要科学上网

在这里插入图片描述
在这里插入图片描述

2、然后下载目录的代码,这里下载的话可以用git下载,我就选择省事点直接下载。因为github对于这种子目录没有直接下载压缩包的方式。这里可以使用第三方网站:DownGit。把仓库地址放里边,然后点击Download就会自动下载了。然后另存到自己找得到的地方

在这里插入图片描述

在这里插入图片描述
解压以后就是跟github上是一样的

在这里插入图片描述

Pycharm运行和调试远程服务器代码

pycharm加载源码

我是使用本地的Liunx服务器进行训练代码的,通过pycharm可以方便调试

1、直接通过pycharm打开下载的代码目录,一般安装pycharm后右键下载的代码文件夹目录就可以打开了
在这里插入图片描述

或者直接通过pycharm打开项目,选中目录即可(名为controlnet目录,这个名称跟刚才的压缩包名称有关,也就是存放刚才解压出来的代码的父目录)
在这里插入图片描述

2、打开后整体项目如下(controlnet_diffusers这个是我代码的父目录,名称我修改过,正常解压出来的文件夹应该是controlnet
在这里插入图片描述
这里关键的文件是train_controlnet.py

服务器配置

主要是配置运行代码的远程服务器,因为我运行代码主机跟训练代码的服务器并不是同一个,所以需要配置,但是主机跟训练代码的服务器是同一个也是类似的配置方式。我都是在本地的,也就是我运行代码的主机跟服务器都是在同一个局域网的,同理如果是互联网上的服务器也是一样操作

1、点击Tools——Deloyments——Configuration
在这里插入图片描述
打开界面如下

在这里插入图片描述
点击左上角的加号并选择SFTP
在这里插入图片描述
然后输入一个名字,随便都行,不冲突即可

在这里插入图片描述
在这里插入图片描述
2、配置SSH,这里就是配置服务器,我是需要登录远程服务器的,这里就需要进行配置。如果之前配置过,可以通过下拉的箭头来选择之前的
在这里插入图片描述
如果没有就需要添加,主要是点击箭头右边的三个点在这里插入图片描述

然后点击左上角的加号添加一个远程服务。点击以后会出现右边的配置框。按照要求填写即可。Host就是填入服务器的IP地址,Port就是服务器配置的ssh服务的端口。username就是填入自己登录服务器的用户名。Local port一般不需要管,默认即可。Password就是的密码,这里建议勾选save password,不需要每次都输入
在这里插入图片描述
搞定以后可以点击Test connection来测试连接情况在这里插入图片描述

解析器配置

上边就是配置好了服务器,就是用来运行代码的机器配置好了,并且代码也准备好,接下来就是配置python解释器,也就是实际运行代码的工具

1、点击File——setting在这里插入图片描述

然后找到如下选项,右边初始是没有选择的(下图是我已经配置好的
在这里插入图片描述

python interpreter就是解析器配置的地方,如果之前配置了,点击下拉箭头就可以选择之前的

在这里插入图片描述
如果没有就点击右边的齿轮,然后Add进行添加
在这里插入图片描述

然后选择SSH Interpreiter,然后右边选择Existing server configuration,直接箭头下拉选择刚才配置的远程服务器的地址即可
在这里插入图片描述
然后点击next
在这里插入图片描述

Interpreter中填入一个python解释器。我这里是直接通过miniconda3在远程服务器创建了环境,并且指定了环境使用的是python3.9的,箭头选中的就是我创建的环境里边的python解析器。这里可以自行搜索minidaconda3安装教程,基本一键安装
在这里插入图片描述

Sync folders则是填入远程服务器的目录的地址(也可以暂时不处理),这里就是需要在远程服务器创建文件夹了。这里工作原理简单来说就是,本地我们有一套代码,也就是我们pycharm打开的那个目录,然后远程也需要有一套代码,在远程服务器运行的时候,实际上运行的是远程服务器的代码,而不是本地的代码,这里就需要我们本地更改完文件后上传更改的文件到远程的服务器。这里pycharm就提供了自动上传还有手动修改上传的功能

在这里插入图片描述

在远程服务器上随便创建一个的文件夹,自己能够找得到就行了,然后把完整的地址填入Sync folders即可(以下是我已经同步完毕的目录)

在这里插入图片描述

点击确认后是以下的状态,其中Path Mappings表示路径映射,也就是把本地的代码的文件夹关联到远程的文件夹
在这里插入图片描述
上述是直接在配置解释器的时候就配置了远程工作文件夹的,如果是第二次直接选择解析器的时候,首次打开项目是没有目录映射的,如下图,那就需要我们自己添加了
在这里插入图片描述
点击Path Mappings右边的文件夹图标,然后再点击加号添加映射在这里插入图片描述

左边就选择当前本地的代码目录,就是pycharm打开的那个代码的父目录,右边就是选择远程的代码即可,然后确认
在这里插入图片描述
在这里插入图片描述
然后就需要将本地代码同步到远程的目录,勾选以下配置一般会自动上传在这里插入图片描述
在这里插入图片描述
首次整个是不会自动上传,或者后续出现不会上传的情况,那么可以手动上传。通过右键项目,然后点击如下选项来上传项目。对于单个文件也是一样,直接右键,然后点击Deployment,紧接着选择upload to,其中upload to后边接的名称就是刚才创建ssh连接的时候自己起的名字

在这里插入图片描述

在这里插入图片描述
一般来说是可能有多个选项的,也就是点击upload to会出现多个,选择你要运行的代码服务器即可。或者将运行代码的服务器设置为当前环境的服务器,如下就是选中相应的服务然后再点击那个就可以了,选中状态是加粗的。
在这里插入图片描述
在这里插入图片描述
到了这里,应该是远程服务器也存在跟本地一模一样的代码在这里插入图片描述

fill50k数据集加载

代码准备好了,环境也准备好了,服务器上代码也有了,接下来就差数据集了,本文主要基于官方的fill50k数据集进行测试

fill50k数据集简单介绍

fill50k数据集就是包含三个部分,一是作为条件的图像(如下conditioning_image),二是关于条件图像的描述(如下text),三是根据条件图像以及相应的描述生成的最终效果图(如下image)。最终实现效果就是给定我们自己的条件图像以及描述,然后生成相应的效果图,其中描述的是根据轮廓生成一个带颜色的圆,并且这个圆的背景颜色也是我们通过提示指定的,如第二个例子light coral circle with white background,指定了生成圆的颜色是light coral,然后背景颜色是white
在这里插入图片描述

数据集下载

1、打开链接fill50k数据集,如下图,切换到Files and versions就是相应的文件了,把里边的下载完就行(.git...git相关的属性文件,可以不下载)。下载完毕后把压缩包的都解压出来

在这里插入图片描述
处理完毕后,如下,其中conditoning_images是参考图像,images是目标文件,就是通过参考图像以及相应描述生成的目标

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
二者并不一定是需要名字相同来关联的,二者管理是通过test.jsonl文件关联的在这里插入图片描述

这里单独拿一行出来,如下是一系列键值对,也就是字典,其中text表示的是描述,conditioning_imageimage相应的图片目录。可以看到这两个目录都是只有一部分的,例如conditioning_images/15.png,并且conditioning_images就是下载的文件夹名称。说明后续是需要手动设置一个根数据目录的路径,而这个目录就是放刚才下载的数据的,所以组织数据的时候,为了能够直接用上它的代码,那么也应该类似组织同时在test.jsonl添加一些列映射

{"text": "light golden rod yellow circle with rosy brown background", "image": "images/15.png", "conditioning_image": "conditioning_images/15.png"}

2、把除了fill50k.py文件以外的文件都上传到服务器上,这里就可以自行选择一个找得到的位置存放即可

在这里插入图片描述

数据集加载

现在数据集已经有了,就要处理一下数据集,这里主要是处理数据集加载,因为数据预处理等等操作,官方源码都写好了

1、打开train_contrlnet.py,搜索load_dataset,这是diffusers封装的加载数据集的方法(这里注意的是,它是依赖datasets库的,运行报错缺少模块的时候自行安装即可),它可以通过多种格式进行加载数据的。如csv格式。但是我们这里是需要自定义的,这个就是fill50k.py文件的作用

在这里插入图片描述
2、使用pycharm打开fill50k.py,里边定义了配合load_dataset方法的加载器,就是用来自定义数据加载方式的。这里只需要关注以下方法,metadata_path填入train.jsonl的完整路径,images_dir填入目标图像目录的上级目录(就是images目录的父目录),同理conditioning_images_dir也填入参考图像目录的上级目录就是conditioning_images目录的父目录)

在这里插入图片描述
在这里插入图片描述
为什么是这样的呢,主要跟fill50k.py加载数据的代码有关,如下row["image"],这里其实就是获取的就是train.jsonl的一行数据,也就是那个字典。如下拿到的就是images/15.png,那要拿到图片,就需要知道完整路径,如下图第二个箭头就是拼接的路径,其中images_dir就是刚才拼接的放图片的目录的路径,那么我这里完整路径就变成了/media/data_7T/cxj/datasets/fill50k/train/images/15.png如果想要自定义键值还有路径的规则,就可以改这几个地方

{"text": "light golden rod yellow circle with rosy brown background", "image": "images/15.png", "conditioning_image": "conditioning_images/15.png"}

在这里插入图片描述
3、改好以后就可以上传到服务器了,这里补充一下是它原本是有路径的,但是这个路径是Hugging Face上的,也就是刚才下载数据集的网站的。如果你的服务器有网络并且也能够连接外网,那么会自动下载的。如果不行的话就是像现在这样自己下载然后改路径

在这里插入图片描述
上传完毕如下,注意这里其实并不一定放数据集一起,随便放一个能找到的地方即可

在这里插入图片描述

本地加载数据集的注意事项

上述fill50k.py是配合load_dataset方法使用的,但是本地使用会有个报错,就是数据信任的问题,这里需要在load_dataset里边加上 trust_remote_code=True

在这里插入图片描述

参数配置和参数说明

上边基本是完整了,现在存在一个问题是fill50k.py相当于load_dataset的加载器,那么怎么关联呢?这里就是需要通过参数来配置了。这个部分就是说明整个训练的参数配置

parser简单说明

1、相关参数在代码中,如下图,在add_arguments里边的就是一个配置参数

在这里插入图片描述

简单说明下,以下是键名为--pretrained_model_name_or_path的参数配置,type指定了它的类型,default指定了它的默认值,required=True表示这个参数是必须的,help就是对这个参数的说明。

在这里插入图片描述
一般使用的时候就是--pretrained_model_name_or_path="/media/data_7T/cxj/pkg/diffusion",这就是相当于一个键值对了,键是pretrained_model_name_or_path,值是/media/data_7T/cxj/pkg/diffusion"。用的时候就是通过.去访问
在这里插入图片描述
args就是解析器对象,刚才add_arguments就相当于往这个对象里边放东西

在这里插入图片描述

参数说明

对于参数部分,可以通过刚才的help字段也可以大概看出具体的作用,这里就给出一些相对有用的

--pretrained_model_name_or_path="/media/data_7T/cxj/pkg/diffusion"
--output_dir="/media/data_7T/cxj/results/controlnet_diffusers"
--dataset_name="/media/data_7T/cxj/datasets/fill50k/train/fill50k.py"
--resolution=512
--learning_rate=1e-5
--validation_image
"/media/data_7T/cxj/datasets/fill50k/valid/conditioning_image_1.png"
"/media/data_7T/cxj/datasets/fill50k/valid/conditioning_image_2.png"
--validation_prompt
"red circle with blue background"
"cyan circle with brown floral background"
--train_batch_size=4
--num_train_epochs=5
--checkpointing_steps=1000
--cache_dir="/media/data_7T/cxj/datasets/fill50k/cache"
--validation_steps=1000
  • --pretrained_model_name_or_path:预训练stable diffusion的权重的路径,前边部分是下载了这个权重文件,那么在服务器上开一个文件夹,然后把这些文件都放里边,这里设置的就是放这些权重文件的目录在这里插入图片描述

  • --output_dir:这个就是最终输出结果存放的目录,就是模型训练的时候,可以根据你的参数保存权重的,这个权重就相当于我们训练好的模型的知识,后续需要拿出来测试的。随便一个能找到都行
    在这里插入图片描述

  • --dataset_name:这里就是用来关联fill50k.pyload_dataset的,除此以外还有个--cache_dir文件夹,这个其实原本是用来缓存从Hugging Face下载的数据的,但是这里是用自己的,可以随便设一个。--train_data_dir是不需要管了,因为用的load_datsetfill50k.py封装完了
    在这里插入图片描述

  • --resolution=512--learning_rate=1e-5保持默认即可,一个是图片的大小,就是图片输入以后会被处理成512x512,第二个参数是学习率,不用管

  • --validation_image--validation_prompt是用来测试的图像,就是训练的时候,不一定是训练完全才测试,而是隔一定的时间就可以测试了,这里就是设置测试的图像按照上边的格式,放图片完整路径和对应描述就可以了,每个例子通过空格或者换行隔开都行,描述和图片是一一对应的。那么测试的间隔就是通过--validation_steps设置的

  • --train_batch_size是训练的批次大小,这里设置是4,意思就是模型训练一次(对应上边参数里边的step)就是用了4张图片,那数据集有50000张图片,50000/4=12500,就是这50000张图片要是训练完,模型就跑了12500步,--validation_steps=1000就表示模型在跑每1000步的倍数的时候就测试一下,测试也是有内容的,下边会说到

  • --num_train_epochs表示训练的总周期,1epoch就是跑完了50000张图像(整个数据集),也就是模型训练了12500次,这里--num_train_epochs=5,那么整个训练过程的训练次数就是12500x5
    - --checkpointing_steps=1000就是每几次保存一次权重了如,如下是官方代码已经写好的输出格式,其中带的数字就是输出的训练次数的权重了,不同训练次数的模型效果是不一样的

在这里插入图片描述

代码运行

运行配置

1、代码运行可以有多种方式,这里有工具就用工具了,打开pycharm,一般来说没运行过右上角是Add...字样的,这个时候只需要右键train_controlnet.py,然后点击run或者运行,然后再点红色的方块停下代码运行可以看到下图的效果了

在这里插入图片描述

在这里插入图片描述2、出现了上图效果后,点击下箭头,然后Edit configuration
在这里插入图片描述
3、弹出的是下边的配置框,然后再点击以下右边的参数配置
在这里插入图片描述
弹出的是一个编辑框,然后把上边的配置参数粘贴进去即可(可以根据情况修改),最后就是点击下边的Apply然后Ok确认就行了

在这里插入图片描述

--pretrained_model_name_or_path="/media/data_7T/cxj/pkg/diffusion"
--output_dir="/media/data_7T/cxj/results/controlnet_diffusers"
--dataset_name="/media/data_7T/cxj/datasets/fill50k/train/fill50k.py"
--resolution=512
--learning_rate=1e-5
--validation_image
"/media/data_7T/cxj/datasets/fill50k/valid/conditioning_image_1.png"
"/media/data_7T/cxj/datasets/fill50k/valid/conditioning_image_2.png"
--validation_prompt
"red circle with blue background"
"cyan circle with brown floral background"
--train_batch_size=4
--num_train_epochs=5
--checkpointing_steps=1000
--cache_dir="/media/data_7T/cxj/datasets/fill50k/cache"
--validation_steps=1000
官方参数设置

看到这里应该是基本了解配置是怎么回事了,对于运行过程中不同参数配置,它是依赖于不同环境的,不同环境,例如显存不同,是有不同的启动参数的,这里可以参考官方的文档

在这里插入图片描述

启动效果

那么现在参数配置好,基本就可以了。直接点击pycharm的绿色按钮直接运行就可以了,或者右键train_controlnet.py然后点运行也可以。下边就是跑起来的效果,其中有个进度条,表示训练多少次还有训练的总次数

在这里插入图片描述上述是通过工具快捷配置的,如果不用工具,在命令行的话就是多加点命令,如下,事实上pycharm也是帮我们构建如下完整的命令的

python train_controlnet.py
--pretrained_model_name_or_path="/media/data_7T/cxj/pkg/diffusion"
--output_dir="/media/data_7T/cxj/results/controlnet_diffusers"
--dataset_name="/media/data_7T/cxj/datasets/fill50k/train/fill50k.py"
--resolution=512
--learning_rate=1e-5
--validation_image
"/media/data_7T/cxj/datasets/fill50k/valid/conditioning_image_1.png"
"/media/data_7T/cxj/datasets/fill50k/valid/conditioning_image_2.png"
--validation_prompt
"red circle with blue background"
"cyan circle with brown floral background"
--train_batch_size=4
--num_train_epochs=5
--checkpointing_steps=1000
--cache_dir="/media/data_7T/cxj/datasets/fill50k/cache"
--validation_steps=1000

测试过程

查看验证效果的两种方法(本地和远程查看)

默认是通过tensorboard输出结果的,点击自己设置的输出文件夹,也就是上边output_dir对应的路径,然后一直往里点,可以找打l日志文件

在这里插入图片描述

在这里插入图片描述
首先查看的一种方法是直接下载到本地查看,首先在本地主机随便创建一个文件夹在这里插入图片描述

然后将远程的日志文件传输到该文件夹
在这里插入图片描述
然后在本地的存放权重的文件夹打开命令行(我这里是log_diffusers文件夹,具体看你放哪里了),按住shift然后对着文件夹右键打开powershell。打开的初始路径就是我创建的文件夹的路径。这里注意powershell里边需要安装tensorboard,因为我本地环境也安装了miniconda3,所以看起来就是下边这样

在这里插入图片描述
然后通过一下命令tensorboard --logdir=.,这里的.就表示当前目录,就是当前目录有刚才转移过来的日志文件。如果是路径的,也可以是tensorboard --logdir=/media/logs_diffusers这样的绝对路径,注意路径不要加双引号,**Windows上要注意正斜杠的转义,也就是用E:\\logs_diffusers而不是E:\logs_diffusers**
在这里插入图片描述

开启以后就下边这样,然后按照提示,浏览器输入网址即可

在这里插入图片描述
在这里插入图片描述

如果是远程查看的话,就可以配置ssh,更加方便,这里有教程。我这就不展开了

测试代码

以下是测试代码,直接创建一个py文件并通过pycharm上传到远程服务器然后运行即可。其中需要关注的参数是

  • base_sd_model_path :这个就是上边下载的stable diffusion v1.5的权重
  • checkpoint_step 是加载的模型,就是训练过程中可以通过--checkpointing_steps参数配置保存不同的阶段的模型,这里就是指要加载哪个阶段的
  • base_controlnet_output_dir :这里就是上边训练的时候对应的out_dir,也就是训练结果的存储目录
  • 下方images_pathprompts,可以填写多张参考图像和提示进行测试,根据个人需求即可
  • grid_output_path就是输出的结果的路径,我处理结果成一张网格图片来的,自动根据自己的需求修改即可
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import torch
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from torchvision import transforms



# 设置模型路径
checkpoint_step = 10000 # 要加载哪个step的模型
base_sd_model_path = "/media/data_7T/cxj/pkg/diffusion" # 基本的sdv 1-5的权重的路径
base_controlnet_output_dir = "/media/data_7T/cxj/results/controlnet_diffusers" # controlnet的路径
grid_output_path = f"/media/data_7T/cxj/results/controlnet_diffusers/test_images/grid_output_{checkpoint_step}.png"

controlnet_path = f"{base_controlnet_output_dir}/checkpoint-{checkpoint_step}/controlnet"

# 加载模型
controlnet = ControlNetModel.from_pretrained(controlnet_path)
pipe = StableDiffusionControlNetPipeline.from_pretrained(base_sd_model_path, controlnet=controlnet)

# 优化扩散过程
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

# 设置安全检查器
pipe.safety_checker = lambda images, clip_input: (images, None)

# 示例:多张图片的路径
images_path = [
    "/media/data_7T/cxj/datasets/fill50k/valid/conditioning_image_1.png",
    "/media/data_7T/cxj/datasets/fill50k/valid/conditioning_image_2.png",
    "/media/data_7T/cxj/datasets/fill50k/valid/0.png"
]

# 示例:对应每张图片的提示语
prompts = [
    "red circle with blue background",
    "silver circle with powder blue background",
    "pale golden rod circle with old lace background"
]

# 创建一个n x 3的网格,n是图像的数量
n = len(images_path)

# 创建一个n行3列的子图网格
fig, axes = plt.subplots(n, 3, figsize=(15, 5 * n))

# 确保axes是一个二维数组,即使n=1
if n == 1:
    axes = [axes]

image_transforms = transforms.Compose([
    transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.CenterCrop(512),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

# 遍历每张源图片,生成对应的图像并显示
for i, (image_path, prompt) in enumerate(zip(images_path, prompts)):
    # 加载控制图像
    control_image = load_image(image_path)

    # 将模型移动到指定的GPU
    pipe.to("cuda:3")

    # 设置生成器的随机种子
    generator = torch.manual_seed(0)

    # 生成图像
    generated_image = pipe(
        prompt, num_inference_steps=40, generator=generator, image=control_image
    ).images[0]

    # 在第i行填充网格中的三个格子
    # 第一个格子:控制图像
    axes[i][0].imshow(control_image)
    axes[i][0].axis('off')
    axes[i][0].set_title("Control Image")

    # 第二个格子:生成的图像
    axes[i][1].imshow(generated_image)
    axes[i][1].axis('off')
    axes[i][1].set_title("Generated Image")

    # 第三个格子:显示prompt文本
    axes[i][2].text(0.5, 0.5, prompt, fontsize=12, ha='center', va='center', wrap=True)
    axes[i][2].axis('off')
    axes[i][2].set_title("Prompt")

# 调整布局,确保内容不重叠
plt.tight_layout()

# 保存网格为本地文件
plt.savefig(grid_output_path, bbox_inches='tight')
print(f"输出路径是:{grid_output_path}")
测试结果

首先我使用的参数是上文提到的参数,只不过我设置的batch_size8,需要的显存大概是26G,根据个人情况设置即可,batch_size小的话,训练轮次可以稍微大点

--pretrained_model_name_or_path="/media/data_7T/cxj/pkg/diffusion"
--output_dir="/media/data_7T/cxj/results/controlnet_diffusers"
--dataset_name="/media/data_7T/cxj/datasets/fill50k/train/fill50k.py"
--resolution=512
--learning_rate=1e-5
--validation_image
"/media/data_7T/cxj/datasets/fill50k/valid/conditioning_image_1.png"
"/media/data_7T/cxj/datasets/fill50k/valid/conditioning_image_2.png"
--validation_prompt
"red circle with blue background"
"cyan circle with brown floral background"
--train_batch_size=8
--num_train_epochs=5
--checkpointing_steps=1000
--cache_dir="/media/data_7T/cxj/datasets/fill50k/cache"
--validation_steps=1000
  • 训练次数是1000的效果,差异有点大

在这里插入图片描述

  • 训练次数是5000的效果,差异还是很大在这里插入图片描述

  • 训练次数是6000的效果,颜色相对而言要好一点了
    在这里插入图片描述

  • 训练次数是8000的效果,跟6000差不了多少
    在这里插入图片描述

  • 训练次数是10000的效果,效果好很多,基本是有效了
    在这里插入图片描述


http://www.kler.cn/a/447857.html

相关文章:

  • jvm栈帧中的动态链接
  • 车载网关性能 --- 缓存buffer划分要求
  • 摩尔信使MThings的逻辑控制功能范例
  • 构建高性能异步任务引擎:FastAPI + Celery + Redis
  • Pytorch | 从零构建Vgg对CIFAR10进行分类
  • linux-----常用指令
  • 数据结构与算法再探(三)树
  • 本地电脑使用命令行上传文件至远程服务器
  • Pydantic 2.0 完整指南
  • k8s 创建密钥以及证书安装
  • Jackson 的@JsonRawValue
  • Python 自带的日期日历处理大师:calendar 库
  • Paimon 是什么?Apache Paimon简介
  • 项目2路由交换
  • 米思齐图形化编程之ESP32开发指导
  • PostgreSQL表达式的类型
  • 晶闸管-直流电动机调速系统设计【MATLAB源码+Word文档】
  • 【系统移植】NFS服务器环境搭建——挂载根文件系统
  • Linux网络——网络套接字
  • java小知识点:比较器
  • 使用PyTorch实现GPT-2直接偏好优化训练:DPO方法改进及其与监督微调的效果对比
  • 机器学习(四)-回归模型评估指标
  • 【LeetCode】906、超级回文数
  • vue入门教程:组件透传 Attributes
  • c++领域展开第四幕——类和对象(上篇收尾 this指针、c++和c语言的初步对比)超详细!!!!
  • 如何使用PSQL Tool还原pg数据库(sql格式)