Skip to content

Commit 2b4a091

Browse files
authored
Create neo4j_driver.py
1 parent f55d66d commit 2b4a091

File tree

1 file changed

+367
-0
lines changed

1 file changed

+367
-0
lines changed

knowledge/graph/neo4j_driver.py

Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
"""
2+
Enterprise Neo4j Driver - High-performance Graph Database Integration with Cluster Support
3+
"""
4+
5+
from __future__ import annotations
6+
import asyncio
7+
from typing import Any, Dict, List, Optional, Union
8+
from functools import lru_cache
9+
import logging
10+
from dataclasses import dataclass
11+
import ssl
12+
from neo4j import (
13+
AsyncGraphDatabase,
14+
AsyncSession,
15+
AsyncTransaction,
16+
EagerResult,
17+
RoutingControl
18+
)
19+
from neo4j.exceptions import (
20+
Neo4jError,
21+
ServiceUnavailable,
22+
SessionExpired,
23+
TransientError
24+
)
25+
from prometheus_client import ( # type: ignore
26+
Histogram,
27+
Counter,
28+
Gauge
29+
)
30+
from ..utils.metrics import MetricsSystem
31+
from ..utils.serialization import deserialize_neo4j
32+
33+
# Prometheus Metrics
34+
QUERY_DURATION = Histogram(
35+
'neo4j_query_duration_seconds',
36+
'Query execution time distribution',
37+
['query_type', 'cluster']
38+
)
39+
CONNECTION_GAUGE = Gauge(
40+
'neo4j_connections_active',
41+
'Active Neo4j connections',
42+
['cluster']
43+
)
44+
RETRY_COUNTER = Counter(
45+
'neo4j_retries_total',
46+
'Total query retry attempts',
47+
['cluster', 'error_type']
48+
)
49+
50+
@dataclass(frozen=True)
51+
class Neo4jConfig:
52+
"""Immutable configuration for Neo4j cluster connectivity"""
53+
uri: str = "neo4j://localhost:7687"
54+
auth: tuple = ("neo4j", "password")
55+
max_connection_pool_size: int = 100
56+
connection_timeout: int = 30 # seconds
57+
encrypted: bool = True
58+
trust: str = "TRUST_ALL_CERTIFICATES" # TRUST_SYSTEM_CA_SIGNED_CERTIFICATES
59+
max_transaction_retry_time: int = 30 # seconds
60+
database: str = "neo4j"
61+
load_balancing_strategy: str = "ROUND_ROBIN"
62+
max_retries: int = 5
63+
retry_delay: float = 0.5 # seconds
64+
fetch_size: int = 1000
65+
cert_path: Optional[str] = None
66+
67+
class Neo4jDriver:
68+
"""Enterprise-grade Neo4j driver with connection pooling and automatic retries"""
69+
70+
def __init__(self, config: Neo4jConfig):
71+
self._config = config
72+
self._driver = None
73+
self._metrics = MetricsSystem([])
74+
self._logger = logging.getLogger("aelion.neo4j")
75+
self._ssl_context = self._configure_ssl()
76+
self._cluster_nodes = []
77+
78+
async def connect(self):
79+
"""Initialize connection pool and cluster discovery"""
80+
kwargs = {
81+
"auth": self._config.auth,
82+
"max_connection_pool_size": self._config.max_connection_pool_size,
83+
"connection_timeout": self._config.connection_timeout,
84+
"encrypted": self._config.encrypted,
85+
"trust": self._config.trust,
86+
"user_agent": "AelionAI/1.0",
87+
"keep_alive": True,
88+
"fetch_size": self._config.fetch_size
89+
}
90+
91+
if self._ssl_context:
92+
kwargs["ssl"] = self._ssl_context
93+
94+
self._driver = AsyncGraphDatabase.driver(
95+
self._config.uri,
96+
**kwargs
97+
)
98+
99+
await self._discover_cluster()
100+
CONNECTION_GAUGE.labels(cluster=self.cluster_name).inc()
101+
102+
async def _discover_cluster(self):
103+
"""Discover cluster topology and update routing tables"""
104+
try:
105+
with await self._driver.session(database="system") as session:
106+
result = await session.run(
107+
"SHOW SERVERS YIELD id, address, role, currentStatus"
108+
)
109+
nodes = await result.values()
110+
self._cluster_nodes = [
111+
{
112+
"id": node[0],
113+
"address": node[1],
114+
"role": node[2],
115+
"status": node[3]
116+
} for node in nodes
117+
]
118+
self._logger.info(f"Discovered {len(nodes)} cluster nodes")
119+
except Exception as e:
120+
self._logger.error(f"Cluster discovery failed: {str(e)}")
121+
raise
122+
123+
def _configure_ssl(self) -> Optional[ssl.SSLContext]:
124+
"""Configure SSL context for encrypted connections"""
125+
if not self._config.encrypted:
126+
return None
127+
128+
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
129+
ctx.check_hostname = False
130+
131+
if self._config.trust == "TRUST_SYSTEM_CA_SIGNED_CERTIFICATES":
132+
ctx.verify_mode = ssl.CERT_REQUIRED
133+
ctx.load_default_certs()
134+
elif self._config.cert_path:
135+
ctx.load_verify_locations(self._config.cert_path)
136+
ctx.verify_mode = ssl.CERT_REQUIRED
137+
else:
138+
ctx.verify_mode = ssl.CERT_NONE
139+
140+
return ctx
141+
142+
@property
143+
def cluster_name(self) -> str:
144+
"""Extract cluster name from connection URI"""
145+
return self._config.uri.split("@")[-1].split("/")[0]
146+
147+
@MetricsSystem.time_method(QUERY_DURATION, labels=["query_type", "cluster"])
148+
async def execute_query(
149+
self,
150+
query: str,
151+
parameters: Optional[Dict] = None,
152+
*,
153+
tx: Optional[AsyncTransaction] = None,
154+
routing: RoutingControl = RoutingControl.WRITE,
155+
**kwargs
156+
) -> EagerResult:
157+
"""
158+
Execute Cypher query with automatic retry and metrics collection
159+
"""
160+
retries = 0
161+
parameters = parameters or {}
162+
163+
while retries <= self._config.max_retries:
164+
session: Optional[AsyncSession] = None
165+
try:
166+
session = self._driver.session(
167+
database=self._config.database,
168+
default_access_mode=routing
169+
)
170+
171+
if tx:
172+
result = await tx.run(query, parameters, **kwargs)
173+
else:
174+
result = await session.run(query, parameters, **kwargs)
175+
176+
eager_result = await result.to_eager_result()
177+
178+
if session:
179+
await session.close()
180+
181+
return eager_result
182+
183+
except (ServiceUnavailable, SessionExpired, TransientError) as e:
184+
RETRY_COUNTER.labels(
185+
cluster=self.cluster_name,
186+
error_type=type(e).__name__
187+
).inc()
188+
189+
if retries >= self._config.max_retries:
190+
raise Neo4jError(
191+
f"Max retries ({self._config.max_retries}) exceeded"
192+
) from e
193+
194+
await self._handle_retry(e, retries)
195+
retries += 1
196+
197+
except Exception as e:
198+
if session:
199+
await session.close()
200+
raise
201+
202+
raise Neo4jError("Unexpected execution path") # Should never reach here
203+
204+
async def _handle_retry(self, error: Exception, retry_count: int):
205+
"""Handle retry logic with exponential backoff and cluster rediscovery"""
206+
delay = self._config.retry_delay * (2 ** retry_count)
207+
self._logger.warning(
208+
f"Retry {retry_count+1} in {delay:.2f}s: {str(error)}"
209+
)
210+
await asyncio.sleep(delay)
211+
await self._discover_cluster()
212+
213+
async def transactional(
214+
self,
215+
query: str,
216+
parameters: Optional[Dict] = None,
217+
**kwargs
218+
) -> Any:
219+
"""Execute transactional query with automatic commit/rollback"""
220+
async with self._driver.session(
221+
database=self._config.database
222+
) as session:
223+
try:
224+
return await session.execute_write(
225+
lambda tx: self.execute_query(query, parameters, tx=tx, **kwargs)
226+
)
227+
except Neo4jError as e:
228+
await self._log_transaction_error(e)
229+
raise
230+
231+
async def _log_transaction_error(self, error: Neo4jError):
232+
"""Log detailed transaction error information"""
233+
self._logger.error(
234+
f"Transaction failed: {error.code} - {error.message}"
235+
)
236+
if error.classification == "ClientError":
237+
self._logger.debug(f"Query parameters: {error.parameters}")
238+
239+
@lru_cache(maxsize=1000)
240+
async def cached_query(
241+
self,
242+
query: str,
243+
parameters: Optional[Dict] = None,
244+
ttl: int = 300
245+
) -> List[Dict]:
246+
"""Execute query with result caching (LRU + TTL)"""
247+
cache_key = hash((query, frozenset(parameters.items() if parameters else {})))
248+
if not hasattr(self, '_query_cache'):
249+
self._query_cache = {}
250+
251+
if cache_key in self._query_cache:
252+
return self._query_cache[cache_key]
253+
254+
result = await self.execute_query(query, parameters)
255+
data = deserialize_neo4j(result)
256+
self._query_cache[cache_key] = data
257+
258+
if ttl > 0:
259+
async def expire_cache():
260+
await asyncio.sleep(ttl)
261+
if cache_key in self._query_cache:
262+
del self._query_cache[cache_key]
263+
asyncio.create_task(expire_cache())
264+
265+
return data
266+
267+
async def batch_operations(
268+
self,
269+
queries: List[str],
270+
parameters_list: List[Dict],
271+
batch_size: int = 1000
272+
) -> List[Any]:
273+
"""Execute batch operations with chunking and parallel execution"""
274+
results = []
275+
for i in range(0, len(queries), batch_size):
276+
chunk = queries[i:i+batch_size]
277+
params_chunk = parameters_list[i:i+batch_size]
278+
279+
tasks = [
280+
self.execute_query(q, p)
281+
for q, p in zip(chunk, params_chunk)
282+
]
283+
results.extend(await asyncio.gather(*tasks))
284+
285+
return results
286+
287+
async def close(self):
288+
"""Close all connections and release resources"""
289+
if self._driver:
290+
await self._driver.close()
291+
CONNECTION_GAUGE.labels(cluster=self.cluster_name).dec()
292+
self._logger.info("Neo4j driver closed")
293+
294+
async def __aenter__(self):
295+
await self.connect()
296+
return self
297+
298+
async def __aexit__(self, exc_type, exc, tb):
299+
await self.close()
300+
301+
# Schema Management Utilities
302+
class Neo4jSchemaManager:
303+
"""Schema versioning and migration utilities"""
304+
305+
def __init__(self, driver: Neo4jDriver):
306+
self.driver = driver
307+
self._lock = asyncio.Lock()
308+
309+
async def initialize_schema(self):
310+
"""Create indexes and constraints if missing"""
311+
constraints = [
312+
"CREATE CONSTRAINT unique_agent_id IF NOT EXISTS "
313+
"FOR (a:Agent) REQUIRE a.id IS UNIQUE",
314+
"CREATE INDEX agent_type_index IF NOT EXISTS "
315+
"FOR (a:Agent) ON (a.type)"
316+
]
317+
318+
async with self._lock:
319+
for constraint in constraints:
320+
await self.driver.execute_query(constraint)
321+
322+
async def migrate_data(self, migration_script: str):
323+
"""Execute schema migration script atomically"""
324+
await self.driver.execute_query(
325+
"CALL apoc.schema.assert({}, {}, true) YIELD label, key, unique, action "
326+
"RETURN *"
327+
)
328+
await self.driver.execute_query(migration_script)
329+
330+
# Example Usage
331+
if __name__ == "__main__":
332+
import json
333+
from dotenv import load_dotenv
334+
335+
load_dotenv()
336+
337+
async def main():
338+
config = Neo4jConfig(
339+
uri="neo4j://cluster.aelion.ai:7687",
340+
auth=("neo4j", os.getenv("NEO4J_PASSWORD")),
341+
encrypted=True,
342+
cert_path="/etc/ssl/neo4j-ca.pem"
343+
)
344+
345+
async with Neo4jDriver(config) as driver:
346+
await driver.initialize_schema()
347+
348+
# Create agent node
349+
result = await driver.execute_query(
350+
"CREATE (a:Agent {id: $id, type: $type}) RETURN a",
351+
{"id": "agent_001", "type": "supervisor"}
352+
)
353+
print("Created agent:", json.dumps(deserialize_neo4j(result), indent=2))
354+
355+
# Complex query example
356+
result = await driver.execute_query(
357+
"""
358+
MATCH (src:Agent)-[rel:COMMUNICATES_WITH]->(dest:Agent)
359+
WHERE src.type = $type
360+
RETURN src.id AS source, collect(dest.id) AS targets
361+
""",
362+
{"type": "worker"},
363+
routing=RoutingControl.READ
364+
)
365+
print("Communication graph:", json.dumps(deserialize_neo4j(result), indent=2))
366+
367+
asyncio.run(main())

0 commit comments

Comments
 (0)