keras中模型保存的4种格式

这篇文章记录了keras中常用的四种模型保存和加载,包括h5(常用)、json和yaml、pb(这种模式在opencv中可以直接读取)。

keras可以保存的模型种类

  • h5:此格式需要python提前安装h5py库,sudo pip install h5py
  • json
  • yaml
  • pb

几种文件格式的区别

h5

我们可以将网络结构和权重一起保存为.h5文件,另外还包括模型训练时的配置(包括损失函数、优化器等),以及优化器的状态,比如优化器中变化的参数的状态等,优点是:操作非常方便,代码很少,缺点是:占用空间较大

model.save('my_model.h5') #用于保存模型

model = load_model('my_model.h5') #用于加载模型

当然,我们也可以单独将网络的权重保存为.h5文件,并不保存结构。

#只保存了网络的权重
model.save_weights('my_model_weights.h5') 

#将网络初始化
...

#加载权重到搭建的网络
model.load_weights('my_model_weights.h5')

json和yaml

一般只保存网络的结构,并不保存权重,配合save_weightsload_weights使用

import json
import yaml

json_model = model.to_json() #将网络结构保存为json字符串
open('model.json', 'w').write(json_model) #写入json文件

# yaml_model = model.to_yaml() #将网络结构保存为yaml字符串
# open('model.yaml', 'w').write(yaml_model) #写入yaml文件

model.save_weights('my_model_weights.h5') #单独保存权重参数
from keras.models import model_from_json,model_from_yaml
#从文件中加载网络结构
json_model = model_from_json(open('model.json').read()) 

# yaml_model = model_from_yaml(open('model.yaml').read())

#单独加载权重参数
model.load_weights('my_model_weights.h5') 

pb

pb格式的文件里是保存了网络的结构和权重的,方便tensorflow和opencv调用。

#1、定义网络结构
model = ...

#2、输出节点名称,用于查看output_node_names
for layer in model.layers:
    print(layer.output.name)

#3、生成pb格式的模型
sess = K.get_session() #获得当前的sess

#这个是将模型的结构和权重固化,注意第二步输出的节点名称形式为dense_1/Softmax:0,但是这里传入时需要将:0删除,不然会提示找不到tensor,不太清楚原因,之前用tensorflow时是需要写全的
frozen_def = tf.graph_util.convert_variables_to_constants(sess,sess.graph.as_graph_def(),output_node_names=['dense_1/Softmax']) 

tf.train.write_graph(frozen_def, './', 'test_model.pb', as_text=False) #将网络的结构和权重写进pb文件中


#4、在tensorflow中使用pb模型
with tf.Graph().as_default():
    output_graph_def = tf.GraphDef()

    with open('test_model.pb', "rb") as f: #打开pb模型文件
        output_graph_def.ParseFromString(f.read()) #进行解析
        _ = tf.import_graph_def(output_graph_def, name="")

    with tf.Session() as sess: #下面就是使用模型
        sess.run(tf.global_variables_initializer()) 
        #下面的input_1:0和dense_1/Softmax:0都是通过步骤2查看的
        #然而这里节点名称却带了:0,我猜测是keras的graph中保存的节点名是不包含:0的,当你选择保存为pb格式时,此时注意我们是使用tensorflow的形式进行保存的,所以又都被加上了:0,当你读取的时候自然也要加上:0
        input_x = sess.graph.get_tensor_by_name("input_1:0")
        output = sess.graph.get_tensor_by_name("dense_1/Softmax:0")
        print(sess.run(output, feed_dict={input_x: img}))

发表评论

电子邮件地址不会被公开。