赞
踩
代码地址: https://github.com/luca-medeiros/lang-segment-anything
https://github.com/IDEA-Research/GroundingDINO
lang-segment-anything是一个基于语言文本提示对图像中的物体进行分割的算法,结合了GroundingDINO和segment-anything两大算法,在半自动标注上有很好的应该场景,它可以对没有训练的物体进行分割,也就是可以完成文本到图像物体的匹配。
说明:经试验,lang-segment-anything在ubuntu18.04系统上成功部署,在windows系统上失败,建议在ubuntu系统上部署。
除了要下载lang-segment-anything的代码外,还需要下载GroundingDINO和segment-anything的一些配置文件和权重。
首先进入github官网下载lang-segment-anything的所有文件,可以git clone或者下载zip压缩包。
然后下载segment-anything需要的权重文件,有三种选择:l,b,h这里只下载了效果最好的h。
除此以外还要下载GroundingDINO的权重文件和对应的配置文件:有swinb和swint两种选择。
GroundingDINO用到了bert模型,因为在线下载模型很容易连接失败,所以最好也下载到本地,下载以下5个文件即可。
以上的文件都可以从百度网盘的链接获取:
https://pan.baidu.com/s/1iqFjmTdJrja1ilSoxnWw6w?pwd=ek6i
下载完成后将三个文件拷贝到lang-segmengt-anything-main下即可。
基本上按照github上操作即可,使用conda从yml文件创建虚拟环境:conda env create -f environment.yml,主要是安装torch、segment-anything和groundingdino,注意需要将environment.yml文件中groundingdino要修改为groundingdino-py,还有lang-sam包找不到,可以不安装。torch的安装可以按照以下代码执行:
pip install torch torchvision torchmetrics --index-url https://download.pytorch.org/whl/cu118
首先修改lang-sam下面的lang-sam.py文件,主要是添加了bulid_groundingdino_local和load_model_loacl方法用来加载本地groundingdino模型,注意权重文件和配置文件要对应。
- class LangSAM():
-
- def __init__(self, sam_type="vit_h", ckpt_path="sam_weight/sam_vit_h_4b8939.pth"):
- self.sam_type = sam_type
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- self.bulid_groundingdino_local()
- self.build_sam(ckpt_path)
- def bulid_groundingdino_local(self):
- ckpt_filename = "grounding_weight/groundingdino_swinb_cogcoor.pth"
- ckpt_config_filename = "grounding_weight/config/GroundingDINO_SwinB_cfg.py"
- self.groundingdino = load_model_loacl(ckpt_config_filename, ckpt_filename)
- def load_model_loacl(model_config_path, model_checkpoint_path, device="cuda"):
- try:
- args = SLConfig.fromfile(model_config_path)
- args.device = device
-
- model = build_model(args)
- checkpoint = torch.load(model_checkpoint_path, map_location='cpu')
- model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
- model.eval()
- model.to(device)
-
- except Exception as e:
- print(str(e))
- return model
除此之外还需要进入到pip安装的groundingdino.util.get_tokenlizer里面将代码做以下修改:主要修改了tokenizer和BertModel的加载方式,修改为加载离线加载。可以在代码中输入groundingdino.util.get_tokenlizer,然后按住ctrl键,鼠标左键点击get_tokenlizer快速跳转到代码。
- from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
- import os
-
- def get_tokenlizer(text_encoder_type):
- if not isinstance(text_encoder_type, str):
- # print("text_encoder_type is not a str")
- if hasattr(text_encoder_type, "text_encoder_type"):
- text_encoder_type = text_encoder_type.text_encoder_type
- elif text_encoder_type.get("text_encoder_type", False):
- text_encoder_type = text_encoder_type.get("text_encoder_type")
- elif os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type):
- pass
- else:
- raise ValueError(
- "Unknown type of text_encoder_type: {}".format(type(text_encoder_type))
- )
- print("final text_encoder_type: {}".format(text_encoder_type))
-
- # tokenizer = AutoTokenizer.from_pretrained(text_encoder_type)
- tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
- return tokenizer
-
-
- def get_pretrained_language_model(text_encoder_type):
- if text_encoder_type == "bert-base-uncased" or (os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type)):
- # return BertModel.from_pretrained(text_encoder_type)
- return BertModel.from_pretrained('bert-base-uncased')
-
- if text_encoder_type == "roberta-base":
- return RobertaModel.from_pretrained(text_encoder_type)
-
- raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type))
运行下列代码,首先会提示是否更新lightning,要选择否,更新的版本不行。
lightning run app app.py
运行代码之后,会在网页弹出以下窗口,上传图片之后,就能对图片进行文本提示的分割了,下面文本提示输入person,在图像中将person对应的人给分割出来了。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。