首页 > 编程知识 正文

SelfAttention Generative Adversarial NetworksSAGAN理解

时间:2023-05-03 09:59:56 阅读:195372 作者:4502

介绍

Self-Attention Generative Adversarial Networks(SAGAN)是Han Zhang, Ian Goodfellow等人在去年提出的一种新的GAN结构,网络主要引入了注意力机制,不仅解决了卷积结构带来的感受野大小的限制,也使得网络在生成图片的过程中能够自己学习应该关注的不同区域。从结果上来看,SAGAN相比于之前最好的结构,在 Inception score上从36.8提高到了52.52,而 Frechet Inception distance 从27.62降到了18.65,如果对这两个指标不了解,可以看一下我之前的博客GAN的几种评价指标。

GAN之前存在的问题: 对于含有较少结构约束的类别,比如海洋、天空等,得到结果较好;而对于含有较多几何或结构约束的类别则容易失败,比如合成图像中狗(四足动物)的毛看起来很真实但手脚很难辨认。这是因为复杂的几何轮廓需要long-range dependencies(长距离依赖),卷积的特点就是局部性,受到感受野大小的限制很难提取到图片中的这些长距离依赖。虽然可以通过加深网络或者扩大卷积核的尺寸一定程度解决该问题,但是这会使卷积网络丧失了其参数和计算的效率优势。

论文的主要贡献:

把self-attention机制引入到了GAN的框架中,对卷积结构进行了补充,有助于对图像区域中长距离,多层次的依赖关系进行建模,并对该机制做了可视化实验;在判别器和生成器中均使用spectral normalization,提升生成器的性能;训练中使用Two Timescale Update Rule (TTUR),对判别器使用较高学习率,从而可以保证生成器和判别器可以更新比例为1:1,加快收敛速度,减少训练时间。 实现原理 自注意力机制生成对抗网络

self-attention机制的实现主要受到之前另一篇论文的启发:Non-local Neural Networks,有兴趣的可以阅读。为了有效地对空间上分隔较远的区域建立联系,SAGAN引入了self-attention module,其结构如下图:

对于某一个卷积层之后输出的feature map(上图中的x),不考虑Batch维度,整个模块有以下步骤:

分别经过三个1x1卷积结构的分支f(x)、g(x)和h(x),特征图的尺寸均不变,f(x)和g(x)改变了通道数,h(x)的输出保持通道数也不变;对于H和W两个维度进行合并,将f(x)的输出转置后和g(x)的输出矩阵相乘,经过softmax归一化得到一个 [HxW,HxW] 的Attention Map;将Attention Map与h(x)的输出进行矩阵相乘,得到一个 [HxW,C] 的特征图,经过一个1x1的卷积结构,并把输出重新reshape为[H,W,C],得到此时的feature map(上图中的o);最终输出的特征图为:y = γo + x

