pytorch-zero-lit

LiT: Zero-Shot Transfer with Locked-image text Tuning


Install
pip install pytorch-zero-lit==0.2.3

Documentation

pytorch-zero-lit

Converted official JAX models for LiT: Zero-Shot Transfer with Locked-image text Tuning to pytorch.

JAX -> Tensorflow -> ONNX -> Pytorch.

  • Image encoder is loaded into pytorch and supports gradients
  • Text encoder is not loaded into pytorch and runs via ONNX on cpu

Install

poetry add pytorch-zero-lit

or

pip install pytorch-zero-lit

Usage

from lit import LiT

model = LiT()

images = TF.to_tensor(
    Image.open("cat.png").convert("RGB").resize((224, 224))
)[None]
texts = [
    "a photo of a cat",
    "a photo of a dog",
    "a photo of a bird",
    "a photo of a fish",
]

image_encodings = model.encode_images(images)
text_encodings = model.encode_texts(texts)

cosine_similarity = model.cosine_similarity(image_encodings, text_encodings)