Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

importers.csvbase: Allow to specifying a default value #153

Merged
merged 2 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions beangulp/importers/csvbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

EMPTY = frozenset()

NA = object()
"""Marker to indicate that a value was not specified."""


def _resolve(spec, names):
"""Resolve column specification into column index.
Expand Down Expand Up @@ -40,10 +43,12 @@ class Column:

Args:
name: Column name or index.
default: Value to return if the field is empty.
"""
default: Value to return if the field is empty. When a default
value is not provided, emty fields are passed to the parser
to generate a value.

def __init__(self, *names, default=None):
"""
def __init__(self, *names, default=NA):
self.names = names
self.default = default

Expand All @@ -70,7 +75,7 @@ def getter(self, names):
idxs = [_resolve(x, names) for x in self.names]
def func(obj):
value = tuple(obj[i] for i in idxs)
if not all(value) and self.default:
if self.default is not NA and not any(value):
return self.default
return self.parse(*value)
return func
Expand All @@ -94,10 +99,11 @@ class Columns(Column):
Args:
name: Column names or indexes.
sep: Separator to use to join columns.
default: Value to return all the fields are empty, if specified.

"""
def __init__(self, *names, sep=' '):
super().__init__(*names)
def __init__(self, *names, sep=' ', default=NA):
super().__init__(*names, default=default)
self.sep = sep

def parse(self, *values):
Expand All @@ -114,10 +120,11 @@ class Date(Column):
Args:
name: Column name or index.
frmt: Date format specification.
default: Value to return if the field is empty, if specified.

"""
def __init__(self, name, frmt='%Y-%m-%d'):
super().__init__(name)
def __init__(self, name, frmt='%Y-%m-%d', default=NA):
super().__init__(name, default=default)
self.frmt = frmt

def parse(self, value):
Expand All @@ -137,11 +144,11 @@ class Amount(Column):
subs: Dictionary mapping regular expression patterns to
replacement strings. Substitutions are performed with
re.sub() in the order they are specified.
default: Value to return if the field is empty, if specified.

"""

def __init__(self, name, subs=None):
super().__init__(name)
def __init__(self, name, subs=None, default=NA):
super().__init__(name, default=default)
self.subs = subs if subs is not None else {}

def parse(self, value):
Expand Down
50 changes: 49 additions & 1 deletion beangulp/importers/csvbase_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,18 @@ def test_strip(self):
value = func((' value ', ))
self.assertEqual(value, 'value')

def test_default(self):
def test_default_value(self):
column = Column(0, default=42)
func = column.getter(None)
value = func(('', ))
self.assertEqual(value, 42)

def test_default_value_none(self):
column = Column(0, default=None)
func = column.getter(None)
value = func(('', ))
self.assertIsNone(value)


class TestDateColumn(unittest.TestCase):

Expand All @@ -57,6 +63,18 @@ def test_custom_format(self):
value = func(('16.05.2021', ))
self.assertEqual(value, datetime.date(2021, 5, 16))

def test_default_value(self):
column = Date(0, default=datetime.date.today())
func = column.getter(None)
value = func(('', ))
self.assertEqual(value, datetime.date.today())

def test_default_value_none(self):
column = Date(0, default=None)
func = column.getter(None)
value = func(('', ))
self.assertIsNone(value)


class TestColumnsColumn(unittest.TestCase):

Expand All @@ -72,6 +90,24 @@ def test_custom_sep(self):
value = func(('0', '1', '2', '3', ))
self.assertEqual(value, '0: 1')

def test_default_value(self):
column = Columns(0, 1, default='something')
func = column.getter(None)
value = func(('', '', ))
self.assertEqual(value, 'something')

def test_default_value_none(self):
column = Columns(0, 1, default=None)
func = column.getter(None)
value = func(('', '', ))
self.assertIsNone(value)

def test_some_empty(self):
column = Columns(0, 1, default=None)
func = column.getter(None)
value = func(('this', '', ))
self.assertEqual(value, 'this')


class TestAmountColumn(unittest.TestCase):

Expand Down Expand Up @@ -103,6 +139,18 @@ def test_parse_subs_currency(self):
self.assertIsInstance(value, decimal.Decimal)
self.assertEqual(value, decimal.Decimal('1000.00'))

def test_default_value(self):
column = Amount(0, default=decimal.Decimal(42))
func = column.getter(None)
value = func(('', ))
self.assertEqual(value, decimal.Decimal(42))

def test_default_value_none(self):
column = Amount(0, default=None)
func = column.getter(None)
value = func(('', ))
self.assertIsNone(value)


class TestCSVMeta(unittest.TestCase):

Expand Down
Loading