当前位置:   article > 正文

Pytorch(六)(模型参数的遍历) —— model.parameters() & model.named_parameters() & model.state_dict()

model.named_parameters()

神经网络的模型参数

model.parameters(), model.named_parameters(), model.state_dict() 这三个方法都可以查看神经网络的参数信息,用于更新参数,或者用于模型的保存。作用都类似,写法略有出入

就以Pytorch之经典神经网络(一) —— 全连接网络(MNIST) 来举例 Pytorch之经典神经网络CNN(一) —— 全连接网络 / MLP (MNIST) (trainset和Dataloader & batch training & learning_rate)_hxxjxw的博客-CSDN博客   

print(*[name for name, _ in self.model.named_parameters()], sep='\n')
print(*set([name.split('.')[0] for name, _ in self.named_parameters()]), sep='\n')
查看网络模型参数是否可训练
print(*[_.requires_grad for name, _ in model.named_parameters()], sep='\n')

model.named_parameters()

net.named_parameters()中param是len为2的tuple
param[0]是name,fc1.weight、fc1.bias等
param[1]是fc1.weight、fc1.bias等对应的值

一直是0,1,2,......, 这种序号

  1. for _,param in enumerate(net.named_parameters()):
  2. print(param[0])
  3. print(param[1])
  4. print('----------------')

model.parameters()

net.parameters()中param就是fc1.weight、fc1.bias等对应的值,没带名字

  1. for _,param in enumerate(net.parameters()):
  2. print(param)
  3. print('----------------')

model.state_dict()

net.state_dict() 中的param就只是str字符串 fc1.weight, fc1.bias等等

但它们可以作为参数来输出对应的值

  1. for _,param in enumerate(net.state_dict()):
  2. print(param)
  3. print(net.state_dict()[param])
  4. print('----------------')

神经网络的各个层

当神经网络是这么定义的时候,即没有用nn.Sequential()

此时 print(net)

  1. net = Net()
  2. print(net)

输出单个的网络层

  1. net = Net()
  2. print(net.fc1)
  3. print(net.fc2)
  4. print(net.fc3)

输出各个网络层的weight,bias参数

  1. net = Net()
  2. print(net.fc1.weight)
  3. print(net.fc1.bias)
  4. print(net.fc2.weight)
  5. print(net.fc2.bias)
  6. print(net.fc3.weight)
  7. print(net.fc3.bias)

当使用nn.Sequential定义的时候

  1. import torch
  2. import torchvision
  3. from torchvision import transforms
  4. from matplotlib import pyplot as plt
  5. from torch import nn
  6. from torch.nn import functional as F
  7. from utils import plot_image,plot_curve,one_hot
  8. # class Net(nn.Module):
  9. # def __init__(self):
  10. # super(Net, self).__init__()
  11. #
  12. # #三层全连接层
  13. # #wx+b
  14. # self.fc1 = nn.Linear(28*28, 256)
  15. # self.fc2 = nn.Linear(256,64)
  16. # self.fc3 = nn.Linear(64,10)
  17. #
  18. # def forward(self, x):
  19. # x = F.rule(self.fc1(x)) #F.relu和torch.relu,用哪个都行
  20. # x = F.relu(self.fc2(x))
  21. # x = F.relu(self.fc(3))
  22. #
  23. # return x
  24. class Net(nn.Module):
  25. def __init__(self):
  26. super(Net, self).__init__()
  27. self.fc = nn.Sequential(
  28. nn.Linear(28 * 28, 256),
  29. nn.ReLU(),
  30. nn.Linear(256, 64),
  31. nn.ReLU(),
  32. nn.Linear(64, 10)
  33. )
  34. def forward(self, x):
  35. # x: [b, 1, 28, 28]
  36. # h1 = relu(xw1+b1)
  37. x = self.fc(x)
  38. return x
  39. batch_size = 512
  40. #一次处理的图片的数量
  41. #gpu一次可以处理并行多张图片
  42. transform = transforms.Compose([
  43. torchvision.transforms.ToTensor(),
  44. torchvision.transforms.Normalize((0.1307,), (0.3081,))
  45. ])
  46. trainset = torchvision.datasets.MNIST(
  47. root='dataset/',
  48. train=True, #如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
  49. download=True, #如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
  50. transform=transform
  51. )
  52. #train=True表示是训练数据,train=False是测试数据
  53. train_loader = torch.utils.data.DataLoader(
  54. dataset=trainset,
  55. batch_size=batch_size,
  56. shuffle=True #在加载的时候将图片随机打散
  57. )
  58. testset = torchvision.datasets.MNIST(
  59. root='dataset/',
  60. train=False,
  61. download=True,
  62. transform=transform
  63. )
  64. train_loader = torch.utils.data.DataLoader(
  65. dataset=testset,
  66. batch_size=batch_size,
  67. shuffle=True
  68. )
  69. net = Net()
  70. print(net.fc)
  71. print(net.fc[0])
  72. print(net.fc[1])
  73. print(net.fc[2])
  74. print(net.fc[3])
  75. print(net.fc[4])
  76. print(net.fc[0].weight)
  77. print(net.fc[0].bias)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/103461
推荐阅读
相关标签
  

闽ICP备14008679号