当前位置:   article > 正文

transformers中BertPreTrainedModel使用说明

bertpretrainedmodel

transformers中BertPreTrainedModel使用说明

最近在学习hugging face中的transformers,huggingface中有大量的预训练模型方便使用。有时候使用transformers加载预训练模型做下游任务,保存的模型并不能使用transformers中的api加载,只能使用model.load_state_dict(torch.load())的方式进行加载模型,最近发现transformers中的预训练模型类可以解决上述问题。以BertModel和BertPreTrainedModel为例进行说明。本次代码是一个简单的文本分类。

模型

基于bert预训练模型的文本分类,代码如下:

class ClassifierModel(BertPreTrainedModel):
    def __init__(self, config, num_class):
        super(ClassifierModel, self).__init__(config)
        self.config = config
        self.num_class = num_class
        self.bert = BertModel(config=self.config)#.from_pretrained(pretrain_model_path)
        self.hidden_size = config.hidden_size
        self.classifier = nn.Linear(self.hidden_size, self.num_class)

    def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
        output = self.bert(input_ids=input_ids, token_type_ids=token_type_ids,
                           attention_mask=attention_mask)[1]
        logits = self.classifier(output)

        if labels is not None:
            entropy_loss = nn.CrossEntropyLoss()
            loss = entropy_loss(logits, labels)
            return loss, logits
        else:
            return logits
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

数据处理过程如下:

class DataProcessor(object):
    def __init__(self, data_file, tokenizer, max_seq_len, label2id):
        self.data_file = data_file
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.label2id = label2id
        self.data = self._read_data()

    def _read_data(self):
        data = []
        with open(self.data_file, "r", encoding="utf-8") as f:
            for line in f.readlines():
                label, sent = line.strip().split("\t")
                data.append({"sent": sent, "label": self.label2id[label]})
        random.shuffle(data)
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sent = self.data[index]["sent"]
        label = self.data[index]["label"]
        features = self.tokenizer(sent, padding=True, truncation=True,
                                  max_length=self.max_seq_len,
                                  add_special_tokens=True)
        padding_len = self.max_seq_len - len(features["input_ids"])
        input_ids = features["input_ids"] + [self.tokenizer.pad_token_id] * padding_len
        token_type_ids = features["token_type_ids"] + [0] * padding_len
        attention_mask = features["attention_mask"] + [0] * padding_len
        return {"input_ids": input_ids,
                "token_type_ids": token_type_ids,
                "attention_mask": attention_mask,
                "label": label}


def collate_fn(batch_data):
    input_ids = [item["input_ids"] for item in batch_data]
    token_type_ids = [item["token_type_ids"] for item in batch_data]
    attention_mask = [item["attention_mask"] for item in batch_data]
    label = [item["label"] for item in batch_data]

    input_ids = torch.tensor(input_ids, dtype=torch.long)
    token_type_ids = torch.tensor(token_type_ids, dtype=torch.long)
    attention_mask = torch.tensor(attention_mask, dtype=torch.long)
    label = torch.tensor(label, dtype=torch.long)

    return {"input_ids": input_ids, "token_type_ids": token_type_ids,
            "attention_mask": attention_mask,
            "label": label}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50

主函数如下:

def main():
    pretrain_model_path = "D:/Spyder/pretrain_model/transformers_torch_tf/bert-base-chinese/"
    tokenizer = BertTokenizer.from_pretrained(pretrain_model_path)
    tokenizer.save_pretrained("./output/")
    label2id = {"体育": 0, "健康": 1, "军事": 2}
    id2label = {value: key for key, value in label2id.items()}

    data_file = "./data/train_data.txt"
    batch_size = 2
    epochs = 1
    lr = 1e-5
    max_seq_len = 128
    device = "cpu"

    train_dataset = DataProcessor(data_file, tokenizer, max_seq_len, label2id)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size,
                                  shuffle=False, collate_fn=collate_fn)

    config = BertConfig.from_pretrained(pretrain_model_path,
                                        num_labels=len(label2id),
                                        id2label=id2label,
                                        label2id=label2id
                                        )
    model = ClassifierModel.from_pretrained(pretrain_model_path,
                                            config=config,
                                            num_class=3)
    model.to(device)
    optimizer = AdamW(params=model.parameters(), lr=lr)
    for epoch in range(epochs):
        model.train()
        for step, batch_data in enumerate(train_dataloader):
            input_ids = batch_data["input_ids"].to(device)
            token_type_ids = batch_data["token_type_ids"].to(device)
            attention_mask = batch_data["attention_mask"].to(device)
            labels = batch_data["label"].to(device)

            loss, _ = model(input_ids, token_type_ids, attention_mask, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print("step:{}, loss:{}".format(step + 1, loss))

        model.save_pretrained("./output/")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44

保存模型结果如下图:;
模型结果图

加载保存模型

这里使用ClassifierModel和BertModel加载保存的模型

	model_path = "./output/"
    config = BertConfig.from_pretrained(model_path)
    model = ClassifierModel.from_pretrained(model_path,
                                            config=config,
                                            num_class=3)
    # for parms in model.named_parameters():
    #     print(parms)
    print(model.state_dict().keys())
    bert_model = BertModel.from_pretrained("./output/")
    print(bert_model.state_dict().keys())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

输出模型参数key如下:

bert.embeddings.position_ids  ClassifierModel的keys
embeddings.position_ids  BertModel的keys
  • 1
  • 2

输入一个句子,模型结果如下:

	text = "拉齐奥获不利排序意甲本周末拉齐奥与帕尔马之"
    tokenizer = BertTokenizer.from_pretrained(model_path)
    features = tokenizer(text, padding=True, truncation=True,
                         max_length=32,
                         add_special_tokens=True)
    input_ids = torch.tensor([features["input_ids"]], dtype=torch.long)
    token_type_ids = torch.tensor([features["token_type_ids"]], dtype=torch.long)
    attention_mask = torch.tensor([features["attention_mask"]], dtype=torch.long)
    bert_output = bert_model(input_ids=input_ids, token_type_ids=token_type_ids,
                             attention_mask=attention_mask)[1]
    print(bert_output.shape)
    output = model(input_ids=input_ids, token_type_ids=token_type_ids,
                             attention_mask=attention_mask)
    print(output.shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

(2)修改模型可以上传到huggingface仓库中

	state_dict = torch.load(os.path.join("./output/", "pytorch_model.bin"), 		   map_location=torch.device("cpu"))
    new_state_dict = {}
    for key, param in state_dict.items():

        # Delete "bert" or "roberta" prefix
        if "bert." in key:
            key = key.replace("bert.", "")
        if "roberta." in key:
            key = key.replace("roberta.", "")

        new_state_dict[key] = param

    torch.save(new_state_dict, os.path.join("./output_v2/", "pytorch_model.bin"))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

上述是学习huggingface中的transformers的笔记记录,如有疑问,欢迎留言一起交流。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/312445
推荐阅读
相关标签
  

闽ICP备14008679号