赞
踩
torch.nn.RNN()
import torch
from torch.autograd import Variable # 获取变量
from torch.utils import data # 获取迭代数
import torchvision
from torchvision.datasets import mnist # 获取数据集
# 预处理
data_tf = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.5],[0.5])]
)
# 获取数据
data_path = r'C:\Users\liev\Desktop\myproject\yin_test\MNIST_DATA_PyTorch'
train_data = mnist.MNIST(data_path,train=True,transform=data_tf,download=False)
test_data = mnist.MNIST(data_path,train=False,transform=data_tf,download=False)
# 定义网络结构
class RNNnet(torch.nn.Module):
def __init__(self):
super(RNNnet, self).__init__()
self.rnn1 = torch.nn.RNN(784,100,3,nonlinearity='relu')
self.rnn2 = torch.nn.RNN(100,10,1,nonlinearity='relu')
def forward(self, x):
x = self.rnn1(x)
x = torch.Tensor(x[1])
x = self.rnn2(
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。