drlex: library for DL and DRL experiments
drlex is an open source library for deep learning (DL) and deep reinforcement learning (DRL) experiments
Introduction
- 自写自用的一个小工具,开源出来希望能
向别人展示我的垃圾代码帮助更多的人 - 目前还在快速迭代过程中,还有许多feature要加(时间和能力允许的话)
- 非常欢迎issue和PR(和star :))
Features
- 记录、追踪、管理实验和实验中产生的文件,简化冗余代码
- 同时在文件(file)和标准输出(stdout)打印日志
- 提供比tensorboard更方便易用的接口,来记录数据
- 保存目前最优的,和每轮实验最新的checkpoint
- 记录实验中重要的数据(如loss,accuracy,reward等)并简单分析
- 多卡可自动选择剩余空间最大的一块GPU
- 灵活、轻量、易上手(argparse用户无痛切换),无需破坏已有代码结构
- ... and more in the future!
Installation
-
Using pip
pip install drlex
-
Build from source
git clone git@github.com:bat67/drlex.git # OR using: git clone https://github.com/bat67/drlex.git cd drlex/ python setup.py install
-
For developers
git clone git@github.com:bat67/drlex.git # OR using: git clone https://github.com/bat67/drlex.git cd drlex/ pip install -e .
Examples
from drlex import Experiment
expt = Experiment.from_json_file('examples/example.json') #初始化一次实验
args = expt.args # 超参数
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_dataset = torchvision.datasets...
test_dataset = torchvision.datasets...
train_loader = torch.utils.data.DataLoader...
test_loader = torch.utils.data.DataLoader...
class ConvNet(nn.Module): ...
model = ConvNet(args.num_classes).to(device)
criterion = ...
optimizer = ...
for epoch in range(args.num_epochs):
model.train()
for i, (images, labels) in enumerate(train_loader):
...
expt.log_msg(f'Epoch [{epoch+1}], Step [{i+1}], Loss: {loss.item():.4f}') # 输出到stdout和file的log
expt.log_scalar('loss', loss.item()) # 使用tensorboard
model.eval()
with torch.no_grad():
...
expt.log_msg(f'Test Accuracy: {accuary} %') # 输出到stdout和file的log
expt.log_metric('acc', accuary, tensorboard=True) # 记录数据数据方便后续分析,同时使用tensorboard
save_dict = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
expt.save_model(save_dict, metric=accuary, epoch=epoch+1, higher_better=True, delete_old=True) #保存模型
# load model
_, _ = expt.load_model(what='last') #加载最后的模型
_, _ = expt.load_model(what='best') #加载最优的模型
# data process
print(expt.watch['acc'].last_n_mean(5)) #最后5个的平均
see more in examples folder
Documentation
- see docs folder
Contribution
- Bugs, feature requests, and more are welcomed, in GitHub Issue.
Changelog
- see CHANGELOG.md