当前位置:   article > 正文

VAE-LSTM tensorflow实现过程_lstm-vae

lstm-vae

依据论文:ANOMALY DETECTION FOR TIME SERIES USING VAE-LSTM HYBRID MODEL(可在IEEE上自行寻找)
代码来源:github
运行环境:gpu
VAE-LSTM原理图:
该图取自上述论文
以下可针对自己的需求进行适当更改。
上图原理可以简单理解为当数据输入时,先由VAE的编码器网络对输入数据进行压缩,并做特征提取,将提取到的特征输入LSTM网络进行故障检测或分类,并对特征进行归类预测,将预测得到的结果输入VAE解码器网络,进行重构,并计算重构损失,更新整体网络参数。VAE与LSTM二者结合,进一步提高模型诊断精度。(详细原理阐述可参见论文原文)
相关程序由6部分组成,一个训练主程序,5个支持子程序。(在加载数据阶段,分好训练集,交叉验证集与测试集)
训练主程序如下:
train.py

import os
import tensorflow as tf
from data_loader import DataGenerator
from model import VAEmodel, lstmKerasModel
from trainer import vaeTrainer
from utils import process_config, create_dirs, get_args, save_config
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


def main():
    # capture the config path from the run arguments
    # then process the json configuration file
    try:
        args = get_args()
        config = process_config(args.config)
    except:
        print("missing or invalid arguments")
        exit(0)

    # create the experiments dirs
    create_dirs([config['result_dir'], config['checkpoint_dir'], config['checkpoint_dir_lstm']])
    # save the config in a txt file
    save_config(config)
    # create tensorflow session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
    # create your data generator
    data = DataGenerator(config)
    # create a CNN model
    model_vae = VAEmodel(config)
    # create a trainer for VAE model
    trainer_vae = vaeTrainer(sess, model_vae, data, config)
    model_vae.load(sess)
    # here you train your model
    if config['TRAIN_VAE']:
        if config['num_epochs_vae'] > 0:
            trainer_vae.train()

    if config['TRAIN_LSTM']:
        # create a lstm model class instance
        lstm_model = lstmKerasModel(data)

        # produce the embedding of all sequences for training of lstm model
        # process the windows in sequence to get their VAE embeddings
        lstm_model.produce_embeddings(config, model_vae, data, sess)

        # Create a basic model instance
        lstm_nn_model = lstm_model.create_lstm_model(config)
        lstm_nn_model.summary()   # Display the model's architecture
        # checkpoint path
        checkpoint_path = config['checkpoint_dir_lstm']\
                          + "cp.ckpt"
        # Create a callback that saves the model's weights
        cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                         save_weights_only=True,
                                                         verbose=1)
        # load weights if possible
        lstm_model.load_model(lstm_nn_model, config, checkpoint_path)

        # start training
        if config['num_epochs_lstm'] > 0:
            lstm_model.train(config, lstm_nn_model, cp_callback)

        # make a prediction on the test set using the trained model
        lstm_embedding = lstm_nn_model.predict(lstm_model.x_test, batch_size=config['batch_size_lstm'])
        print(lstm_embedding.shape)

        # visualise the first 10 test sequences
        for i in range(10):
            lstm_model.plot_lstm_embedding_prediction(i, config, model_vae, sess, data, lstm_embedding)


if __name__ == '__main__':
    main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74

五个支持程序如下:
base.py,做基本类支持

import tensorflow as tf
import tensorflow_probability as tfp
import random
import numpy as np
import time
import matplotlib.pylab as plt
from matplotlib.pyplot import plot, savefig, figure
from utils import count_trainable_variables
tfd = tfp.distributions


class BaseDataGenerator:
  def __init__(self, config):
    self.config = config

  # separate training and val sets
  def separate_train_and_val_set(self, n_win):
    n_train = int(np.floor((n_win * 0.9)))
    n_val = n_win - n_train
    idx_train = random.sample(range(n_win), n_train)
    idx_val = list(set(idx_train) ^ set(range(n_win)))
    return idx_train, idx_val, n_train, n_val


