赞
踩
数据集在data文件夹里,分别训练集在“imgs”,label在“mask”里,数据集用的是医学影像细胞分割的样本,其实在train集里有原图和对应的样本图。
train_x的样图:
train_y(label)样图:
import argparse import logging import os import sys import numpy as np import torch import torch.nn as nn from torch import optim from tqdm import tqdm from eval import eval_net from unet import UNet from torch.utils.tensorboard import SummaryWriter from utils.dataset import BasicDataset from torch.utils.data import DataLoader, random_split dir_img = 'data/imgs/' dir_mask = 'data/masks/' dir_checkpoint = 'checkpoints/' def train_net(net, device, epochs=2, batch_size=1, lr=0.001, val_percent=0.1, save_cp=True, img_scale=0.5): dataset = BasicDataset(dir_img, dir_mask, img_scale) n_val = int(len(dataset) * val_percent) n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, drop_last=True) writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') global_step = 0 logging.info(f'''Starting training: Epochs: {epochs} Batch size: {batch_size} Learning rate: {lr} Training size: {n_train} Validation size: {n_val} Checkpoints: {save_cp} Device: {device.type} Images scaling: {img_scale} ''') optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2) if net.n_classes > 1: criterion = nn.CrossEntropyLoss() else: criterion = nn.BCEWithLogitsLoss() for epoch in range(epochs): net.train() epoch_loss = 0 with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for batch in train_loader: imgs = batch['image'] true_masks = batch['mask'] assert imgs.shape[1] == net.n_channels, \ f'Network has been defined with {net.n_channels} input channels, ' \ f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ 'the images are loaded correctly.' imgs = imgs.to(device=device, dtype=torch.float32) mask_type = torch.float32 if net.n_classes == 1 else torch.long true_masks = true_masks.to(device=device, dtype=mask_type) masks_pred = net(imgs) loss = criterion(masks_pred, true_masks) epoch_loss += loss.item() writer.add_scalar('Loss/train', loss.item(), global_step) pbar.set_postfix(**{'loss (batch)': loss.item()}) optimizer.zero_grad() loss.backward() nn.utils.clip_grad_value_(net.parameters(), 0.1) optimizer.step() pbar.update(imgs.shape[0]) global_step += 1 if global_step % (n_train // (10 * batch_size)) == 0: for tag, value in net.named_parameters(): tag = tag.replace('.', '/') writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step) writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step) val_score = eval_net(net, val_loader, device) scheduler.step(val_score) writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step) if net.n_classes > 1: logging.info('Validation cross entropy: {}'.format(val_score)) writer.add_scalar('Loss/test', val_score, global_step) else: logging.info('Validation Dice Coeff: {}'.format(val_score)) writer.add_scalar('Dice/test', val_score, global_step) writer.add_images('images', imgs, global_step) if net.n_classes == 1: writer.add_images('masks/true', true_masks, global_step) writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step) if save_cp: try: os.mkdir(dir_checkpoint) logging.info('Created checkpoint directory') except OSError: pass torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth') logging.info(f'Checkpoint {epoch + 1} saved !') writer.close() def get_args(): parser = argparse.ArgumentParser(description='Train the UNet on images and target masks', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('-e', '--epochs', metavar='E', type=int, default=100, help='Number of epochs', dest='epochs') parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1, help='Batch size', dest='batchsize') parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.00001, help='Learning rate', dest='lr') #0.00001 parser.add_argument('-f', '--load', dest='load', type=str, default=False, help='Load model from a .pth file') parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.5, help='Downscaling factor of the images')#原先是0.5 parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0, help='Percent of the data that is used as validation (0-100)') return parser.parse_args() if __name__ == '__main__': logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') args = get_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logging.info(f'Using device {device}') # Change here to adapt to your data # n_channels=3 for RGB images # n_classes is the number of probabilities you want to get per pixel # - For 1 class and background, use n_classes=1 # - For 2 classes, use n_classes=1 # - For N > 2 classes, use n_classes=N net = UNet(n_channels=1, n_classes=1, bilinear=True) logging.info(f'Network:\n' f'\t{net.n_channels} input channels\n' f'\t{net.n_classes} output channels (classes)\n' f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling') if args.load: net.load_state_dict( torch.load(args.load, map_location=device) ) logging.info(f'Model loaded from {args.load}') net.to(device=device) # faster convolutions, but more memory # cudnn.benchmark = True try: train_net(net=net, epochs=args.epochs, batch_size=args.batchsize, lr=args.lr, device=device, img_scale=args.scale, val_percent=args.val / 100) except KeyboardInterrupt: torch.save(net.state_dict(), 'INTERRUPTED.pth') logging.info('Saved interrupt') try: sys.exit(0) except SystemExit: os._exit(0)
(1)更改好训练集的地址(x,y)以后,指令输入python train.py即可开始训练。注意:根据电脑的内存和配置情况,可以选择batch_size的大小,另外epoch和学习率也要根据数据集的类型不同自己调参。
(2)保存的模型会根据步长自动保存在“checkpoint”文件下,选择最好的模型,改名为“MODEL.pth”放在根目录下。
(3)优化器选择“RMSprop”,评估指标选择交叉熵损失。
import argparse import logging import os import cv2 import numpy as np import torch import torch.nn.functional as F from PIL import Image from torchvision import transforms from unet import UNet from utils.data_vis import plot_img_and_mask from utils.dataset import BasicDataset def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5): net.eval() img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) if net.n_classes > 1: probs = F.softmax(output, dim=1) else: probs = torch.sigmoid(output) probs = probs.squeeze(0) tf = transforms.Compose( [ transforms.ToPILImage(), transforms.Resize(full_img.size[1]), transforms.ToTensor() ] ) probs = tf(probs.cpu()) full_mask = probs.squeeze().cpu().numpy() return full_mask > out_threshold def get_args(): parser = argparse.ArgumentParser(description='Predict masks from input images', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE', help="Specify the file in which the model is stored") parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='filenames of input images', required=True) parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', help='Filenames of ouput images') parser.add_argument('--viz', '-v', action='store_true', help="Visualize the images as they are processed", default=False) parser.add_argument('--no-save', '-n', action='store_true', help="Do not save the output masks", default=False) parser.add_argument('--mask-threshold', '-t', type=float, help="Minimum probability value to consider a mask pixel white", default=0.5) parser.add_argument('--scale', '-s', type=float, help="Scale factor for the input images", default=0.5) #0.5 return parser.parse_args() def get_output_filenames(args): in_files = args.input out_files = [] if not args.output: for f in in_files: pathsplit = os.path.splitext(f) out_files.append("{}_OUT{}".format(pathsplit[0], pathsplit[1])) elif len(in_files) != len(args.output): logging.error("Input files and output files are not of the same length") raise SystemExit() else: out_files = args.output return out_files def mask_to_image(mask): # for i in range(mask.shape[0]): # for j in range(mask.shape[1]): # if mask[i, j] >0: # mask[i,j]=255 #自己定 # print(mask*255) return Image.fromarray((mask*255).astype(np.uint8)) if __name__ == "__main__": args = get_args() in_files = args.input out_files = get_output_filenames(args) net = UNet(n_channels=1, n_classes=1) logging.info("Loading model {}".format(args.model)) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logging.info(f'Using device {device}') net.to(device=device) net.load_state_dict(torch.load(args.model, map_location=device)) logging.info("Model loaded !") for i, fn in enumerate(in_files): logging.info("\nPredicting image {} ...".format(fn)) img = Image.open(fn) mask = predict_img(net=net, full_img=img, scale_factor=args.scale, out_threshold=args.mask_threshold, device=device) if not args.no_save: out_fn = out_files[i] result = mask_to_image(mask) result.save(out_files[i]) logging.info("Mask saved to {}".format(out_files[i])) if args.viz: logging.info("Visualizing results for image {}, close to continue ...".format(fn)) plot_img_and_mask(img, mask)
注意:
(1)运行预测时,可以输入python predict.py -i (这里输入预测图片的路径) -o output.jpg(这里某人输出文件在根目录下,也可以改变输出文件的位置)
(2)预测时用到的模型,默认是你放在根目录下改为“MODEL.pth”的模型,积极调参,可以获得更好的分割效果。
(3)如果想制作自己的数据集,也可以用Labelme等打标工具,自己制作。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。