当前位置:   article > 正文

pytorch 实现语义分割FCN网络(训练代码+预测代码)_语义分割通用训练代码

语义分割通用训练代码

一,FCN网络

FCN大致上就是下图这个结构:

原图通过“编码器网络”把图片越缩越小,然后再通过“解码器网络”把图片再进行逐步放大。得到就结果就是一个个不同颜色的颜色块(称之为掩码),每一种颜色代表不同的类别。

FCN中一个很重要的部分---反卷积

图片通过卷积层降低分辨率,提取特征,而反卷积则是把图片重新放大的一个结构。

语义分割中,必须对反卷积的反卷积核进行参数初始化(这点很重要)。一般使用的方法是双线性插值法

pytorch 中反卷积函数的说明:

给出反卷积操作输入尺寸和输出尺寸的关系公式:

 

二,代码所用到的数据集:(cityspaces)

cityspaces数据集有很多个,我的是用下面的:(gtFine是label,下面的是原图)

类别数是20。

三,训练代码:

3.1 数据读取代码:

  1. import os
  2. import random
  3. from PIL import Image
  4. import torch
  5. from torch.utils.data import Dataset
  6. # Labels: -1 license plate, 0 unlabeled, 1 ego vehicle, 2 rectification border, 3 out of roi, 4 static, 5 dynamic, 6 ground, 7 road, 8 sidewalk, 9 parking, 10 rail track, 11 building, 12 wall, 13 fence, 14 guard rail, 15 bridge, 16 tunnel, 17 pole, 18 polegroup, 19 traffic light, 20 traffic sign, 21 vegetation, 22 terrain, 23 sky, 24 person, 25 rider, 26 car, 27 truck, 28 bus, 29 caravan, 30 trailer, 31 train, 32 motorcycle, 33 bicycle
  7. num_classes = 20
  8. full_to_train = {-1: 19, 0: 19, 1: 19, 2: 19, 3: 19, 4: 19, 5: 19, 6: 19, 7: 0, 8: 1, 9: 19, 10: 19, 11: 2, 12: 3, 13: 4, 14: 19, 15: 19, 16: 19, 17: 5, 18: 19, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14, 28: 15, 29: 19, 30: 19, 31: 16, 32: 17, 33: 18}
  9. train_to_full = {0: 7, 1: 8, 2: 11, 3: 12, 4: 13, 5: 17, 6: 19, 7: 20, 8: 21, 9: 22, 10: 23, 11: 24, 12: 25, 13: 26, 14: 27, 15: 28, 16: 31, 17: 32, 18: 33, 19: 0}
  10. full_to_colour = {0: (0, 0, 0), 7: (128, 64, 128), 8: (244, 35, 232), 11: (70, 70, 70), 12: (102, 102, 156), 13: (190, 153, 153), 17: (153, 153, 153), 19: (250, 170, 30), 20: (220, 220, 0), 21: (107, 142, 35), 22: (152, 251, 152), 23: (70, 130, 180), 24: (220, 20, 60), 25: (255, 0, 0), 26: (0, 0, 142), 27: (0, 0, 70), 28: (0, 60,100), 31: (0, 80, 100), 32: (0, 0, 230), 33: (119, 11, 32)}
  11. class CityscapesDataset(Dataset):
  12. def __init__(self, split='train', crop=None, flip=False):
  13. super().__init__()
  14. self.crop = crop
  15. self.flip = flip
  16. self.inputs = []
  17. self.targets = []
  18. for root, _, filenames in os.walk(os.path.join('/home/home_data/zjw/cityspaces', 'leftImg8bit', split)):
  19. for filename in filenames:
  20. if os.path.splitext(filename)[1] == '.png':
  21. filename_base = '_'.join(filename.split('_')[:-1])
  22. target_root = os.path.join('/home/home_data/zjw/cityspaces', 'gtFine', split, os.path.basename(root))
  23. self.inputs.append(os.path.join(root, filename_base + '_leftImg8bit.png'))
  24. self.targets.append(os.path.join(target_root, filename_base + '_gtFine_labelIds.png'))
  25. def __len__(self):
  26. return len(self.inputs)
  27. def __getitem__(self, i):
  28. # Load images and perform augmentations with PIL
  29. input, target = Image.open(self.inputs[i]), Image.open(self.targets[i])
  30. # Random uniform crop
  31. if self.crop is not None:
  32. w, h = input.size
  33. x1, y1 = random.randint(0, w - self.crop), random.randint(0, h - self.crop)
  34. input, target = input.crop((x1, y1, x1 + self.crop, y1 + self.crop)), target.crop((x1, y1, x1 + self.crop, y1 + self.crop))
  35. # Random horizontal flip
  36. if self.flip:
  37. if random.random() < 0.5:
  38. input, target = input.transpose(Image.FLIP_LEFT_RIGHT), target.transpose(Image.FLIP_LEFT_RIGHT)
  39. # Convert to tensors
  40. w, h = input.size
  41. input = torch.ByteTensor(torch.ByteStorage.from_buffer(input.tobytes())).view(h, w, 3).permute(2, 0, 1).float().div(255)
  42. target = torch.ByteTensor(torch.ByteStorage.from_buffer(target.tobytes())).view(h, w).long()
  43. # Normalise input
  44. input[0].add_(-0.485).div_(0.229)
  45. input[1].add_(-0.456).div_(0.224)
  46. input[2].add_(-0.406).div_(0.225)
  47. # Convert to training labels
  48. remapped_target = target.clone()
  49. for k, v in full_to_train.items():
  50. remapped_target[target == k] = v
  51. # Create one-hot encoding
  52. target = torch.zeros(num_classes, h, w)
  53. for c in range(num_classes): #把taget变成 类别数x高x宽 ==>类别数x一个面
  54. target[c][remapped_target == c] = 1 #每一类占一个面,原图里A类的像素点坐标(i,j),那么在属于A类的(i,j)处设为1
  55. return input, target, remapped_target # Return x, y (one-hot), y (index)

