当前位置:   article > 正文

自然语言处理(NLP)——前馈网络

自然语言处理(NLP)——前馈网络

一、前馈网络概要

1、定义

前馈神经网络(FNN)是一种最简单的神经网络结构,其中信息在一个方向上流动,从输入层到输出层,不会有反馈环路。该网络由多个神经元组成,这些神经元通常被组织成几个层,包括输入层、隐藏层和输出层。

2、特点

简单直观:前馈神经网络的结构比较简单明了,信息只朝一个方向传播,便于理解和实现。

高效性能:尽管简单,前馈神经网络在许多任务中表现出色,尤其在处理大规模数据和复杂模式时有很高的性能表现。

可扩展性:前馈神经网络可以很容易地扩展到更深、更宽的网络结构,以适应不同的任务和数据。

并行计算:由于前馈神经网络的结构,其各个神经元之间的计算是独立的,可以实现高效的并行计算加速。

适用性广泛:前馈神经网络广泛应用于图像识别、语音识别、文本分类、回归分析等多个领域,具有很好的适用性和灵活性。

易于训练:前馈神经网络通常使用反向传播算法进行训练,这是一种有效的优化算法,可以通过调整网络的权重和偏差来最小化损失函数,提高网络的性能。

二、自然语言处理与前馈神经网络

前馈神经网络的结构简单,易于实现和训练并且能够处理高维数据,而在自然语言处理中文本数据通常是高维的,故前馈神经网络十分适用于完成各种自然语言处理的任务。现在前馈神经网络被广泛应用于各种自然语言处理任务,包括但不限于以下几个方面:

  1. 文本分类:前馈神经网络可以用于文本分类任务,例如情感分析、垃圾邮件过滤、新闻分类等。输入文本数据经过处理后,通过前馈神经网络可以实现对文本进行自动分类。

  2. 语言模型:前馈神经网络可以用于构建语言模型,从而实现自然语言生成和语言理解任务。通过训练前馈神经网络,可以学习文本序列之间的模式和关系,从而生成自然流畅的文本或对文本进行理解。

  3. 序列标注:在词性标注、命名实体识别、文本分类等序列标注任务中,前馈神经网络也被广泛应用。通过将输入序列数据经过前馈神经网络的处理,可以实现对序列数据进行标注或分类。

  4. 机器翻译:前馈神经网络在机器翻译任务中也有应用。通过构建编码器-解码器结构的前馈神经网络,可以实现将一种语言的文本翻译成另一种语言的功能。

  5. 文本生成:前馈神经网络还可用于文本生成任务,如对话生成、摘要生成等。通过学习文本序列的模式和关系,前馈神经网络可以生成自然语言文本。

三、基于pytorch的一个简单实现

任务描述:我们将多层感知机应用于将姓氏分类到其原籍国的任务。

代码实现思路:

1、数据预处理

2、构建多层感知机模型

3、训练模型

4、预测结果

以下是具体代码部分:

1、数据预处理

