当前位置:   article > 正文

意图识别bert_bert隐式意图判断

bert隐式意图判断
import sys
import os
import copy
import json
import logging
import argparse
import torch
import numpy as np
from tqdm import tqdm, trange
import torch.nn as nn
from torchcrf import CRF
from torch.utils.data import TensorDataset
from seqeval.metrics import precision_score, recall_score, f1_score

from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import BertConfig, AdamW, get_linear_schedule_with_warmup
from transformers import (
    BertModel,
    BertTokenizer,
    BertPreTrainedModel,
)

sys.argv=['']
del sys
logger = logging.getLogger(__name__)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

定义加载分词器tokenizer的函数(tokenizer是bert自己的分词器,它可以把词分开并且变为one hot编码)

def load_tokenizer(args):
    tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
    return tokenizer
  • 1
  • 2
  • 3

定义其他的函数

def get_intent_acc(preds, labels):
    acc = (preds == labels).mean()
    return {
        "intent_acc": acc
    }


def get_slot_metrics(preds, labels):
    assert len(preds) == len(labels)
    return {
        "slot_precision": precision_score(labels, preds),
        "slot_recall": recall_score(labels, preds),
        "slot_f1": f1_score(labels, preds)
    }


def get_sentence_frame_acc(intent_preds, intent_labels, slot_preds, slot_labels):
    """For the cases that intent and all the slots are correct (in one sentence)"""
    # Get the intent comparison result
    intent_result = (intent_preds == intent_labels)

    # Get the slot comparision result
    slot_result = []
    for preds, labels in zip(slot_preds, slot_labels):
        assert len(preds) == len(labels)
        one_sent_result = True
        for p, l in zip(preds, labels):
            if p != l:
                one_sent_result = False
                break
        slot_result.append(one_sent_result)
    slot_result = np.array(slot_result)

    sementic_acc = np.multiply(intent_result, slot_result).mean()
    return {
        "sementic_frame_acc": sementic_acc
    }


def compute_metrics(intent_preds, intent_labels, slot_preds, slot_labels):
    assert len(intent_preds) == len(intent_labels) == len(slot_preds) == len(slot_labels)
    results = {}
    intent_result = get_intent_acc(intent_preds, intent_labels)
    slot_result = get_slot_metrics(slot_preds, slot_labels)
    sementic_result = get_sentence_frame_acc(intent_preds, intent_labels, slot_preds, slot_labels)

    results.update(intent_result)
    results.update(slot_result)
    results.update(sementic_result)

    return results

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52

argparse 模块可以让人轻松编写用户友好的命令行接口。程序定义它需要的参数,然后 argparse 将弄清如何从 sys.argv 解析出那些参数。 argparse 模块还会自动生成帮助和使用手册,并在用户给程序传入无效参数时报出错误信息。
https://docs.python.org/zh-cn/3/library/argparse.html

创建argparse 模块, 设置我们的task名称
parser = argparse.ArgumentParser()

parser.add_argument("--task", default='atis', required=False, type=str, help="The name of the task to train")
  • 1
  • 2
  • 3
_StoreAction(option_strings=['--task'], dest='task', nargs=None, const=None, default='atis', type=<class 'str'>, choices=None, help='The name of the task to train', metavar=None)
  • 1
添加其他超参数
parser.add_argument("--model_dir", default="./save_model", required=False, type=str, help="Path to save, load model")
parser.add_argument("--data_dir", default="./data", type=str, help="The input data dir")
parser.add_argument("--intent_label_file", default="intent_label.txt", type=str, help="Intent Label file")
parser.add_argument("--slot_label_file", default="slot_label.txt", type=str, help="Slot Label file")

parser.add_argument("--model_type", default="bert", type=str, help=" Bert is the Model")

parser.add_argument('--seed', type=int, default=1234, help="random seed for initialization")
parser.add_argument("--train_batch_size", default=32, type=int, help="Batch size for training.")
parser.add_argument("--eval_batch_size", default=64, type=int, help="Batch size for evaluation.")
parser.add_argument("--max_seq_len", default=50, type=int, help="The maximum total input sequence length after tokenization.")
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--num_train_epochs", default=2.0, type=float, help="Total number of training epochs to perform.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                    help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--dropout_rate", default=0.1, type=float, help="Dropout for fully-connected layers")

parser.add_argument('--logging_steps', type=int, default=200, help="Log every X updates steps.")
parser.add_argument('--save_steps', type=int, default=200, help="Save checkpoint every X updates steps.")

