Fine-Tune Llama 3.2 Vision-Language Model on Custom Datasets

Llama 3.2, a powerful multimodal large language model (LLM) from Meta AI, has recently been released, pushing the boundaries of AI capabilities by enabling machines to understand both visual and textual information. While this pre-trained model is impressive out of the box, fine-tuning it on your specific dataset can significantly enhance its performance for your particular use case.

In this guide, we'll walk you through the process of fine-tuning Llama 3.2 Vision-Language Model (VLM) on a custom dataset. We'll cover everything from setting up your environment to testing your fine-tuned model.

Key Steps in Fine-Tuning Llama 3.2 VLM:

  1. Define your use case

  2. Set up the development environment

  3. Prepare the dataset

  4. Fine-tune the VLM using TRL and SFTTrainer

  5. Test the fine-tuned model

Let's dive into each step in detail.

1. Defining Your Multimodal Use Case

Before fine-tuning, it's crucial to have a clear understanding of the problem you're trying to solve. For VLMs, this typically involves tasks that integrate visual and textual data.

Example Use Case: Imagine you're building an e-commerce platform where sellers upload product images and metadata, and you need a system that automatically generates detailed product descriptions based on this input.

Image Source: Author

Why Fine-Tune? Though pre-trained VLMs may perform well, they may not fully capture the unique attributes of your dataset or use case. Fine-tuning allows you to adapt the model to your specific needs, such as generating SEO-optimized product descriptions or performing advanced image-based tasks.

Note: Llama 3.2 11B Vision requires at least 24 GB of GPU memory for efficient training or fine-tuning. Larger memory (32 GB or 40 GB) would be more ideal, especially if you're performing tasks with large batch sizes or complex datasets.

2. Setting Up the Development Environment

First, you'll need to install the necessary libraries:

pip install "torch==2.4.0" tensorboard pillow torchvision accelerate huggingface_hub
pip install --upgrade \
 "transformers==4.45.1" \
 "datasets==3.0.1" \
 "accelerate==0.34.2" \
 "evaluate==0.4.3" \
 "bitsandbytes==0.44.0" \
 "trl==0.11.1" \
 "peft==0.13.0" \
 "qwen_vl_utils"

Next, log in to Hugging Face to access the model and dataset:

from huggingface_hub import login

login(token="YOUR_HF_TOKEN")

Note: To generate your HF Token, you can go to Hugging Face or refer to our demos: - Inference API, Llama 3.2: How to Run Meta’s Multimodal AI in Minutes.

Make sure to agree to share your contact information to access the Llama 3.2 vision model when creating your API key.

3. Preparing the Dataset

Preparing a high-quality dataset is crucial for effective fine-tuning of vision-language models. For this tutorial, we'll use the Amazon Product Descriptions VLM dataset created by philschmid. This dataset is specifically designed for fine-tuning vision-language models on e-commerce product description generation tasks.

Dataset Overview:

  • Source: Amazon Product Descriptions VLM dataset

  • Content: Product images paired with product names, categories, and descriptions

  • Image format: JPEG (accessed via URLs)

  • Text data: Product names, categories, and descriptions

Downloading and Preparing the Dataset:

We'll use the Hugging Face datasets library to load and prepare our dataset. Here's how you can do this:

from datasets import load_dataset

prompt = """Create a Short Product description based on the provided ##PRODUCT NAME## and ##CATEGORY## and image.
Only return description. The description should be SEO optimized and for a better mobile search experience.

##PRODUCT NAME##: {product_name}
##CATEGORY##: {category}"""

system_message = "You are an expert product description writer for Amazon."

def format_data(sample):
    return {"messages": [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_message}],
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": prompt.format(product_name=sample["Product Name"], category=sample["Category"]),
                },{
                    "type": "image",
                    "image": sample["image"],
                }
            ],
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": sample["description"]}],
        },
    ],
    }

dataset_id = "philschmid/amazon-product-descriptions-vlm"
dataset = load_dataset(dataset_id, split="train")
dataset = [format_data(sample) for sample in dataset]

This code does the following:

  1. Defines a prompt template for generating product descriptions.

  2. Sets up a system message to guide the model's behavior.

  3. Creates a format_data function to structure each sample in the format expected by the model.

  4. Loads the dataset from Hugging Face and applies the formatting to each sample.

Data Preprocessing:

The format_data function handles the main preprocessing steps:

  1. Structuring the input: It creates a conversation-like structure with system, user, and assistant messages.

  2. Incorporating images: The image URLs are included in the user message content.

  3. Formatting text data: Product names and categories are inserted into the prompt template.

Note on Data Augmentation:

While this dataset doesn't require extensive augmentation, for other use cases you might consider techniques such as:

  • Random cropping or resizing of images

  • Text augmentation (e.g., synonym replacement, random insertion/deletion)

  • Generating additional product descriptions using other LLMs

These techniques can help improve model generalization, especially when working with smaller datasets.

Now that we have our dataset prepared, we're ready to move on to the fine-tuning process.

4. Fine-Tuning VLMs Using TRL, SFTTrainer, and Unsloth

