Skip to content

Commit

Permalink
[feat] Update multimodal lmi-dist/vllm input handling, and add lmi-di… (
Browse files Browse the repository at this point in the history
#2608)

…st mllama integration test
  • Loading branch information
davidthomas426 authored and siddvenk committed Jan 9, 2025
1 parent f39534e commit fbfcb52
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
from djl_python.telemetry import telemetry_manager
from djl_python.properties_manager.lmi_dist_rb_properties import LmiDistRbProperties

LMI_DIST_GENERATION_PARAMS = set(RequestParams().__dict__.keys()).union(
set(SamplingParams().__struct_fields__)) - {"sampling_params"}
LMI_DIST_GENERATION_PARAMS = set(RequestParams().__struct_fields__)


class LmiDistRollingBatch(RollingBatch):
Expand Down Expand Up @@ -188,7 +187,7 @@ def inference(self, new_requests: List[Request]) -> List:
new_lmi_dist_requests = []
for request in new_requests:
request_id = str(request.id)
llm_input = get_prompt_inputs(request)
prompt_inputs = get_prompt_inputs(request)
params = self.translate_lmi_dist_params(request.parameters)
request_params = RequestParams(**params)
lora_request_params = dict()
Expand All @@ -197,13 +196,10 @@ def inference(self, new_requests: List[Request]) -> List:
lora_request_params["lora_request"] = get_lora_request(
adapter_name, self.lora_requests)
# Constructing Request in lmi-dist library
lmi_dist_request = Request(
id=request_id,
prompt=llm_input.get("prompt"),
prompt_token_ids=llm_input.get("prompt_token_ids"),
multi_modal_input=llm_input.get("multi_modal_data"),
params=request_params,
**lora_request_params)
lmi_dist_request = Request(id=request_id,
prompt=prompt_inputs,
params=request_params,
**lora_request_params)
new_lmi_dist_requests.append(lmi_dist_request)
self.request_cache[request_id] = {
"request_output": request.request_output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def get_engine_args_from_config(config: VllmRbProperties) -> EngineArgs:
)


def get_multi_modal_data(request: Request) -> dict:
def get_multi_modal_data(request: Request) -> Optional[dict]:
parameters = request.parameters
images = parameters.pop("images", None)
multi_modal_data = None
Expand All @@ -320,8 +320,9 @@ def get_prompt_inputs(request: Request):
# In both HuggingFace and mistral cases, that process can also yield token-ids directly
# that we may want to consider passing directly to the engine
if isinstance(text_prompt, list):
return TokensPrompt(prompt_token_ids=text_prompt,
multi_modal_data=multi_modal_data)
prompt = TokensPrompt(prompt_token_ids=text_prompt)
else:
return TextPrompt(prompt=text_prompt,
multi_modal_data=multi_modal_data)
prompt = TextPrompt(prompt=text_prompt)

if multi_modal_data is not None:
prompt["multi_modal_data"] = multi_modal_data
18 changes: 12 additions & 6 deletions tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,12 @@
"option.tensor_parallel_degree": 4,
"option.device_map": "auto"
},
"llama-3.1-8b": {
"option.model_id": "s3://djl-llm/llama-3.1-8b-hf/",
"option.task": "text-generation",
"option.tensor_parallel_degree": 4,
"option.max_rolling_batch_size": 4
},
"llava_v1.6-mistral": {
"option.model_id": "s3://djl-llm/llava-v1.6-mistral-7b-hf/",
"option.limit_mm_per_prompt": "image=4",
Expand All @@ -643,12 +649,6 @@
"option.trust_remote_code": True,
"option.max_model_len": 8192,
},
"llama-3.1-8b": {
"option.model_id": "s3://djl-llm/llama-3.1-8b-hf/",
"option.task": "text-generation",
"option.tensor_parallel_degree": 4,
"option.max_rolling_batch_size": 4
},
"pixtral-12b": {
"option.model_id": "s3://djl-llm/pixtral-12b/",
"option.max_model_len": 8192,
Expand All @@ -657,6 +657,12 @@
"option.limit_mm_per_prompt": "image=4",
"option.entryPoint": "djl_python.huggingface"
},
"llama32-11b-multimodal": {
"option.model_id": "s3://djl-llm/llama-3-2-11b-vision-instruct/",
"option.max_model_len": 8192,
"option.max_rolling_batch_size": 16,
"option.enforce_eager": True,
},
"llama32-3b-multi-worker-tp1-pp1": {
"option.model_id": "s3://djl-llm/llama-3-2-3b-instruct/",
"option.tensor_parallel_degree": 1,
Expand Down
6 changes: 6 additions & 0 deletions tests/integration/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,12 @@ def test_pixtral_12b(self):
r.launch()
client.run("multimodal pixtral-12b".split())

def test_mllama_11b(self):
with Runner('lmi', 'llama32-11b-multimodal') as r:
prepare.build_lmi_dist_model('llama32-11b-multimodal')
r.launch()
client.run("multimodal llama32-11b-multimodal".split())


class TestMultiModalVllm:

Expand Down

0 comments on commit fbfcb52

Please sign in to comment.