当前位置: 首页 > article >正文

GAN原理及代码实现

GAN原理及代码实现

文章目录

    • GAN原理及代码实现
      • 基本介绍
        • 原理
        • 损失函数
      • 数学推导
        • 基本推导
        • 交叉熵
        • 损失求解
      • 代码实现
        • 判别网络
        • 生成网络
        • 损失函数及优化器
        • 数据集准备
        • 开始训练
        • 综合代码
      • 参考文章

基本介绍

原理

GAN中文为生成对抗网络。主要由两个基础网络构成,生成网络G和对抗网络D。其中,生成网络主要用于生成新数据,其生成数据的基础往往是一组噪音或者随机数,而判别网络用于判断生成网络生成的数据和真实数据哪个才是真的。其原理图如下,

在这里插入图片描述

生成网络的目标是生成尽量真实的数据,最好能够以假乱真、让判别网络判断不出来,因此生成网络的学习目标是让判别网络上的判断准确性越来越低;相反,判别网络的目标是尽量判别出生成网络生成数据的真伪,因此判别网络的学习目标是让自己的判断准确性越来越高。

当生成网络生成的数据越来越真时,判别网络为维持住自己的准确性,就必须向判别能力越来越强的方向迭代。当判别网络越来越强大时,生成网络为了降低判别网络的判断准确性,就必须生成越来越真的数据。在这个奇妙的关系中,判别网络与生成网络同时训练、相互内卷,对损失函数的影响此消彼长,彼此博弈。

损失函数

GAN损失函数的创建有两个基本构造原则,

  1. 判别网络要尽可能区别开数据是来自生成网络还是判别网络,也就是数据到底是真实数据还是人造数据,可知,当两个数据之间的分布差异越大,判别网络越容易区分,当生成网络生成数据之后,生成数据的分布已经定下来了,此时判别网络要做的就是尽可能找出真实数据分布和生成数据分布之间的不同点,以此来区分数据的真伪。
  2. 生成网络要尽量生成真实的数据,这样才能骗过判别网络,若生成数据的分布与真实数据完全一样,那么就能够完全骗过判别网络了。

即,

  1. 对于判别网络来说,尽可能找出生成网络生成的数据与真实数据分布之间的差异。
  2. 对于生成网络来说,让生成网络生成的数据分布接近真实数据分布。

