Last updated: 16.01.2021

1. Overview

Deep learning heavily relies on Graphical Processing Units (GPUs) to enable fast model training. Recently, Google introduced Tensor Processing Units (TPUs) to further advance the speed of computations used in neural networks. Using cloud TPUs is possible on Kaggle and Google Colab. While TPU chips have been optimized for usage with TensorFlow pipelines, PyTorch users can also take advantage of the better compute. This requires using PyTorch/XLA and implementing certain changes into the modeling pipeline.

Moving a PyTorch pipeline to TPUs includes the following steps:

  • installing relevant packages ans setting up TPU
  • adjusting syntax of some modeling steps such as initialization, optimizer and verbosity
  • distributing data loaders over multiple TPU cores
  • wrapping data processing, training and inference into a master function

This post provides a tutorial on using PyTorch/XLA to build the TPU pipeline. The code is optimized for multi-core TPU training. Many of the ideas are adapted from here and here. We will focus on a computer vision application, but the framework can be used with other deep learning models as well. We will use data from RSNA STR Pulmonary Embolism Detection Kaggle competition on detecting pulmonary embolism on more than 1.7 million CT scans.

2. Preparations and packages

When setting up a script, it is important to introduce two TPU-related parameters: batch size and number of workers.

Google recommends using 128 images per batch for the best performance on the current TPU v3 chips. The v3 chips have 8 cores. This implies that each of the 8 cores can recieve a batch of 128 images at each training step, and the modeling can be performed simultaneously on the separate cores. Next, the model weights would be updated based on the outcomes observed on each core. Therfore, the batch size of 128 actually implies 128 * 8 images in each iteration.


# partitioning
num_folds = 5
use_fold  = 0

# image params
image_size = 128

# modeling
batch_size        = 128  # num_images = batch_size*num_tpu_workers
batches_per_epoch = 1000 # num_images = batch_size*batches_per_epoch*num_tpu_workers
num_epochs        = 1
batch_verbose     = 100
num_tpu_workers   = 8

# learning rate
eta   = 0.0001
step  = 1
gamma = 0.5

# paths
data_path  = '/kaggle/input/rsna-str-pulmonary-embolism-detection/'
image_path = '/kaggle/input/rsna-str-pe-detection-jpeg-256/train-jpegs/'

After specifying the parameters, we need to set up TPU by installing and importing torch_xla using the snippet below. There are two variants here: install the last stable XLA version (1.7 as of 15.01.2021) or install the so-called 'nightly' version that includes the latest updates but can be unstable. I recommend going for a stable version.

We also specify XLA_USE_BF16 variable (default tensor precision format) and XLA_TENSOR_ALLOCATOR_MAXSIZE variable (maximum tensor allocator size). When working in Google Colab we can also run assert os.environ['COLAB_TPU_ADDR'] to check that Colab is correctly connected to a TPU instance.

Don't be discouraged if you see error messages during the installation of fastai, kornia and allennlp. The installation would still proceed to the required versions of torch and torchvision that are needed to work with TPUs.


# XLA version
xla_version = 'nightly' # 'nightly' or '1.7'

# installation 
!curl -o
!python --verion $xla_version

# XLA imports
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.data_parallel as dp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
# configurations
import os
os.environ['XLA_USE_BF16']                 = '1'
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '1000000000'
Running on TPU  ['']
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  5115  100  5115    0     0   9199      0 --:--:-- --:--:-- --:--:--  9183
Updating... This may take around 2 minutes.
Updating TPU runtime to pytorch-dev20200515 ...
Found existing installation: torch 1.5.0
Uninstalling torch-1.5.0:
  Successfully uninstalled torch-1.5.0
Found existing installation: torchvision 0.6.0a0+35d732a
Uninstalling torchvision-0.6.0a0+35d732a:
Done updating TPU runtime
  Successfully uninstalled torchvision-0.6.0a0+35d732a
Copying gs://tpu-pytorch/wheels/torch-nightly+20200515-cp37-cp37m-linux_x86_64.whl...

