Skip to content

Commit

Permalink
feat: update multitable interface & datasources information (#136)
Browse files Browse the repository at this point in the history
* feat: update multitable interface

* fix(linting): code formatting

* fix: remove typeguard Datasource type validation

This does not work for multitable datasets. Needs to be revisisted later on.

* fix(linting): code formatting

* chore: update links

* chore: fix linter messages

* chore: fix linting error.

* fix(linting): code formatting

---------

Co-authored-by: Azory YData Bot <[email protected]>
  • Loading branch information
fabclmnt and azory-ydata authored Jan 29, 2025
1 parent 7e1363c commit 7b22974
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 44 deletions.
2 changes: 2 additions & 0 deletions src/ydata/sdk/datasources/_models/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class DataSource:
datatype: Optional[DataSourceType] = None
metadata: Optional[Metadata] = None
status: Optional[Status] = None
connector_ref: Optional[str] = None
connector_type: Optional[str] = None

def __post_init__(self):
if self.metadata is not None:
Expand Down
1 change: 1 addition & 0 deletions src/ydata/sdk/datasources/_models/datasources/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
class MySQLDataSource(DataSource):

query: str = None
tables: dict = None

def to_payload(self):
self.dict()
5 changes: 3 additions & 2 deletions src/ydata/sdk/datasources/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ def get(uid: UID, project: Optional[Project] = None, client: Optional[Client] =
data: list = response.json()
datasource_type = CONNECTOR_TO_DATASOURCE.get(
ConnectorType(data['connector']['type']))
model = DataSource._model_from_api(data, datasource_type)
datasource = DataSource._init_from_model_data(model)
datasource = DataSource._model_from_api(data, datasource_type)
datasource._project = project
return datasource

Expand Down Expand Up @@ -211,6 +210,8 @@ def _wait_for_metadata(datasource):
@staticmethod
def _model_from_api(data: Dict, datasource_type: Type[mDataSource]) -> mDataSource:
data['datatype'] = data.pop('dataType', None)
data['connector_ref'] = data['connector']['uid']
data['connector_type'] = data['connector']['type']
data = filter_dict(datasource_type, data)
model = datasource_type(**data)
return model
Expand Down
18 changes: 4 additions & 14 deletions src/ydata/sdk/synthesizers/multitable.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from __future__ import annotations

from time import sleep
from typing import Dict, List, Optional, Union
from typing import Dict, Optional, Union

from ydata.datascience.common import PrivacyLevel
from ydata.sdk.common.client import Client
from ydata.sdk.common.config import BACKOFF
from ydata.sdk.common.exceptions import ConnectorError, InputError
from ydata.sdk.common.types import UID, Project
from ydata.sdk.connectors.connector import Connector, ConnectorType
from ydata.sdk.datasources import DataSource
from ydata.sdk.datasources._models.datatype import DataSourceType
from ydata.sdk.datasources._models.metadata.data_types import DataType
from ydata.sdk.synthesizers.synthesizer import BaseSynthesizer
Expand Down Expand Up @@ -43,17 +41,10 @@ def __init__(
connector = self._check_or_fetch_connector(write_connector)
self.__write_connector = connector.uid

def fit(self, X: DataSource,
privacy_level: PrivacyLevel = PrivacyLevel.HIGH_FIDELITY,
def fit(self, X,
datatype: Optional[Union[DataSourceType, str]] = None,
sortbykey: Optional[Union[str, List[str]]] = None,
entities: Optional[Union[str, List[str]]] = None,
generate_cols: Optional[List[str]] = None,
exclude_cols: Optional[List[str]] = None,
dtypes: Optional[Dict[str, Union[str, DataType]]] = None,
target: Optional[str] = None,
anonymize: Optional[dict] = None,
condition_on: Optional[List[str]] = None) -> None:
anonymize: Optional[dict] = None) -> None:
"""Fit the synthesizer.
The synthesizer accepts as training dataset a YData [`DataSource`][ydata.sdk.datasources.DataSource].
Expand All @@ -62,8 +53,7 @@ def fit(self, X: DataSource,
Arguments:
X (DataSource): DataSource to Train
"""

self._fit_from_datasource(X, datatype=DataSourceType.MULTITABLE)
super().fit(X, datatype=DataSourceType.MULTITABLE)

def sample(self, frac: Union[int, float] = 1, write_connector: Optional[Union[Connector, UID]] = None) -> None:
"""Sample from a [`MultiTableSynthesizer`][ydata.sdk.synthesizers.MultiTableSynthesizer]
Expand Down
3 changes: 1 addition & 2 deletions src/ydata/sdk/synthesizers/regular.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from ydata.datascience.common import PrivacyLevel
from ydata.sdk.common.exceptions import InputError
from ydata.sdk.datasources import DataSource
from ydata.sdk.datasources._models.datatype import DataSourceType
from ydata.sdk.datasources._models.metadata.data_types import DataType
from ydata.sdk.synthesizers.synthesizer import BaseSynthesizer
Expand Down Expand Up @@ -33,7 +32,7 @@ def sample(self, n_samples: int = 1, condition_on: Optional[dict] = None) -> pdD
}
return self._sample(payload=payload)

def fit(self, X: Union[DataSource, pdDataFrame],
def fit(self, X,
privacy_level: PrivacyLevel = PrivacyLevel.HIGH_FIDELITY,
entities: Optional[Union[str, List[str]]] = None,
generate_cols: Optional[List[str]] = None,
Expand Down
68 changes: 42 additions & 26 deletions src/ydata/sdk/synthesizers/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from ydata.sdk.common.logger import create_logger
from ydata.sdk.common.types import UID, Project
from ydata.sdk.connectors import LocalConnector
from ydata.sdk.datasources import DataSource, LocalDataSource
from ydata.sdk.datasources._models.attributes import DataSourceAttrs
from ydata.sdk.datasources._models.datatype import DataSourceType
from ydata.sdk.datasources._models.metadata.data_types import DataType
Expand All @@ -27,8 +26,11 @@
from ydata.sdk.synthesizers._models.synthesizer import Synthesizer as mSynthesizer
from ydata.sdk.synthesizers._models.synthesizers_list import SynthesizersList
from ydata.sdk.synthesizers.anonymizer import build_and_validate_anonimization
from ydata.sdk.utils.logger import SDKLogger
from ydata.sdk.utils.model_mixin import ModelFactoryMixin

logger = SDKLogger(name="SynthesizersLogger")


@typechecked
class BaseSynthesizer(ABC, ModelFactoryMixin):
Expand Down Expand Up @@ -65,7 +67,7 @@ def _init_common(self, client: Optional[Client] = None):
def project(self) -> Project:
return self._project or self._client.project

def fit(self, X: Union[DataSource, pdDataFrame],
def fit(self, X,
privacy_level: PrivacyLevel = PrivacyLevel.HIGH_FIDELITY,
datatype: Optional[Union[DataSourceType, str]] = None,
sortbykey: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -100,17 +102,24 @@ def fit(self, X: Union[DataSource, pdDataFrame],
anonymize (Optional[str]): (optional) fields to anonymize and the anonymization strategy
condition_on: (Optional[List[str]]): (optional) list of features to condition upon
"""

logger.info(dataframe=X,
datatype=datatype.value,
method='synthesizer')

if self._already_fitted():
raise AlreadyFittedError()

datatype = DataSourceType(datatype)

dataset_attrs = self._init_datasource_attributes(
sortbykey, entities, generate_cols, exclude_cols, dtypes)

self._validate_datasource_attributes(X, dataset_attrs, datatype, target)

# If the training data is a pandas dataframe, we first need to create a data source and then the instance
if isinstance(X, pdDataFrame):
from ydata.sdk.datasources import LocalDataSource
if X.empty:
raise EmptyDataError("The DataFrame is empty")
self._logger.info('creating local connector with pandas dataframe')
Expand All @@ -131,9 +140,12 @@ def fit(self, X: Union[DataSource, pdDataFrame],
if isinstance(dataset_attrs, dict):
dataset_attrs = DataSourceAttrs(**dataset_attrs)

self._fit_from_datasource(
X=_X, datatype=datatype, dataset_attrs=dataset_attrs, target=target,
anonymize=anonymize, privacy_level=privacy_level, condition_on=condition_on)
if datatype == DataSourceType.MULTITABLE:
self._fit_from_datasource(_X, datatype=DataSourceType.MULTITABLE)
else:
self._fit_from_datasource(
X=_X, datatype=datatype, dataset_attrs=dataset_attrs, target=target,
anonymize=anonymize, privacy_level=privacy_level, condition_on=condition_on)

@staticmethod
def _init_datasource_attributes(
Expand All @@ -152,37 +164,41 @@ def _init_datasource_attributes(
return DataSourceAttrs(**dataset_attrs)

@staticmethod
def _validate_datasource_attributes(X: Union[DataSource, pdDataFrame], dataset_attrs: DataSourceAttrs, datatype: DataSourceType, target: Optional[str]):
def _validate_datasource_attributes(X, dataset_attrs: DataSourceAttrs, datatype: DataSourceType, target: Optional[str]):
columns = []
if isinstance(X, pdDataFrame):
columns = X.columns
if datatype is None:
raise DataTypeMissingError(
"Argument `datatype` is mandatory for pandas.DataFrame training data")
elif datatype == DataSourceType.MULTITABLE:
tables = [t for t in X.tables.keys()] # noqa: F841
# Does it make sense to add more validations here?
else:
columns = [c.name for c in X.metadata.columns]

if target is not None and target not in columns:
raise DataSourceAttrsError(
"Invalid target: column '{target}' does not exist")

if datatype == DataSourceType.TIMESERIES:
if not dataset_attrs.sortbykey:
if datatype != DataSourceType.MULTITABLE:
if target is not None and target not in columns:
raise DataSourceAttrsError(
"The argument `sortbykey` is mandatory for timeseries datasource.")

invalid_fields = {}
for field, v in dataset_attrs.dict().items():
field_columns = v if field != 'dtypes' else v.keys()
not_in_cols = [c for c in field_columns if c not in columns]
if len(not_in_cols) > 0:
invalid_fields[field] = not_in_cols
"Invalid target: column '{target}' does not exist")

if len(invalid_fields) > 0:
error_msgs = ["\t- Field '{}': columns {} do not exist".format(
f, ', '.join(v)) for f, v in invalid_fields.items()]
raise DataSourceAttrsError(
"The dataset attributes are invalid:\n {}".format('\n'.join(error_msgs)))
if datatype == DataSourceType.TIMESERIES:
if not dataset_attrs.sortbykey:
raise DataSourceAttrsError(
"The argument `sortbykey` is mandatory for timeseries datasource.")

invalid_fields = {}
for field, v in dataset_attrs.dict().items():
field_columns = v if field != 'dtypes' else v.keys()
not_in_cols = [c for c in field_columns if c not in columns]
if len(not_in_cols) > 0:
invalid_fields[field] = not_in_cols

if len(invalid_fields) > 0:
error_msgs = ["\t- Field '{}': columns {} do not exist".format(
f, ', '.join(v)) for f, v in invalid_fields.items()]
raise DataSourceAttrsError(
"The dataset attributes are invalid:\n {}".format('\n'.join(error_msgs)))

@staticmethod
def _metadata_to_payload(
Expand Down Expand Up @@ -225,7 +241,7 @@ def _metadata_to_payload(

def _fit_from_datasource(
self,
X: DataSource,
X,
datatype: DataSourceType,
privacy_level: Optional[PrivacyLevel] = None,
dataset_attrs: Optional[DataSourceAttrs] = None,
Expand Down
105 changes: 105 additions & 0 deletions src/ydata/sdk/utils/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""
In this file it can be found both the logic for the logger and decorator function
"""
import contextlib
import logging
import os
import platform
import subprocess

import pandas as pd
import requests

from ydata.sdk import __version__
from ydata.sdk.datasources._models.datatype import DataSourceType


def is_running_in_databricks():
mask = "DATABRICKS_RUNTIME_VERSION" in os.environ
if "DATABRICKS_RUNTIME_VERSION" in os.environ:
return os.environ["DATABRICKS_RUNTIME_VERSION"]
else:
return str(mask)


def get_datasource_info(dataframe, datatype):
"""
calculate required datasource info
"""
if isinstance(dataframe, pd.DataFrame):
connector = 'csv'
nrows, ncols = dataframe.shape[0], dataframe.shape[1]
ntables = None # calculate the number of rows and cols
else:
connector = dataframe.connector_type
if DataSourceType(datatype) != DataSourceType.MULTITABLE:
nrows = dataframe.metadata.number_of_rows
ncols = len(dataframe.metadata.columns)
ntables = 1
else:
nrows = 0
ncols = 0
ntables = len(dataframe.tables.keys())
return connector, nrows, ncols, ntables


def analytics_features(datatype: str, connector: str, nrows: int, ncols: int, ntables: int, method: str, dbx: str) -> None:
"""
Returns metrics and analytics from ydata-fabric-sdk
"""
endpoint = "https://packages.ydata.ai/ydata-fabric-sdk?"
package_version = __version__

if (
bool(os.getenv("YDATA_FABRIC_SDK_NO_ANALYTICS")
) is not True and package_version != "0.0.dev0"
):
try:
subprocess.check_output("nvidia-smi")
gpu_present = True
except Exception:
gpu_present = False

python_version = ".".join(platform.python_version().split(".")[:2])

with contextlib.suppress(Exception):
request_message = (
f"{endpoint}python_version={python_version}"
f"&datatype={datatype}"
f"&connector={connector}"
f"&ncols={ncols}"
f"&nrows={nrows}"
f"&ntables={ntables}"
f"&method={method}"
f"&os={platform.system()}"
f"&gpu={str(gpu_present)}"
f"&dbx={dbx}"
)

requests.get(request_message)


class SDKLogger(logging.Logger):
def __init__(self, name: str, level: int = logging.INFO):
super().__init__(name, level)

def info(self, dataframe, datatype: str, method: str) -> None: # noqa: ANN001

dbx = is_running_in_databricks()

connector, nrows, ncols, ntables = get_datasource_info(dataframe, datatype)

analytics_features(
datatype=datatype,
method=method,
connector=connector,
nrows=nrows,
ncols=ncols,
ntables=ntables,
dbx=dbx
)

super().info(
f"[PROFILING] Calculating profile with the following characteristics "
f"- {datatype} | {method} | {connector}."
)

0 comments on commit 7b22974

Please sign in to comment.