Skip to content

Commit

Permalink
update unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk committed Jan 8, 2025
1 parent 1db8a0e commit e1d7c99
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,62 @@
from djl_python.properties_manager.properties import Properties

DTYPE_MAPPER = {
"float32": "float32",
"fp32": "float32",
"float16": "float16",
"fp16": "float16",
"bfloat16": "bfloat16",
"bf16": "bfloat16",
"auto": "auto"
}


def construct_vllm_args_list(vllm_engine_args: dict,
parser: FlexibleArgumentParser):
# Modified from https://github.com/vllm-project/vllm/blob/v0.6.4/vllm/utils.py#L1258
args_list = []
store_boolean_arguments = {
action.dest
for action in parser._actions if isinstance(action, StoreBoolean)
}
for engine_arg, engine_arg_value in vllm_engine_args.items():
if str(engine_arg_value).lower() in {
'true', 'false'
} and engine_arg not in store_boolean_arguments:
if str(engine_arg_value).lower() == 'true':
args_list.append(f"--{engine_arg}")
else:
args_list.append(f"--{engine_arg}={engine_arg_value}")
return args_list


class VllmRbProperties(Properties):
engine: Optional[str] = None
# The following configs have different names in DJL compared to vLLM, we only accept DJL name currently
tensor_parallel_degree: int = 1
pipeline_parallel_degree: int = 1
# The following configs have different names in DJL compared to vLLM, either is accepted
quantize: Optional[str] = Field(alias="quantization", default=None)
quantize: Optional[str] = Field(alias="quantization",
default=EngineArgs.quantization)
max_rolling_batch_prefill_tokens: Optional[int] = Field(
alias="max_num_batched_tokens", default=None)
cpu_offload_gb_per_gpu: Optional[float] = Field(alias="cpu_offload_gb",
default=None)
alias="max_num_batched_tokens",
default=EngineArgs.max_num_batched_tokens)
cpu_offload_gb_per_gpu: float = Field(alias="cpu_offload_gb",
default=EngineArgs.cpu_offload_gb)
# The following configs have different defaults, or additional processing in DJL compared to vLLM
dtype: str = "auto"
max_loras: int = 4
# The following configs have broken processing in vllm via the FlexibleArgumentParser
long_lora_scaling_factors: Optional[Tuple[float, ...]] = None
use_v2_block_manager: bool = True

# Neuron vLLM properties
device: Optional[str] = None
device: str = 'auto'
preloaded_model: Optional[Any] = None
generation_config: Optional[Any] = None

# This allows generic vllm engine args to be passed in and set with vllm
model_config = ConfigDict(extra='allow')
model_config = ConfigDict(extra='allow', populate_by_name=True)

@field_validator('engine')
def validate_engine(cls, engine):
Expand All @@ -59,6 +85,14 @@ def validate_engine(cls, engine):
f"Need python engine to start vLLM RollingBatcher")
return engine

@field_validator('dtype')
def validate_dtype(cls, val):
if val not in DTYPE_MAPPER:
raise ValueError(
f"Invalid dtype={val} provided. Must be one of {DTYPE_MAPPER.keys()}"
)
return DTYPE_MAPPER[val]

@model_validator(mode='after')
def validate_pipeline_parallel(self):
if self.pipeline_parallel_degree != 1:
Expand All @@ -67,9 +101,9 @@ def validate_pipeline_parallel(self):
)
return self

@field_validator('long_lora_scaling_factors', mode='before')
# TODO: processing of this field is broken in vllm via from_cli_args
# we should upstream a fix for this to vllm
@field_validator('long_lora_scaling_factors', mode='before')
def validate_long_lora_scaling_factors(cls, val):
if isinstance(val, str):
val = ast.literal_eval(val)
Expand All @@ -96,7 +130,7 @@ def validate_potential_lmi_vllm_config_conflict(
if vllm_config_val != lmi_config_val:
raise ValueError(
f"Both the DJL {lmi_config_val}={lmi_config_val} and vLLM {vllm_config_name}={vllm_config_val} configs have been set with conflicting values."
f"We currently only accept the DJL config {lmi_config_val}, please remove the vllm {vllm_config_name} configuration."
f"We currently only accept the DJL config {lmi_config_name}, please remove the vllm {vllm_config_name} configuration."
)

validate_potential_lmi_vllm_config_conflict("tensor_parallel_degree",
Expand All @@ -117,20 +151,18 @@ def generate_vllm_engine_arg_dict(self,
'revision': self.revision,
'max_loras': self.max_loras,
'enable_lora': self.enable_lora,
'trust_remote_code': self.trust_remote_code,
'cpu_offload_gb': self.cpu_offload_gb_per_gpu,
'use_v2_block_manager': self.use_v2_block_manager,
'quantization': self.quantize,
'device': self.device,
}
if self.quantize is not None:
vllm_engine_args['quantization'] = self.quantize
if self.max_rolling_batch_prefill_tokens is not None:
vllm_engine_args[
'max_num_batched_tokens'] = self.max_rolling_batch_prefill_tokens
if self.cpu_offload_gb_per_gpu is not None:
vllm_engine_args['cpu_offload_gb'] = self.cpu_offload_gb_per_gpu
if self.device is not None:
vllm_engine_args['device'] = self.device
if self.preloaded_model is not None:
vllm_engine_args['preloaded_model'] = self.preloaded_model
if self.generation_config is not None:
vllm_engine_args['generation_config'] = self.generation_config
if self.device == 'neuron':
vllm_engine_args['block_size'] = passthrough_vllm_engine_args.get(
"max_model_len")
vllm_engine_args.update(passthrough_vllm_engine_args)
return vllm_engine_args

Expand All @@ -143,11 +175,15 @@ def get_engine_args(self) -> EngineArgs:
f"Construction vLLM engine args from the following DJL configs: {vllm_engine_arg_dict}"
)
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args_list = self.construct_vllm_args_list(vllm_engine_arg_dict, parser)
args_list = construct_vllm_args_list(vllm_engine_arg_dict, parser)
args = parser.parse_args(args=args_list)
engine_args = EngineArgs.from_cli_args(args)
# we have to do this separately because vllm converts it into a string
engine_args.long_lora_scaling_factors = self.long_lora_scaling_factors
# These neuron configs are not implemented in the vllm arg parser
if self.device == 'neuron':
setattr(engine_args, 'preloaded_model', self.preloaded_model)
setattr(engine_args, 'generation_config', self.generation_config)
return engine_args

def get_additional_vllm_engine_args(self) -> Dict[str, Any]:
Expand All @@ -156,21 +192,3 @@ def get_additional_vllm_engine_args(self) -> Dict[str, Any]:
for k, v in self.__pydantic_extra__.items()
if k in EngineArgs.__annotations__
}

def construct_vllm_args_list(self, vllm_engine_args: dict,
parser: FlexibleArgumentParser):
# Modified from https://github.com/vllm-project/vllm/blob/v0.6.4/vllm/utils.py#L1258
args_list = []
store_boolean_arguments = {
action.dest
for action in parser._actions if isinstance(action, StoreBoolean)
}
for engine_arg, engine_arg_value in vllm_engine_args.items():
if str(engine_arg_value).lower() in {
'true', 'false'
} and engine_arg not in store_boolean_arguments:
if str(engine_arg_value).lower() == 'true':
args_list.append(f"--{engine_arg}")
else:
args_list.append(f"--{engine_arg}={engine_arg_value}")
return args_list
Loading

0 comments on commit e1d7c99

Please sign in to comment.