Optimize RAG with Contextual Retrieval in PostgreSQL

KemingKeming
6 min read

In the standard approach to Retrieval Augmented Generation (RAG), a common practice is to break down large documents into smaller, more manageable chunks. This is primarily done to improve the efficiency of the retrieval process, leveraging the sentence embedding and keyword bm25 scoring, allowing systems to quickly find and pull relevant pieces of information. While effective in many scenarios, this method can introduce a significant challenge: the potential for individual chunks to lose their surrounding context.

An example could be from a collection of historical documents. If a user asks about a specific event, and a retrieved chunk says: "The treaty was signed, bringing an end to hostilities."

Without the surrounding text, this chunk lacks vital context such as which treaty was signed, who the involved parties were, and when or where this significant event took place. Relying solely on this decontextualized snippet would provide an incomplete and potentially inaccurate picture to the LLM, hindering its ability to generate a comprehensive and accurate answer about the historical event.

The example highlights that while chunking aids retrieval efficiency, it can inadvertently strip away the necessary context that gives meaning and relevance to the individual pieces of information. This loss of context can significantly impact the quality and accuracy of the responses generated by LLMs in a RAG system.

To tackle this problem, Anthropic explores a method called Contextual Retrieval.

Basically, it could be treated as the chunk augmentation with the document context.

This article will walk through the steps to improve the retrieval accuracy with the help of contextual information. We use the same prompt, dataset, and metrics described in the anthropic tutorial. Note that we use a different embedding model and LLM in this experiment, so the results will differ. However, this does not reflect the model's accuracy, as the dataset consists of code rather than natural language.

We mainly use Gemini models to generate the chunk context and text embedding, other models also work. The embedding quality affects the vector retrieval metrics heavily, we recommend checking the leaderboard to find the one that suits your use case. For the keyword tokenization, we use the default BERT tokenizer, you can also try other solutions.

The rerank part is not the main point of this article, if you want to know more about the rerank methods, check our article about the cross-encoder rerank and token-level late interaction rerank. BTW, since we support multivector MaxSim index, the token-level late interaction rerank can be done more efficiently.

Prepare the environment

We provide both vector search and keyword search features in our VectorChord-suite docker image.

docker run --rm -d --name vdb -e POSTGRES_PASSWORD=postgres -p 5432:5432 ghcr.io/tensorchord/vchord-suite:pg17-20250414

We will use the vechord Python library to reduce the boilerplate code and simplify the process.

pip install 'vechord[gemini]'

Let’s get started with the basic RAG process.

from vechord.spec import (
    ForeignKey,
    Keyword,
    PrimaryKeyAutoIncrease,
    Table,
    UniqueIndex,
    Vector,
)

DenseVector = Vector[768]

class Document(Table, kw_only=True):
    uid: Optional[PrimaryKeyAutoIncrease] = None
    uuid: Annotated[str, UniqueIndex()]
    content: str

class Chunk(Table, kw_only=True):
    uid: Optional[PrimaryKeyAutoIncrease] = None
    doc_uuid: Annotated[str, ForeignKey[Document.uuid]]
    index: int
    content: str
    vector: DenseVector
    keyword: Keyword

class Query(Table, kw_only=True):
    uid: Optional[PrimaryKeyAutoIncrease] = None
    content: str
    answer: str
    doc_uuids: list[str]
    chunk_index: list[int]
    vector: DenseVector

This includes the vector embedding index and keyword index.

To load the datasets:

from vechord.embedding import GeminiDenseEmbedding
from vechord.registry import VechordRegistry

emb = GeminiDenseEmbedding()
vr = VechordRegistry("anthropic", "postgresql://postgres:postgres@172.17.0.1:5432/")

vr.register([Document, Chunk, Query])

def load_data(filepath: str):
    with open(filepath, "r", encoding="utf-8") as f:
        docs = json.load(f)
        for doc in docs:
            vr.insert(
                Document(
                    uuid=doc["original_uuid"],
                    content=doc["content"],
                )
            )
            for chunk in doc["chunks"]:
                vr.insert(
                    Chunk(
                        doc_uuid=doc["original_uuid"],
                        index=chunk["original_index"],
                        content=chunk["content"],
                        vector=emb.vectorize_chunk(chunk["content"]),
                        keyword=Keyword(chunk["content"]),
                    )
                )

def load_query(filepath: str):
    queries = []
    with open(filepath, "r", encoding="utf-8") as f:
        for line in f:
            query = json.loads(line)
            queries.append(
                Query(
                    content=query["query"],
                    answer=query["answer"],
                    doc_uuids=[x[0] for x in query["golden_chunk_uuids"]],
                    chunk_index=[x[1] for x in query["golden_chunk_uuids"]],
                    vector=emb.vectorize_query(query["query"]),
                )
            )
    vr.copy_bulk(queries)

Evaluation

Now we have everything for the basic RAG process. Let’s define the evaluation metric Pass@k , which means that retrieved top-k chunks contain how many of the groundtruth chunks.

