当前位置:   article > 正文

Deeplab训练自己的数据集_tf_initial_checkpoint

tf_initial_checkpoint

1.制作自己的数据集

1.1 用labelme生成json文件

lebelme安装:

# Ubuntu 14.04 / Ubuntu 16.04
# Python2
# sudo apt-get install python-qt4  # PyQt4
sudo apt-get install python-pyqt5  # PyQt5
sudo pip install labelme
# Python3
sudo apt-get install python3-pyqt5  # PyQt5
sudo pip3 install labelme
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

1.2 把json文件生成.png文件:

  1. 调整label.png为灰度图
  2. 批量转换成_gt.png
  3. 提取出所有的_gt.png文件

参考:https://note.youdao.com/ynoteshare1/index.html?id=032620eac64634508cd4f9e65be4617c&type=note#/

要把下面这两句话注释掉,因为会报错,生成的可视化分割图片也没什么软用

lbl_viz = utils.draw.draw_label(lbl, img, captions)
 ...
PIL.Image.fromarray(lbl_viz).save(osp.join(out_dir,'{}_viz.png'.format(filename)))
  • 1
  • 2
  • 3

1.3 数据集目录

 + image
 + mask
 + index 
    - train.txt 
    - trainval.txt 
    - val.txt 
 + tfrecord 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • iamge中存放所有的输入图片,包括训练、测试、验证集的图片
  • mask中存放所有的labeled图片,,和输入图片(即iamge)是一一对应的

