Skip to content

Conversation

@anibohara2000
Copy link
Contributor

This PR adds support for LLaVa-v1.5 model on the serving engine. Use the HF weights and config from https://huggingface.co/llava-hf/llava-1.5-7b-hf

Passing image input is supported as url (reference: https://platform.openai.com/docs/guides/vision )
Example:

data = { "model": "dist/llava-1.5-7b-hf-q4f16_1-MLC/params/", "messages": [ { "role": "user", "content": [ { "type": "image_url", "image_url": "https://llava-vl.github.io/static/images/view.jpg", }, {"type": "text", "text": "What does this image represent?"}, ], } ] } response = requests.post("http://127.0.0.1:8000/v1/chat/completions", json=data) print("Response body:", response.text)
@tqchen
Copy link
Contributor

tqchen commented Mar 16, 2024

would be great to also support bas64 images as per reference

@anibohara2000
Copy link
Contributor Author

Yes I will work on supporting base64 images as well. Can you review this PR and merge in the meantime?

Copy link
Member

@MasterJH5574 MasterJH5574 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @anibohara2000, great job! I did a round of quick pass.


ObjectRef ImageEmbed(const NDArray& image, ObjectRef* dst, int offset) final {
CHECK(ft_.image_embed_func_.defined()) << "`image_embed` function is not found in the model. ";
auto image_dref_or_nd = ft_.CopyToWorker0(image, "image", image.Shape());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we want to pass the maximum possible shape in the third argument, through which we reserve the NDArray to the maximum possible size when allocating. Do all the images of the image embedding func have the same shape? If this is always true, we can pass image.Shape(). Or otherwise, we need to pass the maximum shape.

Copy link
Contributor Author

@anibohara2000 anibohara2000 Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in the get_image_from_url function, we process the image to be of a fixed size. This image is then passed to the model, so all images will be of same size for a given model

Comment on lines 420 to 433
picojson::object vision_config = config["vision_config"].get<picojson::object>();
int image_size = -1;
int patch_size = -1;
if (vision_config.count("image_size")) {
CHECK(vision_config["image_size"].is<int64_t>());
} else {
LOG(FATAL) << "Key \"image_size\" not found in vision_config.";
}
if (vision_config.count("patch_size")) {
CHECK(vision_config["patch_size"].is<int64_t>());
} else {
LOG(FATAL) << "Key \"patch_size\" not found in vision_config.";
}
this->image_embed_size_ = (image_size * image_size) / (patch_size * patch_size);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, did you set the values of image_size and patch_size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out. These values are included in mlc-chat-config.json for the LLaVa model, but this was an old flow which I am not using right now, so removing this for now.


import fastapi
import requests
import tvm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also lazily import requests and tvm in get_image_from_url, given they are not depended by all the functions in this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it!


_models: Dict[str, async_engine.AsyncThreadedEngine] = {}
_conv_templates: Dict[str, Conversation] = {}
_model_config_paths: Dict[str, str] = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's load the model config json and save the config dictionary in ServerContext. So we don't need to parse and read the JSON every time in the entrypoint.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that makes sense. Changed.

Comment on lines 125 to 132
def get_image_embed_size(config_file_path: str) -> int:
"""Get the image embedding size from the model config file."""
with open(config_file_path, "r", encoding="utf-8") as file:
config = json.load(file)
image_size = config["model_config"]["vision_config"]["image_size"]
patch_size = config["model_config"]["vision_config"]["patch_size"]
embed_size = (image_size // patch_size) ** 2
return embed_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check out the other comment. Here we can accept the config dict and reduce the need of loading JSON.

Comment on lines 409 to 420
model_config_path = ServerContext.get_model_config_path(request.model)
image_embed_size = entrypoint_utils.get_image_embed_size(model_config_path)

if content_has_list:
prompts = entrypoint_utils.process_prompts(
conv_template.as_prompt_list(image_embed_size=image_embed_size),
async_engine.tokenizer.encode,
)
else:
prompts = entrypoint_utils.process_prompts(
conv_template.as_prompt(), async_engine.tokenizer.encode
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to mark a future todo item. Let's unify as_prompt_list and as_prompt in the future. Could you help add a TODO in the code here?

async_engine.record_event(request_id, event="invoke generate")
finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)]
async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id):
async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): # type: ignore # pylint: disable=line-too-long
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original line has 85 characters, which fits our black and pylint limit (that is 100 characters). I assume there is no need to add the ignore and disable pylint here. Is there anything wrong with your settings?

