Extracting Intermediate Layer Outputs in PyTorch
Simple way to extract activations from deep networks with hooks
- 1. Overview
- 2. Why do we need intermediate features?
- 3. How to extract activations?
- 4. Closing words
Last update: 23.10.2021
1. Overview
In deep learning tasks, we usually work with predictions outputted by the final layer of a neural network. In some cases, we might also be interested in the outputs of intermediate layers. Whether we want to extract data embeddings or inspect what is learned by earlier layers, it may not be straightforward how to extract the intermediate features from the network.
This blog post provides a quick tutorial on the extraction of intermediate activations from any layer of a deep learning model in PyTorch
using the forward hook functionality. The important advantage of this method is its simplicity and ability to extract features without having to run the inference twice, only requiring a single forward pass through the model to save multiple outputs.
2. Why do we need intermediate features?
Extracting intermediate activations (also called features) can be useful in many applications. In computer vision problems, outputs of intermediate CNN layers are frequently used to visualize the learning process and illustrate visual features distinguished by the model on different layers. Another popular use case is extracting intermediate outputs to create image or text embeddings, which can be used to detect duplicate items, included as input features in a classical ML model, visualize data clusters and much more. When working with Encoder-Decoder architectures, outputs of intermediate layers can also be used to compress the data into a smaller-sized vector containing the data represenatation. There are many further use cases in which intermediate activations can be useful. So, let's discuss how to get them!
3. How to extract activations?
To extract activations from intermediate layers, we will need to register a so-called forward hook for the layers of interest in our neural network and perform inference to store the relevant outputs.
For the purpose of this tutorial, I will use image data from a Cassava Leaf Disease Classification Kaggle competition. In the next few cells, we will import relevant libraries and set up a Dataloader object. Feel free to skip them if you are familiar with standard PyTorch
data loading practices and go directly to the feature extraction part.
#collapse-hide
##### PACKAGES
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
!pip install timm
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import os
device = torch.device('cuda')
#collapse-hide
##### DATASET
class ImageData(Dataset):
# init
def __init__(self,
data,
directory,
transform):
self.data = data
self.directory = directory
self.transform = transform
# length
def __len__(self):
return len(self.data)
# get item
def __getitem__(self, idx):
# import
image = cv2.imread(os.path.join(self.directory, self.data.iloc[idx]['image_id']))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# augmentations
image = self.transform(image = image)['image']
return image
We will use a standrd PyTorch
dataloader to load the data in batches of 32 images.
#collapse-show
##### DATA LOADER
# import data
df = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
display(df.head())
# augmentations
transforms = A.Compose([A.Resize(height = 128, width = 128),
A.Normalize(),
ToTensorV2()])
# dataset
data_set = ImageData(data = df,
directory = '../input/cassava-leaf-disease-classification/train_images/',
transform = transforms)
# dataloader
data_loader = DataLoader(data_set,
batch_size = 32,
shuffle = False,
num_workers = 2)
Model
To extract anything from a neural net, we first need to set up this net, right? In the cell below, we define a simple resnet18
model with a two-node output layer. We use timm
library to instantiate the model, but feature extraction will also work with any neural network written in PyTorch
.
We also print out the architecture of our network. As you can see, there are many intermediate layers through which our image travels during a forward pass before turning into a two-number output. We should note the names of the layers because we will need to provide them to a feature extraction function.
##### DEFINE MODEL
model = timm.create_model(model_name = 'resnet18', pretrained = True)
model.fc = nn.Linear(512, 2)
model.to(device)
Feature extraction
The implementation of feature extraction requires two simple steps:
- Registering a forward hook on a certain layer of the network.
- Performing standard inference to extract features of that layer.
First, we need to define a helper function that will introduce a so-called hook. A hook is simply a command that is executed when a forward or backward call to a certain layer is performed. If you want to know more about hooks, you can check out this link.
In out setup, we are interested in a forward hook that simply copies the layer outputs, sends them to CPU and saves them to a dictionary object we call features
.
The hook is defined in a cell below. The name
argument in get_features()
specifies the dictionary key under which we will store our intermediate activations.
##### HELPER FUNCTION FOR FEATURE EXTRACTION
def get_features(name):
def hook(model, input, output):
features[name] = output.detach()
return hook
After the helper function is defined, we can register a hook using .register_forward_hook()
method. The hook can be applied to any layer of the neural network.
Since we work with a CNN, extracting features from the last convolutional layer might be useful to get image embeddings. Therefore, we are registering a hook for the outputs of the (global_pool)
. To extract features from an earlier layer, we could also access them with, e.g., model.layer1[1].act2
and save it under a different name in the features
dictionary. With this method, we can actually register multiple hooks (one for every layer of interest), but we will only keep one for the purpose of this example.
##### REGISTER HOOK
model.global_pool.register_forward_hook(get_features('feats'))
Now we are ready to extract features! The nice thing about hooks is that we can now perform inference as we usually would and get multiple outputs at the same time:
- outputs of the final layer
- outputs of every layer with a registered hook
The feature extraction happens automatically during the forward pass whenever we run model(inputs)
. To store intermediate features and concatenate them over batches, we just need to include the following in our inference loop:
- Create placeholder list
FEATS = []
. This list will store intermediate outputs from all batches. - Create placeholder dict
features = {}
. We will use this dictionary for storing intermediate outputs from each batch. - Iteratively extract batch features to
features
, send them to CPU and append to the listFEATS
.
##### FEATURE EXTRACTION LOOP
# placeholders
PREDS = []
FEATS = []
# placeholder for batch features
features = {}
# loop through batches
for idx, inputs in enumerate(data_loader):
# move to device
inputs = inputs.to(device)
# forward pass [with feature extraction]
preds = model(inputs)
# add feats and preds to lists
PREDS.append(preds.detach().cpu().numpy())
FEATS.append(features['feats'].cpu().numpy())
# early stop
if idx == 9:
break
This is it! Looking at the shapes of resulting arrays, you can see that the code worked well: we extracted both final layer outputs as PREDS
and intermediate activations as FEATS
. We can now save these features and work with them further.
##### INSPECT FEATURES
PREDS = np.concatenate(PREDS)
FEATS = np.concatenate(FEATS)
print('- preds shape:', PREDS.shape)
print('- feats shape:', FEATS.shape)
4. Closing words
The purpose of this tutorial was to learn you how to extract intermediate outputs from the most interesting layers of your neural networks. With hooks, you can do all feature extraction in a single inference run and avoid complex modifications of your model. I hope you found this post helpful.
If you are interested, check out my other blog posts to see more tips on deep learning and PyTorch
. Happy learning!
Liked the post? Share it on social media!
You can also buy me a cup of tea to support my work. Thanks!