Skip to content

Commit

Permalink
Make Data objects compatible with Path objects
Browse files Browse the repository at this point in the history
  • Loading branch information
enekomartinmartinez committed Dec 13, 2021
1 parent bf03b16 commit 9f6da8a
Show file tree
Hide file tree
Showing 9 changed files with 365 additions and 315 deletions.
19 changes: 11 additions & 8 deletions pysd/py_backend/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
import re
from pathlib import Path

import numpy as np
import xarray as xr
Expand All @@ -19,6 +20,8 @@ def read(cls, file_name, encoding=None):
"""
Read the columns from the data file or return the previously read ones
"""
if isinstance(file_name, str):
file_name = Path(file_name)
if file_name in cls._files:
return cls._files[file_name]
else:
Expand Down Expand Up @@ -50,7 +53,7 @@ def read_file(cls, file_name, encoding=None):
out = cls.read_line(file_name, encoding)
if out is None:
raise ValueError(
f"\nNot able to read '{file_name}'. "
f"\nNot able to read '{str(file_name)}'. "
+ "Only '.csv', '.tab' files are accepted.")

transpose = False
Expand All @@ -64,21 +67,21 @@ def read_file(cls, file_name, encoding=None):
return out, transpose
else:
raise ValueError(
f"Invalid file format '{file_name}'... varible names "
f"Invalid file format '{str(file_name)}'... varible names "
"should appear in the first row or in the first column...")

