|
36 | 36 | from typing import List, Dict, Optional, Any |
37 | 37 |
|
38 | 38 | from llm_router_api.base.constants import REDIS_PORT, REDIS_HOST |
39 | | -from llm_router_api.base.lb.strategy import ChooseProviderStrategyI |
40 | 39 | from llm_router_api.base.lb.provider_monitor import RedisProviderMonitor |
| 40 | +from llm_router_api.base.lb.first_available_i import FirstAvailableStrategyI |
41 | 41 |
|
42 | 42 |
|
43 | | -class FirstAvailableStrategy(ChooseProviderStrategyI): |
| 43 | +class FirstAvailableStrategy(FirstAvailableStrategyI): |
44 | 44 | """ |
45 | 45 | Strategy that selects the first free provider for a model using Redis. |
46 | 46 |
|
@@ -87,53 +87,16 @@ def __init__( |
87 | 87 | clear_buffers: |
88 | 88 | Whether to clear all buffers when starting. Default is ``True``. |
89 | 89 | """ |
90 | | - if not REDIS_IS_AVAILABLE: |
91 | | - raise RuntimeError("Redis is not available. Please install it first.") |
92 | | - |
93 | | - super().__init__(models_config_path=models_config_path, logger=logger) |
94 | | - |
95 | | - self.redis_client = redis.Redis( |
96 | | - host=redis_host, port=redis_port, db=redis_db, decode_responses=True |
97 | | - ) |
98 | | - self.timeout = timeout |
99 | | - self.check_interval = check_interval |
100 | | - |
101 | | - # Atomic acquire script – treat missing field as “available” |
102 | | - self._acquire_script = self.redis_client.register_script( |
103 | | - """ |
104 | | - local redis_key = KEYS[1] |
105 | | - local field = ARGV[1] |
106 | | - local v = redis.call('HGET', redis_key, field) |
107 | | - -- v == false -> field does not exist (nil) |
108 | | - -- v == 'false' -> explicitly marked as free |
109 | | - if v == false or v == 'false' then |
110 | | - redis.call('HSET', redis_key, field, 'true') |
111 | | - return 1 |
112 | | - end |
113 | | - return 0 |
114 | | - """ |
115 | | - ) |
116 | | - |
117 | | - # Atomic release script – simply delete the field (no race condition) |
118 | | - self._release_script = self.redis_client.register_script( |
119 | | - """ |
120 | | - local redis_key = KEYS[1] |
121 | | - local field = ARGV[1] |
122 | | - -- Delete the field; returns 1 if field existed, 0 otherwise |
123 | | - redis.call('HDEL', redis_key, field) |
124 | | - return 1 |
125 | | - """ |
126 | | - ) |
127 | | - |
128 | | - if clear_buffers: |
129 | | - self._clear_buffers() |
130 | | - |
131 | | - # Start providers monitor |
132 | | - self._monitor = RedisProviderMonitor( |
133 | | - redis_client=self.redis_client, |
134 | | - check_interval=30, |
| 90 | + super().__init__( |
| 91 | + models_config_path=models_config_path, |
| 92 | + redis_host=redis_host, |
| 93 | + redis_port=redis_port, |
| 94 | + redis_db=redis_db, |
| 95 | + timeout=timeout, |
| 96 | + check_interval=check_interval, |
135 | 97 | clear_buffers=clear_buffers, |
136 | | - logger=self.logger, |
| 98 | + logger=logger, |
| 99 | + strategy_prefix="fa_", |
137 | 100 | ) |
138 | 101 |
|
139 | 102 | def get_provider( |
@@ -185,24 +148,14 @@ def get_provider( |
185 | 148 | * Call :meth:`put_provider` to release the lock once the provider is no |
186 | 149 | longer needed. |
187 | 150 | """ |
188 | | - if not providers: |
189 | | - return None |
190 | 151 |
|
191 | | - # Register providers for monitoring (only once per model) |
192 | | - self._monitor.add_providers(model_name, providers) |
| 152 | + redis_key, is_random = self.init_provider( |
| 153 | + model_name=model_name, providers=providers, options=options |
| 154 | + ) |
| 155 | + if not redis_key: |
| 156 | + return None |
193 | 157 |
|
194 | | - redis_key = self._get_redis_key(model_name) |
195 | 158 | start_time = time.time() |
196 | | - |
197 | | - # Ensure fields exist; if someone removed the hash, recreate it |
198 | | - if not self.redis_client.exists(redis_key): |
199 | | - for p in providers: |
200 | | - self.redis_client.hset(redis_key, self._provider_field(p), "false") |
201 | | - |
202 | | - # self._print_provider_status(redis_key, providers) |
203 | | - |
204 | | - is_random = options and options.get("random_choice", False) |
205 | | - |
206 | 159 | while True: |
207 | 160 | _providers = self._get_active_providers( |
208 | 161 | model_name=model_name, providers=providers |
@@ -348,145 +301,3 @@ def _get_active_providers( |
348 | 301 | model_name=model_name, only_active=True |
349 | 302 | ) |
350 | 303 | return active_providers |
351 | | - |
352 | | - def _get_redis_key(self, model_name: str) -> str: |
353 | | - """ |
354 | | - Return Redis key prefix for a given model. |
355 | | - """ |
356 | | - for ch in self.REPLACE_PROVIDER_KEY: |
357 | | - model_name = model_name.replace(ch, "_") |
358 | | - return f"model:{model_name}" |
359 | | - |
360 | | - def _provider_field(self, provider: dict) -> str: |
361 | | - """ |
362 | | - Build the Redis hash field name that stores the chosen flag |
363 | | - for a given provider. |
364 | | -
|
365 | | - Parameters |
366 | | - ---------- |
367 | | - provider : dict |
368 | | - Provider configuration dictionary. |
369 | | -
|
370 | | - Returns |
371 | | - ------- |
372 | | - str |
373 | | - Field name in the format ``{provider_id}:is_chosen``. |
374 | | - """ |
375 | | - provider_id = self._provider_key(provider) |
376 | | - return f"{provider_id}:is_chosen" |
377 | | - |
378 | | - def _init_flag(self, model_name: str) -> str: |
379 | | - """ |
380 | | - Build the Redis key used as an initialization flag for a model. |
381 | | -
|
382 | | - Parameters |
383 | | - ---------- |
384 | | - model_name : str |
385 | | - Name of the model. |
386 | | -
|
387 | | - Returns |
388 | | - ------- |
389 | | - str |
390 | | - Flag key in the format ``model:{model_name}:initialized``. |
391 | | - """ |
392 | | - return f"{self._get_redis_key(model_name)}:initialized" |
393 | | - |
394 | | - def _initialize_providers(self, model_name: str, providers: List[Dict]) -> None: |
395 | | - """ |
396 | | - Ensure that the provider lock fields for *model_name* exist in Redis. |
397 | | -
|
398 | | - This method is idempotent – it will create the hash fields only if the |
399 | | - model has not been initialized before. An auxiliary flag key |
400 | | - ``model:{model_name}:initialized`` is used to guard against repeated |
401 | | - initialization, which could otherwise overwrite the current lock state |
402 | | - of providers that are already in use. |
403 | | -
|
404 | | - Parameters |
405 | | - ---------- |
406 | | - model_name : str |
407 | | - The name of the model whose providers are being prepared. |
408 | | - providers : List[Dict] |
409 | | - A list of provider configuration dictionaries. Each dictionary must |
410 | | - contain enough information for :meth:`_provider_field` to generate a |
411 | | - unique field name. |
412 | | -
|
413 | | - Notes |
414 | | - ----- |
415 | | - * The provider fields are stored in a Redis hash whose key is |
416 | | - ``model:{model_name}``. Each field is set to the string ``'false'`` |
417 | | - to indicate that the provider is currently free. |
418 | | - * The initialization flag is a simple Redis key with value ``'1'``. |
419 | | - Its existence signals that the hash has already been populated. |
420 | | - """ |
421 | | - redis_key = self._get_redis_key(model_name) |
422 | | - |
423 | | - # Check if already initialized using a flag |
424 | | - init_flag = self._init_flag(model_name) |
425 | | - if self.redis_client.exists(init_flag): |
426 | | - return |
427 | | - |
428 | | - # Initialize all providers as available |
429 | | - for provider in providers: |
430 | | - provider_field = self._provider_field(provider) |
431 | | - self.redis_client.hset(redis_key, provider_field, "false") |
432 | | - |
433 | | - # Set initialization flag |
434 | | - self.redis_client.set(init_flag, "1") |
435 | | - |
436 | | - def _clear_buffers(self) -> None: |
437 | | - """ |
438 | | - Reset the Redis state for all active models. |
439 | | -
|
440 | | - This method removes any existing initialization flags and provider |
441 | | - lock fields, then re‑initialises the providers as available. It is |
442 | | - typically invoked during strategy start‑up to ensure a clean slate. |
443 | | - """ |
444 | | - active_models = self._api_model_config.active_models |
445 | | - models_configs = self._api_model_config.models_configs |
446 | | - for _, models_names in active_models.items(): |
447 | | - for model_name in models_names: |
448 | | - redis_key = self._get_redis_key(model_name) |
449 | | - providers = models_configs[model_name]["providers"] |
450 | | - if len(providers) > 0: |
451 | | - model_path = providers[0].get("model_path", "").strip() |
452 | | - if model_path: |
453 | | - model_name = model_path |
454 | | - |
455 | | - init_flag = self._init_flag(model_name) |
456 | | - self.redis_client.delete(init_flag) |
457 | | - |
458 | | - for provider in providers: |
459 | | - provider_field = self._provider_field(provider) |
460 | | - self.redis_client.hset(redis_key, provider_field, "false") |
461 | | - |
462 | | - self._initialize_providers( |
463 | | - model_name=model_name, providers=providers |
464 | | - ) |
465 | | - |
466 | | - def _print_provider_status(self, redis_key: str, providers: List[Dict]) -> None: |
467 | | - """ |
468 | | - Print the lock status of each provider stored in the Redis hash |
469 | | - ``redis_key``. Uses emojis for a quick visual cue: |
470 | | -
|
471 | | - * 🟢 – provider is free (`'false'` or missing) |
472 | | - * 🔴 – provider is currently taken (`'true'`) |
473 | | -
|
474 | | - The output is formatted in a table‑like layout for readability. |
475 | | - """ |
476 | | - try: |
477 | | - # Retrieve the entire hash; missing fields default to None |
478 | | - hash_data = self.redis_client.hgetall(redis_key) |
479 | | - except Exception as exc: |
480 | | - print(f"[⚠️] Could not read Redis key '{redis_key}': {exc}") |
481 | | - return |
482 | | - |
483 | | - print("\nProvider lock status:") |
484 | | - print("-" * 40) |
485 | | - for provider in providers: |
486 | | - field = self._provider_field(provider) |
487 | | - status = hash_data.get(field, "false") |
488 | | - icon = "🔴" if status == "true" else "🟢" |
489 | | - # Show a short identifier for the provider (fallback to field) |
490 | | - provider_id = provider.get("id") or provider.get("name") or field |
491 | | - print(f"{icon} {provider_id:<30} [{field}]") |
492 | | - print("-" * 40) |
0 commit comments