简单的tensorflow版实现代码如下:(https://github.com/taki0112/Self-Attention-GAN-Tensorflow)

def attention(self, x, ch): f = conv(x, ch // 8, kernel=1, stride=1, sn=self.sn, scope='f_conv') # [bs, h, w, c'] g = conv(x, ch // 8, kernel=1, stride=1, sn=self.sn, scope='g_conv') # [bs, h, w, c'] h = conv(x, ch, kernel=1, stride=1, sn=self.sn, scope='h_conv') # [bs, h, w, c] # N = h * w s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N] beta = tf.nn.softmax(s) # attention map o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C] gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0)) o = tf.reshape(o, shape=x.shape) # [bs, h, w, C] x = gamma * o + x return x

上面简单的从过程上对self-attention机制做了说明,下面说一下其中一些步骤的具体操作和自己的理解:
f(x)和g(x)的输出通道数: 在计算过程中这个通道数是不影响Attention Map的维度的,较少的通道数会减少参数量和计算量,作者在实验中分别使用C/k(k=1,2,4,8)训练了几个epoch后发现对结果影响不大,因此最终选用了C/8。

通过f(x)和g(x)的输出计算Attention Map: 类似于风格迁移中计算Gram矩阵,但是矩阵相乘的顺序正好相反。Gram矩阵得到的是[C,C]的输出,因此表示的是原始feature map中通道与通道之间的相关性,而这里得到的是[HxW,HxW]的输出,因此表示的是像素点与像素点之间的相关性。当经过了softmax函数之后(注意这里是对每一行单独进行softmax),每一行就代表了一种attention机制,其实质是原始feature map中一个像素位置(C个像素点)与其它所有像素位置的关系,HxW行也就对应了原始的HxW个像素位置。

h(x)的输出与Attention Map矩阵相乘: 上面说到过,Attention Map中的每一行都代表了一种attention机制而h(x)有HxW行和C列。因此两个矩阵相乘就是计算了h(x)每个通道上的所有像素分别经过HxW种attention机制后的结果,经过这个操作,每个像素点都和整张feature map上所有的点建立了联系,实现了self-attention的目的。(取一些简单的值把整个过程的输出矩阵画出来,会更加容易理解)

最终特征图的表示: 最终的输出是y = γo + x,其中γ是一个可学习的参数,并且初始化为0。因为在网络开始训练的时候,还是希望网络先把局部信息学好,初始化为0就表示不采用self-attention模块,而随着训练的进行,网络会开始慢慢尝试使用self-attention模块学习更多长距离的特征。

最后,该网络的生成对抗损失函数采用了合页(Hinge)损失的版本:

稳定训练过程的技巧

主要有两点:
1、在生成器和判别器中均应用了Spectral normalization的方法:这个方法可以限制参数的分布(Lipschitz 条件),从而减少梯度爆炸等影响,稳定网络的训练。最初该方法只在判别器中使用,SAGAN中在生成器和判别器中均加入了Spectral normalization。
2、训练中使用Two Timescale Update Rule (TTUR):通常在GAN中G和D采用交替训练的方法,G训练一次后需要训练多次D(常见的是一次G,五次D),SAGAN使用了TTUR的方法,具体来说就是G和D使用不同的学习率(G为0.0001,D为0.0004),使得每一次生成器的训练之后判别器需要更少的训练次数。

下图是上面两个技巧的消融实验结果:

实验部分 生成效果

作者除了和残差网络对比了生成效果外,还探索了在不同层使用self-attention模块和残差模块的影响:

可以看到在中高层特征引入Attention机制能取得更好的效果,个人感觉原因是生成网络和分类网络正好相反,越是接近输出层,细节特征反而更明显,因此在高层引入Attention机制,效果更好。

Attention机制的可视化

作者为了说明Attention机制的工作原理,对最接近输出层的self-attention模块的Attention Map做了可视化,得到了非常惊人的结果:

如上图所示,根据我们之前的分析,Attention Map中有 HxW 种attention机制,生成图片中的每一个像素点都是根据某一种attention机制生成的,作者在每张生成图片中找了红、蓝、绿三个点,并可视化了计算该点时所使用的attention机制(就是Attention Map中的某一行,把它从 HxW reshape为[H,W]),可视化的结果中,高亮区域表示与对应像素有较大的关联。举例来说,左上第一张图,绿色的像素点是背景的树林,可视化生成该像素点使用的 attention 机制,可以发现高亮的区域正是树林分布的区域,而正好越过了中间存在的鸟,说明self-attention模块确实学到了跨越空间的像素之间的依赖关系。

下图是取五个点对attention机制的可视化,可以发现网络不仅能够区分前景和背景,甚至对一些物体的不同结构也能准确的进行划分,在生成图片的时候也会更合理。

总结

SAGAN通过引入self-attention模块,使得每个像素点与其它像素点产生了关联,解决了普通卷积结构存在的long-range dependencies(长距离依赖)问题,在提高感受野和减小参数量之间得到了一个更好的平衡。

参考博客:https://blog.csdn.net/HaoTheAnswer/article/details/82733234

版权声明:该文观点仅代表作者本人。处理文章:请发送邮件至 三1五14八八95#扣扣.com 举报,一经查实,本站将立刻删除。