Skip to content

Commit

Permalink
Fix CSVTarget with inference after flow restart (#476)
Browse files Browse the repository at this point in the history
  • Loading branch information
gtopper authored Dec 17, 2023
1 parent c343a67 commit fac7787
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
7 changes: 4 additions & 3 deletions storey/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def _init(self):
self._col_to_index[col] = index
self._index_cols = copy.copy(self._initial_index_cols)
self._init_partition_col_indices()
self._still_need_to_infer_columns = self._infer_columns_from_data

def _init_partition_col_indices(self):
self._partition_col_to_index = {}
Expand Down Expand Up @@ -295,10 +296,10 @@ def _get_column_data_from_list(new_data, event, original_data, columns, metadata
def _event_to_writer_entry(self, event):
data = event.body
if isinstance(data, dict):
if self._infer_columns_from_data:
if self._still_need_to_infer_columns:
self._columns.extend(data.keys() - self._index_cols)
self._columns.sort()
self._infer_columns_from_data = False
self._still_need_to_infer_columns = False
self._init_partition_col_indices()
data = {} if self._retain_dict else []
self._get_column_data_from_dict(
Expand Down Expand Up @@ -326,7 +327,7 @@ def _event_to_writer_entry(self, event):
elif isinstance(data, list):
for index in self._partition_col_indices:
del data[index]
if self._infer_columns_from_data:
if self._still_need_to_infer_columns:
raise TypeError(
"Cannot infer_columns_from_data when event type is list. Inference is only possible from dict."
)
Expand Down
19 changes: 19 additions & 0 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2097,6 +2097,25 @@ def test_write_csv_infer_columns(tmpdir):
assert result == expected


# ML-5298
def test_write_csv_infer_columns_after_flow_restart(tmpdir):
file_path = f"{tmpdir}/test_write_csv_infer_columns_after_flow_restart.csv"
flow = build_flow([SyncEmitSource(), CSVTarget(file_path, header=True)])

for r in [range(3), range(3, 6), range(6, 10)]:
controller = flow.run()
for i in r:
controller.emit({"n": i, "n*10": 10 * i})
controller.terminate()
controller.await_termination()

with open(file_path) as file:
result = file.read()

expected = "n,n*10\n0,0\n1,10\n2,20\n3,30\n4,40\n5,50\n6,60\n7,70\n8,80\n9,90\n"
assert result == expected


def test_write_csv_infer_columns_without_header(tmpdir):
file_path = f"{tmpdir}/test_write_csv_infer_columns_without_header.csv"
controller = build_flow([SyncEmitSource(), CSVTarget(file_path)]).run()
Expand Down

0 comments on commit fac7787

Please sign in to comment.