当前位置:   article > 正文

计算机视觉———— 训练数据获取与处理_数据训练

数据训练

一、数据集的获取

  通常,我们的数据来源于各个比赛平台。首先是AIStudio中的数据集,大部分经典数据集例如百度AI Studio ,Kaggle、天池、讯飞等平台(通过关键词搜索获取需要的数据集),或者是Github。还有一些小的平台,需要大家自己去看。通常来说,数据集用于学术目的,有些数据需要申请才能获得链接。

1.1 Kaggle有趣比较火热的数据集

url :https://www.kaggle.com/c/house-prices-advanced-regression-techniques/data?select=test.csv

House Prices-Advanced Regression Techniques      预测销售价格

Cat and Dog                         猫狗分类

Machine Learning from Disaster      预测泰坦尼克号的生存情况并熟悉机器学习基础知识

1.2 天池

阿里数据:https://tianchi.aliyun.com/dataset/dataDetail?dataId=74952

Barley Remote Sensing Dataset大麦遥感检测数据集     遥感影像分割

耶鲁人脸数据库                   目标检测任务(人脸检测)

1.3 DataFountain

URL:https://www.datafountain.cn/datasets/6070

花卉分类数据集                       图像分类

1.4 其他常用的数据集官网

科大讯飞官网

COCO数据集

url:https://cocodataset.org/#download

1.5 完整流程概述

1.5.1 图像处理完整流程

    1. 图片数据获取
    1. 图片数据清洗

  ----初步了解数据,筛选掉不合适的图片

    1. 图片数据标注
    1. 图片数据预处理data preprocessing。

  ----标准化 standardlization

    一 中心化 = 去均值 mean normallization

      一 将各个维度中心化到0

      一 目的是加快收敛速度,在某些激活函数上表现更好

     一 归一化 = 除以标准差

      一 将各个维度的方差标准化处于[-1,1]之间

      一 目的是提高收敛效率,统一不同输入范围的数据对于模型学习的影响,映射到激活函数有效梯度的值域

    1. 图片数据准备data preparation(训练+测试阶段)

  ----划分训练集,验证集,以及测试集

    1. 图片数据增强data augjmentation(训练阶段 )

  ----CV常见的数据增强

       · 随机旋转

       · 随机水平或者重直翻转

       · 缩放

       · 剪裁

       · 平移

       · 调整亮度、对比度、饱和度、色差等等

       · 注入噪声

       · 基于生成对抗网络GAN做数搪增强AutoAugment等

1.5.2 纯数据处理完整流程

  • 数据预处理与特征工程

  • 1.感知数据

  ----初步了解数据

  ----记录和特征的数量特征的名称

  ----抽样了解记录中的数值特点描述性统计结果

  ----特征类型

  ----与相关知识领域数据结合,特征融合

  • 2.数据清理

  ----转换数据类型

  ----处理缺失数据

  ----处理离群数据

  • 3.特征变换

  ----特征数值化

  ----特征二值化

  ----OneHot编码

  ----特征离散化特征

  ----规范化

    区间变换

    标准化

    归一化

  • 4.特征选择

  ----封装器法

    循序特征选择

    穷举特征选择

    递归特征选择

  ----过滤器法

  ----嵌入法

  • 5.特征抽取

  ----无监督特征抽取

    主成分分析

    因子分析

  ----有监督特征抽取

拓展小知识:

   皮尔森相关系数是用来反应俩变量之间相似程度的统计量,在机器学习中可以用来计算特征与类别间的相似度,即可判断所提取到的特征和类别是正相关、负相关还是没有相关程度。 Pearson系数的取值范围为[-1,1],当值为负时,为负相关,当值为正时,为正相关,绝对值越大,则正/负相关的程度越大。若数据无重复值,且两个变量完全单调相关时,spearman相关系数为+1或-1。当两个变量独立时相关系统为0,但反之不成立。

当两个变量的标准差都不为零时,相关系数才有定义,Pearson相关系数适用于:

(1)、两个变量之间是线性关系,都是连续数据。

(2)、两个变量的总体是正态分布,或接近正态的单峰分布。

(3)、两个变量的观测值是成对的,每对观测值之间相互独立。

二、数据处理