parser.add_argument("--do_train", default=True, action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", default=True, action="store_true", help="Whether to run eval on the test set.")
parser.add_argument("--no_cuda", default=True, action="store_true", help="Avoid using CUDA when available")

parser.add_argument("--ignore_index", default=0, type=int,
                    help='Specifies a target value that is ignored and does not contribute to the input gradient')

parser.add_argument('--slot_loss_coef', type=float, default=1.0, help='Coefficient for the slot loss.')

# CRF option
parser.add_argument("--use_crf", default=True, action="store_true", help="Whether to use CRF")
parser.add_argument("--slot_pad_label", default="PAD", type=str, help="Pad token for slot label pad (to be ignore when calculate loss)")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
_StoreAction(option_strings=['--slot_pad_label'], dest='slot_pad_label', nargs=None, const=None, default='PAD', type=<class 'str'>, choices=None, help='Pad token for slot label pad (to be ignore when calculate loss)', metavar=None)
  • 1
实例化parser
args = parser.parse_args()

args.model_name_or_path = 'bert-base-uncased'
  • 1
  • 2
  • 3
看看arg里面有什么
args
  • 1
Namespace(adam_epsilon=1e-08, data_dir='./data', do_eval=True, do_train=True, dropout_rate=0.1, eval_batch_size=64, gradient_accumulation_steps=1, ignore_index=0, intent_label_file='intent_label.txt', learning_rate=5e-05, logging_steps=200, max_grad_norm=1.0, max_seq_len=50, max_steps=-1, model_dir='./save_model', model_name_or_path='bert-base-uncased', model_type='bert', no_cuda=True, num_train_epochs=2.0, save_steps=200, seed=1234, slot_label_file='slot_label.txt', slot_loss_coef=1.0, slot_pad_label='PAD', task='atis', train_batch_size=32, use_crf=True, warmup_steps=0, weight_decay=0.0)
  • 1

数据输入模块

官方文档 https://huggingface.co/transformers/main_classes/processors.html?highlight=inputexample#transformers.data.processors.utils.InputExample

class InputExample(object):
    """
    A single training/test example for simple sequence classification.

    Args:
        guid: Unique id for the example.
        words: list. The words of the sequence.
        intent_label: (Optional) string. The intent label of the example.
        slot_labels: (Optional) list. The slot labels of the example.
    """

    def __init__(self, guid, words, intent_label=None, slot_labels=None):
        self.guid = guid
        self.words = words
        self.intent_label = intent_label
        self.slot_labels = slot_labels

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, attention_mask, token_type_ids, intent_label_id, slot_labels_ids):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.intent_label_id = intent_label_id
        self.slot_labels_ids = slot_labels_ids

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
class JointProcessor(object):
    """Processor for the JointBERT data set """

    def __init__(self, args):
        self.args = args
        self.intent_labels = [label.strip() for label in open("data/atis/intent_label.txt", 'r', encoding='utf-8')]
        self.slot_labels = [label.strip() for label in open("data/atis/slot_label.txt", 'r', encoding='utf-8')]


        self.input_text_file = 'seq.in'
        self.intent_label_file = 'label'
        self.slot_labels_file = 'seq.out'

    @classmethod
    def _read_file(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r", encoding="utf-8") as f:
            lines = []
            for line in f:
                lines.append(line.strip())
            return lines

    def _create_examples(self, texts, intents, slots, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for i, (text, intent, slot) in enumerate(zip(texts, intents, slots)):
            guid = "%s-%s" % (set_type, i)
            # 1. input_text
            words = text.split()  # Some are spaced twice
            # 2. intent
            intent_label = self.intent_labels.index(intent) if intent in self.intent_labels else self.intent_labels.index("UNK")
            # 3. slot
            slot_labels = []
            for s in slot.split():
                slot_labels.append(self.slot_labels.index(s) if s in self.slot_labels else self.slot_labels.index("UNK"))

            assert len(words) == len(slot_labels)
            examples.append(InputExample(guid=guid, words=words, intent_label=intent_label, slot_labels=slot_labels))
        return examples

    def get_examples(self, mode):
        """
        Args:
            mode: train, dev, test
        """
        data_path = os.path.join(self.args.data_dir, self.args.task, mode)
        logger.info("LOOKING AT {}".format(data_path))
        return self._create_examples(texts=self._read_file(os.path.join(data_path, self.input_text_file)),
                                     intents=self._read_file(os.path.join(data_path, self.intent_label_file)),
                                     slots=self._read_file(os.path.join(data_path, self.slot_labels_file)),
                                     set_type=mode)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51

把数据转为输入bert的格式

def convert_examples_to_features(examples, max_seq_len, tokenizer,
                                 pad_token_label_id=-100,
                                 cls_token_segment_id=0,
                                 pad_token_segment_id=0,
                                 sequence_a_segment_id=0,
                                 mask_padding_with_zero=True):
    # Setting based on the current model type
    cls_token = tokenizer.cls_token
    sep_token = tokenizer.sep_token
    unk_token = tokenizer.unk_token
    pad_token_id = tokenizer.pad_token_id

    features = []
    for (ex_index, example) in enumerate(examples):
        if ex_index % 5000 == 0:
            logger.info("Writing example %d of %d" % (ex_index, len(examples)))

        # Tokenize word by word (for NER)
        tokens = []
        slot_labels_ids = []
        for word, slot_label in zip(example.words, example.slot_labels):
            word_tokens = tokenizer.tokenize(word)
            if not word_tokens:
                word_tokens = [unk_token]  # For handling the bad-encoded word
            tokens.extend(word_tokens)
            # Use the real label id for the first token of the word, and padding ids for the remaining tokens
            slot_labels_ids.extend([int(slot_label)] + [pad_token_label_id] * (len(word_tokens) - 1))

        # Account for [CLS] and [SEP]
        special_tokens_count = 2
        if len(tokens) > max_seq_len - special_tokens_count:
            tokens = tokens[:(max_seq_len - special_tokens_count)]
            slot_labels_ids = slot_labels_ids[:(max_seq_len - special_tokens_count)]

        # Add [SEP] token
        tokens += [sep_token]
        slot_labels_ids += [pad_token_label_id]
        token_type_ids = [sequence_a_segment_id] * len(tokens)

        # Add [CLS] token
        tokens = [cls_token] + tokens
        slot_labels_ids = [pad_token_label_id] + slot_labels_ids
        token_type_ids = [cls_token_segment_id] + token_type_ids

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding_length = max_seq_len - len(input_ids)
        input_ids = input_ids + ([pad_token_id] * padding_length)
        attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
        token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
        slot_labels_ids = slot_labels_ids + ([pad_token_label_id] * padding_length)

        assert len(input_ids) == max_seq_len, "Error with input length {} vs {}".format(len(input_ids), max_seq_len)
        assert len(attention_mask) == max_seq_len, "Error with attention mask length {} vs {}".format(len(attention_mask), max_seq_len)
        assert len(token_type_ids) == max_seq_len, "Error with token type length {} vs {}".format(len(token_type_ids), max_seq_len)
        assert len(slot_labels_ids) == max_seq_len, "Error with slot labels length {} vs {}".format(len(slot_labels_ids), max_seq_len)

        intent_label_id = int(example.intent_label)

        features.append(
            InputFeatures(input_ids=input_ids,
                          attention_mask=attention_mask,
                          token_type_ids=token_type_ids,
                          intent_label_id=intent_label_id,
                          slot_labels_ids=slot_labels_ids
                          ))

    return features
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73

把数据转为dataset

def load_and_cache_examples(tokenizer, mode):
    processor = JointProcessor

    # Load data features from cache or dataset file
    cached_features_file = os.path.join(
        args.data_dir,
        'cached_{}_{}_{}_{}'.format(
            mode,
            "atis",
            list(filter(None, args.model_name_or_path .split("/"))).pop(),
            args.max_seq_len
        )
    )

    if os.path.exists(cached_features_file):
        logger.info("Loading features from cached file %s", cached_features_file)
        features = torch.load(cached_features_file)
    else:
        # Load data features from dataset file
        logger.info("Creating features from dataset file at %s", args.data_dir)
        if mode == "train":
            examples = processor.get_examples("train")
        elif mode == "dev":
            examples = processor.get_examples("dev")
        elif mode == "test":
            examples = processor.get_examples("test")
        else:
            raise Exception("For mode, Only train, dev, test is available")

        # Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
        pad_token_label_id = 0
        features = convert_examples_to_features(examples, args.max_seq_len, tokenizer,
                                                pad_token_label_id=pad_token_label_id)
        logger.info("Saving features into cached file %s", cached_features_file)
        torch.save(features, cached_features_file)

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    all_intent_label_ids = torch.tensor([f.intent_label_id for f in features], dtype=torch.long)
    all_slot_labels_ids = torch.tensor([f.slot_labels_ids for f in features], dtype=torch.long)

    dataset = TensorDataset(all_input_ids, all_attention_mask,
                            all_token_type_ids, all_intent_label_ids, all_slot_labels_ids)
    return dataset
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46

制作数据集

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

train_dataset = load_and_cache_examples(tokenizer, mode="train")
dev_dataset = load_and_cache_examples(tokenizer, mode="dev")
test_dataset = load_and_cache_examples(tokenizer, mode="test")
  • 1
  • 2
  • 3
  • 4
  • 5

模型

意图识别分类器

class IntentClassifier(nn.Module):
    def __init__(self, input_dim, num_intent_labels, dropout_rate=0.):
        super(IntentClassifier, self).__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.linear = nn.Linear(input_dim, num_intent_labels)

    def forward(self, x):
        x = self.dropout(x)
        return self.linear(x)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

词槽识别

class SlotClassifier(nn.Module):
    def __init__(self, input_dim, num_slot_labels, dropout_rate=0.):
        super(SlotClassifier, self).__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.linear = nn.Linear(input_dim, num_slot_labels)

    def forward(self, x):
        x = self.dropout(x)
        return self.linear(x)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

JointBERT 模型定义

class JointBERT(BertPreTrainedModel):
    def __init__(self, config, args, intent_label_lst, slot_label_lst):
        super(JointBERT, self).__init__(config)
        self.args = args
        self.num_intent_labels = len(intent_label_lst)
        self.num_slot_labels = len(slot_label_lst)
        self.bert = BertModel(config=config)  # Load pretrained bert

        self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)
        self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels, args.dropout_rate)

        if args.use_crf:
            self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, token_type_ids, intent_label_ids, slot_labels_ids):
        outputs = self.bert(input_ids, attention_mask=attention_mask,
                            token_type_ids=token_type_ids)  # sequence_output, pooled_output, (hidden_states), (attentions)
        sequence_output = outputs[0]
        pooled_output = outputs[1]  # [CLS]

        intent_logits = self.intent_classifier(pooled_output)
        slot_logits = self.slot_classifier(sequence_output)

        total_loss = 0
        # 1. Intent Softmax
        if intent_label_ids is not None:
            if self.num_intent_labels == 1:
                intent_loss_fct = nn.MSELoss()
                intent_loss = intent_loss_fct(intent_logits.view(-1), intent_label_ids.view(-1))
            else:
                intent_loss_fct = nn.CrossEntropyLoss()
                intent_loss = intent_loss_fct(intent_logits.view(-1, self.num_intent_labels), intent_label_ids.view(-1))
            total_loss += intent_loss

        # 2. Slot Softmax
        if slot_labels_ids is not None:
            if self.args.use_crf:
                slot_loss = self.crf(slot_logits, slot_labels_ids, mask=attention_mask.byte(), reduction='mean')
                slot_loss = -1 * slot_loss  # negative log-likelihood
            else:
                slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
                # Only keep active parts of the loss
                if attention_mask is not None:
                    active_loss = attention_mask.view(-1) == 1
                    active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
                    active_labels = slot_labels_ids.view(-1)[active_loss]
                    slot_loss = slot_loss_fct(active_logits, active_labels)
                else:
                    slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))
            total_loss += self.args.slot_loss_coef * slot_loss

        outputs = ((intent_logits, slot_logits),) + outputs[2:]  # add hidden states and attention if they are here

        outputs = (total_loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of intent and slot logits

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57

训练模型

class Trainer(object):
    def __init__(self, args, train_dataset=None, dev_dataset=None, test_dataset=None):
        self.args = args
        self.train_dataset = train_dataset
        self.dev_dataset = dev_dataset
        self.test_dataset = test_dataset

        self.intent_label_lst = [label.strip() for label in open("data/atis/intent_label.txt", 'r', encoding='utf-8')]
        self.slot_label_lst = [label.strip() for label in open("data/atis/slot_label.txt", 'r', encoding='utf-8')]

        # Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
        self.pad_token_label_id = args.ignore_index

        self.config = BertConfig.from_pretrained(self.args.model_name_or_path, finetuning_task=self.args.task)
        self.model = JointBERT.from_pretrained(self.args.model_name_or_path,
                                                      config=self.config,
                                                      args=self.args,
                                                      intent_label_lst=self.intent_label_lst,
                                                      slot_label_lst=self.slot_label_lst)


        # GPU or CPU
        self.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
        self.model.to(self.device)

    def train(self):
        train_sampler = RandomSampler(self.train_dataset)
        train_dataloader = DataLoader(self.train_dataset, sampler=train_sampler, batch_size=self.args.train_batch_size)

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            self.args.num_train_epochs = self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
        else:
            t_total = len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs

        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': self.args.weight_decay},
            {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=t_total)

        global_step = 0
        tr_loss = 0.0
        self.model.zero_grad()

        train_iterator = trange(int(self.args.num_train_epochs), desc="Epoch")

        for _ in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            for step, batch in enumerate(epoch_iterator):
                self.model.train()
                batch = tuple(t.to(self.device) for t in batch)  # GPU or CPU

                inputs = {'input_ids': batch[0],
                          'attention_mask': batch[1],
                          'intent_label_ids': batch[3],
                          'slot_labels_ids': batch[4]}
                if self.args.model_type != 'distilbert':
                    inputs['token_type_ids'] = batch[2]
                outputs = self.model(**inputs)
                loss = outputs[0]

                if self.args.gradient_accumulation_steps > 1:
                    loss = loss / self.args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                if (step + 1) % self.args.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)

                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    self.model.zero_grad()
                    global_step += 1

                    if self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0:
                        self.evaluate("dev")

                    if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
                        self.save_model()

                if 0 < self.args.max_steps < global_step:
                    epoch_iterator.close()
                    break

            if 0 < self.args.max_steps < global_step:
                train_iterator.close()
                break

        return global_step, tr_loss / global_step

    def evaluate(self, mode):
        if mode == 'test':
            dataset = self.test_dataset
        elif mode == 'dev':
            dataset = self.dev_dataset
        else:
            raise Exception("Only dev and test dataset available")

        eval_sampler = SequentialSampler(dataset)
        eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=self.args.eval_batch_size)


        eval_loss = 0.0
        nb_eval_steps = 0
        intent_preds = None
        slot_preds = None
        out_intent_label_ids = None
        out_slot_labels_ids = None

        self.model.eval()

        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            batch = tuple(t.to(self.device) for t in batch)
            with torch.no_grad():
                inputs = {'input_ids': batch[0],
                          'attention_mask': batch[1],
                          'intent_label_ids': batch[3],
                          'slot_labels_ids': batch[4]}
                if self.args.model_type != 'distilbert':
                    inputs['token_type_ids'] = batch[2]

                outputs = self.model(**inputs)
                tmp_eval_loss, (intent_logits, slot_logits) = outputs[:2]

                eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1

            # Intent prediction
            if intent_preds is None:
                intent_preds = intent_logits.detach().cpu().numpy()
                out_intent_label_ids = inputs['intent_label_ids'].detach().cpu().numpy()
            else:
                intent_preds = np.append(intent_preds, intent_logits.detach().cpu().numpy(), axis=0)
                out_intent_label_ids = np.append(
                    out_intent_label_ids, inputs['intent_label_ids'].detach().cpu().numpy(), axis=0)

            # Slot prediction
            if slot_preds is None:
                if self.args.use_crf:
                    # decode() in `torchcrf` returns list with best index directly
                    slot_preds = np.array(self.model.crf.decode(slot_logits))
                else:
                    slot_preds = slot_logits.detach().cpu().numpy()

                out_slot_labels_ids = inputs["slot_labels_ids"].detach().cpu().numpy()
            else:
                if self.args.use_crf:
                    slot_preds = np.append(slot_preds, np.array(self.model.crf.decode(slot_logits)), axis=0)
                else:
                    slot_preds = np.append(slot_preds, slot_logits.detach().cpu().numpy(), axis=0)

                out_slot_labels_ids = np.append(out_slot_labels_ids, inputs["slot_labels_ids"].detach().cpu().numpy(), axis=0)

        eval_loss = eval_loss / nb_eval_steps
        results = {
            "loss": eval_loss
        }

        # Intent result
        intent_preds = np.argmax(intent_preds, axis=1)

        # Slot result
        if not self.args.use_crf:
            slot_preds = np.argmax(slot_preds, axis=2)
        slot_label_map = {i: label for i, label in enumerate(self.slot_label_lst)}
        out_slot_label_list = [[] for _ in range(out_slot_labels_ids.shape[0])]
        slot_preds_list = [[] for _ in range(out_slot_labels_ids.shape[0])]

        for i in range(out_slot_labels_ids.shape[0]):
            for j in range(out_slot_labels_ids.shape[1]):
                if out_slot_labels_ids[i, j] != self.pad_token_label_id:
                    out_slot_label_list[i].append(slot_label_map[out_slot_labels_ids[i][j]])
                    slot_preds_list[i].append(slot_label_map[slot_preds[i][j]])

        total_result = compute_metrics(intent_preds, out_intent_label_ids, slot_preds_list, out_slot_label_list)
        results.update(total_result)

        logger.info("***** Eval results *****")
        for key in sorted(results.keys()):
            # logger.info("  %s = %s", key, str(results[key]))
            print("  %s = %s" %(key, str(results[key])))


        return results

    def save_model(self):
        # Save model checkpoint (Overwrite)
        if not os.path.exists(self.args.model_dir):
            os.makedirs(self.args.model_dir)
        model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
        model_to_save.save_pretrained(self.args.model_dir)

        # Save training arguments together with the trained model
        torch.save(self.args, os.path.join(self.args.model_dir, 'training_args.bin'))
        print("Saving model checkpoint to %s" % self.args.model_dir)

    def load_model(self):
        # Check whether model exists
        if not os.path.exists(self.args.model_dir):
            raise Exception("Model doesn't exists! Train first!")

        try:
            self.model = self.model_class.from_pretrained(self.args.model_dir,
                                                          args=self.args,
                                                          intent_label_lst=self.intent_label_lst,
                                                          slot_label_lst=self.slot_label_lst)
            self.model.to(self.device)
            print("***** Model Loaded *****")
        except:
            raise Exception("Some model files might be missing...")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216

