当前位置:   article > 正文

多模态 | 多模态中单模态中提取特征方法(代码)_多模态特征提取

多模态特征提取

多模态任务中,有一种方法时在单模态中先各自提取各模态的特征,然后进行融合,本文主要实现各模态特征的提取。

(Torch框架)

视觉特征提取

在图像(Image)中提取特征(如果是视频的话,需要将视频分帧)

  1. #/usr/bin/env python
  2. """Script to extract ResNet features from video frames."""
  3. import argparse
  4. from typing import Any, Tuple
  5. import h5py
  6. from overrides import overrides
  7. import torch
  8. import torch.nn
  9. import torch.utils.data
  10. import torchvision
  11. from tqdm import tqdm
  12. from c3d import C3D
  13. from i3d import I3D
  14. from dataset import SarcasmDataset
  15. # noinspection PyUnresolvedReferences
  16. DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  17. def pretrained_resnet152() -> torch.nn.Module:
  18. resnet152 = torchvision.models.resnet152(pretrained=True)
  19. resnet152.eval()
  20. for param in resnet152.parameters():
  21. param.requires_grad = False
  22. return resnet152
  23. def pretrained_c3d() -> torch.nn.Module:
  24. c3d = C3D(pretrained=True)
  25. c3d.eval()
  26. for param in c3d.parameters():
  27. param.requires_grad = False
  28. return c3d
  29. def pretrained_i3d() -> torch.nn.Module:
  30. i3d = I3D(pretrained=True)
  31. i3d.eval()
  32. for param in i3d.parameters():
  33. param.requires_grad = False
  34. return i3d
  35. def save_resnet_features() -> None:
  36. transforms = torchvision.transforms.Compose([
  37. torchvision.transforms.Resize(256),
  38. torchvision.transforms.CenterCrop(224),
  39. torchvision.transforms.ToTensor(),
  40. torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  41. ])
  42. dataset = SarcasmDataset(transform=transforms)
  43. resnet = pretrained_resnet152().to(DEVICE)
  44. class Identity(torch.nn.Module):
  45. @overrides
  46. def forward(self, input_: torch.Tensor) -> torch.Tensor:
  47. return input_
  48. resnet.fc = Identity() # Trick to avoid computing the fc1000 layer, as we don't need it here.
  49. with h5py.File(SarcasmDataset.features_file_path("resnet", "res5c"), "w") as res5c_features_file, \
  50. h5py.File(SarcasmDataset.features_file_path("resnet", "pool5"), "w") as pool5_features_file:
  51. for video_id in dataset.video_ids:
  52. video_frame_count = dataset.frame_count_by_video_id[video_id]
  53. res5c_features_file.create_dataset(video_id, shape=(video_frame_count, 2048, 7, 7))
  54. pool5_features_file.create_dataset(video_id, shape=(video_frame_count, 2048))
  55. res5c_output = None
  56. def avg_pool_hook(_module: torch.nn.Module, input_: Tuple[torch.Tensor], _output: Any) -> None:
  57. nonlocal res5c_output
  58. res5c_output = input_[0]
  59. resnet.avgpool.register_forward_hook(avg_pool_hook)
  60. total_frame_count = sum(dataset.frame_count_by_video_id[video_id] for video_id in dataset.video_ids)
  61. with tqdm(total=total_frame_count, desc="Extracting ResNet features") as progress_bar:
  62. for instance in torch.utils.data.DataLoader(dataset):
  63. video_id = instance["id"][0]
  64. frames = instance["frames"][0].to(DEVICE)
  65. batch_size = 32
  66. for start_index in range(0, len(frames), batch_size):
  67. end_index = min(start_index + batch_size, len(frames))
  68. frame_ids_range = range(start_index, end_index)
  69. frame_batch = frames[frame_ids_range]
  70. avg_pool_value = resnet(frame_batch)
  71. res5c_features_file[video_id][frame_ids_range] = res5c_output.cpu() # noqa
  72. pool5_features_file[video_id][frame_ids_range] = avg_pool_value.cpu()
  73. progress_bar.update(len(frame_ids_range))
  74. def save_c3d_features() -> None:
  75. transforms = torchvision.transforms.Compose([
  76. torchvision.transforms.Resize(128),
  77. torchvision.transforms.CenterCrop(112),
  78. torchvision.transforms.ToTensor(),
  79. torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  80. ])
  81. dataset = SarcasmDataset(transform=transforms)
  82. c3d = pretrained_c3d().to(DEVICE)
  83. with h5py.File(SarcasmDataset.features_file_path("c3d", "fc7"), "w") as fc7_features_file:
  84. for video_id in dataset.video_ids:
  85. video_frame_count = dataset.frame_count_by_video_id[video_id]
  86. feature_count = video_frame_count - 16 + 1
  87. fc7_features_file.create_dataset(video_id, shape=(feature_count, 4096))
  88. for instance in tqdm(torch.utils.data.DataLoader(dataset), desc="Extracting C3D features"):
  89. video_id = instance["id"][0] # noqa
  90. video_frame_count = dataset.frame_count_by_video_id[video_id]
  91. feature_count = video_frame_count - 16 + 1
  92. frames = instance["frames"][0].to(DEVICE)
  93. frames = frames.unsqueeze(0) # Add batch dimension
  94. frames = frames.transpose(1, 2) # C3D expects (B, C, T, H, W)
  95. for i in range(feature_count):
  96. output = c3d.extract_features(frames[:, :, i:i + 16, :, :]).squeeze()
  97. fc7_features_file[video_id][i, :] = output.cpu().data.numpy()
  98. def save_i3d_features() -> None:
  99. transforms = torchvision.transforms.Compose([
  100. torchvision.transforms.Resize(256),
  101. torchvision.transforms.CenterCrop(224),
  102. torchvision.transforms.ToTensor(),
  103. ])
  104. dataset = SarcasmDataset(transform=transforms)
  105. i3d = pretrained_i3d().to(DEVICE)
  106. with h5py.File(SarcasmDataset.features_file_path("i3d", "avg_pool"), "w") as avg_pool_features_file:
  107. for video_id in dataset.video_ids:
  108. video_frame_count = dataset.frame_count_by_video_id[video_id]
  109. feature_count = video_frame_count - 16 + 1
  110. avg_pool_features_file.create_dataset(video_id, shape=(feature_count, 1024))
  111. for instance in tqdm(torch.utils.data.DataLoader(dataset), desc="Extracting I3D features"):
  112. video_id = instance["id"][0] # noqa
  113. video_frame_count = dataset.frame_count_by_video_id[video_id]
  114. feature_count = video_frame_count - 16 + 1
  115. frames = instance["frames"][0].to(DEVICE)
  116. frames = frames.unsqueeze(0) # Add batch dimension
  117. frames = frames.transpose(1, 2) # I3D expects (B, C, T, H, W)
  118. for i in range(feature_count):
  119. output = i3d.extract_features(frames[:, :, i:i + 16, :, :]).squeeze()
  120. avg_pool_features_file[video_id][i, :] = output.cpu().data.numpy()
  121. def parse_args() -> argparse.Namespace:
  122. parser = argparse.ArgumentParser(description="Extract video features.")
  123. parser.add_argument("network", choices=["resnet", "c3d", "i3d"])
  124. return parser.parse_args()
  125. def main() -> None:
  126. args = parse_args()
  127. if args.network == "resnet":
  128. save_resnet_features()
  129. elif args.network == "c3d":
  130. save_c3d_features()
  131. elif args.network == "i3d":
  132. save_i3d_features()
  133. else:
  134. raise ValueError(f"Network type not supported: {args.network}")
  135. if __name__ == "__main__":
  136. main()