class BaseModel:
  def __init__(self, config):
    self.config = config
    # init the global step
    self.init_global_step()
    # init the epoch counter
    self.init_cur_epoch()
    self.two_pi = tf.constant(2 * np.pi)

  # save function that saves the checkpoint in the path defined in the config file
  def save(self, sess):
    print("Saving model...")
    self.saver.save(sess, self.config['checkpoint_dir'],
                    self.global_step_tensor)
    print("Model saved.")

  # load latest checkpoint from the experiment path defined in the config file
  def load(self, sess):
    print("checkpoint_dir at loading: {}".format(self.config['checkpoint_dir']))
    latest_checkpoint = tf.train.latest_checkpoint(self.config['checkpoint_dir'])

    if latest_checkpoint:
      print("Loading model checkpoint {} ...\n".format(latest_checkpoint))
      self.saver.restore(sess, latest_checkpoint)
      print("Model loaded.")
    else:
      print("No model loaded.")

  # initialize a tensorflow variable to use it as epoch counter
  def init_cur_epoch(self):
    with tf.variable_scope('cur_epoch'):
      self.cur_epoch_tensor = tf.Variable(0, trainable=False, name='cur_epoch')
      self.increment_cur_epoch_tensor = tf.assign(self.cur_epoch_tensor, self.cur_epoch_tensor + 1)

  # just initialize a tensorflow variable to use it as global step counter
  def init_global_step(self):
    # DON'T forget to add the global step tensor to the tensorflow trainer
    with tf.variable_scope('global_step'):
      self.global_step_tensor = tf.Variable(0, trainable=False, name='global_step')
      self.increment_global_step_tensor = tf.assign(
        self.global_step_tensor, self.global_step_tensor + 1)

  def define_loss(self):
    with tf.name_scope("loss"):
      # KL divergence loss - analytical result
      KL_loss = 0.5 * (tf.reduce_sum(tf.square(self.code_mean), 1)
                       + tf.reduce_sum(tf.square(self.code_std_dev), 1)
                       - tf.reduce_sum(tf.log(tf.square(self.code_std_dev)), 1)
                       - self.config['code_size'])
      self.KL_loss = tf.reduce_mean(KL_loss)

      # norm 1 of standard deviation of the sample-wise encoder prediction
      self.std_dev_norm = tf.reduce_mean(self.code_std_dev, axis=0)

      weighted_reconstruction_error_dataset = tf.reduce_sum(
        tf.square(self.original_signal - self.decoded), [1, 2])
      weighted_reconstruction_error_dataset = tf.reduce_mean(weighted_reconstruction_error_dataset)
      self.weighted_reconstruction_error_dataset = weighted_reconstruction_error_dataset / (2 * self.sigma2)

      # least squared reconstruction error
      ls_reconstruction_error = tf.reduce_sum(
        tf.square(self.original_signal - self.decoded), [1, 2])
      self.ls_reconstruction_error = tf.reduce_mean(ls_reconstruction_error)

      # sigma regularisor - input elbo
      self.sigma_regularisor_dataset = self.input_dims / 2 * tf.log(self.sigma2)
      two_pi = self.input_dims / 2 * tf.constant(2 * np.pi)

      self.elbo_loss = two_pi + self.sigma_regularisor_dataset + \
                       0.5 * self.weighted_reconstruction_error_dataset + self.KL_loss

  def training_variables(self):
    encoder_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "encoder")
    decoder_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "decoder")
    sigma_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "sigma2_dataset")
    self.train_vars_VAE = encoder_vars + decoder_vars + sigma_vars

    num_encoder = count_trainable_variables('encoder')
    num_decoder = count_trainable_variables('decoder')
    num_sigma2 = count_trainable_variables('sigma2_dataset')
    self.num_vars_total = num_decoder + num_encoder + num_sigma2
    print("Total number of trainable parameters in the VAE network is: {}".format(self.num_vars_total))

  def compute_gradients(self):
    self.lr = tf.placeholder(tf.float32, [])
    opt = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.9, beta2=0.95)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    gvs_dataset = opt.compute_gradients(self.elbo_loss, var_list=self.train_vars_VAE)
    print('gvs for dataset: {}'.format(gvs_dataset))
    capped_gvs = [(self.ClipIfNotNone(grad), var) for grad, var in gvs_dataset]

    with tf.control_dependencies(update_ops):
      self.train_step_gradient = opt.apply_gradients(capped_gvs)
    print("Reach the definition of loss for VAE")

  def ClipIfNotNone(self, grad):
    if grad is None:
      return grad
    return tf.clip_by_value(grad, -1, 1)

  def init_saver(self):
    self.saver = tf.train.Saver(max_to_keep=1, var_list=self.train_vars_VAE)


