赞
踩
import sys import os import copy import json import logging import argparse import torch import numpy as np from tqdm import tqdm, trange import torch.nn as nn from torchcrf import CRF from torch.utils.data import TensorDataset from seqeval.metrics import precision_score, recall_score, f1_score from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers import BertConfig, AdamW, get_linear_schedule_with_warmup from transformers import ( BertModel, BertTokenizer, BertPreTrainedModel, ) sys.argv=[''] del sys logger = logging.getLogger(__name__)
定义加载分词器tokenizer的函数(tokenizer是bert自己的分词器,它可以把词分开并且变为one hot编码)
def load_tokenizer(args):
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
return tokenizer
定义其他的函数
def get_intent_acc(preds, labels): acc = (preds == labels).mean() return { "intent_acc": acc } def get_slot_metrics(preds, labels): assert len(preds) == len(labels) return { "slot_precision": precision_score(labels, preds), "slot_recall": recall_score(labels, preds), "slot_f1": f1_score(labels, preds) } def get_sentence_frame_acc(intent_preds, intent_labels, slot_preds, slot_labels): """For the cases that intent and all the slots are correct (in one sentence)""" # Get the intent comparison result intent_result = (intent_preds == intent_labels) # Get the slot comparision result slot_result = [] for preds, labels in zip(slot_preds, slot_labels): assert len(preds) == len(labels) one_sent_result = True for p, l in zip(preds, labels): if p != l: one_sent_result = False break slot_result.append(one_sent_result) slot_result = np.array(slot_result) sementic_acc = np.multiply(intent_result, slot_result).mean() return { "sementic_frame_acc": sementic_acc } def compute_metrics(intent_preds, intent_labels, slot_preds, slot_labels): assert len(intent_preds) == len(intent_labels) == len(slot_preds) == len(slot_labels) results = {} intent_result = get_intent_acc(intent_preds, intent_labels) slot_result = get_slot_metrics(slot_preds, slot_labels) sementic_result = get_sentence_frame_acc(intent_preds, intent_labels, slot_preds, slot_labels) results.update(intent_result) results.update(slot_result) results.update(sementic_result) return results
argparse 模块可以让人轻松编写用户友好的命令行接口。程序定义它需要的参数,然后 argparse 将弄清如何从 sys.argv 解析出那些参数。 argparse 模块还会自动生成帮助和使用手册,并在用户给程序传入无效参数时报出错误信息。
https://docs.python.org/zh-cn/3/library/argparse.html
parser = argparse.ArgumentParser()
parser.add_argument("--task", default='atis', required=False, type=str, help="The name of the task to train")
_StoreAction(option_strings=['--task'], dest='task', nargs=None, const=None, default='atis', type=<class 'str'>, choices=None, help='The name of the task to train', metavar=None)
parser.add_argument("--model_dir", default="./save_model", required=False, type=str, help="Path to save, load model") parser.add_argument("--data_dir", default="./data", type=str, help="The input data dir") parser.add_argument("--intent_label_file", default="intent_label.txt", type=str, help="Intent Label file") parser.add_argument("--slot_label_file", default="slot_label.txt", type=str, help="Slot Label file") parser.add_argument("--model_type", default="bert", type=str, help=" Bert is the Model") parser.add_argument('--seed', type=int, default=1234, help="random seed for initialization") parser.add_argument("--train_batch_size", default=32, type=int, help="Batch size for training.") parser.add_argument("--eval_batch_size", default=64, type=int, help="Batch size for evaluation.") parser.add_argument("--max_seq_len", default=50, type=int, help="The maximum total input sequence length after tokenization.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--num_train_epochs", default=2.0, type=float, help="Total number of training epochs to perform.") parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.") parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") parser.add_argument("--dropout_rate", default=0.1, type=float, help="Dropout for fully-connected layers") parser.add_argument('--logging_steps', type=int, default=200, help="Log every X updates steps.") parser.add_argument('--save_steps', type=int, default=200, help="Save checkpoint every X updates steps.") parser.add_argument("--do_train", default=True, action="store_true", help="Whether to run training.") parser.add_argument("--do_eval", default=True, action="store_true", help="Whether to run eval on the test set.") parser.add_argument("--no_cuda", default=True, action="store_true", help="Avoid using CUDA when available") parser.add_argument("--ignore_index", default=0, type=int, help='Specifies a target value that is ignored and does not contribute to the input gradient') parser.add_argument('--slot_loss_coef', type=float, default=1.0, help='Coefficient for the slot loss.') # CRF option parser.add_argument("--use_crf", default=True, action="store_true", help="Whether to use CRF") parser.add_argument("--slot_pad_label", default="PAD", type=str, help="Pad token for slot label pad (to be ignore when calculate loss)")
_StoreAction(option_strings=['--slot_pad_label'], dest='slot_pad_label', nargs=None, const=None, default='PAD', type=<class 'str'>, choices=None, help='Pad token for slot label pad (to be ignore when calculate loss)', metavar=None)
args = parser.parse_args()
args.model_name_or_path = 'bert-base-uncased'
args
Namespace(adam_epsilon=1e-08, data_dir='./data', do_eval=True, do_train=True, dropout_rate=0.1, eval_batch_size=64, gradient_accumulation_steps=1, ignore_index=0, intent_label_file='intent_label.txt', learning_rate=5e-05, logging_steps=200, max_grad_norm=1.0, max_seq_len=50, max_steps=-1, model_dir='./save_model', model_name_or_path='bert-base-uncased', model_type='bert', no_cuda=True, num_train_epochs=2.0, save_steps=200, seed=1234, slot_label_file='slot_label.txt', slot_loss_coef=1.0, slot_pad_label='PAD', task='atis', train_batch_size=32, use_crf=True, warmup_steps=0, weight_decay=0.0)
官方文档 https://huggingface.co/transformers/main_classes/processors.html?highlight=inputexample#transformers.data.processors.utils.InputExample
class InputExample(object): """ A single training/test example for simple sequence classification. Args: guid: Unique id for the example. words: list. The words of the sequence. intent_label: (Optional) string. The intent label of the example. slot_labels: (Optional) list. The slot labels of the example. """ def __init__(self, guid, words, intent_label=None, slot_labels=None): self.guid = guid self.words = words self.intent_label = intent_label self.slot_labels = slot_labels def __repr__(self): return str(self.to_json_string()) def to_dict(self): """Serializes this instance to a Python dictionary.""" output = copy.deepcopy(self.__dict__) return output def to_json_string(self): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class InputFeatures(object): """A single set of features of data.""" def __init__(self, input_ids, attention_mask, token_type_ids, intent_label_id, slot_labels_ids): self.input_ids = input_ids self.attention_mask = attention_mask self.token_type_ids = token_type_ids self.intent_label_id = intent_label_id self.slot_labels_ids = slot_labels_ids def __repr__(self): return str(self.to_json_string()) def to_dict(self): """Serializes this instance to a Python dictionary.""" output = copy.deepcopy(self.__dict__) return output def to_json_string(self): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class JointProcessor(object): """Processor for the JointBERT data set """ def __init__(self, args): self.args = args self.intent_labels = [label.strip() for label in open("data/atis/intent_label.txt", 'r', encoding='utf-8')] self.slot_labels = [label.strip() for label in open("data/atis/slot_label.txt", 'r', encoding='utf-8')] self.input_text_file = 'seq.in' self.intent_label_file = 'label' self.slot_labels_file = 'seq.out' @classmethod def _read_file(cls, input_file, quotechar=None): """Reads a tab separated value file.""" with open(input_file, "r", encoding="utf-8") as f: lines = [] for line in f: lines.append(line.strip()) return lines def _create_examples(self, texts, intents, slots, set_type): """Creates examples for the training and dev sets.""" examples = [] for i, (text, intent, slot) in enumerate(zip(texts, intents, slots)): guid = "%s-%s" % (set_type, i) # 1. input_text words = text.split() # Some are spaced twice # 2. intent intent_label = self.intent_labels.index(intent) if intent in self.intent_labels else self.intent_labels.index("UNK") # 3. slot slot_labels = [] for s in slot.split(): slot_labels.append(self.slot_labels.index(s) if s in self.slot_labels else self.slot_labels.index("UNK")) assert len(words) == len(slot_labels) examples.append(InputExample(guid=guid, words=words, intent_label=intent_label, slot_labels=slot_labels)) return examples def get_examples(self, mode): """ Args: mode: train, dev, test """ data_path = os.path.join(self.args.data_dir, self.args.task, mode) logger.info("LOOKING AT {}".format(data_path)) return self._create_examples(texts=self._read_file(os.path.join(data_path, self.input_text_file)), intents=self._read_file(os.path.join(data_path, self.intent_label_file)), slots=self._read_file(os.path.join(data_path, self.slot_labels_file)), set_type=mode)
把数据转为输入bert的格式
def convert_examples_to_features(examples, max_seq_len, tokenizer, pad_token_label_id=-100, cls_token_segment_id=0, pad_token_segment_id=0, sequence_a_segment_id=0, mask_padding_with_zero=True): # Setting based on the current model type cls_token = tokenizer.cls_token sep_token = tokenizer.sep_token unk_token = tokenizer.unk_token pad_token_id = tokenizer.pad_token_id features = [] for (ex_index, example) in enumerate(examples): if ex_index % 5000 == 0: logger.info("Writing example %d of %d" % (ex_index, len(examples))) # Tokenize word by word (for NER) tokens = [] slot_labels_ids = [] for word, slot_label in zip(example.words, example.slot_labels): word_tokens = tokenizer.tokenize(word) if not word_tokens: word_tokens = [unk_token] # For handling the bad-encoded word tokens.extend(word_tokens) # Use the real label id for the first token of the word, and padding ids for the remaining tokens slot_labels_ids.extend([int(slot_label)] + [pad_token_label_id] * (len(word_tokens) - 1)) # Account for [CLS] and [SEP] special_tokens_count = 2 if len(tokens) > max_seq_len - special_tokens_count: tokens = tokens[:(max_seq_len - special_tokens_count)] slot_labels_ids = slot_labels_ids[:(max_seq_len - special_tokens_count)] # Add [SEP] token tokens += [sep_token] slot_labels_ids += [pad_token_label_id] token_type_ids = [sequence_a_segment_id] * len(tokens) # Add [CLS] token tokens = [cls_token] + tokens slot_labels_ids = [pad_token_label_id] + slot_labels_ids token_type_ids = [cls_token_segment_id] + token_type_ids input_ids = tokenizer.convert_tokens_to_ids(tokens) # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) # Zero-pad up to the sequence length. padding_length = max_seq_len - len(input_ids) input_ids = input_ids + ([pad_token_id] * padding_length) attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) slot_labels_ids = slot_labels_ids + ([pad_token_label_id] * padding_length) assert len(input_ids) == max_seq_len, "Error with input length {} vs {}".format(len(input_ids), max_seq_len) assert len(attention_mask) == max_seq_len, "Error with attention mask length {} vs {}".format(len(attention_mask), max_seq_len) assert len(token_type_ids) == max_seq_len, "Error with token type length {} vs {}".format(len(token_type_ids), max_seq_len) assert len(slot_labels_ids) == max_seq_len, "Error with slot labels length {} vs {}".format(len(slot_labels_ids), max_seq_len) intent_label_id = int(example.intent_label) features.append( InputFeatures(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, intent_label_id=intent_label_id, slot_labels_ids=slot_labels_ids )) return features
把数据转为dataset
def load_and_cache_examples(tokenizer, mode): processor = JointProcessor # Load data features from cache or dataset file cached_features_file = os.path.join( args.data_dir, 'cached_{}_{}_{}_{}'.format( mode, "atis", list(filter(None, args.model_name_or_path .split("/"))).pop(), args.max_seq_len ) ) if os.path.exists(cached_features_file): logger.info("Loading features from cached file %s", cached_features_file) features = torch.load(cached_features_file) else: # Load data features from dataset file logger.info("Creating features from dataset file at %s", args.data_dir) if mode == "train": examples = processor.get_examples("train") elif mode == "dev": examples = processor.get_examples("dev") elif mode == "test": examples = processor.get_examples("test") else: raise Exception("For mode, Only train, dev, test is available") # Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later pad_token_label_id = 0 features = convert_examples_to_features(examples, args.max_seq_len, tokenizer, pad_token_label_id=pad_token_label_id) logger.info("Saving features into cached file %s", cached_features_file) torch.save(features, cached_features_file) # Convert to Tensors and build dataset all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) all_intent_label_ids = torch.tensor([f.intent_label_id for f in features], dtype=torch.long) all_slot_labels_ids = torch.tensor([f.slot_labels_ids for f in features], dtype=torch.long) dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_intent_label_ids, all_slot_labels_ids) return dataset
制作数据集
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
train_dataset = load_and_cache_examples(tokenizer, mode="train")
dev_dataset = load_and_cache_examples(tokenizer, mode="dev")
test_dataset = load_and_cache_examples(tokenizer, mode="test")
意图识别分类器
class IntentClassifier(nn.Module):
def __init__(self, input_dim, num_intent_labels, dropout_rate=0.):
super(IntentClassifier, self).__init__()
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_dim, num_intent_labels)
def forward(self, x):
x = self.dropout(x)
return self.linear(x)
词槽识别
class SlotClassifier(nn.Module):
def __init__(self, input_dim, num_slot_labels, dropout_rate=0.):
super(SlotClassifier, self).__init__()
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_dim, num_slot_labels)
def forward(self, x):
x = self.dropout(x)
return self.linear(x)
JointBERT 模型定义
class JointBERT(BertPreTrainedModel): def __init__(self, config, args, intent_label_lst, slot_label_lst): super(JointBERT, self).__init__(config) self.args = args self.num_intent_labels = len(intent_label_lst) self.num_slot_labels = len(slot_label_lst) self.bert = BertModel(config=config) # Load pretrained bert self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate) self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels, args.dropout_rate) if args.use_crf: self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True) def forward(self, input_ids, attention_mask, token_type_ids, intent_label_ids, slot_labels_ids): outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) # sequence_output, pooled_output, (hidden_states), (attentions) sequence_output = outputs[0] pooled_output = outputs[1] # [CLS] intent_logits = self.intent_classifier(pooled_output) slot_logits = self.slot_classifier(sequence_output) total_loss = 0 # 1. Intent Softmax if intent_label_ids is not None: if self.num_intent_labels == 1: intent_loss_fct = nn.MSELoss() intent_loss = intent_loss_fct(intent_logits.view(-1), intent_label_ids.view(-1)) else: intent_loss_fct = nn.CrossEntropyLoss() intent_loss = intent_loss_fct(intent_logits.view(-1, self.num_intent_labels), intent_label_ids.view(-1)) total_loss += intent_loss # 2. Slot Softmax if slot_labels_ids is not None: if self.args.use_crf: slot_loss = self.crf(slot_logits, slot_labels_ids, mask=attention_mask.byte(), reduction='mean') slot_loss = -1 * slot_loss # negative log-likelihood else: slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index) # Only keep active parts of the loss if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss] active_labels = slot_labels_ids.view(-1)[active_loss] slot_loss = slot_loss_fct(active_logits, active_labels) else: slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1)) total_loss += self.args.slot_loss_coef * slot_loss outputs = ((intent_logits, slot_logits),) + outputs[2:] # add hidden states and attention if they are here outputs = (total_loss,) + outputs return outputs # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of intent and slot logits
训练模型
class Trainer(object): def __init__(self, args, train_dataset=None, dev_dataset=None, test_dataset=None): self.args = args self.train_dataset = train_dataset self.dev_dataset = dev_dataset self.test_dataset = test_dataset self.intent_label_lst = [label.strip() for label in open("data/atis/intent_label.txt", 'r', encoding='utf-8')] self.slot_label_lst = [label.strip() for label in open("data/atis/slot_label.txt", 'r', encoding='utf-8')] # Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later self.pad_token_label_id = args.ignore_index self.config = BertConfig.from_pretrained(self.args.model_name_or_path, finetuning_task=self.args.task) self.model = JointBERT.from_pretrained(self.args.model_name_or_path, config=self.config, args=self.args, intent_label_lst=self.intent_label_lst, slot_label_lst=self.slot_label_lst) # GPU or CPU self.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" self.model.to(self.device) def train(self): train_sampler = RandomSampler(self.train_dataset) train_dataloader = DataLoader(self.train_dataset, sampler=train_sampler, batch_size=self.args.train_batch_size) if self.args.max_steps > 0: t_total = self.args.max_steps self.args.num_train_epochs = self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1 else: t_total = len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': self.args.weight_decay}, {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=t_total) global_step = 0 tr_loss = 0.0 self.model.zero_grad() train_iterator = trange(int(self.args.num_train_epochs), desc="Epoch") for _ in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration") for step, batch in enumerate(epoch_iterator): self.model.train() batch = tuple(t.to(self.device) for t in batch) # GPU or CPU inputs = {'input_ids': batch[0], 'attention_mask': batch[1], 'intent_label_ids': batch[3], 'slot_labels_ids': batch[4]} if self.args.model_type != 'distilbert': inputs['token_type_ids'] = batch[2] outputs = self.model(**inputs) loss = outputs[0] if self.args.gradient_accumulation_steps > 1: loss = loss / self.args.gradient_accumulation_steps loss.backward() tr_loss += loss.item() if (step + 1) % self.args.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule self.model.zero_grad() global_step += 1 if self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0: self.evaluate("dev") if self.args.save_steps > 0 and global_step % self.args.save_steps == 0: self.save_model() if 0 < self.args.max_steps < global_step: epoch_iterator.close() break if 0 < self.args.max_steps < global_step: train_iterator.close() break return global_step, tr_loss / global_step def evaluate(self, mode): if mode == 'test': dataset = self.test_dataset elif mode == 'dev': dataset = self.dev_dataset else: raise Exception("Only dev and test dataset available") eval_sampler = SequentialSampler(dataset) eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=self.args.eval_batch_size) eval_loss = 0.0 nb_eval_steps = 0 intent_preds = None slot_preds = None out_intent_label_ids = None out_slot_labels_ids = None self.model.eval() for batch in tqdm(eval_dataloader, desc="Evaluating"): batch = tuple(t.to(self.device) for t in batch) with torch.no_grad(): inputs = {'input_ids': batch[0], 'attention_mask': batch[1], 'intent_label_ids': batch[3], 'slot_labels_ids': batch[4]} if self.args.model_type != 'distilbert': inputs['token_type_ids'] = batch[2] outputs = self.model(**inputs) tmp_eval_loss, (intent_logits, slot_logits) = outputs[:2] eval_loss += tmp_eval_loss.mean().item() nb_eval_steps += 1 # Intent prediction if intent_preds is None: intent_preds = intent_logits.detach().cpu().numpy() out_intent_label_ids = inputs['intent_label_ids'].detach().cpu().numpy() else: intent_preds = np.append(intent_preds, intent_logits.detach().cpu().numpy(), axis=0) out_intent_label_ids = np.append( out_intent_label_ids, inputs['intent_label_ids'].detach().cpu().numpy(), axis=0) # Slot prediction if slot_preds is None: if self.args.use_crf: # decode() in `torchcrf` returns list with best index directly slot_preds = np.array(self.model.crf.decode(slot_logits)) else: slot_preds = slot_logits.detach().cpu().numpy() out_slot_labels_ids = inputs["slot_labels_ids"].detach().cpu().numpy() else: if self.args.use_crf: slot_preds = np.append(slot_preds, np.array(self.model.crf.decode(slot_logits)), axis=0) else: slot_preds = np.append(slot_preds, slot_logits.detach().cpu().numpy(), axis=0) out_slot_labels_ids = np.append(out_slot_labels_ids, inputs["slot_labels_ids"].detach().cpu().numpy(), axis=0) eval_loss = eval_loss / nb_eval_steps results = { "loss": eval_loss } # Intent result intent_preds = np.argmax(intent_preds, axis=1) # Slot result if not self.args.use_crf: slot_preds = np.argmax(slot_preds, axis=2) slot_label_map = {i: label for i, label in enumerate(self.slot_label_lst)} out_slot_label_list = [[] for _ in range(out_slot_labels_ids.shape[0])] slot_preds_list = [[] for _ in range(out_slot_labels_ids.shape[0])] for i in range(out_slot_labels_ids.shape[0]): for j in range(out_slot_labels_ids.shape[1]): if out_slot_labels_ids[i, j] != self.pad_token_label_id: out_slot_label_list[i].append(slot_label_map[out_slot_labels_ids[i][j]]) slot_preds_list[i].append(slot_label_map[slot_preds[i][j]]) total_result = compute_metrics(intent_preds, out_intent_label_ids, slot_preds_list, out_slot_label_list) results.update(total_result) logger.info("***** Eval results *****") for key in sorted(results.keys()): # logger.info(" %s = %s", key, str(results[key])) print(" %s = %s" %(key, str(results[key]))) return results def save_model(self): # Save model checkpoint (Overwrite) if not os.path.exists(self.args.model_dir): os.makedirs(self.args.model_dir) model_to_save = self.model.module if hasattr(self.model, 'module') else self.model model_to_save.save_pretrained(self.args.model_dir) # Save training arguments together with the trained model torch.save(self.args, os.path.join(self.args.model_dir, 'training_args.bin')) print("Saving model checkpoint to %s" % self.args.model_dir) def load_model(self): # Check whether model exists if not os.path.exists(self.args.model_dir): raise Exception("Model doesn't exists! Train first!") try: self.model = self.model_class.from_pretrained(self.args.model_dir, args=self.args, intent_label_lst=self.intent_label_lst, slot_label_lst=self.slot_label_lst) self.model.to(self.device) print("***** Model Loaded *****") except: raise Exception("Some model files might be missing...")
开始训练模型
trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)
trainer.train()
Epoch: 0%| | 0/2 [00:00<?, ?it/s] Iteration: 0%| | 0/140 [00:00<?, ?it/s][A Iteration: 1%| | 1/140 [00:01<03:36, 1.56s/it][A Iteration: 1%|▏ | 2/140 [00:03<03:35, 1.56s/it][A Iteration: 2%|▏ | 3/140 [00:04<03:33, 1.55s/it][A Iteration: 3%|▎ | 4/140 [00:06<03:30, 1.55s/it][A Iteration: 4%|▎ | 5/140 [00:07<03:30, 1.56s/it][A Iteration: 4%|▍ | 6/140 [00:09<03:28, 1.55s/it][A Iteration: 5%|▌ | 7/140 [00:10<03:24, 1.54s/it][A Iteration: 6%|▌ | 8/140 [00:12<03:23, 1.54s/it][A Iteration: 6%|▋ | 9/140 [00:13<03:23, 1.55s/it][A Iteration: 7%|▋ | 10/140 [00:15<03:23, 1.56s/it][A Iteration: 8%|▊ | 11/140 [00:17<03:21, 1.56s/it][A Iteration: 9%|▊ | 12/140 [00:18<03:19, 1.56s/it][A Iteration: 9%|▉ | 13/140 [00:20<03:17, 1.56s/it][A Iteration: 10%|█ | 14/140 [00:21<03:18, 1.57s/it][A Iteration: 11%|█ | 15/140 [00:23<03:14, 1.56s/it][A Iteration: 11%|█▏ | 16/140 [00:24<03:13, 1.56s/it][A Iteration: 12%|█▏ | 17/140 [00:26<03:10, 1.55s/it][A Iteration: 13%|█▎ | 18/140 [00:27<03:08, 1.54s/it][A Iteration: 14%|█▎ | 19/140 [00:29<03:04, 1.53s/it][A Iteration: 14%|█▍ | 20/140 [00:30<03:01, 1.51s/it][A Iteration: 15%|█▌ | 21/140 [00:32<02:59, 1.51s/it][A Iteration: 16%|█▌ | 22/140 [00:33<02:57, 1.51s/it][A Iteration: 16%|█▋ | 23/140 [00:35<02:55, 1.50s/it][A Iteration: 17%|█▋ | 24/140 [00:36<02:54, 1.50s/it][A Iteration: 18%|█▊ | 25/140 [00:38<02:52, 1.50s/it][A Iteration: 19%|█▊ | 26/140 [00:39<02:50, 1.50s/it][A Iteration: 19%|█▉ | 27/140 [00:41<02:48, 1.49s/it][A Iteration: 20%|██ | 28/140 [00:42<02:46, 1.49s/it][A Iteration: 21%|██ | 29/140 [00:44<02:46, 1.50s/it][A Iteration: 21%|██▏ | 30/140 [00:45<02:45, 1.51s/it][A Iteration: 22%|██▏ | 31/140 [00:47<02:44, 1.51s/it][A Iteration: 23%|██▎ | 32/140 [00:48<02:42, 1.50s/it][A Iteration: 24%|██▎ | 33/140 [00:50<02:41, 1.51s/it][A Iteration: 24%|██▍ | 34/140 [00:51<02:39, 1.50s/it][A Iteration: 25%|██▌ | 35/140 [00:53<02:37, 1.50s/it][A Iteration: 26%|██▌ | 36/140 [00:54<02:35, 1.49s/it][A Iteration: 26%|██▋ | 37/140 [00:56<02:33, 1.49s/it][A Iteration: 27%|██▋ | 38/140 [00:57<02:31, 1.49s/it][A Iteration: 28%|██▊ | 39/140 [00:59<02:30, 1.49s/it][A Iteration: 29%|██▊ | 40/140 [01:00<02:28, 1.49s/it][A Iteration: 29%|██▉ | 41/140 [01:02<02:27, 1.49s/it][A Iteration: 30%|███ | 42/140 [01:03<02:26, 1.49s/it][A Iteration: 31%|███ | 43/140 [01:05<02:24, 1.49s/it][A Iteration: 31%|███▏ | 44/140 [01:06<02:23, 1.49s/it][A Iteration: 32%|███▏ | 45/140 [01:08<02:21, 1.49s/it][A Iteration: 33%|███▎ | 46/140 [01:09<02:20, 1.49s/it][A Iteration: 34%|███▎ | 47/140 [01:11<02:18, 1.49s/it][A Iteration: 34%|███▍ | 48/140 [01:12<02:17, 1.49s/it][A Iteration: 35%|███▌ | 49/140 [01:14<02:15, 1.49s/it][A Iteration: 36%|███▌ | 50/140 [01:15<02:14, 1.50s/it][A Iteration: 36%|███▋ | 51/140 [01:17<02:13, 1.50s/it][A Iteration: 37%|███▋ | 52/140 [01:18<02:11, 1.50s/it][A Iteration: 38%|███▊ | 53/140 [01:20<02:10, 1.50s/it][A Iteration: 39%|███▊ | 54/140 [01:21<02:08, 1.50s/it][A Iteration: 39%|███▉ | 55/140 [01:23<02:07, 1.50s/it][A Iteration: 40%|████ | 56/140 [01:24<02:05, 1.50s/it][A Iteration: 41%|████ | 57/140 [01:26<02:04, 1.50s/it][A Iteration: 41%|████▏ | 58/140 [01:27<02:03, 1.50s/it][A Iteration: 42%|████▏ | 59/140 [01:29<02:02, 1.51s/it][A Iteration: 43%|████▎ | 60/140 [01:30<02:00, 1.51s/it][A Iteration: 44%|████▎ | 61/140 [01:32<01:58, 1.50s/it][A Iteration: 44%|████▍ | 62/140 [01:33<01:56, 1.50s/it][A Iteration: 45%|████▌ | 63/140 [01:35<01:55, 1.50s/it][A Iteration: 46%|████▌ | 64/140 [01:36<01:54, 1.51s/it][A Iteration: 46%|████▋ | 65/140 [01:38<01:53, 1.51s/it][A Iteration: 47%|████▋ | 66/140 [01:39<01:51, 1.50s/it][A Iteration: 48%|████▊ | 67/140 [01:41<01:49, 1.50s/it][A Iteration: 49%|████▊ | 68/140 [01:42<01:47, 1.50s/it][A Iteration: 49%|████▉ | 69/140 [01:44<01:46, 1.50s/it][A Iteration: 50%|█████ | 70/140 [01:45<01:44, 1.50s/it][A Iteration: 51%|█████ | 71/140 [01:47<01:43, 1.50s/it][A Iteration: 51%|█████▏ | 72/140 [01:48<01:42, 1.51s/it][A Iteration: 52%|█████▏ | 73/140 [01:50<01:41, 1.52s/it][A Iteration: 53%|█████▎ | 74/140 [01:51<01:40, 1.52s/it][A Iteration: 54%|█████▎ | 75/140 [01:53<01:38, 1.52s/it][A Iteration: 54%|█████▍ | 76/140 [01:54<01:37, 1.52s/it][A Iteration: 55%|█████▌ | 77/140 [01:56<01:35, 1.51s/it][A Iteration: 56%|█████▌ | 78/140 [01:57<01:33, 1.51s/it][A Iteration: 56%|█████▋ | 79/140 [01:59<01:32, 1.51s/it][A Iteration: 57%|█████▋ | 80/140 [02:00<01:30, 1.51s/it][A Iteration: 58%|█████▊ | 81/140 [02:02<01:28, 1.50s/it][A Iteration: 59%|█████▊ | 82/140 [02:03<01:27, 1.51s/it][A Iteration: 59%|█████▉ | 83/140 [02:05<01:25, 1.50s/it][A Iteration: 60%|██████ | 84/140 [02:07<01:25, 1.52s/it][A Iteration: 61%|██████ | 85/140 [02:08<01:24, 1.53s/it][A Iteration: 61%|██████▏ | 86/140 [02:10<01:22, 1.53s/it][A Iteration: 62%|██████▏ | 87/140 [02:11<01:21, 1.54s/it][A Iteration: 63%|██████▎ | 88/140 [02:13<01:20, 1.55s/it][A Iteration: 64%|██████▎ | 89/140 [02:14<01:19, 1.56s/it][A Iteration: 64%|██████▍ | 90/140 [02:16<01:17, 1.55s/it][A Iteration: 65%|██████▌ | 91/140 [02:17<01:15, 1.55s/it][A Iteration: 66%|██████▌ | 92/140 [02:19<01:13, 1.53s/it][A Iteration: 66%|██████▋ | 93/140 [02:20<01:11, 1.52s/it][A Iteration: 67%|██████▋ | 94/140 [02:22<01:09, 1.52s/it][A Iteration: 68%|██████▊ | 95/140 [02:23<01:07, 1.51s/it][A Iteration: 69%|██████▊ | 96/140 [02:25<01:06, 1.51s/it][A Iteration: 69%|██████▉ | 97/140 [02:26<01:04, 1.51s/it][A Iteration: 70%|███████ | 98/140 [02:28<01:03, 1.50s/it][A Iteration: 71%|███████ | 99/140 [02:29<01:02, 1.51s/it][A Iteration: 71%|███████▏ | 100/140 [02:31<01:00, 1.51s/it][A Iteration: 72%|███████▏ | 101/140 [02:32<00:58, 1.51s/it][A Iteration: 73%|███████▎ | 102/140 [02:34<00:57, 1.50s/it][A Iteration: 74%|███████▎ | 103/140 [02:35<00:55, 1.50s/it][A Iteration: 74%|███████▍ | 104/140 [02:37<00:53, 1.50s/it][A Iteration: 75%|███████▌ | 105/140 [02:38<00:52, 1.50s/it][A Iteration: 76%|███████▌ | 106/140 [02:40<00:50, 1.50s/it][A Iteration: 76%|███████▋ | 107/140 [02:41<00:49, 1.50s/it][A Iteration: 77%|███████▋ | 108/140 [02:43<00:47, 1.50s/it][A Iteration: 78%|███████▊ | 109/140 [02:44<00:46, 1.50s/it][A Iteration: 79%|███████▊ | 110/140 [02:46<00:45, 1.50s/it][A Iteration: 79%|███████▉ | 111/140 [02:47<00:43, 1.50s/it][A Iteration: 80%|████████ | 112/140 [02:49<00:41, 1.50s/it][A Iteration: 81%|████████ | 113/140 [02:50<00:40, 1.50s/it][A Iteration: 81%|████████▏ | 114/140 [02:52<00:38, 1.49s/it][A Iteration: 82%|████████▏ | 115/140 [02:53<00:37, 1.50s/it][A Iteration: 83%|████████▎ | 116/140 [02:55<00:35, 1.50s/it][A Iteration: 84%|████████▎ | 117/140 [02:56<00:34, 1.50s/it][A Iteration: 84%|████████▍ | 118/140 [02:58<00:32, 1.50s/it][A Iteration: 85%|████████▌ | 119/140 [02:59<00:31, 1.50s/it][A Iteration: 86%|████████▌ | 120/140 [03:01<00:29, 1.50s/it][A Iteration: 86%|████████▋ | 121/140 [03:02<00:28, 1.49s/it][A Iteration: 87%|████████▋ | 122/140 [03:04<00:26, 1.50s/it][A Iteration: 88%|████████▊ | 123/140 [03:05<00:25, 1.50s/it][A Iteration: 89%|████████▊ | 124/140 [03:07<00:23, 1.50s/it][A Iteration: 89%|████████▉ | 125/140 [03:08<00:22, 1.50s/it][A Iteration: 90%|█████████ | 126/140 [03:10<00:20, 1.50s/it][A Iteration: 91%|█████████ | 127/140 [03:11<00:19, 1.50s/it][A Iteration: 91%|█████████▏| 128/140 [03:13<00:17, 1.50s/it][A Iteration: 92%|█████████▏| 129/140 [03:14<00:16, 1.50s/it][A Iteration: 93%|█████████▎| 130/140 [03:16<00:14, 1.49s/it][A Iteration: 94%|█████████▎| 131/140 [03:17<00:13, 1.49s/it][A Iteration: 94%|█████████▍| 132/140 [03:19<00:11, 1.49s/it][A Iteration: 95%|█████████▌| 133/140 [03:20<00:10, 1.49s/it][A Iteration: 96%|█████████▌| 134/140 [03:22<00:08, 1.49s/it][A Iteration: 96%|█████████▋| 135/140 [03:23<00:07, 1.49s/it][A Iteration: 97%|█████████▋| 136/140 [03:25<00:05, 1.49s/it][A Iteration: 98%|█████████▊| 137/140 [03:26<00:04, 1.49s/it][A Iteration: 99%|█████████▊| 138/140 [03:28<00:02, 1.49s/it][A Iteration: 99%|█████████▉| 139/140 [03:29<00:01, 1.51s/it][A Iteration: 100%|██████████| 140/140 [03:31<00:00, 1.51s/it][A Epoch: 50%|█████ | 1/2 [03:31<03:31, 211.30s/it] Iteration: 0%| | 0/140 [00:00<?, ?it/s][A Iteration: 1%| | 1/140 [00:01<03:37, 1.57s/it][A Iteration: 1%|▏ | 2/140 [00:03<03:35, 1.56s/it][A Iteration: 2%|▏ | 3/140 [00:04<03:34, 1.57s/it][A Iteration: 3%|▎ | 4/140 [00:06<03:31, 1.56s/it][A Iteration: 4%|▎ | 5/140 [00:07<03:30, 1.56s/it][A Iteration: 4%|▍ | 6/140 [00:09<03:28, 1.56s/it][A Iteration: 5%|▌ | 7/140 [00:10<03:26, 1.56s/it][A Iteration: 6%|▌ | 8/140 [00:12<03:23, 1.54s/it][A Iteration: 6%|▋ | 9/140 [00:13<03:21, 1.54s/it][A Iteration: 7%|▋ | 10/140 [00:15<03:18, 1.53s/it][A Iteration: 8%|▊ | 11/140 [00:17<03:19, 1.54s/it][A Iteration: 9%|▊ | 12/140 [00:18<03:16, 1.53s/it][A Iteration: 9%|▉ | 13/140 [00:20<03:13, 1.52s/it][A Iteration: 10%|█ | 14/140 [00:21<03:11, 1.52s/it][A Iteration: 11%|█ | 15/140 [00:23<03:09, 1.52s/it][A Iteration: 11%|█▏ | 16/140 [00:24<03:07, 1.51s/it][A Iteration: 12%|█▏ | 17/140 [00:26<03:05, 1.51s/it][A Iteration: 13%|█▎ | 18/140 [00:27<03:03, 1.50s/it][A Iteration: 14%|█▎ | 19/140 [00:29<03:01, 1.50s/it][A Iteration: 14%|█▍ | 20/140 [00:30<02:59, 1.50s/it][A Iteration: 15%|█▌ | 21/140 [00:32<02:58, 1.50s/it][A Iteration: 16%|█▌ | 22/140 [00:33<02:57, 1.51s/it][A Iteration: 16%|█▋ | 23/140 [00:35<02:56, 1.51s/it][A Iteration: 17%|█▋ | 24/140 [00:36<02:55, 1.51s/it][A Iteration: 18%|█▊ | 25/140 [00:38<02:54, 1.52s/it][A Iteration: 19%|█▊ | 26/140 [00:39<02:54, 1.53s/it][A Iteration: 19%|█▉ | 27/140 [00:41<02:54, 1.55s/it][A Iteration: 20%|██ | 28/140 [00:42<02:52, 1.54s/it][A Iteration: 21%|██ | 29/140 [00:44<02:51, 1.54s/it][A Iteration: 21%|██▏ | 30/140 [00:45<02:49, 1.54s/it][A Iteration: 22%|██▏ | 31/140 [00:47<02:48, 1.55s/it][A Iteration: 23%|██▎ | 32/140 [00:48<02:45, 1.53s/it][A Iteration: 24%|██▎ | 33/140 [00:50<02:44, 1.53s/it][A Iteration: 24%|██▍ | 34/140 [00:51<02:41, 1.53s/it][A Iteration: 25%|██▌ | 35/140 [00:53<02:40, 1.53s/it][A Iteration: 26%|██▌ | 36/140 [00:55<02:38, 1.53s/it][A Iteration: 26%|██▋ | 37/140 [00:56<02:37, 1.53s/it][A Iteration: 27%|██▋ | 38/140 [00:58<02:35, 1.52s/it][A Iteration: 28%|██▊ | 39/140 [00:59<02:33, 1.52s/it][A Iteration: 29%|██▊ | 40/140 [01:01<02:31, 1.52s/it][A Iteration: 29%|██▉ | 41/140 [01:02<02:29, 1.51s/it][A Iteration: 30%|███ | 42/140 [01:04<02:28, 1.52s/it][A Iteration: 31%|███ | 43/140 [01:05<02:27, 1.52s/it][A Iteration: 31%|███▏ | 44/140 [01:07<02:25, 1.52s/it][A Iteration: 32%|███▏ | 45/140 [01:08<02:24, 1.52s/it][A Iteration: 33%|███▎ | 46/140 [01:10<02:23, 1.53s/it][A Iteration: 34%|███▎ | 47/140 [01:11<02:22, 1.53s/it][A Iteration: 34%|███▍ | 48/140 [01:13<02:20, 1.53s/it][A Iteration: 35%|███▌ | 49/140 [01:14<02:18, 1.52s/it][A Iteration: 36%|███▌ | 50/140 [01:16<02:16, 1.52s/it][A Iteration: 36%|███▋ | 51/140 [01:17<02:15, 1.52s/it][A Iteration: 37%|███▋ | 52/140 [01:19<02:13, 1.52s/it][A Iteration: 38%|███▊ | 53/140 [01:20<02:11, 1.51s/it][A Iteration: 39%|███▊ | 54/140 [01:22<02:11, 1.53s/it][A Iteration: 39%|███▉ | 55/140 [01:23<02:09, 1.52s/it][A Iteration: 40%|████ | 56/140 [01:25<02:07, 1.52s/it][A Iteration: 41%|████ | 57/140 [01:27<02:06, 1.53s/it][A Iteration: 41%|████▏ | 58/140 [01:28<02:04, 1.52s/it][A Iteration: 42%|████▏ | 59/140 [01:30<02:02, 1.51s/it][A Evaluating: 0%| | 0/8 [00:00<?, ?it/s][A[A Evaluating: 12%|█▎ | 1/8 [00:00<00:06, 1.15it/s][A[A Evaluating: 25%|██▌ | 2/8 [00:01<00:05, 1.14it/s][A[A Evaluating: 38%|███▊ | 3/8 [00:02<00:04, 1.14it/s][A[A Evaluating: 50%|█████ | 4/8 [00:03<00:03, 1.14it/s][A[A Evaluating: 62%|██████▎ | 5/8 [00:04<00:02, 1.14it/s][A[A Evaluating: 75%|███████▌ | 6/8 [00:05<00:01, 1.14it/s][A[A Evaluating: 88%|████████▊ | 7/8 [00:06<00:00, 1.14it/s][A[A Evaluating: 100%|██████████| 8/8 [00:06<00:00, 1.17it/s][A[A /home/frank/miniconda3/envs/bio-bert-bilstm-crf/lib/python3.7/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: UNK seems not to be NE tag. warnings.warn('{} seems not to be NE tag.'.format(chunk)) intent_acc = 0.85 loss = 2.148504465818405 sementic_frame_acc = 0.708 slot_f1 = 0.9299883313885647 slot_precision = 0.9278230500582072 slot_recall = 0.9321637426900585 Iteration: 43%|████▎ | 60/140 [01:38<04:57, 3.71s/it][A Saving model checkpoint to ./save_model Iteration: 44%|████▎ | 61/140 [01:40<04:00, 3.05s/it][A Iteration: 44%|████▍ | 62/140 [01:41<03:21, 2.58s/it][A Iteration: 45%|████▌ | 63/140 [01:43<02:53, 2.26s/it][A Iteration: 46%|████▌ | 64/140 [01:44<02:34, 2.03s/it][A Iteration: 46%|████▋ | 65/140 [01:46<02:20, 1.88s/it][A Iteration: 47%|████▋ | 66/140 [01:47<02:10, 1.77s/it][A Iteration: 48%|████▊ | 67/140 [01:49<02:03, 1.70s/it][A Iteration: 49%|████▊ | 68/140 [01:50<01:58, 1.64s/it][A Iteration: 49%|████▉ | 69/140 [01:52<01:53, 1.60s/it][A Iteration: 50%|█████ | 70/140 [01:53<01:50, 1.58s/it][A Iteration: 51%|█████ | 71/140 [01:55<01:47, 1.56s/it][A Iteration: 51%|█████▏ | 72/140 [01:57<01:47, 1.58s/it][A Iteration: 52%|█████▏ | 73/140 [01:58<01:45, 1.57s/it][A Iteration: 53%|█████▎ | 74/140 [02:00<01:42, 1.56s/it][A Iteration: 54%|█████▎ | 75/140 [02:01<01:40, 1.54s/it][A Iteration: 54%|█████▍ | 76/140 [02:03<01:38, 1.54s/it][A Iteration: 55%|█████▌ | 77/140 [02:04<01:37, 1.54s/it][A Iteration: 56%|█████▌ | 78/140 [02:06<01:35, 1.54s/it][A Iteration: 56%|█████▋ | 79/140 [02:07<01:34, 1.54s/it][A Iteration: 57%|█████▋ | 80/140 [02:09<01:32, 1.53s/it][A Iteration: 58%|█████▊ | 81/140 [02:10<01:30, 1.53s/it][A Iteration: 59%|█████▊ | 82/140 [02:12<01:28, 1.52s/it][A Iteration: 59%|█████▉ | 83/140 [02:13<01:26, 1.52s/it][A Iteration: 60%|██████ | 84/140 [02:15<01:25, 1.52s/it][A Iteration: 61%|██████ | 85/140 [02:16<01:23, 1.52s/it][A Iteration: 61%|██████▏ | 86/140 [02:18<01:22, 1.52s/it][A Iteration: 62%|██████▏ | 87/140 [02:19<01:20, 1.53s/it][A Iteration: 63%|██████▎ | 88/140 [02:21<01:19, 1.53s/it][A Iteration: 64%|██████▎ | 89/140 [02:23<01:17, 1.52s/it][A Iteration: 64%|██████▍ | 90/140 [02:24<01:15, 1.51s/it][A Iteration: 65%|██████▌ | 91/140 [02:26<01:13, 1.50s/it][A Iteration: 66%|██████▌ | 92/140 [02:27<01:12, 1.50s/it][A Iteration: 66%|██████▋ | 93/140 [02:29<01:10, 1.50s/it][A Iteration: 67%|██████▋ | 94/140 [02:30<01:09, 1.50s/it][A Iteration: 68%|██████▊ | 95/140 [02:32<01:07, 1.50s/it][A Iteration: 69%|██████▊ | 96/140 [02:33<01:06, 1.50s/it][A Iteration: 69%|██████▉ | 97/140 [02:35<01:04, 1.50s/it][A Iteration: 70%|███████ | 98/140 [02:36<01:02, 1.50s/it][A Iteration: 71%|███████ | 99/140 [02:38<01:01, 1.50s/it][A Iteration: 71%|███████▏ | 100/140 [02:39<01:00, 1.50s/it][A Iteration: 72%|███████▏ | 101/140 [02:41<00:58, 1.50s/it][A Iteration: 73%|███████▎ | 102/140 [02:42<00:57, 1.50s/it][A Iteration: 74%|███████▎ | 103/140 [02:44<00:55, 1.50s/it][A Iteration: 74%|███████▍ | 104/140 [02:45<00:54, 1.51s/it][A Iteration: 75%|███████▌ | 105/140 [02:47<00:52, 1.51s/it][A Iteration: 76%|███████▌ | 106/140 [02:48<00:51, 1.51s/it][A Iteration: 76%|███████▋ | 107/140 [02:50<00:49, 1.51s/it][A Iteration: 77%|███████▋ | 108/140 [02:51<00:48, 1.52s/it][A Iteration: 78%|███████▊ | 109/140 [02:53<00:47, 1.53s/it][A Iteration: 79%|███████▊ | 110/140 [02:54<00:45, 1.53s/it][A Iteration: 79%|███████▉ | 111/140 [02:56<00:44, 1.53s/it][A Iteration: 80%|████████ | 112/140 [02:57<00:42, 1.52s/it][A Iteration: 81%|████████ | 113/140 [02:59<00:41, 1.52s/it][A Iteration: 81%|████████▏ | 114/140 [03:00<00:39, 1.52s/it][A Iteration: 82%|████████▏ | 115/140 [03:02<00:37, 1.52s/it][A Iteration: 83%|████████▎ | 116/140 [03:03<00:36, 1.51s/it][A Iteration: 84%|████████▎ | 117/140 [03:05<00:34, 1.52s/it][A Iteration: 84%|████████▍ | 118/140 [03:06<00:33, 1.54s/it][A Iteration: 85%|████████▌ | 119/140 [03:08<00:32, 1.54s/it][A Iteration: 86%|████████▌ | 120/140 [03:09<00:30, 1.54s/it][A Iteration: 86%|████████▋ | 121/140 [03:11<00:29, 1.54s/it][A Iteration: 87%|████████▋ | 122/140 [03:13<00:27, 1.53s/it][A Iteration: 88%|████████▊ | 123/140 [03:14<00:25, 1.52s/it][A Iteration: 89%|████████▊ | 124/140 [03:16<00:24, 1.52s/it][A Iteration: 89%|████████▉ | 125/140 [03:17<00:22, 1.51s/it][A Iteration: 90%|█████████ | 126/140 [03:19<00:21, 1.51s/it][A Iteration: 91%|█████████ | 127/140 [03:20<00:19, 1.51s/it][A Iteration: 91%|█████████▏| 128/140 [03:22<00:18, 1.51s/it][A Iteration: 92%|█████████▏| 129/140 [03:23<00:16, 1.50s/it][A Iteration: 93%|█████████▎| 130/140 [03:25<00:15, 1.50s/it][A Iteration: 94%|█████████▎| 131/140 [03:26<00:13, 1.50s/it][A Iteration: 94%|█████████▍| 132/140 [03:28<00:11, 1.50s/it][A Iteration: 95%|█████████▌| 133/140 [03:29<00:10, 1.50s/it][A Iteration: 96%|█████████▌| 134/140 [03:31<00:08, 1.50s/it][A Iteration: 96%|█████████▋| 135/140 [03:32<00:07, 1.51s/it][A Iteration: 97%|█████████▋| 136/140 [03:34<00:06, 1.50s/it][A Iteration: 98%|█████████▊| 137/140 [03:35<00:04, 1.50s/it][A Iteration: 99%|█████████▊| 138/140 [03:37<00:02, 1.50s/it][A Iteration: 99%|█████████▉| 139/140 [03:38<00:01, 1.51s/it][A Iteration: 100%|██████████| 140/140 [03:39<00:00, 1.57s/it][A Epoch: 100%|██████████| 2/2 [07:11<00:00, 215.63s/it] (280, 5.539662108251027)
用训练好的模型预测数据
trainer.evaluate("test")
Evaluating: 100%|██████████| 14/14 [00:11<00:00, 1.17it/s] intent_acc = 0.8443449048152296 loss = 2.3734985419682095 sementic_frame_acc = 0.7077267637178052 slot_f1 = 0.916652105539053 slot_precision = 0.909500693481276 slot_recall = 0.9239168721380768 {'loss': 2.3734985419682095, 'intent_acc': 0.8443449048152296, 'slot_precision': 0.909500693481276, 'slot_recall': 0.9239168721380768, 'slot_f1': 0.916652105539053, 'sementic_frame_acc': 0.7077267637178052}
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。