Unlocking 70% Faster Response Times Through Token Pooling

Jonathan AdlyJonathan Adly
8 min read
TLDR
This post examines improvements made to ColiVara, our ColPali-based retrieval API. We focus on hybrid search and hierarchical clustering token pooling. By benchmarking these two approaches, we aim to evaluate their impact on latency and performance.

Background

This post examines improvements made to ColiVara, our ColPali-based retrieval API. We focus on hybrid search and hierarchical clustering token pooling. By benchmarking these two approaches, we aim to evaluate their impact on latency and performance.

The conventional approach to handling documents for RAG or data extraction typically involves a multi-stage process: Optical Character Recognition (OCR) to extract text, Layout Recognition to understand the document structure, Figure Captioning to interpret images, Chunking to segment the text, and finally, Embedding to represent each segment in a vector space.

This pipeline is not only complex and computationally demanding but also prone to error propagation. Inaccuracies in any stage, for example, OCR errors or misinterpretation of visual layouts, can significantly degrade the quality of downstream retrieval and generation.

A more streamlined approach, as pioneered in the ColPali paper, leverages the power of vision models. Instead of complex pre-processing, this method directly embeds entire document pages as images, simplifying the retrieval process to a similarity search on these visual embeddings.

This approach eliminates the need for laborious pre-processing steps and potentially captures richer contextual information from the visual layout and style of the document. This innovative approach forms the foundation of our work in ColiVara.

This post details our research on using hierarchical clustering token pooling, building upon and extending the core ideas presented in ColPali.

๐Ÿ’ก
A key contribution of ColiVara lies in its API-first design. We prioritizes developer experience and integration into real-world applications. This architecture, however, introduces practical considerations related to network latency and data storage.

Research Question

In our previous post, we looked into whether we could speed up our inference by using different similarity calculations.

We found that the bottleneck is directly linked to the number of embeddings for each document. The more embeddings there are, the longer the calculation takes. This relationship has a quadratic complexity of O(n*m), where m is the number of document vectors.

The question we wanted to answer was: Can we maintain the same state-of-the-art performance with either fewer document candidates or a significantly lower embedding count?

The goal was to see if we could improve latency by reducing the number of documents (using hybrid search) or the number of document vectors (using token pooling) without losing any performance.

Our metric, consistent with the ColPali paper, was the NCDG@5 score in ArxivQA and our typical API request latency using:

  • Hybrid Keyword Search with Postgres native search capabilities to reduce the number of candidates for inference

  • Token pooling at indexing time to reduce the total number of embeddings by averaging within similarity clusters

The ColPali paper uses late-interaction style computations to determine the relevancy between a query and documents. Here is a simple GitHub gist explaining how late-interaction style computations work: https://gist.github.com/Jonathan-Adly/cb905cc3958b973d7b3b6f25d9915c39

The key point is that late-interactions rely on multi-vector representation. Multi-vectors introduce significant storage and memory overhead and are computationally demanding to search through.

Baseline implementation

To get a realistic picture on real-life performance, we set the parameters as close to our production setup at ColiVara as possible. Embeddings were stored in Postgres with a pgVector extensions. Everything ran in an AWS r6g.xlarge (4 core CPU, 32g ram) and called from our python backend code hosted in a VPS.

๐Ÿ’ก
As we do things over the network, latency is also affected by where the user is in the globe and their network conditions.

We use ColQwen2 as the base model and colpali-engine (v0.3.4) to generate embeddings. It improves upon the base implementation of the paper with ~25% less embeddings and better performance. Our work builds on top of those improvements to enhance them further.

Results:

On ArxivQA using NCDG@5 and end to end latency we had the following:

  • Average NDCG@5 score: 0.88

  • Average latency: 3.58 seconds

The dataset is composed of 500 pages. This score matches the leader of the Vidore leaderboard and considered state of the art retrieval.

๐Ÿ’ก
DCG is a measure of relevance that considers the position of relevant results in the returned list. It assigns higher scores to results that appear earlier. Normalized Discounted Cumulative Gain normalizes DCG by dividing it by the ideal DCG (IDCG) for a given query, providing a score between 0 and 1. In this project, we calculate NDCG@5 to evaluate the top 5 search results for each query.

Having run a few large RAG applications before, we were well aware of the power of hybrid search. The details of the implementation can vary, but the main idea is to quickly narrow down your candidate documents using Postgres search capabilities, then re-rank them with more advanced semantic search techniques.

In our implementation, we used gemini-flash-8B to create captions and add keywords to each document during indexing. At query time, we used the same LLM to convert the query into keywords, then employed standard Postgres search with a GIN index to retrieve the top 25 documents.

๐Ÿ’ก
There are lots of variable parameters in a hybrid search implementation. As our goal was to improve latency we were intentional about using the fastest reasonable multimodal LLM at query time. The lowest reasonable count of candidate. And out of the box standard Postgres settings.

Results:

On ArxivQA using NCDG@5 and end to end latency we had the following:

  • Latency: 2.65

  • Score: 0.68

