Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Nov 13, 2024
1 parent f13f559 commit 0e67735
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
3 changes: 2 additions & 1 deletion optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from transformers import (
AutoConfig,
AutoImageProcessor,
AutoModelForCausalLM,
GenerationConfig,
GenerationMixin,
PretrainedConfig,
Expand Down Expand Up @@ -360,7 +361,7 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
save_directory (`str` or `Path`):
The directory where to save the model files.
"""
src_files = [self.lm_model, self.text_embdings_model, self.vision_embeddings_model]
src_files = [self.lm_model, self.text_embeddings_model, self.vision_embeddings_model]
dst_file_names = [OV_LANGUAGE_MODEL_NAME, OV_TEXT_EMBEDDINGS_MODEL_NAME, OV_VISION_EMBEDDINGS_MODEL_NAME]
for part in self.additional_parts:
model = getattr(self, f"{part}_model", None)
Expand Down
17 changes: 13 additions & 4 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2018,10 +2018,19 @@ def test_compare_to_transformers(self, model_arch):
for additional_part in ov_model.additional_parts:
self.assertTrue("CPU" in getattr(ov_model, additional_part)._device)
self.assertTrue(getattr(ov_model, additional_part).request is not None)
# pytorch minicpmv and internvl are not designed to be used via forward
if model_arch not in ["minicpmv", "internvl2"]:
set_seed(SEED)
ov_outputs = ov_model(**inputs)
ov_model.clear_requests()
self.assertTrue("CPU" in ov_model._device)
self.assertTrue("CPU" in ov_model.vision_embeddings._device)
self.assertTrue(ov_model.vision_embeddings.request is None)
self.assertTrue("CPU" in ov_model.language_model._device)
self.assertTrue(ov_model.language_model.request is None)
self.assertTrue(ov_model.language_model.text_emb_request is None)
for additional_part in ov_model.additional_parts:
self.assertTrue("CPU" in getattr(ov_model, additional_part)._device)
self.assertTrue(getattr(ov_model, additional_part).request is None)

# pytorch minicpmv is not designed to be used via forward
if model_arch in ["minicpmv", "internvl2"]:
set_seed(SEED)
with torch.no_grad():
transformers_outputs = transformers_model(**inputs)
Expand Down

0 comments on commit 0e67735

Please sign in to comment.