当前位置:   article > 正文

【深度学习总结_02】在自己的数据集微调SAM_sam训练自己数据集

sam训练自己数据集

【深度学习总结_02】在自己的数据集微调SAM

前言

SAM (Segment Anything Model)是Meta AI开发的一种分割模型。它被认为是计算机视觉的第一个基础模型。SAM是在包含数百万图像和数十亿mask的庞大数据语料库上进行训练的,这使得它非常强大。SAM能够为各种各样的图像生成准确的分割mask。

SAM通常在自然图像上表现优异,但是在特定领域,如医疗影响,遥感图像等,由于训练数据集缺乏这些数据,SAM的效果并不是理想。因此,在特定数据集上微调SAM是十分有必要的。

准备工作

(1)安装好segment anything:

git clone https://github.com/facebookresearch/segment-anything.git
cd segment-anything
python setup.py install
  • 1
  • 2
  • 3

(2)安装lightning包,它是轻量级的PyTorch库,用于高性能人工智能研究的轻量级PyTorch包装器。本文基于它对SAM进行微调:

pip install lightning
  • 1

使用的数据集下载地址:https://han-seg2023.grand-challenge.org/,它是一个多器官的医疗影像数据集,当然,你也可以使用自己的数据集

步骤

1、创建配置文件

该配置文件含有SAM的哪些部分需要训练,以及数据集的相关配置,如数据集位置,具体配置如下(在config.py文件中):

from box import Box
config = {
    "num_devices": 1,
    "batch_size": 6,
    "num_workers": 4,
    "num_epochs": 20,
    "save_interval": 2,
    "resume": None,
    "out_dir": "模型权重输出地址",
    "opt": {
        "learning_rate": 8e-4,
        "weight_decay": 1e-4,
        "decay_factor": 10,
        "steps": [60000, 86666],
        "warmup_steps": 250,
    },
    "model": {
        "type": 'vit_b',
        "checkpoint": "SAM的权重地址",
        "freeze": {
            "image_encoder": True,
            "prompt_encoder": True,
            "mask_decoder": True,
        },
    },
    "dataset": {
        "root_dir": "数据集的根目录",
        "sample_num": 4,
        "target_size": 1024
    }
}
cfg = Box(config)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

其中freeze部分决定SAM的哪些部分冷却不用训练,dataset则是数据集的相关配置,sample_num表示采样的point的数目,target_size则是输入SAM的图片大小。

这里使用了box这个包,可以通过如下命令安装:

pip install python-box
  • 1

2、构建数据集

该部分负责在数据集加载的时候选择哪些数据进行训练,这里我选择器官mandible进行训练。

同时由于该数据是3D数据,对数据进行切片处理,将3D数据变成2D图像,该部分代码为:

class HaNDataset(Dataset):
    def __init__(self, cfg):
        super().__init__()
        self.gt_path = os.path.join(cfg.dataset.root_dir, "oar_3d")
        self.img_path = os.path.join(cfg.dataset.root_dir, "ct_3d")
        # 文件列表
        self.img_file_list = sorted(os.listdir(self.img_path))
        self.gt_file_list = sorted(os.listdir(self.gt_path))
        # 器官类别
        self.category = [7]
        self.cat2names = {7 : "mandible"}
        # 数据列表,含所有切片
        self.data_list = []
        for i in range(len(self.img_file_list)):
            img_file_path = os.path.join(self.img_path, self.img_file_list[i])
            gt_file_path = os.path.join(self.gt_path, self.gt_file_list[i])
            img_data = nib.load(img_file_path).get_fdata()
            gt_data = nib.load(gt_file_path).get_fdata()
            axial_num = img_data.shape[2]
            for a in range(axial_num):
                a_gt_data = gt_data[:, :, a]
                ps_gt_data = np.zeros_like(a_gt_data)
                for c in self.category:
                    region = (a_gt_data == c)
                    if np.sum(region) > 0:
                        self.data_list.append([i, a, c])
        print(f"Data size is:{len(self.data_list)}")
        # 输入SAM的尺寸要是这个
        self.target_size = cfg.dataset.target_size
        # 正负样本点数目
        self.sample_point_num = cfg.dataset.sample_num
    def __len__(self):
        return len(self.data_list)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

由于HaN这个数据集的数据格式是nii文件,其数据的范围是0-2000,而图像的数据范围是0-255,因此需要将数据范围截断并重新映射。

输入SAM的图像大小应为1024*1024,因此需要将其resize成目标尺寸。

除此之外,由于HaN并没有提供box和point提示,因此还需要从mask中自动获得相应的提示。

