Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions redis/_parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
OSSNodeMigratedNotification,
OSSNodeMigratingNotification,
)
from redis.utils import safe_str

if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
from asyncio import timeout as async_timeout
Expand Down Expand Up @@ -194,8 +195,9 @@ def parse_oss_maintenance_completed_msg(response):
# Expected message format is:
# SMIGRATED <seq_number> <host:port> <slot, range1-range2,...>
id = response[1]
node_address = response[2]
node_address = safe_str(response[2])
slots = response[3]

return OSSNodeMigratedNotification(id, node_address, slots)

@staticmethod
Expand Down Expand Up @@ -225,9 +227,7 @@ def parse_moving_msg(response):
if response[3] is None:
host, port = None, None
else:
value = response[3]
if isinstance(value, bytes):
value = value.decode()
value = safe_str(response[3])
host, port = value.split(":")
port = int(port) if port is not None else None

Expand Down
18 changes: 17 additions & 1 deletion redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@
from redis.lock import Lock
from redis.maint_notifications import (
MaintNotificationsConfig,
OSSMaintNotificationsHandler,
)
from redis.retry import Retry
from redis.utils import (
_set_info_logger,
check_protocol_version,
deprecated_args,
get_lib_version,
safe_str,
Expand Down Expand Up @@ -250,6 +252,9 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
event_dispatcher: Optional[EventDispatcher] = None,
maint_notifications_config: Optional[MaintNotificationsConfig] = None,
oss_cluster_maint_notifications_handler: Optional[
OSSMaintNotificationsHandler
] = None,
) -> None:
"""
Initialize a new Redis client.
Expand Down Expand Up @@ -288,6 +293,11 @@ def __init__(
will be enabled by default (logic is included in the connection pool
initialization).
Argument is ignored when connection_pool is provided.
oss_cluster_maint_notifications_handler:
handler for OSS cluster notifications - see
`redis.maint_notifications.OSSMaintNotificationsHandler` for details.
Only supported with RESP3
Argument is ignored when connection_pool is provided.
"""
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
Expand Down Expand Up @@ -357,7 +367,7 @@ def __init__(
"ssl_ciphers": ssl_ciphers,
}
)
if (cache_config or cache) and protocol in [3, "3"]:
if (cache_config or cache) and check_protocol_version(protocol, 3):
kwargs.update(
{
"cache": cache,
Expand All @@ -380,6 +390,12 @@ def __init__(
"maint_notifications_config": maint_notifications_config,
}
)
if oss_cluster_maint_notifications_handler:
kwargs.update(
{
"oss_cluster_maint_notifications_handler": oss_cluster_maint_notifications_handler,
}
)
connection_pool = ConnectionPool(**kwargs)
self._event_dispatcher.dispatch(
AfterPooledConnectionsInstantiationEvent(
Expand Down
142 changes: 125 additions & 17 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,13 @@
WatchError,
)
from redis.lock import Lock
from redis.maint_notifications import MaintNotificationsConfig
from redis.maint_notifications import (
MaintNotificationsConfig,
OSSMaintNotificationsHandler,
)
from redis.retry import Retry
from redis.utils import (
check_protocol_version,
deprecated_args,
dict_merge,
list_keys_to_dict,
Expand Down Expand Up @@ -214,6 +218,67 @@ def cleanup_kwargs(**kwargs):
return connection_kwargs


class MaintNotificationsAbstractRedisCluster:
"""
Abstract class for handling maintenance notifications logic.
This class is expected to be used as base class together with RedisCluster.

This class is intended to be used with multiple inheritance!

All logic related to maintenance notifications is encapsulated in this class.
"""

def __init__(
self,
maint_notifications_config: Optional[MaintNotificationsConfig],
**kwargs,
):
# Initialize maintenance notifications
is_protocol_supported = check_protocol_version(kwargs.get("protocol"), 3)

if (
maint_notifications_config
and maint_notifications_config.enabled
and not is_protocol_supported
):
raise RedisError(
"Maintenance notifications handlers on connection are only supported with RESP version 3"
)
if maint_notifications_config is None and is_protocol_supported:
maint_notifications_config = MaintNotificationsConfig()

self.maint_notifications_config = maint_notifications_config

if self.maint_notifications_config and self.maint_notifications_config.enabled:
self._oss_cluster_maint_notifications_handler = (
OSSMaintNotificationsHandler(self, self.maint_notifications_config)
)
# Update connection kwargs for all future nodes connections
self._update_connection_kwargs_for_maint_notifications(
self._oss_cluster_maint_notifications_handler
)
# Update existing nodes connections - they are created as part of the RedisCluster constructor
for node in self.get_nodes():
node.redis_connection.connection_pool.update_maint_notifications_config(
self.maint_notifications_config,
oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler,
)
else:
self._oss_cluster_maint_notifications_handler = None

def _update_connection_kwargs_for_maint_notifications(
self, oss_cluster_maint_notifications_handler: OSSMaintNotificationsHandler
):
"""
Update the connection kwargs for all future connections.
"""
self.nodes_manager.connection_kwargs.update(
{
"oss_cluster_maint_notifications_handler": oss_cluster_maint_notifications_handler,
}
)


class AbstractRedisCluster:
RedisClusterRequestTTL = 16

Expand Down Expand Up @@ -461,7 +526,9 @@ def replace_default_node(self, target_node: "ClusterNode" = None) -> None:
self.nodes_manager.default_node = random.choice(replicas)


class RedisCluster(AbstractRedisCluster, RedisClusterCommands):
class RedisCluster(
AbstractRedisCluster, MaintNotificationsAbstractRedisCluster, RedisClusterCommands
):
@classmethod
def from_url(cls, url, **kwargs):
"""
Expand Down Expand Up @@ -612,8 +679,7 @@ def __init__(
`redis.maint_notifications.MaintNotificationsConfig` for details.
Only supported with RESP3.
If not provided and protocol is RESP3, the maintenance notifications
will be enabled by default (logic is included in the NodesManager
initialization).
will be enabled by default.
:**kwargs:
Extra arguments that will be sent into Redis instance when created
(See Official redis-py doc for supported kwargs - the only limitation
Expand Down Expand Up @@ -695,9 +761,16 @@ def __init__(
kwargs.get("decode_responses", False),
)
protocol = kwargs.get("protocol", None)
if (cache_config or cache) and protocol not in [3, "3"]:
if (cache_config or cache) and not check_protocol_version(protocol, 3):
raise RedisError("Client caching is only supported with RESP version 3")

if maint_notifications_config and not check_protocol_version(protocol, 3):
raise RedisError(
"Maintenance notifications are only supported with RESP version 3"
)
if check_protocol_version(protocol, 3) and maint_notifications_config is None:
maint_notifications_config = MaintNotificationsConfig()

self.command_flags = self.__class__.COMMAND_FLAGS.copy()
self.node_flags = self.__class__.NODE_FLAGS.copy()
self.read_from_replicas = read_from_replicas
Expand All @@ -709,6 +782,7 @@ def __init__(
else:
self._event_dispatcher = event_dispatcher
self.startup_nodes = startup_nodes

self.nodes_manager = NodesManager(
startup_nodes=startup_nodes,
from_url=from_url,
Expand Down Expand Up @@ -763,6 +837,10 @@ def __init__(
self._aggregate_nodes = None
self._lock = threading.RLock()

MaintNotificationsAbstractRedisCluster.__init__(
self, maint_notifications_config, **kwargs
)

def __enter__(self):
return self

Expand Down Expand Up @@ -1632,9 +1710,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
cache_factory: Optional[CacheFactoryInterface] = None,
event_dispatcher: Optional[EventDispatcher] = None,
maint_notifications_config: Optional[
MaintNotificationsConfig
] = MaintNotificationsConfig(),
maint_notifications_config: Optional[MaintNotificationsConfig] = None,
**kwargs,
):
self.nodes_cache: Dict[str, Redis] = {}
Expand Down Expand Up @@ -1879,11 +1955,29 @@ def _get_or_create_cluster_node(self, host, port, role, tmp_nodes_cache):

return target_node

def initialize(self):
def initialize(
self,
additional_startup_nodes_info: List[Tuple[str, int]] = [],
disconnect_startup_nodes_pools: bool = True,
):
"""
Initializes the nodes cache, slots cache and redis connections.
:startup_nodes:
Responsible for discovering other nodes in the cluster
:disconnect_startup_nodes_pools:
Whether to disconnect the connection pool of the startup nodes
after the initialization is complete. This is useful when the
startup nodes are not part of the cluster and we want to avoid
keeping the connection open.
:additional_startup_nodes_info:
Additional nodes to add temporarily to the startup nodes.
The additional nodes will be used just in the process of extraction of the slots
and nodes information from the cluster.
This is useful when we want to add new nodes to the cluster
and initialize the client
with them.
The format of the list is a list of tuples, where each tuple contains
the host and port of the node.
"""
self.reset()
tmp_nodes_cache = {}
Expand All @@ -1893,9 +1987,25 @@ def initialize(self):
fully_covered = False
kwargs = self.connection_kwargs
exception = None

# Create cache if it's not provided and cache config is set
# should be done before initializing the first connection
# so that it will be applied to all connections
if self._cache is None and self._cache_config is not None:
if self._cache_factory is None:
self._cache = CacheFactory(self._cache_config).get_cache()
else:
self._cache = self._cache_factory.get_cache()

additional_startup_nodes = [
ClusterNode(host, port) for host, port in additional_startup_nodes_info
]
# Convert to tuple to prevent RuntimeError if self.startup_nodes
# is modified during iteration
for startup_node in tuple(self.startup_nodes.values()):
for startup_node in (
*self.startup_nodes.values(),
*additional_startup_nodes,
):
try:
if startup_node.redis_connection:
r = startup_node.redis_connection
Expand All @@ -1911,7 +2021,11 @@ def initialize(self):
# Make sure cluster mode is enabled on this node
try:
cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS"))
r.connection_pool.disconnect()
if disconnect_startup_nodes_pools:
# Disconnect the connection pool to avoid keeping the connection open
# For some cases we might not want to disconnect current pool and
# lose in flight commands responses
r.connection_pool.disconnect()
except ResponseError:
raise RedisClusterException(
"Cluster mode is not enabled on this node"
Expand Down Expand Up @@ -1992,12 +2106,6 @@ def initialize(self):
f"one reachable node: {str(exception)}"
) from exception

if self._cache is None and self._cache_config is not None:
if self._cache_factory is None:
self._cache = CacheFactory(self._cache_config).get_cache()
else:
self._cache = self._cache_factory.get_cache()

# Create Redis connections to all nodes
self.create_redis_connections(list(tmp_nodes_cache.values()))

Expand Down
Loading