Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix device selection for compilation language model in vlm and model saving #967

Merged
merged 13 commits into from
Nov 15, 2024
5 changes: 4 additions & 1 deletion optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ def __init__(
for inputs in self.model.inputs
}
self.ov_config = ov_config or {**self.parent_model.ov_config}
self.request = None
self.request = None if not self.parent_model._compile_only else self.model
self._model_name = model_name
self.config = self.parent_model.config
self._model_dir = Path(model_dir or parent_model._model_save_dir)
Expand Down Expand Up @@ -832,3 +832,6 @@ def __call__(self, *args, **kwargs):

def forward(self, *args, **kwargs):
raise NotImplementedError

def clear_requests(self):
self.request = None
150 changes: 99 additions & 51 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from transformers import (
AutoConfig,
AutoImageProcessor,
AutoModelForCausalLM,
GenerationConfig,
GenerationMixin,
PretrainedConfig,
Expand All @@ -30,7 +31,23 @@
from .configuration import OVConfig, OVWeightQuantizationConfig
from .modeling_base import OVBaseModel, OVModelPart
from .modeling_decoder import CausalLMOutputWithPast, OVModelForCausalLM
from .utils import TemporaryDirectory
from .utils import (
OV_LANGUAGE_MODEL_NAME,
OV_TEXT_EMBEDDINGS_MODEL_NAME,
OV_VISION_EMBEDDINGS_MODEL_NAME,
TemporaryDirectory,
)


try:
from transformers import LlavaForConditionalGeneration
except ImportError:
LlavaForConditionalGeneration = None

try:
from transformers import LlavaNextForConditionalGeneration
except ImportError:
LlavaNextForConditionalGeneration = None


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,13 +84,19 @@ def __init__(
def compile(self):
if self.request is None:
logger.info(f"Compiling the Language model to {self._device} ...")
self.request = core.compile_model(self.model, self._device, self.ov_config).create_infer_request()
super().compile()
self._compile_text_emb()

def _compile_text_emb(self):
if self.text_emb_request is None:
logger.info(f"Compiling the Text embeddings model to {self._device} ...")
self.text_emb_request = core.compile_model(self.text_emb_model, self._device, self.ov_config)
if self._compile_only:
eaidova marked this conversation as resolved.
Show resolved Hide resolved
self.text_emb_request = self.text_emb_model
else:
logger.info(f"Compiling the Text embeddings model to {self._device} ...")
self.text_emb_request = self._compile_model(
self.text_emb_model, self._device, self.ov_config, self.model_save_dir
)

def clear_requests(self):
if self._compile_only:
Expand Down Expand Up @@ -238,12 +261,18 @@ def forward(self, img_features):
return self.request(img_features)[0]


MODEL_PARTS_CLS_MAPPING = {"resampler": OVResampler, "vision_projection": OVVisionProjection}
MODEL_PARTS_CLS_MAPPING = {
"resampler": OVResampler,
"language_model": OVModelWithEmbedForCausalLM,
"vision_embeddings": OVVisionEmbedding,
"vision_projection": OVVisionProjection,
}


class OVModelForVisualCausalLM(OVBaseModel, GenerationMixin):
export_feature = "image-text-to-text"
additional_parts = []
auto_model_class = AutoModelForCausalLM

def __init__(
self,
Expand Down Expand Up @@ -285,11 +314,11 @@ def __init__(
self.lm_model,
self.text_embeddings_model,
config=config,
deivce=device,
device=device,
ov_config=ov_config,
model_save_dir=model_save_dir,
quantization_config=quantization_config,
compile=not self._compile_only and enable_compilation,
compile=self._compile_only or enable_compilation,
compile_only=self._compile_only,
)
self.vision_embeddings = OVVisionEmbedding(self.vision_embeddings_model, self)
Expand All @@ -315,19 +344,15 @@ def clear_requests(self):
"`clear_requests()` is not supported with `compile_only` mode, please intialize model without this option"
)

self.language_model.clear_requests()
components = [self.vision_embeddings] + [getattr(self, part) for part in self.additional_parts]
for component in components:
if component is not None:
component.request = None
for _, component in self.components.items():
component.clear_requests()

def compile(self):
self.language_model.compile()
self.vision_embeddings._compile()
for part in self.additional_parts:
part_model = getattr(self, part, None)
if part_model is not None:
part_model._compile()
for _, component in self.components.items():
if isinstance(component, OVModelPart):
component._compile()
else:
component.compile()

def _save_config(self, save_directory):
"""
Expand All @@ -345,21 +370,21 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
save_directory (`str` or `Path`):
The directory where to save the model files.
"""
src_files = [self.lm_model, self.text_embeddings_model, self.vision_embeddings_model]
dst_file_names = [
"openvino_language_model.xml",
"openvino_text_embeddings_model.xml",
"openvino_vision_embeddings_model.xml",
]
for part in self.additional_parts:
model = getattr(self, f"{part}_model", None)
if model is not None:
src_files.append(model)
dst_file_names.append(f"openvino_{part}_model.xml")
src_models = self.submodels
dst_file_names = {
"lm_model": OV_LANGUAGE_MODEL_NAME,
"text_embeddings_model": OV_TEXT_EMBEDDINGS_MODEL_NAME,
"vision_embeddings_model": OV_VISION_EMBEDDINGS_MODEL_NAME,
}
for name in self._submodel_names:
if name not in dst_file_names:
dst_file_names[name] = f"openvino_{name}.xml"

for src_file, dst_file_name in zip(src_files, dst_file_names):
for name in self._submodel_names:
model = src_models[name]
dst_file_name = dst_file_names[name]
dst_path = os.path.join(save_directory, dst_file_name)
ov.save_model(src_file, dst_path, compress_to_fp16=False)
ov.save_model(model, dst_path, compress_to_fp16=False)

self._save_openvino_config(save_directory)
if self.generation_config is not None:
Expand Down Expand Up @@ -429,14 +454,18 @@ def _from_pretrained(
token = use_auth_token

model_file_names = {
"language_model": "openvino_language_model.xml",
"text_embeddings": "openvino_text_embeddings_model.xml",
"vision_embeddings": "openvino_vision_embeddings_model.xml",
"language_model": OV_LANGUAGE_MODEL_NAME,
"language_model_bin": OV_LANGUAGE_MODEL_NAME.replace(".xml", ".bin"),
"text_embeddings": OV_TEXT_EMBEDDINGS_MODEL_NAME,
"text_embeddings_bin": OV_TEXT_EMBEDDINGS_MODEL_NAME.replace(".xml", ".bin"),
"vision_embeddings": OV_VISION_EMBEDDINGS_MODEL_NAME,
"vision_embeddings_bin": OV_VISION_EMBEDDINGS_MODEL_NAME.replace(".xml", ".bin"),
}

model_cls = MODEL_TYPE_TO_CLS_MAPPING[config.model_type]
for part in model_cls.additional_parts:
model_file_names[part] = f"openvino_{part}_model.xml"
model_file_names[part + "_bin"] = f"openvino_{part}_model.bin"
compile_only = kwargs.get("compile_only", False)
if os.path.isdir(model_id):
# Load model from a local directory
Expand Down Expand Up @@ -593,6 +622,28 @@ def _from_transformers(
**kwargs,
)

@property
def _component_names(self):
base_components = ["language_model", "vision_embeddings"]
additional_components = [part for part in self.additional_parts if getattr(self, part, None) is not None]
return base_components + additional_components

@property
def components(self):
return {component_name: getattr(self, component_name) for component_name in self._component_names}

@property
def _submodel_names(self):
model_names = ["lm_model", "text_embeddings_model", "vision_embeddings_model"]
for part in self.additional_parts:
if getattr(self, part, None) is not None:
model_names.append(part + "_model")
return model_names

@property
def submodels(self):
return {submodel_name: getattr(self, submodel_name) for submodel_name in self._submodel_names}

def reshape(self, batch_size: int, sequence_length: int):
logger.warning("Static shapes are not supported for causal language model.")
return self
Expand All @@ -601,17 +652,14 @@ def half(self):
"""
Converts all the model weights to FP16 for more efficient inference on GPU.
"""
apply_moc_transformations(self.lm_model, cf=False)
compress_model_transformation(self.lm_model)
apply_moc_transformations(self.text_embeddings_model, cf=False)
compress_model_transformation(self.text_embeddings_model)
apply_moc_transformations(self.vision_embeddings_model, cf=False)
compress_model_transformation(self.vision_embeddings_model)
for part in self.additional_parts:
model = getattr(self, f"{part}_model", None)
if model is not None:
apply_moc_transformations(model, cf=False)
compress_model_transformation(model)
for _, submodel in self.submodels.items():
apply_moc_transformations(submodel, cf=False)
compress_model_transformation(submodel)
return self

def to(self, device):
self.language_model.to(device)
super().to(device)
AlexKoff88 marked this conversation as resolved.
Show resolved Hide resolved
return self

def forward(
Expand All @@ -625,11 +673,8 @@ def forward(
position_ids=None,
image_bound=None,
tgt_sizes=None,
images=None,
**kwargs,
):
if pixel_values is None and images is not None:
pixel_values = images
inputs_embeds, attention_mask, position_ids = self.get_multimodal_embeddings(
input_ids,
pixel_values,
Expand Down Expand Up @@ -733,7 +778,6 @@ def prepare_inputs_for_generation(
"image_sizes": image_sizes,
"image_bound": kwargs.get("image_bound"),
"tgt_sizes": kwargs.get("tgt_sizes"),
"images": kwargs.get("images"),
}
)
return model_inputs
Expand All @@ -756,6 +800,8 @@ def preprocess_inputs(


class _OVLlavaForCausalLM(OVModelForVisualCausalLM):
auto_model_class = LlavaForConditionalGeneration

def __init__(
self,
language_model: ov.Model,
Expand Down Expand Up @@ -941,6 +987,8 @@ def preprocess_inputs(


class _OVLlavaNextForCausalLM(_OVLlavaForCausalLM):
auto_model_class = LlavaNextForConditionalGeneration

# Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L655
def pack_image_features(self, image_features, image_sizes, image_newline=None):
from transformers.models.llava_next.modeling_llava_next import get_anyres_image_grid_shape, unpad_image
Expand Down Expand Up @@ -1211,7 +1259,7 @@ def get_text_embeddings(self, input_ids, **kwargs):
return super().get_text_embeddings(for_inputs_embeds_ids, **kwargs)


class _OvInternVLForCausalLM(OVModelForVisualCausalLM):
class _OVInternVLForCausalLM(OVModelForVisualCausalLM):
def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
return None
Expand Down Expand Up @@ -1822,7 +1870,7 @@ def preprocess_inputs(
attention_mask = torch.ones_like(input_ids, dtype=torch.int64)
result = {"input_ids": input_ids, "attention_mask": attention_mask}
if image is not None:
result["images"] = torch.unsqueeze(processor(images=image, return_tensors="pt")["pixel_values"][0], 0)
result["pixel_values"] = processor(images=[image], return_tensors="pt")["pixel_values"]
return result


Expand Down Expand Up @@ -1979,8 +2027,8 @@ def preprocess_inputs(
MODEL_TYPE_TO_CLS_MAPPING = {
"llava": _OVLlavaForCausalLM,
"llava_next": _OVLlavaNextForCausalLM,
"internvl_chat": _OvInternVLForCausalLM,
"minicpmv": _OVMiniCPMVForCausalLM,
"llava-qwen2": _OVNanoLlavaForCausalLM,
"phi3_v": _OVPhi3VisionForCausalLM,
"internvl_chat": _OVInternVLForCausalLM,
}
3 changes: 3 additions & 0 deletions optimum/intel/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
OV_ENCODER_NAME = "openvino_encoder_model.xml"
OV_DECODER_NAME = "openvino_decoder_model.xml"
OV_DECODER_WITH_PAST_NAME = "openvino_decoder_with_past_model.xml"
OV_TEXT_EMBEDDINGS_MODEL_NAME = "openvino_text_embeddings_model.xml"
OV_LANGUAGE_MODEL_NAME = "openvino_language_model.xml"
OV_VISION_EMBEDDINGS_MODEL_NAME = "openvino_vision_embeddings_model.xml"

OV_TOKENIZER_NAME = "openvino_tokenizer{}.xml"
OV_DETOKENIZER_NAME = "openvino_detokenizer{}.xml"
Expand Down
17 changes: 14 additions & 3 deletions tests/openvino/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@
OVModelForSequenceClassification,
OVModelForSpeechSeq2Seq,
OVModelForTokenClassification,
OVModelForVisualCausalLM,
OVStableDiffusion3Pipeline,
OVStableDiffusionPipeline,
OVStableDiffusionXLImg2ImgPipeline,
OVStableDiffusionXLPipeline,
)
from optimum.intel.openvino.modeling_base import OVBaseModel
from optimum.intel.openvino.modeling_visual_language import MODEL_TYPE_TO_CLS_MAPPING
from optimum.intel.openvino.utils import TemporaryDirectory
from optimum.intel.utils.import_utils import _transformers_version, is_transformers_version
from optimum.utils.save_utils import maybe_load_preprocessors
Expand All @@ -70,12 +72,13 @@ class ExportModelTest(unittest.TestCase):
"stable-diffusion-xl": OVStableDiffusionXLPipeline,
"stable-diffusion-xl-refiner": OVStableDiffusionXLImg2ImgPipeline,
"latent-consistency": OVLatentConsistencyModelPipeline,
"llava": OVModelForVisualCausalLM,
}

if is_transformers_version(">=", "4.45"):
SUPPORTED_ARCHITECTURES.update({"stable-diffusion-3": OVStableDiffusion3Pipeline, "flux": OVFluxPipeline})

GENERATIVE_MODELS = ("pix2struct", "t5", "bart", "gpt2", "whisper")
GENERATIVE_MODELS = ("pix2struct", "t5", "bart", "gpt2", "whisper", "llava")

def _openvino_export(
self,
Expand All @@ -93,6 +96,10 @@ def _openvino_export(
model_class = TasksManager.get_model_class_for_task(task, library=library_name)
model = model_class(f"hf_hub:{model_name}", pretrained=True, exportable=True)
TasksManager.standardize_model_attributes(model_name, model, library_name=library_name)
elif model_type == "llava":
model = MODEL_TYPE_TO_CLS_MAPPING[model_type].auto_model_class.from_pretrained(
model_name, **loading_kwargs
)
else:
model = auto_model.auto_model_class.from_pretrained(model_name, **loading_kwargs)

Expand Down Expand Up @@ -144,8 +151,12 @@ def test_export_with_custom_gen_config(self, model_type):
task = auto_model.export_feature
model_name = MODEL_NAMES[model_type]
loading_kwargs = {"attn_implementation": "eager"} if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED else {}

model = auto_model.auto_model_class.from_pretrained(model_name, **loading_kwargs)
if model_type == "llava":
model = MODEL_TYPE_TO_CLS_MAPPING[model_type].auto_model_class.from_pretrained(
model_name, **loading_kwargs
)
else:
model = auto_model.auto_model_class.from_pretrained(model_name, **loading_kwargs)

model.generation_config.top_k = 42
model.generation_config.do_sample = True
Expand Down
3 changes: 3 additions & 0 deletions tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class OVCLIExportTestCase(unittest.TestCase):
"stable-diffusion-xl": 4 if is_tokenizers_version("<", "0.20") else 0,
"stable-diffusion-3": 6 if is_tokenizers_version("<", "0.20") else 2,
"flux": 4 if is_tokenizers_version("<", "0.20") else 0,
"llava": 2 if is_tokenizers_version("<", "0.20") else 0,
}

SUPPORTED_SD_HYBRID_ARCHITECTURES = [
Expand Down Expand Up @@ -244,6 +245,8 @@ def test_exporters_cli_int8(self, task: str, model_type: str):
elif model_type.startswith("stable-diffusion") or model_type.startswith("flux"):
models = [model.unet or model.transformer, model.vae_encoder, model.vae_decoder]
models.append(model.text_encoder if model_type == "stable-diffusion" else model.text_encoder_2)
elif task.startswith("image-text-to-text"):
models = [model.language_model, model.vision_embeddings]
else:
models = [model]

Expand Down
Loading
Loading