Skip to content

Commit

Permalink
FEAT add CSV support (Azure#197)
Browse files Browse the repository at this point in the history
  • Loading branch information
romanlutz authored May 13, 2024
1 parent 32c8b19 commit d3ad1a0
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 30 deletions.
42 changes: 41 additions & 1 deletion pyrit/memory/memory_exporter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import csv
import json
from typing import Any
import uuid
from datetime import datetime
from pathlib import Path
from collections.abc import MutableMapping

from sqlalchemy.inspection import inspect

Expand All @@ -21,7 +24,8 @@ def __init__(self):
# Using strategy design pattern for export functionality.
self.export_strategies = {
"json": self.export_to_json,
# Future formats can be added here, e.g., "csv": self._export_to_csv
"csv": self.export_to_csv,
# Future formats can be added here
}

def export_data(self, data: list[Base], *, file_path: Path = None, export_type: str = "json"): # type: ignore
Expand Down Expand Up @@ -65,6 +69,31 @@ def export_to_json(self, data: list[Base], file_path: Path = None) -> None: # t
with open(file_path, "w") as f:
json.dump(export_data, f, indent=4)

def export_to_csv(self, data: list[Base], file_path: Path = None) -> None: # type: ignore
"""
Exports the provided data to a CSV file at the specified file path.
Each item in the data list, representing a row from the table,
is converted to a dictionary before being written to the file.
Args:
data (list[Base]): The data to be exported, as a list of SQLAlchemy model instances.
file_path (Path): The full path, including the file name, where the data will be exported.
Raises:
ValueError: If no file_path is provided.
"""
if not file_path:
raise ValueError("Please provide a valid file path for exporting data.")
if not data:
raise ValueError("No data to export.")

export_data = [_flatten_dict(self.model_to_dict(instance)) for instance in data]
fieldnames = list(export_data[0].keys())
with open(file_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(export_data)

def model_to_dict(self, model_instance: Base): # type: ignore
"""
Converts an SQLAlchemy model instance into a dictionary, serializing
Expand All @@ -89,3 +118,14 @@ def model_to_dict(self, model_instance: Base): # type: ignore
else:
model_dict[column.name] = value
return model_dict


def _flatten_dict(d: MutableMapping, parent_key: str = "", sep: str = ".") -> MutableMapping:
items: list[tuple[Any, Any]] = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, MutableMapping):
items.extend(_flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
80 changes: 51 additions & 29 deletions tests/memory/test_memory_exporter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import csv
import json
import pytest

Expand All @@ -21,48 +22,69 @@ def model_to_dict(instance):
return {c.key: getattr(instance, c.key) for c in inspect(instance).mapper.column_attrs}


def test_export_to_json_creates_file(tmp_path, sample_conversation_entries):
def read_file(file_path, export_type):
if export_type == "json":
with open(file_path, "r") as f:
return json.load(f)
elif export_type == "csv":
with open(file_path, "r", newline="") as f:
reader = csv.DictReader(f)
return [row for row in reader]
else:
raise ValueError(f"Invalid export type: {export_type}")


def export(export_type, exporter, data, file_path):
if export_type == "json":
exporter.export_to_json(data, file_path)
elif export_type == "csv":
exporter.export_to_csv(data, file_path)
else:
raise ValueError(f"Invalid export type: {export_type}")


@pytest.mark.parametrize("export_type", ["json", "csv"])
def test_export_to_json_creates_file(tmp_path, sample_conversation_entries, export_type):
exporter = MemoryExporter()
file_path = tmp_path / "conversations.json"
file_path = tmp_path / f"conversations.{export_type}"

exporter.export_to_json(sample_conversation_entries, file_path)
export(export_type=export_type, exporter=exporter, data=sample_conversation_entries, file_path=file_path)

assert file_path.exists() # Check that the file was created
with open(file_path, "r") as f:
content = json.load(f)
# Perform more detailed checks on content if necessary
assert len(content) == 3 # Simple check for the number of items
# Convert each ConversationStore instance to a dictionary
expected_content = [model_to_dict(conv) for conv in sample_conversation_entries]

for expected, actual in zip(expected_content, content):
assert expected["role"] == actual["role"]
assert expected["converted_value"] == actual["converted_value"]
assert expected["conversation_id"] == actual["conversation_id"]
assert expected["original_value_data_type"] == actual["original_value_data_type"]
assert expected["original_value"] == actual["original_value"]


def test_export_data_with_conversations(tmp_path, sample_conversation_entries):
content = read_file(file_path=file_path, export_type=export_type)
# Perform more detailed checks on content if necessary
assert len(content) == 3 # Simple check for the number of items
# Convert each ConversationStore instance to a dictionary
expected_content = [model_to_dict(conv) for conv in sample_conversation_entries]

for expected, actual in zip(expected_content, content):
assert expected["role"] == actual["role"]
assert expected["converted_value"] == actual["converted_value"]
assert expected["conversation_id"] == actual["conversation_id"]
assert expected["original_value_data_type"] == actual["original_value_data_type"]
assert expected["original_value"] == actual["original_value"]


@pytest.mark.parametrize("export_type", ["json", "csv"])
def test_export_to_json_data_with_conversations(tmp_path, sample_conversation_entries, export_type):
exporter = MemoryExporter()
conversation_id = sample_conversation_entries[0].conversation_id

# Define the file path using tmp_path
file_path = tmp_path / "exported_conversations.json"

# Call the method under test
exporter.export_data(sample_conversation_entries, file_path=file_path, export_type="json")
export(export_type=export_type, exporter=exporter, data=sample_conversation_entries, file_path=file_path)

# Verify the file was created
assert file_path.exists()

# Read the file and verify its contents
with open(file_path, "r") as f:
content = json.load(f)
assert len(content) == 3 # Check for the expected number of items
assert content[0]["role"] == "user"
assert content[0]["converted_value"] == "Hello, how are you?"
assert content[0]["conversation_id"] == conversation_id
assert content[1]["role"] == "assistant"
assert content[1]["converted_value"] == "I'm fine, thank you!"
assert content[1]["conversation_id"] == conversation_id
content = read_file(file_path=file_path, export_type=export_type)
assert len(content) == 3 # Check for the expected number of items
assert content[0]["role"] == "user"
assert content[0]["converted_value"] == "Hello, how are you?"
assert content[0]["conversation_id"] == conversation_id
assert content[1]["role"] == "assistant"
assert content[1]["converted_value"] == "I'm fine, thank you!"
assert content[1]["conversation_id"] == conversation_id

0 comments on commit d3ad1a0

Please sign in to comment.