Understanding PaliGemma in 50 minutes or less

Ritwik RahaRitwik Raha
9 min read

PaliGemma is designed as a versatile model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question answering, text reading, object detection, and object segmentation.

Note of some importance
A note of thanks, acknowledgment, and warning: This post is a collation of wonderful resources from the PaliGemma model card, the HuggingFace Blogpost, and the BigVision repository. The point of this post is to simplify and show it from my perspective. I am open to constructive criticism and

What is PaliGemma?

PaliGemma is a new family of vision-language models from Google. These models can process both images and text to produce text outputs.

Google has released three types of PaliGemma models:

  • Pretrained (pt) models: Trained on large datasets without task-specific tuning.

  • Mix models: A combination of pre-trained and fine-tuned elements.

  • Fine-tuned (ft) models: Optimized for specific tasks with additional training.

Each type comes in different resolutions and multiple precisions for convenience. All models are available on the Hugging Face Hub with model cards, licenses, and integration with transformers.

"How do I get it running?"

The model comes with simple out-of-the-box usage with Huggingface Transformers. A simple colab notebook is linked here.


from transformers import AutoTokenizer, PaliGemmaForConditionalGeneration, PaliGemmaProcessor
from huggingface_hub import notebook_login
import torch
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
import requests

input_text = "Who is this person?"
img_url = "https://huggingface.co/datasets/ritwikraha/random-storage/resolve/main/cohen.jpeg"
input_image = Image.open(requests.get(img_url, stream=True).raw)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = "google/paligemma-3b-mix-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16)
processor = PaliGemmaProcessor.from_pretrained(model_id)

inputs = processor(text=input_text, images=input_image,
                  padding="longest", do_convert_rgb=True, return_tensors="pt").to("cuda")
model.to(device)
inputs = inputs.to(dtype=model.dtype)

with torch.no_grad():
  output = model.generate(**inputs, max_length=496)
print(processor.decode(output[0], skip_special_tokens=True)
)

"What is the architecture like?"

PaliGemma (GitHub) is a family of vision-language models with an architecture featuring SigLIP-So400m as the image encoder and Gemma-2B as the text decoder.

  • SigLIP is a state-of-the-art model capable of understanding both images and text. Similar to CLIP, it includes an image and text encoder trained together.

  • Gemma is a decoder-only model designed for text generation.

But what is SigLIP?

SigLIP introduces a straightforward modification to the widely-used CLIP architecture, as detailed in the paper https://arxiv.org/abs/2303.15343. CLIP's architecture includes an image encoder and a text encoder, both utilizing Transformer-based models.

The pre-training of CLIP employs a contrastive approach to ensure that embeddings of corresponding images and texts are close in the embedding space, while non-matching pairs are positioned far apart. Traditionally, CLIP is trained with a softmax loss function, which necessitates a global view of all pairwise similarities for probability normalization.

SigLIP simplifies this process by substituting the softmax loss with a sigmoid loss. Unlike softmax, the sigmoid loss does not require a global perspective on pairwise similarities. This modification converts the task into a binary classification problem: determining whether a given image and text pair belong together, with a straightforward yes or no.

But what is Gemma?

Gemma is a family of lightweight, state-of-the-art open models derived from the research and technology that underpinned the creation of Gemini models. These models exhibit robust performance across various academic benchmarks, particularly in language understanding, reasoning, and safety. The Gemma models are available in two sizes, featuring 2 billion and 7 billion parameters, with both pretrained and fine-tuned checkpoints provided.

PaliGemma combines these two components: the image encoder from SigLIP and the text decoder from Gemma, connected through a linear adapter. This combination creates a powerful vision-language model that can be pre-trained on image-text data and fine-tuned for tasks like captioning and referring segmentation.

Let us look at the boilerplate code of what this might look like. The actual code is present in the bigvision repository and this is simply a reduced version of that code.