代码的上面部分有三个列表,分别是:

full_to_train,train_to_full,full_to_colour。

关于cityspaces的label其实是很长的:(一共有34个)

# Labels: -1 license plate, 0 unlabeled, 1 ego vehicle, 2 rectification border, 3 out of roi, 4 static, 5 dynamic, 6 ground, 7 road, 8 sidewalk, 9 parking, 10 rail track, 11 building, 12 wall, 13 fence, 14 guard rail, 15 bridge, 16 tunnel, 17 pole, 18 polegroup, 19 traffic light, 20 traffic sign, 21 vegetation, 22 terrain, 23 sky, 24 person, 25 rider, 26 car, 27 truck, 28 bus, 29 caravan, 30 trailer, 31 train, 32 motorcycle, 33 bicycle

代表label图片中,会含有少于等于34种数字。而网络只处理20个类别,所以要把34类映射到20类。

从代码看出,label是以_gtFine_labelIds.png结尾的图片:(就是红色框那种)

为什么会那么暗呢?label为什么不是第一张而是红色框的呢?

1.上面四张其实都是同一张原图的label,只是看你 用哪一种而已。

2.红色框那种暗的原因是图片的值全是 -1~33的某些值。

这四张的输入原图如下:

说回full_to_train,train_to_full,full_to_colour:

full 就是 读入的label图,train是把full中-1~33 的值转换为 0~19 ,达到34类映射成20类的效果。

而full_to_colour:就是哪一种类别对应的 rgb 值。

 

3.2 模型代码:

  1. import torch
  2. from torch import nn
  3. from torch.nn import init
  4. from torchvision.models.resnet import BasicBlock, ResNet
  5. # Returns 2D convolutional layer with space-preserving padding
  6. def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, bias=False, transposed=False):
  7. if transposed:
  8. layer = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=1, output_padding=1, dilation=dilation, bias=bias)
  9. # Bilinear interpolation init 用双线性插值法初始化反卷积核
  10. w = torch.Tensor(kernel_size, kernel_size)
  11. centre = kernel_size % 2 == 1 and stride - 1 or stride - 0.5
  12. for y in range(kernel_size):
  13. for x in range(kernel_size):
  14. w[y, x] = (1 - abs((x - centre) / stride)) * (1 - abs((y - centre) / stride))
  15. layer.weight.data.copy_(w.div(in_planes).repeat(in_planes, out_planes, 1, 1))
  16. else:
  17. padding = (kernel_size + 2 * (dilation - 1)) // 2
  18. layer = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias)
  19. if bias:
  20. init.constant(layer.bias, 0)
  21. return layer
  22. # Returns 2D batch normalisation layer
  23. def bn(planes):
  24. layer = nn.BatchNorm2d(planes)
  25. # Use mean 0, standard deviation 1 init
  26. init.constant(layer.weight, 1)
  27. init.constant(layer.bias, 0)
  28. return layer
  29. class FeatureResNet(ResNet):
  30. def __init__(self):
  31. super().__init__(BasicBlock, [3, 4, 6, 3], 1000) #特征提取用resnet
  32. def forward(self, x):
  33. x1 = self.conv1(x)
  34. x = self.bn1(x1)
  35. x = self.relu(x)
  36. x2 = self.maxpool(x)
  37. x = self.layer1(x2)
  38. x3 = self.layer2(x)
  39. x4 = self.layer3(x3)
  40. x5 = self.layer4(x4)
  41. return x1, x2, x3, x4, x5
  42. class SegResNet(nn.Module):
  43. def __init__(self, num_classes, pretrained_net):
  44. super().__init__()
  45. self.pretrained_net = pretrained_net
  46. self.relu = nn.ReLU(inplace=True)
  47. self.conv5 = conv(512, 256, stride=2, transposed=True)
  48. self.bn5 = bn(256)
  49. self.conv6 = conv(256, 128, stride=2, transposed=True)
  50. self.bn6 = bn(128)
  51. self.conv7 = conv(128, 64, stride=2, transposed=True)
  52. self.bn7 = bn(64)
  53. self.conv8 = conv(64, 64, stride=2, transposed=True)
  54. self.bn8 = bn(64)
  55. self.conv9 = conv(64, 32, stride=2, transposed=True)
  56. self.bn9 = bn(32)
  57. self.conv10 = conv(32, num_classes, kernel_size=7)
  58. init.constant(self.conv10.weight, 0) # Zero init
  59. def forward(self, x):
  60. x1, x2, x3, x4, x5 = self.pretrained_net(x)
  61. x = self.relu(self.bn5(self.conv5(x5)))
  62. x = self.relu(self.bn6(self.conv6(x + x4)))
  63. x = self.relu(self.bn7(self.conv7(x + x3)))
  64. x = self.relu(self.bn8(self.conv8(x + x2)))
  65. x = self.relu(self.bn9(self.conv9(x + x1)))
  66. x = self.conv10(x)
  67. return x

