当前位置:   article > 正文

Pytorch入门:Mask R-CNN图像实例分割迁移学习(PyTorch官方教程)_pytorch mask

pytorch mask

     通过微调预训练模型Mask R-CNN来完成目标检测及语义分割,数据集采用Penn-Fudan Database for Pedestrian Detection and Segmentation。该数据集包括170张图片,所有图片中共有345个行人,该数据集与PASCAL VOC数据集类似,本文主要内容为pytorch框架如何在自己的数据集上训练一个实例分割模型。

重点学习:
     1.如何定义自己的数据集
     2.迁移学习的两种方式(仅微调模型最后一层 / 修改模型的backbone)

Pytorch官方教程链接

1. 处理数据集

     定义自己的数据集时需要继承torch.utils.data.Dataset类,并且实现__len ____ getitem __方法,其中__ getitem __方法应该返回:

image: a PIL Image of size (H, W)
target: a dict containing the following fields:
        1. boxes (FloatTensor[N, 4]): the coordinates of the N bounding boxes in [x0, y0, x1, y1] format, ranging from 0 to W and 0 to H
        2. labels (Int64Tensor[N]): the label for each bounding box
        3. image_id (Int64Tensor[1]): an image identifier. It should be unique between all the images in the dataset, and is used during evaluation
        4. area (Tensor[N]): The area of the bounding box. This is used during evaluation with the COCO metric, to separate the metric scores between small, medium and large boxes.
        5. iscrowd (UInt8Tensor[N]): instances with iscrowd=True will be ignored during evaluation.
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

     注意如果想要在训练过程中对图片使用不同的ratio(每一个batch中图像的ratio是相同的),则还需要实现 __ get_height_and_width__方法,该方法返回图像的高和宽。

     下载完 数据集Penn-Fudan Database 后,看一下它的文件结构。首先展示一下数据集其中任意一张图片及其对应mask图片:

from PIL import Image
import os
img = Image.open('PennFudanPed/PNGImages/FudanPed00012.png')
img.show()
mask = Image.open('PennFudanPed/PedMasks/FudanPed00012_mask.png')
mask.putpalette([
    0, 0, 0, # black background
    255, 0, 0, # index 1 is red
    255, 255, 0, # index 2 is yellow
    255, 153, 0, # index 3 is orange
])
mask.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

     数据集中每一张图片对应一张分割mask,在mask中每一种颜色对应一个实例(行人),对应输出为:

对于以上代码中用到的函数 :
Image.putpalette(data, rawmode='RGB')

Parameters:
data – A palette sequence (either a list or a string).
rawmode – The raw mode of the palette.
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

     图片模式定义了图片的类型和像素深度,目前标准图片模式常用的有:

1 (1-bit pixels, black and white, stored with one pixel per byte)
L (8-bit pixels, black and white)
P (8-bit pixels, mapped to any other mode using a color palette)
RGB (3x8-bit pixels, true color)
RGBA (4x8-bit pixels, true color with transparency mask)

     Palette(调色板):The palette mode (P) uses a color palette to define the actual color for each pixel. 即调色板模式(p)使用调色板来定义每个像素的实际颜色。

     下面编写类PennFudanDataset来处理该数据集:

from PIL import Image
import os
import numpy as np
import torch
import torch.utils.data

# img = Image.open('PennFudanPed/PNGImages/FudanPed00012.png')
# img.show()
# mask = Image.open('PennFudanPed/PedMasks/FudanPed00012_mask.png')
# mask.putpalette([
#     0, 0, 0, # black background
#     255, 0, 0, # index 1 is red
#     255, 255, 0, # index 2 is yellow
#     255, 153, 0, # index 3 is orange
# ])
# mask.show()