2.1 官方数据COCO 处理为 VOC

  1. # 创建索引
  2. from pycocotools.coco import COCO
  3. import os
  4. import shutil
  5. from tqdm import tqdm
  6. import skimage.io as io
  7. import matplotlib.pyplot as plt
  8. import cv2
  9. from PIL import Image, ImageDraw
  10. from shutil import move
  11. import xml.etree.ElementTree as ET
  12. from random import shuffle
  13. # 保存路径
  14. savepath = "VOCData/"
  15. img_dir = savepath + 'images/' #images 存取所有照片
  16. anno_dir = savepath + 'Annotations/' #Annotations存取xml文件信息
  17. datasets_list=['train2017', 'val2017']
  18. classes_names = ['person']
  19. # 读取COCO数据集地址 Store annotations and train2017/val2017/... in this folder
  20. dataDir = './'
  21. #写好模板,里面的%s与%d 后面文件输入输出流改变 -------转数据集阶段--------
  22. headstr = """
  23. <annotation>
  24. <folder>VOC</folder>
  25. <filename>%s</filename>
  26. <source>
  27. <database>My Database</database>
  28. <annotation>COCO</annotation>
  29. <image>flickr</image>
  30. <flickrid>NULL</flickrid>
  31. </source>
  32. <owner>
  33. <flickrid>NULL</flickrid>
  34. <name>company</name>
  35. </owner>
  36. <size>
  37. <width>%d</width>
  38. <height>%d</height>
  39. <depth>%d</depth>
  40. </size>
  41. <segmented>0</segmented>
  42. """
  43. objstr = """
  44. <object>
  45. <name>%s</name>
  46. <pose>Unspecified</pose>
  47. <truncated>0</truncated>
  48. <difficult>0</difficult>
  49. <bndbox>
  50. <xmin>%d</xmin>
  51. <ymin>%d</ymin>
  52. <xmax>%d</xmax>
  53. <ymax>%d</ymax>
  54. </bndbox>
  55. </object>
  56. """
  57. tailstr = '''
  58. </annotation>
  59. '''
  60. # if the dir is not exists,make it,else delete it
  61. def mkr(path):
  62. if os.path.exists(path):
  63. shutil.rmtree(path)
  64. os.mkdir(path)
  65. else:
  66. os.mkdir(path)
  67. mkr(img_dir)
  68. mkr(anno_dir)
  69. def id2name(coco): # 生成字典 提取数据中的id,name标签的值 ---------处理数据阶段---------
  70. classes = dict()
  71. for cls in coco.dataset['categories']:
  72. classes[cls['id']] = cls['name']
  73. return classes
  74. def write_xml(anno_path, head, objs, tail): #把提取的数据写入到相应模板的地方
  75. f = open(anno_path, "w")
  76. f.write(head)
  77. for obj in objs:
  78. f.write(objstr % (obj[0], obj[1], obj[2], obj[3], obj[4]))
  79. f.write(tail)
  80. def save_annotations_and_imgs(coco, dataset, filename, objs):
  81. # eg:COCO_train2014_000000196610.jpg-->COCO_train2014_000000196610.xml
  82. anno_path = anno_dir + filename[:-3] + 'xml'
  83. img_path = dataDir + dataset + '/' + filename
  84. dst_imgpath = img_dir + filename
  85. img = cv2.imread(img_path)
  86. if (img.shape[2] == 1):
  87. print(filename + " not a RGB image")
  88. return
  89. shutil.copy(img_path, dst_imgpath)
  90. head = headstr % (filename, img.shape[1], img.shape[0], img.shape[2])
  91. tail = tailstr
  92. write_xml(anno_path, head, objs, tail)
  93. def showimg(coco, dataset, img, classes, cls_id, show=True):
  94. global dataDir
  95. I = Image.open('%s/%s/%s' % (dataDir, dataset, img['file_name']))# 通过id,得到注释的信息
  96. annIds = coco.getAnnIds(imgIds=img['id'], catIds=cls_id, iscrowd=None)
  97. anns = coco.loadAnns(annIds)
  98. # coco.showAnns(anns)
  99. objs = []
  100. for ann in anns:
  101. class_name = classes[ann['category_id']]
  102. if class_name in classes_names:
  103. if 'bbox' in ann:
  104. bbox = ann['bbox']
  105. xmin = int(bbox[0])
  106. ymin = int(bbox[1])
  107. xmax = int(bbox[2] + bbox[0])
  108. ymax = int(bbox[3] + bbox[1])
  109. obj = [class_name, xmin, ymin, xmax, ymax]
  110. objs.append(obj)
  111. return objs
  112. for dataset in datasets_list:
  113. # ./COCO/annotations/instances_train2014.json
  114. annFile = '{}/annotations/instances_{}.json'.format(dataDir, dataset)
  115. # COCO API for initializing annotated data
  116. coco = COCO(annFile)
  117. '''
  118. COCO 对象创建完毕后会输出如下信息:
  119. loading annotations into memory...
  120. Done (t=0.81s)
  121. creating index...
  122. index created!
  123. 至此, json 脚本解析完毕, 并且将图片和对应的标注数据关联起来.
  124. '''
  125. # show all classes in coco
  126. classes = id2name(coco)
  127. print(classes)
  128. # [1, 2, 3, 4, 6, 8]
  129. classes_ids = coco.getCatIds(catNms=classes_names)
  130. print(classes_ids)
  131. for cls in classes_names:
  132. # Get ID number of this class
  133. cls_id = coco.getCatIds(catNms=[cls])
  134. img_ids = coco.getImgIds(catIds=cls_id)
  135. # imgIds=img_ids[0:10]
  136. for imgId in tqdm(img_ids):
  137. img = coco.loadImgs(imgId)[0]
  138. filename = img['file_name']
  139. objs = showimg(coco, dataset, img, classes, classes_ids, show=False)
  140. save_annotations_and_imgs(coco, dataset, filename, objs)
  141. out_img_base = 'VOCData/images'
  142. out_xml_base = 'VOCData/Annotations'
  143. img_base = 'VOCData/images/'
  144. xml_base = 'VOCData/Annotations/'
  145. if not os.path.exists(out_img_base):
  146. os.mkdir(out_img_base)
  147. if not os.path.exists(out_xml_base):
  148. os.mkdir(out_xml_base)
  149. for img in tqdm(os.listdir(img_base)):
  150. xml = img.replace('.jpg', '.xml')
  151. src_img = os.path.join(img_base, img)
  152. src_xml = os.path.join(xml_base, xml)
  153. dst_img = os.path.join(out_img_base, img)
  154. dst_xml = os.path.join(out_xml_base, xml)
  155. if os.path.exists(src_img) and os.path.exists(src_xml):
  156. move(src_img, dst_img)
  157. move(src_xml, dst_xml)
  158. def extract_xml(infile):
  159. with open(infile,'r') as f: #解析xml中的name标签
  160. xml_text = f.read()
  161. root = ET.fromstring(xml_text)
  162. classes = []
  163. for obj in root.iter('object'):
  164. cls_ = obj.find('name').text
  165. classes.append(cls_)
  166. return classes
  167. if __name__ == '__main__':
  168. base = 'VOCData/Annotations/'
  169. Xmls=[]
  170. # Xmls = sorted([v for v in os.listdir(base) if v.endswith('.xml')])
  171. for v in os.listdir(base):
  172. if v.endswith('.xml'):
  173. Xmls.append(str(v))
  174. # iterable -- 可迭代对象。key -- 主要是用来进行比较的元素,只有一个参数,具体的函数的参数就是取自于可迭代对象中,指定可迭代对象中的一个元素来进行排序。reverse -- 排序规则,reverse = True 降序 , reverse = False 升序(默认)。
  175. print('-[INFO] total:', len(Xmls))
  176. # print(Xmls)
  177. labels = {'person': 0}
  178. for xml in Xmls:
  179. infile = os.path.join(base, xml)
  180. # print(infile)
  181. cls_ = extract_xml(infile)
  182. for c in cls_:
  183. if not c in labels:
  184. print(infile, c)
  185. raise
  186. labels[c] += 1
  187. for k, v in labels.items():
  188. print('-[Count] {} total:{} per:{}'.format(k, v, v/len(Xmls)))