PS:这里需要注意一个点,image和mask的文件名应该一致,且全部小写,上一步产生的iamge后缀大写,用 rename ‘y/A-Z/a-z/’ * 修改,,mask文件名是000000_gt.png,用 rename ‘s/_gt.png/.png/’ ./* 修改,这样image和mska的文件名就能对应。对应代码如下:

rename   's/\_gt.png/.png/' ./*  #修改后缀

rename 'y/A-Z/a-z/' *   #全部小写
  • 1
  • 2
  • 3
  • index

该目录下包含三个.txt文件:

train.txt:所有训练集的文件名称
trainval.txt:所有验证集的文件名称
val.txt:所有测试集的文件名称
生成这三个数据集的代码:

# -*- coding: utf-8 -*-
import os
import random

xmlfilepath = 'image'   #注意这里是image而非mask
txtsavepath = 'index'
total_xml = os.listdir(xmlfilepath)
num=len(total_xml)
list = range(num)
trainval = random.sample(list, num)  

os.chdir('PATH/index')   

ftrainval = open('train.txt', 'w')  

for i in list :
    name =total_xml[i][:-4] + '\n'
    ftrainval.write(name)
ftrainval.close()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

1.4 生成tfrecord

用生成voc数据集的build_voc2012_data.py来生成

  • image_folder :数据集中原输入数据的文件目录地址
  • semantic_segmentation_folder:数据集中标签的文件目录地址
  • list_folder : 将数据集分类成训练集、验证集等的指示目录文件目录
  • image_format : 输入图片数据的格式,我的数据集是jpg格式
  • output_dir:制作的TFRecord存放的目录地址(自己创建)
python ./datasets/build_voc2012_data.py \ --image_folder="/home/zyr/deeplab/models/research/deeplab/datasets/my_data/segment/image" \ 
--semantic_segmentation_folder="/home/zyr/deeplab/models/research/deeplab/datasets/my_data/segment/mask" \ 
--list_folder="/home/zyr/deeplab/models/research/deeplab/datasets/my_data/segment/index" \ 
--image_format="jpg" \ 
--output_dir="/home/zyr/deeplab/models/research/deeplab/datasets/my_data/segment/tfrecord"
  • 1
  • 2
  • 3
  • 4
  • 5

2.训练前代码准备

2.1. 修改segmentation_dataset.py

大约在line 100左右,添加如下代码(注意num_classes=num( label+1 ),+1是background类别,没有用ignore label(我也不知道是啥)):

_LAB_DATASET = DatasetDescriptor(
    splits_to_sizes={
        'train': 150,   # num of samples in images/training
        #'train_aug': 10582,
        'trainval': 250,
        'val': 100,
    },
    num_classes=4,   #label+1 (not use ignore label)
    ignore_label=255,
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

大约在line 112,添加对应数据集的名称:

_DATASETS_INFORMATION = {
    'cityscapes': _CITYSCAPES_INFORMATION,
    'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
    'ade20k': _ADE20K_INFORMATION,
    'lab': _LAB_DATASET,
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

2.2. 修改train_utils.py

大约在line 109,exclude_list的设置修改,作用是在使用预训练权重时候,不加载该logit层:

  # Variables that will not be restored.
  #exclude_list = ['global_step','logits']
  exclude_list = ['global_step']
  if not initialize_last_layer:
    exclude_list.extend(last_layers)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

2.3. 数据不平衡问题

在train_utils.py的70行修改权重。但是我的数据集比较均衡,所以没用用到。

2.4. 修改train .py

    initialize_last_layer=False
    last_layers_contain_logits_only=True

  • 1
  • 2
  • 3

3.开始训练

官方给出的指令格式:

python deeplab/train.py \
    --logtostderr \
    --training_number_of_steps=90000 \
    --train_split="train" \
    --model_variant="xception_65" \
    --atrous_rates=6 \
    --atrous_rates=12 \
    --atrous_rates=18 \
    --output_stride=16 \
    --decoder_output_stride=4 \
    --train_crop_size=769 \
    --train_crop_size=769 \
    --train_batch_size=1 \
    --dataset="cityscapes" \
    --tf_initial_checkpoint=${PATH_TO_INITIAL_CHECKPOINT} \
    --train_logdir=${PATH_TO_TRAIN_DIR} \
    --dataset_dir=${PATH_TO_DATASET}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

training_number_of_steps: 训练迭代次数
train_crop_size:训练图片的裁剪大小,我将这个设置为513
tf_initial_checkpoint:预训练的权重,使用CityScapes的预训练权重
train_logdir:训练产生的文件存放位置
train_batch_size:训练的batchsize,这里batchsize设置为4,如果想复现paper效果,建议设置8
dataset_dir:数据集的TFRecord文件
dataset:设置为在segmentation_dataset.py文件设置的数据集名称
如果显存小的话把fine_tune_batch_norm调成False
训练集150张图片,迭代30000次,验证集100张图片

下面是我自己的训练命令:

python ./train02.py \
    --logtostderr \
    --training_number_of_steps=30000 \
    --train_split="train" \
    --model_variant="xception_65" \
    --fine_tune_batch_norm=False \
    --atrous_rates=6 \
    --atrous_rates=12 \
    --atrous_rates=18 \
    --output_stride=16 \
    --decoder_output_stride=4 \
    --train_crop_size=513 \
    --train_crop_size=513 \
    --train_batch_size=4 \
    --dataset="lab" \
    --tf_initial_checkpoint='/home/zyr/deeplab/models/research/deeplab/backbone/deeplabv3_cityscapes_train/model.ckpt' \
    --train_logdir='/home/zyr/deeplab/models/research/deeplab/datasets/my_data/exp/train_on_train_set/train' \
    --dataset_dir='/home/zyr/deeplab/models/research/deeplab/datasets/my_data/segment/tfrecord'

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

4.验证并可视化

4.1 eval

eval指令分析:

前面我们训练了一些模型,下面测试一下。

官方给出的验证指令格式为:

#From tensorflow/models/research/
python deeplab/eval.py \
    --logtostderr \
    --eval_split="val" \
    --model_variant="xception_65" \
    --atrous_rates=6 \
    --atrous_rates=12 \
    --atrous_rates=18 \
    --output_stride=16 \
    --decoder_output_stride=4 \
    --eval_crop_size=1025 \
    --eval_crop_size=2049 \
    --dataset="cityscapes" \
    --checkpoint_dir=${PATH_TO_CHECKPOINT} \
    --eval_logdir=${PATH_TO_EVAL_DIR} \
    --dataset_dir=${PATH_TO_DATASET}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

这里参考调试指令参考local_test.sh,其中有几个比较关键的参数设置如下:

eval_crop_size:验证图片的裁剪大小
checkpoint_dir:预训练的checkpoint,这里设置的即是前面训练模型存储的地址
eval_logdir: 保存验证结果的目录,注意在开始的创建工程目录的时候就创建了
dataset_dir:数据集的地址,前面创建的TFRecords目录

eval实际调用指令

python ./eval.py \
    --logtostderr \
    --eval_split="val" \
    --model_variant="xception_65" \
    --atrous_rates=6 \
    --atrous_rates=12 \
    --atrous_rates=18 \
    --output_stride=16 \
    --decoder_output_stride=4 \
    --eval_crop_size=1080 \
    --eval_crop_size=1920 \
    --dataset="lab" \
    --checkpoint_dir='/home/zyr/deeplab/models/research/deeplab/datasets/my_data/exp/train_on_train_set/train/' \
    --eval_logdir='/home/zyr/deeplab/models/research/deeplab/datasets/my_data/exp/train_on_train_set/val/' \
    --dataset_dir='/home/zyr/deeplab/models/research/deeplab/datasets/my_data/segment/tfrecord/'

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

结果如下:

总共val了100张图片

在这里插入图片描述

可见miou能达到0.78

4.2.vis

vis指令分析

# From tensorflow/models/research/
python deeplab/vis.py \
    --logtostderr \
    --vis_split="val" \
    --model_variant="xception_65" \
    --atrous_rates=6 \
    --atrous_rates=12 \
    --atrous_rates=18 \
    --output_stride=16 \
    --decoder_output_stride=4 \
    --vis_crop_size=1025 \
    --vis_crop_size=2049 \
    --dataset="cityscapes" \
    --colormap_type="cityscapes" \
    --checkpoint_dir=${PATH_TO_CHECKPOINT} \
    --vis_logdir=${PATH_TO_VIS_DIR} \
    --dataset_dir=${PATH_TO_DATASET}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

这里参考调试指令参考local_test.sh,其中有几个比较关键的参数设置如下:

vis_crop_size:图片的裁剪大小
checkpoint_dir:预训练的checkpoint,这里设置的即是前面训练模型存储的地址

vis_logdir: 保存可视化结果的目录

dataset_dir:数据集的地址,前面创建的TFRecords目录

vis实际调用指令

#!/usr/bin/env bash

python ./vis.py \
    --logtostderr \
    --vis_split="val" \
    --model_variant="xception_65" \
    --atrous_rates=6 \
    --atrous_rates=12 \
    --atrous_rates=18 \
    --output_stride=16 \
    --decoder_output_stride=4 \
    --vis_crop_size=1080 \
    --vis_crop_size=1920 \
    --dataset="lab" \
    --colormap_type="pascal" \
    --checkpoint_dir='/home/zyr/deeplab/models/research/deeplab/datasets/my_data/exp/train_on_train_set/train/model.ckpt-30000' \
    --vis_logdir='/home/zyr/deeplab/models/research/deeplab/datasets/my_data/exp/train_on_train_set/vis/' \
    --dataset_dir='/home/zyr/deeplab/models/research/deeplab/datasets/my_data/segment/tfrecord/'

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

部分结果图:

在这里插入图片描述
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/123768
推荐阅读
相关标签
  

闽ICP备14008679号