torch-parameter-groups

Group PyTorch Parameters according to Rules


Keywords
PyTorch, Parameter, Groups
License
MIT
Install
pip install torch-parameter-groups==0.0.2

Documentation

torch-parameter-groups Build Status codecov PyPI version

Group PyTorch Parameters according to Rules.

Installation

Need Python 3.6+.

pip install torch-parameter-groups

Usage

import torch
import torch.nn as nn
import torch_basic_models
import torch_parameter_groups


model = torch_basic_models.MobileNetV2.factory()
optimizer = torch_parameter_groups.optimizer_factory(
    model=model,
    config={
        'type': 'SGD',
        'kwargs': {
            'momentum': 0.9,
            'nesterov': True,
            'weight_decay': 0.0001,
        },
        'rules': [
            {
                'param_name_list': ['weight'],
                'kwargs': {
                    'weight_decay': 0
                }
            },
            {
            }
        ]
    },
)

criterion = nn.CrossEntropyLoss()
output = model(torch.randn(1, 3, 224, 224))
loss = criterion(output, torch.Tensor([0]).long())
loss.backward()
optimizer.step(closure=None)