我们先给出损失函数的表达式,
m i n G m a x D V ( G , D ) = m i n G m a x D [ E x ∼ p d a t a ( x ) l o g ( D ( x ) ) + E z ∼ p ( z ) l o g ( 1 − D ( G ( z ) ) ] \underset{G}{min}\underset{D}{max}V(G,D)=\underset{G}{min}\underset{D}{max}[E_{x\sim p_{data}(x)}log(D(x))+E_{z\sim p(z)}log(1-D(G(z))] GminDmaxV(G,D)=GminDmax[Expdata(x)log(D(x))+Ezp(z)log(1D(G(z))]
其中, V V V 表示交叉熵, x x x表示任意真实数据, z z z 表示与真实数据相同结构的任意随机数据, G ( z ) G(z) G(z)表示在生成网络中基于 z z z生成的假数据,而 D ( x ) D(x) D(x) 表示判别网络在真实数据 x x x上判断出的结果, D ( G ( z ) ) D(G(z)) D(G(z))表示判别网络在假数据 G ( z ) G(z) G(z)上判断出的结果,其中 D ( x ) D(x) D(x) D ( G ( z ) ) D(G(z)) D(G(z))都是样本为“真”的概率,即标签为1的概率。

可以看出,在求解最优解的过程中存在两个过程:

  1. 固定G,求解令损失函数最大的D
  2. 固定D,求解令损失函数最小的G

数学推导

基本推导

固定G,求解令损失函数最大的D。这对应了第一个原则——对于判别网络来说,尽可能找出生成网络生成的数据与真实数据分布之间的差异。

由于真实数据为 x x x,假设真实数据的分布为 P d a t a ( x ) P_{data}(x) Pdata(x),生成网络接受的数据 z z z服从分布 P ( z ) P(z) P(z),生成网络生成数据的分布为 P G ( x ) P_G(x) PG(x),

D ∗ = a r g m a x D V ( D , G ) D^*=\underset{D}{argmax}V(D,G) D=DargmaxV(D,G)

= a r g m a x D V ( D ) =\underset{D}{argmax}V(D) =DargmaxV(D)

= a r g m a x D { E x ∼ P d a t a [ l o g D ( x ) ] + E z ∼ P ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] } =\underset{D}{argmax}\{E_{x\sim P_{data}}[logD(x)]+E_{z\sim P(z)}[log(1-D(G(z))) ]\} =Dargmax{ExPdata[logD(x)]+EzP(z)[log(1D(G(z)))]}

= a r g m a x D { E x ∼ P d a t a [ l o g D ( x ) ] + E x ∼ P G ( x ) [ l o g ( 1 − D ( x ) ) ] } =\underset{D}{argmax}\{E_{x\sim P_{data}}[logD(x)]+E_{x\sim P_G(x)}[log(1-D(x)) ]\} =Dargmax{ExPdata[logD(x)]+ExPG(x)[log(1D(x))]} 这一步是将 G ( z ) G(z) G(z)替换为 P G ( x ) P_G(x) PG(x)

= a r g m a x D { ∫ P d a t a ( x ) l o g D ( x ) d x + ∫ P G ( x ) l o g ( 1 − D ( x ) ) d x } =\underset{D}{argmax}\{\int P_{data}(x)logD(x)dx+\int P_G(x)log(1-D(x))dx \} =Dargmax{Pdata(x)logD(x)dx+PG(x)log(1D(x))dx} 这一步采用了非连续变量的期望公式

= a r g m a x D { ∫ [ P d a t a ( x ) l o g D ( x ) + P G ( x ) l o g ( 1 − D ( x ) ) ] d x } =\underset{D}{argmax}\{\int [P_{data}(x)logD(x)+P_G(x)log(1-D(x))]dx \} =Dargmax{[Pdata(x)logD(x)+PG(x)log(1D(x))]dx}

= a r g m a x D [ P d a t a ( x ) l o g D ( x ) + P G ( x ) l o g ( 1 − D ( x ) ) ] =\underset{D}{argmax} [P_{data}(x)logD(x)+P_G(x)log(1-D(x))] =Dargmax[Pdata(x)logD(x)+PG(x)log(1D(x))]

求取上面积分的最大值,也就是求解要积分的函数的最大值,因为积分可看作是点的累加,每一个累加项的最大值也就是累加和的最大值。

由于上面最终式子是关于D的一元函数,要求最优的D值,对D求导得,
F = P d a t a ( x ) l o g D ( x ) + P G ( x ) l o g ( 1 − D ( x ) ) F=P_{data}(x)logD(x)+P_G(x)log(1-D(x)) F=Pdata(x)logD(x)+PG(x)log(1D(x))

d F d D ( x ) = P d a t a ( x ) D ( x ) − P G ( x ) 1 − D ( x ) ​ \frac{dF}{dD(x)}=\frac{P_{data}(x)}{D(x)}-\frac{P_G(x)}{1-D(x)}​ dD(x)dF=D(x)Pdata(x)1D(x)PG(x)

令上面式子结果为0得,
d F d D ( x ) = P d a t a ( x ) D ( x ) − P G ( x ) 1 − D ( x ) = 0 D ∗ ( x ) = P d a t a ( x ) P d a t a ( x ) + P G ( x ) \frac{dF}{dD(x)}=\frac{P_{data}(x)}{D(x)}-\frac{P_G(x)}{1-D(x)}=0\\ D^*(x)=\frac{P_{data}(x)}{P_{data}(x)+P_G(x)} dD(x)dF=D(x)Pdata(x)1D(x)PG(x)=0D(x)=Pdata(x)+PG(x)Pdata(x)
将所求的极值点 D ∗ ( x ) D*(x) D(x)代入原式,
D ∗ = a r g m a x D V ( D , G ) = V ( D ∗ , G ) = E x ∼ P d a t a ( x ) [ l o g D ∗ ( x ) ] + E x ∼ P G ( x ) [ l o g ( 1 − D ∗ ( x ) ] = E x ∼ P d a t a ( x ) [ l o g P d a t a ( x ) P d a t a ( x ) + P G ( x ) ] + E x ∼ P G ( x ) [ l o g ( 1 − P d a t a ( x ) P d a t a ( x ) + P G ( x ) ] = ∫ P d a t a ( x ) l o g P d a t a ( x ) P d a t a ( x ) + P G ( x ) d x + ∫ P G ( x ) l o g ( 1 − P d a t a ( x ) P d a t a ( x ) + P G ( x ) ) d x = ∫ P d a t a ( x ) l o g P d a t a ( x ) P d a t a ( x ) + P G ( x ) 2 d x + ∫ P G ( x ) l o g ( 1 − P d a t a ( x ) P d a t a ( x ) + P G ( x ) 2 ) d x − 2 l o g 2 = K L ( P d a t a ( x ) ∣ ∣ P d a t a ( x ) + P G ( x ) 2 ) + K L ( P G ( x ) ∣ ∣ P d a t a ( x ) + P G ( x ) 2 ) − 2 l o g 2 = 2 J S ( P d a t a ( x ) ∣ ∣ P G ( x ) ) − 2 l o g 2 \begin{aligned} D^*&=\underset{D}{argmax}V(D,G)\\ &=V(D*,G)\\ &= E_{x\sim P_{data}(x)}[logD^*(x)]+E_{x\sim P_G(x)}[log(1-D^*(x) ]\\ &= E_{x\sim P_{data(x)}}[log\frac{P_{data}(x)}{P_{data}(x)+P_G(x)}]+E_{x\sim P_G(x)}[log(1-\frac{P_{data}(x)}{P_{data}(x)+P_G(x)} ]\\ &=\int P_{data}(x)log\frac{P_{data}(x)}{P_{data}(x)+P_G(x)}dx+\int P_G(x)log(1-\frac{P_{data}(x)}{P_{data}(x)+P_G(x)})dx \\ &=\int P_{data}(x)log\frac{P_{data}(x)}{\frac{P_{data}(x)+P_G(x)}{2}}dx+\int P_G(x)log(1-\frac{P_{data}(x)}{\frac{P_{data}(x)+P_G(x)}{2}})dx-2log2 \\ &=KL(P_{data}(x)||\frac{P_{data}(x)+P_G(x)}{2})+KL(P_G(x)||\frac{P_{data}(x)+P_G(x)}{2})-2log2\\ &=2JS(P_{data}(x)||P_G(x))-2log2\\ \end{aligned} D=DargmaxV(D,G)=V(D,G)=ExPdata(x)[logD(x)]+ExPG(x)[log(1D(x)]=ExPdata(x)[logPdata(x)+PG(x)Pdata(x)]+ExPG(x)[log(1Pdata(x)+PG(x)Pdata(x)]=Pdata(x)logPdata(x)+PG(x)Pdata(x)dx+PG(x)log(1Pdata(x)+PG(x)Pdata(x))dx=Pdata(x)log2Pdata(x)+PG(x)Pdata(x)dx+PG(x)log(12Pdata(x)+PG(x)Pdata(x))dx2log2=KL(Pdata(x)∣∣2Pdata(x)+PG(x))+KL(PG(x)∣∣2Pdata(x)+PG(x))2log2=2JS(Pdata(x)∣∣PG(x))2log2
上面是固定G,求解令损失函数最大的D,我们接着固定D,求令损失函数最小的G,

可以看出,固定 G G G,将最优的 D D D带入后,此时 m a x ( D , G ) max(D,G) max(D,G),也就是 V ( D ∗ , G ) V(D∗,G) V(D,G),实际上是在度量 P d a t a ( x ) Pdata(x) Pdata(x) P G ( x ) PG(x) PG(x)之间的 J S JS JS散度,同 K L KL KL散度一样,他们之间的分布差异越大, J S JS JS散度值也越大。换句话说:保持 G G G不变,最大化 V ( G , D ) V(G,D) V(G,D)就等价于计算 J S JS JS散度!,现在回过头看交叉熵构造原则中的第一条是不是就更加理解了,对于判别网络来说,尽可能找出生成网络生成的数据与真实数据分布之间的差异,这个差异就是JS散度

上面是固定G,求解令损失函数最大的D。接着我们固定D,求解令损失函数最小的G。这对应了第二个原则——对于生成网络来说,让生成网络生成的数据分布接近真实数据分布。

将上面的结果继续代入损失函数,
m i n G m a x D V ( G , D ) = m i n G V ( G , D ∗ ) = m i n G [ 2 J S ( P d a t a ( x ) ∣ ∣ P G ( x ) ) − 2 l o g 2 ] = m i n G 2 J S ( P d a t a ( x ) ∣ ∣ P G ( x ) ) \begin{aligned} \underset{G}{min}\underset{D}{max}V(G,D)&=\underset{G}{min}V(G,D^*)\\ &=\underset{G}{min}[2JS(P_{data}(x)||P_G(x))-2log2]\\ &=\underset{G}{min}2JS(P_{data}(x)||P_G(x))\\ \end{aligned} GminDmaxV(G,D)=GminV(G,D)=Gmin[2JS(Pdata(x)∣∣PG(x))2log2]=Gmin2JS(Pdata(x)∣∣PG(x))
可以看出,这一步就是在最小化JS散度,JS散度越小,分部之间的差异越小,正好印证了第二个原则。

交叉熵

现在我们回到最初的损失函数是怎么来的问题上面。首先我们需要了解交叉熵的含义。

交叉熵:一般用来求目标与预测值之间的差距。

信息量:越不可能的事发生之后,获取到的信息量就越大;越可能的事发生,获取到的信息量就越小。

假设 X X X是一个离散型随机变量,其取值集合为 χ χ χ,概率分布函数为 p ( x ) = P r ( X = x ) , x ∈ χ p(x)=Pr(X=x),x∈χ p(x)=Pr(X=x)xχ,则定义事件 X = x 0 X=x_0 X=x0的信息量为
I ( x 0 ) = − l o g ( p ( x 0 ) ) ​ I(x_0)=-log(p(x_0))​ I(x0)=log(p(x0))
由于是概率所以 p ( x 0 ) p(x_0) p(x0)的取值范围为 [ 0 , 1 ] [0,1] [0,1],-log(x)的函数图像如下:

在这里插入图片描述

从图像可以看出,概率越小,函数值越大,也就对应越不可能的事情,信息量越大。

熵用来表示所有信息量的期望,
H ( X ) = − ∑ i = 1 n p ( x i ) l o g ( p ( x i ) ) ​ H(X)=−\sum_{i=1}^n {p(x_i)log(p(x_i))}​ H(X)=i=1np(xi)log(p(xi))
其中n代表有n中可能性。

相对熵又称KL散度,如果我们对于同一个随机变量 x x x 有两个单独的概率分布 P ( x ) P(x) P(x) Q ( x ) Q(x) Q(x),我们可以使用 KL 散度(Kullback-Leibler (KL) divergence)来衡量这两个分布的差异。

在机器学习中,P往往用来表示样本的真实分布,比如 [ 1 , 0 , 0 ] [1,0,0] [1,0,0]表示当前样本属于第一类。Q用来表示模型所预测的分布,比如 [ 0.7 , 0.2 , 0.1 ] [0.7,0.2,0.1] [0.7,0.2,0.1]。直观的理解就是如果用P来描述样本,那么就非常完美。而用Q来描述样本,虽然可以大致描述,但是不是那么的完美,信息量不足,需要额外的一些“信息增量”才能达到和P一样完美的描述。如果我们的Q通过反复训练,也能完美的描述样本,那么就不再需要额外的“信息增量”,Q等价于P。
KL散度的计算公式:
D K L ( p ∣ ∣ q ) = ∑ i = 1 n p ( x i ) l o g ( p ( x i ) q ( x i ) ) ​ D_{KL}(p||q)=\sum_{i=1}^n p(x_i)log(\frac{p(x_i)}{q(x_i)})​ DKL(p∣∣q)=i=1np(xi)log(q(xi)p(xi))
n为事件的所有可能性。
DKL的值越小,表示q分布和p分布越接近。

对上面的式子变形可以得到:
KaTeX parse error: Can't use function '$' in math mode at position 135: …(x_i)log(q(x_i)$̲ \end{aligned}
式子的前一部分恰巧就是p的熵,等式的后一部分,就是交叉熵
H ( p , q ) = − ∑ i = 1 n p ( x i ) l o g ( q ( x i ) ) H(p,q)=-\sum_{i=1}^n p(x_i)log(q(x_i)) H(p,q)=i=1np(xi)log(q(xi))
在机器学习中,我们需要评估label和predicts之间的差距,使用KL散度刚刚好,即 D K L ( y ∣ ∣ y ^ ) D_{KL}(y||\hat{y}) DKL(y∣∣y^),由于KL散度中的前一部分 − H ( y ) -H(y) H(y)不变,故在优化过程中,只需要关注交叉熵就可以了。所以一般在机器学习中直接用用交叉熵做loss,评估模型。

二分类交叉熵相当于只有两类,上面式子可以化解为,
H ( p , q ) = − [ p ( x ) l o g ( q ( x ) ) + ( 1 − p ( x ) ) l o g ( 1 − q ( x ) ) ] H(p,q)=-[ p(x)log(q(x))+(1-p(x))log(1-q(x))] H(p,q)=[p(x)log(q(x))+(1p(x))log(1q(x))]
其中 p ( x ) p(x) p(x)是实际标签(0 或 1), q ( x ) q(x) q(x)是模型预测后的概率。

损失求解

我们已知,
V ( G , D ) = E x ∼ P d a t a [ l o g D ( x ) ] + E z ∼ P ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] = E x ∼ P d a t a [ l o g D ( x ) ] + E x ∼ P G ( x ) [ l o g ( 1 − D ( x ) ) ] \begin{aligned} V(G,D)&=E_{x\sim P_{data}}[logD(x)]+E_{z\sim P(z)}[log(1-D(G(z))) ]\\ &=E_{x\sim P_{data}}[logD(x)]+E_{x\sim P_G(x)}[log(1-D(x)) ] \end{aligned} V(G,D)=ExPdata[logD(x)]+EzP(z)[log(1D(G(z)))]=ExPdata[logD(x)]+ExPG(x)[log(1D(x))]
由于我们采用神经网络去拟合概率分布,生成的是具体的样本点,因此可以将上式的期望替换为均值,
V ( G , D ) = E x ∼ P d a t a [ l o g D ( x ) ] + E x ∼ P G ( x ) [ l o g ( 1 − D ( x ) ) ] = 1 n r e a l ∑ l o g D ( x i ) + 1 n f a k e ∑ l o g ( 1 − D ( x i ) ) \begin{aligned} V(G,D)&=E_{x\sim P_{data}}[logD(x)]+E_{x\sim P_G(x)}[log(1-D(x)) ]\\ &=\frac{1}{n_{real}}\sum logD(x_i)+\frac{1}{n_{fake}}\sum log(1-D(x_i)) \end{aligned} V(G,D)=ExPdata[logD(x)]+ExPG(x)[log(1D(x))]=nreal1logD(xi)+nfake1log(1D(xi))
其中, n r e a l n_{real} nreal代表真实样本数据的个数, n f a k e n_{fake} nfake代表生成样本数据的个数。

对于判别网络来说

设现在输入到判别网络的样本全部为真实样本,由于D输出的是二分类概率,我们采用交叉熵损失可以求出判别网络在真实样本上的损失,也即 , C r o s s e n t r o p y ( y p r e = D ( x ) , y t r u e = 1 ) Crossentropy(y_{pre}=D(x),y_{true}=1) Crossentropy(ypre=D(x)ytrue=1)
D l o s s r e a l = − 1 n r e a l ∑ [ y i l o g y i ^ + ( 1 − y i ) l o g ( 1 − y i ^ ) ] = − 1 n r e a l ∑ [ y i l o g D ( x i ) + ( 1 − y i ) l o g ( 1 − D ( x i ) ) ] = − 1 n r e a l ∑ [ 1 ∗ l o g D ( x i ) + ( 1 − 1 ) l o g ( 1 − D ( x i ) ) ] = − 1 n r e a l ∑ l o g D ( x i ) \begin{aligned} Dloss_{real}&=-\frac{1}{n_{real}}\sum [y_ilog\hat{y_i}+(1-y_i)log(1-\hat{y_i})]\\ &=-\frac{1}{n_{real}}\sum [y_ilogD(x_i)+(1-y_i)log(1-D(x_i))]\\ &=-\frac{1}{n_{real}}\sum [1*logD(x_i)+(1-1)log(1-D(x_i))]\\ &=-\frac{1}{n_{real}}\sum logD(x_i)\\ \end{aligned} Dlossreal=nreal1[yilogyi^+(1yi)log(1yi^)]=nreal1[yilogD(xi)+(1yi)log(1D(xi))]=nreal1[1logD(xi)+(11)log(1D(xi))]=nreal1logD(xi)
同理,判别网络在生成样本上的交叉熵损失为, C r o s s e n t r o p y ( y p r e = D ( x ) , y t r u e = 0 ) : Crossentropy(y_{pre}=D(x),y_{true}=0) : Crossentropy(ypre=D(x)ytrue=0)
D l o s s f a k e = − 1 n f a k e ∑ [ y i l o g y i ^ + ( 1 − y i ) l o g ( 1 − y i ^ ) ] = − 1 n f a k e ∑ [ y i l o g D ( x i ) + ( 1 − y i ) l o g ( 1 − D ( x i ) ) ] = − 1 n f a k e ∑ [ 0 ∗ l o g D ( x i ) + ( 1 − 0 ) l o g ( 1 − D ( x i ) ) ] = − 1 n f a k e ∑ l o g ( 1 − D ( x i ) ) \begin{aligned} Dloss_{fake}&=-\frac{1}{n_{fake}}\sum [y_ilog\hat{y_i}+(1-y_i)log(1-\hat{y_i})]\\ &=-\frac{1}{n_{fake}}\sum [y_ilogD(x_i)+(1-y_i)log(1-D(x_i))]\\ &=-\frac{1}{n_{fake}}\sum [0*logD(x_i)+(1-0)log(1-D(x_i))]\\ &=-\frac{1}{n_{fake}}\sum log(1-D(x_i)) \end{aligned} Dlossfake=nfake1[yilogyi^+(1yi)log(1yi^)]=nfake1[yilogD(xi)+(1yi)log(1D(xi))]=nfake1[0logD(xi)+(10)log(1D(xi))]=nfake1log(1D(xi))
所以判别网络的总损失由上面两种情况组成,
D l o s s = D l o s s r e a l + D l o s s f a k e = − 1 n r e a l ∑ l o g D ( x i ) − 1 n f a k e ∑ l o g ( 1 − D ( x i ) ) \begin{aligned} Dloss&=Dloss_{real}+Dloss_{fake}\\ &=-\frac{1}{n_{real}}\sum logD(x_i)-\frac{1}{n_{fake}}\sum log(1-D(x_i)) \end{aligned} Dloss=Dlossreal+Dlossfake=nreal1logD(xi)nfake1log(1D(xi))
V ( G , D ) V(G,D) V(G,D)相比较,
V ( G , D ) = − D l o s s ​ V(G,D)=-Dloss​ V(G,D)=Dloss
因此有,
m a x D V ( G , D ) = m i n D [ − V ( G , D ) ] = m i n D { − [ E x ∼ P d a t a [ l o g D ( x ) ] + E x ∼ P G ( x ) [ l o g ( 1 − D ( x ) ) ] } = m i n D − [ 1 n r e a l ∑ l o g D ( x i ) − 1 n f a k e ∑ l o g ( 1 − D ( x i ) ) ] \begin{aligned} \underset{D}{max}V(G,D)&=\underset{D}{min}[-V(G,D)]\\ &=\underset{D}{min} \{-[E_{x\sim P_{data}}[logD(x)]+E_{x\sim P_G(x)}[log(1-D(x)) ]\}\\ &=\underset{D}{min}-[\frac{1}{n_{real}}\sum logD(x_i)-\frac{1}{n_{fake}}\sum log(1-D(x_i))] \end{aligned} DmaxV(G,D)=Dmin[V(G,D)]=Dmin{[ExPdata[logD(x)]+ExPG(x)[log(1D(x))]}=Dmin[nreal1logD(xi)nfake1log(1D(xi))]
所以损失函数的第一步,固定 G G G,求解令损失函数最小的 D D D,可以通过对判别网络求交叉熵 D l o s s Dloss Dloss来表示。

对于生成网络来说

生成网络的目标为:使得生成的样本尽可能与真实样本一样,也就是使得判别网络在以生成样本作为输入的时候输出的的概率大,这时候我们可以将整个GAN网络看成一个整体,输入为 z z z,也就是和真实数据 x x x拥有相同结构的数据,我们让其实际标签为1,个人认为实际标签为0的其他数据应该是和真实数据 x x x结构差别很大的数据,这种数据判别网络一眼鉴定为假数据(预测概率非常低),其交叉熵损失为 , C r o s s e n t r o p y ( y p r e = D ( G ( z ) ) , y t r u e = 1 ) : Crossentropy(ypre=D(G(z)),ytrue=1) : Crossentropy(ypre=D(G(z))ytrue=1)
G l o s s = − 1 n ∑ [ y i l o g y i ^ + ( 1 − y i ) l o g ( 1 − y i ^ ) ] = − 1 n ∑ [ z i l o g D ( G ( z i ) ) + ( 1 − z i ) l o g ( 1 − D ( G ( z i ) ) ) ] = − 1 n ∑ [ 1 ∗ l o g D ( G ( z i ) ) + ( 1 − 1 ) l o g ( 1 − D ( G ( z i ) ) ) ] = − 1 n ∑ l o g D ( G ( z i ) ) \begin{aligned} Gloss&=-\frac{1}{n}\sum[y_ilog\hat{y_i}+(1-y_i)log(1-\hat{y_i})] \\ &=-\frac{1}{n}\sum [z_ilogD(G(z_i))+(1-z_i)log(1-D(G(z_i)))]\\ &=-\frac{1}{n}\sum [1*logD(G(z_i))+(1-1)log(1-D(G(z_i)))]\\ &=-\frac{1}{n}\sum logD(G(z_i)) \end{aligned} Gloss=n1[yilogyi^+(1yi)log(1yi^)]=n1[zilogD(G(zi))+(1zi)log(1D(G(zi)))]=n1[1logD(G(zi))+(11)log(1D(G(zi)))]=n1logD(G(zi))
对比 V ( G , D ) V(G,D) V(G,D),其只有第二项与生成网络G有关,在固定D,求解令损失函数最小的G时,可以转化为,
m i n G V ( G , D ) = m i n G E z ∼ p ( z ) l o g ( 1 − D ( G ( z ) ) ) = m i n G 1 n ∑ i = 1 n l o g ( 1 − D ( G ( z i ) ) ) \begin{aligned} \underset{G}{min}V(G,D)&=\underset{G}{min}E_{z\sim p(z)}log(1-D(G(z)))\\ &=\underset{G}{min}\frac{1}{n}\sum_{i=1}^nlog(1-D(G(z_i))) \end{aligned} GminV(G,D)=GminEzp(z)log(1D(G(z)))=Gminn1i=1nlog(1D(G(zi)))
对比 G l o s s Gloss Gloss,可以发现, V ( G , D ) V(G,D) V(G,D) G l o s s Gloss Gloss原始公式的第二项类似,相当于只有第二项,也就是当实际标签为0的时候,即, C r o s s e n t r o p y ( y p r e = D ( G ( z ) ) , y t r u e = 0 ) Crossentropy(y_{pre}=D(G(z)),y_{true}=0) Crossentropy(ypre=D(G(z))ytrue=0)
C r o s s e n t r o p y ( y p r e = D ( G ( z ) ) , y t r u e = 0 ) = − 1 n ∑ [ 0 ∗ l o g D ( G ( z i ) ) + ( 1 − 0 ) l o g ( 1 − D ( G ( z i ) ) ) ] = − 1 n ∑ l o g ( 1 − D ( G ( z i ) ) ) \begin{aligned} Crossentropy(y_{pre}=D(G(z)),y_{true}=0)&=-\frac{1}{n}\sum [0*logD(G(z_i))+(1-0)log(1-D(G(z_i)))]\\ &=-\frac{1}{n}\sum log(1-D(G(z_i))) \end{aligned} Crossentropy(ypre=D(G(z))ytrue=0)=n1[0logD(G(zi))+(10)log(1D(G(zi)))]=n1log(1D(G(zi)))
可见,
m i n G V ( G , D ) = m i n G [ − C r o s s e n t r o p y ( y p r e = D ( G ( z ) ) , y t r u e = 0 ) ] \underset{G}{min}V(G,D)=\underset{G}{min}[-Crossentropy(y_{pre}=D(G(z)),y_{true}=0)] GminV(G,D)=Gmin[Crossentropy(ypre=D(G(z))ytrue=0)]
实际上, m i n G [ − C r o s s e n t r o p y ( y p r e = D ( G ( z ) ) , y t r u e = 0 ) ] \underset{G}{min}[-Crossentropy(y_{pre}=D(G(z)),y_{true}=0)] Gmin[Crossentropy(ypre=D(G(z))ytrue=0)] m i n G [ C r o s s e n t r o p y ( y p r e = D ( G ( z ) ) , y t r u e = 1 ) ] \underset{G}{min}[Crossentropy(ypre=D(G(z)),ytrue=1) ] Gmin[Crossentropy(ypre=D(G(z))ytrue=1)] 所想要表达的含义是一致的,后者是想让判别网络识别生成数据的交叉熵最小,也就是识别生成数据的为真概率最大,从而反应出生成网络生成的生成数据足够真实,前者是 C r o s s e n t r o p y ( y p r e = D ( G ( z ) ) , y t r u e = 0 ) Crossentropy(y_{pre}=D(G(z)),y_{true}=0) Crossentropy(ypre=D(G(z))ytrue=0)越大,表明生成网络对错误数据的生成越明显,相反, − C r o s s e n t r o p y ( y p r e = D ( G ( z ) ) , y t r u e = 0 ) -Crossentropy(y_{pre}=D(G(z)),y_{true}=0) Crossentropy(ypre=D(G(z))ytrue=0)越小,表明生成网络越好,含义是一致的。

综上,我们理解清楚了损失函数的来历。

代码实现

我们采用手写数字MNIST数据集来验证,将原始维度为[batch,1,28,28],将其展平为[batch,784]。

判别网络
# 定义判别网络(Discriminator)类
class Discriminator(nn.Module):
    def __init__(self, in_features=784):
        """
        in_features: 输入数据的特征数,默认是28x28的图像展开后的维度784
        """
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),  # 全连接层,将输入特征映射到128维
            nn.LeakyReLU(0.1),  # LeakyReLU激活函数,避免梯度消失
            nn.Linear(128, 1),  # 全连接层,将128维映射到1维
            nn.Sigmoid()  # Sigmoid激活函数,将输出映射到0到1之间
        )

    def forward(self, x):
        """
        前向传播函数
        x: 输入数据
        """
        return self.disc(x)
生成网络
# 定义生成网络(Generator)类
class Generator(nn.Module):
    def __init__(self, in_features, out_features=784):
        """
        in_features: 噪声z的维度
        out_features: 生成的数据维度,默认是28x28的图像展开后的维度784
        """
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(in_features, 256),  # 全连接层,将输入噪声映射到256维
            nn.LeakyReLU(0.1),  # LeakyReLU激活函数
            nn.Linear(256, out_features),  # 全连接层,将256维映射到784维
            nn.Tanh()  # Tanh激活函数,将输出数据归一化到-1到1之间
        )

    def forward(self, x):
        """
        前向传播函数
        x: 输入噪声z
        """
        return self.gen(x)
损失函数及优化器
# 实例化判别网络和生成网络
z_dim = 64  # 噪声z的维度
real_data_dim = 784  # 真实数据的维度,28x28图像展开后的大小
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 使用GPU或CPU设备
gen = Generator(in_features=z_dim, out_features=real_data_dim).to(device)  # 初始化生成网络并移到设备上
disc = Discriminator(in_features=real_data_dim).to(device)  # 初始化判别网络并移到设备上

# 定义优化器
lr = 0.0002  # 学习率
optim_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.9, 0.999))  # 判别网络的Adam优化器
optim_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.9, 0.999))  # 生成网络的Adam优化器

# 定义损失函数
criterion = nn.BCELoss()  # 二分类交叉熵损失函数
数据集准备
# 数据加载器
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 归一化图像数据到[-1, 1]
])

dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True, transform=transform),  # 下载并加载MNIST数据集
    batch_size=64, shuffle=True  # 批量大小为64,数据打乱
)
开始训练
# 训练GAN
num_epochs = 100  # 训练100个epoch

