torch-parameter-groups
Group PyTorch Parameters according to Rules.
Installation
Need Python 3.6+.
pip install torch-parameter-groups
Usage
import torch
import torch.nn as nn
import torch_basic_models
import torch_parameter_groups
model = torch_basic_models.MobileNetV2.factory()
optimizer = torch_parameter_groups.optimizer_factory(
model=model,
config={
'type': 'SGD',
'kwargs': {
'momentum': 0.9,
'nesterov': True,
'weight_decay': 0.0001,
},
'rules': [
{
'param_name_list': ['weight'],
'kwargs': {
'weight_decay': 0
}
},
{
}
]
},
)
criterion = nn.CrossEntropyLoss()
output = model(torch.randn(1, 3, 224, 224))
loss = criterion(output, torch.Tensor([0]).long())
loss.backward()
optimizer.step(closure=None)