From 473667111c74c53a686e5371a5efd316a6aa0ca9 Mon Sep 17 00:00:00 2001 From: Stefan Garlonta Date: Tue, 26 Sep 2023 17:02:13 +0200 Subject: [PATCH] :zap: Fix CSV loader very slow Improve speed of CSV loader by adding parameter to disable type casting --- pystreamapi/loaders/__csv_loader.py | 11 +++++++---- tests/test_loaders.py | 12 +++++++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/pystreamapi/loaders/__csv_loader.py b/pystreamapi/loaders/__csv_loader.py index 3f19274..b585833 100644 --- a/pystreamapi/loaders/__csv_loader.py +++ b/pystreamapi/loaders/__csv_loader.py @@ -6,21 +6,22 @@ from pystreamapi.loaders.__lazy_file_iterable import LazyFileIterable -def csv(file_path: str, delimiter=',', encoding="utf-8") -> LazyFileIterable: +def csv(file_path: str, cast_types=True, delimiter=',', encoding="utf-8") -> LazyFileIterable: """ Loads a CSV file and converts it into a list of namedtuples. Returns: list: A list of namedtuples, where each namedtuple represents a row in the CSV. + :param cast_types: Set as False to disable casting of values to int, bool or float. :param encoding: The encoding of the CSV file. :param file_path: The path to the CSV file. :param delimiter: The delimiter used in the CSV file. """ file_path = __validate_path(file_path) - return LazyFileIterable(lambda: __load_csv(file_path, delimiter, encoding)) + return LazyFileIterable(lambda: __load_csv(file_path, cast_types, delimiter, encoding)) -def __load_csv(file_path, delimiter, encoding): +def __load_csv(file_path, cast, delimiter, encoding): """Load a CSV file and convert it into a list of namedtuples""" # skipcq: PTC-W6004 with open(file_path, mode='r', newline='', encoding=encoding) as csvfile: @@ -29,8 +30,10 @@ def __load_csv(file_path, delimiter, encoding): # Create a namedtuple type, casting the header values to int or float if possible Row = namedtuple('Row', list(next(csvreader, []))) + mapper = __try_cast if cast else lambda x: x + # Process the data, casting values to int or float if possible - data = [Row(*[__try_cast(value) for value in row]) for row in csvreader] + data = [Row(*[mapper(value) for value in row]) for row in csvreader] return data diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 674d248..32730b1 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -20,6 +20,16 @@ def test_csv_loader(self): self.assertEqual(data[1].attr1, 'a') self.assertIsInstance(data[1].attr1, str) + def test_csv_loader_with_casting_disabled(self): + data = csv(f'{self.path}/data.csv', cast_types=False) + self.assertEqual(len(data), 2) + self.assertEqual(data[0].attr1, '1') + self.assertIsInstance(data[0].attr1, str) + self.assertEqual(data[0].attr2, '2.0') + self.assertIsInstance(data[0].attr2, str) + self.assertEqual(data[1].attr1, 'a') + self.assertIsInstance(data[1].attr1, str) + def test_csv_loader_is_iterable(self): data = csv(f'{self.path}/data.csv') self.assertEqual(len(list(iter(data))), 2) @@ -38,6 +48,6 @@ def test_csv_loader_with_invalid_path(self): with self.assertRaises(FileNotFoundError): csv(f'{self.path}/invalid.csv') - def test_csv_loader_with_non_file(self): + def test_csv_loader_with_no_file(self): with self.assertRaises(ValueError): csv(f'{self.path}/')