Simple easy to use module to get the intermediate results from chosen submodules. Supports submodule annidation. Inspired in this but does not assume that submodules are executed sequentially.
Installation
pip install torch-intermediate-layer-getter
Usage
Example
import torch
import torch.nn as nn
from torch_intermediate_layer_getter import IntermediateLayerGetter as MidGetter
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 2)
self.fc2 = nn.Linear(2, 2)
self.nested = nn.Sequential(
nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 3)),
nn.Linear(3, 1),
)
self.interaction_idty = nn.Identity() # Simple trick for operations not performed as modules
def forward(self, x):
x1 = self.fc1(x)
x2 = self.fc2(x)
interaction = x1 * x2
self.interaction_idty(interaction)
x_out = self.nested(interaction)
return x_out
model = Model()
return_layers = {
'fc2': 'fc2',
'nested.0.1': 'nested',
'interaction_idty': 'interaction',
}
mid_getter = MidGetter(model, return_layers=return_layers, keep_output=True)
mid_outputs, model_output = mid_getter(torch.randn(1, 2))
print(model_output)
>> tensor([[0.3219]], grad_fn=<AddmmBackward>)
print(mid_outputs)
>> OrderedDict([('fc2', tensor([[-1.5125, 0.9334]], grad_fn=<AddmmBackward>)),
('interaction', tensor([[-0.0687, -0.1462]], grad_fn=<MulBackward0>)),
('nested', tensor([[-0.1697, 0.1432, 0.2959]], grad_fn=<AddmmBackward>))])
# model_output is None if keep_ouput is False
# if keep_output is True the model_output contains the final model's output