当前位置:   article > 正文

生物大分子平台(8)_learn_bpe()

learn_bpe()

生物大分子平台(8)

2021SC@SDUSC

0 本周工作

本周继续阅读本周继续阅读了transformer的代码实现。完成对learn_bpe代码的详细解读

1 BPE分词算法

1.1 简介

BPE算法的分词粒度处于单词级别和字符级别之间。比如说单词"played"和"playing"会被划分为"play","ed”,“ing”,这样在降低词表大小的同时也能学到词的语意信息。
BPE算法的核心主要分成三个部分:词表构建、语料编码和语料解码。

1.2 算法流程

算法将一组词按照字母进行分割,然后统计个数。再在接下来的步骤中,寻找组合出现在这组词的片段,组合片段,删除组合片段中不单独出现的字母。直到词表达到我们设置的期望或者剩下的字节对出现的频率为1。
BPE算法就是利用子词来编码数据,目前算法已经成为机器翻译领域标准的预处理方式。

2 代码分析 LEARN_BPE

2.1 导入库

  • 其中新接触到的库有copy库,copy库是python基础库中的一部分,copy模块提供了通用的深层次复制和浅层次复制。
import os
import sys
import inspect
import codecs
import re
import copy
import warnings
from collections import defaultdict, Counter
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

2.2 具体代码分析

  • 读取文本并返回编码词汇的字典
def update_vocabulary(vocab, file_name, is_dict=False):
    with codecs.open(file_name, encoding='utf-8') as fobj:
        for i, line in enumerate(fobj):
            if is_dict:
                try:
                    word, count = line.strip('\r\n ').split(' ')
                except:
                    print('Failed reading vocabulary file at line {0}: {1}'.format(i, line))
                    sys.exit(1)
                vocab[word] += int(count)
            else:
                for word in line.strip('\r\n ').split(' '):
                    if word:
                        vocab[word] += 1
    return vocab
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 更新词库中词的个数,如果我们合并一对符号,则最少更新符号对的索引和频率,只有与该对出现重叠的对才会受到影响,并且需要更新。
  • 在for循环中通过统计所有的文本找到pair的所有实例,并更新它词表中的频率索引
  • 在接下来的循环中,将一起出现的词在词表中的频率增加,减少原先未组合词的频率
  • 方法是在接下来的循环中,遍历词表,在词表中,确定前一个词,如果后一个词与前一个词组合出现,则将这个组合词的频率增加,并减少未组合词在词表中的数量
  • 举个例子。假设一个符号序列“A B C”,如果合并了“B C”,则降低“A B”的频率,假设一个符号序列“A B C B”,如果“B C”被合并,则降低“C B”的频率。但是,如果序列是 A B C B C,则跳过此步骤,因为“C B”的频率将被前一个代码块降低。假设一个符号序列“A BC D”,如果“B C”被合并,则增加“A BC”的频率。假设一个符号序列“A BC B”,如果“BC”被合并,增加“BC B”的频率,但是如果序列是A BC BC,跳过这一步,因为“BC BC”的计数将增加上一个代码块
def update_pair_statistics(pair, changed, stats, indices):
    stats[pair] = 0
    indices[pair] = defaultdict(int)
    first, second = pair
    new_pair = first+second
    for j, word, old_word, freq in changed:
        i = 0
        while True:
           
            try:
                i = old_word.index(first, i)
            except ValueError:
                break
            if i < len(old_word)-1 and old_word[i+1] == second:

                if i:
                    prev = old_word[i-1:i+1]
                    stats[prev] -= freq
                    indices[prev][j] -= 1
                if i < len(old_word)-2:

                    if old_word[i+2] != first or i >= len(old_word)-3 or old_word[i+3] != second:
                        nex = old_word[i+1:i+3]
                        stats[nex] -= freq
                        indices[nex][j] -= 1
                i += 2
            else:
                i += 1

        i = 0
        while True:
            try:
                # find new pair
                i = word.index(new_pair, i)
            except ValueError:
                break

            if i:
                prev = word[i-1:i+1]
                stats[prev] += freq
                indices[prev][j] += 1
            if i < len(word)-1 and word[i+1] != new_pair:
                nex = word[i:i+2]
                stats[nex] += freq
                indices[nex][j] += 1
            i += 1
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 此方法通过遍历计算所有符号对的频率,并将其添加在词表中。
def get_pair_statistics(vocab):
    stats = defaultdict(int)
    indices = defaultdict(lambda: defaultdict(int))
    for i, (word, freq) in enumerate(vocab):
        prev_char = word[0]
        for char in word[1:]:
            stats[prev_char, char] += freq
            indices[prev_char, char][i] += 1
            prev_char = char
    return stats, indices
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 用合并的新符号替换原来分立的符号。例如,用新符号 ‘AB’ 替换所有出现的符号对 (‘A’, ‘B’)
def replace_pair(pair, vocab, indices):
    first, second = pair
    pair_str = ''.join(pair)
    pair_str = pair_str.replace('\\','\\\\')
    changes = []
    pattern = re.compile(r'(?<!\S)' + re.escape(first + ' ' + second) + r'(?!\S)')
    if sys.version_info < (3, 0):
        iterator = indices[pair].iteritems()
    else:
        iterator = indices[pair].items()
    for j, freq in iterator:
        if freq < 1:
            continue
        word, freq = vocab[j]
        new_word = ' '.join(word)
        new_word = pattern.sub(pair_str, new_word)
        new_word = tuple(new_word.split(' '))
        vocab[j] = (new_word, freq)
        changes.append((j, new_word, word, freq))
    return changes
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 修建词表,如果词表中某一项的数目小于某个阈值,则在词表中对小于阈值的项进行修剪。
def prune_stats(stats, big_stats, threshold):
    for item,freq in list(stats.items()):
        if freq < threshold:
            del stats[item]
            if freq < 0:
                big_stats[item] += freq
            else:
                big_stats[item] = freq
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 从文本中学习分词方法并统计数量,并将其写入outfile。
  • 阈值受到 Zipfian 分布的影响确定。
  • 最后将词输出到outfile中。
