Skip to content

Commit

Permalink
[python] Beam search support
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Jan 6, 2025
1 parent 2df533b commit 7d44fbf
Showing 1 changed file with 124 additions and 9 deletions.
133 changes: 124 additions & 9 deletions engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
# the specific language governing permissions and limitations under the License.
from collections import OrderedDict, defaultdict

from vllm import LLMEngine, SamplingParams
from vllm import LLMEngine
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.utils import random_uuid, AtomicCounter

from djl_python.request import Request
Expand All @@ -24,9 +28,8 @@
from typing import List, Optional

# FIXME: Once all vllm versions are past 0.6.0 we can move to just struct_fields
VLLM_GENERATION_PARAMS = set(SamplingParams().__struct_fields__) if hasattr(
SamplingParams(), "__struct_fields__") else set(
SamplingParams().__dict__.keys())
VLLM_GENERATION_PARAMS = set(SamplingParams().__struct_fields__).union(
{"use_beam_search"})


class VLLMRollingBatch(RollingBatch):
Expand Down Expand Up @@ -78,6 +81,7 @@ def translate_vllm_params(self, parameters: dict) -> dict:
:return: The same parameters dict, but with VLLM style parameter names.
"""
parameters["output_kind"] = RequestOutputKind.DELTA
parameters["max_tokens"] = parameters.pop("max_new_tokens", 30)
if "seed" in parameters.keys():
parameters["seed"] = int(parameters["seed"])
Expand Down Expand Up @@ -123,24 +127,33 @@ def inference(self, new_requests: List[Request]) -> List:
:return results: List of dictionaries, one for each request, that contain output tokens and other data.
"""
self.add_new_requests(new_requests)

# step 0: register new requests to engine
for request in new_requests:
request_id = random_uuid()
prompt_inputs = get_prompt_inputs(request)
params = self.translate_vllm_params(request.parameters)
sampling_params = SamplingParams(**params)
request_params = dict()
if request.adapter is not None:
adapter_name = request.adapter.get_property("name")
request_params["lora_request"] = get_lora_request(
adapter_name, self.lora_requests)
self.engine.add_request(request_id=request_id,
inputs=prompt_inputs,
params=sampling_params,
**request_params)
self.request_cache[request_id] = {
"request_output": request.request_output
}
if "use_beam_search" in params:
beam_search_params = self.to_beam_search_params(params)
request_output = self.beam_search(request_id=request_id,
prompt=prompt_inputs,
params=beam_search_params)
self.request_cache = update_request_cache_with_output(
self.request_cache, request_output, self.get_tokenizer())
else:
sampling_params = SamplingParams(**params)
self.engine.add_request(request_id=request_id,
inputs=prompt_inputs,
params=sampling_params,
**request_params)
request_outputs = self.engine.step()

# step 1: put result to cache and request_output
Expand All @@ -155,6 +168,108 @@ def inference(self, new_requests: List[Request]) -> List:

return self.postprocess_results()

def to_beam_search_params(self, parameters: dict) -> BeamSearchParams:
return BeamSearchParams(beam_width=parameters.get("n", 1),
max_tokens=parameters["max_tokens"],
ignore_eos=parameters.get("ignore_eos", False),
temperature=parameters.get("temperature", 0.0),
length_penalty=parameters.get(
"length_penalty", 1.0))

def beam_search(
self,
request_id: str,
prompt: PromptType,
params: BeamSearchParams,
):
beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty

prompt_token_ids = prompt if "prompt_token_ids" in prompt else self.get_tokenizer(
).encode(prompt["prompt"])
tokenized_length = len(prompt_token_ids)

sort_beams_key = create_sort_beams_key_function(
self.get_tokenizer().eos_token_id, length_penalty)

beam_search_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
)
all_beams = [
BeamSearchSequence(tokens=prompt_token_ids, cum_logprob=0)
]
completed = []

for _ in range(max_tokens):
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
]

beam_search_request_id = f"beam_search-{random_uuid()}"

# output = self.generate(prompts_batch, beam_search_params,
# request_id)
for i, individual_prompt in enumerate(prompts_batch):
self.engine.add_request(
request_id=f"{beam_search_request_id}-{i}",
prompt=individual_prompt,
params=beam_search_params)
output = self.engine.step()
output = self.engine.step()

new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]

if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

if token_id == self.get_tokenizer().eos_token_id and \
not ignore_eos:
completed.append(new_beam)
else:
new_beams.append(new_beam)

sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]

completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]

for beam in best_beams:
beam.text = self.get_tokenizer().decode(
beam.tokens[tokenized_length:])

beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens,
index=i,
logprobs=beam.cum_logprob,
) for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None)

return beam_search_output

def preprocess_requests(self, requests):
"""
Currently not applicable for VLLM.
Expand Down

0 comments on commit 7d44fbf

Please sign in to comment.