赞
踩
本文共介绍四种方法,分别是SumFusion、ConcatFusion、FiLM以及GatedFusion
FiLM参考paper-FiLM: Visual Reasoning with a General Conditioning Layer
GatedFusion参考paper-Efficient Large-Scale Multi-Modal Classification
- import torch
- import torch.nn as nn
-
- #------------------------------------------#
- # SumFusion的定义,为两者过全连接层后进行直接相加
- #------------------------------------------#
- class SumFusion(nn.Module):
- def __init__(self, input_dim=512, output_dim=100):
- super(SumFusion, self).__init__()
- #---------------------------------------#
- # 针对x以及y两个特征张量,分别定义了两个全连接层
- #---------------------------------------#
- self.fc_x = nn.Linear(input_dim, output_dim)
- self.fc_y = nn.Linear(input_dim, output_dim)
-
- def forward(self, x, y):
- output = self.fc_x(x) + self.fc_y(y)
- return x, y, output
-
- #------------------------------------------#
- # ConcatFusion的定义,只定义一个全连接层
- # 首先将两者堆叠,之后再将堆叠后的向量送入至全连接层
- #------------------------------------------#
- class ConcatFusion(nn.Module):
- def __init__(self, input_dim=1024, output_dim=100):
- super(ConcatFusion, self).__init__()
- self.fc_out = nn.Linear(input_dim, output_dim)
-
- def forward(self, x, y):
- output = torch.cat((x, y), dim=1)
- output = self.fc_out(output)
- return x, y, output
-
- #------------------------------------------#
- # FiLM融合方法的定义,只定义一个全连接层
- #------------------------------------------#
- class FiLM(nn.Module):
- """
- FiLM: Visual Reasoning with a General Conditioning Layer,
- https://arxiv.org/pdf/1709.07871.pdf.
- """
- def __init__(self, input_dim=512, dim=512, output_dim=100, x_film=True):
- super(FiLM, self).__init__()
- self.dim = input_dim
- self.fc = nn.Linear(input_dim, 2 * dim)
- self.fc_out = nn.Linear(dim, output_dim)
- self.x_film = x_film
-
- def forward(self, x, y):
- if self.x_film:
- film = x
- to_be_film = y
- else:
- film = y
- to_be_film = x
-
- gamma, beta = torch.split(self.fc(film), self.dim, 1)
-
- output = gamma * to_be_film + beta
- output = self.fc_out(output)
-
- return x, y, output
-
- #------------------------------------------#
- # GatedFusion方法的定义
- #------------------------------------------#
- class GatedFusion(nn.Module):
- """
- Efficient Large-Scale Multi-Modal Classification,
- https://arxiv.org/pdf/1802.02892.pdf.
- """
-
- def __init__(self, input_dim=512, dim=512, output_dim=100, x_gate=True):
- super(GatedFusion, self).__init__()
- self.fc_x = nn.Linear(input_dim, dim)
- self.fc_y = nn.Linear(input_dim, dim)
- self.fc_out = nn.Linear(dim, output_dim)
- self.x_gate = x_gate # whether to choose the x to obtain the gate
- self.sigmoid = nn.Sigmoid()
-
- def forward(self, x, y):
- out_x = self.fc_x(x)
- out_y = self.fc_y(y)
-
- if self.x_gate:
- gate = self.sigmoid(out_x)
- output = self.fc_out(torch.mul(gate, out_y))
- else:
- gate = self.sigmoid(out_y)
- output = self.fc_out(torch.mul(out_x, gate))
-
- return out_x, out_y, output
-
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。