这些部分的实现为(都在HaNDataset当中):

def convert_to_three_channels(self, image):
    # 创建一个具有相同尺寸的3通道图像数组
    three_channel_image 
= np.zeros((image.shape[
0
], image.shape[
1
], 
3
), dtype=np.uint8)
    # 将原始单通道图像复制到每个通道
    for i in range(3):
        three_channel_image[:, :, i] 
= image
    return three_channel_image
def __getitem__(self, idx):
    data_id = self.data_list[idx]
    f_id = data_id[0]
    axial_id = data_id[1]
    category_id = data_id[2]
    name = self.cat2names[category_id]
    img_data_path = os.path.join(self.img_path, self.img_file_list[f_id])
    gt_data_path = os.path.join(self.gt_path, self.gt_file_list[f_id])
    # nii文件的数据范围是0-2000,和图像的范围不符
    img_data = nib.load(img_data_path).get_fdata()
    # 截断,对于ct图像
    img_data[img_data < (50 + 1024 - 200)] = (50 + 1024 - 200)
    img_data[img_data > (50 + 1024 + 200)] = (50 + 1024 + 200)
    img_data = (img_data - (50 + 1024 - 200)) / 400.0 * 255.0
    img_data = img_data[:, :, axial_id]
    img_data = self.convert_to_three_channels(img_data)
    all_gt_data = nib.load(gt_data_path).get_fdata()[:, :, axial_id]
    gt_data = np.zeros_like(all_gt_data)
    gt_data[all_gt_data == category_id] = 1
    # 将image和gt变为target size
    org_size = gt_data.shape
    transforms = train_transforms(self.target_size, org_size[0], org_size[1])
    augments = transforms(image=img_data, mask=gt_data)
    img_data, gt_data = augments['image'].to(torch.float32), augments['mask'].to(torch.int64)
    # 获得box,验证时max_pixel为0
    bbox_data = get_boxes_from_mask(gt_data, max_pixel=0)[0]
    # 获得point提示
    point_coords, point_labels = init_point_sampling(gt_data, self.sample_point_num)
    return {
        "org_size": torch.tensor(org_size),
        "category" : name,
        "image": img_data,
        "label" : gt_data,
        "bbox" : bbox_data,
        "point_coords": point_coords,
        "point_labels": point_labels
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52

获得box和point,以及resize图像的代码为:

def init_point_sampling(mask, get_point=1):
    if isinstance(mask, torch.Tensor):
        mask 
= mask.numpy()
    # Get coordinates of black/white pixels
    fg_coords = np.argwhere(mask == 1)[:, ::-1]
    bg_coords = np.argwhere(mask == 0)[:, ::-1]
    fg_size = len(fg_coords)
    bg_size = len(bg_coords)
    if get_point == 1:
        if fg_size > 0:
            index = np.random.randint(fg_size)
            fg_coord = fg_coords[index]
            label = 1
        else:
            index = np.random.randint(bg_size)
            fg_coord = bg_coords[index]
            label = 0
        return torch.as_tensor([fg_coord.tolist()], dtype=torch.float), torch.as_tensor([label], dtype=torch.int)
    else:
        num_fg = get_point // 2
        num_bg = get_point - num_fg
        fg_indices = np.random.choice(fg_size, size=num_fg, replace=True)
        bg_indices = np.random.choice(bg_size, size=num_bg, replace=True)
        fg_coords = fg_coords[fg_indices]
        bg_coords = bg_coords[bg_indices]
        coords = np.concatenate([fg_coords, bg_coords], axis=0)
        labels = np.concatenate([np.ones(num_fg), np.zeros(num_bg)]).astype(int)
        indices = np.random.permutation(get_point)
        coords, labels = torch.as_tensor(coords[indices], dtype=torch.float), torch.as_tensor(labels[indices],
                                                                                              dtype=torch.int)
        return coords, labels
def get_boxes_from_mask(mask, box_num=1, std=0.1, max_pixel=5):
    if isinstance(mask, torch.Tensor):
        mask = mask.numpy()
    label_img = label(mask)
    regions = regionprops(label_img)
    # Iterate through all regions and get the bounding box coordinates
    boxes = [tuple(region.bbox) for region in regions]
    # If the generated number of boxes is greater than the number of categories,
    # sort them by region area and select the top n regions
    if len(boxes) >= box_num:
        sorted_regions = sorted(regions, key=lambda x: x.area, reverse=True)[:box_num]
        boxes = [tuple(region.bbox) for region in sorted_regions]
    # If the generated number of boxes is less than the number of categories,
    # duplicate the existing boxes
    elif len(boxes) < box_num:
        num_duplicates = box_num - len(boxes)
        boxes += [boxes[i % len(boxes)] for i in range(num_duplicates)]
    # Perturb each bounding box with noise
    noise_boxes = []
    for box in boxes:
        y0, x0, y1, x1 = box
        width, height = abs(x1 - x0), abs(y1 - y0)
        # Calculate the standard deviation and maximum noise value
        noise_std = min(width, height) * std
        max_noise = min(max_pixel, int(noise_std * 5))
        # Add random noise to each coordinate
        try:
            noise_x = np.random.randint(-max_noise, max_noise)
        except:
            noise_x = 0
        try:
            noise_y = np.random.randint(-max_noise, max_noise)
        except:
            noise_y = 0
        x0, y0 = x0 + noise_x, y0 + noise_y
        x1, y1 = x1 + noise_x, y1 + noise_y
        noise_boxes.append((x0, y0, x1, y1))
    return torch.as_tensor(noise_boxes, dtype=torch.float)
def train_transforms(img_size, ori_h, ori_w):
    transforms = []
    transforms.append(A.Resize(int(img_size), int(img_size), interpolation=cv2.INTER_NEAREST))
    transforms.append(ToTensorV2(p=1.0))
    return A.Compose(transforms, p=1.)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75

3、构建SAM模型

因为我们已经安装好了segment anything,因此可以直接调用相关模块,然后组成一个生成mask的流程即可,该部分代码为:

import torch.nn as nn
import torch.nn.functional as F
from segment_anything import sam_model_registry
from segment_anything import SamPredictor
class Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
    def setup(self):
        self.model = sam_model_registry[self.cfg.model.type](checkpoint=self.cfg.model.checkpoint)
        self.model.train()
        if self.cfg.model.freeze.image_encoder:
            for name, param in self.model.image_encoder.named_parameters():
                param.requires_grad = False
        if self.cfg.model.freeze.prompt_encoder:
            for name, param in self.model.prompt_encoder.named_parameters():
                param.requires_grad = False
        # freeze mask decoder参数
        if self.cfg.model.freeze.mask_decoder:
            for name, param in self.model.mask_decoder.named_parameters():
                param.requires_grad = False
    def forward(self, images, bboxes, org_size, point_coords = None, point_labels = None):
        _, _, H, W = images.shape
        image_embeddings = self.model.image_encoder(images)
        pred_masks = []
        ious = []
        # 还要添加points,输入格式(points coords, points label): #coords:B,N,2  labels:B,N
        # 一个batch一个batch处理
        for embedding, bbox, coord, label in zip(image_embeddings, bboxes, point_coords, point_labels):
        
            bbox = bbox.unsqueeze(0)
            coord = coord.unsqueeze(0)
            label = label.unsqueeze(0)
            point = (coord, label)
            sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
                points=point,
                boxes=bbox,
                masks=None,
            )
            low_res_masks, iou_predictions = self.model.mask_decoder(
                image_embeddings=embedding.unsqueeze(0),
                image_pe=self.model.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )
            masks = F.interpolate(
                low_res_masks,
                (H, W),
                mode="bilinear",
                align_corners=False,
            )
            pred_masks.append(masks.squeeze(1))
            ious.append(iou_predictions)
        return pred_masks, ious
    def get_predictor(self):
        return SamPredictor(self.model)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57

