diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index cbcdd74b90..8bf32f1266 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -1,9 +1,9 @@ import inspect import json import os -from dataclasses import asdict, dataclass, is_dataclass +from dataclasses import Field, asdict, dataclass, is_dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, TypeVar, Union import packaging.version @@ -24,9 +24,6 @@ ) -if TYPE_CHECKING: - from _typeshed import DataclassInstance - if is_torch_available(): import torch # type: ignore @@ -38,6 +35,12 @@ logger = logging.get_logger(__name__) + +# Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349 +class DataclassInstance(Protocol): + __dataclass_fields__: ClassVar[Dict[str, Field]] + + # Generic variable that is either ModelHubMixin or a subclass thereof T = TypeVar("T", bound="ModelHubMixin") # Generic variable to represent an args type @@ -175,7 +178,7 @@ class ModelHubMixin: ``` """ - _hub_mixin_config: Optional[Union[dict, "DataclassInstance"]] = None + _hub_mixin_config: Optional[Union[dict, DataclassInstance]] = None # ^ optional config attribute automatically set in `from_pretrained` _hub_mixin_info: MixinInfo # ^ information about the library integrating ModelHubMixin (used to generate model card) @@ -366,7 +369,7 @@ def save_pretrained( self, save_directory: Union[str, Path], *, - config: Optional[Union[dict, "DataclassInstance"]] = None, + config: Optional[Union[dict, DataclassInstance]] = None, repo_id: Optional[str] = None, push_to_hub: bool = False, model_card_kwargs: Optional[Dict[str, Any]] = None, @@ -618,7 +621,7 @@ def push_to_hub( self, repo_id: str, *, - config: Optional[Union[dict, "DataclassInstance"]] = None, + config: Optional[Union[dict, DataclassInstance]] = None, commit_message: str = "Push model using huggingface_hub.", private: Optional[bool] = None, token: Optional[str] = None, @@ -825,7 +828,7 @@ def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, stric return model -def _load_dataclass(datacls: Type["DataclassInstance"], data: dict) -> "DataclassInstance": +def _load_dataclass(datacls: Type[DataclassInstance], data: dict) -> DataclassInstance: """Load a dataclass instance from a dictionary. Fields not expected by the dataclass are ignored. diff --git a/tests/test_hub_mixin.py b/tests/test_hub_mixin.py index 10f12f7a25..f2043a877e 100644 --- a/tests/test_hub_mixin.py +++ b/tests/test_hub_mixin.py @@ -4,7 +4,7 @@ import unittest from dataclasses import dataclass from pathlib import Path -from typing import Dict, Optional, Union +from typing import Dict, Optional, Union, get_type_hints from unittest.mock import Mock, patch import jedi @@ -474,3 +474,24 @@ def dummy_example_for_test(self, x: str) -> str: source_lines = source.split("\n") completions = script.complete(len(source_lines), len(source_lines[-1])) assert any(completion.name == "dummy_example_for_test" for completion in completions) + + def test_get_type_hints_works_as_expected(self): + """ + Ensure that `typing.get_type_hints` works as expected when inheriting from `ModelHubMixin`. + + See https://github.com/huggingface/huggingface_hub/issues/2727. + """ + + class ModelWithHints(ModelHubMixin): + def method_with_hints(self, x: int) -> str: + return str(x) + + assert get_type_hints(ModelWithHints) != {} + + # Test method type hints on class + hints = get_type_hints(ModelWithHints.method_with_hints) + assert hints == {"x": int, "return": str} + + # Test method type hints on instance + model = ModelWithHints() + assert get_type_hints(model.method_with_hints) == {"x": int, "return": str}