Skip to content

Commit

Permalink
Merge pull request #43 from gismart/recipes
Browse files Browse the repository at this point in the history
Add dict_merge func and default secret id const
  • Loading branch information
maxim-lisovsky-gismart authored Jan 15, 2024
2 parents 7e31719 + 20448c6 commit 279d640
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Add `--upgrade` option to update existing package to a new version
Specify package link in your `requirements.txt`:

```txt
git+https://github.com/gismart/bi-utils@0.15.3#egg=bi-utils-gismart
git+https://github.com/gismart/bi-utils@0.16.0#egg=bi-utils-gismart
```

### Usage
Expand Down
1 change: 1 addition & 0 deletions bi_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .queue_exporter import QueueExporter
from .aws import db, s3, connection
from .files import data_filename
from .recipes import dict_merge
from .decorators import retry
from .sql import get_query
from .qa import df_test
9 changes: 5 additions & 4 deletions bi_utils/aws/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from .locopy import locopy


DEFAULT_SECRET_ID = "prod/redshift/bi_utils"
logger = logging.getLogger(__name__)
cached_creds: Dict[str, dict] = {}


def get_creds(secret_id: str = "prod/redshift/analytics") -> dict:
def get_creds(secret_id: str = DEFAULT_SECRET_ID) -> dict:
"""Get AWS credentials"""
creds = cached_creds.get(secret_id)
if not creds:
Expand All @@ -27,7 +28,7 @@ def get_creds(secret_id: str = "prod/redshift/analytics") -> dict:


