当前位置:   article > 正文

使用Slowfast模型训练自己的数据_slowfast算法需要预处理吗

slowfast算法需要预处理吗


  往期的博客介绍了视频数据、视频数据的信息获取、加载和处理,这期博客我们主要介绍Slowfast模型、写一个数据加载器,将预处理好的数据加载进Slowfast模型做训练。

Slowfast模型介绍

  Slowfast模型是何凯明团队于19年的工作,不同以往的双流(光流和视频),Slowfast的双流(都是视频数据)卷积输入,一个慢通道,一个快通道,分别提取空域信息以及时域信息。Slowfast模型论文地址,模型的效果如下图:
Slowfast模型与其他模型效果对比
  上图的对比结果我们可以看出,Slowfast模型在19年以较低的计算量和较高的准确度取得不错的排名,我手上的一些训练任务也是使用的Slowfast模型做训练,因此我想先引入Slowfast模型,循序渐进学习更多的新模型。

编写数据加载器

  数据的训练需要不断地加载数据,pytorch提供了一个数据加载接口,我们只需小小修改就可以使用了。有一篇博文讲解的很不错,博友们可以去看看:迄今为止最细致的DataSet和Dataloader加载步骤(肝)

from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, txt, load_type="test"):
        self.load_type = load_type
        with open(txt, "r") as file:
            self.data = file.readlines()

    def __getitem__(self, idx):
        video_path = self.data[idx].strip()
        data = np.load(video_path)
        # 数据的保存样式[20, 224*223*3 + label]
        # 加载20张图片数据
        imgs = data[:, :-1].reshape(20, 224, 224, 3)
        label = data[:, -1:].reshape(-1)
        return imgs, label[0]

    def __len__(self):
        return len(self.data)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

加载数据做训练

  我们知道数据的加载后就可以做训练了 (底下的代码可能需要有一定的知识储备,如果你有看不懂的地方可以在评论区提出,我会尽可能解惑) 。深度学习三要素:处理数据、训练数据、评估部署模型,下面的代码就是模型训练代码了。

"""
训练SlowFast模型
"""
from torch.utils.data import Dataset, DataLoader
# 加载模型
from models.model import COMP_F
import torch
import numpy as np
from tqdm import tqdm
from time import sleep
import os
torch.backends.cudnn.benchmark = True


def save_matrix(file, epoch, train_matrix, test_matrix, mode="a+"):
	"""
	保存混淆矩阵
	"""
    with open(file, mode) as f:
        f.writelines(f"epoch: {epoch}\n")
        f.writelines(f"train_matrix: {train_matrix}\n")
        f.writelines(f"test_matrix: {test_matrix}\n")


def save_acc(file, epoch, acc, loss, lr, mode="a+"):
	"""
	保存训练准确度、学习率等信息
	"""
    with open(file, mode) as f:
        f.writelines(f"{epoch}\n")
        f.writelines(f"{acc}\n")
        f.writelines(f"{loss}\n")
        f.writelines(f"{lr}\n")


def calculate_matrix(predicted, label, matrix):
	"""
	计算混淆矩阵
	"""
    for pre_y, real_y in zip(predicted, label):
        if not(f"{pre_y}_{real_y}" in matrix.keys()):
            matrix[f"{pre_y}_{real_y}"] = 0
        matrix[f"{pre_y}_{real_y}"] += 1


class MyDataset(Dataset):
    def __init__(self, txt, load_type="test"):
        self.load_type = load_type
        with open(txt, "r") as file:
            self.data = file.readlines()

    def __getitem__(self, idx):
        video_path = self.data[idx].strip()
        data = np.load(video_path)
        # 数据的保存样式[20, 224*223*3 + label]
        # 加载20张图片数据
        imgs = data[:, :-1].reshape(20, 224, 224, 3)
        label = data[:, -1:].reshape(-1)
        return imgs, label[0]

    def __len__(self):
        return len(self.data)


