赞
踩
代码链接:https://github.com/hydxqing/ReNet-pytorch-keras-chapter3
本文提出了一种基于递归神经网络的用于图像识别的深度神经网络结构。所提出的网络被称为ReNet,用深度卷积神经网络中普遍存在的卷积+池化层替换为四个RNN,它们在图像的两个方向上水平和垂直扫描。
ReNet架构背后的基本思想是:四个RNN在不同的方向上扫描底层功能:
(1)从下到上到下,(2)从上到下,(3)从左到右,(4)从右到左。
循环层确保其输出中的每个特征激活都是相对于整个图像的特定位置的激活。
网络处理的步骤是:
代码中使用LSTM代替RNN网络
#coding:utf-8 import torch import torchvision import torchvision.transforms as transforms import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable import torch.optim as optim import numpy as np import sys from torch.autograd import gradcheck import time import math import argparse from torch.utils.data import DataLoader from torchvision.transforms import Compose, CenterCrop, Normalize, Resize, Pad from torchvision.transforms import ToTensor, ToPILImage from dataset import train,test from transform import Relabel, ToLabel, Colorize parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=1, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=1, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=3, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') parser.add_argument('--no-cuda', action='store_true', default=True, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') args = parser.parse_args() #args.cuda = not args.no_cuda and torch.cuda.is_available() args.cuda = False if args.cuda: torch.cuda.manual_seed(args.seed) receptive_filter_size = 4 hidden_size = 320 image_size_w = 32 image_size_h = 32 input_transform = Compose([ Resize((32,32)), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) target_transform = Compose([ Resize((32,32)), ToLabel(), ]) #trainset = torchvision.datasets.CIFAR10(root='./data', train=True, # download=True, transform=transform) #trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, # shuffle=True, num_workers=2) #trainloader = DataLoader(train(input_transform, target_transform),num_workers=1, batch_size=1, shuffle=True) #testloader = DataLoader(train(input_transform, target_transform),num_workers=1, batch_size=1, shuffle=True) #testset = torchvision.datasets.CIFAR10(root='./data', train=False, # download=True, transform=transform) #testloader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, # shuffle=False, num_workers=2) # renet with one layer class ReNet(nn.Module): def __init__(self, receptive_filter_size, hidden_size, batch_size, image_patches_height, image_patches_width): super(ReNet, self).__init__() self.batch_size = batch_size self.receptive_filter_size = receptive_filter_size self.input_size1 = receptive_filter_size * receptive_filter_size * 3 self.input_size2 = hidden_size * 2 self.hidden_size = hidden_size # vertical rnns self.rnn1 = nn.LSTM(self.input_size1, self.hidden_size, dropout = 0.2) self.rnn2 = nn.LSTM(self.input_size1, self.hidden_size, dropout = 0.2) # horizontal rnns self.rnn3 = nn.LSTM(self.input_size2, self.hidden_size, dropout = 0.2) self.rnn4 = nn.LSTM(self.input_size2, self.hidden_size, dropout = 0.2) self.initHidden() #feature_map_dim = int(image_patches_height*image_patches_height*hidden_size*2) self.conv1 = nn.Conv2d(hidden_size*2, 2, 3,padding=1)#[1,640,8,8]->[1,1,8,8] self.UpsamplingBilinear2d=nn.UpsamplingBilinear2d(size=(32,32), scale_factor=None) #self.dense = nn.Linear(feature_map_dim, 4096) #self.fc = nn.Linear(4096, 10) self.log_softmax = nn.LogSoftmax() def initHidden(self): self.hidden = (Variable(torch.zeros(1, self.batch_size, self.hidden_size)), Variable(torch.zeros(1, self.batch_size, self.hidden_size))) def get_image_patches(self, X, receptive_filter_size): """ creates image patches based on the dimension of a receptive filter """ image_patches = [] _, X_channel, X_height, X_width= X.size() for i in range(0, X_height, receptive_filter_size): for j in range(0, X_width, receptive_filter_size): X_patch = X[:, :, i: i + receptive_filter_size, j : j + receptive_filter_size] image_patches.append(X_patch) image_patches_height = (X_height // receptive_filter_size) image_patches_width = (X_width // receptive_filter_size) image_patches = torch.stack(image_patches) image_patches = image_patches.permute(1, 0, 2, 3, 4) image_patches = image_patches.contiguous().view(-1, image_patches_height, image_patches_width, receptive_filter_size * receptive_filter_size * X_channel) return image_patches def get_vertical_rnn_inputs(self, image_patches, forward): """ creates vertical rnn inputs in dimensions (num_patches, batch_size, rnn_input_feature_dim) num_patches: image_patches_height * image_patches_width """ vertical_rnn_inputs = [] _, image_patches_height, image_patches_width, feature_dim = image_patches.size() if forward: for i in range(image_patches_height): for j in range(image_patches_width): vertical_rnn_inputs.append(image_patches[:, j, i, :]) else:#倒着读 for i in range(image_patches_height-1, -1, -1): for j in range(image_patches_width-1, -1, -1): vertical_rnn_inputs.append(image_patches[:, j, i, :]) vertical_rnn_inputs = torch.stack(vertical_rnn_inputs) return vertical_rnn_inputs def get_horizontal_rnn_inputs(self, vertical_feature_map, image_patches_height, image_patches_width, forward): """ creates vertical rnn inputs in dimensions (num_patches, batch_size, rnn_input_feature_dim) num_patches: image_patches_height * image_patches_width """ horizontal_rnn_inputs = [] if forward: for i in range(image_patches_height): for j in range(image_patches_width): horizontal_rnn_inputs.append(vertical_feature_map[:, i, j, :]) else: for i in range(image_patches_height-1, -1, -1): for j in range(image_patches_width -1, -1, -1): horizontal_rnn_inputs.append(vertical_feature_map[:, i, j, :]) horizontal_rnn_inputs = torch.stack(horizontal_rnn_inputs) return horizontal_rnn_inputs def forward(self, X): """ReNet """ # divide input input image to image patches image_patches = self.get_image_patches(X, self.receptive_filter_size) _, image_patches_height, image_patches_width, feature_dim = image_patches.size() # process vertical rnn inputs vertical_rnn_inputs_fw = self.get_vertical_rnn_inputs(image_patches, forward=True) vertical_rnn_inputs_rev = self.get_vertical_rnn_inputs(image_patches, forward=False) # extract vertical hidden states vertical_forward_hidden, vertical_forward_cell = self.rnn1(vertical_rnn_inputs_fw, self.hidden) vertical_reverse_hidden, vertical_reverse_cell = self.rnn2(vertical_rnn_inputs_rev, self.hidden) # create vertical feature map vertical_feature_map = torch.cat((vertical_forward_hidden, vertical_reverse_hidden), 2) vertical_feature_map = vertical_feature_map.permute(1, 0, 2) # reshape vertical feature map to (batch size, image_patches_height, image_patches_width, hidden_size * 2) vertical_feature_map = vertical_feature_map.contiguous().view(-1, image_patches_width, image_patches_height, self.hidden_size * 2) vertical_feature_map.permute(0, 2, 1, 3) # process horizontal rnn inputs horizontal_rnn_inputs_fw = self.get_horizontal_rnn_inputs(vertical_feature_map, image_patches_height, image_patches_width, forward=True) horizontal_rnn_inputs_rev = self.get_horizontal_rnn_inputs(vertical_feature_map, image_patches_height, image_patches_width, forward=False) # extract horizontal hidden states horizontal_forward_hidden, horizontal_forward_cell = self.rnn3(horizontal_rnn_inputs_fw, self.hidden) horizontal_reverse_hidden, horizontal_reverse_cell = self.rnn4(horizontal_rnn_inputs_rev, self.hidden) # create horiztonal feature map[64,1,320] horizontal_feature_map = torch.cat((horizontal_forward_hidden, horizontal_reverse_hidden), 2) horizontal_feature_map = horizontal_feature_map.permute(1, 0, 2) # flatten[1,64,640] output = horizontal_feature_map.contiguous().view(-1, image_patches_height , image_patches_width , self.hidden_size * 2) output=output.permute(0,3,1,2)#[1,640,8,8] conv1=self.conv1(output) Upsampling=self.UpsamplingBilinear2d(conv1) # dense layer #output = F.relu(self.dense(output)) # fully connected layer #logits = self.fc(output) # log softmax logits = self.log_softmax(Upsampling) return logits def asMinutes(s): m = math.floor(s / 60) s -= m * 60 return '%dm %ds' % (m, s) def timeSince(since): now = time.time() s = now - since s = '%s' % (asMinutes(s)) return s if __name__ == "__main__": renet = ReNet(receptive_filter_size, hidden_size, args.batch_size, image_size_w/receptive_filter_size, image_size_h/receptive_filter_size) input = torch.ones((1,3,32,32)) out = renet(input) print(out)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。