Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Nov 5, 2024
1 parent b978df7 commit 233b7b8
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 21 deletions.
5 changes: 1 addition & 4 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,10 @@ def main_export(

if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
loading_kwargs["attn_implementation"] = "eager"

# some models force flash_attn attention by default thta is not available for cpu
logger.warn(model_type)
if is_transformers_version(">=", "4.36") and model_type in FORCE_ATTN_MODEL_CLASSES:
loading_kwargs["_attn_implementation"] = FORCE_ATTN_MODEL_CLASSES[model_type]

logger.warn(loading_kwargs)
# there are some difference between remote and in library representation of past key values for some models,
# for avoiding confusion we disable remote code for them
if (
Expand Down
8 changes: 6 additions & 2 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,12 +683,16 @@ def export_from_model(

model_name_or_path = model.config._name_or_path
if preprocessors is not None:
# phi3-vision processor does not have chat_template attribute that breaks Processor saving on disk
if is_transformers_version(">=", "4.45") and model_type == "phi3-v" and len(preprocessors) > 1:
if not hasattr(preprocessors[1], "chat_template"):
preprocessors[1].chat_template = getattr(preprocessors[0], "chat_template", None)
for processor in preprocessors:
try:
processor.save_pretrained(output)
except Exception as ex:
logger.error(f"Saving {type(processor)} failed with {ex}")
else:
else:
maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code)

files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_export_configs.keys()]
Expand Down Expand Up @@ -849,7 +853,7 @@ def _get_multi_modal_submodels_and_export_configs(

if model_type == "internvl-chat" and preprocessors is not None:
model.config.img_context_token_id = preprocessors[0].convert_tokens_to_ids("<IMG_CONTEXT>")

if model_type == "phi3-v":
model.config.glb_GN = model.model.vision_embed_tokens.glb_GN.tolist()
model.config.sub_GN = model.model.vision_embed_tokens.sub_GN.tolist()
Expand Down
6 changes: 3 additions & 3 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from packaging import version
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, AutoConfig
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel
from transformers.utils import is_tf_available

from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
Expand Down Expand Up @@ -2010,7 +2010,7 @@ def patch_model_for_export(
return MiniCPMVResamplerModelPatcher(self, model, model_kwargs)

return super().patch_model_for_export(model, model_kwargs)


class Phi3VisionConfigBehavior(str, enum.Enum):
LANGUAGE = "language"
Expand Down Expand Up @@ -2216,4 +2216,4 @@ def patch_model_for_export(
model_kwargs = model_kwargs or {}
if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS:
return Phi3VisionImageEmbeddingsPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)
5 changes: 2 additions & 3 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,7 +1362,7 @@ def phi3_442_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
Expand Down Expand Up @@ -2976,7 +2976,6 @@ def __exit__(self, exc_type, exc_value, traceback):
layer.self_attn.forward = layer.self_attn._orig_forward



def phi3_vision_embeddings_forward(self, pixel_values: torch.FloatTensor):
return self.get_img_features(pixel_values)

Expand All @@ -2994,4 +2993,4 @@ def __init__(

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
self._model.forward = self._model.__orig_forward
12 changes: 6 additions & 6 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,7 @@ def forward(self, img_features):
return self.request(img_features)[0]


MODEL_PARTS_CLS_MAPPING = {
"resampler": OVResampler,
"vision_projection": OVVisionProjection
}
MODEL_PARTS_CLS_MAPPING = {"resampler": OVResampler, "vision_projection": OVVisionProjection}


class OVModelForVisualCausalLM(OVBaseModel, GenerationMixin):
Expand Down Expand Up @@ -1675,12 +1672,15 @@ def get_multimodal_embeddings(
input_ids = input_ids.clamp_min(0).clamp_max(self.config.vocab_size)
inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids, **kwargs))
if has_image:
vision_embeds = self.get_vision_embeddings(pixel_values, input_ids=input_ids, image_sizes=image_sizes, **kwargs)
vision_embeds = self.get_vision_embeddings(
pixel_values, input_ids=input_ids, image_sizes=image_sizes, **kwargs
)
image_features_proj = torch.from_numpy(vision_embeds)
inputs_embeds = inputs_embeds.index_put(positions, image_features_proj, accumulate=False)

return inputs_embeds, attention_mask, position_ids


MODEL_TYPE_TO_CLS_MAPPING = {
"llava": _OVLlavaForCausalLM,
"llava_next": _OVLlavaNextForCausalLM,
Expand Down
15 changes: 12 additions & 3 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1884,9 +1884,12 @@ class OVModelForVisualCausalLMIntegrationTest(unittest.TestCase):
if is_transformers_version(">=", "4.40.0"):
SUPPORTED_ARCHITECTURES += ["llava_next"]
if is_transformers_version(">=", "4.45.0"):
SUPPORTED_ARCHITECTURES += ["minicpmv"]
SUPPORTED_ARCHITECTURES += ["minicpmv", "phi3_v"]

TASK = "image-text-to-text"

REMOTE_CODE_MODELS = ["phi3_v", "minicpmv"]

IMAGE = Image.open(
requests.get(
"http://images.cocodataset.org/val2017/000000039769.jpg",
Expand All @@ -1907,15 +1910,21 @@ def get_transformer_model_class(self, model_arch):

def gen_inputs(self, model_arch, base_text_prompt, image=None):
model_id = MODEL_NAMES[model_arch]
if "llava" in model_arch:
if image is None:
prompt = base_text_prompt
elif "llava" in model_arch:
prompt = f"<image>\n {base_text_prompt}"
elif "minicpmv" in model_arch:
prompt = "<|im_start|>user\n(<image>./</image>)\n {base_text_prompt}<|im_end|>\n<|im_start|>assistant\n"
elif "phi3_v" in model_arch:
prompt = f"<|user|>\n<|image_1|>\n{base_text_prompt}<|end|>\n<|assistant|>\n"
if model_arch != "nanollava":
processor = AutoProcessor.from_pretrained(
model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS
)
inputs = processor(images=[self.IMAGE.resize((600, 600))], text=[prompt], return_tensors="pt")
inputs = processor(
images=[image.resize((600, 600))] if image is not None else None, text=prompt, return_tensors="pt"
)
else:
config = AutoConfig.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
processor = AutoProcessor.from_pretrained(
Expand Down
1 change: 1 addition & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
"pix2struct": "fxmarty/pix2struct-tiny-random",
"phi": "echarlaix/tiny-random-PhiForCausalLM",
"phi3": "Xenova/tiny-random-Phi3ForCausalLM",
"phi3_v": "katuni4ka/tiny-random-phi3-vision",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"qwen": "katuni4ka/tiny-random-qwen",
"qwen2": "fxmarty/tiny-dummy-qwen2",
Expand Down

0 comments on commit 233b7b8

Please sign in to comment.