From e4bef02c6982ccd750c9002e7542ff800ad94dd0 Mon Sep 17 00:00:00 2001 From: Paul Prescod Date: Fri, 5 Apr 2024 08:45:04 -0700 Subject: [PATCH] Clean up temporary directories after closing iterator --- snowfakery/standard_plugins/Salesforce.py | 20 ++++---------------- snowfakery/standard_plugins/datasets.py | 14 +++++++++++--- tests/multiple-datasets.yml | 17 +++++++++++++++++ 3 files changed, 32 insertions(+), 19 deletions(-) create mode 100644 tests/multiple-datasets.yml diff --git a/snowfakery/standard_plugins/Salesforce.py b/snowfakery/standard_plugins/Salesforce.py index 5fa15819..8dac7a1f 100644 --- a/snowfakery/standard_plugins/Salesforce.py +++ b/snowfakery/standard_plugins/Salesforce.py @@ -330,26 +330,14 @@ def _load_dataset(self, iteration_mode, rootpath, kwargs): f"Unable to query records for {query}: {','.join(qs.job_result.job_errors)}" ) - self.tempdir, self.iterator = create_tempfile_sql_db_iterator( + tempdir, iterator = create_tempfile_sql_db_iterator( iteration_mode, fieldnames, qs.get_results() ) - return self.iterator + iterator.cleanup.push(tempdir) + return iterator def close(self): - if self.iterator: - self.iterator.close() - self.iterator = None - - if self.tempdir: - self.tempdir.cleanup() - self.tempdir = None - - def __del__(self): - # in case close was not called - # properly, try to do an orderly - # cleanup - self.close() - + pass def create_tempfile_sql_db_iterator(mode, fieldnames, results): tempdir, db_url = _create_db(fieldnames, results) diff --git a/snowfakery/standard_plugins/datasets.py b/snowfakery/standard_plugins/datasets.py index ed10530f..7bdfb5fd 100644 --- a/snowfakery/standard_plugins/datasets.py +++ b/snowfakery/standard_plugins/datasets.py @@ -66,9 +66,17 @@ class DatasetIteratorBase(PluginResultIterator): Subclasses should implement 'self.restart' which puts an iterator into 'self.results' """ + def __init__(self, repeat): + # subclasses can register stuff to be cleaned up here. + self.cleanup = ExitStack() + super().__init__(repeat) + def next_result(self): return next(self.results) + def close(self): + self.cleanup.close() + class SQLDatasetIterator(DatasetIteratorBase): def __init__(self, engine, table, repeat): @@ -86,6 +94,7 @@ def start(self): def close(self): self.results = None self.connection.close() + super().close() def query(self): "Return a SQL Alchemy SELECT statement" @@ -108,14 +117,13 @@ def query(self): class CSVDatasetLinearIterator(DatasetIteratorBase): def __init__(self, datasource: FileLike, repeat: bool): - self.cleanup = ExitStack() + super().__init__(repeat) # utf-8-sig and newline="" are for Windows self.path, self.file = self.cleanup.enter_context( open_file_like(datasource, "r", newline="", encoding="utf-8-sig") ) self.start() - super().__init__(repeat) def start(self): assert self.file @@ -127,7 +135,7 @@ def start(self): def close(self): self.results = None - self.cleanup.close() + super().close() def plugin_result(self, row): if None in row: diff --git a/tests/multiple-datasets.yml b/tests/multiple-datasets.yml new file mode 100644 index 00000000..79187f7a --- /dev/null +++ b/tests/multiple-datasets.yml @@ -0,0 +1,17 @@ +- plugin: snowfakery.standard_plugins.Salesforce.SOQLDataset +- object: Contact + count: 10 + fields: + __users_from_salesforce: + SOQLDataset.shuffle: + fields: Id, FirstName, LastName + from: User + __Account_from_Salesforce: + SOQLDataset.shuffle: + fields: Id + from: Account + # The next line depends on the users having particular + # permissions. + FirstName: ${{__users_from_salesforce.FirstName}} + LastName: ${{__users_from_salesforce.LastName}} + AccountId: ${{__Account_from_Salesforce.Id}}