赞
踩
pima-indians-diabetes: 一、 数据说明: Pima Indians Diabetes Data Set(皮马印第安人糖尿病数据集) 根据现有的医疗信息预测5年内皮马印第安人糖尿病发作的概率。 数据链接:https://archive.ics.uci.edu/ml/datasets/Pima+Indians+Diabeteshttps://gitee.com/biabianm/pima-indians-diabetespima-indians-diabetes: 一、 数据说明: Pima Indians Diabetes Data Set(皮马印第安人糖尿病数据集) 根据现有的医疗信息预测5年内皮马印第安人糖尿病发作的概率。 数据链接:https://archive.ics.uci.edu/ml/datasets/Pima+Indians+Diabetes
np.loadtxt()
用于从文本加载数据。loadtxt(fname, dtype=<class 'float'>, comments='#', delimiter=None, converters=None, skiprows=0, usecols=None, unpack=False, ndmin=0)
fname要读取的文件、文件名、或生成器。
dtype数据类型,默认float。
comments注释。
delimiter分隔符,默认是空格。
skiprows跳过前几行读取,默认是0,必须是int整型。
usecols要读取哪些列,0是第一列。例如,usecols = (1,4,5)将提取第2,第5和第6列。默认读取所有列。
unpack如果为True,将分列读取。
- import torch
- import torch.nn.functional as F
- import matplotlib.pyplot as plt
- import numpy as np
-
-
- # 加载数据集
- xy = np.loadtxt('pima-indians-diabetes.csv', delimiter=',', dtype=np.float32,skiprows = 1)
- # 数据预处理,包括从数据集里区分输入输出,最后把输入输出数据封装成Pytorch期望的Variable格式
- X_train= torch.from_numpy(xy[:,:-1]) # 特征信息
- y_train= torch.from_numpy(xy[:,[-1]]) # 目标分类
-
- class Model(torch.nn.Module):
- def __init__(self):
- super(Model, self).__init__()
- self.linear1=torch.nn.Linear(8,6)
- self.linear2= torch.nn.Linear(6,4)
- self.linear3= torch.nn.Linear(4,1)
- self.sigmoid=torch.nn.Sigmoid()
- def forward(self,x):
- x=self.sigmoid(self.linear1(x))
- x = self.sigmoid(self.linear2(x))
- x = self.sigmoid(self.linear3(x))
- return x
- model=Model()
- criterion=torch.nn.BCELoss(size_average=True)
- optimizer=torch.optim.SGD(model.parameters(),lr=0.1)
- px,py = [],[] # 记录要绘制的数据
- for epoch in range(1000):
- y_pred=model(X_train)
- loss=criterion(y_pred,y_train)
- print(epoch,loss.item())
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- px.append(epoch)
- py.append(loss.item())
- plt.plot(px, py)
- plt.ylabel('loss')
- plt.xlabel('epoch')
- plt.show()
- # # 每十次迭代绘制训练动态
- # if epoch% 10 == 0:
- # plt.cla()
- # plt.plot(px, py, 'r-', lw=1)
- # plt.text(0, 0, 'Loss=%.4f' % loss.item(), fontdict={'size': 20, 'color': 'red'})
- # plt.pause(0.1)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。