代码提取特征并将它们保存到大型 H5 文件中。运行命令

python extract_features.py resnet

语音特征提取

  1. #!/usr/bin/env python
  2. import os
  3. import pickle
  4. import librosa
  5. import numpy as np
  6. from tqdm.auto import tqdm
  7. AUDIOS_FOLDER = "data/audios/utterances_final"
  8. AUDIO_FEATURES_PATH = "data/audio_features.p"
  9. def get_librosa_features(path: str) -> np.ndarray:
  10. y, sr = librosa.load(path)
  11. hop_length = 512 # Set the hop length; at 22050 Hz, 512 samples ~= 23ms
  12. # Remove vocals first
  13. D = librosa.stft(y, hop_length=hop_length)
  14. S_full, phase = librosa.magphase(D)
  15. S_filter = librosa.decompose.nn_filter(S_full, aggregate=np.median, metric="cosine",
  16. width=int(librosa.time_to_frames(0.2, sr=sr)))
  17. S_filter = np.minimum(S_full, S_filter)
  18. margin_i, margin_v = 2, 4
  19. power = 2
  20. mask_v = librosa.util.softmask(S_full - S_filter, margin_v * S_filter, power=power)
  21. S_foreground = mask_v * S_full
  22. # Recreate vocal_removal y
  23. new_D = S_foreground * phase
  24. y = librosa.istft(new_D)
  25. mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13) # Compute MFCC features from the raw signal
  26. mfcc_delta = librosa.feature.delta(mfcc) # And the first-order differences (delta features)
  27. S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
  28. S_delta = librosa.feature.delta(S)
  29. spectral_centroid = librosa.feature.spectral_centroid(S=S_full)
  30. audio_feature = np.vstack((mfcc, mfcc_delta, S, S_delta, spectral_centroid)) # combine features
  31. # binning data
  32. jump = int(audio_feature.shape[1] / 10)
  33. return librosa.util.sync(audio_feature, range(1, audio_feature.shape[1], jump))
  34. def save_audio_features() -> None:
  35. audio_feature = {}
  36. for filename in tqdm(os.listdir(AUDIOS_FOLDER), desc="Computing the audio features"):
  37. id_ = filename.rsplit(".", maxsplit=1)[0]
  38. audio_feature[id_] = get_librosa_features(os.path.join(AUDIOS_FOLDER, filename))
  39. print(audio_feature[id_].shape)
  40. with open(AUDIO_FEATURES_PATH, "wb") as file:
  41. pickle.dump(audio_feature, file, protocol=2)
  42. def get_audio_duration() -> None:
  43. filenames = os.listdir(AUDIOS_FOLDER)
  44. print(sum(librosa.core.get_duration(filename=os.path.join(AUDIOS_FOLDER, filename))
  45. for filename in tqdm(filenames, desc="Computing the average duration of the audios")) / len(filenames))
  46. def main() -> None:
  47. get_audio_duration()
  48. # save_audio_features()
  49. #
  50. # with open(AUDIO_FEATURES_PATH, "rb") as file:
  51. # pickle.load(file)
  52. if __name__ == "__main__":
  53. main()