for epoch in range(num_epochs):
    for batch_idx, (real_data, _) in enumerate(dataloader):
        real_data = real_data.view(-1, 784).to(device)  # 展平输入数据并移到设备上
        batch_size = real_data.size(0)  # 获取批量大小

        # 训练判别网络
        # 真实数据的损失
        dx = disc(real_data).view(-1)  # 判别网络对真实数据的预测概率
        loss_real = criterion(dx, torch.ones_like(dx))  # 计算真实数据的损失
        loss_real.backward()  # 反向传播计算梯度

        # 生成假数据的损失
        noise = torch.randn(batch_size, z_dim).to(device)  # 生成随机噪声
        fake_data = gen(noise)  # 通过生成网络生成假数据
        dgz1 = disc(fake_data.detach()).view(-1)  # 判别网络对假数据的预测概率
        loss_fake = criterion(dgz1, torch.zeros_like(dgz1))  # 计算假数据的损失
        loss_fake.backward()  # 反向传播计算梯度

        optim_disc.step()  # 更新判别网络的权重
        disc.zero_grad()  # 清零判别网络的梯度

        # 训练生成网络
        dgz2 = disc(fake_data).view(-1)  # 判别网络对假数据的预测概率
        loss_gen = criterion(dgz2, torch.ones_like(dgz2))  # 生成网络的损失,目标标签为全1
        loss_gen.backward()  # 反向传播计算梯度

        optim_gen.step()  # 更新生成网络的权重
        gen.zero_grad()  # 清零生成网络的梯度

    print(f"Epoch [{epoch + 1}/{num_epochs}] Loss_D: {loss_real + loss_fake:.4f} Loss_G: {loss_gen:.4f}")
