赞
踩
上一篇关于TransUnet的GitHub复现,大家反映效果不好,调参也不好调,我把模型单独拿出来,放到另外一个框架,供大家参考学习(上一篇链接:https://blog.csdn.net/qq_20373723/article/details/115548900)
我这里训练了20个epoch,下面先给出效果正常的情况:
原图
预测结果
整体代码结构:
注意一下代码结构和文件名字记得保持一样,没有的文件手动新建一下
1.数据准备,文件名字请务必保持一致,不过你也可以去代码里改
一级目录,红线的三个,其它不用管
二级目录
三级目录就分别是图像、标签了,二者名字保持一致,标签值为0和255,代码里改也行
2.数据加载代码data.py
""" Based on https://github.com/asanakoy/kaggle_carvana_segmentation """ import torch import torch.utils.data as data from torch.autograd import Variable as V from PIL import Image import cv2 import numpy as np import os import scipy.misc as misc def randomHueSaturationValue(image, hue_shift_limit=(-180, 180), sat_shift_limit=(-255, 255), val_shift_limit=(-255, 255), u=0.5): if np.random.random() < u: image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) h, s, v = cv2.split(image) hue_shift = np.random.randint(hue_shift_limit[0], hue_shift_limit[1]+1) hue_shift = np.uint8(hue_shift) h += hue_shift sat_shift = np.random.uniform(sat_shift_limit[0], sat_shift_limit[1]) s = cv2.add(s, sat_shift) val_shift = np.random.uniform(val_shift_limit[0], val_shift_limit[1]) v = cv2.add(v, val_shift) image = cv2.merge((h, s, v)) #image = cv2.merge((s, v)) image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) return image def randomShiftScaleRotate(image, mask, shift_limit=(-0.0, 0.0), scale_limit=(-0.0, 0.0), rotate_limit=(-0.0, 0.0), aspect_limit=(-0.0, 0.0), borderMode=cv2.BORDER_CONSTANT, u=0.5): if np.random.random() < u: height, width, channel = image.shape angle = np.random.uniform(rotate_limit[0], rotate_limit[1]) scale = np.random.uniform(1 + scale_limit[0], 1 + scale_limit[1]) aspect = np.random.uniform(1 + aspect_limit[0], 1 + aspect_limit[1]) sx = scale * aspect / (aspect ** 0.5) sy = scale / (aspect ** 0.5) dx = round(np.random.uniform(shift_limit[0], shift_limit[1]) * width) dy = round(np.random.uniform(shift_limit[0], shift_limit[1]) * height) cc = np.math.cos(angle / 180 * np.math.pi) * sx ss = np.math.sin(angle / 180 * np.math.pi) * sy rotate_matrix = np.array([[cc, -ss], [ss, cc]]) box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ]) box1 = box0 - np.array([width / 2, height / 2]) box1 = np.dot(box1, rotate_matrix.T) + np.array([width / 2 + dx, height / 2 + dy]) box0 = box0.astype(np.float32) box1 = box1.astype(np.float32) mat = cv2.getPerspectiveTransform(box0, box1) image = cv2.warpPerspective(image, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode, borderValue=( 0, 0, 0,)) mask = cv2.warpPerspective(mask, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode, borderValue=( 0, 0, 0,)) return image, mask def randomHorizontalFlip(image, mask, u=0.5): if np.random.random() < u: image = cv2.flip(image, 1) mask = cv2.flip(mask, 1) return image, mask def randomVerticleFlip(image, mask, u=0.5): if np.random.random() < u: image = cv2.flip(image, 0) mask = cv2.flip(mask, 0) return image, mask def randomRotate90(image, mask, u=0.5): if np.random.random() < u: image=np.rot90(image) mask=np.rot90(mask) return image, mask def default_loader(img_path, mask_path): img = cv2.imread(img_path) # print("img:{}".format(np.shape(img))) img = cv2.resize(img, (448, 448)) mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) mask = 255. - cv2.resize(mask, (448, 448)) img = randomHueSaturationValue(img, hue_shift_limit=(-30, 30), sat_shift_limit=(-5, 5), val_shift_limit=(-15, 15)) img, mask = randomShiftScaleRotate(img, mask, shift_limit=(-0.1, 0.1), scale_limit=(-0.1, 0.1), aspect_limit=(-0.1, 0.1), rotate_limit=(-0, 0)) img, mask = randomHorizontalFlip(img, mask) img, mask = randomVerticleFlip(img, mask) img, mask = randomRotate90(img, mask) mask = np.expand_dims(mask, axis=2) # # print(np.shape(img)) # print(np.shape(mask)) img = np.array(img, np.float32).transpose(2,0,1)/255.0 * 3.2 - 1.6 mask = np.array(mask, np.float32).transpose(2,0,1)/255.0 mask[mask >= 0.5] = 1 mask[mask <= 0.5] = 0 #mask = abs(mask-1) return img, mask def read_own_data(root_path, mode = 'train'): images = [] masks = [] image_root = os.path.join(root_path, mode + '/images') gt_root = os.path.join(root_path, mode + '/labels') for image_name in os.listdir(gt_root): image_path = os.path.join(image_root, image_name) label_path = os.path.join(gt_root, image_name) images.append(image_path) masks.append(label_path) return images, masks def own_data_loader(img_path, mask_path): img = cv2.imread(img_path) mask = cv2.imread(mask_path, 0) img = randomHueSaturationValue(img, hue_shift_limit=(-30, 30), sat_shift_limit=(-5, 5), val_shift_limit=(-15, 15)) img, mask = randomShiftScaleRotate(img, mask, shift_limit=(-0.1, 0.1), scale_limit=(-0.1, 0.1), aspect_limit=(-0.1, 0.1), rotate_limit=(-0, 0)) img, mask = randomHorizontalFlip(img, mask) img, mask = randomVerticleFlip(img, mask) img, mask = randomRotate90(img, mask) mask = np.expand_dims(mask, axis=2) img = np.array(img, np.float32) / 255.0 * 3.2 - 1.6 mask = np.array(mask, np.float32) / 255.0 mask[mask >= 0.5] = 1 mask[mask < 0.5] = 0 img = np.array(img, np.float32).transpose(2, 0, 1) mask = np.array(mask, np.float32).transpose(2, 0, 1) return img, mask def own_data_test_loader(img_path, mask_path): img = cv2.imread(img_path) mask = cv2.imread(mask_path, 0) return img, mask class ImageFolder(data.Dataset): def __init__(self,root_path,mode='train'): self.root = root_path self.mode = mode self.images, self.labels = read_own_data(self.root, self.mode) def __getitem__(self, index): # img, mask = default_DRIVE_loader(self.images[index], self.labels[index]) if self.mode == 'test': img, mask = own_data_test_loader(self.images[index], self.labels[index]) else: img, mask = own_data_loader(self.images[index], self.labels[index]) img = torch.Tensor(img) mask = torch.Tensor(mask) return img, mask def __len__(self): assert len(self.images) == len(self.labels), 'The number of images must be equal to labels' return len(self.images)
3.训练代码train_normal.py
import torch import torch.optim as optim from torch.optim import lr_scheduler import torch.nn as nn import torch.utils.data as data import torch.nn.functional as F from torch.autograd import Variable as V import cv2 import os import math import warnings from tqdm import tqdm import numpy as np from time import time from shutil import copyfile, move from models.networks.TransUnet import get_transNet from framework import MyFrame from loss.dice_bce_loss import Dice_bce_loss from loss.diceloss import DiceLoss from metrics.iou import iou_pytorch from eval import eval_func, eval_new from data import ImageFolder from inference import TTAFrame from tensorboardX import SummaryWriter # os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["KMP_DUPLICATE_LIB_OK"]='True' def train(Model = None): config_file='train_normal_config.txt' dirs=[] for line in open(config_file): dirs.append(line.split()[0]) data_root = dirs[0] data_root = data_root.replace('\\','/') pre_model = dirs[1] pre_model= pre_model.replace('\\','/') bs_p_card = dirs[2] bs_p_card = bs_p_card.replace('\\','/') lr = dirs[3] epoch_num = dirs[4] epoch_num = epoch_num.replace('\\','/') model_name = dirs[5] model_name = model_name.replace('\\','/') warnings.filterwarnings("ignore") BATCHSIZE_PER_CARD = int(bs_p_card) solver = MyFrame(Model, Dice_bce_loss, float(lr)) if pre_model.endswith('.th'): solver.load(pre_model) else: pass train_batchsize = BATCHSIZE_PER_CARD val_batchsize = BATCHSIZE_PER_CARD train_dataset = ImageFolder(data_root, mode='train') val_dataset = ImageFolder(data_root, mode='val') test_dataset = ImageFolder(data_root, mode='test') data_loader = torch.utils.data.DataLoader( train_dataset, batch_size = train_batchsize, shuffle=True, num_workers=0) val_data_loader = torch.utils.data.DataLoader( val_dataset, batch_size = val_batchsize, shuffle=True, num_workers=0) test_data_loader = torch.utils.data.DataLoader( test_dataset, batch_size = 1, shuffle=True, num_workers=0) writer = SummaryWriter('./record') mylog = open('logs/'+ model_name + '.log','w') tic = time() device = torch.device('cuda:0') no_optim = 0 total_epoch = int(epoch_num) train_epoch_best_loss = 100. val_epoch_best_loss = 100. val_best_iou = 0.3 # criteon = nn.CrossEntropyLoss().to(device) criteon = DiceLoss() # iou_criteon = SoftIoULoss(2) scheduler = solver.lr_strategy() for epoch in range(1, total_epoch + 1): print('---------- Epoch:'+str(epoch)+ ' ----------') # data_loader_iter = iter(data_loader) data_loader_iter = data_loader train_epoch_loss = 0 print('Train:') for img, mask in tqdm(data_loader_iter,ncols=20,total=len(data_loader_iter)): solver.set_input(img, mask) train_loss = solver.optimize() train_epoch_loss += train_loss train_epoch_loss /= len(data_loader_iter) val_data_loader_num = iter(val_data_loader) test_epoch_loss = 0 test_mean_iou = 0 val_pre_list = [] val_mask_list = [] print('Validation:') for val_img, val_mask in tqdm(val_data_loader_num,ncols=20,total=len(val_data_loader_num)): val_img, val_mask = val_img.to(device), val_mask.cpu() val_mask[np.where(val_mask > 0)] = 1 val_mask = val_mask.squeeze(0) predict = solver.test_one_img(val_img) predict_temp = torch.from_numpy(predict).unsqueeze(0) predict_use = V(predict_temp.type(torch.FloatTensor),volatile=True) val_use = V(val_mask.type(torch.FloatTensor),volatile=True) test_epoch_loss += criteon.forward(predict_use,val_use) predict_use = predict_use.squeeze(0) predict_use = predict_use.unsqueeze(1) predict_use[predict_use >= 0.5] = 1 predict_use[predict_use < 0.5] = 0 predict_use = predict_use.type(torch.LongTensor) val_use = val_use.squeeze(1).type(torch.LongTensor) test_mean_iou += iou_pytorch(predict_use, val_use) batch_iou = test_mean_iou / len(val_data_loader_num) val_loss = test_epoch_loss / len(val_data_loader_num) writer.add_scalar('lr', scheduler.get_lr()[0], epoch) writer.add_scalar('train_loss', train_epoch_loss, epoch) writer.add_scalar('val_loss', val_loss, epoch) writer.add_scalar('iou', batch_iou, epoch) mylog.write('********** ' + 'lr={:.10f}'.format(scheduler.get_lr()[0]) + ' **********' + '\n') mylog.write('--epoch:'+ str(epoch) + ' --time:' + str(int(time()-tic)) + ' --train_loss:' + str(train_epoch_loss) + ' --val_loss:' + str(val_loss.item()) + ' --val_iou:' + str(batch_iou.item()) +'\n') print('--epoch:', epoch, ' --time:', int(time()-tic), ' --train_loss:', train_epoch_loss, ' --val_loss:',val_loss.item(), ' --val_iou:',batch_iou.item()) if train_epoch_loss >= train_epoch_best_loss: no_optim += 1 else: no_optim = 0 train_epoch_best_loss = train_epoch_loss solver.save('weights/'+ model_name + '_train_loss_best.th') if batch_iou >= val_best_iou: val_best_iou = batch_iou solver.save('weights/'+model_name + '_iou_best.th') if val_loss <= val_epoch_best_loss: val_epoch_best_loss = val_loss solver.save('weights/' + model_name + '_val_loss_best.th') if no_optim > 10: if solver.old_lr < 5e-8: break solver.load('weights/'+ model_name + '_train_loss_best.th') no_optim = 0 scheduler.step() print('lr={:.10f}'.format(scheduler.get_lr()[0])) mylog.flush() # writer.add_graph(Model(), img) print('Train Finish !') mylog.close() # evaluation # model_path = './weights/'+model_name + '_iou_best.th' model_path = './weights/'+ model_name + '_train_loss_best.th' solver = TTAFrame(Model) solver.load(model_path) label_list = [] pre_list = [] for img, mask in tqdm(test_data_loader,ncols=20,total=len(test_data_loader)): mask[mask>0] = 1 mask = torch.squeeze(mask) mask = mask.numpy() mask = mask.astype(np.int) label_list.append(mask) img = torch.squeeze(img) img = img.numpy() pre = solver.test_one_img_from_path_8(img) pre[pre>=4.0] = 255 pre[pre<4.0] = 0 pre = pre.astype(np.int) pre[pre>0] = 1 pre_list.append(pre) eval_new(label_list, pre_list) if __name__ == '__main__': net = get_transNet(1) # img = torch.randn((2, 3, 256, 256)) # new = net(img) # print(new) train(net)
配置文件内容
参数1:数据路径;参数2:预模型路径,没有就是None;参数3:batchsize;参数4:学习率;参数5:epoch;参数6:模型名字
4.模型加载、训练策略等相关代码framework.py
import cv2 import math import numpy as np import torch import torch.nn as nn from torch.autograd import Variable as V from torch.optim import lr_scheduler class MyFrame(): def __init__(self, net, loss, lr=2e-4, evalmode = False): # self.net = net().cuda() self.net = net.cuda() self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count())) # self.net = torch.nn.DataParallel(self.net, device_ids=[0]) # self.optimizer = torch.optim.Adam(params=self.net.parameters(), lr=lr) self.optimizer = torch.optim.RMSprop(params=self.net.parameters(), lr=lr) # self.optimizer = torch.optim.SGD(params=self.net.parameters(), lr=lr) self.loss = loss() self.old_lr = lr if evalmode: for i in self.net.modules(): if isinstance(i, nn.BatchNorm2d): i.eval() def set_input(self, img_batch, mask_batch=None, img_id=None): self.img = img_batch self.mask = mask_batch self.img_id = img_id def test_one_img(self, img): pred = self.net.forward(img) # pred[pred>0.5] = 1 # pred[pred<=0.5] = 0 # mask = pred.squeeze().cpu().data.numpy() mask = pred.squeeze().cpu().data.numpy() return mask def test_batch(self): self.forward(volatile=True) mask = self.net.forward(self.img).cpu().data.numpy().squeeze(1) mask[mask>0.5] = 1 mask[mask<=0.5] = 0 return mask, self.img_id def test_one_img_from_path(self, path): img = cv2.imread(path) img = np.array(img, np.float32)/255.0 * 3.2 - 1.6 img = V(torch.Tensor(img).cuda()) mask = self.net.forward(img).squeeze().cpu().data.numpy()#.squeeze(1) mask[mask>0.5] = 1 mask[mask<=0.5] = 0 return mask def val_pre(self, img): img90 = np.array(np.rot90(img)) img1 = np.concatenate([img[None],img90[None]]) img2 = np.array(img1)[:,::-1] img3 = np.concatenate([img1,img2]) img4 = np.array(img3)[:,:,::-1] img5 = np.concatenate([img3,img4]).transpose(0,3,1,2) img5 = np.array(img5, np.float32)/255.0 * 3.2 -1.6 img5 = V(torch.Tensor(img5).cuda()) mask = self.net.forward(img5).squeeze().cpu().data.numpy()#.squeeze(1) mask1 = mask[:4] + mask[4:,:,::-1] mask2 = mask1[:2] + mask1[2:,::-1] mask3 = mask2[0] + np.rot90(mask2[1])[::-1,::-1] return mask3 def forward(self, volatile=False): self.img = V(self.img.cuda(), volatile=volatile) if self.mask is not None: self.mask = V(self.mask.cuda(), volatile=volatile) def optimize(self): self.forward() self.optimizer.zero_grad() pred = self.net.forward(self.img) loss = self.loss(self.mask, pred) loss.backward() self.optimizer.step() # return loss.data[0] return loss.item() def save(self, path): torch.save(self.net.state_dict(), path) def load(self, path): self.net.load_state_dict(torch.load(path)) def update_lr(self, new_lr, mylog, factor=False): if factor: new_lr = self.old_lr / new_lr for param_group in self.optimizer.param_groups: param_group['lr'] = new_lr print(mylog, 'update learning rate: %f -> %f' % (self.old_lr, new_lr)) print('update learning rate: %f -> %f' % (self.old_lr, new_lr)) self.old_lr = new_lr def lr_strategy(self): # scheduler = lr_scheduler.StepLR(self.optimizer, step_size=100, gamma=0.1) # scheduler = lr_scheduler.MultiStepLR(self.optimizer, [30, 80], 0.1) scheduler = lr_scheduler.ExponentialLR(self.optimizer, gamma=0.9) return scheduler
5.训练时的iou计算代码iou.py
import torch import numpy as np def iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor, SMOOTH = 1e-6): # You can comment out this line if you are passing tensors of equal shape # But if you are passing output from UNet or something it will most probably # be with the BATCH x 1 x H x W shape outputs = outputs.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W intersection = (outputs & labels).float().sum((1, 2)) # Will be zero if Truth=0 or Prediction=0 union = (outputs | labels).float().sum((1, 2)) # Will be zzero if both are 0 iou = (intersection + SMOOTH) / (union + SMOOTH) # We smooth our devision to avoid 0/0 thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10 # This is equal to comparing with thresolds return thresholded.mean() # Or thresholded.mean() if you are interested in average across the batch # Numpy version # Well, it's the same function, so I'm going to omit the comments def iou_numpy(outputs: np.array, labels: np.array): outputs = outputs.squeeze(1) intersection = (outputs & labels).sum((1, 2)) union = (outputs | labels).sum((1, 2)) iou = (intersection + SMOOTH) / (union + SMOOTH) thresholded = np.ceil(np.clip(20 * (iou - 0.5), 0, 10)) / 10 return thresholded # Or thresholded.mean()
位置
6.损失函数代码dice_bce_loss.py和diceloss.py
dice_bce_loss.py
import torch import torch.nn as nn from torch.autograd import Variable as V import cv2 import numpy as np import torch.nn.functional as F try: from itertools import ifilterfalse except ImportError: # py3k from itertools import filterfalse as ifilterfalse class Dice_bce_loss(nn.Module): def __init__(self, batch=True): super(Dice_bce_loss, self).__init__() self.batch = batch self.bce_loss = nn.BCELoss() def soft_dice_coeff(self, y_true, y_pred): smooth = 1.0 # may change if self.batch: i = torch.sum(y_true) j = torch.sum(y_pred) intersection = torch.sum(y_true * y_pred) else: i = y_true.sum(1).sum(1).sum(1) j = y_pred.sum(1).sum(1).sum(1) intersection = (y_true * y_pred).sum(1).sum(1).sum(1) score = (2. * intersection + smooth) / (i + j + smooth) #score = (intersection + smooth) / (i + j - intersection + smooth)#iou return score.mean() def soft_dice_loss(self, y_true, y_pred): loss = 1 - self.soft_dice_coeff(y_true, y_pred) return loss def __call__(self, y_true, y_pred): a = self.bce_loss(y_pred, y_true) b = self.soft_dice_loss(y_true, y_pred) return a + b class lovasz(nn.Module): def __init__(self, batch=True): super(lovasz, self).__init__() self.bce_loss = nn.BCELoss() # self.cross_entropy = nn.CrossEntropyLoss() def isnan(self, x): return x != x def mean(self, l, ignore_nan=False, empty=0): """ nanmean compatible with generators. """ l = iter(l) if ignore_nan: l = ifilterfalse(self.isnan, l) try: n = 1 acc = next(l) except StopIteration: if empty == 'raise': raise ValueError('Empty mean') return empty for n, v in enumerate(l, 2): acc += v if n == 1: return acc return acc / n def flatten_binary_scores(self, scores, labels, ignore=None): """ Flattens predictions in the batch (binary case) Remove labels equal to 'ignore' """ scores = scores.view(-1) labels = labels.view(-1) if ignore is None: return scores, labels valid = (labels != ignore) vscores = scores[valid] vlabels = labels[valid] return vscores, vlabels def lovasz_grad(self, gt_sorted): """ Computes gradient of the Lovasz extension w.r.t sorted errors See Alg. 1 in paper """ p = len(gt_sorted) gts = gt_sorted.sum() intersection = gts - gt_sorted.float().cumsum(0) union = gts + (1 - gt_sorted).float().cumsum(0) jaccard = 1. - intersection / union if p > 1: # cover 1-pixel case jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] return jaccard def lovasz_hinge_flat(self, logits, labels): """ Binary Lovasz hinge loss logits: [P] Variable, logits at each prediction (between -\infty and +\infty) labels: [P] Tensor, binary ground truth labels (0 or 1) ignore: label to ignore """ if len(labels) == 0: # only void pixels, the gradients should be 0 return logits.sum() * 0. signs = 2. * labels.float() - 1. errors = (1. - logits * V(signs)) errors_sorted, perm = torch.sort(errors, dim=0, descending=True) perm = perm.data gt_sorted = labels[perm] grad = self.lovasz_grad(gt_sorted) loss = torch.dot(F.relu(errors_sorted), V(grad)) return loss def lovasz_hinge(self, logits, labels, per_image=False, ignore=None): """ Binary Lovasz hinge loss logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) per_image: compute the loss per image instead of per batch ignore: void class id """ if per_image: loss = self.mean(self.lovasz_hinge_flat(*self.flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) for log, lab in zip(logits, labels)) else: loss = self.lovasz_hinge_flat(*self.flatten_binary_scores(logits, labels, ignore)) return loss def __call__(self, y_true, y_pred): a = (self.lovasz_hinge(y_pred, y_true) + self.lovasz_hinge(-y_pred, 1 - y_true)) / 2 b = self.bce_loss(y_pred, y_true) c = self.lovasz_hinge(y_pred, y_true) return a + b class multi_loss(nn.Module): def __init__(self, batch=True): super(multi_loss, self).__init__() self.batch = batch self.multi_loss = nn.NLLLoss() def __call__(self, y_true, y_pred): a = self.multi_loss(y_true, y_pred) return a
diceloss.py
import torch import torch.nn as nn class DiceLoss(nn.Module): def __init__(self): super(DiceLoss, self).__init__() def forward(self, input, target): N = target.size(0) smooth = 1 input_flat = input.view(N, -1) target_flat = target.view(N, -1) intersection = input_flat * target_flat loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth) loss = 1 - loss.sum() / N return loss class MulticlassDiceLoss(nn.Module): """ requires one hot encoded target. Applies DiceLoss on each class iteratively. requires input.shape[0:1] and target.shape[0:1] to be (N, C) where N is batch size and C is number of classes """ def __init__(self): super(MulticlassDiceLoss, self).__init__() def forward(self, input, target, weights=None): C = target.shape[1] # if weights is None: # weights = torch.ones(C) #uniform weights for all classes dice = DiceLoss() totalLoss = 0 for i in range(C): diceLoss = dice(input[:,i], target[:,i]) if weights is not None: diceLoss *= weights[i] totalLoss += diceLoss return totalLoss
位置
7.模型调用文件,TransUnet.py
import torch import torch.nn as nn import functools import torch.nn.functional as F from .vit_seg_modeling import VisionTransformer as ViT_seg from .vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg def get_transNet(n_classes): img_size = 256 vit_patches_size = 16 vit_name = 'R50-ViT-B_16' config_vit = CONFIGS_ViT_seg[vit_name] config_vit.n_classes = n_classes config_vit.n_skip = 3 if vit_name.find('R50') != -1: config_vit.patches.grid = (int(img_size / vit_patches_size), int(img_size / vit_patches_size)) net = ViT_seg(config_vit, img_size=img_size, num_classes=n_classes) return net if __name__ == '__main__': net = get_transNet(2) img = torch.randn((2, 3, 512, 512)) segments = net(img) print(segments.size()) # for edge in edges: # print(edge.size())
位置,红框里的三个文件在原作者那里下载,链接https://github.com/Beckschen/TransUNet/tree/main/networks
8.预测代码inference.py
import torch import torch.optim as optim from torch.optim import lr_scheduler import torch.nn as nn import torch.utils.data as data import torch.nn.functional as F from torch.autograd import Variable as V import cv2 import os import math import warnings from tqdm import tqdm import numpy as np from data import ImageFolder from models.networks.TransUnet import get_transNet BATCHSIZE_PER_CARD = 8 class TTAFrame(): def __init__(self, net): # self.net = net(out_planes=1).cuda() self.net = net.cuda() # self.net = net().cuda() self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count())) # self.net = torch.nn.DataParallel(self.net, device_ids=[0]) def test_one_img_from_path(self, path, evalmode = True): if evalmode: self.net.eval() batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD if batchsize >= 8: return self.test_one_img_from_path_1(path) elif batchsize >= 4: return self.test_one_img_from_path_2(path) elif batchsize >= 2: return self.test_one_img_from_path_4(path) def test_one_img_from_path_8(self, img): # img = cv2.imread(path)#.transpose(2,0,1)[None] img90 = np.array(np.rot90(img)) img1 = np.concatenate([img[None],img90[None]]) img2 = np.array(img1)[:,::-1] img3 = np.array(img1)[:,:,::-1] img4 = np.array(img2)[:,:,::-1] img1 = img1.transpose(0,3,1,2) img2 = img2.transpose(0,3,1,2) img3 = img3.transpose(0,3,1,2) img4 = img4.transpose(0,3,1,2) img1 = V(torch.Tensor(np.array(img1, np.float32)/255.0 * 3.2 - 1.6).cuda()) img2 = V(torch.Tensor(np.array(img2, np.float32)/255.0 * 3.2 - 1.6).cuda()) img3 = V(torch.Tensor(np.array(img3, np.float32)/255.0 * 3.2 - 1.6).cuda()) img4 = V(torch.Tensor(np.array(img4, np.float32)/255.0 * 3.2 - 1.6).cuda()) maska = self.net.forward(img1).squeeze().cpu().data.numpy() maskb = self.net.forward(img2).squeeze().cpu().data.numpy() maskc = self.net.forward(img3).squeeze().cpu().data.numpy() maskd = self.net.forward(img4).squeeze().cpu().data.numpy() mask1 = maska + maskb[:,::-1] + maskc[:,:,::-1] + maskd[:,::-1,::-1] mask2 = mask1[0] + np.rot90(mask1[1])[::-1,::-1] return mask2 def test_one_img_from_path_4(self, path): img = cv2.imread(path)#.transpose(2,0,1)[None] img90 = np.array(np.rot90(img)) img1 = np.concatenate([img[None],img90[None]]) img2 = np.array(img1)[:,::-1] img3 = np.array(img1)[:,:,::-1] img4 = np.array(img2)[:,:,::-1] img1 = img1.transpose(0,3,1,2) img2 = img2.transpose(0,3,1,2) img3 = img3.transpose(0,3,1,2) img4 = img4.transpose(0,3,1,2) img1 = V(torch.Tensor(np.array(img1, np.float32)/255.0 * 3.2 -1.6).cuda()) img2 = V(torch.Tensor(np.array(img2, np.float32)/255.0 * 3.2 -1.6).cuda()) img3 = V(torch.Tensor(np.array(img3, np.float32)/255.0 * 3.2 -1.6).cuda()) img4 = V(torch.Tensor(np.array(img4, np.float32)/255.0 * 3.2 -1.6).cuda()) maska = self.net.forward(img1).squeeze().cpu().data.numpy() maskb = self.net.forward(img2).squeeze().cpu().data.numpy() maskc = self.net.forward(img3).squeeze().cpu().data.numpy() maskd = self.net.forward(img4).squeeze().cpu().data.numpy() mask1 = maska + maskb[:,::-1] + maskc[:,:,::-1] + maskd[:,::-1,::-1] mask2 = mask1[0] + np.rot90(mask1[1])[::-1,::-1] return mask2 def test_one_img_from_path_2(self, path): img = cv2.imread(path)#.transpose(2,0,1)[None] img90 = np.array(np.rot90(img)) img1 = np.concatenate([img[None],img90[None]]) img2 = np.array(img1)[:,::-1] img3 = np.concatenate([img1,img2]) img4 = np.array(img3)[:,:,::-1] img5 = img3.transpose(0,3,1,2) img5 = np.array(img5, np.float32)/255.0 * 3.2 -1.6 img5 = V(torch.Tensor(img5).cuda()) img6 = img4.transpose(0,3,1,2) img6 = np.array(img6, np.float32)/255.0 * 3.2 -1.6 img6 = V(torch.Tensor(img6).cuda()) maska = self.net.forward(img5).squeeze().cpu().data.numpy()#.squeeze(1) maskb = self.net.forward(img6).squeeze().cpu().data.numpy() mask1 = maska + maskb[:,:,::-1] mask2 = mask1[:2] + mask1[2:,::-1] mask3 = mask2[0] + np.rot90(mask2[1])[::-1,::-1] return mask3 def test_one_img_from_path_1(self, img): # img = cv2.imread(path)#.transpose(2,0,1)[None] img90 = np.array(np.rot90(img)) img1 = np.concatenate([img[None],img90[None]]) img2 = np.array(img1)[:,::-1] img3 = np.concatenate([img1,img2]) img4 = np.array(img3)[:,:,::-1] img5 = np.concatenate([img3,img4]).transpose(0,3,1,2) img5 = np.array(img5, np.float32)/255.0 * 3.2 -1.6 img5 = V(torch.Tensor(img5).cuda()) mask = self.net.forward(img5).squeeze().cpu().data.numpy()#.squeeze(1) mask1 = mask[:4] + mask[4:,:,::-1] mask2 = mask1[:2] + mask1[2:,::-1] mask3 = mask2[0] + np.rot90(mask2[1])[::-1,::-1] return mask3 def load(self, path): self.net.load_state_dict(torch.load(path)) # self.net.load_state_dict(torch.load(path,map_location={'cuda:4':'cuda:0'})) def tta_use(self,img): #1 tta_model = tta.SegmentationTTAWrapper(self.net, tta.aliases.flip_transform(), merge_mode='mean') img = img.transpose(2,1,0) img = np.array(img, np.float32)/255.0 * 3.2 -1.6 img = V(torch.Tensor(img).cuda()) # print(img.shape) mask = tta_model.forward(img.unsqueeze(0)).squeeze().cpu().data.numpy() return mask if __name__ == "__main__": test_path = './TransUnet/dataset/build/test2/' save_path = './TransUnet/dataset/build/result/' imgs = os.listdir(test_path) model_path = './weights/trans_build_iou_best.th' net = get_transNet(1) solver = TTAFrame(net) solver.load(model_path) for img in tqdm(imgs,ncols=20,total=len(imgs)): img_path = os.path.join(test_path, img) im = cv2.imread(img_path) pre = solver.test_one_img_from_path_8(im) pre[pre>=4.0] = 255 pre[pre<4.0] = 0 save_out = os.path.join(save_path, img) cv2.imwrite(save_out, pre)
9.精度评价eval.py
# -*- coding: utf-8 -*- import os import cv2 import numpy as np from osgeo import gdal from sklearn.metrics import confusion_matrix class IOUMetric: """ Class to calculate mean-iou using fast_hist method """ def __init__(self, num_classes): self.num_classes = num_classes self.hist = np.zeros((num_classes, num_classes)) def _fast_hist(self, label_pred, label_true): mask = (label_true >= 0) & (label_true < self.num_classes) hist = np.bincount( self.num_classes * label_true[mask].astype(int) + label_pred[mask], minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes) return hist def evaluate(self, predictions, gts): for lp, lt in zip(predictions, gts): assert len(lp.flatten()) == len(lt.flatten()) self.hist += self._fast_hist(lp.flatten(), lt.flatten()) # miou iou = np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0) - np.diag(self.hist)) miou = np.nanmean(iou) # mean acc acc = np.diag(self.hist).sum() / self.hist.sum() acc_cls = np.nanmean(np.diag(self.hist) / self.hist.sum(axis=1)) freq = self.hist.sum(axis=1) / self.hist.sum() fwavacc = (freq[freq > 0] * iou[freq > 0]).sum() return acc, acc_cls, iou, miou, fwavacc def read_img(filename): dataset=gdal.Open(filename) im_width = dataset.RasterXSize im_height = dataset.RasterYSize im_geotrans = dataset.GetGeoTransform() im_proj = dataset.GetProjection() im_data = dataset.ReadAsArray(0,0,im_width,im_height) del dataset return im_proj,im_geotrans,im_width, im_height,im_data def write_img(filename, im_proj, im_geotrans, im_data): if 'int8' in im_data.dtype.name: datatype = gdal.GDT_Byte elif 'int16' in im_data.dtype.name: datatype = gdal.GDT_UInt16 else: datatype = gdal.GDT_Float32 if len(im_data.shape) == 3: im_bands, im_height, im_width = im_data.shape else: im_bands, (im_height, im_width) = 1,im_data.shape driver = gdal.GetDriverByName("GTiff") dataset = driver.Create(filename, im_width, im_height, im_bands, datatype) dataset.SetGeoTransform(im_geotrans) dataset.SetProjection(im_proj) if im_bands == 1: dataset.GetRasterBand(1).WriteArray(im_data) else: for i in range(im_bands): dataset.GetRasterBand(i+1).WriteArray(im_data[i]) del dataset def eval_re(label_path, predict_path, eval_path): pres = os.listdir(predict_path) labels = [] predicts = [] for im in pres: if im[-4:] == '.tif': label_name = im.split('.')[0] + '.tif' lab_path = os.path.join(label_path, label_name) pre_path = os.path.join(predict_path, im) im_proj,im_geotrans,im_width, im_height, label = read_img(lab_path) im_proj,im_geotrans,im_width, im_height, pre = read_img(pre_path) # label = cv2.imread(lab_path,0) # pre = cv2.imread(pre_path,0) label[label>0] = 1 pre[pre>0] = 1 label = np.uint8(label) pre = np.uint8(pre) labels.append(label) predicts.append(pre) el = IOUMetric(2) acc, acc_cls, iou, miou, fwavacc = el.evaluate(predicts, labels) pres = os.listdir(predict_path) init = np.zeros((2,2)) for im in pres: lb_path = os.path.join(label_path, im) pre_path = os.path.join(predict_path, im) # lb = cv2.imread(lb_path,0) # pre = cv2.imread(pre_path,0) im_proj,im_geotrans,im_width, im_height, lb = read_img(lb_path) im_proj,im_geotrans,im_width, im_height, pre = read_img(pre_path) lb[lb>0] = 1 pre[pre>0] = 1 lb = np.uint8(lb) pre = np.uint8(pre) lb = lb.flatten() pre = pre.flatten() confuse = confusion_matrix(lb, pre) init += confuse precision = init[1][1]/(init[0][1] + init[1][1]) recall = init[1][1]/(init[1][0] + init[1][1]) accuracy = (init[0][0] + init[1][1])/init.sum() f1_score = 2*precision*recall/(precision + recall) with open(eval_path, 'a') as f: f.write('accuracy: ' + str(accuracy) + '\n') f.write('recal: ' + str(recall) + '\n') f.write('miou: ' + str(miou)) def eval_func(label_path, predict_path): pres = os.listdir(predict_path) labels = [] predicts = [] for im in pres: if im[-4:] == '.png': label_name = im.split('.')[0] + '.png' lab_path = os.path.join(label_path, label_name) pre_path = os.path.join(predict_path, im) label = cv2.imread(lab_path,0) pre = cv2.imread(pre_path,0) label[label>0] = 1 pre[pre>0] = 1 label = np.uint8(label) pre = np.uint8(pre) labels.append(label) predicts.append(pre) el = IOUMetric(2) acc, acc_cls, iou, miou, fwavacc = el.evaluate(predicts,labels) print('acc: ',acc) print('acc_cls: ',acc_cls) print('iou: ',iou) print('miou: ',miou) print('fwavacc: ',fwavacc) pres = os.listdir(predict_path) init = np.zeros((2,2)) for im in pres: lb_path = os.path.join(label_path, im) pre_path = os.path.join(predict_path, im) lb = cv2.imread(lb_path,0) pre = cv2.imread(pre_path,0) lb[lb>0] = 1 pre[pre>0] = 1 lb = np.uint8(lb) pre = np.uint8(pre) lb = lb.flatten() pre = pre.flatten() confuse = confusion_matrix(lb, pre) init += confuse precision = init[1][1]/(init[0][1] + init[1][1]) recall = init[1][1]/(init[1][0] + init[1][1]) accuracy = (init[0][0] + init[1][1])/init.sum() f1_score = 2*precision*recall/(precision + recall) print('class_accuracy: ', precision) print('class_recall: ', recall) print('accuracy: ', accuracy) print('f1_score: ', f1_score) def eval_new(label_list, pre_list): el = IOUMetric(2) acc, acc_cls, iou, miou, fwavacc = el.evaluate(pre_list, label_list) print('acc: ',acc) # print('acc_cls: ',acc_cls) print('iou: ',iou) print('miou: ',miou) print('fwavacc: ',fwavacc) init = np.zeros((2,2)) for i in range(len(label_list)): lab = label_list[i].flatten() pre = pre_list[i].flatten() confuse = confusion_matrix(lab, pre) init += confuse precision = init[1][1]/(init[0][1] + init[1][1]) recall = init[1][1]/(init[1][0] + init[1][1]) accuracy = (init[0][0] + init[1][1])/init.sum() f1_score = 2*precision*recall/(precision + recall) print('class_accuracy: ', precision) print('class_recall: ', recall) # print('accuracy: ', accuracy) print('f1_score: ', f1_score) if __name__ == "__main__": label_path = './data/build/test/labels/' predict_path = './data/build/test/re/' eval_func(label_path, predict_path)
我用的训练数据:
链接:https://pan.baidu.com/s/1487wODEn5bpTbmBw91Oavw
提取码:zow5
–来自百度网盘超级会员V5的分享
清理电脑文件发现原始的预模型我居然有下载,链接
链接:https://pan.baidu.com/s/1Og9eTorM6saM95uWITVqhg
提取码:29zz
–来自百度网盘超级会员V5的分享
以上二分类源码:
https://download.csdn.net/download/qq_20373723/85035195
多分类说明:改多分类只需要找到网络最后一层,把sigmoid 改成softmax就好了,数据加载的地方也要改下,别忘了训练的时候把类别改了
实在不想改了或者想要参考的话:
https://download.csdn.net/download/qq_20373723/83024925
补充:资源里inference.py的第15行改成from models.networks.TransUnet import get_transNet,做实验忘了改到和训练一一致了
测试数据链接:
https://download.csdn.net/download/qq_20373723/83018556
有什么问题评论区或者私信都可以找我,看到了会回复的,另外,付费的资源尽量还是不下载吧,我觉得稍微懂一点的应该能独自完成的,参考博客肯定可以跑出来的
题外话:有什么新的比较好的网络可以评论推荐给我,我来复现贴出来大家一起用一用
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。