Last update: 22.10.2021

1. Overview

In many real-world settings, the size of the labeled training sample is lower compared to the unlabeled test data. This blogpost demonstrates a technique that can improve the performance of neural networks in such settings by learning from both training and test data. This is done by pre-training a model on the complete data set using a surrogate label. The approach can help to reduce the impact of sampling bias by exposing the model to the test data and benefit from a larger sample size while learning.

We will focus on a computer vision application, but the idea can be used with deep learning models in other domains. We will use data from SIIM-ISIC Melanoma Classification Kaggle competition to distinguish malignant and benign lesions on medical images. The modeling is performed in tensorflow. A shorter and interactive version of this blogpost is also available as a Kaggle notebook.

2. Intuition

How to make use of the test sample on the pre-training stage? The labels are only observed for the training data. Luckily, in many settings, there is a bunch of meta-data available for both labeled and unlabeled images. Consider the task of lung cancer detection. The CT scans of cancer patients may contain information on the patient's age and gender. In contrast with the label, which requires medical tests or experts' diagnosis, meta-data is available at no additional cost. Another example is bird image classification, where the image meta-data such as time and location of the photo can serve the same purpose. In this blogpost, we will focus on malignant lesion classification, where patient meta-data is available for all images.

We can leverage meta-data in the following way:

  1. Pre-train a supplementary model on the complete train + test data using one of the meta-features as a surrogate label.
  2. Initialize from the pre-trained weights when training the main model.

The intuition behind this approach is that by learning to classify images according to one of meta variables, the model can learn some of the visual features that might be useful for the main task, which in our case is malignant lesion classification. For instance, lesion size and skin color can be helpful in determining both lesion location (surrogate label) and lesion type (actual label). Exposing the model to the test data also allows it to take a sneak peek at test images, which may help to learn patterns prevalent in the test distribution.

P.S. The notebook heavily relies on the great modeling pipeline developed by Chris Deotte for the SIIM-ISIC competition and reuses much of his original code. Kindly refer to his notebook for general questions on the pipeline where he provided comments and documentation.

3. Initialization

#collapse-hide

### PACKAGES

!pip install -q efficientnet >> /dev/null

import pandas as pd, numpy as np
from kaggle_datasets import KaggleDatasets
import tensorflow as tf, re, math
import tensorflow.keras.backend as K
import efficientnet.tfkeras as efn
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
from scipy.stats import rankdata
import PIL, cv2

Let's set up training parameters such as image size, number of folds and batch size. In addition to these parameters, we introduce USE_PRETRAIN_WEIGHTS variable to reflect whether we want to pre-train a supplementary model on full data before training the main melanoma classification model.

For demonstration purposes, we use EfficientNet B0, 128x128 image size and no TTA. Feel free to experiment with larger architectures and images sizes by editing this notebook.

#collapse-show

# DEVICE
DEVICE = "TPU"

# USE DIFFERENT SEED FOR DIFFERENT STRATIFIED KFOLD
SEED = 42

# NUMBER OF FOLDS. USE 3, 5, OR 15 
FOLDS = 5

# WHICH IMAGE SIZES TO LOAD EACH FOLD
IMG_SIZES = [128]*FOLDS

# BATCH SIZE AND EPOCHS
BATCH_SIZES = [32]*FOLDS
EPOCHS      = [10]*FOLDS

# WHICH EFFICIENTNET TO USE
EFF_NETS = [0]*FOLDS

# WEIGHTS FOR FOLD MODELS WHEN PREDICTING TEST
WGTS = [1/FOLDS]*FOLDS

# PRETRAINED WEIGHTS
USE_PRETRAIN_WEIGHTS = True

Below, we connect to TPU or GPU for faster training.

#collapse-hide

# CONNECT TO DEVICE
if DEVICE == "TPU":
    print("connecting to TPU...")
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        print('Running on TPU ', tpu.master())
    except ValueError:
        print("Could not connect to TPU")
        tpu = None

    if tpu:
        try:
            print("initializing  TPU ...")
            tf.config.experimental_connect_to_cluster(tpu)
            tf.tpu.experimental.initialize_tpu_system(tpu)
            strategy = tf.distribute.experimental.TPUStrategy(tpu)
            print("TPU initialized")
        except _:
            print("failed to initialize TPU")
    else:
        DEVICE = "GPU"

