当前位置:   article > 正文

自用GPTneo源码_gpt neo

gpt neo
  1. #!/user/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # coding=utf-8
  4. # Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """ PyTorch GPT Neo model."""
  18. import os
  19. from typing import Optional, Tuple, Union
  20. import torch
  21. import torch.utils.checkpoint
  22. from torch import nn
  23. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  24. from transformers.activations import ACT2FN
  25. from transformers.modeling_outputs import (
  26. BaseModelOutputWithPast,
  27. BaseModelOutputWithPastAndCrossAttentions,
  28. CausalLMOutputWithCrossAttentions,
  29. CausalLMOutputWithPast,
  30. QuestionAnsweringModelOutput,
  31. SequenceClassifierOutputWithPast,
  32. TokenClassifierOutput,
  33. )
  34. from transformers.modeling_utils import PreTrainedModel
  35. from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
  36. from transformers.models.gpt_neo.configuration_gpt_neo import GPTNeoConfig
  37. logger = logging.get_logger(__name__)
  38. _CONFIG_FOR_DOC = "GPTNeoConfig"
  39. GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST = [
  40. "EleutherAI/gpt-neo-1.3B",
  41. # See all GPTNeo models at https://huggingface.co/models?filter=gpt_neo
  42. ]
  43. _CHECKPOINT_FOR_DOC = "EleutherAI/gpt-neo-1.3B"
  44. def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path):
  45. """Load tf checkpoints in a pytorch model"""
  46. try:
  47. import re
  48. import tensorflow as tf
  49. except ImportError:
  50. logger.error(
  51. "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
  52. "https://www.tensorflow.org/install/ for installation instructions."
  53. )
  54. raise
  55. tf_path = os.path.abspath(gpt_neo_checkpoint_path)
  56. logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
  57. # Load weights from TF model
  58. init_vars = tf.train.list_variables(tf_path)
  59. names = []
  60. arrays = []
  61. for name, shape in init_vars:
  62. if "global_step" not in name and "adam" not in name:
  63. array = tf.train.load_variable(tf_path, name)
  64. array = tf.dtypes.cast(array.squeeze(), tf.float32).numpy()
  65. name = name.replace("attn/q", "attn/attention/q_proj/w")
  66. name = name.replace("attn/k", "attn/attention/k_proj/w")
  67. name = name.replace("attn/v", "attn/attention/v_proj/w")
  68. name = name.replace("attn/o", "attn/attention/out_proj/w")
  69. name = name.replace("norm_1", "ln_1")
  70. name = name.replace("norm_2", "ln_2")
  71. name = name.replace("attn/compute_output_bias/o_b", "attn/attention/out_proj/b")
  72. name = name.replace("conv1d_main/c_fc/kernel", "c_fc/w")
  73. name = name.replace("conv1d_main/c_fc/bias", "c_fc/b")
  74. name = name.replace("conv1d_main/c_proj/kernel", "c_proj/w")
  75. name = name.replace("conv1d_main/c_proj/bias", "c_proj/b")
  76. names.append(name)
  77. arrays.append(array)
  78. for name, array in zip(names, arrays):
  79. name = name[5:] # skip "gpt2/"
  80. name = name.split("/")
  81. pointer = model.transformer
  82. for m_name in name:
  83. if re.fullmatch(r"[A-Za-z]+\d+", m_name):
  84. scope_names = re.split(r"(\d+)", m_name)
  85. else:
  86. scope_names = [m_name]
  87. if scope_names[0] == "w" or scope_names[0] == "g":
  88. pointer = getattr(pointer, "weight")
  89. elif scope_names[0] == "b":
  90. pointer = getattr(pointer, "bias")
  91. elif scope_names[0] == "wpe" or scope_names[0] == "wte":
  92. pointer = getattr(pointer, scope_names[0])
  93. pointer = getattr(pointer, "weight")
  94. else:
  95. pointer = getattr(pointer, scope_names[0])
  96. if len(scope_names) >= 2:
  97. num = int(scope_names[1])
  98. pointer = pointer[num]
  99. if name[-1] == "w" and name[-2] in ["out_proj", "k_proj", "q_proj", "v_proj", "c_proj", "c_fc"]:
  100. array = array.transpose()
  101. if name == ["wte"]:
  102. # if vocab is padded, then trim off the padding embeddings
  103. array = array[: config.vocab_size]
  104. if pointer.shape != array.shape:
  105. raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched {name}")
  106. print(f"Initialize PyTorch weight {name}")
  107. pointer.data = torch.from_numpy(array)
  108. # init the final linear layer using word embeddings
  109. embs = model.transformer.wte.weight
  110. lin = nn.Linear(embs.size()[1], embs.size()[0], bias=False)
  111. lin.weight = embs
  112. model.set_output_embeddings(lin)
  113. return model
  114. class GPTNeoSelfAttention(nn.Module):
  115. def __init__(self, config, attention_type):
  116. super().__init__()
  117. max_positions = config.max_position_embeddings
  118. bias = torch.tril(torch.ones((max_positions, max_positions), dtype=bool)).view(
  119. 1, 1, max_positions, max_positions
  120. )
  121. # local causal self attention is a sliding window where each token can only attend to the previous
  122. # window_size tokens. This is implemented by updating the causal mask such that for each token
  123. # all other tokens are masked except the previous window_size tokens.
  124. if attention_type == "local":
  125. bias = torch.bitwise_xor(bias, torch.tril(bias, -config.window_size))
  126. self.register_buffer("bias", bias)
  127. self.register_buffer("masked_bias", torch.tensor(-1e9))
  128. self.attn_dropout = nn.Dropout(float(config.attention_dropout))
  129. self.resid_dropout = nn.Dropout(float(config.resid_dropout))
  130. self.embed_dim = config.hidden_size
  131. self.num_heads = config.num_heads
  132. self.head_dim = self.embed_dim // self.num_heads
  133. if self.head_dim * self.num_heads != self.embed_dim:
  134. raise ValueError(
  135. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  136. f" {self.num_heads})."
  137. )
  138. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  139. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  140. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  141. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
  142. def _split_heads(self, tensor, num_heads, attn_head_size):
  143. """
  144. Splits hidden_size dim into attn_head_size and num_heads
  145. """
  146. new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
  147. tensor = tensor.view(new_shape)
  148. return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
  149. def _merge_heads(self, tensor, num_heads, attn_head_size):
  150. """
  151. Merges attn_head_size dim and num_attn_heads dim into hidden_size
  152. """
  153. tensor = tensor.permute(0, 2, 1, 3).contiguous()
  154. new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
  155. return tensor.view(new_shape)
  156. def _attn(self, query, key, value, attention_mask=None, head_mask=None):
  157. # Keep the attention weights computation in fp32 to avoid overflow issues
  158. query = query.to(torch.float32)
  159. key = key.to(torch.float32)
  160. attn_weights = torch.matmul(query, key.transpose(-1, -2))
  161. query_length, key_length = query.size(-2), key.size(-2)
  162. causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
  163. mask_value = torch.finfo(attn_weights.dtype).min
  164. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
  165. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
  166. mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
  167. attn_weights = torch.where(causal_mask, attn_weights, mask_value)
  168. if attention_mask is not None:
  169. # Apply the attention mask
  170. attn_weights = attn_weights + attention_mask
  171. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  172. attn_weights = attn_weights.to(value.dtype)
  173. attn_weights = self.attn_dropout(attn_weights)
  174. # Mask heads if we want to
  175. if head_mask is not None:
  176. attn_weights = attn_weights * head_mask
  177. attn_output = torch.matmul(attn_weights, value)
  178. return attn_output, attn_weights
  179. def forward(
  180. self,
  181. hidden_states,
  182. attention_mask=None,
  183. layer_past=None,
  184. head_mask=None,
  185. use_cache=False,
  186. output_attentions=False,
  187. ):
  188. query = self.q_proj(hidden_states)
  189. key = self.k_proj(hidden_states)
  190. value = self.v_proj(hidden_states)
  191. query = self._split_heads(query, self.num_heads, self.head_dim)
  192. key = self._split_heads(key, self.num_heads, self.head_dim)
  193. value = self._split_heads(value, self.num_heads, self.head_dim)
  194. if layer_past is not None:
  195. past_key = layer_past[0]
  196. past_value = layer_past[1]
  197. key = torch.cat((past_key, key), dim=-2)
  198. value = torch.cat((past_value, value), dim=-2)
  199. if use_cache is True:
  200. present = (key, value)
  201. else:
  202. present = None
  203. attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
  204. attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
  205. attn_output = self.out_proj(attn_output)
  206. attn_output = self.resid_dropout(attn_output)
  207. outputs = (attn_output, present)
  208. if output_attentions:
  209. outputs += (attn_weights,)
  210. return outputs # a, present, (attentions)
  211. class GPTNeoAttention(nn.Module):
  212. def __init__(self, config, layer_id=0):
  213. super().__init__()
  214. self.layer_id = layer_id
  215. self.attention_layers = config.attention_layers
  216. self.attention_type = self.attention_layers[layer_id]
  217. if self.attention_type in ["global", "local"]:
  218. self.attention = GPTNeoSelfAttention(config, self.attention_type)
  219. else:
  220. raise NotImplementedError(
  221. "Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: "
  222. f"{config.attention_layers}. Select attn layer types from ['global', 'local'] only."
  223. )
  224. def forward(
  225. self,
  226. hidden_states,
  227. layer_past=None,
  228. attention_mask=None,
  229. head_mask=None,
  230. use_cache=False,
  231. output_attentions=False,
  232. ):
  233. return self.attention(
  234. hidden_states,
  235. attention_mask=attention_mask,
  236. layer_past=layer_past,
  237. head_mask=head_mask,
  238. use_cache=use_cache,
  239. output_attentions=output_attentions,
  240. )
  241. class GPTNeoMLP(nn.Module):
  242. def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * hidden_size
  243. super().__init__()
  244. embed_dim = config.hidden_size
  245. self.c_fc = nn.Linear(embed_dim, intermediate_size)
  246. self.c_proj = nn.Linear(intermediate_size, embed_dim)
  247. self.act = ACT2FN[config.activation_function]
  248. self.dropout = nn.Dropout(float(config.resid_dropout))
  249. def forward(self, hidden_states):
  250. hidden_states = self.c_fc(hidden_states)
  251. hidden_states = self.act(hidden_states)
  252. hidden_states = self.c_proj(hidden_states)
  253. hidden_states = self.dropout(hidden_states)
  254. return hidden_states
  255. class GPTNeoBlock(nn.Module):
  256. def __init__(self, config, layer_id):
  257. super().__init__()
  258. hidden_size = config.hidden_size
  259. inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size
  260. self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  261. self.attn = GPTNeoAttention(config, layer_id)
  262. self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  263. self.mlp = GPTNeoMLP(inner_dim, config)
  264. def forward(
  265. self,
  266. hidden_states,
  267. layer_past=None,
  268. attention_mask=None,
  269. head_mask=None,
  270. use_cache=False,
  271. output_attentions=False,
  272. ):
  273. residual = hidden_states
  274. hidden_states = self.ln_1(hidden_states)
  275. attn_outputs = self.attn(
  276. hidden_states,
  277. layer_past=layer_past,
  278. attention_mask=attention_mask,
  279. head_mask=head_mask,
  280. use_cache=use_cache,
  281. output_attentions=output_attentions,
  282. )
  283. attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
  284. outputs = attn_outputs[1:]
  285. # residual connection
  286. hidden_states = attn_output + residual
  287. residual = hidden_states
  288. hidden_states = self.ln_2(hidden_states)
  289. feed_forward_hidden_states = self.mlp(hidden_states)
  290. # residual connection
  291. hidden_states = residual + feed_forward_hidden_states
  292. if use_cache:
  293. outputs = (hidden_states,) + outputs
  294. else:
  295. outputs = (hidden_states,) + outputs[1:]
  296. return outputs # hidden_states, present, (attentions, cross_attentions)
  297. class GPTNeoPreTrainedModel(PreTrainedModel):
  298. """
  299. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  300. models.
  301. """
  302. config_class = GPTNeoConfig
  303. load_tf_weights = load_tf_weights_in_gpt_neo
  304. base_model_prefix = "transformer"
  305. supports_gradient_checkpointing = True
  306. _no_split_modules = ["GPTNeoBlock"]
  307. def __init__(self, *inputs, **kwargs):
  308. super().__init__(*inputs, **kwargs)
  309. def _init_weights(self, module):
  310. """Initialize the weights."""
  311. if isinstance(module, (nn.Linear,)):
  312. # Slightly different from the TF version which uses truncated_normal for initialization
  313. # cf https://github.com/pytorch/pytorch/pull/5617
  314. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  315. if module.bias is not None:
  316. module.bias.data.zero_()
  317. elif isinstance(module, nn.Embedding):
  318. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  319. if module.padding_idx is not None:
  320. module.weight.data[module.padding_idx].zero_()
  321. elif isinstance(module, nn.LayerNorm):
  322. module.bias.data.zero_()
  323. module.weight.data.fill_(1.0)
  324. def _set_gradient_checkpointing(self, module, value=False):
  325. if isinstance(module, GPTNeoModel):
  326. module.gradient_checkpointing = value
  327. GPT_NEO_START_DOCSTRING = r"""
  328. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  329. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  330. etc.)
  331. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  332. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  333. and behavior.
  334. Parameters:
  335. config ([`GPTNeoConfig`]): Model configuration class with all the parameters of the model.
  336. Initializing with a config file does not load the weights associated with the model, only the
  337. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  338. """
  339. GPT_NEO_INPUTS_DOCSTRING = r"""
  340. Args:
  341. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  342. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  343. `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
  344. sequence tokens in the vocabulary.
  345. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  346. `input_ids`.
  347. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  348. [`PreTrainedTokenizer.__call__`] for details.
  349. [What are input IDs?](../glossary#input-ids)
  350. past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_layers`):
  351. Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
  352. `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
  353. their past given to this model should not be passed as `input_ids` as they have already been computed.
  354. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  355. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  356. - 1 for tokens that are **not masked**,
  357. - 0 for tokens that are **masked**.
  358. [What are attention masks?](../glossary#attention-mask)
  359. token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
  360. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  361. 1]`:
  362. - 0 corresponds to a *sentence A* token,
  363. - 1 corresponds to a *sentence B* token.
  364. [What are token type IDs?](../glossary#token-type-ids)
  365. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  366. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  367. config.max_position_embeddings - 1]`.
  368. [What are position IDs?](../glossary#position-ids)
  369. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  370. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  371. - 1 indicates the head is **not masked**,
  372. - 0 indicates the head is **masked**.
  373. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  374. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  375. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  376. model's internal embedding lookup matrix.
  377. If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
  378. `past_key_values`).
  379. use_cache (`bool`, *optional*):
  380. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  381. `past_key_values`).
  382. output_attentions (`bool`, *optional*):
  383. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  384. tensors for more detail.
  385. output_hidden_states (`bool`, *optional*):
  386. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  387. more detail.
  388. return_dict (`bool`, *optional*):
  389. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  390. """
  391. @add_start_docstrings(
  392. "The bare GPT Neo Model transformer outputting raw hidden-states without any specific head on top.",
  393. GPT_NEO_START_DOCSTRING,
  394. )
  395. class GPTNeoModel(GPTNeoPreTrainedModel):
  396. def __init__(self, config):
  397. super().__init__(config)
  398. self.embed_dim = config.hidden_size
  399. self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
  400. self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
  401. self.drop = nn.Dropout(float(config.embed_dropout))
  402. self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
  403. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  404. self.gradient_checkpointing = False
  405. # Initialize weights and apply final processing
  406. self.post_init()
  407. def get_input_embeddings(self):
  408. return self.wte
  409. def set_input_embeddings(self, new_embeddings):
  410. self.wte = new_embeddings
  411. @add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
  412. @add_code_sample_docstrings(
  413. checkpoint=_CHECKPOINT_FOR_DOC,
  414. output_type=BaseModelOutputWithPastAndCrossAttentions,
  415. config_class=_CONFIG_FOR_DOC,
  416. )
  417. def forward(
  418. self,
  419. input_ids: Optional[torch.Tensor] = None,
  420. past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
  421. attention_mask: Optional[torch.Tensor] = None,
  422. token_type_ids: Optional[torch.Tensor] = None,
  423. position_ids: Optional[torch.Tensor] = None,
  424. head_mask: Optional[torch.Tensor] = None,
  425. inputs_embeds: Optional[torch.Tensor] = None,
  426. use_cache: Optional[bool] = None,
  427. output_attentions: Optional[bool] = None,
  428. output_hidden_states: Optional[bool] = None,
  429. return_dict: Optional[bool] = None,
  430. ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
  431. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  432. output_hidden_states = (
  433. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  434. )
  435. use_cache = use_cache if use_cache is not None else self.config.use_cache
  436. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  437. if input_ids is not None and inputs_embeds is not None:
  438. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  439. elif input_ids is not None:
  440. input_shape = input_ids.size()
  441. input_ids = input_ids.view(-1, input_shape[-1])
  442. batch_size = input_ids.shape[0]
  443. elif inputs_embeds is not None:
  444. input_shape = inputs_embeds.size()[:-1]
  445. batch_size = inputs_embeds.shape[0]
  446. else:
  447. raise ValueError("You have to specify either input_ids or inputs_embeds")
  448. device = input_ids.device if input_ids is not None else inputs_embeds.device
  449. if token_type_ids is not None:
  450. token_type_ids = token_type_ids.view(-1, input_shape[-1])
  451. if position_ids is not None:
  452. position_ids = position_ids.view(-1, input_shape[-1])
  453. if past_key_values is None:
  454. past_length = 0
  455. past_key_values = tuple([None] * len(self.h))
  456. else:
  457. past_length = past_key_values[0][0].size(-2)
  458. if position_ids is None:
  459. position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
  460. position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
  461. # Attention mask.
  462. if attention_mask is not None:
  463. if batch_size <= 0:
  464. raise ValueError("batch_size has to be defined and > 0")
  465. attention_mask = attention_mask.view(batch_size, -1)
  466. # We create a 3D attention mask from a 2D tensor mask.
  467. # Sizes are [batch_size, 1, 1, to_seq_length]
  468. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  469. # this attention mask is more simple than the triangular masking of causal attention
  470. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  471. attention_mask = attention_mask[:, None, None, :]
  472. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  473. # masked positions, this operation will create a tensor which is 0.0 for
  474. # positions we want to attend and the dtype's smallest value for masked positions.
  475. # Since we are adding it to the raw scores before the softmax, this is
  476. # effectively the same as removing these entirely.
  477. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
  478. attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
  479. # Prepare head mask if needed
  480. # 1.0 in head_mask indicate we keep the head
  481. # attention_probs has shape bsz x num_heads x N x N
  482. # head_mask has shape n_layer x batch x num_heads x N x N
  483. head_mask = self.get_head_mask(head_mask, self.config.num_layers)
  484. if inputs_embeds is None:
  485. inputs_embeds = self.wte(input_ids)
  486. position_embeds = self.wpe(position_ids)
  487. hidden_states = inputs_embeds + position_embeds
  488. if token_type_ids is not None:
  489. token_type_embeds = self.wte(token_type_ids)
  490. hidden_states = hidden_states + token_type_embeds
  491. hidden_states = self.drop(hidden_states)
  492. output_shape = input_shape + (hidden_states.size(-1),)
  493. if self.gradient_checkpointing and self.training:
  494. if use_cache:
  495. logger.warning_once(
  496. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  497. )
  498. use_cache = False
  499. presents = () if use_cache else None
  500. all_self_attentions = () if output_attentions else None
  501. all_hidden_states = () if output_hidden_states else None
  502. for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
  503. if output_hidden_states:
  504. all_hidden_states = all_hidden_states + (hidden_states,)
  505. if self.gradient_checkpointing and self.training:
  506. def create_custom_forward(module):
  507. def custom_forward(*inputs):
  508. # None for past_key_value
  509. return module(*inputs, use_cache, output_attentions)
  510. return custom_forward
  511. outputs = torch.utils.checkpoint.checkpoint(
  512. create_custom_forward(block),
  513. hidden_states,
  514. None,
  515. attention_mask,
  516. head_mask[i],
  517. )
  518. else:
  519. outputs = block(
  520. hidden_states,
  521. layer_past=layer_past,
  522. attention_mask=attention_mask,
  523. head_mask=head_mask[i],
  524. use_cache=use_cache,
  525. output_attentions=output_attentions,
  526. )
  527. hidden_states = outputs[0]
  528. if use_cache is True:
  529. presents = presents + (outputs[1],)
  530. if output_attentions:
  531. all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
  532. hidden_states = self.ln_f(hidden_states)
  533. hidden_states = hidden_states.view(output_shape)
  534. # Add last hidden state
  535. if output_hidden_states:
  536. all_hidden_states = all_hidden_states + (hidden_states,)
  537. if not return_dict:
  538. return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
  539. return BaseModelOutputWithPast(
  540. last_hidden_state=hidden_states,
  541. past_key_values=presents,
  542. hidden_states=all_hidden_states,
  543. attentions=all_self_attentions,
  544. )
  545. @add_start_docstrings(
  546. """
  547. The GPT Neo Model transformer with a language modeling head on top (linear layer with weights tied to the input
  548. embeddings).
  549. """,
  550. GPT_NEO_START_DOCSTRING,
  551. )
  552. class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
  553. _keys_to_ignore_on_load_missing = [
  554. r"h\.\d+\.attn\.masked_bias",
  555. r"lm_head.weight",
  556. r"h\.\d+\.attn\.attention\.bias",
  557. ]
  558. _keys_to_ignore_on_save = [r"lm_head.weight"]
  559. def __init__(self, config):
  560. super().__init__(config)
  561. self.transformer = GPTNeoModel(config)
  562. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  563. # Initialize weights and apply final processing
  564. self.post_init()
  565. def get_output_embeddings(self):
  566. return self.lm_head
  567. def set_output_embeddings(self, new_embeddings):
  568. self.lm_head = new_embeddings
  569. def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
  570. token_type_ids = kwargs.get("token_type_ids", None)
  571. # only last token for inputs_ids if past is defined in kwargs
  572. if past_key_values:
  573. input_ids = input_ids[:, -1].unsqueeze(-1)
  574. if token_type_ids is not None:
  575. token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
  576. attention_mask = kwargs.get("attention_mask", None)
  577. position_ids = kwargs.get("position_ids", None)
  578. if attention_mask is not None and position_ids is None:
  579. # create position_ids on the fly for batch generation
  580. position_ids = attention_mask.long().cumsum(-1) - 1
  581. position_ids.masked_fill_(attention_mask == 0, 1)
  582. if past_key_values:
  583. position_ids = position_ids[:, -1].unsqueeze(-1)
  584. else:
  585. position_ids = None
  586. return {
  587. "input_ids": input_ids,
  588. "past_key_values": past_key_values,
  589. "use_cache": kwargs.get("use_cache"),
  590. "position_ids": position_ids,
  591. "attention_mask": attention_mask,
  592. "token_type_ids": token_type_ids,
  593. }
  594. @add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
  595. @add_code_sample_docstrings(
  596. checkpoint=_CHECKPOINT_FOR_DOC,
  597. output_type=CausalLMOutputWithCrossAttentions,
  598. config_class=_CONFIG_FOR_DOC,
  599. )
  600. def forward(
  601. self,
  602. input_ids: Optional[torch.Tensor] = None,
  603. past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
  604. attention_mask: Optional[torch.Tensor] = None,
  605. token_type_ids: Optional[torch.Tensor] = None,
  606. position_ids: Optional[torch.Tensor] = None,
  607. head_mask: Optional[torch.Tensor] = None,
  608. inputs_embeds: Optional[torch.Tensor] = None,
  609. labels: Optional[torch.Tensor] = None,
  610. use_cache: Optional[bool] = None,
  611. output_attentions: Optional[bool] = None,
  612. output_hidden_states: Optional[bool] = None,
  613. return_dict: Optional[bool] = None,
  614. ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
  615. r"""
  616. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  617. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  618. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  619. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  620. """
  621. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  622. transformer_outputs = self.transformer(
  623. input_ids,
  624. past_key_values=past_key_values,
  625. attention_mask=attention_mask,
  626. token_type_ids=token_type_ids,
  627. position_ids=position_ids,
  628. head_mask=head_mask,
  629. inputs_embeds=inputs_embeds,
  630. use_cache=use_cache,
  631. output_attentions=output_attentions,
  632. output_hidden_states=output_hidden_states,
  633. return_dict=return_dict,
  634. )
  635. hidden_states = transformer_outputs[0]
  636. lm_logits = self.lm_head(hidden_states)
  637. loss = None
  638. if labels is not None:
  639. # move labels to correct device to enable model parallelism
  640. labels = labels.to(lm_logits.device)
  641. # Compute loss in fp32 to match with mesh-tf version
  642. # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
  643. lm_logits = lm_logits.to(torch.float32)
  644. # Shift so that tokens < n predict n
  645. shift_logits = lm_logits[..., :-1, :].contiguous()
  646. shift_labels = labels[..., 1:].contiguous()
  647. # Flatten the tokens
  648. loss_fct = CrossEntropyLoss()
  649. loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  650. lm_logits = lm_logits.to(hidden_states.dtype)
  651. loss = loss.to(hidden_states.dtype)
  652. if not return_dict:
  653. output = (lm_logits,) + transformer_outputs[1:]
  654. return ((loss,) + output) if loss is not None else output
  655. return CausalLMOutputWithPast(
  656. loss=loss,
  657. logits=lm_logits,
  658. past_key_values=transformer_outputs.past_key_values,
  659. hidden_states=transformer_outputs.hidden_states,
  660. attentions=transformer_outputs.attentions,
  661. )
  662. @staticmethod
  663. def _reorder_cache(
  664. past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
  665. ) -> Tuple[Tuple[torch.Tensor]]:
  666. """
  667. This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
  668. [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
  669. beam_idx at every generation step.
  670. """
  671. return tuple(
  672. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
  673. for layer_past in past_key_values
  674. )
  675. @add_start_docstrings(
  676. """
  677. The GPTNeo Model transformer with a sequence classification head on top (linear layer).
  678. [`GPTNeoForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  679. (e.g. GPT-1) do.
  680. Since it does classification on the last token, it requires to know the position of the last token. If a
  681. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  682. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  683. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  684. each row of the batch).
  685. """,
  686. GPT_NEO_START_DOCSTRING,
  687. )
  688. class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
  689. _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
  690. def __init__(self, config):
  691. super().__init__(config)
  692. self.num_labels = config.num_labels
  693. self.transformer = GPTNeoModel(config)
  694. self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
  695. # Initialize weights and apply final processing
  696. self.post_init()
  697. @add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
  698. @add_code_sample_docstrings(
  699. checkpoint=_CHECKPOINT_FOR_DOC,
  700. output_type=SequenceClassifierOutputWithPast,
  701. config_class=_CONFIG_FOR_DOC,
  702. )
  703. def forward(
  704. self,
  705. input_ids: Optional[torch.Tensor] = None,
  706. past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
  707. attention_mask: Optional[torch.Tensor] = None,
  708. token_type_ids: Optional[torch.Tensor] = None,
  709. position_ids: Optional[torch.Tensor] = None,
  710. head_mask: Optional[torch.Tensor] = None,
  711. inputs_embeds: Optional[torch.Tensor] = None,
  712. labels: Optional[torch.Tensor] = None,
  713. use_cache: Optional[bool] = None,
  714. output_attentions: Optional[bool] = None,
  715. output_hidden_states: Optional[bool] = None,
  716. return_dict: Optional[bool] = None,
  717. ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
  718. r"""
  719. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  720. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  721. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  722. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  723. """
  724. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  725. transformer_outputs = self.transformer(
  726. input_ids,
  727. past_key_values=past_key_values,
  728. attention_mask=attention_mask,
  729. token_type_ids=token_type_ids,
  730. position_ids=position_ids,
  731. head_mask=head_mask,
  732. inputs_embeds=inputs_embeds,
  733. use_cache=use_cache,
  734. output_attentions=output_attentions,
  735. output_hidden_states=output_hidden_states,
  736. return_dict=return_dict,
  737. )
  738. hidden_states = transformer_outputs[0]
  739. logits = self.score(hidden_states)
  740. if input_ids is not None:
  741. batch_size, sequence_length = input_ids.shape[:2]
  742. else:
  743. batch_size, sequence_length = inputs_embeds.shape[:2]
  744. if self.config.pad_token_id is None and batch_size != 1:
  745. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  746. if self.config.pad_token_id is None:
  747. sequence_lengths = -1
  748. else:
  749. if input_ids is not None:
  750. sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
  751. else:
  752. sequence_lengths = -1
  753. logger.warning(
  754. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  755. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  756. )
  757. pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
  758. loss = None
  759. if labels is not None:
  760. if self.config.problem_type is None:
  761. if self.num_labels == 1:
  762. self.config.problem_type = "regression"
  763. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  764. self.config.problem_type = "single_label_classification"
  765. else:
  766. self.config.problem_type = "multi_label_classification"
  767. if self.config.problem_type == "regression":
  768. loss_fct = MSELoss()
  769. if self.num_labels == 1:
  770. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  771. else:
  772. loss = loss_fct(pooled_logits, labels)
  773. elif self.config.problem_type == "single_label_classification":
  774. loss_fct = CrossEntropyLoss()
  775. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  776. elif self.config.problem_type == "multi_label_classification":
  777. loss_fct = BCEWithLogitsLoss()
  778. loss = loss_fct(pooled_logits, labels)
  779. if not return_dict:
  780. output = (pooled_logits,) + transformer_outputs[1:]
  781. return ((loss,) + output) if loss is not None else output
  782. return SequenceClassifierOutputWithPast(
  783. loss=loss,
  784. logits=pooled_logits,
  785. past_key_values=transformer_outputs.past_key_values,
  786. hidden_states=transformer_outputs.hidden_states,
  787. attentions=transformer_outputs.attentions,
  788. )
  789. @add_start_docstrings(
  790. """
  791. GPT Neo model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
  792. Named-Entity-Recognition (NER) tasks.
  793. """,
  794. GPT_NEO_START_DOCSTRING,
  795. )
  796. class GPTNeoForTokenClassification(GPTNeoPreTrainedModel):
  797. def __init__(self, config):
  798. super().__init__(config)
  799. self.num_labels = config.num_labels
  800. self.transformer = GPTNeoModel(config)
  801. self.dropout = nn.Dropout(config.classifier_dropout)
  802. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  803. # Initialize weights and apply final processing
  804. self.post_init()
  805. @add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
  806. @add_code_sample_docstrings(
  807. checkpoint="EleutherAI/gpt-neo-125m",
  808. output_type=TokenClassifierOutput,
  809. config_class=_CONFIG_FOR_DOC,
  810. expected_loss=0.25,
  811. )
  812. def forward(
  813. self,
  814. input_ids: Optional[torch.LongTensor] = None,
  815. past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  816. attention_mask: Optional[torch.FloatTensor] = None,
  817. token_type_ids: Optional[torch.LongTensor] = None,
  818. position_ids: Optional[torch.LongTensor] = None,
  819. head_mask: Optional[torch.FloatTensor] = None,
  820. inputs_embeds: Optional[torch.FloatTensor] = None,
  821. labels: Optional[torch.LongTensor] = None,
  822. use_cache: Optional[bool] = None,
  823. output_attentions: Optional[bool] = None,
  824. output_hidden_states: Optional[bool] = None,
  825. return_dict: Optional[bool] = None,
  826. ) -> Union[Tuple, TokenClassifierOutput]:
  827. r"""
  828. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  829. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  830. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  831. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  832. """
  833. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  834. transformer_outputs = self.transformer(
  835. input_ids,
  836. past_key_values=past_key_values,
  837. attention_mask=attention_mask,
  838. token_type_ids=token_type_ids,
  839. position_ids=position_ids,
  840. head_mask=head_mask,
  841. inputs_embeds=inputs_embeds,
  842. use_cache=use_cache,
  843. output_attentions=output_attentions,
  844. output_hidden_states=output_hidden_states,
  845. return_dict=return_dict,
  846. )
  847. hidden_states = transformer_outputs[0]
  848. hidden_states = self.dropout(hidden_states)
  849. logits = self.classifier(hidden_states)
  850. loss = None
  851. if labels is not None:
  852. labels = labels.to(logits.device)
  853. loss_fct = CrossEntropyLoss()
  854. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  855. if not return_dict:
  856. output = (logits,) + transformer_outputs[2:]
  857. return ((loss,) + output) if loss is not None else output
  858. return TokenClassifierOutput(
  859. loss=loss,
  860. logits=logits,
  861. hidden_states=transformer_outputs.hidden_states,
  862. attentions=transformer_outputs.attentions,
  863. )
  864. @add_start_docstrings(
  865. """
  866. The GPT-Neo Model transformer with a span classification head on top for extractive question-answering tasks like
  867. SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
  868. """,
  869. GPT_NEO_START_DOCSTRING,
  870. )
  871. class GPTNeoForQuestionAnswering(GPTNeoPreTrainedModel):
  872. _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head.weight"]
  873. def __init__(self, config):
  874. super().__init__(config)
  875. self.num_labels = config.num_labels
  876. self.transformer = GPTNeoModel(config)
  877. self.qa_outputs = nn.Linear(config.hidden_size, 2)
  878. # Initialize weights and apply final processing
  879. self.post_init()
  880. @add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  881. @add_code_sample_docstrings(
  882. checkpoint=_CHECKPOINT_FOR_DOC,
  883. output_type=QuestionAnsweringModelOutput,
  884. config_class=_CONFIG_FOR_DOC,
  885. real_checkpoint=_CHECKPOINT_FOR_DOC,
  886. )
  887. def forward(
  888. self,
  889. input_ids: Optional[torch.LongTensor] = None,
  890. attention_mask: Optional[torch.FloatTensor] = None,
  891. token_type_ids: Optional[torch.LongTensor] = None,
  892. position_ids: Optional[torch.LongTensor] = None,
  893. head_mask: Optional[torch.FloatTensor] = None,
  894. inputs_embeds: Optional[torch.FloatTensor] = None,
  895. start_positions: Optional[torch.LongTensor] = None,
  896. end_positions: Optional[torch.LongTensor] = None,
  897. output_attentions: Optional[bool] = None,
  898. output_hidden_states: Optional[bool] = None,
  899. return_dict: Optional[bool] = None,
  900. ) -> Union[Tuple, QuestionAnsweringModelOutput]:
  901. r"""
  902. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  903. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  904. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  905. are not taken into account for computing the loss.
  906. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  907. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  908. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  909. are not taken into account for computing the loss.
  910. """
  911. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  912. outputs = self.transformer(
  913. input_ids,
  914. attention_mask=attention_mask,
  915. token_type_ids=token_type_ids,
  916. position_ids=position_ids,
  917. head_mask=head_mask,
  918. inputs_embeds=inputs_embeds,
  919. output_attentions=output_attentions,
  920. output_hidden_states=output_hidden_states,
  921. return_dict=return_dict,
  922. )
  923. sequence_output = outputs[0]
  924. logits = self.qa_outputs(sequence_output)
  925. start_logits, end_logits = logits.split(1, dim=-1)
  926. start_logits = start_logits.squeeze(-1).contiguous()
  927. end_logits = end_logits.squeeze(-1).contiguous()
  928. total_loss = None
  929. if start_positions is not None and end_positions is not None:
  930. # If we are on multi-GPU, split add a dimension
  931. if len(start_positions.size()) > 1:
  932. start_positions = start_positions.squeeze(-1)
  933. if len(end_positions.size()) > 1:
  934. end_positions = end_positions.squeeze(-1)
  935. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  936. ignored_index = start_logits.size(1)
  937. start_positions = start_positions.clamp(0, ignored_index)
  938. end_positions = end_positions.clamp(0, ignored_index)
  939. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  940. start_loss = loss_fct(start_logits, start_positions)
  941. end_loss = loss_fct(end_logits, end_positions)
  942. total_loss = (start_loss + end_loss) / 2
  943. if not return_dict:
  944. output = (start_logits, end_logits) + outputs[2:]
  945. return ((total_loss,) + output) if total_loss is not None else output
  946. return QuestionAnsweringModelOutput(
  947. loss=total_loss,
  948. start_logits=start_logits,
  949. end_logits=end_logits,
  950. hidden_states=outputs.hidden_states,
  951. attentions=outputs.attentions,
  952. )

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读