diff --git a/entity/models.py b/entity/models.py index b5dd473..df1da97 100644 --- a/entity/models.py +++ b/entity/models.py @@ -500,6 +500,8 @@ def get_all_entities(self, membership_cache=None, entities_by_kind=None, return_ memberships = membership_cache.get(self.id) if memberships: if self.logic_string: + entity_kind_id = memberships[0][1] + full_set = set(entities_by_kind[entity_kind_id]['all']) try: filter_tree = ast.parse(self.logic_string.lower()) except: @@ -523,7 +525,7 @@ def get_all_entities(self, membership_cache=None, entities_by_kind=None, return_ self.validate_filter_indices(indices, expanded_memberships) kmatch = self._node_to_kmatch(filter_tree.body[0].value) kmatch = self._map_kmatch_values(kmatch, expanded_memberships) - entity_ids = self._process_kmatch(kmatch, full_set=expanded_memberships[-1]) + entity_ids = self._process_kmatch(kmatch, full_set=full_set) else: # Loop over each membership in this group @@ -693,6 +695,12 @@ def get_entities_by_kind(membership_cache=None, is_active=True): kinds_with_supers = set() super_ids = set() + # Determine if we need to include the "universal set" aka all for a kind based on the presence of a logic_string + group_ids_with_logic_string = set(EntityGroup.objects.filter( + id__in=membership_cache.keys(), + logic_string__isnull=False, + ).values_list('id', flat=True)) + # Loop over each group for group_id, memberships in membership_cache.items(): @@ -705,6 +713,11 @@ def get_entities_by_kind(membership_cache=None, is_active=True): # Make sure a dict exists for this kind entities_by_kind.setdefault(entity_kind_id, {}) + # Always include all if there is a logic string + if group_id in group_ids_with_logic_string: + entities_by_kind[entity_kind_id]['all'] = [] + kinds_with_all.add(entity_kind_id) + # Check if this is all entities of a kind under a specific entity if entity_id: entities_by_kind[entity_kind_id][entity_id] = [] diff --git a/entity/tests/model_tests.py b/entity/tests/model_tests.py index e241171..5e4e297 100644 --- a/entity/tests/model_tests.py +++ b/entity/tests/model_tests.py @@ -811,7 +811,7 @@ def test_logic_string(self): EntityRelationship.objects.bulk_create(relationships) # Create the entity group - entity_group = G(EntityGroup, logic_string='(((1 AND 2) OR (3 AND 4)) AND NOT(5) OR 6) AND 7') + entity_group = G(EntityGroup, logic_string='((1 AND 2) OR (3 AND 4)) AND NOT(5) OR 6') # Create the memberships -- two memberships of all subs under a kind G(EntityGroupMembership, entity_group=entity_group, sub_entity_kind=sub_entity_kind, entity=super_entity_a) @@ -820,7 +820,6 @@ def test_logic_string(self): G(EntityGroupMembership, entity_group=entity_group, sub_entity_kind=sub_entity_kind, entity=super_entity_d) G(EntityGroupMembership, entity_group=entity_group, sub_entity_kind=None, entity=sub_entities[1]) G(EntityGroupMembership, entity_group=entity_group, sub_entity_kind=None, entity=sub_entities[9]) - G(EntityGroupMembership, entity_group=entity_group, sub_entity_kind=sub_entity_kind, entity=None) entity_ids = entity_group.get_all_entities() self.assertEqual(entity_ids, set([ @@ -829,6 +828,51 @@ def test_logic_string(self): sub_entities[9].id, ])) + def test_logic_string_not(self): + """ + Verifies that the universal set is properly fetched and used to NOT a set + Group A: 0, 1, 2 + NOT(A) = 3, 4, 5, 6, 7, 8 + + Memberships: + 1. User in Group A + + Logic: NOT(1) + (3, 4, 5, 6, 7, 8) + """ + super_entity_kind = G(EntityKind) + sub_entity_kind = G(EntityKind) + super_entity_a = G(Entity, entity_kind=super_entity_kind) + sub_entities = [ + G(Entity, entity_kind=sub_entity_kind) + for _ in range(10) + ] + + # Create the relationships + relationships = [ + EntityRelationship(sub_entity=sub_entities[0], super_entity=super_entity_a), + EntityRelationship(sub_entity=sub_entities[1], super_entity=super_entity_a), + EntityRelationship(sub_entity=sub_entities[2], super_entity=super_entity_a), + ] + EntityRelationship.objects.bulk_create(relationships) + + # Create the entity group + entity_group = G(EntityGroup, logic_string='NOT(1)') + + # Create the membership + G(EntityGroupMembership, entity_group=entity_group, sub_entity_kind=sub_entity_kind, entity=super_entity_a) + + entity_ids = entity_group.get_all_entities() + self.assertEqual(entity_ids, set([ + sub_entities[3].id, + sub_entities[4].id, + sub_entities[5].id, + sub_entities[6].id, + sub_entities[7].id, + sub_entities[8].id, + sub_entities[9].id, + ])) + def test_individual_entities_returned(self): e = self.super_entities[0] G(EntityGroupMembership, entity_group=self.group, entity=e, sub_entity_kind=None) @@ -904,7 +948,7 @@ def test_number_of_queries(self): G(EntityGroupMembership, entity_group=self.group, entity=e2, sub_entity_kind=self.kind2) - with self.assertNumQueries(3): + with self.assertNumQueries(4): list(self.group.all_entities()) @@ -1081,14 +1125,14 @@ def test_get_all_entities(self): [None, account_kind], ]) - with self.assertNumQueries(3): + with self.assertNumQueries(4): membership_cache = EntityGroup.objects.get_membership_cache() entities_by_kind = get_entities_by_kind(membership_cache=membership_cache) for entity_group in entity_groups: entity_group.get_all_entities(membership_cache, entities_by_kind) - with self.assertNumQueries(3): + with self.assertNumQueries(4): get_entities_by_kind() # Make sure to hit the no group cache case