当前位置:   article > 正文

【keras原理解析】Keras神经网络运行源码深入解析_for batch_index, batch_con in enumerate

for batch_index, batch_con in enumerate

model.fit(X_train,y_train,batch_size=BATCH_SIZE,nb_epoch=1,validation_data=(X_val,y_val))

以上是keras进行model训练的fit代码,它真正的实现流程是怎样的呢?

以上最终调用的是training.Model.fit()方法,在fit方法主要进行步骤如下:

  1. 模型参数的处理,验证数据的合法性相关的准备工作
  2. 准备好模型的输入数据和训练相关的函数

以上准备工作最好后,将后续的工作delegate委托给training_arrays.fit_loop()方法,撇开数据的处理、准备,训练的主要代码是这段循环,非常关键:

  1. callbacks.set_model(callback_model)
  2. callbacks.set_params({
  3. 'batch_size': batch_size,
  4. 'epochs': epochs,
  5. 'steps': steps_per_epoch,
  6. 'samples': num_train_samples,
  7. 'verbose': verbose,
  8. 'do_validation': do_validation,
  9. 'metrics': callback_metrics or [],
  10. })
  11. callbacks.on_train_begin()
  12. for epoch in range(initial_epoch, nb_epoch):
  13. # 记录本回epoch的历史信息
  14. callbacks.on_epoch_begin(epoch)
  15. # 按照batch批次打混索引
  16. if shuffle == 'batch':
  17. index_array = batch_shuffle(index_array, batch_size)
  18. elif shuffle:
  19. np.random.shuffle(index_array)
  20. # 得到一个批次的索引
  21. batches = make_batches(nb_train_sample, batch_size)
  22. epoch_logs = {}
  23. #........
  24. #省略逻辑见下 部分
  25. #........
  26. callbacks.on_epoch_end(epoch, epoch_logs)
  27. if callback_model.stop_training:
  28. break
  29. callbacks.on_train_end()

以上{for epoch in }代码逻辑主要是对每个epoch进行循环,其中核心针对每个batch的处理见下代码:=

  1. for batch_index, (batch_start, batch_end) in enumerate(batches):
  2. batch_ids = index_array[batch_start:batch_end]
  3. try:
  4. if isinstance(ins[-1], float):
  5. # Do not slice the training phase flag.
  6. ins_batch = slice_arrays(
  7. ins[:-1], batch_ids) + [ins[-1]]
  8. else:
  9. ins_batch = slice_arrays(ins, batch_ids)
  10. except TypeError:
  11. raise TypeError('TypeError while preparing batch. '
  12. 'If using HDF5 input data, '
  13. 'pass shuffle="batch".')
  14. batch_logs = {}
  15. batch_logs['batch'] = batch_index
  16. batch_logs['size'] = len(batch_ids)
  17. #回调:每个batch的开始处:logs包含size,即当前batch的样本数
  18. callbacks.on_batch_begin(batch_index, batch_logs)
  19. for i in indices_for_conversion_to_dense:
  20. ins_batch[i] = ins_batch[i].toarray()
  21. outs = f(ins_batch)
  22. outs = to_list(outs)
  23. for l, o in zip(out_labels, outs):
  24. batch_logs[l] = o
  25. #回调:batch结束:logs包含loss,若启用accuracy则还包含acc
  26. callbacks.on_batch_end(batch_index, batch_logs)
  27. if callback_model.stop_training:
  28. break
  29. if batch_index == len(batches) - 1: # Last batch.
  30. if do_validation:
  31. val_outs = test_loop(model, val_f, val_ins,
  32. batch_size=batch_size,
  33. verbose=0)
  34. val_outs = to_list(val_outs)
  35. # Same labels assumed.
  36. for l, o in zip(out_labels, val_outs):
  37. epoch_logs['val_' + l] = o

【1、回调函数callback】

以上就是整个fit_loop()函数的调用代码,其中代码关键点都存在回调函数:

  1. on_epoch_begin: 在每个epoch开始时调用
  2. on_epoch_end: 在每个epoch结束时调用
  3. on_batch_begin: 在每个batch开始时调用
  4. on_batch_end: 在每个batch结束时调用
  5. on_train_begin: 在训练开始时调用
  6. on_train_end: 在训练结束时调用

