当前位置:   article > 正文

基于BERT的文本分类——附-简单的示例代码_基于bert的诗歌分类

基于bert的诗歌分类

**BERT(Bidirectional Encoder Representations from Transformers)**是一种预训练的自然语言处理模型,由Google于2018年提出。BERT通过在大规模文本语料上进行预训练,学习了深层次的语言表示,然后可以通过微调用于各种下游任务,包括文本分类。

文本分类是一个常见的自然语言处理任务,它涉及将文本分为不同的类别或标签。BERT在文本分类任务中表现出色,其双向(bidirectional)的注意力机制允许模型捕捉上下文信息,同时避免了传统模型中的单向限制。以下是BERT文本分类的一般步骤:

  1. 预训练BERT模型: 在大规模文本语料上,BERT模型进行了预训练。这个过程使得BERT能够学到深层次的语言表示,理解单词和短语之间的复杂关系。

  2. 微调: 针对特定的文本分类任务,使用预训练好的BERT模型,通过在带有标签的任务数据上进行微调。这意味着使用任务特定的标签数据,调整模型参数以适应具体的分类任务。

  3. 输入表示: 对于文本分类任务,输入文本需要经过一些处理,以适应BERT的输入格式。通常,文本被分成标记(tokens)并加上特殊的标记,如[CLS](用于表示分类任务的开始)。

  4. 获取输出: BERT模型的输出通常是每个输入标记的隐藏表示。为了进行文本分类,可以采用一些策略,例如取[CLS]标记的输出作为整个句子的表示,然后将其输入到分类器中。

  5. 分类器: 在BERT的输出之上,添加一个分类器(如全连接层),以将句子映射到相应的类别。这个分类器是在微调过程中学到的。

  6. 训练和评估: 使用标记的训练数据对整个模型进行训练,并使用验证集或测试集来评估模型的性能。常见的损失函数包括交叉熵损失。

BERT的优势在于其能够捕捉上下文信息,使得在文本分类等任务中能够更好地理解语境。然而,由于BERT模型参数较多,微调可能需要更多的标签数据和计算资源。最近,出现了一些基于BERT的轻量化模型,以降低计算成本并适应资源有限的环境。

BERT及其变体

  1. bert-base-uncased:

编码器具有12个隐层, 输出768维张量, 12个自注意力头, 共110M参数量, 在小写的英文文本上进行训练而得到.

  1. bert-large-uncased:

编码器具有24个隐层, 输出1024维张量, 16个自注意力头, 共340M参数量, 在小写的英文文本上进行训练而得到.

  1. bert-base-cased:

编码器具有12个隐层, 输出768维张量, 12个自注意力头, 共110M参数量, 在不区分大小写的英文文本上进行训练而得到.

  1. bert-large-cased:

编码器具有24个隐层, 输出1024维张量, 16个自注意力头, 共340M参数量, 在不区分大小写的英文文本上进行训练而得到.

  1. bert-base-multilingual-uncased:

编码器具有12个隐层, 输出768维张量, 12个自注意力头, 共110M参数量, 在小写的102种语言文本上进行训练而得到.

  1. bert-large-multilingual-uncased:

编码器具有24个隐层, 输出1024维张量, 16个自注意力头, 共340M参数量, 在小写的102种语言文本上进行训练而得到.

  1. bert-base-chinese:

编码器具有12个隐层, 输出768维张量, 12个自注意力头, 共110M参数量, 在简体和繁体中文文本上进行训练而得到.

BERT的base模型可以在这里下载:

BERT基础模型 https://github.com/google-research/bert?tab=readme-ov-file

以下是一段我写的示例代码:

import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW,SGD,Adam
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )

        return {
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')#如果自动下载失败,请手动访问https://github.com/google-research/bert?tab=readme-ov-file下载到本地,并修改目录路径
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=2)

## 读取Excel文件
#df = pd.read_excel('train_data.xlsx')
## 提取标签和文本数据
#labels = df.iloc[1:, 0].values
#texts = df.iloc[1:, 1].values
## 创建一个字典,其中键是标签,值是一个唯一的数字
#label_dict = {label: i for i, label in enumerate(set(labels))}


# print(label_dict)
## 使用字典将文本标签转换为数字
# labels = [label_dict[label] for label in labels]

texts = ["你好,世界!", "机器学习太棒了!"]
labels = [0, 1]

## 将数据集分为训练集和测试集
#train_texts, test_texts, train_labels, test_labels = train_test_split(texts, labels, #test_size=0.2, random_state=13)

## 使用训练集创建训练数据加载器
#train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer, max_len=128)
#train_data_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

## 使用测试集创建测试数据加载器
#test_dataset = TextClassificationDataset(test_texts, test_labels, tokenizer, max_len=128)
#test_data_loader = DataLoader(test_dataset, batch_size=128)

# 使用提取的标签和文本数据创建数据集
all_dataset = TextClassificationDataset(texts, labels, tokenizer, max_len=128)
all_data_loader = DataLoader(all_dataset , batch_size=2, shuffle=True)

# 检查是否有可用的CUDA设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载保存的状态字典
# model.load_state_dict(torch.load('my_model.pth'))

# 将模型移动到CUDA设备上
model = model.to(device)

optimizer = AdamW(model.parameters(), lr=1e-4)
epochs = 10  # 定义训练的轮数

for epoch in range(epochs):
    total_loss = 0
    model.train()
    for batch in all_data_loader :
        # 将输入数据移动到CUDA设备上
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)

        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()

        optimizer.step()
        optimizer.zero_grad()


    avg_train_loss = total_loss / len(train_data_loader)
    print(f"Epoch {epoch+1} / {epochs}, Training Loss: {avg_train_loss}")

    # 评估模型
    model.eval()
    correct_predictions = 0
    total_predictions = 0
    with torch.no_grad():
        for batch in all_data_loader :
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)

            predicted_class = torch.argmax(outputs.logits, dim=1)
            correct_predictions += (predicted_class == labels).sum().item()
            total_predictions += labels.size(0)

    # 计算准确率
    accuracy = correct_predictions / total_predictions
    print(f"Accuracy: {accuracy}")



print("Training complete.")
# 将模型移动到CPU设备上
model = model.to('cpu')
# 在训练结束后保存模型的状态字典
torch.save(model.state_dict(), 'my_model.pth')

print("Model saved to model.pth.")

#----------------------------------------------------
# 初始化一个新的模型实例
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=2)

# 加载保存的状态字典
model.load_state_dict(torch.load('my_model.pth'))

# 将模型移动到CUDA设备上
model = model.to(device)

print("Model loaded from model.pth.")
model.eval()
new_text = "你好,世界!"

encoding = tokenizer.encode_plus(
    new_text,
    add_special_tokens=True,
    max_length=128,
    return_token_type_ids=False,
    padding='max_length',
    return_attention_mask=True,
    return_tensors='pt',
)

input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']

with torch.no_grad():
    # 将输入数据移动到CUDA设备上
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    outputs = model(input_ids, attention_mask=attention_mask)

predicted_class = torch.argmax(outputs.logits, dim=1)
# 创建一个反向字典,其中键是数字,值是文本标签
# reverse_label_dict = {v: k for k, v in label_dict.items()}

# 使用反向字典将预测的类别转换回文本标签
# predicted_class = predicted_class.item()
# predicted_label = reverse_label_dict[predicted_class]

#print(f"The predicted class for the text '{new_text}' is {predicted_label}.")
print(f"The predicted class for the text '{new_text}' is {predicted_class.item()}.")

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号