Skip to content

Commit 6f71969

Browse files
Luis Garzatimm4205
authored andcommitted
feat(cache): Added LRU (Least Recently Used) cache for prepared statements
1 parent fd81061 commit 6f71969

File tree

4 files changed

+156
-16
lines changed

4 files changed

+156
-16
lines changed

redshift_connector/core.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import socket
44
import typing
5-
from collections import deque
5+
from collections import deque, OrderedDict
66
from copy import deepcopy
77
from datetime import datetime as Datetime
88
from datetime import timedelta as Timedelta
@@ -1826,11 +1826,14 @@ def execute(self: "Connection", cursor: Cursor, operation: str, vals) -> None:
18261826
# transforms user provided bind parameters to server friendly bind parameters
18271827
params: typing.Tuple[typing.Optional[typing.Tuple[int, int, typing.Callable]], ...] = ()
18281828
has_bind_parameters: bool = False if vals is None else True
1829-
# multi dimensional dictionary to store the data
1830-
# cache = self._caches[cursor.paramstyle][pid]
1831-
# cache = {'statement': {}, 'ps': {}}
1832-
# statement stores the data of the statement, ps store the data of the prepared statement
1833-
# statement = {operation(query): tuple from 'convert_paramstyle'(statement, make_args)}
1829+
statements_to_close = []
1830+
# Cache structure for prepared statements:
1831+
# self._caches[paramstyle][pid] contains:
1832+
# - 'statement': stores SQL statements and their parameter processors
1833+
# - 'ps': stores prepared statement metadata
1834+
# - 'statement_dict': OrderedDict tracking most recently used prepared statements
1835+
# (used when max_prepared_statements is set)
1836+
# Each statement entry contains (processed_statement, parameter_binding_function)
18341837
try:
18351838
cache = self._caches[cursor.paramstyle][pid]
18361839
except KeyError:
@@ -1842,7 +1845,11 @@ def execute(self: "Connection", cursor: Cursor, operation: str, vals) -> None:
18421845
try:
18431846
cache = param_cache[pid]
18441847
except KeyError:
1845-
cache = param_cache[pid] = {"statement": {}, "ps": {}}
1848+
cache = param_cache[pid] = {
1849+
"statement": {},
1850+
"ps": {},
1851+
"statement_dict": OrderedDict() if self.max_prepared_statements > 0 else None
1852+
}
18461853

18471854
try:
18481855
statement, make_args = cache["statement"][operation]
@@ -1863,6 +1870,9 @@ def execute(self: "Connection", cursor: Cursor, operation: str, vals) -> None:
18631870

18641871
try:
18651872
ps = cache["ps"][key]
1873+
# If statement exists, move it to end of ordered dict (most recently used)
1874+
if self.max_prepared_statements > 0 and 'statement_dict' in cache and key in cache["statement_dict"]:
1875+
cache["statement_dict"].move_to_end(key)
18661876
_logger.debug("Using cached prepared statement")
18671877
cursor.ps = ps
18681878
except KeyError:
@@ -1978,12 +1988,21 @@ def execute(self: "Connection", cursor: Cursor, operation: str, vals) -> None:
19781988

19791989
ps["bind_2"] = h_pack(len(output_fc)) + pack("!" + "h" * len(output_fc), *output_fc)
19801990

1981-
if len(cache["ps"]) >= self.max_prepared_statements:
1982-
for p in cache["ps"].values():
1983-
self.close_prepared_statement(p["statement_name_bin"])
1984-
cache["ps"].clear()
19851991
if self.max_prepared_statements > 0:
1992+
# Ensure consistency between ps and statement_dict
1993+
if len(cache["ps"]) != len(cache["statement_dict"]):
1994+
for existing_key in cache["ps"]:
1995+
cache["statement_dict"][existing_key] = None
1996+
1997+
# If cache is full, remove oldest statement
1998+
if len(cache["ps"]) >= self.max_prepared_statements:
1999+
oldest_key, _ = cache["statement_dict"].popitem(last=False)
2000+
statements_to_close.append(cache["ps"][oldest_key]["statement_name_bin"])
2001+
del cache["ps"][oldest_key]
2002+
2003+
# Add new statement to cache and queue
19862004
cache["ps"][key] = ps
2005+
cache["statement_dict"][key] = None
19872006

