首页 > 编程知识 正文

生成对抗网络 pytorch

时间:2023-05-06 13:20:51 阅读:184061 作者:4482

简介

上文说到生成对抗网络GAN能够通过训练学习到数据分布,进而生成新的样本。可是GAN的缺点是生成的图像是随机的,不能控制生成图像属于何种类别。比如数据集包含飞机、汽车和房屋等类别,原始GAN并不能在测试阶段控制输出属于哪一类。

为此,研究人员提出了Conditional Generative Adversarial Network(简称CGAN),CGAN的图像生成过程是可控的。

本文包含以下3个方面:

(1)CGAN原理分析
(2)pytorch实现CGAN
(3)视觉结果和损失函数曲线

CGAN的思想是非常简单的,这也验证了那句话,越简单的想法越伟大!

1、CGAN原理分析 1.1 网络结构

CGAN是在GAN基础上做的一种改进,通过给原始GAN的生成器Generator(下文简记为G)和判别器Discriminator(下文简记为D)添加额外的条件信息,实现条件生成模型。CGAN原文中作者说额外的条件信息可以是类别标签或者其它的辅助信息,本文使用条件信息(记为y)作为例子。

CGAN的核心操作是将条件信息加入到G和D中,下面分别进行讨论:

(1)原始GAN生成器输入是噪声信号,类别标签可以和噪声信号组合作为隐空间表示;
(2)原始GAN判别器输入是图像数据(真实图像和生成图像),同样需要将类别标签和图像数据进行拼接作为判别器输入。


从上图(来自CGAN论文)中可以看出,CGAN的网络相对于原始GAN网络并没有变化,改变的仅仅是生成器G和判别器D的输入数据,这就使得CGAN可以作为一种通用策略嵌入到其它的GAN网络中。

2.2 损失函数

原始GAN包含一个生成器和一个判别器,其中生成器G和判别器D进行极大极小博弈,损失函数如下:

CGAN添加的额外信息y只需要和x与z进行合并,作为G和D的输入即可,由此得到了CGAN的损失函数如下:

1.3 训练策略与实验结果

CGAN在mnist数据集上进行了实验,对于生成器:使用数字的类别y作为标签,并进行了one-hot编码,噪声z来自均均匀分布;噪声z映射到200维的隐层,类别标签映射到1000维的隐层,然后进行拼接作为下一层的输入,激活函数使用ReLU;最后一层使用Sigmoid函数,生成的样本为784维(使用的mnist长宽为28x28=784)。得到的实验结果如下:

上图中每行是由相同的标签生成的,说明CGAN的确可以通过给生成器特定的标签,实现特定模式(类别)的生成。CGAN还做了其它的实验,都证明了CGAN的模式控制能力。

2、pytorch实现 2.1 生成器实现

CGAN的生成器输入为噪声z和类别标签y的联合输入,所以这里我直接在对DCGAN的生成器进行改动(DCGAN的代码和分析参见我之前的文章):

class Generator(nn.Module): def __init__(self, z_dim, num_classes): super().__init__() self.z_dim = z_dim self.num_classes = num_classes net = [] # 1:设定每次反卷积的输入和输出通道数 # 卷积核尺寸固定为3,反卷积输出为“SAME”模式 channels_in = [self.z_dim+self.num_classes, 512, 256, 128, 64] channels_out = [512, 256, 128, 64, 3] active = ["R", "R", "R", "R", "tanh"] stride = [1, 2, 2, 2, 2] padding = [0, 1, 1, 1, 1] for i in range(len(channels_in)): net.append(nn.ConvTranspose2d(in_channels=channels_in[i], out_channels=channels_out[i], kernel_size=4, stride=stride[i], padding=padding[i], bias=False)) if active[i] == "R": net.append(nn.BatchNorm2d(num_features=channels_out[i])) net.append(nn.ReLU()) elif active[i] == "tanh": net.append(nn.Tanh()) self.generator = nn.Sequential(*net) def forward(self, x, label): x = x.unsqueeze(2).unsqueeze(3) label = label.unsqueeze(2).unsqueeze(3) data = torch.cat(tensors=(x, label), dim=1) out = self.generator(data) return out 2.2 判别器的实现

CGAN的判别器需要使用图像(生成的和真实的)和类别标签y联合输入,所以这里也是对DCGAN的判别器第一层进行改动:

class Discriminator(nn.Module): def __init__(self, num_classes): super().__init__() self.num_classes = num_classes net = [] # 1:预先定义 channels_in = [3+self.num_classes, 64, 128, 256, 512] channels_out = [64, 128, 256, 512, 1] padding = [1, 1, 1, 1, 0] active = ["LR", "LR", "LR", "LR", "sigmoid"] for i in range(len(channels_in)): net.append(nn.Conv2d(in_channels=channels_in[i], out_channels=channels_out[i], kernel_size=4, stride=2, padding=padding[i], bias=False)) if i == 0: net.append(nn.LeakyReLU(0.2)) elif active[i] == "LR": net.append(nn.BatchNorm2d(num_features=channels_out[i])) net.append(nn.LeakyReLU(0.2)) elif active[i] == "sigmoid": net.append(nn.Sigmoid()) self.discriminator = nn.Sequential(*net) def forward(self, x, label): label = label.unsqueeze(2).unsqueeze(3) label = label.repeat(1, 1, x.size(2), x.size(3)) data = torch.cat(tensors=(x, label), dim=1) out = self.discriminator(data) out = out.view(data.size(0), -1) return out 3、视觉结果和损失函数曲线

自己的数据包含3类:动漫脸、人脸、鞋。其实当时还选择了其它数据,但是最后发现,在数据集质量不够高时,生成的样本明显不够好,最后筛选才确定了使用这三个数据集。当然,自己的实验结果也非常差!迭代的总体次数为6000次左右,生成了下面的样本:

上面这个动漫脸完全看不清,人脸中也看不见嘴,下面这个结果更好些:

实际上,结果比较差的主要原因还是在于生成器的结构(不够深,拟合能力不够强),如果换成是近两年的生成器结构,生成的效果肯定会好很多。当然,调参数而是很重要的一个方面,自己也没有进行细致的调参。下面这张图显示了迭代过程中生成的图像的变化:

损失函数没有展示出收敛的趋势,尤其是生成器的损失似乎还在增加:

版权声明:该文观点仅代表作者本人。处理文章:请发送邮件至 三1五14八八95#扣扣.com 举报,一经查实,本站将立刻删除。