赞
踩
SAM (Segment Anything Model)是Meta AI开发的一种分割模型。它被认为是计算机视觉的第一个基础模型。SAM是在包含数百万图像和数十亿mask的庞大数据语料库上进行训练的,这使得它非常强大。SAM能够为各种各样的图像生成准确的分割mask。
SAM通常在自然图像上表现优异,但是在特定领域,如医疗影响,遥感图像等,由于训练数据集缺乏这些数据,SAM的效果并不是理想。因此,在特定数据集上微调SAM是十分有必要的。
(1)安装好segment anything:
git clone https://github.com/facebookresearch/segment-anything.git
cd segment-anything
python setup.py install
(2)安装lightning包,它是轻量级的PyTorch库,用于高性能人工智能研究的轻量级PyTorch包装器。本文基于它对SAM进行微调:
pip install lightning
使用的数据集下载地址:https://han-seg2023.grand-challenge.org/,它是一个多器官的医疗影像数据集,当然,你也可以使用自己的数据集
该配置文件含有SAM的哪些部分需要训练,以及数据集的相关配置,如数据集位置,具体配置如下(在config.py文件中):
from box import Box config = { "num_devices": 1, "batch_size": 6, "num_workers": 4, "num_epochs": 20, "save_interval": 2, "resume": None, "out_dir": "模型权重输出地址", "opt": { "learning_rate": 8e-4, "weight_decay": 1e-4, "decay_factor": 10, "steps": [60000, 86666], "warmup_steps": 250, }, "model": { "type": 'vit_b', "checkpoint": "SAM的权重地址", "freeze": { "image_encoder": True, "prompt_encoder": True, "mask_decoder": True, }, }, "dataset": { "root_dir": "数据集的根目录", "sample_num": 4, "target_size": 1024 } } cfg = Box(config)
其中freeze部分决定SAM的哪些部分冷却不用训练,dataset则是数据集的相关配置,sample_num表示采样的point的数目,target_size则是输入SAM的图片大小。
这里使用了box这个包,可以通过如下命令安装:
pip install python-box
该部分负责在数据集加载的时候选择哪些数据进行训练,这里我选择器官mandible进行训练。
同时由于该数据是3D数据,对数据进行切片处理,将3D数据变成2D图像,该部分代码为:
class HaNDataset(Dataset): def __init__(self, cfg): super().__init__() self.gt_path = os.path.join(cfg.dataset.root_dir, "oar_3d") self.img_path = os.path.join(cfg.dataset.root_dir, "ct_3d") # 文件列表 self.img_file_list = sorted(os.listdir(self.img_path)) self.gt_file_list = sorted(os.listdir(self.gt_path)) # 器官类别 self.category = [7] self.cat2names = {7 : "mandible"} # 数据列表,含所有切片 self.data_list = [] for i in range(len(self.img_file_list)): img_file_path = os.path.join(self.img_path, self.img_file_list[i]) gt_file_path = os.path.join(self.gt_path, self.gt_file_list[i]) img_data = nib.load(img_file_path).get_fdata() gt_data = nib.load(gt_file_path).get_fdata() axial_num = img_data.shape[2] for a in range(axial_num): a_gt_data = gt_data[:, :, a] ps_gt_data = np.zeros_like(a_gt_data) for c in self.category: region = (a_gt_data == c) if np.sum(region) > 0: self.data_list.append([i, a, c]) print(f"Data size is:{len(self.data_list)}") # 输入SAM的尺寸要是这个 self.target_size = cfg.dataset.target_size # 正负样本点数目 self.sample_point_num = cfg.dataset.sample_num def __len__(self): return len(self.data_list)
由于HaN这个数据集的数据格式是nii文件,其数据的范围是0-2000,而图像的数据范围是0-255,因此需要将数据范围截断并重新映射。
输入SAM的图像大小应为1024*1024,因此需要将其resize成目标尺寸。
除此之外,由于HaN并没有提供box和point提示,因此还需要从mask中自动获得相应的提示。
这些部分的实现为(都在HaNDataset当中):
def convert_to_three_channels(self, image): # 创建一个具有相同尺寸的3通道图像数组 three_channel_image = np.zeros((image.shape[ 0 ], image.shape[ 1 ], 3 ), dtype=np.uint8) # 将原始单通道图像复制到每个通道 for i in range(3): three_channel_image[:, :, i] = image return three_channel_image def __getitem__(self, idx): data_id = self.data_list[idx] f_id = data_id[0] axial_id = data_id[1] category_id = data_id[2] name = self.cat2names[category_id] img_data_path = os.path.join(self.img_path, self.img_file_list[f_id]) gt_data_path = os.path.join(self.gt_path, self.gt_file_list[f_id]) # nii文件的数据范围是0-2000,和图像的范围不符 img_data = nib.load(img_data_path).get_fdata() # 截断,对于ct图像 img_data[img_data < (50 + 1024 - 200)] = (50 + 1024 - 200) img_data[img_data > (50 + 1024 + 200)] = (50 + 1024 + 200) img_data = (img_data - (50 + 1024 - 200)) / 400.0 * 255.0 img_data = img_data[:, :, axial_id] img_data = self.convert_to_three_channels(img_data) all_gt_data = nib.load(gt_data_path).get_fdata()[:, :, axial_id] gt_data = np.zeros_like(all_gt_data) gt_data[all_gt_data == category_id] = 1 # 将image和gt变为target size org_size = gt_data.shape transforms = train_transforms(self.target_size, org_size[0], org_size[1]) augments = transforms(image=img_data, mask=gt_data) img_data, gt_data = augments['image'].to(torch.float32), augments['mask'].to(torch.int64) # 获得box,验证时max_pixel为0 bbox_data = get_boxes_from_mask(gt_data, max_pixel=0)[0] # 获得point提示 point_coords, point_labels = init_point_sampling(gt_data, self.sample_point_num) return { "org_size": torch.tensor(org_size), "category" : name, "image": img_data, "label" : gt_data, "bbox" : bbox_data, "point_coords": point_coords, "point_labels": point_labels }
获得box和point,以及resize图像的代码为:
def init_point_sampling(mask, get_point=1): if isinstance(mask, torch.Tensor): mask = mask.numpy() # Get coordinates of black/white pixels fg_coords = np.argwhere(mask == 1)[:, ::-1] bg_coords = np.argwhere(mask == 0)[:, ::-1] fg_size = len(fg_coords) bg_size = len(bg_coords) if get_point == 1: if fg_size > 0: index = np.random.randint(fg_size) fg_coord = fg_coords[index] label = 1 else: index = np.random.randint(bg_size) fg_coord = bg_coords[index] label = 0 return torch.as_tensor([fg_coord.tolist()], dtype=torch.float), torch.as_tensor([label], dtype=torch.int) else: num_fg = get_point // 2 num_bg = get_point - num_fg fg_indices = np.random.choice(fg_size, size=num_fg, replace=True) bg_indices = np.random.choice(bg_size, size=num_bg, replace=True) fg_coords = fg_coords[fg_indices] bg_coords = bg_coords[bg_indices] coords = np.concatenate([fg_coords, bg_coords], axis=0) labels = np.concatenate([np.ones(num_fg), np.zeros(num_bg)]).astype(int) indices = np.random.permutation(get_point) coords, labels = torch.as_tensor(coords[indices], dtype=torch.float), torch.as_tensor(labels[indices], dtype=torch.int) return coords, labels def get_boxes_from_mask(mask, box_num=1, std=0.1, max_pixel=5): if isinstance(mask, torch.Tensor): mask = mask.numpy() label_img = label(mask) regions = regionprops(label_img) # Iterate through all regions and get the bounding box coordinates boxes = [tuple(region.bbox) for region in regions] # If the generated number of boxes is greater than the number of categories, # sort them by region area and select the top n regions if len(boxes) >= box_num: sorted_regions = sorted(regions, key=lambda x: x.area, reverse=True)[:box_num] boxes = [tuple(region.bbox) for region in sorted_regions] # If the generated number of boxes is less than the number of categories, # duplicate the existing boxes elif len(boxes) < box_num: num_duplicates = box_num - len(boxes) boxes += [boxes[i % len(boxes)] for i in range(num_duplicates)] # Perturb each bounding box with noise noise_boxes = [] for box in boxes: y0, x0, y1, x1 = box width, height = abs(x1 - x0), abs(y1 - y0) # Calculate the standard deviation and maximum noise value noise_std = min(width, height) * std max_noise = min(max_pixel, int(noise_std * 5)) # Add random noise to each coordinate try: noise_x = np.random.randint(-max_noise, max_noise) except: noise_x = 0 try: noise_y = np.random.randint(-max_noise, max_noise) except: noise_y = 0 x0, y0 = x0 + noise_x, y0 + noise_y x1, y1 = x1 + noise_x, y1 + noise_y noise_boxes.append((x0, y0, x1, y1)) return torch.as_tensor(noise_boxes, dtype=torch.float) def train_transforms(img_size, ori_h, ori_w): transforms = [] transforms.append(A.Resize(int(img_size), int(img_size), interpolation=cv2.INTER_NEAREST)) transforms.append(ToTensorV2(p=1.0)) return A.Compose(transforms, p=1.)
因为我们已经安装好了segment anything,因此可以直接调用相关模块,然后组成一个生成mask的流程即可,该部分代码为:
import torch.nn as nn import torch.nn.functional as F from segment_anything import sam_model_registry from segment_anything import SamPredictor class Model(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg def setup(self): self.model = sam_model_registry[self.cfg.model.type](checkpoint=self.cfg.model.checkpoint) self.model.train() if self.cfg.model.freeze.image_encoder: for name, param in self.model.image_encoder.named_parameters(): param.requires_grad = False if self.cfg.model.freeze.prompt_encoder: for name, param in self.model.prompt_encoder.named_parameters(): param.requires_grad = False # freeze mask decoder参数 if self.cfg.model.freeze.mask_decoder: for name, param in self.model.mask_decoder.named_parameters(): param.requires_grad = False def forward(self, images, bboxes, org_size, point_coords = None, point_labels = None): _, _, H, W = images.shape image_embeddings = self.model.image_encoder(images) pred_masks = [] ious = [] # 还要添加points,输入格式(points coords, points label): #coords:B,N,2 labels:B,N # 一个batch一个batch处理 for embedding, bbox, coord, label in zip(image_embeddings, bboxes, point_coords, point_labels): bbox = bbox.unsqueeze(0) coord = coord.unsqueeze(0) label = label.unsqueeze(0) point = (coord, label) sparse_embeddings, dense_embeddings = self.model.prompt_encoder( points=point, boxes=bbox, masks=None, ) low_res_masks, iou_predictions = self.model.mask_decoder( image_embeddings=embedding.unsqueeze(0), image_pe=self.model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=False, ) masks = F.interpolate( low_res_masks, (H, W), mode="bilinear", align_corners=False, ) pred_masks.append(masks.squeeze(1)) ious.append(iou_predictions) return pred_masks, ious def get_predictor(self): return SamPredictor(self.model)
其中setup方法决定哪些参数需要进行训练,哪些不用。
首先使用lightning进行配置:
import lightning as L
from config import cfg
fabric = L.Fabric(accelerator="auto",
devices=cfg.num_devices,
strategy="auto",
loggers=[TensorBoardLogger(cfg.out_dir, name="lightning-sam")])
fabric.launch()
fabric.seed_everything(1337 + fabric.global_rank)
然后创建模型和加载数据集,代码为:
with fabric.device:
model = Model(cfg)
model.setup()
train_data = HaNDataset(cfg)
train_loader = DataLoader(train_data, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=True)
train_data = fabric._setup_dataloader(train_loader)
接着创建优化器,代码为:
def configure_opt(cfg: Box, model: Model): def lr_lambda(step): if step < cfg.opt.warmup_steps: return step / cfg.opt.warmup_steps elif step < cfg.opt.steps[0]: return 1.0 elif step < cfg.opt.steps[1]: return 1 / cfg.opt.decay_factor else: return 1 / (cfg.opt.decay_factor**2) optimizer = torch.optim.Adam(model.model.parameters(), lr=cfg.opt.learning_rate, weight_decay=cfg.opt.weight_decay) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) return optimizer, scheduler optimizer, scheduler = configure_opt(cfg, model) model, optimizer = fabric.setup(model, optimizer)
最后遍历数据集进行训练,这里使用的损失函数有Focal loss,Dice loss和IoU loss,代码为:
def train_sam( cfg: Box, fabric: L.Fabric, model: Model, optimizer: _FabricOptimizer, scheduler: _FabricOptimizer, train_dataloader: DataLoader, ) : """The SAM training loop.""" focal_loss = FocalLoss() dice_loss = DiceLoss() # 从上次中断的地方训练 start_epoch = 1 if cfg.resume: map_location = 'cuda:%d' % fabric.global_rank checkpoint = torch.load(cfg.resume, map_location={'cuda:0': map_location}) start_epoch = checkpoint['epoch'] network = checkpoint['network'] opt = checkpoint['optimizer'] sche = checkpoint['scheduler'] model.model.load_state_dict(network) optimizer.load_state_dict(opt) scheduler.load_state_dict(sche) fabric.print(f"resume from:{cfg.resume}") for epoch in range(start_epoch, cfg.num_epochs): batch_time = AverageMeter(name="batch_time") data_time = AverageMeter(name="data_time") focal_losses = AverageMeter(name="focal_losses") dice_losses = AverageMeter(name="dice_losses") iou_losses = AverageMeter(name="iou_losses") total_losses = AverageMeter(name="total_losses") end = time.time() # 保存模型 if epoch % cfg.save_interval == 0: fabric.print(f"Saving checkpoint to {cfg.out_dir}") state_dict = model.model.state_dict() checkpoint = { 'epoch': epoch, 'network': state_dict, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict() } # 多卡环境下只在rank=0的gpu上保存 if fabric.global_rank == 0: torch.save(checkpoint, os.path.join(cfg.out_dir, f"epoch-{epoch:06d}-ckpt.pth")) for iter, data in enumerate(train_dataloader): data_time.update(time.time() - end) images = data["image"] gt_masks = data["label"] bboxes = data["bbox"] batch_size = images.shape[0] pred_masks, iou_predictions = model(images, bboxes, data["point_coords"], data["point_labels"]) num_masks = sum(len(pred_mask) for pred_mask in pred_masks) loss_focal = torch.tensor(0., device=fabric.device) loss_dice = torch.tensor(0., device=fabric.device) loss_iou = torch.tensor(0., device=fabric.device) for pred_mask, gt_mask, iou_prediction in zip(pred_masks, gt_masks, iou_predictions): batch_iou = calc_iou(pred_mask, gt_mask) loss_focal += focal_loss(pred_mask, gt_mask, num_masks) loss_dice += dice_loss(pred_mask, gt_mask, num_masks) loss_iou += F.mse_loss(iou_prediction, batch_iou, reduction='sum') / num_masks loss_total = 20. * loss_focal + loss_dice + loss_iou optimizer.zero_grad() fabric.backward(loss_total) optimizer.step() scheduler.step() batch_time.update(time.time() - end) end = time.time() focal_losses.update(loss_focal.item(), batch_size) dice_losses.update(loss_dice.item(), batch_size) iou_losses.update(loss_iou.item(), batch_size) total_losses.update(loss_total.item(), batch_size) fabric.print(f'Epoch: [{epoch}][{iter+1}/{len(train_dataloader)}]' f' | Time [{batch_time.val:.3f}s ({batch_time.avg:.3f}s)]' f' | Data [{data_time.val:.3f}s ({data_time.avg:.3f}s)]' f' | Focal Loss [{focal_losses.val:.4f} ({focal_losses.avg:.4f})]' f' | Dice Loss [{dice_losses.val:.4f} ({dice_losses.avg:.4f})]' f' | IoU Loss [{iou_losses.val:.4f} ({iou_losses.avg:.4f})]' f' | Total Loss [{total_losses.val:.4f} ({total_losses.avg:.4f})]')
通过以上步骤就可以对SAM进行微调了,如果是对mask decoder进行微调,显存占用大概在17G左右。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。