Skip to content

Commit

Permalink
[BugFix] Fix ReferenceGenerator Unions and Choices (#6599)
Browse files Browse the repository at this point in the history
* fix duplicated unions in reference.json

* missing static files

* fix choices

* black

* mypy

* more mypy

* line too long

* fix typo

* fix another typo

* static files

* tradingeconomics country choices

* financial statement period choices

* black

* mypy

* test params

* price-historical interval choices

* test params

* test cassettes

* new test cassettes again

---------

Co-authored-by: Henrique Joaquim <[email protected]>
Co-authored-by: Igor Radovanovic <[email protected]>
  • Loading branch information
3 people authored Aug 5, 2024
1 parent 4500132 commit dd49c5f
Show file tree
Hide file tree
Showing 96 changed files with 4,061 additions and 2,348 deletions.
19 changes: 13 additions & 6 deletions openbb_platform/core/openbb_core/app/provider_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,14 +253,22 @@ def _create_field(
annotation = field.annotation

additional_description = ""
choices: Dict = {}
if extra := field.json_schema_extra:
providers = []
for p, v in extra.items(): # type: ignore[union-attr]
providers: List = []
for p, v in extra.items(): # type: ignore
if isinstance(v, dict) and v.get("multiple_items_allowed"):
providers.append(p)
choices[p] = {"multiple_items_allowed": True, "choices": v.get("choices")} # type: ignore
elif isinstance(v, list) and "multiple_items_allowed" in v:
# For backwards compatibility, before this was a list
providers.append(p)
choices[p] = {"multiple_items_allowed": True, "choices": None} # type: ignore
elif isinstance(v, dict) and v.get("choices"):
choices[p] = {
"multiple_items_allowed": False,
"choices": v.get("choices"),
}

if providers:
if provider_name:
Expand All @@ -271,7 +279,6 @@ def _create_field(
+ ", ".join(providers) # type: ignore[arg-type]
+ "."
)

provider_field = (
f"(provider: {provider_name})" if provider_name != "openbb" else ""
)
Expand Down Expand Up @@ -303,7 +310,7 @@ def _create_field(
title=provider_name,
description=description,
alias=field.alias or None,
json_schema_extra=getattr(field, "json_schema_extra", None),
json_schema_extra=choices,
),
)

Expand All @@ -318,7 +325,7 @@ def _create_field(
title=provider_name,
description=description,
alias=field.alias or None,
json_schema_extra=getattr(field, "json_schema_extra", None),
json_schema_extra=choices,
),
)
if provider_name:
Expand All @@ -329,7 +336,7 @@ def _create_field(
default=default or None,
title=provider_name,
description=description,
json_schema_extra=field.json_schema_extra,
json_schema_extra=choices,
),
)