3.3 训练代码:

  1. from argparse import ArgumentParser
  2. import os
  3. import random
  4. from matplotlib import pyplot as plt
  5. import torch
  6. from torch import optim
  7. from torch import nn
  8. from torch.nn import functional as F
  9. from torch.autograd import Variable
  10. from torch.utils.data import DataLoader
  11. from torchvision import models
  12. from torchvision.utils import save_image
  13. from data import CityscapesDataset, num_classes, full_to_colour, train_to_full
  14. from model import FeatureResNet, SegResNet
  15. # Setup
  16. parser = ArgumentParser(description='Semantic segmentation')
  17. parser.add_argument('--seed', type=int, default=42, help='Random seed')
  18. parser.add_argument('--workers', type=int, default=8, help='Data loader workers')
  19. parser.add_argument('--epochs', type=int, default=100, help='Training epochs')
  20. parser.add_argument('--crop-size', type=int, default=512, help='Training crop size')
  21. parser.add_argument('--lr', type=float, default=5e-5, help='Learning rate')
  22. parser.add_argument('--momentum', type=float, default=0, help='Momentum')
  23. parser.add_argument('--weight-decay', type=float, default=2e-4, help='Weight decay')
  24. parser.add_argument('--batch-size', type=int, default=16, help='Batch size')
  25. args = parser.parse_args()
  26. random.seed(args.seed)
  27. torch.manual_seed(args.seed)
  28. if not os.path.exists('results'):
  29. os.makedirs('results')
  30. plt.switch_backend('agg') # Allow plotting when running remotely
  31. # Data
  32. train_dataset = CityscapesDataset(split='train', crop=args.crop_size, flip=True)
  33. val_dataset = CityscapesDataset(split='val')
  34. train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
  35. val_loader = DataLoader(val_dataset, batch_size=1, num_workers=args.workers, pin_memory=True)
  36. # Training/Testing
  37. pretrained_net = FeatureResNet()
  38. pretrained_net.load_state_dict(models.resnet34(pretrained=True).state_dict())
  39. net = SegResNet(num_classes, pretrained_net).cuda()
  40. crit = nn.BCELoss().cuda()
  41. # Construct optimiser
  42. params_dict = dict(net.named_parameters())
  43. params = []
  44. for key, value in params_dict.items():
  45. if 'bn' in key:
  46. # No weight decay on batch norm
  47. params += [{'params': [value], 'weight_decay': 0}]
  48. elif '.bias' in key:
  49. # No weight decay plus double learning rate on biases
  50. params += [{'params': [value], 'lr': 2 * args.lr, 'weight_decay': 0}]
  51. else:
  52. params += [{'params': [value]}]
  53. optimiser = optim.RMSprop(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
  54. scores, mean_scores = [], []
  55. def train(e):
  56. net.train()
  57. for i, (input, target, _) in enumerate(train_loader):
  58. optimiser.zero_grad()
  59. input, target = Variable(input.cuda(async=True)), Variable(target.cuda(async=True))
  60. output = F.sigmoid(net(input))
  61. loss = crit(output, target)
  62. print(e, i, loss.item())
  63. loss.backward()
  64. optimiser.step()
  65. # Calculates class intersections over unions
  66. def iou(pred, target):
  67. ious = []
  68. # Ignore IoU for background class
  69. for cls in range(num_classes - 1):
  70. pred_inds = pred == cls
  71. target_inds = target == cls
  72. intersection = (pred_inds[target_inds]).long().sum().data.cpu().item() # Cast to long to prevent overflows
  73. union = pred_inds.long().sum().data.cpu().item() + target_inds.long().sum().data.cpu().item() - intersection
  74. if union == 0:
  75. ious.append(float('nan')) # If there is no ground truth, do not include in evaluation
  76. else:
  77. ious.append(intersection / max(union, 1))
  78. return ious
  79. def test(e):
  80. net.eval()
  81. total_ious = []
  82. for i, (input, _, target) in enumerate(val_loader):
  83. input, target = Variable(input.cuda(async=True), volatile=True), Variable(target.cuda(async=True), volatile=True)
  84. output = F.log_softmax(net(input))
  85. b, _, h, w = output.size()
  86. pred = output.permute(0, 2, 3, 1).contiguous().view(-1, num_classes).max(1)[1].view(b, h, w)
  87. total_ious.append(iou(pred, target))
  88. # Save images
  89. if i % 25 == 0:
  90. pred = pred.data.cpu()
  91. pred_remapped = pred.clone()
  92. # Convert to full labels
  93. for k, v in train_to_full.items():
  94. pred_remapped[pred == k] = v
  95. # Convert to colour image
  96. pred = pred_remapped
  97. pred_colour = torch.zeros(b, 3, h, w)
  98. for k, v in full_to_colour.items():
  99. pred_r = torch.zeros(b, 1, h, w)
  100. #print('pred shape:{}'.format(pred.shape))
  101. #print('k:{}'.format(k))
  102. pred = pred.reshape(1,1,h,-1)
  103. #print('pred shape:{}'.format(pred.shape))
  104. pred_r[(pred == k)] = v[0]
  105. pred_g = torch.zeros(b, 1, h, w)
  106. pred_g[(pred == k)] = v[1]
  107. pred_b = torch.zeros(b, 1, h, w)
  108. pred_b[(pred == k)] = v[2]
  109. pred_colour.add_(torch.cat((pred_r, pred_g, pred_b), 1))
  110. save_image(pred_colour[0].float().div(255), os.path.join('results', str(e) + '_' + str(i) + '.png'))
  111. # Calculate average IoU
  112. total_ious = torch.Tensor(total_ious).transpose(0, 1)
  113. ious = torch.Tensor(num_classes - 1)
  114. for i, class_iou in enumerate(total_ious):
  115. ious[i] = class_iou[class_iou == class_iou].mean() # Calculate mean, ignoring NaNs
  116. print(ious, ious.mean())
  117. scores.append(ious)
  118. # Save weights and scores
  119. torch.save(net, os.path.join('results', str(e) + '_net.pth'))
  120. torch.save(scores, os.path.join('results', 'scores.pth'))
  121. # Plot scores
  122. mean_scores.append(ious.mean())
  123. es = list(range(len(mean_scores)))
  124. plt.plot(es, mean_scores, 'b-')
  125. plt.xlabel('Epoch')
  126. plt.ylabel('Mean IoU')
  127. plt.savefig(os.path.join('results', 'ious.png'))
  128. plt.close()
  129. test(0)
  130. for e in range(1, args.epochs + 1):
  131. train(e)
  132. test(e)

运行后的结果:

其中红色框为 保存下来的模型。每次测试(test)都会保存下结果。

随便讲一下test()里面的一行代码:

pred = output.permute(0, 2, 3, 1).contiguous().view(-1, num_classes).max(1)[1].view(b, h, w)

网络里运行的数据是torch  tensor格式的,它的维度定义为(batchsize,通道数,高,宽),而输出后一半用numpy格式处理数据,numpy的维度定义是(batchsize,高,宽,通道数)

所以 permute(0,2,3,1)就是把维度转换过来。

接着contiguous()是为view()操作做准备的。

view()是改变矩阵的形状,-1表示行数待定,列数为num_classes(即类别数),(总的类别数除num_class就得到行数了,当然这个代码自己做了)。假设第一个view之后得到的矩阵维度为(A,num_classes):

A就是所有像素点的个数了。上面的矩阵就表示每个像素点属于各个类别的概率。然后代码来了个max(1),什么意思?就是选出每一个像素数属于哪个类别嘛,选出概率最大的那个类别作为该像素的类别,如下:

之后再把这个矩阵用view()转成 bxHxW。(b为batchsize,H为高度,W为宽度)

 

3.4 预测代码:

  1. import torch
  2. from PIL import Image
  3. import random
  4. from torchvision.utils import save_image
  5. from torch.nn import functional as F
  6. full_to_train = {-1: 19, 0: 19, 1: 19, 2: 19, 3: 19, 4: 19, 5: 19, 6: 19, 7: 0, 8: 1, 9: 19, 10: 19, 11: 2, 12: 3, 13: 4, 14: 19, 15: 19, 16: 19, 17: 5, 18: 19, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14, 28: 15, 29: 19, 30: 19, 31: 16, 32: 17, 33: 18}
  7. train_to_full = {0: 7, 1: 8, 2: 11, 3: 12, 4: 13, 5: 17, 6: 19, 7: 20, 8: 21, 9: 22, 10: 23, 11: 24, 12: 25, 13: 26, 14: 27, 15: 28, 16: 31, 17: 32, 18: 33, 19: 0}
  8. full_to_colour = {0: (0, 0, 0), 7: (128, 64, 128), 8: (244, 35, 232), 11: (70, 70, 70), 12: (102, 102, 156), 13: (190, 153, 153), 17: (153, 153, 153), 19: (250, 170, 30), 20: (220, 220, 0), 21: (107, 142, 35), 22: (152, 251, 152), 23: (70, 130, 180), 24: (220, 20, 60), 25: (255, 0, 0), 26: (0, 0, 142), 27: (0, 0, 70), 28: (0, 60,100), 31: (0, 80, 100), 32: (0, 0, 230), 33: (119, 11, 32)}
  9. path = r'/home/home_data/zjw/FCN-semantic-segmentation-master/s2.jpeg'
  10. model = r'/home/home_data/zjw/FCN-semantic-segmentation-master/results/100_net.pth'
  11. crop_size = 512
  12. num_classes = 20
  13. def test():
  14. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  15. net = torch.load(model)
  16. net = net.to(device)
  17. net.eval()
  18. input = Image.open(path)
  19. w,h = input.size
  20. x1,y1 = random.randint(0,w-crop_size),random.randint(0,h-crop_size)
  21. input = input.crop((x1,y1,x1+crop_size,y1+crop_size))
  22. w,h = input.size
  23. input = torch.ByteTensor(torch.ByteStorage.from_buffer(input.tobytes())).view(h,w,3).permute(2,0,1).float().div(255)
  24. input[0].add_(-0.485).div_(0.229)
  25. input[1].add_(-0.456).div_(0.224)
  26. input[2].add_(-0.406).div_(0.225)
  27. input = input.to(device)
  28. input = input.unsqueeze(0)
  29. output = F.log_softmax(net(input))
  30. b,_,h,w = output.size()
  31. pred = output.permute(0,2,3,1).contiguous().view(-1,num_classes).max(1)[1].view(b,h,w)
  32. pred = pred.data.cpu()
  33. pred_remapped = pred.clone()
  34. for k,v in train_to_full.items():
  35. pred_remapped[pred==k] = v
  36. pred = pred_remapped
  37. pred_colour = torch.zeros(b,3,h,w)
  38. for k,v in full_to_colour.items():
  39. pred_r = torch.zeros(b,1,h,w)
  40. pred = pred.reshape(1,1,h,-1)
  41. pred_r[(pred==k)] = v[0]
  42. pred_g = torch.zeros(b,1,h,w)
  43. pred_g[(pred==k)] = v[1]
  44. pred_b = torch.zeros(b,1,h,w)
  45. pred_b[(pred==k)] = v[2]
  46. pred_colour.add_(torch.cat((pred_r,pred_g,pred_b),1))
  47. print(pred_colour[0].float())
  48. print('-----------------')
  49. pred = pred_colour[0].float().div(255)
  50. print(pred)
  51. save_image(pred,r'./test_street2.png')
  52. #save_image(pred_colour,r'./test_street2.png')
  53. test()

上面的代码中,我们看到有一行:

pred = pred_colour[0].float().div(255)

有没有疑问为什么要除255?

原因如下:https://blog.csdn.net/sdlyjzh/article/details/8245145

然后我们看看运行效果:

我们先从网上搜一张街景图:

输入到预测代码中 运行:

效果是有效果,但效果哈哈看上去好像也不是特别的好。

代码:https://github.com/Andy-zhujunwen/pytoch-FCN-train-inference-

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/354709
推荐阅读
相关标签
  

闽ICP备14008679号