Callbacks in Keras

S.No Lesson Title
What are callbacks in Keras?
Types of callbacks
3.1 EarlyStopping
3.2 ModelCheckpoint
3.3 LearningRateScheduler
3.4 TensorBoard
3.5 Custom Callbacks
3.6 LambdaCallback

1. Introduction

Let's start by assuming that you are training a deep learning model built using Keras library. Now there are few circumstances like the model can train for a really long time without any significant improvement in its performance, it can start overfitting after a few iterations/epochs or the model might never learn anything and train for a really long time without us knowing. These situations can be a nightmare in real life and we would obviously want to avoid them by monitoring our models from time to time. It is similar to driving a car without any control over the speed or steering of the car which can lead to an accident. In order to address the issues discussed above, we have callbacks in Keras about which we'll discuss in this article. Let's get started.

2. What are callbacks in Keras?

In simple terms, callbacks are the functions that help in having some control over the model training stage. Things like stopping the model training when certain accuracy/loss is achieved, adjusting learning rate after epochs, saving the model after each epoch, and many more things.

Following is a formal definition from Keras Documentation:

"A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get a view of the internal states and statistics of the model during training."

This is a simple introduction and rough overview of what callbacks are. In the upcoming sections, we'll see how to use callbacks with the help of Keras library.

3. Types of Callbacks

There are multiple options of callback functions in keras and in this section, we'll explore a few commonly used callback functions. Let's get started.

3.1 EarlyStopping

This is the most commonly used callback function owing to the fact that overfitting can be a nightmare. Imagine training a model for a day and realizing it overfits. This overfitting can be stopped in the initial stages itself.

Let's have a look at few hyperparameters of this function:

  1. monitor - value to consider for stopping training.
  2. E.g. - val_loss.

  3. min_delta - minimum change in the monitored value at which training will stop. For example, min_delta is 0.1 then the training will stop when an absolute change in monitored value is less than 0.1
  4. patience - Number of epochs with no improvement after which training has to stop.
  5. restore_best_weights - Setting this to false will ensure that the weights for the best training epoch will be used.

from keras.callbacks import EarlyStopping

earlystop = EarlyStopping(monitor='val_loss',

#add the callback to your model using the following snippet, y_train, epochs=20, callbacks=[earlystop])


One thing to notice is that we pass the callback function in a list which implies that we can simply pass multiple callbacks of different kids to monitor the progress of our model.

3.2 ModelCheckpoint

This callback method saves the model after each epoch. It can be used when we want to save the best model for which say val_accuracy reaches a maximum value during the training stage. We can also use it for other measures like train_accuracy etc.

Following are some hyperparameters to know about before using this callback method:

  1. filepath: the path of the file where a model is saved
  2. monitor: value that is monitored
  3. save_best_only: set to true if do not want to overwrite the last best model
  4. mode: auto, min, or max. Say your mode is mode='min' for val_loss then you are trying to minimize the monitored value.
from keras.callbacks import ModelCheckpoint

#autosave best model

best_model_file = "best.pdf"
best_model = ModelCheckpoint(best_model_file,

3.3 LearningRateScheduler

This method is used to change the learning rate over time based on a scheduling function that we define (scheduler function in code box). At the beginning of every epoch, this callback gets the updated learning rate from the scheduler function. We can decide if we want to update the learning rate after each epoch or after a few epochs.

Following functions and hyperparameters need to be defined for this callback:

  1. scheduler:It is a function that needs to be defined before calling the callback function. It takes the current epoch number and the current learning rate and returns a new learning rate based on the conditions defined in the function.
  2. verbose:Used to get updates after each epoch.
#This function keeps the initial learning rate for the first ten epochs
#and decreases it exponentially after that

def scheduler(epoch, lr):
        if epoch < 10:
            return lr
            return lr * tf.math.exp(-0,1)

from keras.callbacks import LearningRateScheduler
lrs = LearningRateScheduler(scheduler, verbose=0) #scheduler is a function

#adding the callback while fitting model to data, epochs=15, callbacks=[callback],verbose=0))

3.4 TensorBoard

This is an excellent callback that takes care of a lot of things at once. Logs are written to a directory which is examined using TensorFlow's TensorBoard visualization tool.

tbCallBack = keras.callbacks.TensorBoard(
..., callbacks=[tbCallBack])

Following are some important hyperparameters of the callback function given above:

  1. histogram_freq: frequency (in epochs) at which to compute activation and weight histograms for the layers of the model. If set to 0, histograms won't be computed. Validation data (or split) must be specified for histogram visualizations.
  2. write_graph: whether to visualize the graph in TensorBoard. The log file can become quite large when write_graph is set to True.
  3. write_images: whether to write model weights to visualize as an image in TensorBoard.

If you want to visualize the files created during training, run the following snippet in your terminal.

tensorboard --logdir=path_to_your_logs

3.5 Custom Callbacks

This is an excellent callback that takes care of a lot of things at once. Logs are written to a directory which is examined using TensorFlow's TensorBoard visualization tool.

class myCallback(tf.keras.callbacks.Callback): 
    def on_epoch_end(self, epoch, logs={}): 
        if(logs.get('acc') > ACCURACY_THRESHOLD):   
            print("\nWe have reached %2.2f%% accuracy, so we will stopping training."
            self.model.stop_training = True
callbacks = myCallback(), y_train, epochs=20, callbacks=[callbacks])

Let's try to understand what's happening here. We are building a child class myCallback by extending/inheriting Callbacks class from keras and using its functions. We have implemented the on_epoch_end function which is called at the end of each epoch. After that, we are getting the accuracy after the epoch and if the accuracy value is greater than our threshold then we stop training the model. Next, we create an instance of the object of myCallback class and pass it to function where the model is a model we want to train using keras layers.

3.6 LambdaCallback

Keras provides numerous callback classes to serve almost all our requirements but sometimes we still want to build a custom callback function for specific tasks like stopping the model after certain accuracy is reached or saving the model at each epoch. In place of writing a callback class using keras.callbacks.Callback() as a parent class we can use LamdaCallback. It takes some arguments such as "on_epoch_end" which takes a function that can be called at the end of each epoch.

Following arguments with fixed positional arguments are there:

  • on_epoch_begin and on_epoch_end expect two positional arguments: epoch,logs
  • on_batch_begin and on_batch_end expect two positional arguments: batch,logs
  • on_train_begin and on_train_end expect one positional argument: logs

Let's start using this callback with the help of the following example. In this example the model weights are saved when accuracy goes beyond a set threshold:

from keras.callbacks import LambdaCallback
call = LambdaCallback(on_epoch_end= lambda epochs, 
        logs: model.save_weights('kang.hdf5') if logs.get('acc')>0.99 else None)

Take another example where we stop the training when the accuracy reaches beyond a certain point:

from keras.callbacks import LambdaCallback
call = LambdaCallback(on_epoch_end= lambda epochs, 
        logs: (model.stop_training:=True) if logs.get('acc')>0.99 else None)    


With this, we come to the end of this article. There are many other callback methods in keras such as LambdaCallback, CSVLogger, BaseLogger which can be tried while implementing callbacks. One common tip is that one should use multiple callbacks while training a model to have a better understanding and control over all the aspects of the model during the training phase. Happy Learning!



Reach out to us for Recruitment Services, Corporate Training, Source-Train-Hire (STH) Program, Campus Hiring, Executive Hiring, Diversity Hiring

Send Message