赞
踩
目录
- import torch
- from torch import optim, nn
- import torch.utils.data as Data
- x = torch.linspace(1, 10, 10) # x data (torch tensor)
- y = torch.linspace(10, 1, 10) # y data (torch tensor)
-
- # 注意:x的数据类型是 torch.FloatTensor
- # y的数据类型是 torch.LongTensor
- # x = torch.cat((x0, x1), 0).type(torch.FloatTensor) # FloatTensor = 32-bit floating
- # y = torch.cat((y0, y1), ).type(torch.LongTensor) # LongTensor = 64-bit integer
- # 注意:这个y的数据类型,如果是分类,可以这样(LongTensor:长整型),但是做预测,就需要修改为.double()或者.float()
-
- # 先转换成 torch 能识别的 Dataset
- torch_dataset = Data.TensorDataset(x, y)
-
- # 把 dataset 放入 DataLoader
- loader = Data.DataLoader(
- dataset=torch_dataset, # torch TensorDataset format
- batch_size=3, # mini batch size
- shuffle=True, # 要不要打乱数据 (打乱比较好)
- num_workers=0, # 多线程来读数据
- )
- # 定义网络结构 build net
- class Net(torch.nn.Module):
- def __init__(self,n_feature,n_hidden,n_output):
- super(Net, self).__init__()
-
- self.fc1 =torch.nn.Linear(n_feature,n_hidden)
- self.fc2 =torch.nn.Linear(n_hidden,n_output)
-
- # 定义一个前向传播过程函数
- def forward(self, x):
-
- x=F.relu(self.fc1(x))
- x=self.fc2(x)
- return x
- # 实例化一个网络为 model
- model = Net(n_feature=1,n_hidden=10,n_output=10)
- print(model)
- optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
- loss_func = nn.CrossEntropyLoss()
-
- # 训练模型
- model.train()
- for epoch in range(5):
- for step, (b_x, b_y) in enumerate(loader):
- output = model(b_x)
- loss = loss_func(output, b_y)
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- # 测试模型
- model.eval()
- for step, (b_x, b_y) in enumerate(loader):
- output = model(b_x)
- loss = loss_func(output, b_y)
-
- _, pred_y = torch.max(output.data, 1)
- correct = (pred_y == b_y).sum()
- total = b_y.size(0)
- print('Epoch: ', step, '| test loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % (float(correct)/total))
- # 保存模型参数
- torch.save(model.state_dict(), './path/model.pkl')
- # 读取模型参数
- model.load_state_dict(torch.load('./path/model.pkl'))
- # 保存整个模型
- torch.save(model, './path/model.pkl')
- # 加载整个模型
- model = torch.load('./path/model.pkl')
- # 多个模型参数保存
- torch.save({
- 'epoch': epoch,
- 'model_state_dict': model.state_dict(),
- 'optimizer_state_dict': optimizer.state_dict(),
- 'loss': loss,
- ...
- }, PATH)
-
- # 模型参数加载
- checkpoint = torch.load(PATH)
- model.load_state_dict(checkpoint['model_state_dict'])
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
- epoch = checkpoint['epoch']
- loss = checkpoint['loss']
在已有的数据集中增加数据:
- import torch
- import torch.utils.data as Data
-
- a = Data.TensorDataset(torch.tensor([5]), torch.tensor([0]))
- b = Data.TensorDataset(torch.tensor([[85],[54]]), torch.tensor([[6],[4]]))
- a += b
- for i in a:
- print(i)
这样,我们就可以利用这种方法在训练时候增加数据:
- import torch
- import torch.utils.data as Data
- from torch.utils.data import DataLoader
-
- train_dataset = Data.TensorDataset(torch.tensor([5]), torch.tensor([0]))
-
- # 创建数据加载器
- train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
-
- n_queries = 100
- for i in range(n_queries):
- model.train(train_loader)
-
- # 添加新样本至训练数据集
- new_data = torch.utils.data.TensorDataset(torch.tensor([[85],[54]]), torch.tensor([[6],[4]]))
- train_dataset += new_data
-
- # 更新训练数据加载器
- train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
在数据库中,指定删除对应索引下的数据
- import torch
- import torch.utils.data as Data
- from torch.utils.data import Subset
- import numpy as np
-
- # 原始数据集
- b = Data.TensorDataset(torch.tensor([[85],[54],[12],[23]]), torch.tensor([[6],[4],[8],[7]]))
-
- # 我们需要数据的去除的下标
- indices = torch.tensor([0,2],dtype=torch.int)
-
- # b数据中全部数据下标
- indices_ = torch.tensor(range(len(b)),dtype=torch.int)
-
- # b数据中去除后的下标
- # 注意,indices_和indices对比的字符类型必须相同,否则无法比较
- indices_new = indices_.numpy()[~np.isin(indices_.numpy(), indices)]
-
- # 在b数据中,根据下标重新建立数据加载模块
- # 注意,这个indices_new下标,无论什么格式都可以,即torch、numpy、列表都行,但是 必须是整数
- new_dataset = Subset(b, indices_new)
-
- # 创建数据加载器
- train_loader = Data.DataLoader(new_dataset, batch_size=32, shuffle=True)
-
- for i in new_dataset:
- print(i)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。