其中setup方法决定哪些参数需要进行训练,哪些不用。

4、使用数据进行训练

首先使用lightning进行配置:

import lightning as L
from config import cfg
fabric = L.Fabric(accelerator="auto",
                      devices=cfg.num_devices,
                      strategy="auto",
                      loggers=[TensorBoardLogger(cfg.out_dir, name="lightning-sam")])
fabric.launch()
fabric.seed_everything(1337 + fabric.global_rank)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

然后创建模型和加载数据集,代码为:

with fabric.device:
    model = Model(cfg)
    model.setup()
train_data = HaNDataset(cfg)
train_loader = DataLoader(train_data, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=True)
train_data = fabric._setup_dataloader(train_loader)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

接着创建优化器,代码为:

def configure_opt(cfg: Box, model: Model):
    def lr_lambda(step):
        if step < cfg.opt.warmup_steps:
            return step / cfg.opt.warmup_steps
        elif step < cfg.opt.steps[0]:
            return 1.0
        elif step < cfg.opt.steps[1]:
            return 1 / cfg.opt.decay_factor
        else:
            return 1 / (cfg.opt.decay_factor**2)
    optimizer 
= torch.optim.Adam(model.model.parameters(), lr=cfg.opt.learning_rate, weight_decay=cfg.opt.weight_decay)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    return optimizer, scheduler
optimizer, scheduler = configure_opt(cfg, model)
model, optimizer = fabric.setup(model, optimizer)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