数据集名为surname.csv,它从互联网上不同的姓名来源收集了了来自18个不同国家的10,000个姓氏。数据预处理的目的一是为了平衡数据集中18个国家的姓氏在数据集中的比例,均匀的比例分布有利于训练有效的模型。另外要将数据集分为三个部分:70%到训练数据集,15%到验证数据集,最后15%到测试数据集,以便跨这些部分的类标签分布具有可比性。

  1. from argparse import Namespace
  2. from collections import Counter
  3. import json
  4. import os
  5. import string
  6. import numpy as np
  7. import pandas as pd
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. import torch.optim as optim
  12. from torch.utils.data import Dataset, DataLoader
  13. from tqdm import tqdm_notebook
  14. class Vocabulary(object):
  15. """Class to process text and extract vocabulary for mapping"""
  16. def __init__(self, token_to_idx=None, add_unk=True, unk_token="<UNK>"):
  17. """
  18. Args:
  19. token_to_idx (dict): a pre-existing map of tokens to indices
  20. add_unk (bool): a flag that indicates whether to add the UNK token
  21. unk_token (str): the UNK token to add into the Vocabulary
  22. """
  23. if token_to_idx is None:
  24. token_to_idx = {}
  25. self._token_to_idx = token_to_idx
  26. self._idx_to_token = {idx: token
  27. for token, idx in self._token_to_idx.items()}
  28. self._add_unk = add_unk
  29. self._unk_token = unk_token
  30. self.unk_index = -1
  31. if add_unk:
  32. self.unk_index = self.add_token(unk_token)
  33. def to_serializable(self):
  34. """ returns a dictionary that can be serialized """
  35. return {'token_to_idx': self._token_to_idx,
  36. 'add_unk': self._add_unk,
  37. 'unk_token': self._unk_token}
  38. @classmethod
  39. def from_serializable(cls, contents):
  40. """ instantiates the Vocabulary from a serialized dictionary """
  41. return cls(**contents)
  42. def add_token(self, token):
  43. """Update mapping dicts based on the token.
  44. Args:
  45. token (str): the item to add into the Vocabulary
  46. Returns:
  47. index (int): the integer corresponding to the token
  48. """
  49. try:
  50. index = self._token_to_idx[token]
  51. except KeyError:
  52. index = len(self._token_to_idx)
  53. self._token_to_idx[token] = index
  54. self._idx_to_token[index] = token
  55. return index
  56. def add_many(self, tokens):
  57. """Add a list of tokens into the Vocabulary
  58. Args:
  59. tokens (list): a list of string tokens
  60. Returns:
  61. indices (list): a list of indices corresponding to the tokens
  62. """
  63. return [self.add_token(token) for token in tokens]
  64. def lookup_token(self, token):
  65. """Retrieve the index associated with the token
  66. or the UNK index if token isn't present.
  67. Args:
  68. token (str): the token to look up
  69. Returns:
  70. index (int): the index corresponding to the token
  71. Notes:
  72. `unk_index` needs to be >=0 (having been added into the Vocabulary)
  73. for the UNK functionality
  74. """
  75. if self.unk_index >= 0:
  76. return self._token_to_idx.get(token, self.unk_index)
  77. else:
  78. return self._token_to_idx[token]
  79. def lookup_index(self, index):
  80. """Return the token associated with the index
  81. Args:
  82. index (int): the index to look up
  83. Returns:
  84. token (str): the token corresponding to the index
  85. Raises:
  86. KeyError: if the index is not in the Vocabulary
  87. """
  88. if index not in self._idx_to_token:
  89. raise KeyError("the index (%d) is not in the Vocabulary" % index)
  90. return self._idx_to_token[index]
  91. def __str__(self):
  92. return "<Vocabulary(size=%d)>" % len(self)
  93. def __len__(self):
  94. return len(self._token_to_idx)
  95. class SurnameVectorizer(object):
  96. """ The Vectorizer which coordinates the Vocabularies and puts them to use"""
  97. def __init__(self, surname_vocab, nationality_vocab):
  98. """
  99. Args:
  100. surname_vocab (Vocabulary): maps characters to integers
  101. nationality_vocab (Vocabulary): maps nationalities to integers
  102. """
  103. self.surname_vocab = surname_vocab
  104. self.nationality_vocab = nationality_vocab
  105. def vectorize(self, surname):
  106. """
  107. Args:
  108. surname (str): the surname
  109. Returns:
  110. one_hot (np.ndarray): a collapsed one-hot encoding
  111. """
  112. vocab = self.surname_vocab
  113. one_hot = np.zeros(len(vocab), dtype=np.float32)
  114. for token in surname:
  115. one_hot[vocab.lookup_token(token)] = 1
  116. return one_hot
  117. @classmethod
  118. def from_dataframe(cls, surname_df):
  119. """Instantiate the vectorizer from the dataset dataframe
  120. Args:
  121. surname_df (pandas.DataFrame): the surnames dataset
  122. Returns:
  123. an instance of the SurnameVectorizer
  124. """
  125. surname_vocab = Vocabulary(unk_token="@")
  126. nationality_vocab = Vocabulary(add_unk=False)
  127. for index, row in surname_df.iterrows():
  128. for letter in row.surname:
  129. surname_vocab.add_token(letter)
  130. nationality_vocab.add_token(row.nationality)
  131. return cls(surname_vocab, nationality_vocab)
  132. @classmethod
  133. def from_serializable(cls, contents):
  134. surname_vocab = Vocabulary.from_serializable(contents['surname_vocab'])
  135. nationality_vocab = Vocabulary.from_serializable(contents['nationality_vocab'])
  136. return cls(surname_vocab=surname_vocab, nationality_vocab=nationality_vocab)
  137. def to_serializable(self):
  138. return {'surname_vocab': self.surname_vocab.to_serializable(),
  139. 'nationality_vocab': self.nationality_vocab.to_serializable()}
  140. class SurnameDataset(Dataset):
  141. def __init__(self, surname_df, vectorizer):
  142. """
  143. Args:
  144. surname_df (pandas.DataFrame): the dataset
  145. vectorizer (SurnameVectorizer): vectorizer instatiated from dataset
  146. """
  147. self.surname_df = surname_df
  148. self._vectorizer = vectorizer
  149. self.train_df = self.surname_df[self.surname_df.split=='train']
  150. self.train_size = len(self.train_df)
  151. self.val_df = self.surname_df[self.surname_df.split=='val']
  152. self.validation_size = len(self.val_df)
  153. self.test_df = self.surname_df[self.surname_df.split=='test']
  154. self.test_size = len(self.test_df)
  155. self._lookup_dict = {'train': (self.train_df, self.train_size),
  156. 'val': (self.val_df, self.validation_size),
  157. 'test': (self.test_df, self.test_size)}
  158. self.set_split('train')
  159. # Class weights
  160. class_counts = surname_df.nationality.value_counts().to_dict()
  161. def sort_key(item):
  162. return self._vectorizer.nationality_vocab.lookup_token(item[0])
  163. sorted_counts = sorted(class_counts.items(), key=sort_key)
  164. frequencies = [count for _, count in sorted_counts]
  165. self.class_weights = 1.0 / torch.tensor(frequencies, dtype=torch.float32)
  166. @classmethod
  167. def load_dataset_and_make_vectorizer(cls, surname_csv):
  168. """Load dataset and make a new vectorizer from scratch
  169. Args:
  170. surname_csv (str): location of the dataset
  171. Returns:
  172. an instance of SurnameDataset
  173. """
  174. surname_df = pd.read_csv(surname_csv)
  175. train_surname_df = surname_df[surname_df.split=='train']
  176. return cls(surname_df, SurnameVectorizer.from_dataframe(train_surname_df))
  177. @classmethod
  178. def load_dataset_and_load_vectorizer(cls, surname_csv, vectorizer_filepath):
  179. """Load dataset and the corresponding vectorizer.
  180. Used in the case in the vectorizer has been cached for re-use
  181. Args:
  182. surname_csv (str): location of the dataset
  183. vectorizer_filepath (str): location of the saved vectorizer
  184. Returns:
  185. an instance of SurnameDataset
  186. """
  187. surname_df = pd.read_csv(surname_csv)
  188. vectorizer = cls.load_vectorizer_only(vectorizer_filepath)
  189. return cls(surname_df, vectorizer)
  190. @staticmethod
  191. def load_vectorizer_only(vectorizer_filepath):
  192. """a static method for loading the vectorizer from file
  193. Args:
  194. vectorizer_filepath (str): the location of the serialized vectorizer
  195. Returns:
  196. an instance of SurnameVectorizer
  197. """
  198. with open(vectorizer_filepath) as fp:
  199. return SurnameVectorizer.from_serializable(json.load(fp))
  200. def save_vectorizer(self, vectorizer_filepath):
  201. """saves the vectorizer to disk using json
  202. Args:
  203. vectorizer_filepath (str): the location to save the vectorizer
  204. """
  205. with open(vectorizer_filepath, "w") as fp:
  206. json.dump(self._vectorizer.to_serializable(), fp)
  207. def get_vectorizer(self):
  208. """ returns the vectorizer """
  209. return self._vectorizer
  210. def set_split(self, split="train"):
  211. """ selects the splits in the dataset using a column in the dataframe """
  212. self._target_split = split
  213. self._target_df, self._target_size = self._lookup_dict[split]
  214. def __len__(self):
  215. return self._target_size
  216. def __getitem__(self, index):
  217. """the primary entry point method for PyTorch datasets
  218. Args:
  219. index (int): the index to the data point
  220. Returns:
  221. a dictionary holding the data point's:
  222. features (x_surname)
  223. label (y_nationality)
  224. """
  225. row = self._target_df.iloc[index]
  226. surname_vector = \
  227. self._vectorizer.vectorize(row.surname)
  228. nationality_index = \
  229. self._vectorizer.nationality_vocab.lookup_token(row.nationality)
  230. return {'x_surname': surname_vector,
  231. 'y_nationality': nationality_index}
  232. def get_num_batches(self, batch_size):
  233. """Given a batch size, return the number of batches in the dataset
  234. Args:
  235. batch_size (int)
  236. Returns:
  237. number of batches in the dataset
  238. """
  239. return len(self) // batch_size
  240. def generate_batches(dataset, batch_size, shuffle=True,
  241. drop_last=True, device="cpu"):
  242. """
  243. A generator function which wraps the PyTorch DataLoader. It will
  244. ensure each tensor is on the write device location.
  245. """
  246. dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
  247. shuffle=shuffle, drop_last=drop_last)
  248. for data_dict in dataloader:
  249. out_data_dict = {}
  250. for name, tensor in data_dict.items():
  251. out_data_dict[name] = data_dict[name].to(device)
  252. yield out_data_dict

