Skip to content

Commit bf2903a

Browse files
amullick-gitAmarnath Mullick
andauthored
feat(graph): Add Custom Retrievers for Spanner Graph RAG. (#122)
* Add Spanner Graph QA Chain * Formatted notebook. Added copyright message to prompts file. * Add missing imports for random graph name * Make input table name randomized in integration tests to avoid name collision for tests running parallely from different python environments * Provide timeout to graph cleanup * Make default timeout of 300 secs for ddl application * Increase timeout of integration test * Change integration test timeout * Minor formatting fixes * Make the ddl operations test fixture scoped for the module * Addressed review comments * Addressed a few other review comments. * Remove unused function * fix type check errors * Addressed review comments * Addressed review comments * Clear default project id from notebook * Add import statement for SpanerGraphQAChain to notebook * Add retrievers for Spanner Graph RAG * Add licence headers * Fix DATABASE name key * Fix lint error on import ordering * Fix lint errors * Few minor changes to the SpannerGraphNodeVectorRetriever * Fix lint error * Add an option to expand context graph by hops * Fix lint error * Addressed review comments * Remove expansion query options * Add backticks to property names * Change copyright year * Address review comments * Rename the retrievers. Merge the semantic retriever with the gql retriever. * Fixed lint errors * Change vertex ai versionto latest * Fix lint errors * Add documentation. Fixes the case where expands_by_hops is 0 * Add unit test for expand_by_hops=0 * Fix formatting for documentation * Addressed review comments --------- Co-authored-by: Amarnath Mullick <amullick@google.com>
1 parent fd788d8 commit bf2903a

File tree

7 files changed

+714
-46
lines changed

7 files changed

+714
-46
lines changed

README.rst

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,53 @@ See the full `Spanner Graph QA Chain`_ tutorial.
179179

180180
.. _`Spanner Graph QA Chain`: https://github.com/googleapis/langchain-google-spanner-python/blob/main/docs/graph_qa_chain.ipynb
181181

182+
Spanner Graph Retrievers Usage
183+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
184+
185+
Use ``SpannerGraphTextToGQLRetriever`` to translate natural language question to GQL and query SpannerGraphStore.
186+
187+
.. code:: python
188+
189+
from langchain_google_spanner import SpannerGraphStore, SpannerGraphTextToGQLRetriever
190+
from langchain_google_vertexai import ChatVertexAI
191+
192+
193+
graph = SpannerGraphStore(
194+
instance_id="my-instance",
195+
database_id="my-database",
196+
graph_name="my_graph",
197+
)
198+
llm = ChatVertexAI()
199+
retriever = SpannerGraphTextToGQLRetriever.from_params(
200+
graph_store=graph,
201+
llm=llm
202+
)
203+
retriever.invoke("Where does Elias Thorne's sibling live?")
204+
205+
Use ``SpannerGraphVectorContextRetriever`` to perform vector search on embeddings that are stored in the nodes in a SpannerGraphStore. If expand_by_hops is provided, the nodes and edges at a distance upto the expand_by_hops from the nodes found in the vector search will also be returned.
206+
207+
.. code:: python
208+
209+
from langchain_google_spanner import SpannerGraphStore, SpannerGraphVectorContextRetriever
210+
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
211+
212+
213+
graph = SpannerGraphStore(
214+
instance_id="my-instance",
215+
database_id="my-database",
216+
graph_name="my_graph",
217+
)
218+
embedding_service = VertexAIEmbeddings(model_name="text-embedding-004")
219+
retriever = SpannerGraphVectorContextRetriever.from_params(
220+
graph_store=graph,
221+
embedding_service=embedding_service,
222+
label_expr="Person",
223+
embeddings_column="embeddings",
224+
top_k=1,
225+
expand_by_hops=1,
226+
)
227+
retriever.invoke("Who lives in desert?")
228+
182229
183230
Contributions
184231
~~~~~~~~~~~~~

docs/graph_qa_chain.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@
150150
"source": [
151151
"# @markdown Please fill in the value below with your Google Cloud project ID and then run the cell.\n",
152152
"\n",
153-
"PROJECT_ID = \"\" # @param {type:\"string\"}\n",
153+
"PROJECT_ID = \"my-project-id\" # @param {type:\"string\"}\n",
154154
"\n",
155155
"# Set the project id\n",
156156
"!gcloud config set project {PROJECT_ID}\n",

src/langchain_google_spanner/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515
from langchain_google_spanner.chat_message_history import SpannerChatMessageHistory
1616
from langchain_google_spanner.graph_qa import SpannerGraphQAChain
17+
from langchain_google_spanner.graph_retriever import (
18+
SpannerGraphTextToGQLRetriever,
19+
SpannerGraphVectorContextRetriever,
20+
)
1721
from langchain_google_spanner.graph_store import SpannerGraphStore
1822
from langchain_google_spanner.vector_store import (
1923
DistanceStrategy,
@@ -38,4 +42,6 @@
3842
"SecondaryIndex",
3943
"QueryParameters",
4044
"DistanceStrategy",
45+
"SpannerGraphTextToGQLRetriever",
46+
"SpannerGraphVectorContextRetriever",
4147
]

src/langchain_google_spanner/graph_qa.py

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from __future__ import annotations
1616

17-
import re
1817
from typing import Any, Dict, List, Optional
1918

2019
from langchain.chains.base import Chain
@@ -28,6 +27,7 @@
2827

2928
from langchain_google_spanner.graph_store import SpannerGraphStore
3029

30+
from .graph_utils import extract_gql, fix_gql_syntax
3131
from .prompts import (
3232
DEFAULT_GQL_FIX_TEMPLATE,
3333
DEFAULT_GQL_TEMPLATE,
@@ -71,50 +71,6 @@ class VerifyGqlOutput(BaseModel):
7171
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
7272

7373

74-
def fix_gql_syntax(query: str) -> str:
75-
"""Fixes the syntax of a GQL query.
76-
Example 1:
77-
Input:
78-
MATCH (p:paper {id: 0})-[c:cites*8]->(p2:paper)
79-
Output:
80-
MATCH (p:paper {id: 0})-[c:cites]->{8}(p2:paper)
81-
Example 2:
82-
Input:
83-
MATCH (p:paper {id: 0})-[c:cites*1..8]->(p2:paper)
84-
Output:
85-
MATCH (p:paper {id: 0})-[c:cites]->{1:8}(p2:paper)
86-
87-
Args:
88-
query: The input GQL query.
89-
90-
Returns:
91-
Possibly modified GQL query.
92-
"""
93-
94-
query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]->", r"-[\1:\2]->{\3,\4}", query)
95-
query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\]->", r"-[\1:\2]->{\3}", query)
96-
query = re.sub(r"<-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]-", r"<-[\1:\2]-{\3,\4}", query)
97-
query = re.sub(r"<-\[(.*?):(\w+)\*(\d+)\]-", r"<-[\1:\2]-{\3}", query)
98-
query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\.\.(\d+)\]-", r"-[\1:\2]-{\3,\4}", query)
99-
query = re.sub(r"-\[(.*?):(\w+)\*(\d+)\]-", r"-[\1:\2]-{\3}", query)
100-
return query
101-
102-
103-
def extract_gql(text: str) -> str:
104-
"""Extract GQL query from a text.
105-
106-
Args:
107-
text: Text to extract GQL query from.
108-
109-
Returns:
110-
GQL query extracted from the text.
111-
"""
112-
pattern = r"```(.*?)```"
113-
matches = re.findall(pattern, text, re.DOTALL)
114-
query = matches[0] if matches else text
115-
return fix_gql_syntax(query)
116-
117-
11874
class SpannerGraphQAChain(Chain):
11975
"""Chain for question-answering against a Spanner Graph database by
12076
generating GQL statements from natural language questions.

0 commit comments

Comments
 (0)