赞
踩
from tqdm import tqdm from torch.utils.data import Dataset, DataLoader __all__ = ["MultiProcessor", "taskCore"] class taskCore(Dataset): def __init__(self, custom_list, *args, **kwargs): self.custom_list = custom_list self.args = kwargs self.kwargs = kwargs pass def __len__(self): return len(self.custom_list) def pro_func(self, *args, **kwargs): """替换for循环处理内容""" pass def custom_item(self, custom_list, idx): """替换从for循环中取数据""" return custom_list[idx] def __getitem__(self, idx): return self.pro_func(self.custom_item(self.custom_list, idx), *self.args, **self.kwargs) def collate_fn(item): return 0 class MultiProcessor: def __init__(self, custom_list, custom_taskCore=None, *args, **kwargs): self.taskCore = custom_taskCore if custom_taskCore else taskCore(custom_list, *args, **kwargs) self.handler = DataLoader(dataset=self.taskCore, collate_fn=collate_fn, num_workers=8) def set_custom_func(self, pro_func, custom_item=None): self.taskCore.pro_func = pro_func if custom_item is not None: self.taskCore.custom_item = custom_item def __call__(self, *args, **kwargs): for item in tqdm(self.handler): pass def main(): pass if __name__ == '__main__': main() pass
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。