From d3ad1a05ad8aaa470bde0fca1b4a0795bb7aba64 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Mon, 13 May 2024 13:54:58 -0700 Subject: [PATCH] FEAT add CSV support (#197) --- pyrit/memory/memory_exporter.py | 42 ++++++++++++++- tests/memory/test_memory_exporter.py | 80 ++++++++++++++++++---------- 2 files changed, 92 insertions(+), 30 deletions(-) diff --git a/pyrit/memory/memory_exporter.py b/pyrit/memory/memory_exporter.py index 7147897b5..b31bdfd74 100644 --- a/pyrit/memory/memory_exporter.py +++ b/pyrit/memory/memory_exporter.py @@ -1,10 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import csv import json +from typing import Any import uuid from datetime import datetime from pathlib import Path +from collections.abc import MutableMapping from sqlalchemy.inspection import inspect @@ -21,7 +24,8 @@ def __init__(self): # Using strategy design pattern for export functionality. self.export_strategies = { "json": self.export_to_json, - # Future formats can be added here, e.g., "csv": self._export_to_csv + "csv": self.export_to_csv, + # Future formats can be added here } def export_data(self, data: list[Base], *, file_path: Path = None, export_type: str = "json"): # type: ignore @@ -65,6 +69,31 @@ def export_to_json(self, data: list[Base], file_path: Path = None) -> None: # t with open(file_path, "w") as f: json.dump(export_data, f, indent=4) + def export_to_csv(self, data: list[Base], file_path: Path = None) -> None: # type: ignore + """ + Exports the provided data to a CSV file at the specified file path. + Each item in the data list, representing a row from the table, + is converted to a dictionary before being written to the file. + + Args: + data (list[Base]): The data to be exported, as a list of SQLAlchemy model instances. + file_path (Path): The full path, including the file name, where the data will be exported. + + Raises: + ValueError: If no file_path is provided. + """ + if not file_path: + raise ValueError("Please provide a valid file path for exporting data.") + if not data: + raise ValueError("No data to export.") + + export_data = [_flatten_dict(self.model_to_dict(instance)) for instance in data] + fieldnames = list(export_data[0].keys()) + with open(file_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(export_data) + def model_to_dict(self, model_instance: Base): # type: ignore """ Converts an SQLAlchemy model instance into a dictionary, serializing @@ -89,3 +118,14 @@ def model_to_dict(self, model_instance: Base): # type: ignore else: model_dict[column.name] = value return model_dict + + +def _flatten_dict(d: MutableMapping, parent_key: str = "", sep: str = ".") -> MutableMapping: + items: list[tuple[Any, Any]] = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, MutableMapping): + items.extend(_flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) diff --git a/tests/memory/test_memory_exporter.py b/tests/memory/test_memory_exporter.py index 380334460..a565c8b20 100644 --- a/tests/memory/test_memory_exporter.py +++ b/tests/memory/test_memory_exporter.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import csv import json import pytest @@ -21,29 +22,51 @@ def model_to_dict(instance): return {c.key: getattr(instance, c.key) for c in inspect(instance).mapper.column_attrs} -def test_export_to_json_creates_file(tmp_path, sample_conversation_entries): +def read_file(file_path, export_type): + if export_type == "json": + with open(file_path, "r") as f: + return json.load(f) + elif export_type == "csv": + with open(file_path, "r", newline="") as f: + reader = csv.DictReader(f) + return [row for row in reader] + else: + raise ValueError(f"Invalid export type: {export_type}") + + +def export(export_type, exporter, data, file_path): + if export_type == "json": + exporter.export_to_json(data, file_path) + elif export_type == "csv": + exporter.export_to_csv(data, file_path) + else: + raise ValueError(f"Invalid export type: {export_type}") + + +@pytest.mark.parametrize("export_type", ["json", "csv"]) +def test_export_to_json_creates_file(tmp_path, sample_conversation_entries, export_type): exporter = MemoryExporter() - file_path = tmp_path / "conversations.json" + file_path = tmp_path / f"conversations.{export_type}" - exporter.export_to_json(sample_conversation_entries, file_path) + export(export_type=export_type, exporter=exporter, data=sample_conversation_entries, file_path=file_path) assert file_path.exists() # Check that the file was created - with open(file_path, "r") as f: - content = json.load(f) - # Perform more detailed checks on content if necessary - assert len(content) == 3 # Simple check for the number of items - # Convert each ConversationStore instance to a dictionary - expected_content = [model_to_dict(conv) for conv in sample_conversation_entries] - - for expected, actual in zip(expected_content, content): - assert expected["role"] == actual["role"] - assert expected["converted_value"] == actual["converted_value"] - assert expected["conversation_id"] == actual["conversation_id"] - assert expected["original_value_data_type"] == actual["original_value_data_type"] - assert expected["original_value"] == actual["original_value"] - - -def test_export_data_with_conversations(tmp_path, sample_conversation_entries): + content = read_file(file_path=file_path, export_type=export_type) + # Perform more detailed checks on content if necessary + assert len(content) == 3 # Simple check for the number of items + # Convert each ConversationStore instance to a dictionary + expected_content = [model_to_dict(conv) for conv in sample_conversation_entries] + + for expected, actual in zip(expected_content, content): + assert expected["role"] == actual["role"] + assert expected["converted_value"] == actual["converted_value"] + assert expected["conversation_id"] == actual["conversation_id"] + assert expected["original_value_data_type"] == actual["original_value_data_type"] + assert expected["original_value"] == actual["original_value"] + + +@pytest.mark.parametrize("export_type", ["json", "csv"]) +def test_export_to_json_data_with_conversations(tmp_path, sample_conversation_entries, export_type): exporter = MemoryExporter() conversation_id = sample_conversation_entries[0].conversation_id @@ -51,18 +74,17 @@ def test_export_data_with_conversations(tmp_path, sample_conversation_entries): file_path = tmp_path / "exported_conversations.json" # Call the method under test - exporter.export_data(sample_conversation_entries, file_path=file_path, export_type="json") + export(export_type=export_type, exporter=exporter, data=sample_conversation_entries, file_path=file_path) # Verify the file was created assert file_path.exists() # Read the file and verify its contents - with open(file_path, "r") as f: - content = json.load(f) - assert len(content) == 3 # Check for the expected number of items - assert content[0]["role"] == "user" - assert content[0]["converted_value"] == "Hello, how are you?" - assert content[0]["conversation_id"] == conversation_id - assert content[1]["role"] == "assistant" - assert content[1]["converted_value"] == "I'm fine, thank you!" - assert content[1]["conversation_id"] == conversation_id + content = read_file(file_path=file_path, export_type=export_type) + assert len(content) == 3 # Check for the expected number of items + assert content[0]["role"] == "user" + assert content[0]["converted_value"] == "Hello, how are you?" + assert content[0]["conversation_id"] == conversation_id + assert content[1]["role"] == "assistant" + assert content[1]["converted_value"] == "I'm fine, thank you!" + assert content[1]["conversation_id"] == conversation_id