tensorflow2.0保存和恢复模型3种方法

 更新时间:2020年02月03日 16:31:55   作者:李宜君  
今天小编就为大家分享一篇tensorflow2.0保存和恢复模型3种方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
(福利推荐:【腾讯云】服务器最新限时优惠活动,云服务器1核2G仅99元/年、2核4G仅768元/3年,立即抢购>>>:9i0i.cn/qcloud

(福利推荐:你还在原价购买阿里云服务器?现在阿里云0.8折限时抢购活动来啦!4核8G企业云服务器仅2998元/3年,立即抢购>>>:9i0i.cn/aliyun

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights(
 filepath,
 overwrite=True,
 save_format=None
)

Arguments:

filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights(
 filepath,
 by_name=False
)

实例1:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')
 
# step7 删除模型
del model
 
# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')
 
# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

save(
 filepath,
 overwrite=True,
 include_optimizer=True,
 save_format=None
)

Arguments:

filepath: String, path to SavedModel or H5 file to save the model.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

include_optimizer: If True, save optimizer's state together.

save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).

实例2:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
 
# step7 删除模型
del model # deletes the existing model
 
 
# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')
 
# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.94%

Restored model, accuracy:96.94%

方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型

该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用

以上这篇tensorflow2.0保存和恢复模型3种方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持程序员之家。

相关文章

  • python3 json数据格式的转换(dumps/loads的使用、dict to str/str to dict、json字符串/字典的相互转换)

    python3 json数据格式的转换(dumps/loads的使用、dict to str/str to dict、j

    JSON (JavaScript Object Notation) 是一种轻量级的数据交换格式。它基于ECMAScript的一个子集。这篇文章主要介绍了python3 json数据格式的转换(dumps/loads的使用、dict to str/str to dict、json字符串/字典的相互转换) ,需要的朋友可以参考下
    2019-04-04
  • python实现对服务器脚本敏感信息的加密解密功能

    python实现对服务器脚本敏感信息的加密解密功能

    这篇文章主要介绍了python实现对服务器脚本敏感信息的加密解密功能,本文给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-08-08
  • 深入探索Python解码神器Chardet自动检测文本编码

    深入探索Python解码神器Chardet自动检测文本编码

    Chardet,洞察编码的清晰水晶球,一个让你与编码不再“失联”的神器,本文带大家走近这个隐藏在Python工具箱中的小宝贝,探索它的秘密
    2024-01-01
  • 简单谈谈python的反射机制

    简单谈谈python的反射机制

    本文主要介绍python中的反射,以及该机制的简单应用,熟悉JAVA的程序员,一定经常和Class.forName打交道。在很多框架中(Spring,eclipse plugin机制)都依赖于JAVA的反射能力,而在python中,也同样有着强大的反射能力,本文将做简单的介绍
    2016-06-06
  • 16行Python代码实现微信聊天机器人并自动智能回复功能

    16行Python代码实现微信聊天机器人并自动智能回复功能

    聊天机器人自动智能回复给我们的生活带来了极大的便利,尤其在业务比较繁忙的时候,智能机器人给我们带来极大的方便,今天小编教大家一招通过16行代码实现微信聊天智能机器人,感兴趣的朋友一起看看吧
    2022-01-01
  • 对pandas处理json数据的方法详解

    对pandas处理json数据的方法详解

    今天小编就为大家分享一篇对pandas处理json数据的方法详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-02-02
  • Python抽象和自定义类定义与用法示例

    Python抽象和自定义类定义与用法示例

    这篇文章主要介绍了Python抽象和自定义类定义与用法,结合实例形式分析了Python抽象方法、抽象类相关功能、定义、用法及相关操作注意事项,需要的朋友可以参考下
    2018-08-08
  • Python3实现的反转单链表算法示例

    Python3实现的反转单链表算法示例

    这篇文章主要介绍了Python3实现的反转单链表算法,结合实例形式总结分析了Python基于迭代算法与递归算法实现的翻转单链表相关操作技巧,需要的朋友可以参考下
    2019-03-03
  • Python实现随机创建电话号码的方法示例

    Python实现随机创建电话号码的方法示例

    这篇文章主要介绍了Python实现随机创建电话号码的方法,涉及Python随机数运算相关操作技巧,需要的朋友可以参考下
    2018-12-12
  • Python实现计算文件夹下.h和.cpp文件的总行数

    Python实现计算文件夹下.h和.cpp文件的总行数

    这篇文章主要介绍了Python实现计算文件夹下.h和.cpp文件的总行数,本文直接给出实现代码,需要的朋友可以参考下
    2015-04-04

最新评论

?


http://www.vxiaotou.com