torch.einsum是个好东西,就是输入数据多于2个,就有点看不懂了。(改成了使用torch.matmul主要是为了将代码和论文公式对应上,也验证了计算的结果应该是一致的)
源码来源:https://github.com/jnhwkim/ban-vqa
以下代码位于此处,其中:
1)forward函数用来计算Bilinear Attention Map(输入分别是视觉编码v和问题编码q),也就是注意力权重;
2)forward_with_weights函数基于注意力权重w来进行视觉编码v和问题编码q的融合。
计算权重和融合分别是两个独立的层中的操作,两者之间是不共享v_net和q_net的参数的。
其中,相关数据维度如下:
对应论文中的公式(8)softmax函数中内容(这个mathbbm{1}打不出来,-_-),其中h_mat即公式中的 p mathrm{p} p(公式中未考虑偏置bias),两个Linear层v_net和q_net对应了公式中的权重矩阵 U mathrm{U} U和 V mathrm{V} V,v和q即公式中的 X mathrm{X} X和 Y mathrm{Y} Y。
----------------------------------------------分割线---------------------------------------------
1、【公式(8)softmax内表达式与公式(9)等价,公式(8)直接求出权重矩阵,公式(9)是权重矩阵中每一个单项数据原始计算形式。个人感觉转换成公式(8)更方便使用代码实现,因为X与Y的长度如果都多于1的话,其Hadamard积计算就不是那么方便(pytorch中反正好像没有长度不等的矩阵求Hadamard积的函数实现)】
2、【X和Y可对比为现在注意力机制中常说的Query和Key,如果Query/X长度始终是1,那么也可以使用Linear层实现公式(9),而不需要转换为公式(8)进行实现】
----------------------------------------------分割线完-------------------------------------------
如果不用torch.einsum()函数,按照公式,代码可如下:
对应论文中的公式(5)和公式(6),下标k表示在数据维度上(即维度D)的遍历,完整 f ′ ∈ R K mathrm{f'}in mathbb{R}^K f′∈RK。其中v和q和w分别代表 X mathrm{X} X和 Y mathrm{Y} Y和 A mathcal{A} A(这里面的K本质上应该就是数据维度D)
如果不用torch.einsum()函数,按照公式,代码可如下:
def forward_with_weights(self, v, q, w):"""v: [B, M, D']q: [B, L, D']w: [B, M, L]D = self.k * D', self.k = 3"""v_ = self.v_net(v) # [B, M, D]q_ = self.q_net(q) # [B, L, D]# logits = torch.einsum('bvk,bvq,bqk->bk', (v_, w, q_))# 运算过程与维度D无关,因此需要交换维度logits = torch.matmul(torch.matmul(v_.permute(0,2,1).unsqueeze(-2), w.unsqueeze(1)), q_.permute(0, 2, 1).unsqueeze(-1)).squeeze()if 1 < self.k:logits = logits.unsqueeze(1) # [B, 1, D]logits = self.p_net(logits).squeeze(1) * self.k # sum-poolingreturn logits # [B, D / self.k]