Skip to content

Commit

Permalink
Add common data science functions to ds_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
nrccua-timr committed Nov 3, 2023
1 parent d6c254a commit 1d532c9
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 60 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ default_language_version:
python: python3.11
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-added-large-files
- id: check-ast
Expand Down
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[FORMAT]
max-line-length=140
max-line-length=150
15 changes: 15 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,21 @@ History
=======


v0.19.4 (2023-11-03)

* Add common data science functions to ds_utils.py.
* Update cython==3.0.5.
* Update httpx==0.25.1.
* Update pandas==2.1.2.
* Update pylint==3.0.2.
* Update pytest==7.4.3.
* Update wheel==0.41.3.
* Add python library haversine==2.8.0.
* Add python library polars==0.19.12.
* Add python library pyarrow==13.0.0.
* Add python library pyspark==3.4.1.


v0.19.3 (2023-10-20)

* Add TrustServerCertificate option in sqlserver connection string, enabling use of driver {ODBC Driver 18 for SQL Server}.
Expand Down
286 changes: 236 additions & 50 deletions aioradio/ds_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""utils.py."""

# pylint: disable=broad-except
# pylint: disable=import-outside-toplevel
# pylint: disable=invalid-name
# pylint: disable=logging-fstring-interpolation
# pylint: disable=no-member
# pylint: disable=protected-access
# pylint: disable=too-many-arguments
# pylint: disable=too-many-boolean-expressions
# pylint: disable=unnecessary-comprehension
Expand All @@ -16,12 +18,18 @@
import os
import pickle
import warnings
from math import cos, degrees, radians, sin
from platform import system
from tempfile import NamedTemporaryFile
from time import sleep, time

import boto3
import numpy as np
import pyarrow as pa
import pandas as pd
import polars as pl
from haversine import haversine, Unit
from pyspark.sql import SparkSession
from smb.SMBConnection import SMBConnection

warnings.simplefilter(action='ignore', category=UserWarning)
Expand All @@ -36,6 +44,158 @@
c_handler.setFormatter(c_format)
logger.addHandler(c_handler)

spark = SparkSession.builder.getOrCreate()


############################### Databricks functions ################################


def db_catalog(env):
"""Return the DataBricks catalog based on the passed in environment."""

catalog = ''
if env == 'sandbox':
catalog = 'dsc_sbx'
elif env == 'prod':
catalog = 'dsc_prd'

return catalog


def sql_to_polars_df(sql):
"""Get polars DataFrame from SQL query results."""

return pl.from_arrow(pa.Table.from_batches(spark.sql(sql)._collect_as_arrow()))


def does_db_table_exists(name):
"""Check if delta table exists in databricks."""

exists = False
try:
spark.sql(f"describe formatted {name}")
exists = True
except Exception:
pass

return exists


def merge_spark_df_in_db(df, target, on, partition_by=None):
"""Convert spark DF to staging table than merge with target table in
Databricks."""

stage = f"{target}_stage"

if not does_db_table_exists(target):
if partition_by is None:
df.write.option("delta.columnMapping.mode", "name").saveAsTable(target)
else:
df.write.option("delta.columnMapping.mode", "name").partitionBy(partition_by).saveAsTable(target)
else:
if partition_by is None:
df.write.option("delta.columnMapping.mode", "name").mode('overwrite').saveAsTable(stage)
else:
df.write.option("delta.columnMapping.mode", "name").mode('overwrite').partitionBy(partition_by).saveAsTable(stage)

on_clause = ' AND '.join(f'{target}.{col} = {stage}.{col}' for col in on)
match_clause = ', '.join(f'{target}.{col} = {stage}.{col}' for col in df.columns if col != 'CREATED_DATETIME')

try:
spark.sql(f'MERGE INTO {target} USING {stage} ON {on_clause} WHEN MATCHED THEN UPDATE SET {match_clause} WHEN NOT MATCHED THEN INSERT *')
spark.sql(f'DROP TABLE {stage}')
except Exception:
spark.sql(f'DROP TABLE {stage}')
raise


def merge_pandas_df_in_db(df, target, on, partition_by=None):
"""Convert pandas DF to staging table than merge with target table in
Databricks."""

stage = f"{target}_stage"

for col, dtype in df.dtypes.apply(lambda x: x.name).to_dict().items():
if dtype == 'object':
df[col] = df[col].astype('string[pyarrow]')
df[col].mask(df[col].isna(), '', inplace=True)
elif dtype == 'string':
# pyspark will throw an exception if strings are set to <NA> so convert to empty string
df[col].mask(df[col].isna(), '', inplace=True)

if not does_db_table_exists(target):
if partition_by is None:
spark.createDataFrame(df).write.option("delta.columnMapping.mode", "name").saveAsTable(target)
else:
spark.createDataFrame(df).write.option("delta.columnMapping.mode", "name").partitionBy(partition_by).saveAsTable(target)
else:
if partition_by is None:
spark.createDataFrame(df).write.option("delta.columnMapping.mode", "name").mode('overwrite').saveAsTable(stage)
else:
spark.createDataFrame(df).write.option("delta.columnMapping.mode", "name").mode('overwrite').partitionBy(partition_by).saveAsTable(stage)

on_clause = ' AND '.join(f'{target}.{col} = {stage}.{col}' for col in on)
match_clause = ', '.join(f'{target}.{col} = {stage}.{col}' for col in df.columns if col != 'CREATED_DATETIME')

try:
spark.sql(f'MERGE INTO {target} USING {stage} ON {on_clause} WHEN MATCHED THEN UPDATE SET {match_clause} WHEN NOT MATCHED THEN INSERT *')
spark.sql(f'DROP TABLE {stage}')
except Exception:
spark.sql(f'DROP TABLE {stage}')
raise


################################## DataFrame functions ####################################


def convert_pyspark_dtypes_to_pandas(df):
"""The pyspark toPandas function converts strings to objects.
This function takes the resulting df and converts the object dtypes
to string[pyarrow], then it converts empty strings to pd.NA.
"""

for col, dtype in df.dtypes.apply(lambda x: x.name).to_dict().items():

if dtype == 'object':
df[col] = df[col].astype('string[pyarrow]')
df[col].mask(df[col] == '', pd.NA, inplace=True)
elif (dtype.startswith('int') or dtype.startswith('float')) and not dtype.endswith('[pyarrow]'):
df[col] = df[col].astype(f'{dtype}[pyarrow]')
elif 'string' in dtype:
df[col] = df[col].astype('string[pyarrow]')
df[col].mask(df[col] == '', pd.NA, inplace=True)

return df


def remove_pyarrow_dtypes(df):
"""Switch pyarrow dtype to non pyarrow dtype (int8['pyarrow'] to int8)"""

df = df.astype({k: v.replace('[pyarrow]', '') for k, v in df.dtypes.apply(lambda x: x.name).to_dict().items()})
return df


################################## AWS functions ####################################


def get_boto3_session(env):
"""Get Boto3 Session."""

aws_profile = os.getenv('AWS_PROFILE')

try:
if aws_profile is not None:
del os.environ['AWS_PROFILE']
aws_creds = get_aws_creds(env)
boto3_session = boto3.Session(**aws_creds)
except ValueError:
if aws_profile is not None:
os.environ["AWS_PROFILE"] = aws_profile
boto3_session = boto3.Session()

return boto3_session


def file_to_s3(s3_client, local_filepath, s3_bucket, key):
"""Write file to s3."""
Expand Down Expand Up @@ -82,20 +242,6 @@ def delete_s3_object(s3_client, bucket, s3_prefix):
return s3_client.delete_object(Bucket=bucket, Key=s3_prefix)


def get_fice_institutions_map(db_config):
"""Get mapping of fice to college from mssql table."""

from aioradio.pyodbc import pyodbc_query_fetchall

result = {}
with DbInfo(db_config) as target_db:
query = "SELECT FICE, Institution FROM EESFileuploadAssignments WHERE FileCategory = 'EnrollmentLens'"
rows = pyodbc_query_fetchall(conn=target_db.conn, query=query)
result = {fice: institution for fice, institution in rows}

return result


def bytes_to_s3(s3_client, s3_bucket, key, body):
"""Write data in bytes to s3."""

Expand Down Expand Up @@ -185,24 +331,6 @@ def get_s3_pickle_to_object(s3_client, s3_bucket, key):
return data


def get_ftp_connection(secret_id, port=139, is_direct_tcp=False, env='sandbox'):
"""Get SMB Connection."""

secret_client = get_boto3_session(env).client("secretsmanager", region_name='us-east-1')
creds = json.loads(secret_client.get_secret_value(SecretId=secret_id)['SecretString'])
conn = SMBConnection(
creds['user'],
creds['password'],
secret_id,
creds['server'],
use_ntlm_v2=True,
is_direct_tcp=is_direct_tcp
)
conn.connect(creds['server'], port)

return conn


def get_aws_creds(env):
"""Get AWS credentials from environment variables."""

Expand All @@ -223,6 +351,79 @@ def get_aws_creds(env):
return aws_creds


get_s3_csv_to_df = get_large_s3_csv_to_df
get_s3_parquet_to_df = get_large_s3_parquet_to_df


################################# Misc functions ####################################


def bearing(slat, elat, slon, elon):
"""Bearing function."""

slat, elat, slon, elon = radians(slat), radians(elat), radians(slon), radians(elon)
var_dl = elon - slon
var_x = cos(elat) * sin(var_dl)
var_y = cos(slat) * sin(elat) - sin(slat) * cos(elat) * cos(var_dl)
return (degrees(np.arctan2(var_x, var_y)) + 360) % 360


def apply_bearing(dataframe, latitude, longitude):
"""Apply bearing function on split dataframe."""

return dataframe.apply(lambda x: bearing(x.LATITUDE, latitude, x.LONGITUDE, longitude), axis=1)


def apply_haversine(dataframe, latitude, longitude):
"""Apply haversine function on split dataframe."""

return dataframe.apply(lambda x: haversine((x.LATITUDE, x.LONGITUDE), (latitude, longitude), unit=Unit.MILES), axis=1)


def logit(x, a, b, c, d):
"""Logit function."""

return a / (1 + np.exp(-c * (x - d))) + b


def apply_logit(dataframe, a, b, c, d):
"""Apply logit function on split dataframe."""

return dataframe.apply(lambda x: logit(x, a, b, c, d))


def get_fice_institutions_map(db_config):
"""Get mapping of fice to college from mssql table."""

from aioradio.pyodbc import pyodbc_query_fetchall

result = {}
with DbInfo(db_config) as target_db:
query = "SELECT FICE, Institution FROM EESFileuploadAssignments WHERE FileCategory = 'EnrollmentLens'"
rows = pyodbc_query_fetchall(conn=target_db.conn, query=query)
result = {fice: institution for fice, institution in rows}

return result


def get_ftp_connection(secret_id, port=139, is_direct_tcp=False, env='sandbox'):
"""Get SMB Connection."""

secret_client = get_boto3_session(env).client("secretsmanager", region_name='us-east-1')
creds = json.loads(secret_client.get_secret_value(SecretId=secret_id)['SecretString'])
conn = SMBConnection(
creds['user'],
creds['password'],
secret_id,
creds['server'],
use_ntlm_v2=True,
is_direct_tcp=is_direct_tcp
)
conn.connect(creds['server'], port)

return conn


def monitor_domino_run(domino, run_id, sleep_time=10):
"""Monitor domino job run and return True/False depending if job was
successful."""
Expand All @@ -241,24 +442,6 @@ def monitor_domino_run(domino, run_id, sleep_time=10):
return status


def get_boto3_session(env):
"""Get Boto3 Session."""

aws_profile = os.getenv('AWS_PROFILE')

try:
if aws_profile is not None:
del os.environ['AWS_PROFILE']
aws_creds = get_aws_creds(env)
boto3_session = boto3.Session(**aws_creds)
except ValueError:
if aws_profile is not None:
os.environ["AWS_PROFILE"] = aws_profile
boto3_session = boto3.Session()

return boto3_session


def get_domino_connection(secret_id, project, host, env='sandbox'):
"""Get domino connection."""

Expand All @@ -268,6 +451,9 @@ def get_domino_connection(secret_id, project, host, env='sandbox'):
return Domino(project=project, api_key=api_key, host=host)


######################## Postgres or MSSQL Connection Classes #######################


class DB_CONNECT():
"""[Class for database connection]
Expand Down
Loading

0 comments on commit 1d532c9

Please sign in to comment.