From ce341aa82cc302f5893b86c2a0fcf7a252a9e181 Mon Sep 17 00:00:00 2001 From: Tommy Beadle Date: Thu, 12 Sep 2024 12:11:09 -0400 Subject: [PATCH] Support async view methods in extend_schema_view. --- drf_spectacular/drainage.py | 19 ++++++++++++++----- drf_spectacular/openapi.py | 7 +++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/drf_spectacular/drainage.py b/drf_spectacular/drainage.py index 7c44890b..9ad0ea56 100644 --- a/drf_spectacular/drainage.py +++ b/drf_spectacular/drainage.py @@ -182,8 +182,12 @@ def get_view_method_names(view, schema=None) -> List[str]: return [ item for item in dir(view) if callable(getattr(view, item, None)) and ( item in view.http_method_names - or item in schema.method_mapping.values() - or item == 'list' + or ( + item in schema.async_method_mapping.values() + if view.view_is_async + else item in schema.method_mapping.values() + ) + or item == ('alist' if view.view_is_async else 'list') or hasattr(getattr(view, item, None), 'mapping') ) ] @@ -202,9 +206,14 @@ def isolate_view_method(view, method_name): if method_name in view.__dict__ and method.__name__ != 'handler': return method - @functools.wraps(method) - def wrapped_method(self, request, *args, **kwargs): - return method(self, request, *args, **kwargs) + if getattr(view, "view_is_async", False): + @functools.wraps(method) + async def wrapped_method(self, request, *args, **kwargs): + return await method(self, request, *args, **kwargs) + else: + @functools.wraps(method) + def wrapped_method(self, request, *args, **kwargs): + return method(self, request, *args, **kwargs) # wraps() will only create a shallow copy of method.__dict__. Updates to "kwargs" # via @extend_schema would leak to the original method. Isolate by creating a copy. diff --git a/drf_spectacular/openapi.py b/drf_spectacular/openapi.py index ead7220b..77042e68 100644 --- a/drf_spectacular/openapi.py +++ b/drf_spectacular/openapi.py @@ -57,6 +57,13 @@ class AutoSchema(ViewInspector): 'patch': 'partial_update', 'delete': 'destroy', } + async_method_mapping = { + 'get': 'aretrieve', + 'post': 'acreate', + 'put': 'aupdate', + 'patch': 'partial_aupdate', + 'delete': 'adestroy', + } def get_operation( self,