Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/langchain_google_spanner/chat_message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def create_chat_history_table(
instance_id: str,
database_id: str,
table_name: str,
client: spanner.Client = spanner.Client(),
client: Optional[spanner.Client] = None,
) -> None:
"""
Create a chat history table in a Cloud Spanner database.
Expand All @@ -142,7 +142,7 @@ def create_chat_history_table(
instance_id (str): The ID of the Cloud Spanner instance.
database_id (str): The ID of the Cloud Spanner database.
table_name (str): The name of the table to be created.
client (spanner.Client, optional): An instance of the Cloud Spanner client. Defaults to spanner.Client().
client (spanner.Client, optional): An instance of the Cloud Spanner client. Defaults to None.

Raises:
Exception: If the specified instance or database does not exist.
Expand Down
9 changes: 5 additions & 4 deletions src/langchain_google_spanner/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def init_vector_store_table(
instance_id: str,
database_id: str,
table_name: str,
client: spanner.Client = spanner.Client(),
client: Optional[spanner.Client] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that this is None we need to get a client via client_with_user_agent(client, USER_AGENT_VECTOR_STORE)

id_column: Union[str, TableColumn] = ID_COLUMN_NAME,
content_column: str = CONTENT_COLUMN_NAME,
embedding_column: str = EMBEDDING_COLUMN_NAME,
Expand All @@ -300,6 +300,7 @@ def init_vector_store_table(
- vector_size (Optional[int]): The size of the vector. Defaults to None.
"""

client = client_with_user_agent(client, USER_AGENT_VECTOR_STORE)
instance = client.instance(instance_id)

if not instance.exists():
Expand Down Expand Up @@ -446,7 +447,7 @@ def __init__(
id_column: str = ID_COLUMN_NAME,
content_column: str = CONTENT_COLUMN_NAME,
embedding_column: str = EMBEDDING_COLUMN_NAME,
client: spanner.Client = spanner.Client(),
client: Optional[spanner.Client] = None,
metadata_columns: Optional[List[str]] = None,
ignore_metadata_columns: Optional[List[str]] = None,
metadata_json_column: Optional[str] = None,
Expand Down Expand Up @@ -1109,7 +1110,7 @@ def from_documents( # type: ignore[override]
content_column: str = CONTENT_COLUMN_NAME,
embedding_column: str = EMBEDDING_COLUMN_NAME,
ids: Optional[List[str]] = None,
client: spanner.Client = spanner.Client(),
client: Optional[spanner.Client] = None,
metadata_columns: Optional[List[str]] = None,
ignore_metadata_columns: Optional[List[str]] = None,
metadata_json_column: Optional[str] = None,
Expand Down Expand Up @@ -1170,7 +1171,7 @@ def from_texts( # type: ignore[override]
content_column: str = CONTENT_COLUMN_NAME,
embedding_column: str = EMBEDDING_COLUMN_NAME,
ids: Optional[List[str]] = None,
client: spanner.Client = spanner.Client(),
client: Optional[spanner.Client] = None,
metadata_columns: Optional[List[str]] = None,
ignore_metadata_columns: Optional[List[str]] = None,
metadata_json_column: Optional[str] = None,
Expand Down