最后遍历数据集进行训练,这里使用的损失函数有Focal loss,Dice loss和IoU loss,代码为:

def train_sam(
    cfg: Box,
    fabric: L.Fabric,
    model: Model,
    optimizer: _FabricOptimizer,
    scheduler: _FabricOptimizer,
    train_dataloader: DataLoader,
)
:
    """The SAM training loop."""
    focal_loss 
= FocalLoss()
    dice_loss = DiceLoss()
    # 从上次中断的地方训练
    start_epoch = 1
    if cfg.resume:
        map_location = 'cuda:%d' % fabric.global_rank
        checkpoint = torch.load(cfg.resume, map_location={'cuda:0': map_location})
        start_epoch = checkpoint['epoch']
        network = checkpoint['network']
        opt = checkpoint['optimizer']
        sche = checkpoint['scheduler']
        model.model.load_state_dict(network)
        optimizer.load_state_dict(opt)
        scheduler.load_state_dict(sche)
        fabric.print(f"resume from:{cfg.resume}")
    for epoch in range(start_epoch, cfg.num_epochs):
        batch_time = AverageMeter(name="batch_time")
        data_time = AverageMeter(name="data_time")
        focal_losses = AverageMeter(name="focal_losses")
        dice_losses = AverageMeter(name="dice_losses")
        iou_losses = AverageMeter(name="iou_losses")
        total_losses = AverageMeter(name="total_losses")
        end = time.time()
        # 保存模型
        if epoch % cfg.save_interval == 0:
            fabric.print(f"Saving checkpoint to {cfg.out_dir}")
            state_dict = model.model.state_dict()
            checkpoint = {
                'epoch': epoch,
                'network': state_dict,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            }
            # 多卡环境下只在rank=0的gpu上保存
            if fabric.global_rank == 0:
                torch.save(checkpoint, os.path.join(cfg.out_dir, f"epoch-{epoch:06d}-ckpt.pth"))
        for iter, data in enumerate(train_dataloader):
            data_time.update(time.time() - end)
            images = data["image"]
            gt_masks = data["label"]
            bboxes = data["bbox"]
            batch_size = images.shape[0]
            pred_masks, iou_predictions = model(images, bboxes, data["point_coords"], data["point_labels"])
            num_masks = sum(len(pred_mask) for pred_mask in pred_masks)
            loss_focal = torch.tensor(0., device=fabric.device)
            loss_dice = torch.tensor(0., device=fabric.device)
            loss_iou = torch.tensor(0., device=fabric.device)
            for pred_mask, gt_mask, iou_prediction in zip(pred_masks, gt_masks, iou_predictions):
                batch_iou = calc_iou(pred_mask, gt_mask)
                loss_focal += focal_loss(pred_mask, gt_mask, num_masks)
                loss_dice += dice_loss(pred_mask, gt_mask, num_masks)
                loss_iou += F.mse_loss(iou_prediction, batch_iou, reduction='sum') / num_masks
            loss_total = 20. * loss_focal + loss_dice + loss_iou
            optimizer.zero_grad()
            fabric.backward(loss_total)
            optimizer.step()
            scheduler.step()
            batch_time.update(time.time() - end)
            end = time.time()
            focal_losses.update(loss_focal.item(), batch_size)
            dice_losses.update(loss_dice.item(), batch_size)
            iou_losses.update(loss_iou.item(), batch_size)
            total_losses.update(loss_total.item(), batch_size)
            fabric.print(f'Epoch: [{epoch}][{iter+1}/{len(train_dataloader)}]'
                         f' | Time [{batch_time.val:.3f}s ({batch_time.avg:.3f}s)]'
                         f' | Data [{data_time.val:.3f}s ({data_time.avg:.3f}s)]'
                         f' | Focal Loss [{focal_losses.val:.4f} ({focal_losses.avg:.4f})]'
                         f' | Dice Loss [{dice_losses.val:.4f} ({dice_losses.avg:.4f})]'
                         f' | IoU Loss [{iou_losses.val:.4f} ({iou_losses.avg:.4f})]'
                         f' | Total Loss [{total_losses.val:.4f} ({total_losses.avg:.4f})]')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81

通过以上步骤就可以对SAM进行微调了,如果是对mask decoder进行微调,显存占用大概在17G左右。

参考链接

lightning-sam

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/629053
推荐阅读
相关标签
  

闽ICP备14008679号