当前位置:   article > 正文

(六)处理多维特征的输入(下)+pytorch实现糖尿病数据集的逻辑回归

(六)处理多维特征的输入(下)+pytorch实现糖尿病数据集的逻辑回归

糖尿病数据集的逻辑回归

数据集

pima-indians-diabetes: 一、 数据说明: Pima Indians Diabetes Data Set(皮马印第安人糖尿病数据集) 根据现有的医疗信息预测5年内皮马印第安人糖尿病发作的概率。 数据链接:https://archive.ics.uci.edu/ml/datasets/Pima+Indians+Diabeteshttps://gitee.com/biabianm/pima-indians-diabetespima-indians-diabetes: 一、 数据说明: Pima Indians Diabetes Data Set(皮马印第安人糖尿病数据集) 根据现有的医疗信息预测5年内皮马印第安人糖尿病发作的概率。 数据链接:https://archive.ics.uci.edu/ml/datasets/Pima+Indians+Diabetes

 解压后里面有一个pima-indians-diabetes.csv文件 复制粘贴到python代码统一目录下

 m_features.py如下

np.loadtxt()用于从文本加载数据。

loadtxt(fname, dtype=<class 'float'>, comments='#', delimiter=None, converters=None, skiprows=0, usecols=None, unpack=False, ndmin=0)

 xy = np.loadtxt('pima-indians-diabetes.csv', delimiter=',', dtype=np.float32,skiprows = 1)

fname要读取的文件、文件名、或生成器。

dtype数据类型,默认float。

comments注释。

delimiter分隔符,默认是空格。

skiprows跳过前几行读取,默认是0,必须是int整型。

usecols要读取哪些列,0是第一列。例如,usecols = (1,4,5)将提取第2,第5和第6列。默认读取所有列。

unpack如果为True,将分列读取。
 

  1. import torch
  2. import torch.nn.functional as F
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. # 加载数据集
  6. xy = np.loadtxt('pima-indians-diabetes.csv', delimiter=',', dtype=np.float32,skiprows = 1)
  7. # 数据预处理,包括从数据集里区分输入输出,最后把输入输出数据封装成Pytorch期望的Variable格式
  8. X_train= torch.from_numpy(xy[:,:-1]) # 特征信息
  9. y_train= torch.from_numpy(xy[:,[-1]]) # 目标分类
  10. class Model(torch.nn.Module):
  11. def __init__(self):
  12. super(Model, self).__init__()
  13. self.linear1=torch.nn.Linear(8,6)
  14. self.linear2= torch.nn.Linear(6,4)
  15. self.linear3= torch.nn.Linear(4,1)
  16. self.sigmoid=torch.nn.Sigmoid()
  17. def forward(self,x):
  18. x=self.sigmoid(self.linear1(x))
  19. x = self.sigmoid(self.linear2(x))
  20. x = self.sigmoid(self.linear3(x))
  21. return x
  22. model=Model()
  23. criterion=torch.nn.BCELoss(size_average=True)
  24. optimizer=torch.optim.SGD(model.parameters(),lr=0.1)
  25. px,py = [],[] # 记录要绘制的数据
  26. for epoch in range(1000):
  27. y_pred=model(X_train)
  28. loss=criterion(y_pred,y_train)
  29. print(epoch,loss.item())
  30. optimizer.zero_grad()
  31. loss.backward()
  32. optimizer.step()
  33. px.append(epoch)
  34. py.append(loss.item())
  35. plt.plot(px, py)
  36. plt.ylabel('loss')
  37. plt.xlabel('epoch')
  38. plt.show()
  39. # # 每十次迭代绘制训练动态
  40. # if epoch% 10 == 0:
  41. # plt.cla()
  42. # plt.plot(px, py, 'r-', lw=1)
  43. # plt.text(0, 0, 'Loss=%.4f' % loss.item(), fontdict={'size': 20, 'color': 'red'})
  44. # plt.pause(0.1)

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/671197
推荐阅读
相关标签
  

闽ICP备14008679号