Last updated: 03.02.2021

1. Overview

Deep learning models are getting bigger and bigger. It becomes difficult to fit such networks in the GPU memory. This is especially relevant in computer vision applications where we need to reserve some memory for high-resolution images as well. As a result, we are sometimes forced to use small batches during training, which may lead to a slower convergence and lower accuracy.

This blog post provides a quick tutorial on how to increase the effective batch size by using a trick called graident accumulation. Simply speaking, gradient accumulation means that we will use a small batch size but save the gradients and update network weights once every couple of batches. Automated solutions for this exist in higher-level frameworks such as fast.ai or lightning, but those who love using PyTorch might find this tutorial useful.

2. What is graident accumulation

When training a neural network, we usually divide our data in mini-batches and go through them one by one. The network predicts batch labels, which are used to compute the loss with respect to the actual targets. Next, we perform backward pass to compute gradients and update model weights in the direction of those gradients.

Gradient accumlation modifies the last step of the training process. Instead of updating the network weights on every batch, we can save graident values, proceed to the next batch and add up the new gradients. The weight update is then done only after several batches have been processed by the model.

Gradient accumulation helps to imitate a larger batch size. Imagine you want to use 32 images in one batch, but your hardware crashes once you go beyond 8. In that case, you can use batches of 8 images and update weights once every 4 batches. If you accumulate graidients from every batch in between, the results will be identical and you will be able to perform training on a less expensive machine.

3. How to make it work

The implementation of gradient accumulation is rather straitforward. The standard training loop without accumulation usually looks like this:

# loop through the batches
for (inputs, labels) in data_loader:

    # extract inputs and labels
    inputs = inputs.to(device)
    labels = labels.to(device)

    # passes and weights update
    with torch.set_grad_enabled(True):
        
        # forward pass 
        preds = model(inputs)
        loss  = criterion(preds, labels)

        # backward pass
        loss.backward() 

        # weights update
        optimizer.step()
        optimizer.zero_grad()

Now let's implement gradient accumulation! There are three things we need to do:

  1. Specify the accum_iter parameter. This should be an integer indicating once in how many batches we would like to update the network weights.
  2. Condition the weight update on the index of the running batch. This requires using enumerate(data_loader) to store the batch index.
  3. Divide the running loss by acum_iter. This normalizes the loss to reduce the contribution of each mini-batch we are actually processing.
# batch accumulation parameter
accum_iter = 4  

# loop through the numbered range of batches
for batch_idx, (inputs, labels) in enumerate(data_loader):

    # extract inputs and labels
    inputs = inputs.to(device)
    labels = labels.to(device)

    # passes and weights update
    with torch.set_grad_enabled(True):
        
        # forward pass 
        preds = model(inputs)
        loss  = criterion(preds, labels)

        # normalize loss to account for batch accumulation
        loss = loss / accum_iter 

        # backward pass
        loss.backward()

        # weights update
        if ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_loader)):
            optimizer.step()
            optimizer.zero_grad()

It is realy that simple! The gradients are computed when we call loss.backward() and are stored by PyTorch until we call optimizer.zero_grad(). Therefore, we need to move the weight update performed in optimizer.step() and the gradient reset under the if condition that check the batch index. It is important to also update weights on the last batch when batch_idx + 1 == len(data_loader) - this makes sure that data from the last batches are not discarded and used for optimizaing the network.

4. Closing words

This is it! I hope this brief tutorial will help you to finally fit that model on your machine and train it with the batch size it deserves. If you are interested, check out my other blog posts on tips on deep learning and PyTorch. Happy learning!