1. Overview

In computer vision, it is recommended to normalize pixel values of the images relative to the dataset mean and standard deviation. This helps to get consistent results when applying the model to new images and can also be useful for transfer learning. In practice, computing these statistics can be a little non-trivial since we usually can't load the whole dataset in memory and have to loop through it in batches.

This blog post provides a quick tutorial on computing dataset mean and std within RGB channels using a regular PyTorch dataloader. While computing mean is easy (we can simply average means over batches), standard deviation is a bit more tricky: averaging stds from different batches is not the same as the overall std. Let's see how to do it correctly!

2. Preparations

To demonstrate how to compute image statistics, we will use data from Cassava Leaf Disease Classification Kaggle competition with about 21,000 plant images. Feel free to scroll down to Section 3 if you want to jump directly to calculations. First, we will import all the usuall libraries and specify the relevant parameters. No need to use GPU because there is no modeling involved.

#collapse-hide

####### PACKAGES

import numpy as np
import pandas as pd

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

from PIL import Image, ImageFile
import cv2

from tqdm import tqdm

import matplotlib.pyplot as plt
%matplotlib inline


####### PARAMS

device      = torch.device('cpu') 
num_workers = 4
image_size  = 512 
batch_size  = 8
data_path   = '/kaggle/input/cassava-leaf-disease-classification/'

Now, let's import the dataframe with the image paths and create a Dataset class that will read images and supply them to the dataloader.

#collapse-show

df = pd.read_csv(data_path + 'train.csv')
df.head()
image_id label
0 1000015157.jpg 0
1 1000201771.jpg 3
2 100042118.jpg 1
3 1000723321.jpg 1
4 1000812911.jpg 3

#collapse-show

class LeafData(Dataset):
    
    def __init__(self, 
                 data, 
                 directory, 
                 transform = None):
        self.data      = data
        self.directory = directory
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        
        # import
        path  = os.path.join(self.directory, self.data.iloc[idx]['image_id'])
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
        # augmentations
        if self.transform is not None:
            image = self.transform(image = image)['image']
        
        # output
        return image

Since we want to compute statistcs for original images, our data augmentation pipeline should be minimal and not include any heavy augmentations we might use during training. Below, we use A.Normalize() with mean = 0 and std = 1 to scale pixel values from [0, 255] to [0, 1] and ToTensorV2() to convert numpy arrays into torch tensors.

#collapse-show

augs = A.Compose([A.Resize(height = image_size, 
                           width  = image_size),
                  A.Normalize(mean = (0, 0, 0),
                              std  = (1, 1, 1)),
                  ToTensorV2()])

Let's check if our code above works correctly. We define a DataLoader that loads images in batches from LeafData and plot images from the first batch.

####### EXAMINE SAMPLE BATCH

# dataset
sample_dataset = LeafData(data      = df, 
                          directory = data_path + 'train_images/',
                          transform = augs)

# data loader
sample_loader = DataLoader(sample_dataset, 
                           batch_size  = batch_size, 
                           shuffle     = False, 
                           num_workers = num_workers,
                           pin_memory  = True)

# display images
for batch_idx, (inputs, labels) in enumerate(sample_loader):
    fig = plt.figure(figsize = (14, 7))
    for i in range(8):
        ax = fig.add_subplot(2, 4, i + 1, xticks = [], yticks = [])     
        plt.imshow(inputs[i].numpy().transpose(1, 2, 0))
    break

Looks like everithing is working correctly! Now we can use our sample_loader to compute image stats.

3. Computing image stats

The computation is done in three steps:

  1. Define placeholders to store two batch-level statistics: sum and squared sum of pixel values. The first will be used to compute means, while the latter will be needed for standard deviation calculations.
  2. Loop through the batches and add up channel-specific sum and squared sum values.
  3. Perform final calculations to obtain data-level mean and standard deviation.

The first two steps are done in the snippet below. Note that we set axis = [0, 2, 3] to compute mean values with respect to axis 1. The dimensions of inputs is [batch_size x 3 x image_size x image_size], so we need to make sure we aggregate values per each RGB channel separately.

####### COMPUTE MEAN / STD

# placeholders
psum   = torch.tensor([0.0, 0.0, 0.0])
psumsq = torch.tensor([0.0, 0.0, 0.0])

# loop through images
for batch_idx, (inputs, labels) in tqdm(enumerate(sample_loader), total = len(sample_loader)):
    psum   += inputs.sum(axis        = [0, 2, 3])
    psumsq += (inputs ** 2).sum(axis = [0, 2, 3])
100%|██████████| 2675/2675 [04:21<00:00, 10.23it/s]

Finally, we need to make come final calculations:

  • to get the mean, we simply divide the sum of pixel values by the total count - number of pixels in the dataset computed as len(df) * image_size * image_size
  • to get the standard deviation, we use the following equation: total_std = sqrt(psumq / count - total_mean ** 2). Why is that so? Well, because this is how the variance equation can be simplified to make use of the sum of squares. If you are not sure about this, feel free to check out this link for some details.

variance equation

####### FINAL CALCULATIONS

# pixel count
count = len(df) * image_size * image_size

# mean and std
total_mean = psum / count
total_var  = (psumsq / count) - (total_mean ** 2)
total_std  = torch.sqrt(total_var)

# output
print('mean: '  + str(total_mean))
print('std:  '  + str(total_std))
mean: tensor([0.4417, 0.5110, 0.3178])
std:  tensor([0.2330, 0.2358, 0.2247])

This is it! Now you can plug in the mean and std values to A.Normalize() in your data augmentation pipeline to make sure your dataset is normalized :)

4. Closing words

I hope this tutorial was helpful for those who were looking for a quick guide on computing the image stats. From my experience, normalizing images with respect to dataset mean and std does not always help to improve the performance, but it is one of the things I always try first. Happy learning and stay tuned for the next posts!