当前位置:   article > 正文

pytorch 训练模型很慢,卡在数据读取,卡I/O的有效解决方案_pytorch 每次训练前 加载数据 会卡一下

pytorch 每次训练前 加载数据 会卡一下

多线程加载

  • 在 datalaoder中指定num_works > 0,多线程加载数据集,最大可设置为 cpu 核数
  • 设置 pin_memory = True, 固定内存访问单元,节约内存调度时间
  • 示例如下:
loader = DataLoader(
        dataset,
        batch_size=batch_size * group_size,
        shuffle=True,
        collate_fn=dataset.collate_fn,
        num_workers=2,
        pin_memory=True,
    )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

预加载数据集

说别的都没大用,还得是预加载

  • 原理:将整个数据集预先 load 到内存单元中,读取则直接访问内存,不存在与磁盘的I/O问题
  • 构建自己的dataset类
  • 示例如下:
class My_Dataset(Dataset):
    def __init__(
        self, filename, preprocess_config, train_config, sort=False, drop_last=False
    ):
        self.dataset_name = preprocess_config["dataset"]
        self.preprocessed_path = preprocess_config["path"]["preprocessed_path"]
        self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
        self.batch_size = train_config["optimizer"]["batch_size"]

        self.basename, self.speaker, self.text, self.raw_text = self.process_meta(
            filename
        )
        with open(os.path.join(self.preprocessed_path, "speakers.json")) as f:
            self.speaker_map = json.load(f)
        self.sort = sort
        self.drop_last = drop_last
        # add
        self.mel_list = []
        self.pitch_list = []
        self.energy_list = []
        self.duration_list = []
        for idx in range(len(self.text)):
            basename = self.basename[idx]
            speaker = self.speaker[idx]
            mel_path = os.path.join(
            self.preprocessed_path,
            "mel",
            "{}-mel-{}.npy".format(speaker, basename),
            )
            mel = np.load(mel_path)
            pitch_path = os.path.join(
                self.preprocessed_path,
                "pitch",
                "{}-pitch-{}.npy".format(speaker, basename),
            )
            pitch = np.load(pitch_path)
            energy_path = os.path.join(
                self.preprocessed_path,
                "energy",
                "{}-energy-{}.npy".format(speaker, basename),
            )
            energy = np.load(energy_path)
            duration_path = os.path.join(
                self.preprocessed_path,
                "duration",
                "{}-duration-{}.npy".format(speaker, basename),
            )
            duration = np.load(duration_path)
            self.mel_list.append(mel)
            self.pitch_list.append(pitch)
            self.energy_list.append(energy)
            self.duration_list.append(duration)

    def __len__(self):
        return len(self.text)

    def __getitem__(self, idx):
        basename = self.basename[idx]
        speaker = self.speaker[idx]
        speaker_id = self.speaker_map[speaker]
        raw_text = self.raw_text[idx]
        phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
        
        mel = self.mel_list[idx]
        pitch = self.pitch_list[idx]       
        energy = self.energy_list[idx]        
        duration = self.duration_list[idx]

        sample = {
            "id": basename,
            "speaker": speaker_id,
            "text": phone,
            "raw_text": raw_text,
            "mel": mel,
            "pitch": pitch,
            "energy": energy,
            "duration": duration,
        }

        return sample
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • __init__函数里,即将所有数据load进内存
  • __getitem__(self, idx):函数,则直接通过列表idx访问每一条数据
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/610274
推荐阅读
相关标签
  

闽ICP备14008679号