From bd257abd59f7d7c279e562e431c638293d9468dd Mon Sep 17 00:00:00 2001 From: Michael Milton Date: Fri, 25 Oct 2019 13:49:11 +1100 Subject: [PATCH] Add include_data=True option for including all relationships --- marshmallow_jsonapi/schema.py | 14 +++++++++++++- tests/test_schema.py | 11 +++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/marshmallow_jsonapi/schema.py b/marshmallow_jsonapi/schema.py index b8de86a..252b79c 100644 --- a/marshmallow_jsonapi/schema.py +++ b/marshmallow_jsonapi/schema.py @@ -74,7 +74,10 @@ class Meta: def __init__(self, *args, **kwargs): self.include_data = kwargs.pop("include_data", ()) super().__init__(*args, **kwargs) - if self.include_data: + + if self.include_data is True: + self.include_all_data() + elif self.include_data: self.check_relations(self.include_data) if not self.opts.type_: @@ -93,6 +96,15 @@ def __init__(self, *args, **kwargs): OPTIONS_CLASS = SchemaOpts + def include_all_data(self): + """ + Recursively set include_data for all relationships to this schema + """ + for field in self.fields.values(): + if isinstance(field, BaseRelationship): + field.include_data = True + field.schema.include_all_data() + def check_relations(self, relations): """Recursive function which checks if a relation is valid.""" for rel in relations: diff --git a/tests/test_schema.py b/tests/test_schema.py index 4ca16ea..cc7b64c 100755 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -149,6 +149,17 @@ def test_include_data_with_all_relations(self, post): } assert included_comments_author_ids == expected_comments_author_ids + def test_include_data_auto_all(self, post): + """ + Test that we can use include_data=True to include all relations recursively + """ + data = unpack(PostSchema(include_data=True).dump(post)) + assert "included" in data + assert len(data["included"]) == 8 + for included in data["included"]: + assert included["id"] + assert included["type"] in ("people", "comments", "keywords") + def test_include_no_data(self, post): data = unpack(PostSchema(include_data=()).dump(post)) assert "included" not in data