- Notifications
You must be signed in to change notification settings - Fork 1.4k
Insert Instructions at the end of System Prompt #3614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -395,6 +395,9 @@ def _map_messages( | |
| ) -> list[chat.ChatCompletionMessageParam]: | ||
| """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`.""" | ||
| groq_messages: list[chat.ChatCompletionMessageParam] = [] | ||
| system_prompt_count = sum( | ||
| 1 for m in messages if isinstance(m, ModelRequest) for p in m.parts if isinstance(p, SystemPromptPart) | ||
| Collaborator There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same issue as up | ||
| ) | ||
| for message in messages: | ||
| if isinstance(message, ModelRequest): | ||
| groq_messages.extend(self._map_user_message(message)) | ||
| | @@ -428,7 +431,9 @@ def _map_messages( | |
| else: | ||
| assert_never(message) | ||
| if instructions := self._get_instructions(messages, model_request_parameters): | ||
| groq_messages.insert(0, chat.ChatCompletionSystemMessageParam(role='system', content=instructions)) | ||
| groq_messages.insert( | ||
| system_prompt_count, chat.ChatCompletionSystemMessageParam(role='system', content=instructions) | ||
| ) | ||
| return groq_messages | ||
| | ||
| @staticmethod | ||
| | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -327,6 +327,9 @@ async def _map_messages( | |
| ) -> list[ChatCompletionInputMessage | ChatCompletionOutputMessage]: | ||
| """Just maps a `pydantic_ai.Message` to a `huggingface_hub.ChatCompletionInputMessage`.""" | ||
| hf_messages: list[ChatCompletionInputMessage | ChatCompletionOutputMessage] = [] | ||
| system_prompt_count = sum( | ||
| 1 for m in messages if isinstance(m, ModelRequest) for p in m.parts if isinstance(p, SystemPromptPart) | ||
| Collaborator There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same :) | ||
| ) | ||
| for message in messages: | ||
| if isinstance(message, ModelRequest): | ||
| async for item in self._map_user_message(message): | ||
| | @@ -361,7 +364,7 @@ async def _map_messages( | |
| else: | ||
| assert_never(message) | ||
| if instructions := self._get_instructions(messages, model_request_parameters): | ||
| hf_messages.insert(0, ChatCompletionInputMessage(content=instructions, role='system')) # type: ignore | ||
| hf_messages.insert(system_prompt_count, ChatCompletionInputMessage(content=instructions, role='system')) # type: ignore | ||
| return hf_messages | ||
| | ||
| @staticmethod | ||
| | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -528,6 +528,9 @@ def _map_messages( | |
| ) -> list[MistralMessages]: | ||
| """Just maps a `pydantic_ai.Message` to a `MistralMessage`.""" | ||
| mistral_messages: list[MistralMessages] = [] | ||
| system_prompt_count = sum( | ||
| Collaborator There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same! | ||
| 1 for m in messages if isinstance(m, ModelRequest) for p in m.parts if isinstance(p, SystemPromptPart) | ||
| ) | ||
| for message in messages: | ||
| if isinstance(message, ModelRequest): | ||
| mistral_messages.extend(self._map_user_message(message)) | ||
| | @@ -557,7 +560,7 @@ def _map_messages( | |
| else: | ||
| assert_never(message) | ||
| if instructions := self._get_instructions(messages, model_request_parameters): | ||
| mistral_messages.insert(0, MistralSystemMessage(content=instructions)) | ||
| mistral_messages.insert(system_prompt_count, MistralSystemMessage(content=instructions)) | ||
| | ||
| # Post-process messages to insert fake assistant message after tool message if followed by user message | ||
| # to work around `Unexpected role 'user' after role 'tool'` error. | ||
| | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -831,7 +831,11 @@ async def _map_messages( | |
| ) -> list[chat.ChatCompletionMessageParam]: | ||
| """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`.""" | ||
| openai_messages: list[chat.ChatCompletionMessageParam] = [] | ||
| system_prompt_count = 0 | ||
| for message in messages: | ||
| for part in message.parts: | ||
| if isinstance(part, SystemPromptPart): | ||
| system_prompt_count += 1 | ||
| Collaborator There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And same :) | ||
| if isinstance(message, ModelRequest): | ||
| async for item in self._map_user_message(message): | ||
| openai_messages.append(item) | ||
| | @@ -840,7 +844,9 @@ async def _map_messages( | |
| else: | ||
| assert_never(message) | ||
| if instructions := self._get_instructions(messages, model_request_parameters): | ||
| openai_messages.insert(0, chat.ChatCompletionSystemMessageParam(content=instructions, role='system')) | ||
| openai_messages.insert( | ||
| system_prompt_count, chat.ChatCompletionSystemMessageParam(content=instructions, role='system') | ||
| ) | ||
| return openai_messages | ||
| | ||
| @staticmethod | ||
| | @@ -1313,7 +1319,12 @@ async def _responses_create( # noqa: C901 | |
| # > Response input messages must contain the word 'json' in some form to use 'text.format' of type 'json_object'. | ||
| # Apparently they're only checking input messages for "JSON", not instructions. | ||
| assert isinstance(instructions, str) | ||
| openai_messages.insert(0, responses.EasyInputMessageParam(role='system', content=instructions)) | ||
| system_prompt_count = sum( | ||
| 1 for m in messages if isinstance(m, ModelRequest) for p in m.parts if isinstance(p, SystemPromptPart) | ||
| Collaborator There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And here as well. We should find the right index in the actual | ||
| ) | ||
| openai_messages.insert( | ||
| system_prompt_count, responses.EasyInputMessageParam(role='system', content=instructions) | ||
| ) | ||
| instructions = OMIT | ||
| | ||
| if verbosity := model_settings.get('openai_text_verbosity'): | ||
| | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will result in a wrong index when the system prompt parts are not all of the beginning, and it seems brittle to use an index derived from Pydantic AI
messagesin a list namedcohere_messages, which may or may not map 1:1. So I think we should count system parts at the start of the actualcohere_messageslist instead