Implementing GraphRAG for Query-Focused Summarization

Stephen CollinsStephen Collins
5 min read

In this tutorial, I'll explore the implementation of the GraphRAG (Graph-based Retrieval-Augmented Generation) approach to query-focused summarization, as described in the research paper "From Local to Global: A GraphRAG Approach to Query-Focused Summarization" by Darren Edge et al. This method is designed to generate comprehensive and diverse answers to global questions over entire text corpora by leveraging a graph-based text index and an LLM (Large Language Model).

I'll walk through a Python-based implementation that includes key steps from the paper, demonstrating how to process documents, build a graph, detect communities, and generate a final answer to a query.

All of the code for this tutorial is available on GitHub.

Overview

GraphRAG enhances traditional RAG methods by addressing global questions directed at an entire text corpus. This is achieved through a pipeline that first builds an entity knowledge graph from source documents and then generates community summaries for groups of closely-related entities. Given a query, community summaries are used to generate partial responses, which are then summarized into a final global answer.

Prerequisites

To follow along with this tutorial, you'll need Python 3.12 or later. Install the necessary packages using pip:

pip install openai networkx leidenalg cdlib python-igraph python-dotenv

In addition to the above, you'll need to sign up for an OpenAI API key.

Implementation Steps

The GraphRAG pipeline involves the following steps:

1. Source Documents → Text Chunks

First, we split the input texts into manageable chunks for processing.

def split_documents_into_chunks(documents, chunk_size=600, overlap_size=100):
    chunks = []
    for document in documents:
        for i in range(0, len(document), chunk_size - overlap_size):
            chunk = document[i:i + chunk_size]
            chunks.append(chunk)
    return chunks

2. Text Chunks → Element Instances

Next, we'll extract entities and relationships from each chunk of text using OpenAI's GPT-4.

def extract_elements_from_chunks(chunks):
    elements = []
    for index, chunk in enumerate(chunks):
        response = client.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": "Extract entities and relationships from the following text."},
                {"role": "user", "content": chunk}
            ]
        )
        entities_and_relations = response.choices[0].message.content
        elements.append(entities_and_relations)
    return elements

3. Element Instances → Element Summaries

We summarize the extracted entities and relationships into a structured format.

def summarize_elements(elements):
    summaries = []
    for index, element in enumerate(elements):
        response = client.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": "Summarize the following entities and relationships in a structured format. Use \"->\" to represent relationships, after the \"Relationships:\" word."},
                {"role": "user", "content": element}
            ]
        )
        summary = response.choices[0].message.content
        summaries.append(summary)
    return summaries

4. Element Summaries → Graph Communities

We build a graph from the element summaries and detect communities using the Leiden algorithm.

def build_graph_from_summaries(summaries):
    G = nx.Graph()
    for summary in summaries:
        lines = summary.split("\n")
        entities_section = False
        relationships_section = False
        entities = []
        for line in lines:
            if line.startswith("### Entities:") or line.startswith("**Entities:**"):
                entities_section = True
                relationships_section = False
                continue
            elif line.startswith("### Relationships:") or line.startswith("**Relationships:**"):
                entities_section = False
                relationships_section = True
                continue
            if entities_section and line.strip():
                entity = line.split(".", 1)[1].strip() if line[0].isdigit() and line[1] == "." else line.strip()
                entity = entity.replace("**", "")
                entities.append(entity)
                G.add_node(entity)
            elif relationships_section and line.strip():
                parts = line.split("->")
                if len(parts) >= 2:
                    source = parts[0].strip()
                    target = parts[-1].strip()
                    relation = " -> ".join(parts[1:-1]).strip()
                    G.add_edge(source, target, label=relation)
    return G

def detect_communities(graph):
    communities = []
    for component in nx.connected_components(graph):
        subgraph = graph.subgraph(component)
        if len(subgraph.nodes) > 1:
            try:
                sub_communities = algorithms.leiden(subgraph)
                for community in sub_communities.communities:
                    communities.append(list(community))
            except Exception as e:
                print(f"Error processing community: {e}")
        else:
            communities.append(list(subgraph.nodes))
    return communities

5. Graph Communities → Community Summaries