19882007
cursor._cached_rows.clear()
19892008
cursor._row_count = -1
@@ -2031,6 +2050,10 @@ def execute(self: "Connection", cursor: Cursor, operation: str, vals) -> None:
20312050
else:
20322051
self.handle_messages(cursor)
20332052

2053+
# Clean up prepared statements after query execution and results are returned
2054+
for stmt in statements_to_close:
2055+
self.close_prepared_statement(stmt)
2056+
20342057
def _send_message(self: "Connection", code: bytes, data: bytes) -> None:
20352058
_logger.debug("Sending message with code %s to BE", code)
20362059
try:
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pytest
2+
import redshift_connector
3+
4+
def test_lru_prepared_statements_cache(db_kwargs):
5+
"""
6+
Tests LRU (Least Recently Used) behavior of prepared statements cache:
7+
1. Verifies cache size limits
8+
2. Verifies statement ordering
9+
3. Verifies LRU eviction policy
10+
4. Verifies statement reuse behavior
11+
"""
12+
db_kwargs['max_prepared_statements'] = 5
13+
conn = redshift_connector.connect(**db_kwargs)
14+
cursor = conn.cursor()
15+
16+
try:
17+
# Track statement execution order
18+
executed_statements = []
19+
20+
# Execute 7 unique statements (exceeds cache size of 5)
21+
for i in range(7):
22+
query = f"SELECT %s::int as col1, %s::int as col2, {i} as unique_id"
23+
cursor.execute(query, (i, i + 1))
24+
executed_statements.append(query)
25+
26+
# Get cache and statement queue
27+
cache = conn._caches[cursor.paramstyle][cursor.ps['pid']]
28+
statement_dict = cache['statement_dict']
29+
30+
# Basic cache size verification
31+
assert len(cache['ps']) == 5, f"Cache size should be 5, but was {len(cache['ps'])}"
32+
assert len(statement_dict) == 5, f"Statement dict size should be 5, but was {len(statement_dict)}"
33+
34+
# Verify the most recent statements are in the queue
35+
cached_statements = [key[0] for key in statement_dict.keys()]
36+
last_five_statements = executed_statements[-5:]
37+
assert all(stmt in cached_statements for stmt in last_five_statements), \
38+
"Last 5 statements should be in cache"
39+
40+
# Verify the first two statements were evicted
41+
first_two_statements = executed_statements[:2]
42+
for stmt in first_two_statements:
43+
assert stmt not in cached_statements, f"Statement should have been evicted: {stmt}"
44+
45+
# Test statement reuse
46+
reuse_stmt = executed_statements[-3]
47+
cursor.execute(reuse_stmt, (100, 101))
48+
49+
# Verify the reused statement is now at the end of the queue
50+
assert list(statement_dict.keys())[-1][0] == reuse_stmt, "Reused statement should be most recent"
51+
52+
# Add new statement and verify LRU behavior
53+
new_stmt = "SELECT %s::int as col1, %s::int as col2, 999 as unique_id"
54+
# Track which statement should be evicted (least recently used)
55+
statements_before_new = [key[0] for key in statement_dict]
56+
cursor.execute(new_stmt, (999, 1000))
57+
58+
# Verify cache size and new statement presence
59+
assert len(cache['ps']) == 5, f"Cache size should still be 5, but was {len(cache['ps'])}"
60+
assert list(statement_dict.keys())[-1][0] == new_stmt, "New statement should be most recent"
61+
62+
# Verify LRU eviction - the least recently used statement should be gone
63+
current_statements = [key[0] for key in statement_dict]
64+
lru_statement = statements_before_new[0] # First statement in queue before adding new one
65+
assert lru_statement not in current_statements, \
66+
f"Least recently used statement should have been evicted: {lru_statement}"
67+
68+
finally:
69+
cursor.close()
70+
conn.close()

