Skip to content

Enable specifying load-balancing strategy on per-command basis #3598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
REPLICA,
SLOT_ID,
AbstractRedisCluster,
LoadBalancer,
LoadBalancingStrategy,
block_pipeline_command,
get_node_name,
parse_cluster_slots,
Expand All @@ -63,6 +61,7 @@
TimeoutError,
TryAgainError,
)
from redis.load_balancer import LoadBalancer, LoadBalancingStrategy
from redis.typing import AnyKeyT, EncodableT, KeyT
from redis.utils import (
SSL_AVAILABLE,
Expand Down
50 changes: 1 addition & 49 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import threading
import time
from collections import OrderedDict
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from redis._parsers import CommandsParser, Encoder
Expand Down Expand Up @@ -37,6 +36,7 @@
TimeoutError,
TryAgainError,
)
from redis.load_balancer import LoadBalancer, LoadBalancingStrategy
from redis.lock import Lock
from redis.retry import Retry
from redis.utils import (
Expand Down Expand Up @@ -1328,54 +1328,6 @@ def __del__(self):
self.redis_connection.close()


class LoadBalancingStrategy(Enum):
ROUND_ROBIN = "round_robin"
ROUND_ROBIN_REPLICAS = "round_robin_replicas"
RANDOM_REPLICA = "random_replica"


class LoadBalancer:
"""
Round-Robin Load Balancing
"""

def __init__(self, start_index: int = 0) -> None:
self.primary_to_idx = {}
self.start_index = start_index

def get_server_index(
self,
primary: str,
list_size: int,
load_balancing_strategy: LoadBalancingStrategy = LoadBalancingStrategy.ROUND_ROBIN,
) -> int:
if load_balancing_strategy == LoadBalancingStrategy.RANDOM_REPLICA:
return self._get_random_replica_index(list_size)
else:
return self._get_round_robin_index(
primary,
list_size,
load_balancing_strategy == LoadBalancingStrategy.ROUND_ROBIN_REPLICAS,
)

def reset(self) -> None:
self.primary_to_idx.clear()

def _get_random_replica_index(self, list_size: int) -> int:
return random.randint(1, list_size - 1)

def _get_round_robin_index(
self, primary: str, list_size: int, replicas_only: bool
) -> int:
server_index = self.primary_to_idx.setdefault(primary, self.start_index)
if replicas_only and server_index == 0:
# skip the primary node index
server_index = 1
# Update the index for the next round
self.primary_to_idx[primary] = (server_index + 1) % list_size
return server_index


class NodesManager:
def __init__(
self,
Expand Down
17 changes: 9 additions & 8 deletions redis/commands/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from redis.crc import key_slot
from redis.exceptions import RedisClusterException, RedisError
from redis.load_balancer import LoadBalancingStrategy
from redis.typing import (
AnyKeyT,
ClusterCommandsProtocol,
Expand Down Expand Up @@ -124,7 +125,7 @@ def _partition_pairs_by_slot(
return slots_to_pairs

def _execute_pipeline_by_slot(
self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]]
self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]], *, load_balancing_strategy: Optional["LoadBalancingStrategy"] = None
) -> List[Any]:
read_from_replicas = self.read_from_replicas and command in READ_COMMANDS
pipe = self.pipeline()
Expand All @@ -133,7 +134,7 @@ def _execute_pipeline_by_slot(
command,
*slot_args,
target_nodes=[
self.nodes_manager.get_node_from_slot(slot, read_from_replicas)
self.nodes_manager.get_node_from_slot(slot, read_from_replicas, load_balancing_strategy)
],
)
for slot, slot_args in slots_to_args.items()
Expand All @@ -153,7 +154,7 @@ def _reorder_keys_by_command(
}
return [results[key] for key in keys]

def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]:
def mget_nonatomic(self, keys: KeysT, *args: KeyT, load_balancing_strategy: Optional["LoadBalancingStrategy"] = None) -> List[Optional[Any]]:
"""
Splits the keys into different slots and then calls MGET
for the keys of every slot. This operation will not be atomic
Expand All @@ -171,7 +172,7 @@ def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]:
slots_to_keys = self._partition_keys_by_slot(keys)

# Execute commands using a pipeline
res = self._execute_pipeline_by_slot("MGET", slots_to_keys)
res = self._execute_pipeline_by_slot("MGET", slots_to_keys, load_balancing_strategy=load_balancing_strategy)

# Reorder keys in the order the user provided & return
return self._reorder_keys_by_command(keys, slots_to_keys, res)
Expand Down Expand Up @@ -265,7 +266,7 @@ class AsyncClusterMultiKeyCommands(ClusterMultiKeyCommands):
A class containing commands that handle more than one key
"""

