Skip to content

Commit

Permalink
Ensures files also close
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeerak Waseem committed Sep 28, 2020
1 parent 8707294 commit 0c9edd1
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions mlearn/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,24 @@ def load(self, dataset: str = 'train', skip_header = True, line_count: dict = No
disable = os.environ.get('TQDM_DISABLE', False), total = num_docs):

data_line, datapoint = {}, base.Datapoint() # TODO Look at moving all load processing into datapoint class.
data_line['fields'] = []

for field in self.train_fields:
data_line['fields'].append(field.name)
idx = field.index if self.ftype in ['CSV', 'TSV'] else field.cname
data_line[field.name] = self.process_doc(line[idx].rstrip().replace('\n', '').replace('\r', '').strip(),
**kwargs)
data_line['original'] = line[idx].rstrip().replace('\n', '').replace('\r', '').strip()

for field in self.label_fields:
data_line['fields'].append(field.name)
idx = field.index if self.ftype in ['CSV', 'TSV'] else field.cname
if self.label_preprocessor:
data_line[field.name] = self.label_preprocessor(line[idx].rstrip())
else:
data_line[field.name] = line[idx].rstrip()

data_line['fields'] = set(data_line['fields'])
for key, val in data_line.items():
setattr(datapoint, key, val)
data.append(datapoint)
Expand Down Expand Up @@ -173,6 +177,43 @@ def load_labels(self, dataset: str, label_name: str, label_path: str = None, fty
for label, doc in zip(labels, data):
setattr(doc, label_name, label)

def dump(self, data: str, write_path: str, format: str = 'json') -> None:
"""
Dump processd data to a file.
:data (str): The data slice to dump.
:write_path (str): Output path for the dataset, not including name.
:format (str, default = 'json'): Format of the output file (JSON/TSV).
"""
if format not in ['json', 'tsv']:
tqdm.write("Incorrect format selected. Defaulting to JSON")

if data == 'train':
data_out = self.data
elif data == 'dev':
data_out = self.dev
else:
data_out = self.test

file_path = os.path.abspath(write_path) if '~' not in write_path else os.path.expanduser(write_path)
writepath = os.path.join(file_path, f"{data}.{format}")
if os.path.exists(writepath):
tqdm.write(f"Path {writepath} already exists. creating {writepath}.dump")
writepath = f"{writepath}.dump"

with open(writepath, 'a', encoding = 'utf-8') as outf:
if format == 'tsv':
filewriter = csv.writer(outf, delimiter = '\t')
filewriter.writerow(data_out[0].fields)
else:
filewriter = outf

for datapoint in tqdm(data_out, desc = f"Dumping {data} to {writepath}"):
if format == 'json':
filewriter.write(json.dumps(datapoint) + '\n')
else:
filewriter.writerow([datapoint[key] for key in data_out[0].fields])

def set_labels(self, data: base.DataType, labels: base.DataType) -> None:
"""
Set labels for documents.
Expand Down

0 comments on commit 0c9edd1

Please sign in to comment.