tf-centernet
CenterNet implementation with Tensorflow 2.
Install
pip instal tf-centernet
Example
Object detection
import numpy as np
import PIL.Image
import centernet
# Default: num_classes=80
obj = centernet.ObjectDetection(num_classes=80)
# Default: weights_path=None
# num_classes=80 and weights_path=None: Pre-trained COCO model will be loaded.
# Otherwise: User-defined weight file will be loaded.
obj.load_weights(weights_path=None)
img = np.array(PIL.Image.open('./data/sf.jpg'))[..., ::-1]
# The image with predicted bounding-boxes is created if `debug=True`
boxes, classes, scores = obj.predict(img, debug=True)
Pose estimation
import numpy as np
import PIL.Image
import centernet
# Default: num_joints=17
pe = centernet.PoseEstimation(num_joints=17)
# Default: weights_path=None
# num_joints=17 and weights_path=None: Pre-trained COCO model will be loaded.
# Otherwise: User-defined weight file will be loaded.
pe.load_weights(weights_path=None)
# Adjust this for the better prediction
pe.score_threshold = 0.1
img = np.array(PIL.Image.open('./data/chi.jpg'))[..., ::-1]
# The image with predicted keypoints is created if `debug=True`
boxes, keypoints, scores = pe.predict(img, debug=True)
TODO
- Object detection
- Pre-trained model for object detection with Hourglass-104
- Pose estimation
- Pre-trained model for pose estimation with Hourglass-104
- DLA-34 backbone and pre-trained models
- Training function and Loss definition
- Training data augmentation