Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [#19] schema validation #33

Open
wants to merge 2 commits into
base: devlopment
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@ ENV PYTHONDONTWRITEBYTECODE 1
# Don't buffer stdout and stderr
ENV PYTHONUNBUFFERED 1
COPY . .
COPY Pipfile Pipfile.lock /app/
COPY .env /app/.env
RUN pip install --upgrade pip pipenv
RUN pipenv install --deploy --system
1 change: 1 addition & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ sphinx-autodoc-typehints = "*"
sphinx-rtd-theme = "*"
autodoc-pydantic = "*"
sentry-sdk = {version = "*", extras = ["fastapi"]}
email-validator = "*"

[dev-packages]
pytest = "*"
Expand Down
2,768 changes: 1,424 additions & 1,344 deletions Pipfile.lock

Large diffs are not rendered by default.

157 changes: 60 additions & 97 deletions app/services/solr/validate/schema/validate.py
Original file line number Diff line number Diff line change
@@ -1,125 +1,88 @@
# pylint: disable=line-too-long, invalid-name, logging-fstring-interpolation
"""Validate transformation"""
import logging
from itertools import zip_longest
from typing import KeysView, Literal
from typing import Literal

from pyspark.sql import DataFrame
import pandas as pd
from pydantic import BaseModel, ValidationError
from pyspark.sql import DataFrame as SparkDataFrame

logger = logging.getLogger(__name__)


def validate_schema(
df: DataFrame,
expected_schema: dict[str, list[str] | str],
df: SparkDataFrame | pd.DataFrame,
pydantic_model: BaseModel,
collection: str,
source: Literal["input", "output"],
) -> None:
"""Check whether pyspark dataframe data schema is the same as expected"""
# Assumption: schemas are sorted alphabetically
df = df.select(*sorted(df.columns)) # Sort pyspark dataframe
expected_schema = sort_dict_schemas(expected_schema) # Sort expected schema

validate_column_names(df.columns, expected_schema.keys(), collection, source)
validate_column_types(df, expected_schema, collection, source)
"""Validate DataFrame against the Pydantic model schema."""
if isinstance(df, SparkDataFrame):
validate_spark_schema(df, pydantic_model, collection, source)
elif isinstance(df, pd.DataFrame):
validate_pandas_schema(df, pydantic_model, collection, source)
else:
logger.warning("Unsupported DataFrame type: %s", type(df))


def validate_column_names(
actual_columns: list[str],
expected_columns: KeysView[str],
collection: str,
source: Literal["input", "output"],
) -> None:
"""Validate that column names match"""
cols_diff = set(actual_columns) ^ set(expected_columns)
if cols_diff:
logger.warning(
f"{collection} - {source} schema validation failure. Column names mismatch. Difference: {cols_diff}"
)
def validate_partition(partition, pydantic_model: BaseModel):
"""Validate a partition of data against the Pydantic model."""
errors = set()
for row in partition:
try:
pydantic_model.parse_obj(row.asDict())
except ValidationError as e:
error_json = e.json()
errors.add(error_json)
return list(errors)


def validate_column_types(
df: DataFrame,
expected_schema: dict[str, list[str] | str],
def validate_spark_schema(
df: SparkDataFrame,
pydantic_model: BaseModel,
collection: str,
source: Literal["input", "output"],
) -> None:
"""Validate that column types match"""
actual_types = [column.dataType.simpleString() for column in df.schema.fields]
differences = get_schema_differences(df.columns, actual_types, expected_schema)
if differences:
"""Validate Spark DataFrame against Pydantic model schema."""
errors = df.rdd.mapPartitions(
lambda partition: validate_partition(partition, pydantic_model)
).collect()

all_errors = [
err
for partition_errors in errors
for err in partition_errors
if partition_errors
]

if all_errors:
unique_errors = list(set(all_errors))
logger.warning(
f"{collection} - {source} schema validation failure. Column types mismatch. Differences: {differences}"
"%s - %s schema validation failure. Distinct validation errors: %s",
collection,
source,
unique_errors,
)


def validate_pd_schema(
df: DataFrame,
expected_schema: dict[str, list[str] | str],
collection: str,
source: Literal["input", "output"],
) -> None:
"""Validate Pandas schema"""
# Assumption: schemas are sorted alphabetically
df = df.reindex(sorted(df.columns), axis=1) # Sort pandas df
expected_schema = sort_dict_schemas(expected_schema) # Sort expected schema

validate_column_names(df.columns, expected_schema.keys(), collection, source)
validate_pd_column_types(df, expected_schema, collection, source)


def validate_pd_column_types(
df: DataFrame,
expected_schema: dict[str, list[str] | str],
def validate_pandas_schema(
df: pd.DataFrame,
pydantic_model: BaseModel,
collection: str,
source: Literal["input", "output"],
) -> None:
"""Validate Pandas column types"""
actual_schema = get_pd_df_schema(df)
differences = get_schema_differences(
list(actual_schema.keys()), list(actual_schema.values()), expected_schema
)
if differences:
"""Validate Pandas DataFrame against Pydantic model schema."""
errors = set()
for _, row in df.iterrows():
try:
pydantic_model.parse_obj(row.asDict())
except ValidationError as e:
errors.add(e.json())

if errors:
logger.warning(
f"{collection} - {source} schema validation failure. Column types mismatch. Differences: {differences}"
"%s - %s schema validation failure. Validation errors: %s",
collection,
source,
errors,
)


def get_pd_df_schema(df: DataFrame) -> dict:
"""Get Pandas data schema"""

def get_column_type(column):
for val in df[column]:
if val is not None:
return type(val).__name__
return "NoneType" # If all values are None

return {col: get_column_type(col) for col in df.columns}


def is_type_match(actual_type: str, expected_type: list[str] | str) -> bool:
"""Check if the actual type matches the expected type(s)"""
if isinstance(expected_type, list):
return actual_type in expected_type or actual_type == "void"
else:
return actual_type == expected_type or actual_type == "void"


def get_schema_differences(
columns: list[str],
actual_schema: list[str],
expected_schema: dict[str, list[str] | str],
) -> dict[str, dict[str, list[str] | str]]:
"""Print differences between schemas types"""
differences = {}

for col, actual_sch in zip_longest(columns, actual_schema):
expected_sch = expected_schema[col]
if not is_type_match(actual_sch, expected_sch):
differences[col] = {"ACTUAL": actual_sch, "EXPECTED": expected_sch}
return differences


def sort_dict_schemas(expected_schema: dict) -> dict:
"""Sort expected dict schema"""
return {k: expected_schema[k] for k in sorted(expected_schema)}
60 changes: 30 additions & 30 deletions app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from pydantic import AnyUrl
from pydantic_settings import BaseSettings, SettingsConfigDict

from schemas.old.input import *
from schemas.old.output import *
from schemas.db import *
from schemas.input import *

logger = logging.getLogger(__name__)
EnvironmentType = Literal["dev", "test", "production"]
Expand Down Expand Up @@ -210,73 +210,73 @@ def get_collections_config(self) -> dict:
collections = {
self.SOFTWARE: {
PATH: self.SOFTWARE_PATH,
OUTPUT_SCHEMA: software_output_schema,
INPUT_SCHEMA: software_input_schema,
OUTPUT_SCHEMA: SoftwareDBSchema,
INPUT_SCHEMA: SoftwareInputSchema,
},
self.OTHER_RP: {
PATH: self.OTHER_RP_PATH,
OUTPUT_SCHEMA: other_rp_output_schema,
INPUT_SCHEMA: other_rp_input_schema,
OUTPUT_SCHEMA: OtherRPDBSchema,
INPUT_SCHEMA: OtherRPInputSchema,
},
self.DATASET: {
PATH: self.DATASET_PATH,
OUTPUT_SCHEMA: dataset_output_schema,
INPUT_SCHEMA: dataset_input_schema,
OUTPUT_SCHEMA: DatasetDBSchema,
INPUT_SCHEMA: DatasetInputSchema,
},
self.PUBLICATION: {
PATH: self.PUBLICATION_PATH,
OUTPUT_SCHEMA: publication_output_schema,
INPUT_SCHEMA: publication_input_schema,
OUTPUT_SCHEMA: PublicationDBSchema,
INPUT_SCHEMA: PublicationInputSchema,
},
self.ORGANISATION: {
PATH: self.ORGANISATION_PATH,
OUTPUT_SCHEMA: organisation_output_schema,
INPUT_SCHEMA: organisation_input_schema,
OUTPUT_SCHEMA: OrganisationDBSchema,
INPUT_SCHEMA: OrganisationInputSchema,
},
self.PROJECT: {
PATH: self.PROJECT_PATH,
OUTPUT_SCHEMA: project_output_schema,
INPUT_SCHEMA: project_input_schema,
OUTPUT_SCHEMA: ProjectDBSchema,
INPUT_SCHEMA: ProjectInputSchema,
},
self.SERVICE: {
ADDRESS: mp_api + "services",
OUTPUT_SCHEMA: service_output_schema,
INPUT_SCHEMA: service_input_schema,
OUTPUT_SCHEMA: ServiceDBSchema,
INPUT_SCHEMA: ServiceInputSchema,
},
self.DATASOURCE: {
ADDRESS: mp_api + "datasources",
OUTPUT_SCHEMA: data_source_output_schema,
INPUT_SCHEMA: data_source_input_schema,
OUTPUT_SCHEMA: DataSourceDBSchema,
INPUT_SCHEMA: DataSourceInputSchema,
},
self.BUNDLE: {
ADDRESS: mp_api + "bundles",
OUTPUT_SCHEMA: bundle_output_schema,
INPUT_SCHEMA: bundle_input_schema,
OUTPUT_SCHEMA: BundleDBSchema,
INPUT_SCHEMA: BundleInputSchema,
},
self.GUIDELINE: {
ADDRESS: str(self.GUIDELINE_ADDRESS),
OUTPUT_SCHEMA: guideline_output_schema,
INPUT_SCHEMA: guideline_input_schema,
OUTPUT_SCHEMA: GuidelineDBSchema,
INPUT_SCHEMA: GuidelineInputSchema,
},
self.TRAINING: {
ADDRESS: str(self.TRAINING_ADDRESS),
OUTPUT_SCHEMA: training_output_schema,
INPUT_SCHEMA: training_input_schema,
OUTPUT_SCHEMA: TrainingDBSchema,
INPUT_SCHEMA: TrainingInputSchema,
},
self.PROVIDER: {
ADDRESS: mp_api + "providers",
OUTPUT_SCHEMA: provider_output_schema,
INPUT_SCHEMA: provider_input_schema,
OUTPUT_SCHEMA: ProviderDBSchema,
INPUT_SCHEMA: ProviderInputSchema,
},
self.OFFER: {
ADDRESS: mp_api + "offers",
OUTPUT_SCHEMA: offer_output_schema,
INPUT_SCHEMA: offer_input_schema,
OUTPUT_SCHEMA: OfferDBSchema,
INPUT_SCHEMA: OfferInputSchema,
},
self.CATALOGUE: {
ADDRESS: mp_api + "catalogues",
OUTPUT_SCHEMA: catalogue_output_schema,
INPUT_SCHEMA: catalogue_input_schema,
OUTPUT_SCHEMA: CatalogueDBSchema,
INPUT_SCHEMA: CatalogueInputSchema,
},
}

Expand Down
3 changes: 2 additions & 1 deletion app/tasks/transform/dump/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional

import boto3
from pydantic import BaseModel
from pyspark.sql import DataFrame, SparkSession
from tqdm import tqdm

Expand Down Expand Up @@ -138,7 +139,7 @@ def transform_file(
spark: SparkSession,
s3_client: boto3.client,
collection_name: str,
input_schema: dict,
input_schema: BaseModel,
error_log: dict,
) -> Optional[DataFrame]:
"""
Expand Down
7 changes: 5 additions & 2 deletions app/transform/transformers/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
from logging import getLogger

from pydantic import BaseModel
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import col, udf
from pyspark.sql.types import StringType, StructType
Expand All @@ -28,7 +29,7 @@ def __init__(
cols_to_add: tuple[str, ...] | None,
cols_to_drop: tuple[str, ...] | None,
cols_to_rename: dict[str, str] | None,
exp_output_schema: dict,
exp_output_schema: BaseModel,
spark: SparkSession,
):
self.type = desired_type
Expand Down Expand Up @@ -87,7 +88,9 @@ def filter_columns(self, df: DataFrame) -> DataFrame:
In that manner, if any column was added additionally it won't be included in output data
"""
expected_columns = [
col for col in df.columns if col in self._exp_output_schema.keys()
col
for col in df.columns
if col in self._exp_output_schema.model_fields.keys()
]
return df.select(*expected_columns)

Expand Down
3 changes: 2 additions & 1 deletion app/transform/transformers/base/oag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Transform OAG resources"""
from abc import abstractmethod

from pydantic import BaseModel
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import array, col, lit, year
from pyspark.sql.types import (
Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(
desired_type: str,
cols_to_add: tuple[str, ...] | None,
cols_to_drop: tuple[str, ...] | None,
exp_output_schema: dict,
exp_output_schema: BaseModel,
spark: SparkSession,
):
super().__init__(
Expand Down
4 changes: 2 additions & 2 deletions app/transform/transformers/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from app.transform.transformers.base.base import BaseTransformer
from app.transform.utils.common import harvest_popularity
from app.transform.utils.utils import sort_schema
from schemas.old.output.bundle import bundle_output_schema
from schemas.db.bundle import BundleDBSchema
from schemas.properties.data import ID, POPULARITY, TYPE


Expand All @@ -24,7 +24,7 @@ class BundleTransformer(BaseTransformer):
def __init__(self, spark: SparkSession):
self.type = settings.BUNDLE
self.id_increment = settings.BUNDLE_IDS_INCREMENTOR
self.exp_output_schema = bundle_output_schema
self.exp_output_schema = BundleDBSchema

super().__init__(
self.type,
Expand Down
Loading
Loading