diff --git a/chalice_spec/docs.py b/chalice_spec/docs.py index 50d1c3e..a856708 100644 --- a/chalice_spec/docs.py +++ b/chalice_spec/docs.py @@ -6,6 +6,7 @@ DEFAULT_DESCRIPTION = "Success" DEFAULT_CODE = 200 +DEFAULT_CONTENT_TYPE = "application/json" class Response: @@ -15,10 +16,17 @@ class Response: and an optional description. """ - def __init__(self, model: type, code: int = 200, description: str = "Success"): + def __init__( + self, + model: type, + code: int = 200, + description: str = "Success", + content_type: str = DEFAULT_CONTENT_TYPE, + ): self.model = model self.code = code self.description = description + self.content_type = content_type class Operation: @@ -59,19 +67,23 @@ def __init__( def _populate_response(self, response: Union[Response, type]): if isinstance(response, Response): # If this is a Response object, we can track it as-is. - self.responses = {response.code: response} + self.responses = {response.code: {DEFAULT_CONTENT_TYPE: response}} else: # If not, we will use sensible defaults - self.responses = {DEFAULT_CODE: Response(model=response)} + self.responses = { + DEFAULT_CODE: {DEFAULT_CONTENT_TYPE: Response(model=response)} + } def _populate_responses(self, responses: List[Response]): self.responses = {} for response in responses: - if response.code in self.responses: + if response.code not in self.responses: + self.responses[response.code] = {} + if response.content_type in self.responses[response.code]: raise TypeError( - "You must only specify one response per HTTP status code" + f"Multiple responses defined for {response.code} — {response.content_type}" ) - self.responses[response.code] = response + self.responses[response.code][response.content_type] = response Method = Union[Type[BaseModel], Operation] @@ -139,7 +151,7 @@ def _build_operation_from_operation( if method.content_types else content_types[0] if content_types - else "application/json" + else DEFAULT_CONTENT_TYPE ) operation["requestBody"] = { "content": { @@ -152,21 +164,22 @@ def _build_operation_from_operation( if method.responses: responses = {} - for code, response in method.responses.items(): - if response.model.__name__ not in spec.components.schemas: - spec.components.schema( - response.model.__name__, - model=response.model, - spec=spec, - ) - responses[code] = { - "description": response.description, - "content": { - "application/json": { - "schema": response.model.__name__, + for code, response_contents in method.responses.items(): + for content_type, response in response_contents.items(): + if response.model.__name__ not in spec.components.schemas: + spec.components.schema( + response.model.__name__, + model=response.model, + spec=spec, + ) + if code not in responses: + responses[code] = { + "description": response.description, + "content": {}, } - }, - } + responses[code]["content"][content_type] = { + "schema": response.model.__name__ + } operation["responses"] = responses diff --git a/tests/test_chalice.py b/tests/test_chalice.py index bf0de42..db20b5a 100644 --- a/tests/test_chalice.py +++ b/tests/test_chalice.py @@ -426,11 +426,20 @@ def get_post(): def test_content_types(): app, spec = setup_test() + with pytest.raises(TypeError): + Op(responses=[Resp(model=TestSchema), Resp(model=TestSchema)]) + @app.route( "/posts", methods=["POST"], content_types=["multipart/form-data"], - docs=Docs(request=TestSchema, response=AnotherSchema), + docs=Docs( + request=TestSchema, + responses=[ + Resp(model=AnotherSchema, content_type="application/json"), + Resp(model=AnotherSchema, content_type="application/xml"), + ], + ), ) def get_post(): pass @@ -454,7 +463,12 @@ def get_post(): "schema": { "$ref": "#/components/schemas/AnotherSchema" } - } + }, + "application/xml": { + "schema": { + "$ref": "#/components/schemas/AnotherSchema" + } + }, }, } },