diff --git a/django_tables2/columns/templatecolumn.py b/django_tables2/columns/templatecolumn.py index ddddc24a..4c97666b 100644 --- a/django_tables2/columns/templatecolumn.py +++ b/django_tables2/columns/templatecolumn.py @@ -2,6 +2,8 @@ from django.template.loader import get_template from django.utils.html import strip_tags +from django_tables2.utils import call_with_appropriate + from .base import Column, library @@ -40,11 +42,19 @@ class ExampleTable(tables.Table): empty_values = () - def __init__(self, template_code=None, template_name=None, extra_context=None, **extra): + def __init__( + self, + template_code=None, + template_name=None, + context_object_name="record", + extra_context=None, + **extra + ): super().__init__(**extra) self.template_code = template_code self.template_name = template_name self.extra_context = extra_context or {} + self.context_object_name = context_object_name if not self.template_code and not self.template_name: raise ValueError("A template must be provided") @@ -56,11 +66,18 @@ def render(self, record, table, value, bound_column, **kwargs): additional_context = { "default": bound_column.default, "column": bound_column, - "record": record, + self.context_object_name: record, "value": value, "row_counter": kwargs["bound_row"].row_counter, } - additional_context.update(self.extra_context) + + extra_context = self.extra_context + if callable(extra_context): + extra_context = call_with_appropriate( + extra_context, + {"record": record, "table": table, "value": value, "bound_column": bound_column}, + ) + additional_context.update(extra_context) with context.update(additional_context): if self.template_code: return Template(self.template_code).render(context) @@ -75,3 +92,6 @@ def value(self, **kwargs): """ html = super().value(**kwargs) return strip_tags(html).strip() if isinstance(html, str) else html + + def get_context_data(self, **kwargs): + return diff --git a/tests/columns/test_templatecolumn.py b/tests/columns/test_templatecolumn.py index 643ab3cb..414fead7 100644 --- a/tests/columns/test_templatecolumn.py +++ b/tests/columns/test_templatecolumn.py @@ -115,3 +115,19 @@ class Table(tables.Table): table = Table([{"track": "Space Oddity"}]) self.assertEqual(list(table.as_values()), [["Track"], ["Space Oddity"]]) + + def test_context_object_name(self): + class Table(tables.Table): + name = tables.TemplateColumn("{{ user.name }}", context_object_name="user") + + table = Table([{"name": "Bob"}]) + self.assertEqual(list(table.as_values()), [["Name"], ["Bob"]]) + + def test_extra_context_callable(self): + class Table(tables.Table): + size = tables.TemplateColumn( + "{{ size }}", extra_context=lambda record: {"size": record["clothes"]["size"]} + ) + + table = Table([{"clothes": {"size": "XL"}}]) + self.assertEqual(list(table.as_values()), [["Size"], ["XL"]])