赞
踩
根据zhangy代码改写,主要针对千言问题匹配进行文本数据增强。
依赖安装
- pip install jieba
- pip install synonyms
eda.py
- import jieba
- import synonyms
- import random
- from random import shuffle
-
- random.seed(2019)
-
- #停用词列表,默认使用哈工大停用词表
- f = open('stopwords/hit_stopwords.txt')
- stop_words = list()
- for stop_word in f.readlines():
- stop_words.append(stop_word[:-1])
-
-
- #考虑到与英文的不同,暂时搁置
- #文本清理
- '''
- import re
- def get_only_chars(line):
- #1.清除所有的数字
- '''
-
-
- ########################################################################
- # 同义词替换
- # 替换一个语句中的n个单词为其同义词
- ########################################################################
- def synonym_replacement(words, n):
- new_words = words.copy()
- random_word_list = list(set([word for word in words if word not in stop_words]))
- random.shuffle(random_word_list)
- num_replaced = 0
- for random_word in random_word_list:
- synonyms = get_synonyms(random_word)
- if len(synonyms) >= 1:
- synonym = random.choice(synonyms)
- new_words = [synonym if word == random_word else word for word in new_words]
- num_replaced += 1
- if num_replaced >= n:
- break
-
- sentence = ' '.join(new_words)
- new_words = sentence.split(' ')
-
- return new_words
-
- def get_synonyms(word):
- return synonyms.nearby(word)[0]
-
-
- ########################################################################
- # 随机插入
- # 随机在语句中插入n个词
- ########################################################################
- def random_insertion(words, n):
- new_words = words.copy()
- for _ in range(n):
- add_word(new_words)
- return new_words
-
- def add_word(new_words):
- synonyms = []
- counter = 0
- while len(synonyms) < 1:
- random_word = new_words[random.randint(0, len(new_words)-1)]
- synonyms = get_synonyms(random_word)
- counter += 1
- if counter >= 10:
- return
- random_synonym = random.choice(synonyms)
- random_idx = random.randint(0, len(new_words)-1)
- new_words.insert(random_idx, random_synonym)
-
-
- ########################################################################
- # Random swap
- # Randomly swap two words in the sentence n times
- ########################################################################
-
- def random_swap(words, n):
- new_words = words.copy()
- for _ in range(n):
- new_words = swap_word(new_words)
- return new_words
-
- def swap_word(new_words):
- random_idx_1 = random.randint(0, len(new_words)-1)
- random_idx_2 = random_idx_1
- counter = 0
- while random_idx_2 == random_idx_1:
- random_idx_2 = random.randint(0, len(new_words)-1)
- counter += 1
- if counter > 3:
- return new_words
- new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1]
- return new_words
-
- ########################################################################
- # 随机删除
- # 以概率p删除语句中的词
- ########################################################################
- def random_deletion(words, p):
-
- if len(words) == 1:
- return words
-
- new_words = []
- for word in words:
- r = random.uniform(0, 1)
- if r > p:
- new_words.append(word)
-
- if len(new_words) == 0:
- rand_int = random.randint(0, len(words)-1)
- return [words[rand_int]]
-
- return new_words
-
-
- ########################################################################
- #EDA函数
- def eda(sentence, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=9):
- seg_list = jieba.cut(sentence)
- seg_list = " ".join(seg_list)
- words = list(seg_list.split())
- num_words = len(words)
-
- augmented_sentences = []
- num_new_per_technique = int(num_aug/4)+1
- n_sr = max(1, int(alpha_sr * num_words))
- n_ri = max(1, int(alpha_ri * num_words))
- n_rs = max(1, int(alpha_rs * num_words))
-
- #print(words, "\n")
-
-
- #同义词替换sr
- for _ in range(num_new_per_technique):
- a_words = synonym_replacement(words, n_sr)
- augmented_sentences.append(' '.join(a_words))
-
- #随机插入ri
- for _ in range(num_new_per_technique):
- a_words = random_insertion(words, n_ri)
- augmented_sentences.append(' '.join(a_words))
-
- #随机交换rs
- for _ in range(num_new_per_technique):
- a_words = random_swap(words, n_rs)
- augmented_sentences.append(' '.join(a_words))
-
-
- #随机删除rd
- for _ in range(num_new_per_technique):
- a_words = random_deletion(words, p_rd)
- augmented_sentences.append(' '.join(a_words))
-
- #print(augmented_sentences)
- shuffle(augmented_sentences)
-
- if num_aug >= 1:
- augmented_sentences = augmented_sentences[:num_aug]
- else:
- keep_prob = num_aug / len(augmented_sentences)
- augmented_sentences = [s for s in augmented_sentences if random.uniform(0, 1) < keep_prob]
-
- augmented_sentences.append(seg_list)
-
- return augmented_sentences
augment.py
- from eda import *
-
- import argparse
- ap = argparse.ArgumentParser()
- ap.add_argument("--input", required=True, type=str, help="原始数据的输入文件目录")
- ap.add_argument("--output", required=False, type=str, help="增强数据后的输出文件目录")
- ap.add_argument("--num_aug", required=False, type=int, help="每条原始语句增强的语句数")
- ap.add_argument("--alpha", required=False, type=float, help="每条语句中将会被改变的单词数占比")
- args = ap.parse_args()
-
- #输出文件
- output = None
- if args.output:
- output = args.output
- else:
- from os.path import dirname, basename, join
- output = join(dirname(args.input), 'eda_' + basename(args.input))
-
- #每条原始语句增强的语句数
- num_aug = 9 #default
- if args.num_aug:
- num_aug = args.num_aug
-
- #每条语句中将会被改变的单词数占比
- alpha = 0.1 #default
- if args.alpha:
- alpha = args.alpha
-
- def gen_eda(train_orig, output_file, alpha, num_aug=9):
-
- writer = open(output_file, 'w')
- lines = open(train_orig, 'r').readlines()
-
- print("正在使用EDA生成增强语句...")
- for i, line in enumerate(lines):
- parts = line[:-1].split('\t') #使用[:-1]是把\n去掉了
- sentence1 = parts[0]
- sentence2 = parts[1]
- label = parts[2]
- aug_sentences1 = eda(sentence1, alpha_sr=alpha, alpha_ri=alpha, alpha_rs=alpha, p_rd=alpha, num_aug=num_aug)
- aug_sentences2 = eda(sentence2, alpha_sr=alpha, alpha_ri=alpha, alpha_rs=alpha, p_rd=alpha, num_aug=num_aug)
- for i,aug_sentence1 in enumerate(aug_sentences1):
- writer.write(aug_sentence1.replace(' ','') + "\t" + aug_sentences2[i].replace(' ','') + "\t" + label +'\n')
- writer.close()
- print("已生成增强语句!")
- print(output_file)
-
- if __name__ == "__main__":
- gen_eda(args.input, output, alpha=alpha, num_aug=num_aug)
增强训练数据集
!python augment.py --input train.txt
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。