赞
踩
该代码需要 python>=3.8,以及 pytorch>=1.7 和 torchvision>=0.8。 请按照此处的说明安装 PyTorch 和 TorchVision 依赖项。 强烈建议安装支持 CUDA 的 PyTorch 和 TorchVision。
安装 Segment Anything:
pip install git+https://github.com/facebookresearch/segment-anything.git
或在本地克隆存储库并安装
- git clone git@github.com:facebookresearch/segment-anything.git
- cd segment-anything; pip install -e .
以下可选依赖项对于掩模后处理、以 COCO 格式保存掩模、示例笔记本以及以 ONNX 格式导出模型是必需的。 运行示例笔记本还需要 jupyter。
pip install opencv-python pycocotools matplotlib onnxruntime onnx
相关代码如下
- import numpy as np
- import torch
- import matplotlib.pyplot as plt
- import cv2
- # 添加掩码
- def show_anns(anns):
- if len(anns) == 0:
- return
- sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
- ax = plt.gca()
- ax.set_autoscale_on(False)
-
- img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
- img[:,:,3] = 0
- for ann in sorted_anns[10:200]:
- m = ann['segmentation']
- color_mask = np.concatenate([np.random.random(3), [0.35]])
- img[m] = color_mask
- ax.imshow(img)
-
- image = cv2.imread('./images/05.jpg')
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
- print(image.shape)
-
-
- plt.figure(figsize=(20,20))
- plt.imshow(image)
- plt.axis('off')
- plt.show()
-
- import sys
- sys.path.append("..")
- from mobile_sam import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
-
- sam_checkpoint = "../weights/mobile_sam.pt"
- model_type = "vit_t"
-
- device = "cuda" if torch.cuda.is_available() else "cpu"
-
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
- sam.to(device=device)
- sam.eval()
-
- mask_generator = SamAutomaticMaskGenerator(sam)
-
- masks = mask_generator.generate(image)
-
- plt.figure(figsize=(20,20))
- plt.imshow(image)
- show_anns(masks[:])
- plt.axis('off')
- plt.show()
-
- # 保存掩码
- def save_mask(anns):
- if len(anns) == 0:
- return
- sorted_anns = sorted(anns, key=(lambda x :x['area']), reverse=False)
- img = np.zeros((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 1))
-
- for ann in sorted_anns[:]:
- m = ann['segmentation']
- img[m] = 255
-
- cv2.imwrite('res.jpg', img)
-
-
- # save_mask(masks)
- sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
- save_mask(sorted_anns[:])
接蒙版代码
- # 获取边缘
- import cv2
- import numpy as np
- image = cv2.imread('./images/05.jpg')
- # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5))
- img_ = np.zeros_like(image)
- gray_images = mask_show(masks[:])
- for img in gray_images[:]:
- gray_image = np.uint8(img)
- gray_image = cv2.morphologyEx(gray_image,cv2.MORPH_OPEN,kernel)
- edges = cv2.Canny(gray_image, 50, 150)
- contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
- cv2.drawContours(img_, contours, -1, (255, 255, 255), 2)
- cv2.imwrite("counte2.png", img_)
- # 蒙版-边缘
- im = cv2.imread('images/05.jpg', cv2.IMREAD_GRAYSCALE)
- image1 = cv2.imread('res.jpg', cv2.IMREAD_GRAYSCALE)
- image2 = cv2.imread('counte2.png', cv2.IMREAD_GRAYSCALE)
- img = cv2.subtract(image1, image2)
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(7,7))
- dst2 = cv2.morphologyEx(img,cv2.MORPH_OPEN,kernel)
-
- # print(dst2.shape)
-
- re_img = cv2.addWeighted(dst2, 0.2, im, 0.8 ,0)
- cv2.imwrite("res3.jpg", dst2)
-
- plt.figure(figsize=(20,20))
- plt.imshow(dst2, cmap='gray')
- plt.axis('off')
- plt.show()
-
- # 以COCO格式存储
- import json
- orig_img = cv2.imread('./images/05.jpg')
- image = cv2.imread('res3.jpg')
- edges = cv2.Canny(image, 50, 150)
- contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
- images = [
- {
- 'file_name':'05.jpg',
- 'height': int(orig_img.shape[0]),
- 'width' : int(orig_img.shape[1]),
- 'id': 1
- },
- ]
-
- categories = [
- {
- 'id': 1,
- 'name': 'qituan'
- },
- ]
- annotations= []
- for contour in contours:
- seg = []
- contour_ = contour.squeeze(1)
- seg.append(list(contour_.flatten().tolist()))
- x, y, w, h = cv2.boundingRect(contour)
- bbox = [x, y, w, h]
- area = cv2.contourArea(contour)
- iscrowd = 0
- image_id = 1
- category_id = 1
- id = len(annotations) + 1
- annotations.append({
- 'segmentation': seg,
- 'bbox': bbox,
- 'area': area,
- 'iscrowd': 0,
- 'image_id': 1,
- 'category_id': 1,
- 'id': id
- })
-
-
- coco_data = {
- 'images': images,
- 'annotations': annotations,
- 'categories': categories
- }
-
- print(coco_data)
-
- output_file_path = 'coco_data.json'
-
- # Serialize the data and write to a JSON file
- with open(output_file_path, 'w') as f:
- json.dump(coco_data, f, indent=4)
-
4. 验证COCO格式数据
-
- import cv2
- import random
- import json, os
- from pycocotools.coco import COCO
- from skimage import io
- from matplotlib import pyplot as plt
- import numpy as np
-
- train_json = 'coco_data.json'
- train_path = './images/'
- coco = COCO(train_json)
-
- list_imgIds = coco.getImgIds(catIds=1 )
- list_imgIds
-
- img = coco.loadImgs(list_imgIds[0])[0]
- image = cv2.imread(train_path + img['file_name']) # 读取图像
- img_annIds = coco.getAnnIds(imgIds=1, catIds=1, iscrowd=None)
- anns = coco.loadAnns(img_annIds)
- img = coco.loadImgs(list_imgIds[0])[0]
- img1 = cv2.imread(train_path + img['file_name']) # 读取图像
- #分割
- for ann in anns:
-
- data = np.array(ann['segmentation'][0])
- num_points = len(data) // 2
- contour_restored = data.reshape((num_points, 2))
- contour_restored = contour_restored.reshape(contour_restored.shape[0], 1, contour_restored.shape[1])
- # print(contour_restored.shape)
- color = np.random.randint(0, 255, 3).tolist()
- cv2.drawContours(img1, [contour_restored], -1, tuple(color), thickness=cv2.FILLED)
-
- # mask = coco.annToMask(ann)
- # color = np.random.randint(0, 255, 3) # Random color for each mask
- # img = cv2.addWeighted(img, 1, cv2.cvtColor(mask * 255, cv2.COLOR_GRAY2BGR), 0.5, 0)
-
- plt.rcParams['figure.figsize'] = (20.0, 20.0)
- # 此处的20.0是由于我的图片是2000*2000,目前还没去研究怎么利用plt自动分辨率。
- plt.imshow(img1)
- plt.show()
-
-
- img_annIds = coco.getAnnIds(imgIds=1, catIds=1, iscrowd=None)
- img_annIds
- # 目标检测
- for id in img_annIds[:]:
- ann = coco.loadAnns(id)[0]
- x, y, w, h = ann['bbox']
- # print(ann['bbox'])
- image1 = cv2.rectangle(image, (int(x), int(y)), (int(x + w), int(y + h)), (0, 255, 255), 2)
-
- plt.rcParams['figure.figsize'] = (20.0, 20.0)
- # 此处的20.0是由于我的图片是2000*2000,目前还没去研究怎么利用plt自动分辨率。
- plt.imshow(image1)
- plt.show()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。