Skip to content

Commit

Permalink
Possible fix for #198: memory leak
Browse files Browse the repository at this point in the history
  • Loading branch information
dairiki committed Sep 24, 2022
1 parent 6391003 commit 5123577
Showing 1 changed file with 84 additions and 49 deletions.
133 changes: 84 additions & 49 deletions marshmallow_dataclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class User:
Any,
Callable,
Dict,
Generic,
List,
Mapping,
NewType as typing_NewType,
Expand Down Expand Up @@ -79,9 +80,6 @@ class User:
# Max number of generated schemas that class_schema keeps of generated schemas. Removes duplicates.
MAX_CLASS_SCHEMA_CACHE_SIZE = 1024

# Recursion guard for class_schema()
_RECURSION_GUARD = threading.local()


@overload
def dataclass(
Expand Down Expand Up @@ -352,20 +350,56 @@ def class_schema(
clazz_frame = current_frame.f_back
# Per https://docs.python.org/3/library/inspect.html#the-interpreter-stack
del current_frame
_RECURSION_GUARD.seen_classes = {}
try:
return _internal_class_schema(clazz, base_schema, clazz_frame)
finally:
_RECURSION_GUARD.seen_classes.clear()

with _SchemaContext(clazz_frame):
return _internal_class_schema(clazz, base_schema)


class _SchemaContext:
"""Global context for an invocation of class_schema."""

def __init__(self, frame: Optional[types.FrameType]):
self.seen_classes: Dict[Type, str] = {}
self.frame = frame

def get_type_hints(self, cls: Type) -> Dict[str, Any]:
frame = self.frame
localns = frame.f_locals if frame is not None else None
return get_type_hints(cls, localns=localns)

def __enter__(self) -> "_SchemaContext":
_schema_ctx_stack.push(self)
return self

def __exit__(self, _typ, _value, _tb) -> None:
_schema_ctx_stack.pop()


class _LocalStack(threading.local, Generic[_U]):
def __init__(self):
self.stack: List[_U] = []

def push(self, value: _U) -> None:
self.stack.append(value)

def pop(self) -> None:
self.stack.pop()

@property
def top(self) -> _U:
return self.stack[-1]


_schema_ctx_stack = _LocalStack[_SchemaContext]()


@lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE)
def _internal_class_schema(
clazz: type,
base_schema: Optional[Type[marshmallow.Schema]] = None,
clazz_frame: types.FrameType = None,
) -> Type[marshmallow.Schema]:
_RECURSION_GUARD.seen_classes[clazz] = clazz.__name__
schema_ctx = _schema_ctx_stack.top
schema_ctx.seen_classes[clazz] = clazz.__name__
try:
# noinspection PyDataclass
fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz)
Expand All @@ -383,7 +417,7 @@ def _internal_class_schema(
"****** WARNING ******"
)
created_dataclass: type = dataclasses.dataclass(clazz)
return _internal_class_schema(created_dataclass, base_schema, clazz_frame)
return _internal_class_schema(created_dataclass, base_schema)
except Exception:
raise TypeError(
f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one."
Expand All @@ -397,18 +431,15 @@ def _internal_class_schema(
}

# Update the schema members to contain marshmallow fields instead of dataclass fields
type_hints = get_type_hints(
clazz, localns=clazz_frame.f_locals if clazz_frame else None
)
type_hints = schema_ctx.get_type_hints(clazz)
attributes.update(
(
field.name,
field_for_schema(
_field_for_schema(
type_hints[field.name],
_get_field_default(field),
field.metadata,
base_schema,
clazz_frame,
),
)
for field in fields
Expand All @@ -433,7 +464,6 @@ def _field_by_supertype(
newtype_supertype: Type,
metadata: dict,
base_schema: Optional[Type[marshmallow.Schema]],
typ_frame: Optional[types.FrameType],
) -> marshmallow.fields.Field:
"""
Return a new field for fields based on a super field. (Usually spawned from NewType)
Expand All @@ -459,12 +489,11 @@ def _field_by_supertype(
if field:
return field(**metadata)
else:
return field_for_schema(
return _field_for_schema(
newtype_supertype,
metadata=metadata,
default=default,
base_schema=base_schema,
typ_frame=typ_frame,
)


Expand All @@ -488,7 +517,6 @@ def _generic_type_add_any(typ: type) -> type:
def _field_for_generic_type(
typ: type,
base_schema: Optional[Type[marshmallow.Schema]],
typ_frame: Optional[types.FrameType],
**metadata: Any,
) -> Optional[marshmallow.fields.Field]:
"""
Expand All @@ -501,9 +529,7 @@ def _field_for_generic_type(
type_mapping = base_schema.TYPE_MAPPING if base_schema else {}

if origin in (list, List):
child_type = field_for_schema(
arguments[0], base_schema=base_schema, typ_frame=typ_frame
)
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
list_type = cast(
Type[marshmallow.fields.List],
type_mapping.get(List, marshmallow.fields.List),
Expand All @@ -512,32 +538,25 @@ def _field_for_generic_type(
if origin in (collections.abc.Sequence, Sequence):
from . import collection_field

child_type = field_for_schema(
arguments[0], base_schema=base_schema, typ_frame=typ_frame
)
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
return collection_field.Sequence(cls_or_instance=child_type, **metadata)
if origin in (set, Set):
from . import collection_field

child_type = field_for_schema(
arguments[0], base_schema=base_schema, typ_frame=typ_frame
)
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
return collection_field.Set(
cls_or_instance=child_type, frozen=False, **metadata
)
if origin in (frozenset, FrozenSet):
from . import collection_field

child_type = field_for_schema(
arguments[0], base_schema=base_schema, typ_frame=typ_frame
)
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
return collection_field.Set(
cls_or_instance=child_type, frozen=True, **metadata
)
if origin in (tuple, Tuple):
children = tuple(
field_for_schema(arg, base_schema=base_schema, typ_frame=typ_frame)
for arg in arguments
_field_for_schema(arg, base_schema=base_schema) for arg in arguments
)
tuple_type = cast(
Type[marshmallow.fields.Tuple],
Expand All @@ -549,12 +568,8 @@ def _field_for_generic_type(
elif origin in (dict, Dict, collections.abc.Mapping, Mapping):
dict_type = type_mapping.get(Dict, marshmallow.fields.Dict)
return dict_type(
keys=field_for_schema(
arguments[0], base_schema=base_schema, typ_frame=typ_frame
),
values=field_for_schema(
arguments[1], base_schema=base_schema, typ_frame=typ_frame
),
keys=_field_for_schema(arguments[0], base_schema=base_schema),
values=_field_for_schema(arguments[1], base_schema=base_schema),
**metadata,
)
elif typing_inspect.is_union_type(typ):
Expand All @@ -566,23 +581,21 @@ def _field_for_generic_type(
metadata.setdefault("required", False)
subtypes = [t for t in arguments if t is not NoneType] # type: ignore
if len(subtypes) == 1:
return field_for_schema(
return _field_for_schema(
subtypes[0],
metadata=metadata,
base_schema=base_schema,
typ_frame=typ_frame,
)
from . import union_field

return union_field.Union(
[
(
subtyp,
field_for_schema(
_field_for_schema(
subtyp,
metadata={"required": True},
base_schema=base_schema,
typ_frame=typ_frame,
),
)
for subtyp in subtypes
Expand Down Expand Up @@ -618,6 +631,29 @@ def field_for_schema(
>>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
<class 'marshmallow.fields.Url'>
"""
with _SchemaContext(typ_frame):
return _field_for_schema(typ, default, metadata, base_schema)


def _field_for_schema(
typ: type,
default=marshmallow.missing,
metadata: Mapping[str, Any] = None,
base_schema: Optional[Type[marshmallow.Schema]] = None,
) -> marshmallow.fields.Field:
"""
Get a marshmallow Field corresponding to the given python type.
The metadata of the dataclass field is used as arguments to the marshmallow Field.
This is an internal version of field_for_schema. It assumes a _SchemaContext
has been pushed onto the local stack.
:param typ: The type for which a field should be generated
:param default: value to use for (de)serialization when the field is missing
:param metadata: Additional parameters to pass to the marshmallow field constructor
:param base_schema: marshmallow schema used as a base class when deriving dataclass schema
"""

metadata = {} if metadata is None else dict(metadata)
Expand Down Expand Up @@ -690,10 +726,10 @@ def field_for_schema(
)
else:
subtyp = Any
return field_for_schema(subtyp, default, metadata, base_schema, typ_frame)
return _field_for_schema(subtyp, default, metadata, base_schema)

# Generic types
generic_field = _field_for_generic_type(typ, base_schema, typ_frame, **metadata)
generic_field = _field_for_generic_type(typ, base_schema, **metadata)
if generic_field:
return generic_field

Expand All @@ -707,7 +743,6 @@ def field_for_schema(
newtype_supertype=newtype_supertype,
metadata=metadata,
base_schema=base_schema,
typ_frame=typ_frame,
)

# enumerations
Expand All @@ -726,8 +761,8 @@ def field_for_schema(
nested = (
nested_schema
or forward_reference
or _RECURSION_GUARD.seen_classes.get(typ)
or _internal_class_schema(typ, base_schema, typ_frame)
or _schema_ctx_stack.top.seen_classes.get(typ)
or _internal_class_schema(typ, base_schema)
)

return marshmallow.fields.Nested(nested, **metadata)
Expand Down

0 comments on commit 5123577

Please sign in to comment.