当前位置:   article > 正文

万字逐行解析与实现Transformer,并进行德译英实战(三)_de_core_news_sm

de_core_news_sm

本文由于长度限制,共分为三篇:

  1. 万字逐行解析与实现Transformer,并进行德译英实战(一)
  2. 万字逐行解析与实现Transformer,并进行德译英实战(二)
  3. 万字逐行解析与实现Transformer,并进行德译英实战(三)

你也可以在该项目找到本文的源码。

Part 3: 实战:德译英

现在我们来进行一个案例实战,我们使用Multi30k German-English 翻译任务。虽然这个任务远小于论文中的WMT任务,但也足以阐明整个系统。

数据加载

我们将使用torchtext进行数据加载,并使用spacy进行分词。spacy可以参考这篇文章

加载数据集一定要使用这两个版本torchdata==0.3.0, torchtext==0.12,否则会加载失败。

加载分词模型,如果你还没有下载,请使用如下代码进行下载(代码中也会有):

python -m spacy download de_core_news_sm
python -m spacy download en_core_web_sm
  • 1
  • 2

若在国内使用命令下载失败,请使用离线下载。(注意版本需是3.2.0)de_core_news_sm下载链接en_core_web_sm下载链接

def load_tokenizers():
    """
    加载spacy分词模型
    :return: 返回德语分词模型和英语分词模型
    """

    try:
        spacy_de = spacy.load("de_core_news_sm")
    except IOError:
        # 如果报错,说明还未安装分词模型,进行安装后重新加载
        os.system("python -m spacy download de_core_news_sm")
        spacy_de = spacy.load("de_core_news_sm")

    try:
        spacy_en = spacy.load("en_core_web_sm")
    except IOError:
        os.system("python -m spacy download en_core_web_sm")
        spacy_en = spacy.load("en_core_web_sm")

    # 返回德语分词模型和英语分词模型
    return spacy_de, spacy_en
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
def tokenize(text, tokenizer):
    """
    对text文本进行分词
    :param text: 要分词的文本,例如“I love you”
    :param tokenizer: 分词模型,例如:spacy_en
    :return: 分词结果,例如 ["I", "love", "you"]
    """
    return [tok.text for tok in tokenizer.tokenizer(text)]


def yield_tokens(data_iter, tokenizer, index):
    """
    yield一个token list
    :param data_iter: 包含句子对儿的可迭代对象。例如:
                      [("I love you", "我爱你"), ...]
    :param tokenizer: 分词模型。例如spacy_en
    :param index: 要对句子对儿的哪个语言进行分词,
                  例如0表示对上例的英文进行分词
    :return: yield本轮的分词结果,例如['I', 'love', 'you']
    """
    for from_to_tuple in data_iter:
        yield tokenizer(from_to_tuple[index])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