if DEVICE != "TPU":
    print("Using default strategy for CPU and single GPU")
    strategy = tf.distribute.get_strategy()

if DEVICE == "GPU":
    print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

AUTO     = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')
connecting to TPU...
Running on TPU  grpc://10.0.0.2:8470
initializing  TPU ...
TPU initialized
REPLICAS: 8

4. Image processing

First, we specify data paths. The data is stored as tfrecords to enable fast processing. You can read more on the data here.

#collapse-show

# IMAGE PATHS
GCS_PATH = [None]*FOLDS

for i,k in enumerate(IMG_SIZES):
    GCS_PATH[i]  = KaggleDatasets().get_gcs_path('melanoma-%ix%i'%(k,k))
    
files_train = np.sort(np.array(tf.io.gfile.glob(GCS_PATH[0] + '/train*.tfrec')))
files_test  = np.sort(np.array(tf.io.gfile.glob(GCS_PATH[0] + '/test*.tfrec')))

The read_labeled_tfrecord() function provides two outputs:

  1. Image tensor.
  2. Either anatom_site_general_challenge or target as a label. The former is a one-hot-encoded categorical feature with six possible values indicating the lesion location. The latter is a binary target indicating whether the lesion is malignant. The selection of the label is controlled by the pretraining argument read from the get_dataset() function below. Setting pretraining = True implies using anatom_site_general_challenge as a surrogate label.

We also set up read_unlabeled_tfrecord() that returns image and image name.

#collapse-show

def read_labeled_tfrecord(example, pretraining = False):
    if pretraining:
        tfrec_format = {
            'image'                        : tf.io.FixedLenFeature([], tf.string),
            'image_name'                   : tf.io.FixedLenFeature([], tf.string),
            'anatom_site_general_challenge': tf.io.FixedLenFeature([], tf.int64),
        }      
    else:
        tfrec_format = {
            'image'                        : tf.io.FixedLenFeature([], tf.string),
            'image_name'                   : tf.io.FixedLenFeature([], tf.string),
            'target'                       : tf.io.FixedLenFeature([], tf.int64)
        }   
    example = tf.io.parse_single_example(example, tfrec_format)
    return example['image'], tf.one_hot(example['anatom_site_general_challenge'], 6) if pretraining else example['target']


def read_unlabeled_tfrecord(example, return_image_name=True):
    tfrec_format = {
        'image'                        : tf.io.FixedLenFeature([], tf.string),
        'image_name'                   : tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(example, tfrec_format)
    return example['image'], example['image_name'] if return_image_name else 0

 
def prepare_image(img, dim = 256):    
    img = tf.image.decode_jpeg(img, channels = 3)
    img = tf.cast(img, tf.float32) / 255.0
    img = img * circle_mask
    img = tf.reshape(img, [dim,dim, 3])
            
    return img

def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) 
         for filename in filenames]
    return np.sum(n)

The get_dataset() function is a wrapper function that loads and processes images given the arguments that control the import options.

#collapse-show

def get_dataset(files, 
                shuffle            = False, 
                repeat             = False, 
                labeled            = True, 
                pretraining        = False,
                return_image_names = True, 
                batch_size         = 16, 
                dim                = 256):
    
    ds = tf.data.TFRecordDataset(files, num_parallel_reads = AUTO)
    ds = ds.cache()
    
    if repeat:
        ds = ds.repeat()
    
    if shuffle: 
        ds = ds.shuffle(1024*2) #if too large causes OOM in GPU CPU
        opt = tf.data.Options()
        opt.experimental_deterministic = False
        ds = ds.with_options(opt)
        
    if labeled: 
        ds = ds = ds.map(lambda example: read_labeled_tfrecord(example, pretraining), 
                         num_parallel_calls=AUTO)
    else:
        ds = ds.map(lambda example: read_unlabeled_tfrecord(example, return_image_names), 
                    num_parallel_calls = AUTO)
    
    ds = ds.map(lambda img, imgname_or_label: (
                prepare_image(img, dim = dim), 
                imgname_or_label), 
                num_parallel_calls = AUTO)
    
    ds = ds.batch(batch_size * REPLICAS)
    ds = ds.prefetch(AUTO)
    return ds