class BaseTrain:
  def __init__(self, sess, model, data, config):
    self.model = model
    self.config = config
    self.sess = sess
    self.data = data
    self.init = tf.group(tf.global_variables_initializer(),
                         tf.local_variables_initializer())
    self.sess.run(self.init)

    # keep a record of the training result
    self.train_loss = []
    self.val_loss = []
    self.train_loss_ave_epoch = []
    self.val_loss_ave_epoch = []
    self.recons_loss_train = []
    self.recons_loss_val = []
    self.KL_loss_train = []
    self.KL_loss_val = []
    self.sample_std_dev_train = []
    self.sample_std_dev_val = []
    self.iter_epochs_list = []
    self.test_sigma2 = []

  def train(self):
    self.start_time = time.time()
    for cur_epoch in range(0, self.config['num_epochs_vae'], 1):
      self.train_epoch()

      # compute current execution time
      self.current_time = time.time()
      elapsed_time = (self.current_time - self.start_time) / 60
      est_remaining_time = (
                                   self.current_time - self.start_time) / (cur_epoch + 1) * (
                                     self.config['num_epochs_vae'] - cur_epoch - 1)
      est_remaining_time = est_remaining_time / 60
      print("Already trained for {} min; Remaining {} min.".format(elapsed_time, est_remaining_time))
      self.sess.run(self.model.increment_cur_epoch_tensor)

  def save_variables_VAE(self):
    # save some variables for later inspection
    file_name = "{}{}-batch-{}-epoch-{}-code-{}-lr-{}.npz".format(self.config['result_dir'],
                                                                  self.config['exp_name'],
                                                                  self.config['batch_size'],
                                                                  self.config['num_epochs_vae'],
                                                                  self.config['code_size'],
                                                                  self.config['learning_rate_vae'])
    np.savez(file_name,
             iter_list_val=self.iter_epochs_list,
             train_loss=self.train_loss,
             val_loss=self.val_loss,
             n_train_iter=self.n_train_iter,
             n_val_iter=self.n_val_iter,
             recons_loss_train=self.recons_loss_train,
             recons_loss_val=self.recons_loss_val,
             KL_loss_train=self.KL_loss_train,
             KL_loss_val=self.KL_loss_val,
             num_para_all=self.model.num_vars_total,
             sigma2=self.test_sigma2)

  def plot_train_and_val_loss(self):
    # plot the training and validation loss over epochs
    plt.clf()
    figure(num=1, figsize=(8, 6))
    plot(self.train_loss, 'b-')
    plot(self.iter_epochs_list, self.val_loss_ave_epoch, 'r-')
    plt.legend(('training loss (total)', 'validation loss'))
    plt.title('training loss over iterations (val @ epochs)')
    plt.ylabel('total loss')
    plt.xlabel('iterations')
    plt.grid(True)
    savefig(self.config['result_dir'] + '/loss.png')

    # plot individual components of validation loss over epochs
    plt.clf()
    figure(num=1, figsize=(8, 6))
    plot(self.recons_loss_val, 'b-')
    plot(self.KL_loss_val, 'r-')
    plt.legend(('Reconstruction loss', 'KL loss'))
    plt.title('validation loss breakdown')
    plt.ylabel('loss')
    plt.xlabel('num of batch')
    plt.grid(True)
    savefig(self.config['result_dir'] + '/val-loss.png')

    # plot individual components of validation loss over epochs
    plt.clf()
    figure(num=1, figsize=(8, 6))
    plot(self.test_sigma2, 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/373460
推荐阅读
相关标签
  

闽ICP备14008679号