For this fine-tuning process, we'll be using Unsloth, a powerful library that optimizes LLM fine-tuning. Unsloth can significantly speed up the process and reduce memory usage, making it easier to fine-tune large models like Llama 3.2 Vision.

Install Unsloth

First, let's install Unsloth:

pip install unsloth

Unsloth is designed to optimize the fine-tuning of large language models, providing various techniques to reduce memory usage and increase training speed.

Initialize the Llama 3.2 Vision Model with Unsloth

import torch
from transformers import AutoProcessor
from unsloth import FastLanguageModel

model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"

# Initialize the model with Unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_id,
    max_seq_length=2048,
    dtype=torch.bfloat16,
    load_in_4bit=True,
)

processor = AutoProcessor.from_pretrained(model_id)

This code initializes the Llama 3.2 Vision model using Unsloth's FastLanguageModel. We're using 4-bit quantization and bfloat16 precision for efficiency, which helps reduce memory usage and potentially speed up training.

Set Up LoRA Configuration for Fine-Tuning

from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=8,
    bias="none",
    target_modules=["q_proj", "v_proj"],
    task_type="CAUSAL_LM"
)

Here, we're setting up LoRA (Low-Rank Adaptation), a technique that significantly reduces the number of trainable parameters. We're using the PEFT (Parameter-Efficient Fine-Tuning) library to implement LoRA, which makes fine-tuning more efficient, especially for large models.

Training Configuration

from trl import SFTConfig

args = SFTConfig(
    output_dir="fine-tuned-visionllama-unsloth",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    logging_steps=5,
    save_strategy="epoch",
    learning_rate=2e-4,
    bf16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    push_to_hub=True,
    report_to="tensorboard",
    dataset_kwargs={"skip_prepare_dataset": True},
)

This SFTConfig sets up the training arguments using the TRL (Transformer Reinforcement Learning) library. We're employing techniques like gradient checkpointing and gradient accumulation to manage memory usage effectively during training.

Collate Data and Train the Model

from qwen_vl_utils import process_vision_info
from trl import SFTTrainer

def collate_fn(examples):
    texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
    image_inputs = [process_vision_info(example["messages"])[0]
    for example in examples:
        batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True)
        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = -100

        image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]
        for image_token_id in image_tokens:
            labels[labels == image_token_id] = -100
        batch["labels"] = labels

    return batch

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    data_collator=collate_fn,
    tokenizer=tokenizer,  # Use the tokenizer from Unsloth
    peft_config=peft_config
)

# Apply Unsloth optimizations
trainer = FastLanguageModel.get_peft_model(
    trainer,
    r=8,
    target_modules=["q_proj", "v_proj"],
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    use_gradient_checkpointing=True,
    random_state=3407,
    use_rslora=False,
    loftq_config=None,
)

trainer.train()

This final part sets up the training process using the SFTTrainer from TRL. The collate_fn prepares batches of data, applying the chat template and processing images. We then initialize the trainer with our model, training arguments, dataset, and LoRA configuration.

Finally, we apply Unsloth optimizations to the trainer using FastLanguageModel.get_peft_model. This step further enhances the training process, potentially speeding it up and reducing memory usage. The trainer.train() call starts the actual fine-tuning process.

This comprehensive setup allows for efficient fine-tuning of the Llama 3.2 Vision model on your custom dataset, leveraging advanced techniques like LoRA and Unsloth optimizations to make the process more manageable on consumer-grade hardware.

5. Saving the Fine-Tuned Model

After fine-tuning your model, it's crucial to save the PEFT (Parameter-Efficient Fine-Tuning) weights. These weights contain the changes made to the base model during fine-tuning, allowing you to reuse them without storing the entire model.

Saving PEFT Weights

To save the PEFT weights, you can use the save_pretrained method from the PEFT library. Here's how you can do it:

# Assuming 'trainer' is your SFTTrainer object from the previous step
peft_model = trainer.model

# Save the PEFT weights
output_dir = "path/to/save/peft_weights"
peft_model.save_pretrained(output_dir)

This code will save the PEFT weights to the specified directory. These weights are much smaller than the full model, making them easier to store and share.

Saving the Tokenizer

It's also a good practice to save the tokenizer along with your model weights:

# Save the tokenizer
tokenizer.save_pretrained(output_dir)

Saving the tokenizer ensures that you can properly tokenize inputs when using the model for inference later.

Pushing to Hugging Face Hub (Optional)

If you want to share your model or use it across different environments, you can push it to the Hugging Face Hub:

from huggingface_hub import HfApi

api = HfApi()
api.upload_folder(
    folder_path=output_dir,
    repo_id="your-username/your-model-name",
    repo_type="model"
)

Replace "your-username/your-model-name" with your Hugging Face username and desired model name.

6. Loading the Fine-Tuned Model for Inference

Once you've saved your fine-tuned model, you can load it for inference. Here's how you can do that:

Loading the Base Model and PEFT Weights

First, we need to load the base model and then apply our fine-tuned PEFT weights:

from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
from peft import PeftModel, PeftConfig

# Load the base model
base_model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
base_model = AutoModelForCausalLM.from_pretrained(base_model_id)

