当前位置:   article > 正文

数据生成器

数据生成器

适用于比赛数据生成

一、 回归数据生成器

  1. def tensorGenReg(num_examples=1000,w=[2,-1,1],bias=True,delta=0.01,deg=1):
  2. """
  3. num_examples:创建数据集所需数据量
  4. w: 包括截距的特征系数向量
  5. bias:是否需要b
  6. delta:扰动项取值
  7. deg:方程次数
  8. return:生成的特征张量与标签张量
  9. """
  10. if bias==True:
  11. num_inputs=len(w)-1
  12. w_true = torch.tensor(w[:-1]).reshape(-1,1).float()
  13. b_true=torch.tensor(w[-1]).float()
  14. features_true=torch.randn(num_examples,num_inputs)
  15. if num_inputs==1:
  16. labels_true=torch.pow(features,deg)*w_true+b_true
  17. else:
  18. labels_true=torch.mm(torch.pow(features_true,deg),w_true)+b_true
  19. features=torch.cat(( features_true,torch.ones(len(features_true),1)),1)
  20. labels = labels_true+torch.randn(size=labels_true.shape)*delta#增加小偏差
  21. else:
  22. num_inputs=len(w)
  23. w_true = torch.tensor(w).reshape(-1,1).float()
  24. features=torch.randn(num_examples,num_inputs)
  25. if num_inputs==1:
  26. labels_true=torch.pow(features,deg)*w_true
  27. else:
  28. labels_true=torch.mm(torch.pow(features,deg),w_true)
  29. labels = labels_true+torch.randn(size=labels_true.shape)*delta#增加小偏差
  30. return features,labels

结果展示:

y = 2x_1-x_2+1假设搞这个函数的分布图

  1. f,l=tensorGenReg(delta=0.01)
  2. plt.subplot(121)
  3. plt.scatter(f[:,0],l)
  4. plt.subplot(122)
  5. plt.scatter(f[:,1],l)

二、 分类数据生成器

创建拥有两个特征的三分类数据集,每个类别500条数据,第一类别两特征值均满足均值为4,标准差为2的标准正态分布,第二类别两特征值均满足均值为-2,标准差为2的标准正态分布,第三类别两特征值均满足均值为-6,标准差为2的标准正态分布。

  1. torch.manual_seed(420)
  2. num_input=2
  3. num_examples=500
  4. data0=torch.normal(4,2,size=(num_examples,num_input))#均值=4,标准差=2,
  5. data1=torch.normal(-2,2,size=(num_examples,num_input))#均值=-2,标准差=2,
  6. data2=torch.normal(-6,2,size=(num_examples,num_input))#均值=-6,标准差=2,
  7. label0 = torch.zeros(500)
  8. label1 = torch.ones(500)
  9. label2=torch.full_like(label1,2)#仿照label1创建一个数组数值为2
  10. #合并数据集
  11. features = torch.cat((data0,data1,data2)).float()
  12. labels = torch.cat((label0,label1,label2)).long().reshape(-1,1)#长整型
  13. plt.scatter(features[:,0],features[:,1],c=labels)#数据展示

上面的方式还是比较简单的,我们可以尝试加入扰动项来加大难度(将各类的均值进行压缩,增加方差),下面提供普适性方案

  1. def tensorGencla(num_examples=500,num_inputs=2,num_class=3,deg_dispersion=[4,2],bias=False):
  2. """
  3. num_examples:创建数据集所需数据量
  4. num_inputs: 数据集特征总数
  5. num_class:数据集标签类别综述
  6. deg_dispersion:数据分布离散程度,第一个表示均值的参考值,第二个代表随机数的标准差参考值
  7. bias:建立模型逻辑回归模型是否带入截距
  8. return:生成的特征张量与标签张量
  9. """
  10. cluster_1 = torch.empty(num_examples,1) #每一类标签张量的形状
  11. mean_ = deg_dispersion[0]#均值参考值
  12. std_ =deg_dispersion[1]#标准差参考值
  13. lf = []#用于储存每一类特征张量的列表容器
  14. ll = []#用于储存每一类标签张量的列表容器
  15. k = mean_*(num_class+1)/2#每一类特征张量均值的惩罚因子
  16. for i in range(num_class):
  17. data_temp = torch.normal(i*mean_-k,std_,size=(num_examples,num_inputs))#生成每一类特征张量
  18. #i*mean_-k,(例如3,0,-3对称分布,迭代效率很快),式子为mean_*(i-(num_class+1)/2),i-最大值的一半,基本就在零附近的一组数
  19. lf.append(data_temp)#储存在lf中
  20. labels_temp=torch.full_like(cluster_1,i)
  21. ll.append(labels_temp)
  22. features = torch.cat(lf).float()
  23. labels = torch.cat(ll).float()
  24. return features,labels
  1. f,l=tensorGencla(deg_dispersion=[6,4])
  2. plt.scatter(f[:,0],f[:,1],c=l)
  3. plt.savefig('分类器.png', dpi=300)

三、蒙特卡罗模拟

(可自主构建,灵活度很高)

四、创建小批量切分函数

       之前提到过, 在深度学习建模过程中,梯度下降是最常用的求解目标函数的优化方法,而针对不同类型、拥有不同函数特性的目标函数,所使用的梯度下降算法也各有不同。目前为止,我们判断小批量梯度下降 (MBGD)与Adam是为“普适”的优化算法,在这里介绍下小批量函数需要利用的过程。

  1. def data_iter(batch_size,features,labels):
  2. """
  3. batch_size:每个子集需要多少数据
  4. features:输入特征张量
  5. labels:输入标签张量
  6. return l:包含batch_size个列表,每个列表切分后的特征和标签所组成
  7. """
  8. num_examples=len(features)
  9. indics=list(range(num_examples))
  10. random.shuffle(indices)#打乱的索引
  11. l=[]
  12. for i in range(0,num_examples,batch_size):#0到num_examples,保证每次跨越为一个batch_size
  13. j =torch.tensor(indices[i:min(i+batch_size,num_examples)])
  14. #i在indices内部的索引出来的结果张量化,最后一批可能取的不是整数,
  15. #所以为了保证最后取值大于序列范围,用min
  16. l.append([torch.index_select(features,0,j),torch.index_select(labels,0,j)])
  17. #对features进行批量索引
  18. return l

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/72618
推荐阅读
相关标签
  

闽ICP备14008679号