当前位置:   article > 正文

deap dataset的不同分类模型的实现(3)-遍历文件_torcheeg

torcheeg

学习一下这个里面函数,文件和数据的操作。
除了知道输入是什么,还要知道输出是什么,什么类型,能进行什么操作。

class Dataset(torch.utils.data.Dataset):
    
    def __init__(self, path, stim):
        _, _, filenames = next(os.walk(path))
        
        filenames = sorted(filenames)
        all_data = []
        all_label = []
        for dat in filenames:
            
            temp = pickle.load(open(os.path.join(path,dat), 'rb'), encoding='latin1')
            
            all_data.append(temp['data'])
            
            if stim == "Valence":
                all_label.append(temp['labels'][:,:1])   #the first index is valence
            elif stim == "Arousal":
                all_label.append(temp['labels'][:,1:2]) # Arousal  #the second index is arousal
                
        self.data = np.vstack(all_data)[:, :32, ]   #shape: (1280, 32, 8064) --> take only the first 32 channels
        
        shape = self.data.shape
        
        #perform segmentation=====
        segments = 12
        
        self.data = self.data.reshape(shape[0], shape[1], int(shape[2]/segments), segments)
        #data shape: (1280, 32, 672, 12)

        self.data = self.data.transpose(0, 3, 1, 2)
        #data shape: (1280, 12, 32, 672)

        self.data = self.data.reshape(shape[0] * segments, shape[1], -1)
        #data shape: (1280*12, 32, 672)
        #==========================
        
        self.label = np.vstack(all_label) #(1280, 1)  ==> 1280 samples, 
        self.label = np.repeat(self.label, 12)[:, np.newaxis]  #the dimension 1 is lost after repeat, so need to unsqueeze (1280*12, 1)
        
        del temp, all_data, all_label

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        single_data  = self.data[idx]
        single_label = (self.label[idx] > 5).astype(float)   #convert the scale to either 0 or 1 (to classification problem)
        
        batch = {
            'data': torch.Tensor(single_data),
            'label': torch.Tensor(single_label)
        }
        
        return batch
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54

文件遍历学习

第一句值得深入研究

_, _, filenames = next(os.walk(path))
  • 1

os.walk(路径)遍历文件,返回路径,路径下的文件夹,路径下的文件

path='E:\EEG\DATASET\DEAP\deap_set\data_preprocessed_python'
for root, dirs, files in os.walk(path):
    print("root:"+root)
    print(dirs)
    print(files)
  • 1
  • 2
  • 3
  • 4
  • 5

它的返回值是一个生成器,只能遍历打印或者用next(),会不断遍历路径下的所有文件夹。

生成器理解
遍历目录

next(可迭代对象,最后默认值)不停迭代,输出下一个对象。
可迭代对象,iterable对象。生成器也是一个可迭代对象。

如果我这样写呢,filenames = next(os.walk(path)),返回的是一个三元的元组。生成器经过next()变为了元组。

可迭代对象与可迭代器,生成器的分析

可以直接作用于for循环的数据类型有以下几种:

  • 一类是集合数据类型,如list、tuple、dict、set、str等;
  • 一类是generator,包括生成器和带yield的generator function。

这些可以直接作用于for循环的对象统称为可迭代对象:Iterable。

可以使用isinstance()判断一个对象是否是Iterable对象:

>>> from collections.abc import Iterable
>>> isinstance([], Iterable)
  • 1
  • 2

可以被next()函数调用并不断返回下一个值的对象称为迭代器:Iterator。

生成器(generator)不但可以作用于for循环,还可以被next()函数不断调用并返回下一个值,直到最后抛出StopIteration错误表示无法继续返回下一个值了。与可迭代器的区别。

list、dict、str虽然是Iterable,却不是Iterator
为什么呢?

这是因为Python的Iterator对象表示的是一个数据流,Iterator对象可以被next()函数调用并不断返回下一个数据,直到没有数据时抛出StopIteration错误。可以把这个数据流看做是一个有序序列,但我们却不能提前知道序列的长度,只能不断通过next()函数实现按需计算下一个数据,所以Iterator的计算是惰性的,只有在需要返回下一个数据时它才会计算。

Iterator甚至可以表示一个无限大的数据流,例如全体自然数。而使用list是永远不可能存储全体自然数的。

把list、dict、str等Iterable变成Iterator可以使用iter()函数

总结

凡是可作用于for循环的对象都是Iterable类型,它们是有限的有规律的,确定的;

凡是可作用于next()函数的对象都是Iterator类型,它们表示一个惰性计算的序列,用到才给你;

