|
4 | 4 | import quart |
5 | 5 | from quart import Blueprint, abort, current_app, jsonify, make_response, request |
6 | 6 |
|
7 | | -from spectree._pydantic import InternalValidationError, ValidationError |
| 7 | +from spectree._pydantic import ( |
| 8 | + InternalValidationError, |
| 9 | + SerializedPydanticResponse, |
| 10 | + ValidationError, |
| 11 | + is_partial_base_model_instance, |
| 12 | + serialize_model_instance, |
| 13 | +) |
8 | 14 | from spectree._types import ModelType |
9 | | -from spectree.plugins.base import Context |
10 | | -from spectree.plugins.werkzeug_utils import WerkzeugPlugin |
| 15 | +from spectree.plugins.base import Context, validate_response |
| 16 | +from spectree.plugins.werkzeug_utils import WerkzeugPlugin, flask_response_unpack |
11 | 17 | from spectree.response import Response |
12 | 18 | from spectree.utils import cached_type_hints, get_multidict_items |
13 | 19 |
|
@@ -54,6 +60,64 @@ async def request_validation(self, request, query, json, form, headers, cookies) |
54 | 60 | cookies.parse_obj(req_cookies) if cookies else None, |
55 | 61 | ) |
56 | 62 |
|
| 63 | + async def validate_response( |
| 64 | + self, |
| 65 | + resp, |
| 66 | + resp_model: Optional[Response], |
| 67 | + skip_validation: bool, |
| 68 | + ): |
| 69 | + resp_validation_error = None |
| 70 | + payload, status, additional_headers = flask_response_unpack(resp) |
| 71 | + |
| 72 | + if self.is_app_response(payload): |
| 73 | + resp_status, resp_headers = payload.status_code, payload.headers |
| 74 | + payload = await payload.get_data() |
| 75 | + # the inner flask.Response.status_code only takes effect when there is |
| 76 | + # no other status code |
| 77 | + if status == 200: |
| 78 | + status = resp_status |
| 79 | + # use the `Header` object to avoid deduplicated by `make_response` |
| 80 | + resp_headers.extend(additional_headers) |
| 81 | + additional_headers = resp_headers |
| 82 | + |
| 83 | + if not skip_validation and resp_model: |
| 84 | + try: |
| 85 | + response_validation_result = validate_response( |
| 86 | + validation_model=resp_model.find_model(status), |
| 87 | + response_payload=payload, |
| 88 | + ) |
| 89 | + except (InternalValidationError, ValidationError) as err: |
| 90 | + errors = ( |
| 91 | + err.errors() |
| 92 | + if isinstance(err, InternalValidationError) |
| 93 | + else err.errors(include_context=False) |
| 94 | + ) |
| 95 | + response = await make_response(errors, 500) |
| 96 | + resp_validation_error = err |
| 97 | + else: |
| 98 | + response = await make_response( |
| 99 | + self.get_current_app().response_class( |
| 100 | + response_validation_result.payload.data, |
| 101 | + mimetype="application/json", |
| 102 | + ) |
| 103 | + if isinstance( |
| 104 | + response_validation_result.payload, |
| 105 | + SerializedPydanticResponse, |
| 106 | + ) |
| 107 | + else response_validation_result.payload, |
| 108 | + status, |
| 109 | + additional_headers, |
| 110 | + ) |
| 111 | + else: |
| 112 | + if is_partial_base_model_instance(payload): |
| 113 | + payload = self.get_current_app().response_class( |
| 114 | + serialize_model_instance(payload).data, |
| 115 | + mimetype="application/json", |
| 116 | + ) |
| 117 | + response = await make_response(payload, status, additional_headers) |
| 118 | + |
| 119 | + return response, resp_validation_error |
| 120 | + |
57 | 121 | async def validate( |
58 | 122 | self, |
59 | 123 | func: Callable, |
@@ -104,7 +168,7 @@ async def validate( |
104 | 168 | else func(*args, **kwargs) |
105 | 169 | ) |
106 | 170 |
|
107 | | - response, resp_validation_error = self.validate_response( |
| 171 | + response, resp_validation_error = await self.validate_response( |
108 | 172 | result, resp, skip_validation |
109 | 173 | ) |
110 | 174 | after(request, response, resp_validation_error, None) |
|
0 commit comments