当前位置:   article > 正文

mask rcnn训练自己的数据_mask rcnn训练自己数据

mask rcnn训练自己数据

主流的Mask RCNN主要有两个,一个是matterport的keras版本,另一个是facebookresearch的pytorch版本。

pytorch版本的编译报错问题较多,以下采用tensorflow-gpu版本。

检查gpu算力,低算力下高版本tf可能只调用cpu

参考:tensorflow对应CUDA版本

验证gpu

  1. import tensorflow as tf
  2. tf.test.is_built_with_cuda()#验证tf是否调用GPU
  3. tf.config.list_physical_devices('GPU')#查看gpu列表

下载requirements,tensorflow,cuda,mask_rcnn_coco.h5权重

python setup.py install

划分数据集

data

——cv2_mask(掩膜)

——jason(jason)

——labelme_json(json文件夹)

——pic(原图)

其中labelme新版本在json文件夹生成info.yaml的方法

~\envs\xx\Lib\site-packages\labelme\cli\json_to_dataset.py

  1. f.write(lbl_name + '\n')
  2. # 添加部分
  3. logger.warning('info.yaml is being replaced by label_names.txt')
  4. info = dict(label_names=label_names)
  5. with open(osp.join(out_dir, 'info.yaml'), 'w') as f:
  6. yaml.safe_dump(info, f, default_flow_style=False)
  7. logger.info('Saved to: {}'.format(out_dir))

训练代码参考:修改train_shapes.ipynb

其中

ROOT_DIR = os.path.abspath("../../")#根目录地址
  1. #显存设置
  2. GPU_COUNT = 2
  3. IMAGES_PER_GPU = 1
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")#权重文件地址
NUM_CLASSES = 1 + 80  # 背景加类别
  1. #图片像素大小(用自己数据中最低的分辨率)
  2. IMAGE_MIN_DIM = 768
  3. IMAGE_MAX_DIM = 1024
  1. #训练次数
  2. model.train(dataset_train, dataset_val,
  3. learning_rate=config.LEARNING_RATE,
  4. epochs=300,
  5. layers='heads')
self.add_class("shapes", 1, "box")#逐个添加每个类别
  1. #逐个添加每个类别
  2. if labels[i].find("box")!=-1:
  3. labels_form.append("box")
dataset_root_path=os.path.join(ROOT_DIR, "C:\\Users\\LARA\\Desktop\\Mask_RCNN-master\\train_data_test1")#数据集目录
  1. #图片尺寸
  2. width = 1024
  3. height = 768
  4. dataset_train.load_shapes(count, 768, 1024, img_floder, mask_floder, imglist,dataset_root_path)
  5. dataset_val.load_shapes(count, 768, 1024, img_floder, mask_floder, imglist,dataset_root_path)

改写train.py中的init_with=““last””,下次训练时会接着上一次的epoch继续训练。 

运行问题

1>No module named ‘xxx‘

下载对应安装包

2>load() missing 1 required positional argument: ‘Loader‘

降低yaml版本到5.4.1

3>‘str‘ object has no attribute ‘decode‘

报错位置修改decode('utf-8')为encode('utf-8').decode('utf-8')

降级h5py到3以下版本

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/447566
推荐阅读
相关标签
  

闽ICP备14008679号