当前位置:   article > 正文

Hugging(transformers)读取自定义 checkpoint、使用 Trainer 进行测试回归任务_transformers 数据读取

transformers 数据读取

需求说明

  1. 训练过程在 “HuggingFace (transformers) 自定义图像数据集、使用 DeiT 模型、Trainer 进行训练回归任务” 中已经描述过。
  2. 训练结束后,会生成如下的 checkpoints 文件:
    在这里插入图片描述
  3. 现在想用 checkpoint-500 中保存的模型进行预测,看它在测试集上的效果怎么样,即损失值是多少。

需求解决

关键代码

from dataset import GazeCaptureDataset
from transformers import TrainingArguments
from transformers import DeiTForImageClassification
from torch import nn
from transformers import Trainer
from transformers import DeiTConfig

# 数据集根路径
root_path = r"D:\datasets\GazeCapture_new"
# 1. 定义 Dataset
test_dataset = GazeCaptureDataset(root_path, data_type='test')

# 2. 定义 DeiT 图像模型
configuration = DeiTConfig(num_labels=2, problem_type="regression")
model = DeiTForImageClassification(configuration).from_pretrained('gaze_trainer/checkpoint-500')

# 3. 测试
## 3.1 定义测试参数
testing_args = TrainingArguments(output_dir="pred_trainer")

## 3.2 自定义 Trainer
class CustomTester(Trainer):
    # 重写计算 loss 的函数
    def compute_loss(self, model, inputs, return_outputs=False):
        # 获取标签值
        labels = inputs.get("labels")
        # 获取输入值
        x = inputs.get("pixel_values")
        # 模型输出值
        outputs = model(x)
        logits = outputs.get('logits')
        # 定义损失函数为平滑 L1 损失
        loss_fct = nn.SmoothL1Loss()
        # 计算输出值和标签的损失
        loss = loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss

## 3.3 定义 Trainer 对象
tester = CustomTester(
    model=model,
    args=testing_args,
)

## 3.4 调用 predict 方法,开始测试
output = tester.predict(test_dataset=test_dataset)

# 4. 测试结果
print(output)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

Dataset

dataset.py 代码如下:

import os.path

from torch.utils.data import Dataset
from transform import transform
import numpy as np

# 读取数据,如果是训练数据,随即打乱数据顺序
def get_label_list(label_path):
    # 存储所有标签文件中的所有内容
    full_lines = []
    # 获取所有标签文件的名称,如 00002.label, 00003.label, ......
    label_names = os.listdir(label_path)
    # 遍历每一个标签文件,并读取其中内容
    for label_name in label_names:
        # 标签文件全路径,如 D:\datasets\GazeCapture_new\Label\train\00002.label
        label_abs_path = os.path.join(label_path, label_name)
        # 读取每一个标签文件中的内容
        with open(label_abs_path) as flist:
            # 存储该标签文件中的所有内容
            full_line = []
            for line in flist:
                full_line.append(line.strip())
            # 移除首行表头 'Face Left Right Grid Xcam, Ycam Xdot, Ydot Device'
            full_line.pop(0)
            full_lines.extend(full_line)
    return full_lines


class GazeCaptureDataset(Dataset):
    def __init__(self, root_path, data_type):
        self.data_dir = root_path
        # 标签文件的根路径,如 D:\datasets\GazeCapture_new\Label\train
        label_root_path = os.path.join(root_path + '/Label', data_type)
        # 获取所有标签文件中的所有内容
        self.full_lines = get_label_list(label_root_path)
        # 每一行内容的分隔符
        self.delimiter = ' '
        # 数据集长度,也就是一共有多少个图片
        self.num_samples = len(self.full_lines)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # 标签文件的一行,对应一个训练实例
        line = self.full_lines[idx]
        # 将标签文件中的一行内容按照分隔符进行分割
        Face, Left, Right, Grid, XYcam, XYdot, Device = line.split(self.delimiter)
        # 获取网络的输入:人脸图片
        face_path = os.path.join(self.data_dir + '/Image/', Face)
        # 读取人脸图像
        with open(face_path, 'rb') as f:
            img = f.read()
        # 将人脸图像进行格式转化:缩放、裁剪、标准化
        pixel_values = transform(img)
        # 获取标签值
        labels = np.array(XYcam.split(","), np.float32)
        # 注意返回值的形式一定要是 {"labels": xxx, "pixel_values": xxx}
        result = {"labels": labels}
        result["pixel_values"] = pixel_values
        return result

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62

输出结果如下:

***** Running Prediction *****
  Num examples = 1716
  Batch size = 8
100%|██████████| 215/215 [01:52<00:00,  1.90it/s]
PredictionOutput(predictions=array([[-2.309026 , -2.752627 ],
       [-2.0178156, -3.0546618],
       [-1.8222798, -3.309564 ],
       ...,
       [-2.6463585, -2.3462727],
       [-2.2149038, -2.7406967],
       [-1.7267275, -3.3450181]], dtype=float32), label_ids=array([[ 0.969375, -7.525975],
       [ 0.969375, -7.525975],
       [ 0.969375, -7.525975],
       ...,
       [ 5.5845  ,  1.93875 ],
       [ 5.5845  ,  1.93875 ],
       [ 5.5845  ,  1.93875 ]], dtype=float32), metrics={'test_loss': 2.8067691326141357, 'test_runtime': 118.2811, 'test_samples_per_second': 14.508, 'test_steps_per_second': 1.818})
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

可以看到该模型在测试集的损失值是 2.8067691326141357

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/349840
推荐阅读
相关标签
  

闽ICP备14008679号