集合数据类型如list、dict、str等是Iterable,但不是Iterator,不过可以通过iter()函数获得一个Iterator对象。

Python的for循环本质上就是通过不断调用next()函数实现的。

_, _, filenames = next(os.walk(path))这句代码,就是为了得到遍历数据文件名。

1、sorted()

sorted()可以对所有可迭代类型进行排序,并且返回新的已排序的列表。语法如下:

sorted(iterable, cmp=None, key=None, reverse=False)
一共可接受4个参数,含义分别如下:
1.可迭代类型,例如字典、列表、
2.比较函数
3.可迭代类型中某个属性,对给定元素的每一项进行排序
4.降序或升序

加载数据

定义了两个列表,用于存放标签与数据

all_data = []#存数据
all_label = []#存标签

        for dat in filenames:
            
            temp = pickle.load(open(os.path.join(path,dat), 'rb'), encoding='latin1')
            #加载数据,字典类型
            all_data.append(temp['data'])
            #取出数据,取出字典中data对应的值
            
            if stim == "Valence":
                all_label.append(temp['labels'][:,:1])   
                #取第一列,为Valence,L[0:3]表示,从索引0开始取,直到索引3为止,但不包括索引3。即索引0,1,2,正好是3个元素。
            elif stim == "Arousal":
                all_label.append(temp['labels'][:,1:2]) 
                #取第二列, Arousal  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

我自己写了几句,分析理解了一下

import pprint as pp
import numpy as np
#pp = pprint.PrettyPrinter(indent=4)
path="E:\EEG\DATASET\DEAP\deap_set\data_preprocessed_python"
all_data = []
all_label = []

#遍历文件
_, _, filenames = next(os.walk(path))
filenames = sorted(filenames)
print(filenames)

output:
['s01.dat', 's02.dat', 's03.dat', 's04.dat', 's05.dat', 's06.dat', 's07.dat', 's08.dat', 's09.dat', 's10.dat', 's11.dat', 's12.dat', 's13.dat', 's14.dat', 's15.dat', 's16.dat', 's17.dat', 's18.dat', 's19.dat', 's20.dat', 's21.dat', 's22.dat', 's23.dat', 's24.dat', 's25.dat', 's26.dat', 's27.dat', 's28.dat', 's29.dat', 's30.dat', 's31.dat', 's32.dat']


#分析数据文件
temp = pickle.load(open(os.path.join(path,filenames[0]), 'rb'), encoding='latin1')
print(type(temp))#字典类型
print(temp.keys())#查看包括哪些keys,分别取出

output:
<class 'dict'>
dict_keys(['labels', 'data'])

#取出对应的数据
all_label=temp['labels']
all_data=temp['data']

print(np.shape(all_label))
print(np.shape(all_data))

output:
(40, 4)标签,40条视频
(40, 40, 8064)数据,40条视频,40个通道,每个通道8064个数据点

pp.pprint(all_labels)
output:
array([[7.71, 7.6 , 6.9 , 7.83],
       [8.1 , 7.31, 7.28, 8.47],
       [8.58, 7.54, 9.  , 7.08],
       [4.94, 6.01, 6.12, 8.06],
       [6.96, 3.92, 7.19, 6.05],
       [8.27, 3.92, 7.  , 8.03],
pp.pprint(all_data)
output:
array([[[ 9.48231681e-01,  1.65333533e+00,  3.01372577e+00, ...,
         -2.82648937e+00, -4.47722969e+00, -3.67692812e+00],
        [ 1.24706590e-01,  1.39008270e+00,  1.83509881e+00, ...,
         -2.98702069e+00, -6.28780884e+00, -4.47429041e+00],
        [-2.21651099e+00,  2.29201682e+00,  2.74636923e+00, ...,
         -2.63707760e+00, -7.40651010e+00, -6.75590441e+00],
        ...,
        [ 2.30779684e+02,  6.96716323e+02,  1.19512165e+03, ...,
          1.01080949e+03,  1.28312149e+03,  1.51996480e+03],
        [-1.54180981e+03, -1.61798052e+03, -1.69268642e+03, ...,
         -1.57842691e+04, -1.57823160e+04, -1.57808512e+04],
        [ 6.39054310e-03,  6.39054310e-03,  6.39054310e-03, ...,
         -9.76081241e-02, -9.76081241e-02, -9.76081241e-02]],
        

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
self.data = np.vstack(all_data)[:, :32, ]   #沿着竖直方向将矩阵堆叠起来
#shape: (1280, 32, 8064) --> take only the first 32 channels
  • 1
  • 2
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/72874
推荐阅读
相关标签
  

闽ICP备14008679号