We also use a circular crop (a.k.a. microscope augmentation) to improve image consistency. The snippet below creates a circular mask, which is applied in the prepare_image() function.

#collapse-show

# CIRCLE CROP PREPARATIONS
circle_img  = np.zeros((IMG_SIZES[0], IMG_SIZES[0]), np.uint8)
circle_img  = cv2.circle(circle_img, (int(IMG_SIZES[0]/2), int(IMG_SIZES[0]/2)), int(IMG_SIZES[0]/2), 1, thickness = -1)
circle_img  = np.repeat(circle_img[:, :, np.newaxis], 3, axis = 2)
circle_mask = tf.cast(circle_img, tf.float32)

Let's have a quick look at a batch of our images:

#collapse-hide

# LOAD DATA AND APPLY AUGMENTATIONS
def show_dataset(thumb_size, cols, rows, ds):
    mosaic = PIL.Image.new(mode='RGB', size=(thumb_size*cols + (cols-1), 
                                             thumb_size*rows + (rows-1)))
    for idx, data in enumerate(iter(ds)):
        img, target_or_imgid = data
        ix  = idx % cols
        iy  = idx // cols
        img = np.clip(img.numpy() * 255, 0, 255).astype(np.uint8)
        img = PIL.Image.fromarray(img)
        img = img.resize((thumb_size, thumb_size), resample = PIL.Image.BILINEAR)
        mosaic.paste(img, (ix*thumb_size + ix, 
                           iy*thumb_size + iy))
        nn = target_or_imgid.numpy().decode("utf-8")

    display(mosaic)
    return nn

files_train = tf.io.gfile.glob(GCS_PATH[0] + '/train*.tfrec')
ds = tf.data.TFRecordDataset(files_train, num_parallel_reads = AUTO).shuffle(1024)
ds = ds.take(10).cache()
ds = ds.map(read_unlabeled_tfrecord, num_parallel_calls = AUTO)
ds = ds.map(lambda img, target: (prepare_image(img, dim = IMG_SIZES[0]),
                                 target), num_parallel_calls = AUTO)
ds = ds.take(12*5)
ds = ds.prefetch(AUTO)

# DISPLAY IMAGES
name = show_dataset(128, 5, 2, ds)

i# 5. Modeling

Pre-trained model with surrogate label

The build_model() function incorporates three important features that depend on the training regime:

  1. When building a model for pre-training, we use CategoricalCrossentropy as a loss because anatom_site_general_challenge is a categorical variable. When building a model that classifies lesions as benign/malignant, we use BinaryCrossentropy as a loss.
  2. When training a final binary classification model, we load the pre-trained weights using base.load_weights('base_weights.h5') if use_pretrain_weights == True.
  3. We use a dense layer with six output nodes and softmax activation when doing pre-training and a dense layer with a single output node and sigmoid activation when training a final model.

#collapse-show

EFNS = [efn.EfficientNetB0, efn.EfficientNetB1, efn.EfficientNetB2, efn.EfficientNetB3, 
        efn.EfficientNetB4, efn.EfficientNetB5, efn.EfficientNetB6, efn.EfficientNetB7]

