Last update: 23.10.2021

1. Overview

In deep learning tasks, we usually work with predictions outputted by the final layer of a neural network. In some cases, we might also be interested in the outputs of intermediate layers. Whether we want to extract data embeddings or inspect what is learned by earlier layers, it may not be straightforward how to extract the intermediate features from the network.

This blog post provides a quick tutorial on the extraction of intermediate activations from any layer of a deep learning model in PyTorch using the forward hook functionality. The important advantage of this method is its simplicity and ability to extract features without having to run the inference twice, only requiring a single forward pass through the model to save multiple outputs.

2. Why do we need intermediate features?

Extracting intermediate activations (also called features) can be useful in many applications. In computer vision problems, outputs of intermediate CNN layers are frequently used to visualize the learning process and illustrate visual features distinguished by the model on different layers. Another popular use case is extracting intermediate outputs to create image or text embeddings, which can be used to detect duplicate items, included as input features in a classical ML model, visualize data clusters and much more. When working with Encoder-Decoder architectures, outputs of intermediate layers can also be used to compress the data into a smaller-sized vector containing the data represenatation. There are many further use cases in which intermediate activations can be useful. So, let's discuss how to get them!

3. How to extract activations?

To extract activations from intermediate layers, we will need to register a so-called forward hook for the layers of interest in our neural network and perform inference to store the relevant outputs.

For the purpose of this tutorial, I will use image data from a Cassava Leaf Disease Classification Kaggle competition. In the next few cells, we will import relevant libraries and set up a Dataloader object. Feel free to skip them if you are familiar with standard PyTorch data loading practices and go directly to the feature extraction part.

Preparations

#collapse-hide

##### PACKAGES

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

!pip install timm
import timm

import albumentations as A
from albumentations.pytorch import ToTensorV2

import cv2
import os

device = torch.device('cuda')

#collapse-hide

##### DATASET

class ImageData(Dataset):
    
    # init
    def __init__(self, 
                 data, 
                 directory, 
                 transform):
        self.data      = data
        self.directory = directory
        self.transform = transform
        
    # length
    def __len__(self):
        return len(self.data)
    
    # get item  
    def __getitem__(self, idx):
        
        # import
        image = cv2.imread(os.path.join(self.directory, self.data.iloc[idx]['image_id']))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
        # augmentations
        image = self.transform(image = image)['image']
        
        return image

We will use a standrd PyTorch dataloader to load the data in batches of 32 images.

#collapse-show

##### DATA LOADER

# import data
df = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
display(df.head())

# augmentations
transforms = A.Compose([A.Resize(height = 128, width = 128),
                        A.Normalize(),
                        ToTensorV2()])

# dataset
data_set = ImageData(data      = df, 
                     directory = '../input/cassava-leaf-disease-classification/train_images/',
                     transform = transforms)

# dataloader
data_loader = DataLoader(data_set, 
                         batch_size  = 32, 
                         shuffle     = False, 
                         num_workers = 2)
image_id label
0 1000015157.jpg 0
1 1000201771.jpg 3
2 100042118.jpg 1
3 1000723321.jpg 1
4 1000812911.jpg 3

Model

To extract anything from a neural net, we first need to set up this net, right? In the cell below, we define a simple resnet18 model with a two-node output layer. We use timm library to instantiate the model, but feature extraction will also work with any neural network written in PyTorch.

We also print out the architecture of our network. As you can see, there are many intermediate layers through which our image travels during a forward pass before turning into a two-number output. We should note the names of the layers because we will need to provide them to a feature extraction function.

##### DEFINE MODEL

model    = timm.create_model(model_name = 'resnet18', pretrained = True)
model.fc = nn.Linear(512, 2)
model.to(device)
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
  )
 
 ... 
 
 (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
  )
  (global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=True)
  (fc): Linear(in_features=512, out_features=2, bias=True)
)

Feature extraction

The implementation of feature extraction requires two simple steps:

  1. Registering a forward hook on a certain layer of the network.
  2. Performing standard inference to extract features of that layer.

First, we need to define a helper function that will introduce a so-called hook. A hook is simply a command that is executed when a forward or backward call to a certain layer is performed. If you want to know more about hooks, you can check out this link.

In out setup, we are interested in a forward hook that simply copies the layer outputs, sends them to CPU and saves them to a dictionary object we call features.

The hook is defined in a cell below. The name argument in get_features() specifies the dictionary key under which we will store our intermediate activations.

##### HELPER FUNCTION FOR FEATURE EXTRACTION

def get_features(name):
    def hook(model, input, output):
        features[name] = output.detach()
    return hook

After the helper function is defined, we can register a hook using .register_forward_hook() method. The hook can be applied to any layer of the neural network.

Since we work with a CNN, extracting features from the last convolutional layer might be useful to get image embeddings. Therefore, we are registering a hook for the outputs of the (global_pool). To extract features from an earlier layer, we could also access them with, e.g., model.layer1[1].act2 and save it under a different name in the features dictionary. With this method, we can actually register multiple hooks (one for every layer of interest), but we will only keep one for the purpose of this example.

##### REGISTER HOOK

model.global_pool.register_forward_hook(get_features('feats'))
<torch.utils.hooks.RemovableHandle at 0x7f2540254290>

Now we are ready to extract features! The nice thing about hooks is that we can now perform inference as we usually would and get multiple outputs at the same time:

  • outputs of the final layer
  • outputs of every layer with a registered hook

The feature extraction happens automatically during the forward pass whenever we run model(inputs). To store intermediate features and concatenate them over batches, we just need to include the following in our inference loop:

  1. Create placeholder list FEATS = []. This list will store intermediate outputs from all batches.
  2. Create placeholder dict features = {}. We will use this dictionary for storing intermediate outputs from each batch.
  3. Iteratively extract batch features to features, send them to CPU and append to the list FEATS.
##### FEATURE EXTRACTION LOOP

# placeholders
PREDS = []
FEATS = []

# placeholder for batch features
features = {}

# loop through batches
for idx, inputs in enumerate(data_loader):

    # move to device
    inputs = inputs.to(device)
       
    # forward pass [with feature extraction]
    preds = model(inputs)
    
    # add feats and preds to lists
    PREDS.append(preds.detach().cpu().numpy())
    FEATS.append(features['feats'].cpu().numpy())

    # early stop
    if idx == 9:
        break

This is it! Looking at the shapes of resulting arrays, you can see that the code worked well: we extracted both final layer outputs as PREDS and intermediate activations as FEATS. We can now save these features and work with them further.

##### INSPECT FEATURES

PREDS = np.concatenate(PREDS)
FEATS = np.concatenate(FEATS)

print('- preds shape:', PREDS.shape)
print('- feats shape:', FEATS.shape)
- preds shape: (320, 2)
- feats shape: (320, 512)

4. Closing words

The purpose of this tutorial was to learn you how to extract intermediate outputs from the most interesting layers of your neural networks. With hooks, you can do all feature extraction in a single inference run and avoid complex modifications of your model. I hope you found this post helpful.

If you are interested, check out my other blog posts to see more tips on deep learning and PyTorch. Happy learning!