From 51235772205f5ef6960cd5f1f7cb55cffdc26a34 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Sat, 24 Sep 2022 11:30:22 -0700 Subject: [PATCH] Possible fix for #198: memory leak --- marshmallow_dataclass/__init__.py | 133 +++++++++++++++++++----------- 1 file changed, 84 insertions(+), 49 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index c7b1e0a..2e6aa48 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -46,6 +46,7 @@ class User: Any, Callable, Dict, + Generic, List, Mapping, NewType as typing_NewType, @@ -79,9 +80,6 @@ class User: # Max number of generated schemas that class_schema keeps of generated schemas. Removes duplicates. MAX_CLASS_SCHEMA_CACHE_SIZE = 1024 -# Recursion guard for class_schema() -_RECURSION_GUARD = threading.local() - @overload def dataclass( @@ -352,20 +350,56 @@ def class_schema( clazz_frame = current_frame.f_back # Per https://docs.python.org/3/library/inspect.html#the-interpreter-stack del current_frame - _RECURSION_GUARD.seen_classes = {} - try: - return _internal_class_schema(clazz, base_schema, clazz_frame) - finally: - _RECURSION_GUARD.seen_classes.clear() + + with _SchemaContext(clazz_frame): + return _internal_class_schema(clazz, base_schema) + + +class _SchemaContext: + """Global context for an invocation of class_schema.""" + + def __init__(self, frame: Optional[types.FrameType]): + self.seen_classes: Dict[Type, str] = {} + self.frame = frame + + def get_type_hints(self, cls: Type) -> Dict[str, Any]: + frame = self.frame + localns = frame.f_locals if frame is not None else None + return get_type_hints(cls, localns=localns) + + def __enter__(self) -> "_SchemaContext": + _schema_ctx_stack.push(self) + return self + + def __exit__(self, _typ, _value, _tb) -> None: + _schema_ctx_stack.pop() + + +class _LocalStack(threading.local, Generic[_U]): + def __init__(self): + self.stack: List[_U] = [] + + def push(self, value: _U) -> None: + self.stack.append(value) + + def pop(self) -> None: + self.stack.pop() + + @property + def top(self) -> _U: + return self.stack[-1] + + +_schema_ctx_stack = _LocalStack[_SchemaContext]() @lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE) def _internal_class_schema( clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None, - clazz_frame: types.FrameType = None, ) -> Type[marshmallow.Schema]: - _RECURSION_GUARD.seen_classes[clazz] = clazz.__name__ + schema_ctx = _schema_ctx_stack.top + schema_ctx.seen_classes[clazz] = clazz.__name__ try: # noinspection PyDataclass fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz) @@ -383,7 +417,7 @@ def _internal_class_schema( "****** WARNING ******" ) created_dataclass: type = dataclasses.dataclass(clazz) - return _internal_class_schema(created_dataclass, base_schema, clazz_frame) + return _internal_class_schema(created_dataclass, base_schema) except Exception: raise TypeError( f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one." @@ -397,18 +431,15 @@ def _internal_class_schema( } # Update the schema members to contain marshmallow fields instead of dataclass fields - type_hints = get_type_hints( - clazz, localns=clazz_frame.f_locals if clazz_frame else None - ) + type_hints = schema_ctx.get_type_hints(clazz) attributes.update( ( field.name, - field_for_schema( + _field_for_schema( type_hints[field.name], _get_field_default(field), field.metadata, base_schema, - clazz_frame, ), ) for field in fields @@ -433,7 +464,6 @@ def _field_by_supertype( newtype_supertype: Type, metadata: dict, base_schema: Optional[Type[marshmallow.Schema]], - typ_frame: Optional[types.FrameType], ) -> marshmallow.fields.Field: """ Return a new field for fields based on a super field. (Usually spawned from NewType) @@ -459,12 +489,11 @@ def _field_by_supertype( if field: return field(**metadata) else: - return field_for_schema( + return _field_for_schema( newtype_supertype, metadata=metadata, default=default, base_schema=base_schema, - typ_frame=typ_frame, ) @@ -488,7 +517,6 @@ def _generic_type_add_any(typ: type) -> type: def _field_for_generic_type( typ: type, base_schema: Optional[Type[marshmallow.Schema]], - typ_frame: Optional[types.FrameType], **metadata: Any, ) -> Optional[marshmallow.fields.Field]: """ @@ -501,9 +529,7 @@ def _field_for_generic_type( type_mapping = base_schema.TYPE_MAPPING if base_schema else {} if origin in (list, List): - child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ) + child_type = _field_for_schema(arguments[0], base_schema=base_schema) list_type = cast( Type[marshmallow.fields.List], type_mapping.get(List, marshmallow.fields.List), @@ -512,32 +538,25 @@ def _field_for_generic_type( if origin in (collections.abc.Sequence, Sequence): from . import collection_field - child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ) + child_type = _field_for_schema(arguments[0], base_schema=base_schema) return collection_field.Sequence(cls_or_instance=child_type, **metadata) if origin in (set, Set): from . import collection_field - child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ) + child_type = _field_for_schema(arguments[0], base_schema=base_schema) return collection_field.Set( cls_or_instance=child_type, frozen=False, **metadata ) if origin in (frozenset, FrozenSet): from . import collection_field - child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ) + child_type = _field_for_schema(arguments[0], base_schema=base_schema) return collection_field.Set( cls_or_instance=child_type, frozen=True, **metadata ) if origin in (tuple, Tuple): children = tuple( - field_for_schema(arg, base_schema=base_schema, typ_frame=typ_frame) - for arg in arguments + _field_for_schema(arg, base_schema=base_schema) for arg in arguments ) tuple_type = cast( Type[marshmallow.fields.Tuple], @@ -549,12 +568,8 @@ def _field_for_generic_type( elif origin in (dict, Dict, collections.abc.Mapping, Mapping): dict_type = type_mapping.get(Dict, marshmallow.fields.Dict) return dict_type( - keys=field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ), - values=field_for_schema( - arguments[1], base_schema=base_schema, typ_frame=typ_frame - ), + keys=_field_for_schema(arguments[0], base_schema=base_schema), + values=_field_for_schema(arguments[1], base_schema=base_schema), **metadata, ) elif typing_inspect.is_union_type(typ): @@ -566,11 +581,10 @@ def _field_for_generic_type( metadata.setdefault("required", False) subtypes = [t for t in arguments if t is not NoneType] # type: ignore if len(subtypes) == 1: - return field_for_schema( + return _field_for_schema( subtypes[0], metadata=metadata, base_schema=base_schema, - typ_frame=typ_frame, ) from . import union_field @@ -578,11 +592,10 @@ def _field_for_generic_type( [ ( subtyp, - field_for_schema( + _field_for_schema( subtyp, metadata={"required": True}, base_schema=base_schema, - typ_frame=typ_frame, ), ) for subtyp in subtypes @@ -618,6 +631,29 @@ def field_for_schema( >>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__ + """ + with _SchemaContext(typ_frame): + return _field_for_schema(typ, default, metadata, base_schema) + + +def _field_for_schema( + typ: type, + default=marshmallow.missing, + metadata: Mapping[str, Any] = None, + base_schema: Optional[Type[marshmallow.Schema]] = None, +) -> marshmallow.fields.Field: + """ + Get a marshmallow Field corresponding to the given python type. + The metadata of the dataclass field is used as arguments to the marshmallow Field. + + This is an internal version of field_for_schema. It assumes a _SchemaContext + has been pushed onto the local stack. + + :param typ: The type for which a field should be generated + :param default: value to use for (de)serialization when the field is missing + :param metadata: Additional parameters to pass to the marshmallow field constructor + :param base_schema: marshmallow schema used as a base class when deriving dataclass schema + """ metadata = {} if metadata is None else dict(metadata) @@ -690,10 +726,10 @@ def field_for_schema( ) else: subtyp = Any - return field_for_schema(subtyp, default, metadata, base_schema, typ_frame) + return _field_for_schema(subtyp, default, metadata, base_schema) # Generic types - generic_field = _field_for_generic_type(typ, base_schema, typ_frame, **metadata) + generic_field = _field_for_generic_type(typ, base_schema, **metadata) if generic_field: return generic_field @@ -707,7 +743,6 @@ def field_for_schema( newtype_supertype=newtype_supertype, metadata=metadata, base_schema=base_schema, - typ_frame=typ_frame, ) # enumerations @@ -726,8 +761,8 @@ def field_for_schema( nested = ( nested_schema or forward_reference - or _RECURSION_GUARD.seen_classes.get(typ) - or _internal_class_schema(typ, base_schema, typ_frame) + or _schema_ctx_stack.top.seen_classes.get(typ) + or _internal_class_schema(typ, base_schema) ) return marshmallow.fields.Nested(nested, **metadata)