Operation completed over 1 objects/91.0 MiB.                                     
Copying gs://tpu-pytorch/wheels/torch_xla-nightly+20200515-cp37-cp37m-linux_x86_64.whl...

Operation completed over 1 objects/119.5 MiB.                                    
Copying gs://tpu-pytorch/wheels/torchvision-nightly+20200515-cp37-cp37m-linux_x86_64.whl...

Operation completed over 1 objects/2.3 MiB.                                      
Processing ./torch-nightly+20200515-cp37-cp37m-linux_x86_64.whl
Requirement already satisfied: numpy in /opt/conda/lib/python3.7/site-packages (from torch==nightly+20200515) (1.18.5)
Requirement already satisfied: future in /opt/conda/lib/python3.7/site-packages (from torch==nightly+20200515) (0.18.2)
ERROR: fastai 1.0.61 requires torchvision, which is not installed.
ERROR: kornia 0.3.1 has requirement torch==1.5.0, but you'll have torch 1.6.0a0+bf2bbd9 which is incompatible.
ERROR: allennlp 1.0.0 has requirement torch<1.6.0,>=1.5.0, but you'll have torch 1.6.0a0+bf2bbd9 which is incompatible.
Installing collected packages: torch
Successfully installed torch-1.6.0a0+bf2bbd9
WARNING: You are using pip version 20.1.1; however, version 20.2.3 is available.
You should consider upgrading via the '/opt/conda/bin/python -m pip install --upgrade pip' command.
Processing ./torch_xla-nightly+20200515-cp37-cp37m-linux_x86_64.whl
Installing collected packages: torch-xla
Successfully installed torch-xla-1.6+2b2085a
WARNING: You are using pip version 20.1.1; however, version 20.2.3 is available.
You should consider upgrading via the '/opt/conda/bin/python -m pip install --upgrade pip' command.
Processing ./torchvision-nightly+20200515-cp37-cp37m-linux_x86_64.whl
Requirement already satisfied: numpy in /opt/conda/lib/python3.7/site-packages (from torchvision==nightly+20200515) (1.18.5)
Requirement already satisfied: torch in /opt/conda/lib/python3.7/site-packages (from torchvision==nightly+20200515) (1.6.0a0+bf2bbd9)
Requirement already satisfied: pillow>=4.1.1 in /opt/conda/lib/python3.7/site-packages (from torchvision==nightly+20200515) (7.2.0)
Requirement already satisfied: future in /opt/conda/lib/python3.7/site-packages (from torch->torchvision==nightly+20200515) (0.18.2)
Installing collected packages: torchvision
Successfully installed torchvision-0.7.0a0+a6073f0
WARNING: You are using pip version 20.1.1; however, version 20.2.3 is available.
You should consider upgrading via the '/opt/conda/bin/python -m pip install --upgrade pip' command.

The following additional packages will be installed:
  libgfortran4 libopenblas-base
The following NEW packages will be installed:
  libgfortran4 libomp5 libopenblas-base libopenblas-dev
