{"id":16285,"date":"2021-05-19T17:36:25","date_gmt":"2021-05-19T17:36:25","guid":{"rendered":"https:\/\/www.askpython.com\/?p=16285"},"modified":"2021-05-19T17:36:26","modified_gmt":"2021-05-19T17:36:26","slug":"saving-loading-models-tensorflow","status":"publish","type":"post","link":"https:\/\/www.askpython.com\/python-modules\/saving-loading-models-tensorflow","title":{"rendered":"Saving and Loading Models Using TensorFlow 2.0+"},"content":{"rendered":"\n<p>In this article, we will be discussing saving loading models using TensorFlow 2.0+. This is a beginner-intermediate level article meant for people who have just started out using TensorFlow for their deep learning projects.<\/p>\n\n\n\n<h2 class=\"wp-block-heading\">Why do you need to save a model?<\/h2>\n\n\n\n<p>One of the very common mistakes people make as a beginner in deep learning is not saving their models.<\/p>\n\n\n\n<p>Saving a deep learning model both during training and after training is a good practice. It saves your time and enhances the reproducibility of the model. Here are a few more reasons that you might consider for saving a model:<\/p>\n\n\n\n<ul class=\"wp-block-list\"><li>Training modern deep learning models with millions of parameters and huge datasets can be expensive in terms of computation and time. Moreover, you can get different results\/accuracy during different training. So it is always a good idea to use a saved model for displaying your results rather than training on the spot.<\/li><li>Saving the different version of the same models allows you to inspect and understand the working of the model.<\/li><li>You can use the same compiled model in different languages and platforms that support TensorFlow eg.: TensorFlow Lite and TensorFlow JS without converting any of your code.<\/li><\/ul>\n\n\n\n<p>TensorFlow happens to offer a number of ways to save a model. We will be discussing all of them in detail in the next few sections.<\/p>\n\n\n\n<h3 class=\"wp-block-heading\">How to save a model during training?<\/h3>\n\n\n\n<p>Sometimes it is important to save model weights during model training. If there has been an anomaly in your results after a certain epoch, with check-pointing it becomes easier to inspect the previous states of the model or even restore them.<\/p>\n\n\n\n<p>TensorFlow models are trained using <code>Model.train()<\/code> function. We need to define a model checkpoint callback using <code>tf.keras.callbacks.ModelCheckpoint()<\/code> to tell the compiler to save model weights at certain intervals of epochs.<\/p>\n\n\n\n<p>Callback sounds difficult but it is not difficult in term of usage. Here is an example of using it.<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n# This is the initialization block of code\n# Not important for understanding the saving\n# But to execute the next cells containing the code\n# for saving and loading\n\nimport tensorflow as tf\nfrom tensorflow import keras\n\n# We define a dummy sequential model.\n# This function to create a model will be used throughout the article\n\ndef create_model():\n  model = tf.keras.models.Sequential(&#x5B;\n    keras.layers.Dense(512, activation=&#039;relu&#039;, input_shape=(784,)),\n    keras.layers.Dropout(0.2),\n    keras.layers.Dense(10)\n  ])\n\n  model.compile(optimizer=&#039;adam&#039;,\n                loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),\n                metrics=&#x5B;tf.metrics.SparseCategoricalAccuracy()])\n\n  return model\n\n# Create a basic model instance\nmodel = create_model()\n\n# Get the dataset\n\n(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n\ntrain_labels = train_labels&#x5B;:1000]\ntest_labels = test_labels&#x5B;:1000]\n\ntrain_images = train_images&#x5B;:1000].reshape(-1, 28 * 28) \/ 255.0\ntest_images = test_images&#x5B;:1000].reshape(-1, 28 * 28) \/ 255.0\n<\/pre><\/div>\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n# Create a new model using the function\nmodel = create_model()\n\n# Specify the checkpoint file \n# We use the str.format() for naming files according to epoch\ncheckpoint_path = &quot;training_2\/cp-{epoch:04d}.ckpt&quot;\n\n# Get the directory of checkpoint\ncheckpoint_dir = os.path.dirname(checkpoint_path)\n\n# Define the batch size\nbatch_size = 32\n\n# Create a callback that saves the model&#039;s weights every 5 epochs\ncp_callback = tf.keras.callbacks.ModelCheckpoint(\n    filepath=checkpoint_path, \n    verbose=1, \n    save_weights_only=True,\n    save_freq=5*batch_size)\n\n\n# Save the weights using the `checkpoint_path` format\nmodel.save_weights(checkpoint_path.format(epoch=0))\n\n# Train the model with the the checkpoint callback\nmodel.fit(train_images, train_labels,\n          epochs=50, \n          batch_size=batch_size, \n          callbacks=&#x5B;cp_callback],\n          verbose=0)\n<\/pre><\/div>\n\n\n<h3 class=\"wp-block-heading\">Loading from a checkpoint <\/h3>\n\n\n\n<p>In case you want to restore a checkpoint that you created you can use the model, you can use the <code>model.load_weights()<\/code> function.<\/p>\n\n\n\n<p>Here is the syntax and an example for loading the weights.<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n# Syntax\n\nmodel.load_weights(&quot;&lt;path the checkpoint file(*.cpt)&gt;&quot;)\n\n# Example \n\n# Finds the latest checkpoint\nlatest = tf.train.latest_checkpoint(checkpoint_dir)\n\n# Create a new model\nmodel = create_model()\n\n# Load the weights of the latest checkpoint\nmodel.load_weights(latest)\n<\/pre><\/div>\n\n\n<h3 class=\"wp-block-heading\">Save the weights of a trained model<\/h3>\n\n\n\n<p>A model can also be saved after the training. The process is comparatively much simpler than checkpoints during training.<\/p>\n\n\n\n<p>To save the weights file after a model is trained, we use the Model.save_weights() function. An example for using it is as follows:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n# Save the weights\nmodel.save_weights(&#039;.\/checkpoints\/my_checkpoint&#039;)\n\n# Create a new model instance\nmodel = create_model()\n\n# Restore the weights\nmodel.load_weights(&#039;.\/checkpoints\/my_checkpoint&#039;)\n<\/pre><\/div>\n\n\n<h3 class=\"wp-block-heading\">Load the weights of the trained model<\/h3>\n\n\n\n<p>To load the model from a weight we can use the <code>Model.load_weights()<\/code> just like loading checkpoint weights. In fact, the weights stored as a checkpoint file.<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n# Restore the weights\nmodel.load_weights(&#039;.\/checkpoints\/my_checkpoint&#039;)\n<\/pre><\/div>\n\n\n<h3 class=\"wp-block-heading\">Saving and loading an entire model<\/h3>\n\n\n\n<p>In the previous section, we saw how we can save the weights of a model. This has a certain problem to it. The model must be defined before we load the model weights to the model. Any structural difference between the actual model and the model you want to load the weights to can lead to errors.<\/p>\n\n\n\n<p>Moreover, this method of saving weights becomes difficult when we want to use models across different platforms. For example, you want to use the model trained in python in your browser using TensorFlow JS. <\/p>\n\n\n\n<p>In such cases, you might require to save the whole model i.e. the structure along with the weights. TensorFlow allows you to save the model using the function <code>Model.save()<\/code>. Here is an example of doing so.<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n# Save the whole model in SaveModel format\n\nmodel.save(&#039;my_model&#039;)\n<\/pre><\/div>\n\n\n<p>TensorFlow also offers the users to save the model using HDF5 format. To save the model in HDF5 format just mention the filename using the hdf5 extension.<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n# Save the model in hdf5 format\n\n# The .h5 extension indicates that the model is to be saved in the hdf5 extension.\nmodel.save(&#039;my_model.h5&#039;)\n<\/pre><\/div>\n\n\n<p><em>Note: HDF5 was initially used by Keras before it became mainstream in TensorFlow. TensorFlow uses the SaveModel format and it is always advised to go for the recommended newer format.<\/em><\/p>\n\n\n\n<p>You can load these saved models using the <code>tf.keras.models.load_model()<\/code>. The function automatically intercepts whether the model is saved in SaveModel format or hdf5 format. Here is an example for doing so:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n# For both hdf5 format and SaveModel format use the appropriate path to the file\n\n# SaveModel Format\nloaded_model = tf.keras.models.load_model(&#039;my_model&#039;)\n\n# HDF5 format\nloaded_model = tf.keras.models.load_model(&#039;my_model.h5&#039;)\n<\/pre><\/div>\n\n\n<h2 class=\"wp-block-heading\">Conclusion<\/h2>\n\n\n\n<p>This brings us to the end of the tutorial. Hopefully, you can now save and load models in your training process. Stay tuned to learn more about deep-learning frameworks like PyTorch, TensorFlow and JAX.<\/p>\n","protected":false},"excerpt":{"rendered":"<p>In this article, we will be discussing saving loading models using TensorFlow 2.0+. This is a beginner-intermediate level article meant for people who have just started out using TensorFlow for their deep learning projects. Why do you need to save a model? One of the very common mistakes people make as a beginner in deep [&hellip;]<\/p>\n","protected":false},"author":25,"featured_media":16330,"comment_status":"closed","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[2],"tags":[],"class_list":["post-16285","post","type-post","status-publish","format-standard","has-post-thumbnail","hentry","category-python-modules"],"blocksy_meta":[],"_links":{"self":[{"href":"https:\/\/www.askpython.com\/wp-json\/wp\/v2\/posts\/16285","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.askpython.com\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.askpython.com\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.askpython.com\/wp-json\/wp\/v2\/users\/25"}],"replies":[{"embeddable":true,"href":"https:\/\/www.askpython.com\/wp-json\/wp\/v2\/comments?post=16285"}],"version-history":[{"count":0,"href":"https:\/\/www.askpython.com\/wp-json\/wp\/v2\/posts\/16285\/revisions"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/www.askpython.com\/wp-json\/wp\/v2\/media\/16330"}],"wp:attachment":[{"href":"https:\/\/www.askpython.com\/wp-json\/wp\/v2\/media?parent=16285"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.askpython.com\/wp-json\/wp\/v2\/categories?post=16285"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.askpython.com\/wp-json\/wp\/v2\/tags?post=16285"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}