Skip to content
8 changes: 7 additions & 1 deletion redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ def _get_command_keys(self, *args):
redis_conn = self.get_default_node().redis_connection
return self.commands_parser.get_keys(redis_conn, *args)

def determine_slot(self, *args) -> int:
def determine_slot(self, *args) -> Optional[int]:
"""
Figure out what slot to use based on args.

Expand All @@ -1156,6 +1156,12 @@ def determine_slot(self, *args) -> int:

# Get the keys in the command

# CLIENT TRACKING is a special case.
# It doesn't have any keys, it needs to be sent to the provided nodes
# By default it will be sent to all nodes.
if command.upper() == "CLIENT TRACKING":
return None

# EVAL and EVALSHA are common enough that it's wasteful to go to the
# redis server to parse the keys. Besides, there is a bug in redis<7.0
# where `self._get_command_keys()` fails anyway. So, we special case
Expand Down
143 changes: 143 additions & 0 deletions redis/commands/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
Dict,
Iterable,
Iterator,
Expand All @@ -11,6 +12,7 @@
Mapping,
NoReturn,
Optional,
Sequence,
Union,
)

Expand All @@ -25,6 +27,7 @@
PatternT,
ResponseT,
)
from redis.utils import deprecated_function

from .core import (
ACLCommands,
Expand Down Expand Up @@ -755,6 +758,76 @@ def readwrite(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT:
self.read_from_replicas = False
return self.execute_command("READWRITE", target_nodes=target_nodes)

@deprecated_function(
version="7.2.0",
reason="Use client-side caching feature instead.",
)
def client_tracking_on(
self,
clientid: Optional[int] = None,
prefix: Sequence[KeyT] = [],
bcast: bool = False,
optin: bool = False,
optout: bool = False,
noloop: bool = False,
target_nodes: Optional["TargetNodesT"] = "all",
) -> ResponseT:
"""
Enables the tracking feature of the Redis server, that is used
for server assisted client side caching.

When clientid is provided - in target_nodes only the node that owns the
connection with this id should be provided.
When clientid is not provided - target_nodes can be any node.

For more information see https://redis.io/commands/client-tracking
"""
return self.client_tracking(
True,
clientid,
prefix,
bcast,
optin,
optout,
noloop,
target_nodes=target_nodes,
)

@deprecated_function(
version="7.2.0",
reason="Use client-side caching feature instead.",
)
def client_tracking_off(
self,
clientid: Optional[int] = None,
prefix: Sequence[KeyT] = [],
bcast: bool = False,
optin: bool = False,
optout: bool = False,
noloop: bool = False,
target_nodes: Optional["TargetNodesT"] = "all",
) -> ResponseT:
"""
Disables the tracking feature of the Redis server, that is used
for server assisted client side caching.

When clientid is provided - in target_nodes only the node that owns the
connection with this id should be provided.
When clientid is not provided - target_nodes can be any node.

For more information see https://redis.io/commands/client-tracking
"""
return self.client_tracking(
False,
clientid,
prefix,
bcast,
optin,
optout,
noloop,
target_nodes=target_nodes,
)


class AsyncClusterManagementCommands(
ClusterManagementCommands, AsyncManagementCommands
Expand Down Expand Up @@ -782,6 +855,76 @@ async def cluster_delslots(self, *slots: EncodableT) -> List[bool]:
)
)

@deprecated_function(
version="7.2.0",
reason="Use client-side caching feature instead.",
)
async def client_tracking_on(
self,
clientid: Optional[int] = None,
prefix: Sequence[KeyT] = [],
bcast: bool = False,
optin: bool = False,
optout: bool = False,
noloop: bool = False,
target_nodes: Optional["TargetNodesT"] = "all",
) -> Awaitable[ResponseT]:
"""
Enables the tracking feature of the Redis server, that is used
for server assisted client side caching.

When clientid is provided - in target_nodes only the node that owns the
connection with this id should be provided.
When clientid is not provided - target_nodes can be any node.

For more information see https://redis.io/commands/client-tracking
"""
return await self.client_tracking(
True,
clientid,
prefix,
bcast,
optin,
optout,
noloop,
target_nodes=target_nodes,
)

@deprecated_function(
version="7.2.0",
reason="Use client-side caching feature instead.",
)
async def client_tracking_off(
self,
clientid: Optional[int] = None,
prefix: Sequence[KeyT] = [],
bcast: bool = False,
optin: bool = False,
optout: bool = False,
noloop: bool = False,
target_nodes: Optional["TargetNodesT"] = "all",
) -> Awaitable[ResponseT]:
"""
Disables the tracking feature of the Redis server, that is used
for server assisted client side caching.

When clientid is provided - in target_nodes only the node that owns the
connection with this id should be provided.
When clientid is not provided - target_nodes can be any node.

For more information see https://redis.io/commands/client-tracking
"""
return await self.client_tracking(
False,
clientid,
prefix,
bcast,
optin,
optout,
noloop,
target_nodes=target_nodes,
)


class ClusterDataAccessCommands(DataAccessCommands):
"""
Expand Down
2 changes: 1 addition & 1 deletion redis/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def client_tracking(
if noloop:
pieces.append("NOLOOP")

return self.execute_command("CLIENT TRACKING", *pieces)
return self.execute_command("CLIENT TRACKING", *pieces, **kwargs)

def client_trackinginfo(self, **kwargs) -> ResponseT:
"""
Expand Down
24 changes: 24 additions & 0 deletions tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1742,6 +1742,30 @@ def test_client_trackinginfo(self, r):
assert len(res) > 2
assert "prefixes" in res or b"prefixes" in res

@skip_if_server_version_lt("6.0.0")
@skip_if_redis_enterprise()
def test_client_tracking(self, r):
# simple case - will execute on all node
assert r.client_tracking_on()
assert r.client_tracking_off()

# id based
node = r.get_default_node()
# when id is provided - the command should be sent to the node that
# owns the connection with this id
client_id = node.redis_connection.client_id()
assert r.client_tracking_on(clientid=client_id, target_nodes=node)
assert r.client_tracking_off(clientid=client_id, target_nodes=node)

# execute with client id and prefixes and bcast
assert r.client_tracking_on(
clientid=client_id, prefix=["foo", "bar"], bcast=True, target_nodes=node
)

# now with some prefixes and without bcast
with pytest.raises(DataError):
assert r.client_tracking_on(prefix=["foo", "bar", "blee"])

@skip_if_server_version_lt("2.9.50")
def test_client_pause(self, r):
node = r.get_primaries()[0]
Expand Down
Loading