综合代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms


# 定义判别网络(Discriminator)类
class Discriminator(nn.Module):
    def __init__(self, in_features=784):
        """
        in_features: 输入数据的特征数,默认是28x28的图像展开后的维度784
        """
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),  # 全连接层,将输入特征映射到128维
            nn.LeakyReLU(0.1),  # LeakyReLU激活函数,避免梯度消失
            nn.Linear(128, 1),  # 全连接层,将128维映射到1维
            nn.Sigmoid()  # Sigmoid激活函数,将输出映射到0到1之间
        )

    def forward(self, x):
        """
        前向传播函数
        x: 输入数据
        """
        return self.disc(x)


# 定义生成网络(Generator)类
class Generator(nn.Module):
    def __init__(self, in_features, out_features=784):
        """
        in_features: 噪声z的维度
        out_features: 生成的数据维度,默认是28x28的图像展开后的维度784
        """
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(in_features, 256),  # 全连接层,将输入噪声映射到256维
            nn.LeakyReLU(0.1),  # LeakyReLU激活函数
            nn.Linear(256, out_features),  # 全连接层,将256维映射到784维
            nn.Tanh()  # Tanh激活函数,将输出数据归一化到-1到1之间
        )

    def forward(self, x):
        """
        前向传播函数
        x: 输入噪声z
        """
        return self.gen(x)


