当前位置:   article > 正文

load_dataset方法使用(HuggingFace的datasets库)

load_dataset

1.从本地加载数据集

用path指定数据集格式

  • json格式,path="json"
  • csv格式, path="csv"
  • 纯文本格式, path="text"
  • dataframe格式, path="panda"
  • 图片,path="imagefolder"

然后用data_files指定文件名称,data_files可以是字符串、列表、或者字典,data_dir指定数据集目录

  1. from datasets import load_dataset
  2. dataset = load_dataset('csv', data_files='my_file.csv')
  3. dataset = load_dataset('csv', data_files=['my_file_1.csv', 'my_file_2.csv', 'my_file_3.csv'])
  4. dataset = load_dataset('csv', data_files={'train':['my_train_file_1.csv','my_train_file_2.csv'],'test': 'my_test_file.csv'})
1.1加载图片

如下我们通过打开指定图片目录进行加载图片数据集

  1. dataset = load_dataset(path="imagefolder",
  2. data_dir="D:\Desktop\workspace\code\loaddataset\data\images")
  3. print(dataset)
  4. print(dataset["train"][0])

图片文本对应
很多情况下加载图片并非只要图片,还会有对应的文本,比如在图片分类的时候,每张图片都对应一个类别。这种情况我们需要在图片所在文件夹中加入一个metadata.jsonl的文件,来指定每个图片对应的类别,格式如下,注意file_name字段必须要有,其他字段可自行命名

  1. {
  2. "file_name": "1.jpg",
  3. "class": 1
  4. }
  5. {
  6. "file_name": "2.png",
  7. "class": 0
  8. }

然后我们再来运行

  1. dataset = load_dataset(path="imagefolder",
  2. data_dir="D:\Desktop\workspace\code\loaddataset\data\images")
  3. print(dataset)
  4. print(dataset["train"][0])

输出如下

  1. DatasetDict({
  2. train: Dataset({
  3. features: ['image', 'class'],
  4. num_rows: 2
  5. })
  6. })
  7. {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=800x320 at 0x2912172B520>, 'class': 1}

2.自定义加载脚本

一些情况下加载数据集的逻辑较为复杂,需要自定义加载方式。

如下所示,我们数据处理需要是,每条数据包括两张图片,一个文本。

  • step1: 首先我们先创建一个json文件train.jsonl,把图片和文本对应起来,json文件的格式如下所示
    1. {"text": "pale golden rod circle with old lace background", "image": "images/0.png", "conditioning_image": "conditioning_images/0.png"}
    2. {"text": "light coral circle with white background", "image": "images/1.png", "conditioning_image": "conditioning_images/1.png"}
    3. {"text": "aqua circle with light pink background", "image": "images/2.png", "conditioning_image": "conditioning_images/2.png"}

    step2:创建一个python脚本fill50k.py根据json文件中的对应关系加载图片,python脚本如下所示,这个脚本中定义一个 Fill50k类,并继承datasets.GeneratorBasedBuilder,在类中重写_info(self),_split_generators(self, dl_manager)和_split_generators(self, dl_manager)这三个方法

    1. import pandas as pd
    2. import datasets
    3. import os
    4. import logging
    5. # 数据集路径设置
    6. META_DATA_PATH = "D:\Desktop\workspace\code\loaddataset\\fill50k\\train.jsonl"
    7. IMAGE_DIR = "D:\Desktop\workspace\code\loaddataset\\fill50k"
    8. CONDITION_IMAGE_DIR = "D:\Desktop\workspace\code\loaddataset\\fill50k"
    9. # 定义数据集中有哪些特征,及其类型
    10. _FEATURES = datasets.Features(
    11. {
    12. "image": datasets.Image(),
    13. "conditioning_image": datasets.Image(),
    14. "text": datasets.Value("string"),
    15. },
    16. )
    17. # 定义数据集
    18. class Fill50k(datasets.GeneratorBasedBuilder):
    19. BUILDER_CONFIGS = [datasets.BuilderConfig(name="default", version=datasets.Version("0.0.2"))]
    20. DEFAULT_CONFIG_NAME = "default"
    21. def _info(self):
    22. return datasets.DatasetInfo(
    23. description="None",
    24. features=_FEATURES,
    25. supervised_keys=None,
    26. homepage="None",
    27. license="None",
    28. citation="None",
    29. )
    30. def _split_generators(self, dl_manager):
    31. return [
    32. datasets.SplitGenerator(
    33. name=datasets.Split.TRAIN,
    34. # These kwargs will be passed to _generate_examples
    35. gen_kwargs={
    36. "metadata_path": META_DATA_PATH,
    37. "images_dir": IMAGE_DIR,
    38. "conditioning_images_dir": CONDITION_IMAGE_DIR,
    39. },
    40. ),
    41. ]
    42. def _generate_examples(self, metadata_path, images_dir, conditioning_images_dir):
    43. metadata = pd.read_json(metadata_path, lines=True)
    44. for _, row in metadata.iterrows():
    45. text = row["text"]
    46. image_path = row["image"]
    47. image_path = os.path.join(images_dir, image_path)
    48. # 打开文件错误时直接跳过
    49. try:
    50. image = open(image_path, "rb").read()
    51. except Exception as e:
    52. logging.error(e)
    53. continue
    54. conditioning_image_path = os.path.join(
    55. conditioning_images_dir, row["conditioning_image"]
    56. )
    57. # 打开文件错误直接跳过
    58. try:
    59. conditioning_image = open(conditioning_image_path, "rb").read()
    60. except Exception as e:
    61. logging.error(e)
    62. continue
    63. yield row["image"], {
    64. "text": text,
    65. "image": {
    66. "path": image_path,
    67. "bytes": image,
    68. },
    69. "conditioning_image": {
    70. "path": conditioning_image_path,
    71. "bytes": conditioning_image,
    72. },
    73. }

  • step3: 通过load_dataset加载数据集
    1. dataset = load_dataset(path="D:\Desktop\workspace\code\loaddataset\\fill50k\\fill50k.py",
    2.                        cache_dir="D:\Desktop\workspace\code\loaddataset\\fill50k\cache")
    3. print(dataset)
    4. print(dataset["train"][0])
    输出结果如下
    1. DatasetDict({
    2. train: Dataset({
    3. features: ['image', 'conditioning_image', 'text'],
    4. num_rows: 50000
    5. })
    6. })
    7. {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x512 at 0x1AEA2FF9040>, 'conditioning_image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x512 at 0x1AEA2FE2640>, 'text': 'pale golden rod circle with old lace background'}

  • 本文参考链接:【torch】HuggingFace的datasets库中load_dataset方法使用_orangerfun的博客-CSDN博客

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/412252
推荐阅读
相关标签
  

闽ICP备14008679号