当前位置:   article > 正文

Co Attention注意力机制实现_co-attention

co-attention

Hierarchical Question-Image Co-Attention for Visual Question Answering”中的图像和文本间的Co Attention协同注意力实现

参考:

https://github.com/SkyOL5/VQA-CoAttention/blob/master/coatt/coattention_net.py

https://github.com/Zhangtd/Models-reproducing/blob/master/NIPS2016/selfDef.py

Co Attention示意图如下:

有两种实现方式,分别为Parallel co-attention mechanism和Alternating co-attention mechanism

基于PyTorch实现Parallel co-attention mechanism,代码如下:

  1. from typing import Dict, Optional
  2. import numpy as np
  3. import torch.nn as nn
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import Tensor
  7. def create_src_lengths_mask(
  8. batch_size: int, src_lengths: Tensor, max_src_len: Optional[int] = None
  9. ):
  10. """
  11. Generate boolean mask to prevent attention beyond the end of source
  12. Inputs:
  13. batch_size : int
  14. src_lengths : [batch_size] of sentence lengths
  15. max_src_len: Optionally override max_src_len for the mask
  16. Outputs:
  17. [batch_size, max_src_len]
  18. """
  19. if max_src_len is None:
  20. max_src_len = int(src_lengths.max())
  21. src_indices = torch.arange(0, max_src_len).unsqueeze(0).type_as(src_lengths)
  22. src_indices = src_indices.expand(batch_size, max_src_len)
  23. src_lengths = src_lengths.unsqueeze(dim=1).expand(batch_size, max_src_len)
  24. # returns [batch_size, max_seq_len]
  25. return (src_indices < src_lengths).int().detach()
  26. def masked_softmax(scores, src_lengths, src_length_masking=True):
  27. """Apply source length masking then softmax.
  28. Input and output have shape bsz x src_len"""
  29. if src_length_masking:
  30. bsz, max_src_len = scores.size()
  31. # print('bsz:', bsz)
  32. # compute masks
  33. src_mask = create_src_lengths_mask(bsz, src_lengths)
  34. # Fill pad positions with -inf
  35. scores = scores.masked_fill(src_mask == 0, -np.inf)
  36. # Cast to float and then back again to prevent loss explosion under fp16.
  37. return F.softmax(scores.float(), dim=-1).type_as(scores)
  38. class ParallelCoAttentionNetwork(nn.Module):
  39. def __init__(self, hidden_dim, co_attention_dim, src_length_masking=True):
  40. super(ParallelCoAttentionNetwork, self).__init__()
  41. self.hidden_dim = hidden_dim
  42. self.co_attention_dim = co_attention_dim
  43. self.src_length_masking = src_length_masking
  44. self.W_b = nn.Parameter(torch.randn(self.hidden_dim, self.hidden_dim))
  45. self.W_v = nn.Parameter(torch.randn(self.co_attention_dim, self.hidden_dim))
  46. self.W_q = nn.Parameter(torch.randn(self.co_attention_dim, self.hidden_dim))
  47. self.w_hv = nn.Parameter(torch.randn(self.co_attention_dim, 1))
  48. self.w_hq = nn.Parameter(torch.randn(self.co_attention_dim, 1))
  49. def forward(self, V, Q, Q_lengths):
  50. """
  51. :param V: batch_size * hidden_dim * region_num, eg B x 512 x 196
  52. :param Q: batch_size * seq_len * hidden_dim, eg B x L x 512
  53. :param Q_lengths: batch_size
  54. :return:batch_size * 1 * region_num, batch_size * 1 * seq_len,
  55. batch_size * hidden_dim, batch_size * hidden_dim
  56. """
  57. # (batch_size, seq_len, region_num)
  58. C = torch.matmul(Q, torch.matmul(self.W_b, V))
  59. # (batch_size, co_attention_dim, region_num)
  60. H_v = nn.Tanh()(torch.matmul(self.W_v, V) + torch.matmul(torch.matmul(self.W_q, Q.permute(0, 2, 1)), C))
  61. # (batch_size, co_attention_dim, seq_len)
  62. H_q = nn.Tanh()(
  63. torch.matmul(self.W_q, Q.permute(0, 2, 1)) + torch.matmul(torch.matmul(self.W_v, V), C.permute(0, 2, 1)))
  64. # (batch_size, 1, region_num)
  65. a_v = F.softmax(torch.matmul(torch.t(self.w_hv), H_v), dim=2)
  66. # (batch_size, 1, seq_len)
  67. a_q = F.softmax(torch.matmul(torch.t(self.w_hq), H_q), dim=2)
  68. # # (batch_size, 1, seq_len)
  69. masked_a_q = masked_softmax(
  70. a_q.squeeze(1), Q_lengths, self.src_length_masking
  71. ).unsqueeze(1)
  72. # (batch_size, hidden_dim)
  73. v = torch.squeeze(torch.matmul(a_v, V.permute(0, 2, 1)))
  74. # (batch_size, hidden_dim)
  75. q = torch.squeeze(torch.matmul(masked_a_q, Q))
  76. return a_v, masked_a_q, v, q

