Last update: 19.10.2021. All opinions are my own.

1. Overview

Deep learning heavily relies on Graphical Processing Units (GPUs) to enable fast training. Recently, Google introduced Tensor Processing Units (TPUs) to further advance the speed of computations used in neural networks. Using cloud TPUs is possible on Kaggle and Google Colab. While TPU chips have been optimized for TensorFlow, PyTorch users can also take advantage of the better compute. This requires using PyTorch/XLA and implementing certain changes in the modeling pipeline.

Moving a PyTorch pipeline to TPU includes the following steps:

  • installing relevant packages ans setting up TPU
  • adjusting syntax of some modeling steps such as initialization, optimizer and verbosity
  • distributing data loaders over multiple TPU cores
  • wrapping data processing, training and inference into a master function

This post provides a tutorial on using PyTorch/XLA to build the TPU pipeline. The code is optimized for multi-core TPU training. Many of the ideas are adapted from here and here. We will focus on a computer vision application, but the framework can be used with other deep learning models as well. We will use data from RSNA STR Pulmonary Embolism Detection Kaggle competition on detecting pulmonary embolism on more than 1.7 million CT scans.

2. Preparations and packages

When setting up a script, it is important to introduce two TPU-related parameters: batch size and number of workers.

Google recommends using 128 images per batch for the best performance on the current TPU v3 chips. The v3 chips have 8 cores. This implies that each of the 8 cores can receive a batch of 128 images at each training step, and the modeling can be performed simultaneously on the separate cores. The model weights are then updated based on the outcomes observed on each core. Therefore, the batch size of 128 actually implies 128 * 8 images in each iteration.


# partitioning
num_folds = 5
use_fold  = 0

# image params
image_size = 128

# modeling
batch_size        = 128  # num_images = batch_size*num_tpu_workers
batches_per_epoch = 1000 # num_images = batch_size*batches_per_epoch*num_tpu_workers
num_epochs        = 1
batch_verbose     = 100
num_tpu_workers   = 8

# learning rate
eta   = 0.0001
step  = 1
gamma = 0.5

# paths
data_path  = '/kaggle/input/rsna-str-pulmonary-embolism-detection/'
image_path = '/kaggle/input/rsna-str-pe-detection-jpeg-256/train-jpegs/'

After specifying the parameters, we need to set up TPU by installing and importing torch_xla using the snippet below. There are two options: install the last stable XLA version (1.7 as of 30.03.2021) or the so-called 'nightly' version that includes the latest updates but may be unstable. I recommend going for the stable version.

We also specify XLA_USE_BF16 variable (default tensor precision format) and XLA_TENSOR_ALLOCATOR_MAXSIZE variable (maximum tensor allocator size). When working in Google Colab we can also run assert os.environ['COLAB_TPU_ADDR'] to check that Colab is correctly connected to a TPU instance.

Don't be discouraged if you see error messages during the installation of fastai, kornia and allennlp. The installation would still proceed to the required versions of torch and torchvision needed to work with TPUs. This may take a few minutes.

# XLA version
xla_version = '1.7' # 'nightly' or '1.7'

# installation 
!curl -o
!python --verion $xla_version

# XLA imports
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.data_parallel as dp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
# configurations
import os
os.environ['XLA_USE_BF16']                 = '1'
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '1000000000'
Running on TPU  ['']
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  5115  100  5115    0     0   9199      0 --:--:-- --:--:-- --:--:--  9183
Updating... This may take around 2 minutes.
Updating TPU runtime to pytorch-dev20200515 ...


