当前位置:   article > 正文

EfficientDet训练自己的数据集_efficientdet训练自己数据集

efficientdet训练自己数据集

https://github.com/toandaominh1997/EfficientDet.Pytorchhttps://github.com/toandaominh1997/EfficientDet.Pytorch

问题1:torch.nn.modules.module.ModuleAttributeError: 'EfficientDet' object has no attribute 'module'

  1. model.module.is_training = True
  2. model.module.freeze_bn()
  3. 改为:
  4. model.is_training = True
  5. model.freeze_bn()

问题2:No boxes to NMS

在issue中看了关于这个问题的讨论,这个问题是普遍存在的,建议换一个。。。。 

经过反复尝试,确实不行,但是仍然给出我修改了以后能正常跑起来但预测没有效果的更改,至少数据加载那些改成了加载自己数据集的。下面简要说下

1.数据准备

该项目里有两种数据加载方式,VOC和COCO,所以我们需要做的就是更改自己的数据为这两种数据中的一种,目录结构如下:

具体如下: 

 

 2.训练

--dataset 选择我数据的格式COCO

--dataset_root 数据根目录

--network 表示网络结构 d0, .... d7一共八种自己选

--num_class 类别数别忘了改成自己的

--device 显卡设备列表  --gpu 选择使用的gpu

--workers 线程个数,workers大于0的时候,Windows经常报错 

 改完上面就可以跑了,我跑完测试检测不出结果,下面换一个跑跑

https://github.com/signatrix/efficientdethttps://github.com/signatrix/efficientdet1.按照上面的方式将自己的数据准备得跟COCO数据集格式一样

2.简单修改下训练文件

基本上就是改基础的训练参数,模型的存储位置那些都是自动保存

数据的路径给根目录就行 

3.效果,实际效果一般,我自己也没好好做,还需要自己好好调试,我用阿里的竞赛做了测试,评分差不多才0.6左右,竞赛的链接里面有数据

零基础入门CV - 街景字符编码识别-天池大赛-阿里云天池零基础入门CV - 街景字符编码识别本次新人赛是Datawhale与天池联合发起的零基础入门系列赛事第二场 —— 零基础入门CV赛事之街景字符识别,赛题以计算机视觉中字符识别为背景,要求选手预测真实场景下的字符识别,这是一个典型的字符识别问题。https://tianchi.aliyun.com/competition/entrance/531795/introduction?spm=5176.12281925.0.0.26087137F5M0lm4.预测代码我改了一下,predict.py如下(用于得到阿里那个比赛的标准输出):

  1. # coding: utf-8
  2. from ast import RShift
  3. import os
  4. import argparse
  5. import torch
  6. from torchvision import transforms
  7. from src.dataset import CocoDataset, Resizer, Normalizer
  8. from src.config import COCO_CLASSES, colors
  9. import cv2
  10. import shutil
  11. import numpy as np
  12. import pandas as pd
  13. import json
  14. def get_args():
  15. parser = argparse.ArgumentParser(
  16. "EfficientDet: Scalable and Efficient Object Detection implementation by Signatrix GmbH")
  17. parser.add_argument("--image_size", type=int, default=448, help="The common width and height for all images")
  18. parser.add_argument("--data_path", type=str, default="D:/csdn/tc/work2/data/", help="the root folder of dataset")
  19. parser.add_argument("--cls_threshold", type=float, default=0.4)
  20. parser.add_argument("--nms_threshold", type=float, default=0.5)
  21. parser.add_argument("--pretrained_model", type=str, default="trained_models/signatrix_efficientdet_coco.pth")
  22. parser.add_argument("--output", type=str, default="predictions")
  23. args = parser.parse_args()
  24. return args
  25. class Resizer():
  26. """Convert ndarrays in sample to Tensors."""
  27. def __call__(self, image, common_size=512):
  28. height, width, _ = image.shape
  29. if height > width:
  30. scale = common_size / height
  31. resized_height = common_size
  32. resized_width = int(width * scale)
  33. else:
  34. scale = common_size / width
  35. resized_height = int(height * scale)
  36. resized_width = common_size
  37. image = cv2.resize(image, (resized_width, resized_height))
  38. new_image = np.zeros((common_size, common_size, 3))
  39. new_image[0:resized_height, 0:resized_width] = image
  40. return torch.from_numpy(new_image), scale
  41. class Normalizer():
  42. def __init__(self):
  43. self.mean = np.array([[[0.485, 0.456, 0.406]]])
  44. self.std = np.array([[[0.229, 0.224, 0.225]]])
  45. def __call__(self, image):
  46. return ((image.astype(np.float32) - self.mean) / self.std)
  47. if __name__ == "__main__":
  48. opt = get_args()
  49. checkpoint_file = opt.pretrained_model
  50. model = torch.load(opt.pretrained_model).module
  51. model.cuda()
  52. d = {}
  53. df = pd.DataFrame(columns=['file_name','file_code'])
  54. image_path = "D:/csdn/tc/work2/data/mchar_test_a/"
  55. save_path = ''
  56. piclist = os.listdir(image_path)
  57. piclist.sort()
  58. index = 0
  59. common_size = 256 # datasets.py Resizer
  60. for pic_name in piclist:
  61. # if index == 10:
  62. # break
  63. index += 1
  64. if index % 1000 == 0:
  65. print(f"{index}/40000")
  66. pic_path = os.path.join(image_path, pic_name)
  67. # print(pic_path)
  68. img = cv2.imread(pic_path)
  69. img1 = img
  70. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  71. img = img.astype(np.float32) / 255.
  72. img = Normalizer()(img)
  73. img, scale = Resizer()(img, common_size=common_size)
  74. with torch.no_grad():
  75. scores, labels, boxes = model(img.cuda().permute(2, 0, 1).float().unsqueeze(dim=0))
  76. boxes /= scale
  77. ss = ''
  78. dts = []
  79. if boxes.shape[0] > 0:
  80. # path = os.path.join(opt.output, pic_name)
  81. # output_image = cv2.imread(path)
  82. output_image = img1
  83. for box_id in range(boxes.shape[0]):
  84. pred_prob = float(scores[box_id])
  85. if pred_prob < opt.cls_threshold:
  86. break
  87. pred_label = int(labels[box_id])
  88. xmin, ymin, xmax, ymax = boxes[box_id, :]
  89. dts.append({'class':COCO_CLASSES[pred_label], 'xmin':xmin.item()})
  90. # ss += str(COCO_CLASSES[pred_label])
  91. temp = sorted(dts, key = lambda i: i['xmin'])
  92. for e in temp:
  93. ss += e['class']
  94. df = df.append([{"file_name": pic_name, "file_code": ss}], ignore_index=True)
  95. df.to_csv("sub2.csv",index=False)

说明:这个网络没时间好好调参了,下面给出我用的代码(包含数据)

链接:https://pan.baidu.com/s/1vR8vnCT2yQH7unwYREVn8Q 
提取码:kz3d 
--来自百度网盘超级会员V5的分享

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/262438
推荐阅读
  

闽ICP备14008679号