async def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]:
async def mget_nonatomic(self, keys: KeysT, *args: KeyT, load_balancing_strategy: Optional["LoadBalancingStrategy"] = None) -> List[Optional[Any]]:
"""
Splits the keys into different slots and then calls MGET
for the keys of every slot. This operation will not be atomic
Expand All @@ -283,7 +284,7 @@ async def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]:
slots_to_keys = self._partition_keys_by_slot(keys)

# Execute commands using a pipeline
res = await self._execute_pipeline_by_slot("MGET", slots_to_keys)
res = await self._execute_pipeline_by_slot("MGET", slots_to_keys, load_balancing_strategy=load_balancing_strategy)

# Reorder keys in the order the user provided & return
return self._reorder_keys_by_command(keys, slots_to_keys, res)
Expand Down Expand Up @@ -320,7 +321,7 @@ async def _split_command_across_slots(self, command: str, *keys: KeyT) -> int:
return sum(await self._execute_pipeline_by_slot(command, slots_to_keys))

async def _execute_pipeline_by_slot(
self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]]
self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]], *, load_balancing_strategy: Optional["LoadBalancingStrategy"] = None
) -> List[Any]:
if self._initialize:
await self.initialize()
Expand All @@ -331,7 +332,7 @@ async def _execute_pipeline_by_slot(
command,
*slot_args,
target_nodes=[
self.nodes_manager.get_node_from_slot(slot, read_from_replicas)
self.nodes_manager.get_node_from_slot(slot, read_from_replicas, load_balancing_strategy)
],
)
for slot, slot_args in slots_to_args.items()
Expand Down
49 changes: 49 additions & 0 deletions redis/load_balancer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from enum import Enum


class LoadBalancingStrategy(Enum):
ROUND_ROBIN = "round_robin"
ROUND_ROBIN_REPLICAS = "round_robin_replicas"
RANDOM_REPLICA = "random_replica"


class LoadBalancer:
"""
Round-Robin Load Balancing
"""

def __init__(self, start_index: int = 0) -> None:
self.primary_to_idx = {}
self.start_index = start_index

def get_server_index(
self,
primary: str,
list_size: int,
load_balancing_strategy: LoadBalancingStrategy = LoadBalancingStrategy.ROUND_ROBIN,
) -> int:
if load_balancing_strategy == LoadBalancingStrategy.RANDOM_REPLICA:
return self._get_random_replica_index(list_size)
else:
return self._get_round_robin_index(
primary,
list_size,
load_balancing_strategy == LoadBalancingStrategy.ROUND_ROBIN_REPLICAS,
)

def reset(self) -> None:
self.primary_to_idx.clear()

def _get_random_replica_index(self, list_size: int) -> int:
return random.randint(1, list_size - 1)

def _get_round_robin_index(
self, primary: str, list_size: int, replicas_only: bool
) -> int:
server_index = self.primary_to_idx.setdefault(primary, self.start_index)
if replicas_only and server_index == 0:
# skip the primary node index
server_index = 1
# Update the index for the next round
self.primary_to_idx[primary] = (server_index + 1) % list_size
return server_index
2 changes: 1 addition & 1 deletion tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
PIPELINE_BLOCKED_COMMANDS,
PRIMARY,
REPLICA,
LoadBalancingStrategy,
get_node_name,
)
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
Expand All @@ -34,6 +33,7 @@
RedisError,
ResponseError,
)
from redis.load_balancer import LoadBalancingStrategy
from redis.utils import str_if_bytes
from tests.conftest import (
assert_resp_response,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
REDIS_CLUSTER_HASH_SLOTS,
REPLICA,
ClusterNode,
LoadBalancingStrategy,
NodesManager,
RedisCluster,
get_node_name,
Expand All @@ -39,6 +38,7 @@
ResponseError,
TimeoutError,
)
from redis.load_balancer import LoadBalancingStrategy
from redis.retry import Retry
from redis.utils import str_if_bytes
from tests.test_pubsub import wait_for_message
Expand Down