Skip to content

Commit

Permalink
⚡ Fix CSV loader very slow
Browse files Browse the repository at this point in the history
Improve speed of CSV loader by adding parameter to disable type casting
  • Loading branch information
garlontas committed Sep 26, 2023
1 parent dd00c9e commit 4736671
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
11 changes: 7 additions & 4 deletions pystreamapi/loaders/__csv_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,22 @@
from pystreamapi.loaders.__lazy_file_iterable import LazyFileIterable


def csv(file_path: str, delimiter=',', encoding="utf-8") -> LazyFileIterable:
def csv(file_path: str, cast_types=True, delimiter=',', encoding="utf-8") -> LazyFileIterable:
"""
Loads a CSV file and converts it into a list of namedtuples.
Returns:
list: A list of namedtuples, where each namedtuple represents a row in the CSV.
:param cast_types: Set as False to disable casting of values to int, bool or float.
:param encoding: The encoding of the CSV file.
:param file_path: The path to the CSV file.
:param delimiter: The delimiter used in the CSV file.
"""
file_path = __validate_path(file_path)
return LazyFileIterable(lambda: __load_csv(file_path, delimiter, encoding))
return LazyFileIterable(lambda: __load_csv(file_path, cast_types, delimiter, encoding))


def __load_csv(file_path, delimiter, encoding):
def __load_csv(file_path, cast, delimiter, encoding):
"""Load a CSV file and convert it into a list of namedtuples"""
# skipcq: PTC-W6004
with open(file_path, mode='r', newline='', encoding=encoding) as csvfile:
Expand All @@ -29,8 +30,10 @@ def __load_csv(file_path, delimiter, encoding):
# Create a namedtuple type, casting the header values to int or float if possible
Row = namedtuple('Row', list(next(csvreader, [])))

mapper = __try_cast if cast else lambda x: x

# Process the data, casting values to int or float if possible
data = [Row(*[__try_cast(value) for value in row]) for row in csvreader]
data = [Row(*[mapper(value) for value in row]) for row in csvreader]
return data


Expand Down
12 changes: 11 additions & 1 deletion tests/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ def test_csv_loader(self):
self.assertEqual(data[1].attr1, 'a')
self.assertIsInstance(data[1].attr1, str)

def test_csv_loader_with_casting_disabled(self):
data = csv(f'{self.path}/data.csv', cast_types=False)
self.assertEqual(len(data), 2)
self.assertEqual(data[0].attr1, '1')
self.assertIsInstance(data[0].attr1, str)
self.assertEqual(data[0].attr2, '2.0')
self.assertIsInstance(data[0].attr2, str)
self.assertEqual(data[1].attr1, 'a')
self.assertIsInstance(data[1].attr1, str)

def test_csv_loader_is_iterable(self):
data = csv(f'{self.path}/data.csv')
self.assertEqual(len(list(iter(data))), 2)
Expand All @@ -38,6 +48,6 @@ def test_csv_loader_with_invalid_path(self):
with self.assertRaises(FileNotFoundError):
csv(f'{self.path}/invalid.csv')

def test_csv_loader_with_non_file(self):
def test_csv_loader_with_no_file(self):
with self.assertRaises(ValueError):
csv(f'{self.path}/')

0 comments on commit 4736671

Please sign in to comment.