def build_model(dim = 256, ef = 0, pretraining = False, use_pretrain_weights = False):
    
    # base
    inp  = tf.keras.layers.Input(shape = (dim,dim,3))
    base = EFNS[ef](input_shape = (dim,dim,3), weights = 'imagenet', include_top = False)
    
    # base weights
    if use_pretrain_weights:
        base.load_weights('base_weights.h5')
    
    x = base(inp)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    
    if pretraining:
        x     = tf.keras.layers.Dense(6, activation = 'softmax')(x)
        model = tf.keras.Model(inputs = inp, outputs = x)
        opt   = tf.keras.optimizers.Adam(learning_rate = 0.001)
        loss  = tf.keras.losses.CategoricalCrossentropy()    
        model.compile(optimizer = opt, loss = loss)
    else:
        x     = tf.keras.layers.Dense(1, activation = 'sigmoid')(x)
        model = tf.keras.Model(inputs = inp, outputs = x)
        opt   = tf.keras.optimizers.Adam(learning_rate = 0.001)
        loss  = tf.keras.losses.BinaryCrossentropy(label_smoothing = 0.01)  
        model.compile(optimizer = opt, loss = loss, metrics = ['AUC'])
    
    return model

#collapse-hide

### LEARNING RATE SCHEDULE

def get_lr_callback(batch_size=8):
    
    lr_start   = 0.000005
    lr_max     = 0.00000125 * REPLICAS * batch_size
    lr_min     = 0.000001
    lr_ramp_ep = 5
    lr_sus_ep  = 0
    lr_decay   = 0.8
   
    def lrfn(epoch):
        if epoch < lr_ramp_ep:
            lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start
            
        elif epoch < lr_ramp_ep + lr_sus_ep:
            lr = lr_max
            
        else:
            lr = (lr_max - lr_min) * lr_decay**(epoch - lr_ramp_ep - lr_sus_ep) + lr_min
            
        return lr

    lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose=False)
    return lr_callback

The pre-trained model is trained on both training and test data. Here, we use the original training data merged with the complete test set as a training sample. We fix the number of training epochs to EPOCHS and do not perform early stopping. You can also experiment with setting up a small validation sample from both training and test data to perform early stopping.

#collapse-show

### PRE-TRAINED MODEL
if USE_PRETRAIN_WEIGHTS:

    # USE VERBOSE=0 for silent, VERBOSE=1 for interactive, VERBOSE=2 for commit
    VERBOSE = 2

    # DISPLAY INFO
    if DEVICE == 'TPU':
        if tpu: tf.tpu.experimental.initialize_tpu_system(tpu)

    # CREATE TRAIN AND VALIDATION SUBSETS
    files_train = tf.io.gfile.glob(GCS_PATH[0] + '/train*.tfrec')
    print('#### Using 2020 train data')
    files_train += tf.io.gfile.glob(GCS_PATH[0] + '/test*.tfrec')
    print('#### Using 2020 test data')
    np.random.shuffle(files_train)

    # BUILD MODEL
    K.clear_session()
    tf.random.set_seed(SEED)
    with strategy.scope():
        model = build_model(dim         = IMG_SIZES[0],
                            ef          = EFF_NETS[0], 
                            pretraining = True)

    # SAVE BEST MODEL EACH FOLD
    sv = tf.keras.callbacks.ModelCheckpoint(
        'weights.h5', monitor='loss', verbose=0, save_best_only=True,
        save_weights_only=True, mode='min', save_freq='epoch')

    # TRAIN
    print('Training...')
    history = model.fit(
        get_dataset(files_train, 
                    dim         = IMG_SIZES[0], 
                    batch_size  = BATCH_SIZES[0],
                    shuffle     = True, 
                    repeat      = True, 
                    pretraining = True), 
        epochs          = EPOCHS[0], 
        callbacks       = [sv, get_lr_callback(BATCH_SIZES[0])], 
        steps_per_epoch = count_data_items(files_train)/BATCH_SIZES[0]//REPLICAS,
        verbose = VERBOSE)
    
else:
    
    print('#### NOT using a pre-trained model')