# Load the PEFT configuration
peft_model_id = "path/to/saved/peft_weights"
config = PeftConfig.from_pretrained(peft_model_id)

# Load the PEFT model
model = PeftModel.from_pretrained(base_model, peft_model_id)

# Load the tokenizer and processor
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
processor = AutoProcessor.from_pretrained(base_model_id)

Preparing the Model for Inference

Before using the model for inference, it's a good practice to put it in evaluation mode and move it to the appropriate device (CPU or GPU):

import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model.eval()
model.to(device)

Example Inference

Now, let's use our fine-tuned model to generate a product description based on an image and some basic product information:

from PIL import Image
import requests

# Load an image
image_url = "https://encrypted-tbn2.gstatic.com/shopping?q=tbn:ANd9GcTQ_qXzjL2INrn9jZCzv0gOfUzy3Ua-BaCuucKrdBhCnYI5dbxcAhiI8AwqnNux8aiqeJlMJbJ4AbeiM2za5b8Eh5_EMtInlwG_PGHtrBIRGKkfzQSHFCPi"
image = Image.open(requests.get(image_url, stream=True).raw)

# Prepare the input
product_name = "Ergonomic Office Chair"
category = "Furniture"

prompt = f"""Create a Short Product description based on the provided ##PRODUCT NAME## and ##CATEGORY## and image.
Only return description. The description should be SEO optimized and for a better mobile search experience.

##PRODUCT NAME##: {product_name}
##CATEGORY##: {category}"""

# Process the input
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)

# Generate the description
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=100,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
    )

# Decode the output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

print("Generated Product Description:")
print(generated_text)

This code does the following:

  1. Loads an image from a URL (replace with your actual image URL).

  2. Prepares the input prompt with the product name and category.

  3. Processes the input using the processor.

  4. Generates a description using the fine-tuned model.

  5. Decodes and prints the generated description.

Example Output

The output might look something like this:

Generated Product Description:
Experience ultimate comfort and productivity with our Ergonomic Office Chair.
Designed with your well-being in mind, this chair features adjustable lumbar
support, breathable mesh backrest, and customizable armrests. The sleek,
modern design complements any office space while promoting proper posture
and reducing fatigue during long work hours. Elevate your workspace with
this premium furniture piece, perfect for professionals seeking both style
and functionality in their daily work environment.

This example demonstrates how your fine-tuned model can generate a detailed, SEO-friendly product description based on the provided image, product name, and category. The description highlights key features, benefits, and target audience, showcasing the model's ability to create compelling content for e-commerce applications.

Conclusion

In this comprehensive guide, we've walked through the process of fine-tuning the Llama 3.2 Vision-Language Model for product description generation. We covered several key steps:

  1. Setting up the development environment with necessary libraries.

  2. Preparing the Amazon Product Descriptions VLM dataset.

  3. Fine-tuning the model using advanced techniques like LoRA and Unsloth optimizations.

  4. Saving the fine-tuned model weights and tokenizer.

  5. Loading the model for inference.

By following this guide, you should now have a custom-tuned vision-language model capable of generating SEO-optimized product descriptions based on images and basic product information. This fine-tuned model can be a powerful tool for e-commerce platforms, content creators, and marketers looking to automate and enhance their product description processes.

Remember that fine-tuning is an iterative process. Don't hesitate to experiment with different hyperparameters, dataset sizes, or even model architectures to achieve the best results for your specific use case.

Resources

For those interested in diving deeper into the concepts and tools used in this tutorial, here are some valuable resources:

  1. Llama 3.2: How to Run Meta’s Multimodal AI in Minutes

  2. Inference API: The easiest way to integrate NLP models for inference!

  3. Llama 2 Official Documentation: Learn more about the base Llama 2 model and its capabilities.

  4. Hugging Face Transformers Library: Comprehensive documentation for the Transformers library, which we used extensively in this tutorial.

  5. PEFT (Parameter-Efficient Fine-Tuning) Library: Explore more about efficient fine-tuning techniques like LoRA.

  6. Unsloth GitHub Repository: Dive into the details of Unsloth and its optimization techniques for LLM fine-tuning.

  7. TRL (Transformer Reinforcement Learning) Documentation: Learn more about the SFTTrainer and other tools for training language models.

  8. PyTorch Documentation: For a deeper understanding of the underlying deep learning framework used.

  9. Hugging Face Hub: Explore other models, datasets, and spaces that could be useful for your projects.

  10. Hugging Face Course: A free course that covers many of the concepts used in this tutorial in greater depth.

  11. Colab Notebook

By leveraging these resources, you can continue to expand your knowledge and skills in the exciting field of vision-language models and natural language processing.

Next Steps: Bringing AI into Your Business

Whether you're looking to integrate cutting-edge NLP models or deploy multimodal AI systems, we're here to support your journey. Reach out to us at contact@futuresmart.ai to learn more about how we can help.

Don't forget to check out our futuresmart.ai/case-studies to see how we've successfully partnered with companies to implement transformative AI solutions.

Let us help you take the next step in your AI journey.

1
Subscribe to my newsletter

Read articles from Manish Singh Parihar directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Manish Singh Parihar
Manish Singh Parihar