赞
踩
学习一下这个里面函数,文件和数据的操作。
除了知道输入是什么,还要知道输出是什么,什么类型,能进行什么操作。
class Dataset(torch.utils.data.Dataset): def __init__(self, path, stim): _, _, filenames = next(os.walk(path)) filenames = sorted(filenames) all_data = [] all_label = [] for dat in filenames: temp = pickle.load(open(os.path.join(path,dat), 'rb'), encoding='latin1') all_data.append(temp['data']) if stim == "Valence": all_label.append(temp['labels'][:,:1]) #the first index is valence elif stim == "Arousal": all_label.append(temp['labels'][:,1:2]) # Arousal #the second index is arousal self.data = np.vstack(all_data)[:, :32, ] #shape: (1280, 32, 8064) --> take only the first 32 channels shape = self.data.shape #perform segmentation===== segments = 12 self.data = self.data.reshape(shape[0], shape[1], int(shape[2]/segments), segments) #data shape: (1280, 32, 672, 12) self.data = self.data.transpose(0, 3, 1, 2) #data shape: (1280, 12, 32, 672) self.data = self.data.reshape(shape[0] * segments, shape[1], -1) #data shape: (1280*12, 32, 672) #========================== self.label = np.vstack(all_label) #(1280, 1) ==> 1280 samples, self.label = np.repeat(self.label, 12)[:, np.newaxis] #the dimension 1 is lost after repeat, so need to unsqueeze (1280*12, 1) del temp, all_data, all_label def __len__(self): return self.data.shape[0] def __getitem__(self, idx): single_data = self.data[idx] single_label = (self.label[idx] > 5).astype(float) #convert the scale to either 0 or 1 (to classification problem) batch = { 'data': torch.Tensor(single_data), 'label': torch.Tensor(single_label) } return batch
第一句值得深入研究
_, _, filenames = next(os.walk(path))
os.walk(路径)遍历文件,返回路径,路径下的文件夹,路径下的文件
path='E:\EEG\DATASET\DEAP\deap_set\data_preprocessed_python'
for root, dirs, files in os.walk(path):
print("root:"+root)
print(dirs)
print(files)
它的返回值是一个生成器,只能遍历打印或者用next(),会不断遍历路径下的所有文件夹。
next(可迭代对象,最后默认值)不停迭代,输出下一个对象。
可迭代对象,iterable对象。生成器也是一个可迭代对象。
如果我这样写呢,filenames = next(os.walk(path))
,返回的是一个三元的元组。生成器经过next()变为了元组。
可以直接作用于for循环的数据类型有以下几种:
这些可以直接作用于for循环的对象统称为可迭代对象:Iterable。
可以使用isinstance()判断一个对象是否是Iterable对象:
>>> from collections.abc import Iterable
>>> isinstance([], Iterable)
可以被next()函数调用并不断返回下一个值的对象称为迭代器:Iterator。
生成器(generator)不但可以作用于for循环,还可以被next()函数不断调用并返回下一个值,直到最后抛出StopIteration错误表示无法继续返回下一个值了。与可迭代器的区别。
list、dict、str虽然是Iterable,却不是Iterator
为什么呢?
这是因为Python的Iterator对象表示的是一个数据流,Iterator对象可以被next()函数调用并不断返回下一个数据,直到没有数据时抛出StopIteration错误。可以把这个数据流看做是一个有序序列,但我们却不能提前知道序列的长度,只能不断通过next()函数实现按需计算下一个数据,所以Iterator的计算是惰性的,只有在需要返回下一个数据时它才会计算。
Iterator甚至可以表示一个无限大的数据流,例如全体自然数。而使用list是永远不可能存储全体自然数的。
把list、dict、str等Iterable变成Iterator可以使用iter()函数
总结
凡是可作用于for循环的对象都是Iterable类型,它们是有限的有规律的,确定的;
凡是可作用于next()函数的对象都是Iterator类型,它们表示一个惰性计算的序列,用到才给你;
集合数据类型如list、dict、str等是Iterable,但不是Iterator,不过可以通过iter()函数获得一个Iterator对象。
Python的for循环本质上就是通过不断调用next()函数实现的。
_, _, filenames = next(os.walk(path))
这句代码,就是为了得到遍历数据文件名。
sorted()可以对所有可迭代类型进行排序,并且返回新的已排序的列表。语法如下:
sorted(iterable, cmp=None, key=None, reverse=False)
一共可接受4个参数,含义分别如下:
1.可迭代类型,例如字典、列表、
2.比较函数
3.可迭代类型中某个属性,对给定元素的每一项进行排序
4.降序或升序
定义了两个列表,用于存放标签与数据
all_data = []#存数据
all_label = []#存标签
for dat in filenames:
temp = pickle.load(open(os.path.join(path,dat), 'rb'), encoding='latin1')
#加载数据,字典类型
all_data.append(temp['data'])
#取出数据,取出字典中data对应的值
if stim == "Valence":
all_label.append(temp['labels'][:,:1])
#取第一列,为Valence,L[0:3]表示,从索引0开始取,直到索引3为止,但不包括索引3。即索引0,1,2,正好是3个元素。
elif stim == "Arousal":
all_label.append(temp['labels'][:,1:2])
#取第二列, Arousal
我自己写了几句,分析理解了一下
import pprint as pp import numpy as np #pp = pprint.PrettyPrinter(indent=4) path="E:\EEG\DATASET\DEAP\deap_set\data_preprocessed_python" all_data = [] all_label = [] #遍历文件 _, _, filenames = next(os.walk(path)) filenames = sorted(filenames) print(filenames) output: ['s01.dat', 's02.dat', 's03.dat', 's04.dat', 's05.dat', 's06.dat', 's07.dat', 's08.dat', 's09.dat', 's10.dat', 's11.dat', 's12.dat', 's13.dat', 's14.dat', 's15.dat', 's16.dat', 's17.dat', 's18.dat', 's19.dat', 's20.dat', 's21.dat', 's22.dat', 's23.dat', 's24.dat', 's25.dat', 's26.dat', 's27.dat', 's28.dat', 's29.dat', 's30.dat', 's31.dat', 's32.dat'] #分析数据文件 temp = pickle.load(open(os.path.join(path,filenames[0]), 'rb'), encoding='latin1') print(type(temp))#字典类型 print(temp.keys())#查看包括哪些keys,分别取出 output: <class 'dict'> dict_keys(['labels', 'data']) #取出对应的数据 all_label=temp['labels'] all_data=temp['data'] print(np.shape(all_label)) print(np.shape(all_data)) output: (40, 4)标签,40条视频 (40, 40, 8064)数据,40条视频,40个通道,每个通道8064个数据点 pp.pprint(all_labels) output: array([[7.71, 7.6 , 6.9 , 7.83], [8.1 , 7.31, 7.28, 8.47], [8.58, 7.54, 9. , 7.08], [4.94, 6.01, 6.12, 8.06], [6.96, 3.92, 7.19, 6.05], [8.27, 3.92, 7. , 8.03], pp.pprint(all_data) output: array([[[ 9.48231681e-01, 1.65333533e+00, 3.01372577e+00, ..., -2.82648937e+00, -4.47722969e+00, -3.67692812e+00], [ 1.24706590e-01, 1.39008270e+00, 1.83509881e+00, ..., -2.98702069e+00, -6.28780884e+00, -4.47429041e+00], [-2.21651099e+00, 2.29201682e+00, 2.74636923e+00, ..., -2.63707760e+00, -7.40651010e+00, -6.75590441e+00], ..., [ 2.30779684e+02, 6.96716323e+02, 1.19512165e+03, ..., 1.01080949e+03, 1.28312149e+03, 1.51996480e+03], [-1.54180981e+03, -1.61798052e+03, -1.69268642e+03, ..., -1.57842691e+04, -1.57823160e+04, -1.57808512e+04], [ 6.39054310e-03, 6.39054310e-03, 6.39054310e-03, ..., -9.76081241e-02, -9.76081241e-02, -9.76081241e-02]],
self.data = np.vstack(all_data)[:, :32, ] #沿着竖直方向将矩阵堆叠起来
#shape: (1280, 32, 8064) --> take only the first 32 channels
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。