Setting up libopenblas-dev:amd64 (0.2.20+ds-4) ...
update-alternatives: using /usr/lib/x86_64-linux-gnu/openblas/ to provide /usr/lib/x86_64-linux-gnu/ ( in auto mode
update-alternatives: using /usr/lib/x86_64-linux-gnu/openblas/ to provide /usr/lib/x86_64-linux-gnu/ ( in auto mode
Processing triggers for libc-bin (2.27-3ubuntu1) ...

Next, we import all other relevant libraries.


import numpy as np
import pandas as pd

import torch
import torchvision

import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import transforms, models, datasets
from import Dataset

from PIL import Image, ImageFile
import cv2

from sklearn.model_selection import GroupKFold

import glob

import random
import time
import sys
import os
import gc

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

import warnings

3. Data preparation

Mostly, the data processing pipeline does not need adjustments when training on TPU instead of GPU. It is only necessary to change the data loaders such that they distribute image batches over the TPU cores. This is covered in the next section that overviews the modeling stage, since the data samplers need to be wrapped into the modeling function. Feel free to skip this section if you already know how to process the image data.

It is important to note that given a more efficient training, data import usually becomes a computational bottleneck. Hence, it is crucial to optimize data reading/processing as much as possible. For the best read performance, I recommend transforming the data to .tfrec format. You can read more about using .tfrec with PyTorch here.

Below, we construct a standard Dataset class to read JPG images of the CT scans. Each image has ten binary labels indicating the presence of pulmonary embolism and its characteristics.

label_names = ['pe_present_on_image',



class PEData(Dataset):
    def __init__(self, data, directory, transform = None, load_jpg = False, labeled = False):      = data = directory
        self.transform = transform
        self.load_jpg  = load_jpg
        self.labeled   = labeled

    def __len__(self):
        return len(

    def __getitem__(self, idx):
        # import
        img_name = glob.glob(os.path.join(, '/'.join([idx][['StudyInstanceUID', 'SeriesInstanceUID']]) + '/*' +[idx]['SOPInstanceUID'] + '.jpg'))[0]
        image    = cv2.imread(img_name)
        # switch channels and normalize
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image / 255.0
        # convert
        image = torch.tensor(image, dtype = torch.float)
        image = image.permute(2, 0, 1)

        # augmentations
        image = self.transform(image)
        # output
        if self.labeled:
            labels = torch.tensor([idx][label_names].values.astype('int'), dtype = torch.float)
            return image, labels
            return image


train_trans = test_trans = transforms.Compose([transforms.ToPILImage(),

We split data into training and validation folds such that images from the same patient - StudyInstanceUID - do not appear in both. We will only use a single fold for demonstration purposes.


# partitioning
train = pd.read_csv(data_path + 'train.csv')
gkf = GroupKFold(n_splits = num_folds)
train['fold'] = -1
for fold, (_, val_idx) in enumerate(gkf.split(train, groups = train['StudyInstanceUID'])):
    train.loc[val_idx, 'fold'] = fold

# load splits
data_train = train.loc[train.fold != use_fold].reset_index(drop = True)
data_valid = train.loc[train.fold == use_fold].reset_index(drop = True)

# datasets
train_dataset = PEData(data      = data_train, 
                       directory = image_path,
                       transform = train_trans,
                       load_jpg  = load_jpegs,
                       labeled   = True)
valid_dataset = PEData(data      = data_valid, 
                       directory = image_path,
                       transform = test_trans,
                       load_jpg  = load_jpegs,
                       labeled   = True)

Before proceeding to modeling, let's take a look at a sample batch of training images using our processing pipeline.


# sample loader
sample_loader =,
                                            shuffle     = False,
                                            batch_size  = 8, 
                                            num_workers = 1)

# display images
for batch_idx, (inputs, labels) in enumerate(sample_loader):
    fig = plt.figure(figsize = (14, 7))
    for i in range(8):
        ax = fig.add_subplot(2, 4, i + 1, xticks = [], yticks = [])     
        plt.imshow(inputs[i].numpy().transpose(1, 2, 0))
        ax.set_title(labels.numpy()[:, i])

4. Model setup

The modeling stage needs to be modified because the modeling is performed simultaneously on multiple TPU cores. This requires changes to model initialization, optimizer and building a master function to distribute data loaders, training and inference over multi-core TPU chips. Let's dive in!

We start with the model. We use ResNet-34 with ten output nodes corresponding to each of the binary labels. After initializing the model, we need to wrap it into the MX object that can be sent to TPU. This is done by a simple command mx = xmp.MpModelWrapper(model).

# initialization
def init_model():
    model    = models.resnet34(pretrained = True)
    model.fc = torch.nn.Linear(in_features = 512, out_features = len(label_names), bias = True)
    return model

# model wrapper
model = init_model()
mx    = xmp.MpModelWrapper(model)

Tracking the running loss when training on multiple TPU cores can a bit difficult since we need to aggregate batch losses between the TPU cores. The following helper class allows to externally store the loss values and update it based on the batch outputs from each worker.

class AverageMeter(object):

    '''Computes and stores the average and current value'''
    def __init__(self):

    def reset(self):
        self.val   = 0
        self.avg   = 0
        self.sum   = 0
        self.count = 0

    def update(self, val, n = 1):
        self.val    = val
        self.sum   += val * n
        self.count += n
        self.avg    = self.sum / self.count

Next, we need to wrap modeling into a single master function that can be distributed over TPU cores. We will first define functions for training and inference and then introduce the wrapper function.

The training pass must have several steps:

  • the optimizer step is done with xm.optimizer_step(optimizer)
  • the printing statements need to be defined as xm.master_print() instead of print() in order to only print a statement once (otherwise, each TPU core will print it)
  • dataloader should be defined outside of the function and read as an argument to distribute it over the cores
  • running loss can be computed using the defined AverageMeter() object

In addition, it is important to clear the TPU memory as often as possible to ensure that the modeling does not crash:

  • del [object] to delete objects once they are not needed
  • gc.collect() to collect garbage left in memory



def train_fn(epoch, para_loader, optimizer, criterion, scheduler, device):

    # initialize
    trn_loss_meter = AverageMeter()

    # training loop
    for batch_idx, (inputs, labels) in enumerate(para_loader):

        # extract inputs and labels
        inputs =
        labels =

        # forward and backward pass
        preds = model(inputs)
        loss  = criterion(preds, labels)
        xm.optimizer_step(optimizer, barrier = True) # barrier is required on single-core training but can be dropped with multiple cores
        # compute loss
        trn_loss_meter.update(loss.detach().item(), inputs.size(0))
        # feedback
        if (batch_idx > 0) and (batch_idx % batch_verbose == 0):
            xm.master_print('-- batch {} | cur_loss = {:.6f}, avg_loss = {:.6f}'.format(
                batch_idx, loss.item(), trn_loss_meter.avg))
        # clear memory
        del inputs, labels, preds, loss
        # early stop
        if batch_idx > batches_per_epoch:
    # scheduler step

    # clear memory
    del para_loader, batch_idx

    return trn_loss_meter.avg

Similar to the training pass, inference function uses dataloader as an argument, updates loss using the AverageMeter() object and clears memory after each batch.



def valid_fn(epoch, para_loader, criterion, device):

    # initialize
    val_loss_meter = AverageMeter()
    # validation loop
    for batch_idx, (inputs, labels) in enumerate(para_loader):

        # extract inputs and labels
        inputs =
        labels =

        # compute preds
        with torch.no_grad():
            preds = model(inputs)
            loss  = criterion(preds, labels)
        # compute loss
        val_loss_meter.update(loss.detach().item(), inputs.size(0))

        # feedback
        if (batch_idx > 0) and (batch_idx % batch_verbose == 0):
            xm.master_print('-- batch {} | cur_loss = {:.6f}, avg_loss =  {:.6f}'.format(
                batch_idx, loss.item(), val_loss_meter.avg))
        # clear memory
        del inputs, labels, preds, loss

    # clear memory
    del para_loader, batch_idx
    return val_loss_meter.avg

The master modeling function also includes several TPU-based modifications.

First, we need to create a distributed data sampler that reads our Dataset object and distributes batches over TPU cores. This is done with, which allows data loaders from different cores to only take a portion of the whole dataset. Setting num_replicas to xm.xrt_world_size() checks the number of available TPU cores. After defining the sampler, we can set up a data loader that uses the sampler.

Second, the model is sent to TPU with the following code:

device = xm.xla_device()
model  =

Third, we need to update learning rate since the modeling is done simultaneously on batches on different cores: scaled_eta = eta * xm.xrt_world_size().

Finally, we continue keeping track of the memory and clearing it whenever possible, and use xm.master_print() for displaying intermediate results. We also set up the function to return lists with the training and validation loss values.



def _run(model):
    ### DATA PREP
    # data samplers
    train_sampler =,
                                                                    num_replicas = xm.xrt_world_size(),
                                                                    rank         = xm.get_ordinal(),
                                                                    shuffle      = True)
    valid_sampler =,
                                                                    num_replicas = xm.xrt_world_size(),
                                                                    rank         = xm.get_ordinal(),
                                                                    shuffle      = False)
    # data loaders
    valid_loader =, 
                                               batch_size  = batch_size, 
                                               sampler     = valid_sampler, 
                                               num_workers = 0,
                                               pin_memory  = True) 
    train_loader =,
                                               batch_size  = batch_size, 
                                               sampler     = train_sampler, 
                                               num_workers = 0,
                                               pin_memory  = True)
    ### MODEL PREP
    # send to TPU
    device = xm.xla_device()
    model  =
    # scale LR
    scaled_eta = eta * xm.xrt_world_size()
    # optimizer and loss
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr = scaled_eta)
    scheduler = lr_scheduler.StepLR(optimizer, step_size = step, gamma = gamma)
    ### MODELING
    # placeholders
    trn_losses = []
    val_losses = []
    best_val_loss = 1
    # modeling loop
    for epoch in range(num_epochs):
        # display info
        xm.master_print('EPOCH {}/{}'.format(epoch + 1, num_epochs))
        xm.master_print('- initialization | TPU cores = {}, lr = {:.6f}'.format(
            xm.xrt_world_size(), scheduler.get_lr()[len(scheduler.get_lr()) - 1] / xm.xrt_world_size()))
        epoch_start = time.time()
        # update train_loader shuffling
        # training pass
        train_start = time.time()
        xm.master_print('- training...')
        para_loader = pl.ParallelLoader(train_loader, [device])
        trn_loss = train_fn(epoch       = epoch + 1, 
                            para_loader = para_loader.per_device_loader(device), 
                            criterion   = criterion,
                            optimizer   = optimizer, 
                            scheduler   = scheduler,
                            device      = device)
        del para_loader
        # validation pass
        valid_start = time.time()
        xm.master_print('- validation...')
        para_loader = pl.ParallelLoader(valid_loader, [device])
        val_loss = valid_fn(epoch       = epoch + 1, 
                            para_loader = para_loader.per_device_loader(device), 
                            criterion   = criterion, 
                            device      = device)
        del para_loader

        # save weights
        if val_loss < best_val_loss:
  , 'weights_{}.pt'.format(model_name))
            best_val_loss = val_loss
        # display info
        xm.master_print('- elapsed time | train = {:.2f} min, valid = {:.2f} min'.format(
            (valid_start - train_start) / 60, (time.time() - valid_start) / 60))
        xm.master_print('- average loss | train = {:.6f}, valid = {:.6f}'.format(
            trn_loss, val_loss))
        # save losses
        del trn_loss, val_loss
    # print results
    xm.master_print('Best results: loss = {:.6f} (epoch {})'.format(np.min(val_losses), np.argmin(val_losses) + 1))
    return trn_losses, val_losses