2、构建多层感知机模型

  1. class SurnameClassifier(nn.Module):
  2. """ A 2-layer Multilayer Perceptron for classifying surnames """
  3. def __init__(self, input_dim, hidden_dim, output_dim):
  4. """
  5. Args:
  6. input_dim (int): the size of the input vectors
  7. hidden_dim (int): the output size of the first Linear layer
  8. output_dim (int): the output size of the second Linear layer
  9. """
  10. super(SurnameClassifier, self).__init__()
  11. self.fc1 = nn.Linear(input_dim, hidden_dim)
  12. self.fc2 = nn.Linear(hidden_dim, output_dim)
  13. def forward(self, x_in, apply_softmax=False):
  14. """The forward pass of the classifier
  15. Args:
  16. x_in (torch.Tensor): an input data tensor.
  17. x_in.shape should be (batch, input_dim)
  18. apply_softmax (bool): a flag for the softmax activation
  19. should be false if used with the Cross Entropy losses
  20. Returns:
  21. the resulting tensor. tensor.shape should be (batch, output_dim)
  22. """
  23. intermediate_vector = F.relu(self.fc1(x_in))
  24. prediction_vector = self.fc2(intermediate_vector)
  25. if apply_softmax:
  26. prediction_vector = F.softmax(prediction_vector, dim=1)
  27. return prediction_vector
  28. def make_train_state(args):
  29. return {'stop_early': False,
  30. 'early_stopping_step': 0,
  31. 'early_stopping_best_val': 1e8,
  32. 'learning_rate': args.learning_rate,
  33. 'epoch_index': 0,
  34. 'train_loss': [],
  35. 'train_acc': [],
  36. 'val_loss': [],
  37. 'val_acc': [],
  38. 'test_loss': -1,
  39. 'test_acc': -1,
  40. 'model_filename': args.model_state_file}
  41. def update_train_state(args, model, train_state):
  42. """Handle the training state updates.
  43. Components:
  44. - Early Stopping: Prevent overfitting.
  45. - Model Checkpoint: Model is saved if the model is better
  46. :param args: main arguments
  47. :param model: model to train
  48. :param train_state: a dictionary representing the training state values
  49. :returns:
  50. a new train_state
  51. """
  52. # Save one model at least
  53. if train_state['epoch_index'] == 0:
  54. torch.save(model.state_dict(), train_state['model_filename'])
  55. train_state['stop_early'] = False
  56. # Save model if performance improved
  57. elif train_state['epoch_index'] >= 1:
  58. loss_tm1, loss_t = train_state['val_loss'][-2:]
  59. # If loss worsened
  60. if loss_t >= train_state['early_stopping_best_val']:
  61. # Update step
  62. train_state['early_stopping_step'] += 1
  63. # Loss decreased
  64. else:
  65. # Save the best model
  66. if loss_t < train_state['early_stopping_best_val']:
  67. torch.save(model.state_dict(), train_state['model_filename'])
  68. # Reset early stopping step
  69. train_state['early_stopping_step'] = 0
  70. # Stop early ?
  71. train_state['stop_early'] = \
  72. train_state['early_stopping_step'] >= args.early_stopping_criteria
  73. return train_state
  74. def compute_accuracy(y_pred, y_target):
  75. _, y_pred_indices = y_pred.max(dim=1)
  76. n_correct = torch.eq(y_pred_indices, y_target).sum().item()
  77. return n_correct / len(y_pred_indices) * 100
  78. def set_seed_everywhere(seed, cuda):
  79. np.random.seed(seed)
  80. torch.manual_seed(seed)
  81. if cuda:
  82. torch.cuda.manual_seed_all(seed)
  83. def handle_dirs(dirpath):
  84. if not os.path.exists(dirpath):
  85. os.makedirs(dirpath)

