当前位置:   article > 正文

文本数据增强-同义词替换、随机交换、随机插入、随机删除_eda(sentence, alpha_sr=0.1, alpha_ri=0.1, alpha_rs

eda(sentence, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=9)

根据zhangy代码改写,主要针对千言问题匹配进行文本数据增强

依赖安装

  1. pip install jieba
  2. pip install synonyms

eda.py

  1. import jieba
  2. import synonyms
  3. import random
  4. from random import shuffle
  5. random.seed(2019)
  6. #停用词列表,默认使用哈工大停用词表
  7. f = open('stopwords/hit_stopwords.txt')
  8. stop_words = list()
  9. for stop_word in f.readlines():
  10. stop_words.append(stop_word[:-1])
  11. #考虑到与英文的不同,暂时搁置
  12. #文本清理
  13. '''
  14. import re
  15. def get_only_chars(line):
  16. #1.清除所有的数字
  17. '''
  18. ########################################################################
  19. # 同义词替换
  20. # 替换一个语句中的n个单词为其同义词
  21. ########################################################################
  22. def synonym_replacement(words, n):
  23. new_words = words.copy()
  24. random_word_list = list(set([word for word in words if word not in stop_words]))
  25. random.shuffle(random_word_list)
  26. num_replaced = 0
  27. for random_word in random_word_list:
  28. synonyms = get_synonyms(random_word)
  29. if len(synonyms) >= 1:
  30. synonym = random.choice(synonyms)
  31. new_words = [synonym if word == random_word else word for word in new_words]
  32. num_replaced += 1
  33. if num_replaced >= n:
  34. break
  35. sentence = ' '.join(new_words)
  36. new_words = sentence.split(' ')
  37. return new_words
  38. def get_synonyms(word):
  39. return synonyms.nearby(word)[0]
  40. ########################################################################
  41. # 随机插入
  42. # 随机在语句中插入n个词
  43. ########################################################################
  44. def random_insertion(words, n):
  45. new_words = words.copy()
  46. for _ in range(n):
  47. add_word(new_words)
  48. return new_words
  49. def add_word(new_words):
  50. synonyms = []
  51. counter = 0
  52. while len(synonyms) < 1:
  53. random_word = new_words[random.randint(0, len(new_words)-1)]
  54. synonyms = get_synonyms(random_word)
  55. counter += 1
  56. if counter >= 10:
  57. return
  58. random_synonym = random.choice(synonyms)
  59. random_idx = random.randint(0, len(new_words)-1)
  60. new_words.insert(random_idx, random_synonym)
  61. ########################################################################
  62. # Random swap
  63. # Randomly swap two words in the sentence n times
  64. ########################################################################
  65. def random_swap(words, n):
  66. new_words = words.copy()
  67. for _ in range(n):
  68. new_words = swap_word(new_words)
  69. return new_words
  70. def swap_word(new_words):
  71. random_idx_1 = random.randint(0, len(new_words)-1)
  72. random_idx_2 = random_idx_1
  73. counter = 0
  74. while random_idx_2 == random_idx_1:
  75. random_idx_2 = random.randint(0, len(new_words)-1)
  76. counter += 1
  77. if counter > 3:
  78. return new_words
  79. new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1]
  80. return new_words
  81. ########################################################################
  82. # 随机删除
  83. # 以概率p删除语句中的词
  84. ########################################################################
  85. def random_deletion(words, p):
  86. if len(words) == 1:
  87. return words
  88. new_words = []
  89. for word in words:
  90. r = random.uniform(0, 1)
  91. if r > p:
  92. new_words.append(word)
  93. if len(new_words) == 0:
  94. rand_int = random.randint(0, len(words)-1)
  95. return [words[rand_int]]
  96. return new_words
  97. ########################################################################
  98. #EDA函数
  99. def eda(sentence, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=9):
  100. seg_list = jieba.cut(sentence)
  101. seg_list = " ".join(seg_list)
  102. words = list(seg_list.split())
  103. num_words = len(words)
  104. augmented_sentences = []
  105. num_new_per_technique = int(num_aug/4)+1
  106. n_sr = max(1, int(alpha_sr * num_words))
  107. n_ri = max(1, int(alpha_ri * num_words))
  108. n_rs = max(1, int(alpha_rs * num_words))
  109. #print(words, "\n")
  110. #同义词替换sr
  111. for _ in range(num_new_per_technique):
  112. a_words = synonym_replacement(words, n_sr)
  113. augmented_sentences.append(' '.join(a_words))
  114. #随机插入ri
  115. for _ in range(num_new_per_technique):
  116. a_words = random_insertion(words, n_ri)
  117. augmented_sentences.append(' '.join(a_words))
  118. #随机交换rs
  119. for _ in range(num_new_per_technique):
  120. a_words = random_swap(words, n_rs)
  121. augmented_sentences.append(' '.join(a_words))
  122. #随机删除rd
  123. for _ in range(num_new_per_technique):
  124. a_words = random_deletion(words, p_rd)
  125. augmented_sentences.append(' '.join(a_words))
  126. #print(augmented_sentences)
  127. shuffle(augmented_sentences)
  128. if num_aug >= 1:
  129. augmented_sentences = augmented_sentences[:num_aug]
  130. else:
  131. keep_prob = num_aug / len(augmented_sentences)
  132. augmented_sentences = [s for s in augmented_sentences if random.uniform(0, 1) < keep_prob]
  133. augmented_sentences.append(seg_list)
  134. return augmented_sentences