# 实例化判别网络和生成网络
z_dim = 64  # 噪声z的维度
real_data_dim = 784  # 真实数据的维度,28x28图像展开后的大小
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 使用GPU或CPU设备
gen = Generator(in_features=z_dim, out_features=real_data_dim).to(device)  # 初始化生成网络并移到设备上
disc = Discriminator(in_features=real_data_dim).to(device)  # 初始化判别网络并移到设备上

# 定义优化器
lr = 0.0002  # 学习率
optim_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.9, 0.999))  # 判别网络的Adam优化器
optim_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.9, 0.999))  # 生成网络的Adam优化器

# 定义损失函数
criterion = nn.BCELoss()  # 二分类交叉熵损失函数

# 数据加载器
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 归一化图像数据到[-1, 1]
])

dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True, transform=transform),  # 下载并加载MNIST数据集
    batch_size=64, shuffle=True  # 批量大小为64,数据打乱
)

# 训练GAN
num_epochs = 100  # 训练100个epoch

for epoch in range(num_epochs):
    for batch_idx, (real_data, _) in enumerate(dataloader):
        real_data = real_data.view(-1, 784).to(device)  # 展平输入数据并移到设备上
        batch_size = real_data.size(0)  # 获取批量大小

        # 训练判别网络
        # 真实数据的损失
        dx = disc(real_data).view(-1)  # 判别网络对真实数据的预测概率
        loss_real = criterion(dx, torch.ones_like(dx))  # 计算真实数据的损失
        loss_real.backward()  # 反向传播计算梯度

        # 生成假数据的损失
        noise = torch.randn(batch_size, z_dim).to(device)  # 生成随机噪声
        fake_data = gen(noise)  # 通过生成网络生成假数据
        dgz1 = disc(fake_data.detach()).view(-1)  # 判别网络对假数据的预测概率
        loss_fake = criterion(dgz1, torch.zeros_like(dgz1))  # 计算假数据的损失
        loss_fake.backward()  # 反向传播计算梯度

        optim_disc.step()  # 更新判别网络的权重
        disc.zero_grad()  # 清零判别网络的梯度

        # 训练生成网络
        dgz2 = disc(fake_data).view(-1)  # 判别网络对假数据的预测概率
        loss_gen = criterion(dgz2, torch.ones_like(dgz2))  # 生成网络的损失,目标标签为全1
        loss_gen.backward()  # 反向传播计算梯度

        optim_gen.step()  # 更新生成网络的权重
        gen.zero_grad()  # 清零生成网络的梯度

    print(f"Epoch [{epoch + 1}/{num_epochs}] Loss_D: {loss_real + loss_fake:.4f} Loss_G: {loss_gen:.4f}")