3、训练模型

  1. args = Namespace(
  2. # Data and path information
  3. surname_csv="surnames_with_splits.csv",
  4. vectorizer_file="vectorizer.json",
  5. model_state_file="model.pth",
  6. save_dir="model_storage/ch4/surname_mlp",
  7. # Model hyper parameters
  8. hidden_dim=300,
  9. # Training hyper parameters
  10. seed=1337,
  11. num_epochs=100,
  12. early_stopping_criteria=5,
  13. learning_rate=0.001,
  14. batch_size=64,
  15. # Runtime options
  16. cuda=False,
  17. reload_from_files=False,
  18. expand_filepaths_to_save_dir=True,
  19. )
  20. if args.expand_filepaths_to_save_dir:
  21. args.vectorizer_file = os.path.join(args.save_dir,
  22. args.vectorizer_file)
  23. args.model_state_file = os.path.join(args.save_dir,
  24. args.model_state_file)
  25. print("Expanded filepaths: ")
  26. print("\t{}".format(args.vectorizer_file))
  27. print("\t{}".format(args.model_state_file))
  28. # Check CUDA
  29. if not torch.cuda.is_available():
  30. args.cuda = False
  31. args.device = torch.device("cuda" if args.cuda else "cpu")
  32. print("Using CUDA: {}".format(args.cuda))
  33. # Set seed for reproducibility
  34. set_seed_everywhere(args.seed, args.cuda)
  35. # handle dirs
  36. handle_dirs(args.save_dir)
  37. if args.reload_from_files:
  38. # training from a checkpoint
  39. print("Reloading!")
  40. dataset = SurnameDataset.load_dataset_and_load_vectorizer(args.surname_csv,
  41. args.vectorizer_file)
  42. else:
  43. # create dataset and vectorizer
  44. print("Creating fresh!")
  45. dataset = SurnameDataset.load_dataset_and_make_vectorizer(args.surname_csv)
  46. dataset.save_vectorizer(args.vectorizer_file)
  47. vectorizer = dataset.get_vectorizer()
  48. classifier = SurnameClassifier(input_dim=len(vectorizer.surname_vocab),
  49. hidden_dim=args.hidden_dim,
  50. output_dim=len(vectorizer.nationality_vocab))
  51. classifier = classifier.to(args.device)
  52. dataset.class_weights = dataset.class_weights.to(args.device)
  53. loss_func = nn.CrossEntropyLoss(dataset.class_weights)
  54. optimizer = optim.Adam(classifier.parameters(), lr=args.learning_rate)
  55. scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
  56. mode='min', factor=0.5,
  57. patience=1)
  58. train_state = make_train_state(args)
  59. epoch_bar = tqdm_notebook(desc='training routine',
  60. total=args.num_epochs,
  61. position=0)
  62. dataset.set_split('train')
  63. train_bar = tqdm_notebook(desc='split=train',
  64. total=dataset.get_num_batches(args.batch_size),
  65. position=1,
  66. leave=True)
  67. dataset.set_split('val')
  68. val_bar = tqdm_notebook(desc='split=val',
  69. total=dataset.get_num_batches(args.batch_size),
  70. position=1,
  71. leave=True)
  72. try:
  73. for epoch_index in range(args.num_epochs):
  74. train_state['epoch_index'] = epoch_index
  75. # Iterate over training dataset
  76. # setup: batch generator, set loss and acc to 0, set train mode on
  77. dataset.set_split('train')
  78. batch_generator = generate_batches(dataset,
  79. batch_size=args.batch_size,
  80. device=args.device)
  81. running_loss = 0.0
  82. running_acc = 0.0
  83. classifier.train()
  84. for batch_index, batch_dict in enumerate(batch_generator):
  85. # the training routine is these 5 steps:
  86. # --------------------------------------
  87. # step 1. zero the gradients
  88. optimizer.zero_grad()
  89. # step 2. compute the output
  90. y_pred = classifier(batch_dict['x_surname'])
  91. # step 3. compute the loss
  92. loss = loss_func(y_pred, batch_dict['y_nationality'])
  93. loss_t = loss.item()
  94. running_loss += (loss_t - running_loss) / (batch_index + 1)
  95. # step 4. use loss to produce gradients
  96. loss.backward()
  97. # step 5. use optimizer to take gradient step
  98. optimizer.step()
  99. # -----------------------------------------
  100. # compute the accuracy
  101. acc_t = compute_accuracy(y_pred, batch_dict['y_nationality'])
  102. running_acc += (acc_t - running_acc) / (batch_index + 1)
  103. # update bar
  104. train_bar.set_postfix(loss=running_loss, acc=running_acc,
  105. epoch=epoch_index)
  106. train_bar.update()
  107. train_state['train_loss'].append(running_loss)
  108. train_state['train_acc'].append(running_acc)
  109. # Iterate over val dataset
  110. # setup: batch generator, set loss and acc to 0; set eval mode on
  111. dataset.set_split('val')
  112. batch_generator = generate_batches(dataset,
  113. batch_size=args.batch_size,
  114. device=args.device)
  115. running_loss = 0.
  116. running_acc = 0.
  117. classifier.eval()
  118. for batch_index, batch_dict in enumerate(batch_generator):
  119. # compute the output
  120. y_pred = classifier(batch_dict['x_surname'])
  121. # step 3. compute the loss
  122. loss = loss_func(y_pred, batch_dict['y_nationality'])
  123. loss_t = loss.to("cpu").item()
  124. running_loss += (loss_t - running_loss) / (batch_index + 1)
  125. # compute the accuracy
  126. acc_t = compute_accuracy(y_pred, batch_dict['y_nationality'])
  127. running_acc += (acc_t - running_acc) / (batch_index + 1)
  128. val_bar.set_postfix(loss=running_loss, acc=running_acc,
  129. epoch=epoch_index)
  130. val_bar.update()
  131. train_state['val_loss'].append(running_loss)
  132. train_state['val_acc'].append(running_acc)
  133. train_state = update_train_state(args=args, model=classifier,
  134. train_state=train_state)
  135. scheduler.step(train_state['val_loss'][-1])
  136. if train_state['stop_early']:
  137. break
  138. train_bar.n = 0
  139. val_bar.n = 0
  140. epoch_bar.update()
  141. except KeyboardInterrupt:
  142. print("Exiting loop")