def evaluate(topk=5, search_func=vector_search):
    print(f"TopK={topk}, search by: {search_func.__name__}")
    queries: list[Query] = vr.select_by(Query.partial_init())
    total_score = 0
    start = perf_counter()
    for query in queries:
        chunks: list[Chunk] = search_func(query, topk)
        count = 0
        for doc_uuid, chunk_index in zip(
            query.doc_uuids, query.chunk_index, strict=True
        ):
            for chunk in chunks:
                if chunk.doc_uuid == doc_uuid and chunk.index == chunk_index:
                    count += 1
                    break
        score = count / len(query.doc_uuids)
        total_score += score

    print(
        f"Pass@{topk}: {total_score / len(queries):.4f}, total queries: {len(queries)}, QPS: {len(queries) / (perf_counter() - start):.3f}"
    )

We can try different retrieval strategies like vector search, keyword search, hybrid search with fusion or rerank.

Those strategies can be defined as:

from vechord.rerank import CohereReranker, ReciprocalRankFusion

def vector_search(query: Query, topk: int) -> list[Chunk]:
    return vr.search_by_vector(Chunk, query.vector, topk=topk)

def keyword_search(query: Query, topk: int) -> list[Chunk]:
    return vr.search_by_keyword(Chunk, query.content, topk=topk)

def hybrid_search_fuse(query: Query, topk: int) -> list[Chunk]:
    rrf = ReciprocalRankFusion()
    return rrf.fuse([vector_search(query, topk), keyword_search(query, topk)])[:topk]

def hybrid_search_rerank(query: Query, topk: int, boost=3) -> list[Chunk]:
    ranker = CohereReranker()
    vecs = vector_search(query, topk * boost)
    keys = keyword_search(query, topk * boost)
    chunks = list({chunk.uid: chunk for chunk in vecs + keys}.values())
    indices = ranker.rerank(query.content, [chunk.content for chunk in chunks])
    return [chunks[i] for i in indices[:topk]]

Contextual retrieval

The LLM generates contextual information with the prompt like:

prompt = (
    "<document>\n{whole_document}\n</document>"
    "Here is the chunk we want to situate within the whole document \n"
    "<chunk>\n{chunk}\n</chunk>\n"
    "Please give a short succinct context to situate this chunk within "
    "the overall document for the purposes of improving search retrieval "
    "of the chunk. Answer only with the succinct context and nothing else."
)

The prompt above is a general one designed for retrieval tasks. You may find that a more specialized prompt offers better performance for your specific use case.

This feature is already included in the vechord library.

GeminiAugmenter also supports the prompt caching like Claude, but it requires the cached content to be at least 32768 tokens, which is not the case for this dataset.

from vechord.augment import GeminiAugmenter

class ContextualChunk(Table, kw_only=True):
    uid: Optional[PrimaryKeyAutoIncrease] = None
    doc_uuid: Annotated[str, ForeignKey[Document.uuid]]
    index: int
    content: str
    context: str
    vector: DenseVector
    keyword: Keyword

def load_contextual_chunks(filepath: str):
    augmenter = GeminiAugmenter()

    with open(filepath, "r", encoding="utf-8") as f:
        docs = json.load(f)
        for doc in docs:
            augmenter.reset(doc["content"])
            chunks = doc["chunks"]
            augments = augmenter.augment_context([chunk["content"] for chunk in chunks])
            if len(augments) != len(chunks):
                print(f"augments length not match for uuid: {doc['original_uuid']}, {len(augments)} != {len(chunks)}")
            for chunk, context in zip(chunks, augments, strict=False):
                contextual_content = f"{chunk['content']}\n\n{context}"
                vr.insert(
                    ContextualChunk(
                        doc_uuid=doc["original_uuid"],
                        index=chunk["original_index"],
                        content=chunk["content"],
                        context=context,
                        vector=emb.vectorize_chunk(contextual_content),
                        keyword=Keyword(contextual_content),
                    )
                )

Then, we can expand the strategies like:

def vector_contextual_search(query: Query, topk: int) -> list[ContextualChunk]:
    return vr.search_by_vector(ContextualChunk, query.vector, topk=topk)

def keyword_contextual_search(query: Query, topk: int) -> list[ContextualChunk]:
    return vr.search_by_keyword(ContextualChunk, query.content, topk=topk)

def hybrid_contextual_search_fuse(query: Query, topk: int) -> list[ContextualChunk]:
    rrf = ReciprocalRankFusion()
    return rrf.fuse(
        [vector_contextual_search(query, topk), keyword_contextual_search(query, topk)]
    )[:topk]

def hybrid_contextual_search_rerank(
    query: Query, topk: int, boost=3
) -> list[ContextualChunk]:
    ranker = CohereReranker()
    vecs = vector_contextual_search(query, topk * boost)
    keys = keyword_contextual_search(query, topk * boost)
    chunks = list({chunk.uid: chunk for chunk in vecs + keys}.values())
    indices = ranker.rerank(
        query.content, [f"{chunk.content}\n{chunk.context}" for chunk in chunks]
    )
    return [chunks[i] for i in indices[:topk]]

Benchmark

  • topk=5
Pass_at_5QPS
vector search0.8071289.024
keyword search0.7003337.211
vector contextual search0.8404307.253
keyword contextual search0.8033331.239
  • topk=10
Pass_at_10QPS
vector search0.8574254.730
keyword search0.7577219.653
vector contextual search0.8807247.819
keyword contextual search0.8563216.560

All the code can be found in our vechord repository.

0
Subscribe to my newsletter

Read articles from Keming directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Keming
Keming