@classmethod
def read_line(cls, file_name, encoding=None):
"""
Read the firts row and return a set of it.
"""
if file_name.lower().endswith(".tab"):
if file_name.suffix.lower() == ".tab":
return set(pd.read_table(file_name,
nrows=0,
encoding=encoding,
dtype=str,
header=0).iloc[:, 1:])
elif file_name.lower().endswith(".csv"):
elif file_name.suffix.lower() == ".csv":
return set(pd.read_csv(file_name,
nrows=0,
encoding=encoding,
Expand All @@ -92,12 +95,12 @@ def read_row(cls, file_name, encoding=None):
"""
Read the firts column and return a set of it.
"""
if file_name.lower().endswith(".tab"):
if file_name.suffix.lower() == ".tab":
return set(pd.read_table(file_name,
usecols=[0],
encoding=encoding,
dtype=str).iloc[:, 0].to_list())
elif file_name.lower().endswith(".csv"):
elif file_name.suffix.lower() == ".csv":
return set(pd.read_csv(file_name,
usecols=[0],
encoding=encoding,
Expand Down Expand Up @@ -236,7 +239,7 @@ def load_data(self, file_names):
Resulting data array with the time in the first dimension.
"""
if isinstance(file_names, str):
if isinstance(file_names, (str, Path)):
file_names = [file_names]

for file_name in file_names:
Expand All @@ -248,7 +251,7 @@ def load_data(self, file_names):
raise ValueError(
f"_data_{self.py_name}\n"
f"Data for {self.real_name} not found in "
f"{', '.join(file_names)}")
f"{', '.join([str(file_name) for file_name in file_names])}")

def _load_data(self, file_name):
"""
Expand Down
6 changes: 5 additions & 1 deletion pysd/py_backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import os
import json
from pathlib import Path
from chardet.universaldetector import UniversalDetector

import regex as re
Expand Down Expand Up @@ -452,13 +453,16 @@ def load_outputs(file_name, transpose=False, columns=None, encoding=None):
"""
read_func = {'.csv': pd.read_csv, '.tab': pd.read_table}

if isinstance(file_name, str):
file_name = Path(file_name)

if columns:
columns = set(columns)
if not transpose:
columns.add("Time")

for end, func in read_func.items():
if file_name.lower().endswith(end):
if file_name.suffix.lower() == end:
if transpose:
out = func(file_name,
encoding=encoding,
Expand Down
12 changes: 7 additions & 5 deletions tests/pytest_pysd/user_interaction/pytest_select_submodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
@pytest.mark.parametrize(
"model_path,subview_sep,variables,modules,n_deps,dep_vars",
[
(
( # split_views
Path("more-tests/split_model/test_split_model.mdl"),
[],
["stock"],
[],
(6, 1, 2, 0, 1),
{"rate1": 4, "initial_stock": 2, "initial_stock_correction": 0}
),
(
( # split_subviews
Path("more-tests/split_model/test_split_model_subviews.mdl"),
["."],
[],
Expand All @@ -28,7 +28,7 @@
}
),
(
( # split_sub_subviews
Path("more-tests/split_model/test_split_model_sub_subviews.mdl"),
[".", "-"],
["variablex"],
Expand All @@ -37,6 +37,7 @@
{"another_var": 5, "look_up_definition": 3}
)
],
ids=["split_views", "split_subviews", "split_sub_subviews"]
)
class TestSubmodel:
"""Submodel selecting class"""
Expand Down Expand Up @@ -174,15 +175,15 @@ def test_select_submodel(self, model, variables, modules,
@pytest.mark.parametrize(
"model_path,split_views,module,raise_type,error_message",
[
(
( # module_not_found
Path("more-tests/split_model/test_split_model.mdl"),
True,
"view_4",
NameError,
"Module or submodule 'view_4' not found..."
),
(
( # not_modularized_model
Path("more-tests/split_model/test_split_model.mdl"),
False,
"view_1",
Expand All @@ -191,6 +192,7 @@ def test_select_submodel(self, model, variables, modules,
)
],
ids=["module_not_found", "not_modularized_model"]
)
class TestGetVarsInModuleErrors:
@pytest.fixture
Expand Down
17 changes: 10 additions & 7 deletions tests/pytest_translation/vensim2py/pytest_split_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"model_path,subview_sep,modules,macros,original_vars,py_vars,"
+ "stateful_objs",
[
(
( # split_views
Path("more-tests/split_model/test_split_model.mdl"),
[],
["view_1", "view2", "view_3"],
Expand All @@ -20,7 +20,7 @@
["another_var", "rate1", "varn", "variablex", "stock"],
["_integ_stock"]
),
(
( # split_subviews
Path("more-tests/split_model/test_split_model_subviews.mdl"),
["."],
["view_1/submodule_1", "view_1/submodule_2", "view_2"],
Expand All @@ -29,7 +29,7 @@
["another_var", "rate1", "varn", "variablex", "stock"],
["_integ_stock"]
),
(
( # split_sub_subviews
Path("more-tests/split_model/test_split_model_sub_subviews.mdl"),
[".", "-"],
[
Expand All @@ -43,7 +43,7 @@
"interesting_var_2", "great_var"],
["_integ_stock"]
),
(
( # split_macro
Path("more-tests/split_model_with_macro/"
+ "test_split_model_with_macro.mdl"),
[".", "-"],
Expand All @@ -53,7 +53,7 @@
["new_var"],
["_macro_macro_output"]
),
(
( # split_vensim_8_2_1
Path("more-tests/split_model_vensim_8_2_1/"
+ "test_split_model_vensim_8_2_1.mdl"),
[],
Expand All @@ -64,6 +64,8 @@
["integ_teacup_temperature", "integ_cream_temperature"]
)
],
ids=["split_views", "split_subviews", "split_sub_subviews", "split_macro",
"split_vensim_8_2_1"]
)
class TestSplitViews:
"""
Expand Down Expand Up @@ -152,18 +154,19 @@ def test_read_vensim_split_model(self, model_file, subview_sep,
@pytest.mark.parametrize(
"model_path,subview_sep,warning_message",
[
(
( # warning_noviews
Path("test-models/samples/teacup/teacup.mdl"),
[],
"Only a single view with no subviews was detected. The model"
+ " will be built in a single file."
),
(
( # not_match_separator
Path("more-tests/split_model/test_split_model_sub_subviews.mdl"),
["a"],
"The given subview separators were not matched in any view name."
),
],
ids=["warning_noviews", "not_match_separator"]
)
class TestSplitViewsWarnings:
"""
Expand Down
137 changes: 137 additions & 0 deletions tests/pytest_types/data/pytest_columns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import pytest
import itertools

from pysd.py_backend.data import Columns


class TestColumns:
@pytest.fixture(scope="class")
def out_teacup(self, _root):
return _root.joinpath("data/out_teacup.csv")

@pytest.fixture(scope="class")
def out_teacup_transposed(self, _root):
return _root.joinpath("data/out_teacup_transposed.csv")

def test_clean_columns(self, out_teacup):
# test the singleton works well for laizy loading
Columns.clean()
assert Columns._files == {}
Columns.read(out_teacup)
assert Columns._files != {}
assert out_teacup in Columns._files
Columns.clean()
assert Columns._files == {}

def test_transposed_frame(self, out_teacup, out_teacup_transposed):
# test loading transposed frames
cols1, trans1 = Columns.get_columns(out_teacup)
cols2, trans2 = Columns.get_columns(out_teacup_transposed)
Columns.clean()

assert cols1 == cols2
assert not trans1
assert trans2

def test_get_columns(self, out_teacup, out_teacup_transposed):
# test getting specific columns by name
cols0, trans0 = Columns.get_columns(out_teacup)

cols1, trans1 = Columns.get_columns(
out_teacup,
vars=["Room Temperature", "Teacup Temperature"])

cols2, trans2 = Columns.get_columns(
out_teacup_transposed,
vars=["Heat Loss to Room"])

cols3 = Columns.get_columns(
out_teacup_transposed,
vars=["No column"])[0]

Columns.clean()

assert cols1.issubset(cols0)
assert cols1 == set(["Room Temperature", "Teacup Temperature"])

assert cols2.issubset(cols0)
assert cols2 == set(["Heat Loss to Room"])

assert cols3 == set()

assert not trans0
assert not trans1
assert trans2

def test_get_columns_subscripted(self, _root):
# test get subscripted columns
data_file = _root.joinpath(
"test-models/tests/subscript_3d_arrays_widthwise/output.tab"
)

data_file2 = _root.joinpath(
"test-models/tests/subscript_2d_arrays/output.tab"
)

subsd = {
"d3": ["Depth 1", "Depth 2"],
"d2": ["Column 1", "Column 2"],
"d1": ["Entry 1", "Entry 2", "Entry 3"]
}

cols1 = Columns.get_columns(
data_file,
vars=["Three Dimensional Constant"])[0]

expected = {
"Three Dimensional Constant[" + ",".join(el) + "]"
for el in itertools.product(subsd["d1"], subsd["d2"], subsd["d3"])
}

assert cols1 == expected

cols2 = Columns.get_columns(
data_file2,
vars=["Rate A", "Stock A"])[0]

subs = list(itertools.product(subsd["d1"], subsd["d2"]))
expected = {
"Rate A[" + ",".join(el) + "]"
for el in subs
}

expected.update({
"Stock A[" + ",".join(el) + "]"
for el in subs
})

assert cols2 == expected


@pytest.mark.parametrize(
"file,raise_type,error_message",
[
( # invalid_file_type
"more-tests/not_vensim/test_not_vensim.txt",
ValueError,
"Not able to read '%s'"
),
( # invalid_file_format
"data/out_teacup_no_head.csv",
ValueError,
"Invalid file format '%s'... varible names should appear"
+ " in the first row or in the first column..."
)
],
ids=["invalid_file_type", "invalid_file_format"]
)
class TestColumnsErrors:
# Test errors associated with Columns class

@pytest.fixture
def file_path(self, _root, file):
return _root.joinpath(file)

def test_columns_errors(self, file_path, raise_type, error_message):
with pytest.raises(raise_type, match=error_message % str(file_path)):
Columns.read_file(file_path)
Loading

0 comments on commit 9f6da8a

Please sign in to comment.