Skip to content

Commit

Permalink
Merge pull request #1166 from elementary-data/revert-1143-ele-1696-up…
Browse files Browse the repository at this point in the history
…grade-pydantic

Revert "Ele 1696 upgrade pydantic"
  • Loading branch information
elongl authored Sep 18, 2023
2 parents 067c366 + 54d837d commit fb01967
Show file tree
Hide file tree
Showing 23 changed files with 92 additions and 103 deletions.
4 changes: 2 additions & 2 deletions elementary/clients/dbt/slim_dbt_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dbt.parser.manifest import ManifestLoader
from dbt.tracking import disable_tracking
from dbt.version import __version__ as dbt_version_string
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, validator

from elementary.clients.dbt.base_dbt_runner import BaseDbtRunner
from elementary.utils.log import get_logger
Expand Down Expand Up @@ -65,7 +65,7 @@ class ConfigArgs(BaseModel):
threads: Optional[int] = 1
vars: Optional[Union[str, Dict[str, Any]]] = DEFAULT_VARS

@field_validator("vars", mode="before")
@validator("vars", pre=True)
def validate_vars(cls, vars):
if not vars:
return DEFAULT_VARS
Expand Down
4 changes: 2 additions & 2 deletions elementary/clients/slack/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional

from pydantic import BaseModel, field_validator
from pydantic import BaseModel, validator

from elementary.utils.log import get_logger

Expand All @@ -17,7 +17,7 @@ class SlackMessageSchema(BaseModel):
attachments: Optional[list] = None
blocks: Optional[list] = None

