当前位置:   article > 正文

8 多输出预测与多标签分类pytorch网络搭建

8 多输出预测与多标签分类pytorch网络搭建


前言

前面我们搭建的无论是分类还是回归都只能预测一个标签,这显然效果很局限。下面我们想做到下面这两种效果:

  • 多输出预测(回归):例如训练网络拟合北东天坐标转机体坐标的关系,输入是三坐标,输出也是三坐标
  • 多标签分类:例如,输入图像数据,训练网络判断图片里面有猫,有狗,还是只有其中一种这样

【注】:在介绍pytorch的内置损失函数博客中已经介绍了pytorch的损失函数是支持这个功能的。

一、多输出预测(回归)

1 坐标数据生成

# 本示例演示如何使用 PyTorch 实现多标签回归模型。
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# 构建数据集
# 假设您有一些经纬高度和对应的地心地固坐标的数据
# 这里只是一个示例,您需要根据实际情况准备您自己的数据集
X = np.random.rand(100, 3)  # 100个样本,每个样本有3个特征(经度、纬度、高度)
y = np.random.rand(100, 3)  # 每个样本有3个目标值(地心地固坐标)
print('y:\n',y)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

在这里插入图片描述

2 网络搭建训练预测

# 转换数据为 PyTorch 的 Tensor 类型
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)

# 定义模型
class MultiLabelRegressionModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(MultiLabelRegressionModel, self).__init__()
        self.fc = nn.Linear(input_size, output_size)
        
    def forward(self, x):
        out = self.fc(x)
        return out

# 初始化模型
input_size = 3   # 输入特征的数量
output_size = 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/酷酷是懒虫/article/detail/860463
推荐阅读
相关标签
  

闽ICP备14008679号