Skip to content

Commit c12ded9

Browse files
authored
feat: Add key to engine constructor (#156)
* feat: feat: Add `key` to engine constructor * add column test for tc
1 parent 2d4b268 commit c12ded9

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

src/langchain_google_cloud_sql_pg/engine.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,19 @@ class PostgresEngine:
9191
"""A class for managing connections to a Cloud SQL for Postgres database."""
9292

9393
_connector: Optional[Connector] = None
94+
__create_key = object()
9495

9596
def __init__(
9697
self,
98+
key: object,
9799
engine: AsyncEngine,
98100
loop: Optional[asyncio.AbstractEventLoop],
99101
thread: Optional[Thread],
100102
):
103+
if key != PostgresEngine.__create_key:
104+
raise Exception(
105+
"Only create class through 'create' or 'create_sync' methods!"
106+
)
101107
self._engine = engine
102108
self._loop = loop
103109
self._thread = thread
@@ -191,7 +197,7 @@ async def getconn() -> asyncpg.Connection:
191197
"postgresql+asyncpg://",
192198
async_creator=getconn,
193199
)
194-
return cls(engine, loop, thread)
200+
return cls(cls.__create_key, engine, loop, thread)
195201

196202
@classmethod
197203
async def afrom_instance(
@@ -218,7 +224,7 @@ async def afrom_instance(
218224

219225
@classmethod
220226
def from_engine(cls, engine: AsyncEngine) -> PostgresEngine:
221-
return cls(engine, None, None)
227+
return cls(cls.__create_key, engine, None, None)
222228

223229
async def _aexecute(self, query: str, params: Optional[dict] = None):
224230
"""Execute a SQL query."""

tests/test_postgresql_engine.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ async def getconn() -> asyncpg.Connection:
173173
async def test_column(self, engine):
174174
with pytest.raises(ValueError):
175175
Column("test", VARCHAR)
176+
with pytest.raises(ValueError):
177+
Column(1, "INTEGER")
176178

177179

178180
@pytest.mark.asyncio
@@ -275,3 +277,11 @@ async def test_password(
275277
assert engine
276278
engine._execute("SELECT 1")
277279
PostgresEngine._connector = None
280+
281+
async def test_engine_constructor_key(
282+
self,
283+
engine,
284+
):
285+
key = object()
286+
with pytest.raises(Exception):
287+
PostgresEngine(key, engine)

0 commit comments

Comments
 (0)