Skip to content

Commit

Permalink
Fix loading Timm models with ov_config (#517)
Browse files Browse the repository at this point in the history
  • Loading branch information
helena-intel authored Jan 16, 2024
1 parent 2f2a764 commit e22a2ac
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 1 addition & 1 deletion optimum/intel/openvino/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def from_pretrained(
model = TimmForImageClassification.from_pretrained(model_id, **kwargs)
onnx_config = TimmOnnxConfig(model.config)

return cls._to_load(model=model, config=config, onnx_config=onnx_config, stateful=False)
return cls._to_load(model=model, config=config, onnx_config=onnx_config, stateful=False, **kwargs)
else:
return super().from_pretrained(
model_id=model_id,
Expand Down
1 change: 1 addition & 0 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ def test_pipeline(self, model_arch):
@parameterized.expand(TIMM_MODELS)
def test_compare_to_timm(self, model_id):
ov_model = OVModelForImageClassification.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
self.assertEqual(ov_model.request.get_property("INFERENCE_PRECISION_HINT").to_string(), "f32")
self.assertIsInstance(ov_model.config, PretrainedConfig)
timm_model = timm.create_model(model_id, pretrained=True)
preprocessor = TimmImageProcessor.from_pretrained(model_id)
Expand Down

0 comments on commit e22a2ac

Please sign in to comment.