当前位置:   article > 正文

Pytorch---使用Pytorch实现U-Net进行语义分割_u-net语义分割代码

u-net语义分割代码

一、代码中的数据集可以通过以下链接获取

百度网盘提取码:f1j7

二、代码运行环境

Pytorch-gpu==1.10.1
Python==3.8

三、数据集处理代码如下所示

import os
import torch
from torch.utils import data
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks


class MaskDataset(data.Dataset):
    def __init__(self, image_paths, mask_paths, transform):
        super(MaskDataset, self).__init__()
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        label_path = self.mask_paths[index]

        pil_img = Image.open(image_path)
        pil_img = pil_img.convert('RGB')
        img_tensor = self.transform(pil_img)

        pil_label = Image.open(label_path)
        label_tensor = self.transform(pil_label)
        label_tensor[label_tensor > 0] = 1
        label_tensor = torch.squeeze(input=label_tensor).type(torch.long)

        return img_tensor, label_tensor

    def __len__(self):
        return len(self.mask_paths)


def load_data():
    # DATASET_PATH = r'/home/akita/hk'
    DATASET_PATH = r'/Users/leeakita/Desktop/hk'
    TRAIN_DATASET_PATH = os.path.join(DATASET_PATH, 'training')
    TEST_DATASET_PATH = os.path.join(DATASET_PATH, 'testing')

    train_file_names = os.listdir(TRAIN_DATASET_PATH)
    test_file_names = os.listdir(TEST_DATASET_PATH)

    train_image_names = [name for name in train_file_names if
                         'matte' in name and name.split('_')[0] + '.png' in train_file_names]
    train_image_paths = [os.path.join(TRAIN_DATASET_PATH, name.split('_')[0] + '.png') for name in
                         train_image_names]
    train_label_paths = [os.path.join(TRAIN_DATASET_PATH, name) for name in train_image_names]

    test_image_names = [name for name in test_file_names if
                        'matte' in name and name.split('_')[0] + '.png' in test_file_names]
    test_image_paths = [os.path.join(TEST_DATASET_PATH, name.split('_')[0] + '.png') for name in test_image_names]
    test_label_paths = [os.path.join(TEST_DATASET_PATH, name) for name in test_image_names]

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    BATCH_SIZE = 8

    train_ds = MaskDataset(image_paths=train_image_paths, mask_paths=train_label_paths, transform=transform)
    test_ds = MaskDataset(image_paths=test_image_paths, mask_paths=test_label_paths, transform=transform)

    train_dl = data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
    test_dl = data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)

    return train_dl, test_dl


if __name__ == '__main__':
    train_my, test_my = load_data()
    images, labels = next(iter(train_my))
    index = 5
    images = images[index]
    labels = labels[index]
    labels = torch.unsqueeze(input=labels, dim=0)

    result = draw_segmentation_masks(image=torch.as_tensor(data=images * 255, dtype=torch.uint8),
                                     masks=torch.as_tensor(data=labels, dtype=torch.bool),
                                     alpha=0.6, colors=['red'])
    plt.imshow(result.permute(1, 2, 0).numpy())
    plt.show()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85

四、模型的构建代码如下所示

from torch import nn
import torch


class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownSample, self).__init__()
        self.conv_relu = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.pool = nn.MaxPool2d(kernel_size=2)

    def forward(self, x, is_pool=True):
        if is_pool:
            x = self.pool(x)
        x = self.conv_relu(x)
        return x


