Gradient Accumulation in PyTorch
Increasing batch size to overcome memory constraints
Last updated: 03.02.2021
1. Overview
Deep learning models are getting bigger and bigger. It becomes difficult to fit such networks in the GPU memory. This is especially relevant in computer vision applications where we need to reserve some memory for high-resolution images as well. As a result, we are sometimes forced to use small batches during training, which may lead to a slower convergence and lower accuracy.
This blog post provides a quick tutorial on how to increase the effective batch size by using a trick called graident accumulation. Simply speaking, gradient accumulation means that we will use a small batch size but save the gradients and update network weights once every couple of batches. Automated solutions for this exist in higher-level frameworks such as fast.ai
or lightning
, but those who love using PyTorch
might find this tutorial useful.
2. What is graident accumulation
When training a neural network, we usually divide our data in mini-batches and go through them one by one. The network predicts batch labels, which are used to compute the loss with respect to the actual targets. Next, we perform backward pass to compute gradients and update model weights in the direction of those gradients.
Gradient accumlation modifies the last step of the training process. Instead of updating the network weights on every batch, we can save graident values, proceed to the next batch and add up the new gradients. The weight update is then done only after several batches have been processed by the model.
Gradient accumulation helps to imitate a larger batch size. Imagine you want to use 32 images in one batch, but your hardware crashes once you go beyond 8. In that case, you can use batches of 8 images and update weights once every 4 batches. If you accumulate graidients from every batch in between, the results will be identical and you will be able to perform training on a less expensive machine.
# loop through the batches
for (inputs, labels) in data_loader:
# extract inputs and labels
inputs = inputs.to(device)
labels = labels.to(device)
# passes and weights update
with torch.set_grad_enabled(True):
# forward pass
preds = model(inputs)
loss = criterion(preds, labels)
# backward pass
loss.backward()
# weights update
optimizer.step()
optimizer.zero_grad()
Now let's implement gradient accumulation! There are three things we need to do:
- Specify the
accum_iter
parameter. This should be an integer indicating once in how many batches we would like to update the network weights. - Condition the weight update on the index of the running batch. This requires using
enumerate(data_loader)
to store the batch index. - Divide the running loss by
acum_iter
. This normalizes the loss to reduce the contribution of each mini-batch we are actually processing.
# batch accumulation parameter
accum_iter = 4
# loop through the numbered range of batches
for batch_idx, (inputs, labels) in enumerate(data_loader):
# extract inputs and labels
inputs = inputs.to(device)
labels = labels.to(device)
# passes and weights update
with torch.set_grad_enabled(True):
# forward pass
preds = model(inputs)
loss = criterion(preds, labels)
# normalize loss to account for batch accumulation
loss = loss / accum_iter
# backward pass
loss.backward()
# weights update
if ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_loader)):
optimizer.step()
optimizer.zero_grad()
It is realy that simple! The gradients are computed when we call loss.backward()
and are stored by PyTorch until we call optimizer.zero_grad()
. Therefore, we need to move the weight update performed in optimizer.step()
and the gradient reset under the if
condition that check the batch index. It is important to also update weights on the last batch when batch_idx + 1 == len(data_loader)
- this makes sure that data from the last batches are not discarded and used for optimizaing the network.