augment.py

  1. from eda import *
  2. import argparse
  3. ap = argparse.ArgumentParser()
  4. ap.add_argument("--input", required=True, type=str, help="原始数据的输入文件目录")
  5. ap.add_argument("--output", required=False, type=str, help="增强数据后的输出文件目录")
  6. ap.add_argument("--num_aug", required=False, type=int, help="每条原始语句增强的语句数")
  7. ap.add_argument("--alpha", required=False, type=float, help="每条语句中将会被改变的单词数占比")
  8. args = ap.parse_args()
  9. #输出文件
  10. output = None
  11. if args.output:
  12. output = args.output
  13. else:
  14. from os.path import dirname, basename, join
  15. output = join(dirname(args.input), 'eda_' + basename(args.input))
  16. #每条原始语句增强的语句数
  17. num_aug = 9 #default
  18. if args.num_aug:
  19. num_aug = args.num_aug
  20. #每条语句中将会被改变的单词数占比
  21. alpha = 0.1 #default
  22. if args.alpha:
  23. alpha = args.alpha
  24. def gen_eda(train_orig, output_file, alpha, num_aug=9):
  25. writer = open(output_file, 'w')
  26. lines = open(train_orig, 'r').readlines()
  27. print("正在使用EDA生成增强语句...")
  28. for i, line in enumerate(lines):
  29. parts = line[:-1].split('\t') #使用[:-1]是把\n去掉了
  30. sentence1 = parts[0]
  31. sentence2 = parts[1]
  32. label = parts[2]
  33. aug_sentences1 = eda(sentence1, alpha_sr=alpha, alpha_ri=alpha, alpha_rs=alpha, p_rd=alpha, num_aug=num_aug)
  34. aug_sentences2 = eda(sentence2, alpha_sr=alpha, alpha_ri=alpha, alpha_rs=alpha, p_rd=alpha, num_aug=num_aug)
  35. for i,aug_sentence1 in enumerate(aug_sentences1):
  36. writer.write(aug_sentence1.replace(' ','') + "\t" + aug_sentences2[i].replace(' ','') + "\t" + label +'\n')
  37. writer.close()
  38. print("已生成增强语句!")
  39. print(output_file)
  40. if __name__ == "__main__":
  41. gen_eda(args.input, output, alpha=alpha, num_aug=num_aug)

增强训练数据集

!python augment.py --input train.txt

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/347187
推荐阅读
相关标签
  

闽ICP备14008679号