当前位置:   article > 正文

27、ResNet50处理STEW数据集,用于情感三分类+全备的代码_resnet 情感分类

resnet 情感分类

1、数据介绍

IEEE-Datasets-STEW:SIMULTANEOUS TASK EEG WORKLOAD DATASET :

数据集由48名受试者的原始EEG数据组成,他们参加了利用SIMKAP多任务测试进行的多任务工作负荷实验。受试者在休息时的大脑活动也在测试前被记录下来,也包括在其中。Emotiv EPOC设备,采样频率为128Hz,有14个通道,用于获取数据,每个案例都有2.5分钟的EEG记录。受试者还被要求在每个阶段后以1到9的评分标准对其感知的心理工作量进行评分,评分结果在单独的文件中提供。

说明:每个受试者的数据遵循命名惯例:subno_task.txt。例如,sub01_lo.txt将是受试者1在休息时的原始脑电数据,而sub23_hi.txt将是受试者23在多任务测试中的原始脑电数据。每个数据文件的行对应于记录中的样本,列对应于EEG设备的14个通道: AF3, F7, F3, FC5, T7, P7, O1, O2, P8, T8, FC6, F4, F8, AF4。

数据说明、下载地址:

STEW: Simultaneous Task EEG Workload Data Set | IEEE Journals & Magazine | IEEE Xplore

2、代码

