当前位置: 首页 > article >正文

AI学习记录 - 对抗性神经网络

有用点赞哦

学习机器学习到一定程度之后,一般会先看他的损失函数是什么,看他的训练集是什么,训练集是什么,代表我使用模型的时候,输入是什么类型的数据。

对抗神经网络其实可以这样子理解,网上一直说生成器和判别器的概念,没有触及到本质。

我有一种看法:假如当前场景是输入模糊图片,然后输出高质量图片。当判别器和生成器本来就是一个模型,在不把判别器生成器拆开的时候,我输入一张图片,这个模型输出的是0和1,那这个整体模型的作用就是判断这个图片是不是高清图片,训练集是【模糊图片,0】,【高清图片,1】,通过这种方式进行反向传播。但是现在的目的不是判断他是不是高质量图片,而是要他给我生成高质量图片,所以要在中间切开变成两份,前一份模型修改一下输出,让他输出的是一张图片的格式矩阵,后一份模型修改一下输入,让他输入的是一张图片的格式矩阵,就像个协议一样前一半后一半规格一样就可以接起来,你可能会说我不改变他的输入输出本来的规格也一样啊,但是我说了规格变成一张图片,那人不就是看懂了吗,原来中间切开的输出,人看不懂啊,也不符合我要实现输入模糊图片,输出高清图片的目的。

判别器:判别器是单独训练

训练集就是【真实图片, 1】,【虚假图片, 0】,一个批次有2种训练集,大概如下

第一种:【【真实图片, 1】,【真实图片, 1】,【真实图片, 1】,【真实图片, 1】】
第二种:【【虚假图片, 0】,【虚假图片, 0】,【虚假图片, 0】,【虚假图片, 0】】

所以会产生两种损失:

真实图片的损失
	discriminator.train()
	d_loss1 = discriminator(images, real_targets)
利用生成器生成图片,然后丢到判别器
	 with torch.no_grad():
	     generated_images = generator(batch_size)
	
	 # Loss with generated image inputs and fake_targets as labels
	 d_loss2 = discriminator(generated_images, fake_targets)
两种损失加起来,然后反向传播
	d_loss = d_loss1 + d_loss2
	d_optimizer.zero_grad()
	d_loss.backward()
	d_optimizer.step()

生成器:生成器的训练是将 生成器 和 判别器 整合一起训练,我还没完全理解,后续再看看

反向传播的过程,使用损失值对生成器的参数进行反向传播,更新生成器的权重。

单独创建两个优化器,用哪个优化器,就更新哪个模型的权重,这里有两个模型,判别器的权重在这一步保持不变。
# Optimizers
	g_optimizer = torch.optim.Adam(generator.parameters(), lr=g_lr)
	d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=d_lr)
单独更新
    # Generate images in train mode
    generator.train()
    generated_images = generator(batch_size)

    # Loss with generated image inputs and real_targets as labels
    g_loss = discriminator(generated_images, real_targets)

    # Optimizer updates the generator parameters
    g_optimizer.zero_grad()
    g_loss.backward()
    g_optimizer.step()  # 更新生成器权重

http://www.kler.cn/news/289975.html

相关文章:

  • 企业架构的概念及发展历程简述(附TOGAF架构理论学习资料下载链接)
  • macos安装ArgoCD
  • linux nc
  • 【Altium Designer脚本开发】——PCB平面绕组线圈 V1.4
  • 河南建筑智能化设计专项资质延期流程说明
  • 力扣2503.矩阵查询可获得的最大分数
  • 超声波清洗机有没有平价又好用的推荐?高性价比的眼镜清洗机推荐
  • 百度飞将 paddle ,实现贝叶斯神经网络 bayesue neure network bnn,aistudio公开项目 复现效果不好
  • 语言质量评估对欧洲游戏推广的重要性
  • 阿凡达2.0直播模式来了,数字人直播行业迎来浴火重生!
  • django企业开发实战-学习小结
  • 活动系统开发之采用设计模式与非设计模式的区别-数据库设计及代码设计
  • LeetCode37 解数独
  • 【Steam游戏星露谷物语添加Mod步骤】
  • css中calc
  • 【陪诊系统-H5客户端】订单状态进度条
  • 如果已经提交,重新添加gitignore文件,会忽略么
  • 【QT】学习笔记:枚举桌面窗口句柄
  • 代码随想录算法训练营第35天|背包问题基础、46. 携带研究材料(01背包二维解法)(01背包一维解法)(acm)、416. 分割等和子集
  • 解决Vue npm 淘宝镜像证书过期问题
  • Blazor项目中建立WebApi
  • C++使用MyStack和MyQueue封装栈和队列
  • Chrome 浏览器插件获取网页 window 对象(方案一)
  • pip切换清华源
  • 数据结构二叉树——堆
  • Scott Brinker:Martech中的AI会让买家体验更好还是更糟?这取决于…….
  • Unity版本升级2022 Gradle 升级7.x版本调整
  • 代码随想录 刷题记录-27 图论 (4)拓扑排序
  • Rides实现分布式锁,保障数据一致性,Redisson分布式事务处理
  • python学习之路 - PySpark快速入门