其中回调函数on_batch_end是主要的回调函数,e.gkeras.callbacks.BaseLogger

统计该batch里面训练的loss以及acc的值,计入totals,乘以batch_size后。

  1. def on_batch_end(self, batch, logs=None):
  2. logs = logs or {}
  3. batch_size = logs.get('size', 0)
  4. self.seen += batch_size
  5. for k, v in logs.items():
  6. if k in self.stateful_metrics:
  7. self.totals[k] = v
  8. else:
  9. if k in self.totals:
  10. self.totals[k] += v * batch_size
  11. else:
  12. self.totals[k] = v * batch_size

其中回调函数on_epoch_end,e.gkeras.callbacks.BaseLogger

这个类的on_epoch_end函数里,执行对这个epoch训练数据的loss以及acc求平均值。

  1. def on_epoch_end(self, epoch, logs=None):
  2. if logs is not None:
  3. for k in self.params['metrics']:
  4. if k in self.totals:
  5. # Make value available to next callbacks.
  6. if k in self.stateful_metrics:
  7. logs[k] = self.totals[k]
  8. else:
  9. logs[k] = self.totals[k] / self.seen

补充:

keras.callbacks.ModelCheckpoint

在on_epoch_end时会保存模型数据进入文件

keras.callbacks.History

主要记录每一次epoch训练的结果,结果包含loss以及acc的值

keras.callbacks.ProgbarLogger

这个函数里面实现训练中间状态数据信息的输出,主要涉及进度相关信息。

【2、outs = f(ins_batch)】

其中函数f()是作为参数传递进入,通过debug我们进行调试,发现直接是进入了Keras后端,进行处理,这样符合keras是基于tf进行的二次封装这前提,这而就是调用不同后端引擎的函数。

经过部分数据检验后,进入到tensorflow_backend.Function._call进行真正的tf操作,其中Function类就是提供众多Tensorflow中运算图的工具。

  1. def _call(self, inputs):
  2. if not isinstance(inputs, (list, tuple)):
  3. raise TypeError('`inputs` should be a list or tuple.')
  4. session = get_session()
  5. feed_arrays = []
  6. array_vals = []
  7. feed_symbols = []
  8. symbol_vals = []
  9. #数据处理转换
  10. for tensor, value in zip(self.inputs, inputs):
  11. if value is None:
  12. continue
  13. if is_tensor(value):
  14. # Case: feeding symbolic tensor.
  15. feed_symbols.append(tensor)
  16. symbol_vals.append(value)
  17. else:
  18. feed_arrays.append(tensor)
  19. # We need to do array conversion and type casting
  20. # at this level, since
  21. # `callable_fn` only supports exact matches.
  22. array_vals.append(
  23. np.asarray(value,
  24. dtype=tf.as_dtype(tensor.dtype).as_numpy_dtype))
  25. if self.feed_dict:
  26. for key in sorted(self.feed_dict.keys()):
  27. array_vals.append(
  28. np.asarray(self.feed_dict[key],
  29. dtype=tf.as_dtype(key.dtype).as_numpy_dtype))
  30. # Refresh callable if anything has changed.
  31. if (self._callable_fn is None or
  32. feed_arrays != self._feed_arrays or
  33. symbol_vals != self._symbol_vals or
  34. feed_symbols != self._feed_symbols or
  35. session != self._session):
  36. #生成一个可以调用的graph
  37. self._make_callable(feed_arrays,
  38. feed_symbols,
  39. symbol_vals,
  40. session)
  41. #运行graph
  42. if self.run_metadata:
  43. fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
  44. else:
  45. fetched = self._callable_fn(*array_vals)
  46. #返回结果
  47. return fetched[:len(self.outputs)]

总结:

1、Keras调用tf进行计算,是分batch进行操作,每个batch结束keras可以对返回进行相应的存储等操作。

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/133733
推荐阅读
相关标签
  

闽ICP备14008679号