使用您自己的图像微调 FLUX.1 LORA 并使用 Azure 机器学习进行部署
目录
介绍
了解 Flux.1 模型系列
什么是 Dreambooth?
先决条件
使用 Dreambooth 微调 Flux 的步骤
步骤 1:设置环境
第 2 步:加载库
步骤 3:准备数据集
3.1 通过 AML 数据资产(URI 文件夹)将图像上传到数据存储区
步骤 4:创建训练环境
步骤 5:创建计算
步骤 6:创建计算
步骤 7:下载微调模型并注册模型
步骤 8:在线管理端点部署
步骤 9:为在线端点创建推理环境
步骤 10:为托管在线端点创建部署
步骤 11:测试部署
结论
介绍
人工智能和机器学习领域继续快速发展,生成式 AI 模型取得了重大进展。Black Forest Labs 的 FLUX.1 模型套件就是其中一项显著进展。这些模型突破了文本到图像合成的界限,提供了无与伦比的图像细节、快速一致性和风格多样性。在本博客中,我们将深入研究使用 Dreambooth 对 FLUX 模型进行微调的过程,这种方法因其在生成高质量、定制的 AI 生成内容方面的有效性而广受欢迎。
了解 Flux.1 模型系列
Black Forest Labs 推出了 FLUX.1 模型的三种变体:
- FLUX.1 [pro]:具有一流图像生成功能的高级产品,可供非商业用途使用。
- FLUX.1 [dev]:一种开放重量、指导提炼的模型,用于非商业用途,提供高效的性能。
- FLUX.1 [schnell]:专为本地开发和个人使用而设计,根据 Apache 2.0 许可提供。
欲了解更多信息,请点击此处的官方公告
这些模型基于多模态和并行扩散变压器块的混合架构,可扩展至 120 亿个参数。它们提供最先进的性能,超越其他领先模型。
什么是 Dreambooth?
Dreambooth 是一种使用小型数据集对生成模型进行微调以生成高度定制化输出的技术。它利用预先训练的模型的现有功能,并通过微调数据集中提供的特定细节、风格或主题对其进行增强。此方法对于需要个性化内容生成的应用程序特别有用。
先决条件
在我们继续使用 Dreambooth 对LUX.1 [schnell] 模型进行微调之前,请确保您具有以下条件:
- 访问FLUX.1 [schnell]模型,可以在 HuggingFace 上找到。
- 包含用于微调的图像和相应文本描述的数据集。
- 具有足够资源(例如 GPU)的计算环境来处理训练过程。
使用 Dreambooth 微调 Flux 的步骤
在这篇博客中,我们将利用 Azure 机器学习来微调文本到图像模型,根据文本输入生成狗的图片。
在开始之前,请确保您已准备好以下物品:
- 有权访问 Azure 机器学习的 Azure 帐户。
- 对 Python 和 Jupyter 笔记本有基本的了解。
- 熟悉 Hugging Face 的 Diffusers 库。
步骤 1:设置环境
首先,通过安装必要的库来设置您的环境。您可以使用以下命令:
<span style="color:#3e3e3e"><span style="background-color:#f5f5f5"><code class="language-bash">pip install transformers diffusers accelerate
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
</code></span></span>
第 2 步:加载库
加载库
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#00e0e0">import</span> sys
sys<span style="color:#fefefe">.</span>path<span style="color:#fefefe">.</span>insert<span style="color:#fefefe">(</span><span style="color:#00e0e0">0</span><span style="color:#fefefe">,</span> <span style="color:#abe338">'..'</span><span style="color:#fefefe">)</span>
<span style="color:#00e0e0">import</span> os
<span style="color:#00e0e0">import</span> shutil
<span style="color:#00e0e0">import</span> random
<span style="color:#00e0e0">from</span> azure<span style="color:#fefefe">.</span>ai<span style="color:#fefefe">.</span>ml <span style="color:#00e0e0">import</span> automl<span style="color:#fefefe">,</span> Input<span style="color:#fefefe">,</span> Output<span style="color:#fefefe">,</span> MLClient<span style="color:#fefefe">,</span> command<span style="color:#fefefe">,</span> load_job
<span style="color:#00e0e0">from</span> azure<span style="color:#fefefe">.</span>identity <span style="color:#00e0e0">import</span> DefaultAzureCredential<span style="color:#fefefe">,</span> InteractiveBrowserCredential
<span style="color:#00e0e0">from</span> azure<span style="color:#fefefe">.</span>ai<span style="color:#fefefe">.</span>ml<span style="color:#fefefe">.</span>entities <span style="color:#00e0e0">import</span> Data<span style="color:#fefefe">,</span> Environment<span style="color:#fefefe">,</span> AmlCompute
<span style="color:#00e0e0">from</span> azure<span style="color:#fefefe">.</span>ai<span style="color:#fefefe">.</span>ml<span style="color:#fefefe">.</span>constants <span style="color:#00e0e0">import</span> AssetTypes
<span style="color:#00e0e0">from</span> azure<span style="color:#fefefe">.</span>core<span style="color:#fefefe">.</span>exceptions <span style="color:#00e0e0">import</span> ResourceNotFoundError
<span style="color:#00e0e0">import</span> matplotlib<span style="color:#fefefe">.</span>pyplot <span style="color:#00e0e0">as</span> plt
<span style="color:#00e0e0">import</span> mlflow
<span style="color:#00e0e0">from</span> mlflow<span style="color:#fefefe">.</span>tracking<span style="color:#fefefe">.</span>client <span style="color:#00e0e0">import</span> MlflowClient</code></span></span>
在深入研究代码之前,您需要连接到您的工作区。工作区是 Azure 机器学习的顶级资源,提供了一个集中的位置来处理使用 Azure 机器学习时创建的所有工件。
我们用它来访问工作区。应该能够处理大多数场景。如果你想了解更多其他可用凭据,请转到设置身份验证文档、azure-identity 参考文档。 DefaultAzureCredential
DefaultAzureCredential
将下面单元格中的、和替换为其各自的值。 AML_WORKSPACE_NAME
RESOURCE_GROUP
SUBSCRIPTION_ID
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#00e0e0">from</span> azure<span style="color:#fefefe">.</span>ai<span style="color:#fefefe">.</span>ml <span style="color:#00e0e0">import</span> MLClient
<span style="color:#00e0e0">from</span> azure<span style="color:#fefefe">.</span>identity <span style="color:#00e0e0">import</span> DefaultAzureCredential
credential <span style="color:#00e0e0">=</span> DefaultAzureCredential<span style="color:#fefefe">(</span><span style="color:#fefefe">)</span>
workspace_ml_client <span style="color:#00e0e0">=</span> <span style="color:#00e0e0">None</span>
<span style="color:#00e0e0">try</span><span style="color:#fefefe">:</span>
workspace_ml_client <span style="color:#00e0e0">=</span> MLClient<span style="color:#fefefe">.</span>from_config<span style="color:#fefefe">(</span>credential<span style="color:#fefefe">)</span>
subscription_id <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>subscription_id
resource_group <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>resource_group_name
workspace_name <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>workspace_name
<span style="color:#00e0e0">except</span> Exception <span style="color:#00e0e0">as</span> ex<span style="color:#fefefe">:</span>
<span style="color:#00e0e0">print</span><span style="color:#fefefe">(</span>ex<span style="color:#fefefe">)</span>
<span style="color:#d4d0ab"># Enter details of your AML workspace</span>
subscription_id <span style="color:#00e0e0">=</span> <span style="color:#abe338">"SUBSCRIPTION_ID"</span>
resource_group <span style="color:#00e0e0">=</span> <span style="color:#abe338">"RESOURCE_GROUP"</span>
workspace_name <span style="color:#00e0e0">=</span> <span style="color:#abe338">"AML_WORKSPACE_NAME"</span>
workspace_ml_client <span style="color:#00e0e0">=</span> MLClient<span style="color:#fefefe">(</span>
credential<span style="color:#fefefe">,</span> subscription_id<span style="color:#fefefe">,</span> resource_group<span style="color:#fefefe">,</span> workspace_name
<span style="color:#fefefe">)</span>
registry_ml_client <span style="color:#00e0e0">=</span> MLClient<span style="color:#fefefe">(</span>
credential<span style="color:#fefefe">,</span>
subscription_id<span style="color:#fefefe">,</span>
resource_group<span style="color:#fefefe">,</span>
registry_name<span style="color:#00e0e0">=</span><span style="color:#abe338">"azureml"</span><span style="color:#fefefe">,</span>
<span style="color:#fefefe">)</span></code></span></span>
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python">workspace <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>workspace_name
subscription_id <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>workspaces<span style="color:#fefefe">.</span>get<span style="color:#fefefe">(</span>workspace<span style="color:#fefefe">)</span><span style="color:#fefefe">.</span><span style="color:#abe338">id</span><span style="color:#fefefe">.</span>split<span style="color:#fefefe">(</span><span style="color:#abe338">"/"</span><span style="color:#fefefe">)</span><span style="color:#fefefe">[</span><span style="color:#00e0e0">2</span><span style="color:#fefefe">]</span>
resource_group <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>workspaces<span style="color:#fefefe">.</span>get<span style="color:#fefefe">(</span>workspace<span style="color:#fefefe">)</span><span style="color:#fefefe">.</span>resource_group
local_train_data <span style="color:#00e0e0">=</span> <span style="color:#abe338">'./train-data/monu/'</span> <span style="color:#d4d0ab"># Azure ML dataset will be created for training on this content</span>
generated_images <span style="color:#00e0e0">=</span> <span style="color:#abe338">'./results/monu'</span>
azureml_dataset_name <span style="color:#00e0e0">=</span> <span style="color:#abe338">'monu'</span> <span style="color:#d4d0ab"># Name of the dataset </span>
train_target <span style="color:#00e0e0">=</span> <span style="color:#abe338">'gpu-cluster-big'</span>
experiment_name <span style="color:#00e0e0">=</span> <span style="color:#abe338">'dreambooth-finetuning'</span>
training_env_name <span style="color:#00e0e0">=</span> <span style="color:#abe338">'dreambooth-flux-train-envn'</span>
inference_env_name <span style="color:#00e0e0">=</span> <span style="color:#abe338">'flux-inference-envn'</span></code></span></span>
步骤 3:准备数据集
通过整理图像及其描述来准备数据集。确保数据的格式与 Dreambooth 兼容。以下是示例结构:
train-data/monu/
image_1.jpg
image_2.jpg
...
3.1 通过 AML 数据资产(URI 文件夹)将图像上传到数据存储区
为了使用数据在 Azure ML 中进行训练,我们将其上传到 Azure ML 工作区的默认 Azure Blob 存储。
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#d4d0ab"># Register dataset</span>
my_data <span style="color:#00e0e0">=</span> Data<span style="color:#fefefe">(</span>
path<span style="color:#00e0e0">=</span> local_train_data<span style="color:#fefefe">,</span>
<span style="color:#abe338">type</span><span style="color:#00e0e0">=</span> AssetTypes<span style="color:#fefefe">.</span>URI_FOLDER<span style="color:#fefefe">,</span>
description<span style="color:#00e0e0">=</span> <span style="color:#abe338">"Training images for Dreambooth finetuning"</span><span style="color:#fefefe">,</span>
name<span style="color:#00e0e0">=</span> azureml_dataset_name
<span style="color:#fefefe">)</span>
workspace_ml_client<span style="color:#fefefe">.</span>data<span style="color:#fefefe">.</span>create_or_update<span style="color:#fefefe">(</span>my_data<span style="color:#fefefe">)</span></code></span></span>
步骤 4:创建训练环境
我们需要一个 dreambooth-conda.yaml 文件来创建我们的客户环境。
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-yaml"><span style="color:#ffd700">name</span><span style="color:#fefefe">:</span> dreambooth<span style="color:#fefefe">-</span>flux<span style="color:#fefefe">-</span>env
<span style="color:#ffd700">channels</span><span style="color:#fefefe">:</span>
<span style="color:#fefefe">-</span> conda<span style="color:#fefefe">-</span>forge
<span style="color:#ffd700">dependencies</span><span style="color:#fefefe">:</span>
<span style="color:#fefefe">-</span> python=3.10
<span style="color:#fefefe">-</span> <span style="color:#ffd700">pip</span><span style="color:#fefefe">:</span>
<span style="color:#fefefe">-</span> <span style="color:#abe338">'git+https://github.com/huggingface/diffusers.git'</span>
<span style="color:#fefefe">-</span> transformers<span style="color:#fefefe">></span>=4.41.2
<span style="color:#fefefe">-</span> azureml<span style="color:#fefefe">-</span>acft<span style="color:#fefefe">-</span>accelerator==0.0.59
<span style="color:#fefefe">-</span> azureml_acft_common_components==0.0.59
<span style="color:#fefefe">-</span> azureml<span style="color:#fefefe">-</span>acft<span style="color:#fefefe">-</span>contrib<span style="color:#fefefe">-</span>hf<span style="color:#fefefe">-</span>nlp==0.0.59
<span style="color:#fefefe">-</span> azureml<span style="color:#fefefe">-</span>evaluate<span style="color:#fefefe">-</span>mlflow==0.0.59
<span style="color:#fefefe">-</span> azureml<span style="color:#fefefe">-</span>metrics<span style="color:#fefefe">[</span>text<span style="color:#fefefe">]</span>==0.0.59
<span style="color:#fefefe">-</span> mltable==1.6.1
<span style="color:#fefefe">-</span> mpi4py==3.1.5
<span style="color:#fefefe">-</span> sentencepiece==0.1.99
<span style="color:#fefefe">-</span> transformers==4.44.0
<span style="color:#fefefe">-</span> datasets==2.17.1
<span style="color:#fefefe">-</span> optimum==1.17.1
<span style="color:#fefefe">-</span> accelerate<span style="color:#fefefe">></span>=0.31.0
<span style="color:#fefefe">-</span> onnxruntime==1.17.3
<span style="color:#fefefe">-</span> rouge<span style="color:#fefefe">-</span>score==0.1.2
<span style="color:#fefefe">-</span> sacrebleu==2.4.0
<span style="color:#fefefe">-</span> bitsandbytes==0.43.3
<span style="color:#fefefe">-</span> einops==0.7.0
<span style="color:#fefefe">-</span> aiohttp==3.10.5
<span style="color:#fefefe">-</span> peft==0.8.2
<span style="color:#fefefe">-</span> deepspeed==0.15.0
<span style="color:#fefefe">-</span> trl==0.8.1
<span style="color:#fefefe">-</span> tiktoken==0.6.0
<span style="color:#fefefe">-</span> scipy==1.14.0
</code></span></span>
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python">environment <span style="color:#00e0e0">=</span> Environment<span style="color:#fefefe">(</span>
image<span style="color:#00e0e0">=</span><span style="color:#abe338">"mcr.microsoft.com/azureml/curated/acft-hf-nlp-gpu:67"</span><span style="color:#fefefe">,</span>
conda_file<span style="color:#00e0e0">=</span><span style="color:#abe338">"environment/dreambooth-conda.yaml"</span><span style="color:#fefefe">,</span>
name<span style="color:#00e0e0">=</span>training_env_name<span style="color:#fefefe">,</span>
description<span style="color:#00e0e0">=</span><span style="color:#abe338">"Dreambooth training environment"</span><span style="color:#fefefe">,</span>
<span style="color:#fefefe">)</span>
workspace_ml_client<span style="color:#fefefe">.</span>environments<span style="color:#fefefe">.</span>create_or_update<span style="color:#fefefe">(</span>environment<span style="color:#fefefe">)</span></code></span></span>
步骤 5:创建计算
为了在 Azure 机器学习工作室上微调模型,您需要先创建计算资源。创建计算需要 3-4 分钟。
有关更多参考,请参阅一日 Azure 机器学习。
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#00e0e0">try</span><span style="color:#fefefe">:</span>
_ <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>compute<span style="color:#fefefe">.</span>get<span style="color:#fefefe">(</span>train_target<span style="color:#fefefe">)</span>
<span style="color:#00e0e0">print</span><span style="color:#fefefe">(</span><span style="color:#abe338">"Found existing compute target."</span><span style="color:#fefefe">)</span>
<span style="color:#00e0e0">except</span> ResourceNotFoundError<span style="color:#fefefe">:</span>
<span style="color:#00e0e0">print</span><span style="color:#fefefe">(</span><span style="color:#abe338">"Creating a new compute target..."</span><span style="color:#fefefe">)</span>
compute_config <span style="color:#00e0e0">=</span> AmlCompute<span style="color:#fefefe">(</span>
name<span style="color:#00e0e0">=</span>train_target<span style="color:#fefefe">,</span>
<span style="color:#abe338">type</span><span style="color:#00e0e0">=</span><span style="color:#abe338">"amlcompute"</span><span style="color:#fefefe">,</span>
size<span style="color:#00e0e0">=</span><span style="color:#abe338">"Standard_NC24ads_A100_v4"</span><span style="color:#fefefe">,</span> <span style="color:#d4d0ab"># 1 x A100, 80 GB GPU memory each</span>
tier<span style="color:#00e0e0">=</span><span style="color:#abe338">"low_priority"</span><span style="color:#fefefe">,</span>
idle_time_before_scale_down<span style="color:#00e0e0">=</span><span style="color:#00e0e0">600</span><span style="color:#fefefe">,</span>
min_instances<span style="color:#00e0e0">=</span><span style="color:#00e0e0">0</span><span style="color:#fefefe">,</span>
max_instances<span style="color:#00e0e0">=</span><span style="color:#00e0e0">2</span><span style="color:#fefefe">,</span>
<span style="color:#fefefe">)</span>
workspace_ml_client<span style="color:#fefefe">.</span>begin_create_or_update<span style="color:#fefefe">(</span>compute_config<span style="color:#fefefe">)</span></code></span></span>
步骤 6:创建计算
我们将使用此笔记本中的模型。按照本指南,您已成功使用 Azure 上的 Diffusers 和 Dreambooth 对文本到图像模型进行了微调。此模型可以根据文本描述生成高质量的狗图像,展示了结合这些高级技术的强大功能和灵活性。请随意尝试不同的提示和微调参数,以进一步探索模型的功能。 black-forest-labs/FLUX.1-schnell
首先让我们创建一个命令行指令
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python">command_str <span style="color:#00e0e0">=</span> <span style="color:#abe338">'''python prepare.py && accelerate launch train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-schnell" \
--instance_data_dir=${{inputs.input_data}} \
--output_dir="outputs/models" \
--mixed_precision="bf16" \
--instance_prompt="photo of sks dog" \
--class_prompt="photo of a dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-5 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=2500 \
--seed="0"'''</span></code></span></span>
如您所见,您需要 2 个文件来运行上述命令行 train_dreambooth_lora_flux.py 和 prepare.py。您可以从此处的官方扩散器存储库下载 train_dreambooth_lora_flux.py 。
以下是 prepare.py 的代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#00e0e0">import</span> os
os<span style="color:#fefefe">.</span>environ<span style="color:#fefefe">[</span><span style="color:#abe338">"PYTORCH_CUDA_ALLOC_CONF"</span><span style="color:#fefefe">]</span> <span style="color:#00e0e0">=</span> <span style="color:#abe338">"max_split_size_mb:100"</span>
<span style="color:#00e0e0">from</span> accelerate<span style="color:#fefefe">.</span>utils <span style="color:#00e0e0">import</span> write_basic_config
write_basic_config<span style="color:#fefefe">(</span><span style="color:#fefefe">)</span></code></span></span>
你的文件夹结构应该是这样的
src/
prepare.py
train_dreambooth_lora_flux.py
现在让我们初始化一些变量。
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#d4d0ab"># Retrieve latest version of dataset</span>
latest_version <span style="color:#00e0e0">=</span> <span style="color:#fefefe">[</span>dataset<span style="color:#fefefe">.</span>latest_version <span style="color:#00e0e0">for</span> dataset <span style="color:#00e0e0">in</span> workspace_ml_client<span style="color:#fefefe">.</span>data<span style="color:#fefefe">.</span><span style="color:#abe338">list</span><span style="color:#fefefe">(</span><span style="color:#fefefe">)</span> <span style="color:#00e0e0">if</span> dataset<span style="color:#fefefe">.</span>name <span style="color:#00e0e0">==</span> azureml_dataset_name<span style="color:#fefefe">]</span><span style="color:#fefefe">[</span><span style="color:#00e0e0">0</span><span style="color:#fefefe">]</span>
dataset_asset <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>data<span style="color:#fefefe">.</span>get<span style="color:#fefefe">(</span>name<span style="color:#00e0e0">=</span> azureml_dataset_name<span style="color:#fefefe">,</span> version<span style="color:#00e0e0">=</span> latest_version<span style="color:#fefefe">)</span>
<span style="color:#00e0e0">print</span><span style="color:#fefefe">(</span><span style="color:#abe338">f'Latest version of </span><span style="color:#fefefe">{</span>azureml_dataset_name<span style="color:#fefefe">}</span><span style="color:#abe338">: </span><span style="color:#fefefe">{</span>latest_version<span style="color:#fefefe">}</span><span style="color:#abe338">'</span><span style="color:#fefefe">)</span>
inputs <span style="color:#00e0e0">=</span> <span style="color:#fefefe">{</span><span style="color:#abe338">"input_data"</span><span style="color:#fefefe">:</span> Input<span style="color:#fefefe">(</span><span style="color:#abe338">type</span><span style="color:#00e0e0">=</span>AssetTypes<span style="color:#fefefe">.</span>URI_FOLDER<span style="color:#fefefe">,</span> path<span style="color:#00e0e0">=</span><span style="color:#abe338">f'azureml:</span><span style="color:#fefefe">{</span>azureml_dataset_name<span style="color:#fefefe">}</span><span style="color:#abe338">:</span><span style="color:#fefefe">{</span>latest_version<span style="color:#fefefe">}</span><span style="color:#abe338">'</span><span style="color:#fefefe">)</span><span style="color:#fefefe">}</span>
outputs <span style="color:#00e0e0">=</span> <span style="color:#fefefe">{</span><span style="color:#abe338">"output_dir"</span><span style="color:#fefefe">:</span> Output<span style="color:#fefefe">(</span><span style="color:#abe338">type</span><span style="color:#00e0e0">=</span>AssetTypes<span style="color:#fefefe">.</span>URI_FOLDER<span style="color:#fefefe">)</span><span style="color:#fefefe">}</span></code></span></span>
在这种情况下,我们提交一份包含上述代码、计算和创建的环境的作业。
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python">job <span style="color:#00e0e0">=</span> command<span style="color:#fefefe">(</span>
inputs <span style="color:#00e0e0">=</span> inputs<span style="color:#fefefe">,</span>
outputs <span style="color:#00e0e0">=</span> outputs<span style="color:#fefefe">,</span>
code <span style="color:#00e0e0">=</span> <span style="color:#abe338">"./src"</span><span style="color:#fefefe">,</span>
command <span style="color:#00e0e0">=</span> command_str<span style="color:#fefefe">,</span>
environment <span style="color:#00e0e0">=</span> <span style="color:#abe338">f"</span><span style="color:#fefefe">{</span>training_env_name<span style="color:#fefefe">}</span><span style="color:#abe338">:latest"</span><span style="color:#fefefe">,</span>
compute <span style="color:#00e0e0">=</span> train_target<span style="color:#fefefe">,</span>
experiment_name <span style="color:#00e0e0">=</span> experiment_name<span style="color:#fefefe">,</span>
display_name<span style="color:#00e0e0">=</span> <span style="color:#abe338">"flux-finetune-batchsize-1"</span><span style="color:#fefefe">,</span>
environment_variables <span style="color:#00e0e0">=</span> <span style="color:#fefefe">{</span><span style="color:#abe338">'HF_TOKEN'</span><span style="color:#fefefe">:</span> <span style="color:#abe338">'Place Your HF Token Here'</span><span style="color:#fefefe">}</span>
<span style="color:#fefefe">)</span>
returned_job <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>jobs<span style="color:#fefefe">.</span>create_or_update<span style="color:#fefefe">(</span>job<span style="color:#fefefe">)</span>
returned_job</code></span></span>
步骤 7:下载微调模型并注册模型
微调后,评估模型以确保其满足您的要求。
我们将从微调作业的输出中注册模型。这将跟踪微调模型和微调作业之间的血统。此外,微调作业还跟踪基础模型、数据和训练代码的血统。
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#d4d0ab"># Obtain the tracking URL from MLClient</span>
MLFLOW_TRACKING_URI <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>workspaces<span style="color:#fefefe">.</span>get<span style="color:#fefefe">(</span>name<span style="color:#00e0e0">=</span>workspace_ml_client<span style="color:#fefefe">.</span>workspace_name<span style="color:#fefefe">)</span><span style="color:#fefefe">.</span>mlflow_tracking_uri
<span style="color:#d4d0ab"># Set the MLFLOW TRACKING URI</span>
mlflow<span style="color:#fefefe">.</span>set_tracking_uri<span style="color:#fefefe">(</span>MLFLOW_TRACKING_URI<span style="color:#fefefe">)</span>
<span style="color:#d4d0ab"># Initialize MLFlow client</span>
mlflow_client <span style="color:#00e0e0">=</span> MlflowClient<span style="color:#fefefe">(</span><span style="color:#fefefe">)</span>
mlflow_run <span style="color:#00e0e0">=</span> mlflow_client<span style="color:#fefefe">.</span>get_run<span style="color:#fefefe">(</span>returned_job<span style="color:#fefefe">.</span>name<span style="color:#fefefe">)</span>
mlflow<span style="color:#fefefe">.</span>artifacts<span style="color:#fefefe">.</span>download_artifacts<span style="color:#fefefe">(</span>run_id<span style="color:#00e0e0">=</span> mlflow_run<span style="color:#fefefe">.</span>info<span style="color:#fefefe">.</span>run_id<span style="color:#fefefe">,</span>
artifact_path<span style="color:#00e0e0">=</span><span style="color:#abe338">"outputs/models/"</span><span style="color:#fefefe">,</span> <span style="color:#d4d0ab"># Azure ML job output</span>
dst_path<span style="color:#00e0e0">=</span><span style="color:#abe338">"./train-artifacts"</span><span style="color:#fefefe">)</span> <span style="color:#d4d0ab"># local folder</span></code></span></span>
现在让我们下载模型
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python">json_path <span style="color:#00e0e0">=</span> <span style="color:#abe338">"./train-artifacts/outputs/models/pytorch_lora_weights.safetensors"</span>
<span style="color:#00e0e0">if</span> os<span style="color:#fefefe">.</span>path<span style="color:#fefefe">.</span>isdir<span style="color:#fefefe">(</span><span style="color:#abe338">"./train-artifacts/outputs/models/pytorch_lora_weights.safetensors"</span><span style="color:#fefefe">)</span><span style="color:#fefefe">:</span>
shutil<span style="color:#fefefe">.</span>rmtree<span style="color:#fefefe">(</span>json_path<span style="color:#fefefe">)</span>
mlflow<span style="color:#fefefe">.</span>artifacts<span style="color:#fefefe">.</span>download_artifacts<span style="color:#fefefe">(</span>run_id<span style="color:#00e0e0">=</span> mlflow_run<span style="color:#fefefe">.</span>info<span style="color:#fefefe">.</span>run_id<span style="color:#fefefe">,</span>
artifact_path<span style="color:#00e0e0">=</span><span style="color:#abe338">"outputs/models/pytorch_lora_weights.safetensors"</span><span style="color:#fefefe">,</span> <span style="color:#d4d0ab"># Azure ML job output</span>
dst_path<span style="color:#00e0e0">=</span><span style="color:#abe338">"./train-artifacts"</span><span style="color:#fefefe">)</span> <span style="color:#d4d0ab"># local folder</span></code></span></span>
最后让我们注册模型。
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#00e0e0">from</span> azure<span style="color:#fefefe">.</span>ai<span style="color:#fefefe">.</span>ml<span style="color:#fefefe">.</span>entities <span style="color:#00e0e0">import</span> Model
<span style="color:#00e0e0">from</span> azure<span style="color:#fefefe">.</span>ai<span style="color:#fefefe">.</span>ml<span style="color:#fefefe">.</span>constants <span style="color:#00e0e0">import</span> AssetTypes
run_model <span style="color:#00e0e0">=</span> Model<span style="color:#fefefe">(</span>
path<span style="color:#00e0e0">=</span><span style="color:#abe338">f"azureml://jobs/</span><span style="color:#fefefe">{</span>returned_job<span style="color:#fefefe">.</span>name<span style="color:#fefefe">}</span><span style="color:#abe338">/outputs/artifacts/paths/outputs/models/pytorch_lora_weights.safetensors"</span><span style="color:#fefefe">,</span>
name<span style="color:#00e0e0">=</span><span style="color:#abe338">"mano-dreambooth-flux-finetuned"</span><span style="color:#fefefe">,</span>
description<span style="color:#00e0e0">=</span><span style="color:#abe338">"Model created from run."</span><span style="color:#fefefe">,</span>
<span style="color:#abe338">type</span><span style="color:#00e0e0">=</span>AssetTypes<span style="color:#fefefe">.</span>CUSTOM_MODEL<span style="color:#fefefe">,</span>
<span style="color:#fefefe">)</span>
model <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>models<span style="color:#fefefe">.</span>create_or_update<span style="color:#fefefe">(</span>run_model<span style="color:#fefefe">)</span></code></span></span>
步骤 8:在线管理端点部署
现在让我们将这个经过微调的模型部署为 AML 上的在线托管端点。首先让我们定义一些常量变量,以便稍后部署时使用。
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python">endpoint_name <span style="color:#00e0e0">=</span> <span style="color:#abe338">'flux-endpoint-finetuned-a100'</span>
deployment_name <span style="color:#00e0e0">=</span> <span style="color:#abe338">'flux'</span>
instance_type <span style="color:#00e0e0">=</span> 'Standard_NC24ads_A100_v4
score_file <span style="color:#00e0e0">=</span> <span style="color:#abe338">'score.py'</span></code></span></span>
让我们创建一个托管的在线端点。
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#d4d0ab"># create an online endpoint</span>
endpoint <span style="color:#00e0e0">=</span> ManagedOnlineEndpoint<span style="color:#fefefe">(</span>
name<span style="color:#00e0e0">=</span>endpoint_name<span style="color:#fefefe">,</span>
description<span style="color:#00e0e0">=</span><span style="color:#abe338">"this is the flux inference online endpoint"</span><span style="color:#fefefe">,</span>
auth_mode<span style="color:#00e0e0">=</span><span style="color:#abe338">"key"</span>
<span style="color:#fefefe">)</span>
workspace_ml_client<span style="color:#fefefe">.</span>online_endpoints<span style="color:#fefefe">.</span>begin_create_or_update<span style="color:#fefefe">(</span>endpoint<span style="color:#fefefe">)</span></code></span></span>
步骤 9:为在线端点创建推理环境
首先,我们创建一个 Dockerfile,它将在创建环境时使用。
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-bash">FROM mcr.microsoft.com/aifx/acpt/stable-ubuntu2004-cu121-py310-torch22x:biweekly.202408.3
<span style="color:#d4d0ab"># Install pip dependencies</span>
COPY requirements.txt <span style="color:#abe338">.</span>
RUN pip <span style="color:#ffd700">install</span> -r requirements.txt --no-cache-dir
<span style="color:#d4d0ab"># Inference requirements</span>
COPY --from<span style="color:#00e0e0">=</span>mcr.microsoft.com/azureml/o16n-base/python-assets:20230419.v1 /artifacts /var/
RUN /var/requirements/install_system_requirements.sh <span style="color:#00e0e0">&&</span> <span style="color:#fefefe">\</span>
<span style="color:#ffd700">cp</span> /var/configuration/rsyslog.conf /etc/rsyslog.conf <span style="color:#00e0e0">&&</span> <span style="color:#fefefe">\</span>
<span style="color:#ffd700">cp</span> /var/configuration/nginx.conf /etc/nginx/sites-available/app <span style="color:#00e0e0">&&</span> <span style="color:#fefefe">\</span>
<span style="color:#ffd700">ln</span> -sf /etc/nginx/sites-available/app /etc/nginx/sites-enabled/app <span style="color:#00e0e0">&&</span> <span style="color:#fefefe">\</span>
<span style="color:#ffd700">rm</span> -f /etc/nginx/sites-enabled/default
ENV <span style="color:#00e0e0">SVDIR</span><span style="color:#00e0e0">=</span>/var/runit
ENV <span style="color:#00e0e0">WORKER_TIMEOUT</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">400</span>
EXPOSE <span style="color:#00e0e0">5001</span> <span style="color:#00e0e0">8883</span> <span style="color:#00e0e0">8888</span>
<span style="color:#d4d0ab"># support Deepspeed launcher requirement of passwordless ssh login</span>
RUN <span style="color:#ffd700">apt-get</span> update
RUN <span style="color:#ffd700">apt-get</span> <span style="color:#ffd700">install</span> -y openssh-server openssh-client
</code></span></span>
此Dockefile的requirements.txt部分如下。
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-applescript">azureml<span style="color:#00e0e0">-</span>core<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">1.57</span><span style="color:#00e0e0">.0</span>
azureml<span style="color:#00e0e0">-</span>dataset<span style="color:#00e0e0">-</span>runtime<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">1.57</span><span style="color:#00e0e0">.0</span>
azureml<span style="color:#00e0e0">-</span>defaults<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">1.57</span><span style="color:#00e0e0">.0</span>
azure<span style="color:#00e0e0">-</span>ml<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">0.0</span><span style="color:#00e0e0">.1</span>
azure<span style="color:#00e0e0">-</span>ml<span style="color:#00e0e0">-</span>component<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">0.9</span><span style="color:#00e0e0">.18</span>.post2
azureml<span style="color:#00e0e0">-</span>mlflow<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">1.57</span><span style="color:#00e0e0">.0</span>
azureml<span style="color:#00e0e0">-</span>contrib<span style="color:#00e0e0">-</span>services<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">1.57</span><span style="color:#00e0e0">.0</span>
azureml<span style="color:#00e0e0">-</span>contrib<span style="color:#00e0e0">-</span>services<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">1.57</span><span style="color:#00e0e0">.0</span>
torch<span style="color:#00e0e0">-</span>tb<span style="color:#00e0e0">-</span>profiler~<span style="color:#00e0e0">=</span><span style="color:#00e0e0">0.4</span><span style="color:#00e0e0">.0</span>
azureml<span style="color:#00e0e0">-</span>inference<span style="color:#00e0e0">-</span>server<span style="color:#00e0e0">-</span>http
inference<span style="color:#00e0e0">-</span>schema
MarkupSafe<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">2.1</span><span style="color:#00e0e0">.2</span>
regex
pybind11
urllib3<span style="color:#00e0e0">>=</span><span style="color:#00e0e0">1.26</span><span style="color:#00e0e0">.18</span>
cryptography<span style="color:#00e0e0">>=</span><span style="color:#00e0e0">42.0</span><span style="color:#00e0e0">.4</span>
aiohttp<span style="color:#00e0e0">>=</span><span style="color:#00e0e0">3.8</span><span style="color:#00e0e0">.5</span>
py<span style="color:#00e0e0">-</span>spy<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">0.3</span><span style="color:#00e0e0">.12</span>
debugpy~<span style="color:#00e0e0">=</span><span style="color:#00e0e0">1.6</span><span style="color:#00e0e0">.3</span>
ipykernel~<span style="color:#00e0e0">=</span><span style="color:#00e0e0">6.0</span>
tensorboard
psutil~<span style="color:#00e0e0">=</span><span style="color:#00e0e0">5.8</span><span style="color:#00e0e0">.0</span>
matplotlib~<span style="color:#00e0e0">=</span><span style="color:#00e0e0">3.5</span><span style="color:#00e0e0">.0</span>
tqdm~<span style="color:#00e0e0">=</span><span style="color:#00e0e0">4.66</span><span style="color:#00e0e0">.3</span>
py<span style="color:#00e0e0">-</span>cpuinfo<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">5.0</span><span style="color:#00e0e0">.0</span>
torch<span style="color:#00e0e0">-</span>tb<span style="color:#00e0e0">-</span>profiler~<span style="color:#00e0e0">=</span><span style="color:#00e0e0">0.4</span><span style="color:#00e0e0">.0</span>
transformers<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">4.44</span><span style="color:#00e0e0">.2</span>
diffusers<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">0.30</span><span style="color:#00e0e0">.1</span>
accelerate<span style="color:#00e0e0">>=</span><span style="color:#00e0e0">0.31</span><span style="color:#00e0e0">.0</span>
sentencepiece
peft
bitsandbytes</code></span></span>
确保文件夹结构符合以下格式
inference-env/python-and-pip
Dockerfile
requirements.txt
最后,让我们运行下面的代码来为 Flux LORA 模型创建推理环境
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-applescript">env_docker_context <span style="color:#00e0e0">=</span> Environment<span style="color:#fefefe">(</span>
build<span style="color:#00e0e0">=</span>BuildContext<span style="color:#fefefe">(</span>path<span style="color:#00e0e0">=</span><span style="color:#abe338">"docker-contexts/python-and-pip"</span><span style="color:#fefefe">)</span><span style="color:#fefefe">,</span>
name<span style="color:#00e0e0">=</span>inference_env_name<span style="color:#fefefe">,</span>
description<span style="color:#00e0e0">=</span><span style="color:#abe338">"Environment created from a Docker context."</span><span style="color:#fefefe">,</span>
<span style="color:#fefefe">)</span>
ml_client.environments.create_or_update<span style="color:#fefefe">(</span>env_docker_context<span style="color:#fefefe">)</span></code></span></span>
步骤 10:为托管在线端点创建部署
最后,让我们将模型部署到我们创建的端点。让我们创建一个名为 score.py 的文件并将其放在名为 assets 的文件夹下。
assets/
score.py
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#00e0e0">import</span> torch
<span style="color:#00e0e0">import</span> io
<span style="color:#00e0e0">import</span> os
<span style="color:#00e0e0">import</span> logging
<span style="color:#00e0e0">import</span> json
<span style="color:#00e0e0">import</span> math
<span style="color:#00e0e0">import</span> numpy <span style="color:#00e0e0">as</span> np
<span style="color:#00e0e0">from</span> base64 <span style="color:#00e0e0">import</span> b64encode
<span style="color:#00e0e0">import</span> requests
<span style="color:#00e0e0">from</span> PIL <span style="color:#00e0e0">import</span> Image<span style="color:#fefefe">,</span> ImageDraw
<span style="color:#00e0e0">from</span> safetensors<span style="color:#fefefe">.</span>torch <span style="color:#00e0e0">import</span> load_file
<span style="color:#00e0e0">from</span> azureml<span style="color:#fefefe">.</span>contrib<span style="color:#fefefe">.</span>services<span style="color:#fefefe">.</span>aml_response <span style="color:#00e0e0">import</span> AMLResponse
<span style="color:#00e0e0">from</span> transformers <span style="color:#00e0e0">import</span> pipeline
<span style="color:#00e0e0">from</span> diffusers <span style="color:#00e0e0">import</span> DiffusionPipeline<span style="color:#fefefe">,</span> StableDiffusionXLImg2ImgPipeline
<span style="color:#00e0e0">from</span> diffusers <span style="color:#00e0e0">import</span> AutoPipelineForText2Image<span style="color:#fefefe">,</span> FluxPipeline
<span style="color:#00e0e0">from</span> diffusers<span style="color:#fefefe">.</span>schedulers <span style="color:#00e0e0">import</span> EulerAncestralDiscreteScheduler
<span style="color:#00e0e0">from</span> diffusers <span style="color:#00e0e0">import</span> DPMSolverMultistepScheduler
device <span style="color:#00e0e0">=</span> torch<span style="color:#fefefe">.</span>device<span style="color:#fefefe">(</span><span style="color:#abe338">"cuda"</span> <span style="color:#00e0e0">if</span> torch<span style="color:#fefefe">.</span>cuda<span style="color:#fefefe">.</span>is_available<span style="color:#fefefe">(</span><span style="color:#fefefe">)</span> <span style="color:#00e0e0">else</span> <span style="color:#abe338">"cpu"</span><span style="color:#fefefe">)</span>
<span style="color:#00e0e0">def</span> <span style="color:#ffd700">init</span><span style="color:#fefefe">(</span><span style="color:#fefefe">)</span><span style="color:#fefefe">:</span>
<span style="color:#abe338">"""
This function is called when the container is initialized/started, typically after create/update of the deployment.
You can write the logic here to perform init operations like caching the model in memory
"""</span>
<span style="color:#00e0e0">global</span> pipe<span style="color:#fefefe">,</span> refiner
weights_path <span style="color:#00e0e0">=</span> os<span style="color:#fefefe">.</span>path<span style="color:#fefefe">.</span>join<span style="color:#fefefe">(</span>
os<span style="color:#fefefe">.</span>getenv<span style="color:#fefefe">(</span><span style="color:#abe338">"AZUREML_MODEL_DIR"</span><span style="color:#fefefe">)</span><span style="color:#fefefe">,</span> <span style="color:#abe338">"pytorch_lora_weights.safetensors"</span>
<span style="color:#fefefe">)</span>
<span style="color:#00e0e0">print</span><span style="color:#fefefe">(</span><span style="color:#abe338">"weights_path:"</span><span style="color:#fefefe">,</span> weights_path<span style="color:#fefefe">)</span>
pipe <span style="color:#00e0e0">=</span> FluxPipeline<span style="color:#fefefe">.</span>from_pretrained<span style="color:#fefefe">(</span><span style="color:#abe338">"black-forest-labs/FLUX.1-dev"</span><span style="color:#fefefe">,</span> torch_dtype<span style="color:#00e0e0">=</span>torch<span style="color:#fefefe">.</span>bfloat16<span style="color:#fefefe">)</span>
pipe<span style="color:#fefefe">.</span>enable_model_cpu_offload<span style="color:#fefefe">(</span><span style="color:#fefefe">)</span>
pipe<span style="color:#fefefe">.</span>load_lora_weights<span style="color:#fefefe">(</span>weights_path<span style="color:#fefefe">,</span> use_safetensors<span style="color:#00e0e0">=</span><span style="color:#00e0e0">True</span><span style="color:#fefefe">)</span>
pipe<span style="color:#fefefe">.</span>to<span style="color:#fefefe">(</span>device<span style="color:#fefefe">)</span>
<span style="color:#d4d0ab"># refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(</span>
<span style="color:#d4d0ab"># "stabilityai/stable-diffusion-xl-refiner-1.0", </span>
<span style="color:#d4d0ab"># torch_dtype=torch.float16, </span>
<span style="color:#d4d0ab"># use_safetensors=True, </span>
<span style="color:#d4d0ab"># variant="fp16"</span>
<span style="color:#d4d0ab"># )</span>
<span style="color:#d4d0ab"># refiner.to(device)</span>
logging<span style="color:#fefefe">.</span>info<span style="color:#fefefe">(</span><span style="color:#abe338">"Init complete"</span><span style="color:#fefefe">)</span>
<span style="color:#00e0e0">def</span> <span style="color:#ffd700">get_image_object</span><span style="color:#fefefe">(</span>image_url<span style="color:#fefefe">)</span><span style="color:#fefefe">:</span>
<span style="color:#abe338">"""
This function takes an image URL and returns an Image object.
"""</span>
response <span style="color:#00e0e0">=</span> requests<span style="color:#fefefe">.</span>get<span style="color:#fefefe">(</span>image_url<span style="color:#fefefe">)</span>
init_image <span style="color:#00e0e0">=</span> Image<span style="color:#fefefe">.</span><span style="color:#abe338">open</span><span style="color:#fefefe">(</span>io<span style="color:#fefefe">.</span>BytesIO<span style="color:#fefefe">(</span>response<span style="color:#fefefe">.</span>content<span style="color:#fefefe">)</span><span style="color:#fefefe">.</span>convert<span style="color:#fefefe">(</span><span style="color:#abe338">"RGB"</span><span style="color:#fefefe">)</span><span style="color:#fefefe">)</span>
<span style="color:#00e0e0">return</span> init_image
<span style="color:#00e0e0">def</span> <span style="color:#ffd700">prepare_response</span><span style="color:#fefefe">(</span>images<span style="color:#fefefe">)</span><span style="color:#fefefe">:</span>
<span style="color:#abe338">"""
This function takes a list of images and converts them to a dictionary of base64 encoded strings.
"""</span>
ENCODING <span style="color:#00e0e0">=</span> <span style="color:#abe338">'utf-8'</span>
dic_response <span style="color:#00e0e0">=</span> <span style="color:#fefefe">{</span><span style="color:#fefefe">}</span>
<span style="color:#00e0e0">for</span> i<span style="color:#fefefe">,</span> image <span style="color:#00e0e0">in</span> <span style="color:#abe338">enumerate</span><span style="color:#fefefe">(</span>images<span style="color:#fefefe">)</span><span style="color:#fefefe">:</span>
output <span style="color:#00e0e0">=</span> io<span style="color:#fefefe">.</span>BytesIO<span style="color:#fefefe">(</span><span style="color:#fefefe">)</span>
image<span style="color:#fefefe">.</span>save<span style="color:#fefefe">(</span>output<span style="color:#fefefe">,</span> <span style="color:#abe338">format</span><span style="color:#00e0e0">=</span><span style="color:#abe338">"JPEG"</span><span style="color:#fefefe">)</span>
base64_bytes <span style="color:#00e0e0">=</span> b64encode<span style="color:#fefefe">(</span>output<span style="color:#fefefe">.</span>getvalue<span style="color:#fefefe">(</span><span style="color:#fefefe">)</span><span style="color:#fefefe">)</span>
base64_string <span style="color:#00e0e0">=</span> base64_bytes<span style="color:#fefefe">.</span>decode<span style="color:#fefefe">(</span>ENCODING<span style="color:#fefefe">)</span>
dic_response<span style="color:#fefefe">[</span><span style="color:#abe338">f'image_</span><span style="color:#fefefe">{</span>i<span style="color:#fefefe">}</span><span style="color:#abe338">'</span><span style="color:#fefefe">]</span> <span style="color:#00e0e0">=</span> base64_string
<span style="color:#00e0e0">return</span> dic_response
<span style="color:#00e0e0">def</span> <span style="color:#ffd700">design</span><span style="color:#fefefe">(</span>prompt<span style="color:#fefefe">,</span> image<span style="color:#00e0e0">=</span><span style="color:#00e0e0">None</span><span style="color:#fefefe">,</span> num_images_per_prompt<span style="color:#00e0e0">=</span><span style="color:#00e0e0">4</span><span style="color:#fefefe">,</span> negative_prompt<span style="color:#00e0e0">=</span><span style="color:#00e0e0">None</span><span style="color:#fefefe">,</span> strength<span style="color:#00e0e0">=</span><span style="color:#00e0e0">0.65</span><span style="color:#fefefe">,</span> guidance_scale<span style="color:#00e0e0">=</span><span style="color:#00e0e0">7.5</span><span style="color:#fefefe">,</span> num_inference_steps<span style="color:#00e0e0">=</span><span style="color:#00e0e0">50</span><span style="color:#fefefe">,</span> seed<span style="color:#00e0e0">=</span><span style="color:#00e0e0">None</span><span style="color:#fefefe">,</span> design_type<span style="color:#00e0e0">=</span><span style="color:#abe338">'TXT_TO_IMG'</span><span style="color:#fefefe">,</span> mask<span style="color:#00e0e0">=</span><span style="color:#00e0e0">None</span><span style="color:#fefefe">,</span> other_args<span style="color:#00e0e0">=</span><span style="color:#00e0e0">None</span><span style="color:#fefefe">)</span><span style="color:#fefefe">:</span>
<span style="color:#abe338">"""
This function takes various parameters like prompt, image, seed, design_type, etc., and generates images based on the specified design type. It returns a list of generated images.
"""</span>
generator <span style="color:#00e0e0">=</span> <span style="color:#00e0e0">None</span>
<span style="color:#00e0e0">if</span> seed<span style="color:#fefefe">:</span>
generator <span style="color:#00e0e0">=</span> torch<span style="color:#fefefe">.</span>manual_seed<span style="color:#fefefe">(</span>seed<span style="color:#fefefe">)</span>
<span style="color:#00e0e0">else</span><span style="color:#fefefe">:</span>
generator <span style="color:#00e0e0">=</span> torch<span style="color:#fefefe">.</span>manual_seed<span style="color:#fefefe">(</span><span style="color:#00e0e0">0</span><span style="color:#fefefe">)</span>
<span style="color:#00e0e0">print</span><span style="color:#fefefe">(</span><span style="color:#abe338">'other_args'</span><span style="color:#fefefe">,</span> other_args<span style="color:#fefefe">)</span>
image <span style="color:#00e0e0">=</span> pipe<span style="color:#fefefe">(</span>prompt<span style="color:#00e0e0">=</span>prompt<span style="color:#fefefe">,</span>
height<span style="color:#00e0e0">=</span><span style="color:#00e0e0">512</span><span style="color:#fefefe">,</span>
width<span style="color:#00e0e0">=</span><span style="color:#00e0e0">768</span><span style="color:#fefefe">,</span>
guidance_scale<span style="color:#00e0e0">=</span>guidance_scale<span style="color:#fefefe">,</span>
output_type<span style="color:#00e0e0">=</span><span style="color:#abe338">"latent"</span><span style="color:#fefefe">,</span>
generator<span style="color:#00e0e0">=</span>generator<span style="color:#fefefe">)</span><span style="color:#fefefe">.</span>images<span style="color:#fefefe">[</span><span style="color:#00e0e0">0</span><span style="color:#fefefe">]</span>
<span style="color:#d4d0ab">#image = refiner(prompt=prompt, image=image[None, :], generator=generator).images[0] </span>
<span style="color:#00e0e0">return</span> <span style="color:#fefefe">[</span>image<span style="color:#fefefe">]</span>
<span style="color:#00e0e0">def</span> <span style="color:#ffd700">run</span><span style="color:#fefefe">(</span>raw_data<span style="color:#fefefe">)</span><span style="color:#fefefe">:</span>
<span style="color:#abe338">"""
This function takes raw data as input, processes it, and calls the design function to generate images.
It then prepares the response and returns it.
"""</span>
logging<span style="color:#fefefe">.</span>info<span style="color:#fefefe">(</span><span style="color:#abe338">"Request received"</span><span style="color:#fefefe">)</span>
<span style="color:#00e0e0">print</span><span style="color:#fefefe">(</span><span style="color:#abe338">f'raw data: </span><span style="color:#fefefe">{</span>raw_data<span style="color:#fefefe">}</span><span style="color:#abe338">'</span><span style="color:#fefefe">)</span>
data <span style="color:#00e0e0">=</span> json<span style="color:#fefefe">.</span>loads<span style="color:#fefefe">(</span>raw_data<span style="color:#fefefe">)</span><span style="color:#fefefe">[</span><span style="color:#abe338">"data"</span><span style="color:#fefefe">]</span>
<span style="color:#00e0e0">print</span><span style="color:#fefefe">(</span><span style="color:#abe338">f'data: </span><span style="color:#fefefe">{</span>data<span style="color:#fefefe">}</span><span style="color:#abe338">'</span><span style="color:#fefefe">)</span>
prompt <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'prompt'</span><span style="color:#fefefe">]</span>
negative_prompt <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'negative_prompt'</span><span style="color:#fefefe">]</span>
seed <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'seed'</span><span style="color:#fefefe">]</span>
num_images_per_prompt <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'num_images_per_prompt'</span><span style="color:#fefefe">]</span>
guidance_scale <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'guidance_scale'</span><span style="color:#fefefe">]</span>
num_inference_steps <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'num_inference_steps'</span><span style="color:#fefefe">]</span>
design_type <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'design_type'</span><span style="color:#fefefe">]</span>
image_url <span style="color:#00e0e0">=</span> <span style="color:#00e0e0">None</span>
mask_url <span style="color:#00e0e0">=</span> <span style="color:#00e0e0">None</span>
mask <span style="color:#00e0e0">=</span> <span style="color:#00e0e0">None</span>
other_args <span style="color:#00e0e0">=</span> <span style="color:#00e0e0">None</span>
image <span style="color:#00e0e0">=</span> <span style="color:#00e0e0">None</span>
strength <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'strength'</span><span style="color:#fefefe">]</span>
<span style="color:#00e0e0">if</span> <span style="color:#abe338">'mask_image'</span> <span style="color:#00e0e0">in</span> data<span style="color:#fefefe">:</span>
mask_url <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'mask_image'</span><span style="color:#fefefe">]</span>
mask <span style="color:#00e0e0">=</span> get_image_object<span style="color:#fefefe">(</span>mask_url<span style="color:#fefefe">)</span>
<span style="color:#00e0e0">if</span> <span style="color:#abe338">'other_args'</span> <span style="color:#00e0e0">in</span> data<span style="color:#fefefe">:</span>
other_args <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'other_args'</span><span style="color:#fefefe">]</span>
<span style="color:#00e0e0">if</span> <span style="color:#abe338">'image_url'</span> <span style="color:#00e0e0">in</span> data<span style="color:#fefefe">:</span>
image_url <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'image_url'</span><span style="color:#fefefe">]</span>
image <span style="color:#00e0e0">=</span> get_image_object<span style="color:#fefefe">(</span>image_url<span style="color:#fefefe">)</span>
<span style="color:#00e0e0">if</span> <span style="color:#abe338">'strength'</span> <span style="color:#00e0e0">in</span> data<span style="color:#fefefe">:</span>
strength <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'strength'</span><span style="color:#fefefe">]</span>
<span style="color:#00e0e0">with</span> torch<span style="color:#fefefe">.</span>inference_mode<span style="color:#fefefe">(</span><span style="color:#fefefe">)</span><span style="color:#fefefe">:</span>
images <span style="color:#00e0e0">=</span> design<span style="color:#fefefe">(</span>prompt<span style="color:#00e0e0">=</span>prompt<span style="color:#fefefe">,</span> image<span style="color:#00e0e0">=</span>image<span style="color:#fefefe">,</span>
num_images_per_prompt<span style="color:#00e0e0">=</span>num_images_per_prompt<span style="color:#fefefe">,</span>
negative_prompt<span style="color:#00e0e0">=</span>negative_prompt<span style="color:#fefefe">,</span> strength<span style="color:#00e0e0">=</span>strength<span style="color:#fefefe">,</span>
guidance_scale<span style="color:#00e0e0">=</span>guidance_scale<span style="color:#fefefe">,</span> num_inference_steps<span style="color:#00e0e0">=</span>num_inference_steps<span style="color:#fefefe">,</span>
seed<span style="color:#00e0e0">=</span>seed<span style="color:#fefefe">,</span> design_type<span style="color:#00e0e0">=</span>design_type<span style="color:#fefefe">,</span> mask<span style="color:#00e0e0">=</span>mask<span style="color:#fefefe">,</span> other_args<span style="color:#00e0e0">=</span>other_args<span style="color:#fefefe">)</span>
preped_response <span style="color:#00e0e0">=</span> prepare_response<span style="color:#fefefe">(</span>images<span style="color:#fefefe">)</span>
resp <span style="color:#00e0e0">=</span> AMLResponse<span style="color:#fefefe">(</span>message<span style="color:#00e0e0">=</span>preped_response<span style="color:#fefefe">,</span> status_code<span style="color:#00e0e0">=</span><span style="color:#00e0e0">200</span><span style="color:#fefefe">,</span> json_str<span style="color:#00e0e0">=</span><span style="color:#00e0e0">True</span><span style="color:#fefefe">)</span>
<span style="color:#00e0e0">return</span> resp
</code></span></span>
最后我们可以继续部署它。
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-applescript">deployment <span style="color:#00e0e0">=</span> ManagedOnlineDeployment<span style="color:#fefefe">(</span>
name<span style="color:#00e0e0">=</span>deployment_name<span style="color:#fefefe">,</span>
endpoint_name<span style="color:#00e0e0">=</span>endpoint_name<span style="color:#fefefe">,</span>
model<span style="color:#00e0e0">=</span>model<span style="color:#fefefe">,</span>
environment<span style="color:#00e0e0">=</span>env_docker_context<span style="color:#fefefe">,</span>
code_configuration<span style="color:#00e0e0">=</span>CodeConfiguration<span style="color:#fefefe">(</span>
code<span style="color:#00e0e0">=</span><span style="color:#abe338">"assets"</span><span style="color:#fefefe">,</span> scoring_script<span style="color:#00e0e0">=</span>score_file
<span style="color:#fefefe">)</span><span style="color:#fefefe">,</span>
instance_type<span style="color:#00e0e0">=</span>instance_type<span style="color:#fefefe">,</span>
instance_count<span style="color:#00e0e0">=</span><span style="color:#00e0e0">1</span><span style="color:#fefefe">,</span>
request_settings<span style="color:#00e0e0">=</span>OnlineRequestSettings<span style="color:#fefefe">(</span>request_timeout_ms<span style="color:#00e0e0">=</span><span style="color:#00e0e0">90000</span><span style="color:#fefefe">,</span> max_queue_wait_ms<span style="color:#00e0e0">=</span><span style="color:#00e0e0">900000</span><span style="color:#fefefe">,</span> max_concurrent_requests_per_instance<span style="color:#00e0e0">=</span><span style="color:#00e0e0">5</span><span style="color:#fefefe">)</span><span style="color:#fefefe">,</span>
liveness_probe<span style="color:#00e0e0">=</span>ProbeSettings<span style="color:#fefefe">(</span>
failure_threshold<span style="color:#00e0e0">=</span><span style="color:#00e0e0">30</span><span style="color:#fefefe">,</span>
success_threshold<span style="color:#00e0e0">=</span><span style="color:#00e0e0">1</span><span style="color:#fefefe">,</span>
<span style="color:#00e0e0">timeout</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">2</span><span style="color:#fefefe">,</span>
period<span style="color:#00e0e0">=</span><span style="color:#00e0e0">10</span><span style="color:#fefefe">,</span>
initial_delay<span style="color:#00e0e0">=</span><span style="color:#00e0e0">1000</span><span style="color:#fefefe">,</span>
<span style="color:#fefefe">)</span><span style="color:#fefefe">,</span>
readiness_probe<span style="color:#00e0e0">=</span>ProbeSettings<span style="color:#fefefe">(</span>
failure_threshold<span style="color:#00e0e0">=</span><span style="color:#00e0e0">10</span><span style="color:#fefefe">,</span>
success_threshold<span style="color:#00e0e0">=</span><span style="color:#00e0e0">1</span><span style="color:#fefefe">,</span>
<span style="color:#00e0e0">timeout</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">10</span><span style="color:#fefefe">,</span>
period<span style="color:#00e0e0">=</span><span style="color:#00e0e0">10</span><span style="color:#fefefe">,</span>
initial_delay<span style="color:#00e0e0">=</span><span style="color:#00e0e0">1000</span><span style="color:#fefefe">,</span>
<span style="color:#fefefe">)</span><span style="color:#fefefe">,</span>
environment_variables <span style="color:#00e0e0">=</span> <span style="color:#fefefe">{</span>'HF_TOKEN'<span style="color:#fefefe">:</span> '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'<span style="color:#fefefe">}</span><span style="color:#fefefe">,</span>
<span style="color:#fefefe">)</span>
workspace_ml_client.online_deployments.begin_create_or_update<span style="color:#fefefe">(</span>deployment<span style="color:#fefefe">)</span>.result<span style="color:#fefefe">(</span><span style="color:#fefefe">)</span></code></span></span>
步骤 11:测试部署
最后,我们可以测试这个端点了。
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#d4d0ab"># Create request json</span>
<span style="color:#00e0e0">import</span> json
request_json <span style="color:#00e0e0">=</span> <span style="color:#fefefe">{</span>
<span style="color:#abe338">"input_data"</span><span style="color:#fefefe">:</span> <span style="color:#fefefe">{</span>
<span style="color:#abe338">"columns"</span><span style="color:#fefefe">:</span> <span style="color:#fefefe">[</span><span style="color:#abe338">"prompt"</span><span style="color:#fefefe">]</span><span style="color:#fefefe">,</span>
<span style="color:#abe338">"index"</span><span style="color:#fefefe">:</span> <span style="color:#fefefe">[</span><span style="color:#00e0e0">0</span><span style="color:#fefefe">]</span><span style="color:#fefefe">,</span>
<span style="color:#abe338">"data"</span><span style="color:#fefefe">:</span> <span style="color:#fefefe">[</span><span style="color:#abe338">"a photo of sks dog in a bucket"</span><span style="color:#fefefe">]</span><span style="color:#fefefe">,</span>
<span style="color:#fefefe">}</span><span style="color:#fefefe">,</span>
<span style="color:#abe338">"params"</span><span style="color:#fefefe">:</span> <span style="color:#fefefe">{</span>
<span style="color:#abe338">"height"</span><span style="color:#fefefe">:</span> <span style="color:#00e0e0">512</span><span style="color:#fefefe">,</span>
<span style="color:#abe338">"width"</span><span style="color:#fefefe">:</span> <span style="color:#00e0e0">512</span><span style="color:#fefefe">,</span>
<span style="color:#abe338">"num_inference_steps"</span><span style="color:#fefefe">:</span> <span style="color:#00e0e0">50</span><span style="color:#fefefe">,</span>
<span style="color:#abe338">"guidance_scale"</span><span style="color:#fefefe">:</span> <span style="color:#00e0e0">7.5</span><span style="color:#fefefe">,</span>
<span style="color:#abe338">"negative_prompt"</span><span style="color:#fefefe">:</span> <span style="color:#fefefe">[</span><span style="color:#abe338">"blurry; three legs"</span><span style="color:#fefefe">]</span><span style="color:#fefefe">,</span>
<span style="color:#abe338">"num_images_per_prompt"</span><span style="color:#fefefe">:</span> <span style="color:#00e0e0">2</span><span style="color:#fefefe">,</span>
<span style="color:#fefefe">}</span><span style="color:#fefefe">,</span>
<span style="color:#fefefe">}</span>
request_file_name <span style="color:#00e0e0">=</span> <span style="color:#abe338">"sample_request_data.json"</span>
<span style="color:#00e0e0">with</span> <span style="color:#abe338">open</span><span style="color:#fefefe">(</span>request_file_name<span style="color:#fefefe">,</span> <span style="color:#abe338">"w"</span><span style="color:#fefefe">)</span> <span style="color:#00e0e0">as</span> request_file<span style="color:#fefefe">:</span>
json<span style="color:#fefefe">.</span>dump<span style="color:#fefefe">(</span>request_json<span style="color:#fefefe">,</span> request_file<span style="color:#fefefe">)</span>
responses <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>online_endpoints<span style="color:#fefefe">.</span>invoke<span style="color:#fefefe">(</span>
endpoint_name<span style="color:#00e0e0">=</span>online_endpoint_name<span style="color:#fefefe">,</span>
deployment_name<span style="color:#00e0e0">=</span>deployment_name<span style="color:#fefefe">,</span>
request_file<span style="color:#00e0e0">=</span>request_file_name<span style="color:#fefefe">,</span>
<span style="color:#fefefe">)</span>
responses <span style="color:#00e0e0">=</span> json<span style="color:#fefefe">.</span>loads<span style="color:#fefefe">(</span>responses<span style="color:#fefefe">)</span>
<span style="color:#00e0e0">import</span> base64
<span style="color:#00e0e0">from</span> io <span style="color:#00e0e0">import</span> BytesIO
<span style="color:#00e0e0">from</span> PIL <span style="color:#00e0e0">import</span> Image
<span style="color:#00e0e0">for</span> response <span style="color:#00e0e0">in</span> responses<span style="color:#fefefe">:</span>
base64_string <span style="color:#00e0e0">=</span> response<span style="color:#fefefe">[</span><span style="color:#abe338">"generated_image"</span><span style="color:#fefefe">]</span>
image_stream <span style="color:#00e0e0">=</span> BytesIO<span style="color:#fefefe">(</span>base64<span style="color:#fefefe">.</span>b64decode<span style="color:#fefefe">(</span>base64_string<span style="color:#fefefe">)</span><span style="color:#fefefe">)</span>
image <span style="color:#00e0e0">=</span> Image<span style="color:#fefefe">.</span><span style="color:#abe338">open</span><span style="color:#fefefe">(</span>image_stream<span style="color:#fefefe">)</span>
display<span style="color:#fefefe">(</span>image<span style="color:#fefefe">)</span></code></span></span>
结论
使用 Dreambooth 对 FLUX 模型进行微调是针对特定应用定制生成式 AI 模型的有效方法。按照本博客中概述的步骤,您可以利用 FLUX.1 [dev] 模型的优势,并使用您独特的数据集对其进行增强,从而实现高质量、个性化的输出。无论您是在从事创意项目、研究还是商业应用,这种方法都可以为提升您的 AI 能力提供强大的解决方案