mlc-llm/pyproject.toml

Lines 22 to 34 in edffce4

[tool.black]
line-length = 100
[tool.mypy]
ignore_missing_imports = true
show_column_numbers = true
show_error_context = true
follow_imports = "skip"
ignore_errors = false
strict_optional = false
[tool.pylint.messages_control]
max-line-length = 100

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was getting mypy errors, I changed List->Sequence to handle those errors. Removing # type: ignore # pylint: disable=line-too-long now

)
async_engine.record_event(request_id, event="invoke generate")
async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id):
async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): # type: ignore # pylint: disable=line-too-long
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

"max_batch_size": int,
"max_total_seq_len": int,
"prefill_chunk_size": int,
"page_size": int,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#1967 supports another parameter support_sliding_window. You may need to rebase to the latest main and add the parameter here and in the definition of create_paged_kv_cache.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And also rebase to resolve the conflict in conversation_protocol.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did something wrong when rebasing. Closing this PR and opening a new one. Sorry for that

@MasterJH5574
Copy link
Member

MasterJH5574 commented Mar 17, 2024

Besides the comments above, could you also add test cases for the llava support? Test cases are helpful since it not only ensure the correctness, but also enable others to understand the basic flow of how things work and reproduce your tests.

Specifically, here are two tests I think good to have.

  1. could you add tests/python/serve/test_serve_engine_image.py that tests Llava through Engine.generate? The test can be adapted from test_engine_generate in test_serve_engine.py

    def test_engine_generate():
    # Initialize model loading info and KV cache config
    model = ModelInfo(
    "dist/Llama-2-7b-chat-hf-q0f16-MLC",
    model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so",
    )
    kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096)
    # Create engine
    engine = Engine(model, kv_cache_config)
    num_requests = 10
    max_tokens = 256
    # Generate output.
    output_texts, _ = engine.generate(
    prompts[:num_requests], GenerationConfig(max_tokens=max_tokens)
    )
    for req_id, outputs in enumerate(output_texts):
    print(f"Prompt {req_id}: {prompts[req_id]}")
    if len(outputs) == 1:
    print(f"Output {req_id}:{outputs[0]}\n")
    else:
    for i, output in enumerate(outputs):
    print(f"Output {req_id}({i}):{output}\n")

  2. could you add tests/python/serve/server/test_server_image.py that tests Llava through OpenAI API? You can refer to test_server.py for examples. One test that tests the basic functionality should enough, and we can iterate the test cases and add more in the future.

MasterJH5574 and others added 11 commits March 18, 2024 11:22
) This PR supports the detection of if FlashInfer is enabled when building TVM, so that FlashInfer won't be enabled when TVM is not built with FlashInfer enabled.
* small fix * small fix * Update stablelm_model.py
`test_server::is_json_or_json_prefix` is used to check the output is JSON or a prefix of JSON. It uses json.loads internally. However, json.loads (i.e. json.decode) is token-based instead of char based. If half a token is left at the end of the string, it cannot be matched. This PR adds another check for the rest "half a token" if it exists.
This PR migrates the mistral model to the PagedKVCache interface which supports sliding window attention with paged attention kernel written in TensorIR. We thereby introduce a `support_sliding_window` mode for KV cache, which leaves space for supporting sliding window for any model at runtime. This PR tests the mistral on with both chat and serve. The chat performance of Mistral 7B gets improvement than before, benefitted from the paged attention implementation.
* [Docs][Upd] Server launch, examples for endpoints for MLC Serve * remove v1/completions * add api docs to rest --------- Co-authored-by: Shrey Gupta <shrey2809@gmail.com>
2. Save model config instead of path in ServerContext 3. Sliding window parameter in create_paged_kv_cache 4. Remove pylint line-too-long
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

7 participants