The good news is that we reduced our latency by about 25%. However, our performance decreased significantly. We believe that by adjusting all the parameters of the hybrid search implementation, we can improve performance. However, it will involve a trade-off between latency and performance. For instance, using a larger model like Qwen2-VL 72B for keyword extraction might enhance performance, but it will also be slower. Similarly, indexing the full text of the documents instead of just using captions or keywords might be better, but it would also slow things down.

Hierarchical Clustering Token Pooling

The next optimization we wanted to test is hierarchical clustering token pooling. Despite the complex name, it's actually quite simple. You take your embeddings and average them in clusters of similarity.

You start with a desired compression level, called the pooling factor. At a pooling factor of 1, you keep your embeddings as they are. With a factor of 2, you reduce your embeddings by half. At a pooling factor of 3, you save only a third of your original embeddings.

We got the idea from this excellent post by Answer.AI, which tried this approach using ColBert and found success. The optimal point seemed to be a pooling factor of 3, so we chose that.

๐Ÿ’ก
The most common implementation of pooling is to just take everything and average it into a single-vector. Like Answer.AI - we are very skeptical that this is a good approach. Not all tokens are created equal. Single-vector pooling averages the important token with unimportant token into a single representation.

The implementation to do this is ~10 lines of code. A standalone function that pools your embedding at index time or even after the fact.

def pool_embeddings(embeddings: torch.Tensor, pool_factor: int = 3) -> List[List[float]]:
    """
    Reduces number of embeddings by clustering similar ones together.

    Args:
        embeddings: Single image embeddings of shape (1038, 128)
                   Example with 4 vectors, 3 dimensions for simplicity:
                   [[1,0,1],
                    [1,0,1],
                    [0,1,0],
                    [0,1,0]]
    """
    # Step 1: Calculate similarity between all vectors
    # For our example above, this creates a 4x4 similarity matrix:
    # [[1.0  1.0  0.0  0.0],    # Token 1 compared to all tokens (same, same, different, different)
    #  [1.0  1.0  0.0  0.0],    # Token 2 compared to all tokens
    #  [0.0  0.0  1.0  1.0],    # Token 3 compared to all tokens
    #  [0.0  0.0  1.0  1.0]]    # Token 4 compared to all tokens
    # High values (1.0) mean tokens are very similar
    similarities = torch.mm(embeddings, embeddings.t())

    # Step 2: Convert to distances (1 - similarity)
    # For our example:
    # [[0.0  0.0  1.0  1.0],    # Now low values mean similar
    #  [0.0  0.0  1.0  1.0],    # 0.0 = identical
    #  [1.0  1.0  0.0  0.0],    # 1.0 = completely different
    #  [1.0  1.0  0.0  0.0]]
    distances = 1 - similarities.cpu().numpy()

    # Step 3: Calculate target number of clusters
    # For our example with pool_factor=2:
    # 4 tokens โ†’ 2 clusters
    target_clusters = max(embeddings.shape[0] // pool_factor, 1)

    # Step 4: Perform hierarchical clustering
    # This groups similar tokens together
    # For our example, cluster_labels would be:
    # [1, 1, 2, 2]  # Tokens 1&2 in cluster 1, Tokens 3&4 in cluster 2
    clusters = linkage(distances, method="ward")
    cluster_labels = fcluster(clusters, t=target_clusters, criterion="maxclust")

    # Step 5: Average embeddings within each cluster
    # For our example:
    # Cluster 1 average = [1,0,1] and [1,0,1] โ†’ [1,0,1]
    # Cluster 2 average = [0,1,0] and [0,1,0] โ†’ [0,1,0]
    # Final result: [[1,0,1], [0,1,0]]
    pooled = []
    for cluster_id in range(1, target_clusters + 1):
        mask = cluster_labels == cluster_id
        cluster_embeddings = embeddings[mask]
        cluster_mean = cluster_embeddings.mean(dim=0)
        pooled.append(cluster_mean.tolist())

    return pooled

This method of pooling is quite elegant and brilliant. Where - nothing else changes except the number of embeddings being generated.

Results:

On ArxivQA using NCDG@5 and end to end latency we had the following:

  • Latency: 2.13

  • Score: 0.87

This was as close to a free lunch as you will ever get. The storage cost went down by 66%, latency improved by about 40%, and there was very little performance loss. It was magnificent and beautiful, the kind of optimization that is usually only theoretical.

We decided to implement it in production and ran the entire evaluation suite. The final results were even more impressive, with up to 70% better latency on larger document collections and very minimal loss. You can see the full results here.

Conclusion

We tested and experimented with hybrid search and token pooling to improve our API's latency without sacrificing performance. Hybrid search improved latency by narrowing down results with keywords, but it reduced performance. On the other hand, hierarchical clustering token pooling greatly improved storage efficiency and latency with minimal performance loss.

0
Subscribe to my newsletter

Read articles from Jonathan Adly directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Jonathan Adly
Jonathan Adly