文本特征提取(BERT)

详情请参考【2】

  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Extract pre-computed feature vectors from BERT."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import codecs
  20. import collections
  21. import json
  22. import re
  23. import modeling
  24. import tokenization
  25. import tensorflow as tf
  26. flags = tf.flags
  27. FLAGS = flags.FLAGS
  28. flags.DEFINE_string("input_file", None, "")
  29. flags.DEFINE_string("output_file", None, "")
  30. flags.DEFINE_string("layers", "-1,-2,-3,-4", "")
  31. flags.DEFINE_string(
  32. "bert_config_file", None,
  33. "The config json file corresponding to the pre-trained BERT model. "
  34. "This specifies the model architecture.")
  35. flags.DEFINE_integer(
  36. "max_seq_length", 128,
  37. "The maximum total input sequence length after WordPiece tokenization. "
  38. "Sequences longer than this will be truncated, and sequences shorter "
  39. "than this will be padded.")
  40. flags.DEFINE_string(
  41. "init_checkpoint", None,
  42. "Initial checkpoint (usually from a pre-trained BERT model).")
  43. flags.DEFINE_string("vocab_file", None,
  44. "The vocabulary file that the BERT model was trained on.")
  45. flags.DEFINE_bool(
  46. "do_lower_case", True,
  47. "Whether to lower case the input text. Should be True for uncased "
  48. "models and False for cased models.")
  49. flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.")
  50. flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
  51. flags.DEFINE_string("master", None,
  52. "If using a TPU, the address of the master.")
  53. flags.DEFINE_integer(
  54. "num_tpu_cores", 8,
  55. "Only used if `use_tpu` is True. Total number of TPU cores to use.")
  56. flags.DEFINE_bool(
  57. "use_one_hot_embeddings", False,
  58. "If True, tf.one_hot will be used for embedding lookups, otherwise "
  59. "tf.nn.embedding_lookup will be used. On TPUs, this should be True "
  60. "since it is much faster.")
  61. class InputExample(object):
  62. def __init__(self, unique_id, text_a, text_b):
  63. self.unique_id = unique_id
  64. self.text_a = text_a
  65. self.text_b = text_b
  66. class InputFeatures(object):
  67. """A single set of features of data."""
  68. def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
  69. self.unique_id = unique_id
  70. self.tokens = tokens
  71. self.input_ids = input_ids
  72. self.input_mask = input_mask
  73. self.input_type_ids = input_type_ids
  74. def input_fn_builder(features, seq_length):
  75. """Creates an `input_fn` closure to be passed to TPUEstimator."""
  76. all_unique_ids = []
  77. all_input_ids = []
  78. all_input_mask = []
  79. all_input_type_ids = []
  80. for feature in features:
  81. all_unique_ids.append(feature.unique_id)
  82. all_input_ids.append(feature.input_ids)
  83. all_input_mask.append(feature.input_mask)
  84. all_input_type_ids.append(feature.input_type_ids)
  85. def input_fn(params):
  86. """The actual input function."""
  87. batch_size = params["batch_size"]
  88. num_examples = len(features)
  89. # This is for demo purposes and does NOT scale to large data sets. We do
  90. # not use Dataset.from_generator() because that uses tf.py_func which is
  91. # not TPU compatible. The right way to load data is with TFRecordReader.
  92. d = tf.data.Dataset.from_tensor_slices({
  93. "unique_ids":
  94. tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32),
  95. "input_ids":
  96. tf.constant(
  97. all_input_ids, shape=[num_examples, seq_length],
  98. dtype=tf.int32),
  99. "input_mask":
  100. tf.constant(
  101. all_input_mask,
  102. shape=[num_examples, seq_length],
  103. dtype=tf.int32),
  104. "input_type_ids":
  105. tf.constant(
  106. all_input_type_ids,
  107. shape=[num_examples, seq_length],
  108. dtype=tf.int32),
  109. })
  110. d = d.batch(batch_size=batch_size, drop_remainder=False)
  111. return d
  112. return input_fn
  113. def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu,
  114. use_one_hot_embeddings):
  115. """Returns `model_fn` closure for TPUEstimator."""
  116. def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
  117. """The `model_fn` for TPUEstimator."""
  118. unique_ids = features["unique_ids"]
  119. input_ids = features["input_ids"]
  120. input_mask = features["input_mask"]
  121. input_type_ids = features["input_type_ids"]
  122. model = modeling.BertModel(
  123. config=bert_config,
  124. is_training=False,
  125. input_ids=input_ids,
  126. input_mask=input_mask,
  127. token_type_ids=input_type_ids,
  128. use_one_hot_embeddings=use_one_hot_embeddings)
  129. if mode != tf.estimator.ModeKeys.PREDICT:
  130. raise ValueError("Only PREDICT modes are supported: %s" % (mode))
  131. tvars = tf.trainable_variables()
  132. scaffold_fn = None
  133. (assignment_map,
  134. initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
  135. tvars, init_checkpoint)
  136. if use_tpu:
  137. def tpu_scaffold():
  138. tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
  139. return tf.train.Scaffold()
  140. scaffold_fn = tpu_scaffold
  141. else:
  142. tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
  143. tf.logging.info("**** Trainable Variables ****")
  144. for var in tvars:
  145. init_string = ""
  146. if var.name in initialized_variable_names:
  147. init_string = ", *INIT_FROM_CKPT*"
  148. tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
  149. init_string)
  150. all_layers = model.get_all_encoder_layers()
  151. predictions = {
  152. "unique_id": unique_ids,
  153. }
  154. for (i, layer_index) in enumerate(layer_indexes):
  155. predictions["layer_output_%d" % i] = all_layers[layer_index]
  156. output_spec = tf.contrib.tpu.TPUEstimatorSpec(
  157. mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
  158. return output_spec
  159. return model_fn
  160. def convert_examples_to_features(examples, seq_length, tokenizer):
  161. """Loads a data file into a list of `InputBatch`s."""
  162. features = []
  163. for (ex_index, example) in enumerate(examples):
  164. tokens_a = tokenizer.tokenize(example.text_a)
  165. tokens_b = None
  166. if example.text_b:
  167. tokens_b = tokenizer.tokenize(example.text_b)
  168. if tokens_b:
  169. # Modifies `tokens_a` and `tokens_b` in place so that the total
  170. # length is less than the specified length.
  171. # Account for [CLS], [SEP], [SEP] with "- 3"
  172. _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
  173. else:
  174. # Account for [CLS] and [SEP] with "- 2"
  175. if len(tokens_a) > seq_length - 2:
  176. tokens_a = tokens_a[0:(seq_length - 2)]
  177. # The convention in BERT is:
  178. # (a) For sequence pairs:
  179. # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
  180. # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
  181. # (b) For single sequences:
  182. # tokens: [CLS] the dog is hairy . [SEP]
  183. # type_ids: 0 0 0 0 0 0 0
  184. #
  185. # Where "type_ids" are used to indicate whether this is the first
  186. # sequence or the second sequence. The embedding vectors for `type=0` and
  187. # `type=1` were learned during pre-training and are added to the wordpiece
  188. # embedding vector (and position vector). This is not *strictly* necessary
  189. # since the [SEP] token unambiguously separates the sequences, but it makes
  190. # it easier for the model to learn the concept of sequences.
  191. #
  192. # For classification tasks, the first vector (corresponding to [CLS]) is
  193. # used as as the "sentence vector". Note that this only makes sense because
  194. # the entire model is fine-tuned.
  195. tokens = []
  196. input_type_ids = []
  197. tokens.append("[CLS]")
  198. input_type_ids.append(0)
  199. for token in tokens_a:
  200. tokens.append(token)
  201. input_type_ids.append(0)
  202. tokens.append("[SEP]")
  203. input_type_ids.append(0)
  204. if tokens_b:
  205. for token in tokens_b:
  206. tokens.append(token)
  207. input_type_ids.append(1)
  208. tokens.append("[SEP]")
  209. input_type_ids.append(1)
  210. input_ids = tokenizer.convert_tokens_to_ids(tokens)
  211. # The mask has 1 for real tokens and 0 for padding tokens. Only real
  212. # tokens are attended to.
  213. input_mask = [1] * len(input_ids)
  214. # Zero-pad up to the sequence length.
  215. while len(input_ids) < seq_length:
  216. input_ids.append(0)
  217. input_mask.append(0)
  218. input_type_ids.append(0)
  219. assert len(input_ids) == seq_length
  220. assert len(input_mask) == seq_length
  221. assert len(input_type_ids) == seq_length
  222. if ex_index < 5:
  223. tf.logging.info("*** Example ***")
  224. tf.logging.info("unique_id: %s" % (example.unique_id))
  225. tf.logging.info("tokens: %s" % " ".join(
  226. [tokenization.printable_text(x) for x in tokens]))
  227. tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
  228. tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
  229. tf.logging.info(
  230. "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids]))
  231. features.append(
  232. InputFeatures(
  233. unique_id=example.unique_id,
  234. tokens=tokens,
  235. input_ids=input_ids,
  236. input_mask=input_mask,
  237. input_type_ids=input_type_ids))
  238. return features
  239. def _truncate_seq_pair(tokens_a, tokens_b, max_length):
  240. """Truncates a sequence pair in place to the maximum length."""
  241. # This is a simple heuristic which will always truncate the longer sequence
  242. # one token at a time. This makes more sense than truncating an equal percent
  243. # of tokens from each, since if one sequence is very short then each token
  244. # that's truncated likely contains more information than a longer sequence.
  245. while True:
  246. total_length = len(tokens_a) + len(tokens_b)
  247. if total_length <= max_length:
  248. break
  249. if len(tokens_a) > len(tokens_b):
  250. tokens_a.pop()
  251. else:
  252. tokens_b.pop()
  253. def read_examples(input_file):
  254. """Read a list of `InputExample`s from an input file."""
  255. examples = []
  256. unique_id = 0
  257. with tf.gfile.GFile(input_file, "r") as reader:
  258. while True:
  259. line = tokenization.convert_to_unicode(reader.readline())
  260. if not line:
  261. break
  262. line = line.strip()
  263. text_a = None
  264. text_b = None
  265. m = re.match(r"^(.*) \|\|\| (.*)$", line)
  266. if m is None:
  267. text_a = line
  268. else:
  269. text_a = m.group(1)
  270. text_b = m.group(2)
  271. examples.append(
  272. InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
  273. unique_id += 1
  274. return examples
  275. def main(_):
  276. tf.logging.set_verbosity(tf.logging.INFO)
  277. layer_indexes = [int(x) for x in FLAGS.layers.split(",")]
  278. bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
  279. tokenizer = tokenization.FullTokenizer(
  280. vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
  281. is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
  282. run_config = tf.contrib.tpu.RunConfig(
  283. master=FLAGS.master,
  284. tpu_config=tf.contrib.tpu.TPUConfig(
  285. num_shards=FLAGS.num_tpu_cores,
  286. per_host_input_for_training=is_per_host))
  287. examples = read_examples(FLAGS.input_file)
  288. features = convert_examples_to_features(
  289. examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer)
  290. unique_id_to_feature = {}
  291. for feature in features:
  292. unique_id_to_feature[feature.unique_id] = feature
  293. model_fn = model_fn_builder(
  294. bert_config=bert_config,
  295. init_checkpoint=FLAGS.init_checkpoint,
  296. layer_indexes=layer_indexes,
  297. use_tpu=FLAGS.use_tpu,
  298. use_one_hot_embeddings=FLAGS.use_one_hot_embeddings)
  299. # If TPU is not available, this will fall back to normal Estimator on CPU
  300. # or GPU.
  301. estimator = tf.contrib.tpu.TPUEstimator(
  302. use_tpu=FLAGS.use_tpu,
  303. model_fn=model_fn,
  304. config=run_config,
  305. predict_batch_size=FLAGS.batch_size)
  306. input_fn = input_fn_builder(
  307. features=features, seq_length=FLAGS.max_seq_length)
  308. with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file,
  309. "w")) as writer:
  310. for result in estimator.predict(input_fn, yield_single_examples=True):
  311. unique_id = int(result["unique_id"])
  312. feature = unique_id_to_feature[unique_id]
  313. output_json = collections.OrderedDict()
  314. output_json["linex_index"] = unique_id
  315. all_features = []
  316. for (i, token) in enumerate(feature.tokens):
  317. all_layers = []
  318. for (j, layer_index) in enumerate(layer_indexes):
  319. layer_output = result["layer_output_%d" % j]
  320. layers = collections.OrderedDict()
  321. layers["index"] = layer_index
  322. layers["values"] = [
  323. round(float(x), 6) for x in layer_output[i:(i + 1)].flat
  324. ]
  325. all_layers.append(layers)
  326. features = collections.OrderedDict()
  327. features["token"] = token
  328. features["layers"] = all_layers
  329. all_features.append(features)
  330. output_json["features"] = all_features
  331. writer.write(json.dumps(output_json) + "\n")
  332. if __name__ == "__main__":
  333. flags.mark_flag_as_required("input_file")
  334. flags.mark_flag_as_required("vocab_file")
  335. flags.mark_flag_as_required("bert_config_file")
  336. flags.mark_flag_as_required("init_checkpoint")
  337. flags.mark_flag_as_required("output_file")
  338. tf.app.run()

参考文献

【1】GitHub - soujanyaporia/MUStARD: Multimodal Sarcasm Detection Dataset

【2】GitHub - google-research/bert at d66a146741588fb208450bde15aa7db143baaa69 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/470672
推荐阅读
相关标签
  

闽ICP备14008679号