#### Using 2020 train data
#### Using 2020 test data
Downloading data from https://github.com/Callidior/keras-applications/releases/download/efficientnet/efficientnet-b0_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5
16809984/16804768 [==============================] - 0s 0us/step
Training...
Epoch 1/10
170/170 - 10s - loss: 1.7556 - lr: 5.0000e-06
Epoch 2/10
170/170 - 10s - loss: 1.1257 - lr: 6.8000e-05
Epoch 3/10
170/170 - 11s - loss: 0.8906 - lr: 1.3100e-04
Epoch 4/10
170/170 - 10s - loss: 0.8118 - lr: 1.9400e-04
Epoch 5/10
170/170 - 9s - loss: 0.8222 - lr: 2.5700e-04
Epoch 6/10
170/170 - 9s - loss: 0.8626 - lr: 3.2000e-04
Epoch 7/10
170/170 - 9s - loss: 0.8402 - lr: 2.5620e-04
Epoch 8/10
170/170 - 9s - loss: 0.8257 - lr: 2.0516e-04
Epoch 9/10
170/170 - 10s - loss: 0.8091 - lr: 1.6433e-04
Epoch 10/10
170/170 - 10s - loss: 0.7865 - lr: 1.3166e-04

The pre-training is complete! Now, we need to resave weights of our pre-trained model to make it easier to load them in the future. We are not really interested in the classification head, so we only export the weights of the convolutional part of the network. We can index these layers using model.layers[1].

#collapse-show

# LOAD WEIGHTS AND CHECK MODEL
if USE_PRETRAIN_WEIGHTS:
    model.load_weights('weights.h5')
    model.summary()
    
# EXPORT BASE WEIGHTS
if USE_PRETRAIN_WEIGHTS:
    model.layers[1].save_weights('base_weights.h5')
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 128, 128, 3)]     0         
_________________________________________________________________
efficientnet-b0 (Model)      (None, 4, 4, 1280)        4049564   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 6)                 7686      
=================================================================
Total params: 4,057,250
Trainable params: 4,015,234
Non-trainable params: 42,016
_________________________________________________________________

Main classification model

Now we can train a final classification model using a cross-validation framework on the training data!

We need to take care of a couple of changes:

  1. Make sure that we don't use test data in the training folds.
  2. Set use_pretrain_weights = True and pretraining = False in the build_model() function to initialize from the pre-trained weights in the beginning of each fold.

#collapse-show

# USE VERBOSE=0 for silent, VERBOSE=1 for interactive, VERBOSE=2 for commit
VERBOSE = 0

skf = KFold(n_splits = FOLDS, shuffle = True, random_state = SEED)
oof_pred = []; oof_tar = []; oof_val = []; oof_names = []; oof_folds = []
preds = np.zeros((count_data_items(files_test),1))