训练100轮之后的结果,

在这里插入图片描述

接着我们可以测试一下生成网络生成图片的效果,代码如下,

import torch
import torch.nn as nn

# 定义生成网络(Generator)类
class Generator(nn.Module):
    def __init__(self, in_features, out_features=784):
        """
        in_features: 噪声z的维度
        out_features: 生成的数据维度,默认是28x28的图像展开后的维度784
        """
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(in_features, 256),  # 全连接层,将输入噪声映射到256维
            nn.LeakyReLU(0.1),  # LeakyReLU激活函数
            nn.Linear(256, out_features),  # 全连接层,将256维映射到784维
            nn.Tanh()  # Tanh激活函数,将输出数据归一化到-1到1之间
        )

    def forward(self, x):
        """
        前向传播函数
        x: 输入噪声z
        """
        return self.gen(x)


# 实例化判别网络和生成网络
z_dim = 64  # 噪声z的维度
real_data_dim = 784  # 真实数据的维度,28x28图像展开后的大小
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 使用GPU或CPU设备

# 继续使用之前定义的 Generator 类
gen = Generator(in_features=z_dim, out_features=real_data_dim).to(device)
gen.load_state_dict(torch.load('generator.pth'))
gen.eval()  # 设置为评估模式

import matplotlib.pyplot as plt