4、预测结果

  1. # compute the loss & accuracy on the test set using the best available model
  2. classifier.load_state_dict(torch.load(train_state['model_filename']))
  3. classifier = classifier.to(args.device)
  4. dataset.class_weights = dataset.class_weights.to(args.device)
  5. loss_func = nn.CrossEntropyLoss(dataset.class_weights)
  6. dataset.set_split('test')
  7. batch_generator = generate_batches(dataset,
  8. batch_size=args.batch_size,
  9. device=args.device)
  10. running_loss = 0.
  11. running_acc = 0.
  12. classifier.eval()
  13. for batch_index, batch_dict in enumerate(batch_generator):
  14. # compute the output
  15. y_pred = classifier(batch_dict['x_surname'])
  16. # compute the loss
  17. loss = loss_func(y_pred, batch_dict['y_nationality'])
  18. loss_t = loss.item()
  19. running_loss += (loss_t - running_loss) / (batch_index + 1)
  20. # compute the accuracy
  21. acc_t = compute_accuracy(y_pred, batch_dict['y_nationality'])
  22. running_acc += (acc_t - running_acc) / (batch_index + 1)
  23. train_state['test_loss'] = running_loss
  24. train_state['test_acc'] = running_acc
  25. print("Test loss: {};".format(train_state['test_loss']))
  26. print("Test Accuracy: {}".format(train_state['test_acc']))

  1. def predict_nationality(surname, classifier, vectorizer):
  2. """Predict the nationality from a new surname
  3. Args:
  4. surname (str): the surname to classifier
  5. classifier (SurnameClassifer): an instance of the classifier
  6. vectorizer (SurnameVectorizer): the corresponding vectorizer
  7. Returns:
  8. a dictionary with the most likely nationality and its probability
  9. """
  10. vectorized_surname = vectorizer.vectorize(surname)
  11. vectorized_surname = torch.tensor(vectorized_surname).view(1, -1)
  12. result = classifier(vectorized_surname, apply_softmax=True)
  13. probability_values, indices = result.max(dim=1)
  14. index = indices.item()
  15. predicted_nationality = vectorizer.nationality_vocab.lookup_index(index)
  16. probability_value = probability_values.item()
  17. return {'nationality': predicted_nationality, 'probability': probability_value}
  18. new_surname = input("Enter a surname to classify: ")
  19. classifier = classifier.to("cpu")
  20. prediction = predict_nationality(new_surname, classifier, vectorizer)
  21. print("{} -> {} (p={:0.2f})".format(new_surname,
  22. prediction['nationality'],
  23. prediction['probability']))

  1. vectorizer.nationality_vocab.lookup_index(8)
  2. def predict_topk_nationality(name, classifier, vectorizer, k=5):
  3. vectorized_name = vectorizer.vectorize(name)
  4. vectorized_name = torch.tensor(vectorized_name).view(1, -1)
  5. prediction_vector = classifier(vectorized_name, apply_softmax=True)
  6. probability_values, indices = torch.topk(prediction_vector, k=k)
  7. # returned size is 1,k
  8. probability_values = probability_values.detach().numpy()[0]
  9. indices = indices.detach().numpy()[0]
  10. results = []
  11. for prob_value, index in zip(probability_values, indices):
  12. nationality = vectorizer.nationality_vocab.lookup_index(index)
  13. results.append({'nationality': nationality,
  14. 'probability': prob_value})
  15. return results
  16. new_surname = input("Enter a surname to classify: ")
  17. classifier = classifier.to("cpu")
  18. k = int(input("How many of the top predictions to see? "))
  19. if k > len(vectorizer.nationality_vocab):
  20. print("Sorry! That's more than the # of nationalities we have.. defaulting you to max size :)")
  21. k = len(vectorizer.nationality_vocab)
  22. predictions = predict_topk_nationality(new_surname, classifier, vectorizer, k=k)
  23. print("Top {} predictions:".format(k))
  24. print("===================")
  25. for prediction in predictions:
  26. print("{} -> {} (p={:0.2f})".format(new_surname,
  27. prediction['nationality'],
  28. prediction['probability']))

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/761289
推荐阅读
相关标签
  

闽ICP备14008679号