diff --git a/osiris/app.py b/osiris/app.py index b2a9f12..e05a35d 100644 --- a/osiris/app.py +++ b/osiris/app.py @@ -1,7 +1,8 @@ +from enum import Enum import os +import pandas as pd import numpy as np -import polars as pl import typer from osiris.cairo.data_converter.data_converter import convert_to_cairo @@ -14,17 +15,24 @@ app = typer.Typer() +class InputFormat(Enum): + CSV = 'CSV' + PARQUET = 'Parquet' + NUMPY = 'NumPy' + UNKNOWN = 'Unknown' + + def check_file_format(file_path): _, file_extension = os.path.splitext(file_path) - if file_extension in ['.csv']: - return 'CSV' - elif file_extension in ['.parquet']: - return 'Parquet' - elif file_extension in ['.npy']: - return 'NumPy' + if file_extension == '.csv': + return InputFormat.CSV + elif file_extension == '.parquet': + return InputFormat.PARQUET + elif file_extension == '.npy': + return InputFormat.NUMPY else: - return 'Unknown' + return InputFormat.UNKNOWN def load_data(input_file: str): @@ -42,13 +50,13 @@ def load_data(input_file: str): input_format = check_file_format(input_file) match input_format: case InputFormat.CSV: - return pl.read_csv(input_file) + return pd.read_csv(input_file, header=None) case InputFormat.PARQUET: - return pl.read_parquet(input_file) + return pd.read_parquet(input_file) case InputFormat.NUMPY: return np.load(input_file) case _: - raise ValueError(f"Unsupported input format: {input_format}") + raise ValueError(f"Unsupported input format: {input_format.value}") def convert_to_numpy(data): diff --git a/pyproject.toml b/pyproject.toml index 0792dfa..d5fb72c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "giza-osiris" -version = "0.2.0" +version = "0.2.1" description = "Osiris is a Python library designed for efficient data conversion and management, primarily transforming data into Cairo programs" authors = ["Fran Algaba "] readme = "README.md" diff --git a/tests/data/simple_tensor.csv b/tests/data/simple_tensor.csv new file mode 100644 index 0000000..ebb6763 --- /dev/null +++ b/tests/data/simple_tensor.csv @@ -0,0 +1,2 @@ +1,2 +3,4 diff --git a/tests/data/simple_tensor.npy b/tests/data/simple_tensor.npy new file mode 100644 index 0000000..fbbcc39 Binary files /dev/null and b/tests/data/simple_tensor.npy differ diff --git a/tests/data/simple_tensor.parquet b/tests/data/simple_tensor.parquet new file mode 100644 index 0000000..f0ca502 Binary files /dev/null and b/tests/data/simple_tensor.parquet differ diff --git a/tests/test_convert_to_numpy.py b/tests/test_convert_to_numpy.py new file mode 100644 index 0000000..378ffb2 --- /dev/null +++ b/tests/test_convert_to_numpy.py @@ -0,0 +1,25 @@ +from osiris.app import convert_to_numpy, load_data +from osiris.dtypes.input_output_formats import InputFormat + +import numpy as np + + +def test_convert_to_numpy_from_csv(): + data = load_data("tests/data/simple_tensor.csv") + numpy_array = convert_to_numpy(data) + assert np.array_equal(numpy_array, np.array( + [[1, 2], [3, 4]], dtype=np.uint32)) + + +def test_convert_to_numpy_from_parquet(): + data = load_data("tests/data/simple_tensor.parquet") + numpy_array = convert_to_numpy(data) + assert np.array_equal(numpy_array, np.array( + [[1, 2], [3, 4]], dtype=np.uint32)) + + +def test_convert_to_numpy_from_npy(): + data = load_data("tests/data/simple_tensor.npy") + numpy_array = convert_to_numpy(data) + assert np.array_equal(numpy_array, np.array( + [[1, 2], [3, 4]], dtype=np.uint32))