赞
踩
WMT是机器翻译和机器翻译研究的主要活动。 该会议每年与自然语言处理方面的大型会议联合举行。2006年,第一届机器翻译研讨会在计算语言学协会北美分会年会上举行。2016年,随着神经机器翻译的兴起,WMT成为了一个自己的会议。 机器翻译会议仍然主要被称为WMT[1]。
有些机器翻译工作会使用历年WMT公开的数据集作为他们的数据集[2],如下图所示:
当笔者想要复现工作结果时,首先需要收集得到这样的数据集。而以WMT13[3]为例。如下图所示,笔者需要手动点击下载上面公开的每一个子数据集,然后汇总得到整个WMT13的训练、验证和测试集。而由于每一个子数据集的形式也不同,且数量较多…总的来说还是很麻烦的。
而笔者发现,huggingface[4]上面已经收集了部分年份的WMT数据,并提供了下载接口。以wmt14的所有hi-en数据为例,最终的下载结果如下图所示:
(笔者后知后觉意识到,只要想办法打开.arrow文件就可以得到对应数据了…艹)
本文旨在总结批量获取所有WMT数据的初步解决方案,通过修改huggingface datasets库的源码实现。
第一步,pip install datasets安装datasets库。
第二步,通过git clone https://github.com/huggingface/datasets克隆datasets库,datasets/datasets路径下面包含了该库提供的所有数据集的相关代码:
第三步,创建主程序文件(run.py),代码如下,其中,py_file_path
为上面说的datasets/datasets路径,save_dir
为保存到本地的路径:
from datasets import load_dataset import os wmt_dict = { "wmt14": [(lang, "en") for lang in ["cs", "de", "fr", "hi", "ru"]], "wmt15": [(lang, "en") for lang in ["cs", "de", "fi", "fr", "ru"]], "wmt16": [(lang, "en") for lang in ["cs", "de", "fi", "ro", "ru", "tr"]], "wmt17": [(lang, "en") for lang in ["cs", "de", "fi", "lv", "ru", "tr", "zh"]], "wmt18": [(lang, "en") for lang in ["cs", "de", "et", "fi", "kk", "ru", "tr", "zh"]], "wmt19": [(lang, "en") for lang in ["cs", "de", "fi", "gu", "kk", "lt", "ru", "zh"]] + [("fr", "de")], } py_file_path = r"C:\Users\13359\PycharmProjects\for_fun\other\wmt_datasets\datasets\datasets" save_dir = r"D:\dataset\mt" for wmt in wmt_dict: for lang_tuple in wmt_dict[wmt]: lang_pair = "-".join(lang_tuple) print(f"wmt: {wmt} | lang_pair: {lang_pair}") load_dataset(os.path.join(py_file_path, wmt), name = lang_pair, cache_dir = save_dir)
第四步,在上述datasets/datasets路径下面随便选择一个wmt文件夹,比如wmt14,将里面的wmt_utils.py复制到run.py的同级目录下。(暂时不知道为何,尝试下来这样没有错),也就是文件目录结构如下:
第五步,如果此时运行run.py,则会像前言中的那样,得到所有wmt的所有语言对的数据,但数据格式是arrow的。笔者脑抽,没有直接去想如何把.arrow文件转成更好理解的格式,而是通过修改pip install下来的datasets源码,来直接修改保存数据的过程。具体来说,通过ctrl+B追溯run.py中load_dataset的执行顺序,最终找到了保存数据的源码位置:load_dataset(run.py)->builder_instance.download_and_prepare(load.py,1738行)->self._download_and_prepare(builder.py, 638行)->self._prepare_split(builder.py, 723行),由于此时的self是一个wmtxx object,从具体的wmtxx.py(如wmt14.py,位于datasets/datasets/wmt14/wmt14.py)可知,wmtxx类的继承顺序是:wmtxx->Wmt(wmt_utils.py)->GeneratorBasedBuilder(builder.py)->DatasetBuilder(builder.py),所以self._prepare_split最终方法实现是GeneratorBasedBuilder类的_prepare_split方法。
该方法中完成了.arrow数据的创建,具体代码如下所示:
with ArrowWriter( features=self.info.features, path=fpath, writer_batch_size=self._writer_batch_size, hash_salt=split_info.name, check_duplicates=check_duplicate_keys, ) as writer: try: for key, record in logging.tqdm( generator, unit=" examples", total=split_info.num_examples, leave=False, disable=not logging.is_progress_bar_enabled(), desc=f"Generating {split_info.name} split", ): example = self.info.features.encode_example(record) writer.write(example, key) finally: num_examples, num_bytes = writer.finalize() split_generator.split_info.num_examples = num_examples split_generator.split_info.num_bytes = num_bytes
其中,generator
就是包含了所有数据的生成器。于是,笔者修改了上面的代码,完成了对数据保存的修改:
# ...其它代码 generator = self._generate_examples(**split_generator.gen_kwargs) # 修改代码 user_name = "xushaoyang" str_lst = fpath.split("\\") index = str_lst.index(self.name) lang_pair = str_lst[index + 1] source, target = lang_pair.split("-") path_lst = str_lst[:index + 2] path_lst[index] += f"_{user_name}" dir_path = os.path.join(*path_lst) os.makedirs(dir_path, exist_ok=True) source_file_name = f"{self.name}.{lang_pair}-{split_generator.name}.{source}" source_path = os.path.join(dir_path, source_file_name) target_file_name = f"{self.name}.{lang_pair}-{split_generator.name}.{target}" target_path = os.path.join(dir_path, target_file_name) # flag = "population of the US, or of the combined current population of India and China" with open(source_path, mode="w", encoding="utf-8") as source_f: with open(target_path, mode="w", encoding="utf-8") as target_f: with ArrowWriter( features=self.info.features, path=fpath, writer_batch_size=self._writer_batch_size, hash_salt=split_info.name, check_duplicates=check_duplicate_keys, ) as writer: try: for key, record in logging.tqdm( generator, unit=" examples", total=split_info.num_examples, leave=False, disable=not logging.is_progress_bar_enabled(), desc=f"Generating {split_info.name} split", ): assert list(record.keys()) == ["translation"] lang_keys = list(record["translation"].keys()) if source not in lang_keys: # 问题1:zh的key在有些数据集里是ch target_idx = lang_keys.index(target) source_idx = (target_idx + 1) % 2 error_source_key = lang_keys[source_idx] new_record = {"translation": { source: record["translation"][error_source_key], target: record["translation"][target] }} record = new_record del new_record # arrow write example = self.info.features.encode_example(record) writer.write(example, key) # file write source_sentence = record['translation'][source] target_sentence = record['translation'][target] source_sentence = source_sentence.replace("\r", "").replace("\n", "") # 问题2:多余的换行 target_sentence = target_sentence.replace("\r", "").replace("\n", "") source_f.write(source_sentence + "\n") target_f.write(target_sentence + "\n") finally: num_examples, num_bytes = writer.finalize() split_generator.split_info.num_examples = num_examples split_generator.split_info.num_bytes = num_bytes
完成修改后执行run.py,以wmt14的hi-en数据为例,得到的数据如下图所示,文件的命名仿照了OPUS100[5]:
如"不足之处"的第3点所述,有一些文件没有提供自动下载的url,笔者的解决方案是在下载过程中记录哪些数据集没有自动下载,所有下载完成之后再去手动补上。具体来说,笔者在wmt_utils.py中的_split_generators
函数中的 if dataset.get_manual_dl_files(source):
语句下,加入了如下语句:
with open(f"{self.name}_error_log", mode="a", encoding="utf-8") as file:
file.write(f"lang: {'-'.join(self.config.language_pair)} | data_name: {dataset.name} | url: {str(dataset.get_manual_dl_files(source))}" + "\n")
在运行run.py的过程中,出现数据集缺失的情况,这样的记录就会被保存在wmtxxx_error_log日志文件中,如下图所示:
打开datasets/datasets中的wmt19文件夹,修改里面的wmt19.py和wmt_utils.py。
# wmt19.py:line79 datasets.Split.TEST: ["newstest2019", "newstest2019_csen", "newstest2019_frde"], # wmt_utils.py:line617 SubDataset( name="newstest2019", target="en", sources={"de", "fi", "gu", "kk", "lt", "ru", "zh"}, url="http://data.statmt.org/wmt19/translation-task/test.tgz", path=("sgm/newstest2019-{src}en-src.{src}.sgm", "sgm/newstest2019-{src}en-ref.en.sgm"), ), SubDataset( name="newstest2019_csen", target="en", sources={"cs"}, url="http://data.statmt.org/wmt19/translation-task/test.tgz", path=("sgm/newstest2019-en{src}-src.en.sgm", "sgm/newstest2019-en{src}-ref.{src}.sgm"), ), SubDataset( name="newstest2019_frde", target="de", sources={"fr"}, url="http://data.statmt.org/wmt19/translation-task/test.tgz", path=("sgm/newstest2019-frde-ref.de.sgm", "sgm/newstest2019-frde-src.fr.sgm"), ),
另外,直接这样运行run.py会报错,因为dataset_infos.json中的内容还没有修改。而程序会读取dataset_info.json中一些预设的信息,然后和下载下来的结果进行一些校验:verify_checksums、verify_splits,但是直接修改dataset_infos.json较为麻烦,所以笔者选择取消校验:
load_dataset(os.path.join(py_file_path, wmt), name = lang_pair, cache_dir = save_dir, save_infos = True)
# ignore_verifications = True,也可以
另外,还是要对dataset_infos.json做一个细微的修改,即在splits中增加”test“:
"test": {
"name": "test",
"num_bytes": 3000, # 这个随便设置的,不影响下载
"num_examples": 3000, # 同上
"dataset_name": "wmt19"
}
cs-en: train: source:953621 target:953621 validation: source:3000 target:3000 test: source:3003 target:3003 de-en: train: source:4508785 target:4508785 validation: source:3000 target:3000 test: source:3003 target:3003 fr-en: train: source:40836715 target:40836715 validation: source:3000 target:3000 test: source:3003 target:3003 hi-en: train: source:32863 target:32863 validation: source:520 target:520 test: source:2507 target:2507 ru-en: train: source:1486965 target:1486965 validation: source:3000 target:3000 test: source:3003 target:3003
cs-en: train: source:959768 target:959768 validation: source:3003 target:3003 test: source:2656 target:2656 de-en: train: source:4522998 target:4522998 validation: source:3003 target:3003 test: source:2169 target:2169 fi-en: train: source:2073394 target:2073394 validation: source:1500 target:1500 test: source:1370 target:1370 fr-en: train: source:40853137 target:40853137 validation: source:4503 target:4503 test: source:1500 target:1500 ru-en: train: source:1495081 target:1495081 validation: source:3003 target:3003 test: source:2818 target:2818
cs-en: train: source:997240 target:997240 validation: source:2656 target:2656 test: source:2999 target:2999 de-en: train: source:4548885 target:4548885 validation: source:2169 target:2169 test: source:2999 target:2999 fi-en: train: source:2073394 target:2073394 validation: source:1370 target:1370 test: source:6000 target:6000 ro-en: train: source:610320 target:610320 validation: source:1999 target:1999 test: source:1999 target:1999 ru-en: train: source:1516162 target:1516162 validation: source:2818 target:2818 test: source:2998 target:2998 tr-en: train: source:205756 target:205756 validation: source:1001 target:1001 test: source:3000 target:3000
cs-en: train: source:1018291 target:1018291 validation: source:2999 target:2999 test: source:3005 target:3005 de-en: train: source:5906184 target:5906184 validation: source:2999 target:2999 test: source:3004 target:3004 fi-en: train: source:2656542 target:2656542 validation: source:6000 target:6000 test: source:6004 target:6004 lv-en: train: source:3567528 target:3567528 validation: source:2003 target:2003 test: source:2001 target:2001 ru-en: train: source:24782720 target:24782720 validation: source:2998 target:2998 test: source:3001 target:3001 tr-en: train: source:205756 target:205756 validation: source:3000 target:3000 test: source:3007 target:3007 zh-en: train: source:25134743 target:25134743 validation: source:2002 target:2002 test: source:2001 target:2001
cs-en: train: source:11046024 target:11046024 validation: source:3005 target:3005 test: source:2983 target:2983 de-en: train: source:42271874 target:42271874 validation: source:3004 target:3004 test: source:2998 target:2998 et-en: train: source:2175873 target:2175873 validation: source:2000 target:2000 test: source:2000 target:2000 fi-en: train: source:3280600 target:3280600 validation: source:6004 target:6004 test: source:3000 target:3000 kk-en: train: source:0 target:0 validation: source:0 target:0 test: source:0 target:0 ru-en: train: source:36858512 target:36858512 validation: source:3001 target:3001 test: source:3000 target:3000 tr-en: train: source:205756 target:205756 validation: source:3007 target:3007 test: source:3000 target:3000 zh-en: train: source:25160346 target:25160346 validation: source:2001 target:2001 test: source:3981 target:3981
cs-en: train: source:7270695 target:7270695 validation: source:2983 target:2983 test: source:1997 target:1997 de-en: train: source:38690334 target:38690334 validation: source:2998 target:2998 test: source:2000 target:2000 fi-en: train: source:6587448 target:6587448 validation: source:3000 target:3000 test: source:1996 target:1996 gu-en: train: source:11670 target:11670 validation: source:1998 target:1998 test: source:1016 target:1016 kk-en: train: source:126583 target:126583 validation: source:2066 target:2066 test: source:1000 target:1000 lt-en: train: source:2344893 target:2344893 validation: source:2000 target:2000 test: source:1000 target:1000 ru-en: train: source:37492126 target:37492126 validation: source:3000 target:3000 test: source:2000 target:2000 zh-en: train: source:25984574 target:25984574 validation: source:3981 target:3981 test: source:2000 target:2000 fr-de: train: source:9824476 target:9824476 validation: source:1512 target:1512 test: source:1707 target:1707
可以看到存在大量的重复。以上语料的下载链接统计如下:
基本都需要提交申请
[1]https://machinetranslate.org/wmt
[2]https://arxiv.org/pdf/2105.09259v1.pdf
[3]https://www.statmt.org/wmt14/translation-task.html
[4]https://github.com/huggingface/datasets
[5]https://github.com/EdinburghNLP/opus-100-corpus
赞
踩
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。