这篇文章记录了keras中常用的四种模型保存和加载,包括h5(常用)、json和yaml、pb(这种模式在opencv中可以直接读取)。
keras可以保存的模型种类
- h5:此格式需要python提前安装h5py库,
sudo pip install h5py
- json
- yaml
- pb
几种文件格式的区别
h5
我们可以将网络结构和权重一起保存为.h5文件,另外还包括模型训练时的配置(包括损失函数、优化器等),以及优化器的状态,比如优化器中变化的参数的状态等,优点是:操作非常方便,代码很少,缺点是:占用空间较大
model.save('my_model.h5') #用于保存模型
model = load_model('my_model.h5') #用于加载模型
当然,我们也可以单独将网络的权重保存为.h5文件,并不保存结构。
#只保存了网络的权重
model.save_weights('my_model_weights.h5')
#将网络初始化
...
#加载权重到搭建的网络
model.load_weights('my_model_weights.h5')
json和yaml
一般只保存网络的结构,并不保存权重,配合save_weights
和load_weights
使用
import json
import yaml
json_model = model.to_json() #将网络结构保存为json字符串
open('model.json', 'w').write(json_model) #写入json文件
# yaml_model = model.to_yaml() #将网络结构保存为yaml字符串
# open('model.yaml', 'w').write(yaml_model) #写入yaml文件
model.save_weights('my_model_weights.h5') #单独保存权重参数
from keras.models import model_from_json,model_from_yaml
#从文件中加载网络结构
json_model = model_from_json(open('model.json').read())
# yaml_model = model_from_yaml(open('model.yaml').read())
#单独加载权重参数
model.load_weights('my_model_weights.h5')
pb
pb格式的文件里是保存了网络的结构和权重的,方便tensorflow和opencv调用。
#1、定义网络结构
model = ...
#2、输出节点名称,用于查看output_node_names
for layer in model.layers:
print(layer.output.name)
#3、生成pb格式的模型
sess = K.get_session() #获得当前的sess
#这个是将模型的结构和权重固化,注意第二步输出的节点名称形式为dense_1/Softmax:0,但是这里传入时需要将:0删除,不然会提示找不到tensor,不太清楚原因,之前用tensorflow时是需要写全的
frozen_def = tf.graph_util.convert_variables_to_constants(sess,sess.graph.as_graph_def(),output_node_names=['dense_1/Softmax'])
tf.train.write_graph(frozen_def, './', 'test_model.pb', as_text=False) #将网络的结构和权重写进pb文件中
#4、在tensorflow中使用pb模型
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open('test_model.pb', "rb") as f: #打开pb模型文件
output_graph_def.ParseFromString(f.read()) #进行解析
_ = tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess: #下面就是使用模型
sess.run(tf.global_variables_initializer())
#下面的input_1:0和dense_1/Softmax:0都是通过步骤2查看的
#然而这里节点名称却带了:0,我猜测是keras的graph中保存的节点名是不包含:0的,当你选择保存为pb格式时,此时注意我们是使用tensorflow的形式进行保存的,所以又都被加上了:0,当你读取的时候自然也要加上:0
input_x = sess.graph.get_tensor_by_name("input_1:0")
output = sess.graph.get_tensor_by_name("dense_1/Softmax:0")
print(sess.run(output, feed_dict={input_x: img}))