class Model(nn.Module):
    """Two towers transformer for image and text."""
    img_model: str = "vit"
    img: Optional[ConfigDict] = None
    llm_model: str = "proj.paligemma.gemma_bv"
    llm: Optional[ConfigDict] = None

    def setup(self):
        # Initialize LLM and image models
        self._llm = importlib.import_module(f"big_vision.models.{self.llm_model}").Model(**(self.llm or {}), name="llm")
        img_config = {"num_classes": self._llm.embdim, **(self.img or {})}
        self._img_model = importlib.import_module(f"big_vision.models.{self.img_model}").Model(**img_config, name="img")

    def embed_image(self, image, train=False):
        """Embeds the input image."""
        # Preprocess image and call image model
        # Return image embeddings and any auxiliary outputs
        pass

    def embed_text(self, tokens, train=False):
        """Embeds the input text tokens."""
        # Call LLM to embed text tokens
        # Return text embeddings and any auxiliary outputs
        pass

    def embed_image_and_text(self, image, text, input_mask=None, mask_ar=None, train=False):
        """Concatenates image and text embeddings."""
        # Embed image and text separately
        # Combine embeddings into a single sequence
        pass

    def __call__(self, image, text, mask_ar, train=False):
        """Processes input image and text and returns logits."""
        # Embed image and text
        # Create attention mask and call transformer
        # Extract and return logits
        pass

    def prefill_cache(self, x, input_mask, mask_ar, cache_size):
        """Initializes decoding cache with prompt."""
        # Initialize cache for decoding
        pass

    def extend_cache(self, x):
        """Advances decoding cache with new input."""
        # Extend cache for decoding
        pass

    def _fallback_prefill_cache(self, x, input_mask, mask_ar, cache_size):
        # Fallback method for initializing cache
        pass

    def _fallback_extend_cache(self, x):
        # Fallback method for extending cache
        pass

The above boilerplate code defines a Model class in a PyTorch-like framework, designed to implement a two-tower transformer architecture for vision-language models (VLMs). The model consists of two primary components: an image model (using ViT) and a language model (using a variant of the Gemma model).

The Model class includes methods for embedding images and text, concatenating these embeddings, and processing them together to generate output logits.

Two-Tower Transformer Architecture for VLMs

The two-tower transformer architecture in vision-language models (VLMs) involves separate towers (or networks) for processing images and text, which are later combined for joint tasks. Here's how it works:

  1. Image Embedding Tower:

    • The image model, specified by img_model (e.g., ViT), processes input images.

    • The model is initialized with parameters specified in the img configuration.

    • The embed_image method preprocesses the image and generates embeddings using the image model.

  2. Text Embedding Tower:

    • The language model, specified by llm_model (e.g., Gemma variant), processes input text tokens.

    • The model is initialized with parameters specified in the llm configuration.

    • The embed_text method generates embeddings for text tokens using the language model.

  3. Combining Embeddings:

    • The embed_image_and_text method separately embeds images and text, then concatenates these embeddings into a single sequence.

    • This combined sequence is used for tasks that require joint image-text understanding.

  4. Processing Inputs:

    • The __call__ method processes input images and text, creates an attention mask, and passes the combined embeddings through a transformer to generate output logits.

By leveraging this two-tower approach, the PaliGemma model can effectively learn and utilize the relationships between visual and textual information, which is not wholly indifferent from how a Vision language Model is pre-trained.

As mentioned before the PaliGemma release includes three types of models:

  • PT checkpoints are pre-trained models that can be further fine-tuned for specific downstream tasks.

  • Mix checkpoints: These models are pre-trained and then fine-tuned on a mixture of tasks. They are suitable for general-purpose inference with free-text prompts and are intended for research purposes only.

  • FT checkpoints: These are specialized fine-tuned models, each optimized for a different academic benchmark. They come in various resolutions and are also intended for research purposes only.

PaliGemma Model Sizes
The PaliGemma models come in three different resolutions: 224x224, 448x448, and 896x896. They are also available in three different precisions: bfloat16, float16, and float32. Each model repository contains checkpoints for a specific resolution and task, with three revisions corresponding to the available precisions. The main branch of each repository contains float32 checkpoints, while the bfloat16 and float16 revisions contain the respective precision models.

There are separate repositories for models compatible with HuggingFace Transformers and those using the original JAX implementation.

What was the Pretraining like?

