Skip to content

Commit

Permalink
Merge pull request #5 from gizatechxyz/fix-numpy-converter
Browse files Browse the repository at this point in the history
Fix Numpy converter
  • Loading branch information
raphaelDkhn authored Jan 22, 2024
2 parents 87bd063 + f3510a7 commit 2f2fdca
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 12 deletions.
30 changes: 19 additions & 11 deletions osiris/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from enum import Enum
import os

import pandas as pd
import numpy as np
import polars as pl
import typer

from osiris.cairo.data_converter.data_converter import convert_to_cairo
Expand All @@ -14,17 +15,24 @@
app = typer.Typer()


class InputFormat(Enum):
CSV = 'CSV'
PARQUET = 'Parquet'
NUMPY = 'NumPy'
UNKNOWN = 'Unknown'


def check_file_format(file_path):
_, file_extension = os.path.splitext(file_path)

if file_extension in ['.csv']:
return 'CSV'
elif file_extension in ['.parquet']:
return 'Parquet'
elif file_extension in ['.npy']:
return 'NumPy'
if file_extension == '.csv':
return InputFormat.CSV
elif file_extension == '.parquet':
return InputFormat.PARQUET
elif file_extension == '.npy':
return InputFormat.NUMPY
else:
return 'Unknown'
return InputFormat.UNKNOWN


def load_data(input_file: str):
Expand All @@ -42,13 +50,13 @@ def load_data(input_file: str):
input_format = check_file_format(input_file)
match input_format:
case InputFormat.CSV:
return pl.read_csv(input_file)
return pd.read_csv(input_file, header=None)
case InputFormat.PARQUET:
return pl.read_parquet(input_file)
return pd.read_parquet(input_file)
case InputFormat.NUMPY:
return np.load(input_file)
case _:
raise ValueError(f"Unsupported input format: {input_format}")
raise ValueError(f"Unsupported input format: {input_format.value}")


def convert_to_numpy(data):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "giza-osiris"
version = "0.2.0"
version = "0.2.1"
description = "Osiris is a Python library designed for efficient data conversion and management, primarily transforming data into Cairo programs"
authors = ["Fran Algaba <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 2 additions & 0 deletions tests/data/simple_tensor.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
1,2
3,4
Binary file added tests/data/simple_tensor.npy
Binary file not shown.
Binary file added tests/data/simple_tensor.parquet
Binary file not shown.
25 changes: 25 additions & 0 deletions tests/test_convert_to_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from osiris.app import convert_to_numpy, load_data
from osiris.dtypes.input_output_formats import InputFormat

import numpy as np


def test_convert_to_numpy_from_csv():
data = load_data("tests/data/simple_tensor.csv")
numpy_array = convert_to_numpy(data)
assert np.array_equal(numpy_array, np.array(
[[1, 2], [3, 4]], dtype=np.uint32))


def test_convert_to_numpy_from_parquet():
data = load_data("tests/data/simple_tensor.parquet")
numpy_array = convert_to_numpy(data)
assert np.array_equal(numpy_array, np.array(
[[1, 2], [3, 4]], dtype=np.uint32))


def test_convert_to_numpy_from_npy():
data = load_data("tests/data/simple_tensor.npy")
numpy_array = convert_to_numpy(data)
assert np.array_equal(numpy_array, np.array(
[[1, 2], [3, 4]], dtype=np.uint32))

0 comments on commit 2f2fdca

Please sign in to comment.