当前位置:   article > 正文

小地物小样本U-Net语义分割示例代码

小地物小样本U-Net语义分割示例代码

仅需要准备数据,适当修改格式等(下面也有相应代码)

应该不挑环境,这里我用的是python3.8+cuda12.1+pytorch2.1.0,显卡是RTX3080

工作文件夹下建立data文件夹、model文件夹、utils文件夹、train.py、predict.py、其他一些代码(可选择性使用)

------------------------------------

data文件夹下建立test文件夹、train文件夹

test文件夹内直接放入要预测的图片,预测后的结果也在此文件夹中

train文件夹下建立image文件夹、label文件夹(都是图片)

要求/注意:图片均为PNG,512*512,三通道,image下位深可以不用管,label正常8通道,可把dataset.py中读取image后面加的函数复制过去也不再考虑位深,推荐label数据为二值化图像,读取也不会出错,image如果是从遥感影像tif格式转为png,且遥感影像有四通道(如近红波段),注意查看是否会有透明度的变化,因为png有RGBA 四通道,如果变了可以将原始的tif在Arcgis中图层右键导出数据,选择使用渲染器,强行RGB,格式选png导出。(但此方法未发现在Arcgis中批量导出的方法)

-----------------------------------

model文件夹内有三个.py文件:

__init__.py

