赞
踩
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-ro', use_fast=True)
print(tokenizer)
MarianTokenizer(name_or_path='Helsinki-NLP/opus-mt-en-ro', vocab_size=59543, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'})
tokenizer.batch_encode_plus([['hello, everyone today is a good day', 'It is late, please go home']])
{'input_ids': [[92, 778, 3, 1773, 879, 32, 8, 265, 431, 84, 32, 1450, 3, 709, 100, 540, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}
from datasets import load_dataset
dataset = load_dataset(path='wmt16', name='ro-en')
dataset
DatasetDict({
train: Dataset({
features: ['translation'],
num_rows: 610320
})
validation: Dataset({
features: ['translation'],
num_rows: 1999
})
test: Dataset({
features: ['translation'],
num_rows: 1999
})
})
dataset['train'] = dataset['train'].shuffle(1).select(range(20000))
dataset['validation'] = dataset['validation'].shuffle(1).select(range(200))
dataset['test'] = dataset['test'].shuffle(1).select(range(200))
dataset['train'][0]
{'translation': {'en': 'For these reasons I voted in favour of the proposal for a new regulation that aims for greater clarity and transparency in the GSP system.',
'ro': 'Din aceste motive am votat în favoarea propunerii de nou regulament care își propune o mai mare claritate și transparență în sistemul SPG.'}}
def preprocess_function(data, tokenizer):
en = [ex['en'] for ex in data['translation']]
ro = [ex['ro'] for ex in data['translation']]
data = tokenizer.batch_encode_plus(en, max_length=128, truncation=True)
with tokenizer.as_target_tokenizer():
data['labels'] = tokenizer.batch_encode_plus(ro, max_length=128, truncation=True)['input_ids']
return data
dataset = dataset.map(preprocess_function,
batched=True,
batch_size=1000,
num_proc=1,
remove_columns=['translation'],
fn_kwargs={'tokenizer': tokenizer})
print(dataset['train'][0])
{'input_ids': [460, 354, 3794, 12, 10677, 20, 5046, 14, 4, 2546, 37, 8, 397, 5551, 30, 10113, 37, 3501, 19814, 18, 8465, 20, 4, 44690, 782, 2, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [902, 576, 2946, 76, 10815, 17, 5098, 14997, 5, 559, 1140, 43, 2434, 6624, 27, 50, 337, 19216, 46, 22174, 17, 2317, 121, 16825, 2, 0]}
def collate_fn(data):
max_length=max([len(i['labels']) for i in data]) # 求最长label
for i in data:
pads = [-100] * (max_length - len(i['labels']))
i['labels'] = i['labels'] + pads
data = tokenizer.pad(
encoded_inputs=data,
padding=True,
max_length=None,
pad_to_multiple_of=None,
return_tensors='pt')
# decoder_input_ids
data['decoder_input_ids'] = torch.full_like(data['labels'],
tokenizer.get_vocab()['pad'],
dtype=torch.long)
data['decoder_input_ids'][:, 1:] = data['labels'][:, :-1]
data['decoder_input_ids'][data['decoder_input_ids'] == -100] = tokenizer.get_vocab()['<pad>']
return data
import torch
loader = torch.utils.data.DataLoader(
dataset=dataset['train'],
batch_size=8,
collate_fn=collate_fn,
shuffle=True,
drop_last=True)
for data in loader:
break
data
{'input_ids': tensor([[ 12, 1107, 30, 37, 4, 2194, 476, 63, 123, 47,
116, 15, 27384, 1036, 3, 18, 66, 8, 9911, 1591,
141, 2, 0, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542],
...
[ 172, 2515, 1297, 3, 74, 64, 5023, 133, 23076, 18,
9000, 11, 17351, 21120, 2, 0, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
...
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'labels': tensor([[ 1939, 70, 39, 2149, 3042, 701, 19, 224, 27, 6461,
5968, 9188, 31, 29, 916, 11537, 49, 9803, 71, 2,
0, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100],
...
[ 127, 5742, 343, 3, 76, 79, 27209, 40989, 46, 24725,
181, 43, 34119, 32121, 2, 0, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100]]), 'decoder_input_ids': tensor([[34426, 1939, 70, 39, 2149, 3042, 701, 19, 224, 27,
6461, 5968, 9188, 31, 29, 916, 11537, 49, 9803, 71,
2, 0, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542],
...
[34426, 127, 5742, 343, 3, 76, 79, 27209, 40989, 46,
24725, 181, 43, 34119, 32121, 2, 0, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
59542, 59542, 59542]])}
for k, v in data.items():
print(k, v.shape)
input_ids torch.Size([8, 89])
attention_mask torch.Size([8, 89])
labels torch.Size([8, 103])
decoder_input_ids torch.Size([8, 103])
from transformers import AutoModelForSeq2SeqLM, MarianModel
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.pretrained = MarianModel.from_pretrained('Helsinki-NLP/opus-mt-en-ro')
self.register_buffer('final_logits_bias', torch.zeros(1, tokenizer.vocab_size)) # 登记缓冲 偏差
self.fc = torch.nn.Linear(512, tokenizer.vocab_size, bias=False)
parameters = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-en-ro')
self.fc.load_state_dict(parameters.lm_head.state_dict())
self.criterion = torch.nn.CrossEntropyLoss()
def forward(self, input_ids, attention_mask, labels, decoder_input_ids):
logits = self.pretrained(input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids)
logits = logits.last_hidden_state
logits = self.fc(logits) + self.final_logits_bias
loss = self.criterion(logits.flatten(end_dim=1), labels.flatten())
return {'loss': loss, 'logits': logits}
# [b,lens] -> embedding -> [b,lens,embed_size] -> pretrained[embed_size,512] -> [b,lens,512] -> fc[512,vocab_size] -> [b,lens,vocab_size]
model = Model()
print(sum(i.numel() for i in model.parameters()))
105634816
out = model(**data)
out['loss'], out['logits'].shape
(1.4804006814956665, torch.Size([8, 103, 59543]))
def test(model):
loader_test = torch.utils.data.DataLoader(
dataset=dataset['test'],
batch_size=8,
collate_fn=collate_fn,
shuffle=True,
drop_last=True)
predictions = []
references = []
for i, data in enumerate(loader_test):
with torch.no_grad():
out = model(**data)
pred = tokenizer.batch_decode(out['logits'].argmax(dim=2))
label = tokenizer.batch_decode(data['decoder_input_ids'])
predictions.append(pred)
references.append(label)
if i % 2 == 0:
print(i)
input_ids = tokenizer.decode(data['input_ids'][0])
print('input_ids=', input_ids)
print('pred=', pred[0])
print('label=', label[0])
if i == 10:
break
references = [[j] for j in references]
test(model)
0
input_ids= The▁only name that▁was not▁mentioned by▁any of the▁participants in the▁negotiations of recent▁days is that of the▁former▁head of the▁branch,▁Mayor Gheorghe Nichita.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
pred= Singurul nume care nu a fost men de niciunul dintre participanții la negocierile din ultimele zile este cel al fostului șef al filialei, primarul Gheorghe Nichita.</s> DEaa al al Nicol Nicol Nicol N N În În În În În În În În În În În În În În În În În În Singur Singur Singur Singur Singur Singur
label= pad Singurul nume care nu a fost menționat de niciunul din participanții la negocierile din ultimele zile este cel al fostului lider al filialei, primarul Gheorghe Nichita.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
...
10
input_ids= ▁That▁would▁point to a▁stock▁market▁drop▁if the Fed▁raises the rate,▁unless▁policymakers▁were to▁soften the▁blow by▁promising that▁another▁increase▁would be a▁ways▁off.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
pred= Acest ar indica o scădere a piața, federaled- rata,, cu cazul în care factorii care elaborează politici ar reduce lovitura,mițând că o exista o peste de la o creștere a</s>,, Acest În În În În În În În În În În Acest Acest Acest
label= pad Aceasta ar indica o scădere pe bursă dacă Fed crește rata dobânzii, exceptând cazul în care cei care elaborează politicile ar atenua lovitura promițând că ar trece mult timp până la următoarea creștere.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
from transformers import AdamW
from transformers.optimization import get_scheduler
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device
device(type='cuda', index=0)
def train():
optimizer = AdamW(model.parameters(), lr=2e-5)
scheduler = get_scheduler(name='linear',
num_warmup_steps=0,
num_training_steps=len(loader),
optimizer=optimizer)
model.to(device)
model.train()
for i, data in enumerate(loader):
for k in data.keys():
data[k] = data[k].to(device)
out = model(**data)
loss = out['loss']
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
model.zero_grad()
if i % 50 == 0:
out = out['logits'].argmax(dim=2)
correct = (data['decoder_input_ids'] == out).sum().item()
total = data['decoder_input_ids'].shape[1] * 8
accuracy = correct / total
predictions = []
references = []
for j in range(8):
pred = tokenizer.decode(out[j])
label = tokenizer.decode(data['decoder_input_ids'][j])
predictions.append(pred)
references.append(label)
lr = optimizer.state_dict()['param_groups'][0]['lr']
print(i, loss.item(), accuracy, lr)
train()
0 2.2159132957458496 0.0 1.9992e-05
...
2400 0.6920648217201233 0.006696428571428571 7.920000000000001e-07
2450 0.8634450435638428 0.004032258064516129 3.92e-07
torch.save(model, '../data/翻译.model')
model2 = torch.load('../data/翻译.model', map_location='cpu')
test(model2)
0
input_ids= ▁Last▁month▁saw▁lowest▁growth▁rise▁since▁records▁began in 2000</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
pred= Luna ultima lună,-a înregistrat o mai mică creşte de 2000. în prezent.</s> țiululululul -: - De De De De De De De De De De De De De De De De Luna De De De De De De Luna De Luna Luna Luna Luna De Luna Luna De Luna Luna Luna Luna De Luna De Luna Luna Luna Luna Luna Luna Luna Luna
label= pad În ultima lună s-a înregistrat cea mai lentă creștere din 2000 până în prezent.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
...
10
input_ids= Corneliu Vadim Tudor▁was▁born on▁November 28,▁1949, in▁Bucharest. He▁was a▁writer,▁politician and▁journalist.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
pred= Corneliu Vadim Tudor s-a născutscut in 28 noiembrie 1949, la Bucuresti. a scriitor, politician si jurnalist.</s> </s> al alul al al, La al A A A I La Cor A Cor A Cor Cor Cor La Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Corneliu Cor Cor Cor Cor Cor Cor Cor Cor
label= pad Corneliu Vadim Tudor s-a nascut în 28 noiembrie 1949, în Bucuresti, era scriitor, politician și jurnalist.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。