测试代码如下:

  1. pcan = ParallelCoAttentionNetwork(6, 5)
  2. v = torch.randn((5, 6, 10))
  3. q = torch.randn(5, 8, 6)
  4. q_lens = torch.LongTensor([3, 4, 5, 8, 2])
  5. a_v, a_q, v, q = pcan(v, q, q_lens)
  6. print(a_v)
  7. print(a_v.shape)
  8. print(a_q)
  9. print(a_q.shape)
  10. print(v)
  11. print(v.shape)
  12. print(q)
  13. print(q.shape)

效果如下:

  1. tensor([[[9.2527e-04, 1.1542e-03, 1.1542e-03, 1.1542e-03, 2.0009e-02,
  2. 9.2527e-04, 4.0845e-02, 8.8328e-01, 1.1958e-03, 4.9358e-02]],
  3. [[4.5501e-01, 8.6522e-02, 8.6522e-02, 1.7235e-05, 3.8831e-03,
  4. 2.5070e-04, 9.0637e-05, 4.0010e-03, 2.0196e-03, 3.6169e-01]],
  5. [[8.8455e-03, 7.2149e-04, 1.7595e-04, 2.1307e-04, 7.0610e-01,
  6. 1.3427e-01, 4.3360e-04, 4.0731e-02, 4.0731e-02, 6.7774e-02]],
  7. [[4.0013e-01, 2.3081e-02, 3.8406e-02, 4.3583e-03, 9.9425e-05,
  8. 3.8398e-02, 9.9425e-05, 9.4912e-02, 4.0013e-01, 3.9162e-04]],
  9. [[3.1121e-02, 8.0567e-05, 4.0445e-01, 1.4391e-03, 8.0567e-05,
  10. 4.0445e-01, 7.6909e-02, 2.4837e-04, 4.3044e-03, 7.6909e-02]]],
  11. grad_fn=<SoftmaxBackward>)
  12. torch.Size([5, 1, 10])
  13. tensor([[[0.3466, 0.3267, 0.3267, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
  14. [[0.2256, 0.3276, 0.2237, 0.2232, 0.0000, 0.0000, 0.0000, 0.0000]],
  15. [[0.1761, 0.2254, 0.2254, 0.1823, 0.1908, 0.0000, 0.0000, 0.0000]],
  16. [[0.1292, 0.1411, 0.1411, 0.1100, 0.1292, 0.1100, 0.1101, 0.1292]],
  17. [[0.5284, 0.4716, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
  18. grad_fn=<UnsqueezeBackward0>)
  19. torch.Size([5, 1, 8])
  20. tensor([[-0.7862, 1.0180, 0.1585, 0.4961, -1.5916, -0.3553],
  21. [ 0.3624, -0.2036, 0.2993, -0.4440, 0.2494, 1.4896],
  22. [ 0.1695, -0.2286, 0.4431, 0.6027, -1.6116, 0.0566],
  23. [ 0.2004, 0.8219, -0.2115, -0.6428, 0.3486, 1.3802],
  24. [ 1.4024, -0.1860, 0.1685, 0.2352, -0.4956, 1.0010]],
  25. grad_fn=<SqueezeBackward0>)
  26. torch.Size([5, 6])
  27. tensor([[ 0.3757, 0.1662, 0.2181, 0.0787, 0.0110, -0.5938],
  28. [-0.6106, 0.4000, 0.6068, -0.4054, 0.0193, -0.1147],
  29. [ 0.3877, -0.1800, 1.2430, -0.4881, -0.3598, -0.3592],
  30. [-0.3799, -0.3262, 0.0745, -0.2856, 0.0221, -0.1749],
  31. [ 0.1159, -0.4949, -0.5576, -0.6870, -1.2895, 0.0421]],
  32. grad_fn=<SqueezeBackward0>)
  33. torch.Size([5, 6])

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号