unet_model.py

  1. """ Full assembly of the parts to form the complete network """
  2. """Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""
  3. import torch.nn.functional as F
  4. from .unet_parts import *
  5. class UNet(nn.Module):
  6. def __init__(self, n_channels, n_classes, bilinear=True):
  7. super(UNet, self).__init__()
  8. self.n_channels = n_channels
  9. self.n_classes = n_classes
  10. self.bilinear = bilinear
  11. self.inc = DoubleConv(n_channels, 64)
  12. self.down1 = Down(64, 128)
  13. self.down2 = Down(128, 256)
  14. self.down3 = Down(256, 512)
  15. self.down4 = Down(512, 512)
  16. self.up1 = Up(1024, 256, bilinear)
  17. self.up2 = Up(512, 128, bilinear)
  18. self.up3 = Up(256, 64, bilinear)
  19. self.up4 = Up(128, 64, bilinear)
  20. self.outc = OutConv(64, n_classes)
  21. def forward(self, x):
  22. x1 = self.inc(x)
  23. x2 = self.down1(x1)
  24. x3 = self.down2(x2)
  25. x4 = self.down3(x3)
  26. x5 = self.down4(x4)
  27. x = self.up1(x5, x4)
  28. x = self.up2(x, x3)
  29. x = self.up3(x, x2)
  30. x = self.up4(x, x1)
  31. logits = self.outc(x)
  32. return logits
  33. if __name__ == '__main__':
  34. net = UNet(n_channels=3, n_classes=1)
  35. print(net)

unet_parts.py

  1. """ Parts of the U-Net model """
  2. """https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py"""
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. class DoubleConv(nn.Module):
  7. """(convolution => [BN] => ReLU) * 2"""
  8. def __init__(self, in_channels, out_channels):
  9. super().__init__()
  10. self.double_conv = nn.Sequential(
  11. nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
  12. nn.BatchNorm2d(out_channels),
  13. nn.ReLU(inplace=True),
  14. nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
  15. nn.BatchNorm2d(out_channels),
  16. nn.ReLU(inplace=True)
  17. )
  18. def forward(self, x):
  19. return self.double_conv(x)
  20. class Down(nn.Module):
  21. """Downscaling with maxpool then double conv"""
  22. def __init__(self, in_channels, out_channels):
  23. super().__init__()
  24. self.maxpool_conv = nn.Sequential(
  25. nn.MaxPool2d(2),
  26. DoubleConv(in_channels, out_channels)
  27. )
  28. def forward(self, x):
  29. return self.maxpool_conv(x)
  30. class Up(nn.Module):
  31. """Upscaling then double conv"""
  32. def __init__(self, in_channels, out_channels, bilinear=True):
  33. super().__init__()
  34. # if bilinear, use the normal convolutions to reduce the number of channels
  35. if bilinear:
  36. self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
  37. else:
  38. self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
  39. self.conv = DoubleConv(in_channels, out_channels)
  40. def forward(self, x1, x2):
  41. x1 = self.up(x1)
  42. # input is CHW
  43. diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
  44. diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
  45. x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
  46. diffY // 2, diffY - diffY // 2])
  47. x = torch.cat([x2, x1], dim=1)
  48. return self.conv(x)
  49. class OutConv(nn.Module):
  50. def __init__(self, in_channels, out_channels):
  51. super(OutConv, self).__init__()
  52. self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
  53. def forward(self, x):
  54. return self.conv(x)

-----------------------------------

utils文件夹内有一个dataset.py文件:

  1. import torch
  2. import cv2
  3. import os
  4. import glob
  5. from torch.utils.data import Dataset
  6. import random
  7. class ISBI_Loader(Dataset):
  8. def __init__(self, data_path):
  9. # 初始化函数,读取所有data_path下的图片
  10. self.data_path = data_path
  11. self.imgs_path = glob.glob(os.path.join(data_path, 'image/*.png'))
  12. def augment(self, image, flipCode):
  13. # 使用cv2.flip进行数据增强,filpCode为1水平翻转,0垂直翻转,-1水平+垂直翻转
  14. flip = cv2.flip(image, flipCode)
  15. return flip
  16. def __getitem__(self, index):
  17. # 根据index读取图片
  18. image_path = self.imgs_path[index]
  19. # 根据image_path生成label_path
  20. label_path = image_path.replace('image', 'label')
  21. # 读取训练图片和标签图片
  22. image = cv2.imread(image_path,flags=1)
  23. label = cv2.imread(label_path)
  24. # 将数据转为单通道的图片
  25. image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  26. label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
  27. image = image.reshape(1, image.shape[0], image.shape[1])
  28. label = label.reshape(1, label.shape[0], label.shape[1])
  29. # 处理标签,将像素值为255的改为1
  30. if label.max() > 1:
  31. label = label / 255
  32. # 随机进行数据增强,为2时不做处理
  33. flipCode = random.choice([-1, 0, 1, 2])
  34. if flipCode != 2:
  35. image = self.augment(image, flipCode)
  36. label = self.augment(label, flipCode)
  37. return image, label
  38. def __len__(self):
  39. # 返回训练集大小
  40. return len(self.imgs_path)
  41. if __name__ == "__main__":
  42. isbi_dataset = ISBI_Loader("data/train/")
  43. print("数据个数:", len(isbi_dataset))
  44. train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
  45. batch_size=2,
  46. shuffle=True)
  47. for image, label in train_loader:
  48. print(image.shape)

-----------------------------------

train.py

  1. from model.unet_model import UNet
  2. from utils.dataset import ISBI_Loader
  3. from torch import optim
  4. import torch.nn as nn
  5. import torch
  6. def train_net(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):
  7. # 加载训练集
  8. isbi_dataset = ISBI_Loader(data_path)
  9. train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
  10. batch_size=batch_size,
  11. shuffle=True)
  12. # 定义RMSprop算法
  13. optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
  14. # 定义Loss算法
  15. criterion = nn.BCEWithLogitsLoss()
  16. # best_loss统计,初始化为正无穷
  17. best_loss = float('inf')
  18. # 训练epochs次
  19. for epoch in range(epochs):
  20. # 训练模式
  21. net.train()
  22. # 按照batch_size开始训练
  23. for image, label in train_loader:
  24. optimizer.zero_grad()
  25. # 将数据拷贝到device中
  26. image = image.to(device=device, dtype=torch.float32)
  27. label = label.to(device=device, dtype=torch.float32)
  28. # 使用网络参数,输出预测结果
  29. pred = net(image)
  30. # 计算loss
  31. loss = criterion(pred, label)
  32. print('Loss/train', loss.item())
  33. # 保存loss值最小的网络参数
  34. if loss < best_loss:
  35. best_loss = loss
  36. torch.save(net.state_dict(), 'best_model.pth')
  37. # 更新参数
  38. loss.backward()
  39. optimizer.step()
  40. if __name__ == "__main__":
  41. # 选择设备,有cuda用cuda,没有就用cpu
  42. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  43. # 加载网络,图片单通道1,分类为1。
  44. net = UNet(n_channels=1, n_classes=1)
  45. # 将网络拷贝到deivce中
  46. net.to(device=device)
  47. # 指定训练集地址,开始训练
  48. data_path = './data/train/'
  49. train_net(net, device, data_path)

-----------------------------------

predict.py

  1. import glob
  2. import numpy as np
  3. import torch
  4. import os
  5. import cv2
  6. from model.unet_model import UNet
  7. if __name__ == "__main__":
  8. # 选择设备,有cuda用cuda,没有就用cpu
  9. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  10. # 加载网络,图片单通道,分类为1。
  11. net = UNet(n_channels=1, n_classes=1)
  12. # 将网络拷贝到deivce中
  13. net.to(device=device)
  14. # 加载模型参数
  15. net.load_state_dict(torch.load('best_model.pth', map_location=device))
  16. # 测试模式
  17. net.eval()
  18. # 读取所有图片路径
  19. tests_path = glob.glob('data/test/*.png')
  20. # 遍历素有图片
  21. for test_path in tests_path:
  22. # 保存结果地址
  23. save_res_path = test_path.split('.')[0] + '_res.png'
  24. # 读取图片
  25. img = cv2.imread(test_path)
  26. # 转为灰度图
  27. img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
  28. # 转为batch为1,通道为1,大小为512*512的数组
  29. img = img.reshape(1, 1, img.shape[0], img.shape[1])
  30. # 转为tensor
  31. img_tensor = torch.from_numpy(img)
  32. # 将tensor拷贝到device中,只用cpu就是拷贝到cpu中,用cuda就是拷贝到cuda中。
  33. img_tensor = img_tensor.to(device=device, dtype=torch.float32)
  34. # 预测
  35. pred = net(img_tensor)
  36. # 提取结果
  37. pred = np.array(pred.data.cpu()[0])[0]
  38. # 处理结果
  39. pred[pred >= 0.5] = 255
  40. pred[pred < 0.5] = 0
  41. # 保存图片
  42. cv2.imwrite(save_res_path, pred)

-----------------------------------

剩下是一些杂七杂八的比较有针对性功能的代码:

格式转换tif变为png,但这个不能解决有透明度的问题,解决办法见文章最上面!

  1. from PIL import Image
  2. import os
  3. def tif_to_png(input_path, output_path):
  4. for file in os.listdir(input_path):
  5. if file.endswith('.tif'):
  6. with Image.open(os.path.join(input_path, file)) as im:
  7. im.save(os.path.join(output_path, file.replace('.tif', '.png')))
  8. # 示例
  9. tif_to_png('D:/deeplearning/tif',
  10. 'D:/deeplearning/png')

格式转换jpg变为png

  1. import os
  2. from PIL import Image
  3. # 获取指定目录下的所有png图片
  4. def get_all_png_files(dir):
  5. files_list = []
  6. for root, dirs, files in os.walk(dir):
  7. for file in files:
  8. if os.path.splitext(file)[1] == '.jpg':
  9. files_list.append(os.path.join(root, file))
  10. return files_list
  11. # 批量转换png图片为jpg格式并保存到新的文件夹
  12. def png2jpg(files_list, output_dir):
  13. for file in files_list:
  14. img = Image.open(file)
  15. new_file = os.path.splitext(file)[0] + '.png'
  16. output_file = os.path.join(output_dir, os.path.basename(new_file))
  17. img.convert('RGB').save(output_file)
  18. if __name__ == '__main__':
  19. dir = r'D:\deeplearning\test' # png图片目录
  20. output_dir = r'D:\deeplearning\test' # 新的文件夹路径
  21. files_list = get_all_png_files(dir)
  22. png2jpg(files_list, output_dir)

将文件夹内图像批量重命名:

  1. #coding=gbk
  2. import os
  3. import sys
  4. def rename():
  5. path=input("请输入路径(例如D:\\\\picture):")
  6. name=input("请输入开头名:")
  7. startNumber=input("请输入开始数:")
  8. fileType=input("请输入后缀名(如 .jpg、.txt等等):")
  9. print("正在生成以"+name+startNumber+fileType+"迭代的文件名")
  10. count=0
  11. filelist=os.listdir(path)
  12. for files in filelist:
  13. Olddir=os.path.join(path,files)
  14. if os.path.isdir(Olddir):
  15. continue
  16. Newdir=os.path.join(path,name+str(count+int(startNumber))+fileType)
  17. os.rename(Olddir,Newdir)
  18. count+=1
  19. print("一共修改了"+str(count)+"个文件")
  20. rename()

参考包括:

Pytorch深度学习实战教程(三):UNet模型训练 (qq.com)

GitHub:Deep-Learning/Tutorial/lesson-2 at master · Jack-Cherish/Deep-Learning · GitHub

B站up主:Bubbliiiing

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/87800
推荐阅读
相关标签
  

闽ICP备14008679号