diff --git a/beanquery/query_execute_test.py b/beanquery/query_execute_test.py index d4f3107..2c1fd23 100644 --- a/beanquery/query_execute_test.py +++ b/beanquery/query_execute_test.py @@ -1855,3 +1855,18 @@ def test_create_table_guess_types(self, filename): self.assertEqual(names, ['id', 'name', 'check', 'date', 'value']) types = [column.dtype for column in self.conn.tables['test'].columns.values()] self.assertEqual(types, [int, str, bool, datetime.date, Decimal]) + + +class TestCSVSource(unittest.TestCase): + + @docfile + def test_csv_source(self, filename): + '''\ + id, name, check, date, value + 1234, one, true, 2025-01-01, 1.234 + ''' + conn = beanquery.connect(f'csv:{filename}?name=test') + names = list(conn.tables['test'].columns.keys()) + self.assertEqual(names, ['id', 'name', 'check', 'date', 'value']) + types = [column.dtype for column in conn.tables['test'].columns.values()] + self.assertEqual(types, [int, str, bool, datetime.date, Decimal]) diff --git a/beanquery/sources/csv.py b/beanquery/sources/csv.py index 43c086b..ba1ca97 100644 --- a/beanquery/sources/csv.py +++ b/beanquery/sources/csv.py @@ -1,6 +1,7 @@ import csv import datetime +from os import path from urllib.parse import urlparse, parse_qsl from beanquery import tables @@ -91,3 +92,14 @@ def create(name, columns, using): if filename: data = open(filename, encoding=encoding) return Table(name, columns, data, header=header, **params) + + +def attach(context, dsn, *, data=None): + parts = urlparse(dsn) + filename = parts.path + params = dict(parse_qsl(parts.query)) + encoding = params.pop('encoding', None) + if filename: + data = open(filename, encoding=encoding) + name = params.pop('name', None) or path.splitext(path.basename(filename))[0] or 'csv' + context.tables[name] = Table(name, None, data, header=True, **params)