torch-intermediate-layer-getter

Simple easy to use module to get the intermediate results from chosen submodules


Keywords
pytorch
License
GPL-3.0
Install
pip install torch-intermediate-layer-getter==0.1.post1

Documentation

Simple easy to use module to get the intermediate results from chosen submodules. Supports submodule annidation. Inspired in this but does not assume that submodules are executed sequentially.

Installation

pip install torch-intermediate-layer-getter

Usage

Example

import torch
import torch.nn as nn

from torch_intermediate_layer_getter import IntermediateLayerGetter as MidGetter

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(2, 2)
        self.fc2 = nn.Linear(2, 2)
        self.nested = nn.Sequential(
            nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 3)),
            nn.Linear(3, 1),
        )
        self.interaction_idty = nn.Identity() # Simple trick for operations not performed as modules

    def forward(self, x):
        x1 = self.fc1(x)
        x2 = self.fc2(x)

        interaction = x1 * x2
        self.interaction_idty(interaction)

        x_out = self.nested(interaction)

        return x_out
        
model = Model()
return_layers = {
    'fc2': 'fc2',
    'nested.0.1': 'nested',
    'interaction_idty': 'interaction',
}
mid_getter = MidGetter(model, return_layers=return_layers, keep_output=True)
mid_outputs, model_output = mid_getter(torch.randn(1, 2))

print(model_output)
>> tensor([[0.3219]], grad_fn=<AddmmBackward>)
print(mid_outputs)
>> OrderedDict([('fc2', tensor([[-1.5125,  0.9334]], grad_fn=<AddmmBackward>)),
  ('interaction', tensor([[-0.0687, -0.1462]], grad_fn=<MulBackward0>)),
  ('nested', tensor([[-0.1697,  0.1432,  0.2959]], grad_fn=<AddmmBackward>))])

# model_output is None if keep_ouput is False
# if keep_output is True the model_output contains the final model's output