def build_vocabulary(spacy_de, spacy_en):
    """
    构建德语词典和英语词典
    :return: 返回德语词典和英语词典,均为:Vocab对象
             Vocab对象官方地址为:https://pytorch.org/text/stable/vocab.html#vocab
    """
    # 构建德语分词方法
    def tokenize_de(text):
        return tokenize(text, spacy_de)

    # 构建英语分词方法
    def tokenize_en(text):
        return tokenize(text, spacy_en)

    print("Building German Vocabulary ...")

    """
    其中train, val, test都是可迭代对象。
    例如:next(iter(train)) 返回一个tuple,为:
    ('Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.',
     'Two young, White males are outside near many bushes.')
    """
    train, val, test = datasets.Multi30k(language_pair=("de", "en"))

    """
    build_vocab_from_iterator:根据一个可迭代对象生成一个词典。
    其返回一个Vocab对象,官方地址为:https://pytorch.org/text/stable/vocab.html#vocab

    其接收是三个参数
    1. iterator,需要传入一个可迭代对象。里面为分好词的数据,例如:
                [("I", "love", "you"), ("you", "love", "me")]
    2. min_freq,最小频率,当一个单词的出现频率达到最小频率后才会被
                 算到词典中。例如,如果min_freq=2,则上例中只有“you”
                 会被算到词典中,因为其他单词都只出现一次。
    3.specials, 特殊词汇,例如'<bos>', '<unk>'等。特殊单词会被加到
                 词典的最前面。

    假设我们调用的是:
    vocab = build_vocab_from_iterator(
        [("I", "love", "you"), ("you", "love", "me")],
        min_freq=1,
        specials=["<s>", "</s>"],
    )
    vocab对应的词典则为:{0:<s>, 1:</s>, 2:love, 3:you, 4:I, 5:me}
    """
    vocab_src = build_vocab_from_iterator(
        yield_tokens(train + val + test, tokenize_de, index=0),
        min_freq=2,
        specials=["<s>", "</s>", "<blank>", "<unk>"],
    )


    # 开始构建英语词典,与上面一样
    print("Building English Vocabulary ...")
    train, val, test = datasets.Multi30k(language_pair=("de", "en"))
    vocab_tgt = build_vocab_from_iterator(
        yield_tokens(train + val + test, tokenize_en, index=1),
        min_freq=2,
        specials=["<s>", "</s>", "<blank>", "<unk>"],
    )

    # 设置默认index为`<unk>`,后面对于那些不认识的单词就会自动归为`<unk>`
    vocab_src.set_default_index(vocab_src["<unk>"])
    vocab_tgt.set_default_index(vocab_tgt["<unk>"])

    # 返回构建好的德语词典和英语词典
    return vocab_src, vocab_tgt


def load_vocab(spacy_de, spacy_en):
    """
    加载德语词典和英语词典。由于构建词典的过程需要花费一定时间,
    所以该方法就是对build_vocabulary的进一步封装,增加了缓存机制。
    :return: 返回德语词典和英语词典,均为Vocab对象
    """

    # 如果不存在缓存文件,说明是第一次构建词典
    if not exists("vocab.pt"):
        # 构建词典,并写入缓存文件
        vocab_src, vocab_tgt = build_vocabulary(spacy_de, spacy_en)
        torch.save((vocab_src, vocab_tgt), "vocab.pt")
    else:
        # 如果存在缓存文件,直接加载
        vocab_src, vocab_tgt = torch.load("vocab.pt")
    # 输出些日志:
    print("Finished.\nVocabulary sizes:")
    print("vocab_src size:", len(vocab_src))
    print("vocab_tgt size:", len(vocab_tgt))
    return vocab_src, vocab_tgt


# 全局参数,后续还要用
# 加载德语和英语分词器
spacy_de, spacy_en = load_tokenizers()
# 加载德语词典(源词典)和英语词典(目标词典)
vocab_src, vocab_tgt = load_vocab(spacy_de, spacy_en)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
Building German Vocabulary ...
Building English Vocabulary ...
Finished.
Vocabulary sizes:
vocab_src size: 8315
vocab_tgt size: 6384
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

Iterators

