当前位置:   article > 正文

【PyTorch】多对象分割项目

【PyTorch】多对象分割项目

 【PyTorch】单对象分割项目

对象分割任务的目标是找到图像中目标对象的边界。实际应用例如自动驾驶汽车和医学成像分析。这里将使用PyTorch开发一个深度学习模型来完成多对象分割任务。多对象分割的主要目标是自动勾勒出图像中多个目标对象的边界。

对象的边界通常由与图像大小相同的分割掩码定义,在分割掩码中属于目标对象的所有像素基于预定义的标记被标记为相同。

目录

创建数据集

创建数据加载器

创建模型

部署模型

定义损失函数和优化器

训练和验证模型


创建数据集

  1. from torchvision.datasets import VOCSegmentation
  2. from PIL import Image
  3. from torchvision.transforms.functional import to_tensor, to_pil_image
  4. class myVOCSegmentation(VOCSegmentation):
  5. def __getitem__(self, index):
  6. img = Image.open(self.images[index]).convert('RGB')
  7. target = Image.open(self.masks[index])
  8. if self.transforms is not None:
  9. augmented= self.transforms(image=np.array(img), mask=np.array(target))
  10. img = augmented['image']
  11. target = augmented['mask']
  12. target[target>20]=0
  13. img= to_tensor(img)
  14. target= torch.from_numpy(target).type(torch.long)
  15. return img, target
  16. from albumentations import (
  17. HorizontalFlip,
  18. Compose,
  19. Resize,
  20. Normalize)
  21. mean = [0.485, 0.456, 0.406]
  22. std = [0.229, 0.224, 0.225]
  23. h,w=520,520
  24. transform_train = Compose([ Resize(h,w),
  25. HorizontalFlip(p=0.5),
  26. Normalize(mean=mean,std=std)])
  27. transform_val = Compose( [ Resize(h,w),
  28. Normalize(mean=mean,std=std)])
  29. path2data="./data/"
  30. train_ds=myVOCSegmentation(path2data,
  31. year='2012',
  32. image_set='train',
  33. download=False,
  34. transforms=transform_train)
  35. print(len(train_ds))
  36. val_ds=myVOCSegmentation(path2data,
  37. year='2012',
  38. image_set='val',
  39. download=False,
  40. transforms=transform_val)
  41. print(len(val_ds))
  1. import torch
  2. import numpy as np
  3. from skimage.segmentation import mark_boundaries
  4. import matplotlib.pylab as plt
  5. %matplotlib inline
  6. np.random.seed(0)
  7. num_classes=21
  8. COLORS = np.random.randint(0, 2, size=(num_classes+1, 3),dtype="uint8")
  9. def show_img_target(img, target):
  10. if torch.is_tensor(img):
  11. img=to_pil_image(img)
  12. target=target.numpy()
  13. for ll in range(num_classes):
  14. mask=(target==ll)
  15. img=mark_boundaries(np.array(img) ,
  16. mask,
  17. outline_color=COLORS[ll],
  18. color=COLORS[ll])
  19. plt.imshow(img)
  20. def re_normalize (x, mean = mean, std= std):
  21. x_r= x.clone()
  22. for c, (mean_c, std_c) in enumerate(zip(mean, std)):
  23. x_r [c] *= std_c
  24. x_r [c] += mean_c
  25. return x_r

 展示训练数据集示例图像

  1. img, mask = train_ds[10]
  2. print(img.shape, img.type(),torch.max(img))
  3. print(mask.shape, mask.type(),torch.max(mask))
  4. plt.figure(figsize=(20,20))
  5. img_r= re_normalize(img)
  6. plt.subplot(1, 3, 1)
  7. plt.imshow(to_pil_image(img_r))
  8. plt.subplot(1, 3, 2)
  9. plt.imshow(mask)
  10. plt.subplot(1, 3, 3)
  11. show_img_target(img_r, mask)

展示验证数据集示例图像

  1. img, mask = val_ds[10]
  2. print(img.shape, img.type(),torch.max(img))
  3. print(mask.shape, mask.type(),torch.max(mask))
  4. plt.figure(figsize=(20,20))
  5. img_r= re_normalize(img)
  6. plt.subplot(1, 3, 1)
  7. plt.imshow(to_pil_image(img_r))
  8. plt.subplot(1, 3, 2)
  9. plt.imshow(mask)
  10. plt.subplot(1, 3, 3)
  11. show_img_target(img_r, mask)

创建数据加载器

 通过torch.utils.data针对训练和验证集分别创建Dataloader,打印示例观察效果

  1. from torch.utils.data import DataLoader
  2. train_dl = DataLoader(train_ds, batch_size=4, shuffle=True)
  3. val_dl = DataLoader(val_ds, batch_size=8, shuffle=False)
  4. for img_b, mask_b in train_dl:
  5. print(img_b.shape,img_b.dtype)
  6. print(mask_b.shape, mask_b.dtype)
  7. break
  8. for img_b, mask_b in val_dl:
  9. print(img_b.shape,img_b.dtype)
  10. print(mask_b.shape, mask_b.dtype)
  11. break

创建模型

创建并打印deeplab_resnet模型结构,使用预训练权重

  1. from torchvision.models.segmentation import deeplabv3_resnet101
  2. import torch
  3. model=deeplabv3_resnet101(pretrained=True, num_classes=21)
  4. device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  5. model=model.to(device)
  6. print(model)