0 upgraded, 4 newly installed, 0 to remove and 59 not upgraded.
Need to get 8550 kB of archives.
After this operation, 97.6 MB of additional disk space will be used.
Get:1 bionic-updates/main amd64 libgfortran4 amd64 7.5.0-3ubuntu1~18.04 [492 kB]
Get:2 bionic/universe amd64 libopenblas-base amd64 0.2.20+ds-4 [3964 kB]
Get:3 bionic/universe amd64 libopenblas-dev amd64 0.2.20+ds-4 [3860 kB]
Get:4 bionic/universe amd64 libomp5 amd64 5.0.1-1 [234 kB]
Fetched 8550 kB in 0s (19.8 MB/s)
debconf: delaying package configuration, since apt-utils is not installed
Selecting previously unselected package libgfortran4:amd64.
(Reading database ... 107745 files and directories currently installed.)
Preparing to unpack .../libgfortran4_7.5.0-3ubuntu1~18.04_amd64.deb ...
Unpacking libgfortran4:amd64 (7.5.0-3ubuntu1~18.04) ...
Selecting previously unselected package libopenblas-base:amd64.
Preparing to unpack .../libopenblas-base_0.2.20+ds-4_amd64.deb ...
Unpacking libopenblas-base:amd64 (0.2.20+ds-4) ...
Selecting previously unselected package libopenblas-dev:amd64.
Preparing to unpack .../libopenblas-dev_0.2.20+ds-4_amd64.deb ...
Unpacking libopenblas-dev:amd64 (0.2.20+ds-4) ...
Selecting previously unselected package libomp5:amd64.
Preparing to unpack .../libomp5_5.0.1-1_amd64.deb ...
Unpacking libomp5:amd64 (5.0.1-1) ...
Setting up libomp5:amd64 (5.0.1-1) ...
Setting up libgfortran4:amd64 (7.5.0-3ubuntu1~18.04) ...
Setting up libopenblas-base:amd64 (0.2.20+ds-4) ...
update-alternatives: using /usr/lib/x86_64-linux-gnu/openblas/ to provide /usr/lib/x86_64-linux-gnu/ ( in auto mode
update-alternatives: using /usr/lib/x86_64-linux-gnu/openblas/ to provide /usr/lib/x86_64-linux-gnu/ ( in auto mode
Setting up libopenblas-dev:amd64 (0.2.20+ds-4) ...
update-alternatives: using /usr/lib/x86_64-linux-gnu/openblas/ to provide /usr/lib/x86_64-linux-gnu/ ( in auto mode
update-alternatives: using /usr/lib/x86_64-linux-gnu/openblas/ to provide /usr/lib/x86_64-linux-gnu/ ( in auto mode
Processing triggers for libc-bin (2.27-3ubuntu1) ...

Next, we can import all other relevant libraries.


import numpy as np
import pandas as pd

import torch
import torchvision

import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import transforms, models, datasets
from import Dataset

from PIL import Image, ImageFile
import cv2

from sklearn.model_selection import GroupKFold

import glob

import random
import time
import sys
import os
import gc

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

import warnings

3. Data preparation

Most of the data processing pipeline does not need to be adjusted when training on TPU instead of GPU. In contrast to GPU-based code, it is only necessary to change the data loaders such that they can distribute image batches over the TPU cores. This is covered in the next section that overviews the modeling stage, since the distributed data samplers need to be wrapped into the modeling function.

It is also important to note that given a more efficient training, data import usually becomes a computational bottleneck. Hence, it is crucial to optimize data processing as much as possible. For the best performance, I recommnd transforming the data to .tfrec format to imporve the speed of data reading. You can read more about using .tfrec with PyTorch here.

Below, we construct a Dataset class to read JPG images of the corresponding CT scans. Each image has ten binary labels representing the presence of pulmonary embolism and its characteristics.


# label names
label_names = ['pe_present_on_image',

# label weights
label_weights = torch.tensor([0.07361963,



class PEData(Dataset):
    def __init__(self, data, directory, transform = None, load_jpg = False, labeled = False):      = data = directory
        self.transform = transform
        self.load_jpg  = load_jpg
        self.labeled   = labeled

    def __len__(self):
        return len(

    def __getitem__(self, idx):
        # import
        img_name = glob.glob(os.path.join(, '/'.join([idx][['StudyInstanceUID', 'SeriesInstanceUID']]) + '/*' +[idx]['SOPInstanceUID'] + '.jpg'))[0]
        image    = cv2.imread(img_name)
        # switch channels and normalize
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image / 255.0
        # convert
        image = torch.tensor(image, dtype = torch.float)
        image = image.permute(2, 0, 1)

        # augmentations
        image = self.transform(image)
        # output
        if self.labeled:
            labels = torch.tensor([idx][label_names].values.astype('int'), dtype = torch.float)
            return image, labels
            return image


train_trans = test_trans = transforms.Compose([transforms.ToPILImage(),

We split data into training and validation folds such that images from the same patient - StudyInstanceUID - do not appear in both folds. We will only use a single fold for demonstration purposes.


# partitioning
train = pd.read_csv(data_path + 'train.csv')
gkf = GroupKFold(n_splits = num_folds)
train['fold'] = -1
for fold, (_, val_idx) in enumerate(gkf.split(train, groups = train['StudyInstanceUID'])):
    train.loc[val_idx, 'fold'] = fold

# load splits
data_train = train.loc[train.fold != use_fold].reset_index(drop = True)
data_valid = train.loc[train.fold == use_fold].reset_index(drop = True)

# datasets
train_dataset = PEData(data      = data_train, 
                       directory = image_path,
                       transform = train_trans,
                       load_jpg  = load_jpegs,
                       labeled   = True)
valid_dataset = PEData(data      = data_valid, 
                       directory = image_path,
                       transform = test_trans,
                       load_jpg  = load_jpegs,
                       labeled   = True)

Before proceeding to modeling, let's take a look at a sample bacth of training images using our processing pipeline.


# sample loader
sample_loader =,
                                            shuffle     = False,
                                            batch_size  = 8, 
                                            num_workers = 1)

# display images
for batch_idx, (inputs, labels) in enumerate(sample_loader):
    fig = plt.figure(figsize = (14, 7))
    for i in range(8):
        ax = fig.add_subplot(2, 4, i + 1, xticks = [], yticks = [])     
        plt.imshow(inputs[i].numpy().transpose(1, 2, 0))
        ax.set_title(labels.numpy()[:, i])

4. Model setup

The modeling stage needs to be modified because the modeling is performed simultaneously on multiple TPU cores. This requires changes to model initialization, optimzer and building a master function to distribute data loaders, training and inference over multi-core TPU chips. Let's dive in!

We start with the model. We are using ResNet-34 with ten output nodes corresponding to each of the ten binary labels. After initializing the model, it is important to wrap it into the MX object that can be sent to TPU. This is done by a simple command mx = xmp.MpModelWrapper(model).


# initialization 
def init_model():
    model    = models.resnet34(pretrained = True)
    model.fc = torch.nn.Linear(in_features = 512, out_features = len(label_names), bias = True)
    return model

# model wrapper
model = init_model()
mx    = xmp.MpModelWrapper(model)

Tracking the running loss when training on multiple TPU cores can a bit difficult since we need to aggregate batch losses between the TPU cores. The following helper class allows to externally store the loss values and update it based on the batch outputs from each worker.


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):

    def reset(self):
        self.val   = 0
        self.avg   = 0
        self.sum   = 0
        self.count = 0

    def update(self, val, n = 1):
        self.val    = val
        self.sum   += val * n
        self.count += n
        self.avg    = self.sum / self.count

Next, we need to wrap modeling into a single master function that can be distributed over TPU cores. We will first define functions for training and inference and then introduce the wrapper function.

The training pass must have several steps:

  • the optimizer step is done with xm.optimizer_step(optimizer)
  • the printing statements need to be defined as xm.master_print() instead of print() in order to only print a statement once (otherwise, each TPU core will print it)
  • data loader should be defined outside of the function and read as an argument to distribute it over the cores
  • running loss can be computed using the defined AverageMeter() object

In addition, it is important to clear the TPU memory as often as possible to ensure that the modeling does not crash:

  • del [object] to delete objects once they are not needed
  • gc.collect() to collect garbage left in memory



def train_fn(epoch, para_loader, optimizer, criterion, scheduler, device):

    # initialize
    trn_loss_meter = AverageMeter()

    # training loop
    for batch_idx, (inputs, labels) in enumerate(para_loader):

        # extract inputs and labels
        inputs =
        labels =

        # forward and backward pass
        preds = model(inputs)
        loss  = criterion(preds, labels)
        xm.optimizer_step(optimizer, barrier = True) # barrier is required on single-core training but can be dropped with multiple cores
        # compute loss
        trn_loss_meter.update(loss.detach().item(), inputs.size(0))
        # feedback
        if (batch_idx > 0) and (batch_idx % batch_verbose == 0):
            xm.master_print('-- batch {} | cur_loss = {:.6f}, avg_loss = {:.6f}'.format(
                batch_idx, loss.item(), trn_loss_meter.avg))
        # clear memory
        del inputs, labels, preds, loss
        # early stop
        if batch_idx > batches_per_epoch:
    # scheduler step

    # clear memory
    del para_loader, batch_idx

    return trn_loss_meter.avg

Similar to the training pass, inference function uses data loader as an argument, updates loss using the AverageMeter() object and clears memory after each batch.



def valid_fn(epoch, para_loader, criterion, device):

    # initialize
    val_loss_meter = AverageMeter()
    # validation loop
    for batch_idx, (inputs, labels) in enumerate(para_loader):

        # extract inputs and labels
        inputs =
        labels =

        # compute preds
        with torch.no_grad():
            preds = model(inputs)
            loss  = criterion(preds, labels)
        # compute loss
        val_loss_meter.update(loss.detach().item(), inputs.size(0))

        # feedback
        if (batch_idx > 0) and (batch_idx % batch_verbose == 0):
            xm.master_print('-- batch {} | cur_loss = {:.6f}, avg_loss =  {:.6f}'.format(
                batch_idx, loss.item(), val_loss_meter.avg))
        # clear memory
        del inputs, labels, preds, loss

    # clear memory
    del para_loader, batch_idx
    return val_loss_meter.avg

The master modeling function also includes several TPU-based modifications.

First, we need to create a distributed data sampler that reads our Dataset object and distributes batches over TPU cores. This is done with, which allows data loaders from different cores to only take a portion of the whole dataset. Setting num_replicas to xm.xrt_world_size() checks the number of available TPU cores. After defining the sampler, we can set up a data loader that uses the sampler.

Second, the model is sent to TPU with the following code:

device = xm.xla_device()
model  =

Third, we need to update learning rate since the modeling is done simultaneously on batches on different cores: scaled_eta = eta * xm.xrt_world_size().

Finally, we continue keeping track of the memory and clearing it whenever possible, and use xm.master_print() for displaying intermediate results. We also set up the function to return lists with the training and validation loss values.



def _run(model):
    ### DATA PREP
    # data samplers
    train_sampler =,
                                                                    num_replicas = xm.xrt_world_size(),
                                                                    rank         = xm.get_ordinal(),
                                                                    shuffle      = True)
    valid_sampler =,
                                                                    num_replicas = xm.xrt_world_size(),
                                                                    rank         = xm.get_ordinal(),
                                                                    shuffle      = False)
    # data loaders
    valid_loader =, 
                                               batch_size  = batch_size, 
                                               sampler     = valid_sampler, 
                                               num_workers = 0,
                                               pin_memory  = True) 
    train_loader =,
                                               batch_size  = batch_size, 
                                               sampler     = train_sampler, 
                                               num_workers = 0,
                                               pin_memory  = True)
    ### MODEL PREP
    # send to TPU
    device = xm.xla_device()
    model  =
    # scale LR
    scaled_eta = eta * xm.xrt_world_size()
    # optimizer and loss
    criterion = nn.BCEWithLogitsLoss(pos_weight = label_weights)
    optimizer = optim.Adam(model.parameters(), lr = scaled_eta)
    scheduler = lr_scheduler.StepLR(optimizer, step_size = step, gamma = gamma)
    ### MODELING
    # placeholders
    trn_losses = []
    val_losses = []
    best_val_loss = 1
    # modeling loop
    for epoch in range(num_epochs):
        # display info
        xm.master_print('EPOCH {}/{}'.format(epoch + 1, num_epochs))
        xm.master_print('- initialization | TPU cores = {}, lr = {:.6f}'.format(
            xm.xrt_world_size(), scheduler.get_lr()[len(scheduler.get_lr()) - 1] / xm.xrt_world_size()))
        epoch_start = time.time()
        # update train_loader shuffling
        # training pass
        train_start = time.time()
        xm.master_print('- training...')
        para_loader = pl.ParallelLoader(train_loader, [device])
        trn_loss = train_fn(epoch       = epoch + 1, 
                            para_loader = para_loader.per_device_loader(device), 
                            criterion   = criterion,
                            optimizer   = optimizer, 
                            scheduler   = scheduler,
                            device      = device)
        del para_loader
        # validation pass
        valid_start = time.time()
        xm.master_print('- validation...')
        para_loader = pl.ParallelLoader(valid_loader, [device])
        val_loss = valid_fn(epoch       = epoch + 1, 
                            para_loader = para_loader.per_device_loader(device), 
                            criterion   = criterion, 
                            device      = device)
        del para_loader

        # save weights
        if val_loss < best_val_loss:
  , 'weights_{}.pt'.format(model_name))
            best_val_loss = val_loss
        # display info
        xm.master_print('- elapsed time | train = {:.2f} min, valid = {:.2f} min'.format(
            (valid_start - train_start) / 60, (time.time() - valid_start) / 60))
        xm.master_print('- average loss | train = {:.6f}, valid = {:.6f}'.format(
            trn_loss, val_loss))
        # save losses
        del trn_loss, val_loss

    # print results
    xm.master_print('Best results: loss = {:.6f} (epoch {})'.format(np.min(val_losses), np.argmin(val_losses) + 1))
    return trn_losses, val_losses

5. Modeling

After all helper functions have been introduced, we can finally launch the training! To do that, we need to define the last wrapper function that runs the modeling on multiple TPU cores: _mp_fn(rank, flags). Within the wrapper function, we set default tensor type to make sure that new tensors are initialized as float torch tensors on TPU and then run the modeling. It is important to set nprocs to the number of available TPU cores. The FLAGS object can be used to pass further training arguments to the modeling function.

Running xmp.spawn() will launch our _mp_fn() on multiple TPU cores! Since it does not provide any ouput, it is useful to save losses or other obects you might be interested in after running the _run() function.


# wrapper function
def _mp_fn(rank, flags):
    trn_losses, val_losses = _run(model)'trn_losses.npy', np.array(trn_losses))'val_losses.npy', np.array(val_losses))
# modeling
FLAGS = {}
xmp.spawn(_mp_fn, args = (FLAGS,), nprocs = num_tpu_workers, start_method = 'fork')
- initialization | TPU cores = 8, lr = 0.000010
- training...
-- batch 100 | cur_loss = 0.343750, avg_loss = 0.367342
-- batch 200 | cur_loss = 0.285156, avg_loss = 0.333299
-- batch 300 | cur_loss = 0.241211, avg_loss = 0.311770
-- batch 400 | cur_loss = 0.194336, avg_loss = 0.292837
-- batch 500 | cur_loss = 0.179688, avg_loss = 0.275745
-- batch 600 | cur_loss = 0.180664, avg_loss = 0.260264
-- batch 700 | cur_loss = 0.166016, avg_loss = 0.246196
-- batch 800 | cur_loss = 0.139648, avg_loss = 0.232805
-- batch 900 | cur_loss = 0.085449, avg_loss = 0.220550
-- batch 1000 | cur_loss = 0.109375, avg_loss = 0.209381
- validation...
-- batch 100 | cur_loss = 0.184570, avg_loss =  0.353043
- elapsed time | train = 125.82 min, valid = 26.05 min
- average loss | train = 0.209277, valid = 0.366887

Best results: loss = 0.366887 (epoch 1)

The training is working! Note that every 100 batches displayed in the snippet above actually refer to 100 * batch_size * num_tpu_workers images, since every core processes the same amount of different images simultaneously but printing is done just from one core.

6. Closing words

This is the end of this blogpost. Using a computer vision application, we demonstrated how to use PyTorch/XLA to take advantage of TPU when training deep learning models. We covered important cnahges that need to be implemented into the modeling pipeline to enable TPU-based training, including data processing, modeling and displaying results. I hope this post will help you to get started with TPUs!

If you are interetef in further reading, make sure to check tutorial notebooks developed by PyTorch/XLA team available at their GitHub repo.