- Notifications
You must be signed in to change notification settings - Fork 1.9k
[Model] [Serve] Add support for LLaVa model in serving engine #1962
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| would be great to also support bas64 images as per reference |
| Yes I will work on supporting base64 images as well. Can you review this PR and merge in the meantime? |
MasterJH5574 left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @anibohara2000, great job! I did a round of quick pass.
| | ||
| ObjectRef ImageEmbed(const NDArray& image, ObjectRef* dst, int offset) final { | ||
| CHECK(ft_.image_embed_func_.defined()) << "`image_embed` function is not found in the model. "; | ||
| auto image_dref_or_nd = ft_.CopyToWorker0(image, "image", image.Shape()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we want to pass the maximum possible shape in the third argument, through which we reserve the NDArray to the maximum possible size when allocating. Do all the images of the image embedding func have the same shape? If this is always true, we can pass image.Shape(). Or otherwise, we need to pass the maximum shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, in the get_image_from_url function, we process the image to be of a fixed size. This image is then passed to the model, so all images will be of same size for a given model
cpp/serve/model.cc Outdated
| picojson::object vision_config = config["vision_config"].get<picojson::object>(); | ||
| int image_size = -1; | ||
| int patch_size = -1; | ||
| if (vision_config.count("image_size")) { | ||
| CHECK(vision_config["image_size"].is<int64_t>()); | ||
| } else { | ||
| LOG(FATAL) << "Key \"image_size\" not found in vision_config."; | ||
| } | ||
| if (vision_config.count("patch_size")) { | ||
| CHECK(vision_config["patch_size"].is<int64_t>()); | ||
| } else { | ||
| LOG(FATAL) << "Key \"patch_size\" not found in vision_config."; | ||
| } | ||
| this->image_embed_size_ = (image_size * image_size) / (patch_size * patch_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, did you set the values of image_size and patch_size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing out. These values are included in mlc-chat-config.json for the LLaVa model, but this was an old flow which I am not using right now, so removing this for now.
| | ||
| import fastapi | ||
| import requests | ||
| import tvm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also lazily import requests and tvm in get_image_from_url, given they are not depended by all the functions in this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it!
| | ||
| _models: Dict[str, async_engine.AsyncThreadedEngine] = {} | ||
| _conv_templates: Dict[str, Conversation] = {} | ||
| _model_config_paths: Dict[str, str] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's load the model config json and save the config dictionary in ServerContext. So we don't need to parse and read the JSON every time in the entrypoint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that makes sense. Changed.
| def get_image_embed_size(config_file_path: str) -> int: | ||
| """Get the image embedding size from the model config file.""" | ||
| with open(config_file_path, "r", encoding="utf-8") as file: | ||
| config = json.load(file) | ||
| image_size = config["model_config"]["vision_config"]["image_size"] | ||
| patch_size = config["model_config"]["vision_config"]["patch_size"] | ||
| embed_size = (image_size // patch_size) ** 2 | ||
| return embed_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check out the other comment. Here we can accept the config dict and reduce the need of loading JSON.
| model_config_path = ServerContext.get_model_config_path(request.model) | ||
| image_embed_size = entrypoint_utils.get_image_embed_size(model_config_path) | ||
| | ||
| if content_has_list: | ||
| prompts = entrypoint_utils.process_prompts( | ||
| conv_template.as_prompt_list(image_embed_size=image_embed_size), | ||
| async_engine.tokenizer.encode, | ||
| ) | ||
| else: | ||
| prompts = entrypoint_utils.process_prompts( | ||
| conv_template.as_prompt(), async_engine.tokenizer.encode | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just want to mark a future todo item. Let's unify as_prompt_list and as_prompt in the future. Could you help add a TODO in the code here?
| async_engine.record_event(request_id, event="invoke generate") | ||
| finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] | ||
| async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): | ||
| async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): # type: ignore # pylint: disable=line-too-long |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original line has 85 characters, which fits our black and pylint limit (that is 100 characters). I assume there is no need to add the ignore and disable pylint here. Is there anything wrong with your settings?
Lines 22 to 34 in edffce4
| [tool.black] | |
| line-length = 100 | |
| [tool.mypy] | |
| ignore_missing_imports = true | |
| show_column_numbers = true | |
| show_error_context = true | |
| follow_imports = "skip" | |
| ignore_errors = false | |
| strict_optional = false | |
| [tool.pylint.messages_control] | |
| max-line-length = 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was getting mypy errors, I changed List->Sequence to handle those errors. Removing # type: ignore # pylint: disable=line-too-long now
| ) | ||
| async_engine.record_event(request_id, event="invoke generate") | ||
| async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): | ||
| async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): # type: ignore # pylint: disable=line-too-long |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
| "max_batch_size": int, | ||
| "max_total_seq_len": int, | ||
| "prefill_chunk_size": int, | ||
| "page_size": int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#1967 supports another parameter support_sliding_window. You may need to rebase to the latest main and add the parameter here and in the definition of create_paged_kv_cache.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And also rebase to resolve the conflict in conversation_protocol.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did something wrong when rebasing. Closing this PR and opening a new one. Sorry for that
| Besides the comments above, could you also add test cases for the llava support? Test cases are helpful since it not only ensure the correctness, but also enable others to understand the basic flow of how things work and reproduce your tests. Specifically, here are two tests I think good to have.
|
* small fix * small fix * Update stablelm_model.py
`test_server::is_json_or_json_prefix` is used to check the output is JSON or a prefix of JSON. It uses json.loads internally. However, json.loads (i.e. json.decode) is token-based instead of char based. If half a token is left at the end of the string, it cannot be matched. This PR adds another check for the rest "half a token" if it exists.
This PR migrates the mistral model to the PagedKVCache interface which supports sliding window attention with paged attention kernel written in TensorIR. We thereby introduce a `support_sliding_window` mode for KV cache, which leaves space for supporting sliding window for any model at runtime. This PR tests the mistral on with both chat and serve. The chat performance of Mistral 7B gets improvement than before, benefitted from the paged attention implementation.
* [Docs][Upd] Server launch, examples for endpoints for MLC Serve * remove v1/completions * add api docs to rest --------- Co-authored-by: Shrey Gupta <shrey2809@gmail.com>
2. Save model config instead of path in ServerContext 3. Sliding window parameter in create_paged_kv_cache 4. Remove pylint line-too-long
This PR adds support for LLaVa-v1.5 model on the serving engine. Use the HF weights and config from https://huggingface.co/llava-hf/llava-1.5-7b-hf
Passing image input is supported as url (reference: https://platform.openai.com/docs/guides/vision )
Example: