赞
踩
model.fit(X_train,y_train,batch_size=BATCH_SIZE,nb_epoch=1,validation_data=(X_val,y_val))
以上是keras进行model训练的fit代码,它真正的实现流程是怎样的呢?
以上最终调用的是training.Model.fit()方法,在fit方法主要进行步骤如下:
以上准备工作最好后,将后续的工作delegate委托给training_arrays.fit_loop()方法,撇开数据的处理、准备,训练的主要代码是这段循环,非常关键:
- callbacks.set_model(callback_model)
- callbacks.set_params({
- 'batch_size': batch_size,
- 'epochs': epochs,
- 'steps': steps_per_epoch,
- 'samples': num_train_samples,
- 'verbose': verbose,
- 'do_validation': do_validation,
- 'metrics': callback_metrics or [],
- })
- callbacks.on_train_begin()
- for epoch in range(initial_epoch, nb_epoch):
- # 记录本回epoch的历史信息
- callbacks.on_epoch_begin(epoch)
- # 按照batch批次打混索引
- if shuffle == 'batch':
- index_array = batch_shuffle(index_array, batch_size)
- elif shuffle:
- np.random.shuffle(index_array)
- # 得到一个批次的索引
- batches = make_batches(nb_train_sample, batch_size)
- epoch_logs = {}
- #........
- #省略逻辑见下 部分
- #........
- callbacks.on_epoch_end(epoch, epoch_logs)
- if callback_model.stop_training:
- break
-
- callbacks.on_train_end()
以上{for epoch in }代码逻辑主要是对每个epoch进行循环,其中核心针对每个batch的处理见下代码:=
- for batch_index, (batch_start, batch_end) in enumerate(batches):
- batch_ids = index_array[batch_start:batch_end]
- try:
- if isinstance(ins[-1], float):
- # Do not slice the training phase flag.
- ins_batch = slice_arrays(
- ins[:-1], batch_ids) + [ins[-1]]
- else:
- ins_batch = slice_arrays(ins, batch_ids)
- except TypeError:
- raise TypeError('TypeError while preparing batch. '
- 'If using HDF5 input data, '
- 'pass shuffle="batch".')
- batch_logs = {}
- batch_logs['batch'] = batch_index
- batch_logs['size'] = len(batch_ids)
-
- #回调:每个batch的开始处:logs包含size,即当前batch的样本数
- callbacks.on_batch_begin(batch_index, batch_logs)
- for i in indices_for_conversion_to_dense:
- ins_batch[i] = ins_batch[i].toarray()
-
- outs = f(ins_batch)
- outs = to_list(outs)
- for l, o in zip(out_labels, outs):
- batch_logs[l] = o
- #回调:batch结束:logs包含loss,若启用accuracy则还包含acc
- callbacks.on_batch_end(batch_index, batch_logs)
- if callback_model.stop_training:
- break
-
- if batch_index == len(batches) - 1: # Last batch.
- if do_validation:
- val_outs = test_loop(model, val_f, val_ins,
- batch_size=batch_size,
- verbose=0)
- val_outs = to_list(val_outs)
- # Same labels assumed.
- for l, o in zip(out_labels, val_outs):
- epoch_logs['val_' + l] = o
【1、回调函数callback】
以上就是整个fit_loop()函数的调用代码,其中代码关键点都存在回调函数:
其中回调函数on_batch_end是主要的回调函数,e.gkeras.callbacks.BaseLogger
统计该batch里面训练的loss以及acc的值,计入totals,乘以batch_size后。
- def on_batch_end(self, batch, logs=None):
- logs = logs or {}
- batch_size = logs.get('size', 0)
- self.seen += batch_size
-
- for k, v in logs.items():
- if k in self.stateful_metrics:
- self.totals[k] = v
- else:
- if k in self.totals:
- self.totals[k] += v * batch_size
- else:
- self.totals[k] = v * batch_size
其中回调函数on_epoch_end,e.gkeras.callbacks.BaseLogger
这个类的on_epoch_end函数里,执行对这个epoch训练数据的loss以及acc求平均值。
- def on_epoch_end(self, epoch, logs=None):
- if logs is not None:
- for k in self.params['metrics']:
- if k in self.totals:
- # Make value available to next callbacks.
- if k in self.stateful_metrics:
- logs[k] = self.totals[k]
- else:
- logs[k] = self.totals[k] / self.seen
补充:
keras.callbacks.ModelCheckpoint
在on_epoch_end时会保存模型数据进入文件
keras.callbacks.History
主要记录每一次epoch训练的结果,结果包含loss以及acc的值
keras.callbacks.ProgbarLogger
这个函数里面实现训练中间状态数据信息的输出,主要涉及进度相关信息。
【2、outs = f(ins_batch)】
其中函数f()是作为参数传递进入,通过debug我们进行调试,发现直接是进入了Keras后端,进行处理,这样符合keras是基于tf进行的二次封装这前提,这而就是调用不同后端引擎的函数。
经过部分数据检验后,进入到tensorflow_backend.Function._call进行真正的tf操作,其中Function类就是提供众多Tensorflow中运算图的工具。
- def _call(self, inputs):
- if not isinstance(inputs, (list, tuple)):
- raise TypeError('`inputs` should be a list or tuple.')
-
- session = get_session()
- feed_arrays = []
- array_vals = []
- feed_symbols = []
- symbol_vals = []
- #数据处理转换
- for tensor, value in zip(self.inputs, inputs):
- if value is None:
- continue
- if is_tensor(value):
- # Case: feeding symbolic tensor.
- feed_symbols.append(tensor)
- symbol_vals.append(value)
- else:
- feed_arrays.append(tensor)
- # We need to do array conversion and type casting
- # at this level, since
- # `callable_fn` only supports exact matches.
- array_vals.append(
- np.asarray(value,
- dtype=tf.as_dtype(tensor.dtype).as_numpy_dtype))
- if self.feed_dict:
- for key in sorted(self.feed_dict.keys()):
- array_vals.append(
- np.asarray(self.feed_dict[key],
- dtype=tf.as_dtype(key.dtype).as_numpy_dtype))
-
- # Refresh callable if anything has changed.
- if (self._callable_fn is None or
- feed_arrays != self._feed_arrays or
- symbol_vals != self._symbol_vals or
- feed_symbols != self._feed_symbols or
- session != self._session):
- #生成一个可以调用的graph
- self._make_callable(feed_arrays,
- feed_symbols,
- symbol_vals,
- session)
- #运行graph
- if self.run_metadata:
- fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
- else:
- fetched = self._callable_fn(*array_vals)
- #返回结果
- return fetched[:len(self.outputs)]
总结:
1、Keras调用tf进行计算,是分batch进行操作,每个batch结束keras可以对返回进行相应的存储等操作。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。