Skip to content

Commit 6b8b494

Browse files
krrishdholakiaBardiaKhjcmorrow
authored
Fix azure max retries error (BerriAI#8340)
* fix(azure.py): ensure max_retries=0 is respected Fixes BerriAI#6129 * fix(test_openai.py): add unit test to ensure openai sdk calls always respect max_retries = 0 * test(test_azure_openai.py): add unit testing for azure_text/ route * fix(azure.py): fix passing max retries on streaming * fix(azure.py): fix azure max retries on async completion + streaming * fix(completion/handler.py): fix azure text async completion + streaming * test(test_azure_openai.py): ensure azure openai max retries always respected * test(test_azure_o_series.py): add testing to ensure max retries always respected * Added gemini providers for 2.0-flash and 2.0-flash lite (BerriAI#8321) * Update model_prices_and_context_window.json added gemini providers for 2.0-flash and 2.0-flash light * Update model_prices_and_context_window.json fixed URL --------- Co-authored-by: Krish Dholakia <krrishdholakia@gmail.com> * Convert tool use arguments to string before counting tokens (BerriAI#6989) In at least some cases the `messages["tool_calls"]["function"]["arguments"]` is a dict, not a string. In order to tokenize it properly it needs to be a string. In the case that it is already a string this is a noop, which is also fine. * build(model_prices_and_context_window.json): add gemini 2.0 flash lite pricing * build(model_prices_and_context_window.json): add gemini commercial rate limits * fix(utils.py): fix linting error * refactor(utils.py): refactor to maintain function size --------- Co-authored-by: Bardia Khosravi <bardiakhosravi95@gmail.com> Co-authored-by: Josh Morrow <josh@jcmorrow.com>
1 parent d720744 commit 6b8b494

File tree

7 files changed

+176
-28
lines changed

7 files changed

+176
-28
lines changed

litellm/llms/azure/azure.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import litellm
1111
from litellm.caching.caching import DualCache
12+
from litellm.constants import DEFAULT_MAX_RETRIES
1213
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
1314
from litellm.llms.custom_httpx.http_handler import (
1415
AsyncHTTPHandler,
@@ -98,14 +99,6 @@ def map_openai_params_create_message_params(
9899

99100

100101
def select_azure_base_url_or_endpoint(azure_client_params: dict):
101-
# azure_client_params = {
102-
# "api_version": api_version,
103-
# "azure_endpoint": api_base,
104-
# "azure_deployment": model,
105-
# "http_client": litellm.client_session,
106-
# "max_retries": max_retries,
107-
# "timeout": timeout,
108-
# }
109102
azure_endpoint = azure_client_params.get("azure_endpoint", None)
110103
if azure_endpoint is not None:
111104
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
@@ -353,7 +346,9 @@ def completion( # noqa: PLR0915
353346
status_code=422, message="Missing model or messages"
354347
)
355348

356-
max_retries = optional_params.pop("max_retries", 2)
349+
max_retries = optional_params.pop("max_retries", None)
350+
if max_retries is None:
351+
max_retries = DEFAULT_MAX_RETRIES
357352
json_mode: Optional[bool] = optional_params.pop("json_mode", False)
358353

359354
### CHECK IF CLOUDFLARE AI GATEWAY ###
@@ -415,6 +410,7 @@ def completion( # noqa: PLR0915
415410
azure_ad_token_provider=azure_ad_token_provider,
416411
timeout=timeout,
417412
client=client,
413+
max_retries=max_retries,
418414
)
419415
else:
420416
return self.acompletion(
@@ -430,6 +426,7 @@ def completion( # noqa: PLR0915
430426
timeout=timeout,
431427
client=client,
432428
logging_obj=logging_obj,
429+
max_retries=max_retries,
433430
convert_tool_call_to_json_mode=json_mode,
434431
)
435432
elif "stream" in optional_params and optional_params["stream"] is True:
@@ -445,6 +442,7 @@ def completion( # noqa: PLR0915
445442
azure_ad_token_provider=azure_ad_token_provider,
446443
timeout=timeout,
447444
client=client,
445+
max_retries=max_retries,
448446
)
449447
else:
450448
## LOGGING
@@ -553,19 +551,14 @@ async def acompletion(
553551
dynamic_params: bool,
554552
model_response: ModelResponse,
555553
logging_obj: LiteLLMLoggingObj,
554+
max_retries: int,
556555
azure_ad_token: Optional[str] = None,
557556
azure_ad_token_provider: Optional[Callable] = None,
558557
convert_tool_call_to_json_mode: Optional[bool] = None,
559558
client=None, # this is the AsyncAzureOpenAI
560559
):
561560
response = None
562561
try:
563-
max_retries = data.pop("max_retries", 2)
564-
if not isinstance(max_retries, int):
565-
raise AzureOpenAIError(
566-
status_code=422, message="max retries must be an int"
567-
)
568-
569562
# init AzureOpenAI Client
570563
azure_client_params = {
571564
"api_version": api_version,
@@ -671,15 +664,11 @@ def streaming(
671664
data: dict,
672665
model: str,
673666
timeout: Any,
667+
max_retries: int,
674668
azure_ad_token: Optional[str] = None,
675669
azure_ad_token_provider: Optional[Callable] = None,
676670
client=None,
677671
):
678-
max_retries = data.pop("max_retries", 2)
679-
if not isinstance(max_retries, int):
680-
raise AzureOpenAIError(
681-
status_code=422, message="max retries must be an int"
682-
)
683672
# init AzureOpenAI Client
684673
azure_client_params = {
685674
"api_version": api_version,
@@ -742,6 +731,7 @@ async def async_streaming(
742731
data: dict,
743732
model: str,
744733
timeout: Any,
734+
max_retries: int,
745735
azure_ad_token: Optional[str] = None,
746736
azure_ad_token_provider: Optional[Callable] = None,
747737
client=None,
@@ -753,7 +743,7 @@ async def async_streaming(
753743
"azure_endpoint": api_base,
754744
"azure_deployment": model,
755745
"http_client": litellm.aclient_session,
756-
"max_retries": data.pop("max_retries", 2),
746+
"max_retries": max_retries,
757747
"timeout": timeout,
758748
}
759749
azure_client_params = select_azure_base_url_or_endpoint(

litellm/llms/azure/completion/handler.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def completion( # noqa: PLR0915
131131
timeout=timeout,
132132
client=client,
133133
logging_obj=logging_obj,
134+
max_retries=max_retries,
134135
)
135136
elif "stream" in optional_params and optional_params["stream"] is True:
136137
return self.streaming(
@@ -236,17 +237,12 @@ async def acompletion(
236237
timeout: Any,
237238
model_response: ModelResponse,
238239
logging_obj: Any,
240+
max_retries: int,
239241
azure_ad_token: Optional[str] = None,
240242
client=None, # this is the AsyncAzureOpenAI
241243
):
242244
response = None
243245
try:
244-
max_retries = data.pop("max_retries", 2)
245-
if not isinstance(max_retries, int):
246-
raise AzureOpenAIError(
247-
status_code=422, message="max retries must be an int"
248-
)
249-
250246
# init AzureOpenAI Client
251247
azure_client_params = {
252248
"api_version": api_version,

litellm/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,8 @@ def completion( # type: ignore # noqa: PLR0915
12221222

12231223
if extra_headers is not None:
12241224
optional_params["extra_headers"] = extra_headers
1225+
if max_retries is not None:
1226+
optional_params["max_retries"] = max_retries
12251227

12261228
if litellm.AzureOpenAIO1Config().is_o_series_model(model=model):
12271229

litellm/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1785,7 +1785,10 @@ def token_counter(
17851785
for tool_call in message["tool_calls"]:
17861786
if "function" in tool_call:
17871787
function_arguments = tool_call["function"]["arguments"]
1788-
text += function_arguments
1788+
text = (
1789+
text if isinstance(text, str) else "".join(text or [])
1790+
) + (str(function_arguments) if function_arguments else "")
1791+
17891792
else:
17901793
raise ValueError("text and messages cannot both be None")
17911794
elif isinstance(text, List):

tests/llm_translation/test_azure_o_series.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,18 @@ def test_azure_o_series_routing():
152152
print(e)
153153
assert mock_create.call_count == 1
154154
assert "stream" not in mock_create.call_args.kwargs
155+
156+
157+
@patch("litellm.main.azure_o1_chat_completions._get_openai_client")
158+
def test_openai_o_series_max_retries_0(mock_get_openai_client):
159+
import litellm
160+
161+
litellm.set_verbose = True
162+
response = litellm.completion(
163+
model="azure/o1-preview",
164+
messages=[{"role": "user", "content": "hi"}],
165+
max_retries=0,
166+
)
167+
168+
mock_get_openai_client.assert_called_once()
169+
assert mock_get_openai_client.call_args.kwargs["max_retries"] == 0

tests/llm_translation/test_azure_openai.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,3 +436,130 @@ def test_map_openai_params():
436436
optional_params = azure_openai_config.map_openai_params(**received_args)
437437
assert "tools" in optional_params
438438
assert len(optional_params["tools"]) > 1
439+
440+
441+
@pytest.mark.parametrize("max_retries", [0, 4])
442+
@pytest.mark.parametrize("stream", [True, False])
443+
@patch(
444+
"litellm.main.azure_chat_completions.make_sync_azure_openai_chat_completion_request"
445+
)
446+
def test_azure_max_retries_0(
447+
mock_make_sync_azure_openai_chat_completion_request, max_retries, stream
448+
):
449+
from litellm import completion
450+
451+
try:
452+
completion(
453+
model="azure/gpt-4o",
454+
messages=[{"role": "user", "content": "Hello world"}],
455+
max_retries=max_retries,
456+
stream=stream,
457+
)
458+
except Exception as e:
459+
print(e)
460+
461+
mock_make_sync_azure_openai_chat_completion_request.assert_called_once()
462+
assert (
463+
mock_make_sync_azure_openai_chat_completion_request.call_args.kwargs[
464+
"azure_client"
465+
].max_retries
466+
== max_retries
467+
)
468+
469+
470+
@pytest.mark.parametrize("max_retries", [0, 4])
471+
@pytest.mark.parametrize("stream", [True, False])
472+
@patch("litellm.main.azure_chat_completions.make_azure_openai_chat_completion_request")
473+
@pytest.mark.asyncio
474+
async def test_async_azure_max_retries_0(
475+
make_azure_openai_chat_completion_request, max_retries, stream
476+
):
477+
from litellm import acompletion
478+
479+
try:
480+
await acompletion(
481+
model="azure/gpt-4o",
482+
messages=[{"role": "user", "content": "Hello world"}],
483+
max_retries=max_retries,
484+
stream=stream,
485+
)
486+
except Exception as e:
487+
print(e)
488+
489+
make_azure_openai_chat_completion_request.assert_called_once()
490+
assert (
491+
make_azure_openai_chat_completion_request.call_args.kwargs[
492+
"azure_client"
493+
].max_retries
494+
== max_retries
495+
)
496+
497+
498+
@pytest.mark.parametrize("max_retries", [0, 4])
499+
@pytest.mark.parametrize("stream", [True, False])
500+
@pytest.mark.parametrize("sync_mode", [True, False])
501+
@patch("litellm.llms.azure.completion.handler.select_azure_base_url_or_endpoint")
502+
@pytest.mark.asyncio
503+
async def test_azure_instruct(
504+
mock_select_azure_base_url_or_endpoint, max_retries, stream, sync_mode
505+
):
506+
from litellm import completion, acompletion
507+
508+
args = {
509+
"model": "azure_text/instruct-model",
510+
"messages": [
511+
{"role": "user", "content": "What is the weather like in Boston?"}
512+
],
513+
"max_tokens": 10,
514+
"max_retries": max_retries,
515+
}
516+
517+
try:
518+
if sync_mode:
519+
completion(**args)
520+
else:
521+
await acompletion(**args)
522+
except Exception:
523+
pass
524+
525+
mock_select_azure_base_url_or_endpoint.assert_called_once()
526+
assert (
527+
mock_select_azure_base_url_or_endpoint.call_args.kwargs["azure_client_params"][
528+
"max_retries"
529+
]
530+
== max_retries
531+
)
532+
533+
534+
@pytest.mark.parametrize("max_retries", [0, 4])
535+
@pytest.mark.parametrize("stream", [True, False])
536+
@pytest.mark.parametrize("sync_mode", [True, False])
537+
@patch("litellm.llms.azure.azure.select_azure_base_url_or_endpoint")
538+
@pytest.mark.asyncio
539+
async def test_azure_embedding_max_retries_0(
540+
mock_select_azure_base_url_or_endpoint, max_retries, stream, sync_mode
541+
):
542+
from litellm import aembedding, embedding
543+
544+
args = {
545+
"model": "azure/azure-embedding-model",
546+
"input": "Hello world",
547+
"max_retries": max_retries,
548+
"stream": stream,
549+
}
550+
551+
try:
552+
if sync_mode:
553+
embedding(**args)
554+
else:
555+
await aembedding(**args)
556+
except Exception as e:
557+
print(e)
558+
559+
mock_select_azure_base_url_or_endpoint.assert_called_once()
560+
assert (
561+
mock_select_azure_base_url_or_endpoint.call_args.kwargs["azure_client_params"][
562+
"max_retries"
563+
]
564+
== max_retries
565+
)

tests/llm_translation/test_openai.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,3 +314,18 @@ def test_completion_bad_org():
314314
os.environ["OPENAI_ORGANIZATION"] = _old_org
315315
else:
316316
del os.environ["OPENAI_ORGANIZATION"]
317+
318+
319+
@patch("litellm.main.openai_chat_completions._get_openai_client")
320+
def test_openai_max_retries_0(mock_get_openai_client):
321+
import litellm
322+
323+
litellm.set_verbose = True
324+
response = litellm.completion(
325+
model="gpt-4o-mini",
326+
messages=[{"role": "user", "content": "hi"}],
327+
max_retries=0,
328+
)
329+
330+
mock_get_openai_client.assert_called_once()
331+
assert mock_get_openai_client.call_args.kwargs["max_retries"] == 0

0 commit comments

Comments
 (0)