def get_redshift(
secret_id: str = "prod/redshift/analytics",
secret_id: str = DEFAULT_SECRET_ID,
database: Optional[str] = None,
host: Optional[str] = None,
) -> locopy.Redshift:
Expand All @@ -48,7 +49,7 @@ def get_redshift(


def create_engine(
secret_id: str = "prod/redshift/analytics",
secret_id: str = DEFAULT_SECRET_ID,
drivername: str = "postgresql+psycopg2",
database: Optional[str] = None,
host: Optional[str] = None,
Expand All @@ -72,7 +73,7 @@ def create_engine(

def connect(
schema: Optional[str] = None,
secret_id: str = "prod/redshift/analytics",
secret_id: str = DEFAULT_SECRET_ID,
database: Optional[str] = None,
host: Optional[str] = None,
) -> psycopg2.extensions.connection:
Expand Down
14 changes: 7 additions & 7 deletions bi_utils/aws/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def upload_file(
bucket_dir: str = "dwh/temp",
columns: Optional[Sequence] = None,
delete_s3_after: bool = True,
secret_id: str = "prod/redshift/analytics",
secret_id: str = connection.DEFAULT_SECRET_ID,
database: Optional[str] = None,
host: Optional[str] = None,
retries: int = 0,
Expand Down Expand Up @@ -87,7 +87,7 @@ def download_files(
bucket_dir: str = "dwh/temp",
delete_s3_before: bool = False,
delete_s3_after: bool = True,
secret_id: str = "prod/redshift/analytics",
secret_id: str = connection.DEFAULT_SECRET_ID,
database: Optional[str] = None,
host: Optional[str] = None,
retries: int = 0,
Expand Down Expand Up @@ -155,7 +155,7 @@ def upload_data(
bucket_dir: str = "dwh/temp",
columns: Optional[Sequence] = None,
remove_file: bool = False,
secret_id: str = "prod/redshift/analytics",
secret_id: str = connection.DEFAULT_SECRET_ID,
database: Optional[str] = None,
host: Optional[str] = None,
retries: int = 0,
Expand Down Expand Up @@ -211,7 +211,7 @@ def download_data(
parse_bools: Optional[Sequence[str]] = None,
dtype: Optional[dict] = None,
chunking: bool = False,
secret_id: str = "prod/redshift/analytics",
secret_id: str = connection.DEFAULT_SECRET_ID,
database: Optional[str] = None,
host: Optional[str] = None,
retries: int = 0,
Expand Down Expand Up @@ -308,7 +308,7 @@ def update(
schema: str,
params_set: dict,
params_where: Optional[dict],
secret_id: str = "prod/redshift/analytics",
secret_id: str = connection.DEFAULT_SECRET_ID,
database: Optional[str] = None,
host: Optional[str] = None,
) -> None:
Expand All @@ -330,7 +330,7 @@ def update(
def delete(
table: str,
schema: str,
secret_id: str = "prod/redshift/analytics",
secret_id: str = connection.DEFAULT_SECRET_ID,
database: Optional[str] = None,
host: Optional[str] = None,
**conditions: Any,
Expand All @@ -353,7 +353,7 @@ def delete(
def get_columns(
table: str,
schema: str,
secret_id: str = "prod/redshift/analytics",
secret_id: str = connection.DEFAULT_SECRET_ID,
database: Optional[str] = None,
host: Optional[str] = None,
) -> Sequence[str]:
Expand Down
8 changes: 4 additions & 4 deletions bi_utils/queue_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def export_df(
delete_file_after: bool = False,
delete_s3_after: bool = False,
partition_cols: Optional[Sequence] = None,
secret_id: str = "prod/redshift/analytics",
secret_id: str = aws.connection.DEFAULT_SECRET_ID,
) -> None:
"""
Save dataframe to `filepath` if `s3_bucket`, `s3_bucket_dir`, `schema`, `table` not passed
Expand Down Expand Up @@ -121,7 +121,7 @@ def export_file(
table: Optional[str] = None,
delete_file_after: bool = False,
delete_s3_after: bool = False,
secret_id: str = "prod/redshift/analytics",
secret_id: str = aws.connection.DEFAULT_SECRET_ID,
) -> None:
"""
Export file to S3 if `s3_bucket` and `s3_bucket_dir` passed
Expand Down Expand Up @@ -161,7 +161,7 @@ def _export_df(
delete_file_after: bool = False,
delete_s3_after: bool = False,
partition_cols: Optional[Sequence] = None,
secret_id: str = "prod/redshift/analytics",
secret_id: str = aws.connection.DEFAULT_SECRET_ID,
) -> None:
filename = os.path.basename(file_path)
if columns:
Expand Down Expand Up @@ -211,7 +211,7 @@ def _export_file(
table: Optional[str] = None,
delete_file_after: bool = False,
delete_s3_after: bool = False,
secret_id: str = "prod/redshift/analytics",
secret_id: str = aws.connection.DEFAULT_SECRET_ID,
) -> None:
if schema and table and (".csv" in file_path.lower() or ".parquet" in file_path.lower()):
aws.db.upload_file(
Expand Down
9 changes: 9 additions & 0 deletions bi_utils/recipes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
def dict_merge(receiver: dict, updater: dict) -> dict:
"""Update receiver dict with updater dict's values recursively"""
receiver = receiver.copy()
for k in updater:
if (k in receiver and isinstance(receiver[k], dict) and isinstance(updater[k], dict)):
receiver[k] = dict_merge(receiver[k], updater[k])
else:
receiver[k] = updater[k]
return receiver
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setuptools.setup(
name="bi-utils-gismart",
version="0.15.3",
version="0.16.0",
author="gismart",
author_email="[email protected]",
description="Utils for BI team",
Expand Down
23 changes: 23 additions & 0 deletions tests/test_recipes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from bi_utils import recipes


@pytest.mark.parametrize(
"reciever, updater, expected_dict",
[
(
{"a": {"one": 1, "two": [2, 2]}, "b": 3},
{"a": {"one": 11}},
{"a": {"one": 11, "two": [2, 2]}, "b": 3},
),
(
{"a": {"one": 1, "two": [2, 2]}, "b": 3},
{"a": {"two": [2, 2, 2]}},
{"a": {"one": 1, "two": [2, 2, 2]}, "b": 3},
),
],
)
def test_dict_merge(reciever, updater, expected_dict):
result = recipes.dict_merge(reciever, updater)
assert result == expected_dict

0 comments on commit 279d640

Please sign in to comment.