diff --git a/src/ralph/assets/api/views.py b/src/ralph/assets/api/views.py index b730ac411e..e1b2227fb6 100644 --- a/src/ralph/assets/api/views.py +++ b/src/ralph/assets/api/views.py @@ -234,7 +234,8 @@ class DCHostViewSet(BaseObjectViewSetMixin, RalphAPIViewSet): ] select_related = [ 'service_env', 'service_env__service', 'service_env__environment', - 'configuration_path', 'configuration_path__module' + 'configuration_path', 'configuration_path__module', + 'parent__cloudproject', ] prefetch_related = [ 'tags', diff --git a/src/ralph/lib/custom_fields/fields.py b/src/ralph/lib/custom_fields/fields.py index cbbdde7d6a..32dde965ef 100644 --- a/src/ralph/lib/custom_fields/fields.py +++ b/src/ralph/lib/custom_fields/fields.py @@ -9,6 +9,7 @@ ) from django.contrib.contenttypes.models import ContentType from django.db import connection, models +from django.db.models.fields.related import OneToOneRel from ralph.admin.helpers import get_field_by_relation_path, getattr_dunder @@ -55,6 +56,19 @@ def contribute_to_class(self, cls, name, **kwargs): ) +def _get_content_type_from_field_path(model, field_path): + # TODO: add some validator for it + # TODO: store fields in some field in meta, not "calculate" + # it every time + field = get_field_by_relation_path(model, field_path) + if isinstance(field, OneToOneRel): + related_model = field.related_model + else: + related_model = field.rel.to + content_type = ContentType.objects.get_for_model(related_model) + return content_type + + def _prioritize_custom_field_values(objects, model, content_type): """ Sort custom field values by priorities and leave the ones with @@ -69,8 +83,7 @@ def _prioritize_custom_field_values(objects, model, content_type): """ ct_priority = [content_type.id] for field_path in model.custom_fields_inheritance: - field = get_field_by_relation_path(model, field_path) - content_type = ContentType.objects.get_for_model(field.rel.to) + content_type = _get_content_type_from_field_path(model, field_path) ct_priority.append(content_type.id) ct_priority = { ct_id: index for (index, ct_id) in enumerate(ct_priority) @@ -199,11 +212,9 @@ def _get_inheritance_filters_for_single_instance(self): # for each related field (foreign key), add it's content_type # and object_id to queryset filter for field_path in self.instance.custom_fields_inheritance: - # TODO: add some validator for it - # TODO: store fields in some field in meta, not "calculate" - # it every time - field = get_field_by_relation_path(self.instance, field_path) - content_type = ContentType.objects.get_for_model(field.rel.to) + content_type = _get_content_type_from_field_path( + self.instance, field_path + ) value = getattr_dunder(self.instance, field_path) # filter only if related field has some value if value: @@ -260,8 +271,9 @@ def get_prefetch_queryset(self, instances, queryset=None): # process each dependent field from `custom_fields_inheritance` for field_path in self.instance.custom_fields_inheritance: # assume that field is foreign key - field = get_field_by_relation_path(self.instance, field_path) - content_type = ContentType.objects.get_for_model(field.rel.to) + content_type = _get_content_type_from_field_path( + self.instance, field_path + ) content_types.add(content_type) # for each instance, get value of this dependent field for instance in instances: diff --git a/src/ralph/virtual/models.py b/src/ralph/virtual/models.py index 858be98d8e..9f1d1b4acf 100644 --- a/src/ralph/virtual/models.py +++ b/src/ralph/virtual/models.py @@ -134,6 +134,9 @@ def disk(self, new_disk): class CloudProject(PreviousStateMixin, AdminAbsoluteUrlMixin, BaseObject): cloudprovider = models.ForeignKey(CloudProvider) cloudprovider._autocomplete = False + custom_fields_inheritance = [ + 'service_env', + ] project_id = models.CharField( verbose_name=_('project ID'), @@ -155,6 +158,12 @@ def update_service_env_on_cloudproject_save(sender, instance, **kwargs): class CloudHost(PreviousStateMixin, AdminAbsoluteUrlMixin, BaseObject): previous_dc_host_update_fields = ['hostname'] + custom_fields_inheritance = [ + 'parent__cloudproject', + 'configuration_path', + 'configuration_path__module', + 'service_env', + ] def save(self, *args, **kwargs): try: diff --git a/src/ralph/virtual/tests/test_models.py b/src/ralph/virtual/tests/test_models.py index a0fbc6ff17..61ce4397bf 100644 --- a/src/ralph/virtual/tests/test_models.py +++ b/src/ralph/virtual/tests/test_models.py @@ -3,7 +3,16 @@ from ralph.assets.models.assets import ServiceEnvironment from ralph.assets.models.choices import ComponentType from ralph.assets.models.components import ComponentModel -from ralph.assets.tests.factories import EnvironmentFactory, ServiceFactory +from ralph.assets.tests.factories import ( + EnvironmentFactory, + ServiceEnvironmentFactory, + ServiceFactory +) +from ralph.lib.custom_fields.models import ( + CustomField, + CustomFieldTypes, + CustomFieldValue +) from ralph.networks.models import IPAddress from ralph.tests import RalphTestCase from ralph.virtual.models import CloudHost, VirtualComponent @@ -136,3 +145,46 @@ def test_service_env_inheritance_on_host_creation(self): parent=self.cloud_project ) self.assertEqual(new_host.service_env, self.service_env[1]) + + +class CloudHostTestCase(RalphTestCase): + def setUp(self): + self.service = ServiceFactory() + self.service_env = ServiceEnvironmentFactory(service=self.service) + self.cloud_project = CloudProjectFactory() + self.cloud_host = CloudHostFactory(parent=self.cloud_project) + + self.custom_field_str = CustomField.objects.create( + name='test str', type=CustomFieldTypes.STRING, default_value='xyz' + ) + + def test_if_custom_fields_are_inherited_from_cloud_project(self): + self.assertEqual(self.cloud_host.custom_fields_as_dict, {}) + CustomFieldValue.objects.create( + object=self.cloud_project, + custom_field=self.custom_field_str, + value='sample_value', + ) + self.assertEqual( + self.cloud_host.custom_fields_as_dict, + {'test str': 'sample_value'} + ) + + def test_if_custom_fields_are_inherited_and_overwrited_from_cloud_project( + self + ): + self.assertEqual(self.cloud_host.custom_fields_as_dict, {}) + CustomFieldValue.objects.create( + object=self.cloud_project, + custom_field=self.custom_field_str, + value='sample_value', + ) + CustomFieldValue.objects.create( + object=self.cloud_host, + custom_field=self.custom_field_str, + value='sample_value22', + ) + self.assertEqual( + self.cloud_host.custom_fields_as_dict, + {'test str': 'sample_value22'} + )