Unlocking 70% Faster Response Times Through Token Pooling


TLDR
Background
This post examines improvements made to ColiVara, our ColPali-based retrieval API. We focus on hybrid search and hierarchical clustering token pooling. By benchmarking these two approaches, we aim to evaluate their impact on latency and performance.
The conventional approach to handling documents for RAG or data extraction typically involves a multi-stage process: Optical Character Recognition (OCR) to extract text, Layout Recognition to understand the document structure, Figure Captioning to interpret images, Chunking to segment the text, and finally, Embedding to represent each segment in a vector space.
This pipeline is not only complex and computationally demanding but also prone to error propagation. Inaccuracies in any stage, for example, OCR errors or misinterpretation of visual layouts, can significantly degrade the quality of downstream retrieval and generation.
A more streamlined approach, as pioneered in the ColPali paper, leverages the power of vision models. Instead of complex pre-processing, this method directly embeds entire document pages as images, simplifying the retrieval process to a similarity search on these visual embeddings.
This approach eliminates the need for laborious pre-processing steps and potentially captures richer contextual information from the visual layout and style of the document. This innovative approach forms the foundation of our work in ColiVara.
This post details our research on using hierarchical clustering token pooling, building upon and extending the core ideas presented in ColPali.
Research Question
In our previous post, we looked into whether we could speed up our inference by using different similarity calculations.
We found that the bottleneck is directly linked to the number of embeddings for each document. The more embeddings there are, the longer the calculation takes. This relationship has a quadratic complexity of O(n*m)
, where m
is the number of document vectors.
The question we wanted to answer was: Can we maintain the same state-of-the-art performance with either fewer document candidates or a significantly lower embedding count?
The goal was to see if we could improve latency by reducing the number of documents (using hybrid search) or the number of document vectors (using token pooling) without losing any performance.
Our metric, consistent with the ColPali paper, was the NCDG@5 score in ArxivQA and our typical API request latency using:
Hybrid Keyword Search with Postgres native search capabilities to reduce the number of candidates for inference
Token pooling at indexing time to reduce the total number of embeddings by averaging within similarity clusters
The ColPali paper uses late-interaction style computations to determine the relevancy between a query and documents. Here is a simple GitHub gist explaining how late-interaction style computations work: https://gist.github.com/Jonathan-Adly/cb905cc3958b973d7b3b6f25d9915c39
The key point is that late-interactions rely on multi-vector representation. Multi-vectors introduce significant storage and memory overhead and are computationally demanding to search through.
Baseline implementation
To get a realistic picture on real-life performance, we set the parameters as close to our production setup at ColiVara as possible. Embeddings were stored in Postgres with a pgVector extensions. Everything ran in an AWS r6g.xlarge (4 core CPU, 32g ram) and called from our python backend code hosted in a VPS.
We use ColQwen2 as the base model and colpali-engine (v0.3.4) to generate embeddings. It improves upon the base implementation of the paper with ~25% less embeddings and better performance. Our work builds on top of those improvements to enhance them further.
Results:
On ArxivQA using NCDG@5 and end to end latency we had the following:
Average NDCG@5 score: 0.88
Average latency: 3.58 seconds
The dataset is composed of 500 pages. This score matches the leader of the Vidore leaderboard and considered state of the art retrieval.
Hybrid search
Having run a few large RAG applications before, we were well aware of the power of hybrid search. The details of the implementation can vary, but the main idea is to quickly narrow down your candidate documents using Postgres search capabilities, then re-rank them with more advanced semantic search techniques.
In our implementation, we used gemini-flash-8B to create captions and add keywords to each document during indexing. At query time, we used the same LLM to convert the query into keywords, then employed standard Postgres search with a GIN index to retrieve the top 25 documents.
Results:
On ArxivQA using NCDG@5 and end to end latency we had the following:
Latency: 2.65
Score: 0.68
The good news is that we reduced our latency by about 25%. However, our performance decreased significantly. We believe that by adjusting all the parameters of the hybrid search implementation, we can improve performance. However, it will involve a trade-off between latency and performance. For instance, using a larger model like Qwen2-VL 72B for keyword extraction might enhance performance, but it will also be slower. Similarly, indexing the full text of the documents instead of just using captions or keywords might be better, but it would also slow things down.
Hierarchical Clustering Token Pooling
The next optimization we wanted to test is hierarchical clustering token pooling. Despite the complex name, it's actually quite simple. You take your embeddings and average them in clusters of similarity.
You start with a desired compression level, called the pooling factor. At a pooling factor of 1, you keep your embeddings as they are. With a factor of 2, you reduce your embeddings by half. At a pooling factor of 3, you save only a third of your original embeddings.
We got the idea from this excellent post by Answer.AI, which tried this approach using ColBert and found success. The optimal point seemed to be a pooling factor of 3, so we chose that.
The implementation to do this is ~10 lines of code. A standalone function that pools your embedding at index time or even after the fact.
def pool_embeddings(embeddings: torch.Tensor, pool_factor: int = 3) -> List[List[float]]:
"""
Reduces number of embeddings by clustering similar ones together.
Args:
embeddings: Single image embeddings of shape (1038, 128)
Example with 4 vectors, 3 dimensions for simplicity:
[[1,0,1],
[1,0,1],
[0,1,0],
[0,1,0]]
"""
# Step 1: Calculate similarity between all vectors
# For our example above, this creates a 4x4 similarity matrix:
# [[1.0 1.0 0.0 0.0], # Token 1 compared to all tokens (same, same, different, different)
# [1.0 1.0 0.0 0.0], # Token 2 compared to all tokens
# [0.0 0.0 1.0 1.0], # Token 3 compared to all tokens
# [0.0 0.0 1.0 1.0]] # Token 4 compared to all tokens
# High values (1.0) mean tokens are very similar
similarities = torch.mm(embeddings, embeddings.t())
# Step 2: Convert to distances (1 - similarity)
# For our example:
# [[0.0 0.0 1.0 1.0], # Now low values mean similar
# [0.0 0.0 1.0 1.0], # 0.0 = identical
# [1.0 1.0 0.0 0.0], # 1.0 = completely different
# [1.0 1.0 0.0 0.0]]
distances = 1 - similarities.cpu().numpy()
# Step 3: Calculate target number of clusters
# For our example with pool_factor=2:
# 4 tokens โ 2 clusters
target_clusters = max(embeddings.shape[0] // pool_factor, 1)
# Step 4: Perform hierarchical clustering
# This groups similar tokens together
# For our example, cluster_labels would be:
# [1, 1, 2, 2] # Tokens 1&2 in cluster 1, Tokens 3&4 in cluster 2
clusters = linkage(distances, method="ward")
cluster_labels = fcluster(clusters, t=target_clusters, criterion="maxclust")
# Step 5: Average embeddings within each cluster
# For our example:
# Cluster 1 average = [1,0,1] and [1,0,1] โ [1,0,1]
# Cluster 2 average = [0,1,0] and [0,1,0] โ [0,1,0]
# Final result: [[1,0,1], [0,1,0]]
pooled = []
for cluster_id in range(1, target_clusters + 1):
mask = cluster_labels == cluster_id
cluster_embeddings = embeddings[mask]
cluster_mean = cluster_embeddings.mean(dim=0)
pooled.append(cluster_mean.tolist())
return pooled
This method of pooling is quite elegant and brilliant. Where - nothing else changes except the number of embeddings being generated.
Results:
On ArxivQA using NCDG@5 and end to end latency we had the following:
Latency: 2.13
Score: 0.87
This was as close to a free lunch as you will ever get. The storage cost went down by 66%, latency improved by about 40%, and there was very little performance loss. It was magnificent and beautiful, the kind of optimization that is usually only theoretical.
We decided to implement it in production and ran the entire evaluation suite. The final results were even more impressive, with up to 70% better latency on larger document collections and very minimal loss. You can see the full results here.
Conclusion
We tested and experimented with hybrid search and token pooling to improve our API's latency without sacrificing performance. Hybrid search improved latency by narrowing down results with keywords, but it reduced performance. On the other hand, hierarchical clustering token pooling greatly improved storage efficiency and latency with minimal performance loss.
Subscribe to my newsletter
Read articles from Jonathan Adly directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by