if __name__ == "__main__":
	# 将需要训练的数据的路径写入txt
    train_path = "../txt/train.txt"
    test_path = "../txt/test.txt"
    # 加载数据
    dataset = MyDataset(train_path, load_type="train")
    train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=True)
    dataset = MyDataset(test_path, load_type="test")
    test_dataloader = DataLoader(dataset, batch_size=4, shuffle=False, pin_memory=True)

    Epochs = 500
    criterion = torch.nn.CrossEntropyLoss()
    device = "cuda"
	# 模型的大小
    model_type = "resnet18"
    os.makedirs(f"{model_type}", exist_ok=True)
    os.makedirs(f"{model_type}/log", exist_ok=True)

    model = COMP_F(model_type).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-8)
    last_acc = 0

    save_matrix(file=f"{model_type}/log/matrix.txt", epoch=" ", train_matrix=" ", test_matrix=" ", mode="w")
    save_acc(file=f"{model_type}/log/acc.txt", epoch=" ", acc=" ", loss=" ", lr=" ", mode="w")

    for epoch in range(Epochs):
        model.train()
        train_loss = 0
        batch_num = 0
        train_correct = 0
        sample_num = 0
        train_matrix = {}
        with open(f"{model_type}/log/异常数据.txt", "a+") as error_file:
            for sample, label in tqdm(train_dataloader):
                sample, label = torch.permute(sample.to(device) / 255, (0, 4, 1, 2, 3)), label.to(device)
                # print(sample.shape, label.shape)
                optimizer.zero_grad()
                output = model(sample)
                loss = criterion(output, label)
                loss.backward()
                optimizer.step()

                batch_num += 1
                train_loss += loss.item()
                _, predicted = torch.max(output, 1)
                sample_num += label.size(0)
                train_correct += (predicted == label).sum()
                # 计算混淆矩阵
                calculate_matrix(predicted.cpu().numpy(), label.cpu().numpy(), train_matrix)

            train_acc = train_correct.item() / sample_num
            train_loss = train_loss / batch_num

            # 测试
            model.eval()
            test_loss = 0
            batch_num = 0
            test_correct = 0
            sample_num = 0
            test_matrix = {}
            with torch.no_grad():
                try:
                    for sample, label in tqdm(test_dataloader):
                        sample, label = torch.permute(sample.to(device) / 255, (0, 4, 1, 2, 3)), label.to(device)
                        optimizer.zero_grad()
                        output = model(sample)
                        loss = criterion(output, label)

                        batch_num += 1
                        test_loss += loss.item()
                        _, predicted = torch.max(output, 1)
                        calculate_matrix(predicted.cpu().numpy(), label.cpu().numpy(), test_matrix)
                        sample_num += label.size(0)
                        test_correct += (predicted == label).sum()

                except:
                    pass

            test_acc = test_correct.item() / sample_num
            torch.save(model.state_dict(), f'./{model_type}/train_{epoch}_{test_acc:.4f}.pth')
            test_loss = test_loss / batch_num

            save_matrix(file=f"{model_type}/log/matrix.txt", epoch=epoch, train_matrix=train_matrix, test_matrix=test_matrix)
            save_acc(file=f"{model_type}/log/acc.txt", epoch=epoch, acc=f"train_acc: {train_acc:.4f}  test_acc: {test_acc:.4f}"
                     , loss=f"training_loss:{train_loss:.5f} test_loss:{test_loss:.5f}", lr=f"Lr:{scheduler.get_last_lr()}")
            print(f"Epoch: {epoch} training_loss:{train_loss:.5f} test_loss:{test_loss:.5f}")
            print(f"train_acc: {train_acc:.4f}  test_acc: {test_acc:.4f}")

        scheduler.step()
        adjusted_lr = scheduler.get_last_lr()
        print(f"Epoch Lr:{adjusted_lr}\n")
        sleep(1)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158

结尾

  在这篇blog中我们了解了Slowfast模型、编写了一个数据加载器、并使用数据加载器加载数据做训练,下期我们对pytorch的开源视频模型做一下测试 (博友们有什么不懂得或者想学的也可以私信我或者在评论区写出来,我会的话就去写一写博客的)。感谢您的观看,觉得写的还可以请帮忙点个赞和收藏!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/325390
推荐阅读
相关标签
  

闽ICP备14008679号