@field_validator("attachments", mode="before")
@validator("attachments", pre=True)
def validate_attachments(cls, attachments):
if (
isinstance(attachments, list)
Expand Down
6 changes: 1 addition & 5 deletions elementary/monitor/api/filters/schema.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import List, Optional

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel


class FilterSchema(BaseModel):
model_config = ConfigDict(protected_namespaces=())

name: str
display_name: str
model_unique_ids: List[Optional[str]] = []
Expand All @@ -16,8 +14,6 @@ def add_model_unique_id(self, model_unique_id: Optional[str]):


class FiltersSchema(BaseModel):
model_config = ConfigDict(protected_namespaces=())

test_results: List[FilterSchema] = list()
test_runs: List[FilterSchema] = list()
model_runs: List[FilterSchema] = list()
4 changes: 2 additions & 2 deletions elementary/monitor/api/groups/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@


class GroupItemSchema(BaseModel):
node_id: Optional[str] = None
resource_type: Optional[str] = None
node_id: Optional[str]
resource_type: Optional[str]


DbtGroupSchema = Dict[str, dict]
Expand Down
9 changes: 5 additions & 4 deletions elementary/monitor/api/lineage/schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import List, Literal, Optional, Tuple
from typing import List, Optional, Tuple

import networkx as nx
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, validator
from pydantic.typing import Literal

NodeUniqueIdType = str
NodeType = Literal["model", "source", "exposure"]
Expand All @@ -16,11 +17,11 @@ class LineageSchema(BaseModel):
nodes: Optional[List[LineageNodeSchema]] = None
edges: Optional[List[Tuple[NodeUniqueIdType, NodeUniqueIdType]]] = None

@field_validator("nodes", mode="before")
@validator("nodes", pre=True, always=True)
def set_nodes(cls, nodes):
return nodes or []

@field_validator("edges", mode="before")
@validator("edges", pre=True, always=True)
def set_edges(cls, edges):
return edges or []

Expand Down
3 changes: 2 additions & 1 deletion elementary/monitor/api/models/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import statistics
from collections import defaultdict
Expand Down Expand Up @@ -202,7 +203,7 @@ def _normalize_dbt_artifact_dict(
SourceSchema: NormalizedSourceSchema,
}
artifact_name = artifact.name
normalized_artifact = artifact.model_dump()
normalized_artifact = json.loads(artifact.json())
normalized_artifact["model_name"] = artifact_name
normalized_artifact["normalized_full_path"] = self._normalize_artifact_path(
artifact
Expand Down
24 changes: 11 additions & 13 deletions elementary/monitor/api/models/schema.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import posixpath
from typing import Dict, List, Literal, Optional
from typing import Dict, List, Optional

from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, Field, validator

from elementary.monitor.api.totals_schema import TotalsSchema
from elementary.monitor.fetchers.models.schema import (
Expand All @@ -15,8 +15,6 @@


class NormalizedArtifactSchema(ExtendedBaseModel):
model_config = ConfigDict(protected_namespaces=())

owners: Optional[List[str]] = []
tags: Optional[List[str]] = []
# Should be changed to artifact_name.
Expand All @@ -25,32 +23,32 @@ class NormalizedArtifactSchema(ExtendedBaseModel):
normalized_full_path: str
fqn: str

@field_validator("tags", mode="before")
@validator("tags", pre=True)
def load_tags(cls, tags):
return cls._load_var_to_list(tags)

@field_validator("owners", mode="before")
@validator("owners", pre=True)
def load_owners(cls, owners):
return cls._load_var_to_list(owners)

@field_validator("normalized_full_path", mode="before")
@validator("normalized_full_path", pre=True)
def format_normalized_full_path_sep(cls, normalized_full_path: str) -> str:
return posixpath.sep.join(normalized_full_path.split(os.path.sep))


# NormalizedArtifactSchema must be first in the inheritance order
class NormalizedModelSchema(NormalizedArtifactSchema, ModelSchema):
artifact_type: Literal["model"] = "model"
artifact_type: str = Field("model", const=True)


# NormalizedArtifactSchema must be first in the inheritance order
class NormalizedSourceSchema(NormalizedArtifactSchema, SourceSchema):
artifact_type: Literal["source"] = "source"
artifact_type: str = Field("source", const=True)


# NormalizedArtifactSchema must be first in the inheritance order
class NormalizedExposureSchema(NormalizedArtifactSchema, ExposureSchema):
artifact_type: Literal["exposure"] = "exposure"
artifact_type: str = Field("exposure", const=True)


class ModelCoverageSchema(BaseModel):
Expand All @@ -62,11 +60,11 @@ class ModelRunSchema(BaseModel):
id: str
time_utc: str
status: str
full_refresh: Optional[bool] = None
materialization: Optional[str] = None
full_refresh: Optional[bool]
materialization: Optional[str]
execution_time: float

@field_validator("time_utc", mode="before")
@validator("time_utc", pre=True)
def format_time_utc(cls, time_utc):
return convert_partial_iso_format_to_full_iso_format(time_utc)

Expand Down
18 changes: 9 additions & 9 deletions elementary/monitor/api/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def get_report_data(
test_results.totals, test_runs.totals, models, sources, models_runs.runs
)

serializable_groups = groups.model_dump()
serializable_groups = groups.dict()
serializable_models = self._serialize_models(models, sources, exposures)
serializable_model_runs = self._serialize_models_runs(models_runs.runs)
serializable_model_runs_totals = models_runs.model_dump(include={"totals"})[
serializable_model_runs_totals = models_runs.dict(include={"totals"})[
"totals"
]
serializable_models_coverages = self._serialize_coverages(coverages)
Expand All @@ -86,9 +86,9 @@ def get_report_data(
)
serializable_test_runs = self._serialize_test_runs(test_runs.runs)
serializable_test_runs_totals = self._serialize_totals(test_runs.totals)
serializable_invocation = test_results.invocation.model_dump()
serializable_filters = filters.model_dump()
serializable_lineage = lineage.model_dump()
serializable_invocation = test_results.invocation.dict()
serializable_filters = filters.dict()
serializable_lineage = lineage.dict()

models_latest_invocation = invocations_api.get_models_latest_invocation()
invocations = invocations_api.get_models_latest_invocations_data()
Expand Down Expand Up @@ -143,15 +143,15 @@ def _serialize_coverages(
return {model_id: dict(coverage) for model_id, coverage in coverages.items()}

def _serialize_models_runs(self, models_runs: List[ModelRunsSchema]) -> List[dict]:
return [model_runs.model_dump(by_alias=True) for model_runs in models_runs]
return [model_runs.dict(by_alias=True) for model_runs in models_runs]

def _serialize_test_results(
self, test_results: Dict[Optional[str], List[TestResultSchema]]
) -> Dict[Optional[str], List[dict]]:
serializable_test_results = defaultdict(list)
for model_unique_id, test_result in test_results.items():
serializable_test_results[model_unique_id].extend(
[result.model_dump() for result in test_result]
[result.dict() for result in test_result]
)
return serializable_test_results

Expand All @@ -161,7 +161,7 @@ def _serialize_test_runs(
serializable_test_runs = defaultdict(list)
for model_unique_id, test_run in test_runs.items():
serializable_test_runs[model_unique_id].extend(
[run.model_dump() for run in test_run]
[run.dict() for run in test_run]
)
return serializable_test_runs

Expand All @@ -170,5 +170,5 @@ def _serialize_totals(
) -> Dict[Optional[str], dict]:
serialized_totals = dict()
for model_unique_id, total in totals.items():
serialized_totals[model_unique_id] = total.model_dump()
serialized_totals[model_unique_id] = total.dict()
return serialized_totals
4 changes: 1 addition & 3 deletions elementary/monitor/api/report/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel


class ReportDataEnvSchema(BaseModel):
Expand All @@ -10,8 +10,6 @@ class ReportDataEnvSchema(BaseModel):


class ReportDataSchema(BaseModel):
model_config = ConfigDict(protected_namespaces=())

creation_time: Optional[str] = None
days_back: Optional[int] = None
models: dict = dict()
Expand Down
2 changes: 0 additions & 2 deletions elementary/monitor/api/test_management/test_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@


class TestManagementAPI(APIClient):
__test__ = False

def __init__(
self,
dbt_runner: BaseDbtRunner,
Expand Down
16 changes: 10 additions & 6 deletions elementary/monitor/api/tests/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, List, Optional, Union

from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, Field, validator

from elementary.monitor.api.totals_schema import TotalsSchema
from elementary.monitor.fetchers.invocations.schema import DbtInvocationSchema
Expand All @@ -12,6 +12,9 @@ class ElementaryTestResultSchema(BaseModel):
metrics: Optional[Union[list, dict]] = None
result_description: Optional[str] = None

class Config:
smart_union = True


class DbtTestResultSchema(BaseModel):
display_name: Optional[str] = None
Expand All @@ -21,12 +24,12 @@ class DbtTestResultSchema(BaseModel):


class InvocationSchema(BaseModel):
affected_rows: Optional[int] = None
affected_rows: Optional[int]
time_utc: str
id: str
status: str

@field_validator("time_utc", mode="before")
@validator("time_utc", pre=True)
def format_time_utc(cls, time_utc):
return convert_partial_iso_format_to_full_iso_format(time_utc)

Expand All @@ -39,8 +42,6 @@ class InvocationsSchema(BaseModel):


class TestMetadataSchema(BaseModel):
model_config = ConfigDict(protected_namespaces=())

test_unique_id: str
elementary_unique_id: str
database_name: Optional[str] = None
Expand Down Expand Up @@ -68,6 +69,9 @@ class TestResultSchema(BaseModel):
metadata: TestMetadataSchema
test_results: Union[DbtTestResultSchema, ElementaryTestResultSchema]

class Config:
smart_union = True


class TestResultsWithTotalsSchema(BaseModel):
results: Dict[Optional[str], List[TestResultSchema]] = dict()
Expand All @@ -77,7 +81,7 @@ class TestResultsWithTotalsSchema(BaseModel):

class TestRunSchema(BaseModel):
metadata: TestMetadataSchema
test_runs: Optional[InvocationsSchema] = None
test_runs: Optional[InvocationsSchema]


class TestRunsWithTotalsSchema(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion elementary/monitor/data_monitoring/data_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
tracking.register_group(
"warehouse",
self.warehouse_info.id,
self.warehouse_info.model_dump(),
self.warehouse_info.dict(),
)
tracking.set_env("target_name", latest_invocation.get("target_name"))
tracking.set_env("dbt_orchestrator", latest_invocation.get("orchestrator"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_report_data(
)
self.success = False

report_data_dict = report_data.model_dump()
report_data_dict = report_data.dict()
return report_data_dict

def _add_report_tracking(
Expand Down
4 changes: 2 additions & 2 deletions elementary/monitor/data_monitoring/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from typing import Dict, List, Optional

from pydantic import BaseModel, field_validator
from pydantic import BaseModel, validator

from elementary.monitor.alerts.model import ModelAlert
from elementary.monitor.alerts.source_freshness import SourceFreshnessAlert
Expand Down Expand Up @@ -62,7 +62,7 @@ class SelectorFilterSchema(BaseModel):
resource_types: Optional[List[ResourceType]] = None
node_names: Optional[List[str]] = None

@field_validator("invocation_time", mode="before")
@validator("invocation_time", pre=True)
def format_invocation_time(cls, invocation_time):
if invocation_time:
try:
Expand Down
6 changes: 3 additions & 3 deletions elementary/monitor/fetchers/invocations/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from pydantic import BaseModel, field_validator
from pydantic import BaseModel, validator

from elementary.utils.json_utils import try_load_json
from elementary.utils.time import convert_partial_iso_format_to_full_iso_format
Expand All @@ -17,11 +17,11 @@ class DbtInvocationSchema(BaseModel):
job_id: Optional[str] = None
orchestrator: Optional[str] = None

@field_validator("detected_at", mode="before")
@validator("detected_at", pre=True)
def format_detected_at(cls, detected_at):
return convert_partial_iso_format_to_full_iso_format(detected_at)

@field_validator("selected", mode="before")
@validator("selected", pre=True)
def format_selected(cls, selected):
selected_list = try_load_json(selected) or []
return " ".join(selected_list)
Loading

0 comments on commit fb01967

Please sign in to comment.