From 2a3f8af83fa05a1e26046b78ef6ecd5a43c3b251 Mon Sep 17 00:00:00 2001 From: Keming Date: Tue, 31 Dec 2024 18:33:49 +0800 Subject: [PATCH] feat: support user custom error (#398) * feat: support user custom error Signed-off-by: Keming * legacy validator Signed-off-by: Keming --------- Signed-off-by: Keming --- Makefile | 4 +- README.md | 4 +- docs/source/conf.py | 2 +- pyproject.toml | 2 +- spectree/plugins/falcon_plugin.py | 24 +++++- spectree/plugins/flask_plugin.py | 12 ++- spectree/plugins/quart_plugin.py | 14 +++- spectree/plugins/starlette_plugin.py | 14 +++- spectree/utils.py | 2 +- .../test_plugin_spec[falcon][full_spec].json | 53 ++++++++++++ .../test_plugin_spec[flask][full_spec].json | 53 ++++++++++++ ...ugin_spec[flask_blueprint][full_spec].json | 53 ++++++++++++ ...st_plugin_spec[flask_view][full_spec].json | 84 +++++++++++++++++++ ...est_plugin_spec[starlette][full_spec].json | 53 ++++++++++++ tests/common.py | 24 +++++- tests/flask_imports/__init__.py | 8 ++ tests/flask_imports/dry_plugin_flask.py | 12 ++- tests/quart_imports/__init__.py | 6 ++ tests/quart_imports/dry_plugin_quart.py | 10 +++ tests/test_plugin_falcon.py | 21 +++++ tests/test_plugin_flask.py | 13 ++- tests/test_plugin_flask_blueprint.py | 7 ++ tests/test_plugin_flask_view.py | 15 ++++ tests/test_plugin_quart.py | 7 ++ tests/test_plugin_starlette.py | 17 ++++ 25 files changed, 489 insertions(+), 25 deletions(-) diff --git a/Makefile b/Makefile index 5b3610f6..ed6d2655 100644 --- a/Makefile +++ b/Makefile @@ -16,9 +16,9 @@ import_test: test: import_test pip install -U -e .[email,flask,quart,falcon,starlette] - pytest tests -vv -rs + pytest tests -vv -rs --disable-warnings pip install --force-reinstall 'pydantic[email]<2' - pytest tests -vv -rs + pytest tests -vv -rs --disable-warnings update_snapshot: @pytest --snapshot-update diff --git a/README.md b/README.md index c074b676..97d33540 100644 --- a/README.md +++ b/README.md @@ -489,8 +489,8 @@ if __name__ == "__main__": > ValidationError: missing field for headers The HTTP headers' keys in Flask are capitalized, in Falcon are upper cases, in Starlette are lower cases. -You can use [`pydantic.root_validators(pre=True)`](https://pydantic-docs.helpmanual.io/usage/validators/#root-validators) to change all the keys into lower cases or upper cases. +You can use [`pydantic.model_validator(mode="before")`](https://docs.pydantic.dev/dev/concepts/validators/#model-validators) to change all the keys into lower cases or upper cases. > ValidationError: value is not a valid list for the query -Since there is no standard for HTTP queries with multiple values, it's hard to find a way to handle this for different web frameworks. So I suggest not to use list type in query until I find a suitable way to fix it. +Since there is no standard for HTTP queries with multiple values, it's hard to find a way to handle this for different web frameworks. diff --git a/docs/source/conf.py b/docs/source/conf.py index a57376b0..df53220e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -47,7 +47,7 @@ source_suffix = [".rst", ".md"] language = "en" html_baseurl = "https://0b01001001.github.io/spectree/" -html_extra_path = ['robots.txt'] +html_extra_path = ["robots.txt"] # -- Options for HTML output ------------------------------------------------- diff --git a/pyproject.toml b/pyproject.toml index dbff1343..9a11405f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spectree" -version = "1.4.2" +version = "1.4.3" dynamic = [] description = "generate OpenAPI document and validate request&response with Python annotations." readme = "README.md" diff --git a/spectree/plugins/falcon_plugin.py b/spectree/plugins/falcon_plugin.py index 5f2be2a3..c3535176 100644 --- a/spectree/plugins/falcon_plugin.py +++ b/spectree/plugins/falcon_plugin.py @@ -214,7 +214,11 @@ def validate( except (InternalValidationError, ValidationError) as err: req_validation_error = err _resp.status = f"{validation_error_status} Validation Error" - _resp.media = err.errors() + _resp.media = ( + err.errors() + if isinstance(err, InternalValidationError) + else err.errors(include_context=False) + ) if self.config.annotations: annotations = get_type_hints(func) @@ -238,7 +242,11 @@ def validate( except (InternalValidationError, ValidationError) as err: resp_validation_error = err _resp.status = HTTP_500 - _resp.media = err.errors() + _resp.media = ( + err.errors() + if isinstance(err, InternalValidationError) + else err.errors(include_context=False) + ) else: _resp.media = response_validation_result.payload @@ -320,7 +328,11 @@ async def validate( except (InternalValidationError, ValidationError) as err: req_validation_error = err _resp.status = f"{validation_error_status} Validation Error" - _resp.media = err.errors() + _resp.media = ( + err.errors() + if isinstance(err, InternalValidationError) + else err.errors(include_context=False) + ) if self.config.annotations: annotations = get_type_hints(func) @@ -348,7 +360,11 @@ async def validate( except (InternalValidationError, ValidationError) as err: resp_validation_error = err _resp.status = HTTP_500 - _resp.media = err.errors() + _resp.media = ( + err.errors() + if isinstance(err, InternalValidationError) + else err.errors(include_context=False) + ) else: _resp.media = response_validation_result.payload diff --git a/spectree/plugins/flask_plugin.py b/spectree/plugins/flask_plugin.py index 90207575..09103a8d 100644 --- a/spectree/plugins/flask_plugin.py +++ b/spectree/plugins/flask_plugin.py @@ -187,7 +187,11 @@ def validate( self.request_validation(request, query, json, form, headers, cookies) except (InternalValidationError, ValidationError) as err: req_validation_error = err - errors = err.errors() if isinstance(err, InternalValidationError) else err.errors(include_context=False) + errors = ( + err.errors() + if isinstance(err, InternalValidationError) + else err.errors(include_context=False) + ) response = make_response(jsonify(errors), validation_error_status) if self.config.annotations: @@ -225,7 +229,11 @@ def validate( response_payload=payload, ) except (InternalValidationError, ValidationError) as err: - errors = err.errors() if isinstance(err, InternalValidationError) else err.errors(include_context=False) + errors = ( + err.errors() + if isinstance(err, InternalValidationError) + else err.errors(include_context=False) + ) response = make_response(errors, 500) resp_validation_error = err else: diff --git a/spectree/plugins/quart_plugin.py b/spectree/plugins/quart_plugin.py index fe82a6c5..8f0cfffe 100644 --- a/spectree/plugins/quart_plugin.py +++ b/spectree/plugins/quart_plugin.py @@ -197,9 +197,12 @@ async def validate( ) except (InternalValidationError, ValidationError) as err: req_validation_error = err - response = await make_response( - jsonify(err.errors()), validation_error_status + errors = ( + err.errors() + if isinstance(err, InternalValidationError) + else err.errors(include_context=False) ) + response = await make_response(jsonify(errors), validation_error_status) if self.config.annotations: annotations = get_type_hints(func) @@ -241,7 +244,12 @@ async def validate( response_payload=payload, ) except (InternalValidationError, ValidationError) as err: - response = await make_response(err.errors(), 500) + errors = ( + err.errors() + if isinstance(err, InternalValidationError) + else err.errors(include_context=False) + ) + response = await make_response(errors, 500) resp_validation_error = err else: response = await make_response( diff --git a/spectree/plugins/starlette_plugin.py b/spectree/plugins/starlette_plugin.py index dd8e1a6e..7e7000e2 100644 --- a/spectree/plugins/starlette_plugin.py +++ b/spectree/plugins/starlette_plugin.py @@ -117,7 +117,12 @@ async def validate( ) except (InternalValidationError, ValidationError) as err: req_validation_error = err - response = JSONResponse(err.errors(), validation_error_status) + response = JSONResponse( + err.errors() + if isinstance(err, InternalValidationError) + else err.errors(include_context=False), + validation_error_status, + ) except JSONDecodeError as err: json_decode_error = err self.logger.info( @@ -162,7 +167,12 @@ async def validate( response_payload=RawResponsePayload(payload=response.body), ) except (InternalValidationError, ValidationError) as err: - response = JSONResponse(err.errors(), 500) + response = JSONResponse( + err.errors() + if isinstance(err, InternalValidationError) + else err.errors(include_context=False), + 500, + ) resp_validation_error = err after(request, response, resp_validation_error, instance) diff --git a/spectree/utils.py b/spectree/utils.py index 25de5a2d..8bfd2fed 100644 --- a/spectree/utils.py +++ b/spectree/utils.py @@ -393,7 +393,7 @@ def flask_response_unpack(resp: Any) -> Tuple[Any, int, Dict[str, Any]]: f"Invalid return tuple: {resp}, expect (body,), (body, status), " "(body, headers), or (body, status, headers)." ) - return payload, status, headers + return payload, status, dict(headers) def parse_resp(func: Any, naming_strategy: NamingStrategy = get_model_key): diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[falcon][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[falcon][full_spec].json index 722cbe44..fbb56f45 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[falcon][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[falcon][full_spec].json @@ -14,6 +14,19 @@ "title": "Cookies", "type": "object" }, + "CustomError.7068f62": { + "properties": { + "foo": { + "title": "Foo", + "type": "string" + } + }, + "required": [ + "foo" + ], + "title": "CustomError", + "type": "object" + }, "FormFileUpload.7068f62": { "properties": { "file": { @@ -167,6 +180,46 @@ }, "openapi": "3.1.0", "paths": { + "/api/custom_error": { + "post": { + "description": "", + "operationId": "post__api_custom_error", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CustomError.7068f62" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CustomError.7068f62" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Content" + } + }, + "summary": "on_post ", + "tags": [] + } + }, "/api/custom_serializer": { "get": { "description": "", diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json index e0a6fe76..288cc51c 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json @@ -14,6 +14,19 @@ "title": "Cookies", "type": "object" }, + "CustomError.7068f62": { + "properties": { + "foo": { + "title": "Foo", + "type": "string" + } + }, + "required": [ + "foo" + ], + "title": "CustomError", + "type": "object" + }, "Form.7068f62": { "properties": { "limit": { @@ -201,6 +214,46 @@ }, "openapi": "3.1.0", "paths": { + "/api/custom_error": { + "post": { + "description": "", + "operationId": "post__api_custom_error", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CustomError.7068f62" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CustomError.7068f62" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Content" + } + }, + "summary": "custom_error ", + "tags": [] + } + }, "/api/file_upload": { "post": { "description": "", diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json index b82ac180..efd0d4ef 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json @@ -14,6 +14,19 @@ "title": "Cookies", "type": "object" }, + "CustomError.7068f62": { + "properties": { + "foo": { + "title": "Foo", + "type": "string" + } + }, + "required": [ + "foo" + ], + "title": "CustomError", + "type": "object" + }, "Form.7068f62": { "properties": { "limit": { @@ -201,6 +214,46 @@ }, "openapi": "3.1.0", "paths": { + "/api/custom_error": { + "post": { + "description": "", + "operationId": "post__api_custom_error", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CustomError.7068f62" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CustomError.7068f62" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Content" + } + }, + "summary": "custom_error ", + "tags": [] + } + }, "/api/file_upload": { "post": { "description": "", diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json index 5506ee01..69de0fe3 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json @@ -14,6 +14,19 @@ "title": "Cookies", "type": "object" }, + "CustomError.7068f62": { + "properties": { + "foo": { + "title": "Foo", + "type": "string" + } + }, + "required": [ + "foo" + ], + "title": "CustomError", + "type": "object" + }, "Form.7068f62": { "properties": { "limit": { @@ -201,6 +214,46 @@ }, "openapi": "3.1.0", "paths": { + "/api/custom_error": { + "post": { + "description": "", + "operationId": "post__api_custom_error", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CustomError.7068f62" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CustomError.7068f62" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Content" + } + }, + "summary": "post ", + "tags": [] + } + }, "/api/file_upload": { "post": { "description": "", @@ -475,6 +528,37 @@ "tags": [] } }, + "/api/return_root": { + "get": { + "description": "", + "operationId": "get__api_return_root", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RootResp.a9993e3" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Content" + } + }, + "summary": "get ", + "tags": [] + } + }, "/api/user/{name}": { "post": { "description": "", diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json index 3924355f..cfc3f2cf 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json @@ -14,6 +14,19 @@ "title": "Cookies", "type": "object" }, + "CustomError.7068f62": { + "properties": { + "foo": { + "title": "Foo", + "type": "string" + } + }, + "required": [ + "foo" + ], + "title": "CustomError", + "type": "object" + }, "FormFileUpload.7068f62": { "properties": { "file": { @@ -167,6 +180,46 @@ }, "openapi": "3.1.0", "paths": { + "/api/custom_error": { + "post": { + "description": "", + "operationId": "post__api_custom_error", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CustomError.7068f62" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CustomError.7068f62" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Content" + } + }, + "summary": "custom_error ", + "tags": [] + } + }, "/api/file_upload": { "post": { "description": "", diff --git a/tests/common.py b/tests/common.py index 91f9d625..3208a36b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -4,15 +4,16 @@ from enum import Enum, IntEnum from typing import Any, Dict, List, Optional, Union, cast -from pydantic import BaseModel, Field, root_validator +# legacy code of: from pydantic import model_validator, field_validator +from pydantic import BaseModel, Field, root_validator, validator from spectree import BaseFile, ExternalDocs, SecurityScheme, SecuritySchemeData, Tag from spectree._pydantic import generate_root_model from spectree.utils import hash_module_path # suppress warnings -warnings.filterwarnings("ignore", category=UserWarning) -warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning, append=True) +warnings.filterwarnings("ignore", category=DeprecationWarning, append=True) api_tag = Tag( name="API", description="🐱", externalDocs=ExternalDocs(url="https://pypi.org") @@ -77,6 +78,11 @@ class Language(str, Enum): class Headers(BaseModel): lang: Language + # @model_validator(mode="before") + # @classmethod + # def lower_keys(cls, data: Any): + # return {key.lower(): value for key, value in data.items()} + @root_validator(pre=True) def lower_keys(cls, values): return {key.lower(): value for key, value in values.items()} @@ -97,6 +103,18 @@ class DemoQuery(BaseModel): names2: List[str] = Field(..., style="matrix", explode=True, non_keyword="dummy") # type: ignore +class CustomError(BaseModel): + foo: str + + # @field_validator("foo") + @validator("foo") + def value_must_be_foo(cls, value): + if value != "foo": + # this is not JSON serializable if included in the error context + raise ValueError("value must be foo") + return value + + def get_paths(spec): paths = [] for path in spec["paths"]: diff --git a/tests/flask_imports/__init__.py b/tests/flask_imports/__init__.py index b1a8d1bd..b17d8f07 100644 --- a/tests/flask_imports/__init__.py +++ b/tests/flask_imports/__init__.py @@ -1,29 +1,37 @@ from .dry_plugin_flask import ( + test_flask_custom_error, test_flask_doc, test_flask_list_json_request, test_flask_make_response_get, test_flask_make_response_post, test_flask_no_response, test_flask_optional_alias_response, + test_flask_query_list, test_flask_return_list_request, test_flask_return_model, + test_flask_return_root_request, test_flask_skip_validation, test_flask_upload_file, + test_flask_validate_basic, test_flask_validate_post_data, test_flask_validation_error_response_status_code, ) __all__ = [ + "test_flask_custom_error", "test_flask_doc", "test_flask_list_json_request", "test_flask_make_response_get", "test_flask_make_response_post", "test_flask_no_response", "test_flask_optional_alias_response", + "test_flask_query_list", "test_flask_return_list_request", "test_flask_return_model", + "test_flask_return_root_request", "test_flask_skip_validation", "test_flask_upload_file", + "test_flask_validate_basic", "test_flask_validate_post_data", "test_flask_validation_error_response_status_code", ] diff --git a/tests/flask_imports/dry_plugin_flask.py b/tests/flask_imports/dry_plugin_flask.py index 9b720c03..1ef75fc2 100644 --- a/tests/flask_imports/dry_plugin_flask.py +++ b/tests/flask_imports/dry_plugin_flask.py @@ -123,7 +123,7 @@ def test_flask_validate_basic(client): assert resp.json == {"msg": "pong"} assert resp.headers.get("X-Error") is None assert resp.headers.get("X-Validation") == "Pass" - assert resp.headers.get("lang") == "en-US" + assert resp.headers.get("lang") == "en-US", resp.headers resp = client.post("api/user/flask") assert resp.status_code == 422 @@ -270,3 +270,13 @@ def test_flask_optional_alias_response(client): def test_flask_query_list(client): resp = client.get("/api/query_list?ids=1&ids=2&ids=3") assert resp.status_code == 200 + + +def test_flask_custom_error(client): + # request error + resp = client.post("/api/custom_error", json={"foo": "bar"}) + assert resp.status_code == 422 + + # response error + resp = client.post("/api/custom_error", json={"foo": "foo"}) + assert resp.status_code == 500 diff --git a/tests/quart_imports/__init__.py b/tests/quart_imports/__init__.py index a5d6c093..0b8ccd73 100644 --- a/tests/quart_imports/__init__.py +++ b/tests/quart_imports/__init__.py @@ -1,6 +1,9 @@ from .dry_plugin_quart import ( + test_quart_custom_error, test_quart_doc, + test_quart_list_json_request, test_quart_no_response, + test_quart_return_list_request, test_quart_return_model, test_quart_skip_validation, test_quart_validate, @@ -8,8 +11,11 @@ ) __all__ = [ + "test_quart_custom_error", "test_quart_doc", + "test_quart_list_json_request", "test_quart_no_response", + "test_quart_return_list_request", "test_quart_return_model", "test_quart_skip_validation", "test_quart_validate", diff --git a/tests/quart_imports/dry_plugin_quart.py b/tests/quart_imports/dry_plugin_quart.py index 531aebdb..a8b2dc19 100644 --- a/tests/quart_imports/dry_plugin_quart.py +++ b/tests/quart_imports/dry_plugin_quart.py @@ -190,3 +190,13 @@ def test_quart_return_list_request(client, pre_serialize: bool): {"name": "user1", "limit": 1}, {"name": "user2", "limit": 2}, ] + + +def test_quart_custom_error(client): + # request error + resp = asyncio.run(client.post("/api/custom_error", json={"foo": "bar"})) + assert resp.status_code == 422 + + # response error + resp = asyncio.run(client.post("/api/custom_error", json={"foo": "foo"})) + assert resp.status_code == 500 diff --git a/tests/test_plugin_falcon.py b/tests/test_plugin_falcon.py index 88f598c9..dcd6bbe6 100644 --- a/tests/test_plugin_falcon.py +++ b/tests/test_plugin_falcon.py @@ -15,6 +15,7 @@ from .common import ( JSON, Cookies, + CustomError, FormFileUpload, Headers, ListJSON, @@ -234,6 +235,15 @@ def on_get(self, req, resp): ) +class CustomErrorView: + name = "custom error view" + + @api.validate(resp=Response(HTTP_200=CustomError)) + def on_post(self, req, resp, json: CustomError): + resp.media = {"foo": "bar"} + resp.status = falcon.HTTP_422 + + class ReturnOptionalAliasView: @api.validate(resp=Response(HTTP_200=OptionalAliasResp)) def on_get(self, req, resp): @@ -270,6 +280,7 @@ def on_post(self, req, resp): app.add_route("/api/return_root", ReturnRootView()) app.add_route("/api/return_optional_alias", ReturnOptionalAliasView()) app.add_route("/api/custom_serializer", ViewWithCustomSerializer()) +app.add_route("/api/custom_error", CustomErrorView()) api.register(app) @@ -587,3 +598,13 @@ def on_post_v2(self, req, resp, json: V2): "POST", "/api/compatibility/v2", json={"value": "invalid"} ) assert resp.status_code == 422 + + +def test_custom_error(client): + # error in request + resp = client.simulate_post("/api/custom_error", json={"foo": "bar"}) + assert resp.status_code == 422 + + # error in response + resp = client.simulate_post("/api/custom_error", json={"foo": "foo"}) + assert resp.status_code == 500 diff --git a/tests/test_plugin_flask.py b/tests/test_plugin_flask.py index 63d61f6d..91e5d85e 100644 --- a/tests/test_plugin_flask.py +++ b/tests/test_plugin_flask.py @@ -10,6 +10,7 @@ JSON, SECURITY_SCHEMAS, Cookies, + CustomError, Form, FormFileUpload, Headers, @@ -62,12 +63,12 @@ def api_after_handler(req, resp, err, _): @app.route("/ping") -@api.validate(headers=Headers, resp=Response(HTTP_202=StrDict), tags=["test", "health"]) -def ping(): +@api.validate(resp=Response(HTTP_202=StrDict), tags=["test", "health"]) +def ping(headers: Headers): """summary description""" - return jsonify(msg="pong"), 202 + return jsonify(msg="pong"), 202, headers @app.route("/api/file_upload", methods=["POST"]) @@ -251,6 +252,12 @@ def return_optional_alias_resp(): return {"schema": "test"} +@app.route("/api/custom_error", methods=["POST"]) +@api.validate(resp=Response(HTTP_200=CustomError)) +def custom_error(json: CustomError): + return {"foo": "bar"} + + # INFO: ensures that spec is calculated and cached _after_ registering # view functions for validations. This enables tests to access `api.spec` # without app_context. diff --git a/tests/test_plugin_flask_blueprint.py b/tests/test_plugin_flask_blueprint.py index 315dbadf..0087d5f5 100644 --- a/tests/test_plugin_flask_blueprint.py +++ b/tests/test_plugin_flask_blueprint.py @@ -10,6 +10,7 @@ from .common import ( JSON, Cookies, + CustomError, Form, FormFileUpload, Headers, @@ -238,6 +239,12 @@ def return_optional_alias(): return {"schema": "test"} +@app.route("/api/custom_error", methods=["POST"]) +@api.validate(resp=Response(HTTP_200=CustomError)) +def custom_error(json: CustomError): + return {"foo": "bar"} + + api.register(app) flask_app = Flask(__name__) diff --git a/tests/test_plugin_flask_view.py b/tests/test_plugin_flask_view.py index 82e8c150..8118327c 100644 --- a/tests/test_plugin_flask_view.py +++ b/tests/test_plugin_flask_view.py @@ -10,6 +10,7 @@ from .common import ( JSON, Cookies, + CustomError, Form, FormFileUpload, Headers, @@ -253,6 +254,12 @@ def get(self): return {"schema": "test"} +class CustomErrorView(MethodView): + @api.validate(resp=Response(HTTP_200=CustomError)) + def post(self, json: CustomError): + return jsonify(foo="bar") + + app.add_url_rule("/ping", view_func=Ping.as_view("ping")) app.add_url_rule("/api/user/", view_func=User.as_view("user"), methods=["POST"]) app.add_url_rule( @@ -303,6 +310,14 @@ def get(self): "/api/return_optional_alias", view_func=ReturnOptionalAlias.as_view("return_optional_alias"), ) +app.add_url_rule( + "/api/return_root", + view_func=ReturnRootView.as_view("return_root_view"), +) +app.add_url_rule( + "/api/custom_error", + view_func=CustomErrorView.as_view("custom_error_view"), +) # INFO: ensures that spec is calculated and cached _after_ registering # view functions for validations. This enables tests to access `api.spec` diff --git a/tests/test_plugin_quart.py b/tests/test_plugin_quart.py index 3a4b4d62..dddca6cd 100644 --- a/tests/test_plugin_quart.py +++ b/tests/test_plugin_quart.py @@ -10,6 +10,7 @@ JSON, SECURITY_SCHEMAS, Cookies, + CustomError, Headers, ListJSON, Order, @@ -179,6 +180,12 @@ def return_root(): ) +@app.route("/api/custom_error", methods=["POST"]) +@api.validate(resp=Response(HTTP_200=CustomError)) +def custom_error(json: CustomError): + return jsonify(foo="bar") + + # INFO: ensures that spec is calculated and cached _after_ registering # view functions for validations. This enables tests to access `api.spec` # without app_context. diff --git a/tests/test_plugin_starlette.py b/tests/test_plugin_starlette.py index 5ea70a11..70775cb9 100644 --- a/tests/test_plugin_starlette.py +++ b/tests/test_plugin_starlette.py @@ -17,6 +17,7 @@ from .common import ( JSON, Cookies, + CustomError, FormFileUpload, Headers, ListJSON, @@ -176,6 +177,11 @@ async def return_optional_alias(request): return JSONResponse({"schema": "test"}) +@api.validate(resp=Response(HTTP_200=CustomError)) +async def custom_error(request, json: CustomError): + return JSONResponse({"foo": "bar"}) + + app = Starlette( routes=[ Route("/ping", Ping), @@ -212,6 +218,7 @@ async def return_optional_alias(request): Route("/return_list", return_list, methods=["GET"]), Route("/return_root", return_root, methods=["GET"]), Route("/return_optional_alias", return_optional_alias, methods=["GET"]), + Route("/custom_error", custom_error, methods=["POST"]), ], ), Mount("/static", app=StaticFiles(directory="docs"), name="static"), @@ -456,3 +463,13 @@ def test_starlette_return_optional_alias(client): resp = client.get("/api/return_optional_alias") assert resp.status_code == 200 assert resp.json() == {"schema": "test"} + + +def test_custom_error(client): + # request error + resp = client.post("/api/custom_error", json={"foo": "bar"}) + assert resp.status_code == 422 + + # response error + resp = client.post("/api/custom_error", json={"foo": "foo"}) + assert resp.status_code == 500