# 生成随机噪声
noise = torch.randn(1, z_dim).to(device)  # 生成一个样本,批量大小为1
# 使用生成网络生成图片
with torch.no_grad():  # 禁用梯度计算,节省内存和计算资源
    generated_image = gen(noise).view(28, 28).cpu().numpy()  # 生成的图片

plt.savefig("generated_image.png")

# 显示生成的图片
plt.imshow(generated_image, cmap='gray')
plt.title("Generated Image")
plt.savefig("generated_image.png")
plt.show()

显示图片为,

在这里插入图片描述

然后再用生成网络生成的图片,测试一下判别网络的判别效果,代码如下,

import torch
import torch.nn as nn

# 定义判别网络(Discriminator)类
class Discriminator(nn.Module):
    def __init__(self, in_features=784):
        """
        in_features: 输入数据的特征数,默认是28x28的图像展开后的维度784
        """
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),  # 全连接层,将输入特征映射到128维
            nn.LeakyReLU(0.1),  # LeakyReLU激活函数,避免梯度消失
            nn.Linear(128, 1),  # 全连接层,将128维映射到1维
            nn.Sigmoid()  # Sigmoid激活函数,将输出映射到0到1之间
        )

    def forward(self, x):
        """
        前向传播函数
        x: 输入数据
        """
        return self.disc(x)

