赞
踩
keras中文官方文档
keras是对tensorflow的high-level封装,之前一直都用pytorch,但是pytorch训练出来的模型不太好部署应用,最近需要使用keras训练模型,因此简单了解了一些keras的训练推断使用方法。总体而言,keras上手比较简单,搭建网络图纸计算也很方便,给的API接口对训练调参很友好(简直就是调参器:),本文写了一部分训练流程,随着我学习的深入会不断更新本教程。
只要安装了keras-gpu版本,自动调用GPU进行训练,默认只会单卡训练,
一定要注意,使用pip或keras安装
conda install keras(-gpu)
,不带gpu只会装cpu版本!!
os.environ["CUDA_VISIBLE_DEVICES"]="xx"
指定可见显卡
tf会申请资源,但是不一定全都是利用的,可以显式地控制显存资源的用量
- import tensorflow as tf
- from keras.backend.tensorflow_backend import set_session
- config= tf.ConfigProto()
- config.gpu_options.per_process_gpu_memory_fraction = set_session(tf.Session(config=config))
多卡训练:
只需要在外层使用multi_gpu_model() wrap一下并返回就OK了,默认会调用所有的GPU资源
- from keras.utils import multi_gpu_model
- multi_train_model = multi_gpu_model(model, gpus=, cpu_relocation=False)#对原来的model外层包一下并返回
该方法的并行方式也是将模型copy到多个GPU上面,训练时将batch_size平均分配到多个GPU上面,最后再reduce到CPU或者GPU上处理后更新多个GPU的权重参数 注意:如果batch_size=1,原则上是无法进行多卡并行的,但是keras会在可见GPUs上面开辟空间,但是只会在一张卡上训练,因此会看到只有一张卡的utils在活动,其余卡只会占用显存而不具体工作
- from keras import Input, Model#搭建模型输入库,模型库
- from keras.applications.vgg16/resnet50 import VGG16/ResNet50#网络模型库
- from keras.layers import Concatenate, Conv2D, upSampling2D, BatchNormalization#网络层
- from keras.utils import multi_gpu_model#多GPU训练
- from keras.optimizers import Adam, SGD#优化器
- from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard#训练时后端callbacks
- import tensorflow as tf
- from keras import Input, Model
- from keras.applications.vgg import VGG16
- class Network:
- def __init__(self, *argv, **kwargs):
- self._input = Input(name',#定义一个Input的实例
- shape=(),
- dtype='float')
- vgg = VGG16(input_tensor=self.input_img#送入输入的input_tensor
- weights='imagenet',
- dtype='float32')
- def build_graph(self):
- Conv2D(**kwargs)(vgg)...#此处定义自己的网络层
- return Model(inputs=self._input, outputs=)#
- net_instance = Network(...)
- model = net_instance.build_graph()
在该过程中需要指定loss函数和优化器,compile完成后相当于固定网络的计算图纸
model.compile(loss=, optimizer=)
model.summary()#打印网络信息,包括网络层数,中间层尺寸变化,以及参数量等信息
generator=送入需要训练的数据,需要自己定义数据生成的方法
- model.fit_generator(generator=data_gen_func(),
- steps_per_epoch=,#每个epoch迭代多少次,注意 step * batchsize = total_nums
- epochs=,#共训练多少个Epoch
- validation_data=,
- verbose=,
- initial_epoch=,#开始训练的epoch
- callbacks= [ModelCheckpoint(filepath=,#存储文件路径,str最后预留剩余%s可以在命名时保存训练的一些loss
- save_best_only=,#之存储
- verbose=),
- TensorBoard(log_dir=,
- histogram_freq=0,
- write_graph=True,
- writer_images=False)])
- model.save(filepath=)#保存网络对象
- model.save_weights(filepath=)#只保存网络权重
介绍:
callbacks: 训练时后台的一些调用服务,常用的有ModelCheckpoint和TensorBoard..
Model.predict(x)#x为输入测试数据
pred = model.predict(x, batch_size=None, steps=None)
x:输入数据,Numpy array
输出即为网络output
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。