test/integration/test_query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,8 +474,8 @@ def test_handle_COMMAND_COMPLETE_closed_ps(con, mocker) -> None:
474474
"name": "max_prepared_statements_limit_2",
475475
"max_prepared_statements": 2,
476476
"queries": ["SELECT 1", "SELECT 2"],
477-
"expected_close_calls": 2,
478-
"expected_cache_size": 1
477+
"expected_close_calls": 1,
478+
"expected_cache_size": 2
479479
}
480480
])
481481
def test_max_prepared_statement(con, mocker, test_case) -> None:

test/unit/test_core.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,16 +178,16 @@ def test_handle_command_complete_no_cache_cleanup(command_status, connection):
178178
"name": "max_prepared_statements_limit_1",
179179
"max_prepared_statements": 2,
180180
"queries": ["SELECT 1", "SELECT 2", "SELECT 3"],
181-
"expected_close_calls": 2
181+
"expected_close_calls": 1
182182
},
183-
{
183+
{
184184
"name": "max_prepared_statements_limit_2",
185185
"max_prepared_statements": 2,
186186
"queries": ["SELECT 1", "SELECT 2"],
187187
"expected_close_calls": 0
188188
}
189189
])
190-
def test_max_prepared_statement_zero(mocker, test_case):
190+
def test_prepared_statement_cache_behavior(mocker, test_case):
191191
"""
192192
Test prepared statement cache management in execute() with different configurations.
193193
:type mocker: object
@@ -217,6 +217,53 @@ def test_max_prepared_statement_zero(mocker, test_case):
217217
assert mock_connection.close_prepared_statement.call_count == test_case["expected_close_calls"]
218218

219219

220+
@pytest.mark.parametrize("test_case", [
221+
{
222+
"name": "statement_reuse",
223+
"queries": ["SELECT 1", "SELECT 2", "SELECT 1", "SELECT 3"],
224+
"expected_in_cache": ["SELECT 1", "SELECT 3"],
225+
"expected_not_in_cache": ["SELECT 2"],
226+
"expected_close_calls": 1
227+
},
228+
{
229+
"name": "lru_order",
230+
"queries": ["SELECT 1", "SELECT 2", "SELECT 1", "SELECT 2", "SELECT 3"],
231+
"expected_in_cache": ["SELECT 2", "SELECT 3"],
232+
"expected_not_in_cache": ["SELECT 1"],
233+
"expected_close_calls": 1
234+
}
235+
])
236+
def test_prepared_statement_lru_behavior(mocker, test_case):
237+
"""Test LRU behavior of prepared statement cache."""
238+
from os import getpid
239+
240+
mock_connection = Connection.__new__(Connection)
241+
mock_connection.max_prepared_statements = 2 # Set to 2 for LRU testing
242+
mock_connection.merge_socket_read = True
243+
mock_connection._caches = {}
244+
mock_connection._send_message = mocker.Mock()
245+
mock_connection._write = mocker.Mock()
246+
mock_connection._flush = mocker.Mock()
247+
mock_connection.handle_messages = mocker.Mock()
248+
mock_connection.handle_messages_merge_socket_read = mocker.Mock()
249+
mock_connection.close_prepared_statement = mocker.Mock()
250+
251+
mock_cursor = mocker.Mock()
252+
mock_cursor.paramstyle = "named"
253+
254+
for query in test_case["queries"]:
255+
mock_connection.execute(mock_cursor, query, None)
256+
257+
assert mock_connection.close_prepared_statement.call_count == test_case["expected_close_calls"]
258+
259+
# Verify cache contents
260+
pid = getpid()
261+
cache = mock_connection._caches["named"][pid]
262+
for query in test_case["expected_in_cache"]:
263+
assert any(key[0] == query for key in cache["statement_dict"])
264+
for query in test_case["expected_not_in_cache"]:
265+
assert not any(key[0] == query for key in cache["statement_dict"])
266+
220267
@pytest.mark.parametrize("test_case", [
221268
{
222269
"max_prepared_statements" : 0,

0 commit comments

Comments
 (0)