class PennFudanDataset(torch.utils.data.Dataset):
    def __init__(self, root, transfroms):
        self.root = root
        self.transfroms = transfroms

        # 在当前工作目录下获取所有排序好的文件名存入一个list
        self.imgs = list(sorted(os.listdir(os.path.join(root, 'PNGImages'))))
        self.masks = list(sorted(os.listdir(os.path.join(root, 'PedMasks'))))

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, 'PNGImages', self.imgs[idx])
        mask_path = os.path.join(self.root, 'PedMasks', self.masks[idx])
        # 确保图像为RGB模式,而mask不需要转换为RGB模式,因为mask背景为0,其他每种颜色代表一个实例
        img = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path)
        # 把PIL图像转换为numpy数组,得到mask中的实例编码并去掉背景
        mask = np.array(mask)
        obj_id = np.unique(mask)
        obj_id = obj_id[1:]

        # None就是newaxis,相当于多了一个维度

        # split the color-encoded mask to a set of binary masks
        # 下面这行代码的解释:以FudanPed000012为例,有两个目标,FudanPed000012_mask中像素为0表示背景,
        # 像素1表示目标1,像素2表示目标2,仅用于代表目标,而并非通过颜色显示,所以点开mask图像肉眼看到全部都是黑色的
        # mask是一个559*536的二维矩阵,obj_id=[0, 1, 2]
        # “obj_ids = obj_ids[1:]”去掉背景像素0 , 故obj_id=[1, 2]
        # 而下面这行代码,创建了masks(2*559*536),包含两个大小为(559*536)的mask,分别对应第一个目标和第二个目标,
        # 第一个mask中,目标1所占像素为True,其余全为False,第二个mask中,目标2所占像素为True,其余全为False。
        
        masks = mask == obj_id[:, None, None]  # 即使图片的L模式为8字节单通道,而PIL读入时仍作为3通道处理

        # 对于每一个mask的边界框坐标
        num_objs= len(obj_id)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # 数据集只有一个类别
        labels = torch.ones((num_objs,), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transfroms is not None:
            img, target = self.transfroms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

# 验证输出
# dataset = PennFudanDataset('PennFudanPed/')
# print(dataset[0])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88

2. Mask R-CNN微调模型

     有两种微调模型的方法,一种是只微调预训练模型的最后一层输出,另一种是用另外一个模型代替backbone。两种方法以下代码中均给出实例,本文目的为建立实力分割模型,故使用Mask R-CNN模型,因为给定数据集较小,故使用第一种微调预训练模型方法,即仅微调模型的最后一层输出。

# 1. start from a pre-trained model, and just finetune the last layer.
import torchvision
import torch
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor


# 加载一个预训练模型
# model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained = True)
# numclasses = 2
# in_features = model.roi_heads.box_predictor.cls_score.in_features
# model.roi_heads.box_predictor = FastRCNNPredictor(in_features, numclasses)


# 2. Modifying the model to add a different backbone
# import torchvision
# from torchvision.models.detection import FasterRCNN
# from torchvision.models.detection.rpn import AnchorGenerator
#
# backbone = torchvision.models.mobilenet_v2(pretrained=True).features
# backbone.out_channels = 1280
# anchor_generator = AnchorGenerator(sizes=((32,64,128,256,512),), aspect_ratios=((0.5,1.0,2.0),))
#
# roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0], output_size=7, sampling_ratio=2)
# model = FasterRCNN(
#         backbone,
#         num_classes=2,
#         rpn_anchor_generator=anchor_generator,
#         box_roi_pool=roi_pooler
# )

# def get_model_instance_segmentation(num_classes):
#     model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
#     in_features = model.roi_heads.box_predictor.cls_score.in_features
#     model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
#     in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
#     hidden_layer = 256
#     model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
#                                                        hidden_layer,
#                                                        num_classes)
#     return model

def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60

3. 模型的训练及验证

     Pytorch官方提供了很多关于训练和评估检测模型的方法,这里用到了references/detection/engine.py, utils.pytransforms.py.下载地址

     首先定义图像转换函数:

import transforms as T

def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

其次定义模型训练主函数:

from Penn_Fudan_dataset import PennFudanDataset
from Mask_rcnn_Model import get_model_instance_segmentation
import torch
import utils
import torchvision
from torch.utils.data import DataLoader
from engine import train_one_epoch, evaluate
import torchvision.transforms as T

# 数据增强/转换

def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_classes = 2
    dataset = PennFudanDataset('PennFudanPed', get_transform(train=True))
    dataset_test = PennFudanDataset('PennFudanPed', get_transform(train=False))

    indices = torch.randperm(len(dataset)).tolist()
    dataset = torch.utils.data.Subset(dataset, indices[:-50])
    dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

    data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=4,
    )

    data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=4,
    )

    model = get_model_instance_segmentation(num_classes)
    model.to(device)

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.0005, momentum=0.9, weight_decay=0.0005)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=3,
                                                   gamma=0.1)

    num_epochs = 10
    for epoch in range(num_epochs):
        train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
        lr_scheduler.step()
        evaluate(model, data_loader_test, device)
 
if __name__ == '__main__':
    main()
    print("That's it!")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54

4. 遗留问题(解决后删掉)

      学习该教程过程中,Pytroch官方推荐使用colab跑该代码,自己在pycharm中搭建并编写好程序后发现报出以下错误,研究半天未能解决,代码编写没问题,自己感觉是代码与下载的engine.py, utils.py 之间的冲突,还望大佬指点指点!错误如下:

Traceback (most recent call last):
  File "E:/Coding/pycharm/3 Penn-Fudan Database for Pedestrian Detection and Segmentation/train_model.py", line 74, in <module>
    main()
  File "E:/Coding/pycharm/3 Penn-Fudan Database for Pedestrian Detection and Segmentation/train_model.py", line 49, in main
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
  File "E:\Coding\pycharm\3 Penn-Fudan Database for Pedestrian Detection and Segmentation\engine.py", line 26, in train_one_epoch
    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
  File "E:\Coding\pycharm\3 Penn-Fudan Database for Pedestrian Detection and Segmentation\utils.py", line 209, in log_every
    for obj in iterable:
  File "F:\Anaconda\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 363, in __next__
    data = self._next_data()
  File "F:\Anaconda\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 989, in _next_data
    return self._process_data(data)
  File "F:\Anaconda\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 1014, in _process_data
    data.reraise()
  File "F:\Anaconda\envs\pytorch\lib\site-packages\torch\_utils.py", line 395, in reraise
    raise self.exc_type(msg)
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "F:\Anaconda\envs\pytorch\lib\site-packages\torch\utils\data\_utils\worker.py", line 185, in _worker_loop
    data = fetcher.fetch(index)
  File "F:\Anaconda\envs\pytorch\lib\site-packages\torch\utils\data\_utils\fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "F:\Anaconda\envs\pytorch\lib\site-packages\torch\utils\data\_utils\fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "F:\Anaconda\envs\pytorch\lib\site-packages\torch\utils\data\dataset.py", line 257, in __getitem__
    return self.dataset[self.indices[idx]]
  File "E:\Coding\pycharm\3 Penn-Fudan Database for Pedestrian Detection and Segmentation\Penn_Fudan_dataset.py", line 79, in __getitem__
    img, target = self.transfroms(img, target)
TypeError: __call__() takes 2 positional arguments but 3 were given
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

欢迎关注【OAOA

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/949140
推荐阅读
相关标签
  

闽ICP备14008679号