diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index 66abbf811..35697f3c8 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -12,7 +12,11 @@ # the specific language governing permissions and limitations under the License. from collections import OrderedDict, defaultdict -from vllm import LLMEngine, SamplingParams +from vllm import LLMEngine +from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function +from vllm.inputs.data import PromptType, TokensPrompt +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.utils import random_uuid, AtomicCounter from djl_python.request import Request @@ -24,9 +28,8 @@ from typing import List, Optional # FIXME: Once all vllm versions are past 0.6.0 we can move to just struct_fields -VLLM_GENERATION_PARAMS = set(SamplingParams().__struct_fields__) if hasattr( - SamplingParams(), "__struct_fields__") else set( - SamplingParams().__dict__.keys()) +VLLM_GENERATION_PARAMS = set(SamplingParams().__struct_fields__).union( + {"use_beam_search"}) class VLLMRollingBatch(RollingBatch): @@ -78,6 +81,7 @@ def translate_vllm_params(self, parameters: dict) -> dict: :return: The same parameters dict, but with VLLM style parameter names. """ + parameters["output_kind"] = RequestOutputKind.DELTA parameters["max_tokens"] = parameters.pop("max_new_tokens", 30) if "seed" in parameters.keys(): parameters["seed"] = int(parameters["seed"]) @@ -123,24 +127,33 @@ def inference(self, new_requests: List[Request]) -> List: :return results: List of dictionaries, one for each request, that contain output tokens and other data. """ self.add_new_requests(new_requests) + # step 0: register new requests to engine for request in new_requests: request_id = random_uuid() prompt_inputs = get_prompt_inputs(request) params = self.translate_vllm_params(request.parameters) - sampling_params = SamplingParams(**params) request_params = dict() if request.adapter is not None: adapter_name = request.adapter.get_property("name") request_params["lora_request"] = get_lora_request( adapter_name, self.lora_requests) - self.engine.add_request(request_id=request_id, - inputs=prompt_inputs, - params=sampling_params, - **request_params) self.request_cache[request_id] = { "request_output": request.request_output } + if "use_beam_search" in params: + beam_search_params = self.to_beam_search_params(params) + request_output = self.beam_search(request_id=request_id, + prompt=prompt_inputs, + params=beam_search_params) + self.request_cache = update_request_cache_with_output( + self.request_cache, request_output, self.get_tokenizer()) + else: + sampling_params = SamplingParams(**params) + self.engine.add_request(request_id=request_id, + inputs=prompt_inputs, + params=sampling_params, + **request_params) request_outputs = self.engine.step() # step 1: put result to cache and request_output @@ -155,6 +168,108 @@ def inference(self, new_requests: List[Request]) -> List: return self.postprocess_results() + def to_beam_search_params(self, parameters: dict) -> BeamSearchParams: + return BeamSearchParams(beam_width=parameters.get("n", 1), + max_tokens=parameters["max_tokens"], + ignore_eos=parameters.get("ignore_eos", False), + temperature=parameters.get("temperature", 0.0), + length_penalty=parameters.get( + "length_penalty", 1.0)) + + def beam_search( + self, + request_id: str, + prompt: PromptType, + params: BeamSearchParams, + ): + beam_width = params.beam_width + max_tokens = params.max_tokens + ignore_eos = params.ignore_eos + temperature = params.temperature + length_penalty = params.length_penalty + + prompt_token_ids = prompt if "prompt_token_ids" in prompt else self.get_tokenizer( + ).encode(prompt["prompt"]) + tokenized_length = len(prompt_token_ids) + + sort_beams_key = create_sort_beams_key_function( + self.get_tokenizer().eos_token_id, length_penalty) + + beam_search_params = SamplingParams( + logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature, + ) + all_beams = [ + BeamSearchSequence(tokens=prompt_token_ids, cum_logprob=0) + ] + completed = [] + + for _ in range(max_tokens): + prompts_batch = [ + TokensPrompt(prompt_token_ids=beam.tokens) + for beam in all_beams + ] + + beam_search_request_id = f"beam_search-{random_uuid()}" + + # output = self.generate(prompts_batch, beam_search_params, + # request_id) + for i, individual_prompt in enumerate(prompts_batch): + self.engine.add_request( + request_id=f"{beam_search_request_id}-{i}", + prompt=individual_prompt, + params=beam_search_params) + output = self.engine.step() + output = self.engine.step() + + new_beams = [] + for i, current_beam in enumerate(all_beams): + result = output[i] + + if result.outputs[0].logprobs is not None: + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob) + + if token_id == self.get_tokenizer().eos_token_id and \ + not ignore_eos: + completed.append(new_beam) + else: + new_beams.append(new_beam) + + sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) + all_beams = sorted_beams[:beam_width] + + completed.extend(all_beams) + sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + beam.text = self.get_tokenizer().decode( + beam.tokens[tokenized_length:]) + + beam_search_output = RequestOutput( + request_id=request_id, + prompt=prompt, + outputs=[ + CompletionOutput( + text=beam.text, + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens, + index=i, + logprobs=beam.cum_logprob, + ) for (i, beam) in enumerate(best_beams) + ], + finished=True, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None) + + return beam_search_output + def preprocess_requests(self, requests): """ Currently not applicable for VLLM.