当前位置:   article > 正文

【实测】小白一看就学会的python的AI模型torch(中)_从零开始使用pytorch训练一个简单的ai预测模型

从零开始使用pytorch训练一个简单的ai预测模型

上节文章我们搞定了一个最简单的预测,那么本节我们来提高难度:

看下我假设的这个场景:

预测一对夫妻结婚多久才能生娃,生的什么娃?结婚多久才能买房,买的什么房?

这个问题中,明显要比之前的复杂很多,上节就是一个人来预测结婚年龄等。这节课,是俩个人的信息来预测,而且预测的事呢,虽然还是俩件,但是每件里都要有更详细的信息。

来看我准备的一个测试输入数据:

这是一对夫妻的个人数据,俩个人在一起就是一个二维列表了:

[[1, 174, 33, 900, 1], [2, 164, 21, 100, 3]]   


然后是他们的结果输出数据:

[[2,1,5],[4,1,120]]  

含义:

[[结婚两年,男孩,5斤],[结婚4年,楼房,120平]]


好,注意看,这只是一对夫妻的数据,就已经是一个二维列表了,如果我们要拿出很多对夫妻的话,那么这个输入和输出数据,都会变成三维的!

在python的AI库torch的官方教程中是明确可以支持多维度数据的。所以应该可以继续使用。那么代码中get_test_data就要变成如下的:

图片

如图,in_data和out_data都变成三维的了!

如果我们不改变make_AI函数,直接训练那么结果如下:

图片

原因很简单,这些测试数据和咱们的模型是不匹配的。

经过调整,代码改成如下:

import torchimport torch.nn as nnimport torch.optim as optim

# 自定义的WQRF模型class WQRF(nn.Module):    def __init__(self, input_features=5, hidden_features=16, output_features=3):        super(WQRF, self).__init__()        # 输入层,假设每个样本有两个5维特征向量        self.fc1 = nn.Linear(input_features * 2, hidden_features)        self.relu = nn.ReLU()        # 输出层,假设每个样本有两个3维输出向量        self.fc2 = nn.Linear(hidden_features, output_features * 2)
    def forward(self, x):        # 展平输入,每个样本的两个5维特征向量变为10        x = x.view(x.size(0), -1)        x = self.fc1(x)        x = self.relu(x)        x = self.fc2(x)        # 重塑输出,每个样本的6维输出变为两个3维向量        x = x.view(x.size(0), 2, -1)        return x
    # 获取测试数据的函数

def get_test_data():    in_data = torch.tensor([        [[1, 174, 33, 900, 1], [2, 164, 21, 100, 3]],        [[2, 144, 23, 100, 2], [1, 154, 31, 150, 1]]    ], dtype=torch.float32)    out_data = torch.tensor([        [[2, 1, 5], [4, 1, 120]],        [[5, 2, 6], [10, 1, 50]]    ], dtype=torch.float32)    return in_data, out_data

# 示例的make_AI函数,用于训练和保存模型def make_AI():    # 创建WQRF模型实例    wqrf = WQRF()
    # 定义损失函数和优化器    criterion = nn.MSELoss()    optimizer = optim.SGD(wqrf.parameters(), lr=0.01)
    # 获取测试数据    in_data, out_data = get_test_data()
    # 训练过程    for i in range(10000):        optimizer.zero_grad()        outputs = wqrf(in_data)  # 前向传播        loss = criterion(outputs.view(-1, 3), out_data.view(-1, 3))  # 计算损失,需要展平输出和真实值        loss.backward()  # 反向传播        optimizer.step()  # 更新权重
        # 打印损失        print('Loss:', loss.item())
    # 保存模型(如果需要)    torch.save(wqrf.state_dict(), 'WQRF_AI.pt')

其中在WQRF模型基本结构中,对输入和输出数据进行了重构降维和展平,让其的形状能互相匹配上。这一步是有规则的,如果匹配不上就是保错。当然下面的make_AI函数要进行相应的更改。然后还是要训练10000次!

训练结果的损失最终稳定在了206

图片

如果我们想看看更详细的:

那就在循环里加上这句:

图片

print('【预期输出】:\n%s\n【实际最后输出】:\n%s\n【指标误差】:\n%s' % (out_data.numpy(), outputs.detach().numpy(), loss.item()))

最后一次的结果如下:

图片

因为中间节点才16个,加上测试数据只有3个,所以也就能预测到这个级别了。想更精准,就要继续加数据!加节点!

然后就是使用这个模型来预测新的数据喽!

图片

这次的play_AI函数,要写的更专业更复杂一点了哦~ 因为中间有涉及到要给结果进行重塑的过程,还是有点烧脑的。

def play_AI(model_path, new_data):    # 实例化模型    wqrf = WQRF()
    # 加载模型参数    wqrf.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    # 评估模式(关闭dropout和batch normalization的训练时行为)    wqrf.eval()
    # 假设new_data是一个包含新数据的列表,每个元素是两个特征向量的列表    new_data_tensor = torch.tensor(new_data, dtype=torch.float32)
    # 如果在GPU上训练,则需要将模型和数据移到GPU上    # 这里假设我们在CPU上运行    device = torch.device('cpu')    wqrf.to(device)    new_data_tensor = new_data_tensor.to(device)
    # 添加一个batch维度(如果需要的话)    if len(new_data_tensor.shape) == 2:        new_data_tensor = new_data_tensor.unsqueeze(0)
        # 重塑输入数据以匹配模型的输入要求(假设每个样本有两个特征向量)    new_data_tensor = new_data_tensor.view(new_data_tensor.size(0), -1)
    # 进行预测    with torch.no_grad():  # 关闭梯度计算        predictions = wqrf(new_data_tensor)
        # 重塑预测结果以匹配每个样本有两个3维向量的输出    predictions = predictions.view(predictions.size(0), 2, -1)
    # 打印预测结果    print('预测结果:', predictions.detach().cpu().numpy())

# 示例新数据new_data = [ [[1, 180, 35, 850, 0.5], [2, 170, 25, 95, 2.8]]]
# 假设模型参数保存在'WQRF_AI.pt'文件中play_AI('WQRF_AI.pt', new_data)

我选了一个测试数据:

 
 [ [[1, 180, 35, 850, 1], [2, 170, 32, 555, 2]]]

结果为:

图片

结果翻译过来就是:

结婚3年,生个男孩,5斤5两。结婚7年,买了楼房,85平米。

好,本节到此结束,欢迎继续期待下一节,点赞过百更新!

咨询加:qingwanjianhua

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/喵喵爱编程/article/detail/762906
推荐阅读
相关标签
  

闽ICP备14008679号