本次使用ResNet50,去做此情感数据的分类工作,数据导入+模型训练+测试代码如下:

  1. import torch
  2. import torchvision.datasets
  3. from torch.utils.data import Dataset # 继承Dataset类
  4. import os
  5. from PIL import Image
  6. import numpy as np
  7. from torchvision import transforms
  8. # 预处理
  9. data_transform = transforms.Compose([
  10. transforms.Resize((224,224)), # 缩放图像
  11. transforms.ToTensor(), # 转为Tenso
  12. transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) # 标准化
  13. ])
  14. path = r'C:\STEW\test'
  15. for root,dirs,files in os.walk(path):
  16. print('root',root) #遍历到该目录地址
  17. print('dirs',dirs) #遍历到该目录下的子目录名 []
  18. print('files',files) #遍历到该目录下的文件 []
  19. def read_txt_files(path):
  20. # 创建文件名列表
  21. file_names = []
  22. # 遍历给定目录及其子目录下的所有文件
  23. for root, dirs, files in os.walk(path):
  24. # 遍历所有文件
  25. for file in files:
  26. # 如果是 .txt 文件,则加入文件名列表
  27. if file.endswith('.txt'): # endswith () 方法用于判断字符串是否以指定后缀结尾,如果以指定后缀结尾返回True,否则返回False。
  28. file_names.append(os.path.join(root, file))
  29. # 返回文件名列表
  30. return file_names
  31. class DogCat(Dataset): # 数据处理
  32. def __init__(self,root,transforms = None): # 初始化,指定路径,是否预处理等等
  33. #['cat.15454.jpg', 'cat.445.jpg', 'cat.46456.jpg', 'cat.656165.jpg', 'dog.123.jpg', 'dog.15564.jpg', 'dog.4545.jpg', 'dog.456465.jpg']
  34. imgs = os.listdir(root)
  35. self.imgs = [os.path.join(root,img) for img in imgs] # 取出root下所有的文件
  36. self.transforms = data_transform # 图像预处理
  37. def __getitem__(self, index): # 读取图片
  38. img_path = self.imgs[index]
  39. label = 1 if 'dog' in img_path.split('/')[-1] else 0
  40. #然后,就可以根据每个路径的id去做label了。将img_path 路径按照 '/ '分割,-1代表取最后一个字符串,如果里面有dog就为1,cat就为0.
  41. data = Image.open(img_path)
  42. if self.transforms: # 图像预处理
  43. data = self.transforms(data)
  44. return data,label
  45. def __len__(self):
  46. return len(self.imgs)
  47. dataset = DogCat('./data/',transforms=True)
  48. for img,label in dataset:
  49. print('img:',img.size(),'label:',label)
  50. '''
  51. img: torch.Size([3, 224, 224]) label: 0
  52. img: torch.Size([3, 224, 224]) label: 0
  53. img: torch.Size([3, 224, 224]) label: 0
  54. img: torch.Size([3, 224, 224]) label: 0
  55. img: torch.Size([3, 224, 224]) label: 1
  56. img: torch.Size([3, 224, 224]) label: 1
  57. img: torch.Size([3, 224, 224]) label: 1
  58. img: torch.Size([3, 224, 224]) label: 1
  59. '''
  60. import os
  61. # 获取file_path路径下的所有TXT文本内容和文件名
  62. def get_text_list(file_path):
  63. files = os.listdir(file_path)
  64. text_list = []
  65. for file in files:
  66. with open(os.path.join(file_path, file), "r", encoding="UTF-8") as f:
  67. text_list.append(f.read())
  68. return text_list, files
  69. class ImageFolderCustom(Dataset):
  70. # 2. Initialize with a targ_dir and transform (optional) parameter
  71. def __init__(self, targ_dir: str, transform=None) -> None:
  72. # 3. Create class attributes
  73. # Get all image paths
  74. self.paths = list(pathlib.Path(targ_dir).glob("*/*.jpg")) # note: you'd have to update this if you've got .png's or .jpeg's
  75. # Setup transforms
  76. self.transform = transform
  77. # Create classes and class_to_idx attributes
  78. self.classes, self.class_to_idx = find_classes(targ_dir)
  79. # 4. Make function to load images
  80. def load_image(self, index: int) -> Image.Image:
  81. "Opens an image via a path and returns it."
  82. image_path = self.paths[index]
  83. return Image.open(image_path)
  84. # 5. Overwrite the __len__() method (optional but recommended for subclasses of torch.utils.data.Dataset)
  85. def __len__(self) -> int:
  86. "Returns the total number of samples."
  87. return len(self.paths)
  88. # 6. Overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset)
  89. def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
  90. "Returns one sample of data, data and label (X, y)."
  91. img = self.load_image(index)
  92. class_name = self.paths[index].parent.name # expects path in data_folder/class_name/image.jpeg
  93. class_idx = self.class_to_idx[class_name]
  94. # Transform if necessary
  95. if self.transform:
  96. return self.transform(img), class_idx # return data, label (X, y)
  97. else:
  98. return img, class_idx # return data, label (X, y)
  99. import torchvision as tv
  100. import numpy as np
  101. import torch
  102. import time
  103. import os
  104. from torch import nn, optim
  105. from torchvision.models import resnet50
  106. from torchvision.transforms import transforms
  107. os.environ["CUDA_VISIBLE_DEVICE"] = "0,1,2"
  108. # cifar-10进行测验
  109. class Cutout(object):
  110. """Randomly mask out one or more patches from an image.
  111. Args:
  112. n_holes (int): Number of patches to cut out of each image.
  113. length (int): The length (in pixels) of each square patch.
  114. """
  115. def __init__(self, n_holes, length):
  116. self.n_holes = n_holes
  117. self.length = length
  118. def __call__(self, img):
  119. """
  120. Args:
  121. img (Tensor): Tensor image of size (C, H, W).
  122. Returns:
  123. Tensor: Image with n_holes of dimension length x length cut out of it.
  124. """
  125. h = img.size(1)
  126. w = img.size(2)
  127. mask = np.ones((h, w), np.float32)
  128. for n in range(self.n_holes):
  129. y = np.random.randint(h)
  130. x = np.random.randint(w)
  131. y1 = np.clip(y - self.length // 2, 0, h)
  132. y2 = np.clip(y + self.length // 2, 0, h)
  133. x1 = np.clip(x - self.length // 2, 0, w)
  134. x2 = np.clip(x + self.length // 2, 0, w)
  135. mask[y1: y2, x1: x2] = 0.
  136. mask = torch.from_numpy(mask)
  137. mask = mask.expand_as(img)
  138. img = img * mask
  139. return img
  140. def load_data_cifar10(batch_size=128,num_workers=2):
  141. # 操作合集
  142. # Data augmentation
  143. train_transform_1 = transforms.Compose([
  144. transforms.Resize((224, 224)),
  145. transforms.RandomHorizontalFlip(), # 随机水平翻转
  146. transforms.RandomRotation(degrees=(-80,80)), # 随机角度翻转
  147. transforms.ToTensor(),
  148. transforms.Normalize(
  149. (0.491339968,0.48215827,0.44653124), (0.24703233,0.24348505,0.26158768) # 两者分别为(mean,std)
  150. ),
  151. Cutout(1, 16), # 务必放在ToTensor的后面
  152. ])
  153. train_transform_2 = transforms.Compose([
  154. transforms.Resize((224, 224)),
  155. transforms.ToTensor(),
  156. transforms.Normalize(
  157. (0.491339968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768) # 两者分别为(mean,std)
  158. )
  159. ])
  160. test_transform = transforms.Compose([
  161. transforms.Resize((224,224)),
  162. transforms.ToTensor(),
  163. transforms.Normalize(
  164. (0.491339968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768) # 两者分别为(mean,std)
  165. )
  166. ])
  167. # 训练集1
  168. trainset1 = tv.datasets.CIFAR10(
  169. root='data',
  170. train=True,
  171. download=False,
  172. transform=train_transform_1,
  173. )
  174. # 训练集2
  175. trainset2 = tv.datasets.CIFAR10(
  176. root='data',
  177. train=True,
  178. download=False,
  179. transform=train_transform_2,
  180. )
  181. # 测试集
  182. testset = tv.datasets.CIFAR10(
  183. root='data',
  184. train=False,
  185. download=False,
  186. transform=test_transform,
  187. )
  188. # 训练数据加载器1
  189. trainloader1 = torch.utils.data.DataLoader(
  190. trainset1,
  191. batch_size=batch_size,
  192. shuffle=True,
  193. num_workers=num_workers,
  194. pin_memory=(torch.cuda.is_available())
  195. )
  196. # 训练数据加载器2
  197. trainloader2 = torch.utils.data.DataLoader(
  198. trainset2,
  199. batch_size=batch_size,
  200. shuffle=True,
  201. num_workers=num_workers,
  202. pin_memory=(torch.cuda.is_available())
  203. )
  204. # 测试数据加载器
  205. testloader = torch.utils.data.DataLoader(
  206. testset,
  207. batch_size=batch_size,
  208. shuffle=False,
  209. num_workers=num_workers,
  210. pin_memory=(torch.cuda.is_available())
  211. )
  212. return trainloader1,trainloader2,testloader
  213. def main():
  214. start = time.time()
  215. batch_size = 128
  216. cifar_train1,cifar_train2,cifar_test = load_data_cifar10(batch_size=batch_size)
  217. model = resnet50().cuda()
  218. # model.load_state_dict(torch.load('_ResNet50.pth'))
  219. # 存在已保存的参数文件
  220. # model = nn.DataParallel(model,device_ids=[0,]) # 又套一层
  221. model = nn.DataParallel(model,device_ids=[0,1,2])
  222. loss = nn.CrossEntropyLoss().cuda()
  223. optimizer = optim.Adam(model.parameters(),lr=0.001)
  224. for epoch in range(50):
  225. model.train() # 训练时务必写
  226. loss_=0.0
  227. num=0.0
  228. # train on trainloader1(data augmentation) and trainloader2
  229. for i,data in enumerate(cifar_train1,0):
  230. x, label = data
  231. x, label = x.cuda(),label.cuda()
  232. # x
  233. p = model(x) #output
  234. l = loss(p,label) #loss
  235. optimizer.zero_grad()
  236. l.backward()
  237. optimizer.step()
  238. loss_ += float(l.mean().item())
  239. num+=1
  240. for i, data in enumerate(cifar_train2, 0):
  241. x, label = data
  242. x, label = x.cuda(), label.cuda()
  243. # x
  244. p = model(x)
  245. l = loss(p, label)
  246. optimizer.zero_grad()
  247. l.backward()
  248. optimizer.step()
  249. loss_ += float(l.mean().item())
  250. num += 1
  251. model.eval() # 评估时务必写
  252. print("loss:",float(loss_)/num)
  253. # test on trainloader2,testloader
  254. with torch.no_grad():
  255. total_correct = 0
  256. total_num = 0
  257. for x, label in cifar_train2:
  258. # [b, 3, 32, 32]
  259. # [b]
  260. x, label = x.cuda(), label.cuda()
  261. # [b, 10]
  262. logits = model(x)
  263. # [b]
  264. pred = logits.argmax(dim=1)
  265. # [b] vs [b] => scalar tensor
  266. correct = torch.eq(pred, label).float().sum().item()
  267. total_correct += correct
  268. total_num += x.size(0)
  269. # print(correct)
  270. acc_1 = total_correct / total_num
  271. # Test
  272. with torch.no_grad():
  273. total_correct = 0
  274. total_num = 0
  275. for x, label in cifar_test:
  276. # [b, 3, 32, 32]
  277. # [b]
  278. x, label = x.cuda(), label.cuda()
  279. # [b, 10]
  280. logits = model(x) #output
  281. # [b]
  282. pred = logits.argmax(dim=1)
  283. # [b] vs [b] => scalar tensor
  284. correct = torch.eq(pred, label).float().sum().item()
  285. total_correct += correct
  286. total_num += x.size(0)
  287. # print(correct)
  288. acc_2 = total_correct / total_num
  289. print(epoch+1,'train acc',acc_1,'|','test acc:', acc_2)
  290. # 保存时只保存model.module
  291. torch.save(model.module.state_dict(),'resnet50.pth')
  292. print("The interval is :",time.time() - start)
  293. if __name__ == '__main__':
  294. main()

3、对你有帮助的话,给个关注吧~

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号