部署模型

在验证数据集的数据批次上部署模型观察效果 

  1. from torch import nn
  2. model.eval()
  3. with torch.no_grad():
  4. for xb, yb in val_dl:
  5. yb_pred = model(xb.to(device))
  6. yb_pred = yb_pred["out"].cpu()
  7. print(yb_pred.shape)
  8. yb_pred = torch.argmax(yb_pred,axis=1)
  9. break
  10. print(yb_pred.shape)
  11. plt.figure(figsize=(20,20))
  12. n=2
  13. img, mask= xb[n], yb_pred[n]
  14. img_r= re_normalize(img)
  15. plt.subplot(1, 3, 1)
  16. plt.imshow(to_pil_image(img_r))
  17. plt.subplot(1, 3, 2)
  18. plt.imshow(mask)
  19. plt.subplot(1, 3, 3)
  20. show_img_target(img_r, mask)

可见勾勒对象方面效果很好 

定义损失函数和优化器

  1. from torch import nn
  2. criterion = nn.CrossEntropyLoss(reduction="sum")
  1. from torch import optim
  2. opt = optim.Adam(model.parameters(), lr=1e-6)
  3. def loss_batch(loss_func, output, target, opt=None):
  4. loss = loss_func(output, target)
  5. if opt is not None:
  6. opt.zero_grad()
  7. loss.backward()
  8. opt.step()
  9. return loss.item(), None
  10. from torch.optim.lr_scheduler import ReduceLROnPlateau
  11. lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)
  12. def get_lr(opt):
  13. for param_group in opt.param_groups:
  14. return param_group['lr']
  15. current_lr=get_lr(opt)
  16. print('current lr={}'.format(current_lr))

训练和验证模型

  1. def loss_epoch(model,loss_func,dataset_dl,sanity_check=False,opt=None):
  2. running_loss=0.0
  3. len_data=len(dataset_dl.dataset)
  4. for xb, yb in dataset_dl:
  5. xb=xb.to(device)
  6. yb=yb.to(device)
  7. output=model(xb)["out"]
  8. loss_b, _ = loss_batch(loss_func, output, yb, opt)
  9. running_loss += loss_b
  10. if sanity_check is True:
  11. break
  12. loss=running_loss/float(len_data)
  13. return loss, None
  14. import copy
  15. def train_val(model, params):
  16. num_epochs=params["num_epochs"]
  17. loss_func=params["loss_func"]
  18. opt=params["optimizer"]
  19. train_dl=params["train_dl"]
  20. val_dl=params["val_dl"]
  21. sanity_check=params["sanity_check"]
  22. lr_scheduler=params["lr_scheduler"]
  23. path2weights=params["path2weights"]
  24. loss_history={
  25. "train": [],
  26. "val": []}
  27. metric_history={
  28. "train": [],
  29. "val": []}
  30. best_model_wts = copy.deepcopy(model.state_dict())
  31. best_loss=float('inf')
  32. for epoch in range(num_epochs):
  33. current_lr=get_lr(opt)
  34. print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs - 1, current_lr))
  35. model.train()
  36. train_loss, train_metric=loss_epoch(model,loss_func,train_dl,sanity_check,opt)
  37. loss_history["train"].append(train_loss)
  38. metric_history["train"].append(train_metric)
  39. model.eval()
  40. with torch.no_grad():
  41. val_loss, val_metric=loss_epoch(model,loss_func,val_dl,sanity_check)
  42. loss_history["val"].append(val_loss)
  43. metric_history["val"].append(val_metric)
  44. if val_loss < best_loss:
  45. best_loss = val_loss
  46. best_model_wts = copy.deepcopy(model.state_dict())
  47. torch.save(model.state_dict(), path2weights)
  48. print("Copied best model weights!")
  49. lr_scheduler.step(val_loss)
  50. if current_lr != get_lr(opt):
  51. print("Loading best model weights!")
  52. model.load_state_dict(best_model_wts)
  53. print("train loss: %.6f" %(train_loss))
  54. print("val loss: %.6f" %(val_loss))
  55. print("-"*10)
  56. model.load_state_dict(best_model_wts)
  57. return model, loss_history, metric_history
  1. import os
  2. opt = optim.Adam(model.parameters(), lr=1e-6)
  3. lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)
  4. path2models= "./models/"
  5. if not os.path.exists(path2models):
  6. os.mkdir(path2models)
  7. params_train={
  8. "num_epochs": 10,
  9. "optimizer": opt,
  10. "loss_func": criterion,
  11. "train_dl": train_dl,
  12. "val_dl": val_dl,
  13. "sanity_check": True,
  14. "lr_scheduler": lr_scheduler,
  15. "path2weights": path2models+"sanity_weights.pt",
  16. }
  17. model, loss_hist, _ = train_val(model, params_train)

绘制了训练和验证损失曲线 

  1. num_epochs=params_train["num_epochs"]
  2. plt.title("Train-Val Loss")
  3. plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train")
  4. plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val")
  5. plt.ylabel("Loss")
  6. plt.xlabel("Training Epochs")
  7. plt.legend()
  8. plt.show()

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号