Skip to content

Commit

Permalink
Fix typing.get_type_hints call on a ModelHubMixin (#2729)
Browse files Browse the repository at this point in the history
* Add DataclassInstance for runtime type_checking

* Add suggestion from code review

* Fix type annotation

* Add test

* Actually fix type annotation (3.8)

---------

Co-authored-by: Celina Hanouti <[email protected]>
  • Loading branch information
aliberts and hanouticelina committed Jan 6, 2025
1 parent 6bfa5dd commit d3dec49
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 10 deletions.
21 changes: 12 additions & 9 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import inspect
import json
import os
from dataclasses import asdict, dataclass, is_dataclass
from dataclasses import Field, asdict, dataclass, is_dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Callable, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, TypeVar, Union

import packaging.version

Expand All @@ -24,9 +24,6 @@
)


if TYPE_CHECKING:
from _typeshed import DataclassInstance

if is_torch_available():
import torch # type: ignore

Expand All @@ -38,6 +35,12 @@

logger = logging.get_logger(__name__)


# Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349
class DataclassInstance(Protocol):
__dataclass_fields__: ClassVar[Dict[str, Field]]


# Generic variable that is either ModelHubMixin or a subclass thereof
T = TypeVar("T", bound="ModelHubMixin")
# Generic variable to represent an args type
Expand Down Expand Up @@ -175,7 +178,7 @@ class ModelHubMixin:
```
"""

_hub_mixin_config: Optional[Union[dict, "DataclassInstance"]] = None
_hub_mixin_config: Optional[Union[dict, DataclassInstance]] = None
# ^ optional config attribute automatically set in `from_pretrained`
_hub_mixin_info: MixinInfo
# ^ information about the library integrating ModelHubMixin (used to generate model card)
Expand Down Expand Up @@ -366,7 +369,7 @@ def save_pretrained(
self,
save_directory: Union[str, Path],
*,
config: Optional[Union[dict, "DataclassInstance"]] = None,
config: Optional[Union[dict, DataclassInstance]] = None,
repo_id: Optional[str] = None,
push_to_hub: bool = False,
model_card_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -618,7 +621,7 @@ def push_to_hub(
self,
repo_id: str,
*,
config: Optional[Union[dict, "DataclassInstance"]] = None,
config: Optional[Union[dict, DataclassInstance]] = None,
commit_message: str = "Push model using huggingface_hub.",
private: Optional[bool] = None,
token: Optional[str] = None,
Expand Down Expand Up @@ -825,7 +828,7 @@ def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, stric
return model


def _load_dataclass(datacls: Type["DataclassInstance"], data: dict) -> "DataclassInstance":
def _load_dataclass(datacls: Type[DataclassInstance], data: dict) -> DataclassInstance:
"""Load a dataclass instance from a dictionary.
Fields not expected by the dataclass are ignored.
Expand Down
23 changes: 22 additions & 1 deletion tests/test_hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unittest
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, Optional, Union, get_type_hints
from unittest.mock import Mock, patch

import jedi
Expand Down Expand Up @@ -474,3 +474,24 @@ def dummy_example_for_test(self, x: str) -> str:
source_lines = source.split("\n")
completions = script.complete(len(source_lines), len(source_lines[-1]))
assert any(completion.name == "dummy_example_for_test" for completion in completions)

def test_get_type_hints_works_as_expected(self):
"""
Ensure that `typing.get_type_hints` works as expected when inheriting from `ModelHubMixin`.
See https://github.com/huggingface/huggingface_hub/issues/2727.
"""

class ModelWithHints(ModelHubMixin):
def method_with_hints(self, x: int) -> str:
return str(x)

assert get_type_hints(ModelWithHints) != {}

# Test method type hints on class
hints = get_type_hints(ModelWithHints.method_with_hints)
assert hints == {"x": int, "return": str}

# Test method type hints on instance
model = ModelWithHints()
assert get_type_hints(model.method_with_hints) == {"x": int, "return": str}

0 comments on commit d3dec49

Please sign in to comment.