Training PyTorch Models on TPU
Tutorial on using PyTorch/XLA 1.7 with TPUs
- 1. Overview
- 2. Preparations and packages
- 3. Data preparation
- 4. Model setup
- 5. Modeling
- 6. Closing words
Last update: 19.10.2021
1. Overview
Deep learning heavily relies on Graphical Processing Units (GPUs) to enable fast training. Recently, Google introduced Tensor Processing Units (TPUs) to further advance the speed of computations used in neural networks. Using cloud TPUs is possible on Kaggle and Google Colab. While TPU chips have been optimized for TensorFlow
, PyTorch
users can also take advantage of the better compute. This requires using PyTorch/XLA
and implementing certain changes in the modeling pipeline.
Moving a PyTorch
pipeline to TPU includes the following steps:
- installing relevant packages ans setting up TPU
- adjusting syntax of some modeling steps such as initialization, optimizer and verbosity
- distributing data loaders over multiple TPU cores
- wrapping data processing, training and inference into a master function
This post provides a tutorial on using PyTorch/XLA
to build the TPU pipeline. The code is optimized for multi-core TPU training. Many of the ideas are adapted from here and here. We will focus on a computer vision application, but the framework can be used with other deep learning models as well. We will use data from RSNA STR Pulmonary Embolism Detection Kaggle competition on detecting pulmonary embolism on more than 1.7 million CT scans.
2. Preparations and packages
When setting up a script, it is important to introduce two TPU-related parameters: batch size and number of workers.
Google recommends using 128 images per batch for the best performance on the current TPU v3 chips. The v3 chips have 8 cores. This implies that each of the 8 cores can receive a batch of 128 images at each training step, and the modeling can be performed simultaneously on the separate cores. The model weights are then updated based on the outcomes observed on each core. Therefore, the batch size of 128 actually implies 128 * 8 images in each iteration.
#collapse-show
# partitioning
num_folds = 5
use_fold = 0
# image params
image_size = 128
# modeling
batch_size = 128 # num_images = batch_size*num_tpu_workers
batches_per_epoch = 1000 # num_images = batch_size*batches_per_epoch*num_tpu_workers
num_epochs = 1
batch_verbose = 100
num_tpu_workers = 8
# learning rate
eta = 0.0001
step = 1
gamma = 0.5
# paths
data_path = '/kaggle/input/rsna-str-pulmonary-embolism-detection/'
image_path = '/kaggle/input/rsna-str-pe-detection-jpeg-256/train-jpegs/'
After specifying the parameters, we need to set up TPU by installing and importing torch_xla
using the snippet below. There are two options: install the last stable XLA version (1.7 as of 30.03.2021) or the so-called 'nightly' version that includes the latest updates but may be unstable. I recommend going for the stable version.
We also specify XLA_USE_BF16
variable (default tensor precision format) and XLA_TENSOR_ALLOCATOR_MAXSIZE
variable (maximum tensor allocator size). When working in Google Colab we can also run assert os.environ['COLAB_TPU_ADDR']
to check that Colab is correctly connected to a TPU instance.
Don't be discouraged if you see error messages during the installation of fastai
, kornia
and allennlp
. The installation would still proceed to the required versions of torch
and torchvision
needed to work with TPUs. This may take a few minutes.
# XLA version
xla_version = '1.7' # 'nightly' or '1.7'
# installation
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --verion $xla_version
# XLA imports
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.data_parallel as dp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
# configurations
import os
os.environ['XLA_USE_BF16'] = '1'
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '1000000000'
Next, we import all other relevant libraries.
#collapse-hide
import numpy as np
import pandas as pd
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import transforms, models, datasets
from torch.utils.data import Dataset
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import cv2
from sklearn.model_selection import GroupKFold
import glob
import random
import time
import sys
import os
import gc
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')
3. Data preparation
Mostly, the data processing pipeline does not need adjustments when training on TPU instead of GPU. It is only necessary to change the data loaders such that they distribute image batches over the TPU cores. This is covered in the next section that overviews the modeling stage, since the data samplers need to be wrapped into the modeling function. Feel free to skip this section if you already know how to process the image data.
It is important to note that given a more efficient training, data import usually becomes a computational bottleneck. Hence, it is crucial to optimize data reading/processing as much as possible. For the best read performance, I recommend transforming the data to .tfrec
format. You can read more about using .tfrec
with PyTorch
here.
Below, we construct a standard Dataset
class to read JPG images of the CT scans. Each image has ten binary labels indicating the presence of pulmonary embolism and its characteristics.
#collapse-hide
label_names = ['pe_present_on_image',
'negative_exam_for_pe',
'rv_lv_ratio_gte_1',
'rv_lv_ratio_lt_1',
'leftsided_pe',
'chronic_pe',
'rightsided_pe',
'acute_and_chronic_pe',
'central_pe',
'indeterminate']
#collapse-show
### DATASET
class PEData(Dataset):
def __init__(self, data, directory, transform = None, load_jpg = False, labeled = False):
self.data = data
self.directory = directory
self.transform = transform
self.load_jpg = load_jpg
self.labeled = labeled
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# import
img_name = glob.glob(os.path.join(self.directory, '/'.join(self.data.iloc[idx][['StudyInstanceUID', 'SeriesInstanceUID']]) + '/*' + self.data.iloc[idx]['SOPInstanceUID'] + '.jpg'))[0]
image = cv2.imread(img_name)
# switch channels and normalize
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image / 255.0
# convert
image = torch.tensor(image, dtype = torch.float)
image = image.permute(2, 0, 1)
# augmentations
image = self.transform(image)
# output
if self.labeled:
labels = torch.tensor(self.data.iloc[idx][label_names].values.astype('int'), dtype = torch.float)
return image, labels
else:
return image
### AUGMENTATIONS
train_trans = test_trans = transforms.Compose([transforms.ToPILImage(),
transforms.Resize(image_size),
transforms.ToTensor()])
We split data into training and validation folds such that images from the same patient - StudyInstanceUID
- do not appear in both. We will only use a single fold for demonstration purposes.
#collapse-hide
# partitioning
train = pd.read_csv(data_path + 'train.csv')
gkf = GroupKFold(n_splits = num_folds)
train['fold'] = -1
for fold, (_, val_idx) in enumerate(gkf.split(train, groups = train['StudyInstanceUID'])):
train.loc[val_idx, 'fold'] = fold
# load splits
data_train = train.loc[train.fold != use_fold].reset_index(drop = True)
data_valid = train.loc[train.fold == use_fold].reset_index(drop = True)
# datasets
train_dataset = PEData(data = data_train,
directory = image_path,
transform = train_trans,
load_jpg = load_jpegs,
labeled = True)
valid_dataset = PEData(data = data_valid,
directory = image_path,
transform = test_trans,
load_jpg = load_jpegs,
labeled = True)
Before proceeding to modeling, let's take a look at a sample batch of training images using our processing pipeline.
#collapse-show
# sample loader
sample_loader = torch.utils.data.DataLoader(valid_dataset,
shuffle = False,
batch_size = 8,
num_workers = 1)
# display images
for batch_idx, (inputs, labels) in enumerate(sample_loader):
fig = plt.figure(figsize = (14, 7))
for i in range(8):
ax = fig.add_subplot(2, 4, i + 1, xticks = [], yticks = [])
plt.imshow(inputs[i].numpy().transpose(1, 2, 0))
ax.set_title(labels.numpy()[:, i])
break
4. Model setup
The modeling stage needs to be modified because the modeling is performed simultaneously on multiple TPU cores. This requires changes to model initialization, optimizer and building a master function to distribute data loaders, training and inference over multi-core TPU chips. Let's dive in!
We start with the model. We use ResNet-34 with ten output nodes corresponding to each of the binary labels. After initializing the model, we need to wrap it into the MX object that can be sent to TPU. This is done by a simple command mx = xmp.MpModelWrapper(model)
.
# initialization
def init_model():
model = models.resnet34(pretrained = True)
model.fc = torch.nn.Linear(in_features = 512, out_features = len(label_names), bias = True)
return model
# model wrapper
model = init_model()
mx = xmp.MpModelWrapper(model)
Tracking the running loss when training on multiple TPU cores can a bit difficult since we need to aggregate batch losses between the TPU cores. The following helper class allows to externally store the loss values and update it based on the batch outputs from each worker.
class AverageMeter(object):
'''Computes and stores the average and current value'''
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n = 1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
Next, we need to wrap modeling into a single master function that can be distributed over TPU cores. We will first define functions for training and inference and then introduce the wrapper function.
The training pass must have several steps:
- the optimizer step is done with
xm.optimizer_step(optimizer)
- the printing statements need to be defined as
xm.master_print()
instead ofprint()
in order to only print a statement once (otherwise, each TPU core will print it) - dataloader should be defined outside of the function and read as an argument to distribute it over the cores
- running loss can be computed using the defined
AverageMeter()
object
In addition, it is important to clear the TPU memory as often as possible to ensure that the modeling does not crash:
-
del [object]
to delete objects once they are not needed -
gc.collect()
to collect garbage left in memory
#collapse-show
### TRAINING
def train_fn(epoch, para_loader, optimizer, criterion, scheduler, device):
# initialize
model.train()
trn_loss_meter = AverageMeter()
# training loop
for batch_idx, (inputs, labels) in enumerate(para_loader):
# extract inputs and labels
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
# forward and backward pass
preds = model(inputs)
loss = criterion(preds, labels)
loss.backward()
xm.optimizer_step(optimizer, barrier = True) # barrier is required on single-core training but can be dropped with multiple cores
# compute loss
trn_loss_meter.update(loss.detach().item(), inputs.size(0))
# feedback
if (batch_idx > 0) and (batch_idx % batch_verbose == 0):
xm.master_print('-- batch {} | cur_loss = {:.6f}, avg_loss = {:.6f}'.format(
batch_idx, loss.item(), trn_loss_meter.avg))
# clear memory
del inputs, labels, preds, loss
gc.collect()
# early stop
if batch_idx > batches_per_epoch:
break
# scheduler step
scheduler.step()
# clear memory
del para_loader, batch_idx
gc.collect()
return trn_loss_meter.avg
Similar to the training pass, inference function uses dataloader as an argument, updates loss using the AverageMeter()
object and clears memory after each batch.
#collapse-show
### INFERENCE
def valid_fn(epoch, para_loader, criterion, device):
# initialize
model.eval()
val_loss_meter = AverageMeter()
# validation loop
for batch_idx, (inputs, labels) in enumerate(para_loader):
# extract inputs and labels
inputs = inputs.to(device)
labels = labels.to(device)
# compute preds
with torch.no_grad():
preds = model(inputs)
loss = criterion(preds, labels)
# compute loss
val_loss_meter.update(loss.detach().item(), inputs.size(0))
# feedback
if (batch_idx > 0) and (batch_idx % batch_verbose == 0):
xm.master_print('-- batch {} | cur_loss = {:.6f}, avg_loss = {:.6f}'.format(
batch_idx, loss.item(), val_loss_meter.avg))
# clear memory
del inputs, labels, preds, loss
gc.collect()
# clear memory
del para_loader, batch_idx
gc.collect()
return val_loss_meter.avg
The master modeling function also includes several TPU-based modifications.
First, we need to create a distributed data sampler that reads our Dataset
object and distributes batches over TPU cores. This is done with torch.utils.data.distributed.DistributedSampler()
, which allows data loaders from different cores to only take a portion of the whole dataset. Setting num_replicas
to xm.xrt_world_size()
checks the number of available TPU cores. After defining the sampler, we can set up a data loader that uses the sampler.
Second, the model is sent to TPU with the following code:
device = xm.xla_device()
model = mx.to(device)
Third, we need to update learning rate since the modeling is done simultaneously on batches on different cores: scaled_eta = eta * xm.xrt_world_size()
.
Finally, we continue keeping track of the memory and clearing it whenever possible, and use xm.master_print()
for displaying intermediate results. We also set up the function to return lists with the training and validation loss values.
#collapse-show
### MASTER FUNCTION
def _run(model):
### DATA PREP
# data samplers
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,
num_replicas = xm.xrt_world_size(),
rank = xm.get_ordinal(),
shuffle = True)
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset,
num_replicas = xm.xrt_world_size(),
rank = xm.get_ordinal(),
shuffle = False)
# data loaders
valid_loader = torch.utils.data.DataLoader(valid_dataset,
batch_size = batch_size,
sampler = valid_sampler,
num_workers = 0,
pin_memory = True)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size = batch_size,
sampler = train_sampler,
num_workers = 0,
pin_memory = True)
### MODEL PREP
# send to TPU
device = xm.xla_device()
model = mx.to(device)
# scale LR
scaled_eta = eta * xm.xrt_world_size()
# optimizer and loss
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr = scaled_eta)
scheduler = lr_scheduler.StepLR(optimizer, step_size = step, gamma = gamma)
### MODELING
# placeholders
trn_losses = []
val_losses = []
best_val_loss = 1
# modeling loop
gc.collect()
for epoch in range(num_epochs):
# display info
xm.master_print('-'*55)
xm.master_print('EPOCH {}/{}'.format(epoch + 1, num_epochs))
xm.master_print('-'*55)
xm.master_print('- initialization | TPU cores = {}, lr = {:.6f}'.format(
xm.xrt_world_size(), scheduler.get_lr()[len(scheduler.get_lr()) - 1] / xm.xrt_world_size()))
epoch_start = time.time()
gc.collect()
# update train_loader shuffling
train_loader.sampler.set_epoch(epoch)
# training pass
train_start = time.time()
xm.master_print('- training...')
para_loader = pl.ParallelLoader(train_loader, [device])
trn_loss = train_fn(epoch = epoch + 1,
para_loader = para_loader.per_device_loader(device),
criterion = criterion,
optimizer = optimizer,
scheduler = scheduler,
device = device)
del para_loader
gc.collect()
# validation pass
valid_start = time.time()
xm.master_print('- validation...')
para_loader = pl.ParallelLoader(valid_loader, [device])
val_loss = valid_fn(epoch = epoch + 1,
para_loader = para_loader.per_device_loader(device),
criterion = criterion,
device = device)
del para_loader
gc.collect()
# save weights
if val_loss < best_val_loss:
xm.save(model.state_dict(), 'weights_{}.pt'.format(model_name))
best_val_loss = val_loss
# display info
xm.master_print('- elapsed time | train = {:.2f} min, valid = {:.2f} min'.format(
(valid_start - train_start) / 60, (time.time() - valid_start) / 60))
xm.master_print('- average loss | train = {:.6f}, valid = {:.6f}'.format(
trn_loss, val_loss))
xm.master_print('-'*55)
xm.master_print('')
# save losses
trn_losses.append(trn_loss)
val_losses.append(val_loss)
del trn_loss, val_loss
gc.collect()
# print results
xm.master_print('Best results: loss = {:.6f} (epoch {})'.format(np.min(val_losses), np.argmin(val_losses) + 1))
return trn_losses, val_losses
5. Modeling
After all helper functions have been introduced, we can finally launch the training! To do that, we need to define the last wrapper function that runs the modeling on multiple TPU cores: _mp_fn(rank, flags)
. Within the wrapper function, we set default tensor type to make sure that new tensors are initialized as float torch tensors on TPU and then run the modeling. It is important to set nprocs
to the number of available TPU cores. The FLAGS
object can be used to pass further training arguments to the modeling function.
Running xmp.spawn()
will launch our _mp_fn()
on multiple TPU cores! Since it does not provide any output, it is useful to save losses or other objects you might be interested in after running the _run()
function.
# wrapper function
def _mp_fn(rank, flags):
torch.set_default_tensor_type('torch.FloatTensor')
trn_losses, val_losses = _run(model)
np.save('trn_losses.npy', np.array(trn_losses))
np.save('val_losses.npy', np.array(val_losses))
# modeling
gc.collect()
FLAGS = {}
xmp.spawn(_mp_fn, args = (FLAGS,), nprocs = num_tpu_workers, start_method = 'fork')
The training is working! Note that every 100 batches displayed in the snippet above actually refer to 100 * batch_size * num_tpu_workers
images, since every core processes the same amount of different images simultaneously but printing is done just from one core.
6. Closing words
This is the end of this blog post. Using a computer vision application, we demonstrated how to use PyTorch/XLA
to take advantage of TPU when training deep learning models. We covered important changes that need to be implemented into the modeling pipeline to enable TPU-based training, including data processing, modeling and displaying results. I hope this post will help you to get started with TPUs!
If you are interested in further reading, make sure to check tutorial notebooks developed by PyTorch/XLA
team available at their GitHub repo.
Liked the post? Share it on social media!
You can also buy me a cup of tea to support my work. Thanks!