赞
踩
from dataset import GazeCaptureDataset from transformers import TrainingArguments from transformers import DeiTForImageClassification from torch import nn from transformers import Trainer from transformers import DeiTConfig # 数据集根路径 root_path = r"D:\datasets\GazeCapture_new" # 1. 定义 Dataset test_dataset = GazeCaptureDataset(root_path, data_type='test') # 2. 定义 DeiT 图像模型 configuration = DeiTConfig(num_labels=2, problem_type="regression") model = DeiTForImageClassification(configuration).from_pretrained('gaze_trainer/checkpoint-500') # 3. 测试 ## 3.1 定义测试参数 testing_args = TrainingArguments(output_dir="pred_trainer") ## 3.2 自定义 Trainer class CustomTester(Trainer): # 重写计算 loss 的函数 def compute_loss(self, model, inputs, return_outputs=False): # 获取标签值 labels = inputs.get("labels") # 获取输入值 x = inputs.get("pixel_values") # 模型输出值 outputs = model(x) logits = outputs.get('logits') # 定义损失函数为平滑 L1 损失 loss_fct = nn.SmoothL1Loss() # 计算输出值和标签的损失 loss = loss_fct(logits, labels) return (loss, outputs) if return_outputs else loss ## 3.3 定义 Trainer 对象 tester = CustomTester( model=model, args=testing_args, ) ## 3.4 调用 predict 方法,开始测试 output = tester.predict(test_dataset=test_dataset) # 4. 测试结果 print(output)
dataset.py
代码如下:
import os.path from torch.utils.data import Dataset from transform import transform import numpy as np # 读取数据,如果是训练数据,随即打乱数据顺序 def get_label_list(label_path): # 存储所有标签文件中的所有内容 full_lines = [] # 获取所有标签文件的名称,如 00002.label, 00003.label, ...... label_names = os.listdir(label_path) # 遍历每一个标签文件,并读取其中内容 for label_name in label_names: # 标签文件全路径,如 D:\datasets\GazeCapture_new\Label\train\00002.label label_abs_path = os.path.join(label_path, label_name) # 读取每一个标签文件中的内容 with open(label_abs_path) as flist: # 存储该标签文件中的所有内容 full_line = [] for line in flist: full_line.append(line.strip()) # 移除首行表头 'Face Left Right Grid Xcam, Ycam Xdot, Ydot Device' full_line.pop(0) full_lines.extend(full_line) return full_lines class GazeCaptureDataset(Dataset): def __init__(self, root_path, data_type): self.data_dir = root_path # 标签文件的根路径,如 D:\datasets\GazeCapture_new\Label\train label_root_path = os.path.join(root_path + '/Label', data_type) # 获取所有标签文件中的所有内容 self.full_lines = get_label_list(label_root_path) # 每一行内容的分隔符 self.delimiter = ' ' # 数据集长度,也就是一共有多少个图片 self.num_samples = len(self.full_lines) def __len__(self): return self.num_samples def __getitem__(self, idx): # 标签文件的一行,对应一个训练实例 line = self.full_lines[idx] # 将标签文件中的一行内容按照分隔符进行分割 Face, Left, Right, Grid, XYcam, XYdot, Device = line.split(self.delimiter) # 获取网络的输入:人脸图片 face_path = os.path.join(self.data_dir + '/Image/', Face) # 读取人脸图像 with open(face_path, 'rb') as f: img = f.read() # 将人脸图像进行格式转化:缩放、裁剪、标准化 pixel_values = transform(img) # 获取标签值 labels = np.array(XYcam.split(","), np.float32) # 注意返回值的形式一定要是 {"labels": xxx, "pixel_values": xxx} result = {"labels": labels} result["pixel_values"] = pixel_values return result
输出结果如下:
***** Running Prediction ***** Num examples = 1716 Batch size = 8 100%|██████████| 215/215 [01:52<00:00, 1.90it/s] PredictionOutput(predictions=array([[-2.309026 , -2.752627 ], [-2.0178156, -3.0546618], [-1.8222798, -3.309564 ], ..., [-2.6463585, -2.3462727], [-2.2149038, -2.7406967], [-1.7267275, -3.3450181]], dtype=float32), label_ids=array([[ 0.969375, -7.525975], [ 0.969375, -7.525975], [ 0.969375, -7.525975], ..., [ 5.5845 , 1.93875 ], [ 5.5845 , 1.93875 ], [ 5.5845 , 1.93875 ]], dtype=float32), metrics={'test_loss': 2.8067691326141357, 'test_runtime': 118.2811, 'test_samples_per_second': 14.508, 'test_steps_per_second': 1.818})
可以看到该模型在测试集的损失值是 2.8067691326141357
。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。