2.2按VOC格式划分数据集,train : val = 0.85 : 0.15生成标签label_list.txt

  1. """
  2. 按VOC格式划分数据集,train : val = 0.85 : 0.15
  3. 生成标签label_list.txt
  4. """
  5. import os
  6. import shutil
  7. import skimage.io as io
  8. from tqdm import tqdm
  9. from random import shuffle
  10. dataset = 'dataset/VOCData/'
  11. train_txt = os.path.join(dataset, 'train_val.txt')
  12. val_txt = os.path.join(dataset, 'val.txt')
  13. lbl_txt = os.path.join(dataset, 'label_list.txt')
  14. classes = [
  15. "person"
  16. ]
  17. with open(lbl_txt, 'w') as f:
  18. for l in classes:
  19. f.write(l+'\n')
  20. xml_base = 'Annotations'
  21. img_base = 'images'
  22. xmls = [v for v in os.listdir(os.path.join(dataset, xml_base)) if v.endswith('.xml')]
  23. shuffle(xmls)
  24. split = int(0.85 * len(xmls)) #划分训练集与验证集
  25. with open(train_txt, 'w') as f:
  26. for x in tqdm(xmls[:split]):
  27. m = x[:-4]+'.jpg'
  28. xml_path = os.path.join(xml_base, x)
  29. img_path = os.path.join(img_base, m)
  30. f.write('{} {}\n'.format(img_path, xml_path))
  31. with open(val_txt, 'w') as f:
  32. for x in tqdm(xmls[split:]):
  33. m = x[:-4]+'.jpg'
  34. xml_path = os.path.join(xml_base, x)
  35. img_path = os.path.join(img_base, m)
  36. f.write('{} {}\n'.format(img_path, xml_path))