开始训练模型

trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)
  • 1
trainer.train()
  • 1
Epoch:   0%|          | 0/2 [00:00<?, ?it/s]
Iteration:   0%|          | 0/140 [00:00<?, ?it/s][A
Iteration:   1%|          | 1/140 [00:01<03:36,  1.56s/it][A
Iteration:   1%|▏         | 2/140 [00:03<03:35,  1.56s/it][A
Iteration:   2%|▏         | 3/140 [00:04<03:33,  1.55s/it][A
Iteration:   3%|▎         | 4/140 [00:06<03:30,  1.55s/it][A
Iteration:   4%|▎         | 5/140 [00:07<03:30,  1.56s/it][A
Iteration:   4%|▍         | 6/140 [00:09<03:28,  1.55s/it][A
Iteration:   5%|▌         | 7/140 [00:10<03:24,  1.54s/it][A
Iteration:   6%|▌         | 8/140 [00:12<03:23,  1.54s/it][A
Iteration:   6%|▋         | 9/140 [00:13<03:23,  1.55s/it][A
Iteration:   7%|▋         | 10/140 [00:15<03:23,  1.56s/it][A
Iteration:   8%|▊         | 11/140 [00:17<03:21,  1.56s/it][A
Iteration:   9%|▊         | 12/140 [00:18<03:19,  1.56s/it][A
Iteration:   9%|▉         | 13/140 [00:20<03:17,  1.56s/it][A
Iteration:  10%|█         | 14/140 [00:21<03:18,  1.57s/it][A
Iteration:  11%|█         | 15/140 [00:23<03:14,  1.56s/it][A
Iteration:  11%|█▏        | 16/140 [00:24<03:13,  1.56s/it][A
Iteration:  12%|█▏        | 17/140 [00:26<03:10,  1.55s/it][A
Iteration:  13%|█▎        | 18/140 [00:27<03:08,  1.54s/it][A
Iteration:  14%|█▎        | 19/140 [00:29<03:04,  1.53s/it][A
Iteration:  14%|█▍        | 20/140 [00:30<03:01,  1.51s/it][A
Iteration:  15%|█▌        | 21/140 [00:32<02:59,  1.51s/it][A
Iteration:  16%|█▌        | 22/140 [00:33<02:57,  1.51s/it][A
Iteration:  16%|█▋        | 23/140 [00:35<02:55,  1.50s/it][A
Iteration:  17%|█▋        | 24/140 [00:36<02:54,  1.50s/it][A
Iteration:  18%|█▊        | 25/140 [00:38<02:52,  1.50s/it][A
Iteration:  19%|█▊        | 26/140 [00:39<02:50,  1.50s/it][A
Iteration:  19%|█▉        | 27/140 [00:41<02:48,  1.49s/it][A
Iteration:  20%|██        | 28/140 [00:42<02:46,  1.49s/it][A
Iteration:  21%|██        | 29/140 [00:44<02:46,  1.50s/it][A
Iteration:  21%|██▏       | 30/140 [00:45<02:45,  1.51s/it][A
Iteration:  22%|██▏       | 31/140 [00:47<02:44,  1.51s/it][A
Iteration:  23%|██▎       | 32/140 [00:48<02:42,  1.50s/it][A
Iteration:  24%|██▎       | 33/140 [00:50<02:41,  1.51s/it][A
Iteration:  24%|██▍       | 34/140 [00:51<02:39,  1.50s/it][A
Iteration:  25%|██▌       | 35/140 [00:53<02:37,  1.50s/it][A
Iteration:  26%|██▌       | 36/140 [00:54<02:35,  1.49s/it][A
Iteration:  26%|██▋       | 37/140 [00:56<02:33,  1.49s/it][A
Iteration:  27%|██▋       | 38/140 [00:57<02:31,  1.49s/it][A
Iteration:  28%|██▊       | 39/140 [00:59<02:30,  1.49s/it][A
Iteration:  29%|██▊       | 40/140 [01:00<02:28,  1.49s/it][A
Iteration:  29%|██▉       | 41/140 [01:02<02:27,  1.49s/it][A
Iteration:  30%|███       | 42/140 [01:03<02:26,  1.49s/it][A
Iteration:  31%|███       | 43/140 [01:05<02:24,  1.49s/it][A
Iteration:  31%|███▏      | 44/140 [01:06<02:23,  1.49s/it][A
Iteration:  32%|███▏      | 45/140 [01:08<02:21,  1.49s/it][A
Iteration:  33%|███▎      | 46/140 [01:09<02:20,  1.49s/it][A
Iteration:  34%|███▎      | 47/140 [01:11<02:18,  1.49s/it][A
Iteration:  34%|███▍      | 48/140 [01:12<02:17,  1.49s/it][A
Iteration:  35%|███▌      | 49/140 [01:14<02:15,  1.49s/it][A
Iteration:  36%|███▌      | 50/140 [01:15<02:14,  1.50s/it][A
Iteration:  36%|███▋      | 51/140 [01:17<02:13,  1.50s/it][A
Iteration:  37%|███▋      | 52/140 [01:18<02:11,  1.50s/it][A
Iteration:  38%|███▊      | 53/140 [01:20<02:10,  1.50s/it][A
Iteration:  39%|███▊      | 54/140 [01:21<02:08,  1.50s/it][A
Iteration:  39%|███▉      | 55/140 [01:23<02:07,  1.50s/it][A
Iteration:  40%|████      | 56/140 [01:24<02:05,  1.50s/it][A
Iteration:  41%|████      | 57/140 [01:26<02:04,  1.50s/it][A
Iteration:  41%|████▏     | 58/140 [01:27<02:03,  1.50s/it][A
Iteration:  42%|████▏     | 59/140 [01:29<02:02,  1.51s/it][A
Iteration:  43%|████▎     | 60/140 [01:30<02:00,  1.51s/it][A
Iteration:  44%|████▎     | 61/140 [01:32<01:58,  1.50s/it][A
Iteration:  44%|████▍     | 62/140 [01:33<01:56,  1.50s/it][A
Iteration:  45%|████▌     | 63/140 [01:35<01:55,  1.50s/it][A
Iteration:  46%|████▌     | 64/140 [01:36<01:54,  1.51s/it][A
Iteration:  46%|████▋     | 65/140 [01:38<01:53,  1.51s/it][A
Iteration:  47%|████▋     | 66/140 [01:39<01:51,  1.50s/it][A
Iteration:  48%|████▊     | 67/140 [01:41<01:49,  1.50s/it][A
Iteration:  49%|████▊     | 68/140 [01:42<01:47,  1.50s/it][A
Iteration:  49%|████▉     | 69/140 [01:44<01:46,  1.50s/it][A
Iteration:  50%|█████     | 70/140 [01:45<01:44,  1.50s/it][A
Iteration:  51%|█████     | 71/140 [01:47<01:43,  1.50s/it][A
Iteration:  51%|█████▏    | 72/140 [01:48<01:42,  1.51s/it][A
Iteration:  52%|█████▏    | 73/140 [01:50<01:41,  1.52s/it][A
Iteration:  53%|█████▎    | 74/140 [01:51<01:40,  1.52s/it][A
Iteration:  54%|█████▎    | 75/140 [01:53<01:38,  1.52s/it][A
Iteration:  54%|█████▍    | 76/140 [01:54<01:37,  1.52s/it][A
Iteration:  55%|█████▌    | 77/140 [01:56<01:35,  1.51s/it][A
Iteration:  56%|█████▌    | 78/140 [01:57<01:33,  1.51s/it][A
Iteration:  56%|█████▋    | 79/140 [01:59<01:32,  1.51s/it][A
Iteration:  57%|█████▋    | 80/140 [02:00<01:30,  1.51s/it][A
Iteration:  58%|█████▊    | 81/140 [02:02<01:28,  1.50s/it][A
Iteration:  59%|█████▊    | 82/140 [02:03<01:27,  1.51s/it][A
Iteration:  59%|█████▉    | 83/140 [02:05<01:25,  1.50s/it][A
Iteration:  60%|██████    | 84/140 [02:07<01:25,  1.52s/it][A
Iteration:  61%|██████    | 85/140 [02:08<01:24,  1.53s/it][A
Iteration:  61%|██████▏   | 86/140 [02:10<01:22,  1.53s/it][A
Iteration:  62%|██████▏   | 87/140 [02:11<01:21,  1.54s/it][A
Iteration:  63%|██████▎   | 88/140 [02:13<01:20,  1.55s/it][A
Iteration:  64%|██████▎   | 89/140 [02:14<01:19,  1.56s/it][A
Iteration:  64%|██████▍   | 90/140 [02:16<01:17,  1.55s/it][A
Iteration:  65%|██████▌   | 91/140 [02:17<01:15,  1.55s/it][A
Iteration:  66%|██████▌   | 92/140 [02:19<01:13,  1.53s/it][A
Iteration:  66%|██████▋   | 93/140 [02:20<01:11,  1.52s/it][A
Iteration:  67%|██████▋   | 94/140 [02:22<01:09,  1.52s/it][A
Iteration:  68%|██████▊   | 95/140 [02:23<01:07,  1.51s/it][A
Iteration:  69%|██████▊   | 96/140 [02:25<01:06,  1.51s/it][A
Iteration:  69%|██████▉   | 97/140 [02:26<01:04,  1.51s/it][A
Iteration:  70%|███████   | 98/140 [02:28<01:03,  1.50s/it][A
Iteration:  71%|███████   | 99/140 [02:29<01:02,  1.51s/it][A
Iteration:  71%|███████▏  | 100/140 [02:31<01:00,  1.51s/it][A
Iteration:  72%|███████▏  | 101/140 [02:32<00:58,  1.51s/it][A
Iteration:  73%|███████▎  | 102/140 [02:34<00:57,  1.50s/it][A
Iteration:  74%|███████▎  | 103/140 [02:35<00:55,  1.50s/it][A
Iteration:  74%|███████▍  | 104/140 [02:37<00:53,  1.50s/it][A
Iteration:  75%|███████▌  | 105/140 [02:38<00:52,  1.50s/it][A
Iteration:  76%|███████▌  | 106/140 [02:40<00:50,  1.50s/it][A
Iteration:  76%|███████▋  | 107/140 [02:41<00:49,  1.50s/it][A
Iteration:  77%|███████▋  | 108/140 [02:43<00:47,  1.50s/it][A
Iteration:  78%|███████▊  | 109/140 [02:44<00:46,  1.50s/it][A
Iteration:  79%|███████▊  | 110/140 [02:46<00:45,  1.50s/it][A
Iteration:  79%|███████▉  | 111/140 [02:47<00:43,  1.50s/it][A
Iteration:  80%|████████  | 112/140 [02:49<00:41,  1.50s/it][A
Iteration:  81%|████████  | 113/140 [02:50<00:40,  1.50s/it][A
Iteration:  81%|████████▏ | 114/140 [02:52<00:38,  1.49s/it][A
Iteration:  82%|████████▏ | 115/140 [02:53<00:37,  1.50s/it][A
Iteration:  83%|████████▎ | 116/140 [02:55<00:35,  1.50s/it][A
Iteration:  84%|████████▎ | 117/140 [02:56<00:34,  1.50s/it][A
Iteration:  84%|████████▍ | 118/140 [02:58<00:32,  1.50s/it][A
Iteration:  85%|████████▌ | 119/140 [02:59<00:31,  1.50s/it][A
Iteration:  86%|████████▌ | 120/140 [03:01<00:29,  1.50s/it][A
Iteration:  86%|████████▋ | 121/140 [03:02<00:28,  1.49s/it][A
Iteration:  87%|████████▋ | 122/140 [03:04<00:26,  1.50s/it][A
Iteration:  88%|████████▊ | 123/140 [03:05<00:25,  1.50s/it][A
Iteration:  89%|████████▊ | 124/140 [03:07<00:23,  1.50s/it][A
Iteration:  89%|████████▉ | 125/140 [03:08<00:22,  1.50s/it][A
Iteration:  90%|█████████ | 126/140 [03:10<00:20,  1.50s/it][A
Iteration:  91%|█████████ | 127/140 [03:11<00:19,  1.50s/it][A
Iteration:  91%|█████████▏| 128/140 [03:13<00:17,  1.50s/it][A
Iteration:  92%|█████████▏| 129/140 [03:14<00:16,  1.50s/it][A
Iteration:  93%|█████████▎| 130/140 [03:16<00:14,  1.49s/it][A
Iteration:  94%|█████████▎| 131/140 [03:17<00:13,  1.49s/it][A
Iteration:  94%|█████████▍| 132/140 [03:19<00:11,  1.49s/it][A
Iteration:  95%|█████████▌| 133/140 [03:20<00:10,  1.49s/it][A
Iteration:  96%|█████████▌| 134/140 [03:22<00:08,  1.49s/it][A
Iteration:  96%|█████████▋| 135/140 [03:23<00:07,  1.49s/it][A
Iteration:  97%|█████████▋| 136/140 [03:25<00:05,  1.49s/it][A
Iteration:  98%|█████████▊| 137/140 [03:26<00:04,  1.49s/it][A
Iteration:  99%|█████████▊| 138/140 [03:28<00:02,  1.49s/it][A
Iteration:  99%|█████████▉| 139/140 [03:29<00:01,  1.51s/it][A
Iteration: 100%|██████████| 140/140 [03:31<00:00,  1.51s/it][A
Epoch:  50%|█████     | 1/2 [03:31<03:31, 211.30s/it]
Iteration:   0%|          | 0/140 [00:00<?, ?it/s][A
Iteration:   1%|          | 1/140 [00:01<03:37,  1.57s/it][A
Iteration:   1%|▏         | 2/140 [00:03<03:35,  1.56s/it][A
Iteration:   2%|▏         | 3/140 [00:04<03:34,  1.57s/it][A
Iteration:   3%|▎         | 4/140 [00:06<03:31,  1.56s/it][A
Iteration:   4%|▎         | 5/140 [00:07<03:30,  1.56s/it][A
Iteration:   4%|▍         | 6/140 [00:09<03:28,  1.56s/it][A
Iteration:   5%|▌         | 7/140 [00:10<03:26,  1.56s/it][A
Iteration:   6%|▌         | 8/140 [00:12<03:23,  1.54s/it][A
Iteration:   6%|▋         | 9/140 [00:13<03:21,  1.54s/it][A
Iteration:   7%|▋         | 10/140 [00:15<03:18,  1.53s/it][A
Iteration:   8%|▊         | 11/140 [00:17<03:19,  1.54s/it][A
Iteration:   9%|▊         | 12/140 [00:18<03:16,  1.53s/it][A
Iteration:   9%|▉         | 13/140 [00:20<03:13,  1.52s/it][A
Iteration:  10%|█         | 14/140 [00:21<03:11,  1.52s/it][A
Iteration:  11%|█         | 15/140 [00:23<03:09,  1.52s/it][A
Iteration:  11%|█▏        | 16/140 [00:24<03:07,  1.51s/it][A
Iteration:  12%|█▏        | 17/140 [00:26<03:05,  1.51s/it][A
Iteration:  13%|█▎        | 18/140 [00:27<03:03,  1.50s/it][A
Iteration:  14%|█▎        | 19/140 [00:29<03:01,  1.50s/it][A
Iteration:  14%|█▍        | 20/140 [00:30<02:59,  1.50s/it][A
Iteration:  15%|█▌        | 21/140 [00:32<02:58,  1.50s/it][A
Iteration:  16%|█▌        | 22/140 [00:33<02:57,  1.51s/it][A
Iteration:  16%|█▋        | 23/140 [00:35<02:56,  1.51s/it][A
Iteration:  17%|█▋        | 24/140 [00:36<02:55,  1.51s/it][A
Iteration:  18%|█▊        | 25/140 [00:38<02:54,  1.52s/it][A
Iteration:  19%|█▊        | 26/140 [00:39<02:54,  1.53s/it][A
Iteration:  19%|█▉        | 27/140 [00:41<02:54,  1.55s/it][A
Iteration:  20%|██        | 28/140 [00:42<02:52,  1.54s/it][A
Iteration:  21%|██        | 29/140 [00:44<02:51,  1.54s/it][A
Iteration:  21%|██▏       | 30/140 [00:45<02:49,  1.54s/it][A
Iteration:  22%|██▏       | 31/140 [00:47<02:48,  1.55s/it][A
Iteration:  23%|██▎       | 32/140 [00:48<02:45,  1.53s/it][A
Iteration:  24%|██▎       | 33/140 [00:50<02:44,  1.53s/it][A
Iteration:  24%|██▍       | 34/140 [00:51<02:41,  1.53s/it][A
Iteration:  25%|██▌       | 35/140 [00:53<02:40,  1.53s/it][A
Iteration:  26%|██▌       | 36/140 [00:55<02:38,  1.53s/it][A
Iteration:  26%|██▋       | 37/140 [00:56<02:37,  1.53s/it][A
Iteration:  27%|██▋       | 38/140 [00:58<02:35,  1.52s/it][A
Iteration:  28%|██▊       | 39/140 [00:59<02:33,  1.52s/it][A
Iteration:  29%|██▊       | 40/140 [01:01<02:31,  1.52s/it][A
Iteration:  29%|██▉       | 41/140 [01:02<02:29,  1.51s/it][A
Iteration:  30%|███       | 42/140 [01:04<02:28,  1.52s/it][A
Iteration:  31%|███       | 43/140 [01:05<02:27,  1.52s/it][A
Iteration:  31%|███▏      | 44/140 [01:07<02:25,  1.52s/it][A
Iteration:  32%|███▏      | 45/140 [01:08<02:24,  1.52s/it][A
Iteration:  33%|███▎      | 46/140 [01:10<02:23,  1.53s/it][A
Iteration:  34%|███▎      | 47/140 [01:11<02:22,  1.53s/it][A
Iteration:  34%|███▍      | 48/140 [01:13<02:20,  1.53s/it][A
Iteration:  35%|███▌      | 49/140 [01:14<02:18,  1.52s/it][A
Iteration:  36%|███▌      | 50/140 [01:16<02:16,  1.52s/it][A
Iteration:  36%|███▋      | 51/140 [01:17<02:15,  1.52s/it][A
Iteration:  37%|███▋      | 52/140 [01:19<02:13,  1.52s/it][A
Iteration:  38%|███▊      | 53/140 [01:20<02:11,  1.51s/it][A
Iteration:  39%|███▊      | 54/140 [01:22<02:11,  1.53s/it][A
Iteration:  39%|███▉      | 55/140 [01:23<02:09,  1.52s/it][A
Iteration:  40%|████      | 56/140 [01:25<02:07,  1.52s/it][A
Iteration:  41%|████      | 57/140 [01:27<02:06,  1.53s/it][A
Iteration:  41%|████▏     | 58/140 [01:28<02:04,  1.52s/it][A
Iteration:  42%|████▏     | 59/140 [01:30<02:02,  1.51s/it][A

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s][A[A

Evaluating:  12%|█▎        | 1/8 [00:00<00:06,  1.15it/s][A[A

Evaluating:  25%|██▌       | 2/8 [00:01<00:05,  1.14it/s][A[A

Evaluating:  38%|███▊      | 3/8 [00:02<00:04,  1.14it/s][A[A

Evaluating:  50%|█████     | 4/8 [00:03<00:03,  1.14it/s][A[A

Evaluating:  62%|██████▎   | 5/8 [00:04<00:02,  1.14it/s][A[A

Evaluating:  75%|███████▌  | 6/8 [00:05<00:01,  1.14it/s][A[A

Evaluating:  88%|████████▊ | 7/8 [00:06<00:00,  1.14it/s][A[A

Evaluating: 100%|██████████| 8/8 [00:06<00:00,  1.17it/s][A[A
/home/frank/miniconda3/envs/bio-bert-bilstm-crf/lib/python3.7/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: UNK seems not to be NE tag.
  warnings.warn('{} seems not to be NE tag.'.format(chunk))


  intent_acc = 0.85
  loss = 2.148504465818405
  sementic_frame_acc = 0.708
  slot_f1 = 0.9299883313885647
  slot_precision = 0.9278230500582072
  slot_recall = 0.9321637426900585



Iteration:  43%|████▎     | 60/140 [01:38<04:57,  3.71s/it][A

Saving model checkpoint to ./save_model



Iteration:  44%|████▎     | 61/140 [01:40<04:00,  3.05s/it][A
Iteration:  44%|████▍     | 62/140 [01:41<03:21,  2.58s/it][A
Iteration:  45%|████▌     | 63/140 [01:43<02:53,  2.26s/it][A
Iteration:  46%|████▌     | 64/140 [01:44<02:34,  2.03s/it][A
Iteration:  46%|████▋     | 65/140 [01:46<02:20,  1.88s/it][A
Iteration:  47%|████▋     | 66/140 [01:47<02:10,  1.77s/it][A
Iteration:  48%|████▊     | 67/140 [01:49<02:03,  1.70s/it][A
Iteration:  49%|████▊     | 68/140 [01:50<01:58,  1.64s/it][A
Iteration:  49%|████▉     | 69/140 [01:52<01:53,  1.60s/it][A
Iteration:  50%|█████     | 70/140 [01:53<01:50,  1.58s/it][A
Iteration:  51%|█████     | 71/140 [01:55<01:47,  1.56s/it][A
Iteration:  51%|█████▏    | 72/140 [01:57<01:47,  1.58s/it][A
Iteration:  52%|█████▏    | 73/140 [01:58<01:45,  1.57s/it][A
Iteration:  53%|█████▎    | 74/140 [02:00<01:42,  1.56s/it][A
Iteration:  54%|█████▎    | 75/140 [02:01<01:40,  1.54s/it][A
Iteration:  54%|█████▍    | 76/140 [02:03<01:38,  1.54s/it][A
Iteration:  55%|█████▌    | 77/140 [02:04<01:37,  1.54s/it][A
Iteration:  56%|█████▌    | 78/140 [02:06<01:35,  1.54s/it][A
Iteration:  56%|█████▋    | 79/140 [02:07<01:34,  1.54s/it][A
Iteration:  57%|█████▋    | 80/140 [02:09<01:32,  1.53s/it][A
Iteration:  58%|█████▊    | 81/140 [02:10<01:30,  1.53s/it][A
Iteration:  59%|█████▊    | 82/140 [02:12<01:28,  1.52s/it][A
Iteration:  59%|█████▉    | 83/140 [02:13<01:26,  1.52s/it][A
Iteration:  60%|██████    | 84/140 [02:15<01:25,  1.52s/it][A
Iteration:  61%|██████    | 85/140 [02:16<01:23,  1.52s/it][A
Iteration:  61%|██████▏   | 86/140 [02:18<01:22,  1.52s/it][A
Iteration:  62%|██████▏   | 87/140 [02:19<01:20,  1.53s/it][A
Iteration:  63%|██████▎   | 88/140 [02:21<01:19,  1.53s/it][A
Iteration:  64%|██████▎   | 89/140 [02:23<01:17,  1.52s/it][A
Iteration:  64%|██████▍   | 90/140 [02:24<01:15,  1.51s/it][A
Iteration:  65%|██████▌   | 91/140 [02:26<01:13,  1.50s/it][A
Iteration:  66%|██████▌   | 92/140 [02:27<01:12,  1.50s/it][A
Iteration:  66%|██████▋   | 93/140 [02:29<01:10,  1.50s/it][A
Iteration:  67%|██████▋   | 94/140 [02:30<01:09,  1.50s/it][A
Iteration:  68%|██████▊   | 95/140 [02:32<01:07,  1.50s/it][A
Iteration:  69%|██████▊   | 96/140 [02:33<01:06,  1.50s/it][A
Iteration:  69%|██████▉   | 97/140 [02:35<01:04,  1.50s/it][A
Iteration:  70%|███████   | 98/140 [02:36<01:02,  1.50s/it][A
Iteration:  71%|███████   | 99/140 [02:38<01:01,  1.50s/it][A
Iteration:  71%|███████▏  | 100/140 [02:39<01:00,  1.50s/it][A
Iteration:  72%|███████▏  | 101/140 [02:41<00:58,  1.50s/it][A
Iteration:  73%|███████▎  | 102/140 [02:42<00:57,  1.50s/it][A
Iteration:  74%|███████▎  | 103/140 [02:44<00:55,  1.50s/it][A
Iteration:  74%|███████▍  | 104/140 [02:45<00:54,  1.51s/it][A
Iteration:  75%|███████▌  | 105/140 [02:47<00:52,  1.51s/it][A
Iteration:  76%|███████▌  | 106/140 [02:48<00:51,  1.51s/it][A
Iteration:  76%|███████▋  | 107/140 [02:50<00:49,  1.51s/it][A
Iteration:  77%|███████▋  | 108/140 [02:51<00:48,  1.52s/it][A
Iteration:  78%|███████▊  | 109/140 [02:53<00:47,  1.53s/it][A
Iteration:  79%|███████▊  | 110/140 [02:54<00:45,  1.53s/it][A
Iteration:  79%|███████▉  | 111/140 [02:56<00:44,  1.53s/it][A
Iteration:  80%|████████  | 112/140 [02:57<00:42,  1.52s/it][A
Iteration:  81%|████████  | 113/140 [02:59<00:41,  1.52s/it][A
Iteration:  81%|████████▏ | 114/140 [03:00<00:39,  1.52s/it][A
Iteration:  82%|████████▏ | 115/140 [03:02<00:37,  1.52s/it][A
Iteration:  83%|████████▎ | 116/140 [03:03<00:36,  1.51s/it][A
Iteration:  84%|████████▎ | 117/140 [03:05<00:34,  1.52s/it][A
Iteration:  84%|████████▍ | 118/140 [03:06<00:33,  1.54s/it][A
Iteration:  85%|████████▌ | 119/140 [03:08<00:32,  1.54s/it][A
Iteration:  86%|████████▌ | 120/140 [03:09<00:30,  1.54s/it][A
Iteration:  86%|████████▋ | 121/140 [03:11<00:29,  1.54s/it][A
Iteration:  87%|████████▋ | 122/140 [03:13<00:27,  1.53s/it][A
Iteration:  88%|████████▊ | 123/140 [03:14<00:25,  1.52s/it][A
Iteration:  89%|████████▊ | 124/140 [03:16<00:24,  1.52s/it][A
Iteration:  89%|████████▉ | 125/140 [03:17<00:22,  1.51s/it][A
Iteration:  90%|█████████ | 126/140 [03:19<00:21,  1.51s/it][A
Iteration:  91%|█████████ | 127/140 [03:20<00:19,  1.51s/it][A
Iteration:  91%|█████████▏| 128/140 [03:22<00:18,  1.51s/it][A
Iteration:  92%|█████████▏| 129/140 [03:23<00:16,  1.50s/it][A
Iteration:  93%|█████████▎| 130/140 [03:25<00:15,  1.50s/it][A
Iteration:  94%|█████████▎| 131/140 [03:26<00:13,  1.50s/it][A
Iteration:  94%|█████████▍| 132/140 [03:28<00:11,  1.50s/it][A
Iteration:  95%|█████████▌| 133/140 [03:29<00:10,  1.50s/it][A
Iteration:  96%|█████████▌| 134/140 [03:31<00:08,  1.50s/it][A
Iteration:  96%|█████████▋| 135/140 [03:32<00:07,  1.51s/it][A
Iteration:  97%|█████████▋| 136/140 [03:34<00:06,  1.50s/it][A
Iteration:  98%|█████████▊| 137/140 [03:35<00:04,  1.50s/it][A
Iteration:  99%|█████████▊| 138/140 [03:37<00:02,  1.50s/it][A
Iteration:  99%|█████████▉| 139/140 [03:38<00:01,  1.51s/it][A
Iteration: 100%|██████████| 140/140 [03:39<00:00,  1.57s/it][A
Epoch: 100%|██████████| 2/2 [07:11<00:00, 215.63s/it]





(280, 5.539662108251027)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327

用训练好的模型预测数据

trainer.evaluate("test")
  • 1
Evaluating: 100%|██████████| 14/14 [00:11<00:00,  1.17it/s]

  intent_acc = 0.8443449048152296
  loss = 2.3734985419682095
  sementic_frame_acc = 0.7077267637178052
  slot_f1 = 0.916652105539053
  slot_precision = 0.909500693481276
  slot_recall = 0.9239168721380768



{'loss': 2.3734985419682095,
 'intent_acc': 0.8443449048152296,
 'slot_precision': 0.909500693481276,
 'slot_recall': 0.9239168721380768,
 'slot_f1': 0.916652105539053,
 'sementic_frame_acc': 0.7077267637178052}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/311472
推荐阅读
相关标签
  

闽ICP备14008679号