PaliGemma is pre-trained on the following mixture of datasets:

Datasets

  • WebLI (Web Language Image):

    The WebLI dataset. Top: Sampled images 4 associated with multilingual... |  Download Scientific Diagram

    A web-scale multilingual image-text dataset sourced from the public web. Various splits of WebLI are used to develop versatile model capabilities such as visual semantic understanding, object localization, visually situated text understanding, and multilingual proficiency.

  • CC3M-35L:

    a) Three image-text pairs randomly sampled from CC3M dataset have some... |  Download Scientific Diagram

    Curated English image-alt_text pairs from webpages (Sharma et al., 2018). Translated into 34 additional languages using the Google Cloud Translation API.

  • VQ²A-CC3M-35L/VQG-CC3M-35L:
    A subset of VQ2A-CC3M (Changpinyo et al., 2022a), translated into the same 34 languages as CC3M-35L, using the Google Cloud Translation API.

  • OpenImages:

    Open Images V7 - Description

    Detection and object-aware questions and answers (Piergiovanni et al. 2022) generated by handcrafted rules on the OpenImages dataset.

  • WIT:
    Images and texts collected from Wikipedia (Srinivasan et al., 2021).

"Wait can I finetune it?"

Fine-Tuning Methods

1. JAX Fine-Tuning Script:

  • PaliGemma was trained in the big_vision codebase, which has also been used for models like BiT, ViT, LiT, CapPa, and SigLIP.

  • The project configuration folder configs/proj/paligemma/ contains a README.md.

  • Pretrained models can be transferred using configuration files in the transfers/ subfolder.

  • To transfer your own model, fork transfers/forkme.py and follow the instructions in the comments to adapt it to your use case.

  • A Colab notebook, finetune_paligemma.ipynb, provides a simplified fine-tuning process on a free T4 GPU runtime, updating only the weights in the attention layers (170M parameters) and using SGD instead of Adam.

2. Fine-Tuning with Hugging Face Transformers:

  • Fine-tuning PaliGemma is straightforward using the transformers library.

  • Methods such as QLoRA or LoRA fine-tuning can be employed.

  • An example process involves briefly fine-tuning the decoder, followed by switching to QLoRA fine-tuning.

  • Ensure to install the latest version of the transformers library.

3. Fine-Tuning with Vanilla Pytorch script

  • A small and lean PyTorch script to fine-tune the PaliGemma model

  • Developed by Aritra Roy Gosthipaty in this repository.

  • Fine-tune on any dataset containing images and caption pairs.

Training and Model Information

  • PaliGemma models have been released in various fine-tuned versions by Google.

  • These models were trained using the big_vision codebase, with a history of developing models like BiT, ViT, LiT, CapPa, SigLIP, and more.

Model Performance Table on Fine-tuned Checkpoints

Model NameDataset/TaskScore in Transferred Task
paligemma-3b-ft-vqav2-448Diagram Understanding85.64 Accuracy on VQAV2
paligemma-3b-ft-cococap-448COCO Captions144.6 CIDEr
paligemma-3b-ft-science-qa-448Science Question Answering95.93 Accuracy on ScienceQA Img subset with no CoT
paligemma-3b-ft-refcoco-seg-896Understanding References to Specific Objects in Images76.94 Mean IoU on refcoco
paligemma-3b-ft-rsvqa-hr-224Remote Sensing Visual Question Answering92.61 Accuracy on test

Why PaliGemma?

Changing the pretraining strategy and utilizing larger datasets like LAION can significantly enhance PaliGemma's capabilities as a multimodal model for various tasks. Pretraining on vast and diverse datasets improves the model's understanding and generation of nuanced and contextually rich outputs. By scaling the architecture, such as replacing the autoregressive decoder with a more advanced model like Gemini, and training the SigLIP processor on higher-quality, finer-grained images, PaliGemma can achieve superior performance in tasks requiring detailed visual-semantic understanding, precise object localization, and robust multilingual text generation. This will eventually lead to the model becoming more versatile and powerful for a wide range of multimodal applications.

1
Subscribe to my newsletter

Read articles from Ritwik Raha directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Ritwik Raha
Ritwik Raha