A minimal implementation of a genenerative model with flow matching for tabular data. No deep learning - uses XGBoost to learn the generative model.
The original implementation is available in forest-diffusion. Another implementation is available in the torchcfm
library.
Unlike the implemenation in the forest-diffusion
, we simplify the implemenatation by utilising XGBoost
's ability to predict multiple regression outputs.
pip install flowmatching-bdt
from sklearn.datasets import make_moons
from flowmatching_bdt import FlowMatchingBDT
data, _ = make_moons(n_samples=1000, noise=0.1, random_state=0)
model = FlowMatchingBDT()
# train the model
model.fit(data)
# get new samples
num_samples = 1000
samples = model.predict(num_samples=num_samples)
If you'd like to do conditional generation:
import numpy as np
from sklearn.datasets import make_moons
from flowmatching_bdt import FlowMatchingBDT
data, labels = make_moons(n_samples=1000, noise=0.1, random_state=42)
model = FlowMatchingBDT()
# train the model
model.fit(data, conditions=labels)
# get new samples
num_samples = 1000
conditions = np.ones(num_samples)
samples = model.predict(num_samples=num_samples, conditions=conditions)
To learn more about flow matching for generative modelling check out these resources.
- Introduction to Flow Matching Tor Fjelde, Emilie Mathieu, Vincent Dutordoir
- Generating Tabular Data with XGBoost Alexia Jolicoeur (Author of the ForestFlow paper)
@inproceedings{jolicoeur2024generating,
title={Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees},
author={Jolicoeur-Martineau, Alexia and Fatras, Kilian and Kachman, Tal},
booktitle={International Conference on Artificial Intelligence and Statistics},
pages={1288--1296},
year={2024},
organization={PMLR}
}
This repository is inspired heavily and borrows parts from lucidrains
(project structure) and torch-cfm
.