From f3510a7a0e4db454e8b3f8f60659a3ab4cd40e48 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 22 Jan 2024 17:58:05 +0200 Subject: [PATCH] fix convert_to_numpy --- osiris/app.py | 30 +++++++++++++++++++----------- pyproject.toml | 2 +- tests/data/simple_tensor.csv | 2 ++ tests/data/simple_tensor.npy | Bin 0 -> 144 bytes tests/data/simple_tensor.parquet | Bin 0 -> 2128 bytes tests/test_convert_to_numpy.py | 25 +++++++++++++++++++++++++ 6 files changed, 47 insertions(+), 12 deletions(-) create mode 100644 tests/data/simple_tensor.csv create mode 100644 tests/data/simple_tensor.npy create mode 100644 tests/data/simple_tensor.parquet create mode 100644 tests/test_convert_to_numpy.py diff --git a/osiris/app.py b/osiris/app.py index b2a9f12..e05a35d 100644 --- a/osiris/app.py +++ b/osiris/app.py @@ -1,7 +1,8 @@ +from enum import Enum import os +import pandas as pd import numpy as np -import polars as pl import typer from osiris.cairo.data_converter.data_converter import convert_to_cairo @@ -14,17 +15,24 @@ app = typer.Typer() +class InputFormat(Enum): + CSV = 'CSV' + PARQUET = 'Parquet' + NUMPY = 'NumPy' + UNKNOWN = 'Unknown' + + def check_file_format(file_path): _, file_extension = os.path.splitext(file_path) - if file_extension in ['.csv']: - return 'CSV' - elif file_extension in ['.parquet']: - return 'Parquet' - elif file_extension in ['.npy']: - return 'NumPy' + if file_extension == '.csv': + return InputFormat.CSV + elif file_extension == '.parquet': + return InputFormat.PARQUET + elif file_extension == '.npy': + return InputFormat.NUMPY else: - return 'Unknown' + return InputFormat.UNKNOWN def load_data(input_file: str): @@ -42,13 +50,13 @@ def load_data(input_file: str): input_format = check_file_format(input_file) match input_format: case InputFormat.CSV: - return pl.read_csv(input_file) + return pd.read_csv(input_file, header=None) case InputFormat.PARQUET: - return pl.read_parquet(input_file) + return pd.read_parquet(input_file) case InputFormat.NUMPY: return np.load(input_file) case _: - raise ValueError(f"Unsupported input format: {input_format}") + raise ValueError(f"Unsupported input format: {input_format.value}") def convert_to_numpy(data): diff --git a/pyproject.toml b/pyproject.toml index 0792dfa..d5fb72c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "giza-osiris" -version = "0.2.0" +version = "0.2.1" description = "Osiris is a Python library designed for efficient data conversion and management, primarily transforming data into Cairo programs" authors = ["Fran Algaba "] readme = "README.md" diff --git a/tests/data/simple_tensor.csv b/tests/data/simple_tensor.csv new file mode 100644 index 0000000..ebb6763 --- /dev/null +++ b/tests/data/simple_tensor.csv @@ -0,0 +1,2 @@ +1,2 +3,4 diff --git a/tests/data/simple_tensor.npy b/tests/data/simple_tensor.npy new file mode 100644 index 0000000000000000000000000000000000000000..fbbcc395d645458212e63907e237149ac4399339 GIT binary patch literal 144 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZQ);5FqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= jXCxM+0{I$7ItoUbItsN4WCJcn1_lNuAZ7+)79a)y6m=Vf literal 0 HcmV?d00001 diff --git a/tests/data/simple_tensor.parquet b/tests/data/simple_tensor.parquet new file mode 100644 index 0000000000000000000000000000000000000000..f0ca502aaeca851278e016f5d54a1619670707cb GIT binary patch literal 2128 zcmb_eOK;mo5T<0pb$}j>G!&%32Z2C@9%NUfO*aa9S?WpEa$-6rDFubWmnF(1#qwdA zG7R|GV-CF(=uha+=nw1cidH1oJ*5jtJ3EhWz8%gG`haAPp_yIg)d_6WON!drD^L`* z1z(0@UNWURF>f>5d&kuGG}f1C>;daVwFtj-p|Dk~eJ<=V#fq>?ZCA>)Id?gxSfWc0 zh0;Y??#h+W7Eci7Re-~@N7YUEG z@aK3`eiEqic4SU%PiJT#`KGwM(wo;Oe|=Yl8hz36Uvwq9vRA2peIaW9!zD@&Q2p9Y z>2vWFMOEtM?LhY}Ju3e}-?5Hw*|(SQh!W3_*mu=WckDIjKw?APpV(}@%KExzgThZ- z7fw;EhcO%*M4ksA@fkHQ<`c;uObyR+B`2Loo-j4k3+$6v>TPt=LXLST zPfh;}bSTZwH)cyZmfBJzwFV2Tr=%EXsBspykN9R}sa&G*%iLpal`|QsoQr)8Ux|zs z&NmLm>g~10Ka;oD1@7I2b|?#G4|76}r3wMyUysj8Zs5qJw$g0ann-XZuj%NXGP7F5 z9*;G73K)&iIq>aA&R7*@1}_JOXA(R`k$Wc}C(hLq?q&x5-5Bck$ZL78;WDqMt4cKP zHdptNJg9HGh9M3}$fbz`{uG1MLI zk>+?$YnXF&NDcCyIK8s^Zkoq(@a!AIg+%zmvxv>U``>m|X^d2PYV|s+zPM6=4{$I& z1@z4f^Msy?bajx8w1I%nk}nD%XU3sP@>rgdnY?>rIGDe5J8T`LC$o-obq;;ti_nA6 zXIzsXYlDTZHWPv|>m;Y=M+?bs;8w>ofn@wchN69eM-2`lni3@*0z4m&i_8!1#~~S) zb`SfLp#ditMi56ZH3$ZV%+ce#KH`~J|12JO!aPI#8Qybd*c&0LTn+T_*TjzZ%^=t( cFP3Wl;;1fPDV6#jMZt$U{)VDHz<-H<0PlA;!~g&Q literal 0 HcmV?d00001 diff --git a/tests/test_convert_to_numpy.py b/tests/test_convert_to_numpy.py new file mode 100644 index 0000000..378ffb2 --- /dev/null +++ b/tests/test_convert_to_numpy.py @@ -0,0 +1,25 @@ +from osiris.app import convert_to_numpy, load_data +from osiris.dtypes.input_output_formats import InputFormat + +import numpy as np + + +def test_convert_to_numpy_from_csv(): + data = load_data("tests/data/simple_tensor.csv") + numpy_array = convert_to_numpy(data) + assert np.array_equal(numpy_array, np.array( + [[1, 2], [3, 4]], dtype=np.uint32)) + + +def test_convert_to_numpy_from_parquet(): + data = load_data("tests/data/simple_tensor.parquet") + numpy_array = convert_to_numpy(data) + assert np.array_equal(numpy_array, np.array( + [[1, 2], [3, 4]], dtype=np.uint32)) + + +def test_convert_to_numpy_from_npy(): + data = load_data("tests/data/simple_tensor.npy") + numpy_array = convert_to_numpy(data) + assert np.array_equal(numpy_array, np.array( + [[1, 2], [3, 4]], dtype=np.uint32))