原文与代码链接: https://github.com/AndrewHYC/COPNER
parser.add_argument('--gpu', default='0',
help='the gpu number for traning')
parser.add_argument('--seed', type=int, default=42,
help='random seed')
parser.add_argument('--mode', default='inter', help='training mode, must be in [inter, intra, supervised, i2b2, conll, wnut, mit-movie]') parser.add_argument('--task', default='cross-label-space', help='training task, must be in [cross-label-space, domain-transfer, in-label-space]') parser.add_argument('--trainN', default=5, type=int, help='N in train') parser.add_argument('--N', default=5, type=int, help='N way') parser.add_argument('--K', default=1, type=int, help='K shot') parser.add_argument('--Q', default=1, type=int, help='Num of query per class') parser.add_argument('--support_num', default=0, type=int, help='the id number of support set') parser.add_argument('--zero_shot', action='store_true', help='') parser.add_argument('--only_test', action='store_true', help='only test') parser.add_argument('--load_ckpt', default=None, help='load ckpt') parser.add_argument('--ckpt_name', type=str, default='', help='checkpoint name.')
parser.add_argument('--pretrain_ckpt', default='./premodel/roberta-wwm-ext-base', help='bert pre-trained checkpoint: bert-base-uncased / bert-base-cased') parser.add_argument('--prompt', default=1, type=int, choices=[0,1,2], help='choice in [0,1,2]:\ 0: Continue Prompt\ 1: Partition Prompt\ 2: Queue Prompt') parser.add_argument('--pseudo_token', default='[S]', type=str, help='pseudo_token') parser.add_argument('--max_length', default=64, type=int, help='max length') parser.add_argument('--ignore_index', type=int, default=-1, help='label index to ignore when calculating loss and metrics') parser.add_argument('--struct', action='store_true', help='StructShot parameter to re-normalizes the transition probabilities') parser.add_argument('--tau', default=1, type=float, help='the temperature rate for contrastive learning') parser.add_argument('--struct_tau', default=0.32, type=float, help='the tau in the viterbi decode')
parser.add_argument('--batch_size', default=16, type=int, help='batch size') parser.add_argument('--test_bz', default=1, type=int, help='test or val batch size') parser.add_argument('--train_iter', default=10000, type=int, help='num of iters in training') parser.add_argument('--val_iter', default=200, type=int, help='num of iters in validation') parser.add_argument('--test_iter', default=5000, type=int, help='num of iters in testing') parser.add_argument('--val_step', default=200, type=int, help='val after training how many iters') parser.add_argument('--adapt_step', default=5, type=int, help='adapting how many iters in validing or testing') parser.add_argument('--adapt_auto', action='store_true', help='adapting how many iters in validing or testing') parser.add_argument('--threshold_alpha', default=0.1, type=float, help='Gradient descent change threshold for early stopping') parser.add_argument('--threshold_beta', default=0.5, type=float, help='loss threshold for early stopping') parser.add_argument('--lr', default=1e-4, type=float, help='learning rate of Training') parser.add_argument('--adapt_lr', default=None, type=float, help='learning rate of Adapting') parser.add_argument('--grad_iter', default=1, type=int, help='accumulate gradient every x iterations') parser.add_argument('--early_stopping', type=int, default=3000, help='iteration numbers to stop without performance increasing') parser.add_argument('--use_sgd_for_lm', action='store_true', help='use SGD instead of AdamW for BERT.')
def main(): trainN = opt.trainN if opt.trainN is not None else opt.N # opt.trainN = opt.N = 5 N = opt.N # 5 K = opt.K # 1 Q = opt.Q # 1 max_length = opt.max_length # 64 if opt.adapt_lr is None and opt.lr: # opt.adapt_lr = None / opt.lr = 1e-4 opt.adapt_lr = opt.lr print("{}-way-{}-shot Few-Shot NER".format(N, K)) print('task: {}'.format(opt.task)) print('mode: {}'.format(opt.mode)) print('prompt: {}'.format(opt.prompt)) print("support: {}".format(opt.support_num)) print("max_length: {}".format(max_length)) print("batch_size: {}".format(opt.test_bz if opt.only_test else opt.batch_size)) set_seed(opt.seed) print('loading model and tokenizer...') pretrain_ckpt = opt.pretrain_ckpt or 'bert-base-uncased' config = BertConfig.from_pretrained(pretrain_ckpt) tokenizer = BertTokenizer.from_pretrained(pretrain_ckpt) opt.tokenizer = tokenizer word_encoder = BERTWordEncoder.from_pretrained(pretrain_ckpt, config=config, args=opt)
if opt.task == 'cross-label-space': opt.train = f'data/few-nerd/{opt.mode}/train.txt' opt.dev = f'data/few-nerd/{opt.mode}/dev.txt' opt.test = f'data/few-nerd/{opt.mode}/test.txt' opt.train_word_map = opt.dev_word_map = opt.test_word_map = FEWNERD_WORD_MAP print(f'loading train data: {opt.train}') train_data_loader = get_loader(opt.train, tokenizer, word_map = opt.train_word_map, N=trainN, K=1, Q=Q, batch_size=opt.batch_size, max_length=max_length, # K=1 for training ignore_index=opt.ignore_index, args=opt, train=True) print(f'loading eval data: {opt.dev}') val_data_loader = get_loader(opt.dev, tokenizer, word_map = opt.dev_word_map, N=N, K=K, Q=Q, batch_size=opt.test_bz, max_length=max_length, ignore_index=opt.ignore_index, args=opt) print(f'loading test data: {opt.test}') test_data_loader = get_loader(opt.test, tokenizer, word_map = opt.test_word_map, N=N, K=K, Q=Q, batch_size=opt.test_bz, max_length=max_length, ignore_index=opt.ignore_index, args=opt)
N=5 K=1
for training, Q=1
opt.train_word_map = opt.dev_word_map = opt.test_word_map = FEWNERD_WORD_MAP
from collections import OrderedDict # # Few-NERD FEWNERD_WORD_MAP = OrderedDict() FEWNERD_WORD_MAP['O'] = 'none' FEWNERD_WORD_MAP['location-GPE'] = 'nation' FEWNERD_WORD_MAP['location-bodiesofwater'] = 'water' FEWNERD_WORD_MAP['location-island'] = 'island' FEWNERD_WORD_MAP['location-mountain'] = 'mountain' FEWNERD_WORD_MAP['location-park'] = 'parks' FEWNERD_WORD_MAP['location-road/railway/highway/transit'] = 'road' FEWNERD_WORD_MAP['location-other'] = 'location' FEWNERD_WORD_MAP['person-actor'] = 'actor' FEWNERD_WORD_MAP['person-artist/author'] = 'artist' FEWNERD_WORD_MAP['person-athlete'] = 'athlete' FEWNERD_WORD_MAP['person-director'] = 'director' FEWNERD_WORD_MAP['person-politician'] = 'politician' FEWNERD_WORD_MAP['person-scholar'] = 'scholar' FEWNERD_WORD_MAP['person-soldier'] = 'soldier' FEWNERD_WORD_MAP['person-other'] = 'person' FEWNERD_WORD_MAP['organization-company'] = 'company' FEWNERD_WORD_MAP['organization-education'] = 'education' FEWNERD_WORD_MAP['organization-government/governmentagency'] = 'government' FEWNERD_WORD_MAP['organization-media/newspaper'] = 'media' FEWNERD_WORD_MAP['organization-politicalparty'] = 'parties' FEWNERD_WORD_MAP['organization-religion'] = 'religion' FEWNERD_WORD_MAP['organization-showorganization'] = 'show' FEWNERD_WORD_MAP['organization-sportsleague'] = 'league' FEWNERD_WORD_MAP['organization-sportsteam'] = 'team' FEWNERD_WORD_MAP['organization-other'] = 'organization' FEWNERD_WORD_MAP['building-airport'] = 'airport' FEWNERD_WORD_MAP['building-hospital'] = 'hospital' FEWNERD_WORD_MAP['building-hotel'] = 'hotel' FEWNERD_WORD_MAP['building-library'] = 'library' FEWNERD_WORD_MAP['building-restaurant'] = 'restaurant' FEWNERD_WORD_MAP['building-sportsfacility'] = 'facility' FEWNERD_WORD_MAP['building-theater'] = 'theater' FEWNERD_WORD_MAP['building-other'] = 'building' FEWNERD_WORD_MAP['art-broadcastprogram'] = 'broadcast' FEWNERD_WORD_MAP['art-film'] = 'film' FEWNERD_WORD_MAP['art-music'] = 'music' FEWNERD_WORD_MAP['art-painting'] = 'painting' FEWNERD_WORD_MAP['art-writtenart'] = 'writing' FEWNERD_WORD_MAP['art-other'] = 'art' FEWNERD_WORD_MAP['product-airplane'] = 'airplane' FEWNERD_WORD_MAP['product-car'] = 'car' FEWNERD_WORD_MAP['product-food'] = 'food' FEWNERD_WORD_MAP['product-game'] = 'game' FEWNERD_WORD_MAP['product-ship'] = 'ship' FEWNERD_WORD_MAP['product-software'] = 'software' FEWNERD_WORD_MAP['product-train'] = 'train' FEWNERD_WORD_MAP['product-weapon'] = 'weapon' FEWNERD_WORD_MAP['product-other'] = 'product' FEWNERD_WORD_MAP['event-attack/battle/war/militaryconflict'] = 'war' FEWNERD_WORD_MAP['event-disaster'] = 'disaster' FEWNERD_WORD_MAP['event-election'] = 'election' FEWNERD_WORD_MAP['event-protest'] = 'protest' FEWNERD_WORD_MAP['event-sportsevent'] = 'sport' FEWNERD_WORD_MAP['event-other'] = 'event' FEWNERD_WORD_MAP['other-astronomything'] = 'astronomy' FEWNERD_WORD_MAP['other-award'] = 'award' FEWNERD_WORD_MAP['other-biologything'] = 'biology' FEWNERD_WORD_MAP['other-chemicalthing'] = 'chemistry' FEWNERD_WORD_MAP['other-currency'] = 'currency' FEWNERD_WORD_MAP['other-disease'] = 'disease' FEWNERD_WORD_MAP['other-educationaldegree'] = 'degree' FEWNERD_WORD_MAP['other-god'] = 'god' FEWNERD_WORD_MAP['other-language'] = 'language' FEWNERD_WORD_MAP['other-law'] = 'law' FEWNERD_WORD_MAP['other-livingthing'] = 'organism' FEWNERD_WORD_MAP['other-medical'] = 'medical'
def get_loader(filepath, tokenizer, N, K, Q, batch_size, max_length, word_map, ignore_index=-1, args=None, num_workers=4, support_file_path=None, train=False): if train: dataset = SingleDatasetwithEpisodeSample(N, 1, filepath, tokenizer, max_length, ignore_label_id=ignore_index, args=args, word_map=word_map) return data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers, collate_fn=single_collate_fn) else: if args.task in ['cross-label-space']: dataset = PairDatasetwithEpisodeSample(N, K, Q, filepath, tokenizer, max_length, ignore_label_id=ignore_index, args=args, word_map=word_map) return data.DataLoader(dataset=dataset, batch_size=1, shuffle=True, pin_memory=True, num_workers=num_workers, collate_fn=pair_collate_fn) elif args.task in ['domain-transfer']: dataset = PairDatasetwithFixedSupport(N, filepath, support_file_path, tokenizer, max_length, ignore_label_id=ignore_index, args=args, word_map=word_map) return data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers, collate_fn=pair_collate_fn) elif args.task in ['in-label-space']: dataset = SingleDatasetwithRamdonSample(filepath, tokenizer, max_length, ignore_label_id=ignore_index, args=args, word_map=word_map) return data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=num_workers, collate_fn=single_collate_fn)
继承自 PairDatasetwithEpisodeSample 类,该类用于处理单数据集的示例采样。
class SingleDatasetwithEpisodeSample(PairDatasetwithEpisodeSample): def __init__(self, N, K, filepath, tokenizer, max_length, word_map, ignore_label_id=-1, args=None): if not os.path.exists(filepath): print("[ERROR] Data file does not exist!") assert(0) self.class2sampleid = {} self.word_map = word_map self.word2class = OrderedDict() for key, value in self.word_map.items(): self.word2class[value] = key self.BOS = '[CLS]' self.EOS = '[SEP]' self.max_length = max_length self.ignore_label_id = ignore_label_id self.samples, self.classes = self.__load_data_from_file__(filepath) self.sampler = SingleFewshotSampler(N, K, self.samples, classes=self.classes) self.prompt = args.prompt self.tokenizer = tokenizer self.pseudo_token = args.pseudo_token self.tokenizer.add_special_tokens({'additional_special_tokens': [args.pseudo_token]}) def __getitem__(self, index): target_classes, support_idx = self.sampler.__next__() # add 'none' and make sure 'none' is labeled 0 distinct_tags = [self.word_map['O']] + target_classes prompt_tags = distinct_tags.copy() random.shuffle(prompt_tags) self.tag2label = {tag:idx for idx, tag in enumerate(distinct_tags)} self.label2tag = {idx:self.word2class[tag] for idx, tag in enumerate(distinct_tags)} support_set = self.__populate__(support_idx, distinct_tags, prompt_tags, savelabeldic=True) return support_set def __len__(self): return 1000000
def __load_data_from_file__(self, filepath): samples = [] # 存储样本 classes = [] # 存储类别 with open(filepath, 'r', encoding='utf-8')as f: lines = f.readlines() samplelines = [] index = 0 for line in lines: line = line.strip() if len(line.split('\t'))>1: # 若一行中包含制表符'\t',则将改行添加到samplelines列表中,表示这一行是样本数据的一部分 samplelines.append(line) else: # 若不包含制表符,则表示当前行是样本的结束,开始处理新的样本,将samplelines列表中的数据用于创建一个Sample对象 sample = Sample(samplelines, self.word_map) samples.append(sample) # 从Sample对象中获取标签类别,通过get_tag_class方法获取,并将这些类别添加到classes列表中 sample_classes = sample.get_tag_class() self.__insert_sample__(index, sample_classes) classes += sample_classes samplelines = [] # 清空samplelines列表 index += 1 # 将index加1 classes = list(set(classes)) # 遍历完成后将classes列表转换为集合,去除重复的类别 return samples, classes
class SingleFewshotSampler(PairFewshotSampler): def __init__(self, N, K, samples, classes=None, random_state=0): ''' N: int, how many types in each set K: int, how many instances for each type in data set samples: List[Sample], Sample class must have `get_class_count` attribute classes[Optional]: List[any], all unique classes in samples. If not given, the classes will be got from samples.get_class_count() random_state[Optional]: int, the random seed ''' self.K = K self.N = N self.samples = samples self.__check__() # check if samples have correct types if classes: self.classes = classes else: self.classes = self.__get_all_classes__() random.seed(random_state) def __next__(self): ''' randomly sample one episode set ''' episode_class = {'k':self.K} episode_idx = [] target_classes = random.sample(self.classes, self.N) candidates = self.__get_candidates__(target_classes) while not candidates: target_classes = random.sample(self.classes, self.N) candidates = self.__get_candidates__(target_classes) # greedy search for episode set while not self.__finish__(episode_class): index = random.choice(candidates) if index not in episode_idx: if self.__valid_sample__(self.samples[index], episode_class, target_classes): self.__additem__(index, episode_class) episode_idx.append(index) return target_classes, episode_idx
这段代码定义了一个名为 SingleFewshotSampler 的类,它继承自 PairFewshotSampler。SingleFewshotSampler 的目的是从一个包含多种类别(types)的数据集中采样少数样本(few-shot),以用于训练或测试。
model = COPNER(word_encoder, opt, opt.train_word_map if not opt.only_test else opt.test_word_map)
class COPNER(FewShotNERModel): def __init__(self, word_encoder, args, word_map): FewShotNERModel.__init__(self, word_encoder, ignore_index=args.ignore_index) self.tokenizer = args.tokenizer self.tau = args.tau # 初始化损失函数loss_fct为CrossEntropyLoss,用于分类问题,并设置忽略索引 self.loss_fct = CrossEntropyLoss(ignore_index=args.ignore_index) self.method = 'euclidean' self.class2word = word_map self.word2class = OrderedDict() for key, value in self.class2word.items(): self.word2class[value] = key def __dist__(self, x, y, dim, normalize=False): if normalize: # 对向量进行归一化处理 x = F.normalize(x, dim=-1) y = F.normalize(y, dim=-1) if self.method == 'dot': # 点积 sim = (x * y).sum(dim) elif self.method == 'euclidean': # 欧氏距离 sim = -(torch.pow(x - y, 2)).sum(dim) elif self.method == 'cosine': # 余弦相似度 sim = F.cosine_similarity(x, y, dim=dim) return sim / self.tau def get_contrastive_logits(self, hidden_states, inputs, valid_mask, target_classes): # 获取对比损失 class_indexs = [self.tokenizer.get_vocab()[tclass] for tclass in target_classes] # 获取目标类别的索引列表class_indexs class_rep = [] for iclass in class_indexs: class_rep.append(torch.mean(hidden_states[inputs.eq(iclass), :].view(-1, hidden_states.size(-1)), 0)) class_rep = torch.stack(class_rep).unsqueeze(0) # 计算每个类别的代表性向量class_rep token_rep = hidden_states[valid_mask != self.tokenizer.pad_token_id, :].view(-1, hidden_states.size(-1)).unsqueeze(1) logits = self.__dist__(class_rep, token_rep, -1) return logits.view(-1, len(target_classes)) def forward(self, input_ids, labels, valid_masks, target_classes, sentence_num, ): # 验证输入数据的尺寸是否一致 assert input_ids.size(0) == labels.size(0) == valid_masks.size(0), \ print('[ERROR] inputs and labels must have same batch size.') assert len(sentence_num) == len(target_classes) # 通过词编码器获得隐藏状态hidden_states hidden_states = self.word_encoder(input_ids) # logits, (encoder_hs, decoder_hs) loss = None logits = [] current_num = 0 # 对于每个句子,计算对比损失,若处于训练状态,累加损失 for i, num in enumerate(sentence_num): current_hs = hidden_states[current_num: current_num+num] current_input_ids = input_ids[current_num: current_num+num] current_labels = labels[current_num: current_num+num] current_valid_masks = valid_masks[current_num: current_num+num] current_target_classes = target_classes[i] current_num += num contrastive_logits = self.get_contrastive_logits(current_hs, current_input_ids, current_valid_masks, current_target_classes) current_logits = F.softmax(contrastive_logits, -1) if self.training: contrastive_loss = self.loss_fct(contrastive_logits, current_labels[current_valid_masks != self.tokenizer.pad_token_id].view(-1)) loss = contrastive_loss if loss is None else loss + contrastive_loss current_logits = current_logits.view(-1, current_logits.size(-1)) logits.append(current_logits) # 计算每个句子的logits,并将其堆叠起来 logits = torch.cat(logits, 0) _, preds = torch.max(logits, 1) # 预测结果 # 返回平均损失 if loss: loss /= len(sentence_num) return logits, preds, loss
framework = FewShotNERFramework(opt, train_data_loader, val_data_loader, test_data_loader,
train_fname=opt.train if opt.struct else None,
viterbi=True if opt.struct else False)
class FewShotNERFramework: def __init__(self, args, train_data_loader, val_data_loader, test_data_loader, viterbi=False, train_fname=None): ''' train_data_loader: DataLoader for training. val_data_loader: DataLoader for validating. test_data_loader: DataLoader for testing. viterbi: Whether to use Viterbi decoding. train_fname: Path of the data file to get abstract transitions. ''' self.args = args self.train_data_loader = train_data_loader self.val_data_loader = val_data_loader self.test_data_loader = test_data_loader self.viterbi = viterbi if viterbi: # 是否使用维特比解码器来进行序列标注任务的解码 abstract_transitions = get_abstract_transitions(train_fname, args) self.viterbi_decoder = ViterbiDecoder(self.args.N+2, abstract_transitions, tau=args.struct_tau)
def get_abstract_transitions(train_fname, args): """ Compute abstract transitions on the training dataset for StructShot """ samples = SingleDatasetwithRamdonSample(train_fname, None, None, word_map=args.train_word_map, args=args).samples tag_lists = [sample.tags for sample in samples] s_o, s_i = 0., 0. o_o, o_i = 0., 0. i_o, i_i, x_y = 0., 0., 0. for tags in tag_lists: if tags[0] == 'O': s_o += 1 else: s_i += 1 for i in range(len(tags)-1): p, n = tags[i], tags[i+1] if p == 'O': if n == 'O': o_o += 1 else: o_i += 1 else: if n == 'O': i_o += 1 elif p != n: x_y += 1 else: i_i += 1 trans = [] trans.append(s_o / (s_o + s_i)) trans.append(s_i / (s_o + s_i)) trans.append(o_o / (o_o + o_i)) trans.append(o_i / (o_o + o_i)) trans.append(i_o / (i_o + i_i + x_y)) trans.append(i_i / (i_o + i_i + x_y)) trans.append(x_y / (i_o + i_i + x_y)) return trans
get_emmissions将模型输出的logits(即未归一化的得分)根据输入的标签列表进行分割,形成与标签对应的 emissions(发射概率)。
def __get_emmissions__(self, logits, tags_list):
# split [num_of_query_tokens, num_class] into [[num_of_token_in_sent, num_class], ...]
emmissions = []
current_idx = 0
for tags in tags_list:
current_idx += len(tags)
assert current_idx == logits.size()[0]
return emmissions
def viterbi_decode(self, logits, query_tags): emissions_list = self.__get_emmissions__(logits, query_tags) pred = [] for i in range(len(query_tags)): sent_scores = emissions_list[i].cpu() sent_len, n_label = sent_scores.shape sent_probs = F.softmax(sent_scores, dim=1) start_probs = torch.zeros(sent_len) + 1e-6 sent_probs = torch.cat((start_probs.view(sent_len, 1), sent_probs), 1) feats = self.viterbi_decoder.forward(torch.log(sent_probs).view(1, sent_len, n_label+1)) vit_labels = self.viterbi_decoder.viterbi(feats) vit_labels = vit_labels.view(sent_len) vit_labels = vit_labels.detach().cpu().numpy().tolist() for label in vit_labels: pred.append(label-1) return torch.tensor(pred).cuda()
使用维特比解码器来对序列标签进行解码。首先,它将 logits 分割成与查询标签对应的 emissions。然后,对于每个句子,计算发射概率,并且结合转移概率使用维特比算法找出最有可能的标签序列。最后,将解码得到的标签序列转换为张量并返回。
framework.train(model, prefix,
warmup_step=int(opt.train_iter * 0.05),
def train(self, model, model_name, learning_rate=1e-4, train_iter=30000, val_iter=1000, val_step=2000, load_ckpt=None, save_ckpt=None, warmup_step=300, grad_iter=1, use_sgd_for_lm=False): ''' model: a FewShotREModel instance model_name: Name of the model learning_rate: Initial learning rate train_iter: Num of iterations of training val_iter: Num of iterations of validating val_step: Validate every val_step steps load_ckpt: Path of the checkpoint to load save_ckpt: Path of the checkpoint to save warmup_step: Num of warmup steps grad_iter: Accumulate gradients for grad_iter steps use_sgd_for_lm: Whether to use SGD for the language model ''' # Init optimizer print('Use bert optim!') parameters_to_optimize = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] parameters_to_optimize = [ {'params': [p for n, p in parameters_to_optimize if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in parameters_to_optimize if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] if use_sgd_for_lm: optimizer = torch.optim.SGD(parameters_to_optimize, lr=learning_rate) else: optimizer = AdamW(parameters_to_optimize, lr=learning_rate) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=train_iter) # load model if load_ckpt: state_dict = self.__load_model__(load_ckpt)['state_dict'] own_state = model.state_dict() for name, param in state_dict.items(): if name not in own_state: print('ignore {}'.format(name)) continue print('load {} from {}'.format(name, load_ckpt)) own_state[name].copy_(param) model.train() # Training iter_loss = 0.0 best_precision = 0.0 best_recall = 0.0 best_f1 = 0.0 iter_sample = 0 pred_cnt = 1e-9 label_cnt = 1e-9 correct_cnt = 0 last_step = 0 print("Start training...") with tqdm(self.train_data_loader, total=train_iter, disable=False, desc="Training") as tbar: for it, batch in enumerate(tbar): if torch.cuda.is_available(): for k in batch: if k != 'target_classes' and \ k != 'sentence_num' and \ k != 'labels' and \ k != 'label2tag': batch[k] = batch[k].cuda() label = torch.cat(batch['labels'], 0) label = label.cuda() logits, pred, loss = model(batch['inputs'], batch['batch_labels'], batch['valid_masks'], batch['target_classes'], batch['sentence_num']) loss.backward() if it % grad_iter == 0: optimizer.step() scheduler.step() optimizer.zero_grad() # Calculate metrics tmp_pred_cnt, tmp_label_cnt, correct = model.metrics_by_entity(pred, label) iter_loss += self.item(loss.data) pred_cnt += tmp_pred_cnt label_cnt += tmp_label_cnt correct_cnt += correct iter_sample += 1 precision = correct_cnt / pred_cnt recall = correct_cnt / label_cnt f1 = 2 * precision * recall / (precision + recall + 1e-9) # 1e-9 for error'float division by zero' tbar.set_postfix_str("loss: {:2.6f} | F1: {:3.4f}, P: {:3.4f}, R: {:3.4f}, Correct:{}"\ .format(self.item(loss.data), f1, precision, recall, correct_cnt)) if (it + 1) % val_step == 0: precision, recall, f1, _, _, _, _ = self.eval(model, val_iter, word_map=self.args.dev_word_map) model.train() if f1 > best_f1: # print(f'Best checkpoint! Saving to: {save_ckpt}\n') # torch.save({'state_dict': model.state_dict()}, save_ckpt) best_f1 = f1 best_precision = precision best_recall = recall last_step = it else: if it - last_step >= self.args.early_stopping: print('\nEarly Stop by {} steps, best f1: {:.4f}%'.format(self.args.early_stopping, best_f1)) raise KeyboardInterrupt if (it + 1) % 100 == 0: iter_loss = 0. iter_sample = 0. pred_cnt = 1e-9 label_cnt = 1e-9 correct_cnt = 0 if (it + 1) >= train_iter: break print("\n####################\n") print("Finish training {}, best f1: {:.4f}%".format(model_name, best_f1))