class UpSample(nn.Module):
    def __init__(self, channels):
        super(UpSample, self).__init__()
        self.conv_relu = nn.Sequential(
            nn.Conv2d(in_channels=2 * channels, out_channels=channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.up_conv = nn.Sequential(
            nn.ConvTranspose2d(in_channels=channels, out_channels=channels // 2, kernel_size=3, stride=2,
                               output_padding=1, padding=1),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.conv_relu(x)
        x = self.up_conv(x)
        return x


class UnetModel(nn.Module):
    def __init__(self):
        super(UnetModel, self).__init__()
        self.down_1 = DownSample(in_channels=3, out_channels=64)
        self.down_2 = DownSample(in_channels=64, out_channels=128)
        self.down_3 = DownSample(in_channels=128, out_channels=256)
        self.down_4 = DownSample(in_channels=256, out_channels=512)
        self.down_5 = DownSample(in_channels=512, out_channels=1024)

        self.up = nn.Sequential(
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, output_padding=1,
                               padding=1),
            nn.ReLU()
        )
        self.up_1 = UpSample(channels=512)
        self.up_2 = UpSample(channels=256)
        self.up_3 = UpSample(channels=128)

        self.conv_2 = DownSample(in_channels=128, out_channels=64)
        self.last = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)

    def forward(self, x):
        down_1 = self.down_1(x, is_pool=False)
        down_2 = self.down_2(down_1)
        down_3 = self.down_3(down_2)
        down_4 = self.down_4(down_3)
        down_5 = self.down_5(down_4)

        down_5 = self.up(down_5)

        down_5 = torch.cat([down_4, down_5], dim=1)
        down_5 = self.up_1(down_5)

        down_5 = torch.cat([down_3, down_5], dim=1)
        down_5 = self.up_2(down_5)

        down_5 = torch.cat([down_2, down_5], dim=1)
        down_5 = self.up_3(down_5)

        down_5 = torch.cat([down_1, down_5], dim=1)

        down_5 = self.conv_2(down_5, is_pool=False)

        down_5 = self.last(down_5)

        return down_5

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90

五、模型的训练代码如下所示

import torch
from data_loader import load_data
from model_loader import UnetModel
from torch import nn
from torch import optim
import tqdm
import os

# 环境变量的配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载数据
train_dl, test_dl = load_data()

# 加载模型
model = UnetModel()
model = model.to(device=device)

# 训练的相关配置
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.7)

# 开始进行训练
for epoch in range(100):
    train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
    train_tqdm.set_description_str('Train epoch: {:3d}'.format(epoch))
    train_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
    for train_images, train_labels in train_tqdm:
        train_images, train_labels = train_images.to(device), train_labels.to(device)
        pred = model(train_images)
        loss = loss_fn(pred, train_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            train_loss_sum = torch.cat([train_loss_sum, torch.unsqueeze(input=loss, dim=-1)], dim=-1)
            train_tqdm.set_postfix({'train loss': train_loss_sum.mean().item()})
    train_tqdm.close()

    lr_scheduler.step()

    with torch.no_grad():
        test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
        test_tqdm.set_description_str('Test epoch: {:3d}'.format(epoch))
        test_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
        for test_images, test_labels in test_tqdm:
            test_images, test_labels = test_images.to(device), test_labels.to(device)
            test_pred = model(test_images)
            test_loss = loss_fn(test_pred.softmax(dim=1), test_labels)
            test_loss_sum = torch.cat([test_loss_sum, torch.unsqueeze(input=test_loss, dim=-1)], dim=-1)
            test_tqdm.set_postfix({'test loss': test_loss_sum.mean().item()})
        test_tqdm.close()

# 模型的保存
if not os.path.exists(os.path.join('model_data')):
    os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59

六、模型的预测代码如下所示

import torch
import os
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
from data_loader import load_data
from model_loader import UnetModel

# 数据的加载
train_dl, test_dl = load_data()

# 模型的加载
model = UnetModel()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')
model.load_state_dict(model_state_dict)

# 开始进行预测
images, labels = next(iter(test_dl))
index = 1
with torch.no_grad():
    pred = model(images)
    pred = torch.argmax(input=pred, dim=1)
    result = draw_segmentation_masks(image=torch.as_tensor(data=images[index] * 255, dtype=torch.uint8),
                                     masks=torch.as_tensor(data=pred[index], dtype=torch.bool),
                                     alpha=0.6, colors=['red'])
    plt.figure(figsize=(8, 8), dpi=500)
    plt.axis('off')
    plt.imshow(result.permute(1, 2, 0))
    plt.savefig('result.png')
    plt.show()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

七、代码的运行结果如下所示

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/87845
推荐阅读
相关标签
  

闽ICP备14008679号