首页 > 编程知识 正文

Bilinear Attention Networks 代码记录,苹果重启记录代码是什么

时间:2023-05-06 16:02:24 阅读:218227 作者:3815

torch.einsum是个好东西,就是输入数据多于2个,就有点看不懂了。(改成了使用torch.matmul主要是为了将代码和论文公式对应上,也验证了计算的结果应该是一致的)
源码来源:https://github.com/jnhwkim/ban-vqa
以下代码位于此处,其中:
1)forward函数用来计算Bilinear Attention Map(输入分别是视觉编码v和问题编码q),也就是注意力权重;
2)forward_with_weights函数基于注意力权重w来进行视觉编码v和问题编码q的融合。
计算权重和融合分别是两个独立的层中的操作,两者之间是不共享v_net和q_net的参数的。

其中,相关数据维度如下:

# 1 forward函数:v_ [B, M, D]q_ [B, L, D]# 2 forward_with_weights函数:v_ [B, M, D]q_ [B, L, D]w [B, M, L] 1 forward函数 # low-rank bilinear pooling using einsumdef forward(self, v, q):...elif self.h_out <= self.c:v_ = self.dropout(self.v_net(v))q_ = self.q_net(q)logits = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_biasreturn logits # b x h_out x v x q...

对应论文中的公式(8)softmax函数中内容(这个mathbbm{1}打不出来,-_-),其中h_mat即公式中的 p mathrm{p} p(公式中未考虑偏置bias),两个Linear层v_net和q_net对应了公式中的权重矩阵 U mathrm{U} U和 V mathrm{V} V,v和q即公式中的 X mathrm{X} X和 Y mathrm{Y} Y。

----------------------------------------------分割线---------------------------------------------
1、【公式(8)softmax内表达式与公式(9)等价,公式(8)直接求出权重矩阵,公式(9)是权重矩阵中每一个单项数据原始计算形式。个人感觉转换成公式(8)更方便使用代码实现,因为X与Y的长度如果都多于1的话,其Hadamard积计算就不是那么方便(pytorch中反正好像没有长度不等的矩阵求Hadamard积的函数实现)】
2、【X和Y可对比为现在注意力机制中常说的Query和Key,如果Query/X长度始终是1,那么也可以使用Linear层实现公式(9),而不需要转换为公式(8)进行实现】
----------------------------------------------分割线完-------------------------------------------
如果不用torch.einsum()函数,按照公式,代码可如下:

# low-rank bilinear pooling using einsumdef forward(self, v, q):...elif self.h_out <= self.c:"""v: [B, M, D']q: [B, L, D']h_mat: [1, h_out, 1, D], h_out默认等于8h_bias: [1, h_out, 1, 1]D = self.k * D', self.k = 3""" v_ = self.dropout(self.v_net(v)) # [B, M, D]q_ = self.q_net(q) # [B, L, D]# logits = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias logits = torch.matmul((h_mat * v_.unsqueeze(1)), q_.unsqueeze(1).transpose(-1, -2)) + h_biasreturn logits # b x h_out x M x L... 2 forward_with_weights函数 def forward_with_weights(self, v, q, w):v_ = self.v_net(v) # b x v x dq_ = self.q_net(q) # b x q x dlogits = torch.einsum('bvk,bvq,bqk->bk', (v_, w, q_))if 1 < self.k:logits = logits.unsqueeze(1) # b x 1 x dlogits = self.p_net(logits).squeeze(1) * self.k # sum-poolingreturn logits

对应论文中的公式(5)和公式(6),下标k表示在数据维度上(即维度D)的遍历,完整 f ′ ∈ R K mathrm{f'}in mathbb{R}^K f′∈RK。其中v和q和w分别代表 X mathrm{X} X和 Y mathrm{Y} Y和 A mathcal{A} A(这里面的K本质上应该就是数据维度D)

如果不用torch.einsum()函数,按照公式,代码可如下:

def forward_with_weights(self, v, q, w):"""v: [B, M, D']q: [B, L, D']w: [B, M, L]D = self.k * D', self.k = 3"""v_ = self.v_net(v) # [B, M, D]q_ = self.q_net(q) # [B, L, D]# logits = torch.einsum('bvk,bvq,bqk->bk', (v_, w, q_))# 运算过程与维度D无关,因此需要交换维度logits = torch.matmul(torch.matmul(v_.permute(0,2,1).unsqueeze(-2), w.unsqueeze(1)), q_.permute(0, 2, 1).unsqueeze(-1)).squeeze()if 1 < self.k:logits = logits.unsqueeze(1) # [B, 1, D]logits = self.p_net(logits).squeeze(1) * self.k # sum-poolingreturn logits # [B, D / self.k]

版权声明:该文观点仅代表作者本人。处理文章:请发送邮件至 三1五14八八95#扣扣.com 举报,一经查实,本站将立刻删除。