赞
踩
num_works > 0
,多线程加载数据集,最大可设置为 cpu 核数pin_memory = True
, 固定内存访问单元,节约内存调度时间loader = DataLoader(
dataset,
batch_size=batch_size * group_size,
shuffle=True,
collate_fn=dataset.collate_fn,
num_workers=2,
pin_memory=True,
)
说别的都没大用,还得是预加载
class My_Dataset(Dataset): def __init__( self, filename, preprocess_config, train_config, sort=False, drop_last=False ): self.dataset_name = preprocess_config["dataset"] self.preprocessed_path = preprocess_config["path"]["preprocessed_path"] self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"] self.batch_size = train_config["optimizer"]["batch_size"] self.basename, self.speaker, self.text, self.raw_text = self.process_meta( filename ) with open(os.path.join(self.preprocessed_path, "speakers.json")) as f: self.speaker_map = json.load(f) self.sort = sort self.drop_last = drop_last # add self.mel_list = [] self.pitch_list = [] self.energy_list = [] self.duration_list = [] for idx in range(len(self.text)): basename = self.basename[idx] speaker = self.speaker[idx] mel_path = os.path.join( self.preprocessed_path, "mel", "{}-mel-{}.npy".format(speaker, basename), ) mel = np.load(mel_path) pitch_path = os.path.join( self.preprocessed_path, "pitch", "{}-pitch-{}.npy".format(speaker, basename), ) pitch = np.load(pitch_path) energy_path = os.path.join( self.preprocessed_path, "energy", "{}-energy-{}.npy".format(speaker, basename), ) energy = np.load(energy_path) duration_path = os.path.join( self.preprocessed_path, "duration", "{}-duration-{}.npy".format(speaker, basename), ) duration = np.load(duration_path) self.mel_list.append(mel) self.pitch_list.append(pitch) self.energy_list.append(energy) self.duration_list.append(duration) def __len__(self): return len(self.text) def __getitem__(self, idx): basename = self.basename[idx] speaker = self.speaker[idx] speaker_id = self.speaker_map[speaker] raw_text = self.raw_text[idx] phone = np.array(text_to_sequence(self.text[idx], self.cleaners)) mel = self.mel_list[idx] pitch = self.pitch_list[idx] energy = self.energy_list[idx] duration = self.duration_list[idx] sample = { "id": basename, "speaker": speaker_id, "text": phone, "raw_text": raw_text, "mel": mel, "pitch": pitch, "energy": energy, "duration": duration, } return sample
__init__
函数里,即将所有数据load进内存__getitem__(self, idx):
函数,则直接通过列表idx访问每一条数据Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。