def collate_batch(
    batch,
    src_pipeline,
    tgt_pipeline,
    src_vocab,
    tgt_vocab,
    device,
    max_padding=128,
    pad_id=2,
):
    """
    Dataloader中的collate_fn函数。该函数的作用是:将文本句子处理成数字句子,然后pad到固定长度,最终batch到一起

    :param batch: 一个batch的语句对。例如:
                  [('Ein Kleinkind ...', 'A toddler in ...'), # [(德语), (英语)
                   ....                                       # ...
                   ...]                                       # ... ]
    :param src_pipeline: 德语分词器,也就是tokenize_de方法,后面会定义
                         其实就是对spacy_de的封装
    :param tgt_pipeline: 英语分词器,也就是tokenize_en方法
    :param src_vocab: 德语词典,Vocab对象
    :param tgt_vocab: 英语词典,Vocab对象
    :param device: cpu或cuda
    :param max_padding: 句子的长度。pad长度不足的句子和裁剪长度过长的句子,
                        目的是让不同长度的句子可以组成一个tensor
    :param pad_id: '<blank>'在词典中对应的index
    :return: src和tgt。处理后并batch后的句子。例如:
             src为:[[0, 4354, 314, ..., 1, 2, 2, ..., 2],  [0, 4905, 8567, ..., 1, 2, 2, ..., 2]]
             其中0是<bos>, 1是<eos>, 2是<blank>
             src的Shape为(batch_size, max_padding)
             tgt同理。
    """

    # 定义'<bos>'的index,在词典中为0,所以这里也是0
    bs_id = torch.tensor([0], device=device)  # <s> token id
    # 定义'<eos>'的index
    eos_id = torch.tensor([1], device=device)  # </s> token id

    # 用于存储处理后的src和tgt
    src_list, tgt_list = [], []
    # 循环遍历句子对儿
    for (_src, _tgt) in batch:
        """
        _src: 德语句子,例如:Ein Junge wirft Blätter in die Luft.
        _tgt: 英语句子,例如:A boy throws leaves into the air.
        """

        """
        将句子进行分词,并将词转成对应的index。例如:
        "I love you" -> ["I", "love", "you"] ->
        [1136, 2468, 1349] -> [0, 1136, 2468, 1349, 1]
        其中0,1是<bos>和<eos>。

        Vocab对象可以将list中的词转为index,例如:
        `vocab_tgt(["I", "love", "you"])` 的输出为:
        [1136, 2468, 1349]
        """
        processed_src = torch.cat(
            # 将<bos>,句子index和<eos>拼到一块
            [
                bs_id,
                torch.tensor(
                    # 进行分词后,转换为index。
                    src_vocab(src_pipeline(_src)),
                    dtype=torch.int64,
                    device=device,
                ),
                eos_id,
            ],
            0,
        )
        processed_tgt = torch.cat(
            [
                bs_id,
                torch.tensor(
                    tgt_vocab(tgt_pipeline(_tgt)),
                    dtype=torch.int64,
                    device=device,
                ),
                eos_id,
            ],
            0,
        )

        """
        将长度不足的句子进行填充到max_padding的长度的,然后增添到list中

        pad:假设processed_src为[0, 1136, 2468, 1349, 1]
             第二个参数为: (0, 72-5)
             第三个参数为:2
        则pad的意思表示,给processed_src左边填充0个2,右边填充67个2。
        最终结果为:[0, 1136, 2468, 1349, 1, 2, 2, 2, ..., 2]
        """
        src_list.append(
            pad(
                processed_src,
                (0, max_padding - len(processed_src),),
                value=pad_id,
            )
        )
        tgt_list.append(
            pad(
                processed_tgt,
                (0, max_padding - len(processed_tgt),),
                value=pad_id,
            )
        )

    # 将多个src句子堆叠到一起
    src = torch.stack(src_list)
    tgt = torch.stack(tgt_list)

    # 返回batch后的结果
    return (src, tgt)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
def create_dataloaders(
    device,
    vocab_src,
    vocab_tgt,
    spacy_de,
    spacy_en,
    batch_size=12000,
    max_padding=128
):
    """
    创建train_dataloader和valid_dataloader
    :param device: cpu或cuda
    :param vocab_src: 源词典,本例中为德语词典
    :param vocab_tgt: 目标词典,本例中为英语词典
    :param spacy_de: 德语分词器
    :param spacy_en: 英语分词器
    :param batch_size: batch_size
    :param max_padding: 句子的最大长度

    :return: train_dataloader和valid_dataloader
    """

    # 定义德语分词器
    def tokenize_de(text):
        return tokenize(text, spacy_de)

    # 定义英语分词器
    def tokenize_en(text):
        return tokenize(text, spacy_en)

    # 创建批处理工具,即应该如何将一批数据汇总成一个Batch
    def collate_fn(batch):
        return collate_batch(
            batch,
            tokenize_de,
            tokenize_en,
            vocab_src,
            vocab_tgt,
            device,
            max_padding=max_padding,
            pad_id=vocab_src.get_stoi()["<blank>"],
        )

    # 加载数据集
    train_iter, valid_iter, test_iter = datasets.Multi30k(
        language_pair=("de", "en")
    )

    """
    将Iterator类型的Dataset转为Map类型的Dataset。如果你不熟悉,可以参考:
    https://blog.csdn.net/zhaohongfei_358/article/details/122742656

    经过测试,发现其实不转也可以。效果没差别
    """
    train_iter_map = to_map_style_dataset(train_iter)
    valid_iter_map = to_map_style_dataset(valid_iter)

    # 构建DataLoader,若DataLoader不熟悉,请参考文章:
    # https://blog.csdn.net/zhaohongfei_358/article/details/122742656
    train_dataloader = DataLoader(
        train_iter_map,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
    )
    valid_dataloader = DataLoader(
        valid_iter_map,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
    )
    return train_dataloader, valid_dataloader
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72

