当前位置:   article > 正文

huggingface如何加载本地数据集进行大模型训练_huggingface 加载本地模型

huggingface 加载本地模型

背景:

一般情况下,我们使用huggingface都是从其网站上直接加载数据集进行训练,代码如下:

  1. from datasets import load_dataset
  2. food = load_dataset("food101")

写了上面的代码,那么load_dataset函数就会自动从huggingface上面下载数据集到本地,然后缓存起来。

可对于我们自己的数据集该如何加载呢?尤其是图片数据集,对于模型来说,他需要的数据格式如下:

  1. {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512 at 0x7F52AFC8AC50>,
  2. 'label': 79}

从上述代码可以看到,image是一个PIL格式的图片,而我们本地是一个图片文件,如何将本地的图片转换成PIL的图片呢?

下面我们就来讲一讲如何加载本地数据集。

数据集目录结构

  1. dataset
  2. --image
  3. --1.jsp
  4. --2.jsp
  5. --……
  6. --image.json
  7. --label.json

目录结构说明

1、image文件夹

存储所有的图片

2、image.json

存储map结构的json字符串,key是图片的名称,value是图片对应的标签分类

  1. {
  2. "1.jpg": 0,
  3. "2.jpg": 0,
  4. "3.jpg": 1,
  5. "4.jpg": 1,
  6. "5.jpg": 4,
  7. "6.jpg": 4,
  8. "7.jpg": 2,
  9. "8.jpg": 2,
  10. "9.jpg": 3,
  11. "10.jpg": 3
  12. }

3、label.json

存储map结构的json字符串,key是标签的名称,value是标签的分类

  1. {
  2. "apple": 0,
  3. "pear": 1,
  4. "strawberry": 2,
  5. "peach": 3,
  6. "chestnut": 4
  7. }

代码

如何将上述的目录结构存储的数据转换成模型需要的格式呢?

话不多说,直接上代码

  1. import json
  2. import os
  3. from PIL import Image
  4. from datasets import Dataset
  5. path = 'D:/项目/AI/数据集/image/vit_dataset'
  6. def gen(path):
  7. image_json = os.path.join(path, "image.json")
  8. with open(image_json, 'r') as f:
  9. # 读取JSON数据
  10. data = json.load(f)
  11. for key, value in data.items():
  12. imagePath = os.path.join(path, "image")
  13. imagePath = os.path.join(imagePath, key)
  14. image = Image.open(imagePath)
  15. yield {'image': image, 'label': value}
  16. ds = Dataset.from_generator(gen, gen_kwargs={"path": path})
  17. ds = ds.train_test_split(test_size=0.2)
  18. def get_label(path):
  19. label_json = os.path.join(path, "label.json")
  20. with open(label_json, 'r') as f:
  21. # 读取JSON数据
  22. data = json.load(f)
  23. label2id, id2label = dict(), dict()
  24. for key, value in data.items():
  25. label2id[key] = str(value)
  26. id2label[str(value)] = key
  27. return label2id, id2label
  28. print(ds)
  29. print(ds['train'][0])
  30. label2id, id2label = get_label(path)
  31. print(label2id)
  32. print(id2label)

运行结果:

  1. DatasetDict({
  2. train: Dataset({
  3. features: ['image', 'label'],
  4. num_rows: 8
  5. })
  6. test: Dataset({
  7. features: ['image', 'label'],
  8. num_rows: 2
  9. })
  10. })
  11. {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=332x332 at 0x1ED6A84C690>, 'label': 0}
  12. {'apple': '0', 'pear': '1', 'strawberry': '2', 'peach': '3', 'chestnut': '4'}
  13. {'0': 'apple', '1': 'pear', '2': 'strawberry', '3': 'peach', '4': 'chestnut'}

从输出结果,我们发现我们已经获得了模型需要的数据结构。主要关注运行结果的倒数第三行。

总结:

利用Dataset.from_generator()函数,通过定义一个生成器,就能根据将我们本地自定义的数据转换成大模型需要的任何的格式类型。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/412219
推荐阅读
  

闽ICP备14008679号