Expand Down
63 changes: 49 additions & 14 deletions openbb_platform/core/openbb_core/app/static/package_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,15 @@ def generate_model_docstring(
def format_type(type_: str, char_limit: Optional[int] = None) -> str:
"""Format type in docstrings."""
type_str = str(type_)
type_str = type_str.replace("NoneType", "None")
type_str = (
type_str.replace("<class '", "")
.replace("'>", "")
.replace("typing.", "")
.replace("pydantic.types.", "")
.replace("NoneType", "None")
.replace("datetime.date", "date")
.replace("datetime.datetime", "datetime")
)
if char_limit:
type_str = type_str[:char_limit] + (
"..." if len(str(type_str)) > char_limit else ""
Expand Down Expand Up @@ -1109,7 +1117,7 @@ def get_param_info(parameter: Optional[Parameter]) -> Tuple[str, str]:
# Explicit parameters
for param_name, param in explicit_params.items():
type_, description = get_param_info(param)
type_str = format_type(str(type_), char_limit=79)
type_str = format_type(str(type_), char_limit=86)
docstring += f"{create_indent(2)}{param_name} : {type_str}\n"
docstring += f"{create_indent(3)}{format_description(description)}\n"

Expand Down Expand Up @@ -1484,28 +1492,26 @@ def _get_provider_field_params(
field_type, is_required, "website"
)

if params_type == "QueryParams" and field in expanded_types:
expanded_type = DocstringGenerator.get_field_type(
expanded_types[field], is_required, "website"
)
field_type = f"Union[{field_type}, {expanded_type}]"
cleaned_description = (
str(field_info.description)
.strip().replace("\n", " ").replace(" ", " ").replace('"', "'")
.strip().replace('"', "'")
) # fmt: skip

extra = field_info.json_schema_extra or {}
choices = extra.get("choices")

# Add information for the providers supporting multiple symbols
if params_type == "QueryParams" and extra:

providers = []
providers: List = []
for p, v in extra.items(): # type: ignore[union-attr]
if isinstance(v, dict) and v.get("multiple_items_allowed"):
providers.append(p)
choices = v.get("choices") # type: ignore
elif isinstance(v, list) and "multiple_items_allowed" in v:
# For backwards compatibility, before this was a list
providers.append(p)
elif isinstance(v, dict) and "choices" in v:
choices = v.get("choices")

if providers:
multiple_items = ", ".join(providers)
Expand All @@ -1515,6 +1521,12 @@ def _get_provider_field_params(
# Manually setting to List[<field_type>] for multiple items
# Should be removed if TYPE_EXPANSION is updated to include this
field_type = f"Union[{field_type}, List[{field_type}]]"
elif field in expanded_types:
expanded_type = DocstringGenerator.get_field_type(
expanded_types[field], is_required, "website"
)
field_type = f"Union[{field_type}, {expanded_type}]"

default_value = "" if field_info.default is PydanticUndefined else field_info.default # fmt: skip

provider_field_params.append(
Expand All @@ -1524,7 +1536,7 @@ def _get_provider_field_params(
"description": cleaned_description,
"default": default_value,
"optional": not is_required,
"choices": extra.get("choices"),
"choices": choices,
}
)

Expand Down Expand Up @@ -1742,25 +1754,48 @@ def get_paths(cls, route_map: Dict[str, BaseRoute]) -> Dict[str, Dict[str, Any]]
provider_parameter_fields = cls._get_provider_parameter_info(
standard_model
)
reference[path]["parameters"]["standard"].append(
provider_parameter_fields
)

# Add endpoint data fields for standard provider
reference[path]["data"]["standard"] = (
cls._get_provider_field_params(standard_model, "Data")
)
continue

# Adds provider specific parameter fields to the reference
reference[path]["parameters"][provider] = (
cls._get_provider_field_params(
standard_model, "QueryParams", provider
)
)

# Adds provider specific data fields to the reference
reference[path]["data"][provider] = cls._get_provider_field_params(
standard_model, "Data", provider
)

# Remove choices from 'standard' if choices for a parameter exist
# for both standard and provider, and are the same
standard = [
{d["name"]: d["choices"]}
for d in reference[path]["parameters"]["standard"]
if d.get("choices")
]
standard = standard[0] if standard else [] # type: ignore
_provider = [
{d["name"]: d["choices"]}
for d in reference[path]["parameters"][provider]
if d.get("choices")
]
_provider = _provider[0] if _provider else [] # type: ignore
if standard and _provider and standard == _provider:
for i, d in enumerate(
reference[path]["parameters"]["standard"]
):
if d.get("name") in standard:
reference[path]["parameters"]["standard"][i][
"choices"
] = None

# Add endpoint returns data
# Currently only OBBject object is returned
providers = provider_parameter_fields["type"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ class BalanceSheetQueryParams(QueryParams):
"""Balance Sheet Query."""

symbol: str = Field(description=QUERY_DESCRIPTIONS.get("symbol", ""))
period: str = Field(
default="annual",
description=QUERY_DESCRIPTIONS.get("period", ""),
)
limit: Optional[NonNegativeInt] = Field(
default=5, description=QUERY_DESCRIPTIONS.get("limit", "")
)
Expand All @@ -30,12 +26,6 @@ def to_upper(cls, v: str):
"""Convert field to uppercase."""
return v.upper()

@field_validator("period", mode="before", check_fields=False)
@classmethod
def to_lower(cls, v: Optional[str]) -> Optional[str]:
"""Convert field to lowercase."""
return v.lower() if v else v


class BalanceSheetData(Data):
"""Balance Sheet Data."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ class BalanceSheetGrowthQueryParams(QueryParams):
"""Balance Sheet Statement Growth Query."""

symbol: str = Field(description=QUERY_DESCRIPTIONS.get("symbol", ""))
period: str = Field(
default="annual",
description=QUERY_DESCRIPTIONS.get("period", ""),
)
limit: Optional[int] = Field(
default=10, description=QUERY_DESCRIPTIONS.get("limit", "")
)
Expand All @@ -28,12 +24,6 @@ def to_upper(cls, v: str):
"""Convert field to uppercase."""
return v.upper()

@field_validator("period", mode="before", check_fields=False)
@classmethod
def to_lower(cls, v: Optional[str]) -> Optional[str]:
"""Convert field to lowercase."""
return v.lower() if v else v


class BalanceSheetGrowthData(Data):
"""Balance Sheet Statement Growth Data."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ class CashFlowStatementQueryParams(QueryParams):
"""Cash Flow Statement Query."""

symbol: str = Field(description=QUERY_DESCRIPTIONS.get("symbol", ""))
period: str = Field(
default="annual",
description=QUERY_DESCRIPTIONS.get("period", ""),
)
limit: Optional[NonNegativeInt] = Field(
default=5, description=QUERY_DESCRIPTIONS.get("limit", "")
)
Expand All @@ -28,12 +24,6 @@ def to_upper(cls, v: str):
"""Convert field to uppercase."""
return v.upper()

@field_validator("period", mode="before", check_fields=False)
@classmethod
def to_lower(cls, v: Optional[str]) -> Optional[str]:
"""Convert field to lowercase."""
return v.lower() if v else v


class CashFlowStatementData(Data):
"""Cash Flow Statement Data."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ class CashFlowStatementGrowthQueryParams(QueryParams):
"""Cash Flow Statement Growth Query."""

symbol: str = Field(description=QUERY_DESCRIPTIONS.get("symbol", ""))
period: str = Field(
default="annual",
description=QUERY_DESCRIPTIONS.get("period", ""),
)
limit: Optional[int] = Field(
default=10, description=QUERY_DESCRIPTIONS.get("limit", "")
)
Expand All @@ -28,12 +24,6 @@ def to_upper(cls, v: str) -> str:
"""Convert field to uppercase."""
return v.upper()

@field_validator("period", mode="before", check_fields=False)
@classmethod
def to_lower(cls, v: Optional[str]) -> Optional[str]:
"""Convert field to lowercase."""
return v.lower() if v else v


class CashFlowStatementGrowthData(Data):
"""Cash Flow Statement Growth Data."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ class EquityHistoricalQueryParams(QueryParams):
"""Equity Historical Price Query."""

symbol: str = Field(description=QUERY_DESCRIPTIONS.get("symbol", ""))
interval: Optional[str] = Field(
default="1d",
description=QUERY_DESCRIPTIONS.get("interval", ""),
)
start_date: Optional[dateType] = Field(
default=None,
description=QUERY_DESCRIPTIONS.get("start_date", ""),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ class FinancialRatiosQueryParams(QueryParams):
"""Financial Ratios Query."""

symbol: str = Field(description=QUERY_DESCRIPTIONS.get("symbol", ""))
period: str = Field(
default="annual", description=QUERY_DESCRIPTIONS.get("period", "")
)
limit: NonNegativeInt = Field(
default=12, description=QUERY_DESCRIPTIONS.get("limit", "")
)
Expand All @@ -29,12 +26,6 @@ def to_upper(cls, v: str):
"""Convert field to uppercase."""
return v.upper()

@field_validator("period", mode="before", check_fields=False)
@classmethod
def to_lower(cls, v: Optional[str]) -> Optional[str]:
"""Convert field to lowercase."""
return v.lower() if v else v


class FinancialRatiosData(Data):
"""Financial Ratios Standard Model."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ class IncomeStatementQueryParams(QueryParams):
"""Income Statement Query."""

symbol: str = Field(description=QUERY_DESCRIPTIONS.get("symbol", ""))
period: str = Field(
default="annual",
description=QUERY_DESCRIPTIONS.get("period", ""),
)
limit: Optional[NonNegativeInt] = Field(
default=5, description=QUERY_DESCRIPTIONS.get("limit", "")
)
Expand All @@ -30,12 +26,6 @@ def to_upper(cls, v: str):
"""Convert field to uppercase."""
return v.upper()

@field_validator("period", mode="before", check_fields=False)
@classmethod
def to_lower(cls, v: Optional[str]) -> Optional[str]:
"""Convert field to lowercase."""
return v.lower() if v else v


class IncomeStatementData(Data):
"""Income Statement Data."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ class IncomeStatementGrowthQueryParams(QueryParams):
"""Income Statement Growth Query."""

symbol: str = Field(description=QUERY_DESCRIPTIONS.get("symbol", ""))
period: str = Field(
default="annual",
description=QUERY_DESCRIPTIONS.get("period", ""),
)
limit: Optional[int] = Field(
default=10, description=QUERY_DESCRIPTIONS.get("limit", "")
)
Expand All @@ -28,12 +24,6 @@ def to_upper(cls, v: str) -> str:
"""Convert field to uppercase."""
return v.upper()

@field_validator("period", mode="before", check_fields=False)
@classmethod
def to_lower(cls, v: Optional[str]) -> Optional[str]:
"""Convert field to lowercase."""
return v.lower() if v else v


class IncomeStatementGrowthData(Data):
"""Income Statement Growth Data."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,13 @@ class IndexHistoricalQueryParams(QueryParams):
end_date: Optional[dateType] = Field(
description=QUERY_DESCRIPTIONS.get("end_date", ""), default=None
)
interval: Optional[str] = Field(
default="1d",
description=QUERY_DESCRIPTIONS.get("interval", ""),
)

@field_validator("symbol", mode="before", check_fields=False)
@classmethod
def to_upper(cls, v: str) -> str:
"""Convert field to uppercase."""
return v.upper()

@field_validator("sort", mode="before", check_fields=False)
@classmethod
def to_lower(cls, v: Optional[str]) -> Optional[str]:
"""Convert field to lowercase."""
return v.lower() if v else v


class IndexHistoricalData(Data):
"""Index Historical Data."""
Expand Down
Loading

0 comments on commit dd49c5f

Please sign in to comment.