训练模型

def train_worker(
    device,
    vocab_src,
    vocab_tgt,
    spacy_de,
    spacy_en,
    config,
    is_distributed=False,
):
    """
    训练模型
    :param device: cpu或cuda
    :param vocab_src: 源词典,本例中为德语词典
    :param vocab_tgt: 目标词典,本例中为英语词典
    :param spacy_de: 德语分词器
    :param spacy_en: 英语分词器
    :param config: 一个保存了配置参数的dict,例如学习率啥的
    """

    print(f"Train worker process using device: {device} for training")

    # 找出目标词典中‘<blank>’所对应的index
    pad_idx = vocab_tgt["<blank>"]
    # 设置词向量大小。
    d_model = 512
    # 构建模型,Layer数为6
    model = make_model(len(vocab_src), len(vocab_tgt), N=6)
    model.to(device)

    # 定义损失函数
    criterion = LabelSmoothing(
        size=len(vocab_tgt), padding_idx=pad_idx, smoothing=0.1
    )
    criterion.to(device)

    # 创建train_dataloader和valid_dataloader
    train_dataloader, valid_dataloader = create_dataloaders(
        device,
        vocab_src,
        vocab_tgt,
        spacy_de,
        spacy_en,
        batch_size=config["batch_size"],
        max_padding=config["max_padding"]
    )

    # 创建Adam优化器
    optimizer = torch.optim.Adam(
        model.parameters(), lr=config["base_lr"], betas=(0.9, 0.98), eps=1e-9
    )

    # 定义Warmup学习率策略
    lr_scheduler = LambdaLR(
        optimizer=optimizer,
        lr_lambda=lambda step: rate(
            step, d_model, factor=1, warmup=config["warmup"]
        ),
    )

    # 创建train_state,保存训练状态
    train_state = TrainState()

    # 开始训练
    for epoch in range(config["num_epochs"]):
        model.train()
        print(f"[Epoch {epoch} Training ====", flush=True)
        _, train_state = run_epoch(
            (Batch(b[0], b[1], pad_idx) for b in train_dataloader),
            model,
            SimpleLossCompute(model.generator, criterion),
            optimizer,
            lr_scheduler,
            mode="train+log",
            accum_iter=config["accum_iter"],
            train_state=train_state,
        )

        """
        展示GPU使用情况,例如:
        | ID | GPU | MEM |
        ------------------
        |  0 | 11% |  6% |
        """
        if torch.cuda.is_available():
            GPUtil.showUtilization()

        # 每训练一个epoch保存一次模型
        file_path = "%s%.2d.pt" % (config["file_prefix"], epoch)
        torch.save(model.state_dict(), file_path)

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # 在一个epoch后,进行模型验证
        print(f"[Epoch {epoch} Validation ====")
        model.eval()
        # 跑验证集中的数据,看看loss有多少
        sloss = run_epoch(
            (Batch(b[0], b[1], pad_idx) for b in valid_dataloader),
            model,
            SimpleLossCompute(model.generator, criterion),
            DummyOptimizer(),
            DummyScheduler(),
            mode="eval",
        )
        # 打印验证集的Loss
        print("Validation Loss:", sloss[0].data)

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    # 全部epoch训练完毕后,保存模型
    file_path = "%sfinal.pt" % config["file_prefix"]
    torch.save(model.state_dict(), file_path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
def load_trained_model():
    """
    加载模型或训练模型。
    若没有找到模型,说明没有训练过,则进行训练
    :return: Transformer对象,即EncoderDecoder类对象
    """

    # 定义一些模型训练参数
    config = {
        "batch_size": 32,
        "num_epochs": 8, # epoch数量
        "accum_iter": 10, # 每10个batch更新一次模型参数
        "base_lr": 1.0,  # 基础学习率,根据这个学习率进行warmup
        "max_padding": 72, # 句子的最大长度
        "warmup": 3000,  # Warmup3000次,也就是从第3000次学习率开始下降
        "file_prefix": "multi30k_model_", # 模型文件前缀名
    }

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model_path = "multi30k_model_final.pt"

    # 如果模型不存在,则训练一个模型
    if not exists(model_path):
        train_worker(device, vocab_src, vocab_tgt, spacy_de, spacy_en, config)

    # 初始化模型实例
    model = make_model(len(vocab_src), len(vocab_tgt), N=6)
    # 加载模型参数
    model.load_state_dict(torch.load("multi30k_model_final.pt"))
    return model


# 加载或训练模型
model = load_trained_model()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
Train worker process using device: cuda for training
[Epoch 0 Training ====
Epoch Step:      1 | Accumulation Step:   1 | Loss:   7.65 | Tokens / Sec:  2701.9 | Learning Rate: 5.4e-07
...略
Epoch Step:    881 | Accumulation Step:  89 | Loss:   1.03 | Tokens / Sec:  2758.8 | Learning Rate: 5.2e-04
| ID | GPU | MEM |
------------------
|  0 | 57% | 29% |
[Epoch 7 Validation ====
Validation Loss: tensor(1.4455, device='cuda:0')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

测试结果

在最后我们可以使用验证集来简单的测试一下我们的模型

# Load data and model for output checks
def check_outputs(
    valid_dataloader,
    model,
    vocab_src,
    vocab_tgt,
    n_examples=15,
    pad_idx=2,
    eos_string="</s>",
):
    results = [()] * n_examples
    for idx in range(n_examples):
        print("\nExample %d ========\n" % idx)
        b = next(iter(valid_dataloader))
        rb = Batch(b[0], b[1], pad_idx)
        greedy_decode(model, rb.src, rb.src_mask, 64, 0)[0]

        src_tokens = [
            vocab_src.get_itos()[x] for x in rb.src[0] if x != pad_idx
        ]
        tgt_tokens = [
            vocab_tgt.get_itos()[x] for x in rb.tgt[0] if x != pad_idx
        ]

        print(
            "Source Text (Input)        : "
            + " ".join(src_tokens).replace("\n", "")
        )
        print(
            "Target Text (Ground Truth) : "
            + " ".join(tgt_tokens).replace("\n", "")
        )
        model_out = greedy_decode(model, rb.src, rb.src_mask, 72, 0)[0]
        model_txt = (
            " ".join(
                [vocab_tgt.get_itos()[x] for x in model_out if x != pad_idx]
            ).split(eos_string, 1)[0]
            + eos_string
        )
        print("Model Output               : " + model_txt.replace("\n", ""))
        results[idx] = (rb, src_tokens, tgt_tokens, model_out, model_txt)
    return results


def run_model_example(n_examples=5):
    global vocab_src, vocab_tgt, spacy_de, spacy_en

    print("Preparing Data ...")
    _, valid_dataloader = create_dataloaders(
        torch.device("cpu"),
        vocab_src,
        vocab_tgt,
        spacy_de,
        spacy_en,
        batch_size=1,
        is_distributed=False,
    )

    print("Loading Trained Model ...")

    model = make_model(len(vocab_src), len(vocab_tgt), N=6)
    model.load_state_dict(
        torch.load("multi30k_model_final.pt", map_location=torch.device("cpu"))
    )

    print("Checking Model Outputs:")
    example_data = check_outputs(
        valid_dataloader, model, vocab_src, vocab_tgt, n_examples=n_examples
    )
    return model, example_data


run_model_example()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73

-------完结,撒花--------

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/378271
推荐阅读
相关标签
  

闽ICP备14008679号