赞
踩
依据论文:ANOMALY DETECTION FOR TIME SERIES USING VAE-LSTM HYBRID MODEL(可在IEEE上自行寻找)
代码来源:github
运行环境:gpu
VAE-LSTM原理图:
以下可针对自己的需求进行适当更改。
上图原理可以简单理解为当数据输入时,先由VAE的编码器网络对输入数据进行压缩,并做特征提取,将提取到的特征输入LSTM网络进行故障检测或分类,并对特征进行归类预测,将预测得到的结果输入VAE解码器网络,进行重构,并计算重构损失,更新整体网络参数。VAE与LSTM二者结合,进一步提高模型诊断精度。(详细原理阐述可参见论文原文)
相关程序由6部分组成,一个训练主程序,5个支持子程序。(在加载数据阶段,分好训练集,交叉验证集与测试集)
训练主程序如下:
train.py
import os import tensorflow as tf from data_loader import DataGenerator from model import VAEmodel, lstmKerasModel from trainer import vaeTrainer from utils import process_config, create_dirs, get_args, save_config os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 os.environ["CUDA_VISIBLE_DEVICES"] = "0" def main(): # capture the config path from the run arguments # then process the json configuration file try: args = get_args() config = process_config(args.config) except: print("missing or invalid arguments") exit(0) # create the experiments dirs create_dirs([config['result_dir'], config['checkpoint_dir'], config['checkpoint_dir_lstm']]) # save the config in a txt file save_config(config) # create tensorflow session sess = tf.Session(config=tf.ConfigProto(log_device_placement=True)) # create your data generator data = DataGenerator(config) # create a CNN model model_vae = VAEmodel(config) # create a trainer for VAE model trainer_vae = vaeTrainer(sess, model_vae, data, config) model_vae.load(sess) # here you train your model if config['TRAIN_VAE']: if config['num_epochs_vae'] > 0: trainer_vae.train() if config['TRAIN_LSTM']: # create a lstm model class instance lstm_model = lstmKerasModel(data) # produce the embedding of all sequences for training of lstm model # process the windows in sequence to get their VAE embeddings lstm_model.produce_embeddings(config, model_vae, data, sess) # Create a basic model instance lstm_nn_model = lstm_model.create_lstm_model(config) lstm_nn_model.summary() # Display the model's architecture # checkpoint path checkpoint_path = config['checkpoint_dir_lstm']\ + "cp.ckpt" # Create a callback that saves the model's weights cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1) # load weights if possible lstm_model.load_model(lstm_nn_model, config, checkpoint_path) # start training if config['num_epochs_lstm'] > 0: lstm_model.train(config, lstm_nn_model, cp_callback) # make a prediction on the test set using the trained model lstm_embedding = lstm_nn_model.predict(lstm_model.x_test, batch_size=config['batch_size_lstm']) print(lstm_embedding.shape) # visualise the first 10 test sequences for i in range(10): lstm_model.plot_lstm_embedding_prediction(i, config, model_vae, sess, data, lstm_embedding) if __name__ == '__main__': main()
五个支持程序如下:
base.py,做基本类支持
import tensorflow as tf import tensorflow_probability as tfp import random import numpy as np import time import matplotlib.pylab as plt from matplotlib.pyplot import plot, savefig, figure from utils import count_trainable_variables tfd = tfp.distributions class BaseDataGenerator: def __init__(self, config): self.config = config # separate training and val sets def separate_train_and_val_set(self, n_win): n_train = int(np.floor((n_win * 0.9))) n_val = n_win - n_train idx_train = random.sample(range(n_win), n_train) idx_val = list(set(idx_train) ^ set(range(n_win))) return idx_train, idx_val, n_train, n_val class BaseModel: def __init__(self, config): self.config = config # init the global step self.init_global_step() # init the epoch counter self.init_cur_epoch() self.two_pi = tf.constant(2 * np.pi) # save function that saves the checkpoint in the path defined in the config file def save(self, sess): print("Saving model...") self.saver.save(sess, self.config['checkpoint_dir'], self.global_step_tensor) print("Model saved.") # load latest checkpoint from the experiment path defined in the config file def load(self, sess): print("checkpoint_dir at loading: {}".format(self.config['checkpoint_dir'])) latest_checkpoint = tf.train.latest_checkpoint(self.config['checkpoint_dir']) if latest_checkpoint: print("Loading model checkpoint {} ...\n".format(latest_checkpoint)) self.saver.restore(sess, latest_checkpoint) print("Model loaded.") else: print("No model loaded.") # initialize a tensorflow variable to use it as epoch counter def init_cur_epoch(self): with tf.variable_scope('cur_epoch'): self.cur_epoch_tensor = tf.Variable(0, trainable=False, name='cur_epoch') self.increment_cur_epoch_tensor = tf.assign(self.cur_epoch_tensor, self.cur_epoch_tensor + 1) # just initialize a tensorflow variable to use it as global step counter def init_global_step(self): # DON'T forget to add the global step tensor to the tensorflow trainer with tf.variable_scope('global_step'): self.global_step_tensor = tf.Variable(0, trainable=False, name='global_step') self.increment_global_step_tensor = tf.assign( self.global_step_tensor, self.global_step_tensor + 1) def define_loss(self): with tf.name_scope("loss"): # KL divergence loss - analytical result KL_loss = 0.5 * (tf.reduce_sum(tf.square(self.code_mean), 1) + tf.reduce_sum(tf.square(self.code_std_dev), 1) - tf.reduce_sum(tf.log(tf.square(self.code_std_dev)), 1) - self.config['code_size']) self.KL_loss = tf.reduce_mean(KL_loss) # norm 1 of standard deviation of the sample-wise encoder prediction self.std_dev_norm = tf.reduce_mean(self.code_std_dev, axis=0) weighted_reconstruction_error_dataset = tf.reduce_sum( tf.square(self.original_signal - self.decoded), [1, 2]) weighted_reconstruction_error_dataset = tf.reduce_mean(weighted_reconstruction_error_dataset) self.weighted_reconstruction_error_dataset = weighted_reconstruction_error_dataset / (2 * self.sigma2) # least squared reconstruction error ls_reconstruction_error = tf.reduce_sum( tf.square(self.original_signal - self.decoded), [1, 2]) self.ls_reconstruction_error = tf.reduce_mean(ls_reconstruction_error) # sigma regularisor - input elbo self.sigma_regularisor_dataset = self.input_dims / 2 * tf.log(self.sigma2) two_pi = self.input_dims / 2 * tf.constant(2 * np.pi) self.elbo_loss = two_pi + self.sigma_regularisor_dataset + \ 0.5 * self.weighted_reconstruction_error_dataset + self.KL_loss def training_variables(self): encoder_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "encoder") decoder_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "decoder") sigma_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "sigma2_dataset") self.train_vars_VAE = encoder_vars + decoder_vars + sigma_vars num_encoder = count_trainable_variables('encoder') num_decoder = count_trainable_variables('decoder') num_sigma2 = count_trainable_variables('sigma2_dataset') self.num_vars_total = num_decoder + num_encoder + num_sigma2 print("Total number of trainable parameters in the VAE network is: {}".format(self.num_vars_total)) def compute_gradients(self): self.lr = tf.placeholder(tf.float32, []) opt = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.9, beta2=0.95) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) gvs_dataset = opt.compute_gradients(self.elbo_loss, var_list=self.train_vars_VAE) print('gvs for dataset: {}'.format(gvs_dataset)) capped_gvs = [(self.ClipIfNotNone(grad), var) for grad, var in gvs_dataset] with tf.control_dependencies(update_ops): self.train_step_gradient = opt.apply_gradients(capped_gvs) print("Reach the definition of loss for VAE") def ClipIfNotNone(self, grad): if grad is None: return grad return tf.clip_by_value(grad, -1, 1) def init_saver(self): self.saver = tf.train.Saver(max_to_keep=1, var_list=self.train_vars_VAE) class BaseTrain: def __init__(self, sess, model, data, config): self.model = model self.config = config self.sess = sess self.data = data self.init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) self.sess.run(self.init) # keep a record of the training result self.train_loss = [] self.val_loss = [] self.train_loss_ave_epoch = [] self.val_loss_ave_epoch = [] self.recons_loss_train = [] self.recons_loss_val = [] self.KL_loss_train = [] self.KL_loss_val = [] self.sample_std_dev_train = [] self.sample_std_dev_val = [] self.iter_epochs_list = [] self.test_sigma2 = [] def train(self): self.start_time = time.time() for cur_epoch in range(0, self.config['num_epochs_vae'], 1): self.train_epoch() # compute current execution time self.current_time = time.time() elapsed_time = (self.current_time - self.start_time) / 60 est_remaining_time = ( self.current_time - self.start_time) / (cur_epoch + 1) * ( self.config['num_epochs_vae'] - cur_epoch - 1) est_remaining_time = est_remaining_time / 60 print("Already trained for {} min; Remaining {} min.".format(elapsed_time, est_remaining_time)) self.sess.run(self.model.increment_cur_epoch_tensor) def save_variables_VAE(self): # save some variables for later inspection file_name = "{}{}-batch-{}-epoch-{}-code-{}-lr-{}.npz".format(self.config['result_dir'], self.config['exp_name'], self.config['batch_size'], self.config['num_epochs_vae'], self.config['code_size'], self.config['learning_rate_vae']) np.savez(file_name, iter_list_val=self.iter_epochs_list, train_loss=self.train_loss, val_loss=self.val_loss, n_train_iter=self.n_train_iter, n_val_iter=self.n_val_iter, recons_loss_train=self.recons_loss_train, recons_loss_val=self.recons_loss_val, KL_loss_train=self.KL_loss_train, KL_loss_val=self.KL_loss_val, num_para_all=self.model.num_vars_total, sigma2=self.test_sigma2) def plot_train_and_val_loss(self): # plot the training and validation loss over epochs plt.clf() figure(num=1, figsize=(8, 6)) plot(self.train_loss, 'b-') plot(self.iter_epochs_list, self.val_loss_ave_epoch, 'r-') plt.legend(('training loss (total)', 'validation loss')) plt.title('training loss over iterations (val @ epochs)') plt.ylabel('total loss') plt.xlabel('iterations') plt.grid(True) savefig(self.config['result_dir'] + '/loss.png') # plot individual components of validation loss over epochs plt.clf() figure(num=1, figsize=(8, 6)) plot(self.recons_loss_val, 'b-') plot(self.KL_loss_val, 'r-') plt.legend(('Reconstruction loss', 'KL loss')) plt.title('validation loss breakdown') plt.ylabel('loss') plt.xlabel('num of batch') plt.grid(True) savefig(self.config['result_dir'] + '/val-loss.png') # plot individual components of validation loss over epochs plt.clf() figure(num=1, figsize=(8, 6)) plot(self.test_sigma2,
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。