for fold,(idxT,idxV) in enumerate(skf.split(np.arange(15))):
    
    # DISPLAY FOLD INFO
    if DEVICE == 'TPU':
        if tpu: tf.tpu.experimental.initialize_tpu_system(tpu)
    print('#'*25); print('#### FOLD',fold+1)
    
    # CREATE TRAIN AND VALIDATION SUBSETS
    files_train = tf.io.gfile.glob([GCS_PATH[fold] + '/train%.2i*.tfrec'%x for x in idxT])      
    np.random.shuffle(files_train); print('#'*25)
    
    files_valid = tf.io.gfile.glob([GCS_PATH[fold] + '/train%.2i*.tfrec'%x for x in idxV])
    files_test = np.sort(np.array(tf.io.gfile.glob(GCS_PATH[fold] + '/test*.tfrec')))
    
    # BUILD MODEL
    K.clear_session()
    tf.random.set_seed(SEED)
    with strategy.scope():
        model = build_model(dim                  = IMG_SIZES[fold],
                            ef                   = EFF_NETS[fold],
                            use_pretrain_weights = USE_PRETRAIN_WEIGHTS, 
                            pretraining          = False)
        
    # SAVE BEST MODEL EACH FOLD
    sv = tf.keras.callbacks.ModelCheckpoint(
        'fold-%i.h5'%fold, monitor='val_auc', verbose=0, save_best_only=True,
        save_weights_only=True, mode='max', save_freq='epoch')
   
    # TRAIN
    print('Training...')
    history = model.fit(
        get_dataset(files_train, 
                    shuffle    = True, 
                    repeat     = True, 
                    dim        = IMG_SIZES[fold], 
                    batch_size = BATCH_SIZES[fold]), 
        epochs = EPOCHS[fold], 
        callbacks = [sv,get_lr_callback(BATCH_SIZES[fold])], 
        steps_per_epoch = count_data_items(files_train)/BATCH_SIZES[fold]//REPLICAS,
        validation_data = get_dataset(files_valid,
                                      shuffle = False,
                                      repeat  = False, 
                                      dim     = IMG_SIZES[fold]),
        verbose = VERBOSE
    )
    model.load_weights('fold-%i.h5'%fold)
    
    # PREDICT OOF
    print('Predicting OOF...')
    ds_valid = get_dataset(files_valid,labeled=False,return_image_names=False,shuffle=False,dim=IMG_SIZES[fold],batch_size=BATCH_SIZES[fold]*4)
    ct_valid = count_data_items(files_valid); STEPS = ct_valid/BATCH_SIZES[fold]/4/REPLICAS
    pred     = model.predict(ds_valid,steps=STEPS,verbose=VERBOSE)[:ct_valid,] 
    oof_pred.append(pred)      

    # GET OOF TARGETS AND NAMES
    ds_valid = get_dataset(files_valid,dim=IMG_SIZES[fold],labeled=True, return_image_names=True)
    oof_tar.append(np.array([target.numpy() for img, target in iter(ds_valid.unbatch())]) )
    oof_folds.append(np.ones_like(oof_tar[-1],dtype='int8')*fold )
    ds = get_dataset(files_valid,dim=IMG_SIZES[fold],labeled=False,return_image_names=True)
    oof_names.append(np.array([img_name.numpy().decode("utf-8") for img, img_name in iter(ds.unbatch())]))
    
    # PREDICT TEST
    print('Predicting Test...')
    ds_test     = get_dataset(files_test,labeled=False,return_image_names=False,shuffle=False,dim=IMG_SIZES[fold],batch_size=BATCH_SIZES[fold]*4)
    ct_test     = count_data_items(files_test); STEPS = ct_test/BATCH_SIZES[fold]/4/REPLICAS
    pred        = model.predict(ds_test,steps=STEPS,verbose=VERBOSE)[:ct_test,]
    preds[:,0] += (pred * WGTS[fold]).reshape(-1)
#########################
#### FOLD 1
#########################
Training...
Predicting OOF...
Predicting Test...
#########################
#### FOLD 2
#########################
Training...
Predicting OOF...
Predicting Test...
#########################
#### FOLD 3
#########################
Training...
Predicting OOF...
Predicting Test...
#########################
#### FOLD 4
#########################
Training...
Predicting OOF...
Predicting Test...
#########################
#### FOLD 5
#########################
Training...
Predicting OOF...
Predicting Test...

#collapse-show

# COMPUTE OOF AUC
oof      = np.concatenate(oof_pred);  true  = np.concatenate(oof_tar);
names    = np.concatenate(oof_names); folds = np.concatenate(oof_folds)
auc      = roc_auc_score(true,oof)
print('Overall OOF AUC = %.4f'%auc)
Overall OOF AUC = 0.8414

How does the OOF AUC compare to a model without the pre-training stage? To check this, we can simply set USE_PRETRAIN_WEIGHTS = False in the beginning of the notebook. This is done in thus version of the Kaggle notebook, yielding a model with a lower OOF AUC (0.8329 compared to 0.8414 with pre-training).

Compared to a model initialized from the Imagenet weights, pre-training on a surrogate label brings a CV improvement. The AUC gain also translates into the performance gain on the competition leaderboard (increase from 0.8582 to 0.8809). Great news!

6. Closing words

This is the end of this blogpost. Using a computer vision application, we demonstrated how to use meta-data to construct a surrogate label and pre-train a CNN on both training and test data to improve performance.

The pre-trained model can be further optimized to increase performance gains. Using a validation subset on the pre-training stage can help to tune the number of epochs and other learning parameters. Another idea could be to construct a surrogate label with more unique values (e.g., combination of anatom_site_general_challenge and sex) to make the pre-training task more challenging and motivate the model to learn better. On the other hand, further optimizing the main classification model may reduce the benefit of pre-training.