Skip to content

Commit 0c3b7bb

Browse files
fix(router.py): handle edge case where user sets 'model_group' inside… (BerriAI#10191)
* fix(router.py): handle edge case where user sets 'model_group' inside 'model_info' * fix(key_management_endpoints.py): security fix - return hashed token in 'token' field Ensures when creating a key on UI - only hashed token shown * test(test_key_management_endpoints.py): add unit test * test: update test
1 parent 03245c7 commit 0c3b7bb

File tree

5 files changed

+107
-22
lines changed

5 files changed

+107
-22
lines changed

litellm/proxy/management_endpoints/key_management_endpoints.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -577,12 +577,16 @@ async def generate_key_fn( # noqa: PLR0915
577577
request_type="key", **data_json, table_name="key"
578578
)
579579

580-
response["soft_budget"] = (
581-
data.soft_budget
582-
) # include the user-input soft budget in the response
580+
response[
581+
"soft_budget"
582+
] = data.soft_budget # include the user-input soft budget in the response
583583

584584
response = GenerateKeyResponse(**response)
585585

586+
response.token = (
587+
response.token_id
588+
) # remap token to use the hash, and leave the key in the `key` field [TODO]: clean up generate_key_helper_fn to do this
589+
586590
asyncio.create_task(
587591
KeyManagementEventHooks.async_key_generated_hook(
588592
data=data,
@@ -1470,10 +1474,10 @@ async def delete_verification_tokens(
14701474
try:
14711475
if prisma_client:
14721476
tokens = [_hash_token_if_needed(token=key) for key in tokens]
1473-
_keys_being_deleted: List[LiteLLM_VerificationToken] = (
1474-
await prisma_client.db.litellm_verificationtoken.find_many(
1475-
where={"token": {"in": tokens}}
1476-
)
1477+
_keys_being_deleted: List[
1478+
LiteLLM_VerificationToken
1479+
] = await prisma_client.db.litellm_verificationtoken.find_many(
1480+
where={"token": {"in": tokens}}
14771481
)
14781482

14791483
# Assuming 'db' is your Prisma Client instance
@@ -1575,9 +1579,9 @@ async def _rotate_master_key(
15751579
from litellm.proxy.proxy_server import proxy_config
15761580

15771581
try:
1578-
models: Optional[List] = (
1579-
await prisma_client.db.litellm_proxymodeltable.find_many()
1580-
)
1582+
models: Optional[
1583+
List
1584+
] = await prisma_client.db.litellm_proxymodeltable.find_many()
15811585
except Exception:
15821586
models = None
15831587
# 2. process model table
@@ -1864,11 +1868,11 @@ async def validate_key_list_check(
18641868
param="user_id",
18651869
code=status.HTTP_403_FORBIDDEN,
18661870
)
1867-
complete_user_info_db_obj: Optional[BaseModel] = (
1868-
await prisma_client.db.litellm_usertable.find_unique(
1869-
where={"user_id": user_api_key_dict.user_id},
1870-
include={"organization_memberships": True},
1871-
)
1871+
complete_user_info_db_obj: Optional[
1872+
BaseModel
1873+
] = await prisma_client.db.litellm_usertable.find_unique(
1874+
where={"user_id": user_api_key_dict.user_id},
1875+
include={"organization_memberships": True},
18721876
)
18731877

18741878
if complete_user_info_db_obj is None:
@@ -1929,10 +1933,10 @@ async def get_admin_team_ids(
19291933
if complete_user_info is None:
19301934
return []
19311935
# Get all teams that user is an admin of
1932-
teams: Optional[List[BaseModel]] = (
1933-
await prisma_client.db.litellm_teamtable.find_many(
1934-
where={"team_id": {"in": complete_user_info.teams}}
1935-
)
1936+
teams: Optional[
1937+
List[BaseModel]
1938+
] = await prisma_client.db.litellm_teamtable.find_many(
1939+
where={"team_id": {"in": complete_user_info.teams}}
19361940
)
19371941
if teams is None:
19381942
return []

litellm/router.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4983,8 +4983,12 @@ def _set_model_group_info( # noqa: PLR0915
49834983
)
49844984

49854985
if model_group_info is None:
4986-
model_group_info = ModelGroupInfo(
4987-
model_group=user_facing_model_group_name, providers=[llm_provider], **model_info # type: ignore
4986+
model_group_info = ModelGroupInfo( # type: ignore
4987+
**{
4988+
"model_group": user_facing_model_group_name,
4989+
"providers": [llm_provider],
4990+
**model_info,
4991+
}
49884992
)
49894993
else:
49904994
# if max_input_tokens > curr

tests/litellm/proxy/management_endpoints/test_key_management_endpoints.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,54 @@ async def test_list_keys():
4646
assert json.dumps({"team_id": {"not": "litellm-dashboard"}}) in json.dumps(
4747
where_condition
4848
)
49+
50+
51+
@pytest.mark.asyncio
52+
async def test_key_token_handling(monkeypatch):
53+
"""
54+
Test that token handling in key generation follows the expected behavior:
55+
1. token field should not equal key field
56+
2. if token_id exists, it should equal token field
57+
"""
58+
mock_prisma_client = AsyncMock()
59+
mock_insert_data = AsyncMock(
60+
return_value=MagicMock(token="hashed_token_123", litellm_budget_table=None)
61+
)
62+
mock_prisma_client.insert_data = mock_insert_data
63+
mock_prisma_client.db = MagicMock()
64+
mock_prisma_client.db.litellm_verificationtoken = MagicMock()
65+
mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock(
66+
return_value=None
67+
)
68+
mock_prisma_client.db.litellm_verificationtoken.find_many = AsyncMock(
69+
return_value=[]
70+
)
71+
mock_prisma_client.db.litellm_verificationtoken.count = AsyncMock(return_value=0)
72+
mock_prisma_client.db.litellm_verificationtoken.update = AsyncMock(
73+
return_value=MagicMock(token="hashed_token_123", litellm_budget_table=None)
74+
)
75+
76+
from litellm.proxy._types import GenerateKeyRequest, LitellmUserRoles
77+
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth
78+
from litellm.proxy.management_endpoints.key_management_endpoints import (
79+
generate_key_fn,
80+
)
81+
from litellm.proxy.proxy_server import prisma_client
82+
83+
# Use monkeypatch to set the prisma_client
84+
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
85+
86+
# Test key generation
87+
response = await generate_key_fn(
88+
data=GenerateKeyRequest(),
89+
user_api_key_dict=UserAPIKeyAuth(
90+
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234"
91+
),
92+
)
93+
94+
# Verify token handling
95+
assert response.key != response.token, "Token should not equal key"
96+
if hasattr(response, "token_id"):
97+
assert (
98+
response.token == response.token_id
99+
), "Token should equal token_id if token_id exists"

tests/litellm/test_router.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,29 @@ def test_update_kwargs_does_not_mutate_defaults_and_merges_metadata():
5252

5353
# 3) metadata lands under "metadata"
5454
assert kwargs["litellm_metadata"] == {"baz": 123}
55+
56+
57+
def test_router_with_model_info_and_model_group():
58+
"""
59+
Test edge case where user specifies model_group in model_info
60+
"""
61+
router = litellm.Router(
62+
model_list=[
63+
{
64+
"model_name": "gpt-3.5-turbo",
65+
"litellm_params": {
66+
"model": "gpt-3.5-turbo",
67+
},
68+
"model_info": {
69+
"tpm": 1000,
70+
"rpm": 1000,
71+
"model_group": "gpt-3.5-turbo",
72+
},
73+
}
74+
],
75+
)
76+
77+
router._set_model_group_info(
78+
model_group="gpt-3.5-turbo",
79+
user_facing_model_group_name="gpt-3.5-turbo",
80+
)

tests/store_model_in_db_tests/test_adding_passthrough_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def create_virtual_key():
142142
json={},
143143
)
144144
print(response.json())
145-
return response.json()["token"]
145+
return response.json()["key"]
146146

147147

148148
def add_assembly_ai_model_to_db(

0 commit comments

Comments
 (0)