Skip to content

Commit

Permalink
Merge pull request #132 from inventree/category-template-fix
Browse files Browse the repository at this point in the history
Category template fix
  • Loading branch information
SchrodingersGat authored Aug 2, 2022
2 parents 44ea1fb + 1d0edf5 commit 05ea9ff
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 66 deletions.
32 changes: 24 additions & 8 deletions inventree/part.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,20 @@
logger = logging.getLogger('inventree')


class PartCategoryParameterTemplate(inventree.base.InventreeObject):
"""A model which link a ParameterTemplate to a PartCategory"""

URL = 'part/category/parameters'

def getCategory(self):
"""Return the referenced PartCategory instance"""
return PartCategory(self._api, self.category)

def getTemplate(self):
"""Return the referenced ParameterTemplate instance"""
return ParameterTemplate(self._api, self.parameter_template)


class PartCategory(inventree.base.MetadataMixin, inventree.base.InventreeObject):
""" Class representing the PartCategory database model """

Expand All @@ -29,16 +43,18 @@ def getParentCategory(self):
def getChildCategories(self, **kwargs):
return PartCategory.list(self._api, parent=self.pk, **kwargs)

def get_category_parameter_templates(self, fetch_parent=True):
"""
fetch_parent: enable to fetch templates for parent categories
"""
def getCategoryParameterTemplates(self, fetch_parent: bool = True) -> list:
"""Fetch a list of default parameter templates associated with this category
parameters_url = f'part/category/{self.pk}/parameters'
Arguments:
fetch_parent: If True (default) include templates for parents also
"""

return self.list(self._api,
url=parameters_url,
fetch_parent=fetch_parent)
return PartCategoryParameterTemplate.list(
self._api,
category=self.pk,
fetch_parent=fetch_parent
)


class Part(inventree.base.MetadataMixin, inventree.base.ImageMixin, inventree.base.InventreeObject):
Expand Down
160 changes: 102 additions & 58 deletions test/test_part.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,67 +15,12 @@

from test_api import InvenTreeTestCase # noqa: E402

from inventree.part import Part, PartAttachment, PartCategory, Parameter, ParameterTemplate # noqa: E402
from inventree.part import Part, PartAttachment, PartCategory, PartCategoryParameterTemplate, Parameter, ParameterTemplate # noqa: E402
from inventree.part import InternalPrice # noqa: E402


class PartTest(InvenTreeTestCase):
"""
Test for PartCategory and Part objects.
"""

def test_access_erors(self):
"""
Test that errors are flagged when we try to access an invalid part
"""

with self.assertRaises(TypeError):
Part(self.api, 'hello')

with self.assertRaises(ValueError):
Part(self.api, -1)

# Try to access a Part which does not exist
with self.assertRaises(requests.exceptions.HTTPError):
Part(self.api, 9999999999999)

def test_fields(self):
"""
Test field names via OPTIONS request
"""

field_names = Part.fieldNames(self.api)

self.assertIn('active', field_names)
self.assertIn('revision', field_names)
self.assertIn('full_name', field_names)
self.assertIn('IPN', field_names)

def test_options(self):
"""Extends tests for OPTIONS model metadata"""

# Check for field which does not exist
with self.assertLogs():
Part.fieldInfo('abcde', self.api)

active = Part.fieldInfo('active', self.api)

self.assertEqual(active['type'], 'boolean')
self.assertEqual(active['required'], True)
self.assertEqual(active['label'], 'Active')
self.assertEqual(active['default'], True)

for field_name in [
'name',
'description',
'component',
'assembly',
]:
field = Part.fieldInfo(field_name, self.api)

# Check required field attributes
for attr in ['type', 'required', 'read_only', 'label', 'help_text']:
self.assertIn(attr, field)
class PartCategoryTest(InvenTreeTestCase):
"""Tests for PartCategory models"""

def test_part_cats(self):
"""
Expand Down Expand Up @@ -173,6 +118,105 @@ def test_caps(self):

self.assertEqual(len(parts), n_parts + 10)

def test_part_category_parameter_templates(self):
"""Unit tests for the PartCategoryParameterTemplate model"""

electronics = PartCategory(self.api, pk=3)

# Ensure there are some parameter templates associated with this category
templates = electronics.getCategoryParameterTemplates(fetch_parent=False)

if len(templates) == 0:
for name in ['wodth', 'lungth', 'herght']:
template = ParameterTemplate.create(self.api, data={
'name': name,
'units': 'uu',
})

pcpt = PartCategoryParameterTemplate.create(
self.api,
data={
'category': electronics.pk,
'parameter_template': template.pk,
'default_value': name,
}
)

# Check that model lookup functions work
self.assertEqual(pcpt.getCategory().pk, electronics.pk)
self.assertEqual(pcpt.getTemplate().pk, template.pk)

# Reload
templates = electronics.getCategoryParameterTemplates(fetch_parent=False)

self.assertTrue(len(templates) >= 3)

# Check child categories
childs = electronics.getChildCategories()

self.assertTrue(len(childs) > 0)

for child in childs:
child_templates = child.getCategoryParameterTemplates(fetch_parent=True)
self.assertTrue(len(child_templates) >= 3)


class PartTest(InvenTreeTestCase):
"""Tests for Part models"""

def test_access_erors(self):
"""
Test that errors are flagged when we try to access an invalid part
"""

with self.assertRaises(TypeError):
Part(self.api, 'hello')

with self.assertRaises(ValueError):
Part(self.api, -1)

# Try to access a Part which does not exist
with self.assertRaises(requests.exceptions.HTTPError):
Part(self.api, 9999999999999)

def test_fields(self):
"""
Test field names via OPTIONS request
"""

field_names = Part.fieldNames(self.api)

self.assertIn('active', field_names)
self.assertIn('revision', field_names)
self.assertIn('full_name', field_names)
self.assertIn('IPN', field_names)

def test_options(self):
"""Extends tests for OPTIONS model metadata"""

# Check for field which does not exist
with self.assertLogs():
Part.fieldInfo('abcde', self.api)

active = Part.fieldInfo('active', self.api)

self.assertEqual(active['type'], 'boolean')
self.assertEqual(active['required'], True)
self.assertEqual(active['label'], 'Active')
self.assertEqual(active['default'], True)

for field_name in [
'name',
'description',
'component',
'assembly',
]:
field = Part.fieldInfo(field_name, self.api)

# Check required field attributes
for attr in ['type', 'required', 'read_only', 'label', 'help_text']:
self.assertIn(attr, field)

def test_pagination(self):
""" Test that we can paginate the queryset by specifying a 'limit' parameter"""

Expand Down

0 comments on commit 05ea9ff

Please sign in to comment.