def learn_bpe(infile_names, outfile_name, num_symbols, min_frequency=2, verbose=False, is_dict=False, total_symbols=False):
    sys.stderr = codecs.getwriter('UTF-8')(sys.stderr.buffer)
    sys.stdout = codecs.getwriter('UTF-8')(sys.stdout.buffer)
    sys.stdin = codecs.getreader('UTF-8')(sys.stdin.buffer)

    #vocab = get_vocabulary(infile, is_dict)
    vocab = Counter()
    for f in infile_names:
        sys.stderr.write(f'Collecting vocab from {f}\n')
        vocab = update_vocabulary(vocab, f, is_dict)

    vocab = dict([(tuple(x[:-1])+(x[-1]+'</w>',) ,y) for (x,y) in vocab.items()])
    sorted_vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True)

    stats, indices = get_pair_statistics(sorted_vocab)
    big_stats = copy.deepcopy(stats)

    if total_symbols:
        uniq_char_internal = set()
        uniq_char_final = set()
        for word in vocab:
            for char in word[:-1]:
                uniq_char_internal.add(char)
            uniq_char_final.add(word[-1])
        sys.stderr.write('Number of word-internal characters: {0}\n'.format(len(uniq_char_internal)))
        sys.stderr.write('Number of word-final characters: {0}\n'.format(len(uniq_char_final)))
        sys.stderr.write('Reducing number of merge operations by {0}\n'.format(len(uniq_char_internal) + len(uniq_char_final)))
        num_symbols -= len(uniq_char_internal) + len(uniq_char_final)


    sys.stderr.write(f'Write vocab file to {outfile_name}')
    with codecs.open(outfile_name, 'w', encoding='utf-8') as outfile:
        outfile.write('#version: 0.2\n')
        threshold = max(stats.values()) / 10
        for i in range(num_symbols):
            if stats:
                most_frequent = max(stats, key=lambda x: (stats[x], x))

            # we probably missed the best pair because of pruning; go back to full statistics
            if not stats or (i and stats[most_frequent] < threshold):
                prune_stats(stats, big_stats, threshold)
                stats = copy.deepcopy(big_stats)
                most_frequent = max(stats, key=lambda x: (stats[x], x))
                # threshold is inspired by Zipfian assumption, but should only affect speed
                threshold = stats[most_frequent] * i/(i+10000.0)
                prune_stats(stats, big_stats, threshold)

            if stats[most_frequent] < min_frequency:
                sys.stderr.write(f'no pair has frequency >= {min_frequency}. Stopping\n')
                break

            if verbose:
                sys.stderr.write('pair {0}: {1} {2} -> {1}{2} (frequency {3})\n'.format(
                    i, most_frequent[0], most_frequent[1], stats[most_frequent]))
            outfile.write('{0} {1}\n'.format(*most_frequent))
            changes = replace_pair(most_frequent, sorted_vocab, indices)
            update_pair_statistics(most_frequent, changes, stats, indices)
            stats[most_frequent] = 0
            if not i % 100:
                prune_stats(stats, big_stats, threshold)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/369505
推荐阅读
相关标签
  

闽ICP备14008679号