赞
踩
import sys
import os
import json
import time
import math
import string
import re
import numpy as np
import random
import unicodedata
from scipy import spatial
import subprocess
from tempfile import NamedTemporaryFile
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
#from torch.distributed import get_rank
#from torch.distributed import get_world_size
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data.sampler import Sampler
import Levenshtein as Lev
import torchaudio
#from trainer.asr.trainer import Trainer
from utils import constant
#from utils.data_loader import SpectrogramDataset, AudioDataLoader, BucketingSampler
#from utils.audio import load_audio, get_audio_length, audio_with_sox, augment_audio_with_sox, load_randomly_augmented_audio
#from utils.functions import save_model, load_model, init_transformer_model, init_optimizer
#from models.asr.transformer import Transformer, Encoder, Decoder
#from utils.optimizer import NoamOpt, AnnealingOpt
#from utils.metrics import calculate_metrics
#from utils.lstm_utils import calculate_lm_score
#from data.helper import get_word_segments_per_language, is_contain_chinese_word
#import logging
windows = {'hamming': scipy.signal.hamming, 'hann': scipy.signal.hann, 'blackman': scipy.signal.blackman,
'bartlett': scipy.signal.bartlett}
dir_path = os.path.dirname(os.path.realpath(__file__))
def load_stanford_core_nlp(path):
from stanfordcorenlp import StanfordCoreNLP
"""
Load stanford core NLP toolkit object
args:
path: String
output:
Stanford core NLP objects
"""
zh_nlp = StanfordCoreNLP(path, lang='zh')
en_nlp = StanfordCoreNLP(path, lang='en')
return zh_nlp, en_nlp
"""
################################################
TEXT PREPROCESSING
################################################
"""
def is_chinese_char(cc):
"""
Check if the character is Chinese
args:
cc: char
output:
boolean
"""
return unicodedata.category(cc) == 'Lo'
def is_contain_chinese_word(seq):
"""
Check if the sequence has chinese character(s)
args:
seq: String
output:
boolean
"""
for i in range(len(seq)):
if is_chinese_char(seq[i]):
return True
return False
def get_word_segments_per_language(seq):
"""
Get word segments
args:
seq: String
output:
word_segments: list of String
"""
cur_lang = -1 # cur_lang = 0 (english), 1 (chinese)
words = seq.split(" ")
temp_words = ""
word_segments = []
for i in range(len(words)):
word = words[i]
if is_contain_chinese_word(word):
if cur_lang == -1:
cur_lang = 1
temp_words = word
elif cur_lang == 0: # english
cur_lang = 1
word_segments.append(temp_words)
temp_words = word
else:
if temp_words != "":
temp_words += " "
temp_words += word
else:
if cur_lang == -1:
cur_lang = 0
temp_words = word
elif cur_lang == 1: # chinese
cur_lang = 0
word_segments.append(temp_words)
temp_words = word
else:
if temp_words != "":
temp_words += " "
temp_words += word
word_segments.append(temp_words)
return word_segments
def get_word_segments_per_language_with_tokenization(seq, tokenize_lang=-1, zh_nlp=None, en_nlp=None):
"""
Get word segments and tokenize the sequence for selected language
We cannot run two different languages on stanford core nlp, will be very slow
so instead we do it as many times as the number of languages we want to tokenize
args:
seq: String
tokenize_lang: int (-1 means no language is selected, 0 (english), 1 (chinese))
"""
cur_lang = -1
words = seq.split(" ")
temp_words = ""
word_segments = []
for i in range(len(words)):
word = words[i]
if is_contain_chinese_word(word):
if cur_lang == -1:
cur_lang = 1
temp_words = word
elif cur_lang == 0: # english
cur_lang = 1
if tokenize_lang == 0:
word_list = en_nlp.word_tokenize(temp_words)
temp_words = ' '.join(word for word in word_list)
word_segments.append(temp_words)
temp_words = word
else:
if temp_words != "":
temp_words += " "
temp_words += word
else:
if cur_lang == -1:
cur_lang = 0
temp_words = word
elif cur_lang == 1: # chinese
cur_lang = 0
if tokenize_lang == 1:
word_list = zh_nlp.word_tokenize(temp_words.replace(" ",""))
temp_words = ' '.join(word for word in word_list)
word_segments.append(temp_words)
temp_words = word
else:
if temp_words != "":
temp_words += " "
temp_words += word
if tokenize_lang == 0 and cur_lang == 0:
word_list = en_nlp.word_tokenize(temp_words)
temp_words = ' '.join(word for word in word_list)
elif tokenize_lang == 1 and cur_lang == 1:
word_list = zh_nlp.word_tokenize(temp_words)
temp_words = ' '.join(word for word in word_list)
word_segments.append(temp_words)
# word_seq = ""
# for i in range(len(word_segments)):
# if word_seq != "":
# word_seq += " "
# else:
# word_seq = word_segments[i]
return word_segments
def remove_emojis(seq):
"""
Remove emojis
args:
seq: String
output:
seq: String
"""
emoji_pattern = re.compile("["
u"\U0001F600-\U0001F64F" # emoticons
u"\U0001F300-\U0001F5FF" # symbols & pictographs
u"\U0001F680-\U0001F6FF" # transport & map symbols
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
"]+", flags=re.UNICODE)
seq = emoji_pattern.sub(r'', seq).strip()
return seq
def merge_abbreviation(seq):
seq = seq.replace(" ", " ")
words = seq.split(" ")
final_seq = ""
temp = ""
for i in range(len(words)):
word_length = len(words[i])
if word_length == 0: # unknown character case
continue
if words[i][word_length-1] == ".":
temp += words[i]
else:
if temp != "":
if final_seq != "":
final_seq += " "
final_seq += temp
temp = ""
if final_seq != "":
final_seq += " "
final_seq += words[i]
if temp != "":
if final_seq != "":
final_seq += " "
final_seq += temp
return final_seq
def remove_punctuation(seq):
"""
Remove english and chinese punctuation except hypen/dash, and full stop.
Also fix some typos and encoding issues
args:
seq: String
output:
seq: String
"""
seq = re.sub("[\s+\\!\/_,$%=^*?:@&^~`(+"]+|[+!,。?、~@#¥%……&*():;:;《)《》“”()»〔〕]+", " ", seq)
seq = seq.replace(" ' ", " ")
seq = seq.replace(" ’ ", " ")
seq = seq.replace(" ' ", " ")
seq = seq.replace(" ` ", " ")
seq = seq.replace(" '", "'")
seq = seq.replace(" ’", "’")
seq = seq.replace(" '", "'")
seq = seq.replace("' ", " ")
seq = seq.replace("’ ", " ")
seq = seq.replace("' ", " ")
seq = seq.replace("` ", " ")
seq = seq.replace(".", "")
seq = seq.replace("`", "")
seq = seq.replace("-", " ")
seq = seq.replace("?", " ")
seq = seq.replace(":", " ")
seq = seq.replace(";", " ")
seq = seq.replace("]", " ")
seq = seq.replace("[", " ")
seq = seq.replace("}", " ")
seq = seq.replace("{", " ")
seq = seq.replace("|", " ")
seq = seq.replace("_", " ")
seq = seq.replace("(", " ")
seq = seq.replace(")", " ")
seq = seq.replace("=", " ")
seq = seq.replace(" dont ", " don't ")
seq = seq.replace("welcome外星人", "welcome 外星人")
seq = seq.replace("doens't", "doesn't")
seq = seq.replace("o' clock", "o'clock")
seq = seq.replace("因为it's", "因为 it's")
seq = seq.replace("it' s", "it's")
seq = seq.replace("it ' s", "it's")
seq = seq.replace("it' s", "it's")
seq = seq.replace("y'", "y")
seq = seq.replace("y ' ", "y")
seq = seq.replace("看different", "看 different")
seq = seq.replace("it'self", "itself")
seq = seq.replace("it'ss", "it's")
seq = seq.replace("don'r", "don't")
seq = seq.replace("has't", "hasn't")
seq = seq.replace("don'know", "don't know")
seq = seq.replace("i'll", "i will")
seq = seq.replace("you're", "you are")
seq = seq.replace("'re ", " are ")
seq = seq.replace("'ll ", " will ")
seq = seq.replace("'ve ", " have ")
seq = seq.replace("'re\n", " are\n")
seq = seq.replace("'ll\n", " will\n")
seq = seq.replace("'ve\n", " have\n")
seq = remove_space_in_between_words(seq)
return seq
def remove_special_char(seq):
"""
Remove special characters from the corpus
args:
seq: String
output:
seq: String
"""
seq = re.sub("[【】·.%°℃×→①ぃγ ̄σς=~•+δ≤∶/⊥_ñãíå∈△β[]±]+", " ", seq)
return seq
def remove_space_in_between_words(seq):
"""
Remove space between words
args:
seq: String
output:
seq: String
"""
return seq.replace(" ", " ").replace(" ", " ").replace(" ", " ").replace(" ", " ").strip().lstrip()
def remove_return(seq):
"""
Remove return characters
args:
seq: String
output:
seq: String
"""
return seq.replace("\n", "").replace("\r", "").replace("\t", "")
def preprocess_mixed_language_sentence(seq, tokenize=False, en_nlp=None, zh_nlp=None, tokenize_lang=-1):
"""
Preprocess function
args:
seq: String
output:
seq: String
"""
if len(seq) == 0:
return ""
seq = seq.lower()
seq = merge_abbreviation(seq)
seq = seq.replace("\x7f", "")
seq = seq.replace("\x80", "")
seq = seq.replace("\u3000", " ")
seq = seq.replace("\xa0", "")
seq = seq.replace("[", " [")
seq = seq.replace("]", "] ")
seq = seq.replace("#", "")
seq = seq.replace(",", "")
seq = seq.replace("*", "")
seq = seq.replace("\n", "")
seq = seq.replace("\r", "")
seq = seq.replace("\t", "")
seq = seq.replace("~", "")
seq = seq.replace("—", "")
seq = seq.replace(" ", " ").replace(" ", " ")
seq = re.sub('\<.*?\>','', seq) # REMOVE < >
seq = re.sub('\【.*?\】','', seq) # REMOVE 【 】
seq = re.sub("[
seq = re.sub("[\{
temp_words = ""
if not tokenize:
segments = get_word_segments_per_language(seq)
else:
segments = get_word_segments_per_language_with_tokenization(seq, en_nlp=en_nlp, zh_nlp=zh_nlp, tokenize_lang=tokenize_lang)
for j in range(len(segments)):
if not is_contain_chinese_word(segments[j]):
segments[j] = re.sub(r'[^\x00-\x7f]',r' ',segments[j])
if temp_words != "":
temp_words += " "
temp_words += segments[j].replace("\n", "")
seq = temp_words
seq = remove_space_in_between_words(seq)
seq = seq.strip()
seq = seq.lstrip()
# Tokenize chinese characters
if len(seq) <= 1:
return ""
else:
return seq
"""
################################################
AUDIO PREPROCESSING
################################################
"""
def preprocess_wav(root, dirc, filename):
source_audio = root + "/" + dirc + "/audio/" + filename + ".flac"
with open(root + "/" + dirc + "/proc_transcript/phaseII/" + filename + ".txt", "r", encoding="utf-8") as transcript_file:
part_num = 0
for line in transcript_file:
data = line.replace("\n", "").split("\t")
start_time = float(data[1]) / 1000
end_time = float(data[2]) / 1000
dif_time = end_time-start_time
text = data[4]
target_flac_audio = root + "/parts/" + dirc + "/flac/" + filename + "_" + str(part_num) + ".flac"
target_wav_audio = root + "/parts/" + dirc + "/wav/" + filename + "_" + str(part_num) + ".wav"
# print("sox " + source_audio + " " + target_flac_audio + " trim " + str(start_time) + " " + str(dif_time))
pipe = subprocess.check_output("sox " + source_audio + " " + target_flac_audio + " trim " + str(start_time) + " " + str(dif_time), shell=True)
try:
# print("sox " + target_flac_audio + " " + target_wav_audio)
out2 = os.popen("sox " + target_flac_audio + " " + target_wav_audio).read()
sound, _ = torchaudio.load(target_wav_audio)
# print("Write transcript")
with open(root + "/parts/" + dirc + "/proc_transcript/" + filename + "_" + str(part_num) + ".txt", "w+", encoding="utf-8") as text_file:
text_file.write(text + "\n")
except:
print("Error reading audio file: unknown length, the audio is not with proper length, skip, target_flac_audio {}", target_flac_audio)
part_num += 1
"""
################################################
COMMON FUNCTIONS
################################################
"""
def traverse(root, path, dev_conversation_phase2, test_conversation_phase2, dev_interview_phase2, test_interview_phase2, search_fix=".txt"):
f_train_list = []
f_dev_list = []
f_test_list = []
p = root + path
for sub_p in sorted(os.listdir(p)):
if sub_p[len(sub_p)-len(search_fix):] == search_fix:
if "conversation" in path:
print(">", path, sub_p)
if sub_p[2:6] in dev_conversation_phase2:
f_dev_list.append(p + "/" + sub_p)
elif sub_p[2:6] in test_conversation_phase2:
f_test_list.append(p + "/" + sub_p)
else:
f_train_list.append(p + "/" + sub_p)
elif "interview" in path:
print(">", path, sub_p)
if sub_p[:4] in dev_interview_phase2:
f_dev_list.append(p + "/" + sub_p)
elif sub_p[:4] in test_interview_phase2:
f_test_list.append(p + "/" + sub_p)
else:
f_train_list.append(p + "/" + sub_p)
else:
print("hoho")
return f_train_list, f_dev_list, f_test_list
def traverse_all(root, path):
f_list = []
p = root + path
for sub_p in sorted(os.listdir(p)):
f_list.append(p + "/" + sub_p)
return f_list
def calculate_lm_score(seq, lm, id2label):
"""
seq: (1, seq_len)
id2label: map
"""
# print("hello")
seq_str = "".join(id2label[char.item()] for char in seq[0]).replace(
constant.PAD_CHAR, "").replace(constant.SOS_CHAR, "").replace(constant.EOS_CHAR, "")
seq_str = seq_str.replace(" ", " ")
seq_arr = get_word_segments_per_language(seq_str)
seq_str = ""
for i in range(len(seq_arr)):
if is_contain_chinese_word(seq_arr[i]):
for char in seq_arr[i]:
if seq_str != "":
seq_str += " "
seq_str += char
else:
if seq_str != "":
seq_str += " "
seq_str += seq_arr[i]
# print("seq_str:", seq_str)
seq_str = seq_str.replace(" ", " ").replace(" ", " ")
# print("seq str:", seq_str)
if seq_str == "":
return -999, 0, 0
score, oov_token = lm.evaluate(seq_str)
# a, b = lm.evaluate("除非 的 不会 improve 什么 东西 的 这些 esperience")
# a2, b2 = lm.evaluate("除非 的 不会 improve 什么 东西 的 这些 experience")
# print(a, a2)
return -1 * score / len(seq_str.split()) + 1, len(seq_str.split()) + 1, oov_token
class LM(object):
def __init__(self, model_path):
self.model_path = model_path
print("load model path:", self.model_path)
checkpoint = torch.load(model_path)
self.word2idx = checkpoint["word2idx"]
self.idx2word = checkpoint["idx2word"]
ntokens = checkpoint["ntoken"]
ninp = checkpoint["ninp"]
nhid = checkpoint["nhid"]
nlayers = checkpoint["nlayers"]
dropout = checkpoint["dropout"]
tie_weights = checkpoint["tie_weights"]
self.model = RNNModel("LSTM", ntoken=ntokens, ninp=ninp, nhid=nhid,
nlayers=nlayers, dropout=dropout, tie_weights=tie_weights)
self.model.load_state_dict(checkpoint["model_state_dict"])
if constant.args.cuda:
self.model = self.model.cuda()
self.criterion = nn.CrossEntropyLoss()
def batchify(self, data, bsz, cuda):
# Work out how cleanly we can divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * bsz)
# Evenly divide the data across the bsz batches.
data = data.view(bsz, -1).t().contiguous()
if cuda:
data = data.cuda()
return data
def seq_to_tensor(self, seq):
words = seq.split() + ['<eos>']
ids = torch.LongTensor(len(words))
token = 0
oov_token = 0
for word in words:
if word in self.word2idx:
ids[token] = self.word2idx[word]
else:
ids[token] = self.word2idx['<oov>']
oov_token += 1
# print(">", word, ids[token])
token += 1
# print("ids", ids)
return ids, oov_token
def get_batch(self, source, i, bptt, seq_len=None, evaluation=False):
seq_len = min(seq_len if seq_len else bptt, len(source) - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len].view(-1)
return data, target
def evaluate(self, seq):
"""
batch_size = 1
"""
tensor, oov_token = self.seq_to_tensor(seq)
data_source = self.batchify(tensor
, 1, constant.args.cuda)
self.model.eval()
total_loss = 0
ntokens = len(self.word2idx)
hidden = self.model.init_hidden(1)
data, targets = self.get_batch(
data_source, 0, data_source.size(0), evaluation=True)
output, hidden = self.model(data, hidden)
# calculate probability
# print(output.size()) # seq_len, vocab
output_flat = output.view(-1, ntokens)
total_loss += len(data) * self.criterion(output_flat, targets).data
hidden = self.repackage_hidden(hidden)
return total_loss, oov_token
def repackage_hidden(self, h):
"""Wraps hidden states in new Tensors,
to detach them from their history."""
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(self.repackage_hidden(v) for v in h)
class RNNModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False):
super(RNNModel, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
if rnn_type in ['LSTM', 'GRU']:
self.rnn = getattr(nn, rnn_type)(
ninp, nhid, nlayers, dropout=dropout)
else:
try:
nonlinearity = {'RNN_TANH': 'tanh',
'RNN_RELU': 'relu'}[rnn_type]
except KeyError:
raise ValueError("""An invalid option for `--model` was supplied,
options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""")
self.rnn = nn.RNN(ninp, nhid, nlayers,
nonlinearity=nonlinearity, dropout=dropout)
self.decoder = nn.Linear(nhid, ntoken)
# Optionally tie weights as in:
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
# https://arxiv.org/abs/1608.05859
# and
# "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
# https://arxiv.org/abs/1611.01462
if tie_weights:
if nhid != ninp:
raise ValueError(
'When using the tied flag, nhid must be equal to emsize')
self.decoder.weight = self.encoder.weight
self.rnn_type = rnn_type
self.nhid = nhid
self.nlayers = nlayers
self.init_weights()
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.fill_(0)
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, input, hidden):
emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
decoded = self.decoder(output.view(
output.size(0)*output.size(1), output.size(2)))
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden
def init_hidden(self, bsz):
weight = next(self.parameters()).data
if self.rnn_type == 'LSTM':
return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()),
Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()))
else:
return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())
def calculate_cer_en_zh(s1, s2):
"""
Computes the Character Error Rate, defined as the edit distance.
Arguments:
s1 (string): space-separated sentence (hyp)
s2 (string): space-separated sentence (gold)
"""
s1_segments = get_word_segments_per_language(s1)
s2_segments = get_word_segments_per_language(s2)
en_s1_seq, en_s2_seq = "", ""
zh_s1_seq, zh_s2_seq = "", ""
for segment in s1_segments:
if is_contain_chinese_word(segment):
if zh_s1_seq != "":
zh_s1_seq += " "
zh_s1_seq += segment
else:
if en_s1_seq != "":
en_s1_seq += " "
en_s1_seq += segment
for segment in s2_segments:
if is_contain_chinese_word(segment):
if zh_s2_seq != "":
zh_s2_seq += " "
zh_s2_seq += segment
else:
if en_s2_seq != "":
en_s2_seq += " "
en_s2_seq += segment
# print(">", en_s1_seq, "||", en_s2_seq, len(en_s2_seq), "||", calculate_cer(en_s1_seq, en_s2_seq) / max(1, len(en_s2_seq.replace(' ', ''))))
# print(">>", zh_s1_seq, "||", zh_s2_seq, len(zh_s2_seq), "||", calculate_cer(zh_s1_seq, zh_s2_seq) / max(1, len(zh_s2_seq.replace(' ', ''))))
return calculate_cer(en_s1_seq, en_s2_seq), calculate_cer(zh_s1_seq, zh_s2_seq), len(en_s2_seq), len(zh_s2_seq)
def calculate_cer(s1, s2):
"""
Computes the Character Error Rate, defined as the edit distance.
Arguments:
s1 (string): space-separated sentence (hyp)
s2 (string): space-separated sentence (gold)
"""
return Lev.distance(s1, s2)
def calculate_wer(s1, s2):
"""
Computes the Word Error Rate, defined as the edit distance between the
two provided sentences after tokenizing to words.
Arguments:
s1 (string): space-separated sentence
s2 (string): space-separated sentence
"""
# build mapping of words to integers
b = set(s1.split() + s2.split())
word2char = dict(zip(b, range(len(b))))
# map the words to a char array (Levenshtein packages only accepts
# strings)
w1 = [chr(word2char[w]) for w in s1.split()]
w2 = [chr(word2char[w]) for w in s2.split()]
return Lev.distance(''.join(w1), ''.join(w2))
def calculate_metrics(pred, gold, input_lengths=None, target_lengths=None, smoothing=0.0, loss_type="ce"):
"""
Calculate metrics
args:
pred: B x T x C
gold: B x T
input_lengths: B (for CTC)
target_lengths: B (for CTC)
"""
loss = calculate_loss(pred, gold, input_lengths, target_lengths, smoothing, loss_type)
if loss_type == "ce":
pred = pred.view(-1, pred.size(2)) # (B*T) x C
gold = gold.contiguous().view(-1) # (B*T)
pred = pred.max(1)[1]
non_pad_mask = gold.ne(constant.PAD_TOKEN)
num_correct = pred.eq(gold)
num_correct = num_correct.masked_select(non_pad_mask).sum().item()
return loss, num_correct
elif loss_type == "ctc":
return loss, None
else:
print("loss is not defined")
return None, None
def calculate_loss(pred, gold, input_lengths=None, target_lengths=None, smoothing=0.0, loss_type="ce"):
"""
Calculate loss
args:
pred: B x T x C
gold: B x T
input_lengths: B (for CTC)
target_lengths: B (for CTC)
smoothing:
type: ce|ctc (ctc => pytorch 1.0.0 or later)
input_lengths: B (only for ctc)
target_lengths: B (only for ctc)
"""
if loss_type == "ce":
pred = pred.view(-1, pred.size(2)) # (B*T) x C
gold = gold.contiguous().view(-1) # (B*T)
if smoothing > 0.0:
eps = smoothing
num_class = pred.size(1)
gold_for_scatter = gold.ne(constant.PAD_TOKEN).long() * gold
one_hot = torch.zeros_like(pred).scatter(1, gold_for_scatter.view(-1, 1), 1)
one_hot = one_hot * (1-eps) + (1-one_hot) * eps / num_class
log_prob = F.log_softmax(pred, dim=1)
non_pad_mask = gold.ne(constant.PAD_TOKEN)
num_word = non_pad_mask.sum().item()
loss = -(one_hot * log_prob).sum(dim=1)
loss = loss.masked_select(non_pad_mask).sum() / num_word
else:
loss = F.cross_entropy(pred, gold, ignore_index=constant.PAD_TOKEN, reduction="mean")
elif loss_type == "ctc":
log_probs = pred.transpose(0, 1) # T x B x C
# print(gold.size())
targets = gold
# targets = gold.contiguous().view(-1) # (B*T)
"""
log_probs: torch.Size([209, 8, 3793])
targets: torch.Size([8, 46])
input_lengths: torch.Size([8])
target_lengths: torch.Size([8])
"""
# print("log_probs:", log_probs.size())
# print("targets:", targets.size())
# print("input_lengths:", input_lengths.size())
# print("target_lengths:", target_lengths.size())
# print(input_lengths)
# print(target_lengths)
log_probs = F.log_softmax(log_probs, dim=2)
loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction="mean")
# mask = loss.clone() # mask Inf loss
# # mask[mask != float("Inf")] = 1
# mask[mask == float("Inf")] = 0
# loss = mask
# print(loss)
# loss_size = len(loss)
# loss = loss.sum() / loss_size
# print(loss)
else:
print("loss is not defined")
return loss
class NoamOpt:
"Optim wrapper that implements rate."
def __init__(self, model_size, factor, warmup, optimizer, min_lr=1e-5):
self.optimizer = optimizer
self._step = 0
self.warmup = warmup
self.factor = factor
self.model_size = model_size
self._rate = 0
self.min_lr = min_lr
def step(self):
"Update parameters and rate"
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
def zero_grad(self):
self.optimizer.zero_grad()
def rate(self, step=None):
"Implement `lrate` above"
step = self._step
return max(self.min_lr, self.factor * \
(self.model_size ** (-0.5) * min(step **
(-0.5), step * self.warmup ** (-1.5))))
class AnnealingOpt:
"Optim wrapper for annealing opt"
def __init__(self, lr, lr_anneal, optimizer):
self.optimizer = optimizer
self.lr = lr
self.lr_anneal = lr_anneal
def step(self):
optim_state = self.optimizer.state_dict()
optim_state['param_groups'][0]['lr'] = optim_state['param_groups'][0]['lr'] / self.lr_anneal
self.optimizer.load_state_dict(optim_state)
# class SGDOpt:
# "Optim wrapper that implements SGD"
# def __init__(self, parameters, lr, momentum, nesterov=True):
# self.optimizer = torch.optim.SGD(parameters, lr=lr, momentum=momentum, nesterov=nesterov)
class Trainer():
"""
Trainer class
"""
def __init__(self):
logging.info("Trainer is initialized")
def train(self, model, train_loader, train_sampler, valid_loader_list, opt, loss_type, start_epoch, num_epochs, label2id, id2label, last_metrics=None):
"""
Training
args:
model: Model object
train_loader: DataLoader object of the training set
valid_loader_list: a list of Validation DataLoader objects
opt: Optimizer object
start_epoch: start epoch (> 0 if you resume the process)
num_epochs: last epoch
last_metrics: (if resume)
"""
history = []
start_time = time.time()
best_valid_loss = 1000000000 if last_metrics is None else last_metrics['valid_loss']
smoothing = constant.args.label_smoothing
logging.info("name " + constant.args.name)
for epoch in range(start_epoch, num_epochs):
sys.stdout.flush()
total_loss, total_cer, total_wer, total_char, total_word = 0, 0, 0, 0, 0
start_iter = 0
logging.info("TRAIN")
model.train()
pbar = tqdm(iter(train_loader), leave=True, total=len(train_loader))
for i, (data) in enumerate(pbar, start=start_iter):
src, tgt, src_percentages, src_lengths, tgt_lengths = data
if constant.USE_CUDA:
src = src.cuda()
tgt = tgt.cuda()
opt.zero_grad()
pred, gold, hyp_seq, gold_seq = model(src, src_lengths, tgt, verbose=False)
try: # handle case for CTC
strs_gold, strs_hyps = [], []
for ut_gold in gold_seq:
str_gold = ""
for x in ut_gold:
if int(x) == constant.PAD_TOKEN:
break
str_gold = str_gold + id2label[int(x)]
strs_gold.append(str_gold)
for ut_hyp in hyp_seq:
str_hyp = ""
for x in ut_hyp:
if int(x) == constant.PAD_TOKEN:
break
str_hyp = str_hyp + id2label[int(x)]
strs_hyps.append(str_hyp)
except Exception as e:
print(e)
logging.info("NaN predictions")
continue
seq_length = pred.size(1)
sizes = Variable(src_percentages.mul_(int(seq_length)).int(), requires_grad=False)
loss, num_correct = calculate_metrics(
pred, gold, input_lengths=sizes, target_lengths=tgt_lengths, smoothing=smoothing, loss_type=loss_type)
if loss.item() == float('Inf'):
logging.info("Found infinity loss, masking")
loss = torch.where(loss != loss, torch.zeros_like(loss), loss) # NaN masking
continue
# if constant.args.verbose:
# logging.info("GOLD", strs_gold)
# logging.info("HYP", strs_hyps)
for j in range(len(strs_hyps)):
strs_hyps[j] = strs_hyps[j].replace(constant.SOS_CHAR, '').replace(constant.EOS_CHAR, '')
strs_gold[j] = strs_gold[j].replace(constant.SOS_CHAR, '').replace(constant.EOS_CHAR, '')
cer = calculate_cer(strs_hyps[j].replace(' ', ''), strs_gold[j].replace(' ', ''))
wer = calculate_wer(strs_hyps[j], strs_gold[j])
total_cer += cer
total_wer += wer
total_char += len(strs_gold[j].replace(' ', ''))
total_word += len(strs_gold[j].split(" "))
loss.backward()
if constant.args.clip:
torch.nn.utils.clip_grad_norm_(model.parameters(), constant.args.max_norm)
opt.step()
total_loss += loss.item()
non_pad_mask = gold.ne(constant.PAD_TOKEN)
num_word = non_pad_mask.sum().item()
pbar.set_description("(Epoch {}) TRAIN LOSS:{:.4f} CER:{:.2f}% LR:{:.7f}".format(
(epoch+1), total_loss/(i+1), total_cer*100/total_char, opt._rate))
logging.info("(Epoch {}) TRAIN LOSS:{:.4f} CER:{:.2f}% LR:{:.7f}".format(
(epoch+1), total_loss/(len(train_loader)), total_cer*100/total_char, opt._rate))
# evaluate
print("")
logging.info("VALID")
model.eval()
for ind in range(len(valid_loader_list)):
valid_loader = valid_loader_list[ind]
total_valid_loss, total_valid_cer, total_valid_wer, total_valid_char, total_valid_word = 0, 0, 0, 0, 0
valid_pbar = tqdm(iter(valid_loader), leave=True, total=len(valid_loader))
for i, (data) in enumerate(valid_pbar):
src, tgt, src_percentages, src_lengths, tgt_lengths = data
if constant.USE_CUDA:
src = src.cuda()
tgt = tgt.cuda()
with torch.no_grad():
pred, gold, hyp_seq, gold_seq = model(src, src_lengths, tgt, verbose=False)
seq_length = pred.size(1)
sizes = Variable(src_percentages.mul_(int(seq_length)).int(), requires_grad=False)
loss, num_correct = calculate_metrics(
pred, gold, input_lengths=sizes, target_lengths=tgt_lengths, smoothing=smoothing, loss_type=loss_type)
if loss.item() == float('Inf'):
logging.info("Found infinity loss, masking")
loss = torch.where(loss != loss, torch.zeros_like(loss), loss) # NaN masking
continue
try: # handle case for CTC
strs_gold, strs_hyps = [], []
for ut_gold in gold_seq:
str_gold = ""
for x in ut_gold:
if int(x) == constant.PAD_TOKEN:
break
str_gold = str_gold + id2label[int(x)]
strs_gold.append(str_gold)
for ut_hyp in hyp_seq:
str_hyp = ""
for x in ut_hyp:
if int(x) == constant.PAD_TOKEN:
break
str_hyp = str_hyp + id2label[int(x)]
strs_hyps.append(str_hyp)
except Exception as e:
print(e)
logging.info("NaN predictions")
continue
for j in range(len(strs_hyps)):
strs_hyps[j] = strs_hyps[j].replace(constant.SOS_CHAR, '').replace(constant.EOS_CHAR, '')
strs_gold[j] = strs_gold[j].replace(constant.SOS_CHAR, '').replace(constant.EOS_CHAR, '')
cer = calculate_cer(strs_hyps[j].replace(' ', ''), strs_gold[j].replace(' ', ''))
wer = calculate_wer(strs_hyps[j], strs_gold[j])
total_valid_cer += cer
total_valid_wer += wer
total_valid_char += len(strs_gold[j].replace(' ', ''))
total_valid_word += len(strs_gold[j].split(" "))
total_valid_loss += loss.item()
valid_pbar.set_description("VALID SET {} LOSS:{:.4f} CER:{:.2f}%".format(ind,
total_valid_loss/(i+1), total_valid_cer*100/total_valid_char))
logging.info("VALID SET {} LOSS:{:.4f} CER:{:.2f}%".format(ind,
total_valid_loss/(len(valid_loader)), total_valid_cer*100/total_valid_char))
metrics = {}
metrics["train_loss"] = total_loss / len(train_loader)
metrics["valid_loss"] = total_valid_loss / (len(valid_loader))
metrics["train_cer"] = total_cer
metrics["train_wer"] = total_wer
metrics["valid_cer"] = total_valid_cer
metrics["valid_wer"] = total_valid_wer
metrics["history"] = history
history.append(metrics)
if epoch % constant.args.save_every == 0:
save_model(model, (epoch+1), opt, metrics,
label2id, id2label, best_model=False)
# save the best model
if best_valid_loss > total_valid_loss/len(valid_loader):
best_valid_loss = total_valid_loss/len(valid_loader)
save_model(model, (epoch+1), opt, metrics,
label2id, id2label, best_model=True)
if constant.args.shuffle:
logging.info("SHUFFLE")
print("SHUFFLE")
train_sampler.shuffle(epoch)
class Transformer(nn.Module):
"""
Transformer class
args:
encoder: Encoder object
decoder: Decoder object
"""
def __init__(self, encoder, decoder, feat_extractor='vgg_cnn'):
super(Transformer, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.id2label = decoder.id2label
self.feat_extractor = feat_extractor
# feature embedding
if feat_extractor == 'emb_cnn':
self.conv = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(0, 10)),
nn.BatchNorm2d(32),
nn.Hardtanh(0, 20, inplace=True),
nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), ),
nn.BatchNorm2d(32),
nn.Hardtanh(0, 20, inplace=True)
)
elif feat_extractor == 'vgg_cnn':
self.conv = nn.Sequential(
nn.Conv2d(1, 64, 3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(64, 128, 3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, stride=2)
)
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, padded_input, input_lengths, padded_target, verbose=False):
"""
args:
padded_input: B x 1 (channel for spectrogram=1) x (freq) x T
padded_input: B x T x D
input_lengths: B
padded_target: B x T
output:
pred: B x T x vocab
gold: B x T
"""
if self.feat_extractor == 'emb_cnn' or self.feat_extractor == 'vgg_cnn':
padded_input = self.conv(padded_input)
# Reshaping features
sizes = padded_input.size() # B x H_1 (channel?) x H_2 x T
padded_input = padded_input.view(sizes[0], sizes[1] * sizes[2], sizes[3])
padded_input = padded_input.transpose(1, 2).contiguous() # BxTxH
encoder_padded_outputs, _ = self.encoder(padded_input, input_lengths)
pred, gold, *_ = self.decoder(padded_target, encoder_padded_outputs, input_lengths)
hyp_best_scores, hyp_best_ids = torch.topk(pred, 1, dim=2)
hyp_seq = hyp_best_ids.squeeze(2)
gold_seq = gold
return pred, gold, hyp_seq, gold_seq
def evaluate(self, padded_input, input_lengths, padded_target, beam_search=False, beam_width=0, beam_nbest=0, lm=None, lm_rescoring=False, lm_weight=0.1, c_weight=1, verbose=False):
"""
args:
padded_input: B x T x D
input_lengths: B
padded_target: B x T
output:
batch_ids_nbest_hyps: list of nbest id
batch_strs_nbest_hyps: list of nbest str
batch_strs_gold: list of gold str
"""
if self.feat_extractor == 'emb_cnn' or self.feat_extractor == 'vgg_cnn':
padded_input = self.conv(padded_input)
# Reshaping features
sizes = padded_input.size() # B x H_1 (channel?) x H_2 x T
padded_input = padded_input.view(sizes[0], sizes[1] * sizes[2], sizes[3])
padded_input = padded_input.transpose(1, 2).contiguous() # BxTxH
encoder_padded_outputs, _ = self.encoder(padded_input, input_lengths)
hyp, gold, *_ = self.decoder(padded_target, encoder_padded_outputs, input_lengths)
hyp_best_scores, hyp_best_ids = torch.topk(hyp, 1, dim=2)
strs_gold = ["".join([self.id2label[int(x)] for x in gold_seq]) for gold_seq in gold]
if beam_search:
ids_hyps, strs_hyps = self.decoder.beam_search(encoder_padded_outputs, beam_width=beam_width, nbest=1, lm=lm, lm_rescoring=lm_rescoring, lm_weight=lm_weight, c_weight=c_weight)
if len(strs_hyps) != sizes[0]:
print(">>>>>>> switch to greedy")
strs_hyps = self.decoder.greedy_search(encoder_padded_outputs)
else:
strs_hyps = self.decoder.greedy_search(encoder_padded_outputs)
if verbose:
print("GOLD", strs_gold)
print("HYP", strs_hyps)
return _, strs_hyps, strs_gold
class Encoder(nn.Module):
"""
Encoder Transformer class
"""
def __init__(self, num_layers, num_heads, dim_model, dim_key, dim_value, dim_input, dim_inner, dropout=0.1, src_max_length=2500):
super(Encoder, self).__init__()
self.dim_input = dim_input
self.num_layers = num_layers
self.num_heads = num_heads
self.dim_model = dim_model
self.dim_key = dim_key
self.dim_value = dim_value
self.dim_inner = dim_inner
self.src_max_length = src_max_length
self.dropout = nn.Dropout(dropout)
self.dropout_rate = dropout
self.input_linear = nn.Linear(dim_input, dim_model)
self.layer_norm_input = nn.LayerNorm(dim_model)
self.positional_encoding = PositionalEncoding(
dim_model, src_max_length)
self.layers = nn.ModuleList([
EncoderLayer(num_heads, dim_model, dim_inner, dim_key, dim_value, dropout=dropout) for _ in range(num_layers)
])
def forward(self, padded_input, input_lengths):
"""
args:
padded_input: B x T x D
input_lengths: B
return:
output: B x T x H
"""
encoder_self_attn_list = []
# Prepare masks
non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths) # B x T x D
seq_len = padded_input.size(1)
self_attn_mask = get_attn_pad_mask(padded_input, input_lengths, seq_len) # B x T x T
encoder_output = self.layer_norm_input(self.input_linear(
padded_input)) + self.positional_encoding(padded_input)
for layer in self.layers:
encoder_output, self_attn = layer(
encoder_output, non_pad_mask=non_pad_mask, self_attn_mask=self_attn_mask)
encoder_self_attn_list += [self_attn]
return encoder_output, encoder_self_attn_list
class EncoderLayer(nn.Module):
"""
Encoder Layer Transformer class
"""
def __init__(self, num_heads, dim_model, dim_inner, dim_key, dim_value, dropout=0.1):
super(EncoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(
num_heads, dim_model, dim_key, dim_value, dropout=dropout)
self.pos_ffn = PositionwiseFeedForwardWithConv(
dim_model, dim_inner, dropout=dropout)
def forward(self, enc_input, non_pad_mask=None, self_attn_mask=None):
enc_output, self_attn = self.self_attn(
enc_input, enc_input, enc_input, mask=self_attn_mask)
enc_output *= non_pad_mask
enc_output = self.pos_ffn(enc_output)
enc_output *= non_pad_mask
return enc_output, self_attn
class Decoder(nn.Module):
"""
Decoder Layer Transformer class
"""
def __init__(self, id2label, num_src_vocab, num_trg_vocab, num_layers, num_heads, dim_emb, dim_model, dim_inner, dim_key, dim_value, dropout=0.1, trg_max_length=1000, emb_trg_sharing=False):
super(Decoder, self).__init__()
self.sos_id = constant.SOS_TOKEN
self.eos_id = constant.EOS_TOKEN
self.id2label = id2label
self.num_src_vocab = num_src_vocab
self.num_trg_vocab = num_trg_vocab
self.num_layers = num_layers
self.num_heads = num_heads
self.dim_emb = dim_emb
self.dim_model = dim_model
self.dim_inner = dim_inner
self.dim_key = dim_key
self.dim_value = dim_value
self.dropout_rate = dropout
self.emb_trg_sharing = emb_trg_sharing
self.trg_max_length = trg_max_length
self.trg_embedding = nn.Embedding(num_trg_vocab, dim_emb, padding_idx=constant.PAD_TOKEN)
self.positional_encoding = PositionalEncoding(
dim_model, max_length=trg_max_length)
self.dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList([
DecoderLayer(dim_model, dim_inner, num_heads,
dim_key, dim_value, dropout=dropout)
for _ in range(num_layers)
])
self.output_linear = nn.Linear(dim_model, num_trg_vocab, bias=False)
nn.init.xavier_normal_(self.output_linear.weight)
if emb_trg_sharing:
self.output_linear.weight = self.trg_embedding.weight
self.x_logit_scale = (dim_model ** -0.5)
else:
self.x_logit_scale = 1.0
def preprocess(self, padded_input):
"""
Add SOS TOKEN and EOS TOKEN into padded_input
"""
seq = [y[y != constant.PAD_TOKEN] for y in padded_input]
eos = seq[0].new([self.eos_id])
sos = seq[0].new([self.sos_id])
seq_in = [torch.cat([sos, y], dim=0) for y in seq]
seq_out = [torch.cat([y, eos], dim=0) for y in seq]
seq_in_pad = pad_list(seq_in, self.eos_id)
seq_out_pad = pad_list(seq_out, constant.PAD_TOKEN)
assert seq_in_pad.size() == seq_out_pad.size()
return seq_in_pad, seq_out_pad
def forward(self, padded_input, encoder_padded_outputs, encoder_input_lengths):
"""
args:
padded_input: B x T
encoder_padded_outputs: B x T x H
encoder_input_lengths: B
returns:
pred: B x T x vocab
gold: B x T
"""
decoder_self_attn_list, decoder_encoder_attn_list = [], []
seq_in_pad, seq_out_pad = self.preprocess(padded_input)
# Prepare masks
non_pad_mask = get_non_pad_mask(seq_in_pad, pad_idx=constant.EOS_TOKEN)
self_attn_mask_subseq = get_subsequent_mask(seq_in_pad)
self_attn_mask_keypad = get_attn_key_pad_mask(
seq_k=seq_in_pad, seq_q=seq_in_pad, pad_idx=constant.EOS_TOKEN)
self_attn_mask = (self_attn_mask_keypad + self_attn_mask_subseq).gt(0)
output_length = seq_in_pad.size(1)
dec_enc_attn_mask = get_attn_pad_mask(
encoder_padded_outputs, encoder_input_lengths, output_length)
decoder_output = self.dropout(self.trg_embedding(
seq_in_pad) * self.x_logit_scale + self.positional_encoding(seq_in_pad))
for layer in self.layers:
decoder_output, decoder_self_attn, decoder_enc_attn = layer(
decoder_output, encoder_padded_outputs, non_pad_mask=non_pad_mask, self_attn_mask=self_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask)
decoder_self_attn_list += [decoder_self_attn]
decoder_encoder_attn_list += [decoder_enc_attn]
seq_logit = self.output_linear(decoder_output)
pred, gold = seq_logit, seq_out_pad
return pred, gold, decoder_self_attn_list, decoder_encoder_attn_list
def post_process_hyp(self, hyp):
"""
args:
hyp: list of hypothesis
output:
list of hypothesis (string)>
"""
return "".join([self.id2label[int(x)] for x in hyp['yseq'][1:]])
def greedy_search(self, encoder_padded_outputs, beam_width=2, lm_rescoring=False, lm=None, lm_weight=0.1, c_weight=1):
"""
Greedy search, decode 1-best utterance
args:
encoder_padded_outputs: B x T x H
output:
batch_ids_nbest_hyps: list of nbest in ids (size B)
batch_strs_nbest_hyps: list of nbest in strings (size B)
"""
max_seq_len = self.trg_max_length
ys = torch.ones(encoder_padded_outputs.size(0),1).fill_(constant.SOS_TOKEN).long() # batch_size x 1
if constant.args.cuda:
ys = ys.cuda()
decoded_words = []
for t in range(300):
# for t in range(max_seq_len):
# print(t)
# Prepare masks
non_pad_mask = torch.ones_like(ys).float().unsqueeze(-1) # batch_size x t x 1
self_attn_mask = get_subsequent_mask(ys) # batch_size x t x t
decoder_output = self.dropout(self.trg_embedding(ys) * self.x_logit_scale
+ self.positional_encoding(ys))
for layer in self.layers:
decoder_output, _, _ = layer(
decoder_output, encoder_padded_outputs,
non_pad_mask=non_pad_mask,
self_attn_mask=self_attn_mask,
dec_enc_attn_mask=None
)
prob = self.output_linear(decoder_output) # batch_size x t x label_size
# _, next_word = torch.max(prob[:, -1], dim=1)
# decoded_words.append([constant.EOS_CHAR if ni.item() == constant.EOS_TOKEN else self.id2label[ni.item()] for ni in next_word.view(-1)])
# next_word = next_word.unsqueeze(-1)
# local_best_scores, local_best_ids = torch.topk(local_scores, beam_width, dim=1)
if lm_rescoring:
local_scores = F.log_softmax(prob, dim=1)
local_best_scores, local_best_ids = torch.topk(local_scores, beam_width, dim=1)
best_score = -1
best_word = None
# calculate beam scores
for j in range(beam_width):
cur_seq = " ".join(word for word in decoded_words)
lm_score, num_words, oov_token = calculate_lm_score(cur_seq, lm, self.id2label)
score = local_best_scores[0, j] + lm_score
if best_score < score:
best_score = score
best_word = local_best_ids[0, j]
next_word = best_word.unsqueeze(-1)
decoded_words.append(self.id2label[int(best_word)])
else:
_, next_word = torch.max(prob[:, -1], dim=1)
decoded_words.append([constant.EOS_CHAR if ni.item() == constant.EOS_TOKEN else self.id2label[ni.item()] for ni in next_word.view(-1)])
next_word = next_word.unsqueeze(-1)
if constant.args.cuda:
ys = torch.cat([ys, next_word.cuda()], dim=1)
ys = ys.cuda()
else:
ys = torch.cat([ys, next_word], dim=1)
sent = []
for _, row in enumerate(np.transpose(decoded_words)):
st = ''
for e in row:
if e == constant.EOS_CHAR:
break
else:
st += e
sent.append(st)
return sent
def beam_search(self, encoder_padded_outputs, beam_width=2, nbest=5, lm_rescoring=False, lm=None, lm_weight=0.1, c_weight=1, prob_weight=1.0):
"""
Beam search, decode nbest utterances
args:
encoder_padded_outputs: B x T x H
beam_size: int
nbest: int
output:
batch_ids_nbest_hyps: list of nbest in ids (size B)
batch_strs_nbest_hyps: list of nbest in strings (size B)
"""
batch_size = encoder_padded_outputs.size(0)
max_len = encoder_padded_outputs.size(1)
batch_ids_nbest_hyps = []
batch_strs_nbest_hyps = []
for x in range(batch_size):
encoder_output = encoder_padded_outputs[x].unsqueeze(0) # 1 x T x H
# add SOS_TOKEN
ys = torch.ones(1, 1).fill_(constant.SOS_TOKEN).type_as(encoder_output).long()
hyp = {'score': 0.0, 'yseq':ys}
hyps = [hyp]
ended_hyps = []
for i in range(300):
# for i in range(self.trg_max_length):
hyps_best_kept = []
for hyp in hyps:
ys = hyp['yseq'] # 1 x i
# Prepare masks
non_pad_mask = torch.ones_like(ys).float().unsqueeze(-1) # 1xix1
self_attn_mask = get_subsequent_mask(ys)
decoder_output = self.dropout(self.trg_embedding(ys) * self.x_logit_scale
+ self.positional_encoding(ys))
for layer in self.layers:
# print(decoder_output.size(), encoder_output.size())
decoder_output, _, _ = layer(
decoder_output, encoder_output,
non_pad_mask=non_pad_mask,
self_attn_mask=self_attn_mask,
dec_enc_attn_mask=None
)
seq_logit = self.output_linear(decoder_output[:, -1])
local_scores = F.log_softmax(seq_logit, dim=1)
local_best_scores, local_best_ids = torch.topk(local_scores, beam_width, dim=1)
# calculate beam scores
for j in range(beam_width):
new_hyp = {}
new_hyp["score"] = hyp["score"] + local_best_scores[0, j]
new_hyp["yseq"] = torch.ones(1, (1+ys.size(1))).type_as(encoder_output).long()
new_hyp["yseq"][:, :ys.size(1)] = hyp["yseq"].cpu()
new_hyp["yseq"][:, ys.size(1)] = int(local_best_ids[0, j]) # adding new word
hyps_best_kept.append(new_hyp)
hyps_best_kept = sorted(hyps_best_kept, key=lambda x:x["score"], reverse=True)[:beam_width]
hyps = hyps_best_kept
# add EOS_TOKEN
if i == max_len - 1:
for hyp in hyps:
hyp["yseq"] = torch.cat([hyp["yseq"], torch.ones(1,1).fill_(constant.EOS_TOKEN).type_as(encoder_output).long()], dim=1)
# add hypothesis that have EOS_TOKEN to ended_hyps list
unended_hyps = []
for hyp in hyps:
if hyp["yseq"][0, -1] == constant.EOS_TOKEN:
if lm_rescoring:
# seq_str = "".join(self.id2label[char.item()] for char in hyp["yseq"][0]).replace(constant.PAD_CHAR,"").replace(constant.SOS_CHAR,"").replace(constant.EOS_CHAR,"")
# seq_str = seq_str.replace(" ", " ")
# num_words = len(seq_str.split())
hyp["lm_score"], hyp["num_words"], oov_token = calculate_lm_score(hyp["yseq"], lm, self.id2label)
num_words = hyp["num_words"]
hyp["lm_score"] -= oov_token * 2
hyp["final_score"] = hyp["score"] + lm_weight * hyp["lm_score"] + math.sqrt(num_words) * c_weight
else:
seq_str = "".join(self.id2label[char.item()] for char in hyp["yseq"][0]).replace(constant.PAD_CHAR,"").replace(constant.SOS_CHAR,"").replace(constant.EOS_CHAR,"")
seq_str = seq_str.replace(" ", " ")
num_words = len(seq_str.split())
hyp["final_score"] = hyp["score"] + math.sqrt(num_words) * c_weight
ended_hyps.append(hyp)
else:
unended_hyps.append(hyp)
hyps = unended_hyps
if len(hyps) == 0:
# decoding process is finished
break
num_nbest = min(len(ended_hyps), nbest)
nbest_hyps = sorted(ended_hyps, key=lambda x:x["final_score"], reverse=True)[:num_nbest]
a_nbest_hyps = sorted(ended_hyps, key=lambda x:x["final_score"], reverse=True)[:beam_width]
if lm_rescoring:
for hyp in a_nbest_hyps:
seq_str = "".join(self.id2label[char.item()] for char in hyp["yseq"][0]).replace(constant.PAD_CHAR,"").replace(constant.SOS_CHAR,"").replace(constant.EOS_CHAR,"")
seq_str = seq_str.replace(" ", " ")
num_words = len(seq_str.split())
# print("{} || final:{} e2e:{} lm:{} num words:{}".format(seq_str, hyp["final_score"], hyp["score"], hyp["lm_score"], hyp["num_words"]))
for hyp in nbest_hyps:
hyp["yseq"] = hyp["yseq"][0].cpu().numpy().tolist()
hyp_strs = self.post_process_hyp(hyp)
batch_ids_nbest_hyps.append(hyp["yseq"])
batch_strs_nbest_hyps.append(hyp_strs)
# print(hyp["yseq"], hyp_strs)
return batch_ids_nbest_hyps, batch_strs_nbest_hyps
class DecoderLayer(nn.Module):
"""
Decoder Transformer class
"""
def __init__(self, dim_model, dim_inner, num_heads, dim_key, dim_value, dropout=0.1):
super(DecoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(
num_heads, dim_model, dim_key, dim_value, dropout=dropout)
self.encoder_attn = MultiHeadAttention(
num_heads, dim_model, dim_key, dim_value, dropout=dropout)
self.pos_ffn = PositionwiseFeedForwardWithConv(
dim_model, dim_inner, dropout=dropout)
def forward(self, decoder_input, encoder_output, non_pad_mask=None, self_attn_mask=None, dec_enc_attn_mask=None):
decoder_output, decoder_self_attn = self.self_attn(
decoder_input, decoder_input, decoder_input, mask=self_attn_mask)
decoder_output *= non_pad_mask
decoder_output, decoder_encoder_attn = self.encoder_attn(
decoder_output, encoder_output, encoder_output, mask=dec_enc_attn_mask)
decoder_output *= non_pad_mask
decoder_output = self.pos_ffn(decoder_output)
decoder_output *= non_pad_mask
return decoder_output, decoder_self_attn, decoder_encoder_attn
"""
General purpose functions
"""
def pad_list(xs, pad_value):
# From: espnet/src/nets/e2e_asr_th.py: pad_list()
n_batch = len(xs)
# max_len = max(x.size(0) for x in xs)
max_len = constant.args.tgt_max_len
pad = xs[0].new(n_batch, max_len, * xs[0].size()[1:]).fill_(pad_value)
for i in range(n_batch):
pad[i, :xs[i].size(0)] = xs[i]
return pad
"""
Transformer common layers
"""
def get_non_pad_mask(padded_input, input_lengths=None, pad_idx=None):
"""
padding position is set to 0, either use input_lengths or pad_idx
"""
assert input_lengths is not None or pad_idx is not None
if input_lengths is not None:
# padded_input: N x T x ..
N = padded_input.size(0)
non_pad_mask = padded_input.new_ones(padded_input.size()[:-1]) # B x T
for i in range(N):
non_pad_mask[i, input_lengths[i]:] = 0
if pad_idx is not None:
# padded_input: N x T
assert padded_input.dim() == 2
non_pad_mask = padded_input.ne(pad_idx).float()
# unsqueeze(-1) for broadcast
return non_pad_mask.unsqueeze(-1)
def get_attn_key_pad_mask(seq_k, seq_q, pad_idx):
"""
For masking out the padding part of key sequence.
"""
# Expand to fit the shape of key query attention matrix.
len_q = seq_q.size(1)
padding_mask = seq_k.eq(pad_idx)
padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # B x T_Q x T_K
return padding_mask
def get_attn_pad_mask(padded_input, input_lengths, expand_length):
"""mask position is set to 1"""
# N x Ti x 1
non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths)
# N x Ti, lt(1) like not operation
pad_mask = non_pad_mask.squeeze(-1).lt(1)
attn_mask = pad_mask.unsqueeze(1).expand(-1, expand_length, -1)
return attn_mask
def get_subsequent_mask(seq):
''' For masking out the subsequent info. '''
sz_b, len_s = seq.size()
subsequent_mask = torch.triu(
torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls
return subsequent_mask
class PositionalEncoding(nn.Module):
"""
Positional Encoding class
"""
def __init__(self, dim_model, max_length=2000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_length, dim_model, requires_grad=False)
position = torch.arange(0, max_length).unsqueeze(1).float()
exp_term = torch.exp(torch.arange(0, dim_model, 2).float() * -(math.log(10000.0) / dim_model))
pe[:, 0::2] = torch.sin(position * exp_term) # take the odd (jump by 2)
pe[:, 1::2] = torch.cos(position * exp_term) # take the even (jump by 2)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, input):
"""
args:
input: B x T x D
output:
tensor: B x T
"""
return self.pe[:, :input.size(1)]
class PositionwiseFeedForward(nn.Module):
"""
Position-wise Feedforward Layer class
FFN(x) = max(0, xW1 + b1) W2+ b2
"""
def __init__(self, dim_model, dim_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.linear_1 = nn.Linear(dim_model, dim_ff)
self.linear_2 = nn.Linear(dim_ff, dim_model)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(dim_model)
def forward(self, x):
"""
args:
x: tensor
output:
y: tensor
"""
residual = x
output = self.dropout(self.linear_2(F.relu(self.linear_1(x))))
output = self.layer_norm(output + residual)
return output
class PositionwiseFeedForwardWithConv(nn.Module):
"""
Position-wise Feedforward Layer Implementation with Convolution class
"""
def __init__(self, dim_model, dim_hidden, dropout=0.1):
super(PositionwiseFeedForwardWithConv, self).__init__()
self.conv_1 = nn.Conv1d(dim_model, dim_hidden, 1)
self.conv_2 = nn.Conv1d(dim_hidden, dim_model, 1)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(dim_model)
def forward(self, x):
residual = x
output = x.transpose(1, 2)
output = self.conv_2(F.relu(self.conv_1(output)))
output = output.transpose(1, 2)
output = self.dropout(output)
output = self.layer_norm(output + residual)
return output
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, dim_model, dim_key, dim_value, dropout=0.1):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.dim_model = dim_model
self.dim_key = dim_key
self.dim_value = dim_value
self.query_linear = nn.Linear(dim_model, num_heads * dim_key)
self.key_linear = nn.Linear(dim_model, num_heads * dim_key)
self.value_linear = nn.Linear(dim_model, num_heads * dim_value)
nn.init.normal_(self.query_linear.weight, mean=0, std=np.sqrt(2.0 / (self.dim_model + self.dim_key)))
nn.init.normal_(self.key_linear.weight, mean=0, std=np.sqrt(2.0 / (self.dim_model + self.dim_key)))
nn.init.normal_(self.value_linear.weight, mean=0, std=np.sqrt(2.0 / (self.dim_model + self.dim_value)))
self.attention = ScaledDotProductAttention(temperature=np.power(dim_key, 0.5), attn_dropout=dropout)
self.layer_norm = nn.LayerNorm(dim_model)
self.output_linear = nn.Linear(num_heads * dim_value, dim_model)
nn.init.xavier_normal_(self.output_linear.weight)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
"""
query: B x T_Q x H, key: B x T_K x H, value: B x T_V x H
mask: B x T x T (attention mask)
"""
batch_size, len_query, _ = query.size()
batch_size, len_key, _ = key.size()
batch_size, len_value, _ = value.size()
residual = query
query = self.query_linear(query).view(batch_size, len_query, self.num_heads, self.dim_key) # B x T_Q x num_heads x H_K
key = self.key_linear(key).view(batch_size, len_key, self.num_heads, self.dim_key) # B x T_K x num_heads x H_K
value = self.value_linear(value).view(batch_size, len_value, self.num_heads, self.dim_value) # B x T_V x num_heads x H_V
query = query.permute(2, 0, 1, 3).contiguous().view(-1, len_query, self.dim_key) # (num_heads * B) x T_Q x H_K
key = key.permute(2, 0, 1, 3).contiguous().view(-1, len_key, self.dim_key) # (num_heads * B) x T_K x H_K
value = value.permute(2, 0, 1, 3).contiguous().view(-1, len_value, self.dim_value) # (num_heads * B) x T_V x H_V
if mask is not None:
mask = mask.repeat(self.num_heads, 1, 1) # (B * num_head) x T x T
output, attn = self.attention(query, key, value, mask=mask)
output = output.view(self.num_heads, batch_size, len_query, self.dim_value) # num_heads x B x T_Q x H_V
output = output.permute(1, 2, 0, 3).contiguous().view(batch_size, len_query, -1) # B x T_Q x (num_heads * H_V)
output = self.dropout(self.output_linear(output)) # B x T_Q x H_O
output = self.layer_norm(output + residual)
return output, attn
class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, mask=None):
"""
"""
attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / self.temperature
if mask is not None:
attn = attn.masked_fill(mask, -np.inf)
attn = self.softmax(attn)
attn = self.dropout(attn)
output = torch.bmm(attn, v)
return output, attn
"""
LAS common layers
"""
class DotProductAttention(nn.Module):
"""
Dot product attention.
Given a set of vector values, and a vector query, attention is a technique
to compute a weighted sum of the values, dependent on the query.
NOTE: Here we use the terminology in Stanford cs224n-2018-lecture11.
"""
def __init__(self):
super(DotProductAttention, self).__init__()
# TODO: move this out of this class?
# self.linear_out = nn.Linear(dim*2, dim)
def forward(self, queries, values):
"""
Args:
queries: N x To x H
values : N x Ti x H
Returns:
output: N x To x H
attention_distribution: N x To x Ti
"""
batch_size = queries.size(0)
hidden_size = queries.size(2)
input_lengths = values.size(1)
# (N, To, H) * (N, H, Ti) -> (N, To, Ti)
attention_scores = torch.bmm(queries, values.transpose(1, 2))
attention_distribution = F.softmax(
attention_scores.view(-1, input_lengths), dim=1).view(batch_size, -1, input_lengths)
# (N, To, Ti) * (N, Ti, H) -> (N, To, H)
attention_output = torch.bmm(attention_distribution, values)
# # concat -> (N, To, 2*H)
# concated = torch.cat((attention_output, queries), dim=2)
# # TODO: Move this out of this class?
# # output -> (N, To, H)
# output = torch.tanh(self.linear_out(
# concated.view(-1, 2*hidden_size))).view(batch_size, -1, hidden_size)
return attention_output, attention_distribution
def save_model(model, epoch, opt, metrics, label2id, id2label, best_model=False):
"""
Saving model, TODO adding history
"""
if best_model:
save_path = "{}/{}/best_model.th".format(
constant.args.save_folder, constant.args.name)
else:
save_path = "{}/{}/epoch_{}.th".format(constant.args.save_folder,
constant.args.name, epoch)
if not os.path.exists(constant.args.save_folder + "/" + constant.args.name):
os.makedirs(constant.args.save_folder + "/" + constant.args.name)
print("SAVE MODEL to", save_path)
if constant.args.loss == "ce":
args = {
'label2id': label2id,
'id2label': id2label,
'args': constant.args,
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': opt.optimizer.state_dict(),
'optimizer_params': {
'_step': opt._step,
'_rate': opt._rate,
'warmup': opt.warmup,
'factor': opt.factor,
'model_size': opt.model_size
},
'metrics': metrics
}
elif constant.args.loss == "ctc":
args = {
'label2id': label2id,
'id2label': id2label,
'args': constant.args,
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': opt.optimizer.state_dict(),
'optimizer_params': {
'lr': opt.lr,
'lr_anneal': opt.lr_anneal
},
'metrics': metrics
}
else:
print("Loss is not defined")
torch.save(args, save_path)
def load_model(load_path):
"""
Loading model
args:
load_path: string
"""
checkpoint = torch.load(load_path)
epoch = checkpoint['epoch']
metrics = checkpoint['metrics']
if 'args' in checkpoint:
args = checkpoint['args']
label2id = checkpoint['label2id']
id2label = checkpoint['id2label']
model = init_transformer_model(args, label2id, id2label)
model.load_state_dict(checkpoint['model_state_dict'])
if args.cuda:
model = model.cuda()
opt = init_optimizer(args, model)
if opt is not None:
opt.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if constant.args.loss == "ce":
opt._step = checkpoint['optimizer_params']['_step']
opt._rate = checkpoint['optimizer_params']['_rate']
opt.warmup = checkpoint['optimizer_params']['warmup']
opt.factor = checkpoint['optimizer_params']['factor']
opt.model_size = checkpoint['optimizer_params']['model_size']
elif constant.args.loss == "ctc":
opt.lr = checkpoint['optimizer_params']['lr']
opt.lr_anneal = checkpoint['optimizer_params']['lr_anneal']
else:
print("Need to define loss type")
return model, opt, epoch, metrics, args, label2id, id2label
def init_optimizer(args, model, opt_type="noam"):
dim_input = args.dim_input
warmup = args.warmup
lr = args.lr
if opt_type == "noam":
opt = NoamOpt(dim_input, args.k_lr, warmup, torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9), min_lr=args.min_lr)
elif opt_type == "sgd":
opt = AnnealingOpt(lr, args.lr_anneal, torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, nesterov=True))
else:
opt = None
print("Optimizer is not defined")
return opt
def init_transformer_model(args, label2id, id2label):
"""
Initiate a new transformer object
"""
if args.feat_extractor == 'emb_cnn':
hidden_size = int(math.floor(
(args.sample_rate * args.window_size) / 2) + 1)
hidden_size = int(math.floor(hidden_size - 41) / 2 + 1)
hidden_size = int(math.floor(hidden_size - 21) / 2 + 1)
hidden_size *= 32
args.dim_input = hidden_size
elif args.feat_extractor == 'vgg_cnn':
hidden_size = int(math.floor((args.sample_rate * args.window_size) / 2) + 1) # 161
hidden_size = int(math.floor(int(math.floor(hidden_size)/2)/2)) * 128 # divide by 2 for maxpooling
args.dim_input = hidden_size
else:
print("the model is initialized without feature extractor")
num_layers = args.num_layers
num_heads = args.num_heads
dim_model = args.dim_model
dim_key = args.dim_key
dim_value = args.dim_value
dim_input = args.dim_input
dim_inner = args.dim_inner
dim_emb = args.dim_emb
src_max_len = args.src_max_len
tgt_max_len = args.tgt_max_len
dropout = args.dropout
emb_trg_sharing = args.emb_trg_sharing
feat_extractor = args.feat_extractor
encoder = Encoder(num_layers, num_heads=num_heads, dim_model=dim_model, dim_key=dim_key,
dim_value=dim_value, dim_input=dim_input, dim_inner=dim_inner, src_max_length=src_max_len, dropout=dropout)
decoder = Decoder(id2label, num_src_vocab=len(label2id), num_trg_vocab=len(label2id), num_layers=num_layers, num_heads=num_heads,
dim_emb=dim_emb, dim_model=dim_model, dim_inner=dim_inner, dim_key=dim_key, dim_value=dim_value, trg_max_length=tgt_max_len, dropout=dropout, emb_trg_sharing=emb_trg_sharing)
model = Transformer(encoder, decoder, feat_extractor=feat_extractor)
if args.parallel:
device_ids = args.device_ids
if constant.args.device_ids:
print("load with device_ids", constant.args.device_ids)
model = nn.DataParallel(model, device_ids=constant.args.device_ids)
else:
model = nn.DataParallel(model)
return model
def load_audio(path):
sound, _ = torchaudio.load(path, normalization=True)
sound = sound.numpy().T
if len(sound.shape) > 1:
if sound.shape[1] == 1:
sound = sound.squeeze()
else:
sound = sound.mean(axis=1) # multiple channels, average
return sound
def get_audio_length(path):
output = subprocess.check_output(
['soxi -D "%s"' % path.strip()], shell=True)
return float(output)
def audio_with_sox(path, sample_rate, start_time, end_time):
"""
crop and resample the recording with sox and loads it.
"""
with NamedTemporaryFile(suffix=".wav") as tar_file:
tar_filename = tar_file.name
sox_params = "sox "{}" -r {} -c 1 -b 16 -e si {} trim {} ={} >/dev/null 2>&1".format(path, sample_rate,
tar_filename, start_time,
end_time)
os.system(sox_params)
y = load_audio(tar_filename)
return y
def augment_audio_with_sox(path, sample_rate, tempo, gain):
"""
Changes tempo and gain of the recording with sox and loads it.
"""
with NamedTemporaryFile(suffix=".wav") as augmented_file:
augmented_filename = augmented_file.name
sox_augment_params = ["tempo", "{:.3f}".format(
tempo), "gain", "{:.3f}".format(gain)]
sox_params = "sox "{}" -r {} -c 1 -b 16 -e si {} {} >/dev/null 2>&1".format(
path, sample_rate, augmented_filename, " ".join(sox_augment_params))
os.system(sox_params)
y = load_audio(augmented_filename)
return y
def load_randomly_augmented_audio(path, sample_rate=16000, tempo_range=(0.85, 1.15), gain_range=(-6, 8)):
"""
Picks tempo and gain uniformly, applies it to the utterance by using sox utility.
Returns the augmented utterance.
"""
low_tempo, high_tempo = tempo_range
tempo_value = np.random.uniform(low=low_tempo, high=high_tempo)
low_gain, high_gain = gain_range
gain_value = np.random.uniform(low=low_gain, high=high_gain)
audio = augment_audio_with_sox(path=path, sample_rate=sample_rate,
tempo=tempo_value, gain=gain_value)
return audio
class AudioParser(object):
def parse_transcript(self, transcript_path):
"""
:param transcript_path: Path where transcript is stored from the manifest file
:return: Transcript in training/testing format
"""
raise NotImplementedError
def parse_audio(self, audio_path):
"""
:param audio_path: Path where audio is stored from the manifest file
:return: Audio in training/testing format
"""
raise NotImplementedError
class SpectrogramParser(AudioParser):
def __init__(self, audio_conf, normalize=False, augment=False):
"""
Parses audio file into spectrogram with optional normalization and various augmentations
:param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds
:param normalize(default False): Apply standard mean and deviation normalization to audio tensor
:param augment(default False): Apply random tempo and gain perturbations
"""
super(SpectrogramParser, self).__init__()
self.window_stride = audio_conf['window_stride']
self.window_size = audio_conf['window_size']
self.sample_rate = audio_conf['sample_rate']
self.window = windows.get(audio_conf['window'], windows['hamming'])
self.normalize = normalize
self.augment = augment
self.noiseInjector = NoiseInjection(audio_conf['noise_dir'], self.sample_rate,
audio_conf['noise_levels']) if audio_conf.get(
'noise_dir') is not None else None
self.noise_prob = audio_conf.get('noise_prob')
def parse_audio(self, audio_path):
if self.augment:
y = load_randomly_augmented_audio(audio_path, self.sample_rate)
else:
y = load_audio(audio_path)
if self.noiseInjector:
logging.info("inject noise")
add_noise = np.random.binomial(1, self.noise_prob)
if add_noise:
y = self.noiseInjector.inject_noise(y)
n_fft = int(self.sample_rate * self.window_size)
win_length = n_fft
hop_length = int(self.sample_rate * self.window_stride)
# Short-time Fourier transform (STFT)
D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, window=self.window)
spect, phase = librosa.magphase(D)
# S = log(S+1)
spect = np.log1p(spect)
spect = torch.FloatTensor(spect)
if self.normalize:
mean = spect.mean()
std = spect.std()
spect.add_(-mean)
spect.div_(std)
return spect
def parse_transcript(self, transcript_path):
raise NotImplementedError
class SpectrogramDataset(Dataset, SpectrogramParser):
def __init__(self, audio_conf, manifest_filepath_list, label2id, normalize=False, augment=False):
"""
Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by
a comma. Each new line is a different sample. Example below:
/path/to/audio.wav,/path/to/audio.txt
...
:param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds
:param manifest_filepath: Path to manifest csv as describe above
:param labels: String containing all the possible characters to map to
:param normalize: Apply standard mean and deviation normalization to audio tensor
:param augment(default False): Apply random tempo and gain perturbations
"""
self.max_size = 0
self.ids_list = []
for i in range(len(manifest_filepath_list)):
manifest_filepath = manifest_filepath_list[i]
with open(manifest_filepath) as f:
ids = f.readlines()
ids = [x.strip().split(',') for x in ids]
self.ids_list.append(ids)
self.max_size = max(len(ids), self.max_size)
self.manifest_filepath_list = manifest_filepath_list
self.label2id = label2id
super(SpectrogramDataset, self).__init__(
audio_conf, normalize, augment)
def __getitem__(self, index):
random_id = random.randint(0, len(self.ids_list)-1)
ids = self.ids_list[random_id]
sample = ids[index % len(ids)]
audio_path, transcript_path = sample[0], sample[1]
spect = self.parse_audio(audio_path)[:,:constant.args.src_max_len]
transcript = self.parse_transcript(transcript_path)
return spect, transcript
def parse_transcript(self, transcript_path):
with open(transcript_path, 'r', encoding='utf8') as transcript_file:
transcript = constant.SOS_CHAR + transcript_file.read().replace('\n', '').lower() + constant.EOS_CHAR
transcript = list(
filter(None, [self.label2id.get(x) for x in list(transcript)]))
return transcript
def __len__(self):
return self.max_size
class NoiseInjection(object):
def __init__(self,
path=None,
sample_rate=16000,
noise_levels=(0, 0.5)):
"""
Adds noise to an input signal with specific SNR. Higher the noise level, the more noise added.
Modified code from https://github.com/willfrey/audio/blob/master/torchaudio/transforms.py
"""
if not os.path.exists(path):
print("Directory doesn't exist: {}".format(path))
raise IOError
self.paths = path is not None and librosa.util.find_files(path)
self.sample_rate = sample_rate
self.noise_levels = noise_levels
def inject_noise(self, data):
noise_path = np.random.choice(self.paths)
noise_level = np.random.uniform(*self.noise_levels)
return self.inject_noise_sample(data, noise_path, noise_level)
def inject_noise_sample(self, data, noise_path, noise_level):
noise_len = get_audio_length(noise_path)
data_len = len(data) / self.sample_rate
noise_start = np.random.rand() * (noise_len - data_len)
noise_end = noise_start + data_len
noise_dst = audio_with_sox(
noise_path, self.sample_rate, noise_start, noise_end)
assert len(data) == len(noise_dst)
noise_energy = np.sqrt(noise_dst.dot(noise_dst) / noise_dst.size)
data_energy = np.sqrt(data.dot(data) / data.size)
data += noise_level * noise_dst * data_energy / noise_energy
return data
def _collate_fn(batch):
def func(p):
return p[0].size(1)
def func_tgt(p):
return len(p[1])
# descending sorted
batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True)
max_seq_len = max(batch, key=func)[0].size(1)
freq_size = max(batch, key=func)[0].size(0)
max_tgt_len = len(max(batch, key=func_tgt)[1])
inputs = torch.zeros(len(batch), 1, freq_size, max_seq_len)
input_sizes = torch.IntTensor(len(batch))
input_percentages = torch.FloatTensor(len(batch))
targets = torch.zeros(len(batch), max_tgt_len).long()
target_sizes = torch.IntTensor(len(batch))
for x in range(len(batch)):
sample = batch[x]
input_data = sample[0]
target = sample[1]
seq_length = input_data.size(1)
input_sizes[x] = seq_length
inputs[x][0].narrow(1, 0, seq_length).copy_(input_data)
input_percentages[x] = seq_length / float(max_seq_len)
target_sizes[x] = len(target)
targets[x][:len(target)] = torch.IntTensor(target)
return inputs, targets, input_percentages, input_sizes, target_sizes
class AudioDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super(AudioDataLoader, self).__init__(*args, **kwargs)
self.collate_fn = _collate_fn
class BucketingSampler(Sampler):
def __init__(self, data_source, batch_size=1):
"""
Samples batches assuming they are in order of size to batch similarly sized samples together.
"""
super(BucketingSampler, self).__init__(data_source)
self.data_source = data_source
ids = list(range(0, len(data_source)))
self.bins = [ids[i:i + batch_size]
for i in range(0, len(ids), batch_size)]
def __iter__(self):
for ids in self.bins:
np.random.shuffle(ids)
yield ids
def __len__(self):
return len(self.bins)
def shuffle(self, epoch):
np.random.shuffle(self.bins)
if __name__ == '__main__':
args = constant.args
print("="*50)
print("THE EXPERIMENT LOG IS SAVED IN: " + "log/" + args.name)
print("TRAINING MANIFEST: ", args.train_manifest_list)
print("VALID MANIFEST: ", args.valid_manifest_list)
print("TEST MANIFEST: ", args.test_manifest_list)
print("="*50)
if not os.path.exists("./log"):
os.mkdir("./log")
logging.basicConfig(filename="log/" + args.name, filemode='w+', format='%(asctime)s - %(message)s', level=logging.INFO)
audio_conf = dict(sample_rate=args.sample_rate,
window_size=args.window_size,
window_stride=args.window_stride,
window=args.window,
noise_dir=args.noise_dir,
noise_prob=args.noise_prob,
noise_levels=(args.noise_min, args.noise_max))
logging.info(audio_conf)
with open(args.labels_path) as label_file:
labels = str(''.join(json.load(label_file)))
# add PAD_CHAR, SOS_CHAR, EOS_CHAR
labels = constant.PAD_CHAR + constant.SOS_CHAR + constant.EOS_CHAR + labels
label2id, id2label = {}, {}
count = 0
for i in range(len(labels)):
if labels[i] not in label2id:
label2id[labels[i]] = count
id2label[count] = labels[i]
count += 1
else:
print("multiple label: ", labels[i])
# label2id = dict([(labels[i], i) for i in range(len(labels))])
# id2label = dict([(i, labels[i]) for i in range(len(labels))])
train_data = SpectrogramDataset(audio_conf, manifest_filepath_list=args.train_manifest_list, label2id=label2id, normalize=True, augment=args.augment)
train_sampler = BucketingSampler(train_data, batch_size=args.batch_size)
train_loader = AudioDataLoader(
train_data, num_workers=args.num_workers, batch_sampler=train_sampler)
valid_loader_list, test_loader_list = [], []
for i in range(len(args.valid_manifest_list)):
valid_data = SpectrogramDataset(audio_conf, manifest_filepath_list=[args.valid_manifest_list[i]], label2id=label2id,
normalize=True, augment=False)
valid_loader = AudioDataLoader(valid_data, num_workers=args.num_workers, batch_size=args.batch_size)
valid_loader_list.append(valid_loader)
for i in range(len(args.test_manifest_list)):
test_data = SpectrogramDataset(audio_conf, manifest_filepath_list=[args.test_manifest_list[i]], label2id=label2id,
normalize=True, augment=False)
test_loader = AudioDataLoader(test_data, num_workers=args.num_workers)
test_loader_list.append(test_loader)
start_epoch = 0
metrics = None
loaded_args = None
print(constant.args.continue_from)
if constant.args.continue_from != "":
logging.info("Continue from checkpoint: " + constant.args.continue_from)
model, opt, epoch, metrics, loaded_args, label2id, id2label = load_model(
constant.args.continue_from)
start_epoch = epoch # index starts from zero
verbose = constant.args.verbose
if loaded_args != None:
# Unwrap nn.DataParallel
if loaded_args.parallel:
logging.info("unwrap from DataParallel")
model = model.module
# Parallelize the batch
if args.parallel:
model = nn.DataParallel(model, device_ids=args.device_ids)
else:
if constant.args.model == "TRFS":
model = init_transformer_model(constant.args, label2id, id2label)
opt = init_optimizer(constant.args, model, "noam")
else:
logging.info("The model is not supported, check args --h")
loss_type = args.loss
if constant.USE_CUDA:
model = model.cuda(0)
logging.info(model)
num_epochs = constant.args.epochs
trainer = Trainer()
trainer.train(model, train_loader, train_sampler, valid_loader_list, opt, loss_type, start_epoch, num_epochs, label2id, id2label, metrics)
参考:【1】End2End-ASR-Pytorch - 深度学习 - Hello Mat - Powered by Discuz! (halcom.cn)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。