当前位置:   article > 正文

mask language model 的具体实现及思路详解_mask language model tf实现

mask language model tf实现
  1. import os
  2. import json
  3. import copy
  4. from tqdm.notebook import tqdm
  5. import torch
  6. from torch.optim import AdamW
  7. from torch.utils.data import DataLoader, Dataset
  8. from transformers import BertForMaskedLM, BertTokenizerFast
  9. class Config:
  10. def __init__(self):
  11. pass
  12. def mlm_config(
  13. self,
  14. mlm_probability=0.15,
  15. special_tokens_mask=None,
  16. prob_replace_mask=0.8,
  17. prob_replace_rand=0.1,
  18. prob_keep_ori=0.1,
  19. ):
  20. """
  21. :param mlm_probability: 被mask的token总数
  22. :param special_token_mask: 特殊token
  23. :param prob_replace_mask: 被替换成[MASK]的token比率
  24. :param prob_replace_rand: 被随机替换成其他token比率
  25. :param prob_keep_ori: 保留原token的比率
  26. """
  27. assert sum([prob_replace_mask, prob_replace_rand, prob_keep_ori]) == 1, ValueError(
  28. "Sum of the probs must equal to 1.")
  29. self.mlm_probability = mlm_probability
  30. self.special_tokens_mask = special_tokens_mask
  31. self.prob_replace_mask = prob_replace_mask
  32. self.prob_replace_rand = prob_replace_rand
  33. self.prob_keep_ori = prob_keep_ori
  34. def training_config(
  35. self,
  36. batch_size,
  37. epochs,
  38. learning_rate,
  39. weight_decay,
  40. device,
  41. ):
  42. self.batch_size = batch_size
  43. self.epochs = epochs
  44. self.learning_rate = learning_rate
  45. self.weight_decay = weight_decay
  46. self.device = device
  47. def io_config(
  48. self,
  49. from_path,
  50. save_path,
  51. ):
  52. self.from_path = from_path
  53. self.save_path = save_path
  54. class TrainDataset(Dataset):
  55. """
  56. 注意:由于没有使用data_collator,batch放在dataset里边做,
  57. 因而在dataloader出来的结果会多套一层batch维度,传入模型时注意squeeze掉
  58. """
  59. def __init__(self, input_texts, tokenizer, config):
  60. self.input_texts = input_texts
  61. self.tokenizer = tokenizer
  62. self.config = config
  63. self.ori_inputs = copy.deepcopy(input_texts)
  64. def __len__(self):
  65. return len(self.input_texts) // self.config.batch_size
  66. def __getitem__(self, idx):
  67. batch_text = self.input_texts[: self.config.batch_size]
  68. features = self.tokenizer(batch_text, max_length=512, truncation=True, padding=True, return_tensors='pt')
  69. inputs, labels = self.mask_tokens(features['input_ids'])#inputs为带有[mask]等替换的id,label为替换前的id,未替换的值用-100表示
  70. batch = {"inputs": inputs, "labels": labels}
  71. self.input_texts = self.input_texts[self.config.batch_size:]
  72. if not len(self):
  73. self.input_texts = self.ori_inputs
  74. return batch
  75. def mask_tokens(self, inputs):
  76. """
  77. Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
  78. """
  79. labels = inputs.clone()
  80. # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
  81. probability_matrix = torch.full(labels.shape, self.config.mlm_probability)#[4,9]所有值都为0.15
  82. if self.config.special_tokens_mask is None:
  83. special_tokens_mask = [
  84. self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
  85. ]
  86. special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
  87. else:
  88. special_tokens_mask = self.config.special_tokens_mask.bool()
  89. probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
  90. masked_indices = torch.bernoulli(probability_matrix).bool()#矩阵中每一个值都以一定的概率变为1,同时1变为True
  91. labels[~masked_indices] = -100 # We only compute loss on masked tokens,False处变为-100
  92. # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
  93. indices_replaced = torch.bernoulli(
  94. torch.full(labels.shape, self.config.prob_replace_mask)).bool() & masked_indices
  95. inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)#将[mask]转换为id=103
  96. # 10% of the time, we replace masked input tokens with random word
  97. current_prob = self.config.prob_replace_rand / (1 - self.config.prob_replace_mask)
  98. indices_random = torch.bernoulli(
  99. torch.full(labels.shape, current_prob)).bool() & masked_indices & ~indices_replaced
  100. random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
  101. # a=random_words[indices_random]
  102. # b=inputs[indices_random]
  103. # print(a)
  104. # print(b)
  105. inputs[indices_random] = random_words[indices_random]
  106. # The rest of the time (10% of the time) we keep the masked input tokens unchanged
  107. return inputs, labels
  108. def train(model, train_dataloader, config):
  109. """
  110. 训练
  111. :param model: nn.Module
  112. :param train_dataloader: DataLoader
  113. :param config: Config
  114. ---------------
  115. ver: 2021-11-08
  116. by: changhongyu
  117. """
  118. assert config.device.startswith('cuda') or config.device == 'cpu', ValueError("Invalid device.")
  119. device = torch.device(config.device)
  120. model.to(device)
  121. if not len(train_dataloader):
  122. raise EOFError("Empty train_dataloader.")
  123. param_optimizer = list(model.named_parameters())
  124. no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
  125. optimizer_grouped_parameters = [
  126. {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
  127. {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}]
  128. optimizer = AdamW(params=optimizer_grouped_parameters, lr=config.learning_rate, weight_decay=config.weight_decay)
  129. for cur_epc in tqdm(range(int(config.epochs)), desc="Epoch"):
  130. training_loss = 0
  131. print("Epoch: {}".format(cur_epc + 1))
  132. model.train()
  133. for step, batch in enumerate(tqdm(train_dataloader, desc='Step')):
  134. input_ids = batch['inputs'].squeeze(0).to(device)#[b,text_length]包含替换[mask]后,所有词的id
  135. labels = batch['labels'].squeeze(0).to(device)#[b,text_length],替换处id保留,未替换出id变为-100
  136. result = model(input_ids=input_ids, labels=labels)#logits,[b,text_length,vocab_size]
  137. loss = model(input_ids=input_ids, labels=labels).loss
  138. optimizer.zero_grad()
  139. loss.backward()
  140. optimizer.step()
  141. model.zero_grad()
  142. training_loss += loss.item()
  143. print("Training loss: ", training_loss)
  144. if __name__ == '__main__':
  145. config = Config()
  146. config.mlm_config()
  147. config.training_config(batch_size=4, epochs=10, learning_rate=1e-5, weight_decay=0, device='cuda:0')
  148. config.io_config(from_path='/root/autodl-tmp/bert-base-chinese',
  149. save_path='mlm')
  150. bert_tokenizer = BertTokenizerFast.from_pretrained(config.from_path)
  151. bert_mlm_model = BertForMaskedLM.from_pretrained(config.from_path)
  152. training_texts = [
  153. "这是一条文本",
  154. "这是另一条文本",
  155. "这是一条文本",
  156. "这是另一条文本",
  157. "这是一条文本",
  158. "这是另一条文本",
  159. "这是一条文本",
  160. "这是另一条文本",
  161. "这是一条文本",
  162. "这是另一条文本",
  163. "这是一条文本",
  164. "这是另一条文本",
  165. "这是一条文本",
  166. "这是另一条文本",
  167. "这是一条文本",
  168. "这是另一条文本",
  169. "这是一条文本",
  170. "这是另一条文本",
  171. "这是一条文本",
  172. "这是另一条文本",
  173. "这是一条文本",
  174. "这是另一条文本",
  175. ]
  176. train_dataset = TrainDataset(training_texts, bert_tokenizer, config)
  177. train_dataloader = DataLoader(train_dataset)
  178. train(model=bert_mlm_model, train_dataloader=train_dataloader, config=config)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/703864
推荐阅读
相关标签
  

闽ICP备14008679号