首页 > 编程知识 正文

乡土阅读笔记,阅读笔记摘抄大全

时间:2023-05-03 08:49:11 阅读:195371 作者:4040

论文《Self-Attention Generative Adversarial Networks》
地址:https://arxiv.org/abs/1805.08318
代码地址:https://github.com/heykeetae/Self-Attention-GAN

按照代码流程进行记录

默认参数设置 adv_loss = 'hinge'attn_path = './attn'batch_size = 64beta1 = 0.0beta2 = 0.9d_conv_dim = 64d_iters = 5d_lr = 0.0004dataset = 'celeb'g_conv_dim = 64g_lr = 0.0001g_num = 5image_path = './data'imsize = 64lambda_gp = 10log_path = './logs'log_step = 10lr_decay = 0.95model = 'sagan'model_save_path = './models'model_save_step = 1.0num_workers = 2parallel = Falsepretrained_model = Nonesample_path = './samples'sample_step = 100total_step = 1000000train = Trueuse_tensorboard = Falseversion = 'sagan_celeb'z_dim = 128 Discriminator网络结构

判别器网络设定参数为batch size=64, image_size=64, conv_dim=64

假定输入数据为 torch.Size([64, 3, 64, 64])

# layer1Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))SpectralNorm()LeakyReLU(negative_slope=0.1)

此时变为 torch.Size([64, 64, 32, 32])

# layer2Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))SpectralNorm()LeakyReLU(negative_slope=0.1)

此时变为 torch.Size([64, 128, 16, 16])

# layer3Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))SpectralNorm()LeakyReLU(negative_slope=0.1)

此时变为 torch.Size([64, 256, 8, 8])

可以看出前三层的网络结构基本一致,channel在不断增加,但是尺寸在减小。

前三层结束之后,进行一次 self-attention 层,此时尺寸不变,还是 torch.Size([64, 256, 8, 8]) 注意力map为 torch.Size([64, 64, 64])

如果输入图像数据的尺寸为64时,还有一个layer4,与前三层结构一致

# layer4Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))SpectralNorm()LeakyReLU(negative_slope=0.1)

此时变为 torch.Size([64, 512, 4, 4])

第4层结束之后,再进行一次 self-attention,输出第二个注意力map 为 torch.Size([64, 16, 16])

# lastConv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))

此时变为 torch.Size([64, 1, 1, 1])

最后,使用squeeze()从数组的形状中删除单维度条目,即把shape中为1的维度去掉,判别器输出为 torch.Size([64])

Generator网络结构

生成器网络参数设置为 batch_size=64, image_size=64, z_dim=128, conv_dim=64

首先生成一个随机值,每个图像有z_dim维度的噪音组成,假定输入数据为 torch.Size([64, 128])

先将输入数据变为 torch.Size([64, 128, 1, 1])

repeat_num = int(np.log2(self.imsize)) - 3mult = 2 ** repeat_num # 8

计算mult=8

# layer1ConvTranspose2d(128, 512, kernel_size=(4, 4), stride=(1, 1))SpectralNorm()BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)ReLU()

此时变为 torch.Size([64, 512, 4, 4])

# layer2ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))SpectralNorm()BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)ReLU()

此时变为 torch.Size([64, 256, 8, 8])

# layer3ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))SpectralNorm()BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)ReLU()

此时变为 torch.Size([64, 128, 16, 16])

第3层之后,会计算 self-attention,其中map1为 torch.Size([64, 256, 256])

# layer4ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))SpectralNorm()BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)ReLU()

此时变为 torch.Size([64, 64, 32, 32])

第4层之后,也会有attention层,map2为 torch.Size([64, 1024, 1024])

# lastConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))Tanh()

此时变为 torch.Size([64, 3, 64, 64])

损失函数计算 Discriminator

判别器整体的损失函数是
L D = − E ( x , y ) ∼ p d a t a [ m i n ( 0 , − 1 + D ( x , y ) ) ] − E z ∼ p z , y ∼ p d a t a [ m i n ( 0 , − 1 − D ( G ( z ) , y ) ) ] L_D = -E_{(x,y)sim p_{data}}[min(0, -1 + D(x,y))] - E_{z sim p_{z},y sim p_{data}}[min(0, -1 - D(G(z),y))] LD​=−E(x,y)∼pdata​​[min(0,−1+D(x,y))]−Ez∼pz​,y∼pdata​​[min(0,−1−D(G(z),y))]

输入真实图像

d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()

输入生成图像

d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean() Generator

生成器整体的损失函数是
L G = − E z ∼ p z , y ∼ p d a t a D ( G ( z ) , y ) L_G=-E_{z sim p_{z},y sim p_{data}}D(G(z),y) LG​=−Ez∼pz​,y∼pdata​​D(G(z),y)

fake_images,_,_ = self.G(z)g_out_fake,_,_ = self.D(fake_images) # batch x ng_loss_fake = - g_out_fake.mean()

也就是说,生成器的损失是判别器对生成图像判别的平均值

总结

生成器和判别器中使用了两层self-attention

生成器中使用光谱归一化之后,又加了一层BatchNorm2d,这个地方没有看明白

学习速率不同,但是学习迭代比例是1:1的

版权声明:该文观点仅代表作者本人。处理文章:请发送邮件至 三1五14八八95#扣扣.com 举报,一经查实,本站将立刻删除。