Skip to content

Commit

Permalink
Add type support and mypy checking (#100)
Browse files Browse the repository at this point in the history
* Add type support and mypy checking

* Don't need models direct import

* Simplify some auto-generated types

* Pull mypy settings into config

* Disallow untyped calls

* Make type annotations 3.5 compatible

* Remove ModelBase usage

* Explicitly run mypy in tox

* Python 3.5 needs typing_extensions

* Correct ModelBase usage

* Cope with inference from union types

* Deal with new action_generator work in tests

* Don't use test class in main code!

* _skip_field can use Field

* Loosen typing-extensions requirement

* Fix create_files attribute

Co-authored-by: Bernardo Fontes <[email protected]>

* Fix unique/ambiguous_models types

* Rip out action_generator

* Don't need typing-extensions any more

* Added changelog entry for type annotations

Co-authored-by: Bernardo Fontes <[email protected]>
  • Loading branch information
palfrey and berinhard authored Oct 5, 2020
1 parent 91019ce commit 805c090
Show file tree
Hide file tree
Showing 13 changed files with 162 additions and 120 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
## [Unreleased](https://github.com/model-bakers/model_bakery/tree/master)

### Added
- Support to django 3.1 `JSONField`[PR #85](https://github.com/model-bakers/model_bakery/pull/85) and [PR #106](https://github.com/model-bakers/model_bakery/pull/106)
- Support to django 3.1 `JSONField` [PR #85](https://github.com/model-bakers/model_bakery/pull/85) and [PR #106](https://github.com/model-bakers/model_bakery/pull/106)
- [dev] Changelog reminder (GitHub action)
- Added type annotations [PR #100](https://github.com/model-bakers/model_bakery/pull/100)

### Changed
- [dev] CI switched to GitHub Actions
Expand Down
136 changes: 78 additions & 58 deletions model_bakery/baker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from os.path import dirname, join
from typing import Any, Callable, Dict, Iterator, List, Optional, Type, Union, cast

from django.apps import apps
from django.conf import settings
Expand All @@ -10,13 +11,15 @@
FileField,
ForeignKey,
ManyToManyField,
Model,
OneToOneField,
)
from django.db.models.base import ModelBase
from django.db.models.fields.proxy import OrderWrt
from django.db.models.fields.related import (
ReverseManyToOneDescriptor as ForeignRelatedObjectsDescriptor,
)
from django.db.models.fields.reverse_related import ManyToOneRel, OneToOneRel

from . import generators, random_gen
from .exceptions import (
Expand All @@ -39,18 +42,18 @@
MAX_MANY_QUANTITY = 5


def _valid_quantity(quantity):
def _valid_quantity(quantity: Optional[Union[str, int]]) -> bool:
return quantity is not None and (not isinstance(quantity, int) or quantity < 1)


def make(
_model,
_quantity=None,
make_m2m=False,
_save_kwargs=None,
_refresh_after_create=False,
_create_files=False,
**attrs
_model: str,
_quantity: Optional[int] = None,
make_m2m: bool = False,
_save_kwargs: Optional[Dict] = None,
_refresh_after_create: bool = False,
_create_files: bool = False,
**attrs: Any
):
"""Create a persisted instance from a given model its associated models.
Expand All @@ -76,7 +79,7 @@ def make(
)


def prepare(_model, _quantity=None, _save_related=False, **attrs):
def prepare(_model: str, _quantity=None, _save_related=False, **attrs) -> Model:
"""Create but do not persist an instance from a given model.
It fill the fields with random values or you can specify which
Expand All @@ -95,7 +98,7 @@ def prepare(_model, _quantity=None, _save_related=False, **attrs):
return baker.prepare(_save_related=_save_related, **attrs)


def _recipe(name):
def _recipe(name: str) -> Any:
app, recipe_name = name.rsplit(".", 1)
return import_from_str(".".join((app, "baker_recipes", recipe_name)))

Expand All @@ -113,10 +116,10 @@ def prepare_recipe(baker_recipe_name, _quantity=None, _save_related=False, **new
class ModelFinder(object):
"""Encapsulates all the logic for finding a model to Baker."""

_unique_models = None
_ambiguous_models = None
_unique_models = None # type: Optional[Dict[str, Type[Model]]]
_ambiguous_models = None # type: Optional[List[str]]

def get_model(self, name):
def get_model(self, name: str) -> Type[Model]:
"""Get a model.
Args:
Expand All @@ -140,26 +143,26 @@ def get_model(self, name):

return model

def get_model_by_name(self, name):
def get_model_by_name(self, name: str) -> Optional[Type[Model]]:
"""Get a model by name.
If a model with that name exists in more than one app, raises
AmbiguousModelName.
"""
name = name.lower()

if self._unique_models is None:
if self._unique_models is None or self._ambiguous_models is None:
self._populate()

if name in self._ambiguous_models:
if name in cast(List, self._ambiguous_models):
raise AmbiguousModelName(
"%s is a model in more than one app. "
'Use the form "app.model".' % name.title()
)

return self._unique_models.get(name)
return cast(Dict, self._unique_models).get(name)

def _populate(self):
def _populate(self) -> None:
"""Cache models for faster self._get_model."""
unique_models = {}
ambiguous_models = []
Expand All @@ -180,14 +183,14 @@ def _populate(self):
self._unique_models = unique_models


def is_iterator(value):
def is_iterator(value: Any) -> bool:
if not hasattr(value, "__iter__"):
return False

return hasattr(value, "__next__")


def _custom_baker_class():
def _custom_baker_class() -> Optional[Type]:
"""Return the specified custom baker class.
Returns:
Expand Down Expand Up @@ -217,36 +220,43 @@ def _custom_baker_class():


class Baker(object):
attr_mapping = {}
type_mapping = None
attr_mapping = {} # type: Dict[str, Any]
type_mapping = {} # type: Dict

# Note: we're using one finder for all Baker instances to avoid
# rebuilding the model cache for every make_* or prepare_* call.
finder = ModelFinder()

@classmethod
def create(cls, _model, make_m2m=False, create_files=False):
def create(
cls, _model: str, make_m2m: bool = False, create_files: bool = False
) -> "Baker":
"""Create the baker class defined by the `BAKER_CUSTOM_CLASS` setting."""
baker_class = _custom_baker_class() or cls
return baker_class(_model, make_m2m, create_files)

def __init__(self, _model, make_m2m=False, create_files=False):
def __init__(
self,
_model: Union[str, Type[ModelBase]],
make_m2m: bool = False,
create_files: bool = False,
) -> None:
self.make_m2m = make_m2m
self.create_files = create_files
self.m2m_dict = {}
self.iterator_attrs = {}
self.model_attrs = {}
self.rel_attrs = {}
self.rel_fields = []
self.m2m_dict = {} # type: Dict[str, List]
self.iterator_attrs = {} # type: Dict[str, Iterator]
self.model_attrs = {} # type: Dict[str, Any]
self.rel_attrs = {} # type: Dict[str, Any]
self.rel_fields = [] # type: List[str]

if isinstance(_model, ModelBase):
self.model = _model
else:
if isinstance(_model, str):
self.model = self.finder.get_model(_model)
else:
self.model = _model

self.init_type_mapping()

def init_type_mapping(self):
def init_type_mapping(self) -> None:
self.type_mapping = generators.get_type_mapping()
generators_from_settings = getattr(settings, "BAKER_CUSTOM_FIELDS_GEN", {})
for k, v in generators_from_settings.items():
Expand All @@ -256,10 +266,10 @@ def init_type_mapping(self):

def make(
self,
_save_kwargs=None,
_refresh_after_create=False,
_save_kwargs: Optional[Dict[str, Any]] = None,
_refresh_after_create: bool = False,
_from_manager=None,
**attrs
**attrs: Any
):
"""Create and persist an instance of the model associated with Baker instance."""
params = {
Expand All @@ -272,14 +282,16 @@ def make(
params.update(attrs)
return self._make(**params)

def prepare(self, _save_related=False, **attrs):
def prepare(self, _save_related=False, **attrs: Any) -> Model:
"""Create, but do not persist, an instance of the associated model."""
return self._make(commit=False, commit_related=_save_related, **attrs)

def get_fields(self):
def get_fields(self) -> Any:
return self.model._meta.fields + self.model._meta.many_to_many

def get_related(self):
def get_related(
self,
) -> List[Union[ManyToOneRel, OneToOneRel]]:
return [r for r in self.model._meta.related_objects if not r.many_to_many]

def _make(
Expand All @@ -289,8 +301,8 @@ def _make(
_save_kwargs=None,
_refresh_after_create=False,
_from_manager=None,
**attrs
):
**attrs: Any
) -> Model:
_save_kwargs = _save_kwargs or {}

self._clean_attrs(attrs)
Expand Down Expand Up @@ -336,21 +348,23 @@ def _make(

return instance

def m2m_value(self, field):
def m2m_value(self, field: ManyToManyField) -> List[Any]:
if field.name in self.rel_fields:
return self.generate_value(field)
if not self.make_m2m or field.null and not field.fill_optional:
return []
return self.generate_value(field)

def instance(self, attrs, _commit, _save_kwargs, _from_manager):
def instance(
self, attrs: Dict[str, Any], _commit, _save_kwargs, _from_manager
) -> Model:
one_to_many_keys = {}
for k in tuple(attrs.keys()):
field = getattr(self.model, k, None)
if isinstance(field, ForeignRelatedObjectsDescriptor):
one_to_many_keys[k] = attrs.pop(k)

instance = self.model(**attrs)
instance = self.model(**attrs) # type: Model
# m2m only works for persisted instances
if _commit:
instance.save(**_save_kwargs)
Expand All @@ -367,19 +381,20 @@ def instance(self, attrs, _commit, _save_kwargs, _from_manager):

return instance

def create_by_related_name(self, instance, related):
def create_by_related_name(
self, instance: Model, related: Union[ManyToOneRel, OneToOneRel]
) -> None:
rel_name = related.get_accessor_name()
if rel_name not in self.rel_fields:
return

kwargs = filter_rel_attrs(rel_name, **self.rel_attrs)
kwargs[related.field.name] = instance
kwargs["_model"] = related.field.model

make(**kwargs)
make(related.field.model, **kwargs)

def _clean_attrs(self, attrs):
def is_rel_field(x):
def _clean_attrs(self, attrs: Dict[str, Any]) -> None:
def is_rel_field(x: str):
return "__" in x

self.fill_in_optional = attrs.pop("_fill_optional", False)
Expand All @@ -401,7 +416,7 @@ def is_rel_field(x):
x.split("__")[0] for x in self.rel_attrs.keys() if is_rel_field(x)
]

def _skip_field(self, field):
def _skip_field(self, field: Field) -> bool:
from django.contrib.contenttypes.fields import GenericRelation

# check for fill optional argument
Expand Down Expand Up @@ -444,7 +459,7 @@ def _skip_field(self, field):

return False

def _handle_one_to_many(self, instance, attrs):
def _handle_one_to_many(self, instance: Model, attrs: Dict[str, Any]):
for k, v in attrs.items():
manager = getattr(instance, k)

Expand All @@ -454,7 +469,7 @@ def _handle_one_to_many(self, instance, attrs):
# for many-to-many relationships the bulk keyword argument doesn't exist
manager.set(v, clear=True)

def _handle_m2m(self, instance):
def _handle_m2m(self, instance: Model):
for key, values in self.m2m_dict.items():
for value in values:
if not value.pk:
Expand All @@ -473,10 +488,12 @@ def _handle_m2m(self, instance):
}
make(through_model, **base_kwargs)

def _remote_field(self, field):
def _remote_field(
self, field: Union[ForeignKey, OneToOneField]
) -> Union[OneToOneRel, ManyToOneRel]:
return field.remote_field

def generate_value(self, field, commit=True):
def generate_value(self, field: Field, commit: bool = True) -> Any:
"""Call the associated generator with a field passing all required args.
Generator Resolution Precedence Order:
Expand Down Expand Up @@ -517,20 +534,23 @@ def generate_value(self, field, commit=True):

if not commit:
generator = getattr(generator, "prepare", generator)

return generator(**generator_attrs)


def get_required_values(generator, field):
def get_required_values(
generator: Callable, field: Field
) -> Dict[str, Union[bool, int, str, List[Callable]]]:
"""Get required values for a generator from the field.
If required value is a function, calls it with field as argument. If
required value is a string, simply fetch the value from the field
and return.
"""
# FIXME: avoid abbreviations
rt = {}
rt = {} # type: Dict[str, Any]
if hasattr(generator, "required"):
for item in generator.required:
for item in generator.required: # type: ignore[attr-defined]

if callable(item): # baker can deal with the nasty hacking too!
key, value = item(field)
Expand All @@ -549,7 +569,7 @@ def get_required_values(generator, field):
return rt


def filter_rel_attrs(field_name, **rel_attrs):
def filter_rel_attrs(field_name: str, **rel_attrs) -> Dict[str, Any]:
clean_dict = {}

for k, v in rel_attrs.items():
Expand Down
Loading

0 comments on commit 805c090

Please sign in to comment.