We summarize each detected community.

def summarize_communities(communities, graph):
    community_summaries = []
    for index, community in enumerate(communities):
        subgraph = graph.subgraph(community)
        nodes = list(subgraph.nodes)
        edges = list(subgraph.edges(data=True))
        description = "Entities: " + ", ".join(nodes) + "\nRelationships: "
        relationships = []
        for edge in edges:
            relationships.append(
                f"{edge[0]} -> {edge[2]['label']} -> {edge[1]}")
        description += ", ".join(relationships)

        response = client.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": "Summarize the following community of entities and relationships."},
                {"role": "user", "content": description}
            ]
        )
        summary = response.choices[0].message.content.strip()
        community_summaries.append(summary)
    return community_summaries

6. Community Summaries → Community Answers → Global Answer

Finally, we generate answers from community summaries and combine them into a final global answer.

def generate_answers_from_communities(community_summaries, query):
    intermediate_answers = []
    for summary in community_summaries:
        response = client.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": "Answer the following query based on the provided summary."},
                {"role": "user", "content": f"Query: {query} Summary: {summary}"}
            ]
        )
        intermediate_answers.append(response.choices[0].message.content)

    final_response = client.chat.completions.create(
        model="gpt-4",
        messages=[
            {"role": "system", "content": "Combine these answers into a final, concise response."},
            {"role": "user", "content": f"Intermediate answers: {intermediate_answers}"}
        ]
    )
    final_answer = final_response.choices[0].message.content
    return final_answer

Putting It All Together

We can now combine these steps into a single pipeline function.

def graph_rag_pipeline(documents, query, chunk_size=600, overlap_size=100):
    chunks = split_documents_into_chunks(documents, chunk_size, overlap_size)
    elements = extract_elements_from_chunks(chunks)
    summaries = summarize_elements(elements)
    graph = build_graph_from_summaries(summaries)
    communities = detect_communities(graph)
    community_summaries = summarize_communities(communities)
    final_answer = generate_answers_from_communities(community_summaries, query)
    return final_answer

# Example usage
query = "What are the main themes in these documents?"
answer = graph_rag_pipeline(DOCUMENTS, query)
print('Answer:', answer)

Limitations and Improvements

While this example implementation provides a starting point for graph-augmented summarization and question-answering tasks, there are several areas for potential improvement:

  1. Integration with Graph Databases: Using graph databases like Neo4j could enhance the scalability and efficiency of the graph operations. Neo4j's powerful graph traversal and querying capabilities would allow for more complex and large-scale analyses.

  2. Leveraging LlamaIndex: Incorporating LlamaIndex could further streamline the process of indexing and retrieving document chunks. LlamaIndex provides efficient methods for handling large datasets, which could improve the performance of the Graph RAG pipeline.

  3. Enhanced Entity and Relationship Extraction: The current implementation uses GPT-4o for extracting entities and relationships. Fine-tuning the prompts or using domain-specific models could improve the accuracy and relevance of the extracted elements.

  4. Community Detection Algorithms: While the Leiden algorithm is used here, experimenting with other community detection algorithms could yield better results depending on the nature of the dataset. Algorithms like Louvain or Infomap might offer alternative insights.

  5. Validation: Developing a validation strategy to assess the quality of the summaries and answers generated by the GraphRAG pipeline is crucial. This could involve techniques such as cross-validation, precision-recall analysis, or even human-in-the-loop validation.

  6. Interactive User Interface: Developing an interactive UI for visualizing the graph and its communities could make the summarization process more intuitive. Tools like D3.js or Cytoscape.js can be used to create dynamic visualizations.

Conclusion

In this tutorial, we implemented a GraphRAG pipeline for query-focused summarization based on the research paper "From Local to Global: A GraphRAG Approach to Query-Focused Summarization." This approach uses a graph-based text index to generate comprehensive and diverse answers to global questions over entire text corpora. By following the steps outlined above, you can leverage the capabilities of LLMs and graph-based methods to achieve advanced summarization and question-answering tasks.

0
Subscribe to my newsletter

Read articles from Stephen Collins directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Stephen Collins
Stephen Collins

Senior Software engineer currently working with a climate-tech startup