5. Modeling

After all helper functions have been introduced, we can finally launch the training! To do that, we need to define the last wrapper function that runs the modeling on multiple TPU cores: _mp_fn(rank, flags). Within the wrapper function, we set default tensor type to make sure that new tensors are initialized as float torch tensors on TPU and then run the modeling. It is important to set nprocs to the number of available TPU cores. The FLAGS object can be used to pass further training arguments to the modeling function.

Running xmp.spawn() will launch our _mp_fn() on multiple TPU cores! Since it does not provide any output, it is useful to save losses or other objects you might be interested in after running the _run() function.

# wrapper function
def _mp_fn(rank, flags):
    trn_losses, val_losses = _run(model)'trn_losses.npy', np.array(trn_losses))'val_losses.npy', np.array(val_losses))
# modeling
FLAGS = {}
xmp.spawn(_mp_fn, args = (FLAGS,), nprocs = num_tpu_workers, start_method = 'fork')
- initialization | TPU cores = 8, lr = 0.000010
- training...
-- batch 100 | cur_loss = 0.343750, avg_loss = 0.367342
-- batch 200 | cur_loss = 0.285156, avg_loss = 0.333299
-- batch 300 | cur_loss = 0.241211, avg_loss = 0.311770
-- batch 400 | cur_loss = 0.194336, avg_loss = 0.292837
-- batch 500 | cur_loss = 0.179688, avg_loss = 0.275745
-- batch 600 | cur_loss = 0.180664, avg_loss = 0.260264
-- batch 700 | cur_loss = 0.166016, avg_loss = 0.246196
-- batch 800 | cur_loss = 0.139648, avg_loss = 0.232805
-- batch 900 | cur_loss = 0.085449, avg_loss = 0.220550
-- batch 1000 | cur_loss = 0.109375, avg_loss = 0.209381
- validation...
-- batch 100 | cur_loss = 0.184570, avg_loss =  0.353043
- elapsed time | train = 125.82 min, valid = 26.05 min
- average loss | train = 0.209277, valid = 0.366887

Best results: loss = 0.366887 (epoch 1)

The training is working! Note that every 100 batches displayed in the snippet above actually refer to 100 * batch_size * num_tpu_workers images, since every core processes the same amount of different images simultaneously but printing is done just from one core.

6. Closing words

This is the end of this blog post. Using a computer vision application, we demonstrated how to use PyTorch/XLA to take advantage of TPU when training deep learning models. We covered important changes that need to be implemented into the modeling pipeline to enable TPU-based training, including data processing, modeling and displaying results. I hope this post will help you to get started with TPUs!

If you are interested in further reading, make sure to check tutorial notebooks developed by PyTorch/XLA team available at their GitHub repo.