当前位置:   article > 正文

多模态学习中四种常用的跨模态特征融合方法定义与PyTorch实现

跨模态特征融合

本文共介绍四种方法,分别是SumFusion、ConcatFusion、FiLM以及GatedFusion

FiLM参考paper-FiLM: Visual Reasoning with a General Conditioning Layer

GatedFusion参考paper-Efficient Large-Scale Multi-Modal Classification

  1. import torch
  2. import torch.nn as nn
  3. #------------------------------------------#
  4. # SumFusion的定义,为两者过全连接层后进行直接相加
  5. #------------------------------------------#
  6. class SumFusion(nn.Module):
  7. def __init__(self, input_dim=512, output_dim=100):
  8. super(SumFusion, self).__init__()
  9. #---------------------------------------#
  10. # 针对x以及y两个特征张量,分别定义了两个全连接层
  11. #---------------------------------------#
  12. self.fc_x = nn.Linear(input_dim, output_dim)
  13. self.fc_y = nn.Linear(input_dim, output_dim)
  14. def forward(self, x, y):
  15. output = self.fc_x(x) + self.fc_y(y)
  16. return x, y, output
  17. #------------------------------------------#
  18. # ConcatFusion的定义,只定义一个全连接层
  19. # 首先将两者堆叠,之后再将堆叠后的向量送入至全连接层
  20. #------------------------------------------#
  21. class ConcatFusion(nn.Module):
  22. def __init__(self, input_dim=1024, output_dim=100):
  23. super(ConcatFusion, self).__init__()
  24. self.fc_out = nn.Linear(input_dim, output_dim)
  25. def forward(self, x, y):
  26. output = torch.cat((x, y), dim=1)
  27. output = self.fc_out(output)
  28. return x, y, output
  29. #------------------------------------------#
  30. # FiLM融合方法的定义,只定义一个全连接层
  31. #------------------------------------------#
  32. class FiLM(nn.Module):
  33. """
  34. FiLM: Visual Reasoning with a General Conditioning Layer,
  35. https://arxiv.org/pdf/1709.07871.pdf.
  36. """
  37. def __init__(self, input_dim=512, dim=512, output_dim=100, x_film=True):
  38. super(FiLM, self).__init__()
  39. self.dim = input_dim
  40. self.fc = nn.Linear(input_dim, 2 * dim)
  41. self.fc_out = nn.Linear(dim, output_dim)
  42. self.x_film = x_film
  43. def forward(self, x, y):
  44. if self.x_film:
  45. film = x
  46. to_be_film = y
  47. else:
  48. film = y
  49. to_be_film = x
  50. gamma, beta = torch.split(self.fc(film), self.dim, 1)
  51. output = gamma * to_be_film + beta
  52. output = self.fc_out(output)
  53. return x, y, output
  54. #------------------------------------------#
  55. # GatedFusion方法的定义
  56. #------------------------------------------#
  57. class GatedFusion(nn.Module):
  58. """
  59. Efficient Large-Scale Multi-Modal Classification,
  60. https://arxiv.org/pdf/1802.02892.pdf.
  61. """
  62. def __init__(self, input_dim=512, dim=512, output_dim=100, x_gate=True):
  63. super(GatedFusion, self).__init__()
  64. self.fc_x = nn.Linear(input_dim, dim)
  65. self.fc_y = nn.Linear(input_dim, dim)
  66. self.fc_out = nn.Linear(dim, output_dim)
  67. self.x_gate = x_gate # whether to choose the x to obtain the gate
  68. self.sigmoid = nn.Sigmoid()
  69. def forward(self, x, y):
  70. out_x = self.fc_x(x)
  71. out_y = self.fc_y(y)
  72. if self.x_gate:
  73. gate = self.sigmoid(out_x)
  74. output = self.fc_out(torch.mul(gate, out_y))
  75. else:
  76. gate = self.sigmoid(out_y)
  77. output = self.fc_out(torch.mul(out_x, gate))
  78. return out_x, out_y, output

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/在线问答5/article/detail/924082
推荐阅读
相关标签
  

闽ICP备14008679号