From 0c9edd15fea3ba414f326285d784ea881920e3e8 Mon Sep 17 00:00:00 2001 From: Zeerak Waseem Date: Mon, 28 Sep 2020 08:21:42 +0100 Subject: [PATCH] Ensures files also close --- mlearn/data/dataset.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/mlearn/data/dataset.py b/mlearn/data/dataset.py index 10adb26..5a86688 100644 --- a/mlearn/data/dataset.py +++ b/mlearn/data/dataset.py @@ -106,20 +106,24 @@ def load(self, dataset: str = 'train', skip_header = True, line_count: dict = No disable = os.environ.get('TQDM_DISABLE', False), total = num_docs): data_line, datapoint = {}, base.Datapoint() # TODO Look at moving all load processing into datapoint class. + data_line['fields'] = [] for field in self.train_fields: + data_line['fields'].append(field.name) idx = field.index if self.ftype in ['CSV', 'TSV'] else field.cname data_line[field.name] = self.process_doc(line[idx].rstrip().replace('\n', '').replace('\r', '').strip(), **kwargs) data_line['original'] = line[idx].rstrip().replace('\n', '').replace('\r', '').strip() for field in self.label_fields: + data_line['fields'].append(field.name) idx = field.index if self.ftype in ['CSV', 'TSV'] else field.cname if self.label_preprocessor: data_line[field.name] = self.label_preprocessor(line[idx].rstrip()) else: data_line[field.name] = line[idx].rstrip() + data_line['fields'] = set(data_line['fields']) for key, val in data_line.items(): setattr(datapoint, key, val) data.append(datapoint) @@ -173,6 +177,43 @@ def load_labels(self, dataset: str, label_name: str, label_path: str = None, fty for label, doc in zip(labels, data): setattr(doc, label_name, label) + def dump(self, data: str, write_path: str, format: str = 'json') -> None: + """ + Dump processd data to a file. + + :data (str): The data slice to dump. + :write_path (str): Output path for the dataset, not including name. + :format (str, default = 'json'): Format of the output file (JSON/TSV). + """ + if format not in ['json', 'tsv']: + tqdm.write("Incorrect format selected. Defaulting to JSON") + + if data == 'train': + data_out = self.data + elif data == 'dev': + data_out = self.dev + else: + data_out = self.test + + file_path = os.path.abspath(write_path) if '~' not in write_path else os.path.expanduser(write_path) + writepath = os.path.join(file_path, f"{data}.{format}") + if os.path.exists(writepath): + tqdm.write(f"Path {writepath} already exists. creating {writepath}.dump") + writepath = f"{writepath}.dump" + + with open(writepath, 'a', encoding = 'utf-8') as outf: + if format == 'tsv': + filewriter = csv.writer(outf, delimiter = '\t') + filewriter.writerow(data_out[0].fields) + else: + filewriter = outf + + for datapoint in tqdm(data_out, desc = f"Dumping {data} to {writepath}"): + if format == 'json': + filewriter.write(json.dumps(datapoint) + '\n') + else: + filewriter.writerow([datapoint[key] for key in data_out[0].fields]) + def set_labels(self, data: base.DataType, labels: base.DataType) -> None: """ Set labels for documents.