diff --git a/inventree/part.py b/inventree/part.py index eb57aa2..c8f3f19 100644 --- a/inventree/part.py +++ b/inventree/part.py @@ -12,6 +12,20 @@ logger = logging.getLogger('inventree') +class PartCategoryParameterTemplate(inventree.base.InventreeObject): + """A model which link a ParameterTemplate to a PartCategory""" + + URL = 'part/category/parameters' + + def getCategory(self): + """Return the referenced PartCategory instance""" + return PartCategory(self._api, self.category) + + def getTemplate(self): + """Return the referenced ParameterTemplate instance""" + return ParameterTemplate(self._api, self.parameter_template) + + class PartCategory(inventree.base.MetadataMixin, inventree.base.InventreeObject): """ Class representing the PartCategory database model """ @@ -29,16 +43,18 @@ def getParentCategory(self): def getChildCategories(self, **kwargs): return PartCategory.list(self._api, parent=self.pk, **kwargs) - def get_category_parameter_templates(self, fetch_parent=True): - """ - fetch_parent: enable to fetch templates for parent categories - """ + def getCategoryParameterTemplates(self, fetch_parent: bool = True) -> list: + """Fetch a list of default parameter templates associated with this category - parameters_url = f'part/category/{self.pk}/parameters' + Arguments: + fetch_parent: If True (default) include templates for parents also + """ - return self.list(self._api, - url=parameters_url, - fetch_parent=fetch_parent) + return PartCategoryParameterTemplate.list( + self._api, + category=self.pk, + fetch_parent=fetch_parent + ) class Part(inventree.base.MetadataMixin, inventree.base.ImageMixin, inventree.base.InventreeObject): diff --git a/test/test_part.py b/test/test_part.py index 26fbe17..ada3b10 100644 --- a/test/test_part.py +++ b/test/test_part.py @@ -15,67 +15,12 @@ from test_api import InvenTreeTestCase # noqa: E402 -from inventree.part import Part, PartAttachment, PartCategory, Parameter, ParameterTemplate # noqa: E402 +from inventree.part import Part, PartAttachment, PartCategory, PartCategoryParameterTemplate, Parameter, ParameterTemplate # noqa: E402 from inventree.part import InternalPrice # noqa: E402 -class PartTest(InvenTreeTestCase): - """ - Test for PartCategory and Part objects. - """ - - def test_access_erors(self): - """ - Test that errors are flagged when we try to access an invalid part - """ - - with self.assertRaises(TypeError): - Part(self.api, 'hello') - - with self.assertRaises(ValueError): - Part(self.api, -1) - - # Try to access a Part which does not exist - with self.assertRaises(requests.exceptions.HTTPError): - Part(self.api, 9999999999999) - - def test_fields(self): - """ - Test field names via OPTIONS request - """ - - field_names = Part.fieldNames(self.api) - - self.assertIn('active', field_names) - self.assertIn('revision', field_names) - self.assertIn('full_name', field_names) - self.assertIn('IPN', field_names) - - def test_options(self): - """Extends tests for OPTIONS model metadata""" - - # Check for field which does not exist - with self.assertLogs(): - Part.fieldInfo('abcde', self.api) - - active = Part.fieldInfo('active', self.api) - - self.assertEqual(active['type'], 'boolean') - self.assertEqual(active['required'], True) - self.assertEqual(active['label'], 'Active') - self.assertEqual(active['default'], True) - - for field_name in [ - 'name', - 'description', - 'component', - 'assembly', - ]: - field = Part.fieldInfo(field_name, self.api) - - # Check required field attributes - for attr in ['type', 'required', 'read_only', 'label', 'help_text']: - self.assertIn(attr, field) +class PartCategoryTest(InvenTreeTestCase): + """Tests for PartCategory models""" def test_part_cats(self): """ @@ -173,6 +118,105 @@ def test_caps(self): self.assertEqual(len(parts), n_parts + 10) + def test_part_category_parameter_templates(self): + """Unit tests for the PartCategoryParameterTemplate model""" + + electronics = PartCategory(self.api, pk=3) + + # Ensure there are some parameter templates associated with this category + templates = electronics.getCategoryParameterTemplates(fetch_parent=False) + + if len(templates) == 0: + for name in ['wodth', 'lungth', 'herght']: + template = ParameterTemplate.create(self.api, data={ + 'name': name, + 'units': 'uu', + }) + + pcpt = PartCategoryParameterTemplate.create( + self.api, + data={ + 'category': electronics.pk, + 'parameter_template': template.pk, + 'default_value': name, + } + ) + + # Check that model lookup functions work + self.assertEqual(pcpt.getCategory().pk, electronics.pk) + self.assertEqual(pcpt.getTemplate().pk, template.pk) + + # Reload + templates = electronics.getCategoryParameterTemplates(fetch_parent=False) + + self.assertTrue(len(templates) >= 3) + + # Check child categories + childs = electronics.getChildCategories() + + self.assertTrue(len(childs) > 0) + + for child in childs: + child_templates = child.getCategoryParameterTemplates(fetch_parent=True) + self.assertTrue(len(child_templates) >= 3) + + +class PartTest(InvenTreeTestCase): + """Tests for Part models""" + + def test_access_erors(self): + """ + Test that errors are flagged when we try to access an invalid part + """ + + with self.assertRaises(TypeError): + Part(self.api, 'hello') + + with self.assertRaises(ValueError): + Part(self.api, -1) + + # Try to access a Part which does not exist + with self.assertRaises(requests.exceptions.HTTPError): + Part(self.api, 9999999999999) + + def test_fields(self): + """ + Test field names via OPTIONS request + """ + + field_names = Part.fieldNames(self.api) + + self.assertIn('active', field_names) + self.assertIn('revision', field_names) + self.assertIn('full_name', field_names) + self.assertIn('IPN', field_names) + + def test_options(self): + """Extends tests for OPTIONS model metadata""" + + # Check for field which does not exist + with self.assertLogs(): + Part.fieldInfo('abcde', self.api) + + active = Part.fieldInfo('active', self.api) + + self.assertEqual(active['type'], 'boolean') + self.assertEqual(active['required'], True) + self.assertEqual(active['label'], 'Active') + self.assertEqual(active['default'], True) + + for field_name in [ + 'name', + 'description', + 'component', + 'assembly', + ]: + field = Part.fieldInfo(field_name, self.api) + + # Check required field attributes + for attr in ['type', 'required', 'read_only', 'label', 'help_text']: + self.assertIn(attr, field) + def test_pagination(self): """ Test that we can paginate the queryset by specifying a 'limit' parameter"""