mltreelib

A real tree based ML package


Keywords
nbdev, jupyter, notebook, python
License
Apache-2.0
Install
pip install mltreelib==0.0.2

Documentation

mltree

This package evovled from the attempt to make right kind of Decision Tress which was ideated by many people like Hastie, Tibshirani, Friedman, Quilan, Loh, Chaudhari.

Install

pip install mltreelib

How to use

Create a sample data

import numpy as np
import pandas as pd
from mltreelib.data import Dataset
from mltreelib.tree import Tree
n_size = 1000
rnd = np.random.RandomState(1234)
dummy_data = pd.DataFrame({'numericfull':rnd.randint(1,500,size=n_size),
                            'unitint':rnd.randint(1,25,size=n_size),
                            'floatfull':rnd.random_sample(size=n_size),
                            'floatsmall':np.round(rnd.random_sample(size=n_size)+rnd.randint(1,25,size=n_size),2),
                            'categoryobj':rnd.choice(['a','b','c','d'],size=n_size),
                            'stringobj':rnd.choice(["{:c}".format(k) for k in range(97, 123)],size=n_size)})
dummy_data.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
numericfull unitint floatfull floatsmall categoryobj stringobj
0 304 18 0.908959 8.56 a c
1 212 24 0.348582 14.35 a g
2 295 15 0.392977 21.98 a y
3 54 20 0.720856 5.33 a q
4 205 21 0.897588 23.03 c k

Create a Dataset

dataset = Dataset(df=dummy_data)
print(dataset)
print('Pandas Data Frame        : ',np.round(dummy_data.memory_usage(deep=True).sum()*1e-6,2),'MB')
print('Dataset Structured Array : ',np.round(dataset.data.nbytes*1e-6/ 1024 * 1024,2),'MB')
dataset.data[:5]
Dataset(df=Shape((1000, 6), reduce_datatype=True, encode_category=None, add_intercept=False, na_treatment=allow, copy=False, digits=None, n_category=None, split_ratio=None)
Pandas Data Frame        :  0.15 MB
Dataset Structured Array :  0.03 MB

array([(304, 18, 0.9089594 ,  8.56, 'a', 'c'),
       (212, 24, 0.34858167, 14.35, 'a', 'g'),
       (295, 15, 0.39297667, 21.98, 'a', 'y'),
       ( 54, 20, 0.7208556 ,  5.33, 'a', 'q'),
       (205, 21, 0.89758754, 23.03, 'c', 'k')],
      dtype=[('numericfull', '<u2'), ('unitint', 'u1'), ('floatfull', '<f4'), ('floatsmall', '<f4'), ('categoryobj', 'O'), ('stringobj', 'O')])