赞
踩
# imports import snntorch as snn from snntorch import surrogate # from snntorch import backprop from snntorch import functional as SF from snntorch import utils # from snntorch import spikeplot as splt import torch import torch.nn as nn from torch.utils.data import DataLoader from torchvision import datasets, transforms # import torch.nn.functional as F import matplotlib.pyplot as plt # import numpy as np # import itertools # dataloader arguments batch_size = 64 # 内存不够128 data_path = 'data' dtype = torch.float # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cpu") # 这个电脑的cuda版本太低了 # Define a transform transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.Grayscale(), transforms.ToTensor(), transforms.Normalize((0,), (1,))]) # 其实这行都没什么用 mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform) mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform) # Create DataLoaders train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True) test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True) # neuron and simulation parameters spike_grad = surrogate.fast_sigmoid(slope=25) # 替代梯度,用于反向传播 beta = 0.5 # 神经元膜电位的衰减率 num_steps = 50 # 时间步(SNN特有的) # Define Network # class Net(nn.Module): # def __init__(self): # super().__init__() # # # Initialize layers # self.conv1 = nn.Conv2d(1, 12, 5) # self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad) # self.conv2 = nn.Conv2d(12, 64, 5) # self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad) # self.fc1 = nn.Linear(64*4*4, 10) # self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad) # # def forward(self, x): # # # Initialize hidden states and outputs at t=0 # mem1 = self.lif1.init_leaky() # mem2 = self.lif2.init_leaky() # mem3 = self.lif3.init_leaky() # # cur1 = F.max_pool2d(self.conv1(x), 2) # spk1, mem1 = self.lif1(cur1, mem1) # # cur2 = F.max_pool2d(self.conv2(spk1), 2) # spk2, mem2 = self.lif2(cur2, mem2) # # cur3 = self.fc1(spk2.view(batch_size, -1)) # spk3, mem3 = self.lif3(cur3, mem3) # # return spk3, mem3 # Initialize Network net = nn.Sequential(nn.Conv2d(1, 12, 5), nn.MaxPool2d(2), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True), nn.Conv2d(12, 64, 5), nn.MaxPool2d(2), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True), nn.Flatten(), nn.Linear(64 * 4 * 4, 10), snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True) ).to(device) def forward_pass(net, num_steps, data): mem_rec = [] spk_rec = [] utils.reset(net) # resets hidden states for all LIF neurons in net for step in range(num_steps): spk_out, mem_out = net(data) spk_rec.append(spk_out) mem_rec.append(mem_out) # 难道说每个step都不会初始化了,我觉得应该是这样 return torch.stack(spk_rec), torch.stack(mem_rec) loss_fn = SF.ce_rate_loss() def batch_accuracy(train_loader, net, num_steps): # 训练集一个batch的accuracy with torch.no_grad(): total = 0 acc = 0 net.eval() train_loader = iter(train_loader) for data, targets in train_loader: data = data.to(device) targets = targets.to(device) spk_rec, _ = forward_pass(net, num_steps, data) acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1) total += spk_rec.size(1) return acc / total optimizer = torch.optim.Adam(net.parameters(), lr=1e-2, betas=(0.9, 0.999)) num_epochs = 1 loss_hist = [] test_acc_hist = [] counter = 0 # Outer training loop for epoch in range(num_epochs): # 只训练1个epoch!? # Training loop for data, targets in iter(train_loader): data = data.to(device) targets = targets.to(device) # forward pass net.train() #进行num_steps次循环,完成一次前向传播 spk_rec, _ = forward_pass(net, num_steps, data) # initialize the loss & sum over time loss_val = loss_fn(spk_rec, targets) #和正确的标签进行计算得到损失 # Gradient calculation + weight update optimizer.zero_grad() #根据损失计算梯度 loss_val.backward() #更新梯度 optimizer.step() # Store loss history for future plotting loss_hist.append(loss_val.item()) # Test set if counter % 50 == 0: with torch.no_grad(): net.eval() # Test set forward pass test_acc = batch_accuracy(test_loader, net, num_steps) print(f"Iteration {counter}, Test Acc: {test_acc * 100:.2f}%\n") test_acc_hist.append(test_acc.item()) counter += 1 # Plot Loss fig = plt.figure(facecolor="w") plt.plot(test_acc_hist) plt.title("Test Set Accuracy") plt.xlabel("Epoch") plt.ylabel("Accuracy") plt.show()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。