赞
踩
torch.save()有两种方法
1)仅保存网络参数
torch.save(net.state_dict(), 'net_params.pkl')
2)保存整个网络结构
torch.save(net, 'net.pkl')
1)仅加载参数
model_object.load_state_dict(torch.load('net_params.pkl'))
2)加载整个模型
model = torch.load('net.pkl')
两种方法在载入模型时都需要有预设的网络结构,例如下边代码,否则会提示找不到相应的module
- #加载整个网络
-
- class AlexNet(nn.Module):
- def __init__(self):
- super(AlexNet,self).__init__()
- self.conv1 = nn.Conv2d(3, 64, 5)
- self.pool1 = nn.MaxPool2d(3, 2)
- self.conv2 = nn.Conv2d(64, 64, 5)
- self.pool2 = nn.MaxPool2d(3, 2)
- self.fc1 = nn.Linear(1024, 384)
- self.fc2 = nn.Linear(384, 192)
- self.fc3 = nn.Linear(192, 10)
-
- def forward(self, x):
- x = self.pool1(F.relu(self.conv1(x)))
- x = self.pool2(F.relu(self.conv2(x)))
- x = x.view(x.shape[0], -1)
- x = F.relu(self.fc1(x))
- x = F.relu(self.fc2(x))
- x = F.softmax(self.fc3(x))
- return x
-
- net = torch.load("TestSave.pkl")#加载整个模型时直接用这句就可以实例化网络,并且把CUDA上运行这个属性也继承了过来
- #只加载网络参数
-
- class AlexNet(nn.Module):
- def __init__(self):
- super(AlexNet,self).__init__()
- self.conv1 = nn.Conv2d(3, 64, 5)
- self.pool1 = nn.MaxPool2d(3, 2)
- self.conv2 = nn.Conv2d(64, 64, 5)
- self.pool2 = nn.MaxPool2d(3, 2)
- self.fc1 = nn.Linear(1024, 384)
- self.fc2 = nn.Linear(384, 192)
- self.fc3 = nn.Linear(192, 10)
-
- def forward(self, x):
- x = self.pool1(F.relu(self.conv1(x)))
- x = self.pool2(F.relu(self.conv2(x)))
- x = x.view(x.shape[0], -1)
- x = F.relu(self.fc1(x))
- x = F.relu(self.fc2(x))
- x = F.softmax(self.fc3(x))
- return x
-
- net = AlexNet()#只加载网络参数的时候需要自行实例化网络
- net.cuda()#并设置网络运行在cpu还是gpu上
-
- net.load_state_dict(torch.load('net_params.pkl'))#再加载网络的参数
注意:
1.只加载网络参数的速度比加载整个网络快得多
2.pth、pkl格式效果相同,ckpt是tensorflow的格式
参考链接:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。