赞
踩
@[TOC] (Bert简单实现)
“”"
BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova
https://arxiv.org/abs/1810.04805
“”"
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
device = torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)
“”"
情感分类
pip install -qq transformers
“”"
PRE_TRAINED_MODEL_NAME = ‘bert-base-cased’
def to_sentiment(rating):
“”“将score分为3类”“”
rating = int(rating)
if rating <= 2:
return 0
elif rating == 3:
return 1
else:
return 2
class GPReviewDataset(Dataset):
def init(self, reviews, targets, tokenizer, max_len):
self.reviews = reviews
self.targets = targets
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.reviews)
def __getitem__(self, item):
review = str(self.reviews[item])
target = self.targets[item]
encoding = self.tokenizer.encode_plus(
review,
add_special_tokens=True, # Add '[CLS]' and '[SEP]'
max_length=self.max_len,
return_token_type_ids=False,
pad_to_max_length=True,
return_attention_mask=True,
return_tensors='pt', # Return PyTorch tensors
)
return {
'review_text': review,
'input_ids': encoding['input_ids'].flatten(), # batch size x seq length
'attention_mask': encoding['attention_mask'].flatten(), # batch size x seq length
'targets': torch.tensor(target, dtype=torch.long) # batch size
}
def create_data_loader(df, tokenizer, max_len, batch_size):
ds = GPReviewDataset(reviews=df.content.to_numpy(), targets=df.sentiment.to_numpy(),
tokenizer=tokenizer, max_len=max_len)
return DataLoader(ds, batch_size=batch_size, num_workers=4)
class SentimentClassifier(nn.Module):
def init(self, n_classes):
super(SentimentClassifier, self).init()
self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME, return_dict=False) # 加载预训练模型
self.drop = nn.Dropout(p=0.3)
self.out = nn.Linear(self.bert.config.hidden_size, n_classes) # 768,3
def forward(self, input_ids, attention_mask):
"""
input_ids:LongTensor of shape [batch_size, sequence_length]
词汇表中的单词标记索引
attention_mask: an optional torch.LongTensor of shape [batch_size, sequence_length]
取值为0或1。
如果输入序列长度小于当前批处理中的最大输入序列长度,则使用该掩码。
当一批有不同长度的句子时,我们通常使用这个掩码。
"""
_, pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask) # 句首的CLS的输出内容
output = self.drop(pooled_output)
return self.out(output)
def train_epoch(model, data_loader, loss_fn, optimizer, device, scheduler, n_examples):
model = model.train()
losses = []
correct_predictions = 0
for d in data_loader:
input_ids = d["input_ids"].to(device)
attention_mask = d["attention_mask"].to(device)
targets = d["targets"].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
_, preds = torch.max(outputs, dim=1)
loss = loss_fn(outputs, targets)
correct_predictions += torch.sum(preds == targets)
losses.append(loss.item())
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
return correct_predictions.double() / n_examples, np.mean(losses)
if name == ‘main’:
df = pd.read_csv(“reviews.csv”).loc[:100, :] # (15746, 12)
df[‘sentiment’] = df.score.apply(to_sentiment)
class_names = [‘negative’, ‘neutral’, ‘positive’]
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
BATCH_SIZE = 2
MAX_LEN = 160
train_data_loader = create_data_loader(df, tokenizer, MAX_LEN, BATCH_SIZE)
model = SentimentClassifier(len(class_names))
model = model.to(device)
EPOCHS = 2
optimizer = AdamW(model.parameters(), lr=2e-5, correct_bias=False)
total_steps = len(train_data_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
loss_fn = nn.CrossEntropyLoss().to(device)
for epoch in range(EPOCHS):
train_acc, train_loss = train_epoch(model, train_data_loader,
loss_fn, optimizer,
device, scheduler, len(df))
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。