Skip to content

Commit

Permalink
Make failure to close a stream an error, as it would be by default.
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul Prescod committed Nov 26, 2021
1 parent ca96441 commit 312b9b8
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 25 deletions.
13 changes: 3 additions & 10 deletions snowfakery/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,16 +224,9 @@ def configure_output_stream(
try:
yield output_stream
finally:
try:
messages = output_stream.close()
except Exception as e:
messages = None
parent_application.echo(
f"Could not close {output_stream}: {str(e)}", err=True
)
if messages:
for message in messages:
parent_application.echo(message)
messages = output_stream.close() or []
for message in messages:
parent_application.echo(message)


@contextmanager
Expand Down
16 changes: 14 additions & 2 deletions snowfakery/output_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def close(self) -> Optional[Sequence[str]]:
Return a list of messages to print out.
"""
return super().close()
raise NotImplementedError()

def __enter__(self, *args):
return self
Expand Down Expand Up @@ -578,8 +578,20 @@ def write_row(self, tablename: str, row_with_references: Dict) -> None:
stream.write_row(tablename, row_with_references)

def close(self) -> Optional[Sequence[str]]:
all_messages = []
closing_errors = []
for stream in self.outputstreams:
stream.close()
try:
messages = stream.close() or []
all_messages.extend(messages)
except Exception as e:
closing_errors.append(e)

if len(closing_errors) == 1:
raise closing_errors[1]
elif closing_errors:
raise IOError(f"Could not close streams: {closing_errors}")
return all_messages

def write_single_row(self, tablename: str, row: Dict) -> None:
return super().write_single_row(tablename, row)
21 changes: 8 additions & 13 deletions tests/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_continuation_as_open_file(self):
with mapping_file.open() as f:
assert yaml.safe_load(f)

def test_parent_application__echo(self):
def test_parent_application__exception_raised(self):
called = False

class MyEmbedder(SnowfakeryApplication):
Expand All @@ -74,10 +74,10 @@ def echo(self, *args, **kwargs):
meth = "snowfakery.output_streams.DebugOutputStream.close"
with mock.patch(meth) as close:
close.side_effect = AssertionError
generate_data(
yaml_file="examples/company.yml", parent_application=MyEmbedder()
)
assert called
with pytest.raises(AssertionError):
generate_data(
yaml_file="examples/company.yml", parent_application=MyEmbedder()
)

def test_parent_application__early_finish(self, generated_rows):
class MyEmbedder(SnowfakeryApplication):
Expand All @@ -89,14 +89,9 @@ def check_if_finished(self, idmanager):
assert self.__class__.count < 100, "Runaway recipe!"
return idmanager["Employee"] >= 10

meth = "snowfakery.output_streams.DebugOutputStream.close"
with mock.patch(meth) as close:
close.side_effect = AssertionError
generate_data(
yaml_file="examples/company.yml", parent_application=MyEmbedder()
)
# called 5 times, after generating 2 employees each
assert MyEmbedder.count == 5
generate_data(yaml_file="examples/company.yml", parent_application=MyEmbedder())
# called 5 times, after generating 2 employees each
assert MyEmbedder.count == 5

def test_embedding__cannot_infer_output_format(self):
with pytest.raises(exc.DataGenError, match="No format"):
Expand Down

0 comments on commit 312b9b8

Please sign in to comment.