Samsung AI Challenge solution
๋ณธ repository๋ฅผ ํตํด 2021๋ DACON์ ํตํด ๊ฐ์ต๋ Samsung AI Challenge for Scientific Discovery ๊ฒฝ์ง๋ํ์ 5์ ์๋ฃจ์ ์ฝ๋๋ฅผ ์ ๋ฆฌํ์ฌ ๊ณต๊ฐํฉ๋๋ค.
1. ๊ฐ์
๋ณธ ์ฑ๋ฆฐ์ง์์๋ ๋ถ์์ 3์ฐจ์ ๊ตฌ์กฐ ์ ๋ณด๋ฅผ ์ด์ฉํ์ฌ S1-T1 ์ฌ์ด์ ์๋์ง ๊ฐญ์ ์ถ์ ํ ์ ์๋ Machine Learning ์๊ณ ๋ฆฌ์ฆ์ ์ฑ๋ฅ์ ๊ฒจ๋ฃน๋๋ค.
2. ์ ๊ทผ
๋ชจ๋ธ ์ค๋ช
WIP
ํ์ต ๋ฐฉ๋ฒ
Pretraining
- ์๋์ ๋ฐ์ดํฐ์ ์ ์ด์ฉํ์ฌ HOMO ๋ฐ LUMO๋ฅผ ์์ธกํ๋ ๋ฉํฐํ์คํฌ ์ฌ์ ํ์ต์ ์ํํฉ๋๋ค.
- Pretraining์ ์ฌ์ฉ๋๋ molecule sdf ๋ฐ์ดํฐ์ ๋ฉํ๋ฐ์ดํฐ(
pretrain_metadata.csv
)๋ ์ฌ๊ธฐ์์ ๋ค์ด๋ก๋ ๋ฐ์ ์ ์์ต๋๋ค. - Pretraining์ ์ํด ํ๋ก์ธ์ฑ ์๋ฃ๋ molecule sdf๋ค์ ๋ชจ์ ๋ ๋๋ ํ ๋ฆฌ
pretrain_sdf
๋ ์ฌ๊ธฐ์์ ๋ค์ด๋ก๋ ๋ฐ์ ์ ์์ต๋๋ค.
Fine-tuning
- ์ฌ์ ํ์ต๋ stem์ ์ด์ฉํฉ๋๋ค.
- ์ฒซ 9 epoch์ pretrained weight๋ฅผ freeze ์ํจ ์ํ๋ก ํ์ตํ๊ณ , 10 epoch ๋ถํฐ weight unfreeze ํ ๋ชจ๋ weight๋ฅผ ์ ๋ฐ์ดํธ ์ํต๋๋ค.
- ์ ๊ณต๋ ํ์ต ๋ฐ์ดํฐ๋ก S1-T1 gap๊ณผ, S1, T1 ๊ฐ๊ฐ์ ๊ฐ์ ์์ธกํ๋ regression head๋ฅผ ํ์ตํฉ๋๋ค.
- Gap, S1, T1 regression์ MSE loss๋ฅผ ์ฌ์ฉํฉ๋๋ค.
- Gap์ weight๋ 1.0์ด๊ณ , S1, T1 regression์ weight๋ 0.05๋ก ํ์ตํฉ๋๋ค.
- Optimizer =
AdamW(lr=3e-5)
- Scheduler =
ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=15, threshold=0.005, threshold_mode='rel')
- Batch size = 64
3. ์ค์น ๋ฐ ์ฌ์ฉ๋ฒ
๋ณธ ์๋ฃจ์ ์ฝ๋ ๋ฐ ๋ชจ๋ธ์ PyPI์ ๋ฐฐํฌ๋์ด ์์ต๋๋ค. ๋ชจ๋ธ์ ๋ค์๊ณผ ๊ฐ์ด ์ค์นํ ์ ์์ต๋๋ค.
- ์ฌ์ฉ์ ์์์
openbabel
ํจํค์ง๊ฐ ํ๊ฒฝ์ ์ค์น๋์ด ์์ด์ผ ํฉ๋๋ค.conda install -c conda-forge openbabel
๋ก ์ค์น ๊ฐ๋ฅํฉ๋๋ค.
$ pip install sac2021
Pretraining
- ๋ฉํ๋ฐ์ดํฐ
pretrain_metadata.csv
์ sdf ํ์ผ ๋๋ ํ ๋ฆฌpretrain_sdf
๋ฅผ ๋ค์ด๋ก๋ ํ, ์๋ ๋ช ๋ น์--meta
์--data
ํ๋ผ๋ฏธํฐ๋ฅผ ์ ์ ํ ์ค์ ํ์ฌ ์ฌ์ ํ์ต์ ์งํํฉ๋๋ค.
$ python -m sac2021.pretrain \
--meta [path/to/pretrain_metadata.csv] \
--data [path/to/pretrain_sdf] \
--output [OUTPUT_CHECKPOINT] \
--model-id [ID] \
--fold 0 \ # For validation purpose. (2.5% of the data will be held out)
--loss mse \
Pretraining ํ์ต ๋ก๊ทธ๋ ์ด Weight & Biases Project์์ ํ์ธ ๊ฐ๋ฅํฉ๋๋ค.
Fine-tuning
- ํ์ต ๋ฐ์ดํฐ
traindev.csv
์ sdf ํ์ผ ๋๋ ํ ๋ฆฌtraindev_sdf
๋ฅผ ๋ค์ด๋ก๋ ํ, ์๋ ๋ช ๋ น์--meta
์--data
ํ๋ผ๋ฏธํฐ๋ฅผ ์ ์ ํ ์ค์ ํ์ฌ fine-tuning์ ์งํํฉ๋๋ค.
$ python -m sac2021.finetune \
--meta [path/to/traindev.csv] \
--data [path/to/traindev_sdf] \
--ckpt [path/to/pretrained_checkpoint] \
--output [OUTPUT_CHECKPOINT] \
--model-id [ID] \
--fold 0 \
--loss mse