赞
踩
利用人工智能解决音频、视觉和语言问题。音频分类、图像分类、物体检测、问答、总结、文本分类、翻译等均有大量模型进行参考。
图像分类是为整个图像分配标签或类别的任务。每张图像预计只有一个类别。图像分类模型将图像作为输入并返回有关图像所属类别的预测
借助该transformers库,可以使用image-classification管道来推断图像分类模型。在不提供模型ID时,默认使用google/vit-base-patch16-224进行初始化pipeline。 调用pipeline管道时,只需要指定路径、http链接或PIL(Python Imaging Library)中加载的图标;还可以提供一个top_k参数来确定应返回多少结果
如何像对句子标记一样对图像进行标记,以便将其传递到Transformer模型进行训练。
首先安装软件包
pip install datasets transformers Pillow
加载数据集
使用beans数据集,是健康和非健康豆叶的图片集合
from datasets import load_dataset ds = load_dataset('beans') //DatasetDict({ // train: Dataset({ // features: ['image_file_path', 'image', 'labels'], // num_rows: 1034 // }) // validation: Dataset({ // features: ['image_file_path', 'image', 'labels'], // num_rows: 133 // }) // test: Dataset({ // features: ['image_file_path', 'image', 'labels'], // num_rows: 128 // }) //})
每个数据集中每个示例都有3个特征:
{
'image': <PIL.JpegImagePlugin ...>,
'image_file_path': '/root/.cache/.../bean_rust_train.4.jpg',
'labels': 1
}
ex = ds['train'][400]
image = ex['image']
由于'labels'
该数据集的特征是 datasets.features.ClassLabel
,我们可以使用它来查找本示例的标签 ID 的相应名称
labels = ds['train'].features['labels']
// ClassLabel(num_classes=3, names=['angular_leaf_spot', 'bean_rust', 'healthy'], names_file=None, id=None)
使用int2str
函数来打印示例的类标签
labels.int2str(ex['labels'])
// 'bean_rust'
上面图片叶子感染了“豆锈病”,是一种豆科植物的严重疾病
编写一个函数显示每个类的示例网格:
import random from PIL import ImageDraw, ImageFont, Image def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)): w, h = size labels = ds['train'].features['labels'].names grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h)) draw = ImageDraw.Draw(grid) font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24) for label_id, label in enumerate(labels): # Filter the dataset by a single label, shuffle it, and grab a few samples ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class)) # Plot this label's examples along a row for i, example in enumerate(ds_slice): image = example['image'] idx = examples_per_class * label_id + i box = (idx % examples_per_class * w, idx // examples_per_class * h) grid.paste(image.resize(size), box=box) draw.text(box, label, (255, 255, 255), font=font) return grid show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)
加载ViT特征提取器
现在知道图像是什么样子,并且更好地理解我们要解决的问题。让我们看看如何为我们的模型准备这些图像!
当训练 ViT 模型时,特定的转换将应用于输入到其中的图像。对图像使用错误的转换,模型将无法理解它所看到的内容!
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。