pytorchsummary

Summary of PyTorch Models just like `model.summary() in Keras


Keywords
python, PyTorch, model, summary, parameter, cnn-model, deep-learning, deep-neural-networks, model-summary
License
MIT
Install
pip install pytorchsummary==1.3.0

Documentation

PyTorch Model Parameters Summary

Install using pip

pip install pytorchsummary

Example 1

from torch import nn
from pytorchsummary import parameter_summary

class CNNET(nn.Module):
    def __init__(self):
        super(CNNET,self).__init__()

        self.layer = nn.Sequential(
            nn.Conv2d(3,16,5), # 28-5+1
            nn.ReLU(), #24
            nn.MaxPool2d(2,2), # 12

            nn.Conv2d(16,32,3), # 12+1-3
            nn.ReLU(), # 10
            nn.MaxPool2d(2,2), # 5
            

            nn.Conv2d(32,64,5), # 11-3+1
            nn.ReLU(),

            nn.Conv2d(64,10,1)   
        )
    
    def forward(self,x):
        x = self.layer(x)
        return x

m = CNNET()
parameter_summary(m,False) 
for i,j in enumerate(m.parameters()):
    if i==2:
        break
    j.requires_grad=False 
# parameter_summary(model=m,border=False)
# if border set to True then it will print 
# the lines in between every layer 

Output

LAYER TYPE                   KERNEL SHAPE     #parameters        (weights+bias)         requires_grad         
____________________________________________________________________________________________________
 Conv2d-1                  [16, 3, 5, 5]    	1,216                (1200 + 16)          False False          
 ReLU-2                          -          	-                          -                               
 MaxPool2d-3                     -          	-                          -                               
 Conv2d-4                  [32, 16, 3, 3]   	4,640                (4608 + 32)           True True           
 ReLU-5                          -          	-                          -                               
 MaxPool2d-6                     -          	-                          -                               
 Conv2d-7                  [64, 32, 5, 5]   	51,264               (51200 + 64)           True True           
 ReLU-8                          -          	-                          -                               
 Conv2d-9                  [10, 64, 1, 1]   	650                 (640 + 10)           True True           
====================================================================================================

Total parameters 57,770
Total Non-Trainable parameters 1,216
Total Trainable parameters 56,554

57770

Example 2

from torchvision import models
from pytorchsummary import parameter_summary

m = models.alexnet(False)
parameter_summary(m)
# this function returns the total number of 
# parameters (int) in a model

ouput

LAYER TYPE                   KERNEL SHAPE     #parameters        (weights+bias)         requires_grad         
____________________________________________________________________________________________________
____________________________________________________________________________________________________
 Conv2d-1                 [64, 3, 11, 11]   	23,296               (23232 + 64)           True True           
____________________________________________________________________________________________________
 ReLU-2                          -          	-                          -                               
____________________________________________________________________________________________________
 MaxPool2d-3                     -          	-                          -                               
____________________________________________________________________________________________________
 Conv2d-4                 [192, 64, 5, 5]   	307,392             (307200 + 192)           True True           
____________________________________________________________________________________________________
 ReLU-5                          -          	-                          -                               
____________________________________________________________________________________________________
 MaxPool2d-6                     -          	-                          -                               
____________________________________________________________________________________________________
 Conv2d-7                 [384, 192, 3, 3]  	663,936             (663552 + 384)           True True           
____________________________________________________________________________________________________
 ReLU-8                          -          	-                          -                               
____________________________________________________________________________________________________
 Conv2d-9                 [256, 384, 3, 3]  	884,992             (884736 + 256)           True True           
____________________________________________________________________________________________________
 ReLU-10                         -          	-                          -                               
____________________________________________________________________________________________________
 Conv2d-11                [256, 256, 3, 3]  	590,080             (589824 + 256)           True True           
____________________________________________________________________________________________________
 ReLU-12                         -          	-                          -                               
____________________________________________________________________________________________________
 MaxPool2d-13                    -          	-                          -                               
____________________________________________________________________________________________________
 AdaptiveAvgPool2d-14            -          	-                          -                               
____________________________________________________________________________________________________
 Dropout-15                      -          	-                          -                               
____________________________________________________________________________________________________
 Linear-16                  [4096, 9216]    	37,752,832          (37748736 + 4096)           True True           
____________________________________________________________________________________________________
 ReLU-17                         -          	-                          -                               
____________________________________________________________________________________________________
 Dropout-18                      -          	-                          -                               
____________________________________________________________________________________________________
 Linear-19                  [4096, 4096]    	16,781,312          (16777216 + 4096)           True True           
____________________________________________________________________________________________________
 ReLU-20                         -          	-                          -                               
____________________________________________________________________________________________________
 Linear-21                  [1000, 4096]    	4,097,000           (4096000 + 1000)           True True           
====================================================================================================

Total parameters 61,100,840
Total Non-Trainable parameters 0
Total Trainable parameters 61,100,840