2.2 自定义数据集进行训练

2.2.1 常见标注工具

  对于图像分类任务,我们只要将对应的图片是哪个类别划分好即可。对于检测任务和分割任务,目前比较流行的数据标注工具是labelimg、labelme,分别用于检测任务与分割任务的标注。

标注工具Github地址:

labelimg :https://github.com/tzutalin/labelImg

 labelme :https://github.com/wkentaro/labelme

PPOCRLabelhttps://github.com/PaddlePaddle/PaddleOCR

三、数据处理方法

3.1 图像的本质

  我们常见的图片其实分为两种,一种叫位图,另一种叫做矢量图。如下图所示:

位图的特点:

  •   由像素点定义一放大会糊

  •   文件体积较大

  •   色彩表现丰富逼真

矢量图的特点:

  •   超矢量定义

  •   放太不模糊

  •   文件体积较小

  •   表现力差

 3.2 数据增强手段

  1. import paddle
  2. import paddlex as pdx
  3. import numpy as np
  4. import paddle.nn as nn
  5. import paddle.nn.functional as F
  6. import PIL.Image as Image
  7. import cv2
  8. import os
  9. from random import shuffle
  10. from paddlex.det import transforms as T
  11. from PIL import Image, ImageFilter, ImageEnhance
  12. import matplotlib.pyplot as plt # plt 用于显示图片
  13. path='dataset/MaskCOCOData/JPEGImages/maksssksksss195.png'
  14. img = Image.open(path)
  15. plt.imshow(img) #根据数组绘制图像
  16. plt.show() #显示图像
  17. # 灰度图
  18. img = np.array(Image.open(path).convert('L'), 'f')
  19. plt.imshow(img,cmap="gray") #根据数组绘制图像
  20. plt.show() #显示图像
  21. #小Tips:jupyter notebook中plt显示灰度图异常,需要使用plt.imshow(gray,cmap="gray")方法正常显示灰度图。
  22. img = cv2.imread(path)
  23. plt.subplot(221)
  24. plt.imshow(img,cmap="gray")
  25. # matplotlib 按照RGB顺序展示原图
  26. plt.imshow(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
  27. plt.subplot(222)
  28. # cv2默认的GBR显示图
  29. def preprocess(dataType="train"):
  30. if dataType == "train":
  31. transform = T.Compose([
  32. T.MixupImage(mixup_epoch=10), #对图像进行mixup操作,模型训练时的数据增强操作,目前仅YOLOv3模型支持该transform
  33. # T.RandomExpand(), #随机扩张图像
  34. # T.RandomDistort(brightness_range=1.2, brightness_prob=0.3), #以一定的概率对图像进行随机像素内容变换
  35. # T.RandomCrop(), #随机裁剪图像
  36. # T.ResizeByShort(), #根据图像的短边调整图像大小
  37. T.Resize(target_size=608, interp='RANDOM'), #调整图像大小,[’NEAREST’, ‘LINEAR’, ‘CUBIC’, ‘AREA’, ‘LANCZOS4’, ‘RANDOM’]
  38. # T.RandomHorizontalFlip(), #以一定的概率对图像进行随机水平翻转
  39. T.Normalize() #对图像进行标准化
  40. ])
  41. return transform
  42. else:
  43. transform = T.Compose([
  44. T.Resize(target_size=608, interp='CUBIC'),
  45. T.Normalize()
  46. ])
  47. return transform
  48. train_transforms = preprocess(dataType="train")
  49. eval_transforms = preprocess(dataType="eval")
  50. # 定义训练和验证所用的数据集
  51. # API地址:https://paddlex.readthedocs.io/zh_CN/develop/data/format/detection.html?highlight=paddlex.det
  52. train_dataset = pdx.datasets.VOCDetection(
  53. data_dir='./dataset/MaskVOCData',
  54. file_list='./dataset/MaskVOCData/train_list.txt',
  55. label_list='./dataset/MaskVOCData/label_list.txt',
  56. transforms=train_transforms,
  57. shuffle=True)
  58. eval_dataset = pdx.datasets.VOCDetection(
  59. data_dir='./dataset/MaskVOCData',
  60. file_list='./dataset/MaskVOCData/val_list.txt',
  61. label_list='./dataset/MaskVOCData/label_list.txt',
  62. transforms=eval_transforms)
  63. plt.imshow(img)
  64. plt.subplot(223)
  65. # 32*32的缩略图
  66. plt.imshow(cv2.resize(img, (32, 32)))
  67. #图像处理示例 目标视野里比较多重叠,或者有点模糊的适用
  68. path='dataset/MaskCOCOData/JPEGImages/maksssksksss443.png'
  69. img = Image.open(path)
  70. plt.imshow(img)
  71. plt.show()
  72. #锐化
  73. img = img.filter(ImageFilter.SHARPEN)
  74. img = img.filter(ImageFilter.SHARPEN)
  75. plt.imshow(img)
  76. plt.show()
  77. #亮度变换
  78. bright_enhancer = ImageEnhance.Brightness(img) # 传入调整系数亮度
  79. img = bright_enhancer.enhance(1.6)
  80. plt.imshow(img)
  81. plt.show()
  82. #提高对比度
  83. contrast_enhancer = ImageEnhance.Contrast(img) # 传入调整系数对比度
  84. img = contrast_enhancer.enhance(1.9)
  85. plt.imshow(img)
  86. plt.show()

 3.3 为什么做数据增强

是因为很多深度学习的模型复杂度太高了,且在数据量少的情况下,比较容易造成过拟合(通俗来说就是训练的这个模型它太沉浸在这个训练样本当中的一些特质上面了),表现为的这个模型呢受到了很多无关因素的影响。 所得出的结果就是在没有看到过的样本上对它做出预测呢就表现的不太好。

四、模型训练与评估

  1. import matplotlib
  2. matplotlib.use('Agg')
  3. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  4. %matplotlib inline
  5. import warnings
  6. warnings.filterwarnings("ignore")
  7. #num_classes有些模型需要加1 比如faster_rcnn
  8. num_classes = len(train_dataset.labels)
  9. model = pdx.det.PPYOLO(num_classes=num_classes, )
  10. model.train(
  11. num_epochs=70,
  12. train_dataset=train_dataset,
  13. train_batch_size=16,
  14. eval_dataset=eval_dataset,
  15. learning_rate=3e-5,
  16. warmup_steps=90,
  17. warmup_start_lr=0.0,
  18. save_interval_epochs=7,
  19. lr_decay_epochs=[42, 70],
  20. save_dir='output/PPYOLO',
  21. use_vdl=True)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/718435
推荐阅读
相关标签
  

闽ICP备14008679号