Skip to content

Commit

Permalink
Merge pull request #1029 from prescod/pprescod/close-tempdirs-first
Browse files Browse the repository at this point in the history
Clean up temporary directories after closing iterator
  • Loading branch information
jstvz authored Apr 8, 2024
2 parents c3b64ee + e4bef02 commit 7be154d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 19 deletions.
20 changes: 4 additions & 16 deletions snowfakery/standard_plugins/Salesforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,26 +330,14 @@ def _load_dataset(self, iteration_mode, rootpath, kwargs):
f"Unable to query records for {query}: {','.join(qs.job_result.job_errors)}"
)

self.tempdir, self.iterator = create_tempfile_sql_db_iterator(
tempdir, iterator = create_tempfile_sql_db_iterator(
iteration_mode, fieldnames, qs.get_results()
)
return self.iterator
iterator.cleanup.push(tempdir)
return iterator

def close(self):
if self.iterator:
self.iterator.close()
self.iterator = None

if self.tempdir:
self.tempdir.cleanup()
self.tempdir = None

def __del__(self):
# in case close was not called
# properly, try to do an orderly
# cleanup
self.close()

pass

def create_tempfile_sql_db_iterator(mode, fieldnames, results):
tempdir, db_url = _create_db(fieldnames, results)
Expand Down
14 changes: 11 additions & 3 deletions snowfakery/standard_plugins/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,17 @@ class DatasetIteratorBase(PluginResultIterator):
Subclasses should implement 'self.restart' which puts an iterator into 'self.results'
"""

def __init__(self, repeat):
# subclasses can register stuff to be cleaned up here.
self.cleanup = ExitStack()
super().__init__(repeat)

def next_result(self):
return next(self.results)

def close(self):
self.cleanup.close()


class SQLDatasetIterator(DatasetIteratorBase):
def __init__(self, engine, table, repeat):
Expand All @@ -86,6 +94,7 @@ def start(self):
def close(self):
self.results = None
self.connection.close()
super().close()

def query(self):
"Return a SQL Alchemy SELECT statement"
Expand All @@ -108,14 +117,13 @@ def query(self):

class CSVDatasetLinearIterator(DatasetIteratorBase):
def __init__(self, datasource: FileLike, repeat: bool):
self.cleanup = ExitStack()
super().__init__(repeat)
# utf-8-sig and newline="" are for Windows
self.path, self.file = self.cleanup.enter_context(
open_file_like(datasource, "r", newline="", encoding="utf-8-sig")
)

self.start()
super().__init__(repeat)

def start(self):
assert self.file
Expand All @@ -127,7 +135,7 @@ def start(self):

def close(self):
self.results = None
self.cleanup.close()
super().close()

def plugin_result(self, row):
if None in row:
Expand Down
17 changes: 17 additions & 0 deletions tests/multiple-datasets.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
- plugin: snowfakery.standard_plugins.Salesforce.SOQLDataset
- object: Contact
count: 10
fields:
__users_from_salesforce:
SOQLDataset.shuffle:
fields: Id, FirstName, LastName
from: User
__Account_from_Salesforce:
SOQLDataset.shuffle:
fields: Id
from: Account
# The next line depends on the users having particular
# permissions.
FirstName: ${{__users_from_salesforce.FirstName}}
LastName: ${{__users_from_salesforce.LastName}}
AccountId: ${{__Account_from_Salesforce.Id}}

0 comments on commit 7be154d

Please sign in to comment.