Skip to content

Commit

Permalink
refactor: default to df output
Browse files Browse the repository at this point in the history
* remove format_type, always classic
* if only_classic, return classic, else df
  • Loading branch information
dshemetov committed Jul 11, 2024
1 parent c1ffe1c commit b951654
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 46 deletions.
18 changes: 2 additions & 16 deletions epidatpy/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,6 @@ def __str__(self) -> str:
return f"{format_date(self.start)}-{format_date(self.end)}"


class EpiDataFormatType(str, Enum):
"""
possible formatting options for API calls
"""

json = "json"
classic = "classic"


class InvalidArgumentException(Exception):
"""
exception for an invalid argument
Expand Down Expand Up @@ -174,41 +165,36 @@ def _verify_parameters(self) -> None:

def _formatted_parameters(
self,
format_type: Optional[EpiDataFormatType] = None,
fields: Optional[Sequence[str]] = None,
) -> Mapping[str, str]:
"""
format this call into a [URL, Params] tuple
"""
all_params = dict(self._params)
if format_type and format_type != EpiDataFormatType.classic:
all_params["format"] = format_type
if fields:
all_params["fields"] = fields
return {k: format_list(v) for k, v in all_params.items() if v is not None}

def request_arguments(
self,
format_type: Optional[EpiDataFormatType] = None,
fields: Optional[Sequence[str]] = None,
) -> Tuple[str, Mapping[str, str]]:
"""
format this call into a [URL, Params] tuple
"""
formatted_params = self._formatted_parameters(format_type, fields)
formatted_params = self._formatted_parameters(fields)
full_url = add_endpoint_to_url(self._base_url, self._endpoint)
return full_url, formatted_params

def request_url(
self,
format_type: Optional[EpiDataFormatType] = None,
fields: Optional[Sequence[str]] = None,
) -> str:
"""
format this call into a full HTTP request url with encoded parameters
"""
self._verify_parameters()
u, p = self.request_arguments(format_type, fields)
u, p = self.request_arguments(fields)
query = urlencode(p)
if query:
return f"{u}?{query}"
Expand Down
32 changes: 8 additions & 24 deletions epidatpy/request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from datetime import date
from typing import (
Any,
Dict,
Expand All @@ -24,7 +23,6 @@
AEpiDataCall,
EpidataFieldInfo,
EpidataFieldType,
EpiDataFormatType,
EpiDataResponse,
EpiRange,
EpiRangeParam,
Expand Down Expand Up @@ -87,11 +85,10 @@ def with_session(self, session: Session) -> "EpiDataCall":

def _call(
self,
format_type: Optional[EpiDataFormatType] = None,
fields: Optional[Sequence[str]] = None,
stream: bool = False,
) -> Response:
url, params = self.request_arguments(format_type, fields)
url, params = self.request_arguments(fields)
return _request_with_retry(url, params, self._session, stream)

def classic(
Expand All @@ -102,7 +99,7 @@ def classic(
"""Request and parse epidata in CLASSIC message format."""
self._verify_parameters()
try:
response = self._call(None, fields)
response = self._call(fields)
r = cast(EpiDataResponse, response.json())
epidata = r.get("epidata")
if epidata and isinstance(epidata, list) and len(epidata) > 0 and isinstance(epidata[0], dict):
Expand All @@ -115,25 +112,11 @@ def __call__(
self,
fields: Optional[Sequence[str]] = None,
disable_date_parsing: Optional[bool] = False,
) -> EpiDataResponse:
"""Request and parse epidata in CLASSIC message format."""
return self.classic(fields, disable_date_parsing=disable_date_parsing)

def json(
self,
fields: Optional[Sequence[str]] = None,
disable_date_parsing: Optional[bool] = False,
) -> List[Mapping[str, Union[str, int, float, date, None]]]:
"""Request and parse epidata in JSON format"""
) -> Union[EpiDataResponse, DataFrame]:
"""Request and parse epidata in df message format."""
if self.only_supports_classic:
raise OnlySupportsClassicFormatException()
self._verify_parameters()
response = self._call(EpiDataFormatType.json, fields)
response.raise_for_status()
return [
self._parse_row(row, disable_date_parsing=disable_date_parsing)
for row in cast(List[Mapping[str, Union[str, int, float, None]]], response.json())
]
return self.classic(fields, disable_date_parsing=disable_date_parsing)
return self.df(fields, disable_date_parsing=disable_date_parsing)

def df(
self,
Expand All @@ -144,7 +127,8 @@ def df(
if self.only_supports_classic:
raise OnlySupportsClassicFormatException()
self._verify_parameters()
rows = self.json(fields, disable_date_parsing=disable_date_parsing)
json = self.classic(fields, disable_date_parsing=disable_date_parsing)
rows = json.get("epidata", [])
pred = fields_to_predicate(fields)
columns: List[str] = [info.name for info in self.meta if pred(info.name)]
df = DataFrame(rows, columns=columns or None)
Expand Down
6 changes: 0 additions & 6 deletions smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
classic = apicall.classic()
print(classic)

data = apicall.json()
print(data[0])

df = apicall.df()
print(df.columns)
print(df.dtypes)
Expand Down Expand Up @@ -47,9 +44,6 @@
classic = apicall.classic()
print(classic)

data = apicall.json()
print(data[0])

df = apicall.df()
print(df.columns)
print(df.dtypes)
Expand Down

0 comments on commit b951654

Please sign in to comment.