# 实例化判别网络和生成网络
z_dim = 64  # 噪声z的维度
real_data_dim = 784  # 真实数据的维度,28x28图像展开后的大小
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 使用GPU或CPU设备



# 加载判别网络模型
disc = Discriminator(in_features=real_data_dim).to(device)
disc.load_state_dict(torch.load('discriminator.pth'))
disc.eval()

# 假设你有一张图片 real_image
# 这里 real_image 是一个28x28的灰度图片,已经预处理为 tensor
# 你需要确保图片已经转换成 [1, 784] 的形状并放在正确的设备上


from PIL import Image
import torchvision.transforms as transforms

# 加载图片
image_path = 'generated_image.png'  # 替换为你的图片路径
image = Image.open(image_path).convert('L')  # 以灰度模式加载图片

# 定义预处理变换
transform = transforms.Compose([
    transforms.Resize((28, 28)),  # 调整大小为28x28
    transforms.ToTensor(),  # 转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 归一化到[-1, 1]
])

# 应用预处理
real_image = transform(image)  # 预处理图片

real_image = real_image.view(-1, 784).to(device)  # 确保形状正确
with torch.no_grad():
    prediction = disc(real_image).item()

print(f"The probability that the image is real: {prediction:.4f}")

显示结果,

在这里插入图片描述

可以看到判别网络对生成网络生成的数据判定为真实数据的概率为0.999.

参考文章

[1]https://zhuanlan.zhihu.com/p/628915533
[2]Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., … & Bengio, Y. (2014). Generative adversarial nets. arXiv preprint arXiv:1406.2661.)


http://www.kler.cn/a/370085.html

相关文章:

  • SpringCloud微服务Gateway网关简单集成Sentinel
  • 炸场硅谷,大模型“蒸汽机”迎来“瓦特时刻”
  • Text2SQL 智能报表方案介绍
  • 将 AzureBlob 的日志通过 Azure Event Hubs 发给 Elasticsearch(1.标准版)
  • Redis的Windows版本安装以及可视化工具
  • Chrome远程桌面无法连接怎么解决?
  • 51单片机完全学习——DS18B20温度传感器
  • 医院信息化与智能化系统(12)
  • 极狐GitLab 发布安全补丁版本17.5.1, 17.4.3, 17.3.6
  • TextHarmony:视觉文本理解与生成的新型多模态大模型
  • 唤醒车机时娱乐屏出现黑屏,卡顿的案例分享
  • 深度学习(五):语音处理领域的创新引擎(5/10)
  • 106. 平行光阴影计算
  • springmvc请求源码流程解析(二)
  • 优先算法——移动零(双指针)
  • LVGL移植教程(超详细)——基于GD32F303X系列MCU
  • 人脸美颜 API 对接说明
  • 批量剪辑视频软件源码搭建全解析,支持OEM
  • 【瑞吉外卖】-day01
  • 使用 three.js 渲染个blender模型
  • 特定机器学习问题的基准测试数据
  • 【数据价值化】数据资产变现及管理规划
  • SQL语句优化之Sql执行顺序
  • 记录如何在RK3588板子上跑通paddle的OCR模型
  • 从零开始:使用Spring Boot搭建网上摄影工作室
  • c++二级指针