From 3deae1a4ac54a5111a74fb2b41832b158a6642d2 Mon Sep 17 00:00:00 2001 From: Mikita Sakalouski <38785549+mikita-sakalouski@users.noreply.github.com> Date: Tue, 29 Oct 2024 22:14:54 +0100 Subject: [PATCH 01/33] feature: add support for Spark Connect (#63) Refactor code to support Spark Connect --------- Co-authored-by: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> --- .github/workflows/test.yml | 11 +- makefile | 6 +- pyproject.toml | 68 +- src/koheesio/__about__.py | 4 +- src/koheesio/__init__.py | 2 +- src/koheesio/asyncio/__init__.py | 17 +- src/koheesio/asyncio/http.py | 99 +- src/koheesio/context.py | 28 +- src/koheesio/integrations/__init__.py | 3 - src/koheesio/integrations/box.py | 99 +- .../integrations/snowflake/__init__.py | 551 ++++++++ .../integrations/snowflake/test_utils.py | 69 + .../spark/dq/spark_expectations.py | 5 +- src/koheesio/integrations/spark/sftp.py | 36 +- src/koheesio/integrations/spark/snowflake.py | 1106 +++++++++++++++++ .../integrations/spark/tableau/hyper.py | 31 +- .../integrations/spark/tableau/server.py | 32 +- src/koheesio/logger.py | 12 +- src/koheesio/models/__init__.py | 86 +- src/koheesio/models/reader.py | 2 +- src/koheesio/models/sql.py | 26 +- src/koheesio/notifications/slack.py | 15 +- src/koheesio/pandas/__init__.py | 5 +- src/koheesio/pandas/readers/excel.py | 2 +- src/koheesio/secrets/__init__.py | 7 +- src/koheesio/spark/__init__.py | 69 +- src/koheesio/spark/delta.py | 32 +- src/koheesio/spark/etl_task.py | 14 +- src/koheesio/spark/functions/__init__.py | 11 + .../spark/readers/databricks/autoloader.py | 10 +- src/koheesio/spark/readers/delta.py | 34 +- src/koheesio/spark/readers/dummy.py | 2 +- src/koheesio/spark/readers/excel.py | 2 +- src/koheesio/spark/readers/file_loader.py | 12 +- src/koheesio/spark/readers/hana.py | 4 +- src/koheesio/spark/readers/jdbc.py | 8 +- src/koheesio/spark/readers/kafka.py | 17 +- src/koheesio/spark/readers/memory.py | 70 +- src/koheesio/spark/readers/metastore.py | 2 +- src/koheesio/spark/readers/rest_api.py | 6 +- .../spark/readers/spark_sql_reader.py | 2 +- src/koheesio/spark/readers/teradata.py | 2 +- src/koheesio/spark/snowflake.py | 131 +- .../spark/transformations/__init__.py | 54 +- src/koheesio/spark/transformations/arrays.py | 83 +- .../spark/transformations/camel_to_snake.py | 6 +- .../spark/transformations/cast_to_datatype.py | 9 +- .../transformations/date_time/__init__.py | 10 +- .../transformations/date_time/interval.py | 65 +- .../spark/transformations/drop_column.py | 2 +- src/koheesio/spark/transformations/dummy.py | 2 +- .../spark/transformations/get_item.py | 2 +- src/koheesio/spark/transformations/hash.py | 10 +- src/koheesio/spark/transformations/lookup.py | 21 +- .../spark/transformations/repartition.py | 16 +- src/koheesio/spark/transformations/replace.py | 4 +- .../spark/transformations/row_number_dedup.py | 12 +- .../spark/transformations/sql_transform.py | 15 +- .../transformations/strings/change_case.py | 6 +- .../spark/transformations/strings/concat.py | 7 +- .../spark/transformations/strings/pad.py | 2 +- .../spark/transformations/strings/regexp.py | 8 +- .../spark/transformations/strings/replace.py | 4 +- .../spark/transformations/strings/split.py | 6 +- .../transformations/strings/substring.py | 8 +- .../spark/transformations/strings/trim.py | 2 +- .../spark/transformations/transform.py | 13 +- src/koheesio/spark/transformations/uuid5.py | 14 +- src/koheesio/spark/utils.py | 203 --- src/koheesio/spark/utils/__init__.py | 27 + src/koheesio/spark/utils/common.py | 382 ++++++ src/koheesio/spark/utils/connect.py | 19 + src/koheesio/spark/writers/__init__.py | 11 +- src/koheesio/spark/writers/buffer.py | 56 +- src/koheesio/spark/writers/delta/batch.py | 40 +- src/koheesio/spark/writers/delta/scd.py | 64 +- src/koheesio/spark/writers/delta/stream.py | 4 +- src/koheesio/spark/writers/delta/utils.py | 6 +- src/koheesio/spark/writers/dummy.py | 11 +- src/koheesio/spark/writers/file_writer.py | 12 +- src/koheesio/spark/writers/kafka.py | 12 +- src/koheesio/spark/writers/stream.py | 51 +- src/koheesio/sso/okta.py | 16 +- src/koheesio/steps/__init__.py | 77 +- src/koheesio/steps/dummy.py | 2 +- src/koheesio/steps/http.py | 33 +- src/koheesio/utils.py | 17 +- tests/asyncio/test_asyncio_http.py | 1 + tests/snowflake/test_snowflake.py | 255 ++++ tests/spark/conftest.py | 138 +- .../integrations/snowflake/test_snowflake.py | 375 ------ .../snowflake/test_spark_snowflake.py | 296 +++++ .../integrations/snowflake/test_sync_task.py | 122 +- .../spark/integrations/tableau/test_hyper.py | 1 + tests/spark/readers/test_delta_reader.py | 6 +- tests/spark/readers/test_hana.py | 8 - tests/spark/readers/test_jdbc.py | 26 +- tests/spark/readers/test_memory.py | 5 +- tests/spark/readers/test_metastore_reader.py | 3 +- tests/spark/readers/test_rest_api.py | 6 +- tests/spark/readers/test_teradata.py | 11 +- tests/spark/tasks/test_etl_task.py | 40 +- tests/spark/test_spark.py | 13 +- tests/spark/test_spark_utils.py | 17 +- .../date_time/test_interval.py | 10 +- .../strings/test_change_case.py | 3 +- .../transformations/strings/test_concat.py | 4 +- .../spark/transformations/strings/test_pad.py | 3 +- .../transformations/strings/test_regexp.py | 5 +- .../transformations/strings/test_split.py | 5 +- .../strings/test_string_replace.py | 4 +- .../transformations/strings/test_substring.py | 3 +- .../transformations/strings/test_trim.py | 3 +- tests/spark/transformations/test_arrays.py | 2 +- .../transformations/test_cast_to_datatype.py | 16 +- tests/spark/transformations/test_get_item.py | 5 +- .../spark/transformations/test_repartition.py | 6 +- tests/spark/transformations/test_replace.py | 5 +- .../transformations/test_row_number_dedup.py | 8 +- tests/spark/transformations/test_transform.py | 2 +- .../spark/writers/delta/test_delta_writer.py | 25 +- tests/spark/writers/delta/test_scd.py | 16 +- tests/spark/writers/test_file_writer.py | 16 +- tests/steps/test_http.py | 2 + tests/steps/test_steps.py | 10 - tests/utils/test_utils.py | 8 - 126 files changed, 4093 insertions(+), 1565 deletions(-) create mode 100644 src/koheesio/integrations/snowflake/__init__.py create mode 100644 src/koheesio/integrations/snowflake/test_utils.py create mode 100644 src/koheesio/integrations/spark/snowflake.py create mode 100644 src/koheesio/spark/functions/__init__.py delete mode 100644 src/koheesio/spark/utils.py create mode 100644 src/koheesio/spark/utils/__init__.py create mode 100644 src/koheesio/spark/utils/common.py create mode 100644 src/koheesio/spark/utils/connect.py create mode 100644 tests/snowflake/test_snowflake.py delete mode 100644 tests/spark/integrations/snowflake/test_snowflake.py create mode 100644 tests/spark/integrations/snowflake/test_spark_snowflake.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1f724466..1d6e640b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,6 +6,7 @@ on: pull_request: branches: - main + - release/* workflow_dispatch: inputs: logLevel: @@ -40,8 +41,8 @@ jobs: fetch-depth: 0 ref: ${{ github.event.pull_request.head.ref }} repository: ${{ github.event.pull_request.head.repo.full_name }} - - name: Fetch main branch - run: git fetch origin main:main + - name: Fetch target branch + run: git fetch origin ${{ github.event.pull_request.base.ref || 'main'}}:${{ github.event.pull_request.base.ref || 'main'}} - name: Check changes id: check run: | @@ -71,10 +72,12 @@ jobs: # os: [ubuntu-latest, windows-latest, macos-latest] # FIXME: Add Windows and macOS os: [ubuntu-latest] python-version: ['3.9', '3.10', '3.11', '3.12'] - pyspark-version: ['33', '34', '35'] + pyspark-version: ['33', '34', '35', '35r'] exclude: - python-version: '3.9' pyspark-version: '35' + - python-version: '3.9' + pyspark-version: '35r' - python-version: '3.11' pyspark-version: '33' - python-version: '3.11' @@ -100,7 +103,7 @@ jobs: # hatch fmt --check --python=${{ matrix.python-version }} - name: Run tests - run: hatch test --python=${{ matrix.python-version }} -i version=pyspark${{ matrix.pyspark-version }} + run: hatch test --python=${{ matrix.python-version }} -i version=pyspark${{ matrix.pyspark-version }} --verbose # https://github.com/marketplace/actions/alls-green#why final_check: # This job does nothing and is only used for the branch protection diff --git a/makefile b/makefile index 54da8d05..584ce49a 100644 --- a/makefile +++ b/makefile @@ -105,16 +105,16 @@ coverage: cov all-tests: @echo "\033[1mRunning all tests:\033[0m\n\033[35m This will run the full test suite\033[0m" @echo "\033[1;31mWARNING:\033[0;33m This may take upward of 20-30 minutes to complete!\033[0m" - @hatch test --no-header --no-summary + @hatch test --no-header .PHONY: spark-tests ## testing - Run SPARK tests in ALL environments spark-tests: @echo "\033[1mRunning Spark tests:\033[0m\n\033[35m This will run the Spark test suite against all specified environments\033[0m" @echo "\033[1;31mWARNING:\033[0;33m This may take upward of 20-30 minutes to complete!\033[0m" - @hatch test -m spark --no-header --no-summary + @hatch test -m spark --no-header .PHONY: non-spark-tests ## testing - Run non-spark tests in ALL environments non-spark-tests: @echo "\033[1mRunning non-Spark tests:\033[0m\n\033[35m This will run the non-Spark test suite against all specified environments\033[0m" - @hatch test -m "not spark" --no-header --no-summary + @hatch test -m "not spark" --no-header .PHONY: dev-test ## testing - Run pytest, with all tests in the dev environment dev-test: diff --git a/pyproject.toml b/pyproject.toml index be0a5496..3edd7eca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,19 +57,30 @@ async_http = [ "nest-asyncio>=1.6.0", ] box = ["boxsdk[jwt]==3.8.1"] -pandas = ["pandas>=1.3", "setuptools", "numpy<2.0.0"] +pandas = ["pandas>=1.3", "setuptools", "numpy<2.0.0", "pandas-stubs"] pyspark = ["pyspark>=3.2.0", "pyarrow>13"] +pyspark_connect = ["pyspark[connect]>=3.5"] se = ["spark-expectations>=2.1.0"] # SFTP dependencies in to_csv line_iterator sftp = ["paramiko>=2.6.0"] delta = ["delta-spark>=2.2"] excel = ["openpyxl>=3.0.0"] # Tableau dependencies -tableau = [ - "tableauhyperapi>=0.0.19484", - "tableauserverclient>=0.25", +tableau = ["tableauhyperapi>=0.0.19484", "tableauserverclient>=0.25"] +# Snowflake dependencies +snowflake = ["snowflake-connector-python>=3.12.0"] +# Development dependencies +dev = [ + "black", + "isort", + "ruff", + "mypy", + "pylint", + "colorama", + "types-PyYAML", + "types-requests", + ] -dev = ["black", "isort", "ruff", "mypy", "pylint", "colorama", "types-PyYAML"] test = [ "chispa", "coverage[toml]", @@ -186,6 +197,7 @@ features = [ "excel", "se", "box", + "snowflake", "tableau", "dev", ] @@ -216,6 +228,7 @@ lint = ["- ruff-fmt", "- mypy-check", "pylint-check"] log-versions = "python --version && {env:HATCH_UV} pip freeze | grep pyspark" test = "- pytest{env:HATCH_TEST_ARGS:} {args} -n 2" spark-tests = "test -m spark" +spark-remote-tests = "test -m spark -m \"not skip_on_remote_session\"" non-spark-tests = "test -m \"not spark\"" # scripts.run = "echo bla {env:HATCH_TEST_ARGS:} {args}" @@ -251,6 +264,7 @@ features = [ "sftp", "delta", "excel", + "snowflake", "tableau", "dev", "test", @@ -258,10 +272,10 @@ features = [ parallel = true retries = 2 -retry-delay = 1 +retry-delay = 3 [tool.hatch.envs.hatch-test.scripts] -run = "pytest{env:HATCH_TEST_ARGS:} {args} -n auto" +run = "pytest{env:HATCH_TEST_ARGS:} {args}" run-cov = "coverage run -m pytest{env:HATCH_TEST_ARGS:} {args}" cov-combine = "coverage combine" cov-report = "coverage report" @@ -273,11 +287,11 @@ version = ["pyspark33", "pyspark34"] [[tool.hatch.envs.hatch-test.matrix]] python = ["3.10"] -version = ["pyspark33", "pyspark34", "pyspark35"] +version = ["pyspark33", "pyspark34", "pyspark35", "pyspark35r"] [[tool.hatch.envs.hatch-test.matrix]] python = ["3.11", "3.12"] -version = ["pyspark35"] +version = ["pyspark35", "pyspark35r"] [tool.hatch.envs.hatch-test.overrides] matrix.version.extra-dependencies = [ @@ -299,6 +313,9 @@ matrix.version.extra-dependencies = [ { value = "pyspark>=3.5,<3.6", if = [ "pyspark35", ] }, + { value = "pyspark[connect]>=3.5,<3.6", if = [ + "pyspark35r", + ] }, ] name.".*".env-vars = [ @@ -308,10 +325,19 @@ name.".*".env-vars = [ { key = "KOHEESIO__PRINT_LOGO", value = "False" }, ] +name.".*(pyspark35r).*".env-vars = [ + # enable soark connect, setting to local as it will trigger + # spark to start local spark server and enbale remote session + { key = "SPARK_REMOTE", value = "local" }, + { key = "SPARK_TESTING", value = "True" }, +] + + [tool.pytest.ini_options] addopts = "-q --color=yes --order-scope=module" log_level = "CRITICAL" testpaths = ["tests"] +asyncio_default_fixture_loop_scope = "scope" markers = [ "default: added to all tests by default if no other marker expect of standard pytest markers is present", "spark: mark a test as a Spark test", @@ -325,10 +351,18 @@ filterwarnings = [ # pyspark.pandas warnings "ignore:distutils.*:DeprecationWarning:pyspark.pandas.*", "ignore:'PYARROW_IGNORE_TIMEZONE'.*:UserWarning:pyspark.pandas.*", + # pydantic warnings + "ignore:A custom validator is returning a value other than `self`.*.*:UserWarning:pydantic.main.*", + # pyspark.sql.connect warnings + "ignore:is_datetime64tz_dtype.*:DeprecationWarning:pyspark.sql.connect.*", + "ignore:distutils.*:DeprecationWarning:pyspark.sql.connect.*", + # pyspark.sql.pandas warnings "ignore:distutils.*:DeprecationWarning:pyspark.sql.pandas.*", "ignore:is_datetime64tz_dtype.*:DeprecationWarning:pyspark.sql.pandas.*", + "ignore:is_categorical_dtype.*:DeprecationWarning:pyspark.sql.pandas.*", + "ignore:iteritems.*:FutureWarning:pyspark.sql.pandas.*", # Koheesio warnings - "ignore:DayTimeIntervalType .*:UserWarning:koheesio.spark.snowflake.*", + "ignore:DayTimeIntervalType.*:UserWarning:koheesio.spark.snowflake.*", ] [tool.coverage.run] @@ -403,8 +437,9 @@ features = [ "box", "pandas", "pyspark", - "se", + # "se", "sftp", + "snowflake", "delta", "excel", "tableau", @@ -412,7 +447,6 @@ features = [ "test", "docs", ] -extra-dependencies = ["pyspark==3.4.*"] ### ~~~~~~~~~~~~~~~~~~ ### @@ -569,10 +603,18 @@ unfixable = [] dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" [tool.mypy] +python_version = "3.10" +files = ["koheesio/**/*.py"] +plugins = ["pydantic.mypy"] +pretty = true +warn_unused_configs = true check_untyped_defs = false disallow_untyped_calls = false disallow_untyped_defs = true -files = ["koheesio/**/*.py"] +warn_no_return = false +implicit_optional = true +allow_untyped_globals = true +disable_error_code = ["attr-defined", "return-value", "union-attr", "override"] [tool.pylint.main] fail-under = 9.5 diff --git a/src/koheesio/__about__.py b/src/koheesio/__about__.py index 81ddfde4..ff52467e 100644 --- a/src/koheesio/__about__.py +++ b/src/koheesio/__about__.py @@ -12,7 +12,7 @@ LICENSE_INFO = "Licensed as Apache 2.0" SOURCE = "https://github.com/Nike-Inc/koheesio" -__version__ = "0.8.1" +__version__ = "0.9.0rc0" __logo__ = ( 75, ( @@ -32,7 +32,7 @@ # fmt: off -def _about(): # pragma: no cover +def _about() -> str: # pragma: no cover """Return the Koheesio logo and version/about information as a string Note: this code is not meant to be readable, instead it is written to be as compact as possible """ diff --git a/src/koheesio/__init__.py b/src/koheesio/__init__.py index 6404e85a..347adde8 100644 --- a/src/koheesio/__init__.py +++ b/src/koheesio/__init__.py @@ -25,7 +25,7 @@ ] -def print_logo(): +def print_logo() -> None: global _logo_printed global _koheesio_print_logo diff --git a/src/koheesio/asyncio/__init__.py b/src/koheesio/asyncio/__init__.py index a71a8182..093c4a09 100644 --- a/src/koheesio/asyncio/__init__.py +++ b/src/koheesio/asyncio/__init__.py @@ -16,15 +16,13 @@ class AsyncStepMetaClass(StepMetaClass): It inherits from the StepMetaClass and provides additional functionality for executing asynchronous steps. - Attributes: - None - - Methods: - _execute_wrapper: Wrapper method for executing asynchronous steps. + Methods + ------- + _execute_wrapper: Wrapper method for executing asynchronous steps. """ - def _execute_wrapper(cls, *args, **kwargs): + def _execute_wrapper(cls, *args, **kwargs): # type: ignore[no-untyped-def] """Wrapper method for executing asynchronous steps. This method is called when an asynchronous step is executed. It wraps the @@ -60,16 +58,14 @@ class AsyncStepOutput(Step.Output): Merge key-value map with self. """ - def merge(self, other: Union[Dict, StepOutput]): + def merge(self, other: Union[Dict, StepOutput]) -> "AsyncStepOutput": """Merge key,value map with self Examples -------- ```python step_output = StepOutput(foo="bar") - step_output.merge( - {"lorem": "ipsum"} - ) # step_output will now contain {'foo': 'bar', 'lorem': 'ipsum'} + step_output.merge({"lorem": "ipsum"}) # step_output will now contain {'foo': 'bar', 'lorem': 'ipsum'} ``` Functionally similar to adding two dicts together; like running `{**dict_a, **dict_b}`. @@ -89,6 +85,7 @@ def merge(self, other: Union[Dict, StepOutput]): return self +# noinspection PyUnresolvedReferences class AsyncStep(Step, ABC, metaclass=AsyncStepMetaClass): """ Asynchronous step class that inherits from Step and uses the AsyncStepMetaClass metaclass. diff --git a/src/koheesio/asyncio/http.py b/src/koheesio/asyncio/http.py index 119bf197..ece14f15 100644 --- a/src/koheesio/asyncio/http.py +++ b/src/koheesio/asyncio/http.py @@ -8,7 +8,7 @@ import warnings from typing import Any, Dict, List, Optional, Tuple, Union -import nest_asyncio +import nest_asyncio # type: ignore[import-untyped] import yarl from aiohttp import BaseConnector, ClientSession, TCPConnector from aiohttp_retry import ExponentialRetry, RetryClient, RetryOptionsBase @@ -20,6 +20,7 @@ from koheesio.steps.http import HttpMethod +# noinspection PyUnresolvedReferences class AsyncHttpStep(AsyncStep, ExtraParamsMixin): """ Asynchronous HTTP step for making HTTP requests using aiohttp. @@ -45,42 +46,42 @@ class AsyncHttpStep(AsyncStep, ExtraParamsMixin): Examples -------- ```python - >>> import asyncio - >>> from aiohttp import ClientSession - >>> from aiohttp.connector import TCPConnector - >>> from aiohttp_retry import ExponentialRetry - >>> from koheesio.steps.async.http import AsyncHttpStep - >>> from yarl import URL - >>> from typing import Dict, Any, Union, List, Tuple - >>> - >>> # Initialize the AsyncHttpStep - >>> async def main(): - >>> session = ClientSession() - >>> urls = [URL('https://example.com/api/1'), URL('https://example.com/api/2')] - >>> retry_options = ExponentialRetry() - >>> connector = TCPConnector(limit=10) - >>> headers = {'Content-Type': 'application/json'} - >>> step = AsyncHttpStep( - >>> client_session=session, - >>> url=urls, - >>> retry_options=retry_options, - >>> connector=connector, - >>> headers=headers - >>> ) - >>> - >>> # Execute the step - >>> responses_urls= await step.get() - >>> - >>> return responses_urls - >>> - >>> # Run the main function - >>> responses_urls = asyncio.run(main()) + import asyncio + from aiohttp import ClientSession + from aiohttp.connector import TCPConnector + from aiohttp_retry import ExponentialRetry + from koheesio.asyncio.http import AsyncHttpStep + from yarl import URL + from typing import Dict, Any, Union, List, Tuple + + # Initialize the AsyncHttpStep + async def main(): + session = ClientSession() + urls = [URL('https://example.com/api/1'), URL('https://example.com/api/2')] + retry_options = ExponentialRetry() + connector = TCPConnector(limit=10) + headers = {'Content-Type': 'application/json'} + step = AsyncHttpStep( + client_session=session, + url=urls, + retry_options=retry_options, + connector=connector, + headers=headers + ) + + # Execute the step + responses_urls= await step.get() + + return responses_urls + + # Run the main function + responses_urls = asyncio.run(main()) ``` """ client_session: Optional[ClientSession] = Field(default=None, description="Aiohttp ClientSession", exclude=True) url: List[yarl.URL] = Field( - default=None, + default_factory=list, alias="urls", description="""Expecting list, as there is no value in executing async request for one value. yarl.URL is preferable, because params/data can be injected into URL instance""", @@ -113,7 +114,7 @@ class Output(AsyncStepOutput): default=None, description="List of responses from the API and request URL", repr=False ) - def __tasks_generator(self, method) -> List[asyncio.Task]: + def __tasks_generator(self, method: HttpMethod) -> List[asyncio.Task]: """ Generate a list of tasks for making HTTP requests. @@ -141,7 +142,7 @@ def __tasks_generator(self, method) -> List[asyncio.Task]: return tasks @model_validator(mode="after") - def _move_extra_params_to_params(self): + def _move_extra_params_to_params(self) -> AsyncHttpStep: """ Move extra_params to params dict. @@ -170,12 +171,13 @@ async def _execute(self, tasks: List[asyncio.Task]) -> List[Tuple[Dict[str, Any] try: responses_urls = await asyncio.gather(*tasks) finally: - await self.client_session.close() + if self.client_session: + await self.client_session.close() await self.__retry_client.close() return responses_urls - def _init_session(self): + def _init_session(self) -> None: """ Initialize the aiohttp session and retry client. """ @@ -189,13 +191,13 @@ def _init_session(self): ) @field_validator("timeout") - def validate_timeout(cls, timeout): + def validate_timeout(cls, timeout: Any) -> None: """ - Validate the 'data' field. + Validate the 'timeout' field. Parameters ---------- - data : Any + timeout : Any The value of the 'timeout' field. Raises @@ -206,7 +208,7 @@ def validate_timeout(cls, timeout): if timeout: raise ValueError("timeout is not allowed in AsyncHttpStep. Provide timeout through retry_options.") - def get_headers(self): + def get_headers(self) -> Union[None, dict]: """ Get the request headers. @@ -226,7 +228,8 @@ def get_headers(self): return _headers or self.headers - def set_outputs(self, response): + # noinspection PyUnusedLocal,PyMethodMayBeStatic + def set_outputs(self, response) -> None: # type: ignore[no-untyped-def] """ Set the outputs of the step. @@ -237,7 +240,8 @@ def set_outputs(self, response): """ warnings.warn("set outputs is not implemented in AsyncHttpStep.") - def get_options(self): + # noinspection PyMethodMayBeStatic + def get_options(self) -> None: """ Get the options of the step. """ @@ -245,7 +249,7 @@ def get_options(self): # Disable pylint warning: method was expected to be 'non-async' # pylint: disable=W0236 - async def request( + async def request( # type: ignore[no-untyped-def] self, method: HttpMethod, url: yarl.URL, @@ -271,10 +275,11 @@ async def request( async with self.__retry_client.request(method=method, url=url, **kwargs) as response: res = await response.json() - return (res, response.request_info.url) + return res, response.request_info.url # Disable pylint warning: method was expected to be 'non-async' # pylint: disable=W0236 + # noinspection PyMethodOverriding async def get(self) -> List[Tuple[Dict[str, Any], yarl.URL]]: """ Make GET requests. @@ -337,7 +342,7 @@ async def delete(self) -> List[Tuple[Dict[str, Any], yarl.URL]]: return responses_urls - def execute(self) -> AsyncHttpStep.Output: + def execute(self) -> None: """ Execute the step. @@ -364,9 +369,7 @@ def execute(self) -> AsyncHttpStep.Output: if self.method not in map_method_func: raise ValueError(f"Method {self.method} not implemented in AsyncHttpStep.") - self.output.responses_urls = asyncio.run(map_method_func[self.method]()) - - return self.output + self.output.responses_urls = asyncio.run(map_method_func[self.method]()) # type: ignore[index] class AsyncHttpGetStep(AsyncHttpStep): diff --git a/src/koheesio/context.py b/src/koheesio/context.py index 0f4b69e3..e0b818a4 100644 --- a/src/koheesio/context.py +++ b/src/koheesio/context.py @@ -14,11 +14,11 @@ from __future__ import annotations import re -from typing import Any, Dict, Union from collections.abc import Mapping from pathlib import Path +from typing import Any, Dict, Iterator, Union -import jsonpickle +import jsonpickle # type: ignore[import-untyped] import tomli import yaml @@ -79,7 +79,7 @@ class Context(Mapping): - `values()`: Returns all values of the Context. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] """Initializes the Context object with given arguments.""" for arg in args: if isinstance(arg, dict): @@ -87,32 +87,33 @@ def __init__(self, *args, **kwargs): if isinstance(arg, Context): kwargs = kwargs.update(arg.to_dict()) - for key, value in kwargs.items(): - self.__dict__[key] = self.process_value(value) + if kwargs: + for key, value in kwargs.items(): + self.__dict__[key] = self.process_value(value) - def __str__(self): + def __str__(self) -> str: """Returns a string representation of the Context.""" return str(dict(self.__dict__)) - def __repr__(self): + def __repr__(self) -> str: """Returns a string representation of the Context.""" return self.__str__() - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Allows for iteration across a Context""" return self.to_dict().__iter__() - def __len__(self): + def __len__(self) -> int: """Returns the length of the Context""" return self.to_dict().__len__() - def __getattr__(self, item): + def __getattr__(self, item: str) -> Any: try: return self.get(item, safe=False) except KeyError as e: raise AttributeError(item) from e - def __getitem__(self, item): + def __getitem__(self, item: str) -> Any: """Makes class subscriptable""" return self.get(item, safe=False) @@ -248,11 +249,12 @@ def from_toml(cls, toml_file_or_str: Union[str, Path]) -> Context: ------- Context """ - toml_str = toml_file_or_str # check if toml_str is pathlike if (toml_file := Path(toml_file_or_str)).exists(): toml_str = toml_file.read_text(encoding="utf-8") + else: + toml_str = str(toml_file_or_str) toml_dict = tomli.loads(toml_str) return cls.from_dict(toml_dict) @@ -421,7 +423,7 @@ def to_dict(self) -> Dict[str, Any]: if isinstance(value, Context): result[key] = value.to_dict() elif isinstance(value, list): - result[key] = [e.to_dict() if isinstance(e, Context) else e for e in value] + result[key] = [e.to_dict() if isinstance(e, Context) else e for e in value] # type: ignore[assignment] else: result[key] = value diff --git a/src/koheesio/integrations/__init__.py b/src/koheesio/integrations/__init__.py index e3dfb266..e69de29b 100644 --- a/src/koheesio/integrations/__init__.py +++ b/src/koheesio/integrations/__init__.py @@ -1,3 +0,0 @@ -""" -Nothing to see here, move along. -""" diff --git a/src/koheesio/integrations/box.py b/src/koheesio/integrations/box.py index 843601f9..cd5baab9 100644 --- a/src/koheesio/integrations/box.py +++ b/src/koheesio/integrations/box.py @@ -10,18 +10,18 @@ * Application is authorized for the enterprise (Developer Portal - MyApp - Authorization) """ +import datetime import re from typing import Any, Dict, Optional, Union -from abc import ABC, abstractmethod -from datetime import datetime -from io import BytesIO +from abc import ABC +from io import BytesIO, StringIO from pathlib import PurePath +import pandas as pd from boxsdk import Client, JWTAuth from boxsdk.object.file import File from boxsdk.object.folder import Folder -from pyspark.sql import DataFrame from pyspark.sql.functions import expr, lit from pyspark.sql.types import StructType @@ -36,6 +36,7 @@ model_validator, ) from koheesio.spark.readers import Reader +from koheesio.utils import utc_now class BoxFolderNotFoundError(Exception): @@ -105,15 +106,15 @@ class Box(Step, ABC): description="Private key passphrase generated in the app management console.", ) - client: SkipValidation[Client] = None + client: SkipValidation[Client] = None # type: ignore - def init_client(self): + def init_client(self) -> None: """Set up the Box client.""" if not self.client: self.client = Client(JWTAuth(**self.auth_options)) @property - def auth_options(self): + def auth_options(self) -> Dict[str, Any]: """ Get a dictionary of authentication options, that can be handily used in the child classes """ @@ -126,11 +127,11 @@ def auth_options(self): "rsa_private_key_passphrase": self.rsa_private_key_passphrase.get_secret_value(), } - def __init__(self, **data): + def __init__(self, **data: dict): super().__init__(**data) self.init_client() - def execute(self): + def execute(self) -> Step.Output: # type: ignore # Plug to be able to unit test ABC pass @@ -167,7 +168,7 @@ class Output(StepOutput): folder: Optional[Folder] = Field(default=None, description="Box folder object") @model_validator(mode="after") - def validate_folder_or_path(self): + def validate_folder_or_path(self) -> "BoxFolderBase": """ Validations for 'folder' and 'path' parameter usage """ @@ -183,13 +184,13 @@ def validate_folder_or_path(self): return self @property - def _obj_from_id(self): + def _obj_from_id(self) -> Folder: """ Get folder object from identifier """ return self.client.folder(folder_id=self.folder).get() if isinstance(self.folder, str) else self.folder - def action(self): + def action(self) -> Optional[Folder]: """ Placeholder for 'action' method, that should be implemented in the child classes @@ -223,7 +224,7 @@ class BoxFolderGet(BoxFolderBase): False, description="Create sub-folders recursively if the path does not exist." ) - def _get_or_create_folder(self, current_folder_object, next_folder_name): + def _get_or_create_folder(self, current_folder_object: Folder, next_folder_name: str) -> Folder: """ Get or create a folder. @@ -238,9 +239,16 @@ def _get_or_create_folder(self, current_folder_object, next_folder_name): ------- next_folder_object: Folder Next folder object. + + Raises + ------ + BoxFolderNotFoundError + If the folder does not exist and 'create_sub_folders' is set to False. """ for item in current_folder_object.get_items(): + # noinspection PyUnresolvedReferences if item.type == "folder" and item.name == next_folder_name: + # noinspection PyTypeChecker return item if self.create_sub_folders: @@ -251,13 +259,13 @@ def _get_or_create_folder(self, current_folder_object, next_folder_name): "to create required directory structure automatically." ) - def action(self): + def action(self) -> Optional[Folder]: """ Get folder action Returns ------- - folder: Folder + folder: Optional[Folder] Box Folder object as specified in Box SDK """ current_folder_object = None @@ -267,7 +275,9 @@ def action(self): if self.path: cleaned_path_parts = [p for p in PurePath(self.path).parts if p.strip() not in [None, "", " ", "/"]] - current_folder_object = self.client.folder(folder_id=self.root) if isinstance(self.root, str) else self.root + current_folder_object: Union[Folder, str] = ( + self.client.folder(folder_id=self.root) if isinstance(self.root, str) else self.root + ) for next_folder_name in cleaned_path_parts: current_folder_object = self._get_or_create_folder(current_folder_object, next_folder_name) @@ -295,7 +305,7 @@ class BoxFolderCreate(BoxFolderGet): ) @field_validator("folder") - def validate_folder(cls, folder): + def validate_folder(cls, folder: Any) -> None: """ Validate 'folder' parameter """ @@ -322,7 +332,7 @@ class BoxFolderDelete(BoxFolderBase): ``` """ - def action(self): + def action(self) -> None: """ Delete folder action @@ -345,7 +355,7 @@ class BoxReaderBase(Box, Reader, ABC): """ schema_: Optional[StructType] = Field( - None, + default=None, alias="schema", description="[Optional] Schema that will be applied during the creation of Spark DataFrame", ) @@ -354,15 +364,6 @@ class BoxReaderBase(Box, Reader, ABC): description="[Optional] Set of extra parameters that should be passed to the Spark reader.", ) - class Output(StepOutput): - """Make default reader output optional to gracefully handle 'no-files / folder' cases.""" - - df: Optional[DataFrame] = Field(default=None, description="The Spark DataFrame") - - @abstractmethod - def execute(self) -> Output: - raise NotImplementedError - class BoxCsvFileReader(BoxReaderBase): """ @@ -397,7 +398,7 @@ class BoxCsvFileReader(BoxReaderBase): file: Union[str, list[str]] = Field(default=..., description="ID or list of IDs for the files to read.") - def execute(self): + def execute(self) -> BoxReaderBase.Output: """ Loop through the list of provided file identifiers and load data into dataframe. For traceability purposes the following columns will be added to the dataframe: @@ -412,9 +413,14 @@ def execute(self): for f in self.file: self.log.debug(f"Reading contents of file with the ID '{f}' into Spark DataFrame") file = self.client.file(file_id=f) - data = file.content().decode("utf-8").splitlines() - rdd = self.spark.sparkContext.parallelize(data) - temp_df = self.spark.read.csv(rdd, header=True, schema=self.schema_, **self.params) + data = file.content().decode("utf-8") + + data_buffer = StringIO(data) + temp_df_pandas = pd.read_csv(data_buffer, header=0, dtype=str if not self.schema_ else None, **self.params) # type: ignore + temp_df = self.spark.createDataFrame(temp_df_pandas, schema=self.schema_) + + # type: ignore + # noinspection PyUnresolvedReferences temp_df = ( temp_df # fmt: off @@ -456,9 +462,9 @@ class BoxCsvPathReader(BoxReaderBase): """ path: str = Field(default=..., description="Box path") - filter: Optional[str] = Field(default=r".csv|.txt$", description="[Optional] Regexp to filter folder contents") + filter: str = Field(default=r".csv|.txt$", description="[Optional] Regexp to filter folder contents") - def execute(self): + def execute(self) -> BoxReaderBase.Output: """ Identify the list of files from the source Box path that match desired filter and load them into Dataframe """ @@ -507,13 +513,13 @@ class BoxFileBase(Box): ) path: Optional[str] = Field(default=None, description="Path to the Box folder, for example: `folder/sub-folder/lz") - def action(self, file: File, folder: Folder): + def action(self, file: File, folder: Folder) -> None: """ Abstract class for File level actions. """ raise NotImplementedError - def execute(self): + def execute(self) -> Box.Output: """ Generic execute method for all BoxToBox interactions. Deals with getting the correct folder and file objects from various parameter inputs @@ -547,20 +553,20 @@ class BoxToBoxFileCopy(BoxFileBase): ``` """ - def action(self, file: File, folder: Folder): + def action(self, file: File, folder: Folder) -> None: """ Copy file to the desired destination and extend file description with the processing info Parameters ---------- - file: File + file : File File object as specified in Box SDK - folder: Folder + folder : Folder Folder object as specified in Box SDK """ self.log.info(f"Copying '{file.get()}' to '{folder.get()}'...") file.copy(parent_folder=folder).update_info( - data={"description": "\n".join([f"File processed on {datetime.utcnow()}", file.get()["description"]])} + data={"description": "\n".join([f"File processed on {utc_now()}", file.get()["description"]])} ) @@ -583,20 +589,20 @@ class BoxToBoxFileMove(BoxFileBase): ``` """ - def action(self, file: File, folder: Folder): + def action(self, file: File, folder: Folder) -> None: """ Move file to the desired destination and extend file description with the processing info Parameters ---------- - file: File + file : File File object as specified in Box SDK - folder: Folder + folder : Folder Folder object as specified in Box SDK """ self.log.info(f"Moving '{file.get()}' to '{folder.get()}'...") file.move(parent_folder=folder).update_info( - data={"description": "\n".join([f"File processed on {datetime.utcnow()}", file.get()["description"]])} + data={"description": "\n".join([f"File processed on {utc_now()}", file.get()["description"]])} ) @@ -638,7 +644,7 @@ class Output(StepOutput): shared_link: str = Field(default=..., description="Shared link for the Box file") @model_validator(mode="before") - def validate_name_for_binary_data(cls, values): + def validate_name_for_binary_data(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate 'file_name' parameter when providing a binary input for 'file'.""" file, file_name = values.get("file"), values.get("file_name") if not isinstance(file, str) and not file_name: @@ -646,7 +652,7 @@ def validate_name_for_binary_data(cls, values): return values - def action(self): + def action(self) -> None: _file = self.file _name = self.file_name @@ -658,6 +664,7 @@ def action(self): folder: Folder = BoxFolderGet.from_step(self, create_sub_folders=True).execute().folder folder.preflight_check(size=0, name=_name) + # noinspection PyUnresolvedReferences self.log.info(f"Uploading file '{_name}' to Box folder '{folder.get().name}'...") _box_file: File = folder.upload_stream(file_stream=_file, file_name=_name, file_description=self.description) diff --git a/src/koheesio/integrations/snowflake/__init__.py b/src/koheesio/integrations/snowflake/__init__.py new file mode 100644 index 00000000..d461ffff --- /dev/null +++ b/src/koheesio/integrations/snowflake/__init__.py @@ -0,0 +1,551 @@ +# noinspection PyUnresolvedReferences +""" +Snowflake steps and tasks for Koheesio + +Every class in this module is a subclass of `Step` or `Task` and is used to perform operations on Snowflake. + +Notes +----- +Every Step in this module is based on [SnowflakeBaseModel](./snowflake.md#koheesio.spark.snowflake.SnowflakeBaseModel). +The following parameters are available for every Step. + +Parameters +---------- +url : str + Hostname for the Snowflake account, e.g. .snowflakecomputing.com. + Alias for `sfURL`. +user : str + Login name for the Snowflake user. + Alias for `sfUser`. +password : SecretStr + Password for the Snowflake user. + Alias for `sfPassword`. +database : str + The database to use for the session after connecting. + Alias for `sfDatabase`. +sfSchema : str + The schema to use for the session after connecting. + Alias for `schema` ("schema" is a reserved name in Pydantic, so we use `sfSchema` as main name instead). +role : str + The default security role to use for the session after connecting. + Alias for `sfRole`. +warehouse : str + The default virtual warehouse to use for the session after connecting. + Alias for `sfWarehouse`. +authenticator : Optional[str], optional, default=None + Authenticator for the Snowflake user. Example: "okta.com". +options : Optional[Dict[str, Any]], optional, default={"sfCompress": "on", "continue_on_error": "off"} + Extra options to pass to the Snowflake connector. +format : str, optional, default="snowflake" + The default `snowflake` format can be used natively in Databricks, use `net.snowflake.spark.snowflake` in other + environments and make sure to install required JARs. +""" + +from __future__ import annotations + +from typing import Any, Dict, Generator, List, Optional, Set, Union +from abc import ABC +from contextlib import contextmanager +from types import ModuleType + +from koheesio import Step +from koheesio.logger import warn +from koheesio.models import ( + BaseModel, + ExtraParamsMixin, + Field, + PrivateAttr, + SecretStr, + conlist, + field_validator, + model_validator, +) + +__all__ = [ + "GrantPrivilegesOnFullyQualifiedObject", + "GrantPrivilegesOnObject", + "GrantPrivilegesOnTable", + "GrantPrivilegesOnView", + "SnowflakeRunQueryPython", + "SnowflakeBaseModel", + "SnowflakeStep", + "SnowflakeTableStep", + "safe_import_snowflake_connector", +] + +# pylint: disable=inconsistent-mro, too-many-lines +# Turning off inconsistent-mro because we are using ABCs and Pydantic models and Tasks together in the same class +# Turning off too-many-lines because we are defining a lot of classes in this file + + +def safe_import_snowflake_connector() -> Optional[ModuleType]: + """Validate that the Snowflake connector is installed + + Returns + ------- + Optional[ModuleType] + The Snowflake connector module if it is installed, otherwise None + """ + try: + from snowflake import connector as snowflake_connector + + return snowflake_connector + except (ImportError, ModuleNotFoundError): + warn( + "You need to have the `snowflake-connector-python` package installed to use the Snowflake steps that are" + "based around SnowflakeRunQueryPython. You can install this in Koheesio by adding `koheesio[snowflake]` to " + "your package dependencies.", + UserWarning, + ) + return None + + +class SnowflakeBaseModel(BaseModel, ExtraParamsMixin, ABC): # type: ignore[misc] + """ + BaseModel for setting up Snowflake Driver options. + + Notes + ----- + * Snowflake is supported natively in Databricks 4.2 and newer: + https://docs.snowflake.com/en/user-guide/spark-connector-databricks + * Refer to Snowflake docs for the installation instructions for non-Databricks environments: + https://docs.snowflake.com/en/user-guide/spark-connector-install + * Refer to Snowflake docs for connection options: + https://docs.snowflake.com/en/user-guide/spark-connector-use#setting-configuration-options-for-the-connector + + Parameters + ---------- + url : str + Hostname for the Snowflake account, e.g. .snowflakecomputing.com. + Alias for `sfURL`. + user : str + Login name for the Snowflake user. + Alias for `sfUser`. + password : SecretStr + Password for the Snowflake user. + Alias for `sfPassword`. + role : str + The default security role to use for the session after connecting. + Alias for `sfRole`. + warehouse : str + The default virtual warehouse to use for the session after connecting. + Alias for `sfWarehouse`. + authenticator : Optional[str], optional, default=None + Authenticator for the Snowflake user. Example: "okta.com". + database : Optional[str], optional, default=None + The database to use for the session after connecting. + Alias for `sfDatabase`. + sfSchema : Optional[str], optional, default=None + The schema to use for the session after connecting. + Alias for `schema` ("schema" is a reserved name in Pydantic, so we use `sfSchema` as main name instead). + options : Optional[Dict[str, Any]], optional, default={"sfCompress": "on", "continue_on_error": "off"} + Extra options to pass to the Snowflake connector. + """ + + url: str = Field( + default=..., + alias="sfURL", + description="Hostname for the Snowflake account, e.g. .snowflakecomputing.com", + examples=["example.snowflakecomputing.com"], + ) + user: str = Field(default=..., alias="sfUser", description="Login name for the Snowflake user") + password: SecretStr = Field(default=..., alias="sfPassword", description="Password for the Snowflake user") + role: str = Field( + default=..., alias="sfRole", description="The default security role to use for the session after connecting" + ) + warehouse: str = Field( + default=..., + alias="sfWarehouse", + description="The default virtual warehouse to use for the session after connecting", + ) + authenticator: Optional[str] = Field( + default=None, + description="Authenticator for the Snowflake user", + examples=["okta.com"], + ) + database: Optional[str] = Field( + default=None, alias="sfDatabase", description="The database to use for the session after connecting" + ) + sfSchema: Optional[str] = Field( + default=..., alias="schema", description="The schema to use for the session after connecting" + ) + options: Optional[Dict[str, Any]] = Field( + default={"sfCompress": "on", "continue_on_error": "off"}, + description="Extra options to pass to the Snowflake connector", + ) + + def get_options(self, by_alias: bool = True, include: Optional[Set[str]] = None) -> Dict[str, Any]: + """Get the sfOptions as a dictionary. + + Note + ---- + - Any parameters that are `None` are excluded from the output dictionary. + - `sfSchema` and `password` are handled separately. + - The values from both 'options' and 'params' (kwargs / extra params) are included as is. + - Koheesio specific fields are excluded by default (i.e. `name`, `description`, `format`). + + Parameters + ---------- + by_alias : bool, optional, default=True + Whether to use the alias names or not. E.g. `sfURL` instead of `url` + include : Optional[Set[str]], optional, default=None + Set of keys to include in the output dictionary. When None is provided, all fields will be returned. + Note: be sure to include all the keys you need. + """ + exclude_set = { + # Exclude koheesio specific fields + "name", + "description", + # options and params are separately implemented + "params", + "options", + # schema and password have to be handled separately + "sfSchema", + "password", + } - (include or set()) + + fields = self.model_dump( + by_alias=by_alias, + exclude_none=True, + exclude=exclude_set, + ) + + # handle schema and password + fields.update( + { + "sfSchema" if by_alias else "schema": self.sfSchema, + "sfPassword" if by_alias else "password": self.password.get_secret_value(), + } + ) + + # handle include + if include: + # user specified filter + fields = {key: value for key, value in fields.items() if key in include} + else: + # default filter + include = {"options", "params"} + + # handle options + if "options" in include: + options = fields.pop("options", self.options) + fields.update(**options) + + # handle params + if "params" in include: + params = fields.pop("params", self.params) + fields.update(**params) + + return {key: value for key, value in fields.items() if value} + + +class SnowflakeStep(SnowflakeBaseModel, Step, ABC): + """Expands the SnowflakeBaseModel so that it can be used as a Step""" + + +class SnowflakeTableStep(SnowflakeStep, ABC): + """Expands the SnowflakeStep, adding a 'table' parameter""" + + table: str = Field(default=..., description="The name of the table") + + @property + def full_name(self) -> str: + """ + Returns the fullname of snowflake table based on schema and database parameters. + + Returns + ------- + str + Snowflake Complete table name (database.schema.table) + """ + return f"{self.database}.{self.sfSchema}.{self.table}" + + +class SnowflakeRunQueryPython(SnowflakeStep): + """ + Run a query on Snowflake using the Python connector + + Example + ------- + ```python + RunQueryPython( + database="MY_DB", + schema="MY_SCHEMA", + warehouse="MY_WH", + user="account", + password="***", + role="APPLICATION.SNOWFLAKE.ADMIN", + query="CREATE TABLE test (col1 string)", + ).execute() + ``` + """ + + query: str = Field(default=..., description="The query to run", alias="sql", serialization_alias="query") + account: str = Field(default=..., description="Snowflake Account Name", alias="account") + + # for internal use + _snowflake_connector: Optional[ModuleType] = PrivateAttr(default_factory=safe_import_snowflake_connector) + + class Output(SnowflakeStep.Output): + """Output class for RunQueryPython""" + + results: List = Field(default_factory=list, description="The results of the query") + + @field_validator("query") + def validate_query(cls, query: str) -> str: + """Replace escape characters, strip whitespace, ensure it is not empty""" + query = query.replace("\\n", "\n").replace("\\t", "\t").strip() + if not query: + raise ValueError("Query cannot be empty") + return query + + def get_options(self, by_alias: bool = False, include: Optional[Set[str]] = None) -> Dict[str, Any]: + if include is None: + include = { + "account", + "url", + "authenticator", + "user", + "role", + "warehouse", + "database", + "schema", + "password", + } + return super().get_options(by_alias=by_alias, include=include) + + @property + @contextmanager + def conn(self) -> Generator: + if not self._snowflake_connector: + raise RuntimeError("Snowflake connector is not installed. Please install `snowflake-connector-python`.") + + sf_options = self.get_options() + _conn = self._snowflake_connector.connect(**sf_options) + self.log.info(f"Connected to Snowflake account: {sf_options['account']}") + + try: + yield _conn + finally: + if _conn: + _conn.close() + + def get_query(self) -> str: + """allows to customize the query""" + return self.query + + def execute(self) -> None: + """Execute the query""" + with self.conn as conn: + cursors = conn.execute_string(self.get_query()) + for cursor in cursors: + self.log.debug(f"Cursor executed: {cursor}") + self.output.results.extend(cursor.fetchall()) + + +class GrantPrivilegesOnObject(SnowflakeRunQueryPython): + """ + A wrapper on Snowflake GRANT privileges + + With this Step, you can grant Snowflake privileges to a set of roles on a table, a view, or an object + + See Also + -------- + https://docs.snowflake.com/en/sql-reference/sql/grant-privilege.html + + Parameters + ---------- + account : str + Snowflake Account Name. + warehouse : str + The name of the warehouse. Alias for `sfWarehouse` + user : str + The username. Alias for `sfUser` + password : SecretStr + The password. Alias for `sfPassword` + role : str + The role name + object : str + The name of the object to grant privileges on + type : str + The type of object to grant privileges on, e.g. TABLE, VIEW + privileges : Union[conlist(str, min_length=1), str] + The Privilege/Permission or list of Privileges/Permissions to grant on the given object. + roles : Union[conlist(str, min_length=1), str] + The Role or list of Roles to grant the privileges to + + Example + ------- + ```python + GrantPermissionsOnTable( + object="MY_TABLE", + type="TABLE", + warehouse="MY_WH", + user="gid.account@nike.com", + password=Secret("super-secret-password"), + role="APPLICATION.SNOWFLAKE.ADMIN", + permissions=["SELECT", "INSERT"], + ).execute() + ``` + + In this example, the `APPLICATION.SNOWFLAKE.ADMIN` role will be granted `SELECT` and `INSERT` privileges on + the `MY_TABLE` table using the `MY_WH` warehouse. + """ + + object: str = Field(default=..., description="The name of the object to grant privileges on") + type: str = Field(default=..., description="The type of object to grant privileges on, e.g. TABLE, VIEW") + + privileges: Union[conlist(str, min_length=1), str] = Field( # type: ignore[valid-type] + default=..., + alias="permissions", + description="The Privilege/Permission or list of Privileges/Permissions to grant on the given object. " + "See https://docs.snowflake.com/en/sql-reference/sql/grant-privilege.html", + ) + roles: Union[conlist(str, min_length=1), str] = Field( # type: ignore[valid-type] + default=..., + alias="role", + validation_alias="roles", + description="The Role or list of Roles to grant the privileges to", + ) + query: str = "GRANT {privileges} ON {type} {object} TO ROLE {role}" + + class Output(SnowflakeRunQueryPython.Output): + """Output class for GrantPrivilegesOnObject""" + + query: conlist(str, min_length=1) = Field( # type: ignore[valid-type] + default=..., description="Query that was executed to grant privileges", validate_default=False + ) + + @model_validator(mode="before") + def set_roles_privileges(cls, values: dict) -> dict: + """Coerce roles and privileges to be lists if they are not already.""" + roles_value = values.get("roles") or values.get("role") + privileges_value = values.get("privileges") + + if not (roles_value and privileges_value): + raise ValueError("You have to specify roles AND privileges when using 'GrantPrivilegesOnObject'.") + + # coerce values to be lists + values["roles"] = [roles_value] if isinstance(roles_value, str) else roles_value + values["role"] = values["roles"][0] # hack to keep the validator happy + values["privileges"] = [privileges_value] if isinstance(privileges_value, str) else privileges_value + + return values + + @model_validator(mode="after") + def validate_object_and_object_type(self) -> "GrantPrivilegesOnObject": + """Validate that the object and type are set.""" + object_value = self.object + if not object_value: + raise ValueError("You must provide an `object`, this should be the name of the object. ") + + object_type = self.type + if not object_type: + raise ValueError( + "You must provide a `type`, e.g. TABLE, VIEW, DATABASE. " + "See https://docs.snowflake.com/en/sql-reference/sql/grant-privilege.html" + ) + + return self + + # noinspection PyMethodOverriding + def get_query(self, role: str) -> str: + """Build the GRANT query + + Parameters + ---------- + role: str + The role name + + Returns + ------- + query : str + The Query that performs the grant + """ + query = self.query.format( + privileges=",".join(self.privileges), + type=self.type, + object=self.object, + role=role, + ) + return query + + def execute(self) -> None: + self.output.query = [] + roles = self.roles + + for role in roles: + query = self.get_query(role) + self.output.query.append(query) + + # Create a new instance of SnowflakeRunQueryPython with the current query + instance = SnowflakeRunQueryPython.from_step(self, query=query) + instance.execute() + print(f"{instance.output = }") + self.output.results.extend(instance.output.results) + + +class GrantPrivilegesOnFullyQualifiedObject(GrantPrivilegesOnObject): + """Grant Snowflake privileges to a set of roles on a fully qualified object, i.e. `database.schema.object_name` + + This class is a subclass of `GrantPrivilegesOnObject` and is used to grant privileges on a fully qualified object. + The advantage of using this class is that it sets the object name to be fully qualified, i.e. + `database.schema.object_name`. + + Meaning, you can set the `database`, `schema` and `object` separately and the object name will be set to be fully + qualified, i.e. `database.schema.object_name`. + + Example + ------- + ```python + GrantPrivilegesOnFullyQualifiedObject( + database="MY_DB", + schema="MY_SCHEMA", + warehouse="MY_WH", + ... + object="MY_TABLE", + type="TABLE", + ... + ) + ``` + + In this example, the object name will be set to be fully qualified, i.e. `MY_DB.MY_SCHEMA.MY_TABLE`. + If you were to use `GrantPrivilegesOnObject` instead, you would have to set the object name to be fully qualified + yourself. + """ + + @model_validator(mode="after") + def set_object_name(self) -> "GrantPrivilegesOnFullyQualifiedObject": + """Set the object name to be fully qualified, i.e. database.schema.object_name""" + # database, schema, obj_name + db = self.database + schema = self.model_dump()["sfSchema"] # since "schema" is a reserved name + obj_name = self.object + + self.object = f"{db}.{schema}.{obj_name}" + + return self + + +class GrantPrivilegesOnTable(GrantPrivilegesOnFullyQualifiedObject): + """Grant Snowflake privileges to a set of roles on a table""" + + type: str = "TABLE" + object: str = Field( + default=..., + alias="table", + description="The name of the Table to grant Privileges on. This should be just the name of the table; so " + "without Database and Schema, use sfDatabase/database and sfSchema/schema to set those instead.", + ) + + +class GrantPrivilegesOnView(GrantPrivilegesOnFullyQualifiedObject): + """Grant Snowflake privileges to a set of roles on a view""" + + type: str = "VIEW" + object: str = Field( + default=..., + alias="view", + description="The name of the View to grant Privileges on. This should be just the name of the view; so " + "without Database and Schema, use sfDatabase/database and sfSchema/schema to set those instead.", + ) diff --git a/src/koheesio/integrations/snowflake/test_utils.py b/src/koheesio/integrations/snowflake/test_utils.py new file mode 100644 index 00000000..8b85e97d --- /dev/null +++ b/src/koheesio/integrations/snowflake/test_utils.py @@ -0,0 +1,69 @@ +"""Module holding re-usable test utilities for Snowflake modules""" + +from typing import Generator +from unittest.mock import MagicMock, patch + +# safe import pytest fixture +try: + import pytest +except (ImportError, ModuleNotFoundError): + pytest = MagicMock() + + +@pytest.fixture(scope="function") +def mock_query() -> Generator: + """Mock the query execution for SnowflakeRunQueryPython + + This can be used to test the query execution without actually connecting to Snowflake. + + Example + ------- + ```python + def test_execute(self, mock_query): + # Arrange + query = "SELECT * FROM two_row_table" + mock_query.expected_data = [("row1",), ("row2",)] + + # Act + instance = SnowflakeRunQueryPython(**COMMON_OPTIONS, query=query, account="42") + instance.execute() + + # Assert + mock_query.assert_called_with(query) + assert instance.output.results == mock_query.expected_data + ``` + + In this example, we are using the mock_query fixture to test the execution of a query. + - We set the expected data to a known value by setting `mock_query.expected_data`, + - Then, we execute the query. + - We then assert that the query was called with the expected query by using `mock_query.assert_called_with` and + that the results are as expected. + """ + with patch("koheesio.integrations.snowflake.SnowflakeRunQueryPython.conn", new_callable=MagicMock) as mock_conn: + mock_cursor = MagicMock() + mock_conn.__enter__.return_value.execute_string.return_value = [mock_cursor] + + class MockQuery: + def __init__(self) -> None: + self.mock_conn = mock_conn + self.mock_cursor = mock_cursor + self._expected_data: list = [] + + def assert_called_with(self, query: str) -> None: + self.mock_conn.__enter__.return_value.execute_string.assert_called_once_with(query) + self.mock_cursor.fetchall.return_value = self.expected_data + + @property + def expected_data(self) -> list: + return self._expected_data + + @expected_data.setter + def expected_data(self, data: list) -> None: + self._expected_data = data + self.set_expected_data() + + def set_expected_data(self) -> None: + self.mock_cursor.fetchall.return_value = self.expected_data + + mock_query_instance = MockQuery() + yield mock_query_instance diff --git a/src/koheesio/integrations/spark/dq/spark_expectations.py b/src/koheesio/integrations/spark/dq/spark_expectations.py index 8766a8e4..2f90b8f3 100644 --- a/src/koheesio/integrations/spark/dq/spark_expectations.py +++ b/src/koheesio/integrations/spark/dq/spark_expectations.py @@ -4,7 +4,10 @@ from typing import Any, Dict, Optional, Union +# noinspection PyUnresolvedReferences,PyPep8Naming from spark_expectations.config.user_config import Constants as user_config + +# noinspection PyUnresolvedReferences from spark_expectations.core.expectations import ( SparkExpectations, WrappedDataFrameWriter, @@ -13,8 +16,8 @@ from pydantic import Field import pyspark -from pyspark.sql import DataFrame +from koheesio.spark import DataFrame from koheesio.spark.transformations import Transformation from koheesio.spark.writers import BatchOutputMode diff --git a/src/koheesio/integrations/spark/sftp.py b/src/koheesio/integrations/spark/sftp.py index 672fdfd7..90812b8d 100644 --- a/src/koheesio/integrations/spark/sftp.py +++ b/src/koheesio/integrations/spark/sftp.py @@ -79,12 +79,12 @@ class SFTPWriteMode(str, Enum): UPDATE = "update" @classmethod - def from_string(cls, mode: str): + def from_string(cls, mode: str) -> "SFTPWriteMode": """Return the SFTPWriteMode for the given string.""" return cls[mode.upper()] @property - def write_mode(self): + def write_mode(self) -> str: """Return the write mode for the given SFTPWriteMode.""" if self in {SFTPWriteMode.OVERWRITE, SFTPWriteMode.BACKUP, SFTPWriteMode.EXCLUSIVE, SFTPWriteMode.UPDATE}: return "wb" # Overwrite, Backup, Exclusive, Update modes set the file to be written from the beginning @@ -148,7 +148,7 @@ class SFTPWriter(Writer): mode: SFTPWriteMode = Field( default=SFTPWriteMode.OVERWRITE, - description="Write mode: overwrite, append, ignore, exclusive, backup, or update." + SFTPWriteMode.__doc__, + description="Write mode: overwrite, append, ignore, exclusive, backup, or update." + SFTPWriteMode.__doc__, # type: ignore ) # private attrs @@ -179,26 +179,26 @@ def validate_path_and_file_name(cls, data: dict) -> dict: return data @field_validator("host") - def validate_sftp_host(cls, v) -> str: + def validate_sftp_host(cls, host: str) -> str: """Validate the host""" # remove the sftp:// prefix if present - if v.startswith("sftp://"): - v = v.replace("sftp://", "") + if host.startswith("sftp://"): + host = host.replace("sftp://", "") # remove the trailing slash if present - if v.endswith("/"): - v = v[:-1] + if host.endswith("/"): + host = host[:-1] - return v + return host @property - def write_mode(self): + def write_mode(self) -> str: """Return the write mode for the given SFTPWriteMode.""" mode = SFTPWriteMode.from_string(self.mode) # Convert string to SFTPWriteMode return mode.write_mode @property - def transport(self): + def transport(self) -> Transport: """Return the transport for the SFTP connection. If it doesn't exist, create it. If the username and password are provided, use them to connect to the SFTP server. @@ -224,14 +224,14 @@ def client(self) -> SFTPClient: raise e return self.__client__ - def _close_client(self): + def _close_client(self) -> None: """Close the SFTP client and transport.""" if self.client: self.client.close() if self.transport: self.transport.close() - def write_file(self, file_path: str, buffer_output: InstanceOf[BufferWriter.Output]): + def write_file(self, file_path: str, buffer_output: InstanceOf[BufferWriter.Output]) -> None: """ Using Paramiko, write the data in the buffer to SFTP. """ @@ -292,7 +292,7 @@ def _handle_write_mode(self, file_path: str, buffer_output: InstanceOf[BufferWri # Then overwrite the file self.write_file(file_path, buffer_output) - def execute(self): + def execute(self) -> Writer.Output: buffer_output: InstanceOf[BufferWriter.Output] = self.buffer_writer.write(self.df) # write buffer to the SFTP server @@ -377,7 +377,7 @@ class SendCsvToSftp(PandasCsvBufferWriter, SFTPWriter): For more details on the CSV parameters, refer to the PandasCsvBufferWriter class documentation. """ - buffer_writer: PandasCsvBufferWriter = Field(default=None, validate_default=False) + buffer_writer: Optional[PandasCsvBufferWriter] = Field(default=None, validate_default=False) @model_validator(mode="after") def set_up_buffer_writer(self) -> "SendCsvToSftp": @@ -385,7 +385,7 @@ def set_up_buffer_writer(self) -> "SendCsvToSftp": self.buffer_writer = PandasCsvBufferWriter(**self.get_options(options_type="koheesio_pandas_buffer_writer")) return self - def execute(self): + def execute(self) -> SFTPWriter.Output: SFTPWriter.execute(self) @@ -459,7 +459,7 @@ class SendJsonToSftp(PandasJsonBufferWriter, SFTPWriter): For more details on the JSON parameters, refer to the PandasJsonBufferWriter class documentation. """ - buffer_writer: PandasJsonBufferWriter = Field(default=None, validate_default=False) + buffer_writer: Optional[PandasJsonBufferWriter] = Field(default=None, validate_default=False) @model_validator(mode="after") def set_up_buffer_writer(self) -> "SendJsonToSftp": @@ -469,5 +469,5 @@ def set_up_buffer_writer(self) -> "SendJsonToSftp": ) return self - def execute(self): + def execute(self) -> SFTPWriter.Output: SFTPWriter.execute(self) diff --git a/src/koheesio/integrations/spark/snowflake.py b/src/koheesio/integrations/spark/snowflake.py new file mode 100644 index 00000000..59686d93 --- /dev/null +++ b/src/koheesio/integrations/spark/snowflake.py @@ -0,0 +1,1106 @@ +# noinspection PyUnresolvedReferences +""" +Snowflake steps and tasks for Koheesio + +Every class in this module is a subclass of `Step` or `Task` and is used to perform operations on Snowflake. + +Notes +----- +Every Step in this module is based on [SnowflakeBaseModel](./snowflake.md#koheesio.integrations.snowflake.SnowflakeBaseModel). +The following parameters are available for every Step. + +Parameters +---------- +url : str + Hostname for the Snowflake account, e.g. .snowflakecomputing.com. + Alias for `sfURL`. +user : str + Login name for the Snowflake user. + Alias for `sfUser`. +password : SecretStr + Password for the Snowflake user. + Alias for `sfPassword`. +database : str + The database to use for the session after connecting. + Alias for `sfDatabase`. +sfSchema : str + The schema to use for the session after connecting. + Alias for `schema` ("schema" is a reserved name in Pydantic, so we use `sfSchema` as main name instead). +role : str + The default security role to use for the session after connecting. + Alias for `sfRole`. +warehouse : str + The default virtual warehouse to use for the session after connecting. + Alias for `sfWarehouse`. +authenticator : Optional[str], optional, default=None + Authenticator for the Snowflake user. Example: "okta.com". +options : Optional[Dict[str, Any]], optional, default={"sfCompress": "on", "continue_on_error": "off"} + Extra options to pass to the Snowflake connector. +format : str, optional, default="snowflake" + The default `snowflake` format can be used natively in Databricks, use `net.snowflake.spark.snowflake` in other + environments and make sure to install required JARs. +""" + +from __future__ import annotations + +import json +from typing import Any, Callable, Dict, List, Optional, Set, Union +from abc import ABC +from copy import deepcopy +from textwrap import dedent + +from pyspark.sql import Window +from pyspark.sql import functions as f +from pyspark.sql import types as t + +from koheesio import Step, StepOutput +from koheesio.integrations.snowflake import * +from koheesio.logger import LoggingFactory, warn +from koheesio.models import ExtraParamsMixin, Field, field_validator, model_validator +from koheesio.spark import DataFrame, DataType, SparkStep +from koheesio.spark.delta import DeltaTableStep +from koheesio.spark.readers.delta import DeltaTableReader, DeltaTableStreamReader +from koheesio.spark.readers.jdbc import JdbcReader +from koheesio.spark.transformations import Transformation +from koheesio.spark.writers import BatchOutputMode, Writer +from koheesio.spark.writers.stream import ( + ForEachBatchStreamWriter, + writer_to_foreachbatch, +) + +__all__ = [ + "AddColumn", + "CreateOrReplaceTableFromDataFrame", + "DbTableQuery", + "GetTableSchema", + "GrantPrivilegesOnFullyQualifiedObject", + "GrantPrivilegesOnObject", + "GrantPrivilegesOnTable", + "GrantPrivilegesOnView", + "Query", + "RunQuery", + "SnowflakeBaseModel", + "SnowflakeReader", + "SnowflakeStep", + "SnowflakeTableStep", + "SnowflakeTransformation", + "SnowflakeWriter", + "SyncTableAndDataFrameSchema", + "SynchronizeDeltaToSnowflakeTask", + "TableExists", + "TagSnowflakeQuery", + "map_spark_type", +] + +# pylint: disable=inconsistent-mro, too-many-lines +# Turning off inconsistent-mro because we are using ABCs and Pydantic models and Tasks together in the same class +# Turning off too-many-lines because we are defining a lot of classes in this file + + +def map_spark_type(spark_type: t.DataType) -> str: + """ + Translates Spark DataFrame Schema type to SnowFlake type + + | Basic Types | Snowflake Type | + |-------------------|----------------| + | StringType | STRING | + | NullType | STRING | + | BooleanType | BOOLEAN | + + | Numeric Types | Snowflake Type | + |-------------------|----------------| + | LongType | BIGINT | + | IntegerType | INT | + | ShortType | SMALLINT | + | DoubleType | DOUBLE | + | FloatType | FLOAT | + | NumericType | FLOAT | + | ByteType | BINARY | + + | Date / Time Types | Snowflake Type | + |-------------------|----------------| + | DateType | DATE | + | TimestampType | TIMESTAMP | + + | Advanced Types | Snowflake Type | + |-------------------|----------------| + | DecimalType | DECIMAL | + | MapType | VARIANT | + | ArrayType | VARIANT | + | StructType | VARIANT | + + References + ---------- + - Spark SQL DataTypes: https://spark.apache.org/docs/latest/sql-ref-datatypes.html + - Snowflake DataTypes: https://docs.snowflake.com/en/sql-reference/data-types.html + + Parameters + ---------- + spark_type : pyspark.sql.types.DataType + DataType taken out of the StructField + + Returns + ------- + str + The Snowflake data type + """ + # StructField means that the entire Field was passed, we need to extract just the dataType before continuing + if isinstance(spark_type, t.StructField): + spark_type = spark_type.dataType + + # Check if the type is DayTimeIntervalType + if isinstance(spark_type, t.DayTimeIntervalType): + warn( + "DayTimeIntervalType is being converted to STRING. " + "Consider converting to a more supported date/time/timestamp type in Snowflake." + ) + + # fmt: off + # noinspection PyUnresolvedReferences + data_type_map = { + # Basic Types + t.StringType: "STRING", + t.NullType: "STRING", + t.BooleanType: "BOOLEAN", + + # Numeric Types + t.LongType: "BIGINT", + t.IntegerType: "INT", + t.ShortType: "SMALLINT", + t.DoubleType: "DOUBLE", + t.FloatType: "FLOAT", + t.NumericType: "FLOAT", + t.ByteType: "BINARY", + t.BinaryType: "VARBINARY", + + # Date / Time Types + t.DateType: "DATE", + t.TimestampType: "TIMESTAMP", + t.DayTimeIntervalType: "STRING", + + # Advanced Types + t.DecimalType: + f"DECIMAL({spark_type.precision},{spark_type.scale})" # pylint: disable=no-member + if isinstance(spark_type, t.DecimalType) else "DECIMAL(38,0)", + t.MapType: "VARIANT", + t.ArrayType: "VARIANT", + t.StructType: "VARIANT", + } + return data_type_map.get(type(spark_type), 'STRING') + # fmt: on + + +class SnowflakeSparkStep(SparkStep, SnowflakeBaseModel, ABC): + """Expands the SnowflakeBaseModel so that it can be used as a SparkStep""" + + +class SnowflakeReader(SnowflakeBaseModel, JdbcReader, SparkStep): + """ + Wrapper around JdbcReader for Snowflake. + + Example + ------- + ```python + sr = SnowflakeReader( + url="foo.snowflakecomputing.com", + user="YOUR_USERNAME", + password="***", + database="db", + schema="schema", + ) + df = sr.read() + ``` + + Notes + ----- + * Snowflake is supported natively in Databricks 4.2 and newer: + https://docs.snowflake.com/en/user-guide/spark-connector-databricks + * Refer to Snowflake docs for the installation instructions for non-Databricks environments: + https://docs.snowflake.com/en/user-guide/spark-connector-install + * Refer to Snowflake docs for connection options: + https://docs.snowflake.com/en/user-guide/spark-connector-use#setting-configuration-options-for-the-connector + """ + + format: str = Field(default="snowflake", description="The format to use when writing to Snowflake") + # overriding `driver` property of JdbcReader, because it is not required by Snowflake + driver: Optional[str] = None # type: ignore + + def execute(self) -> SparkStep.Output: + """Read from Snowflake""" + super().execute() + + +class SnowflakeTransformation(SnowflakeBaseModel, Transformation, ABC): + """Adds Snowflake parameters to the Transformation class""" + + +class RunQuery(SnowflakeSparkStep): + """ + Run a query on Snowflake that does not return a result, e.g. create table statement + + This is a wrapper around 'net.snowflake.spark.snowflake.Utils.runQuery' on the JVM + + Example + ------- + ```python + RunQuery( + database="MY_DB", + schema="MY_SCHEMA", + warehouse="MY_WH", + user="account", + password="***", + role="APPLICATION.SNOWFLAKE.ADMIN", + query="CREATE TABLE test (col1 string)", + ).execute() + ``` + """ + + query: str = Field(default=..., description="The query to run", alias="sql") + + @model_validator(mode="after") + def validate_spark_and_deprecate(self) -> RunQuery: + """If we do not have a spark session with a JVM, we can not use spark to run the query""" + warn( + "The RunQuery class is deprecated and will be removed in a future release. " + "Please use the Python connector for Snowflake instead.", + DeprecationWarning, + stacklevel=2, + ) + if not hasattr(self.spark, "_jvm"): + raise RuntimeError( + "Your Spark session does not have a JVM and cannot run Snowflake query using RunQuery implementation. " + "Please update your code to use python connector for Snowflake." + ) + return self + + @field_validator("query") + def validate_query(cls, query: str) -> str: + """Replace escape characters, strip whitespace, ensure it is not empty""" + query = query.replace("\\n", "\n").replace("\\t", "\t").strip() + if not query: + raise ValueError("Query cannot be empty") + return query + + def execute(self) -> RunQuery.Output: + # Executing the RunQuery without `host` option raises the following error: + # An error occurred while calling z:net.snowflake.spark.snowflake.Utils.runQuery. + # : java.util.NoSuchElementException: key not found: host + options = self.get_options() + options["host"] = self.url + # noinspection PyProtectedMember + self.spark._jvm.net.snowflake.spark.snowflake.Utils.runQuery(self.get_options(), self.query) + + +class Query(SnowflakeReader): + """ + Query data from Snowflake and return the result as a DataFrame + + Example + ------- + ```python + Query( + database="MY_DB", + schema_="MY_SCHEMA", + warehouse="MY_WH", + user="gid.account@nike.com", + password=Secret("super-secret-password"), + role="APPLICATION.SNOWFLAKE.ADMIN", + query="SELECT * FROM MY_TABLE", + ).execute().df + ``` + """ + + query: str = Field(default=..., description="The query to run") + + @field_validator("query") + def validate_query(cls, query: str) -> str: + """Replace escape characters""" + query = query.replace("\\n", "\n").replace("\\t", "\t").strip() + return query + + def get_options(self, by_alias: bool = True, include: Set[str] = None) -> Dict[str, Any]: + """add query to options""" + options = super().get_options(by_alias) + options["query"] = self.query + return options + + +class DbTableQuery(SnowflakeReader): + """ + Read table from Snowflake using the `dbtable` option instead of `query` + + Example + ------- + ```python + DbTableQuery( + database="MY_DB", + schema_="MY_SCHEMA", + warehouse="MY_WH", + user="user", + password=Secret("super-secret-password"), + role="APPLICATION.SNOWFLAKE.ADMIN", + table="db.schema.table", + ).execute().df + ``` + """ + + dbtable: str = Field(default=..., alias="table", description="The name of the table") + + +class TableExists(SnowflakeTableStep): + """ + Check if the table exists in Snowflake by using INFORMATION_SCHEMA. + + Example + ------- + ```python + k = TableExists( + url="foo.snowflakecomputing.com", + user="YOUR_USERNAME", + password="***", + database="db", + schema="schema", + table="table", + ) + ``` + """ + + class Output(StepOutput): + """Output class for TableExists""" + + exists: bool = Field(default=..., description="Whether or not the table exists") + + def execute(self) -> Output: + query = ( + dedent( + # Force upper case, due to case-sensitivity of where clause + f""" + SELECT * + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_CATALOG = '{self.database}' + AND TABLE_SCHEMA = '{self.sfSchema}' + AND TABLE_TYPE = 'BASE TABLE' + AND upper(TABLE_NAME) = '{self.table.upper()}' + """ # nosec B608: hardcoded_sql_expressions + ) + .upper() + .strip() + ) + + self.log.debug(f"Query that was executed to check if the table exists:\n{query}") + + df = Query(**self.get_options(), query=query).read() + + exists = df.count() > 0 + self.log.info( + f"Table '{self.database}.{self.sfSchema}.{self.table}' {'exists' if exists else 'does not exist'}" + ) + self.output.exists = exists + + +class CreateOrReplaceTableFromDataFrame(SnowflakeTransformation): + """ + Create (or Replace) a Snowflake table which has the same schema as a Spark DataFrame + + Can be used as any Transformation. The DataFrame is however left unchanged, and only used for determining the + schema of the Snowflake Table that is to be created (or replaced). + + Example + ------- + ```python + CreateOrReplaceTableFromDataFrame( + database="MY_DB", + schema="MY_SCHEMA", + warehouse="MY_WH", + user="gid.account@nike.com", + password="super-secret-password", + role="APPLICATION.SNOWFLAKE.ADMIN", + table="MY_TABLE", + df=df, + ).execute() + ``` + + Or, as a Transformation: + ```python + CreateOrReplaceTableFromDataFrame( + ... + table="MY_TABLE", + ).transform(df) + ``` + + """ + + account: str = Field(default=..., description="The Snowflake account") + table: str = Field(default=..., alias="table_name", description="The name of the (new) table") + + class Output(SnowflakeTransformation.Output): + """Output class for CreateOrReplaceTableFromDataFrame""" + + input_schema: t.StructType = Field(default=..., description="The original schema from the input DataFrame") + snowflake_schema: str = Field( + default=..., description="Derived Snowflake table schema based on the input DataFrame" + ) + query: str = Field(default=..., description="Query that was executed to create the table") + + def execute(self) -> Output: + self.output.df = self.df + + input_schema = self.df.schema + self.output.input_schema = input_schema + + snowflake_schema = ", ".join([f"{c.name} {map_spark_type(c.dataType)}" for c in input_schema]) + self.output.snowflake_schema = snowflake_schema + + table_name = f"{self.database}.{self.sfSchema}.{self.table}" + query = f"CREATE OR REPLACE TABLE {table_name} ({snowflake_schema})" + self.output.query = query + + SnowflakeRunQueryPython(**self.get_options(), query=query).execute() + + +class GetTableSchema(SnowflakeStep): + """ + Get the schema from a Snowflake table as a Spark Schema + + Notes + ----- + * This Step will execute a `SELECT * FROM LIMIT 1` query to get the schema of the table. + * The schema will be stored in the `table_schema` attribute of the output. + * `table_schema` is used as the attribute name to avoid conflicts with the `schema` attribute of Pydantic's + BaseModel. + + Example + ------- + ```python + schema = ( + GetTableSchema( + database="MY_DB", + schema_="MY_SCHEMA", + warehouse="MY_WH", + user="gid.account@nike.com", + password="super-secret-password", + role="APPLICATION.SNOWFLAKE.ADMIN", + table="MY_TABLE", + ) + .execute() + .table_schema + ) + ``` + """ + + table: str = Field(default=..., description="The Snowflake table name") + + class Output(StepOutput): + """Output class for GetTableSchema""" + + table_schema: t.StructType = Field(default=..., serialization_alias="schema", description="The Spark Schema") + + def execute(self) -> Output: + query = f"SELECT * FROM {self.table} LIMIT 1" # nosec B608: hardcoded_sql_expressions + df = Query(**self.get_options(), query=query).execute().df + self.output.table_schema = df.schema + + +class AddColumn(SnowflakeStep): + """ + Add an empty column to a Snowflake table with given name and DataType + + Example + ------- + ```python + AddColumn( + database="MY_DB", + schema_="MY_SCHEMA", + warehouse="MY_WH", + user="gid.account@nike.com", + password=Secret("super-secret-password"), + role="APPLICATION.SNOWFLAKE.ADMIN", + table="MY_TABLE", + col="MY_COL", + dataType=StringType(), + ).execute() + ``` + """ + + table: str = Field(default=..., description="The name of the Snowflake table") + column: str = Field(default=..., description="The name of the new column") + type: DataType = Field(default=..., description="The DataType represented as a Spark DataType") # type: ignore + account: str = Field(default=..., description="The Snowflake account") + + class Output(SnowflakeStep.Output): + """Output class for AddColumn""" + + query: str = Field(default=..., description="Query that was executed to add the column") + + def execute(self) -> Output: + query = f"ALTER TABLE {self.table} ADD COLUMN {self.column} {map_spark_type(self.type)}".upper() + self.output.query = query + SnowflakeRunQueryPython(**self.get_options(), query=query).execute() + + +class SyncTableAndDataFrameSchema(SnowflakeStep, SnowflakeTransformation): + """ + Sync the schema's of a Snowflake table and a DataFrame. This will add NULL columns for the columns that are not in + both and perform type casts where needed. + + The Snowflake table will take priority in case of type conflicts. + """ + + df: DataFrame = Field(default=..., description="The Spark DataFrame") + table: str = Field(default=..., description="The table name") + dry_run: bool = Field(default=False, description="Only show schema differences, do not apply changes") + + class Output(SparkStep.Output): + """Output class for SyncTableAndDataFrameSchema""" + + original_df_schema: t.StructType = Field(default=..., description="Original DataFrame schema") + original_sf_schema: t.StructType = Field(default=..., description="Original Snowflake schema") + new_df_schema: t.StructType = Field(default=..., description="New DataFrame schema") + new_sf_schema: t.StructType = Field(default=..., description="New Snowflake schema") + sf_table_altered: bool = Field( + default=False, description="Flag to indicate whether Snowflake schema has been altered" + ) + + def execute(self) -> Output: + self.log.warning("Snowflake table will always take a priority in case of data type conflicts!") + + # spark side + df_schema = self.df.schema + self.output.original_df_schema = deepcopy(df_schema) # using deepcopy to avoid storing in place changes + df_cols = {c.name.lower() for c in df_schema} + + # snowflake side + _options = {**self.get_options(), "table": self.table} + sf_schema = GetTableSchema(**_options).execute().table_schema + self.output.original_sf_schema = sf_schema + sf_cols = {c.name.lower() for c in sf_schema} + + if self.dry_run: + # Display differences between Spark DataFrame and Snowflake schemas + # and provide dummy values that are expected as class outputs. + _sf_diff = df_cols - sf_cols + self.log.warning(f"Columns to be added to Snowflake table: {set(df_cols) - set(sf_cols)}") + _df_diff = sf_cols - df_cols + self.log.warning(f"Columns to be added to Spark DataFrame: {set(sf_cols) - set(df_cols)}") + + self.output.new_df_schema = t.StructType() + self.output.new_sf_schema = t.StructType() + self.output.df = self.df + self.output.sf_table_altered = False + + else: + # Add columns to SnowFlake table that exist in DataFrame + for df_column in df_schema: + if df_column.name.lower() not in sf_cols: + AddColumn( + **self.get_options(), + table=self.table, + column=df_column.name, + type=df_column.dataType, + ).execute() + self.output.sf_table_altered = True + + if self.output.sf_table_altered: + sf_schema = GetTableSchema(**self.get_options(), table=self.table).execute().table_schema + sf_cols = {c.name.lower() for c in sf_schema} + + self.output.new_sf_schema = sf_schema + + # Add NULL columns to the DataFrame if they exist in SnowFlake but not in the df + df = self.df + for sf_col in self.output.original_sf_schema: + sf_col_name = sf_col.name.lower() + if sf_col_name not in df_cols: + sf_col_type = sf_col.dataType + df = df.withColumn(sf_col_name, f.lit(None).cast(sf_col_type)) # type: ignore + + # Put DataFrame columns in the same order as the Snowflake table + df = df.select(*sf_cols) + + self.output.df = df + self.output.new_df_schema = df.schema + + +class SnowflakeWriter(SnowflakeBaseModel, Writer): + """Class for writing to Snowflake + + See Also + -------- + - [koheesio.spark.writers.Writer](writers/index.md#koheesio.spark.writers.Writer) + - [koheesio.spark.writers.BatchOutputMode](writers/index.md#koheesio.spark.writers.BatchOutputMode) + - [koheesio.spark.writers.StreamingOutputMode](writers/index.md#koheesio.spark.writers.StreamingOutputMode) + """ + + table: str = Field(default=..., description="Target table name") + insert_type: Optional[BatchOutputMode] = Field( + BatchOutputMode.APPEND, alias="mode", description="The insertion type, append or overwrite" + ) + format: str = Field("snowflake", description="The format to use when writing to Snowflake") + + def execute(self) -> SnowflakeWriter.Output: + """Write to Snowflake""" + self.log.debug(f"writing to {self.table} with mode {self.insert_type}") + self.df.write.format(self.format).options(**self.get_options()).option("dbtable", self.table).mode( + self.insert_type + ).save() + + +class SynchronizeDeltaToSnowflakeTask(SnowflakeSparkStep): + """ + Synchronize a Delta table to a Snowflake table + + * Overwrite - only in batch mode + * Append - supports batch and streaming mode + * Merge - only in streaming mode + + Example + ------- + ```python + SynchronizeDeltaToSnowflakeTask( + account="acme", + url="acme.snowflakecomputing.com", + user="admin", + role="ADMIN", + warehouse="SF_WAREHOUSE", + database="SF_DATABASE", + schema="SF_SCHEMA", + source_table=DeltaTableStep(...), + target_table="my_sf_table", + key_columns=[ + "id", + ], + streaming=False, + ).run() + ``` + """ + + source_table: DeltaTableStep = Field(default=..., description="Source delta table to synchronize") + target_table: str = Field(default=..., description="Target table in snowflake to synchronize to") + synchronisation_mode: BatchOutputMode = Field( + default=BatchOutputMode.MERGE, + description="Determines if synchronisation will 'overwrite' any existing table, 'append' new rows or " + "'merge' with existing rows.", + ) + checkpoint_location: Optional[str] = Field(default=None, description="Checkpoint location to use") + schema_tracking_location: Optional[str] = Field( + default=None, + description="Schema tracking location to use. " + "Info: https://docs.delta.io/latest/delta-streaming.html#-schema-tracking", + ) + staging_table_name: Optional[str] = Field( + default=None, alias="staging_table", description="Optional snowflake staging name", validate_default=False + ) + key_columns: Optional[List[str]] = Field( + default_factory=list, + description="Key columns on which merge statements will be MERGE statement will be applied.", + ) + streaming: bool = Field( + default=False, + description="Should synchronisation happen in streaming or in batch mode. Streaming is supported in 'APPEND' " + "and 'MERGE' mode. Batch is supported in 'OVERWRITE' and 'APPEND' mode.", + ) + persist_staging: bool = Field( + default=False, + description="In case of debugging, set `persist_staging` to True to retain the staging table for inspection " + "after synchronization.", + ) + enable_deletion: bool = Field( + default=False, + description="In case of merge synchronisation_mode add deletion statement in merge query.", + ) + account: Optional[str] = Field( + default=None, + description="The Snowflake account to connect to. " + "If not provided, the `truncate_table` and `drop_table` methods will fail.", + ) + + writer_: Optional[Union[ForEachBatchStreamWriter, SnowflakeWriter]] = None + + @field_validator("staging_table_name") + def _validate_staging_table(cls, staging_table_name: str) -> str: + """Validate the staging table name and return it if it's valid.""" + if "." in staging_table_name: + raise ValueError( + "Custom staging table must not contain '.', it is located in the same Schema as the target table." + ) + return staging_table_name + + @model_validator(mode="before") + def _checkpoint_location_check(cls, values: Dict) -> Dict: + """Give a warning if checkpoint location is given but not expected and vice versa""" + streaming = values.get("streaming") + checkpoint_location = values.get("checkpoint_location") + log = LoggingFactory.get_logger(cls.__name__) + + if streaming is False and checkpoint_location is not None: + log.warning("checkpoint_location is provided but will be ignored in batch mode") + if streaming is True and checkpoint_location is None: + log.warning("checkpoint_location is not provided in streaming mode") + return values + + @model_validator(mode="before") + def _synch_mode_check(cls, values: Dict) -> Dict: + """Validate requirements for various synchronisation modes""" + streaming = values.get("streaming") + synchronisation_mode = values.get("synchronisation_mode") + key_columns = values.get("key_columns") + + allowed_output_modes = [BatchOutputMode.OVERWRITE, BatchOutputMode.MERGE, BatchOutputMode.APPEND] + + if synchronisation_mode not in allowed_output_modes: + raise ValueError( + f"Synchronisation mode should be one of {', '.join([m.value for m in allowed_output_modes])}" + ) + if synchronisation_mode == BatchOutputMode.OVERWRITE and streaming is True: + raise ValueError("Synchronisation mode can't be 'OVERWRITE' with streaming enabled") + if synchronisation_mode == BatchOutputMode.MERGE and streaming is False: + raise ValueError("Synchronisation mode can't be 'MERGE' with streaming disabled") + if synchronisation_mode == BatchOutputMode.MERGE and len(key_columns) < 1: # type: ignore + raise ValueError("MERGE synchronisation mode requires a list of PK columns in `key_columns`.") + + return values + + @property + def non_key_columns(self) -> List[str]: + """Columns of source table that aren't part of the (composite) primary key""" + lowercase_key_columns: Set[str] = {c.lower() for c in self.key_columns} # type: ignore + source_table_columns = self.source_table.columns + non_key_columns: List[str] = [c for c in source_table_columns if c.lower() not in lowercase_key_columns] # type: ignore + return non_key_columns + + @property + def staging_table(self) -> str: + """Intermediate table on snowflake where staging results are stored""" + if stg_tbl_name := self.staging_table_name: + return stg_tbl_name + + return f"{self.source_table.table}_stg" + + @property + def reader(self) -> Union[DeltaTableReader, DeltaTableStreamReader]: + """ + DeltaTable reader + + Returns: + -------- + DeltaTableReader + DeltaTableReader the will yield source delta table + """ + # Wrap in lambda functions to mimic lazy evaluation. + # This ensures the Task doesn't fail if a config isn't provided for a reader/writer that isn't used anyway + map_mode_reader = { + BatchOutputMode.OVERWRITE: lambda: DeltaTableReader( + table=self.source_table, streaming=False, schema_tracking_location=self.schema_tracking_location + ), + BatchOutputMode.APPEND: lambda: DeltaTableReader( + table=self.source_table, + streaming=self.streaming, + schema_tracking_location=self.schema_tracking_location, + ), + BatchOutputMode.MERGE: lambda: DeltaTableStreamReader( + table=self.source_table, read_change_feed=True, schema_tracking_location=self.schema_tracking_location + ), + } + return map_mode_reader[self.synchronisation_mode]() + + def _get_writer(self) -> Union[SnowflakeWriter, ForEachBatchStreamWriter]: + """ + Writer to persist to snowflake + + Depending on configured options, this returns an SnowflakeWriter or ForEachBatchStreamWriter: + - OVERWRITE/APPEND mode yields SnowflakeWriter + - MERGE mode yields ForEachBatchStreamWriter + + Returns + ------- + ForEachBatchStreamWriter | SnowflakeWriter + The right writer for the configured options and mode + """ + # Wrap in lambda functions to mimic lazy evaluation. + # This ensures the Task doesn't fail if a config isn't provided for a reader/writer that isn't used anyway + map_mode_writer = { + (BatchOutputMode.OVERWRITE, False): lambda: SnowflakeWriter( + table=self.target_table, insert_type=BatchOutputMode.OVERWRITE, **self.get_options() + ), + (BatchOutputMode.APPEND, False): lambda: SnowflakeWriter( + table=self.target_table, insert_type=BatchOutputMode.APPEND, **self.get_options() + ), + (BatchOutputMode.APPEND, True): lambda: ForEachBatchStreamWriter( + checkpointLocation=self.checkpoint_location, + batch_function=writer_to_foreachbatch( + SnowflakeWriter(table=self.target_table, insert_type=BatchOutputMode.APPEND, **self.get_options()) + ), + ), + (BatchOutputMode.MERGE, True): lambda: ForEachBatchStreamWriter( + checkpointLocation=self.checkpoint_location, + batch_function=self._merge_batch_write_fn( + key_columns=self.key_columns, # type: ignore + non_key_columns=self.non_key_columns, + staging_table=self.staging_table, + ), + ), + } + return map_mode_writer[(self.synchronisation_mode, self.streaming)]() + + @property + def writer(self) -> Union[ForEachBatchStreamWriter, SnowflakeWriter]: + """ + Writer to persist to snowflake + + Depending on configured options, this returns an SnowflakeWriter or ForEachBatchStreamWriter: + - OVERWRITE/APPEND mode yields SnowflakeWriter + - MERGE mode yields ForEachBatchStreamWriter + + Returns + ------- + Union[ForEachBatchStreamWriter, SnowflakeWriter] + """ + # Cache 'writer' object in memory to ensure same object is used everywhere, this ensures access to underlying + # member objects such as active streaming queries (if any). + if not self.writer_: + self.writer_ = self._get_writer() + return self.writer_ + + def truncate_table(self, snowflake_table: str) -> None: + """Truncate a given snowflake table""" + truncate_query = f"""TRUNCATE TABLE IF EXISTS {snowflake_table}""" # nosec B608: hardcoded_sql_expressions + query_executor = SnowflakeRunQueryPython( + **self.get_options(), + query=truncate_query, + ) + query_executor.execute() + + def drop_table(self, snowflake_table: str) -> None: + """Drop a given snowflake table""" + self.log.warning(f"Dropping table {snowflake_table} from snowflake") + drop_table_query = f"""DROP TABLE IF EXISTS {snowflake_table}""" # nosec B608: hardcoded_sql_expressions + query_executor = SnowflakeRunQueryPython(**self.get_options(), query=drop_table_query) + query_executor.execute() + + def _merge_batch_write_fn(self, key_columns: List[str], non_key_columns: List[str], staging_table: str) -> Callable: + """Build a batch write function for merge mode""" + + # pylint: disable=unused-argument + # noinspection PyUnusedLocal,PyPep8Naming + def inner(dataframe: DataFrame, batchId: int): # type: ignore + self._build_staging_table(dataframe, key_columns, non_key_columns, staging_table) + self._merge_staging_table_into_target() + + # pylint: enable=unused-argument + return inner + + @staticmethod + def _compute_latest_changes_per_pk( + dataframe: DataFrame, key_columns: List[str], non_key_columns: List[str] + ) -> DataFrame: + """Compute the latest changes per primary key""" + window_spec = Window.partitionBy(*key_columns).orderBy(f.col("_commit_version").desc()) + ranked_df = ( + dataframe.filter("_change_type != 'update_preimage'") + .withColumn("rank", f.rank().over(window_spec)) # type: ignore + .filter("rank = 1") + .select(*key_columns, *non_key_columns, "_change_type") # discard unused columns + .distinct() + ) + return ranked_df + + def _build_staging_table( + self, dataframe: DataFrame, key_columns: List[str], non_key_columns: List[str], staging_table: str + ) -> None: + """Build snowflake staging table""" + ranked_df = self._compute_latest_changes_per_pk(dataframe, key_columns, non_key_columns) + batch_writer = SnowflakeWriter( + table=staging_table, df=ranked_df, insert_type=BatchOutputMode.APPEND, **self.get_options() + ) + batch_writer.execute() + + def _merge_staging_table_into_target(self) -> None: + """ + Merge snowflake staging table into final snowflake table + """ + merge_query = self._build_sf_merge_query( + target_table=self.target_table, + stage_table=self.staging_table, + pk_columns=[*(self.key_columns or [])], + non_pk_columns=self.non_key_columns, + enable_deletion=self.enable_deletion, + ) # type: ignore + + query_executor = RunQuery( + **self.get_options(), + query=merge_query, + ) + query_executor.execute() + + @staticmethod + def _build_sf_merge_query( + target_table: str, + stage_table: str, + pk_columns: List[str], + non_pk_columns: List[str], + enable_deletion: bool = False, + ) -> str: + """Build a CDF merge query string + + Parameters + ---------- + target_table: Table + Destination table to merge into + stage_table: Table + Temporary table containing updates to be executed + pk_columns: List[str] + Column names used to uniquely identify each row + non_pk_columns: List[str] + Non-key columns that may need to be inserted/updated + enable_deletion: bool + DELETE actions are synced. If set to False (default) then sync is non-destructive + + Returns + ------- + str + Query to be executed on the target database + """ + all_fields = [*pk_columns, *non_pk_columns] + key_join_string = " AND ".join(f"target.{k} = temp.{k}" for k in pk_columns) + columns_string = ", ".join(all_fields) + assignment_string = ", ".join(f"{k} = temp.{k}" for k in non_pk_columns) + values_string = ", ".join(f"temp.{k}" for k in all_fields) + + query = dedent( + f""" + MERGE INTO {target_table} target + USING {stage_table} temp ON {key_join_string} + WHEN MATCHED AND temp._change_type = 'update_postimage' + THEN UPDATE SET {assignment_string} + WHEN NOT MATCHED AND temp._change_type != 'delete' + THEN INSERT ({columns_string}) + VALUES ({values_string}) + {"WHEN MATCHED AND temp._change_type = 'delete' THEN DELETE" if enable_deletion else ""}""" + ).strip() # nosec B608: hardcoded_sql_expressions + + return query + + def extract(self) -> DataFrame: + """ + Extract source table + """ + if self.synchronisation_mode == BatchOutputMode.MERGE: + if not self.source_table.is_cdf_active: + raise RuntimeError( + f"Source table {self.source_table.table_name} does not have CDF enabled. " + f"Set TBLPROPERTIES ('delta.enableChangeDataFeed' = true) to enable. " + f"Current properties = {self.source_table_properties}" + ) + + df = self.reader.read() + self.output.source_df = df + return df + + def load(self, df: DataFrame) -> DataFrame: + """Load source table into snowflake""" + if self.synchronisation_mode == BatchOutputMode.MERGE: + self.log.info(f"Truncating staging table {self.staging_table}") # type: ignore + self.truncate_table(self.staging_table) + self.writer.write(df) + self.output.target_df = df + return df + + def execute(self) -> SynchronizeDeltaToSnowflakeTask.Output: + # extract + df = self.extract() + self.output.source_df = df + + # synchronize + self.output.target_df = df + self.load(df) + if not self.persist_staging: + # If it's a streaming job, await for termination before dropping staging table + if self.streaming: + self.writer.await_termination() # type: ignore + self.drop_table(self.staging_table) + + +class TagSnowflakeQuery(Step, ExtraParamsMixin): + """ + Provides Snowflake query tag pre-action that can be used to easily find queries through SF history search + and further group them for debugging and cost tracking purposes. + + Takes in query tag attributes as kwargs and additional Snowflake options dict that can optionally contain + other set of pre-actions to be applied to a query, in that case existing pre-action aren't dropped, query tag + pre-action will be added to them. + + Passed Snowflake options dictionary is not modified in-place, instead anew dictionary containing updated pre-actions + is returned. + + Notes + ----- + See this article for explanation: https://select.dev/posts/snowflake-query-tags + + Arbitrary tags can be applied, such as team, dataset names, business capability, etc. + + Example + ------- + #### Using `options` parameter + ```python + query_tag = AddQueryTag( + options={"preactions": "ALTER SESSION"}, + task_name="cleanse_task", + pipeline_name="ingestion-pipeline", + etl_date="2022-01-01", + pipeline_execution_time="2022-01-01T00:00:00", + task_execution_time="2022-01-01T01:00:00", + environment="dev", + trace_id="acd4f3f96045", + span_id="546d2d66f6cb", + ).execute().options + ``` + In this example, the query tag pre-action will be added to the Snowflake options. + + #### Using `preactions` parameter + Instead of using `options` parameter, you can also use `preactions` parameter to provide existing preactions. + ```python + query_tag = AddQueryTag( + preactions="ALTER SESSION" + ... + ).execute().options + ``` + + The result will be the same as in the previous example. + + #### Using `get_options` method + The shorthand method `get_options` can be used to get the `options` dictionary. + ```python + query_tag = AddQueryTag(...).get_options() + ``` + """ + + options: Dict = Field( + default_factory=dict, description="Additional Snowflake options, optionally containing additional preactions" + ) + + preactions: Optional[str] = Field(default="", description="Existing preactions from Snowflake options") + + class Output(StepOutput): + """Output class for AddQueryTag""" + + options: Dict = Field(default=..., description="Snowflake options dictionary with added query tag preaction") + + def execute(self) -> Output: + """Add query tag preaction to Snowflake options""" + tag_json = json.dumps(self.extra_params, indent=4, sort_keys=True) + tag_preaction = f"ALTER SESSION SET QUERY_TAG = '{tag_json}';" + preactions = self.options.get("preactions", self.preactions) + # update options with new preactions + self.output.options = {**self.options, "preactions": f"{preactions}\n{tag_preaction}".strip()} + + def get_options(self) -> Dict: + """shorthand method to get the options dictionary + + Functionally equivalent to running `execute().options` + + Returns + ------- + Dict + Snowflake options dictionary with added query tag preaction + """ + return self.execute().options diff --git a/src/koheesio/integrations/spark/tableau/hyper.py b/src/koheesio/integrations/spark/tableau/hyper.py index b3330dbe..992d9f19 100644 --- a/src/koheesio/integrations/spark/tableau/hyper.py +++ b/src/koheesio/integrations/spark/tableau/hyper.py @@ -19,7 +19,6 @@ from pydantic import Field, conlist -from pyspark.sql import DataFrame from pyspark.sql.functions import col from pyspark.sql.types import ( BooleanType, @@ -36,9 +35,9 @@ TimestampType, ) -from koheesio.spark.readers import SparkStep +from koheesio.spark import DataFrame, SparkStep from koheesio.spark.transformations.cast_to_datatype import CastToDatatype -from koheesio.spark.utils import spark_minor_version +from koheesio.spark.utils import SPARK_MINOR_VERSION from koheesio.steps import Step, StepOutput @@ -79,7 +78,7 @@ class HyperFileReader(HyperFile, SparkStep): default=..., description="Path to the Hyper file", examples=["PurePath(~/data/my-file.hyper)"] ) - def execute(self): + def execute(self) -> SparkStep.Output: type_mapping = { "date": StringType, "text": StringType, @@ -165,7 +164,7 @@ class Output(StepOutput): hyper_path: PurePath = Field(default=..., description="Path to created Hyper file") @property - def hyper_path(self) -> Connection: + def hyper_path(self) -> PurePath: """ Return full path to the Hyper file. """ @@ -176,11 +175,11 @@ def hyper_path(self) -> Connection: self.log.info(f"Destination file: {hyper_path}") return hyper_path - def write(self): + def write(self) -> Output: self.execute() @abstractmethod - def execute(self): + def execute(self) -> Output: pass @@ -225,7 +224,7 @@ class HyperFileListWriter(HyperFileWriter): data: conlist(List[Any], min_length=1) = Field(default=..., description="List of rows to write to the Hyper file") - def execute(self): + def execute(self) -> HyperFileWriter.Output: with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hp: with Connection( endpoint=hp.endpoint, database=self.hyper_path, create_mode=CreateMode.CREATE_AND_REPLACE @@ -288,7 +287,7 @@ class HyperFileParquetWriter(HyperFileWriter): default=..., alias="files", description="One or multiple parquet files to write to the Hyper file" ) - def execute(self): + def execute(self) -> HyperFileWriter.Output: _file = [str(f) for f in self.file] array_files = "'" + "','".join(_file) + "'" @@ -352,7 +351,7 @@ def table_definition_column(column: StructField) -> TableDefinition.Column: # Handling the TimestampNTZType for Spark 3.4+ # Mapping both TimestampType and TimestampNTZType to NTZ type of Hyper - if spark_minor_version >= 3.4: + if SPARK_MINOR_VERSION >= 3.4: from pyspark.sql.types import TimestampNTZType type_mapping[TimestampNTZType()] = SqlType.timestamp @@ -362,7 +361,7 @@ def table_definition_column(column: StructField) -> TableDefinition.Column: type_mapping[TimestampType()] = SqlType.timestamp_tz if column.dataType in type_mapping: - sql_type = type_mapping[column.dataType]() + sql_type = type_mapping[column.dataType]() # type: ignore elif str(column.dataType).startswith("DecimalType"): # Tableau Hyper API limits the precision to 18 decimal places # noinspection PyUnresolvedReferences @@ -407,11 +406,11 @@ def clean_dataframe(self) -> DataFrame: # Handling the TimestampNTZType for Spark 3.4+ # Any TimestampType column will be cast to TimestampNTZType for compatibility with Tableau Hyper API - if spark_minor_version >= 3.4: + if SPARK_MINOR_VERSION >= 3.4: from pyspark.sql.types import TimestampNTZType for t_col in timestamp_cols: - _df = _df.withColumn(t_col, col(t_col).cast(TimestampNTZType())) + _df = _df.withColumn(t_col, col(t_col).cast(TimestampNTZType())) # type: ignore # Replace null and NaN values with 0 if len(integer_cols) > 0: @@ -436,14 +435,14 @@ def clean_dataframe(self) -> DataFrame: if d_col.dataType.precision > 18: # noinspection PyUnresolvedReferences _df = _df.withColumn( - d_col.name, col(d_col.name).cast(DecimalType(precision=18, scale=d_col.dataType.scale)) + d_col.name, col(d_col.name).cast(DecimalType(precision=18, scale=d_col.dataType.scale)) # type: ignore ) if len(decimal_col_names) > 0: _df = _df.na.fill(0.0, decimal_col_names) return _df - def write_parquet(self): + def write_parquet(self) -> List[PurePath]: _path = self.path.joinpath("parquet") ( self.clean_dataframe() @@ -465,7 +464,7 @@ def write_parquet(self): self.log.info("Parquet file created: %s", fp) return [fp] - def execute(self): + def execute(self) -> HyperFileWriter.Output: w = HyperFileParquetWriter( path=self.path, name=self.name, table_definition=self._table_definition, files=self.write_parquet() ) diff --git a/src/koheesio/integrations/spark/tableau/server.py b/src/koheesio/integrations/spark/tableau/server.py index fc6f958f..7770f627 100644 --- a/src/koheesio/integrations/spark/tableau/server.py +++ b/src/koheesio/integrations/spark/tableau/server.py @@ -1,17 +1,17 @@ import os -from typing import ContextManager, Optional, Union +from typing import Any, ContextManager, Optional, Union from enum import Enum from pathlib import PurePath -import urllib3 +import urllib3 # type: ignore from tableauserverclient import ( DatasourceItem, - Pager, PersonalAccessTokenAuth, ProjectItem, - Server, TableauAuth, ) +from tableauserverclient.server.pager import Pager +from tableauserverclient.server.server import Server from pydantic import Field, SecretStr @@ -68,22 +68,22 @@ class TableauServer(Step): description="ID of the project on the Tableau server", ) - def __init__(self, **data): + def __init__(self, **data: Any) -> None: super().__init__(**data) - self.server = None + self.server: Optional[Server] = None @model_validator(mode="after") - def validate_project(cls, data: dict) -> dict: + def validate_project(self) -> "TableauServer": """Validate when project and project_id are provided at the same time.""" - project = data.get("project") - project_id = data.get("project_id") - if project and project_id: + if self.project and self.project_id: raise ValueError("Both 'project' and 'project_id' parameters cannot be provided at the same time.") - if not project and not project_id: + if not self.project_id and not self.project_id: raise ValueError("Either 'project' or 'project_id' parameters should be provided, none is set") + return self + @property def auth(self) -> ContextManager: """ @@ -101,8 +101,10 @@ def auth(self) -> ContextManager: ContextManager for TableauAuth or PersonalAccessTokenAuth authorization object """ # Suppress 'InsecureRequestWarning' + # noinspection PyUnresolvedReferences urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + tableau_auth: Union[TableauAuth, PersonalAccessTokenAuth] tableau_auth = TableauAuth(username=self.user, password=self.password.get_secret_value(), site_id=self.site_id) if self.token_name and self.token_value: @@ -174,7 +176,7 @@ def working_project(self) -> Union[ProjectItem, None]: self.log.info(f"\nWorking project identified:\n\tName: {lim_p[0].name}\n\tID: {lim_p[0].id}") return lim_p[0] - def execute(self): + def execute(self) -> None: raise NotImplementedError("Method `execute` must be implemented in the subclass.") @@ -208,7 +210,7 @@ class Output(StepOutput): default=..., description="DatasourceItem object representing the published datasource" ) - def execute(self): + def execute(self) -> None: # Ensure that the Hyper File exists if not os.path.isfile(self.hyper_path): raise FileNotFoundError(f"Hyper file not found at: {self.hyper_path.as_posix()}") @@ -219,7 +221,7 @@ def execute(self): self.log.debug(f"Create mode: {self.publish_mode}") datasource_item = self.server.datasources.publish( - datasource_item=DatasourceItem(project_id=self.working_project.id, name=self.datasource_name), + datasource_item=DatasourceItem(project_id=str(self.working_project.id), name=self.datasource_name), file=self.hyper_path.as_posix(), mode=self.publish_mode, ) @@ -227,5 +229,5 @@ def execute(self): self.output.datasource_item = datasource_item - def publish(self): + def publish(self) -> None: self.execute() diff --git a/src/koheesio/logger.py b/src/koheesio/logger.py index 9f00d363..cad22138 100644 --- a/src/koheesio/logger.py +++ b/src/koheesio/logger.py @@ -34,7 +34,7 @@ import os import sys from typing import Any, Dict, Generator, Generic, List, Optional, Tuple, TypeVar -from logging import Formatter, Logger, getLogger +from logging import Formatter, Logger, LogRecord, getLogger from uuid import uuid4 from warnings import warn @@ -108,7 +108,7 @@ def __get_validators__(cls) -> Generator: yield cls.validate @classmethod - def validate(cls, v: Any, _values): + def validate(cls, v: Any, _values: Any) -> Masked: """ Validate the input value and return an instance of the class. @@ -165,7 +165,7 @@ class LoggerIDFilter(logging.Filter): LOGGER_ID: str = str(uuid4()) - def filter(self, record): + def filter(self, record: LogRecord) -> bool: record.logger_id = LoggerIDFilter.LOGGER_ID return True @@ -240,11 +240,13 @@ def add_handlers(handlers: List[Tuple[str, Dict]]) -> None: handler_class: logging.Handler = import_class(handler_module_class) handler_level = handler_conf.pop("level") if "level" in handler_conf else "WARNING" # noinspection PyCallingNonCallable - handler = handler_class(**handler_conf) + handler = handler_class(**handler_conf) # type: ignore[operator] handler.setLevel(handler_level) handler.addFilter(LoggingFactory.LOGGER_FILTER) handler.setFormatter(LoggingFactory.LOGGER_FORMATTER) - LoggingFactory.LOGGER.addHandler(handler) + + if LoggingFactory.LOGGER: + LoggingFactory.LOGGER.addHandler(handler) @staticmethod def get_logger(name: str, inherit_from_koheesio: bool = False) -> Logger: diff --git a/src/koheesio/models/__init__.py b/src/koheesio/models/__init__.py index 5ab80f1a..dd0a8c8b 100644 --- a/src/koheesio/models/__init__.py +++ b/src/koheesio/models/__init__.py @@ -9,15 +9,38 @@ Transformation and Reader classes. """ -from typing import Annotated, Any, Dict, List, Optional, Union +from __future__ import annotations + from abc import ABC from functools import cached_property from pathlib import Path +from typing import Annotated, Any, Dict, List, Optional, Union # to ensure that koheesio.models is a drop in replacement for pydantic from pydantic import BaseModel as PydanticBaseModel -from pydantic import * # noqa +from pydantic import ( + BeforeValidator, + ConfigDict, + Field, + InstanceOf, + PositiveInt, + PrivateAttr, + SecretBytes, + SecretStr, + SkipValidation, + ValidationError, + conint, + conlist, + constr, + field_serializer, + field_validator, + model_validator, +) + +# noinspection PyProtectedMember from pydantic._internal._generics import PydanticGenericMetadata + +# noinspection PyProtectedMember from pydantic._internal._model_construction import ModelMetaclass from koheesio.context import Context @@ -28,15 +51,28 @@ "ExtraParamsMixin", "Field", "ListOfColumns", + # Directly from pydantic + "ConfigDict", + "InstanceOf", "ModelMetaclass", + "PositiveInt", + "PrivateAttr", "PydanticGenericMetadata", + "SecretBytes", + "SecretStr", + "SkipValidation", + "ValidationError", + "conint", + "conlist", + "constr", + "field_serializer", "field_validator", "model_validator", ] # pylint: disable=function-redefined -class BaseModel(PydanticBaseModel, ABC): +class BaseModel(PydanticBaseModel, ABC): # type: ignore[no-redef] """ Base model for all models. @@ -154,7 +190,7 @@ class Person(BaseModel): Koheesio specific configuration: ------------------------------- - Koheesio models are configured differently from Pydantic defaults. The following configuration is used: + Koheesio models are configured differently from Pydantic defaults. The configuration looks like this: 1. *extra="allow"*\n This setting allows for extra fields that are not specified in the model definition. If a field is present in @@ -175,8 +211,8 @@ class Person(BaseModel): This setting determines whether the model should be revalidated when the data is changed. If set to `True`, every time a field is assigned a new value, the entire model is validated again.\n Pydantic default is (also) `False`, which means that the model is not revalidated when the data is changed. - The default behavior of Pydantic is to validate the data when the model is created. In case the user changes - the data after the model is created, the model is _not_ revalidated. + By default, Pydantic validates the data when creating the model. If the user changes the data after creating + the model, it does _not_ revalidate the model. 5. *revalidate_instances="subclass-instances"*\n This setting determines whether to revalidate models during validation if the instance is a subclass of the @@ -222,7 +258,7 @@ class Person(BaseModel): description: Optional[str] = Field(default=None, description="Description of the Model") @model_validator(mode="after") - def _validate_name_and_description(self): + def _validate_name_and_description(self): # type: ignore[no-untyped-def] """ Validates the 'name' and 'description' of the Model according to the rules outlined in the class docstring. """ @@ -246,7 +282,7 @@ def log(self) -> Logger: return LoggingFactory.get_logger(name=self.__class__.__name__, inherit_from_koheesio=True) @classmethod - def from_basemodel(cls, basemodel: BaseModel, **kwargs) -> InstanceOf[BaseModel]: + def from_basemodel(cls, basemodel: BaseModel, **kwargs) -> InstanceOf[BaseModel]: # type: ignore[no-untyped-def] """Returns a new BaseModel instance based on the data of another BaseModel""" kwargs = {**basemodel.model_dump(), **kwargs} return cls(**kwargs) @@ -354,7 +390,7 @@ def from_yaml(cls, yaml_file_or_str: str) -> BaseModel: return cls.from_context(_context) @classmethod - def lazy(cls): + def lazy(cls): # type: ignore[no-untyped-def] """Constructs the model without doing validation Essentially an alias to BaseModel.construct() @@ -371,9 +407,7 @@ def __add__(self, other: Union[Dict, BaseModel]) -> BaseModel: ```python step_output_1 = StepOutput(foo="bar") step_output_2 = StepOutput(lorem="ipsum") - ( - step_output_1 + step_output_2 - ) # step_output_1 will now contain {'foo': 'bar', 'lorem': 'ipsum'} + (step_output_1 + step_output_2) # step_output_1 will now contain {'foo': 'bar', 'lorem': 'ipsum'} ``` Parameters @@ -388,10 +422,10 @@ def __add__(self, other: Union[Dict, BaseModel]) -> BaseModel: """ return self.merge(other) - def __enter__(self): + def __enter__(self): # type: ignore[no-untyped-def] return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore[no-untyped-def] if exc_type is not None: # An exception occurred. We log it and raise it again. self.log.exception(f"An exception occurred: {exc_val}") @@ -401,7 +435,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.validate() return True - def __getitem__(self, name) -> Any: + def __getitem__(self, name) -> Any: # type: ignore[no-untyped-def] """Get Item dunder method for BaseModel Allows for subscriptable (`class[key]`) type of access to the data. @@ -425,7 +459,7 @@ def __getitem__(self, name) -> Any: """ return self.__getattribute__(name) - def __setitem__(self, key: str, value: Any): + def __setitem__(self, key: str, value: Any): # type: ignore[no-untyped-def] """Set Item dunder method for BaseModel Allows for subscribing / assigning to `class[key]` @@ -459,7 +493,7 @@ def hasattr(self, key: str) -> bool: """ return hasattr(self, key) - def get(self, key: str, default: Optional[Any] = None): + def get(self, key: str, default: Optional[Any] = None) -> Any: """Get an attribute of the model, but don't fail if not present Similar to dict.get() @@ -488,7 +522,7 @@ def get(self, key: str, default: Optional[Any] = None): return self.__getitem__(key) return default - def merge(self, other: Union[Dict, BaseModel]): + def merge(self, other: Union[Dict, BaseModel]) -> BaseModel: """Merge key,value map with self Functionally similar to adding two dicts together; like running `{**dict_a, **dict_b}`. @@ -497,9 +531,7 @@ def merge(self, other: Union[Dict, BaseModel]): -------- ```python step_output = StepOutput(foo="bar") - step_output.merge( - {"lorem": "ipsum"} - ) # step_output will now contain {'foo': 'bar', 'lorem': 'ipsum'} + step_output.merge({"lorem": "ipsum"}) # step_output will now contain {'foo': 'bar', 'lorem': 'ipsum'} ``` Parameters @@ -515,7 +547,7 @@ def merge(self, other: Union[Dict, BaseModel]): return self - def set(self, key: str, value: Any): + def set(self, key: str, value: Any) -> None: """Allows for subscribing / assigning to `class[key]`. Examples @@ -552,7 +584,7 @@ def to_dict(self) -> Dict[str, Any]: """ return self.model_dump() - def to_json(self, pretty: bool = False): + def to_json(self, pretty: bool = False) -> str: """Converts the BaseModel instance to a JSON string BaseModel offloads the serialization and deserialization of the JSON string to Context class. Context uses @@ -646,19 +678,19 @@ class ExtraParamsMixin(PydanticBaseModel): params: Dict[str, Any] = Field(default_factory=dict) @cached_property - def extra_params(self) -> Dict[str, Any]: + def extra_params(self) -> Optional[Dict[str, Any]]: """Extract params (passed as arbitrary kwargs) from values and move them to params dict""" # noinspection PyUnresolvedReferences return self.model_extra @model_validator(mode="after") - def _move_extra_params_to_params(self): + def _move_extra_params_to_params(self): # type: ignore[no-untyped-def] """Move extra_params to params dict""" - self.params = {**self.params, **self.extra_params} + self.params = {**self.params, **self.extra_params} # type: ignore[assignment] return self -def _list_of_columns_validation(columns_value): +def _list_of_columns_validation(columns_value: Union[str, list]) -> list: """ Performs validation for ListOfColumns type. Will ensure that there are no duplicate columns, empty strings, etc. In case an individual column is passed, it will coerce it to a list. diff --git a/src/koheesio/models/reader.py b/src/koheesio/models/reader.py index c1227940..3f351923 100644 --- a/src/koheesio/models/reader.py +++ b/src/koheesio/models/reader.py @@ -2,8 +2,8 @@ Module for the BaseReader class """ -from typing import Optional, TypeVar from abc import ABC, abstractmethod +from typing import Optional, TypeVar from koheesio import Step diff --git a/src/koheesio/models/sql.py b/src/koheesio/models/sql.py index baa3bc29..f19bc967 100644 --- a/src/koheesio/models/sql.py +++ b/src/koheesio/models/sql.py @@ -1,8 +1,8 @@ """This module contains the base class for SQL steps.""" -from typing import Any, Dict, Optional, Union from abc import ABC from pathlib import Path +from typing import Any, Dict, Optional, Union from koheesio import Step from koheesio.models import ExtraParamsMixin, Field, model_validator @@ -11,7 +11,8 @@ class SqlBaseStep(Step, ExtraParamsMixin, ABC): """Base class for SQL steps - `params` are used as placeholders for templating. These are identified with ${placeholder} in the SQL script. + `params` are used as placeholders for templating. The substitutions are identified by braces ('{' and '}') and can + optionally contain a $-sign - e.g. `${placeholder}` or `{placeholder}`. Parameters ---------- @@ -28,12 +29,12 @@ class SqlBaseStep(Step, ExtraParamsMixin, ABC): sql: Optional[str] = Field(default=None, description="SQL script to apply") params: Dict[str, Any] = Field( default_factory=dict, - description="Placeholders (parameters) for templating. These are identified with ${placeholder} in the SQL " - "script. Note: any arbitrary kwargs passed to the class will be added to params.", + description="Placeholders (parameters) for templating. The substitutions are identified by braces ('{' and '}')" + "and can optionally contain a $-sign. Note: any arbitrary kwargs passed to the class will be added to params.", ) @model_validator(mode="after") - def _validate_sql_and_sql_path(self): + def _validate_sql_and_sql_path(self) -> "SqlBaseStep": """Validate the SQL and SQL path""" sql = self.sql sql_path = self.sql_path @@ -57,12 +58,17 @@ def _validate_sql_and_sql_path(self): return self @property - def query(self): + def query(self) -> str: """Returns the query while performing params replacement""" - query = self.sql - for key, value in self.params.items(): - query = query.replace(f"${{{key}}}", value) + if self.sql: + query = self.sql + + for key, value in self.params.items(): + query = query.replace(f"${{{key}}}", value) + + self.log.debug(f"Generated query: {query}") + else: + query = "" - self.log.debug(f"Generated query: {query}") return query diff --git a/src/koheesio/notifications/slack.py b/src/koheesio/notifications/slack.py index 25bcbb29..423b2f37 100644 --- a/src/koheesio/notifications/slack.py +++ b/src/koheesio/notifications/slack.py @@ -2,14 +2,15 @@ Classes to ease interaction with Slack """ +import datetime import json from typing import Any, Dict, Optional -from datetime import datetime from textwrap import dedent from koheesio.models import ConfigDict, Field from koheesio.notifications import NotificationSeverity from koheesio.steps.http import HttpPostStep +from koheesio.utils import utc_now class SlackNotification(HttpPostStep): @@ -34,7 +35,7 @@ class SlackNotification(HttpPostStep): channel: Optional[str] = Field(default=None, description="Slack channel id") headers: Optional[Dict[str, Any]] = {"Content-type": "application/json"} - def get_payload(self): + def get_payload(self) -> str: """ Generate payload with `Block Kit`. More details: https://api.slack.com/block-kit @@ -56,11 +57,11 @@ def get_payload(self): } if self.channel: - payload["channel"] = self.channel + payload["channel"] = self.channel # type: ignore[assignment] return json.dumps(payload) - def execute(self): + def execute(self) -> None: """ Generate payload and send post request """ @@ -92,14 +93,14 @@ class SlackNotificationWithSeverity(SlackNotification): environment: str = Field(default=..., description="Environment description, e.g. dev / qa /prod") application: str = Field(default=..., description="Pipeline or application name") timestamp: datetime = Field( - default=datetime.utcnow(), + default_factory=utc_now, alias="execution_timestamp", description="Pipeline or application execution timestamp", ) model_config = ConfigDict(use_enum_values=False) - def get_payload_message(self): + def get_payload_message(self) -> str: """ Generate payload message based on the predefined set of parameters """ @@ -113,7 +114,7 @@ def get_payload_message(self): """ ) - def execute(self): + def execute(self) -> None: """ Generate payload and send post request """ diff --git a/src/koheesio/pandas/__init__.py b/src/koheesio/pandas/__init__.py index a9d324aa..c753a8d4 100644 --- a/src/koheesio/pandas/__init__.py +++ b/src/koheesio/pandas/__init__.py @@ -6,12 +6,13 @@ from typing import Optional from abc import ABC +from types import ModuleType from koheesio import Step, StepOutput from koheesio.models import Field from koheesio.spark.utils import import_pandas_based_on_pyspark_version -pandas = import_pandas_based_on_pyspark_version() +pandas: ModuleType = import_pandas_based_on_pyspark_version() class PandasStep(Step, ABC): @@ -24,4 +25,4 @@ class PandasStep(Step, ABC): class Output(StepOutput): """Output class for PandasStep""" - df: Optional[pandas.DataFrame] = Field(default=None, description="The Pandas DataFrame") + df: Optional[pandas.DataFrame] = Field(default=None, description="The Pandas DataFrame") # type: ignore diff --git a/src/koheesio/pandas/readers/excel.py b/src/koheesio/pandas/readers/excel.py index 1fbfc03d..bab7e525 100644 --- a/src/koheesio/pandas/readers/excel.py +++ b/src/koheesio/pandas/readers/excel.py @@ -45,7 +45,7 @@ class ExcelReader(Reader, ExtraParamsMixin): sheet_name: str = Field(default="Sheet1", description="The name of the sheet to read") header: Optional[Union[int, List[int]]] = Field(default=0, description="Row(s) to use as the column names") - def execute(self): + def execute(self) -> Reader.Output: extra_params = self.params or {} extra_params.pop("spark", None) self.output.df = pd.read_excel(self.path, sheet_name=self.sheet_name, header=self.header, **extra_params) diff --git a/src/koheesio/secrets/__init__.py b/src/koheesio/secrets/__init__.py index 838acd15..caa424b3 100644 --- a/src/koheesio/secrets/__init__.py +++ b/src/koheesio/secrets/__init__.py @@ -37,7 +37,7 @@ class Output(StepOutput): context: Context = Field(default=..., description="Koheesio context") @classmethod - def encode_secret_values(cls, data: dict): + def encode_secret_values(cls, data: dict) -> dict: """Encode secret values in the dictionary. Ensures that all values in the dictionary are wrapped in SecretStr. @@ -47,7 +47,7 @@ def encode_secret_values(cls, data: dict): if isinstance(value, dict): encoded_dict[key] = cls.encode_secret_values(value) else: - encoded_dict[key] = SecretStr(value) + encoded_dict[key] = SecretStr(value) # type: ignore[assignment] return encoded_dict @abstractmethod @@ -57,13 +57,14 @@ def _get_secrets(self) -> dict: """ ... - def execute(self): + def execute(self) -> None: """ Main method to handle secrets protection and context creation with "root-parent-secrets" structure. """ context = Context(self.encode_secret_values(data={self.root: {self.parent: self._get_secrets()}})) self.output.context = self.context.merge(context=context) + # noinspection PyMethodOverriding def get(self) -> Context: """ Convenience method to return context with secrets. diff --git a/src/koheesio/spark/__init__.py b/src/koheesio/spark/__init__.py index 2646ce07..c72cfb0e 100644 --- a/src/koheesio/spark/__init__.py +++ b/src/koheesio/spark/__init__.py @@ -4,29 +4,42 @@ from __future__ import annotations -from typing import Optional +import warnings from abc import ABC +from typing import Optional from pydantic import Field -from pyspark.sql import Column -from pyspark.sql import DataFrame as PySparkSQLDataFrame -from pyspark.sql import SparkSession as OriginalSparkSession -from pyspark.sql import functions as F - -try: - from pyspark.sql.utils import AnalysisException as SparkAnalysisException -except ImportError: - from pyspark.errors.exceptions.base import AnalysisException as SparkAnalysisException - from koheesio import Step, StepOutput from koheesio.models import model_validator - -# TODO: Move to spark/__init__.py after reorganizing the code -# Will be used for typing checks and consistency, specifically for PySpark >=3.5 -DataFrame = PySparkSQLDataFrame -SparkSession = OriginalSparkSession -AnalysisException = SparkAnalysisException +from koheesio.spark.utils.common import ( + AnalysisException, + Column, + DataFrame, + DataFrameReader, + DataFrameWriter, + DataStreamReader, + DataStreamWriter, + DataType, + ParseException, + SparkSession, + StreamingQuery, +) + +__all__ = [ + "SparkStep", + "Column", + "DataFrame", + "ParseException", + "SparkSession", + "AnalysisException", + "DataType", + "DataFrameReader", + "DataStreamReader", + "DataFrameWriter", + "DataStreamWriter", + "StreamingQuery", +] class SparkStep(Step, ABC): @@ -50,17 +63,27 @@ class Output(StepOutput): df: Optional[DataFrame] = Field(default=None, description="The Spark DataFrame") @model_validator(mode="after") - def _get_active_spark_session(self): + def _get_active_spark_session(self) -> SparkStep: """Return active SparkSession instance If a user provides a SparkSession instance, it will be returned. Otherwise, an active SparkSession will be attempted to be retrieved. """ if self.spark is None: - self.spark = SparkSession.getActiveSession() + from koheesio.spark.utils.common import get_active_session + + self.spark = get_active_session() return self -# TODO: Move to spark/functions/__init__.py after reorganizing the code -def current_timestamp_utc(spark: SparkSession) -> Column: - """Get the current timestamp in UTC""" - return F.to_utc_timestamp(F.current_timestamp(), spark.conf.get("spark.sql.session.timeZone")) +def current_timestamp_utc(spark): + warnings.warn( + message=( + "The current_timestamp_utc function has been moved to the koheesio.spark.functions module." + "Import it from there instead. Current import path will be deprecated in the future." + ), + category=DeprecationWarning, + stacklevel=2, + ) + from koheesio.spark.functions import current_timestamp_utc as _current_timestamp_utc + + return _current_timestamp_utc(spark) diff --git a/src/koheesio/spark/delta.py b/src/koheesio/spark/delta.py index da015c64..8d252a68 100644 --- a/src/koheesio/spark/delta.py +++ b/src/koheesio/spark/delta.py @@ -7,11 +7,10 @@ from py4j.protocol import Py4JJavaError # type: ignore -from pyspark.sql import DataFrame from pyspark.sql.types import DataType from koheesio.models import Field, field_validator, model_validator -from koheesio.spark import AnalysisException, SparkStep +from koheesio.spark import AnalysisException, DataFrame, SparkStep from koheesio.spark.utils import on_databricks @@ -59,11 +58,6 @@ class DeltaTableStep(SparkStep): max_version_ts_of_last_execution(query_predicate: str = None) -> datetime.datetime Max version timestamp of last execution. If no timestamp is found, returns 1900-01-01 00:00:00. Note: will raise an error if column `VERSION_TIMESTAMP` does not exist. - Properties ---------- @@ -124,7 +118,7 @@ class DeltaTableStep(SparkStep): ) @field_validator("default_create_properties") - def _adjust_default_properties(cls, default_create_properties): + def _adjust_default_properties(cls, default_create_properties: dict) -> dict: """Adjust default properties based on environment.""" if on_databricks(): default_create_properties["delta.autoOptimize.autoCompact"] = True @@ -136,7 +130,7 @@ def _adjust_default_properties(cls, default_create_properties): return default_create_properties @model_validator(mode="after") - def _validate_catalog_database_table(self): + def _validate_catalog_database_table(self) -> "DeltaTableStep": """Validate that catalog, database/schema, and table are correctly set""" database, catalog, table = self.database, self.catalog, self.table @@ -186,7 +180,7 @@ def is_cdf_active(self) -> bool: props = self.get_persisted_properties() return props.get("delta.enableChangeDataFeed", "false") == "true" - def add_property(self, key: str, value: Union[str, int, bool], override: bool = False): + def add_property(self, key: str, value: Union[str, int, bool], override: bool = False) -> None: """Alter table and set table property. Parameters @@ -232,7 +226,7 @@ def _alter_table() -> None: else: self.default_create_properties[key] = v_str - def add_properties(self, properties: Dict[str, Union[str, bool, int]], override: bool = False): + def add_properties(self, properties: Dict[str, Union[str, bool, int]], override: bool = False) -> None: """Alter table and add properties. Parameters @@ -247,7 +241,7 @@ def add_properties(self, properties: Dict[str, Union[str, bool, int]], override: v_str = str(v) if not isinstance(v, bool) else str(v).lower() self.add_property(key=k, value=v_str, override=override) - def execute(self): + def execute(self) -> None: """Nothing to execute on a Table""" @property @@ -274,7 +268,7 @@ def columns(self) -> Optional[List[str]]: return self.dataframe.columns if self.exists else None def get_column_type(self, column: str) -> Optional[DataType]: - """Get the type of a column in the table. + """Get the type of a specific column in the table. Parameters ---------- @@ -291,7 +285,7 @@ def get_column_type(self, column: str) -> Optional[DataType]: @property def has_change_type(self) -> bool: """Checks if a column named `_change_type` is present in the table""" - return "_change_type" in self.columns + return "_change_type" in self.columns # type: ignore @property def exists(self) -> bool: @@ -300,7 +294,15 @@ def exists(self) -> bool: result = False try: - self.spark.table(self.table_name) + from koheesio.spark.utils.connect import is_remote_session + + _df = self.spark.table(self.table_name) + + if is_remote_session(): + # In Spark remote session it is not enough to call just spark.table(self.table_name) + # as it will not raise an exception, we have to make action call on table to check if it exists + _df.take(1) + result = True except AnalysisException as e: err_msg = str(e).lower() diff --git a/src/koheesio/spark/etl_task.py b/src/koheesio/spark/etl_task.py index 033d04d6..31323fb5 100644 --- a/src/koheesio/spark/etl_task.py +++ b/src/koheesio/spark/etl_task.py @@ -4,15 +4,15 @@ Extract -> Transform -> Load """ -from datetime import datetime - -from pyspark.sql import DataFrame +import datetime from koheesio import Step from koheesio.models import Field, InstanceOf, conlist +from koheesio.spark import DataFrame from koheesio.spark.readers import Reader from koheesio.spark.transformations import Transformation from koheesio.spark.writers import Writer +from koheesio.utils import utc_now class EtlTask(Step): @@ -86,7 +86,7 @@ class EtlTask(Step): # private attrs etl_date: datetime = Field( - default=datetime.utcnow(), + default_factory=utc_now, description="Date time when this object was created as iso format. Example: '2023-01-24T09:39:23.632374'", ) @@ -123,7 +123,7 @@ def load(self, df: DataFrame) -> DataFrame: writer.write(df) return df - def execute(self): + def execute(self) -> Step.Output: """Run the ETL process""" self.log.info(f"Task started at {self.etl_date}") @@ -135,7 +135,3 @@ def execute(self): # load to target self.output.target_df = self.load(self.output.transform_df) - - def run(self): - """alias of execute""" - return self.execute() diff --git a/src/koheesio/spark/functions/__init__.py b/src/koheesio/spark/functions/__init__.py new file mode 100644 index 00000000..5bac8835 --- /dev/null +++ b/src/koheesio/spark/functions/__init__.py @@ -0,0 +1,11 @@ +from pyspark.sql import functions as f + +from koheesio.spark import Column, SparkSession + + +def current_timestamp_utc(spark: SparkSession) -> Column: + """Get the current timestamp in UTC""" + tz_session = spark.conf.get("spark.sql.session.timeZone", "UTC") + tz = tz_session if tz_session else "UTC" + + return f.to_utc_timestamp(f.current_timestamp(), tz) diff --git a/src/koheesio/spark/readers/databricks/autoloader.py b/src/koheesio/spark/readers/databricks/autoloader.py index 8444a548..282b4168 100644 --- a/src/koheesio/spark/readers/databricks/autoloader.py +++ b/src/koheesio/spark/readers/databricks/autoloader.py @@ -7,6 +7,8 @@ from enum import Enum from pyspark.sql.streaming import DataStreamReader + +# noinspection PyProtectedMember from pyspark.sql.types import AtomicType, StructType from koheesio.models import Field, field_validator @@ -97,14 +99,14 @@ class AutoLoader(Reader): ) @field_validator("format") - def validate_format(cls, format_specified): + def validate_format(cls, format_specified: Union[str, AutoLoaderFormat]) -> str: """Validate `format` value""" if isinstance(format_specified, str): if format_specified.upper() in [f.value.upper() for f in AutoLoaderFormat]: format_specified = getattr(AutoLoaderFormat, format_specified.upper()) return str(format_specified.value) - def get_options(self): + def get_options(self) -> Dict[str, Any]: """Get the options for the autoloader""" self.options.update( { @@ -118,10 +120,10 @@ def get_options(self): def reader(self) -> DataStreamReader: reader = self.spark.readStream.format("cloudFiles") if self.schema_ is not None: - reader = reader.schema(self.schema_) + reader = reader.schema(self.schema_) # type: ignore reader = reader.options(**self.get_options()) return reader - def execute(self): + def execute(self) -> Reader.Output: """Reads from the given location with the given options using Autoloader""" self.output.df = self.reader().load(self.location) diff --git a/src/koheesio/spark/readers/delta.py b/src/koheesio/spark/readers/delta.py index 54ee7950..8983f1aa 100644 --- a/src/koheesio/spark/readers/delta.py +++ b/src/koheesio/spark/readers/delta.py @@ -8,14 +8,16 @@ Reads data from a Delta table and returns a DataStream """ +from __future__ import annotations + from typing import Any, Dict, Optional, Union -import pyspark.sql.functions as f -from pyspark.sql import Column, DataFrameReader -from pyspark.sql.streaming import DataStreamReader +from pyspark.sql import DataFrameReader +from pyspark.sql import functions as f from koheesio.logger import LoggingFactory from koheesio.models import Field, ListOfColumns, field_validator, model_validator +from koheesio.spark import Column, DataStreamReader from koheesio.spark.delta import DeltaTableStep from koheesio.spark.readers import Reader from koheesio.utils import get_random_string @@ -76,7 +78,7 @@ class DeltaTableReader(Reader): ignoreChanges: re-process updates if files had to be rewritten in the source table due to a data changing operation such as UPDATE, MERGE INTO, DELETE (within partitions), or OVERWRITE. Unchanged rows may still be emitted, therefore your downstream consumers should be able to handle duplicates. Deletes are not propagated - downstream. ignoreChanges subsumes ignoreDeletes. Therefore if you use ignoreChanges, your stream will not be + downstream. ignoreChanges subsumes ignoreDeletes. Therefore, if you use ignoreChanges, your stream will not be disrupted by either deletions or updates to the source table. """ @@ -160,15 +162,15 @@ class DeltaTableReader(Reader): ) # private attrs - __temp_view_name__ = None + __temp_view_name__: Optional[str] = None @property - def temp_view_name(self): + def temp_view_name(self) -> str: """Get the temporary view name for the dataframe for SQL queries""" return self.__temp_view_name__ @field_validator("table") - def _validate_table_name(cls, tbl: Union[DeltaTableStep, str]): + def _validate_table_name(cls, tbl: Union[DeltaTableStep, str]) -> DeltaTableStep: """Validate the table name provided as a string or a DeltaTableStep instance.""" if isinstance(tbl, str): return DeltaTableStep(table=tbl) @@ -177,7 +179,7 @@ def _validate_table_name(cls, tbl: Union[DeltaTableStep, str]): raise AttributeError(f"Table name provided cannot be processed as a Table : {tbl}") @model_validator(mode="after") - def _validate_starting_version_and_timestamp(self): + def _validate_starting_version_and_timestamp(self) -> "DeltaTableReader": """Validate 'starting_version' and 'starting_timestamp' - Only one of each should be provided""" starting_version = self.starting_version starting_timestamp = self.starting_timestamp @@ -199,7 +201,7 @@ def _validate_starting_version_and_timestamp(self): return self @model_validator(mode="after") - def _validate_ignore_deletes_and_changes_and_skip_commits(self): + def _validate_ignore_deletes_and_changes_and_skip_commits(self) -> "DeltaTableReader": """Validate 'ignore_deletes' and 'ignore_changes' - Only one of each should be provided""" ignore_deletes = self.ignore_deletes ignore_changes = self.ignore_changes @@ -214,7 +216,7 @@ def _validate_ignore_deletes_and_changes_and_skip_commits(self): return self @model_validator(mode="before") - def _warn_on_streaming_options_without_streaming(cls, options: Dict): + def _warn_on_streaming_options_without_streaming(cls, options: Dict) -> Dict: """throws a warning if streaming options were provided, but streaming was not set to true""" streaming_options = [val for opt, val in options.items() if opt in STREAMING_ONLY_OPTIONS] streaming_toggled_on = options.get("streaming") @@ -229,7 +231,7 @@ def _warn_on_streaming_options_without_streaming(cls, options: Dict): return options @model_validator(mode="after") - def set_temp_view_name(self): + def set_temp_view_name(self) -> "DeltaTableReader": """Set a temporary view name for the dataframe for SQL queries""" table_name = self.table.table vw_name = get_random_string(prefix=f"tmp_{table_name}") @@ -237,9 +239,10 @@ def set_temp_view_name(self): return self @property - def view(self): + def view(self) -> str: """Create a temporary view of the dataframe for SQL queries""" temp_view_name = self.temp_view_name + if (output_df := self.output.df) is None: self.log.warning( "Attempting to createTempView without any data being present. Please run .execute() or .read() first. " @@ -247,6 +250,7 @@ def view(self): ) else: output_df.createOrReplaceTempView(temp_view_name) + return temp_view_name def get_options(self) -> Dict[str, Any]: @@ -273,7 +277,7 @@ def get_options(self) -> Dict[str, Any]: else: pass # there are none... for now :) - def normalize(v: Union[str, bool]): + def normalize(v: Union[str, bool]) -> str: """normalize values""" # True becomes "true", False becomes "false" v = str(v).lower() if isinstance(v, bool) else v @@ -300,10 +304,10 @@ def reader(self) -> Union[DataStreamReader, DataFrameReader]: reader = reader.option(key, value) return reader - def execute(self): + def execute(self) -> Reader.Output: df = self.reader.table(self.table.table_name) if self.filter_cond is not None: - df = df.filter(f.expr(self.filter_cond) if isinstance(self.filter_cond, str) else self.filter_cond) + df = df.filter(f.expr(self.filter_cond) if isinstance(self.filter_cond, str) else self.filter_cond) # type: ignore if self.columns is not None: df = df.select(*self.columns) self.output.df = df diff --git a/src/koheesio/spark/readers/dummy.py b/src/koheesio/spark/readers/dummy.py index a604b3b3..5097f79c 100644 --- a/src/koheesio/spark/readers/dummy.py +++ b/src/koheesio/spark/readers/dummy.py @@ -40,5 +40,5 @@ class DummyReader(Reader): range: int = Field(default=100, description="How large to make the Dataframe") - def execute(self): + def execute(self) -> Reader.Output: self.output.df = self.spark.range(self.range) diff --git a/src/koheesio/spark/readers/excel.py b/src/koheesio/spark/readers/excel.py index 4b52cc79..4e5abae3 100644 --- a/src/koheesio/spark/readers/excel.py +++ b/src/koheesio/spark/readers/excel.py @@ -35,6 +35,6 @@ class ExcelReader(Reader, PandasExcelReader): The row to use as the column names """ - def execute(self): + def execute(self) -> Reader.Output: pdf: PandasDataFrame = PandasExcelReader.from_step(self).execute().df self.output.df = self.spark.createDataFrame(pdf) diff --git a/src/koheesio/spark/readers/file_loader.py b/src/koheesio/spark/readers/file_loader.py index 9d338063..2bb3cd8b 100644 --- a/src/koheesio/spark/readers/file_loader.py +++ b/src/koheesio/spark/readers/file_loader.py @@ -100,13 +100,13 @@ class FileLoader(Reader, ExtraParamsMixin): streaming: Optional[bool] = Field(default=False, description="Whether to read the files as a Stream or not") @field_validator("path") - def ensure_path_is_str(cls, v): + def ensure_path_is_str(cls, path: Union[Path, str]) -> Union[Path, str]: """Ensure that the path is a string as required by Spark.""" - if isinstance(v, Path): - return str(v.absolute().as_posix()) - return v + if isinstance(path, Path): + return str(path.absolute().as_posix()) + return path - def execute(self): + def execute(self) -> Reader.Output: """Reads the file, in batch or as a stream, using the specified format and schema, while applying any extra parameters.""" reader = self.spark.readStream if self.streaming else self.spark.read reader = reader.format(self.format) @@ -117,7 +117,7 @@ def execute(self): if self.extra_params: reader = reader.options(**self.extra_params) - self.output.df = reader.load(self.path) + self.output.df = reader.load(self.path) # type: ignore class CsvReader(FileLoader): diff --git a/src/koheesio/spark/readers/hana.py b/src/koheesio/spark/readers/hana.py index 92cfbbd1..76168560 100644 --- a/src/koheesio/spark/readers/hana.py +++ b/src/koheesio/spark/readers/hana.py @@ -28,10 +28,10 @@ class HanaReader(JdbcReader): ```python from koheesio.spark.readers.hana import HanaReader jdbc_hana = HanaReader( - url="jdbc:sap://:/? + url="jdbc:sap://:/?", user="YOUR_USERNAME", password="***", - dbtable="schemaname.tablename" + dbtable="schema_name.table_name" ) df = jdbc_hana.read() ``` diff --git a/src/koheesio/spark/readers/jdbc.py b/src/koheesio/spark/readers/jdbc.py index 08b3197f..bbaea071 100644 --- a/src/koheesio/spark/readers/jdbc.py +++ b/src/koheesio/spark/readers/jdbc.py @@ -48,7 +48,7 @@ class JdbcReader(Reader): url="jdbc:sqlserver://10.xxx.xxx.xxx:1433;databaseName=YOUR_DATABASE", user="YOUR_USERNAME", password="***", - dbtable="schemaname.tablename", + dbtable="schema_name.table_name", options={"fetchsize": 100}, ) df = jdbc_mssql.read() @@ -73,7 +73,7 @@ class JdbcReader(Reader): query: Optional[str] = Field(default=None, description="Query") options: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Extra options to pass to spark reader") - def get_options(self): + def get_options(self) -> Dict[str, Any]: """ Dictionary of options required for the specific JDBC driver. @@ -84,10 +84,10 @@ def get_options(self): "url": self.url, "user": self.user, "password": self.password, - **self.options, + **self.options, # type: ignore } - def execute(self): + def execute(self) -> Reader.Output: """Wrapper around Spark's jdbc read format""" # Can't have both dbtable and query empty diff --git a/src/koheesio/spark/readers/kafka.py b/src/koheesio/spark/readers/kafka.py index 08fed3e5..915dff98 100644 --- a/src/koheesio/spark/readers/kafka.py +++ b/src/koheesio/spark/readers/kafka.py @@ -5,6 +5,7 @@ from typing import Dict, Optional from koheesio.models import ExtraParamsMixin, Field +from koheesio.spark import DataFrameReader, DataStreamReader from koheesio.spark.readers import Reader @@ -73,7 +74,7 @@ class KafkaReader(Reader, ExtraParamsMixin): streaming: Optional[bool] = Field( default=False, description="Whether to read the kafka topic as a stream or not. Defaults to False." ) - params: Optional[Dict[str, str]] = Field( + params: Dict[str, str] = Field( default_factory=dict, alias="kafka_options", description="Arbitrary options to be applied when creating NSP Reader. If a user provides values for " @@ -82,24 +83,24 @@ class KafkaReader(Reader, ExtraParamsMixin): ) @property - def stream_reader(self): + def stream_reader(self) -> DataStreamReader: """Returns the Spark readStream object.""" return self.spark.readStream @property - def batch_reader(self): + def batch_reader(self) -> DataFrameReader: """Returns the Spark read object for batch processing.""" return self.spark.read @property - def reader(self): + def reader(self) -> Reader: """Returns the appropriate reader based on the streaming flag.""" if self.streaming: return self.stream_reader return self.batch_reader @property - def options(self): + def options(self) -> Dict[str, str]: """Merge fixed parameters with arbitrary options provided by user.""" return { **self.params, @@ -108,7 +109,7 @@ def options(self): } @property - def logged_option_keys(self): + def logged_option_keys(self) -> set: """Keys that are allowed to be logged for the options.""" return { "kafka.bootstrap.servers", @@ -122,11 +123,11 @@ def logged_option_keys(self): "kafka.group.id", } - def execute(self): + def execute(self) -> Reader.Output: applied_options = {k: v for k, v in self.options.items() if k in self.logged_option_keys} self.log.debug(f"Applying options {applied_options}") - self.output.df = self.reader.format("kafka").options(**self.options).load() + self.output.df = self.reader.format("kafka").options(**self.options).load() # type: ignore class KafkaStreamReader(KafkaReader): diff --git a/src/koheesio/spark/readers/memory.py b/src/koheesio/spark/readers/memory.py index 9b5e95a7..79002058 100644 --- a/src/koheesio/spark/readers/memory.py +++ b/src/koheesio/spark/readers/memory.py @@ -3,15 +3,16 @@ """ import json -from typing import Any, Dict, Optional, Union from enum import Enum from functools import partial +from io import StringIO +from typing import Any, Dict, Optional, Union -from pyspark.rdd import RDD -from pyspark.sql import DataFrame +import pandas as pd from pyspark.sql.types import StructType from koheesio.models import ExtraParamsMixin, Field +from koheesio.spark import DataFrame from koheesio.spark.readers import Reader @@ -66,48 +67,55 @@ class InMemoryDataReader(Reader, ExtraParamsMixin): description="[Optional] Schema that will be applied during the creation of Spark DataFrame", ) - params: Optional[Dict[str, Any]] = Field( + params: Dict[str, Any] = Field( default_factory=dict, description="[Optional] Set of extra parameters that should be passed to the appropriate reader (csv / json)", ) - @property - def _rdd(self) -> RDD: - """ - Read provided data and transform it into Spark RDD - - Returns - ------- - RDD - """ - _data = self.data + def _csv(self) -> DataFrame: + """Method for reading CSV data""" + if isinstance(self.data, list): + csv_data: str = "\n".join(self.data) + else: + csv_data: str = self.data # type: ignore - if isinstance(_data, bytes): - _data = _data.decode("utf-8") + if "header" in self.params and self.params["header"] is True: + self.params["header"] = 0 - if isinstance(_data, dict): - _data = json.dumps(_data) + pandas_df = pd.read_csv(StringIO(csv_data), **self.params) # type: ignore + df = self.spark.createDataFrame(pandas_df, schema=self.schema_) # type: ignore - # 'list' type already compatible with 'parallelize' - if not isinstance(_data, list): - _data = _data.splitlines() + return df - _rdd = self.spark.sparkContext.parallelize(_data) + def _json(self) -> DataFrame: + """Method for reading JSON data""" + if isinstance(self.data, str): + json_data = [json.loads(self.data)] + elif isinstance(self.data, list): + if all(isinstance(x, str) for x in self.data): + json_data = [json.loads(x) for x in self.data] + else: + json_data = [self.data] - return _rdd + # Use pyspark.pandas to read the JSON data from the string + # noinspection PyUnboundLocalVariable + pandas_df = pd.read_json(StringIO(json.dumps(json_data)), **self.params) # type: ignore - def _csv(self, rdd: RDD) -> DataFrame: - """Method for reading CSV data""" - return self.spark.read.csv(rdd, schema=self.schema_, **self.params) + # Convert pyspark.pandas DataFrame to Spark DataFrame + df = self.spark.createDataFrame(pandas_df, schema=self.schema_) # type: ignore - def _json(self, rdd: RDD) -> DataFrame: - """Method for reading JSON data""" - return self.spark.read.json(rdd, schema=self.schema_, **self.params) + return df - def execute(self): + def execute(self) -> Reader.Output: """ Execute method appropriate to the specific data format """ + if self.data is None: + raise ValueError("Data is not provided") + + if isinstance(self.data, bytes): + self.data = self.data.decode("utf-8") + _func = getattr(InMemoryDataReader, f"_{self.format}") - _df = partial(_func, self, self._rdd)() + _df = partial(_func, self)() self.output.df = _df diff --git a/src/koheesio/spark/readers/metastore.py b/src/koheesio/spark/readers/metastore.py index ca247770..cb0f4928 100644 --- a/src/koheesio/spark/readers/metastore.py +++ b/src/koheesio/spark/readers/metastore.py @@ -17,5 +17,5 @@ class MetastoreReader(Reader): table: str = Field(default=..., description="Table name in spark metastore") - def execute(self): + def execute(self) -> Reader.Output: self.output.df = self.spark.table(self.table) diff --git a/src/koheesio/spark/readers/rest_api.py b/src/koheesio/spark/readers/rest_api.py index 49de6dbf..bad9036c 100644 --- a/src/koheesio/spark/readers/rest_api.py +++ b/src/koheesio/spark/readers/rest_api.py @@ -13,6 +13,7 @@ from pydantic import Field, InstanceOf +# noinspection PyProtectedMember from pyspark.sql.types import AtomicType, StructType from koheesio.asyncio.http import AsyncHttpGetStep @@ -20,7 +21,9 @@ from koheesio.steps.http import HttpGetStep +# noinspection HttpUrlsUsage class RestApiReader(Reader): + # noinspection HttpUrlsUsage """ A reader class that executes an API call and stores the response in a DataFrame. @@ -61,7 +64,7 @@ class RestApiReader(Reader): session.mount("https://", HTTPAdapter(max_retries=retry_logic)) session.mount("http://", HTTPAdapter(max_retries=retry_logic)) - transport = PaginatedHtppGetStep( + transport = PaginatedHttpGetStep( url="https://api.example.com/data?page={page}", paginate=True, pages=3, @@ -121,6 +124,7 @@ def execute(self) -> Reader.Output: """ raw_data = self.transport.execute() + data = None if isinstance(raw_data, HttpGetStep.Output): data = raw_data.response_json elif isinstance(raw_data, AsyncHttpGetStep.Output): diff --git a/src/koheesio/spark/readers/spark_sql_reader.py b/src/koheesio/spark/readers/spark_sql_reader.py index fe5900e9..c31e6ed3 100644 --- a/src/koheesio/spark/readers/spark_sql_reader.py +++ b/src/koheesio/spark/readers/spark_sql_reader.py @@ -58,5 +58,5 @@ class SparkSqlReader(SqlBaseStep, Reader): Any arbitrary kwargs passed to the class will be added to params. """ - def execute(self): + def execute(self) -> Reader.Output: self.output.df = self.spark.sql(self.query) diff --git a/src/koheesio/spark/readers/teradata.py b/src/koheesio/spark/readers/teradata.py index b4c81676..5c7d2d23 100644 --- a/src/koheesio/spark/readers/teradata.py +++ b/src/koheesio/spark/readers/teradata.py @@ -37,7 +37,7 @@ class TeradataReader(JdbcReader): url="jdbc:teradata:///logmech=ldap,charset=utf8,database=,type=fastexport, maybenull=on", user="YOUR_USERNAME", password="***", - dbtable="schemaname.tablename", + dbtable="schema_name.table_name", ) ``` diff --git a/src/koheesio/spark/snowflake.py b/src/koheesio/spark/snowflake.py index 466f9123..67cab023 100644 --- a/src/koheesio/spark/snowflake.py +++ b/src/koheesio/spark/snowflake.py @@ -1,3 +1,4 @@ +# noinspection PyUnresolvedReferences """ Snowflake steps and tasks for Koheesio @@ -41,12 +42,12 @@ """ import json -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Set, Union from abc import ABC from copy import deepcopy from textwrap import dedent -from pyspark.sql import DataFrame, Window +from pyspark.sql import Window from pyspark.sql import functions as f from pyspark.sql import types as t @@ -61,7 +62,7 @@ field_validator, model_validator, ) -from koheesio.spark import SparkStep +from koheesio.spark import DataFrame, SparkStep from koheesio.spark.delta import DeltaTableStep from koheesio.spark.readers.delta import DeltaTableReader, DeltaTableStreamReader from koheesio.spark.readers.jdbc import JdbcReader @@ -180,7 +181,7 @@ class SnowflakeBaseModel(BaseModel, ExtraParamsMixin, ABC): "`net.snowflake.spark.snowflake` in other environments and make sure to install required JARs.", ) - def get_options(self): + def get_options(self) -> Dict[str, Any]: """Get the sfOptions as a dictionary.""" return { key: value @@ -193,7 +194,7 @@ def get_options(self): "sfSchema": self.sfSchema, "sfRole": self.role, "sfWarehouse": self.warehouse, - **self.options, + **self.options, # type: ignore }.items() if value is not None } @@ -208,7 +209,7 @@ class SnowflakeTableStep(SnowflakeStep, ABC): table: str = Field(default=..., description="The name of the table") - def get_options(self): + def get_options(self) -> Dict[str, Any]: options = super().get_options() options["table"] = self.table return options @@ -241,7 +242,8 @@ class SnowflakeReader(SnowflakeBaseModel, JdbcReader): https://docs.snowflake.com/en/user-guide/spark-connector-use#setting-configuration-options-for-the-connector """ - driver: Optional[str] = None # overriding `driver` property of JdbcReader, because it is not required by Snowflake + # overriding `driver` property of JdbcReader, because it is not required by Snowflake + driver: Optional[str] = None # type: ignore class SnowflakeTransformation(SnowflakeBaseModel, Transformation, ABC): @@ -272,11 +274,11 @@ class RunQuery(SnowflakeStep): query: str = Field(default=..., description="The query to run", alias="sql") @field_validator("query") - def validate_query(cls, query): + def validate_query(cls, query: str) -> str: """Replace escape characters""" return query.replace("\\n", "\n").replace("\\t", "\t").strip() - def get_options(self): + def get_options(self) -> Dict[str, Any]: # Executing the RunQuery without `host` option in Databricks throws: # An error occurred while calling z:net.snowflake.spark.snowflake.Utils.runQuery. # : java.util.NoSuchElementException: key not found: host @@ -284,7 +286,7 @@ def get_options(self): options["host"] = options["sfURL"] return options - def execute(self) -> None: + def execute(self) -> SnowflakeStep.Output: if not self.query: self.log.warning("Empty string given as query input, skipping execution") return @@ -314,12 +316,12 @@ class Query(SnowflakeReader): query: str = Field(default=..., description="The query to run") @field_validator("query") - def validate_query(cls, query): + def validate_query(cls, query: str) -> str: """Replace escape characters""" query = query.replace("\\n", "\n").replace("\\t", "\t").strip() return query - def get_options(self): + def get_options(self) -> Dict[str, Any]: """add query to options""" options = super().get_options() options["query"] = self.query @@ -371,7 +373,7 @@ class Output(StepOutput): exists: bool = Field(default=..., description="Whether or not the table exists") - def execute(self): + def execute(self) -> Output: query = ( dedent( # Force upper case, due to case-sensitivity of where clause @@ -397,7 +399,7 @@ def execute(self): self.output.exists = exists -def map_spark_type(spark_type: t.DataType): +def map_spark_type(spark_type: t.DataType) -> str: """ Translates Spark DataFrame Schema type to SnowFlake type @@ -533,7 +535,7 @@ class Output(SnowflakeTransformation.Output): ) query: str = Field(default=..., description="Query that was executed to create the table") - def execute(self): + def execute(self) -> Output: self.output.df = self.df input_schema = self.df.schema @@ -620,7 +622,7 @@ class Output(SnowflakeStep.Output): ) @model_validator(mode="before") - def set_roles_privileges(cls, values): + def set_roles_privileges(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Coerce roles and privileges to be lists if they are not already.""" roles_value = values.get("roles") or values.get("role") privileges_value = values.get("privileges") @@ -636,7 +638,7 @@ def set_roles_privileges(cls, values): return values @model_validator(mode="after") - def validate_object_and_object_type(self): + def validate_object_and_object_type(self) -> "GrantPrivilegesOnObject": """Validate that the object and type are set.""" object_value = self.object if not object_value: @@ -651,7 +653,7 @@ def validate_object_and_object_type(self): return self - def get_query(self, role: str): + def get_query(self, role: str) -> str: """Build the GRANT query Parameters @@ -664,10 +666,12 @@ def get_query(self, role: str): query : str The Query that performs the grant """ - query = f"GRANT {','.join(self.privileges)} ON {self.type} {self.object} TO ROLE {role}".upper() + query = ( + f"GRANT {','.join(self.privileges)} ON {self.type} {self.object} TO ROLE {role}".upper() + ) # nosec B608: hardcoded_sql_expressions return query - def execute(self): + def execute(self) -> SnowflakeStep.Output: self.output.query = [] roles = self.roles @@ -707,7 +711,7 @@ class GrantPrivilegesOnFullyQualifiedObject(GrantPrivilegesOnObject): """ @model_validator(mode="after") - def set_object_name(self): + def set_object_name(self) -> "GrantPrivilegesOnFullyQualifiedObject": """Set the object name to be fully qualified, i.e. database.schema.object_name""" # database, schema, obj_name db = self.database @@ -809,14 +813,14 @@ class AddColumn(SnowflakeStep): table: str = Field(default=..., description="The name of the Snowflake table") column: str = Field(default=..., description="The name of the new column") - type: f.DataType = Field(default=..., description="The DataType represented as a Spark DataType") + type: t.DataType = Field(default=..., description="The DataType represented as a Spark DataType") class Output(SnowflakeStep.Output): """Output class for AddColumn""" query: str = Field(default=..., description="Query that was executed to add the column") - def execute(self): + def execute(self) -> Output: query = f"ALTER TABLE {self.table} ADD COLUMN {self.column} {map_spark_type(self.type)}".upper() self.output.query = query RunQuery(**self.get_options(), query=query).execute() @@ -845,7 +849,7 @@ class Output(SparkStep.Output): default=False, description="Flag to indicate whether Snowflake schema has been altered" ) - def execute(self): + def execute(self) -> Output: self.log.warning("Snowflake table will always take a priority in case of data type conflicts!") # spark side @@ -893,7 +897,7 @@ def execute(self): sf_col_name = sf_col.name.lower() if sf_col_name not in df_cols: sf_col_type = sf_col.dataType - df = df.withColumn(sf_col_name, f.lit(None).cast(sf_col_type)) + df = df.withColumn(sf_col_name, f.lit(None).cast(sf_col_type)) # type: ignore # Put DataFrame columns in the same order as the Snowflake table df = df.select(*sf_cols) @@ -917,7 +921,7 @@ class SnowflakeWriter(SnowflakeBaseModel, Writer): BatchOutputMode.APPEND, alias="mode", description="The insertion type, append or overwrite" ) - def execute(self): + def execute(self) -> Writer.Output: """Write to Snowflake""" self.log.debug(f"writing to {self.table} with mode {self.insert_type}") self.df.write.format(self.format).options(**self.get_options()).option("dbtable", self.table).mode( @@ -954,9 +958,8 @@ class TagSnowflakeQuery(Step, ExtraParamsMixin): pipeline_execution_time="2022-01-01T00:00:00", task_execution_time="2022-01-01T01:00:00", environment="dev", - trace_id="e0fdec43-a045-46e5-9705-acd4f3f96045", - span_id="cb89abea-1c12-471f-8b12-546d2d66f6cb", - ), + trace_id="acd4f3f96045", + span_id="546d2d66f6cb", ).execute().options ``` """ @@ -970,7 +973,7 @@ class Output(StepOutput): options: Dict = Field(default=..., description="Copy of provided SF options, with added query tag preaction") - def execute(self): + def execute(self) -> Output: """Add query tag preaction to Snowflake options""" tag_json = json.dumps(self.extra_params, indent=4, sort_keys=True) tag_preaction = f"ALTER SESSION SET QUERY_TAG = '{tag_json}';" @@ -1025,7 +1028,7 @@ class SynchronizeDeltaToSnowflakeTask(SnowflakeStep): staging_table_name: Optional[str] = Field( default=None, alias="staging_table", description="Optional snowflake staging name", validate_default=False ) - key_columns: Optional[List[str]] = Field( + key_columns: List[str] = Field( default_factory=list, description="Key columns on which merge statements will be MERGE statement will be applied.", ) @@ -1048,7 +1051,7 @@ class SynchronizeDeltaToSnowflakeTask(SnowflakeStep): writer_: Optional[Union[ForEachBatchStreamWriter, SnowflakeWriter]] = None @field_validator("staging_table_name") - def _validate_staging_table(cls, staging_table_name): + def _validate_staging_table(cls, staging_table_name: str) -> str: """Validate the staging table name and return it if it's valid.""" if "." in staging_table_name: raise ValueError( @@ -1057,7 +1060,7 @@ def _validate_staging_table(cls, staging_table_name): return staging_table_name @model_validator(mode="before") - def _checkpoint_location_check(cls, values: Dict): + def _checkpoint_location_check(cls, values: Dict) -> Dict: """Give a warning if checkpoint location is given but not expected and vice versa""" streaming = values.get("streaming") checkpoint_location = values.get("checkpoint_location") @@ -1070,7 +1073,7 @@ def _checkpoint_location_check(cls, values: Dict): return values @model_validator(mode="before") - def _synch_mode_check(cls, values: Dict): + def _synch_mode_check(cls, values: Dict) -> Dict: """Validate requirements for various synchronisation modes""" streaming = values.get("streaming") synchronisation_mode = values.get("synchronisation_mode") @@ -1086,7 +1089,7 @@ def _synch_mode_check(cls, values: Dict): raise ValueError("Synchronisation mode can't be 'OVERWRITE' with streaming enabled") if synchronisation_mode == BatchOutputMode.MERGE and streaming is False: raise ValueError("Synchronisation mode can't be 'MERGE' with streaming disabled") - if synchronisation_mode == BatchOutputMode.MERGE and len(key_columns) < 1: + if synchronisation_mode == BatchOutputMode.MERGE and len(key_columns) < 1: # type: ignore raise ValueError("MERGE synchronisation mode requires a list of PK columns in `key_columns`.") return values @@ -1094,13 +1097,13 @@ def _synch_mode_check(cls, values: Dict): @property def non_key_columns(self) -> List[str]: """Columns of source table that aren't part of the (composite) primary key""" - lowercase_key_columns: Set[str] = {c.lower() for c in self.key_columns} + lowercase_key_columns: Set[str] = {c.lower() for c in self.key_columns} # type: ignore source_table_columns = self.source_table.columns - non_key_columns: List[str] = [c for c in source_table_columns if c.lower() not in lowercase_key_columns] + non_key_columns: List[str] = [c for c in source_table_columns if c.lower() not in lowercase_key_columns] # type: ignore return non_key_columns @property - def staging_table(self): + def staging_table(self) -> str: """Intermediate table on snowflake where staging results are stored""" if stg_tbl_name := self.staging_table_name: return stg_tbl_name @@ -1108,13 +1111,14 @@ def staging_table(self): return f"{self.source_table.table}_stg" @property - def reader(self): + def reader(self) -> Union[DeltaTableReader, DeltaTableStreamReader]: """ DeltaTable reader Returns: -------- - DeltaTableReader the will yield source delta table + Union[DeltaTableReader, DeltaTableStreamReader] + DeltaTableReader that will yield source delta table """ # Wrap in lambda functions to mimic lazy evaluation. # This ensures the Task doesn't fail if a config isn't provided for a reader/writer that isn't used anyway @@ -1164,13 +1168,13 @@ def _get_writer(self) -> Union[SnowflakeWriter, ForEachBatchStreamWriter]: (BatchOutputMode.MERGE, True): lambda: ForEachBatchStreamWriter( checkpointLocation=self.checkpoint_location, batch_function=self._merge_batch_write_fn( - key_columns=self.key_columns, + key_columns=self.key_columns, # type: ignore non_key_columns=self.non_key_columns, staging_table=self.staging_table, ), ), } - return map_mode_writer[(self.synchronisation_mode, self.streaming)]() + return map_mode_writer[(self.synchronisation_mode, self.streaming)]() # type: ignore @property def writer(self) -> Union[ForEachBatchStreamWriter, SnowflakeWriter]: @@ -1191,27 +1195,28 @@ def writer(self) -> Union[ForEachBatchStreamWriter, SnowflakeWriter]: self.writer_ = self._get_writer() return self.writer_ - def truncate_table(self, snowflake_table): + def truncate_table(self, snowflake_table: str) -> None: """Truncate a given snowflake table""" - truncate_query = f"""TRUNCATE TABLE IF EXISTS {snowflake_table}""" + truncate_query = f"""TRUNCATE TABLE IF EXISTS {snowflake_table}""" # nosec B608: hardcoded_sql_expressions query_executor = RunQuery( **self.get_options(), query=truncate_query, ) query_executor.execute() - def drop_table(self, snowflake_table): + def drop_table(self, snowflake_table: str) -> None: """Drop a given snowflake table""" self.log.warning(f"Dropping table {snowflake_table} from snowflake") - drop_table_query = f"""DROP TABLE IF EXISTS {snowflake_table}""" + drop_table_query = f"""DROP TABLE IF EXISTS {snowflake_table}""" # nosec B608: hardcoded_sql_expressions query_executor = RunQuery(**self.get_options(), query=drop_table_query) query_executor.execute() - def _merge_batch_write_fn(self, key_columns, non_key_columns, staging_table): + def _merge_batch_write_fn(self, key_columns: List[str], non_key_columns: List[str], staging_table: str) -> Callable: """Build a batch write function for merge mode""" # pylint: disable=unused-argument - def inner(dataframe: DataFrame, batchId: int): + # noinspection PyUnusedLocal,PyPep8Naming + def inner(dataframe: DataFrame, batchId: int) -> None: self._build_staging_table(dataframe, key_columns, non_key_columns, staging_table) self._merge_staging_table_into_target() @@ -1223,17 +1228,19 @@ def _compute_latest_changes_per_pk( dataframe: DataFrame, key_columns: List[str], non_key_columns: List[str] ) -> DataFrame: """Compute the latest changes per primary key""" - windowSpec = Window.partitionBy(*key_columns).orderBy(f.col("_commit_version").desc()) + window_spec = Window.partitionBy(*key_columns).orderBy(f.col("_commit_version").desc()) ranked_df = ( dataframe.filter("_change_type != 'update_preimage'") - .withColumn("rank", f.rank().over(windowSpec)) + .withColumn("rank", f.rank().over(window_spec)) # type: ignore .filter("rank = 1") .select(*key_columns, *non_key_columns, "_change_type") # discard unused columns .distinct() ) return ranked_df - def _build_staging_table(self, dataframe, key_columns, non_key_columns, staging_table): + def _build_staging_table( + self, dataframe: DataFrame, key_columns: List[str], non_key_columns: List[str], staging_table: str + ) -> None: """Build snowflake staging table""" ranked_df = self._compute_latest_changes_per_pk(dataframe, key_columns, non_key_columns) batch_writer = SnowflakeWriter( @@ -1248,9 +1255,9 @@ def _merge_staging_table_into_target(self) -> None: merge_query = self._build_sf_merge_query( target_table=self.target_table, stage_table=self.staging_table, - pk_columns=self.key_columns, + pk_columns=self.key_columns, # type: ignore non_pk_columns=self.non_key_columns, - enable_deletion=self.enable_deletion, + enable_deletion=self.enable_deletion, # type: ignore ) query_executor = RunQuery( @@ -1261,8 +1268,12 @@ def _merge_staging_table_into_target(self) -> None: @staticmethod def _build_sf_merge_query( - target_table: str, stage_table: str, pk_columns: List[str], non_pk_columns, enable_deletion: bool = False - ): + target_table: str, + stage_table: str, + pk_columns: List[str], + non_pk_columns: List[str], + enable_deletion: bool = False, + ) -> str: """Build a CDF merge query string Parameters @@ -1316,7 +1327,7 @@ def extract(self) -> DataFrame: self.output.source_df = df return df - def load(self, df) -> DataFrame: + def load(self, df: DataFrame) -> DataFrame: """Load source table into snowflake""" if self.synchronisation_mode == BatchOutputMode.MERGE: self.log.info(f"Truncating staging table {self.staging_table}") @@ -1325,7 +1336,7 @@ def load(self, df) -> DataFrame: self.output.target_df = df return df - def execute(self) -> None: + def execute(self) -> SnowflakeStep.Output: # extract df = self.extract() self.output.source_df = df @@ -1336,9 +1347,5 @@ def execute(self) -> None: if not self.persist_staging: # If it's a streaming job, await for termination before dropping staging table if self.streaming: - self.writer.await_termination() + self.writer.await_termination() # type: ignore self.drop_table(self.staging_table) - - def run(self): - """alias of execute""" - return self.execute() diff --git a/src/koheesio/spark/transformations/__init__.py b/src/koheesio/spark/transformations/__init__.py index 251d66fa..3f273a85 100644 --- a/src/koheesio/spark/transformations/__init__.py +++ b/src/koheesio/spark/transformations/__init__.py @@ -21,16 +21,14 @@ Extended ColumnsTransformation class with an additional `target_column` field """ -from typing import List, Optional, Union +from typing import Iterator, List, Optional, Union from abc import ABC, abstractmethod -from pyspark.sql import Column from pyspark.sql import functions as f -from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import DataType from koheesio.models import Field, ListOfColumns, field_validator -from koheesio.spark import SparkStep +from koheesio.spark import Column, DataFrame, SparkStep from koheesio.spark.utils import SparkDatatype @@ -58,9 +56,7 @@ class Transformation(SparkStep, ABC): class AddOne(Transformation): def execute(self): - self.output.df = self.df.withColumn( - "new_column", f.col("old_column") + 1 - ) + self.output.df = self.df.withColumn("new_column", f.col("old_column") + 1) ``` In the example above, the `execute` method is implemented to add 1 to the values of the `old_column` and store the @@ -236,7 +232,7 @@ class ColumnConfig: allows to run the transformation for all columns of a given type. A user can trigger this behavior by either omitting the `columns` parameter or by passing a single `*` as a column name. In both cases, the `run_for_all_data_type` will be used to determine the data type. - Value should be be passed as a SparkDatatype enum. + Value should be passed as a SparkDatatype enum. (default: [None]) limit_data_type : Optional[List[SparkDatatype]] @@ -251,12 +247,12 @@ class ColumnConfig: (default: False) """ - run_for_all_data_type: Optional[List[SparkDatatype]] = [None] + run_for_all_data_type: Optional[List[SparkDatatype]] = [None] # type: ignore limit_data_type: Optional[List[SparkDatatype]] = [None] data_type_strict_mode: bool = False @field_validator("columns", mode="before") - def set_columns(cls, columns_value): + def set_columns(cls, columns_value: ListOfColumns) -> ListOfColumns: """Validate columns through the columns configuration provided""" columns = columns_value run_for_all_data_type = cls.ColumnConfig.run_for_all_data_type @@ -280,7 +276,7 @@ def run_for_all_is_set(self) -> bool: @property def limit_data_type_is_set(self) -> bool: """Returns True if limit_data_type is set""" - return self.ColumnConfig.limit_data_type[0] is not None + return self.ColumnConfig.limit_data_type[0] is not None # type: ignore[index] @property def data_type_strict_mode_is_set(self) -> bool: @@ -288,7 +284,10 @@ def data_type_strict_mode_is_set(self) -> bool: return self.ColumnConfig.data_type_strict_mode def column_type_of_col( - self, col: Union[str, Column], df: Optional[DataFrame] = None, simple_return_mode: bool = True + self, + col: Union[Column, str], + df: Optional[DataFrame] = None, + simple_return_mode: bool = True, ) -> Union[DataType, str]: """ Returns the dataType of a Column object as a string. @@ -338,12 +337,15 @@ def column_type_of_col( if not df: raise RuntimeError("No valid Dataframe was passed") - if not isinstance(col, Column): - col = f.col(col) + if not isinstance(col, Column): # type:ignore[misc, arg-type] + col = f.col(col) # type:ignore[arg-type] - # ask the JVM for the name of the column - # noinspection PyProtectedMember - col_name = col._jc.toString() + # noinspection PyProtectedMember,PyUnresolvedReferences + col_name = ( + col._expr._unparsed_identifier + if col.__class__.__module__ == "pyspark.sql.connect.column" + else col._jc.toString() # type: ignore # noqa: E721 + ) # In order to check the datatype of the column, we have to ask the DataFrame its schema df_col = [c for c in df.schema if c.name == col_name][0] @@ -382,7 +384,7 @@ def get_all_columns_of_specific_type(self, data_type: Union[str, SparkDatatype]) ] return columns_of_given_type - def is_column_type_correct(self, column): + def is_column_type_correct(self, column: Union[Column, str]) -> bool: """Check if column type is correct and handle it if not, when limit_data_type is set""" if not self.limit_data_type_is_set: return True @@ -401,19 +403,19 @@ def is_column_type_correct(self, column): self.log.warning(f"Column `{column}` is not of type `{limit_data_types}` and will be skipped.") return False - def get_limit_data_types(self): + def get_limit_data_types(self) -> list: """Get the limit_data_type as a list of strings""" - return [dt.value for dt in self.ColumnConfig.limit_data_type] + return [dt.value for dt in self.ColumnConfig.limit_data_type] # type: ignore - def get_columns(self) -> iter: + def get_columns(self) -> Iterator[str]: """Return an iterator of the columns""" # If `run_for_all_is_set` is True, we want to run the transformation for all columns of a given type if self.run_for_all_is_set: columns = [] - for data_type in self.ColumnConfig.run_for_all_data_type: + for data_type in self.ColumnConfig.run_for_all_data_type: # type: ignore columns += self.get_all_columns_of_specific_type(data_type) else: - columns = self.columns + columns = self.columns # type:ignore[assignment] for column in columns: if self.is_column_type_correct(column): @@ -521,7 +523,7 @@ def func(self, column: Column) -> Column: """ raise NotImplementedError - def get_columns_with_target(self) -> iter: + def get_columns_with_target(self) -> Iterator[tuple[str, str]]: """Return an iterator of the columns Works just like in get_columns from the ColumnsTransformation class except that it handles the `target_column` @@ -550,7 +552,7 @@ def get_columns_with_target(self) -> iter: yield target_column, column - def execute(self): + def execute(self) -> None: """Execute on a ColumnsTransformationWithTarget handles self.df (input) and set self.output.df (output) This can be left unchanged, and hence should not be implemented in the child class. """ @@ -560,7 +562,7 @@ def execute(self): func = self.func # select the applicable function df = df.withColumn( target_column, - func(f.col(column)), + func(f.col(column)), # type:ignore[arg-type] ) self.output.df = df diff --git a/src/koheesio/spark/transformations/arrays.py b/src/koheesio/spark/transformations/arrays.py index d58a1333..dfc59c9e 100644 --- a/src/koheesio/spark/transformations/arrays.py +++ b/src/koheesio/spark/transformations/arrays.py @@ -27,15 +27,15 @@ from abc import ABC from functools import reduce -from pyspark.sql import Column -from pyspark.sql import functions as F +from pyspark.sql import functions as f from koheesio.models import Field +from koheesio.spark import Column from koheesio.spark.transformations import ColumnsTransformationWithTarget from koheesio.spark.utils import ( + SPARK_MINOR_VERSION, SparkDatatype, spark_data_type_is_numeric, - spark_minor_version, ) __all__ = [ @@ -87,7 +87,7 @@ class ArrayDistinct(ArrayTransformation): ) def func(self, column: Column) -> Column: - _fn = F.array_distinct(column) + _fn = f.array_distinct(column) # noinspection PyUnresolvedReferences element_type = self.column_type_of_col(column, None, False).elementType @@ -95,7 +95,7 @@ def func(self, column: Column) -> Column: if self.filter_empty: # Remove null values from array - if spark_minor_version >= 3.4: + if SPARK_MINOR_VERSION >= 3.4: # Run array_compact if spark version is 3.4 or higher # https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.array_compact.html # pylint: disable=E0611 @@ -105,15 +105,15 @@ def func(self, column: Column) -> Column: # pylint: enable=E0611 else: # Otherwise, remove null from array using array_except - _fn = F.array_except(_fn, F.array(F.lit(None))) + _fn = f.array_except(_fn, f.array(f.lit(None))) # Remove nan or empty values from array (depends on the type of the elements in array) if is_numeric: # Remove nan from array (float/int/numbers) - _fn = F.array_except(_fn, F.array(F.lit(float("nan")).cast(element_type))) + _fn = f.array_except(_fn, f.array(f.lit(float("nan")).cast(element_type))) else: # Remove empty values from array (string/text) - _fn = F.array_except(_fn, F.array(F.lit(""), F.lit(" "))) + _fn = f.array_except(_fn, f.array(f.lit(""), f.lit(" "))) return _fn @@ -139,7 +139,7 @@ class Explode(ArrayTransformation): def func(self, column: Column) -> Column: if self.distinct: column = ArrayDistinct.from_step(self).func(column) - return F.explode_outer(column) if self.preserve_nulls else F.explode(column) + return f.explode_outer(column) if self.preserve_nulls else f.explode(column) class ExplodeDistinct(Explode): @@ -168,7 +168,7 @@ class ArrayReverse(ArrayTransformation): """ def func(self, column: Column) -> Column: - return F.reverse(column) + return f.reverse(column) class ArraySort(ArrayTransformation): @@ -190,7 +190,7 @@ class ArraySort(ArrayTransformation): ) def func(self, column: Column) -> Column: - column = F.array_sort(column) + column = f.array_sort(column) if self.reverse: # Reverse the order of elements in the array column = ArrayReverse.from_step(self).func(column) @@ -277,18 +277,19 @@ def func(self, column: Column) -> Column: The processed column with NaN and/or NULL values removed from elements. """ - def apply_logic(x: Column): + def apply_logic(x: Column) -> Column: if self.keep_nan is False and self.keep_null is False: - logic = x.isNotNull() & ~F.isnan(x) + logic = x.isNotNull() & ~f.isnan(x) elif self.keep_nan is False: - logic = ~F.isnan(x) + logic = ~f.isnan(x) elif self.keep_null is False: logic = x.isNotNull() - + else: + raise ValueError("unexpected condition") return logic if self.keep_nan is False or self.keep_null is False: - column = F.filter(column, apply_logic) + column = f.filter(column, apply_logic) return column @@ -322,25 +323,25 @@ def func(self, column: Column) -> Column: def filter_logic(x: Column, _val: Any): if self.keep_null and self.keep_nan: - logic = (x != F.lit(_val)) | x.isNull() | F.isnan(x) + logic = (x != f.lit(_val)) | x.isNull() | f.isnan(x) elif self.keep_null: - logic = (x != F.lit(_val)) | x.isNull() + logic = (x != f.lit(_val)) | x.isNull() elif self.keep_nan: - logic = (x != F.lit(_val)) | F.isnan(x) + logic = (x != f.lit(_val)) | f.isnan(x) else: - logic = x != F.lit(_val) + logic = x != f.lit(_val) return logic # Check if the value is iterable (i.e., a list, tuple, or set) if isinstance(value, (list, tuple, set)): - result = reduce(lambda res, val: F.filter(res, lambda x: filter_logic(x, val)), value, column) + result = reduce(lambda res, val: f.filter(res, lambda x: filter_logic(x, val)), value, column) else: # If the value is not iterable, simply remove the value from the array - result = F.filter(column, lambda x: filter_logic(x, value)) + result = f.filter(column, lambda x: filter_logic(x, value)) if self.make_distinct: - result = F.array_distinct(result) + result = f.array_distinct(result) return result @@ -357,7 +358,7 @@ class ArrayMin(ArrayTransformation): """ def func(self, column: Column) -> Column: - return F.array_min(column) + return f.array_min(column) class ArrayMax(ArrayNullNanProcess): @@ -375,7 +376,7 @@ def func(self, column: Column) -> Column: # Call for processing of nan values column = super().func(column) - return F.array_max(column) + return f.array_max(column) class ArraySum(ArrayNullNanProcess): @@ -400,6 +401,7 @@ class ArraySum(ArrayNullNanProcess): def func(self, column: Column) -> Column: """Using the `aggregate` function to sum the values in the array""" # raise an error if the array contains non-numeric elements + # noinspection PyUnresolvedReferences element_type = self.column_type_of_col(column, None, False).elementType if not spark_data_type_is_numeric(element_type): raise ValueError( @@ -413,8 +415,8 @@ def func(self, column: Column) -> Column: # Using the `aggregate` function to sum the values in the array by providing the initial value as 0.0 and the # lambda function to add the elements together. Pyspark will automatically infer the type of the initial value # making 0.0 valid for both integer and float types. - initial_value = F.lit(0.0) - return F.aggregate(column, initial_value, lambda accumulator, x: accumulator + x) + initial_value = f.lit(0.0) + return f.aggregate(column, initial_value, lambda accumulator, x: accumulator + x) class ArrayMean(ArrayNullNanProcess): @@ -433,6 +435,7 @@ class ArrayMean(ArrayNullNanProcess): def func(self, column: Column) -> Column: """Calculate the mean of the values in the array""" # raise an error if the array contains non-numeric elements + # noinspection PyUnresolvedReferences element_type = self.column_type_of_col(col=column, df=None, simple_return_mode=False).elementType if not spark_data_type_is_numeric(element_type): @@ -444,9 +447,9 @@ def func(self, column: Column) -> Column: _sum = ArraySum.from_step(self).func(column) # Call for processing of nan values column = super().func(column) - _size = F.size(column) + _size = f.size(column) # return 0 if the size of the array is 0 to avoid division by zero - return F.when(_size == 0, F.lit(0)).otherwise(_sum / _size) + return f.when(_size == 0, f.lit(0)).otherwise(_sum / _size) class ArrayMedian(ArrayNullNanProcess): @@ -467,42 +470,44 @@ class ArrayMedian(ArrayNullNanProcess): ``` """ - def func(self, column: Column) -> Column: + def func(self, column: Column) -> Column: # type: ignore """Calculate the median of the values in the array""" # Call for processing of nan values column = super().func(column) sorted_array = ArraySort.from_step(self).func(column) - _size: Column = F.size(sorted_array) + _size: Column = f.size(sorted_array) # Calculate the middle index. If the size is odd, PySpark discards the fractional part. # Use floor function to ensure the result is an integer - middle: Column = F.floor((_size + 1) / 2).cast("int") + # noinspection PyTypeChecker + middle: Column = f.floor((_size + 1) / 2).cast("int") # Define conditions is_size_zero: Column = _size == 0 is_column_null: Column = column.isNull() + # noinspection PyTypeChecker is_size_even: Column = _size % 2 == 0 # Define actions / responses # For even-sized arrays, calculate the average of the two middle elements - average_of_middle_elements = (F.element_at(sorted_array, middle) + F.element_at(sorted_array, middle + 1)) / 2 + average_of_middle_elements = (f.element_at(sorted_array, middle) + f.element_at(sorted_array, middle + 1)) / 2 # For odd-sized arrays, select the middle element - middle_element = F.element_at(sorted_array, middle) + middle_element = f.element_at(sorted_array, middle) # In case the array is empty, return either None or 0 - none_value = F.lit(None) - zero_value = F.lit(0) + none_value = f.lit(None) + zero_value = f.lit(0) median = ( # Check if the size of the array is 0 - F.when( + f.when( is_size_zero, # If the size of the array is 0 and the column is null, return None # If the size of the array is 0 and the column is not null, return 0 - F.when(is_column_null, none_value).otherwise(zero_value), + f.when(is_column_null, none_value).otherwise(zero_value), ).otherwise( # If the size of the array is not 0, calculate the median - F.when(is_size_even, average_of_middle_elements).otherwise(middle_element) + f.when(is_size_even, average_of_middle_elements).otherwise(middle_element) ) ) diff --git a/src/koheesio/spark/transformations/camel_to_snake.py b/src/koheesio/spark/transformations/camel_to_snake.py index 7a0b8ebb..33f5d237 100644 --- a/src/koheesio/spark/transformations/camel_to_snake.py +++ b/src/koheesio/spark/transformations/camel_to_snake.py @@ -11,7 +11,7 @@ camel_to_snake_re = re.compile("([a-z0-9])([A-Z])") -def convert_camel_to_snake(name: str): +def convert_camel_to_snake(name: str) -> str: """ Converts a string from camelCase to snake_case. @@ -65,14 +65,14 @@ class CamelToSnakeTransformation(ColumnsTransformation): """ - columns: Optional[ListOfColumns] = Field( + columns: Optional[ListOfColumns] = Field( # type: ignore default="", alias="column", description="The column or columns to convert. If no columns are specified, all columns will be converted. " "A list of columns or a single column can be specified. For example: `['column1', 'column2']` or `'column1'` ", ) - def execute(self): + def execute(self) -> ColumnsTransformation.Output: _df = self.df # Prepare columns input: diff --git a/src/koheesio/spark/transformations/cast_to_datatype.py b/src/koheesio/spark/transformations/cast_to_datatype.py index 004c0efb..42e2ebae 100644 --- a/src/koheesio/spark/transformations/cast_to_datatype.py +++ b/src/koheesio/spark/transformations/cast_to_datatype.py @@ -1,7 +1,8 @@ +# noinspection PyUnresolvedReferences """ Transformations to cast a column or set of columns to a given datatype. -Each one of these have been vetted to throw warnings when wrong datatypes are passed (to skip erroring any job or +Each one of these have been vetted to throw warnings when wrong datatypes are passed (to prevent errors in any job or pipeline). Furthermore, detailed tests have been added to ensure that types are actually compatible as prescribed. @@ -124,7 +125,7 @@ class CastToDatatype(ColumnsTransformationWithTarget): datatype: Union[str, SparkDatatype] = Field(default=..., description="Datatype. Choose from SparkDatatype Enum") @field_validator("datatype") - def validate_datatype(cls, datatype_value) -> SparkDatatype: + def validate_datatype(cls, datatype_value: Union[str, SparkDatatype]) -> SparkDatatype: # type: ignore """Validate the datatype.""" # handle string input try: @@ -142,7 +143,7 @@ def validate_datatype(cls, datatype_value) -> SparkDatatype: def func(self, column: Column) -> Column: # This is to let the IDE explicitly know that the datatype is not a string, but a `SparkDatatype` Enum - datatype: SparkDatatype = self.datatype + datatype: SparkDatatype = self.datatype # type: ignore return column.cast(datatype.spark_type()) @@ -631,7 +632,7 @@ class ColumnConfig(CastToDatatype.ColumnConfig): ) @model_validator(mode="after") - def validate_scale_and_precisions(self): + def validate_scale_and_precisions(self) -> "CastToDecimal": """Validate the precision and scale values.""" precision_value = self.precision scale_value = self.scale diff --git a/src/koheesio/spark/transformations/date_time/__init__.py b/src/koheesio/spark/transformations/date_time/__init__.py index 9270110f..931fe5df 100644 --- a/src/koheesio/spark/transformations/date_time/__init__.py +++ b/src/koheesio/spark/transformations/date_time/__init__.py @@ -4,7 +4,6 @@ from pytz import all_timezones_set -from pyspark.sql import Column from pyspark.sql import functions as f from pyspark.sql.functions import ( col, @@ -17,10 +16,11 @@ ) from koheesio.models import Field, field_validator, model_validator +from koheesio.spark import Column from koheesio.spark.transformations import ColumnsTransformationWithTarget -def change_timezone(column: Union[str, Column], source_timezone: str, target_timezone: str): +def change_timezone(column: Union[str, Column], source_timezone: str, target_timezone: str) -> Column: """Helper function to change from one timezone to another wrapper around `pyspark.sql.functions.from_utc_timestamp` and `to_utc_timestamp` @@ -140,7 +140,7 @@ class ChangeTimeZone(ColumnsTransformationWithTarget): ) @model_validator(mode="before") - def validate_no_duplicate_timezones(cls, values): + def validate_no_duplicate_timezones(cls, values: dict) -> dict: """Validate that source and target timezone are not the same""" from_timezone_value = values.get("from_timezone") to_timezone_value = values.get("o_timezone") @@ -151,7 +151,7 @@ def validate_no_duplicate_timezones(cls, values): return values @field_validator("from_timezone", "to_timezone") - def validate_timezone(cls, timezone_value): + def validate_timezone(cls, timezone_value: str) -> str: """Validate that the timezone is a valid timezone.""" if timezone_value not in all_timezones_set: raise ValueError( @@ -163,7 +163,7 @@ def validate_timezone(cls, timezone_value): def func(self, column: Column) -> Column: return change_timezone(column=column, source_timezone=self.from_timezone, target_timezone=self.to_timezone) - def execute(self): + def execute(self) -> ColumnsTransformationWithTarget.Output: df = self.df for target_column, column in self.get_columns_with_target(): diff --git a/src/koheesio/spark/transformations/date_time/interval.py b/src/koheesio/spark/transformations/date_time/interval.py index 9b574a7d..e30244aa 100644 --- a/src/koheesio/spark/transformations/date_time/interval.py +++ b/src/koheesio/spark/transformations/date_time/interval.py @@ -29,7 +29,7 @@ These classes are subclasses of `ColumnsTransformationWithTarget` and hence can be used to perform transformations on multiple columns at once. -The above transformations both use the provided `asjust_time()` function to perform the actual transformation. +The above transformations both use the provided `adjust_time()` function to perform the actual transformation. See also: --------- @@ -120,47 +120,70 @@ `DateTimeSubtractInterval` works in a similar way, but subtracts an interval value from a datetime column. """ +from __future__ import annotations + from typing import Literal, Union -from pyspark.sql import Column +from pyspark.sql import Column as SparkColumn from pyspark.sql.functions import col, expr -from pyspark.sql.utils import ParseException from koheesio.models import Field, field_validator +from koheesio.spark import Column, ParseException from koheesio.spark.transformations import ColumnsTransformationWithTarget +from koheesio.spark.utils import check_if_pyspark_connect_is_supported, get_column_name # create a literal constraining the operations to 'add' and 'subtract' Operations = Literal["add", "subtract"] -class DateTimeColumn(Column): +class DateTimeColumn(SparkColumn): """A datetime column that can be adjusted by adding or subtracting an interval value using the `+` and `-` operators. """ - def __add__(self, value: str): + def __add__(self, value: str) -> Column: """Add an `interval` value to a date or time column A valid value is a string that can be parsed by the `interval` function in Spark SQL. See https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html#interval-literal """ - return self.from_column(adjust_time(self, operation="add", interval=value)) + print(f"__add__: {value = }") + return adjust_time(self, operation="add", interval=value) - def __sub__(self, value: str): + def __sub__(self, value: str) -> Column: """Subtract an `interval` value to a date or time column A valid value is a string that can be parsed by the `interval` function in Spark SQL. See https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html#interval-literal """ - return self.from_column(adjust_time(self, operation="subtract", interval=value)) + return adjust_time(self, operation="subtract", interval=value) + # noinspection PyProtectedMember @classmethod - def from_column(cls, column: Column): + def from_column(cls, column: Column) -> Union["DateTimeColumn", "DateTimeColumnConnect"]: """Create a DateTimeColumn from an existing Column""" - return cls(column._jc) + if isinstance(column, SparkColumn): + return DateTimeColumn(column._jc) + return DateTimeColumnConnect(expr=column._expr) + + +# if spark version is 3.5 or higher, we have to account for the connect mode +if check_if_pyspark_connect_is_supported(): + from pyspark.sql.connect.column import Column as ConnectColumn + + class DateTimeColumnConnect(ConnectColumn): + """A datetime column that can be adjusted by adding or subtracting an interval value using the `+` and `-` + operators. + Optimized for Spark Connect mode. + """ + + __add__ = DateTimeColumn.__add__ + __sub__ = DateTimeColumn.__sub__ + from_column = DateTimeColumn.from_column -def validate_interval(interval: str): + +def validate_interval(interval: str) -> str: """Validate an interval string Parameters @@ -173,14 +196,20 @@ def validate_interval(interval: str): ValueError If the interval string is invalid """ + from koheesio.spark.utils.common import get_active_session + from koheesio.spark.utils.connect import is_remote_session + try: - expr(f"interval '{interval}'") + if is_remote_session(): + get_active_session().sql(f"SELECT interval '{interval}'") # type: ignore + else: + expr(f"interval '{interval}'") except ParseException as e: raise ValueError(f"Value '{interval}' is not a valid interval.") from e return interval -def dt_column(column: Union[str, Column]) -> DateTimeColumn: +def dt_column(column: Column) -> DateTimeColumn: """Convert a column to a DateTimeColumn Aims to be a drop-in replacement for `pyspark.sql.functions.col` that returns a DateTimeColumn instead of a Column. @@ -204,7 +233,7 @@ def dt_column(column: Union[str, Column]) -> DateTimeColumn: """ if isinstance(column, str): column = col(column) - elif not isinstance(column, Column): + elif type(column) not in ("pyspark.sql.Column", "pyspark.sql.connect.column.Column"): raise TypeError(f"Expected column to be of type str or Column, got {type(column)} instead.") return DateTimeColumn.from_column(column) @@ -268,14 +297,16 @@ def adjust_time(column: Column, operation: Operations, interval: str) -> Column: # check that value is a valid interval interval = validate_interval(interval) - column_name = column._jc.toString() + column_name = get_column_name(column) # determine the operation to perform try: operation = { "add": "try_add", "subtract": "try_subtract", - }[operation] + }[ + operation + ] # type: ignore except KeyError as e: raise ValueError(f"Operation '{operation}' is not valid. Must be either 'add' or 'subtract'.") from e @@ -336,7 +367,7 @@ class DateTimeAddInterval(ColumnsTransformationWithTarget): # validators validate_interval = field_validator("interval")(validate_interval) - def func(self, column: Column): + def func(self, column: Column) -> Column: return adjust_time(column, operation=self.operation, interval=self.interval) diff --git a/src/koheesio/spark/transformations/drop_column.py b/src/koheesio/spark/transformations/drop_column.py index 975ad506..d4da7772 100644 --- a/src/koheesio/spark/transformations/drop_column.py +++ b/src/koheesio/spark/transformations/drop_column.py @@ -45,6 +45,6 @@ class DropColumn(ColumnsTransformation): In this example, the `product` column is dropped from the DataFrame `df`. """ - def execute(self): + def execute(self) -> ColumnsTransformation.Output: self.log.info(f"{self.column=}") self.output.df = self.df.drop(*self.columns) diff --git a/src/koheesio/spark/transformations/dummy.py b/src/koheesio/spark/transformations/dummy.py index 21e9a88c..c8baf902 100644 --- a/src/koheesio/spark/transformations/dummy.py +++ b/src/koheesio/spark/transformations/dummy.py @@ -34,5 +34,5 @@ class DummyTransformation(Transformation): """ - def execute(self): + def execute(self) -> Transformation.Output: self.output.df = self.df.withColumn("hello", lit("world")) diff --git a/src/koheesio/spark/transformations/get_item.py b/src/koheesio/spark/transformations/get_item.py index 941daec3..647508e8 100644 --- a/src/koheesio/spark/transformations/get_item.py +++ b/src/koheesio/spark/transformations/get_item.py @@ -11,7 +11,7 @@ from koheesio.spark.utils import SparkDatatype -def get_item(column: Column, key: Union[str, int]): +def get_item(column: Column, key: Union[str, int]) -> Column: """ Wrapper around pyspark.sql.functions.getItem diff --git a/src/koheesio/spark/transformations/hash.py b/src/koheesio/spark/transformations/hash.py index 4c55dd10..a6e8608c 100644 --- a/src/koheesio/spark/transformations/hash.py +++ b/src/koheesio/spark/transformations/hash.py @@ -17,7 +17,7 @@ STRING = SparkDatatype.STRING -def sha2_hash(columns: List[str], delimiter: Optional[str] = "|", num_bits: Optional[HASH_ALGORITHM] = 256): +def sha2_hash(columns: List[str], delimiter: Optional[str] = "|", num_bits: Optional[HASH_ALGORITHM] = 256) -> Column: """ hash the value of 1 or more columns using SHA-2 family of hash functions @@ -43,16 +43,16 @@ def sha2_hash(columns: List[str], delimiter: Optional[str] = "|", num_bits: Opti _columns = [] for c in columns: if isinstance(c, str): - c: Column = col(c) + c: Column = col(c) # type: ignore _columns.append(c.cast(STRING.spark_type())) # concatenate columns if more than 1 column is provided if len(_columns) > 1: - column = concat_ws(delimiter, *_columns) + column = concat_ws(delimiter, *_columns) # type: ignore else: column = _columns[0] - return sha2(column, num_bits) + return sha2(column, num_bits) # type: ignore class Sha2Hash(ColumnsTransformation): @@ -92,7 +92,7 @@ class Sha2Hash(ColumnsTransformation): default=..., description="The generated hash will be written to the column name specified here" ) - def execute(self): + def execute(self) -> ColumnsTransformation.Output: columns = list(self.get_columns()) self.output.df = ( self.df.withColumn( diff --git a/src/koheesio/spark/transformations/lookup.py b/src/koheesio/spark/transformations/lookup.py index f1b5a9c2..73292ec8 100644 --- a/src/koheesio/spark/transformations/lookup.py +++ b/src/koheesio/spark/transformations/lookup.py @@ -12,10 +12,11 @@ from typing import List, Optional, Union from enum import Enum -import pyspark.sql.functions as f -from pyspark.sql import Column, DataFrame +from pyspark.sql import Column +from pyspark.sql import functions as f from koheesio.models import BaseModel, Field, field_validator +from koheesio.spark import DataFrame from koheesio.spark.transformations import Transformation @@ -122,8 +123,8 @@ class DataframeLookup(Transformation): column from the `right_df` is aliased as `right_value` in the output dataframe. """ - df: DataFrame = Field(default=None, description="The left Spark DataFrame") - other: DataFrame = Field(default=None, description="The right Spark DataFrame") + df: Optional[DataFrame] = Field(default=None, description="The left Spark DataFrame") + other: Optional[DataFrame] = Field(default=None, description="The right Spark DataFrame") on: Union[List[JoinMapping], JoinMapping] = Field( default=..., alias="join_mapping", @@ -135,14 +136,14 @@ class DataframeLookup(Transformation): description="List of target columns. If only one target is passed, it can be passed as a single object.", ) how: Optional[JoinType] = Field( - default=JoinType.LEFT, description="What type of join to perform. Defaults to left. " + JoinType.__doc__ + default=JoinType.LEFT, description="What type of join to perform. Defaults to left. " + str(JoinType.__doc__) ) hint: Optional[JoinHint] = Field( - default=None, description="What type of join hint to use. Defaults to None. " + JoinHint.__doc__ + default=None, description="What type of join hint to use. Defaults to None. " + str(JoinHint.__doc__) ) @field_validator("on", "targets") - def set_list(cls, value): + def set_list(cls, value: Union[List[JoinMapping], JoinMapping, List[TargetColumn], TargetColumn]) -> List: """Ensure that we can pass either a single object, or a list of objects""" return [value] if not isinstance(value, list) else value @@ -160,8 +161,8 @@ def execute(self) -> Output: """Execute the lookup transformation""" # prepare the right dataframe prepared_right_df = self.get_right_df().select( - *[join_mapping.column for join_mapping in self.on], - *[target.column for target in self.targets], + *[join_mapping.column for join_mapping in self.on], # type: ignore + *[target.column for target in self.targets], # type: ignore ) if self.hint: prepared_right_df = prepared_right_df.hint(self.hint) @@ -170,7 +171,7 @@ def execute(self) -> Output: self.output.left_df = self.df self.output.right_df = prepared_right_df self.output.df = self.df.join( - prepared_right_df, + prepared_right_df, # type: ignore on=[join_mapping.source_column for join_mapping in self.on], how=self.how, ) diff --git a/src/koheesio/spark/transformations/repartition.py b/src/koheesio/spark/transformations/repartition.py index 6d466236..73a6363b 100644 --- a/src/koheesio/spark/transformations/repartition.py +++ b/src/koheesio/spark/transformations/repartition.py @@ -12,7 +12,7 @@ class Repartition(ColumnsTransformation): With repartition, the number of partitions can be given as an optional value. If this is not provided, a default value is used. The default number of partitions is defined by the spark config 'spark.sql.shuffle.partitions', for - which the default value is 200 and will never exceed the number or rows in the DataFrame (whichever is value is + which the default value is 200 and will never exceed the number of rows in the DataFrame (whichever is value is lower). If columns are omitted, the entire DataFrame is repartitioned without considering the particular values in the @@ -20,9 +20,9 @@ class Repartition(ColumnsTransformation): Parameters ---------- - column : Optional[Union[str, List[str]]], optional, default=None + columns : Optional[Union[str, List[str]]], optional, default=None Name of the source column(s). If omitted, the entire DataFrame is repartitioned without considering the - particular values in the columns. Alias: columns + particular values in the columns. Alias: column num_partitions : Optional[int], optional, default=None The number of partitions to repartition to. If omitted, the default number of partitions is used as defined by the spark config 'spark.sql.shuffle.partitions'. @@ -38,15 +38,15 @@ class Repartition(ColumnsTransformation): """ columns: Optional[ListOfColumns] = Field(default="", alias="column", description="Name of the source column(s)") - numPartitions: Optional[int] = Field( + num_partitions: Optional[int] = Field( default=None, - alias="num_partitions", + alias="numPartitions", description="The number of partitions to repartition to. If omitted, the default number of partitions is used " "as defined by the spark config 'spark.sql.shuffle.partitions'.", ) @model_validator(mode="before") - def _validate_field_and_num_partitions(cls, values): + def _validate_field_and_num_partitions(cls, values: dict) -> dict: """Ensure that at least one of the fields 'columns' and 'num_partitions' is provided.""" columns_value = values.get("columns") or values.get("column") num_partitions_value = values.get("numPartitions") or values.get("num_partitions") @@ -57,10 +57,10 @@ def _validate_field_and_num_partitions(cls, values): values["numPartitions"] = num_partitions_value return values - def execute(self): + def execute(self) -> ColumnsTransformation.Output: # Prepare columns input: columns = self.df.columns if self.columns == ["*"] else self.columns # Prepare repartition input: # num_partitions comes first, but if it is not provided it should not be included as None. - repartition_inputs = [i for i in [self.numPartitions, *columns] if i] + repartition_inputs = [i for i in [self.num_partitions, *columns] if i] # type: ignore self.output.df = self.df.repartition(*repartition_inputs) diff --git a/src/koheesio/spark/transformations/replace.py b/src/koheesio/spark/transformations/replace.py index 977b11be..6f10613e 100644 --- a/src/koheesio/spark/transformations/replace.py +++ b/src/koheesio/spark/transformations/replace.py @@ -2,15 +2,15 @@ from typing import Optional, Union -from pyspark.sql import Column from pyspark.sql.functions import col, lit, when from koheesio.models import Field +from koheesio.spark import Column from koheesio.spark.transformations import ColumnsTransformationWithTarget from koheesio.spark.utils import SparkDatatype -def replace(column: Union[Column, str], to_value: str, from_value: Optional[str] = None): +def replace(column: Union[Column, str], to_value: str, from_value: Optional[str] = None) -> Column: """Function to replace a particular value in a column with another one""" # make sure we have a Column object if isinstance(column, str): diff --git a/src/koheesio/spark/transformations/row_number_dedup.py b/src/koheesio/spark/transformations/row_number_dedup.py index c0d80f17..13625282 100644 --- a/src/koheesio/spark/transformations/row_number_dedup.py +++ b/src/koheesio/spark/transformations/row_number_dedup.py @@ -6,12 +6,13 @@ from __future__ import annotations -from typing import Optional, Union +from typing import List, Optional, Union -from pyspark.sql import Column, Window, WindowSpec +from pyspark.sql import Window, WindowSpec from pyspark.sql.functions import col, desc, row_number from koheesio.models import Field, conlist, field_validator +from koheesio.spark import Column from koheesio.spark.transformations import ColumnsTransformation @@ -24,7 +25,7 @@ class RowNumberDedup(ColumnsTransformation): the top-row_number row for each group of duplicates. The row_number of each row can be stored in a specified target column or a default column named "meta_row_number_column". The class also provides an option to preserve meta columns - (like the row_numberk column) in the output DataFrame. + (like the `row_number` column) in the output DataFrame. Attributes ---------- @@ -58,7 +59,7 @@ class RowNumberDedup(ColumnsTransformation): ) @field_validator("sort_columns", mode="before") - def set_sort_columns(cls, columns_value): + def set_sort_columns(cls, columns_value: Union[str, Column, List[Union[str, Column]]]) -> List[Union[str, Column]]: """ Validates and optimizes the sort_columns parameter. @@ -75,7 +76,6 @@ def set_sort_columns(cls, columns_value): List[Union[str, Column]] The optimized and deduplicated list of sort columns. """ - # Convert single string or Column object to a list columns = [columns_value] if isinstance(columns_value, (str, Column)) else [*columns_value] # Remove empty strings, None, etc. @@ -117,7 +117,7 @@ def window_spec(self) -> WindowSpec: return Window.partitionBy([*self.get_columns()]).orderBy(*order_clause) - def execute(self) -> RowNumberDedup.Output: + def execute(self) -> RowNumberDedup.Output: # type: ignore """ Performs the row_number deduplication operation on the DataFrame. diff --git a/src/koheesio/spark/transformations/sql_transform.py b/src/koheesio/spark/transformations/sql_transform.py index 4d47f2a9..030e1d47 100644 --- a/src/koheesio/spark/transformations/sql_transform.py +++ b/src/koheesio/spark/transformations/sql_transform.py @@ -6,6 +6,7 @@ from koheesio.models.sql import SqlBaseStep from koheesio.spark.transformations import Transformation +from koheesio.spark.utils import SPARK_MINOR_VERSION from koheesio.utils import get_random_string @@ -26,12 +27,20 @@ class SqlTransform(SqlBaseStep, Transformation): ``` """ - def execute(self): + def execute(self) -> Transformation.Output: table_name = get_random_string(prefix="sql_transform") self.params = {**self.params, "table_name": table_name} - df = self.df - df.createOrReplaceTempView(table_name) + from koheesio.spark.utils.connect import is_remote_session + + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session() and self.df.isStreaming: + raise RuntimeError( + "SQL Transform is not supported in remote sessions with streaming dataframes." + "See https://issues.apache.org/jira/browse/SPARK-45957" + "It is fixed in PySpark 4.0.0" + ) + + self.df.createOrReplaceTempView(table_name) query = self.query self.output.df = self.spark.sql(query) diff --git a/src/koheesio/spark/transformations/strings/change_case.py b/src/koheesio/spark/transformations/strings/change_case.py index 42d63018..3906b352 100644 --- a/src/koheesio/spark/transformations/strings/change_case.py +++ b/src/koheesio/spark/transformations/strings/change_case.py @@ -74,7 +74,7 @@ class ColumnConfig(ColumnsTransformationWithTarget.ColumnConfig): run_for_all_data_type = [SparkDatatype.STRING] limit_data_type = [SparkDatatype.STRING] - def func(self, column: Column): + def func(self, column: Column) -> Column: return lower(column) @@ -126,7 +126,7 @@ class UpperCase(LowerCase): to upper case. """ - def func(self, column: Column): + def func(self, column: Column) -> Column: return upper(column) @@ -179,7 +179,7 @@ class TitleCase(LowerCase): to title case (each word now starts with an upper case). """ - def func(self, column: Column): + def func(self, column: Column) -> Column: return initcap(column) diff --git a/src/koheesio/spark/transformations/strings/concat.py b/src/koheesio/spark/transformations/strings/concat.py index 9f7a68dd..f0c8c8ae 100644 --- a/src/koheesio/spark/transformations/strings/concat.py +++ b/src/koheesio/spark/transformations/strings/concat.py @@ -4,7 +4,6 @@ from typing import List, Optional -from pyspark.sql import DataFrame from pyspark.sql.functions import col, concat, concat_ws from koheesio.models import Field, field_validator @@ -111,7 +110,7 @@ class Concat(ColumnsTransformation): ) @field_validator("target_column") - def get_target_column(cls, target_column_value, values): + def get_target_column(cls, target_column_value: str, values: dict) -> str: """Get the target column name if it is not provided. If not provided, a name will be generated by concatenating the names of the source columns with an '_'.""" @@ -122,8 +121,8 @@ def get_target_column(cls, target_column_value, values): return target_column_value - def execute(self) -> DataFrame: + def execute(self) -> ColumnsTransformation.Output: columns = [col(s) for s in self.get_columns()] - self.output.df = self.df.withColumn( + self.output.df = self.df.withColumn( # type: ignore self.target_column, concat_ws(self.spacer, *columns) if self.spacer else concat(*columns) ) diff --git a/src/koheesio/spark/transformations/strings/pad.py b/src/koheesio/spark/transformations/strings/pad.py index 45cccdc5..132faf8f 100644 --- a/src/koheesio/spark/transformations/strings/pad.py +++ b/src/koheesio/spark/transformations/strings/pad.py @@ -82,7 +82,7 @@ class Pad(ColumnsTransformationWithTarget): default="left", description='On which side to add the characters . Either "left" or "right". Defaults to "left"' ) - def func(self, column: Column): + def func(self, column: Column) -> Column: func = lpad if self.direction == "left" else rpad return func(column, self.length, self.character) diff --git a/src/koheesio/spark/transformations/strings/regexp.py b/src/koheesio/spark/transformations/strings/regexp.py index 63f31719..4747e109 100644 --- a/src/koheesio/spark/transformations/strings/regexp.py +++ b/src/koheesio/spark/transformations/strings/regexp.py @@ -13,8 +13,6 @@ """ -from typing import Optional - from pyspark.sql import Column from pyspark.sql.functions import regexp_extract, regexp_replace @@ -95,13 +93,13 @@ class RegexpExtract(ColumnsTransformationWithTarget): """ regexp: str = Field(default=..., description="The Java regular expression to extract") - index: Optional[int] = Field( + index: int = Field( default=0, description="When there are more groups in the match, you can indicate which one you want. " "0 means the whole match. 1 and above are groups within that match.", ) - def func(self, column: Column): + def func(self, column: Column) -> Column: return regexp_extract(column, self.regexp, self.index) @@ -154,5 +152,5 @@ class RegexpReplace(ColumnsTransformationWithTarget): description="String to replace matched pattern with.", ) - def func(self, column: Column): + def func(self, column: Column) -> Column: return regexp_replace(column, self.regexp, self.replacement) diff --git a/src/koheesio/spark/transformations/strings/replace.py b/src/koheesio/spark/transformations/strings/replace.py index 8c18892f..689835f0 100644 --- a/src/koheesio/spark/transformations/strings/replace.py +++ b/src/koheesio/spark/transformations/strings/replace.py @@ -91,12 +91,12 @@ class Replace(ColumnsTransformationWithTarget): new_value: str = Field(default=..., alias="to", description="The new value to replace this with") @field_validator("original_value", "new_value", mode="before") - def cast_values_to_str(cls, value): + def cast_values_to_str(cls, value: Optional[str]) -> Optional[str]: """Cast values to string if they are not None""" if value: return str(value) - def func(self, column: Column): + def func(self, column: Column) -> Column: when_statement = ( when(column.isNull(), lit(self.new_value)) if not self.original_value diff --git a/src/koheesio/spark/transformations/strings/split.py b/src/koheesio/spark/transformations/strings/split.py index a7ef90af..0c71d370 100644 --- a/src/koheesio/spark/transformations/strings/split.py +++ b/src/koheesio/spark/transformations/strings/split.py @@ -67,13 +67,13 @@ class SplitAll(ColumnsTransformationWithTarget): split_pattern: str = Field(default=..., description="The pattern to split the column contents.") - def func(self, column: Column): + def func(self, column: Column) -> Column: return split(column, pattern=self.split_pattern) class SplitAtFirstMatch(SplitAll): """ - Like SplitAll, but only splits the string once. You can specify whether you want the first or second part.. + Like SplitAll, but only splits the string once. You can specify whether you want the first or second part. Note ---- @@ -128,7 +128,7 @@ class SplitAtFirstMatch(SplitAll): description="Takes the first part of the split when true, the second part when False. Other parts are ignored.", ) - def func(self, column: Column): + def func(self, column: Column) -> Column: split_func = split(column, pattern=self.split_pattern) # first part diff --git a/src/koheesio/spark/transformations/strings/substring.py b/src/koheesio/spark/transformations/strings/substring.py index 14b0b219..09fef9f0 100644 --- a/src/koheesio/spark/transformations/strings/substring.py +++ b/src/koheesio/spark/transformations/strings/substring.py @@ -2,8 +2,6 @@ Extracts a substring from a string column starting at the given position. """ -from typing import Optional - from pyspark.sql import Column from pyspark.sql.functions import substring, when from pyspark.sql.types import StringType @@ -63,18 +61,18 @@ class Substring(ColumnsTransformationWithTarget): """ start: PositiveInt = Field(default=..., description="The starting position") - length: Optional[int] = Field( + length: int = Field( default=-1, description="The target length for the string. use -1 to perform until end", ) @field_validator("length") - def _valid_length(cls, length_value): + def _valid_length(cls, length_value: int) -> int: """Integer.maxint fix for Java. Python's sys.maxsize is larger which makes f.substring fail""" if length_value == -1: return 2147483647 return length_value - def func(self, column: Column): + def func(self, column: Column) -> Column: return when(column.isNull(), None).otherwise(substring(column, self.start, self.length)).cast(StringType()) diff --git a/src/koheesio/spark/transformations/strings/trim.py b/src/koheesio/spark/transformations/strings/trim.py index 36a9105e..ce116e24 100644 --- a/src/koheesio/spark/transformations/strings/trim.py +++ b/src/koheesio/spark/transformations/strings/trim.py @@ -114,7 +114,7 @@ class ColumnConfig(ColumnsTransformationWithTarget.ColumnConfig): default="left-right", description="On which side to remove the spaces. Either 'left', 'right' or 'left-right'" ) - def func(self, column: Column): + def func(self, column: Column) -> Column: if self.direction == "left": return f.ltrim(column) diff --git a/src/koheesio/spark/transformations/transform.py b/src/koheesio/spark/transformations/transform.py index 69d39e12..8ec4d916 100644 --- a/src/koheesio/spark/transformations/transform.py +++ b/src/koheesio/spark/transformations/transform.py @@ -6,12 +6,11 @@ from __future__ import annotations -from typing import Callable, Dict +from typing import Callable, Dict, Optional from functools import partial -from pyspark.sql import DataFrame - from koheesio.models import ExtraParamsMixin, Field +from koheesio.spark import DataFrame from koheesio.spark.transformations import Transformation from koheesio.utils import get_args_for_func @@ -71,19 +70,19 @@ def some_func(df, a: str, b: str): ``` """ - func: Callable = Field(default=None, description="The function to be called on the DataFrame.") + func: Callable = Field(default=..., description="The function to be called on the DataFrame.") - def __init__(self, func: Callable, params: Dict = None, df: DataFrame = None, **kwargs): + def __init__(self, func: Callable, params: Dict = None, df: Optional[DataFrame] = None, **kwargs: dict): params = {**(params or {}), **kwargs} super().__init__(func=func, params=params, df=df) - def execute(self): + def execute(self) -> Transformation.Output: """Call the function on the DataFrame with the given keyword arguments.""" func, kwargs = get_args_for_func(self.func, self.params) self.output.df = self.df.transform(func=func, **kwargs) @classmethod - def from_func(cls, func: Callable, **kwargs) -> Callable[..., Transform]: + def from_func(cls, func: Callable, **kwargs: dict) -> Callable[..., Transform]: """Create a Transform class from a function. Useful for creating a new class with a different name. This method uses the `functools.partial` function to create a new class with the given function and keyword diff --git a/src/koheesio/spark/transformations/uuid5.py b/src/koheesio/spark/transformations/uuid5.py index ec735329..545a2f9d 100644 --- a/src/koheesio/spark/transformations/uuid5.py +++ b/src/koheesio/spark/transformations/uuid5.py @@ -38,9 +38,9 @@ def uuid5_namespace(ns: Optional[Union[str, uuid.UUID]]) -> uuid.UUID: def hash_uuid5( input_value: str, - namespace: Optional[Union[str, uuid.UUID]] = "", - extra_string: Optional[str] = "", -): + namespace: Union[str, uuid.UUID] = "", + extra_string: str = "", +) -> str: """pure python implementation of HashUUID5 See: https://docs.python.org/3/library/uuid.html#uuid.uuid5 @@ -49,9 +49,9 @@ def hash_uuid5( ---------- input_value : str value that will be hashed - namespace : Optional[str | uuid.UUID] + namespace : str | uuid.UUID, optional, default="" namespace DNS - extra_string : Optional[str] + extra_string : str, optional, default="" optional extra string that will be prepended to the input_value Returns @@ -127,7 +127,7 @@ class HashUUID5(Transformation): description="List of columns that should be hashed. Should contain the name of at least 1 column. A list of " "columns or a single column can be specified. For example: `['column1', 'column2']` or `'column1'`", ) - delimiter: Optional[str] = Field(default="|", description="Separator for the string that will eventually be hashed") + delimiter: str = Field(default="|", description="Separator for the string that will eventually be hashed") namespace: Optional[Union[str, uuid.UUID]] = Field(default="", description="Namespace DNS") extra_string: Optional[str] = Field( default="", @@ -138,7 +138,7 @@ class HashUUID5(Transformation): description: str = "Generate a UUID with the UUID5 algorithm" @field_validator("source_columns") - def _set_columns(cls, columns): + def _set_columns(cls, columns: ListOfColumns) -> ListOfColumns: """Ensures every column is wrapped in backticks""" columns = [f"`{column}`" for column in columns] return columns diff --git a/src/koheesio/spark/utils.py b/src/koheesio/spark/utils.py deleted file mode 100644 index b382c4b1..00000000 --- a/src/koheesio/spark/utils.py +++ /dev/null @@ -1,203 +0,0 @@ -""" -Spark Utility functions -""" - -import os -from enum import Enum - -from pyspark.sql.types import ( - ArrayType, - BinaryType, - BooleanType, - ByteType, - DataType, - DateType, - DecimalType, - DoubleType, - FloatType, - IntegerType, - LongType, - MapType, - NullType, - ShortType, - StringType, - StructType, - TimestampType, -) -from pyspark.version import __version__ as spark_version - -__all__ = [ - "SparkDatatype", - "get_spark_minor_version", - "import_pandas_based_on_pyspark_version", - "on_databricks", - "schema_struct_to_schema_str", - "spark_data_type_is_array", - "spark_data_type_is_numeric", - "spark_minor_version", -] - - -class SparkDatatype(Enum): - """ - Allowed spark datatypes - - The following table lists the data types that are supported by Spark SQL. - - | Data type | SQL name | - |---------------|---------------------------| - | ByteType | BYTE, TINYINT | - | ShortType | SHORT, SMALLINT | - | IntegerType | INT, INTEGER | - | LongType | LONG, BIGINT | - | FloatType | FLOAT, REAL | - | DoubleType | DOUBLE | - | DecimalType | DECIMAL, DEC, NUMERIC | - | StringType | STRING | - | BinaryType | BINARY | - | BooleanType | BOOLEAN | - | TimestampType | TIMESTAMP, TIMESTAMP_LTZ | - | DateType | DATE | - | ArrayType | ARRAY | - | MapType | MAP | - | NullType | VOID | - - Not supported yet - ---------------- - * __TimestampNTZType__ - TIMESTAMP_NTZ - * __YearMonthIntervalType__ - INTERVAL YEAR, INTERVAL YEAR TO MONTH, INTERVAL MONTH - * __DayTimeIntervalType__ - INTERVAL DAY, INTERVAL DAY TO HOUR, INTERVAL DAY TO MINUTE, INTERVAL DAY TO SECOND, INTERVAL HOUR, - INTERVAL HOUR TO MINUTE, INTERVAL HOUR TO SECOND, INTERVAL MINUTE, INTERVAL MINUTE TO SECOND, INTERVAL SECOND - - See Also - -------- - https://spark.apache.org/docs/latest/sql-ref-datatypes.html#supported-data-types - """ - - # byte - BYTE = "byte" - TINYINT = "byte" - - # short - SHORT = "short" - SMALLINT = "short" - - # integer - INTEGER = "integer" - INT = "integer" - - # long - LONG = "long" - BIGINT = "long" - - # float - FLOAT = "float" - REAL = "float" - - # timestamp - TIMESTAMP = "timestamp" - TIMESTAMP_LTZ = "timestamp" - - # decimal - DECIMAL = "decimal" - DEC = "decimal" - NUMERIC = "decimal" - - DATE = "date" - DOUBLE = "double" - STRING = "string" - BINARY = "binary" - BOOLEAN = "boolean" - ARRAY = "array" - MAP = "map" - VOID = "void" - - @property - def spark_type(self) -> DataType: - """Returns the spark type for the given enum value""" - mapping_dict = { - "byte": ByteType, - "short": ShortType, - "integer": IntegerType, - "long": LongType, - "float": FloatType, - "double": DoubleType, - "decimal": DecimalType, - "string": StringType, - "binary": BinaryType, - "boolean": BooleanType, - "timestamp": TimestampType, - "date": DateType, - "array": ArrayType, - "map": MapType, - "void": NullType, - } - return mapping_dict[self.value] - - @classmethod - def from_string(cls, value: str) -> "SparkDatatype": - """Allows for getting the right Enum value by simply passing a string value - This method is not case-sensitive - """ - return getattr(cls, value.upper()) - - -def get_spark_minor_version() -> float: - """Returns the minor version of the spark instance. - - For example, if the spark version is 3.3.2, this function would return 3.3 - """ - return float(".".join(spark_version.split(".")[:2])) - - -# short-hand for the get_spark_minor_version function -spark_minor_version: float = get_spark_minor_version() - - -def on_databricks() -> bool: - """Retrieve if we're running on databricks or elsewhere""" - dbr_version = os.getenv("DATABRICKS_RUNTIME_VERSION", None) - return dbr_version is not None and dbr_version != "" - - -def spark_data_type_is_array(data_type: DataType) -> bool: - """Check if the column's dataType is of type ArrayType""" - return isinstance(data_type, ArrayType) - - -def spark_data_type_is_numeric(data_type: DataType) -> bool: - """Check if the column's dataType is of type ArrayType""" - return isinstance(data_type, (IntegerType, LongType, FloatType, DoubleType, DecimalType)) - - -def schema_struct_to_schema_str(schema: StructType) -> str: - """Converts a StructType to a schema str""" - if not schema: - return "" - return ",\n".join([f"{field.name} {field.dataType.typeName().upper()}" for field in schema.fields]) - - -def import_pandas_based_on_pyspark_version(): - """ - This function checks the installed version of PySpark and then tries to import the appropriate version of pandas. - If the correct version of pandas is not installed, it raises an ImportError with a message indicating which version - of pandas should be installed. - """ - try: - import pandas as pd - - pyspark_version = get_spark_minor_version() - pandas_version = pd.__version__ - - if (pyspark_version < 3.4 and pandas_version >= "2") or (pyspark_version >= 3.4 and pandas_version < "2"): - raise ImportError( - f"For PySpark {pyspark_version}, " - f"please install Pandas version {'< 2' if pyspark_version < 3.4 else '>= 2'}" - ) - - return pd - except ImportError as e: - raise ImportError("Pandas module is not installed.") from e diff --git a/src/koheesio/spark/utils/__init__.py b/src/koheesio/spark/utils/__init__.py new file mode 100644 index 00000000..1ecc4449 --- /dev/null +++ b/src/koheesio/spark/utils/__init__.py @@ -0,0 +1,27 @@ +from koheesio.spark.utils.common import ( + SPARK_MINOR_VERSION, + SparkDatatype, + check_if_pyspark_connect_is_supported, + get_column_name, + get_spark_minor_version, + import_pandas_based_on_pyspark_version, + on_databricks, + schema_struct_to_schema_str, + show_string, + spark_data_type_is_array, + spark_data_type_is_numeric, +) + +__all__ = [ + "SparkDatatype", + "import_pandas_based_on_pyspark_version", + "on_databricks", + "schema_struct_to_schema_str", + "spark_data_type_is_array", + "spark_data_type_is_numeric", + "show_string", + "get_spark_minor_version", + "SPARK_MINOR_VERSION", + "check_if_pyspark_connect_is_supported", + "get_column_name", +] diff --git a/src/koheesio/spark/utils/common.py b/src/koheesio/spark/utils/common.py new file mode 100644 index 00000000..10050d5f --- /dev/null +++ b/src/koheesio/spark/utils/common.py @@ -0,0 +1,382 @@ +""" +Spark Utility functions +""" + +import importlib +import inspect +import os +from typing import Union +from enum import Enum +from types import ModuleType + +from pyspark import sql +from pyspark.sql.types import ( + ArrayType, + BinaryType, + BooleanType, + ByteType, + DateType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + LongType, + MapType, + NullType, + ShortType, + StringType, + StructType, + TimestampType, +) +from pyspark.version import __version__ as spark_version + +__all__ = [ + "SparkDatatype", + "import_pandas_based_on_pyspark_version", + "on_databricks", + "schema_struct_to_schema_str", + "spark_data_type_is_array", + "spark_data_type_is_numeric", + "show_string", + "get_spark_minor_version", + "SPARK_MINOR_VERSION", + "AnalysisException", + "Column", + "DataFrame", + "SparkSession", + "ParseException", + "DataType", + "DataFrameReader", + "DataStreamReader", + "DataFrameWriter", + "DataStreamWriter", + "StreamingQuery", + "get_active_session", + "check_if_pyspark_connect_is_supported", + "get_column_name", +] + +try: + from pyspark.errors.exceptions.base import AnalysisException # type: ignore +except (ImportError, ModuleNotFoundError): + from pyspark.sql.utils import AnalysisException # type: ignore + + +AnalysisException = AnalysisException + + +def get_spark_minor_version() -> float: + """Returns the minor version of the spark instance. + + For example, if the spark version is 3.3.2, this function would return 3.3 + """ + return float(".".join(spark_version.split(".")[:2])) + + +# shorthand for the get_spark_minor_version function +SPARK_MINOR_VERSION: float = get_spark_minor_version() + + +def check_if_pyspark_connect_is_supported() -> bool: + result = False + module_name: str = "pyspark" + if SPARK_MINOR_VERSION >= 3.5: + try: + importlib.import_module(f"{module_name}.sql.connect") + from pyspark.sql.connect.column import Column + + _col: Column + result = True + except (ModuleNotFoundError, ImportError): + result = False + return result + + +if check_if_pyspark_connect_is_supported(): + from pyspark.errors.exceptions.captured import ( + ParseException as CapturedParseException, + ) + from pyspark.errors.exceptions.connect import ( + ParseException as ConnectParseException, + ) + from pyspark.sql.connect.column import Column as ConnectColumn + from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame + from pyspark.sql.connect.proto.types_pb2 import DataType as ConnectDataType + from pyspark.sql.connect.readwriter import DataFrameReader, DataFrameWriter + from pyspark.sql.connect.session import SparkSession as ConnectSparkSession + from pyspark.sql.connect.streaming.readwriter import ( + DataStreamReader, + DataStreamWriter, + ) + from pyspark.sql.streaming.query import StreamingQuery + from pyspark.sql.types import DataType as SqlDataType + + Column = Union[sql.Column, ConnectColumn] + DataFrame = Union[sql.DataFrame, ConnectDataFrame] + SparkSession = Union[sql.SparkSession, ConnectSparkSession] + ParseException = (CapturedParseException, ConnectParseException) + DataType = Union[SqlDataType, ConnectDataType] + DataFrameReader = Union[sql.readwriter.DataFrameReader, DataFrameReader] + DataStreamReader = Union[sql.streaming.readwriter.DataStreamReader, DataStreamReader] + DataFrameWriter = Union[sql.readwriter.DataFrameWriter, DataFrameWriter] + DataStreamWriter = Union[sql.streaming.readwriter.DataStreamWriter, DataStreamWriter] + StreamingQuery = StreamingQuery +else: + try: + from pyspark.errors.exceptions.captured import ParseException # type: ignore + except (ImportError, ModuleNotFoundError): + from pyspark.sql.utils import ParseException # type: ignore + + ParseException = ParseException + + from pyspark.sql.column import Column # type: ignore + from pyspark.sql.dataframe import DataFrame # type: ignore + from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter # type: ignore + from pyspark.sql.session import SparkSession # type: ignore + from pyspark.sql.types import DataType # type: ignore + + try: + from pyspark.sql.streaming.query import StreamingQuery + from pyspark.sql.streaming.readwriter import DataStreamReader, DataStreamWriter + except (ImportError, ModuleNotFoundError): + from pyspark.sql.streaming import ( # type: ignore + DataStreamReader, + DataStreamWriter, + StreamingQuery, + ) + DataFrameReader = DataFrameReader + DataStreamReader = DataStreamReader + DataFrameWriter = DataFrameWriter + DataStreamWriter = DataStreamWriter + StreamingQuery = StreamingQuery + + +def get_active_session() -> SparkSession: # type: ignore + if check_if_pyspark_connect_is_supported(): + from pyspark.sql.connect.session import SparkSession as _ConnectSparkSession + + session = _ConnectSparkSession.getActiveSession() or sql.SparkSession.getActiveSession() # type: ignore + else: + session = sql.SparkSession.getActiveSession() # type: ignore + + if not session: + raise RuntimeError( + "No active Spark session found. Please create a Spark session before using module connect_utils." + " Or perform local import of the module." + ) + + return session + + +class SparkDatatype(Enum): + """ + Allowed spark datatypes + + The following table lists the data types that are supported by Spark SQL. + + | Data type | SQL name | + |---------------|---------------------------| + | ByteType | BYTE, TINYINT | + | ShortType | SHORT, SMALLINT | + | IntegerType | INT, INTEGER | + | LongType | LONG, BIGINT | + | FloatType | FLOAT, REAL | + | DoubleType | DOUBLE | + | DecimalType | DECIMAL, DEC, NUMERIC | + | StringType | STRING | + | BinaryType | BINARY | + | BooleanType | BOOLEAN | + | TimestampType | TIMESTAMP, TIMESTAMP_LTZ | + | DateType | DATE | + | ArrayType | ARRAY | + | MapType | MAP | + | NullType | VOID | + + Not supported yet + ---------------- + * __TimestampNTZType__ + TIMESTAMP_NTZ + * __YearMonthIntervalType__ + INTERVAL YEAR, INTERVAL YEAR TO MONTH, INTERVAL MONTH + * __DayTimeIntervalType__ + INTERVAL DAY, INTERVAL DAY TO HOUR, INTERVAL DAY TO MINUTE, INTERVAL DAY TO SECOND, INTERVAL HOUR, + INTERVAL HOUR TO MINUTE, INTERVAL HOUR TO SECOND, INTERVAL MINUTE, INTERVAL MINUTE TO SECOND, INTERVAL SECOND + + See Also + -------- + https://spark.apache.org/docs/latest/sql-ref-datatypes.html#supported-data-types + """ + + # byte + BYTE = "byte" + TINYINT = "byte" + + # short + SHORT = "short" + SMALLINT = "short" + + # integer + INTEGER = "integer" + INT = "integer" + + # long + LONG = "long" + BIGINT = "long" + + # float + FLOAT = "float" + REAL = "float" + + # timestamp + TIMESTAMP = "timestamp" + TIMESTAMP_LTZ = "timestamp" + + # decimal + DECIMAL = "decimal" + DEC = "decimal" + NUMERIC = "decimal" + + DATE = "date" + DOUBLE = "double" + STRING = "string" + BINARY = "binary" + BOOLEAN = "boolean" + ARRAY = "array" + MAP = "map" + VOID = "void" + + @property + def spark_type(self) -> type: + """Returns the spark type for the given enum value""" + mapping_dict = { + "byte": ByteType, + "short": ShortType, + "integer": IntegerType, + "long": LongType, + "float": FloatType, + "double": DoubleType, + "decimal": DecimalType, + "string": StringType, + "binary": BinaryType, + "boolean": BooleanType, + "timestamp": TimestampType, + "date": DateType, + "array": ArrayType, + "map": MapType, + "void": NullType, + } + return mapping_dict[self.value] + + @classmethod + def from_string(cls, value: str) -> "SparkDatatype": + """Allows for getting the right Enum value by simply passing a string value + This method is not case-sensitive + """ + return getattr(cls, value.upper()) + + +def on_databricks() -> bool: + """Retrieve if we're running on databricks or elsewhere""" + dbr_version = os.getenv("DATABRICKS_RUNTIME_VERSION", None) + return dbr_version is not None and dbr_version != "" + + +def spark_data_type_is_array(data_type: DataType) -> bool: # type: ignore + """Check if the column's dataType is of type ArrayType""" + return isinstance(data_type, ArrayType) + + +def spark_data_type_is_numeric(data_type: DataType) -> bool: # type: ignore + """Check if the column's dataType is of type ArrayType""" + return isinstance(data_type, (IntegerType, LongType, FloatType, DoubleType, DecimalType)) + + +def schema_struct_to_schema_str(schema: StructType) -> str: + """Converts a StructType to a schema str""" + if not schema: + return "" + return ",\n".join([f"{field.name} {field.dataType.typeName().upper()}" for field in schema.fields]) + + +def import_pandas_based_on_pyspark_version() -> ModuleType: + """ + This function checks the installed version of PySpark and then tries to import the appropriate version of pandas. + If the correct version of pandas is not installed, it raises an ImportError with a message indicating which version + of pandas should be installed. + """ + try: + import pandas as pd + + pyspark_version = get_spark_minor_version() + pandas_version = pd.__version__ + + if (pyspark_version < 3.4 and pandas_version >= "2") or (pyspark_version >= 3.4 and pandas_version < "2"): + raise ImportError( + f"For PySpark {pyspark_version}, " + f"please install Pandas version {'< 2' if pyspark_version < 3.4 else '>= 2'}" + ) + + return pd + except ImportError as e: + raise ImportError("Pandas module is not installed.") from e + + +# noinspection PyProtectedMember +def show_string(df: DataFrame, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> str: # type: ignore + """Returns a string representation of the DataFrame + The default implementation of DataFrame.show() hardcodes a print statement, which is not always desirable. + With this function, you can get the string representation of the DataFrame instead, and choose how to display it. + + Example + ------- + ```python + print(show_string(df)) + + # or use with a logger + logger.info(show_string(df)) + ``` + + Parameters + ---------- + df : DataFrame + The DataFrame to display + n : int, optional + The number of rows to display, by default 20 + truncate : Union[bool, int], optional + If set to True, truncate the displayed columns, by default True + vertical : bool, optional + If set to True, display the DataFrame vertically, by default False + """ + if SPARK_MINOR_VERSION < 3.5: + return df._jdf.showString(n, truncate, vertical) # type: ignore + # as per spark 3.5, the _show_string method is now available making calls to _jdf.showString obsolete + return df._show_string(n, truncate, vertical) + + +# noinspection PyProtectedMember +def get_column_name(col: Column) -> str: # type: ignore + """Get the column name from a Column object + + Normally, the name of a Column object is not directly accessible in the regular pyspark API. This function + extracts the name of the given column object without needing to provide it in the context of a DataFrame. + + Parameters + ---------- + col: Column + The Column object + + Returns + ------- + str + The name of the given column + """ + # we have to distinguish between the Column object from column from local session and remote + if hasattr(col, "_jc"): + # In case of a 'regular' Column object, we can directly access the name attribute through the _jc attribute + name = col._jc.toString() # type: ignore[operator] + elif any(cls.__module__ == "pyspark.sql.connect.column" for cls in inspect.getmro(col.__class__)): + name = col._expr.name() + else: + raise ValueError("Column object is not a valid Column object") + + return name diff --git a/src/koheesio/spark/utils/connect.py b/src/koheesio/spark/utils/connect.py new file mode 100644 index 00000000..9cf7f028 --- /dev/null +++ b/src/koheesio/spark/utils/connect.py @@ -0,0 +1,19 @@ +from typing import Optional + +from pyspark.sql import SparkSession + +from koheesio.spark.utils.common import ( + check_if_pyspark_connect_is_supported, + get_active_session, +) + +__all__ = ["is_remote_session"] + + +def is_remote_session(spark: Optional[SparkSession] = None) -> bool: + result = False + + if (_spark := spark or get_active_session()) and check_if_pyspark_connect_is_supported(): + result = True if _spark.conf.get("spark.remote", None) else False # type: ignore + + return result diff --git a/src/koheesio/spark/writers/__init__.py b/src/koheesio/spark/writers/__init__.py index 945d4ceb..7f6fa656 100644 --- a/src/koheesio/spark/writers/__init__.py +++ b/src/koheesio/spark/writers/__init__.py @@ -4,10 +4,8 @@ from abc import ABC, abstractmethod from enum import Enum -from pyspark.sql import DataFrame - from koheesio.models import Field -from koheesio.spark import SparkStep +from koheesio.spark import DataFrame, SparkStep # TODO: Investigate if we can clean various OutputModes into a more streamlined structure @@ -52,16 +50,19 @@ class StreamingOutputMode(str, Enum): class Writer(SparkStep, ABC): """The Writer class is used to write the DataFrame to a target.""" - df: Optional[DataFrame] = Field(default=None, description="The Spark DataFrame") + df: Optional[DataFrame] = Field(default=None, description="The Spark DataFrame", exclude=True) format: str = Field(default="delta", description="The format of the output") @property def streaming(self) -> bool: """Check if the DataFrame is a streaming DataFrame or not.""" + if not self.df: + raise RuntimeError("No valid Dataframe was passed") + return self.df.isStreaming @abstractmethod - def execute(self): + def execute(self) -> SparkStep.Output: """Execute on a Writer should handle writing of the self.df (input) as a minimum""" # self.df # input dataframe ... diff --git a/src/koheesio/spark/writers/buffer.py b/src/koheesio/spark/writers/buffer.py index 64f57db7..e94b5f81 100644 --- a/src/koheesio/spark/writers/buffer.py +++ b/src/koheesio/spark/writers/buffer.py @@ -13,13 +13,16 @@ to more arbitrary file systems (e.g., SFTP). """ +from __future__ import annotations + import gzip -from typing import Literal, Optional +from typing import AnyStr, Literal, Optional from abc import ABC from functools import partial from os import linesep from tempfile import SpooledTemporaryFile +# noinspection PyProtectedMember from pandas._typing import CompressionOptions as PandasCompressionOptions from pydantic import InstanceOf @@ -27,6 +30,7 @@ from pyspark import pandas from koheesio.models import ExtraParamsMixin, Field, constr +from koheesio.spark import DataFrame from koheesio.spark.writers import Writer @@ -53,32 +57,32 @@ class Output(Writer.Output, ABC): default_factory=partial(SpooledTemporaryFile, mode="w+b", max_size=0), exclude=True ) - def read(self): + def read(self) -> AnyStr: """Read the buffer""" self.rewind_buffer() data = self.buffer.read() self.rewind_buffer() return data - def rewind_buffer(self): + def rewind_buffer(self): # type: ignore """Rewind the buffer""" self.buffer.seek(0) return self - def reset_buffer(self): + def reset_buffer(self): # type: ignore """Reset the buffer""" self.buffer.truncate(0) self.rewind_buffer() return self - def is_compressed(self): + def is_compressed(self): # type: ignore """Check if the buffer is compressed.""" self.rewind_buffer() magic_number_present = self.buffer.read(2) == b"\x1f\x8b" self.rewind_buffer() return magic_number_present - def compress(self): + def compress(self): # type: ignore """Compress the file_buffer in place using GZIP""" # check if the buffer is already compressed if self.is_compressed(): @@ -95,7 +99,7 @@ def compress(self): return self # to allow for chaining - def write(self, df=None) -> Output: + def write(self, df: DataFrame = None) -> Output: """Write the DataFrame to the buffer""" self.df = df or self.df if not self.df: @@ -260,7 +264,7 @@ class Output(BufferWriter.Output): pandas_df: Optional[pandas.DataFrame] = Field(None, description="The Pandas DataFrame that was written") - def get_options(self, options_type: str = "csv"): + def get_options(self, options_type: str = "csv") -> dict: """Returns the options to pass to Pandas' to_csv() method.""" try: import pandas as _pd @@ -294,7 +298,7 @@ def get_options(self, options_type: str = "csv"): return csv_options - def execute(self): + def execute(self) -> BufferWriter.Output: """Write the DataFrame to the buffer using Pandas to_csv() method. Compression is handled by pandas to_csv() method. """ @@ -353,8 +357,8 @@ class PandasJsonBufferWriter(BufferWriter, ExtraParamsMixin): all other `orient` values, the default is 'epoch'. However, in Koheesio, the default is set to 'iso' irrespective of the `orient` parameter. - - `date_unit`: This parameter specifies the time unit for encoding timestamps and datetime objects. It accepts four - options: 's' for seconds, 'ms' for milliseconds, 'us' for microseconds, and 'ns' for nanoseconds. + - `date_unit`: This parameter specifies the time unit for encoding timestamps and datetime objects. It accepts + four options: 's' for seconds, 'ms' for milliseconds, 'us' for microseconds, and 'ns' for nanoseconds. The default is 'ms'. Note that this parameter is ignored when `date_format='iso'`. ### Orient Parameter @@ -405,13 +409,33 @@ class PandasJsonBufferWriter(BufferWriter, ExtraParamsMixin): - Preserves data types and indexes of the original DataFrame. - Example: ```json - {"schema":{"fields": [{"name": index, "type": dtype}], "primaryKey": [index]}, "pandas_version":"1.4.0"}, "data": [{"column1": value1, "column2": value2}]} + { + "schema": { + "fields": [ + { + "name": "index", + "type": "dtype" + } + ], + "primaryKey": ["index"] + }, + "pandas_version": "1.4.0", + "data": [ + { + "column1": "value1", + "column2": "value2" + } + ] + } ``` - Note: For 'records' orient, set `lines` to True to write each record as a separate line. The pandas output will + Note + ---- + For 'records' orient, set `lines` to True to write each record as a separate line. The pandas output will then match the PySpark output (orient='records' and lines=True parameters). - References: + References + ---------- - [Pandas DataFrame to_json documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_json.html) - [Pandas IO tools (text, CSV, HDF5, …) documentation](https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html) """ @@ -454,7 +478,7 @@ class Output(BufferWriter.Output): pandas_df: Optional[pandas.DataFrame] = Field(None, description="The Pandas DataFrame that was written") - def get_options(self): + def get_options(self) -> dict: """Returns the options to pass to Pandas' to_json() method.""" json_options = { "orient": self.orient, @@ -471,7 +495,7 @@ def get_options(self): return json_options - def execute(self): + def execute(self) -> BufferWriter.Output: """Write the DataFrame to the buffer using Pandas to_json() method.""" df = self.df if self.columns: diff --git a/src/koheesio/spark/writers/delta/batch.py b/src/koheesio/spark/writers/delta/batch.py index 7334f272..7fd8376b 100644 --- a/src/koheesio/spark/writers/delta/batch.py +++ b/src/koheesio/spark/writers/delta/batch.py @@ -34,12 +34,11 @@ ``` """ -from typing import List, Optional, Set, Type, Union from functools import partial +from typing import List, Optional, Set, Type, Union from delta.tables import DeltaMergeBuilder, DeltaTable from py4j.protocol import Py4JError - from pyspark.sql import DataFrameWriter from koheesio.models import ExtraParamsMixin, Field, field_validator @@ -138,7 +137,7 @@ class DeltaTableWriter(Writer, ExtraParamsMixin): alias="outputMode", description=f"{BatchOutputMode.__doc__}\n{StreamingOutputMode.__doc__}", ) - params: Optional[dict] = Field( + params: dict = Field( default_factory=dict, alias="output_mode_params", description="Additional parameters to use for specific mode", @@ -149,7 +148,7 @@ class DeltaTableWriter(Writer, ExtraParamsMixin): ) format: str = "delta" # The format to use for writing the dataframe to the Delta table - _merge_builder: DeltaMergeBuilder = None + _merge_builder: Optional[DeltaMergeBuilder] = None # noinspection PyProtectedMember def __merge(self, merge_builder: Optional[DeltaMergeBuilder] = None) -> Union[DeltaMergeBuilder, DataFrameWriter]: @@ -208,9 +207,7 @@ def __merge(self, merge_builder: Optional[DeltaMergeBuilder] = None) -> Union[De def __merge_all(self) -> Union[DeltaMergeBuilder, DataFrameWriter]: """Merge dataframes using DeltaMergeBuilder or DataFrameWriter""" - merge_cond = self.params.get("merge_cond", None) - - if merge_cond is None: + if (merge_cond := self.params.get("merge_cond")) is None: raise ValueError( "Provide `merge_cond` in DeltaTableWriter(output_mode_params={'merge_cond':''})" ) @@ -233,7 +230,7 @@ def __merge_all(self) -> Union[DeltaMergeBuilder, DataFrameWriter]: return self.__merge(merge_builder=builder) - def _get_merge_builder(self, provided_merge_builder=None) -> DeltaMergeBuilder: + def _get_merge_builder(self, provided_merge_builder: DeltaMergeBuilder = None) -> DeltaMergeBuilder: """Resolves the merge builder. If provided, it will be used, otherwise it will be created from the args""" # A merge builder has been already created - case for merge_all @@ -261,7 +258,7 @@ def _get_merge_builder(self, provided_merge_builder=None) -> DeltaMergeBuilder: "See documentation for options." ) - def _merge_builder_from_args(self): + def _merge_builder_from_args(self) -> DeltaMergeBuilder: """Creates the DeltaMergeBuilder from the provided configuration""" merge_clauses = self.params.get("merge_builder", None) merge_cond = self.params.get("merge_cond", None) @@ -282,10 +279,11 @@ def _merge_builder_from_args(self): return builder @field_validator("output_mode") - def _validate_output_mode(cls, mode): + def _validate_output_mode(cls, mode: Union[str, BatchOutputMode, StreamingOutputMode]) -> str: """Validate `output_mode` value""" if isinstance(mode, str): mode = cls.get_output_mode(mode, options={StreamingOutputMode, BatchOutputMode}) + if not isinstance(mode, BatchOutputMode) and not isinstance(mode, StreamingOutputMode): raise AttributeError( f""" @@ -294,17 +292,18 @@ def _validate_output_mode(cls, mode): Streaming Mode - {StreamingOutputMode.__doc__} """ ) + return str(mode.value) @field_validator("table") - def _validate_table(cls, table): + def _validate_table(cls, table: Union[DeltaTableStep, str]) -> Union[DeltaTableStep, str]: """Validate `table` value""" if isinstance(table, str): return DeltaTableStep(table=table) return table @field_validator("params") - def _validate_params(cls, params): + def _validate_params(cls, params: dict) -> dict: """Validates params. If an array of merge clauses is provided, they will be validated against the available ones in DeltaMergeBuilder""" @@ -331,8 +330,16 @@ def get_output_mode(cls, choice: str, options: Set[Type]) -> Union[BatchOutputMo - BatchOutputMode - StreamingOutputMode """ + from koheesio.spark.utils.connect import is_remote_session + + if ( + choice.upper() in (BatchOutputMode.MERGEALL, BatchOutputMode.MERGE_ALL, BatchOutputMode.MERGE) + and is_remote_session() + ): + raise RuntimeError(f"Output mode {choice.upper()} is not supported in remote mode") + for enum_type in options: - if choice.upper() in [om.value.upper() for om in enum_type]: + if choice.upper() in [om.value.upper() for om in enum_type]: # type: ignore return getattr(enum_type, choice.upper()) raise AttributeError( f""" @@ -352,14 +359,13 @@ def __data_frame_writer(self) -> DataFrameWriter: @property def writer(self) -> Union[DeltaMergeBuilder, DataFrameWriter]: """Specify DeltaTableWriter""" - map_mode_writer = { + map_mode_to_writer = { BatchOutputMode.MERGEALL.value: self.__merge_all, BatchOutputMode.MERGE.value: self.__merge, } + return map_mode_to_writer.get(self.output_mode, self.__data_frame_writer)() # type: ignore - return map_mode_writer.get(self.output_mode, self.__data_frame_writer)() - - def execute(self): + def execute(self) -> Writer.Output: _writer = self.writer if self.table.create_if_not_exists and not self.table.exists: diff --git a/src/koheesio/spark/writers/delta/scd.py b/src/koheesio/spark/writers/delta/scd.py index 87be0fe6..f93762ef 100644 --- a/src/koheesio/spark/writers/delta/scd.py +++ b/src/koheesio/spark/writers/delta/scd.py @@ -22,13 +22,13 @@ from pydantic import InstanceOf -from pyspark.sql import Column -from pyspark.sql import functions as F +from pyspark.sql import functions as f from pyspark.sql.types import DateType, TimestampType from koheesio.models import Field -from koheesio.spark import DataFrame, SparkSession, current_timestamp_utc +from koheesio.spark import Column, DataFrame, SparkSession from koheesio.spark.delta import DeltaTableStep +from koheesio.spark.functions import current_timestamp_utc from koheesio.spark.writers import Writer @@ -73,7 +73,7 @@ class SCD2DeltaTableWriter(Writer): scd2_columns: List[str] = Field( default_factory=list, description="List of attributes for scd2 type (track changes)" ) - scd2_timestamp_col: Optional[Column] = Field( + scd2_timestamp_col: Column = Field( default=None, description="Timestamp column for SCD2 type (track changes). Default to current_timestamp", ) @@ -114,7 +114,7 @@ def _prepare_attr_clause(attrs: List[str], src_alias: str, dest_alias: str) -> O if attrs: attr_clause = list(map(lambda attr: f"NOT ({src_alias}.{attr} <=> {dest_alias}.{attr})", attrs)) - attr_clause = " OR ".join(attr_clause) + attr_clause = " OR ".join(attr_clause) # type: ignore return attr_clause @@ -147,7 +147,7 @@ def _scd2_timestamp(spark: SparkSession, scd2_timestamp_col: Optional[Column] = return scd2_timestamp @staticmethod - def _scd2_end_time(meta_scd2_end_time_col: str, **_kwargs) -> Column: + def _scd2_end_time(meta_scd2_end_time_col: str, **_kwargs: dict) -> Column: """ Generate a SCD2 end time column. @@ -166,7 +166,7 @@ def _scd2_end_time(meta_scd2_end_time_col: str, **_kwargs) -> Column: The generated SCD2 end time column. """ - scd2_end_time = F.expr( + scd2_end_time = f.expr( "CASE WHEN __meta_scd2_system_merge_action='UC' AND cross.__meta_scd2_rn=2 THEN __meta_scd2_timestamp " f" ELSE tgt.{meta_scd2_end_time_col} END" ) @@ -174,7 +174,7 @@ def _scd2_end_time(meta_scd2_end_time_col: str, **_kwargs) -> Column: return scd2_end_time @staticmethod - def _scd2_effective_time(meta_scd2_effective_time_col: str, **_kwargs) -> Column: + def _scd2_effective_time(meta_scd2_effective_time_col: str, **_kwargs: dict) -> Column: """ Generate a SCD2 effective time column. @@ -194,15 +194,15 @@ def _scd2_effective_time(meta_scd2_effective_time_col: str, **_kwargs) -> Column The generated SCD2 effective time column. """ - scd2_effective_time = F.when( - F.expr("__meta_scd2_system_merge_action='UC' and cross.__meta_scd2_rn=1"), - F.col("__meta_scd2_timestamp"), - ).otherwise(F.coalesce(meta_scd2_effective_time_col, "__meta_scd2_timestamp")) + scd2_effective_time = f.when( + f.expr("__meta_scd2_system_merge_action='UC' and cross.__meta_scd2_rn=1"), + f.col("__meta_scd2_timestamp"), + ).otherwise(f.coalesce(meta_scd2_effective_time_col, "__meta_scd2_timestamp")) return scd2_effective_time @staticmethod - def _scd2_is_current(**_kwargs) -> Column: + def _scd2_is_current(**_kwargs: dict) -> Column: """ Generate a SCD2 is_current column. @@ -215,7 +215,7 @@ def _scd2_is_current(**_kwargs) -> Column: The generated SCD2 is_current column. """ - scd2_is_current = F.expr( + scd2_is_current = f.expr( "CASE WHEN __meta_scd2_system_merge_action='UC' AND cross.__meta_scd2_rn=2 THEN False ELSE True END" ) @@ -231,7 +231,7 @@ def _prepare_staging( src_alias: str, dest_alias: str, cross_alias: str, - **_kwargs, + **_kwargs: dict, ) -> DataFrame: """ Prepare a DataFrame for staging. @@ -272,7 +272,7 @@ def _prepare_staging( .alias(src_alias) .join( other=delta_table.toDF() - .filter(F.col(meta_scd2_is_current_col).eqNullSafe(F.lit(True))) + .filter(f.col(meta_scd2_is_current_col).eqNullSafe(f.lit(True))) .alias(dest_alias), on=self.merge_key, how="left", @@ -283,7 +283,7 @@ def _prepare_staging( # Filter cross joined data so that we have one row for U # and another for I in case of closing SCD2 # and keep just one for SCD1 or NEW row - .filter(F.expr("__meta_scd2_system_merge_action='UC' OR cross.__meta_scd2_rn=1")) + .filter(f.expr("__meta_scd2_system_merge_action='UC' OR cross.__meta_scd2_rn=1")) ) return df @@ -297,7 +297,7 @@ def _preserve_existing_target_values( cross_alias: str, dest_alias: str, logger: Logger, - **_kwargs, + **_kwargs: dict, ) -> DataFrame: """ Preserve existing target values in the DataFrame. @@ -342,14 +342,14 @@ def _preserve_existing_target_values( df = ( df.withColumn( f"newly_{c}", - F.when( - F.col("__meta_scd2_system_merge_action").eqNullSafe(F.lit("UC")) - & F.col(f"{cross_alias}.__meta_scd2_rn").eqNullSafe(F.lit(2)), - F.col(f"{dest_alias}.{c}"), - ).otherwise(F.col(f"{src_alias}.{c}")), + f.when( + f.col("__meta_scd2_system_merge_action").eqNullSafe(f.lit("UC")) + & f.col(f"{cross_alias}.__meta_scd2_rn").eqNullSafe(f.lit(2)), + f.col(f"{dest_alias}.{c}"), + ).otherwise(f.col(f"{src_alias}.{c}")), ) - .drop(F.col(f"{src_alias}.{c}")) - .drop(F.col(f"{dest_alias}.{c}")) + .drop(f.col(f"{src_alias}.{c}")) + .drop(f.col(f"{dest_alias}.{c}")) .withColumnRenamed(f"newly_{c}", c) ) @@ -364,7 +364,7 @@ def _add_scd2_columns( meta_scd2_effective_time_col_name: str, meta_scd2_end_time_col_name: str, meta_scd2_is_current_col_name: str, - **_kwargs, + **_kwargs: dict, ) -> DataFrame: """ Add SCD2 columns to the DataFrame. @@ -391,10 +391,10 @@ def _add_scd2_columns( """ df = df.withColumn( meta_scd2_struct_col_name, - F.struct( - F.col("__meta_scd2_effective_time").alias(meta_scd2_effective_time_col_name), - F.col("__meta_scd2_end_time").alias(meta_scd2_end_time_col_name), - F.col("__meta_scd2_is_current").alias(meta_scd2_is_current_col_name), + f.struct( + f.col("__meta_scd2_effective_time").alias(meta_scd2_effective_time_col_name), + f.col("__meta_scd2_end_time").alias(meta_scd2_end_time_col_name), + f.col("__meta_scd2_is_current").alias(meta_scd2_is_current_col_name), ), ).drop( "__meta_scd2_end_time", @@ -415,7 +415,7 @@ def _prepare_merge_builder( merge_key: str, columns_to_process: List[str], meta_scd2_effective_time_col: str, - **_kwargs, + **_kwargs: dict, ) -> DeltaMergeBuilder: """ Prepare a DeltaMergeBuilder for merging data. @@ -536,7 +536,7 @@ def execute(self) -> None: .transform( func=self._prepare_staging, delta_table=delta_table, - merge_action_logic=F.expr(system_merge_action), + merge_action_logic=f.expr(system_merge_action), meta_scd2_is_current_col=meta_scd2_is_current_col, columns_to_process=columns_to_process, src_alias=src_alias, diff --git a/src/koheesio/spark/writers/delta/stream.py b/src/koheesio/spark/writers/delta/stream.py index 33eb7543..aea03a57 100644 --- a/src/koheesio/spark/writers/delta/stream.py +++ b/src/koheesio/spark/writers/delta/stream.py @@ -2,8 +2,8 @@ This module defines the DeltaTableStreamWriter class, which is used to write streaming dataframes to Delta tables. """ -from typing import Optional from email.policy import default +from typing import Optional from pydantic import Field @@ -29,7 +29,7 @@ class Options(BaseModel): description="The maximum number of new files to be considered in every trigger (default: 1000).", ) - def execute(self): + def execute(self) -> DeltaTableWriter.Output: if self.batch_function: self.streaming_query = self.writer.start() else: diff --git a/src/koheesio/spark/writers/delta/utils.py b/src/koheesio/spark/writers/delta/utils.py index ef47e979..2e08a16f 100644 --- a/src/koheesio/spark/writers/delta/utils.py +++ b/src/koheesio/spark/writers/delta/utils.py @@ -4,7 +4,7 @@ from typing import Optional -from py4j.java_gateway import JavaObject +from py4j.java_gateway import JavaObject # type: ignore[import-untyped] def log_clauses(clauses: JavaObject, source_alias: str, target_alias: str) -> Optional[str]: @@ -39,7 +39,7 @@ def log_clauses(clauses: JavaObject, source_alias: str, target_alias: str) -> Op if not clauses.isEmpty(): clauses_type = clauses.last().nodeName().replace("DeltaMergeInto", "") - _processed_clauses = {} + _processed_clauses: dict = {} for i in range(0, clauses.length()): clause = clauses.apply(i) @@ -54,6 +54,8 @@ def log_clauses(clauses: JavaObject, source_alias: str, target_alias: str) -> Op ) elif condition.toString() == "None": condition_clause = "No conditions required" + else: + raise ValueError(f"Condition {condition} is not supported") clause_type: str = clause.clauseType().capitalize() columns = "ALL" if clause_type == "Delete" else clause.actions().toList().apply(0).toString() diff --git a/src/koheesio/spark/writers/dummy.py b/src/koheesio/spark/writers/dummy.py index da26f734..5e90a989 100644 --- a/src/koheesio/spark/writers/dummy.py +++ b/src/koheesio/spark/writers/dummy.py @@ -2,9 +2,8 @@ from typing import Any, Dict, Union -from pyspark.sql import DataFrame - from koheesio.models import Field, PositiveInt, field_validator +from koheesio.spark.utils import show_string from koheesio.spark.writers import Writer @@ -43,7 +42,7 @@ class DummyWriter(Writer): ) @field_validator("truncate") - def int_truncate(cls, truncate_value) -> int: + def int_truncate(cls, truncate_value: Union[int, bool]) -> int: """ Truncate is either a bool or an int. @@ -72,12 +71,8 @@ class Output(Writer.Output): def execute(self) -> Output: """Execute the DummyWriter""" - df: DataFrame = self.df - - # noinspection PyProtectedMember - df_content = df._jdf.showString(self.n, self.truncate, self.vertical) - # logs the equivalent of doing df.show() + df_content = show_string(df=self.df, n=self.n, truncate=self.truncate, vertical=self.vertical) self.log.info(f"content of df that was passed to DummyWriter:\n{df_content}") self.output.head = self.df.head().asDict() diff --git a/src/koheesio/spark/writers/file_writer.py b/src/koheesio/spark/writers/file_writer.py index 362c6200..9e79d562 100644 --- a/src/koheesio/spark/writers/file_writer.py +++ b/src/koheesio/spark/writers/file_writer.py @@ -13,6 +13,8 @@ """ +from __future__ import annotations + from typing import Union from enum import Enum from pathlib import Path @@ -63,24 +65,24 @@ class FileWriter(Writer, ExtraParamsMixin): """ output_mode: BatchOutputMode = Field(default=BatchOutputMode.APPEND, description="The output mode to use") - format: FileFormat = Field(None, description="The file format to use when writing the data.") - path: Union[Path, str] = Field(default=None, description="The path to write the file to") + format: FileFormat = Field(..., description="The file format to use when writing the data.") + path: Union[Path, str] = Field(default=..., description="The path to write the file to") @field_validator("path") - def ensure_path_is_str(cls, v): + def ensure_path_is_str(cls, v: Union[Path, str]) -> str: """Ensure that the path is a string as required by Spark.""" if isinstance(v, Path): return str(v.absolute().as_posix()) return v - def execute(self): + def execute(self) -> FileWriter.Output: writer = self.df.write if self.extra_params: self.log.info(f"Setting extra parameters for the writer: {self.extra_params}") writer = writer.options(**self.extra_params) - writer.save(path=self.path, format=self.format, mode=self.output_mode) + writer.save(path=self.path, format=self.format, mode=self.output_mode) # type: ignore self.output.df = self.df diff --git a/src/koheesio/spark/writers/kafka.py b/src/koheesio/spark/writers/kafka.py index 34b3beea..9cc61f93 100644 --- a/src/koheesio/spark/writers/kafka.py +++ b/src/koheesio/spark/writers/kafka.py @@ -74,12 +74,12 @@ def streaming_query(self) -> Optional[Union[str, StreamingQuery]]: return self.output.streaming_query @property - def _trigger(self): - """return the trigger value as a Trigger object if it is not already one.""" + def _trigger(self) -> dict[str, str]: + """return the value of the Trigger object""" return self.trigger.value @field_validator("trigger") - def _validate_trigger(cls, trigger): + def _validate_trigger(cls, trigger: Optional[Union[Trigger, str, Dict]]) -> Trigger: """Validate the trigger value and convert it to a Trigger object if it is not already one.""" return Trigger.from_any(trigger) @@ -131,7 +131,7 @@ def writer(self) -> Union[DataStreamWriter, DataFrameWriter]: return self.stream_writer if self.streaming else self.batch_writer @property - def options(self): + def options(self) -> Dict[str, str]: """retrieve the kafka options incl topic and broker. Returns @@ -151,7 +151,7 @@ def options(self): return options @property - def logged_option_keys(self): + def logged_option_keys(self) -> set: """keys to be logged""" return { "kafka.bootstrap.servers", @@ -163,7 +163,7 @@ def logged_option_keys(self): "checkpointLocation", } - def execute(self): + def execute(self) -> Writer.Output: """Effectively write the data from the dataframe (streaming of batch) to kafka topic. Returns diff --git a/src/koheesio/spark/writers/stream.py b/src/koheesio/spark/writers/stream.py index 0c55a8be..661ff86e 100644 --- a/src/koheesio/spark/writers/stream.py +++ b/src/koheesio/spark/writers/stream.py @@ -18,10 +18,9 @@ class to run a writer for each batch from typing import Callable, Dict, Optional, Union from abc import ABC, abstractmethod -from pyspark.sql.streaming import DataStreamWriter, StreamingQuery - from koheesio import Step from koheesio.models import ConfigDict, Field, field_validator, model_validator +from koheesio.spark import DataFrame, DataStreamWriter, StreamingQuery from koheesio.spark.writers import StreamingOutputMode, Writer from koheesio.utils import convert_str_to_bool @@ -71,7 +70,7 @@ class Trigger(Step): model_config = ConfigDict(validate_default=False, extra="forbid") @classmethod - def _all_triggers_with_alias(cls): + def _all_triggers_with_alias(cls) -> list: """Internal method to return all trigger types with their alias. Used for logging purposes""" fields = cls.model_fields triggers = [ @@ -82,12 +81,12 @@ def _all_triggers_with_alias(cls): return triggers @property - def triggers(self): + def triggers(self) -> Dict: """Returns a list of tuples with the value for each trigger""" return self.model_dump(exclude={"name", "description"}, by_alias=True) @model_validator(mode="before") - def validate_triggers(cls, triggers: Dict): + def validate_triggers(cls, triggers: Dict) -> Dict: """Validate the trigger value""" params = [*triggers.values()] @@ -100,7 +99,7 @@ def validate_triggers(cls, triggers: Dict): return triggers @field_validator("processing_time", mode="before") - def validate_processing_time(cls, processing_time): + def validate_processing_time(cls, processing_time: str) -> str: """Validate the processing time trigger value""" # adapted from `pyspark.sql.streaming.readwriter.DataStreamWriter.trigger` if not isinstance(processing_time, str): @@ -111,7 +110,7 @@ def validate_processing_time(cls, processing_time): return processing_time @field_validator("continuous", mode="before") - def validate_continuous(cls, continuous): + def validate_continuous(cls, continuous: str) -> str: """Validate the continuous trigger value""" # adapted from `pyspark.sql.streaming.readwriter.DataStreamWriter.trigger` except that the if statement is not # split in two parts @@ -123,7 +122,7 @@ def validate_continuous(cls, continuous): return continuous @field_validator("once", mode="before") - def validate_once(cls, once): + def validate_once(cls, once: str) -> bool: """Validate the once trigger value""" # making value a boolean when given once = convert_str_to_bool(once) @@ -134,7 +133,7 @@ def validate_once(cls, once): return once @field_validator("available_now", mode="before") - def validate_available_now(cls, available_now): + def validate_available_now(cls, available_now: str) -> bool: """Validate the available_now trigger value""" # making value a boolean when given available_now = convert_str_to_bool(available_now) @@ -151,12 +150,12 @@ def value(self) -> Dict[str, str]: return trigger @classmethod - def from_dict(cls, _dict): + def from_dict(cls, _dict: dict) -> "Trigger": """Creates a Trigger class based on a dictionary""" return cls(**_dict) @classmethod - def from_string(cls, trigger: str): + def from_string(cls, trigger: str) -> "Trigger": """Creates a Trigger class based on a string Example @@ -202,7 +201,7 @@ def from_string(cls, trigger: str): return cls.from_dict({trigger_type: value}) @classmethod - def from_any(cls, value): + def from_any(cls, value: Union["Trigger", str, dict]) -> "Trigger": """Dynamically creates a Trigger class based on either another Trigger class, a passed string value, or a dictionary @@ -219,7 +218,7 @@ def from_any(cls, value): raise RuntimeError(f"Unable to create Trigger based on the given value: {value}") - def execute(self): + def execute(self) -> None: """Returns the trigger value as a dictionary This method can be skipped, as the value can be accessed directly from the `value` property """ @@ -251,7 +250,7 @@ class StreamWriter(Writer, ABC): ) trigger: Optional[Union[Trigger, str, Dict]] = Field( - default=Trigger(available_now=True), + default=Trigger(available_now=True), # type: ignore[call-arg] description="Set the trigger for the stream query. If this is not set it process data as batch", ) @@ -260,28 +259,28 @@ class StreamWriter(Writer, ABC): ) @property - def _trigger(self): + def _trigger(self) -> dict: """Returns the trigger value as a dictionary""" return self.trigger.value @field_validator("output_mode") - def _validate_output_mode(cls, mode): + def _validate_output_mode(cls, mode: Union[str, StreamingOutputMode]) -> str: """Ensure that the given mode is a valid StreamingOutputMode""" if isinstance(mode, str): return mode return str(mode.value) @field_validator("trigger") - def _validate_trigger(cls, trigger): + def _validate_trigger(cls, trigger: Union[Trigger, str, Dict]) -> Trigger: """Ensure that the given trigger is a valid Trigger class""" return Trigger.from_any(trigger) - def await_termination(self, timeout: Optional[int] = None): + def await_termination(self, timeout: Optional[int] = None) -> None: """Await termination of the stream query""" self.streaming_query.awaitTermination(timeout=timeout) @property - def stream_writer(self) -> DataStreamWriter: + def stream_writer(self) -> DataStreamWriter: # type: ignore """Returns the stream writer for the given DataFrame and settings""" write_stream = self.df.writeStream.format(self.format).outputMode(self.output_mode) @@ -297,12 +296,12 @@ def stream_writer(self) -> DataStreamWriter: return write_stream @property - def writer(self): + def writer(self) -> DataStreamWriter: # type: ignore """Returns the stream writer since we don't have a batch mode for streams""" return self.stream_writer @abstractmethod - def execute(self): + def execute(self) -> None: raise NotImplementedError @@ -310,17 +309,17 @@ class ForEachBatchStreamWriter(StreamWriter): """Runnable ForEachBatchWriter""" @field_validator("batch_function") - def _validate_batch_function_exists(cls, batch_function): + def _validate_batch_function_exists(cls, batch_function: Callable) -> Callable: """Ensure that a batch_function is defined""" - if not batch_function or not isinstance(batch_function, Callable): + if not batch_function or not isinstance(batch_function, Callable): # type: ignore[truthy-function, arg-type] raise ValueError(f"{cls.__name__} requires a defined for `batch_function`") return batch_function - def execute(self): + def execute(self) -> None: self.streaming_query = self.writer.start() -def writer_to_foreachbatch(writer: Writer): +def writer_to_foreachbatch(writer: Writer) -> Callable: """Call `writer.execute` on each batch To be passed as batch_function for StreamWriter (sub)classes. @@ -343,7 +342,7 @@ def writer_to_foreachbatch(writer: Writer): ``` """ - def inner(df, batch_id: int): + def inner(df: DataFrame, batch_id: int) -> None: """Inner method As per the Spark documentation: diff --git a/src/koheesio/sso/okta.py b/src/koheesio/sso/okta.py index 4a0e840e..5a20ca05 100644 --- a/src/koheesio/sso/okta.py +++ b/src/koheesio/sso/okta.py @@ -5,7 +5,7 @@ from __future__ import annotations from typing import Dict, Optional, Union -from logging import Filter +from logging import Filter, LogRecord from requests import HTTPError @@ -26,7 +26,7 @@ class Okta(HttpPostStep): ) @model_validator(mode="before") - def _set_auth_param(cls, v): + def _set_auth_param(cls, v: dict) -> dict: """ Assign auth parameter with Okta client and secret to the params dictionary. If auth parameter already exists, it will be overwritten. @@ -43,7 +43,7 @@ def __init__(self, okta_object: OktaAccessToken, name: str = "OktaToken"): self.__okta_object = okta_object super().__init__(name=name) - def filter(self, record): + def filter(self, record: LogRecord) -> bool: # noinspection PyUnresolvedReferences if token := self.__okta_object.output.token: token_value = token.get_secret_value() @@ -79,13 +79,13 @@ class Output(Okta.Output): token: Optional[SecretStr] = Field(default=None, description="Okta authentication token") - def __init__(self, **kwargs): + def __init__(self, **kwargs): # type: ignore[no-untyped-def] _logger = LoggingFactory.get_logger(name=self.__class__.__name__, inherit_from_koheesio=True) logger_filter = LoggerOktaTokenFilter(okta_object=self) _logger.addFilter(logger_filter) super().__init__(**kwargs) - def execute(self): + def execute(self) -> None: """ Execute an HTTP Post call to Okta service and retrieve the access token. """ @@ -97,7 +97,11 @@ def execute(self): raw_payload = self.output.raw_payload if status_code != 200: - raise HTTPError(f"Request failed with '{status_code}' code. Payload: {raw_payload}") + raise HTTPError( + f"Request failed with '{status_code}' code. Payload: {raw_payload}", + response=self.output.response_raw, + request=None, + ) # noinspection PyUnresolvedReferences json_payload = self.output.json_payload diff --git a/src/koheesio/steps/__init__.py b/src/koheesio/steps/__init__.py index cb7cd4d3..89c04ba0 100644 --- a/src/koheesio/steps/__init__.py +++ b/src/koheesio/steps/__init__.py @@ -20,12 +20,15 @@ import json import sys import warnings -from typing import Any +from typing import Any, Callable, Union from abc import abstractmethod from functools import partialmethod, wraps import yaml +from pydantic import BaseModel as PydanticBaseModel +from pydantic import InstanceOf + from koheesio.models import BaseModel, ConfigDict, ModelMetaclass __all__ = [ @@ -59,7 +62,7 @@ def validate_output(self) -> StepOutput: Essentially, this method is a wrapper around the validate method of the BaseModel class """ - validated_model = self.validate() + validated_model = self.validate() # type: ignore[call-arg] return StepOutput.from_basemodel(validated_model) @@ -73,8 +76,9 @@ class StepMetaClass(ModelMetaclass): # When partialmethod is forgetting that _execute_wrapper # is a method of wrapper, and it needs to pass that in as the first arg. # https://github.com/python/cpython/issues/99152 + # noinspection PyPep8Naming,PyUnresolvedReferences class _partialmethod_with_self(partialmethod): - def __get__(self, obj, cls=None): + def __get__(self, obj: Any, cls=None): # type: ignore[no-untyped-def] return self._make_unbound_method().__get__(obj, cls) # Unique object to mark a function as wrapped @@ -116,11 +120,12 @@ def __new__( The method wraps the `execute` method of the class with a partial method if it is not already wrapped. The wrapped method is then set as the new `execute` method of the class. - If the `execute` method is already wrapped, the method is not modified. + If the execute method is already wrapped, the class does not modify the method. The method also keeps track of the number of times the `execute` method has been wrapped. """ + # noinspection PyTypeChecker cls = super().__new__( mcs, cls_name, @@ -140,6 +145,7 @@ def __new__( # Check if the sentinel is the same as the class's sentinel. If they are the same, # it means the function is already wrapped. + # noinspection PyUnresolvedReferences is_already_wrapped = sentinel is cls._step_execute_wrapper_sentinel # Get the wrap count of the function. If the function is not wrapped yet, the default value is 0. @@ -150,6 +156,8 @@ def __new__( if not is_already_wrapped: # Create a partial method with the execute_method as one of the arguments. # This is the new function that will be called instead of the original execute_method. + + # noinspection PyProtectedMember,PyUnresolvedReferences wrapper = mcs._partialmethod_impl(cls=cls, execute_method=execute_method) # Updating the attributes of the wrapping function to those of the original function. @@ -157,6 +165,7 @@ def __new__( # Set the sentinel attribute to the wrapper. This is done so that we can check # if the function is already wrapped. + # noinspection PyUnresolvedReferences setattr(wrapper, "_step_execute_wrapper_sentinel", cls._step_execute_wrapper_sentinel) # Increase the wrap count of the function. This is done to keep track of @@ -167,7 +176,7 @@ def __new__( return cls @staticmethod - def _is_called_through_super(caller_self: Any, caller_name: str, *_args, **_kwargs) -> bool: + def _is_called_through_super(caller_self: Any, caller_name: str, *_args, **_kwargs) -> bool: # type: ignore[no-untyped-def] """ Check if the method is called through super() in the immediate parent class. @@ -193,7 +202,7 @@ def _is_called_through_super(caller_self: Any, caller_name: str, *_args, **_kwar return caller_name in base_class.__dict__ @classmethod - def _partialmethod_impl(mcs, cls: type, execute_method) -> partialmethod: + def _partialmethod_impl(mcs, cls: type, execute_method: Callable) -> partialmethod: """ This method creates a partial method implementation for a given class and execute method. It handles a specific issue with python>=3.11 where partialmethod forgets that _execute_wrapper @@ -212,13 +221,15 @@ def _partialmethod_impl(mcs, cls: type, execute_method) -> partialmethod: # When partialmethod is forgetting that _execute_wrapper # is a method of wrapper, and it needs to pass that in as the first arg. # https://github.com/python/cpython/issues/99152 + # noinspection PyPep8Naming class _partialmethod_with_self(partialmethod): """ This class is a workaround for the issue with python>=3.11 where partialmethod forgets that _execute_wrapper is a method of wrapper, and it needs to pass that in as the first argument. """ - def __get__(self, obj, cls=None): + # noinspection PyShadowingNames + def __get__(self, obj: Any, cls=None): # type: ignore[no-untyped-def] """ This method returns the unbound method for the given object and class. @@ -229,15 +240,17 @@ def __get__(self, obj, cls=None): Returns: The unbound method. """ + # noinspection PyUnresolvedReferences return self._make_unbound_method().__get__(obj, cls) _partialmethod_impl = partialmethod if sys.version_info < (3, 11) else _partialmethod_with_self + # noinspection PyUnresolvedReferences wrapper = _partialmethod_impl(cls._execute_wrapper, execute_method=execute_method) return wrapper @classmethod - def _execute_wrapper(mcs, step: Step, execute_method, *args, **kwargs) -> StepOutput: + def _execute_wrapper(cls, step: Step, execute_method: Callable, *args, **kwargs) -> StepOutput: # type: ignore[no-untyped-def] """ Method that wraps some common functionalities on Steps Ensures proper logging and makes it so that a Steps execute method always returns the StepOutput @@ -262,18 +275,18 @@ def _execute_wrapper(mcs, step: Step, execute_method, *args, **kwargs) -> StepOu # check if the method is called through super() in the immediate parent class caller_name = inspect.currentframe().f_back.f_back.f_code.co_name - is_called_through_super_ = mcs._is_called_through_super(step, caller_name) + is_called_through_super_ = cls._is_called_through_super(step, caller_name) - mcs._log_start_message(step=step, skip_logging=is_called_through_super_) - return_value = mcs._run_execute(step=step, execute_method=execute_method, *args, **kwargs) - mcs._configure_step_output(step=step, return_value=return_value) - mcs._validate_output(step=step, skip_validating=is_called_through_super_) - mcs._log_end_message(step=step, skip_logging=is_called_through_super_) + cls._log_start_message(step=step, skip_logging=is_called_through_super_) + return_value = cls._run_execute(step=step, execute_method=execute_method, *args, **kwargs) # type: ignore[misc] + cls._configure_step_output(step=step, return_value=return_value) + cls._validate_output(step=step, skip_validating=is_called_through_super_) + cls._log_end_message(step=step, skip_logging=is_called_through_super_) return step.output @classmethod - def _log_start_message(mcs, step: Step, *_args, skip_logging: bool = False, **_kwargs): + def _log_start_message(cls, step: Step, *_args, skip_logging: bool = False, **_kwargs) -> None: # type: ignore[no-untyped-def] """ Log the start message of the step execution @@ -292,10 +305,10 @@ def _log_start_message(mcs, step: Step, *_args, skip_logging: bool = False, **_k if not skip_logging: step.log.info("Start running step") - step.log.debug(f"Step Input: {step.__repr_str__(' ')}") + step.log.debug(f"Step Input: {step.__repr_str__(' ')}") # type: ignore[misc] @classmethod - def _log_end_message(mcs, step: Step, *_args, skip_logging: bool = False, **_kwargs): + def _log_end_message(cls, step: Step, *_args, skip_logging: bool = False, **_kwargs) -> None: # type: ignore[no-untyped-def] """ Log the end message of the step execution @@ -313,11 +326,11 @@ def _log_end_message(mcs, step: Step, *_args, skip_logging: bool = False, **_kwa """ if not skip_logging: - step.log.debug(f"Step Output: {step.output.__repr_str__(' ')}") + step.log.debug(f"Step Output: {step.output.__repr_str__(' ')}") # type: ignore[misc] step.log.info("Finished running step") @classmethod - def _validate_output(mcs, step: Step, *_args, skip_validating: bool = False, **_kwargs): + def _validate_output(cls, step: Step, *_args, skip_validating: bool = False, **_kwargs) -> None: # type: ignore[no-untyped-def] """ Validate the output of the step @@ -338,7 +351,7 @@ def _validate_output(mcs, step: Step, *_args, skip_validating: bool = False, **_ step.output.validate_output() @classmethod - def _configure_step_output(mcs, step, return_value: Any, *_args, **_kwargs): + def _configure_step_output(cls, step, return_value: Any, *_args, **_kwargs) -> None: # type: ignore[no-untyped-def] """ Configure the output of the step. If the execute method returns a value, and it is not the output, set the output to the return value @@ -372,7 +385,7 @@ def _configure_step_output(mcs, step, return_value: Any, *_args, **_kwargs): step.output = output @classmethod - def _run_execute(mcs, execute_method, step, *args, **kwargs) -> Any: + def _run_execute(cls, execute_method: Callable, step, *args, **kwargs) -> Any: # type: ignore[no-untyped-def] """ Run the execute method of the step, and catch any errors @@ -529,17 +542,17 @@ def output(self) -> Output: """Interact with the output of the Step""" if not self.__output__: self.__output__ = self.Output.lazy() - self.__output__.name = self.name + ".Output" - self.__output__.description = "Output for " + self.name + self.__output__.name = self.name + ".Output" # type: ignore[operator] + self.__output__.description = "Output for " + self.name # type: ignore[operator] return self.__output__ @output.setter - def output(self, value: Output): + def output(self, value: Output) -> None: """Set the output of the Step""" self.__output__ = value @abstractmethod - def execute(self): + def execute(self) -> InstanceOf[StepOutput]: """Abstract method to implement for new steps. The Inputs of the step can be accessed, using `self.input_name` @@ -549,7 +562,7 @@ def execute(self): """ raise NotImplementedError - def run(self): + def run(self) -> InstanceOf[StepOutput]: """Alias to .execute()""" return self.execute() @@ -564,7 +577,7 @@ def __str__(self) -> str: """String representation of a step""" return self.__repr__() - def repr_json(self, simple=False) -> str: + def repr_json(self, simple: bool = False) -> str: """dump the step to json, meant for representation Note: use to_json if you want to dump the step to json for serialization @@ -593,7 +606,7 @@ def repr_json(self, simple=False) -> str: _result = {} # extract input - _input = self.model_dump(**model_dump_options) + _input = self.model_dump(**model_dump_options) # type: ignore[arg-type] # remove name and description from input and add to result if simple is not set name = _input.pop("name", None) @@ -607,7 +620,7 @@ def repr_json(self, simple=False) -> str: model_dump_options["exclude"] = {"name", "description"} # extract output - _output = self.output.model_dump(**model_dump_options) + _output = self.output.model_dump(**model_dump_options) # type: ignore[arg-type] # add output to result if _output: @@ -630,7 +643,7 @@ def default(self, o: Any) -> Any: return json_str - def repr_yaml(self, simple=False) -> str: + def repr_yaml(self, simple: bool = False) -> str: """dump the step to yaml, meant for representation Note: use to_yaml if you want to dump the step to yaml for serialization @@ -662,7 +675,7 @@ def repr_yaml(self, simple=False) -> str: return yaml.dump(_result) - def __getattr__(self, key: str): + def __getattr__(self, key: str) -> Union[Any, None]: """__getattr__ dunder Allows input to be accessed through `self.input_name` @@ -680,6 +693,6 @@ def __getattr__(self, key: str): return self.model_dump().get(key) @classmethod - def from_step(cls, step: Step, **kwargs): + def from_step(cls, step: Step, **kwargs) -> InstanceOf[PydanticBaseModel]: # type: ignore[no-untyped-def] """Returns a new Step instance based on the data of another Step or BaseModel instance""" return cls.from_basemodel(step, **kwargs) diff --git a/src/koheesio/steps/dummy.py b/src/koheesio/steps/dummy.py index a7ab8b72..9cab47bb 100644 --- a/src/koheesio/steps/dummy.py +++ b/src/koheesio/steps/dummy.py @@ -36,7 +36,7 @@ class Output(DummyOutput): c: str - def execute(self): + def execute(self) -> None: """Dummy execute for testing purposes.""" self.output.a = self.a self.output.b = self.b diff --git a/src/koheesio/steps/http.py b/src/koheesio/steps/http.py index 5981eb18..68329cc7 100644 --- a/src/koheesio/steps/http.py +++ b/src/koheesio/steps/http.py @@ -16,7 +16,7 @@ from typing import Any, Dict, List, Optional, Union from enum import Enum -import requests +import requests # type: ignore[import-untyped] from koheesio import Step from koheesio.models import ( @@ -34,7 +34,7 @@ "HttpPostStep", "HttpPutStep", "HttpDeleteStep", - "PaginatedHtppGetStep", + "PaginatedHttpGetStep", ] @@ -49,7 +49,7 @@ class HttpMethod(str, Enum): DELETE = "delete" @classmethod - def from_string(cls, value: str): + def from_string(cls, value: str) -> str: """Allows for getting the right Method Enum by simply passing a string value This method is not case-sensitive """ @@ -102,7 +102,7 @@ class HttpStep(Step, ExtraParamsMixin): data: Optional[Union[Dict[str, str], str]] = Field( default_factory=dict, description="[Optional] Data to be sent along with the request", alias="body" ) - params: Optional[Dict[str, Any]] = Field( + params: Optional[Dict[str, Any]] = Field( # type: ignore[assignment] default_factory=dict, description="[Optional] Set of extra parameters that should be passed to HTTP request", ) @@ -135,12 +135,12 @@ class Output(Step.Output): status_code: Optional[int] = Field(default=None, description="The status return code of the request") @property - def json_payload(self): + def json_payload(self) -> Union[dict, list, None]: """Alias for response_json""" return self.response_json @field_validator("method") - def get_proper_http_method_from_str_value(cls, method_value): + def get_proper_http_method_from_str_value(cls, method_value: str) -> str: """Converts string value to HttpMethod enum value""" if isinstance(method_value, str): try: @@ -154,7 +154,7 @@ def get_proper_http_method_from_str_value(cls, method_value): return method_value @field_validator("headers", mode="before") - def encode_sensitive_headers(cls, headers): + def encode_sensitive_headers(cls, headers: dict) -> dict: """ Encode potentially sensitive data into pydantic.SecretStr class to prevent them being displayed as plain text in logs. @@ -164,7 +164,7 @@ def encode_sensitive_headers(cls, headers): return headers @field_serializer("headers", when_used="json") - def decode_sensitive_headers(self, headers): + def decode_sensitive_headers(self, headers: dict) -> dict: """ Authorization headers are being converted into SecretStr under the hood to avoid dumping any sensitive content into logs by the `encode_sensitive_headers` method. @@ -178,13 +178,13 @@ def decode_sensitive_headers(self, headers): headers[k] = v.get_secret_value() if isinstance(v, SecretStr) else v return headers - def get_headers(self): + def get_headers(self) -> dict: """ Dump headers into JSON without SecretStr masking. """ return json.loads(self.model_dump_json()).get("headers") - def set_outputs(self, response): + def set_outputs(self, response: requests.Response) -> None: """ Types of response output """ @@ -200,7 +200,7 @@ def set_outputs(self, response): except json.decoder.JSONDecodeError as e: self.log.info(f"An error occurred while processing the JSON payload. Error message:\n{e.msg}") - def get_options(self): + def get_options(self) -> dict: """options to be passed to requests.request()""" return { "url": self.url, @@ -253,6 +253,7 @@ def request(self, method: Optional[HttpMethod] = None) -> requests.Response: return response + # noinspection PyMethodOverriding def get(self) -> requests.Response: """Execute an HTTP GET call""" self.method = HttpMethod.GET @@ -273,7 +274,7 @@ def delete(self) -> requests.Response: self.method = HttpMethod.DELETE return self.request() - def execute(self) -> Output: + def execute(self) -> None: """ Executes the HTTP request. @@ -320,7 +321,7 @@ class HttpDeleteStep(HttpStep): method: HttpMethod = HttpMethod.DELETE -class PaginatedHtppGetStep(HttpGetStep): +class PaginatedHttpGetStep(HttpGetStep): """ Represents a paginated HTTP GET step. @@ -366,7 +367,7 @@ def _adjust_params(self) -> Dict[str, Any]: """ return {k: v for k, v in self.params.items() if k not in ["paginate"]} # type: ignore - def get_options(self): + def get_options(self) -> dict: """ Returns the options to be passed to the requests.request() function. @@ -414,7 +415,7 @@ def _url(self, basic_url: str, page: Optional[int] = None) -> str: return basic_url.format(**url_params) - def execute(self) -> HttpGetStep.Output: + def execute(self) -> None: """ Executes the HTTP GET request and handles pagination. @@ -428,7 +429,7 @@ def execute(self) -> HttpGetStep.Output: data = [] _basic_url = self.url - for page in range(offset, pages): + for page in range(offset, pages): # type: ignore[arg-type] if self.paginate: self.log.info(f"Fetching page {page} of {pages - 1}") diff --git a/src/koheesio/utils.py b/src/koheesio/utils.py index 9547f8c7..253a985e 100644 --- a/src/koheesio/utils.py +++ b/src/koheesio/utils.py @@ -2,12 +2,14 @@ Utility functions """ +import datetime import inspect import uuid from typing import Any, Callable, Dict, Optional, Tuple from functools import partial from importlib import import_module from pathlib import Path +from sys import version_info as PYTHON_VERSION __all__ = [ "get_args_for_func", @@ -18,10 +20,14 @@ ] +PYTHON_MINOR_VERSION = PYTHON_VERSION.major + PYTHON_VERSION.minor / 10 +"""float: Python minor version as a float (e.g. 3.7)""" + + def get_args_for_func(func: Callable, params: Dict) -> Tuple[Callable, Dict[str, Any]]: """Helper function that matches keyword arguments (params) on a given function - This function uses inspect to extract the signature on the passed Callable, and then uses functools.partial to + This function uses inspect to extract the signature on the passed Callable, and then uses `functools.partial` to construct a new Callable (partial) function on which the input was mapped. Example @@ -94,8 +100,15 @@ def get_random_string(length: int = 64, prefix: Optional[str] = None) -> str: return f"{uuid.uuid4().hex}"[0:length] -def convert_str_to_bool(value) -> Any: +def convert_str_to_bool(value: str) -> Any: """Converts a string to a boolean if the string is either 'true' or 'false'""" if isinstance(value, str) and (v := value.lower()) in ["true", "false"]: value = v == "true" return value + + +def utc_now() -> datetime.datetime: + """Get current time in UTC""" + if PYTHON_MINOR_VERSION < 3.11: + return datetime.datetime.utcnow() + return datetime.datetime.now(datetime.timezone.utc) diff --git a/tests/asyncio/test_asyncio_http.py b/tests/asyncio/test_asyncio_http.py index 13bdcaf5..8625c710 100644 --- a/tests/asyncio/test_asyncio_http.py +++ b/tests/asyncio/test_asyncio_http.py @@ -10,6 +10,7 @@ from koheesio.asyncio.http import AsyncHttpStep from koheesio.steps.http import HttpMethod +# noinspection HttpUrlsUsage ASYNC_BASE_URL = "http://httpbin.org" ASYNC_GET_ENDPOINT = URL(f"{ASYNC_BASE_URL}/get") ASYNC_STATUS_503_ENDPOINT = URL(f"{ASYNC_BASE_URL}/status/503") diff --git a/tests/snowflake/test_snowflake.py b/tests/snowflake/test_snowflake.py new file mode 100644 index 00000000..0541bdf8 --- /dev/null +++ b/tests/snowflake/test_snowflake.py @@ -0,0 +1,255 @@ +# flake8: noqa: F811 +from unittest import mock + +import pytest +from pydantic_core._pydantic_core import ValidationError + +from koheesio.integrations.snowflake import ( + GrantPrivilegesOnObject, + GrantPrivilegesOnTable, + GrantPrivilegesOnView, + SnowflakeBaseModel, + SnowflakeRunQueryPython, + SnowflakeStep, + SnowflakeTableStep, +) +from koheesio.integrations.snowflake.test_utils import mock_query + +COMMON_OPTIONS = { + "url": "url", + "user": "user", + "password": "password", + "database": "db", + "schema": "schema", + "role": "role", + "warehouse": "warehouse", +} + + +class TestGrantPrivilegesOnObject: + options = dict( + **COMMON_OPTIONS, + account="42", + object="foo", + type="TABLE", + privileges=["DELETE", "SELECT"], + roles=["role_1", "role_2"], + ) + + def test_execute(self, mock_query): + """Test that the query is correctly generated""" + # Arrange + del self.options["role"] # role is not required for this test as we are setting "roles" + mock_query.expected_data = [None] + expected_query = [ + "GRANT DELETE,SELECT ON TABLE foo TO ROLE role_1", + "GRANT DELETE,SELECT ON TABLE foo TO ROLE role_2", + ] + + # Act + kls = GrantPrivilegesOnObject(**self.options) + output = kls.execute() + + # Assert - 2 queries are expected, result should be None + assert output.query == expected_query + assert output.results == [None, None] + + +class TestGrantPrivilegesOnTable: + options = {**COMMON_OPTIONS, **dict(account="42", table="foo", privileges=["SELECT"], roles=["role_1"])} + + def test_execute(self, mock_query): + """Test that the query is correctly generated""" + # Arrange + del self.options["role"] # role is not required for this test as we are setting "roles" + mock_query.expected_data = [None] + expected_query = ["GRANT SELECT ON TABLE db.schema.foo TO ROLE role_1"] + + # Act + kls = GrantPrivilegesOnTable(**self.options) + output = kls.execute() + + # Assert - 1 query is expected, result should be None + assert output.query == expected_query + assert output.results == mock_query.expected_data + + +class TestGrantPrivilegesOnView: + options = {**COMMON_OPTIONS, **dict(account="42", view="foo", privileges=["SELECT"], roles=["role_1"])} + + def test_execute(self, mock_query): + """Test that the query is correctly generated""" + # Arrange + del self.options["role"] # role is not required for this test as we are setting "roles" + mock_query.expected_data = [None] + expected_query = ["GRANT SELECT ON VIEW db.schema.foo TO ROLE role_1"] + + # Act + kls = GrantPrivilegesOnView(**self.options) + output = kls.execute() + + # Assert - 1 query is expected, result should be None + assert output.query == expected_query + assert output.results == mock_query.expected_data + + +class TestSnowflakeRunQueryPython: + def test_mandatory_fields(self): + """Test that query and account fields are mandatory""" + with pytest.raises(ValidationError): + _1 = SnowflakeRunQueryPython(**COMMON_OPTIONS) + + # sql/query and account should work without raising an error + _2 = SnowflakeRunQueryPython(**COMMON_OPTIONS, sql="SELECT foo", account="42") + _3 = SnowflakeRunQueryPython(**COMMON_OPTIONS, query="SELECT foo", account="42") + + def test_get_options(self): + """Test that the options are correctly generated""" + # Arrange + expected_query = "SELECT foo" + kls = SnowflakeRunQueryPython(**COMMON_OPTIONS, sql=expected_query, account="42") + + # Act + actual_options = kls.get_options() + query_in_options = kls.get_options(include={"query"}, by_alias=True) + + # Assert + expected_options = { + "account": "42", + "database": "db", + "password": "password", + "role": "role", + "schema": "schema", + "url": "url", + "user": "user", + "warehouse": "warehouse", + } + assert actual_options == expected_options + assert query_in_options["query"] == expected_query, "query should be returned regardless of the input" + + def test_execute(self, mock_query): + # Arrange + query = "SELECT * FROM two_row_table" + expected_data = [("row1",), ("row2",)] + mock_query.expected_data = expected_data + + # Act + instance = SnowflakeRunQueryPython(**COMMON_OPTIONS, query=query, account="42") + instance.execute() + + # Assert + mock_query.assert_called_with(query) + assert instance.output.results == expected_data + + def test_with_missing_dependencies(self): + """Missing dependency should throw a warning first, and raise an error if execution is attempted""" + # Arrange -- remove the snowflake connector + with mock.patch.dict("sys.modules", {"snowflake": None}): + from koheesio.integrations.snowflake import safe_import_snowflake_connector + + # Act & Assert -- first test for the warning, then test for the error + match_text = "You need to have the `snowflake-connector-python` package installed" + with pytest.warns(UserWarning, match=match_text): + safe_import_snowflake_connector() + with pytest.warns(UserWarning, match=match_text): + instance = SnowflakeRunQueryPython(**COMMON_OPTIONS, query="", account="42") + with pytest.raises(RuntimeError): + instance.execute() + + +class TestSnowflakeBaseModel: + + def test_get_options_using_alias(self): + """Test that the options are correctly generated using alias""" + k = SnowflakeBaseModel( + sfURL="url", + sfUser="user", + sfPassword="password", + sfDatabase="database", + sfRole="role", + sfWarehouse="warehouse", + schema="schema", + ) + options = k.get_options() # alias should be used by default + assert options["sfURL"] == "url" + assert options["sfUser"] == "user" + assert options["sfDatabase"] == "database" + assert options["sfRole"] == "role" + assert options["sfWarehouse"] == "warehouse" + assert options["sfSchema"] == "schema" + + def test_get_options(self): + """Test that the options are correctly generated not using alias""" + k = SnowflakeBaseModel( + url="url", + user="user", + password="password", + database="database", + role="role", + warehouse="warehouse", + schema="schema", + ) + options = k.get_options(by_alias=False) + assert options["url"] == "url" + assert options["user"] == "user" + assert options["database"] == "database" + assert options["role"] == "role" + assert options["warehouse"] == "warehouse" + assert options["schema"] == "schema" + + # make sure none of the koheesio options are present + assert "description" not in options + assert "name" not in options + + def test_get_options_include(self): + """Test that the options are correctly generated using include""" + k = SnowflakeBaseModel( + url="url", + user="user", + password="password", + database="database", + role="role", + warehouse="warehouse", + schema="schema", + options={"foo": "bar"}, + ) + options = k.get_options(include={"url", "user", "description", "options"}, by_alias=False) + + # should be present + assert options["url"] == "url" + assert options["user"] == "user" + assert "description" in options + + # options should be expanded + assert "options" not in options + assert options["foo"] == "bar" + + # should not be present + assert "database" not in options + assert "role" not in options + assert "warehouse" not in options + assert "schema" not in options + + +class TestSnowflakeStep: + def test_initialization(self): + """Test that the Step fields come through correctly""" + # Arrange + kls = SnowflakeStep(**COMMON_OPTIONS) + + # Act + options = kls.get_options() + + # Assert + assert kls.name == "SnowflakeStep" + assert kls.description == "Expands the SnowflakeBaseModel so that it can be used as a Step" + assert ( + "name" not in options and "description" not in options + ), "koheesio options should not be present in get_options" + + +class TestSnowflakeTableStep: + def test_initialization(self): + """Test that the table is correctly set""" + kls = SnowflakeTableStep(**COMMON_OPTIONS, table="table") + assert kls.table == "table" diff --git a/tests/spark/conftest.py b/tests/spark/conftest.py index f6f40f93..b0a7c51e 100644 --- a/tests/spark/conftest.py +++ b/tests/spark/conftest.py @@ -1,15 +1,17 @@ import datetime import os +import socket import sys +from collections import namedtuple from decimal import Decimal from pathlib import Path from textwrap import dedent -from unittest.mock import Mock +from unittest import mock import pytest from delta import configure_spark_with_delta_pip -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import SparkSession from pyspark.sql.types import ( ArrayType, BinaryType, @@ -34,6 +36,15 @@ from koheesio.spark.readers.dummy import DummyReader +def is_port_free(port): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("localhost", port)) + return True + except socket.error: + return False + + @pytest.fixture(scope="session") def warehouse_path(tmp_path_factory, random_uuid, logger): fldr = tmp_path_factory.mktemp("spark-warehouse" + random_uuid) @@ -51,10 +62,25 @@ def checkpoint_folder(tmp_path_factory, random_uuid, logger): @pytest.fixture(scope="session") def spark(warehouse_path, random_uuid): """Spark session fixture with Delta enabled.""" + builder = SparkSession.builder.appName("test_session" + random_uuid) + + if os.environ.get("SPARK_REMOTE") == "local": + # SPARK_TESTING is set in environment variables + # This triggers spark connect logic + # ---->>>> For testing, we use 0 to use an ephemeral port to allow parallel testing. + # --->>>>>> See also SPARK-42272. + from pyspark.version import __version__ as spark_version + + builder = configure_spark_with_delta_pip( + spark_session_builder=builder.remote("local"), + extra_packages=[f"org.apache.spark:spark-connect_2.12:{spark_version}"], + ) + else: + builder = builder.master("local[*]") + builder = configure_spark_with_delta_pip(spark_session_builder=builder) + builder = ( - SparkSession.builder.appName("test_session" + random_uuid) - .master("local[*]") - .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + builder.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") .config("spark.sql.warehouse.dir", warehouse_path) .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") .config("spark.sql.session.timeZone", "UTC") @@ -62,7 +88,8 @@ def spark(warehouse_path, random_uuid): .config("spark.sql.execution.arrow.pyspark.fallback.enabled", "true") ) - spark_session = configure_spark_with_delta_pip(builder).getOrCreate() + spark_session = builder.getOrCreate() + yield spark_session spark_session.stop() @@ -121,10 +148,10 @@ def dummy_df(spark): @pytest.fixture(scope="class") def sample_df_to_partition(spark): """ - | paritition | Value - |----|----| - | BE | 12 | - | FR | 20 | + | partition | Value | + |-----------|-------| + | BE | 12 | + | FR | 20 | """ data = [["BE", 12], ["FR", 20]] schema = ["partition", "value"] @@ -156,15 +183,16 @@ def sample_df_with_strings(spark): def sample_df_with_timestamp(spark): """ df: - | id | a_date | a_timestamp - |----|---------------------|--------------------- - | 1 | 1970-04-20 12:33:09 | - | 2 | 1980-05-21 13:34:08 | - | 3 | 1990-06-22 14:35:07 | + | id | a_date | a_timestamp | + |----|---------------------|---------------------| + | 1 | 1970-04-20 12:33:09 | 2000-07-01 01:01:00 | + | 2 | 1980-05-21 13:34:08 | 2010-08-02 02:02:00 | + | 3 | 1990-06-22 14:35:07 | 2020-09-03 03:03:00 | Schema: - id: bigint (nullable = true) - - date: timestamp (nullable = true) + - a_date: timestamp (nullable = true) + - a_timestamp: timestamp (nullable = true) """ data = [ (1, datetime.datetime(1970, 4, 20, 12, 33, 9), datetime.datetime(2000, 7, 1, 1, 1)), @@ -207,43 +235,67 @@ def setup_test_data(spark, delta_file): ) +SparkContextData = namedtuple("SparkContextData", ["spark", "options_dict"]) +"""A named tuple containing the Spark session and the options dictionary used to create the DataFrame""" + + @pytest.fixture(scope="class") -def dummy_spark(): - class DummySpark: - """Mocking SparkSession""" +def dummy_spark(spark, sample_df_with_strings) -> SparkContextData: + """SparkSession fixture that makes any call to SparkSession.read.load() return a DataFrame with strings. + + Because of the use of `type(spark.read)`, this fixture automatically alters its behavior for either a remote or + regular Spark session. + + Example + ------- + ```python + def test_dummy_spark(dummy_spark, sample_df_with_strings): + df = dummy_spark.read.load() + assert df.count() == sample_df_with_strings.count() + ``` + + Returns + ------- + SparkContextData + A named tuple containing the Spark session and the options dictionary used to create the DataFrame + """ + _options_dict = {} - def __init__(self): - self.options_dict = {} + def mock_options(*args, **kwargs): + _options_dict.update(kwargs) + return spark.read - def mock_method(self, *args, **kwargs): - return self + spark_reader = type(spark.read) + with mock.patch.object(spark_reader, "options", side_effect=mock_options): + with mock.patch.object(spark_reader, "load", return_value=sample_df_with_strings): + yield SparkContextData(spark, _options_dict) - @property - def mock_property(self): - return self - def mock_options(self, *args, **kwargs): - self.options_dict = kwargs - return self +@pytest.fixture(scope="class") +def mock_df(spark) -> mock.Mock: + """Fixture to mock a DataFrame's methods.""" + # create a local DataFrame so we can get the spec of the DataFrame + df = spark.range(1) - options = mock_options - format = mock_method - read = mock_property + # mock the df.write method + mock_df_write = mock.create_autospec(type(df.write)) - _jvm = Mock() - _jvm.net.snowflake.spark.snowflake.Utils.runQuery.return_value = True + # mock the save method + mock_df_write.save = mock.Mock(return_value=None) - @staticmethod - def load() -> DataFrame: - df = Mock(spec=DataFrame) - df.count.return_value = 1 - df.schema = StructType([StructField("foo", StringType(), True)]) - return df + # mock the format, option(s), and mode methods + mock_df_write.format.return_value = mock_df_write + mock_df_write.options.return_value = mock_df_write + mock_df_write.option.return_value = mock_df_write + mock_df_write.mode.return_value = mock_df_write - return DummySpark() + # now create a mock DataFrame with the mocked write method + mock_df = mock.create_autospec(type(df), instance=True) + mock_df.write = mock_df_write + yield mock_df -def await_job_completion(timeout=300, query_id=None): +def await_job_completion(spark, timeout=300, query_id=None): """ Waits for a Spark streaming job to complete. @@ -254,7 +306,7 @@ def await_job_completion(timeout=300, query_id=None): logger = LoggingFactory.get_logger(name="await_job_completion", inherit_from_koheesio=True) start_time = datetime.datetime.now() - spark = SparkSession.getActiveSession() + spark = spark.getActiveSession() logger.info("Waiting for streaming job to complete") if query_id is not None: stream = spark.streams.get(query_id) diff --git a/tests/spark/integrations/snowflake/test_snowflake.py b/tests/spark/integrations/snowflake/test_snowflake.py deleted file mode 100644 index c2f8bc14..00000000 --- a/tests/spark/integrations/snowflake/test_snowflake.py +++ /dev/null @@ -1,375 +0,0 @@ -from textwrap import dedent -from unittest import mock -from unittest.mock import Mock, patch - -import pytest - -from pyspark.sql import SparkSession -from pyspark.sql import types as t - -from koheesio.spark.snowflake import ( - AddColumn, - CreateOrReplaceTableFromDataFrame, - DbTableQuery, - GetTableSchema, - GrantPrivilegesOnObject, - GrantPrivilegesOnTable, - GrantPrivilegesOnView, - Query, - RunQuery, - SnowflakeBaseModel, - SnowflakeReader, - SnowflakeWriter, - SyncTableAndDataFrameSchema, - TableExists, - TagSnowflakeQuery, - map_spark_type, -) -from koheesio.spark.writers import BatchOutputMode - -pytestmark = pytest.mark.spark - -COMMON_OPTIONS = { - "url": "url", - "user": "user", - "password": "password", - "database": "db", - "schema": "schema", - "role": "role", - "warehouse": "warehouse", -} - - -def test_snowflake_module_import(): - # test that the pass-through imports in the koheesio.spark snowflake modules are working - from koheesio.spark.readers import snowflake as snowflake_writers - from koheesio.spark.writers import snowflake as snowflake_readers - - -class TestSnowflakeReader: - @pytest.mark.parametrize( - "reader_options", [{"dbtable": "table", **COMMON_OPTIONS}, {"table": "table", **COMMON_OPTIONS}] - ) - def test_get_options(self, reader_options): - sf = SnowflakeReader(**(reader_options | {"authenticator": None})) - o = sf.get_options() - assert sf.format == "snowflake" - assert o["sfUser"] == "user" - assert o["sfCompress"] == "on" - assert "authenticator" not in o - - @pytest.mark.parametrize( - "reader_options", [{"dbtable": "table", **COMMON_OPTIONS}, {"table": "table", **COMMON_OPTIONS}] - ) - def test_execute(self, dummy_spark, reader_options): - """Method should be callable from parent class""" - with mock.patch.object(SparkSession, "getActiveSession") as mock_spark: - mock_spark.return_value = dummy_spark - - k = SnowflakeReader(**reader_options).execute() - assert k.df.count() == 1 - - -class TestRunQuery: - query_options = {"query": "query", **COMMON_OPTIONS} - - def test_get_options(self): - k = RunQuery(**self.query_options) - o = k.get_options() - - assert o["host"] == o["sfURL"] - - def test_execute(self, dummy_spark): - pass - - -class TestQuery: - query_options = {"query": "query", **COMMON_OPTIONS} - - def test_execute(self, dummy_spark): - with mock.patch.object(SparkSession, "getActiveSession") as mock_spark: - mock_spark.return_value = dummy_spark - - k = Query(**self.query_options) - assert k.df.count() == 1 - - -class TestTableQuery: - options = {"table": "table", **COMMON_OPTIONS} - - def test_execute(self, dummy_spark): - with mock.patch.object(SparkSession, "getActiveSession") as mock_spark: - mock_spark.return_value = dummy_spark - - k = DbTableQuery(**self.options).execute() - assert k.df.count() == 1 - - -class TestTableExists: - table_exists_options = {"table": "table", **COMMON_OPTIONS} - - def test_execute(self, dummy_spark): - with mock.patch.object(SparkSession, "getActiveSession") as mock_spark: - mock_spark.return_value = dummy_spark - - k = TableExists(**self.table_exists_options).execute() - assert k.exists is True - - -class TestCreateOrReplaceTableFromDataFrame: - options = {"table": "table", **COMMON_OPTIONS} - - def test_execute(self, dummy_spark, dummy_df): - with mock.patch.object(SparkSession, "getActiveSession") as mock_spark: - mock_spark.return_value = dummy_spark - - k = CreateOrReplaceTableFromDataFrame(**self.options, df=dummy_df).execute() - assert k.snowflake_schema == "id BIGINT" - assert k.query == "CREATE OR REPLACE TABLE db.schema.table (id BIGINT)" - assert len(k.input_schema) > 0 - - -class TestGetTableSchema: - get_table_schema_options = {"table": "table", **COMMON_OPTIONS} - - def test_execute(self, dummy_spark): - with mock.patch.object(SparkSession, "getActiveSession") as mock_spark: - mock_spark.return_value = dummy_spark - - k = GetTableSchema(**self.get_table_schema_options) - assert len(k.execute().table_schema.fields) == 1 - - -class TestAddColumn: - options = {"table": "foo", "column": "bar", "type": t.DateType(), **COMMON_OPTIONS} - - def test_execute(self, dummy_spark): - with mock.patch.object(SparkSession, "getActiveSession") as mock_spark: - mock_spark.return_value = dummy_spark - - k = AddColumn(**self.options).execute() - assert k.query == "ALTER TABLE FOO ADD COLUMN BAR DATE" - - -def test_grant_privileges_on_object(dummy_spark): - options = dict( - **COMMON_OPTIONS, object="foo", type="TABLE", privileges=["DELETE", "SELECT"], roles=["role_1", "role_2"] - ) - del options["role"] # role is not required for this step as we are setting "roles" - - kls = GrantPrivilegesOnObject(**options) - - with mock.patch.object(SparkSession, "getActiveSession") as mock_spark: - mock_spark.return_value = dummy_spark - k = kls.execute() - - assert len(k.query) == 2, "expecting 2 queries (one for each role)" - assert "DELETE" in k.query[0] - assert "SELECT" in k.query[0] - - -def test_grant_privileges_on_table(dummy_spark): - options = {**COMMON_OPTIONS, **dict(table="foo", privileges=["SELECT"], roles=["role_1"])} - del options["role"] # role is not required for this step as we are setting "roles" - - kls = GrantPrivilegesOnTable( - **options, - ) - with mock.patch.object(SparkSession, "getActiveSession") as mock_spark: - mock_spark.return_value = dummy_spark - - k = kls.execute() - assert k.query == [ - "GRANT SELECT ON TABLE DB.SCHEMA.FOO TO ROLE ROLE_1", - ] - - -class TestGrantPrivilegesOnView: - options = {**COMMON_OPTIONS} - - def test_execute(self, dummy_spark): - with mock.patch.object(SparkSession, "getActiveSession") as mock_spark: - mock_spark.return_value = dummy_spark - - k = GrantPrivilegesOnView(**self.options, view="foo", privileges=["SELECT"], roles=["role_1"]).execute() - assert k.query == [ - "GRANT SELECT ON VIEW DB.SCHEMA.FOO TO ROLE ROLE_1", - ] - - -class TestSnowflakeWriter: - def test_execute(self, dummy_spark): - with mock.patch.object(SparkSession, "getActiveSession") as mock_spark: - mock_spark.return_value = dummy_spark - - k = SnowflakeWriter( - **COMMON_OPTIONS, - table="foo", - df=dummy_spark.load(), - mode=BatchOutputMode.OVERWRITE, - ) - k.execute() - - -class TestSyncTableAndDataFrameSchema: - @mock.patch("koheesio.spark.snowflake.AddColumn") - @mock.patch("koheesio.spark.snowflake.GetTableSchema") - def test_execute(self, mock_get_table_schema, mock_add_column, spark, caplog): - from pyspark.sql.types import StringType, StructField, StructType - - df = spark.createDataFrame(data=[["val"]], schema=["foo"]) - sf_schema_before = StructType([StructField("bar", StringType(), True)]) - sf_schema_after = StructType([StructField("bar", StringType(), True), StructField("foo", StringType(), True)]) - - mock_get_table_schema_instance = mock_get_table_schema() - mock_get_table_schema_instance.execute.side_effect = [ - mock.Mock(table_schema=sf_schema_before), - mock.Mock(table_schema=sf_schema_after), - ] - - with caplog.at_level("DEBUG"): - k = SyncTableAndDataFrameSchema( - **COMMON_OPTIONS, - table="foo", - df=df, - dry_run=True, - ).execute() - print(f"{caplog.text = }") - assert "Columns to be added to Snowflake table: {'foo'}" in caplog.text - assert "Columns to be added to Spark DataFrame: {'bar'}" in caplog.text - assert k.new_df_schema == StructType() - - k = SyncTableAndDataFrameSchema( - **COMMON_OPTIONS, - table="foo", - df=df, - ).execute() - assert k.df.columns == ["bar", "foo"] - - -@pytest.mark.parametrize( - "input_value,expected", - [ - (t.BinaryType(), "VARBINARY"), - (t.BooleanType(), "BOOLEAN"), - (t.ByteType(), "BINARY"), - (t.DateType(), "DATE"), - (t.TimestampType(), "TIMESTAMP"), - (t.DoubleType(), "DOUBLE"), - (t.FloatType(), "FLOAT"), - (t.IntegerType(), "INT"), - (t.LongType(), "BIGINT"), - (t.NullType(), "STRING"), - (t.ShortType(), "SMALLINT"), - (t.StringType(), "STRING"), - (t.NumericType(), "FLOAT"), - (t.DecimalType(0, 1), "DECIMAL(0,1)"), - (t.DecimalType(0, 100), "DECIMAL(0,100)"), - (t.DecimalType(10, 0), "DECIMAL(10,0)"), - (t.DecimalType(), "DECIMAL(10,0)"), - (t.MapType(t.IntegerType(), t.StringType()), "VARIANT"), - (t.ArrayType(t.StringType()), "VARIANT"), - (t.StructType([t.StructField(name="foo", dataType=t.StringType())]), "VARIANT"), - (t.DayTimeIntervalType(), "STRING"), - ], -) -def test_map_spark_type(input_value, expected): - assert map_spark_type(input_value) == expected - - -class TestSnowflakeBaseModel: - def test_get_options(self, dummy_spark): - k = SnowflakeBaseModel( - sfURL="url", - sfUser="user", - sfPassword="password", - sfDatabase="database", - sfRole="role", - sfWarehouse="warehouse", - schema="schema", - ) - options = k.get_options() - assert options["sfURL"] == "url" - assert options["sfUser"] == "user" - assert options["sfDatabase"] == "database" - assert options["sfRole"] == "role" - assert options["sfWarehouse"] == "warehouse" - assert options["sfSchema"] == "schema" - - -class TestTagSnowflakeQuery: - def test_tag_query_no_existing_preactions(self): - expected_preactions = ( - """ALTER SESSION SET QUERY_TAG = '{"pipeline_name": "test-pipeline-1","task_name": "test_task_1"}';""" - ) - - tagged_options = ( - TagSnowflakeQuery( - task_name="test_task_1", - pipeline_name="test-pipeline-1", - ) - .execute() - .options - ) - - assert len(tagged_options) == 1 - preactions = tagged_options["preactions"].replace(" ", "").replace("\n", "") - assert preactions == expected_preactions - - def test_tag_query_present_existing_preactions(self): - options = { - "otherSfOption": "value", - "preactions": "SET TEST_VAR = 'ABC';", - } - query_tag_preaction = ( - """ALTER SESSION SET QUERY_TAG = '{"pipeline_name": "test-pipeline-2","task_name": "test_task_2"}';""" - ) - expected_preactions = f"SET TEST_VAR = 'ABC';{query_tag_preaction}" "" - - tagged_options = ( - TagSnowflakeQuery(task_name="test_task_2", pipeline_name="test-pipeline-2", options=options) - .execute() - .options - ) - - assert len(tagged_options) == 2 - assert tagged_options["otherSfOption"] == "value" - preactions = tagged_options["preactions"].replace(" ", "").replace("\n", "") - assert preactions == expected_preactions - - -def test_table_exists(spark): - # Create a TableExists instance - te = TableExists( - sfURL="url", - sfUser="user", - sfPassword="password", - sfDatabase="database", - sfRole="role", - sfWarehouse="warehouse", - schema="schema", - table="table", - ) - - expected_query = dedent( - """ - SELECT * - FROM INFORMATION_SCHEMA.TABLES - WHERE TABLE_CATALOG = 'DATABASE' - AND TABLE_SCHEMA = 'SCHEMA' - AND TABLE_TYPE = 'BASE TABLE' - AND UPPER(TABLE_NAME) = 'TABLE' - """ - ).strip() - - # Create a Mock object for the Query class - mock_query = Mock(spec=Query) - mock_query.read.return_value = spark.range(1) - - # Patch the Query class to return the mock_query when instantiated - with patch("koheesio.spark.snowflake.Query", return_value=mock_query) as mock_query_class: - # Execute the SnowflakeBaseModel instance - te.execute() - - # Assert that the query is as expected - assert mock_query_class.call_args[1]["query"] == expected_query diff --git a/tests/spark/integrations/snowflake/test_spark_snowflake.py b/tests/spark/integrations/snowflake/test_spark_snowflake.py new file mode 100644 index 00000000..c2346aa4 --- /dev/null +++ b/tests/spark/integrations/snowflake/test_spark_snowflake.py @@ -0,0 +1,296 @@ +# flake8: noqa: F811 +import logging +from textwrap import dedent +from unittest import mock +from unittest.mock import Mock + +import pytest + +from pyspark.sql import types as t + +from koheesio.integrations.snowflake.test_utils import mock_query +from koheesio.integrations.spark.snowflake import ( + AddColumn, + CreateOrReplaceTableFromDataFrame, + DbTableQuery, + GetTableSchema, + Query, + RunQuery, + SnowflakeReader, + SnowflakeWriter, + SyncTableAndDataFrameSchema, + TableExists, + TagSnowflakeQuery, + map_spark_type, +) +from koheesio.spark.writers import BatchOutputMode + +pytestmark = pytest.mark.spark + +COMMON_OPTIONS = { + "url": "url", + "user": "user", + "password": "password", + "database": "db", + "schema": "schema", + "role": "role", + "warehouse": "warehouse", +} + + +def test_snowflake_module_import(): + # test that the pass-through imports in the koheesio.spark snowflake modules are working + from koheesio.spark.readers import snowflake as snowflake_writers + from koheesio.spark.writers import snowflake as snowflake_readers + + +class TestSnowflakeReader: + reader_options = {"dbtable": "table", **COMMON_OPTIONS} + + def test_get_options(self): + sf = SnowflakeReader(**(self.reader_options | {"authenticator": None})) + o = sf.get_options() + assert sf.format == "snowflake" + assert o["sfUser"] == "user" + assert o["sfCompress"] == "on" + assert "authenticator" not in o + + def test_execute(self, dummy_spark): + """Method should be callable from parent class""" + k = SnowflakeReader(**self.reader_options).execute() + assert k.df.count() == 3 + + +class TestRunQuery: + def test_deprecation(self): + """Test for the deprecation warning""" + with pytest.warns( + DeprecationWarning, match="The RunQuery class is deprecated and will be removed in a future release." + ): + try: + _ = RunQuery( + **COMMON_OPTIONS, + query="", + ) + except RuntimeError: + pass # Ignore any RuntimeError that occur after the warning + + def test_spark_connect(self, spark): + """Test that we get a RuntimeError when using a SparkSession without a JVM""" + from koheesio.spark.utils.connect import is_remote_session + + if not is_remote_session(spark): + pytest.skip(reason="Test only runs when we have a remote SparkSession") + + with pytest.raises(RuntimeError): + _ = RunQuery( + **COMMON_OPTIONS, + query="", + ) + + +class TestQuery: + options = {"query": "query", **COMMON_OPTIONS} + + def test_execute(self, dummy_spark): + k = Query(**self.options).execute() + assert k.df.count() == 3 + + +class TestTableQuery: + options = {"table": "table", **COMMON_OPTIONS} + + def test_execute(self, dummy_spark): + k = DbTableQuery(**self.options).execute() + assert k.df.count() == 3 + + +class TestCreateOrReplaceTableFromDataFrame: + options = {"table": "table", "account": "bar", **COMMON_OPTIONS} + + def test_execute(self, dummy_spark, dummy_df, mock_query): + k = CreateOrReplaceTableFromDataFrame(**self.options, df=dummy_df).execute() + assert k.snowflake_schema == "id BIGINT" + assert k.query == "CREATE OR REPLACE TABLE db.schema.table (id BIGINT)" + assert len(k.input_schema) > 0 + mock_query.assert_called_with(k.query) + + +class TestGetTableSchema: + options = {"table": "table", **COMMON_OPTIONS} + + def test_execute(self, dummy_spark): + k = GetTableSchema(**self.options) + assert len(k.execute().table_schema.fields) == 2 + + +class TestAddColumn: + options = {"table": "foo", "column": "bar", "type": t.DateType(), "account": "foo", **COMMON_OPTIONS} + + def test_execute(self, dummy_spark, mock_query): + k = AddColumn(**self.options).execute() + assert k.query == "ALTER TABLE FOO ADD COLUMN BAR DATE" + mock_query.assert_called_with(k.query) + + +class TestSnowflakeWriter: + def test_execute(self, mock_df): + k = SnowflakeWriter( + **COMMON_OPTIONS, + table="foo", + df=mock_df, + mode=BatchOutputMode.OVERWRITE, + ) + k.execute() + + # check that the format was set to snowflake + mocked_format: Mock = mock_df.write.format + assert mocked_format.call_args[0][0] == "snowflake" + mock_df.write.format.assert_called_with("snowflake") + + +class TestSyncTableAndDataFrameSchema: + @mock.patch("koheesio.integrations.spark.snowflake.AddColumn") + @mock.patch("koheesio.integrations.spark.snowflake.GetTableSchema") + def test_execute(self, mock_get_table_schema, mock_add_column, spark, caplog): + # Arrange + from pyspark.sql.types import StringType, StructField, StructType + + df = spark.createDataFrame(data=[["val"]], schema=["foo"]) + sf_schema_before = StructType([StructField("bar", StringType(), True)]) + sf_schema_after = StructType([StructField("bar", StringType(), True), StructField("foo", StringType(), True)]) + + mock_get_table_schema_instance = mock_get_table_schema() + mock_get_table_schema_instance.execute.side_effect = [ + mock.Mock(table_schema=sf_schema_before), + mock.Mock(table_schema=sf_schema_after), + ] + + logger = logging.getLogger("koheesio") + logger.setLevel(logging.WARNING) + + # Act and Assert -- dry run + with caplog.at_level(logging.WARNING): + k = SyncTableAndDataFrameSchema( + **COMMON_OPTIONS, + table="foo", + df=df, + dry_run=True, + ).execute() + print(f"{caplog.text = }") + assert "Columns to be added to Snowflake table: {'foo'}" in caplog.text + assert "Columns to be added to Spark DataFrame: {'bar'}" in caplog.text + assert k.new_df_schema == StructType() + + # Act and Assert -- execute + k = SyncTableAndDataFrameSchema( + **COMMON_OPTIONS, + table="foo", + df=df, + ).execute() + assert sorted(k.df.columns) == ["bar", "foo"] + + +@pytest.mark.parametrize( + "input_value,expected", + [ + (t.BinaryType(), "VARBINARY"), + (t.BooleanType(), "BOOLEAN"), + (t.ByteType(), "BINARY"), + (t.DateType(), "DATE"), + (t.TimestampType(), "TIMESTAMP"), + (t.DoubleType(), "DOUBLE"), + (t.FloatType(), "FLOAT"), + (t.IntegerType(), "INT"), + (t.LongType(), "BIGINT"), + (t.NullType(), "STRING"), + (t.ShortType(), "SMALLINT"), + (t.StringType(), "STRING"), + (t.NumericType(), "FLOAT"), + (t.DecimalType(0, 1), "DECIMAL(0,1)"), + (t.DecimalType(0, 100), "DECIMAL(0,100)"), + (t.DecimalType(10, 0), "DECIMAL(10,0)"), + (t.DecimalType(), "DECIMAL(10,0)"), + (t.MapType(t.IntegerType(), t.StringType()), "VARIANT"), + (t.ArrayType(t.StringType()), "VARIANT"), + (t.StructType([t.StructField(name="foo", dataType=t.StringType())]), "VARIANT"), + (t.DayTimeIntervalType(), "STRING"), + ], +) +def test_map_spark_type(input_value, expected): + assert map_spark_type(input_value) == expected + + +class TestTableExists: + options = dict( + sfURL="url", + sfUser="user", + sfPassword="password", + sfDatabase="database", + sfRole="role", + sfWarehouse="warehouse", + schema="schema", + table="table", + ) + + def test_table_exists(self, dummy_spark): + # Arrange + te = TableExists(**self.options) + expected_query = dedent( + """ + SELECT * + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_CATALOG = 'DATABASE' + AND TABLE_SCHEMA = 'SCHEMA' + AND TABLE_TYPE = 'BASE TABLE' + AND UPPER(TABLE_NAME) = 'TABLE' + """ + ).strip() + + # Act + output = te.execute() + + # Assert that the query is as expected and that we got exists as True + assert dummy_spark.options_dict["query"] == expected_query + assert output.exists + + +class TestTagSnowflakeQuery: + def test_tag_query_no_existing_preactions(self): + expected_preactions = ( + """ALTER SESSION SET QUERY_TAG = '{"pipeline_name": "test-pipeline-1","task_name": "test_task_1"}';""" + ) + + tagged_options = ( + TagSnowflakeQuery( + task_name="test_task_1", + pipeline_name="test-pipeline-1", + ) + .execute() + .options + ) + + assert len(tagged_options) == 1 + preactions = tagged_options["preactions"].replace(" ", "").replace("\n", "") + assert preactions == expected_preactions + + def test_tag_query_present_existing_preactions(self): + options = { + "otherSfOption": "value", + "preactions": "SET TEST_VAR = 'ABC';", + } + query_tag_preaction = ( + """ALTER SESSION SET QUERY_TAG = '{"pipeline_name": "test-pipeline-2","task_name": "test_task_2"}';""" + ) + expected_preactions = f"SET TEST_VAR = 'ABC';{query_tag_preaction}" "" + + tagged_options = ( + TagSnowflakeQuery(task_name="test_task_2", pipeline_name="test-pipeline-2", options=options) + .execute() + .options + ) + + assert len(tagged_options) == 2 + assert tagged_options["otherSfOption"] == "value" + preactions = tagged_options["preactions"].replace(" ", "").replace("\n", "") + assert preactions == expected_preactions diff --git a/tests/spark/integrations/snowflake/test_sync_task.py b/tests/spark/integrations/snowflake/test_sync_task.py index 178253ca..a4c50e8a 100644 --- a/tests/spark/integrations/snowflake/test_sync_task.py +++ b/tests/spark/integrations/snowflake/test_sync_task.py @@ -1,4 +1,5 @@ from datetime import datetime +from textwrap import dedent from unittest import mock import chispa @@ -7,15 +8,15 @@ import pydantic -from pyspark.sql import DataFrame - -from koheesio.spark.delta import DeltaTableStep -from koheesio.spark.readers.delta import DeltaTableReader -from koheesio.spark.snowflake import ( - RunQuery, +from koheesio.integrations.snowflake import SnowflakeRunQueryPython +from koheesio.integrations.snowflake.test_utils import mock_query +from koheesio.integrations.spark.snowflake import ( SnowflakeWriter, SynchronizeDeltaToSnowflakeTask, ) +from koheesio.spark import DataFrame +from koheesio.spark.delta import DeltaTableStep +from koheesio.spark.readers.delta import DeltaTableReader from koheesio.spark.writers import BatchOutputMode, StreamingOutputMode from koheesio.spark.writers.delta import DeltaTableWriter from koheesio.spark.writers.stream import ForEachBatchStreamWriter @@ -23,7 +24,6 @@ pytestmark = pytest.mark.spark COMMON_OPTIONS = { - "source_table": DeltaTableStep(table=""), "target_table": "foo.bar", "key_columns": [ "Country", @@ -126,7 +126,7 @@ def mock_drop_table(table): task.execute() chispa.assert_df_equality(task.output.target_df, df) - @mock.patch.object(RunQuery, "execute") + @mock.patch.object(SnowflakeRunQueryPython, "execute") def test_merge( self, mocked_sf_query_execute, @@ -134,42 +134,46 @@ def test_merge( foreach_batch_stream_local, snowflake_staging_file, ): - # Prepare Delta requirements - source_table = DeltaTableStep(datbase="klettern", table="test_merge") + # Arrange - Prepare Delta requirements + source_table = DeltaTableStep(database="klettern", table="test_merge") spark.sql( - f""" - CREATE OR REPLACE TABLE {source_table.table_name} - (Country STRING, NumVaccinated LONG, AvailableDoses LONG) - USING DELTA - TBLPROPERTIES ('delta.enableChangeDataFeed' = true); - """ + dedent( + f""" + CREATE OR REPLACE TABLE {source_table.table_name} + (Country STRING, NumVaccinated LONG, AvailableDoses LONG) + USING DELTA + TBLPROPERTIES ('delta.enableChangeDataFeed' = true); + """ + ) ) - # Prepare local representation of snowflake + # Arrange - Prepare local representation of snowflake task = SynchronizeDeltaToSnowflakeTask( streaming=True, synchronisation_mode=BatchOutputMode.MERGE, - **{**COMMON_OPTIONS, "source_table": source_table}, + **{**COMMON_OPTIONS, "source_table": source_table, "account": "sf_account"}, ) - # Perform actions + # Arrange - Add data to previously empty Delta table spark.sql( - f"""INSERT INTO {source_table.table_name} VALUES - ("Australia", 100, 3000), - ("USA", 10000, 20000), - ("UK", 7000, 10000); - """ + dedent( + f""" + INSERT INTO {source_table.table_name} VALUES + ("Australia", 100, 3000), + ("USA", 10000, 20000), + ("UK", 7000, 10000); + """ + ) ) - # Run code - + # Act - Run code + # Note: We are using the foreach_batch_stream_local fixture to simulate writing to a live environment with mock.patch.object(SynchronizeDeltaToSnowflakeTask, "writer", new=foreach_batch_stream_local): task.execute() task.writer.await_termination() - # Validate result + # Assert - Validate result df = spark.read.parquet(snowflake_staging_file).select("Country", "NumVaccinated", "AvailableDoses") - chispa.assert_df_equality( df, spark.sql(f"SELECT * FROM {source_table.table_name}"), @@ -187,7 +191,7 @@ def test_merge( # Test that this call doesn't raise exception after all queries were completed task.writer.await_termination() task.execute() - await_job_completion() + await_job_completion(spark) # Validate result df = spark.read.parquet(snowflake_staging_file).select("Country", "NumVaccinated", "AvailableDoses") @@ -368,6 +372,13 @@ def test_changed_table(self, spark, sample_df_with_timestamp): class TestValidations: + options = {**COMMON_OPTIONS} + + @pytest.fixture(autouse=True, scope="class") + def set_spark(self, spark): + self.options["source_table"] = DeltaTableStep(table="") + yield spark + @pytest.mark.parametrize( "sync_mode,streaming", [ @@ -381,7 +392,7 @@ def test_snowflake_sync_task_allowed_options(self, sync_mode: BatchOutputMode, s task = SynchronizeDeltaToSnowflakeTask( streaming=streaming, synchronisation_mode=sync_mode, - **COMMON_OPTIONS, + **self.options, ) assert task.reader.streaming == streaming @@ -430,21 +441,21 @@ def test_snowflake_sync_task_allowed_writers( task = SynchronizeDeltaToSnowflakeTask( streaming=streaming, synchronisation_mode=sync_mode, - **COMMON_OPTIONS, + **self.options, ) - print(f"{task.writer = }") - print(f"{type(task.writer) = }") assert isinstance(task.writer, expected_writer_type) def test_merge_cdf_enabled(self, spark): table = DeltaTableStep(database="klettern", table="sync_test_table") spark.sql( - f""" - CREATE OR REPLACE TABLE {table.table_name} - (Country STRING, NumVaccinated INT, AvailableDoses INT) - USING DELTA - TBLPROPERTIES ('delta.enableChangeDataFeed' = false); - """ + dedent( + f""" + CREATE OR REPLACE TABLE {table.table_name} + (Country STRING, NumVaccinated INT, AvailableDoses INT) + USING DELTA + TBLPROPERTIES ('delta.enableChangeDataFeed' = false); + """ + ) ) task = SynchronizeDeltaToSnowflakeTask( streaming=True, @@ -466,12 +477,16 @@ def test_merge_query_no_delete(self): pk_columns=["Country"], non_pk_columns=["NumVaccinated", "AvailableDoses"], ) - expected_query = """ - MERGE INTO target_table target - USING tmp_table temp ON target.Country = temp.Country - WHEN MATCHED AND temp._change_type = 'update_postimage' THEN UPDATE SET NumVaccinated = temp.NumVaccinated, AvailableDoses = temp.AvailableDoses - WHEN NOT MATCHED AND temp._change_type != 'delete' THEN INSERT (Country, NumVaccinated, AvailableDoses) VALUES (temp.Country, temp.NumVaccinated, temp.AvailableDoses) - """ + expected_query = dedent( + """ + MERGE INTO target_table target + USING tmp_table temp ON target.Country = temp.Country + WHEN MATCHED AND temp._change_type = 'update_postimage' + THEN UPDATE SET NumVaccinated = temp.NumVaccinated, AvailableDoses = temp.AvailableDoses + WHEN NOT MATCHED AND temp._change_type != 'delete' + THEN INSERT (Country, NumVaccinated, AvailableDoses) + VALUES (temp.Country, temp.NumVaccinated, temp.AvailableDoses)""" + ).strip() assert query == expected_query @@ -483,12 +498,17 @@ def test_merge_query_with_delete(self): non_pk_columns=["NumVaccinated", "AvailableDoses"], enable_deletion=True, ) - expected_query = """ - MERGE INTO target_table target - USING tmp_table temp ON target.Country = temp.Country - WHEN MATCHED AND temp._change_type = 'update_postimage' THEN UPDATE SET NumVaccinated = temp.NumVaccinated, AvailableDoses = temp.AvailableDoses - WHEN NOT MATCHED AND temp._change_type != 'delete' THEN INSERT (Country, NumVaccinated, AvailableDoses) VALUES (temp.Country, temp.NumVaccinated, temp.AvailableDoses) - WHEN MATCHED AND temp._change_type = 'delete' THEN DELETE""" + expected_query = dedent( + """ + MERGE INTO target_table target + USING tmp_table temp ON target.Country = temp.Country + WHEN MATCHED AND temp._change_type = 'update_postimage' + THEN UPDATE SET NumVaccinated = temp.NumVaccinated, AvailableDoses = temp.AvailableDoses + WHEN NOT MATCHED AND temp._change_type != 'delete' + THEN INSERT (Country, NumVaccinated, AvailableDoses) + VALUES (temp.Country, temp.NumVaccinated, temp.AvailableDoses) + WHEN MATCHED AND temp._change_type = 'delete' THEN DELETE""" + ).strip() assert query == expected_query diff --git a/tests/spark/integrations/tableau/test_hyper.py b/tests/spark/integrations/tableau/test_hyper.py index 691e45a9..d57cd971 100644 --- a/tests/spark/integrations/tableau/test_hyper.py +++ b/tests/spark/integrations/tableau/test_hyper.py @@ -2,6 +2,7 @@ from pathlib import Path, PurePath import pytest + from pyspark.sql.functions import lit from koheesio.integrations.spark.tableau.hyper import ( diff --git a/tests/spark/readers/test_delta_reader.py b/tests/spark/readers/test_delta_reader.py index bf0ec4f1..ab1c6b2e 100644 --- a/tests/spark/readers/test_delta_reader.py +++ b/tests/spark/readers/test_delta_reader.py @@ -1,9 +1,8 @@ import pytest from pyspark.sql import functions as F -from pyspark.sql.dataframe import DataFrame -from koheesio.spark import AnalysisException +from koheesio.spark import AnalysisException, DataFrame from koheesio.spark.readers.delta import DeltaTableReader pytestmark = pytest.mark.spark @@ -61,8 +60,11 @@ def test_delta_table_cdf_reader(spark, streaming_dummy_df, random_uuid): def test_delta_reader_view(spark): reader = DeltaTableReader(table="delta_test_table") + with pytest.raises(AnalysisException): _ = spark.table(reader.view) + # In Spark remote session the above statetment will not raise an exception + _ = spark.table(reader.view).take(1) reader.read() df = spark.table(reader.view) assert df.count() == 10 diff --git a/tests/spark/readers/test_hana.py b/tests/spark/readers/test_hana.py index 2b226db8..35c603ad 100644 --- a/tests/spark/readers/test_hana.py +++ b/tests/spark/readers/test_hana.py @@ -24,11 +24,3 @@ def test_get_options(self): assert o["driver"] == "com.sap.db.jdbc.Driver" assert o["fetchsize"] == 2000 assert o["numPartitions"] == 10 - - def test_execute(self, dummy_spark): - """Method should be callable from parent class""" - with mock.patch.object(SparkSession, "getActiveSession") as mock_spark: - mock_spark.return_value = dummy_spark - - hana = HanaReader(**self.common_options) - assert hana.execute().df.count() == 1 diff --git a/tests/spark/readers/test_jdbc.py b/tests/spark/readers/test_jdbc.py index 1c50c2d2..b75c2ce1 100644 --- a/tests/spark/readers/test_jdbc.py +++ b/tests/spark/readers/test_jdbc.py @@ -55,22 +55,18 @@ def test_execute_wo_dbtable_and_query(self): assert e.type is ValueError def test_execute_w_dbtable_and_query(self, dummy_spark): - with mock.patch.object(SparkSession, "getActiveSession") as mock_spark: - mock_spark.return_value = dummy_spark + """query should take precedence over dbtable""" + jr = JdbcReader(**self.common_options, dbtable="foo", query="bar") + jr.execute() - jr = JdbcReader(**self.common_options, dbtable="foo", query="bar") - jr.execute() - - assert jr.df.count() == 1 - assert mock_spark.return_value.options_dict["query"] == "bar" - assert "dbtable" not in mock_spark.return_value.options_dict + assert jr.df.count() == 3 + assert dummy_spark.options_dict["query"] == "bar" + assert dummy_spark.options_dict.get("dbtable") is None def test_execute_w_dbtable(self, dummy_spark): - with mock.patch.object(SparkSession, "getActiveSession") as mock_spark: - mock_spark.return_value = dummy_spark - - jr = JdbcReader(**self.common_options, dbtable="foo") - jr.execute() + """check that dbtable is passed to the reader correctly""" + jr = JdbcReader(**self.common_options, dbtable="foo") + jr.execute() - assert jr.df.count() == 1 - assert mock_spark.return_value.options_dict["dbtable"] == "foo" + assert jr.df.count() == 3 + assert dummy_spark.options_dict["dbtable"] == "foo" diff --git a/tests/spark/readers/test_memory.py b/tests/spark/readers/test_memory.py index 1cf949e2..21b5d53f 100644 --- a/tests/spark/readers/test_memory.py +++ b/tests/spark/readers/test_memory.py @@ -1,6 +1,5 @@ import pytest from chispa import assert_df_equality - from pyspark.sql.types import StructType from koheesio.spark.readers.memory import DataFormat, InMemoryDataReader @@ -14,10 +13,10 @@ class TestInMemoryDataReader: "data,format,params,expect_filter", [ pytest.param( - "id,string\n1,hello,\n2,world", DataFormat.CSV, {"header": True}, "id < 3" + "id,string\n1,hello\n2,world", DataFormat.CSV, {"header":True}, "id < 3" ), pytest.param( - b"id,string\n1,hello,\n2,world", DataFormat.CSV, {"header": True}, "id < 3" + b"id,string\n1,hello\n2,world", DataFormat.CSV, {"header":0}, "id < 3" ), pytest.param( '{"id": 1, "string": "hello"}', DataFormat.JSON, {}, "id < 2" diff --git a/tests/spark/readers/test_metastore_reader.py b/tests/spark/readers/test_metastore_reader.py index 3e3d2944..4af75ea9 100644 --- a/tests/spark/readers/test_metastore_reader.py +++ b/tests/spark/readers/test_metastore_reader.py @@ -1,7 +1,6 @@ import pytest -from pyspark.sql.dataframe import DataFrame - +from koheesio.spark import DataFrame from koheesio.spark.readers.metastore import MetastoreReader pytestmark = pytest.mark.spark diff --git a/tests/spark/readers/test_rest_api.py b/tests/spark/readers/test_rest_api.py index 328c8549..9c22ea3f 100644 --- a/tests/spark/readers/test_rest_api.py +++ b/tests/spark/readers/test_rest_api.py @@ -8,7 +8,7 @@ from koheesio.asyncio.http import AsyncHttpStep from koheesio.spark.readers.rest_api import AsyncHttpGetStep, RestApiReader -from koheesio.steps.http import PaginatedHtppGetStep +from koheesio.steps.http import PaginatedHttpGetStep ASYNC_BASE_URL = "http://httpbin.org" ASYNC_GET_ENDPOINT = URL(f"{ASYNC_BASE_URL}/get") @@ -27,10 +27,10 @@ def mock_paginated_api(): def test_paginated_api(mock_paginated_api): # Test that the paginated API returns all the data - transport = PaginatedHtppGetStep(url="https://api.example.com/data?page={page}", paginate=True, pages=3) + transport = PaginatedHttpGetStep(url="https://api.example.com/data?page={page}", paginate=True, pages=3) task = RestApiReader(transport=transport, spark_schema="id: int, page:int, value: string") - assert isinstance(task.transport, PaginatedHtppGetStep) + assert isinstance(task.transport, PaginatedHttpGetStep) task.execute() diff --git a/tests/spark/readers/test_teradata.py b/tests/spark/readers/test_teradata.py index 3ba2a1ad..8ac74aa7 100644 --- a/tests/spark/readers/test_teradata.py +++ b/tests/spark/readers/test_teradata.py @@ -1,9 +1,5 @@ -from unittest import mock - import pytest -from pyspark.sql import SparkSession - from koheesio.spark.readers.teradata import TeradataReader pytestmark = pytest.mark.spark @@ -27,8 +23,5 @@ def test_get_options(self): def test_execute(self, dummy_spark): """Method should be callable from parent class""" - with mock.patch.object(SparkSession, "getActiveSession") as mock_spark: - mock_spark.return_value = dummy_spark - - tr = TeradataReader(**self.common_options) - assert tr.execute().df.count() == 1 + tr = TeradataReader(**self.common_options) + assert tr.execute().df.count() == 3 diff --git a/tests/spark/tasks/test_etl_task.py b/tests/spark/tasks/test_etl_task.py index 3025d25d..be5f5a2f 100644 --- a/tests/spark/tasks/test_etl_task.py +++ b/tests/spark/tasks/test_etl_task.py @@ -1,5 +1,4 @@ import pytest -from conftest import await_job_completion from pyspark.sql import DataFrame, SparkSession from pyspark.sql.functions import col, lit @@ -11,6 +10,7 @@ from koheesio.spark.readers.dummy import DummyReader from koheesio.spark.transformations.sql_transform import SqlTransform from koheesio.spark.transformations.transform import Transform +from koheesio.spark.utils import SPARK_MINOR_VERSION from koheesio.spark.writers.delta import DeltaTableStreamWriter, DeltaTableWriter from koheesio.spark.writers.dummy import DummyWriter @@ -70,27 +70,39 @@ def test_delta_task(spark): def test_delta_stream_task(spark, checkpoint_folder): + from koheesio.spark.utils.connect import is_remote_session + delta_table = DeltaTableStep(table="delta_stream_table") DummyReader(range=5).read().write.format("delta").mode("append").saveAsTable("delta_stream_table") + writer = DeltaTableStreamWriter(table="delta_stream_table_out", checkpoint_location=checkpoint_folder) + transformations = [ + SqlTransform( + sql="SELECT ${field} FROM ${table_name} WHERE id = 0", + table_name="temp_view", + field="id", + ), + Transform(dummy_function2, name="pari"), + ] delta_task = EtlTask( source=DeltaTableStreamReader(table=delta_table), - target=DeltaTableStreamWriter(table="delta_stream_table_out", checkpoint_location=checkpoint_folder), - transformations=[ - SqlTransform( - sql="SELECT ${field} FROM ${table_name} WHERE id = 0", table_name="temp_view", params={"field": "id"} - ), - Transform(dummy_function2, name="pari"), - ], + target=writer, + transformations=transformations, ) - delta_task.run() - await_job_completion(timeout=20) + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(RuntimeError) as excinfo: + delta_task.run() - out_df = spark.table("delta_stream_table_out") - actual = out_df.head().asDict() - expected = {"id": 0, "name": "pari"} - assert actual == expected + assert "https://issues.apache.org/jira/browse/SPARK-45957" in str(excinfo.value.args[0]) + else: + delta_task.run() + writer.streaming_query.awaitTermination(timeout=20) # type: ignore + + out_df = spark.table("delta_stream_table_out") + actual = out_df.head().asDict() + expected = {"id": 0, "name": "pari"} + assert actual == expected def test_transformations_alias(spark: SparkSession) -> None: diff --git a/tests/spark/test_spark.py b/tests/spark/test_spark.py index 24c77ecb..e19b3e02 100644 --- a/tests/spark/test_spark.py +++ b/tests/spark/test_spark.py @@ -14,7 +14,8 @@ from pyspark.sql import SparkSession from koheesio.models import SecretStr -from koheesio.spark import SparkStep +from koheesio.spark import DataFrame, SparkStep +from koheesio.spark.transformations.transform import Transform pytestmark = pytest.mark.spark @@ -49,3 +50,13 @@ def test_spark_property_without_session(self): spark = SparkSession.builder.appName("pytest-pyspark-local-testing-implicit").master("local[*]").getOrCreate() step = SparkStep() assert step.spark is spark + + def test_transformation(self): + from pyspark.sql import functions as F + + def dummy_function(df: DataFrame): + return df.withColumn("hello", F.lit("world")) + + test_transformation = Transform(dummy_function) + + assert test_transformation diff --git a/tests/spark/test_spark_utils.py b/tests/spark/test_spark_utils.py index 6455bea6..cbd83bad 100644 --- a/tests/spark/test_spark_utils.py +++ b/tests/spark/test_spark_utils.py @@ -6,9 +6,11 @@ from pyspark.sql.types import StringType, StructField, StructType from koheesio.spark.utils import ( + get_column_name, import_pandas_based_on_pyspark_version, on_databricks, schema_struct_to_schema_str, + show_string, ) @@ -43,7 +45,7 @@ def test_on_databricks(env_var_value, expected_result): ) def test_import_pandas_based_on_pyspark_version(spark_version, pandas_version, expected_error): with ( - patch("koheesio.spark.utils.get_spark_minor_version", return_value=spark_version), + patch("koheesio.spark.utils.common.get_spark_minor_version", return_value=spark_version), patch("pandas.__version__", new=pandas_version), ): if expected_error: @@ -51,3 +53,16 @@ def test_import_pandas_based_on_pyspark_version(spark_version, pandas_version, e import_pandas_based_on_pyspark_version() else: import_pandas_based_on_pyspark_version() # This should not raise an error + + +def test_show_string(dummy_df): + actual = show_string(dummy_df, n=1, truncate=1, vertical=False) + assert actual == "+---+\n| id|\n+---+\n| 0|\n+---+\n" + + +def test_column_name(): + from pyspark.sql.functions import col + + name = "my_column" + column = col(name) + assert get_column_name(column) == name diff --git a/tests/spark/transformations/date_time/test_interval.py b/tests/spark/transformations/date_time/test_interval.py index 016dcbfb..71208da0 100644 --- a/tests/spark/transformations/date_time/test_interval.py +++ b/tests/spark/transformations/date_time/test_interval.py @@ -12,6 +12,7 @@ adjust_time, col, dt_column, + validate_interval, ) pytestmark = pytest.mark.spark @@ -107,13 +108,12 @@ def test_interval(input_data, column_name, operation, interval, expected, spark) df = spark.createDataFrame([(input_data,)], [column_name]) column = col(column_name) - - print(f"{df.dtypes = }") + column = DateTimeColumn.from_column(column) if operation == "-": - df_adjusted = df.withColumn("adjusted", DateTimeColumn.from_column(column) - interval) + df_adjusted = df.withColumn("adjusted", column - interval) elif operation == "+": - df_adjusted = df.withColumn("adjusted", DateTimeColumn.from_column(column) + interval) + df_adjusted = df.withColumn("adjusted", column + interval) else: raise RuntimeError(f"Invalid operation: {operation}") @@ -122,6 +122,8 @@ def test_interval(input_data, column_name, operation, interval, expected, spark) def test_interval_unhappy(spark): + with pytest.raises(ValueError): + validate_interval("some random sym*bol*s") # invalid operation with pytest.raises(ValueError): _ = adjust_time(col("some_col"), "invalid operation", "1 day") diff --git a/tests/spark/transformations/strings/test_change_case.py b/tests/spark/transformations/strings/test_change_case.py index 69c422f7..3750e79d 100644 --- a/tests/spark/transformations/strings/test_change_case.py +++ b/tests/spark/transformations/strings/test_change_case.py @@ -11,6 +11,7 @@ TitleCase, UpperCase, ) +from koheesio.spark.utils import show_string pytestmark = pytest.mark.spark @@ -76,7 +77,7 @@ def test_happy_flow(input_values, input_data, input_schema, expected, spark): target_column = change_case.target_column # log equivalent of doing df.show() - log.info(f"show output_df: \n{output_df._jdf.showString(20, 20, False)}") + log.info(f"show output_df: \n{show_string(output_df, 20, 20, False)}") actual = [row[target_column] for row in output_df.select(target_column).collect()] assert actual == expected[kls.__name__] diff --git a/tests/spark/transformations/strings/test_concat.py b/tests/spark/transformations/strings/test_concat.py index 90eacdf9..c6af459e 100644 --- a/tests/spark/transformations/strings/test_concat.py +++ b/tests/spark/transformations/strings/test_concat.py @@ -4,6 +4,7 @@ from koheesio.logger import LoggingFactory from koheesio.spark.transformations.strings.concat import Concat +from koheesio.spark.utils import show_string pytestmark = pytest.mark.spark @@ -166,7 +167,8 @@ def test_happy_flow(input_values, input_data, input_schema, expected, spark): output_df = concat.transform(input_df) # log equivalent of doing df.show() - log.info(f"show output_df: \n{output_df._jdf.showString(20, 20, False)}") + log.info(f"show output_df: \n{show_string(output_df, 20, 20, False)}") + actual = [row[target_column] for row in output_df.select(target_column).collect()] assert actual == expected diff --git a/tests/spark/transformations/strings/test_pad.py b/tests/spark/transformations/strings/test_pad.py index 05c13aa7..a8da5e92 100644 --- a/tests/spark/transformations/strings/test_pad.py +++ b/tests/spark/transformations/strings/test_pad.py @@ -7,6 +7,7 @@ from koheesio.logger import LoggingFactory from koheesio.models import ValidationError from koheesio.spark.transformations.strings.pad import LPad, Pad, RPad +from koheesio.spark.utils import show_string pytestmark = pytest.mark.spark @@ -72,7 +73,7 @@ def test_happy_flow(input_values, expected, spark): target_column = trim.target_column # log equivalent of doing df.show() - log.info(f"show output_df: \n{output_df._jdf.showString(20, 20, False)}") + log.info(f"show output_df: \n{show_string(output_df, 20, 20, False)}") actual = [row[target_column] for row in output_df.select(target_column).collect()] assert actual == expected[kls.__name__] diff --git a/tests/spark/transformations/strings/test_regexp.py b/tests/spark/transformations/strings/test_regexp.py index 112c6257..a4e95c76 100644 --- a/tests/spark/transformations/strings/test_regexp.py +++ b/tests/spark/transformations/strings/test_regexp.py @@ -6,6 +6,7 @@ from koheesio.logger import LoggingFactory from koheesio.spark.transformations.strings.regexp import RegexpExtract, RegexpReplace +from koheesio.spark.utils import show_string pytestmark = pytest.mark.spark @@ -64,7 +65,7 @@ def test_regexp_extract(input_values, expected, spark): output_df = RegexpExtract(**input_values).transform(input_df) # log equivalent of doing df.show() - log.info(f"show output_df: \n{output_df._jdf.showString(20, 20, False)}") + log.info(f"show output_df: \n{show_string(output_df, 20, 20, False)}") actual = [row.asDict() for row in output_df.collect()] assert actual == expected @@ -122,7 +123,7 @@ def test_regexp_replace(input_values, expected, spark): output_df = regexp_replace.transform(input_df) # log equivalent of doing df.show() - log.info(f"show output_df: \n{output_df._jdf.showString(20, 20, False)}") + log.info(f"show output_df: \n{show_string(output_df, 20, 20, False)}") actual = [row.asDict()[target_column] for row in output_df.collect()] assert actual == expected diff --git a/tests/spark/transformations/strings/test_split.py b/tests/spark/transformations/strings/test_split.py index 0d858bad..1bd4e8e0 100644 --- a/tests/spark/transformations/strings/test_split.py +++ b/tests/spark/transformations/strings/test_split.py @@ -6,6 +6,7 @@ from koheesio.logger import LoggingFactory from koheesio.spark.transformations.strings.split import SplitAll, SplitAtFirstMatch +from koheesio.spark.utils import show_string pytestmark = pytest.mark.spark @@ -83,7 +84,7 @@ def test_split_all(input_values, data, schema, expected, spark): output_df = split_all.transform(df=input_df) # log equivalent of doing df.show() - log.info(f"show output_df: \n{output_df._jdf.showString(20, 20, False)}") + log.info(f"show output_df: \n{show_string(output_df, 20, 20, False)}") actual = [row.asDict()[filter_column] for row in output_df.collect()] assert actual == expected @@ -165,7 +166,7 @@ def test_split_at_first_match(input_values, data, schema, expected, spark): output_df = split_at_first_match.transform(df=input_df) # log equivalent of doing df.show() - log.info(f"show output_df: \n{output_df._jdf.showString(20, 20, False)}") + log.info(f"show output_df: \n{show_string(output_df, 20, 20, False)}") actual = [row.asDict()[filter_column] for row in output_df.collect()] assert actual == expected diff --git a/tests/spark/transformations/strings/test_string_replace.py b/tests/spark/transformations/strings/test_string_replace.py index 22081dc1..fb9b8534 100644 --- a/tests/spark/transformations/strings/test_string_replace.py +++ b/tests/spark/transformations/strings/test_string_replace.py @@ -6,6 +6,7 @@ from koheesio.logger import LoggingFactory from koheesio.spark.transformations.strings.replace import Replace +from koheesio.spark.utils import show_string pytestmark = pytest.mark.spark @@ -49,7 +50,8 @@ def test_happy_flow(input_values, expected, spark): output_df = replace.transform(input_df) # log equivalent of doing df.show() - log.info(f"show output_df: \n{output_df._jdf.showString(20, 20, False)}") + log.info(f"show output_df: \n{show_string(output_df, 20, 20, False)}") + actual = [row.asDict()[target_column] for row in output_df.collect()] assert actual == expected diff --git a/tests/spark/transformations/strings/test_substring.py b/tests/spark/transformations/strings/test_substring.py index 29c43018..b31b14ed 100644 --- a/tests/spark/transformations/strings/test_substring.py +++ b/tests/spark/transformations/strings/test_substring.py @@ -6,6 +6,7 @@ from koheesio.logger import LoggingFactory from koheesio.spark.transformations.strings.substring import Substring +from koheesio.spark.utils import show_string pytestmark = pytest.mark.spark @@ -68,7 +69,7 @@ def test_substring(input_values, data, schema, expected, spark): output_df = substring.transform(input_df) # log equivalent of doing df.show() - log.info(f"show output_df: \n{output_df._jdf.showString(20, 20, False)}") + log.info(f"show output_df: \n{show_string(output_df, 20, 20, False)}") if target_column := substring.target_column: actual = [row.asDict()[target_column] for row in output_df.collect()] diff --git a/tests/spark/transformations/strings/test_trim.py b/tests/spark/transformations/strings/test_trim.py index 2252aabc..63a63109 100644 --- a/tests/spark/transformations/strings/test_trim.py +++ b/tests/spark/transformations/strings/test_trim.py @@ -8,6 +8,7 @@ from koheesio.logger import LoggingFactory from koheesio.spark.transformations.strings.trim import LTrim, RTrim, Trim +from koheesio.spark.utils import show_string pytestmark = pytest.mark.spark @@ -49,7 +50,7 @@ def test_happy_flow(input_values, input_data, input_schema, expected, spark): target_column = trim.target_column # log equivalent of doing df.show() - log.info(f"show output_df: \n{output_df._jdf.showString(20, 20, False)}") + log.info(f"show output_df: \n{show_string(output_df, 20, 20, False)}") actual = [row[target_column] for row in output_df.select(target_column).collect()] assert actual == expected[kls.__name__] diff --git a/tests/spark/transformations/test_arrays.py b/tests/spark/transformations/test_arrays.py index 2e89faed..49efc8c9 100644 --- a/tests/spark/transformations/test_arrays.py +++ b/tests/spark/transformations/test_arrays.py @@ -340,7 +340,7 @@ def test_array(kls, column, expected_data, params, spark): # noinspection PyCallingNonCallable df = kls(df=test_data, column=column, **params).transform() - actual_data = df.select(column).rdd.flatMap(lambda x: x).collect() + actual_data = [row.asDict()[column] for row in df.select(column).collect()] def check_result(_actual_data: list, _expected_data: list): _data = _expected_data or _actual_data diff --git a/tests/spark/transformations/test_cast_to_datatype.py b/tests/spark/transformations/test_cast_to_datatype.py index b0d9bf13..a0fc628c 100644 --- a/tests/spark/transformations/test_cast_to_datatype.py +++ b/tests/spark/transformations/test_cast_to_datatype.py @@ -9,10 +9,10 @@ from pydantic import ValidationError -from pyspark.sql import DataFrame from pyspark.sql import functions as f from koheesio.logger import LoggingFactory +from koheesio.spark import DataFrame from koheesio.spark.transformations.cast_to_datatype import ( CastToBinary, CastToBoolean, @@ -27,7 +27,7 @@ CastToString, CastToTimestamp, ) -from koheesio.spark.utils import SparkDatatype +from koheesio.spark.utils import SparkDatatype, show_string pytestmark = pytest.mark.spark @@ -165,7 +165,7 @@ def test_happy_flow(input_values, expected, df_with_all_types: DataFrame): target_column = cast_to_datatype.target_column # log equivalent of doing df.show() - log.error(f"show output_df: \n{output_df.select(col, target_column)._jdf.showString(20, 20, False)}") + log.info(f"show output_df: \n{show_string(output_df.select(col, target_column), 20, 20, False)}") actual = [row[target_column] for row in output_df.select(target_column).collect()][0] assert actual == expected @@ -385,7 +385,7 @@ def test_cast_to_specific_type(klass, expected, df_with_all_types): actual = output_df.head().asDict() # log equivalent of doing df.show() - log.error(f"show actual: \n{output_df._jdf.showString(20, 20, False)}") + log.info(f"show output_df: \n{show_string(output_df, 20, 20, False)}") assert target_columns == list(expected.keys()) assert actual == expected @@ -430,13 +430,11 @@ def test_decimal_precision_and_scale(precision, scale, alternative_value, expect .select("c1", "c2") ) - # log equivalent of doing df.show() and df.printSchema() - log.error(f"show input_df: \n{input_df._jdf.showString(20, 20, False)}") - log.error(f"printSchema input_df: \n{input_df._jdf.schema().treeString()}") + # log equivalent of doing df.show() + log.info(f"show output_df: \n{show_string(input_df, 20, 20, False)}") output_df = CastToDecimal(columns=["c1", "c2"], scale=scale, precision=precision).transform(input_df) - log.error(f"show output_df: \n{output_df._jdf.showString(20, 20, False)}") - log.error(f"printSchema output_df: \n{output_df._jdf.schema().treeString()}") + log.info(f"show output_df: \n{show_string(output_df, 20, 20, False)}") actual = [row.asDict() for row in output_df.collect()] assert actual == expected diff --git a/tests/spark/transformations/test_get_item.py b/tests/spark/transformations/test_get_item.py index 8401680c..54f70360 100644 --- a/tests/spark/transformations/test_get_item.py +++ b/tests/spark/transformations/test_get_item.py @@ -55,7 +55,10 @@ def test_transform_get_item(input_values, input_data, input_schema, expected, sp input_df = spark.createDataFrame(data=input_data, schema=input_schema) gi = GetItem(**input_values) output_df = gi.transform(input_df) - actual = output_df.orderBy(input_schema[0]).select(gi.target_column).rdd.map(lambda r: r[0]).collect() + target_column = gi.target_column + actual = [ + row.asDict()[target_column] for row in output_df.orderBy(input_schema[0]).select(gi.target_column).collect() + ] assert actual == expected diff --git a/tests/spark/transformations/test_repartition.py b/tests/spark/transformations/test_repartition.py index 8b5f6a02..9a1ee5db 100644 --- a/tests/spark/transformations/test_repartition.py +++ b/tests/spark/transformations/test_repartition.py @@ -1,5 +1,7 @@ import pytest +from pyspark.sql import DataFrame + from koheesio.models import ValidationError from koheesio.spark.transformations.repartition import Repartition @@ -53,9 +55,9 @@ def test_repartition(input_values, expected, spark): ], schema="product string, amount int, country string", ) - df = Repartition(**input_values).transform(input_df) - assert df.rdd.getNumPartitions() == expected + if isinstance(input_df, DataFrame): + assert df.rdd.getNumPartitions() == expected def test_repartition_should_raise_error(): diff --git a/tests/spark/transformations/test_replace.py b/tests/spark/transformations/test_replace.py index 202f0933..207242cc 100644 --- a/tests/spark/transformations/test_replace.py +++ b/tests/spark/transformations/test_replace.py @@ -2,6 +2,7 @@ from koheesio.logger import LoggingFactory from koheesio.spark.transformations.replace import Replace +from koheesio.spark.utils import show_string pytestmark = pytest.mark.spark @@ -98,13 +99,13 @@ def test_all_data_types(input_values, df_with_all_types): input_values["to_value"] = input_values.get("to_value", "happy") expected = input_values["to_value"] df = Replace(**input_values).transform(df_with_all_types) - log.error(f"show df: \n{df._jdf.showString(20, 20, False)}") + log.info(f"show df: \n{show_string(df,20, 20, False)}") actual = df.head().asDict()[column] assert actual == expected else: input_values["to_value"] = "unhappy" expected = df_with_all_types.head().asDict()[column] # stay the same df = Replace(**input_values).transform(df_with_all_types) - log.error(f"show df: \n{df._jdf.showString(20, 20, False)}") + log.info(f"show df: \n{show_string(df,20, 20, False)}") actual = df.head().asDict()[column] assert actual == expected diff --git a/tests/spark/transformations/test_row_number_dedup.py b/tests/spark/transformations/test_row_number_dedup.py index 7949a9d4..ba521a5d 100644 --- a/tests/spark/transformations/test_row_number_dedup.py +++ b/tests/spark/transformations/test_row_number_dedup.py @@ -10,7 +10,7 @@ pytestmark = pytest.mark.spark -@pytest.mark.parametrize("target_column", ["col_row_nuber"]) +@pytest.mark.parametrize("target_column", ["col_row_number"]) def test_row_number_dedup(spark: SparkSession, target_column: str) -> None: df = spark.createDataFrame( [ @@ -48,7 +48,7 @@ def test_row_number_dedup(spark: SparkSession, target_column: str) -> None: } -@pytest.mark.parametrize("target_column", ["col_row_nuber"]) +@pytest.mark.parametrize("target_column", ["col_row_number"]) def test_row_number_dedup_not_list_column(spark: SparkSession, target_column: str) -> None: df = spark.createDataFrame( [ @@ -88,7 +88,7 @@ def test_row_number_dedup_not_list_column(spark: SparkSession, target_column: st } -@pytest.mark.parametrize("target_column", ["col_row_nuber"]) +@pytest.mark.parametrize("target_column", ["col_row_number"]) def test_row_number_dedup_with_columns(spark: SparkSession, target_column: str) -> None: df = spark.createDataFrame( [ @@ -128,7 +128,7 @@ def test_row_number_dedup_with_columns(spark: SparkSession, target_column: str) } -@pytest.mark.parametrize("target_column", ["col_row_nuber"]) +@pytest.mark.parametrize("target_column", ["col_row_number"]) def test_row_number_dedup_with_duplicated_columns(spark: SparkSession, target_column: str) -> None: df = spark.createDataFrame( [ diff --git a/tests/spark/transformations/test_transform.py b/tests/spark/transformations/test_transform.py index c30c4343..1f92e490 100644 --- a/tests/spark/transformations/test_transform.py +++ b/tests/spark/transformations/test_transform.py @@ -2,10 +2,10 @@ import pytest -from pyspark.sql import DataFrame from pyspark.sql import functions as f from koheesio.logger import LoggingFactory +from koheesio.spark import DataFrame from koheesio.spark.transformations.transform import Transform pytestmark = pytest.mark.spark diff --git a/tests/spark/writers/delta/test_delta_writer.py b/tests/spark/writers/delta/test_delta_writer.py index 66306de1..92a349c8 100644 --- a/tests/spark/writers/delta/test_delta_writer.py +++ b/tests/spark/writers/delta/test_delta_writer.py @@ -11,6 +11,7 @@ from koheesio.spark import AnalysisException from koheesio.spark.delta import DeltaTableStep +from koheesio.spark.utils import SPARK_MINOR_VERSION from koheesio.spark.writers import BatchOutputMode, StreamingOutputMode from koheesio.spark.writers.delta import DeltaTableStreamWriter, DeltaTableWriter from koheesio.spark.writers.delta.utils import log_clauses @@ -18,6 +19,8 @@ pytestmark = pytest.mark.spark +skip_reason = "Tests are not working with PySpark 3.5 due to delta calling _sc. Test requires pyspark version >= 4.0" + def test_delta_table_writer(dummy_df, spark): table_name = "test_table" @@ -48,6 +51,11 @@ def test_delta_partitioning(spark, sample_df_to_partition): def test_delta_table_merge_all(spark): + from koheesio.spark.utils.connect import is_remote_session + + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + pytest.skip(reason=skip_reason) + table_name = "test_merge_all_table" target_df = spark.createDataFrame( [{"id": 1, "value": "no_merge"}, {"id": 2, "value": "expected_merge"}, {"id": 5, "value": "xxxx"}] @@ -86,6 +94,11 @@ def test_delta_table_merge_all(spark): def test_deltatablewriter_with_invalid_conditions(spark, dummy_df): + from koheesio.spark.utils.connect import is_remote_session + + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + pytest.skip(reason=skip_reason) + table_name = "delta_test_table" merge_builder = ( DeltaTable.forName(sparkSession=spark, tableOrViewName=table_name) @@ -186,7 +199,7 @@ def test_delta_stream_table_writer(streaming_dummy_df, spark, checkpoint_folder) df=streaming_dummy_df, ) delta_writer.write() - await_job_completion(timeout=20, query_id=delta_writer.streaming_query.id) + await_job_completion(spark, timeout=20, query_id=delta_writer.streaming_query.id) df = spark.read.table(table_name) assert df.count() == 10 @@ -271,6 +284,11 @@ def test_delta_with_options(spark): def test_merge_from_args(spark, dummy_df): + from koheesio.spark.utils.connect import is_remote_session + + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + pytest.skip(reason=skip_reason) + table_name = "test_table_merge_from_args" dummy_df.write.format("delta").saveAsTable(table_name) @@ -330,6 +348,11 @@ def test_merge_from_args_raise_value_error(spark, output_mode_params): def test_merge_no_table(spark): + from koheesio.spark.utils.connect import is_remote_session + + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + pytest.skip(reason=skip_reason) + table_name = "test_merge_no_table" target_df = spark.createDataFrame( [{"id": 1, "value": "no_merge"}, {"id": 2, "value": "expected_merge"}, {"id": 5, "value": "expected_merge"}] diff --git a/tests/spark/writers/delta/test_scd.py b/tests/spark/writers/delta/test_scd.py index 3d91e658..087f957d 100644 --- a/tests/spark/writers/delta/test_scd.py +++ b/tests/spark/writers/delta/test_scd.py @@ -11,14 +11,23 @@ from pyspark.sql import functions as F from pyspark.sql.types import Row -from koheesio.spark import DataFrame, current_timestamp_utc +from koheesio.spark import DataFrame from koheesio.spark.delta import DeltaTableStep +from koheesio.spark.functions import current_timestamp_utc +from koheesio.spark.utils import SPARK_MINOR_VERSION from koheesio.spark.writers.delta.scd import SCD2DeltaTableWriter pytestmark = pytest.mark.spark +skip_reason = "Tests are not working with PySpark 3.5 due to delta calling _sc. Test requires pyspark version >= 4.0" + def test_scd2_custom_logic(spark): + from koheesio.spark.utils.connect import is_remote_session + + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + pytest.skip(reason=skip_reason) + def _get_result(target_df: DataFrame, expr: str): res = ( target_df.where(expr) @@ -249,6 +258,11 @@ def _prepare_merge_builder( def test_scd2_logic(spark): + from koheesio.spark.utils.connect import is_remote_session + + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + pytest.skip(reason=skip_reason) + changes_data = [ [("key1", "value1", "scd1-value11", "2024-05-01"), ("key2", "value2", "scd1-value21", "2024-04-01")], [("key1", "value1_updated", "scd1-value12", "2024-05-02"), ("key3", "value3", "scd1-value31", "2024-05-03")], diff --git a/tests/spark/writers/test_file_writer.py b/tests/spark/writers/test_file_writer.py index 3f497571..29b3eb2a 100644 --- a/tests/spark/writers/test_file_writer.py +++ b/tests/spark/writers/test_file_writer.py @@ -1,7 +1,6 @@ from pathlib import Path from unittest.mock import MagicMock -from koheesio.spark import DataFrame, SparkSession from koheesio.spark.writers import BatchOutputMode from koheesio.spark.writers.file_writer import FileFormat, FileWriter @@ -20,7 +19,20 @@ def test_execute(dummy_df, mocker): writer = FileWriter(df=dummy_df, output_mode=output_mode, path=path, format=format, **options) mock_df_writer = MagicMock() - mocker.patch.object(DataFrame, "write", mock_df_writer) + + from koheesio.spark.utils.connect import is_remote_session + + if is_remote_session(): + from pyspark.sql import DataFrame as SparkDataFrame + from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame + + mocker.patch.object(SparkDataFrame, "write", mock_df_writer) + mocker.patch.object(ConnectDataFrame, "write", mock_df_writer) + else: + from pyspark.sql import DataFrame + + mocker.patch.object(DataFrame, "write", mock_df_writer) + mock_df_writer.options.return_value = mock_df_writer writer.execute() diff --git a/tests/steps/test_http.py b/tests/steps/test_http.py index f0270385..5ddee35a 100644 --- a/tests/steps/test_http.py +++ b/tests/steps/test_http.py @@ -162,6 +162,7 @@ def test_max_retries(max_retries, endpoint, status_code, expected_count, error_t session = requests.Session() retry_logic = Retry(total=max_retries, status_forcelist=[status_code]) session.mount("https://", HTTPAdapter(max_retries=retry_logic)) + # noinspection HttpUrlsUsage session.mount("http://", HTTPAdapter(max_retries=retry_logic)) step = HttpGetStep(url=endpoint, session=session) @@ -187,6 +188,7 @@ def test_initial_delay_and_backoff(monkeypatch, backoff, expected): session = requests.Session() retry_logic = Retry(total=3, backoff_factor=backoff, status_forcelist=[503]) session.mount("https://", HTTPAdapter(max_retries=retry_logic)) + # noinspection HttpUrlsUsage session.mount("http://", HTTPAdapter(max_retries=retry_logic)) step = HttpGetStep( diff --git a/tests/steps/test_steps.py b/tests/steps/test_steps.py index 7ae53541..92c563a7 100644 --- a/tests/steps/test_steps.py +++ b/tests/steps/test_steps.py @@ -11,20 +11,11 @@ from pydantic import ValidationError -from pyspark.sql import DataFrame -from pyspark.sql.functions import lit - from koheesio.models import Field -from koheesio.spark.transformations.transform import Transform from koheesio.steps import Step, StepMetaClass, StepOutput from koheesio.steps.dummy import DummyOutput, DummyStep from koheesio.utils import get_project_root - -def dummy_function(df: DataFrame): - return df.withColumn("hello", lit("world")) - - output_dict_1 = dict(a="foo", b=42) test_output_1 = DummyOutput(**output_dict_1) @@ -35,7 +26,6 @@ def dummy_function(df: DataFrame): # we put the newline in the description to test that the newline is removed test_step = DummyStep(a="foo", b=2, description="Dummy step for testing purposes.\nwith a newline") -test_transformation = Transform(dummy_function) PROJECT_ROOT = get_project_root() diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index ead642f4..fab14393 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -1,11 +1,3 @@ -import os -from unittest.mock import patch - -import pytest - -from pyspark.sql.types import StringType, StructField, StructType - -from koheesio.spark.utils import on_databricks, schema_struct_to_schema_str from koheesio.utils import get_args_for_func, get_random_string From 90e64629c34185fe276b25d689ab00cf3da94954 Mon Sep 17 00:00:00 2001 From: Mikita Sakalouski <38785549+mikita-sakalouski@users.noreply.github.com> Date: Wed, 30 Oct 2024 09:56:57 +0100 Subject: [PATCH 02/33] refactor: change private attr and step getter (#82) ## Description Adjust Output class and fix getting of attributes from Step classes. ## Related Issue #81 ## Motivation and Context Get correct behavior for class attributes and also raise Exception in proper manner ## How Has This Been Tested? Existing tests ## Types of changes - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Checklist: - [x] My code follows the code style of this project. - [ ] My change requires a change to the documentation. - [x] I have updated the documentation accordingly. - [x] I have read the **CONTRIBUTING** document. - [ ] I have added tests to cover my changes. - [x] All new and existing tests passed. --------- Co-authored-by: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> --- pyproject.toml | 128 ++++------ src/koheesio/asyncio/__init__.py | 10 +- src/koheesio/asyncio/http.py | 16 +- src/koheesio/context.py | 4 +- src/koheesio/integrations/box.py | 7 +- .../integrations/snowflake/test_utils.py | 4 +- src/koheesio/integrations/spark/sftp.py | 26 +- src/koheesio/integrations/spark/snowflake.py | 30 ++- .../integrations/spark/tableau/hyper.py | 5 +- .../integrations/spark/tableau/server.py | 4 +- src/koheesio/logger.py | 4 +- src/koheesio/models/__init__.py | 10 +- src/koheesio/models/reader.py | 2 +- src/koheesio/models/sql.py | 2 +- src/koheesio/notifications/slack.py | 2 +- src/koheesio/spark/__init__.py | 4 +- src/koheesio/spark/delta.py | 2 +- src/koheesio/spark/readers/hana.py | 3 +- src/koheesio/spark/readers/memory.py | 5 +- src/koheesio/spark/snowflake.py | 34 +-- .../spark/transformations/__init__.py | 4 +- .../spark/transformations/camel_to_snake.py | 2 +- .../transformations/date_time/interval.py | 4 +- .../spark/transformations/drop_column.py | 2 +- .../spark/transformations/strings/trim.py | 2 +- src/koheesio/spark/transformations/uuid5.py | 2 +- src/koheesio/spark/utils/common.py | 4 +- src/koheesio/spark/writers/buffer.py | 11 +- src/koheesio/spark/writers/delta/batch.py | 16 +- src/koheesio/spark/writers/delta/scd.py | 3 +- src/koheesio/spark/writers/delta/stream.py | 2 +- src/koheesio/spark/writers/delta/utils.py | 54 +++- src/koheesio/spark/writers/dummy.py | 1 + src/koheesio/steps/__init__.py | 39 +-- src/koheesio/steps/http.py | 2 +- src/koheesio/utils.py | 6 +- tests/asyncio/test_asyncio_http.py | 2 +- tests/conftest.py | 2 +- tests/core/test_logger.py | 2 +- tests/models/test_models.py | 2 +- tests/snowflake/test_snowflake.py | 3 +- tests/spark/conftest.py | 8 +- .../integrations/snowflake/test_sync_task.py | 3 +- tests/spark/readers/test_auto_loader.py | 2 +- tests/spark/readers/test_memory.py | 3 +- tests/spark/readers/test_rest_api.py | 4 +- tests/spark/test_delta.py | 2 +- .../spark/writers/delta/test_delta_writer.py | 125 ++++++---- tests/spark/writers/delta/test_scd.py | 235 ++++++++++-------- tests/spark/writers/test_buffer.py | 2 +- tests/spark/writers/test_sftp.py | 2 +- tests/sso/test_okta.py | 2 +- tests/steps/test_steps.py | 4 +- 53 files changed, 473 insertions(+), 386 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3edd7eca..493eb783 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,17 +70,7 @@ tableau = ["tableauhyperapi>=0.0.19484", "tableauserverclient>=0.25"] # Snowflake dependencies snowflake = ["snowflake-connector-python>=3.12.0"] # Development dependencies -dev = [ - "black", - "isort", - "ruff", - "mypy", - "pylint", - "colorama", - "types-PyYAML", - "types-requests", - -] +dev = ["ruff", "mypy", "pylint", "colorama", "types-PyYAML", "types-requests"] test = [ "chispa", "coverage[toml]", @@ -153,23 +143,19 @@ Run `hatch run` to run scripts in the default environment. # Code Quality To check and format the codebase, we use: - - `black` for code formatting - - `isort` for import sorting (includes colorama for colored output) - - `ruff` for linting. + - `ruff` for linting, formtting and sorting imports - `mypy` for static type checking. - `pylint` for code quality checks. --- There are several ways to run style checks and formatting: - `hatch run black-check` will check the codebase with black without applying fixes. - `hatch run black-fmt` will format the codebase using black. -- `hatch run isort-check` will check the codebase with isort without applying fixes. -- `hatch run isort-fmt` will format the codebase using isort. - `hatch run ruff-check` will check the codebase with ruff without applying fixes. - `hatch run ruff-fmt` will format the codebase using ruff. - `hatch run mypy-check` will check the codebase with mypy. - `hatch run pylint-check` will check the codebase with pylint. - `hatch run check` will run all the above checks (including pylint and mypy). -- `hatch run fmt` or `hatch run fix` will format the codebase using black, isort, and ruff. +- `hatch run fmt` or `hatch run fix` will format the codebase using ruff. - `hatch run lint` will run ruff, mypy, and pylint. # Testing and Coverage @@ -207,22 +193,14 @@ features = [ # TODO: add scripts section based on Makefile # TODO: add bandit # Code Quality commands -black-check = "black --check --diff ." -black-fmt = "black ." -isort-check = "isort . --check --diff --color" -isort-fmt = "isort ." -ruff-check = "ruff check ." -ruff-fmt = "ruff check . --fix" +ruff-fmt = "ruff format --check --diff ." +ruff-fmt-fix = "ruff format ." +ruff-check = "ruff check . --diff" +ruff-check-fix = "ruff check . --fix" mypy-check = "mypy src" pylint-check = "pylint --output-format=colorized -d W0511 src" -check = [ - "- black-check", - "- isort-check", - "- ruff-check", - "- mypy-check", - "- pylint-check", -] -fmt = ["black-fmt", "isort-fmt", "ruff-fmt"] +check = ["- ruff-fmt", "- ruff-check", "- mypy-check", "- pylint-check"] +fmt = ["ruff-fmt-fix", "ruff-check-fix"] fix = "fmt" lint = ["- ruff-fmt", "- mypy-check", "pylint-check"] log-versions = "python --version && {env:HATCH_UV} pip freeze | grep pyspark" @@ -353,6 +331,7 @@ filterwarnings = [ "ignore:'PYARROW_IGNORE_TIMEZONE'.*:UserWarning:pyspark.pandas.*", # pydantic warnings "ignore:A custom validator is returning a value other than `self`.*.*:UserWarning:pydantic.main.*", + "ignore: 79 characters)' -> let Black handle this instead @@ -549,7 +494,6 @@ ignore = [ ] # Unlike Flake8, default to a complexity level of 10. mccabe.max-complexity = 10 - # Allow autofix for all enabled rules (when `--fix` is provided). fixable = [ "A", @@ -602,6 +546,22 @@ unfixable = [] # Allow unused variables when underscore-prefixed. dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +[tool.ruff.lint.isort] +force-to-top = ["__future__", "typing"] +section-order = [ + "future", + "standard-library", + "third-party", + "pydantic", + "pyspark", + "first-party", + "local-folder", +] +sections.pydantic = ["pydantic"] +sections.pyspark = ["pyspark"] +detect-same-package = true +force-sort-within-sections = true + [tool.mypy] python_version = "3.10" files = ["koheesio/**/*.py"] diff --git a/src/koheesio/asyncio/__init__.py b/src/koheesio/asyncio/__init__.py index 093c4a09..3dc63fea 100644 --- a/src/koheesio/asyncio/__init__.py +++ b/src/koheesio/asyncio/__init__.py @@ -2,10 +2,12 @@ This module provides classes for asynchronous steps in the koheesio package. """ -from typing import Dict, Union +from typing import Dict, Optional, Union from abc import ABC from asyncio import iscoroutine +from pydantic import PrivateAttr + from koheesio.steps import Step, StepMetaClass, StepOutput @@ -65,7 +67,9 @@ def merge(self, other: Union[Dict, StepOutput]) -> "AsyncStepOutput": -------- ```python step_output = StepOutput(foo="bar") - step_output.merge({"lorem": "ipsum"}) # step_output will now contain {'foo': 'bar', 'lorem': 'ipsum'} + step_output.merge( + {"lorem": "ipsum"} + ) # step_output will now contain {'foo': 'bar', 'lorem': 'ipsum'} ``` Functionally similar to adding two dicts together; like running `{**dict_a, **dict_b}`. @@ -103,4 +107,4 @@ class Output(AsyncStepOutput): This class represents the output of the asyncio step. It inherits from the AsyncStepOutput class. """ - __output__: Output + _output: Optional[Output] = PrivateAttr(default=None) diff --git a/src/koheesio/asyncio/http.py b/src/koheesio/asyncio/http.py index ece14f15..14e4ad52 100644 --- a/src/koheesio/asyncio/http.py +++ b/src/koheesio/asyncio/http.py @@ -4,14 +4,14 @@ from __future__ import annotations +from typing import Any, Dict, List, Optional, Tuple, Union import asyncio import warnings -from typing import Any, Dict, List, Optional, Tuple, Union -import nest_asyncio # type: ignore[import-untyped] -import yarl from aiohttp import BaseConnector, ClientSession, TCPConnector from aiohttp_retry import ExponentialRetry, RetryClient, RetryOptionsBase +import nest_asyncio # type: ignore[import-untyped] +import yarl from pydantic import Field, SecretStr, field_validator, model_validator @@ -54,26 +54,28 @@ class AsyncHttpStep(AsyncStep, ExtraParamsMixin): from yarl import URL from typing import Dict, Any, Union, List, Tuple + # Initialize the AsyncHttpStep async def main(): session = ClientSession() - urls = [URL('https://example.com/api/1'), URL('https://example.com/api/2')] + urls = [URL("https://example.com/api/1"), URL("https://example.com/api/2")] retry_options = ExponentialRetry() connector = TCPConnector(limit=10) - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} step = AsyncHttpStep( client_session=session, url=urls, retry_options=retry_options, connector=connector, - headers=headers + headers=headers, ) # Execute the step - responses_urls= await step.get() + responses_urls = await step.get() return responses_urls + # Run the main function responses_urls = asyncio.run(main()) ``` diff --git a/src/koheesio/context.py b/src/koheesio/context.py index e0b818a4..6fe725c2 100644 --- a/src/koheesio/context.py +++ b/src/koheesio/context.py @@ -13,10 +13,10 @@ from __future__ import annotations -import re +from typing import Any, Dict, Iterator, Union from collections.abc import Mapping from pathlib import Path -from typing import Any, Dict, Iterator, Union +import re import jsonpickle # type: ignore[import-untyped] import tomli diff --git a/src/koheesio/integrations/box.py b/src/koheesio/integrations/box.py index cd5baab9..114fdc04 100644 --- a/src/koheesio/integrations/box.py +++ b/src/koheesio/integrations/box.py @@ -10,17 +10,16 @@ * Application is authorized for the enterprise (Developer Portal - MyApp - Authorization) """ -import datetime -import re from typing import Any, Dict, Optional, Union from abc import ABC from io import BytesIO, StringIO from pathlib import PurePath +import re -import pandas as pd from boxsdk import Client, JWTAuth from boxsdk.object.file import File from boxsdk.object.folder import Folder +import pandas as pd from pyspark.sql.functions import expr, lit from pyspark.sql.types import StructType @@ -475,7 +474,7 @@ def execute(self) -> BoxReaderBase.Output: if len(files) > 0: self.log.info( - f"A total of {len(files)} files, that match the mask '{self.mask}' has been detected in {self.path}." + f"A total of {len(files)} files, that match the mask '{self.filter}' has been detected in {self.path}." f" They will be loaded into Spark Dataframe: {files}" ) else: diff --git a/src/koheesio/integrations/snowflake/test_utils.py b/src/koheesio/integrations/snowflake/test_utils.py index 8b85e97d..8ae9ac3e 100644 --- a/src/koheesio/integrations/snowflake/test_utils.py +++ b/src/koheesio/integrations/snowflake/test_utils.py @@ -25,7 +25,9 @@ def test_execute(self, mock_query): mock_query.expected_data = [("row1",), ("row2",)] # Act - instance = SnowflakeRunQueryPython(**COMMON_OPTIONS, query=query, account="42") + instance = SnowflakeRunQueryPython( + **COMMON_OPTIONS, query=query, account="42" + ) instance.execute() # Assert diff --git a/src/koheesio/integrations/spark/sftp.py b/src/koheesio/integrations/spark/sftp.py index 90812b8d..d983913f 100644 --- a/src/koheesio/integrations/spark/sftp.py +++ b/src/koheesio/integrations/spark/sftp.py @@ -12,15 +12,17 @@ For more details on each mode, see the docstring of the SFTPWriteMode enum. """ -import hashlib -import time from typing import Optional, Union from enum import Enum +import hashlib from pathlib import Path +import time from paramiko.sftp_client import SFTPClient from paramiko.transport import Transport +from pydantic import PrivateAttr + from koheesio.models import ( Field, InstanceOf, @@ -152,8 +154,8 @@ class SFTPWriter(Writer): ) # private attrs - __client__: SFTPClient - __transport__: Transport + _client: Optional[SFTPClient] = PrivateAttr(default=None) + _transport: Optional[Transport] = PrivateAttr(default=None) @model_validator(mode="before") def validate_path_and_file_name(cls, data: dict) -> dict: @@ -203,26 +205,26 @@ def transport(self) -> Transport: If the username and password are provided, use them to connect to the SFTP server. """ - if not self.__transport__: - self.__transport__ = Transport((self.host, self.port)) + if not self._transport: + self._transport = Transport((self.host, self.port)) if self.username and self.password: - self.__transport__.connect( + self._transport.connect( username=self.username.get_secret_value(), password=self.password.get_secret_value() ) else: - self.__transport__.connect() - return self.__transport__ + self._transport.connect() + return self._transport @property def client(self) -> SFTPClient: """Return the SFTP client. If it doesn't exist, create it.""" - if not self.__client__: + if not self._client: try: - self.__client__ = SFTPClient.from_transport(self.transport) + self._client = SFTPClient.from_transport(self.transport) except EOFError as e: self.log.error(f"Failed to create SFTP client. Transport active: {self.transport.is_active()}") raise e - return self.__client__ + return self._client def _close_client(self) -> None: """Close the SFTP client and transport.""" diff --git a/src/koheesio/integrations/spark/snowflake.py b/src/koheesio/integrations/spark/snowflake.py index 59686d93..8a4ad9a2 100644 --- a/src/koheesio/integrations/spark/snowflake.py +++ b/src/koheesio/integrations/spark/snowflake.py @@ -43,10 +43,10 @@ from __future__ import annotations -import json from typing import Any, Callable, Dict, List, Optional, Set, Union from abc import ABC from copy import deepcopy +import json from textwrap import dedent from pyspark.sql import Window @@ -989,7 +989,7 @@ def extract(self) -> DataFrame: raise RuntimeError( f"Source table {self.source_table.table_name} does not have CDF enabled. " f"Set TBLPROPERTIES ('delta.enableChangeDataFeed' = true) to enable. " - f"Current properties = {self.source_table_properties}" + f"Current properties = {self.source_table.get_persisted_properties()}" ) df = self.reader.read() @@ -1042,17 +1042,21 @@ class TagSnowflakeQuery(Step, ExtraParamsMixin): ------- #### Using `options` parameter ```python - query_tag = AddQueryTag( - options={"preactions": "ALTER SESSION"}, - task_name="cleanse_task", - pipeline_name="ingestion-pipeline", - etl_date="2022-01-01", - pipeline_execution_time="2022-01-01T00:00:00", - task_execution_time="2022-01-01T01:00:00", - environment="dev", - trace_id="acd4f3f96045", - span_id="546d2d66f6cb", - ).execute().options + query_tag = ( + AddQueryTag( + options={"preactions": "ALTER SESSION"}, + task_name="cleanse_task", + pipeline_name="ingestion-pipeline", + etl_date="2022-01-01", + pipeline_execution_time="2022-01-01T00:00:00", + task_execution_time="2022-01-01T01:00:00", + environment="dev", + trace_id="acd4f3f96045", + span_id="546d2d66f6cb", + ) + .execute() + .options + ) ``` In this example, the query tag pre-action will be added to the Snowflake options. diff --git a/src/koheesio/integrations/spark/tableau/hyper.py b/src/koheesio/integrations/spark/tableau/hyper.py index 992d9f19..94230d87 100644 --- a/src/koheesio/integrations/spark/tableau/hyper.py +++ b/src/koheesio/integrations/spark/tableau/hyper.py @@ -1,6 +1,6 @@ -import os from typing import Any, List, Optional, Union from abc import ABC, abstractmethod +import os from pathlib import PurePath from tempfile import TemporaryDirectory @@ -435,7 +435,8 @@ def clean_dataframe(self) -> DataFrame: if d_col.dataType.precision > 18: # noinspection PyUnresolvedReferences _df = _df.withColumn( - d_col.name, col(d_col.name).cast(DecimalType(precision=18, scale=d_col.dataType.scale)) # type: ignore + d_col.name, + col(d_col.name).cast(DecimalType(precision=18, scale=d_col.dataType.scale)), # type: ignore ) if len(decimal_col_names) > 0: _df = _df.na.fill(0.0, decimal_col_names) diff --git a/src/koheesio/integrations/spark/tableau/server.py b/src/koheesio/integrations/spark/tableau/server.py index 7770f627..14d80a6f 100644 --- a/src/koheesio/integrations/spark/tableau/server.py +++ b/src/koheesio/integrations/spark/tableau/server.py @@ -1,9 +1,8 @@ -import os from typing import Any, ContextManager, Optional, Union from enum import Enum +import os from pathlib import PurePath -import urllib3 # type: ignore from tableauserverclient import ( DatasourceItem, PersonalAccessTokenAuth, @@ -12,6 +11,7 @@ ) from tableauserverclient.server.pager import Pager from tableauserverclient.server.server import Server +import urllib3 # type: ignore from pydantic import Field, SecretStr diff --git a/src/koheesio/logger.py b/src/koheesio/logger.py index cad22138..82474b61 100644 --- a/src/koheesio/logger.py +++ b/src/koheesio/logger.py @@ -29,12 +29,12 @@ from __future__ import annotations +from typing import Any, Dict, Generator, Generic, List, Optional, Tuple, TypeVar import inspect import logging +from logging import Formatter, Logger, LogRecord, getLogger import os import sys -from typing import Any, Dict, Generator, Generic, List, Optional, Tuple, TypeVar -from logging import Formatter, Logger, LogRecord, getLogger from uuid import uuid4 from warnings import warn diff --git a/src/koheesio/models/__init__.py b/src/koheesio/models/__init__.py index dd0a8c8b..a2db492f 100644 --- a/src/koheesio/models/__init__.py +++ b/src/koheesio/models/__init__.py @@ -11,10 +11,10 @@ from __future__ import annotations +from typing import Annotated, Any, Dict, List, Optional, Union from abc import ABC from functools import cached_property from pathlib import Path -from typing import Annotated, Any, Dict, List, Optional, Union # to ensure that koheesio.models is a drop in replacement for pydantic from pydantic import BaseModel as PydanticBaseModel @@ -407,7 +407,9 @@ def __add__(self, other: Union[Dict, BaseModel]) -> BaseModel: ```python step_output_1 = StepOutput(foo="bar") step_output_2 = StepOutput(lorem="ipsum") - (step_output_1 + step_output_2) # step_output_1 will now contain {'foo': 'bar', 'lorem': 'ipsum'} + ( + step_output_1 + step_output_2 + ) # step_output_1 will now contain {'foo': 'bar', 'lorem': 'ipsum'} ``` Parameters @@ -531,7 +533,9 @@ def merge(self, other: Union[Dict, BaseModel]) -> BaseModel: -------- ```python step_output = StepOutput(foo="bar") - step_output.merge({"lorem": "ipsum"}) # step_output will now contain {'foo': 'bar', 'lorem': 'ipsum'} + step_output.merge( + {"lorem": "ipsum"} + ) # step_output will now contain {'foo': 'bar', 'lorem': 'ipsum'} ``` Parameters diff --git a/src/koheesio/models/reader.py b/src/koheesio/models/reader.py index 3f351923..c1227940 100644 --- a/src/koheesio/models/reader.py +++ b/src/koheesio/models/reader.py @@ -2,8 +2,8 @@ Module for the BaseReader class """ -from abc import ABC, abstractmethod from typing import Optional, TypeVar +from abc import ABC, abstractmethod from koheesio import Step diff --git a/src/koheesio/models/sql.py b/src/koheesio/models/sql.py index f19bc967..86ad38e3 100644 --- a/src/koheesio/models/sql.py +++ b/src/koheesio/models/sql.py @@ -1,8 +1,8 @@ """This module contains the base class for SQL steps.""" +from typing import Any, Dict, Optional, Union from abc import ABC from pathlib import Path -from typing import Any, Dict, Optional, Union from koheesio import Step from koheesio.models import ExtraParamsMixin, Field, model_validator diff --git a/src/koheesio/notifications/slack.py b/src/koheesio/notifications/slack.py index 423b2f37..d38655b5 100644 --- a/src/koheesio/notifications/slack.py +++ b/src/koheesio/notifications/slack.py @@ -2,9 +2,9 @@ Classes to ease interaction with Slack """ +from typing import Any, Dict, Optional import datetime import json -from typing import Any, Dict, Optional from textwrap import dedent from koheesio.models import ConfigDict, Field diff --git a/src/koheesio/spark/__init__.py b/src/koheesio/spark/__init__.py index c72cfb0e..2beccf40 100644 --- a/src/koheesio/spark/__init__.py +++ b/src/koheesio/spark/__init__.py @@ -4,9 +4,9 @@ from __future__ import annotations -import warnings -from abc import ABC from typing import Optional +from abc import ABC +import warnings from pydantic import Field diff --git a/src/koheesio/spark/delta.py b/src/koheesio/spark/delta.py index 8d252a68..e4ef31c3 100644 --- a/src/koheesio/spark/delta.py +++ b/src/koheesio/spark/delta.py @@ -2,8 +2,8 @@ Module for creating and managing Delta tables. """ -import warnings from typing import Dict, List, Optional, Union +import warnings from py4j.protocol import Py4JJavaError # type: ignore diff --git a/src/koheesio/spark/readers/hana.py b/src/koheesio/spark/readers/hana.py index 76168560..a98bed87 100644 --- a/src/koheesio/spark/readers/hana.py +++ b/src/koheesio/spark/readers/hana.py @@ -27,11 +27,12 @@ class HanaReader(JdbcReader): ```python from koheesio.spark.readers.hana import HanaReader + jdbc_hana = HanaReader( url="jdbc:sap://:/?", user="YOUR_USERNAME", password="***", - dbtable="schema_name.table_name" + dbtable="schema_name.table_name", ) df = jdbc_hana.read() ``` diff --git a/src/koheesio/spark/readers/memory.py b/src/koheesio/spark/readers/memory.py index 79002058..a90e09ee 100644 --- a/src/koheesio/spark/readers/memory.py +++ b/src/koheesio/spark/readers/memory.py @@ -2,13 +2,14 @@ Create Spark DataFrame directly from the data stored in a Python variable """ -import json +from typing import Any, Dict, Optional, Union from enum import Enum from functools import partial from io import StringIO -from typing import Any, Dict, Optional, Union +import json import pandas as pd + from pyspark.sql.types import StructType from koheesio.models import ExtraParamsMixin, Field diff --git a/src/koheesio/spark/snowflake.py b/src/koheesio/spark/snowflake.py index 67cab023..4a70b8f5 100644 --- a/src/koheesio/spark/snowflake.py +++ b/src/koheesio/spark/snowflake.py @@ -41,10 +41,10 @@ environments and make sure to install required JARs. """ -import json from typing import Any, Callable, Dict, List, Optional, Set, Union from abc import ABC from copy import deepcopy +import json from textwrap import dedent from pyspark.sql import Window @@ -666,9 +666,7 @@ def get_query(self, role: str) -> str: query : str The Query that performs the grant """ - query = ( - f"GRANT {','.join(self.privileges)} ON {self.type} {self.object} TO ROLE {role}".upper() - ) # nosec B608: hardcoded_sql_expressions + query = f"GRANT {','.join(self.privileges)} ON {self.type} {self.object} TO ROLE {role}".upper() # nosec B608: hardcoded_sql_expressions return query def execute(self) -> SnowflakeStep.Output: @@ -950,17 +948,21 @@ class TagSnowflakeQuery(Step, ExtraParamsMixin): Example ------- ```python - query_tag = AddQueryTag( - options={"preactions": ...}, - task_name="cleanse_task", - pipeline_name="ingestion-pipeline", - etl_date="2022-01-01", - pipeline_execution_time="2022-01-01T00:00:00", - task_execution_time="2022-01-01T01:00:00", - environment="dev", - trace_id="acd4f3f96045", - span_id="546d2d66f6cb", - ).execute().options + query_tag = ( + AddQueryTag( + options={"preactions": ...}, + task_name="cleanse_task", + pipeline_name="ingestion-pipeline", + etl_date="2022-01-01", + pipeline_execution_time="2022-01-01T00:00:00", + task_execution_time="2022-01-01T01:00:00", + environment="dev", + trace_id="acd4f3f96045", + span_id="546d2d66f6cb", + ) + .execute() + .options + ) ``` """ @@ -1320,7 +1322,7 @@ def extract(self) -> DataFrame: raise RuntimeError( f"Source table {self.source_table.table_name} does not have CDF enabled. " f"Set TBLPROPERTIES ('delta.enableChangeDataFeed' = true) to enable. " - f"Current properties = {self.source_table_properties}" + f"Current properties = {self.source_table.get_persisted_properties()}" ) df = self.reader.read() diff --git a/src/koheesio/spark/transformations/__init__.py b/src/koheesio/spark/transformations/__init__.py index 3f273a85..c44a9981 100644 --- a/src/koheesio/spark/transformations/__init__.py +++ b/src/koheesio/spark/transformations/__init__.py @@ -56,7 +56,9 @@ class Transformation(SparkStep, ABC): class AddOne(Transformation): def execute(self): - self.output.df = self.df.withColumn("new_column", f.col("old_column") + 1) + self.output.df = self.df.withColumn( + "new_column", f.col("old_column") + 1 + ) ``` In the example above, the `execute` method is implemented to add 1 to the values of the `old_column` and store the diff --git a/src/koheesio/spark/transformations/camel_to_snake.py b/src/koheesio/spark/transformations/camel_to_snake.py index 33f5d237..f62c822b 100644 --- a/src/koheesio/spark/transformations/camel_to_snake.py +++ b/src/koheesio/spark/transformations/camel_to_snake.py @@ -2,8 +2,8 @@ Class for converting DataFrame column names from camel case to snake case. """ -import re from typing import Optional +import re from koheesio.models import Field, ListOfColumns from koheesio.spark.transformations import ColumnsTransformation diff --git a/src/koheesio/spark/transformations/date_time/interval.py b/src/koheesio/spark/transformations/date_time/interval.py index e30244aa..e56bd548 100644 --- a/src/koheesio/spark/transformations/date_time/interval.py +++ b/src/koheesio/spark/transformations/date_time/interval.py @@ -304,9 +304,7 @@ def adjust_time(column: Column, operation: Operations, interval: str) -> Column: operation = { "add": "try_add", "subtract": "try_subtract", - }[ - operation - ] # type: ignore + }[operation] # type: ignore except KeyError as e: raise ValueError(f"Operation '{operation}' is not valid. Must be either 'add' or 'subtract'.") from e diff --git a/src/koheesio/spark/transformations/drop_column.py b/src/koheesio/spark/transformations/drop_column.py index d4da7772..9cd114e1 100644 --- a/src/koheesio/spark/transformations/drop_column.py +++ b/src/koheesio/spark/transformations/drop_column.py @@ -46,5 +46,5 @@ class DropColumn(ColumnsTransformation): """ def execute(self) -> ColumnsTransformation.Output: - self.log.info(f"{self.column=}") + self.log.info(f"{self.columns=}") self.output.df = self.df.drop(*self.columns) diff --git a/src/koheesio/spark/transformations/strings/trim.py b/src/koheesio/spark/transformations/strings/trim.py index ce116e24..6f72eae8 100644 --- a/src/koheesio/spark/transformations/strings/trim.py +++ b/src/koheesio/spark/transformations/strings/trim.py @@ -15,8 +15,8 @@ from typing import Literal -import pyspark.sql.functions as f from pyspark.sql import Column +import pyspark.sql.functions as f from koheesio.models import Field, ListOfColumns from koheesio.spark.transformations import ColumnsTransformationWithTarget diff --git a/src/koheesio/spark/transformations/uuid5.py b/src/koheesio/spark/transformations/uuid5.py index 545a2f9d..cd709a86 100644 --- a/src/koheesio/spark/transformations/uuid5.py +++ b/src/koheesio/spark/transformations/uuid5.py @@ -1,7 +1,7 @@ """Ability to generate UUID5 using native pyspark (no udf)""" -import uuid from typing import Optional, Union +import uuid from pyspark.sql import functions as f diff --git a/src/koheesio/spark/utils/common.py b/src/koheesio/spark/utils/common.py index 10050d5f..1f9b47cb 100644 --- a/src/koheesio/spark/utils/common.py +++ b/src/koheesio/spark/utils/common.py @@ -2,11 +2,11 @@ Spark Utility functions """ +from typing import Union +from enum import Enum import importlib import inspect import os -from typing import Union -from enum import Enum from types import ModuleType from pyspark import sql diff --git a/src/koheesio/spark/writers/buffer.py b/src/koheesio/spark/writers/buffer.py index e94b5f81..e83880e0 100644 --- a/src/koheesio/spark/writers/buffer.py +++ b/src/koheesio/spark/writers/buffer.py @@ -15,10 +15,10 @@ from __future__ import annotations -import gzip from typing import AnyStr, Literal, Optional from abc import ABC from functools import partial +import gzip from os import linesep from tempfile import SpooledTemporaryFile @@ -252,6 +252,15 @@ class PandasCsvBufferWriter(BufferWriter, ExtraParamsMixin): "by default. Can be set to one of 'infer', 'gzip', 'bz2', 'zip', 'xz', 'zstd', or 'tar'. " "See Pandas documentation for more details.", ) + emptyValue: Optional[str] = Field( + default="", + description="The string to use for missing values. Koheesio sets this default to an empty string.", + ) + + nullValue: Optional[str] = Field( + default="", + description="The string to use for missing values. Koheesio sets this default to an empty string.", + ) # -- Pandas specific properties -- index: bool = Field( diff --git a/src/koheesio/spark/writers/delta/batch.py b/src/koheesio/spark/writers/delta/batch.py index 7fd8376b..6959ef0e 100644 --- a/src/koheesio/spark/writers/delta/batch.py +++ b/src/koheesio/spark/writers/delta/batch.py @@ -34,18 +34,19 @@ ``` """ +from typing import Callable, Dict, List, Optional, Set, Type, Union from functools import partial -from typing import List, Optional, Set, Type, Union -from delta.tables import DeltaMergeBuilder, DeltaTable +from delta.tables import DeltaMergeBuilder from py4j.protocol import Py4JError + from pyspark.sql import DataFrameWriter from koheesio.models import ExtraParamsMixin, Field, field_validator from koheesio.spark.delta import DeltaTableStep from koheesio.spark.utils import on_databricks from koheesio.spark.writers import BatchOutputMode, StreamingOutputMode, Writer -from koheesio.spark.writers.delta.utils import log_clauses +from koheesio.spark.writers.delta.utils import get_delta_table_for_name, log_clauses class DeltaTableWriter(Writer, ExtraParamsMixin): @@ -157,8 +158,9 @@ def __merge(self, merge_builder: Optional[DeltaMergeBuilder] = None) -> Union[De if self.table.exists: merge_builder = self._get_merge_builder(merge_builder) + from koheesio.spark.utils.connect import is_remote_session - if on_databricks(): + if on_databricks() and not is_remote_session(): try: source_alias = merge_builder._jbuilder.getMergePlan().source().alias() target_alias = merge_builder._jbuilder.getMergePlan().target().alias() @@ -219,7 +221,7 @@ def __merge_all(self) -> Union[DeltaMergeBuilder, DataFrameWriter]: if self.table.exists: builder = ( - DeltaTable.forName(sparkSession=self.spark, tableOrViewName=self.table.table_name) + get_delta_table_for_name(spark_session=self.spark, table_name=self.table.table_name) .alias(target_alias) .merge(source=self.df.alias(source_alias), condition=merge_cond) .whenMatchedUpdateAll(condition=update_cond) @@ -266,7 +268,7 @@ def _merge_builder_from_args(self) -> DeltaMergeBuilder: target_alias = self.params.get("target_alias", "target") builder = ( - DeltaTable.forName(self.spark, self.table.table_name) + get_delta_table_for_name(spark_session=self.spark, table_name=self.table.table_name) .alias(target_alias) .merge(self.df.alias(source_alias), merge_cond) ) @@ -359,7 +361,7 @@ def __data_frame_writer(self) -> DataFrameWriter: @property def writer(self) -> Union[DeltaMergeBuilder, DataFrameWriter]: """Specify DeltaTableWriter""" - map_mode_to_writer = { + map_mode_to_writer: Dict[str, Callable] = { BatchOutputMode.MERGEALL.value: self.__merge_all, BatchOutputMode.MERGE.value: self.__merge, } diff --git a/src/koheesio/spark/writers/delta/scd.py b/src/koheesio/spark/writers/delta/scd.py index f93762ef..6f0087ef 100644 --- a/src/koheesio/spark/writers/delta/scd.py +++ b/src/koheesio/spark/writers/delta/scd.py @@ -30,6 +30,7 @@ from koheesio.spark.delta import DeltaTableStep from koheesio.spark.functions import current_timestamp_utc from koheesio.spark.writers import Writer +from koheesio.spark.writers.delta.utils import get_delta_table_for_name class SCD2DeltaTableWriter(Writer): @@ -476,7 +477,7 @@ def execute(self) -> None: """ self.df: DataFrame self.spark: SparkSession - delta_table = DeltaTable.forName(sparkSession=self.spark, tableOrViewName=self.table.table_name) + delta_table = get_delta_table_for_name(spark_session=self.spark, table_name=self.table.table_name) src_alias, cross_alias, dest_alias = "src", "cross", "tgt" # Prepare required merge columns diff --git a/src/koheesio/spark/writers/delta/stream.py b/src/koheesio/spark/writers/delta/stream.py index aea03a57..ada8f5b4 100644 --- a/src/koheesio/spark/writers/delta/stream.py +++ b/src/koheesio/spark/writers/delta/stream.py @@ -2,8 +2,8 @@ This module defines the DeltaTableStreamWriter class, which is used to write streaming dataframes to Delta tables. """ -from email.policy import default from typing import Optional +from email.policy import default from pydantic import Field diff --git a/src/koheesio/spark/writers/delta/utils.py b/src/koheesio/spark/writers/delta/utils.py index 2e08a16f..03c5d75e 100644 --- a/src/koheesio/spark/writers/delta/utils.py +++ b/src/koheesio/spark/writers/delta/utils.py @@ -4,7 +4,21 @@ from typing import Optional -from py4j.java_gateway import JavaObject # type: ignore[import-untyped] +from delta import DeltaTable +from py4j.java_gateway import JavaObject + +from koheesio.spark import SparkSession +from koheesio.spark.utils import SPARK_MINOR_VERSION + + +class SparkConnectDeltaTableException(AttributeError): + EXCEPTION_CONNECT_TEXT: str = """`DeltaTable.forName` is not supported due to delta calling _sc, + which is not available in Spark Connect and PySpark>=3.5,<4.0. Required version of PySpark >=4.0. + Possible workaround to use spark.read and Spark SQL for any Delta operation (e.g. merge)""" + + def __init__(self, original_exception: AttributeError): + custom_message = f"{self.EXCEPTION_CONNECT_TEXT}\nOriginal exception: {str(original_exception)}" + super().__init__(custom_message) def log_clauses(clauses: JavaObject, source_alias: str, target_alias: str) -> Optional[str]: @@ -68,3 +82,41 @@ def log_clauses(clauses: JavaObject, source_alias: str, target_alias: str) -> Op ) return log_message + + +def get_delta_table_for_name(spark_session: SparkSession, table_name: str) -> DeltaTable: + """ + Retrieves the DeltaTable instance for the specified table name. + + This method attempts to get the DeltaTable using the provided Spark session and table name. + If an AttributeError occurs and the Spark version is between 3.4 and 4.0, and the session is remote, + it raises a SparkConnectDeltaTableException. + + Parameters + ---------- + spark_session : SparkSession + The Spark Session to use. + table_name : str + The table name. + + Returns + ------- + DeltaTable + The DeltaTable instance for the specified table name. + + Raises + ------ + SparkConnectDeltaTableException + If the Spark version is between 3.4 and 4.0, the session is remote, and an AttributeError occurs. + """ + try: + delta_table = DeltaTable.forName(sparkSession=spark_session, tableOrViewName=table_name) + except AttributeError as e: + from koheesio.spark.utils.connect import is_remote_session + + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + raise SparkConnectDeltaTableException(e) from e + else: + raise e + + return delta_table diff --git a/src/koheesio/spark/writers/dummy.py b/src/koheesio/spark/writers/dummy.py index 5e90a989..a6183811 100644 --- a/src/koheesio/spark/writers/dummy.py +++ b/src/koheesio/spark/writers/dummy.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Union from koheesio.models import Field, PositiveInt, field_validator +from koheesio.spark import DataFrame from koheesio.spark.utils import show_string from koheesio.spark.writers import Writer diff --git a/src/koheesio/steps/__init__.py b/src/koheesio/steps/__init__.py index 89c04ba0..5a1faa7c 100644 --- a/src/koheesio/steps/__init__.py +++ b/src/koheesio/steps/__init__.py @@ -16,18 +16,18 @@ from __future__ import annotations +from typing import Any, Callable, Optional +from abc import abstractmethod +from functools import partialmethod, wraps import inspect import json import sys import warnings -from typing import Any, Callable, Union -from abc import abstractmethod -from functools import partialmethod, wraps import yaml from pydantic import BaseModel as PydanticBaseModel -from pydantic import InstanceOf +from pydantic import InstanceOf, PrivateAttr from koheesio.models import BaseModel, ConfigDict, ModelMetaclass @@ -535,21 +535,21 @@ def execute(self) -> MyStep.Output: class Output(StepOutput): """Output class for Step""" - __output__: Output + _output: Optional[Output] = PrivateAttr(default=None) @property def output(self) -> Output: """Interact with the output of the Step""" - if not self.__output__: - self.__output__ = self.Output.lazy() - self.__output__.name = self.name + ".Output" # type: ignore[operator] - self.__output__.description = "Output for " + self.name # type: ignore[operator] - return self.__output__ + if not self._output: + self._output = self.Output.lazy() + self._output.name = self.name + ".Output" # type: ignore + self._output.description = "Output for " + self.name # type: ignore + return self._output @output.setter def output(self, value: Output) -> None: """Set the output of the Step""" - self.__output__ = value + self._output = value @abstractmethod def execute(self) -> InstanceOf[StepOutput]: @@ -675,23 +675,6 @@ def repr_yaml(self, simple: bool = False) -> str: return yaml.dump(_result) - def __getattr__(self, key: str) -> Union[Any, None]: - """__getattr__ dunder - - Allows input to be accessed through `self.input_name` - - Parameters - ---------- - key: str - Name of the attribute to return the value of - - Returns - ------- - Any - The value of the attribute - """ - return self.model_dump().get(key) - @classmethod def from_step(cls, step: Step, **kwargs) -> InstanceOf[PydanticBaseModel]: # type: ignore[no-untyped-def] """Returns a new Step instance based on the data of another Step or BaseModel instance""" diff --git a/src/koheesio/steps/http.py b/src/koheesio/steps/http.py index 68329cc7..8a16a8f3 100644 --- a/src/koheesio/steps/http.py +++ b/src/koheesio/steps/http.py @@ -12,9 +12,9 @@ In the above example, the `response` variable will contain the JSON response from the HTTP request. """ -import json from typing import Any, Dict, List, Optional, Union from enum import Enum +import json import requests # type: ignore[import-untyped] diff --git a/src/koheesio/utils.py b/src/koheesio/utils.py index 253a985e..0556a394 100644 --- a/src/koheesio/utils.py +++ b/src/koheesio/utils.py @@ -2,14 +2,14 @@ Utility functions """ -import datetime -import inspect -import uuid from typing import Any, Callable, Dict, Optional, Tuple +import datetime from functools import partial from importlib import import_module +import inspect from pathlib import Path from sys import version_info as PYTHON_VERSION +import uuid __all__ = [ "get_args_for_func", diff --git a/tests/asyncio/test_asyncio_http.py b/tests/asyncio/test_asyncio_http.py index 8625c710..5dcbc119 100644 --- a/tests/asyncio/test_asyncio_http.py +++ b/tests/asyncio/test_asyncio_http.py @@ -1,8 +1,8 @@ import warnings -import pytest from aiohttp import ClientResponseError, ClientSession, TCPConnector from aiohttp_retry import ExponentialRetry +import pytest from yarl import URL from pydantic import ValidationError diff --git a/tests/conftest.py b/tests/conftest.py index a0090a0c..36e7d748 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import os +from pathlib import Path import time import uuid -from pathlib import Path import pytest diff --git a/tests/core/test_logger.py b/tests/core/test_logger.py index 277d7664..c739ae8f 100644 --- a/tests/core/test_logger.py +++ b/tests/core/test_logger.py @@ -1,5 +1,5 @@ -import logging from io import StringIO +import logging from logging import Logger from unittest.mock import MagicMock, patch diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 3d981f95..29e0ea6d 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -1,5 +1,5 @@ -import json from typing import Optional +import json from textwrap import dedent import pytest diff --git a/tests/snowflake/test_snowflake.py b/tests/snowflake/test_snowflake.py index 0541bdf8..8cf2fb4b 100644 --- a/tests/snowflake/test_snowflake.py +++ b/tests/snowflake/test_snowflake.py @@ -1,8 +1,8 @@ # flake8: noqa: F811 from unittest import mock -import pytest from pydantic_core._pydantic_core import ValidationError +import pytest from koheesio.integrations.snowflake import ( GrantPrivilegesOnObject, @@ -158,7 +158,6 @@ def test_with_missing_dependencies(self): class TestSnowflakeBaseModel: - def test_get_options_using_alias(self): """Test that the options are correctly generated using alias""" k = SnowflakeBaseModel( diff --git a/tests/spark/conftest.py b/tests/spark/conftest.py index b0a7c51e..f918ae40 100644 --- a/tests/spark/conftest.py +++ b/tests/spark/conftest.py @@ -1,15 +1,15 @@ +from collections import namedtuple import datetime +from decimal import Decimal import os +from pathlib import Path import socket import sys -from collections import namedtuple -from decimal import Decimal -from pathlib import Path from textwrap import dedent from unittest import mock -import pytest from delta import configure_spark_with_delta_pip +import pytest from pyspark.sql import SparkSession from pyspark.sql.types import ( diff --git a/tests/spark/integrations/snowflake/test_sync_task.py b/tests/spark/integrations/snowflake/test_sync_task.py index a4c50e8a..980e03ee 100644 --- a/tests/spark/integrations/snowflake/test_sync_task.py +++ b/tests/spark/integrations/snowflake/test_sync_task.py @@ -3,13 +3,12 @@ from unittest import mock import chispa -import pytest from conftest import await_job_completion +import pytest import pydantic from koheesio.integrations.snowflake import SnowflakeRunQueryPython -from koheesio.integrations.snowflake.test_utils import mock_query from koheesio.integrations.spark.snowflake import ( SnowflakeWriter, SynchronizeDeltaToSnowflakeTask, diff --git a/tests/spark/readers/test_auto_loader.py b/tests/spark/readers/test_auto_loader.py index 8f2b168c..71e6cea2 100644 --- a/tests/spark/readers/test_auto_loader.py +++ b/tests/spark/readers/test_auto_loader.py @@ -1,5 +1,5 @@ -import pytest from chispa import assert_df_equality +import pytest from pyspark.sql.types import * diff --git a/tests/spark/readers/test_memory.py b/tests/spark/readers/test_memory.py index 21b5d53f..19c5b5c2 100644 --- a/tests/spark/readers/test_memory.py +++ b/tests/spark/readers/test_memory.py @@ -1,5 +1,6 @@ -import pytest from chispa import assert_df_equality +import pytest + from pyspark.sql.types import StructType from koheesio.spark.readers.memory import DataFormat, InMemoryDataReader diff --git a/tests/spark/readers/test_rest_api.py b/tests/spark/readers/test_rest_api.py index 9c22ea3f..0803f630 100644 --- a/tests/spark/readers/test_rest_api.py +++ b/tests/spark/readers/test_rest_api.py @@ -1,7 +1,7 @@ -import pytest -import requests_mock from aiohttp import ClientSession, TCPConnector from aiohttp_retry import ExponentialRetry +import pytest +import requests_mock from yarl import URL from pyspark.sql.types import MapType, StringType, StructField, StructType diff --git a/tests/spark/test_delta.py b/tests/spark/test_delta.py index 1806ac09..920d2b6b 100644 --- a/tests/spark/test_delta.py +++ b/tests/spark/test_delta.py @@ -3,8 +3,8 @@ from pathlib import Path from unittest.mock import patch -import pytest from conftest import setup_test_data +import pytest from pydantic import ValidationError diff --git a/tests/spark/writers/delta/test_delta_writer.py b/tests/spark/writers/delta/test_delta_writer.py index 92a349c8..c916b0da 100644 --- a/tests/spark/writers/delta/test_delta_writer.py +++ b/tests/spark/writers/delta/test_delta_writer.py @@ -1,9 +1,8 @@ import os from unittest.mock import MagicMock, patch -import pytest from conftest import await_job_completion -from delta import DeltaTable +import pytest from pydantic import ValidationError @@ -14,13 +13,14 @@ from koheesio.spark.utils import SPARK_MINOR_VERSION from koheesio.spark.writers import BatchOutputMode, StreamingOutputMode from koheesio.spark.writers.delta import DeltaTableStreamWriter, DeltaTableWriter -from koheesio.spark.writers.delta.utils import log_clauses +from koheesio.spark.writers.delta.utils import ( + SparkConnectDeltaTableException, + log_clauses, +) from koheesio.spark.writers.stream import Trigger pytestmark = pytest.mark.spark -skip_reason = "Tests are not working with PySpark 3.5 due to delta calling _sc. Test requires pyspark version >= 4.0" - def test_delta_table_writer(dummy_df, spark): table_name = "test_table" @@ -53,9 +53,6 @@ def test_delta_partitioning(spark, sample_df_to_partition): def test_delta_table_merge_all(spark): from koheesio.spark.utils.connect import is_remote_session - if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): - pytest.skip(reason=skip_reason) - table_name = "test_merge_all_table" target_df = spark.createDataFrame( [{"id": 1, "value": "no_merge"}, {"id": 2, "value": "expected_merge"}, {"id": 5, "value": "xxxx"}] @@ -77,7 +74,7 @@ def test_delta_table_merge_all(spark): 5: "xxxx", } DeltaTableWriter(table=table_name, output_mode=BatchOutputMode.APPEND, df=target_df).execute() - DeltaTableWriter( + merge_writer = DeltaTableWriter( table=table_name, output_mode=BatchOutputMode.MERGEALL, output_mode_params={ @@ -86,33 +83,46 @@ def test_delta_table_merge_all(spark): "insert_cond": F.expr("source.value IS NOT NULL"), }, df=source_df, - ).execute() - result = { - list(row.asDict().values())[0]: list(row.asDict().values())[1] for row in spark.read.table(table_name).collect() - } - assert result == expected + ) + + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + merge_writer.execute() + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + merge_writer.execute() + result = { + list(row.asDict().values())[0]: list(row.asDict().values())[1] + for row in spark.read.table(table_name).collect() + } + assert result == expected def test_deltatablewriter_with_invalid_conditions(spark, dummy_df): from koheesio.spark.utils.connect import is_remote_session - - if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): - pytest.skip(reason=skip_reason) + from koheesio.spark.writers.delta.utils import get_delta_table_for_name table_name = "delta_test_table" - merge_builder = ( - DeltaTable.forName(sparkSession=spark, tableOrViewName=table_name) - .alias("target") - .merge(condition="invalid_condition", source=dummy_df.alias("source")) - ) - writer = DeltaTableWriter( - table=table_name, - output_mode=BatchOutputMode.MERGE, - output_mode_params={"merge_builder": merge_builder}, - df=dummy_df, - ) - with pytest.raises(AnalysisException): - writer.execute() + + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + builder = get_delta_table_for_name(spark_session=spark, table_name=table_name) + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + with pytest.raises(AnalysisException): + builder = get_delta_table_for_name(spark_session=spark, table_name=table_name) + merge_builder = builder.alias("target").merge( + condition="invalid_condition", source=dummy_df.alias("source") + ) + writer = DeltaTableWriter( + table=table_name, + output_mode=BatchOutputMode.MERGE, + output_mode_params={"merge_builder": merge_builder}, + df=dummy_df, + ) + writer.execute() @patch.dict( @@ -286,9 +296,6 @@ def test_delta_with_options(spark): def test_merge_from_args(spark, dummy_df): from koheesio.spark.utils.connect import is_remote_session - if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): - pytest.skip(reason=skip_reason) - table_name = "test_table_merge_from_args" dummy_df.write.format("delta").saveAsTable(table_name) @@ -316,14 +323,20 @@ def test_merge_from_args(spark, dummy_df): "merge_cond": "source.id=target.id", }, ) - writer._merge_builder_from_args() - mock_delta_builder.whenMatchedUpdate.assert_called_once_with( - set={"id": "source.id"}, condition="source.id=target.id" - ) - mock_delta_builder.whenNotMatchedInsert.assert_called_once_with( - values={"id": "source.id"}, condition="source.id IS NOT NULL" - ) + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + writer._merge_builder_from_args() + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + writer._merge_builder_from_args() + mock_delta_builder.whenMatchedUpdate.assert_called_once_with( + set={"id": "source.id"}, condition="source.id=target.id" + ) + mock_delta_builder.whenNotMatchedInsert.assert_called_once_with( + values={"id": "source.id"}, condition="source.id IS NOT NULL" + ) @pytest.mark.parametrize( @@ -350,9 +363,6 @@ def test_merge_from_args_raise_value_error(spark, output_mode_params): def test_merge_no_table(spark): from koheesio.spark.utils.connect import is_remote_session - if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): - pytest.skip(reason=skip_reason) - table_name = "test_merge_no_table" target_df = spark.createDataFrame( [{"id": 1, "value": "no_merge"}, {"id": 2, "value": "expected_merge"}, {"id": 5, "value": "expected_merge"}] @@ -388,20 +398,29 @@ def test_merge_no_table(spark): ], "merge_cond": "source.id=target.id", } - - DeltaTableWriter( + writer1 = DeltaTableWriter( df=target_df, table=table_name, output_mode=BatchOutputMode.MERGE, output_mode_params=params - ).execute() - - DeltaTableWriter( + ) + writer2 = DeltaTableWriter( df=source_df, table=table_name, output_mode=BatchOutputMode.MERGE, output_mode_params=params - ).execute() + ) + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + writer1.execute() - result = { - list(row.asDict().values())[0]: list(row.asDict().values())[1] for row in spark.read.table(table_name).collect() - } + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + writer2.execute() - assert result == expected + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + writer1.execute() + writer2.execute() + + result = { + list(row.asDict().values())[0]: list(row.asDict().values())[1] + for row in spark.read.table(table_name).collect() + } + + assert result == expected def test_log_clauses(mocker): diff --git a/tests/spark/writers/delta/test_scd.py b/tests/spark/writers/delta/test_scd.py index 087f957d..df4106b9 100644 --- a/tests/spark/writers/delta/test_scd.py +++ b/tests/spark/writers/delta/test_scd.py @@ -1,9 +1,9 @@ -import datetime from typing import List, Optional +import datetime -import pytest from delta import DeltaTable from delta.tables import DeltaMergeBuilder +import pytest from pydantic import Field @@ -16,18 +16,14 @@ from koheesio.spark.functions import current_timestamp_utc from koheesio.spark.utils import SPARK_MINOR_VERSION from koheesio.spark.writers.delta.scd import SCD2DeltaTableWriter +from koheesio.spark.writers.delta.utils import SparkConnectDeltaTableException pytestmark = pytest.mark.spark -skip_reason = "Tests are not working with PySpark 3.5 due to delta calling _sc. Test requires pyspark version >= 4.0" - def test_scd2_custom_logic(spark): from koheesio.spark.utils.connect import is_remote_session - if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): - pytest.skip(reason=skip_reason) - def _get_result(target_df: DataFrame, expr: str): res = ( target_df.where(expr) @@ -145,124 +141,161 @@ def _prepare_merge_builder( meta_scd2_end_time_col_name="valid_to_timestamp", df=source_df, ) - writer.execute() - expected = { - "id": 4, - "last_updated_at": datetime.datetime(2024, 4, 1, 0, 0), - "valid_from_timestamp": datetime.datetime(1970, 1, 1, 0, 0), - "valid_to_timestamp": datetime.datetime(2999, 12, 31, 0, 0), - "value": "value-4", - } + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + writer.execute() + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + writer.execute() - target_df = spark.read.table(target_table) - result = _get_result(target_df, "id = 4") + expected = { + "id": 4, + "last_updated_at": datetime.datetime(2024, 4, 1, 0, 0), + "valid_from_timestamp": datetime.datetime(1970, 1, 1, 0, 0), + "valid_to_timestamp": datetime.datetime(2999, 12, 31, 0, 0), + "value": "value-4", + } - assert spark.table(target_table).count() == 4 - assert spark.table(target_table).where("meta.valid_to_timestamp = '2999-12-31'").count() == 4 - assert result == expected + target_df = spark.read.table(target_table) + result = _get_result(target_df, "id = 4") + + assert spark.table(target_table).count() == 4 + assert spark.table(target_table).where("meta.valid_to_timestamp = '2999-12-31'").count() == 4 + assert result == expected source_df2 = source_df.withColumn( "value", F.expr("CASE WHEN id = 2 THEN 'value-2-change' ELSE value END") ).withColumn("last_updated_at", F.expr("CASE WHEN id = 2 THEN TIMESTAMP'2024-02-02' ELSE last_updated_at END")) writer.df = source_df2 - writer.execute() - - expected_insert = { - "id": 2, - "last_updated_at": datetime.datetime(2024, 2, 2, 0, 0), - "valid_from_timestamp": datetime.datetime(2024, 2, 2, 0, 0), - "valid_to_timestamp": datetime.datetime(2999, 12, 31, 0, 0), - "value": "value-2-change", - } - - expected_update = { - "id": 2, - "last_updated_at": datetime.datetime(2024, 2, 1, 0, 0), - "valid_from_timestamp": datetime.datetime(1970, 1, 1, 0, 0), - "valid_to_timestamp": datetime.datetime(2024, 2, 2, 0, 0), - "value": "value-2", - } - - result_insert = _get_result(target_df, "id = 2 and meta.valid_to_timestamp = '2999-12-31'") - result_update = _get_result(target_df, "id = 2 and meta.valid_from_timestamp = '1970-01-01'") - - assert spark.table(target_table).count() == 5 - assert spark.table(target_table).where("meta.valid_to_timestamp = '2999-12-31'").count() == 4 - assert result_insert == expected_insert - assert result_update == expected_update + + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + writer.execute() + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + writer.execute() + + expected_insert = { + "id": 2, + "last_updated_at": datetime.datetime(2024, 2, 2, 0, 0), + "valid_from_timestamp": datetime.datetime(2024, 2, 2, 0, 0), + "valid_to_timestamp": datetime.datetime(2999, 12, 31, 0, 0), + "value": "value-2-change", + } + + expected_update = { + "id": 2, + "last_updated_at": datetime.datetime(2024, 2, 1, 0, 0), + "valid_from_timestamp": datetime.datetime(1970, 1, 1, 0, 0), + "valid_to_timestamp": datetime.datetime(2024, 2, 2, 0, 0), + "value": "value-2", + } + + result_insert = _get_result(target_df, "id = 2 and meta.valid_to_timestamp = '2999-12-31'") + result_update = _get_result(target_df, "id = 2 and meta.valid_from_timestamp = '1970-01-01'") + + assert spark.table(target_table).count() == 5 + assert spark.table(target_table).where("meta.valid_to_timestamp = '2999-12-31'").count() == 4 + assert result_insert == expected_insert + assert result_update == expected_update source_df3 = source_df2.withColumn( "value", F.expr("CASE WHEN id = 3 THEN 'value-3-change' ELSE value END") ).withColumn("last_updated_at", F.expr("CASE WHEN id = 3 THEN TIMESTAMP'2024-03-02' ELSE last_updated_at END")) writer.df = source_df3 - writer.execute() - - expected_insert = { - "id": 3, - "last_updated_at": datetime.datetime(2024, 3, 2, 0, 0), - "valid_from_timestamp": datetime.datetime(2024, 3, 2, 0, 0), - "valid_to_timestamp": datetime.datetime(2999, 12, 31, 0, 0), - "value": "value-3-change", - } - - expected_update = { - "id": 3, - "last_updated_at": datetime.datetime(2024, 3, 1, 0, 0), - "valid_from_timestamp": datetime.datetime(1970, 1, 1, 0, 0), - "valid_to_timestamp": datetime.datetime(2024, 3, 2, 0, 0), - "value": None, - } - - result_insert = _get_result(target_df, "id = 3 and meta.valid_to_timestamp = '2999-12-31'") - result_update = _get_result(target_df, "id = 3 and meta.valid_from_timestamp = '1970-01-01'") - - assert spark.table(target_table).count() == 6 - assert spark.table(target_table).where("meta.valid_to_timestamp = '2999-12-31'").count() == 4 - assert result_insert == expected_insert - assert result_update == expected_update + + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + writer.execute() + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + writer.execute() + + expected_insert = { + "id": 3, + "last_updated_at": datetime.datetime(2024, 3, 2, 0, 0), + "valid_from_timestamp": datetime.datetime(2024, 3, 2, 0, 0), + "valid_to_timestamp": datetime.datetime(2999, 12, 31, 0, 0), + "value": "value-3-change", + } + + expected_update = { + "id": 3, + "last_updated_at": datetime.datetime(2024, 3, 1, 0, 0), + "valid_from_timestamp": datetime.datetime(1970, 1, 1, 0, 0), + "valid_to_timestamp": datetime.datetime(2024, 3, 2, 0, 0), + "value": None, + } + + result_insert = _get_result(target_df, "id = 3 and meta.valid_to_timestamp = '2999-12-31'") + result_update = _get_result(target_df, "id = 3 and meta.valid_from_timestamp = '1970-01-01'") + + assert spark.table(target_table).count() == 6 + assert spark.table(target_table).where("meta.valid_to_timestamp = '2999-12-31'").count() == 4 + assert result_insert == expected_insert + assert result_update == expected_update source_df4 = source_df3.where("id != 4") writer.df = source_df4 - writer.execute() + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + writer.execute() + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + writer.execute() - assert spark.table(target_table).count() == 6 - assert spark.table(target_table).where("id = 4 and meta.valid_to_timestamp = '2999-12-31'").count() == 1 + assert spark.table(target_table).count() == 6 + assert spark.table(target_table).where("id = 4 and meta.valid_to_timestamp = '2999-12-31'").count() == 1 writer.orphaned_records_close = True - writer.execute() + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + writer.execute() + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + writer.execute() - assert spark.table(target_table).count() == 6 - assert spark.table(target_table).where("id = 4 and meta.valid_to_timestamp = '2999-12-31'").count() == 0 + assert spark.table(target_table).count() == 6 + assert spark.table(target_table).where("id = 4 and meta.valid_to_timestamp = '2999-12-31'").count() == 0 source_df5 = source_df4.where("id != 5") writer.orphaned_records_close_ts = F.col("snapshot_at") writer.df = source_df5 - writer.execute() - expected = { - "id": 5, - "last_updated_at": datetime.datetime(2024, 5, 1, 0, 0), - "valid_from_timestamp": datetime.datetime(1970, 1, 1, 0, 0), - "valid_to_timestamp": datetime.datetime(2024, 12, 31, 0, 0), - "value": "value-5", - } - result = _get_result(target_df, "id = 5") + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + writer.execute() + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + writer.execute() + + expected = { + "id": 5, + "last_updated_at": datetime.datetime(2024, 5, 1, 0, 0), + "valid_from_timestamp": datetime.datetime(1970, 1, 1, 0, 0), + "valid_to_timestamp": datetime.datetime(2024, 12, 31, 0, 0), + "value": "value-5", + } + result = _get_result(target_df, "id = 5") - assert spark.table(target_table).count() == 6 - assert spark.table(target_table).where("id = 5 and meta.valid_to_timestamp = '2999-12-31'").count() == 0 - assert result == expected + assert spark.table(target_table).count() == 6 + assert spark.table(target_table).where("id = 5 and meta.valid_to_timestamp = '2999-12-31'").count() == 0 + assert result == expected def test_scd2_logic(spark): from koheesio.spark.utils.connect import is_remote_session - if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): - pytest.skip(reason=skip_reason) - changes_data = [ [("key1", "value1", "scd1-value11", "2024-05-01"), ("key2", "value2", "scd1-value21", "2024-04-01")], [("key1", "value1_updated", "scd1-value12", "2024-05-02"), ("key3", "value3", "scd1-value31", "2024-05-03")], @@ -400,11 +433,17 @@ def test_scd2_logic(spark): changes_df = spark.createDataFrame(changes, ["merge_key", "value_scd2", "value_scd1", "run_date"]) changes_df = changes_df.withColumn("run_date", F.to_timestamp("run_date")) writer.df = changes_df - writer.execute() - res = ( - spark.sql("SELECT merge_key,value_scd2, value_scd1, _scd2.* FROM scd2_test_data_set") - .orderBy("merge_key", "effective_time") - .collect() - ) + if 3.4 < SPARK_MINOR_VERSION < 4.0 and is_remote_session(): + with pytest.raises(SparkConnectDeltaTableException) as exc_info: + writer.execute() + + assert str(exc_info.value).startswith("`DeltaTable.forName` is not supported due to delta calling _sc") + else: + writer.execute() + res = ( + spark.sql("SELECT merge_key,value_scd2, value_scd1, _scd2.* FROM scd2_test_data_set") + .orderBy("merge_key", "effective_time") + .collect() + ) - assert res == expected + assert res == expected diff --git a/tests/spark/writers/test_buffer.py b/tests/spark/writers/test_buffer.py index 6da783a6..141bf578 100644 --- a/tests/spark/writers/test_buffer.py +++ b/tests/spark/writers/test_buffer.py @@ -1,5 +1,5 @@ -import gzip from datetime import datetime, timezone +import gzip from importlib.util import find_spec import pytest diff --git a/tests/spark/writers/test_sftp.py b/tests/spark/writers/test_sftp.py index 7119edda..a19e2fbd 100644 --- a/tests/spark/writers/test_sftp.py +++ b/tests/spark/writers/test_sftp.py @@ -1,8 +1,8 @@ from unittest import mock import paramiko -import pytest from paramiko import SSHException +import pytest from koheesio.integrations.spark.sftp import ( SendCsvToSftp, diff --git a/tests/sso/test_okta.py b/tests/sso/test_okta.py index 5247493e..8c7a5483 100644 --- a/tests/sso/test_okta.py +++ b/tests/sso/test_okta.py @@ -1,5 +1,5 @@ -import logging from io import StringIO +import logging import pytest from requests_mock.mocker import Mocker diff --git a/tests/steps/test_steps.py b/tests/steps/test_steps.py index 92c563a7..484f3b15 100644 --- a/tests/steps/test_steps.py +++ b/tests/steps/test_steps.py @@ -1,11 +1,11 @@ from __future__ import annotations -import io -import warnings from copy import deepcopy from functools import wraps +import io from unittest import mock from unittest.mock import call, patch +import warnings import pytest From b37a302dbbd16a38e018559d8405009bb2131910 Mon Sep 17 00:00:00 2001 From: louis-paulvlx <90868690+louis-paulvlx@users.noreply.github.com> Date: Fri, 8 Nov 2024 15:30:09 +0100 Subject: [PATCH 03/33] 90-Bug-fix-file-encoding-box-integration (#96) Allow setting File Encoding in Box implementation --- src/koheesio/integrations/box.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/koheesio/integrations/box.py b/src/koheesio/integrations/box.py index 114fdc04..2b4c0e8f 100644 --- a/src/koheesio/integrations/box.py +++ b/src/koheesio/integrations/box.py @@ -362,7 +362,11 @@ class BoxReaderBase(Box, Reader, ABC): default_factory=dict, description="[Optional] Set of extra parameters that should be passed to the Spark reader.", ) - + + file_encoding: Optional[str] = Field( + default="utf-8", + description="[Optional] Set file encoding format. By default is utf-8." + ) class BoxCsvFileReader(BoxReaderBase): """ @@ -412,7 +416,7 @@ def execute(self) -> BoxReaderBase.Output: for f in self.file: self.log.debug(f"Reading contents of file with the ID '{f}' into Spark DataFrame") file = self.client.file(file_id=f) - data = file.content().decode("utf-8") + data = file.content().decode(self.file_encoding) data_buffer = StringIO(data) temp_df_pandas = pd.read_csv(data_buffer, header=0, dtype=str if not self.schema_ else None, **self.params) # type: ignore From 6d6ccbd1167b78c16719d3df62bc03193ab0f5d2 Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Fri, 8 Nov 2024 16:41:39 +0100 Subject: [PATCH 04/33] added documentation --- src/koheesio/integrations/box.py | 6 +++--- src/koheesio/spark/utils/common.py | 10 ++++++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/koheesio/integrations/box.py b/src/koheesio/integrations/box.py index 2b4c0e8f..9124b9a7 100644 --- a/src/koheesio/integrations/box.py +++ b/src/koheesio/integrations/box.py @@ -362,12 +362,12 @@ class BoxReaderBase(Box, Reader, ABC): default_factory=dict, description="[Optional] Set of extra parameters that should be passed to the Spark reader.", ) - + file_encoding: Optional[str] = Field( - default="utf-8", - description="[Optional] Set file encoding format. By default is utf-8." + default="utf-8", description="[Optional] Set file encoding format. By default is utf-8." ) + class BoxCsvFileReader(BoxReaderBase): """ Class facilitates reading one or multiple CSV files with the same structure directly from Box and diff --git a/src/koheesio/spark/utils/common.py b/src/koheesio/spark/utils/common.py index 1f9b47cb..ed4fcc69 100644 --- a/src/koheesio/spark/utils/common.py +++ b/src/koheesio/spark/utils/common.py @@ -78,6 +78,7 @@ def get_spark_minor_version() -> float: def check_if_pyspark_connect_is_supported() -> bool: + """Check if the current version of PySpark supports the connect module""" result = False module_name: str = "pyspark" if SPARK_MINOR_VERSION >= 3.5: @@ -93,6 +94,7 @@ def check_if_pyspark_connect_is_supported() -> bool: if check_if_pyspark_connect_is_supported(): + """Only import the connect module if the current version of PySpark supports it""" from pyspark.errors.exceptions.captured import ( ParseException as CapturedParseException, ) @@ -122,6 +124,7 @@ def check_if_pyspark_connect_is_supported() -> bool: DataStreamWriter = Union[sql.streaming.readwriter.DataStreamWriter, DataStreamWriter] StreamingQuery = StreamingQuery else: + """Import the regular PySpark modules if the current version of PySpark does not support the connect module""" try: from pyspark.errors.exceptions.captured import ParseException # type: ignore except (ImportError, ModuleNotFoundError): @@ -152,6 +155,7 @@ def check_if_pyspark_connect_is_supported() -> bool: def get_active_session() -> SparkSession: # type: ignore + """Get the active Spark session""" if check_if_pyspark_connect_is_supported(): from pyspark.sql.connect.session import SparkSession as _ConnectSparkSession @@ -321,7 +325,6 @@ def import_pandas_based_on_pyspark_version() -> ModuleType: raise ImportError("Pandas module is not installed.") from e -# noinspection PyProtectedMember def show_string(df: DataFrame, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> str: # type: ignore """Returns a string representation of the DataFrame The default implementation of DataFrame.show() hardcodes a print statement, which is not always desirable. @@ -348,12 +351,13 @@ def show_string(df: DataFrame, n: int = 20, truncate: Union[bool, int] = True, v If set to True, display the DataFrame vertically, by default False """ if SPARK_MINOR_VERSION < 3.5: + # noinspection PyProtectedMember return df._jdf.showString(n, truncate, vertical) # type: ignore # as per spark 3.5, the _show_string method is now available making calls to _jdf.showString obsolete + # noinspection PyProtectedMember return df._show_string(n, truncate, vertical) -# noinspection PyProtectedMember def get_column_name(col: Column) -> str: # type: ignore """Get the column name from a Column object @@ -373,8 +377,10 @@ def get_column_name(col: Column) -> str: # type: ignore # we have to distinguish between the Column object from column from local session and remote if hasattr(col, "_jc"): # In case of a 'regular' Column object, we can directly access the name attribute through the _jc attribute + # noinspection PyProtectedMember name = col._jc.toString() # type: ignore[operator] elif any(cls.__module__ == "pyspark.sql.connect.column" for cls in inspect.getmro(col.__class__)): + # noinspection PyProtectedMember name = col._expr.name() else: raise ValueError("Column object is not a valid Column object") From a0aa8a26d12fe61134ceba23625b0b0fe7ebdb83 Mon Sep 17 00:00:00 2001 From: Danny Meijer Date: Mon, 11 Nov 2024 19:23:49 +0100 Subject: [PATCH 05/33] fix for DeltaMergeBuilder, when the instance doesn't check out (#100) --- src/koheesio/spark/writers/delta/batch.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/koheesio/spark/writers/delta/batch.py b/src/koheesio/spark/writers/delta/batch.py index 6959ef0e..5f66df99 100644 --- a/src/koheesio/spark/writers/delta/batch.py +++ b/src/koheesio/spark/writers/delta/batch.py @@ -232,7 +232,7 @@ def __merge_all(self) -> Union[DeltaMergeBuilder, DataFrameWriter]: return self.__merge(merge_builder=builder) - def _get_merge_builder(self, provided_merge_builder: DeltaMergeBuilder = None) -> DeltaMergeBuilder: + def _get_merge_builder(self, provided_merge_builder: DeltaMergeBuilder = None) -> "DeltaMergeBuilder": """Resolves the merge builder. If provided, it will be used, otherwise it will be created from the args""" # A merge builder has been already created - case for merge_all @@ -251,6 +251,11 @@ def _get_merge_builder(self, provided_merge_builder: DeltaMergeBuilder = None) - if isinstance(merge_builder, DeltaMergeBuilder): return merge_builder + if type(merge_builder).__name__ == "DeltaMergeBuilder": + # This check is to account for the case when the merge_builder is not a DeltaMergeBuilder instance, but + # still a compatible object + return merge_builder # type: ignore + if isinstance(merge_builder, list) and "merge_cond" in self.params: # type: ignore return self._merge_builder_from_args() From 6ca9f306da2f3c6a94be9e7238f0d8888eeabac2 Mon Sep 17 00:00:00 2001 From: Danny Meijer Date: Mon, 11 Nov 2024 19:37:40 +0100 Subject: [PATCH 06/33] Bump version to rc1 --- src/koheesio/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/koheesio/__about__.py b/src/koheesio/__about__.py index ff52467e..64285c72 100644 --- a/src/koheesio/__about__.py +++ b/src/koheesio/__about__.py @@ -12,7 +12,7 @@ LICENSE_INFO = "Licensed as Apache 2.0" SOURCE = "https://github.com/Nike-Inc/koheesio" -__version__ = "0.9.0rc0" +__version__ = "0.9.0rc1" __logo__ = ( 75, ( From 76207e78bd7d9a75a796fd590deade738165a924 Mon Sep 17 00:00:00 2001 From: Danny Meijer Date: Mon, 11 Nov 2024 19:47:24 +0100 Subject: [PATCH 07/33] Fix/98 sparkexpectations bump version to 220 (#99) --- makefile | 6 ++ pyproject.toml | 18 ++-- .../spark/dq/spark_expectations.py | 5 - src/koheesio/spark/utils/common.py | 32 ++++-- .../dq/test_spark_expectations.py | 3 - tests/spark/test_spark_utils.py | 98 ++++++++++++++++++- 6 files changed, 131 insertions(+), 31 deletions(-) diff --git a/makefile b/makefile index 584ce49a..03690076 100644 --- a/makefile +++ b/makefile @@ -40,6 +40,12 @@ hatch-install: fi init: hatch-install +.PHONY: sync ## hatch - Update dependencies if you changed project dependencies in pyproject.toml +.PHONY: update ## hatch - alias for sync (if you are used to poetry, thi is similar to running `poetry update`) +sync: + @hatch run dev:uv sync --all-extras +update: sync + # Code Quality .PHONY: black black-fmt ## code quality - Use black to (re)format the codebase black-fmt: diff --git a/pyproject.toml b/pyproject.toml index 493eb783..aff2a9a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,16 +60,11 @@ box = ["boxsdk[jwt]==3.8.1"] pandas = ["pandas>=1.3", "setuptools", "numpy<2.0.0", "pandas-stubs"] pyspark = ["pyspark>=3.2.0", "pyarrow>13"] pyspark_connect = ["pyspark[connect]>=3.5"] -se = ["spark-expectations>=2.1.0"] -# SFTP dependencies in to_csv line_iterator sftp = ["paramiko>=2.6.0"] delta = ["delta-spark>=2.2"] excel = ["openpyxl>=3.0.0"] -# Tableau dependencies tableau = ["tableauhyperapi>=0.0.19484", "tableauserverclient>=0.25"] -# Snowflake dependencies snowflake = ["snowflake-connector-python>=3.12.0"] -# Development dependencies dev = ["ruff", "mypy", "pylint", "colorama", "types-PyYAML", "types-requests"] test = [ "chispa", @@ -104,6 +99,10 @@ docs = [ "pymdown-extensions>=10.7.0", "black", ] +se = ["spark-expectations>=2.2.1,<2.3.0"] + +[tool.hatch.metadata] +allow-direct-references = true ### ~~~~~~~~~~~~~~~ ### @@ -237,14 +236,15 @@ features = [ "async", "async_http", "box", + "delta", + "dev", + "excel", "pandas", "pyspark", + "se", "sftp", - "delta", - "excel", "snowflake", "tableau", - "dev", "test", ] @@ -416,7 +416,7 @@ features = [ "box", "pandas", "pyspark", - # "se", + "se", "sftp", "snowflake", "delta", diff --git a/src/koheesio/integrations/spark/dq/spark_expectations.py b/src/koheesio/integrations/spark/dq/spark_expectations.py index 2f90b8f3..7053218e 100644 --- a/src/koheesio/integrations/spark/dq/spark_expectations.py +++ b/src/koheesio/integrations/spark/dq/spark_expectations.py @@ -15,15 +15,10 @@ from pydantic import Field -import pyspark - from koheesio.spark import DataFrame from koheesio.spark.transformations import Transformation from koheesio.spark.writers import BatchOutputMode -if pyspark.__version__.startswith("3.5"): - raise ImportError("Spark Expectations is not supported for Spark 3.5") - class SparkExpectationsTransformation(Transformation): """ diff --git a/src/koheesio/spark/utils/common.py b/src/koheesio/spark/utils/common.py index ed4fcc69..a099fd47 100644 --- a/src/koheesio/spark/utils/common.py +++ b/src/koheesio/spark/utils/common.py @@ -80,10 +80,10 @@ def get_spark_minor_version() -> float: def check_if_pyspark_connect_is_supported() -> bool: """Check if the current version of PySpark supports the connect module""" result = False - module_name: str = "pyspark" + if SPARK_MINOR_VERSION >= 3.5: try: - importlib.import_module(f"{module_name}.sql.connect") + importlib.import_module("pyspark.sql.connect") from pyspark.sql.connect.column import Column _col: Column @@ -119,9 +119,13 @@ def check_if_pyspark_connect_is_supported() -> bool: ParseException = (CapturedParseException, ConnectParseException) DataType = Union[SqlDataType, ConnectDataType] DataFrameReader = Union[sql.readwriter.DataFrameReader, DataFrameReader] - DataStreamReader = Union[sql.streaming.readwriter.DataStreamReader, DataStreamReader] + DataStreamReader = Union[ + sql.streaming.readwriter.DataStreamReader, DataStreamReader + ] DataFrameWriter = Union[sql.readwriter.DataFrameWriter, DataFrameWriter] - DataStreamWriter = Union[sql.streaming.readwriter.DataStreamWriter, DataStreamWriter] + DataStreamWriter = Union[ + sql.streaming.readwriter.DataStreamWriter, DataStreamWriter + ] StreamingQuery = StreamingQuery else: """Import the regular PySpark modules if the current version of PySpark does not support the connect module""" @@ -156,8 +160,9 @@ def check_if_pyspark_connect_is_supported() -> bool: def get_active_session() -> SparkSession: # type: ignore """Get the active Spark session""" + print("Entering get_active_session") if check_if_pyspark_connect_is_supported(): - from pyspark.sql.connect.session import SparkSession as _ConnectSparkSession + from pyspark.sql.connect import SparkSession as _ConnectSparkSession session = _ConnectSparkSession.getActiveSession() or sql.SparkSession.getActiveSession() # type: ignore else: @@ -292,14 +297,18 @@ def spark_data_type_is_array(data_type: DataType) -> bool: # type: ignore def spark_data_type_is_numeric(data_type: DataType) -> bool: # type: ignore """Check if the column's dataType is of type ArrayType""" - return isinstance(data_type, (IntegerType, LongType, FloatType, DoubleType, DecimalType)) + return isinstance( + data_type, (IntegerType, LongType, FloatType, DoubleType, DecimalType) + ) def schema_struct_to_schema_str(schema: StructType) -> str: """Converts a StructType to a schema str""" if not schema: return "" - return ",\n".join([f"{field.name} {field.dataType.typeName().upper()}" for field in schema.fields]) + return ",\n".join( + [f"{field.name} {field.dataType.typeName().upper()}" for field in schema.fields] + ) def import_pandas_based_on_pyspark_version() -> ModuleType: @@ -314,7 +323,9 @@ def import_pandas_based_on_pyspark_version() -> ModuleType: pyspark_version = get_spark_minor_version() pandas_version = pd.__version__ - if (pyspark_version < 3.4 and pandas_version >= "2") or (pyspark_version >= 3.4 and pandas_version < "2"): + if (pyspark_version < 3.4 and pandas_version >= "2") or ( + pyspark_version >= 3.4 and pandas_version < "2" + ): raise ImportError( f"For PySpark {pyspark_version}, " f"please install Pandas version {'< 2' if pyspark_version < 3.4 else '>= 2'}" @@ -379,7 +390,10 @@ def get_column_name(col: Column) -> str: # type: ignore # In case of a 'regular' Column object, we can directly access the name attribute through the _jc attribute # noinspection PyProtectedMember name = col._jc.toString() # type: ignore[operator] - elif any(cls.__module__ == "pyspark.sql.connect.column" for cls in inspect.getmro(col.__class__)): + elif any( + cls.__module__ == "pyspark.sql.connect.column" + for cls in inspect.getmro(col.__class__) + ): # noinspection PyProtectedMember name = col._expr.name() else: diff --git a/tests/spark/integrations/dq/test_spark_expectations.py b/tests/spark/integrations/dq/test_spark_expectations.py index 6aa05e65..e776bd40 100644 --- a/tests/spark/integrations/dq/test_spark_expectations.py +++ b/tests/spark/integrations/dq/test_spark_expectations.py @@ -11,9 +11,6 @@ pytestmark = pytest.mark.spark -if pyspark.__version__.startswith("3.5"): - pytestmark = pytest.mark.skip("Spark Expectations is not supported for Spark 3.5") - class TestSparkExpectationsTransform: """ diff --git a/tests/spark/test_spark_utils.py b/tests/spark/test_spark_utils.py index cbd83bad..ab1b608f 100644 --- a/tests/spark/test_spark_utils.py +++ b/tests/spark/test_spark_utils.py @@ -1,5 +1,5 @@ from os import environ -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -12,10 +12,89 @@ schema_struct_to_schema_str, show_string, ) +from koheesio.spark.utils.common import ( + check_if_pyspark_connect_is_supported, + get_active_session, + get_spark_minor_version, +) + + +class TestGetActiveSession: + def test_unhappy_get_active_session_spark_connect(self): + """Test that get_active_session raises an error when no active session is found when using spark connect.""" + with ( + # ensure that we are forcing the code to think that we are using spark connect + patch( + "koheesio.spark.utils.common.check_if_pyspark_connect_is_supported", + return_value=True, + ), + # make sure that spark session is not found + patch("pyspark.sql.SparkSession.getActiveSession", return_value=None), + ): + session = MagicMock( + SparkSession=MagicMock(getActiveSession=MagicMock(return_value=None)) + ) + with patch.dict("sys.modules", {"pyspark.sql.connect": session}): + with pytest.raises( + RuntimeError, + match="No active Spark session found. Please create a Spark session before using module " + "connect_utils. Or perform local import of the module.", + ): + get_active_session() + + def test_unhappy_get_active_session(self): + """Test that get_active_session raises an error when no active session is found.""" + with ( + patch( + "koheesio.spark.utils.common.check_if_pyspark_connect_is_supported", + return_value=False, + ), + patch("pyspark.sql.SparkSession.getActiveSession", return_value=None), + ): + with pytest.raises( + RuntimeError, + match="No active Spark session found. Please create a Spark session before using module connect_utils. " + "Or perform local import of the module.", + ): + get_active_session() + + def test_get_active_session_with_spark(self, spark): + """Test get_active_session when an active session is found""" + session = get_active_session() + assert session is not None + + +class TestCheckIfPysparkConnectIsSupported: + def test_if_pyspark_connect_is_not_supported(self): + """Test that check_if_pyspark_connect_is_supported returns False when pyspark connect is not supported.""" + with patch.dict("sys.modules", {"pyspark.sql.connect": None}): + assert check_if_pyspark_connect_is_supported() is False + + def test_check_if_pyspark_connect_is_supported(self): + """Test that check_if_pyspark_connect_is_supported returns True when pyspark connect is supported.""" + with ( + patch("koheesio.spark.utils.common.SPARK_MINOR_VERSION", 3.5), + patch.dict( + "sys.modules", + { + "pyspark.sql.connect.column": MagicMock(Column=MagicMock()), + "pyspark.sql.connect": MagicMock(), + }, + ), + ): + assert check_if_pyspark_connect_is_supported() is True + + +def test_get_spark_minor_version(): + """Test that get_spark_minor_version returns the correctly formatted version.""" + with patch("koheesio.spark.utils.common.spark_version", "9.9.42"): + assert get_spark_minor_version() == 9.9 def test_schema_struct_to_schema_str(): - struct_schema = StructType([StructField("a", StringType()), StructField("b", StringType())]) + struct_schema = StructType( + [StructField("a", StringType()), StructField("b", StringType())] + ) val = schema_struct_to_schema_str(struct_schema) assert val == "a STRING,\nb STRING" assert schema_struct_to_schema_str(None) == "" @@ -40,12 +119,21 @@ def test_on_databricks(env_var_value, expected_result): (3.3, "1.2.3", None), # PySpark 3.3, pandas < 2, should not raise an error (3.4, "2.3.4", None), # PySpark not 3.3, pandas >= 2, should not raise an error (3.3, "2.3.4", ImportError), # PySpark 3.3, pandas >= 2, should raise an error - (3.4, "1.2.3", ImportError), # PySpark not 3.3, pandas < 2, should raise an error + ( + 3.4, + "1.2.3", + ImportError, + ), # PySpark not 3.3, pandas < 2, should raise an error ], ) -def test_import_pandas_based_on_pyspark_version(spark_version, pandas_version, expected_error): +def test_import_pandas_based_on_pyspark_version( + spark_version, pandas_version, expected_error +): with ( - patch("koheesio.spark.utils.common.get_spark_minor_version", return_value=spark_version), + patch( + "koheesio.spark.utils.common.get_spark_minor_version", + return_value=spark_version, + ), patch("pandas.__version__", new=pandas_version), ): if expected_error: From 1e645d80e6e9dee65b5221356c07bf54bec6ef0e Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Tue, 12 Nov 2024 10:29:06 +0100 Subject: [PATCH 08/33] quick fix --- src/koheesio/__about__.py | 2 +- .../transformations/date_time/interval.py | 1 - src/koheesio/spark/utils/common.py | 28 +++++-------------- tests/spark/test_spark_utils.py | 12 ++------ 4 files changed, 11 insertions(+), 32 deletions(-) diff --git a/src/koheesio/__about__.py b/src/koheesio/__about__.py index 64285c72..740b6265 100644 --- a/src/koheesio/__about__.py +++ b/src/koheesio/__about__.py @@ -12,7 +12,7 @@ LICENSE_INFO = "Licensed as Apache 2.0" SOURCE = "https://github.com/Nike-Inc/koheesio" -__version__ = "0.9.0rc1" +__version__ = "0.9.0rc2" __logo__ = ( 75, ( diff --git a/src/koheesio/spark/transformations/date_time/interval.py b/src/koheesio/spark/transformations/date_time/interval.py index e56bd548..1904a187 100644 --- a/src/koheesio/spark/transformations/date_time/interval.py +++ b/src/koheesio/spark/transformations/date_time/interval.py @@ -147,7 +147,6 @@ def __add__(self, value: str) -> Column: A valid value is a string that can be parsed by the `interval` function in Spark SQL. See https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html#interval-literal """ - print(f"__add__: {value = }") return adjust_time(self, operation="add", interval=value) def __sub__(self, value: str) -> Column: diff --git a/src/koheesio/spark/utils/common.py b/src/koheesio/spark/utils/common.py index a099fd47..3d0d79ef 100644 --- a/src/koheesio/spark/utils/common.py +++ b/src/koheesio/spark/utils/common.py @@ -119,13 +119,9 @@ def check_if_pyspark_connect_is_supported() -> bool: ParseException = (CapturedParseException, ConnectParseException) DataType = Union[SqlDataType, ConnectDataType] DataFrameReader = Union[sql.readwriter.DataFrameReader, DataFrameReader] - DataStreamReader = Union[ - sql.streaming.readwriter.DataStreamReader, DataStreamReader - ] + DataStreamReader = Union[sql.streaming.readwriter.DataStreamReader, DataStreamReader] DataFrameWriter = Union[sql.readwriter.DataFrameWriter, DataFrameWriter] - DataStreamWriter = Union[ - sql.streaming.readwriter.DataStreamWriter, DataStreamWriter - ] + DataStreamWriter = Union[sql.streaming.readwriter.DataStreamWriter, DataStreamWriter] StreamingQuery = StreamingQuery else: """Import the regular PySpark modules if the current version of PySpark does not support the connect module""" @@ -160,9 +156,8 @@ def check_if_pyspark_connect_is_supported() -> bool: def get_active_session() -> SparkSession: # type: ignore """Get the active Spark session""" - print("Entering get_active_session") if check_if_pyspark_connect_is_supported(): - from pyspark.sql.connect import SparkSession as _ConnectSparkSession + from pyspark.sql.connect.session import SparkSession as _ConnectSparkSession session = _ConnectSparkSession.getActiveSession() or sql.SparkSession.getActiveSession() # type: ignore else: @@ -297,18 +292,14 @@ def spark_data_type_is_array(data_type: DataType) -> bool: # type: ignore def spark_data_type_is_numeric(data_type: DataType) -> bool: # type: ignore """Check if the column's dataType is of type ArrayType""" - return isinstance( - data_type, (IntegerType, LongType, FloatType, DoubleType, DecimalType) - ) + return isinstance(data_type, (IntegerType, LongType, FloatType, DoubleType, DecimalType)) def schema_struct_to_schema_str(schema: StructType) -> str: """Converts a StructType to a schema str""" if not schema: return "" - return ",\n".join( - [f"{field.name} {field.dataType.typeName().upper()}" for field in schema.fields] - ) + return ",\n".join([f"{field.name} {field.dataType.typeName().upper()}" for field in schema.fields]) def import_pandas_based_on_pyspark_version() -> ModuleType: @@ -323,9 +314,7 @@ def import_pandas_based_on_pyspark_version() -> ModuleType: pyspark_version = get_spark_minor_version() pandas_version = pd.__version__ - if (pyspark_version < 3.4 and pandas_version >= "2") or ( - pyspark_version >= 3.4 and pandas_version < "2" - ): + if (pyspark_version < 3.4 and pandas_version >= "2") or (pyspark_version >= 3.4 and pandas_version < "2"): raise ImportError( f"For PySpark {pyspark_version}, " f"please install Pandas version {'< 2' if pyspark_version < 3.4 else '>= 2'}" @@ -390,10 +379,7 @@ def get_column_name(col: Column) -> str: # type: ignore # In case of a 'regular' Column object, we can directly access the name attribute through the _jc attribute # noinspection PyProtectedMember name = col._jc.toString() # type: ignore[operator] - elif any( - cls.__module__ == "pyspark.sql.connect.column" - for cls in inspect.getmro(col.__class__) - ): + elif any(cls.__module__ == "pyspark.sql.connect.column" for cls in inspect.getmro(col.__class__)): # noinspection PyProtectedMember name = col._expr.name() else: diff --git a/tests/spark/test_spark_utils.py b/tests/spark/test_spark_utils.py index ab1b608f..db76351d 100644 --- a/tests/spark/test_spark_utils.py +++ b/tests/spark/test_spark_utils.py @@ -31,9 +31,7 @@ def test_unhappy_get_active_session_spark_connect(self): # make sure that spark session is not found patch("pyspark.sql.SparkSession.getActiveSession", return_value=None), ): - session = MagicMock( - SparkSession=MagicMock(getActiveSession=MagicMock(return_value=None)) - ) + session = MagicMock(SparkSession=MagicMock(getActiveSession=MagicMock(return_value=None))) with patch.dict("sys.modules", {"pyspark.sql.connect": session}): with pytest.raises( RuntimeError, @@ -92,9 +90,7 @@ def test_get_spark_minor_version(): def test_schema_struct_to_schema_str(): - struct_schema = StructType( - [StructField("a", StringType()), StructField("b", StringType())] - ) + struct_schema = StructType([StructField("a", StringType()), StructField("b", StringType())]) val = schema_struct_to_schema_str(struct_schema) assert val == "a STRING,\nb STRING" assert schema_struct_to_schema_str(None) == "" @@ -126,9 +122,7 @@ def test_on_databricks(env_var_value, expected_result): ), # PySpark not 3.3, pandas < 2, should raise an error ], ) -def test_import_pandas_based_on_pyspark_version( - spark_version, pandas_version, expected_error -): +def test_import_pandas_based_on_pyspark_version(spark_version, pandas_version, expected_error): with ( patch( "koheesio.spark.utils.common.get_spark_minor_version", From abb435b409b620ae71dd868a5647256eeb4194b4 Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Tue, 12 Nov 2024 10:54:02 +0100 Subject: [PATCH 09/33] quick fix #2 --- src/koheesio/spark/utils/common.py | 10 ++++------ tests/spark/test_spark_utils.py | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/koheesio/spark/utils/common.py b/src/koheesio/spark/utils/common.py index 3d0d79ef..de918a7c 100644 --- a/src/koheesio/spark/utils/common.py +++ b/src/koheesio/spark/utils/common.py @@ -79,18 +79,16 @@ def get_spark_minor_version() -> float: def check_if_pyspark_connect_is_supported() -> bool: """Check if the current version of PySpark supports the connect module""" - result = False - if SPARK_MINOR_VERSION >= 3.5: try: importlib.import_module("pyspark.sql.connect") from pyspark.sql.connect.column import Column - _col: Column - result = True + _col: Column # type: ignore + return True except (ModuleNotFoundError, ImportError): - result = False - return result + return False + return False if check_if_pyspark_connect_is_supported(): diff --git a/tests/spark/test_spark_utils.py b/tests/spark/test_spark_utils.py index db76351d..dfb4f9b5 100644 --- a/tests/spark/test_spark_utils.py +++ b/tests/spark/test_spark_utils.py @@ -32,7 +32,7 @@ def test_unhappy_get_active_session_spark_connect(self): patch("pyspark.sql.SparkSession.getActiveSession", return_value=None), ): session = MagicMock(SparkSession=MagicMock(getActiveSession=MagicMock(return_value=None))) - with patch.dict("sys.modules", {"pyspark.sql.connect": session}): + with patch.dict("sys.modules", {"pyspark.sql.connect.session": session}): with pytest.raises( RuntimeError, match="No active Spark session found. Please create a Spark session before using module " From fb54aefeee1d2f3278c5c7ff941f9e49b5486cb3 Mon Sep 17 00:00:00 2001 From: Danny Meijer Date: Mon, 18 Nov 2024 18:47:49 +0100 Subject: [PATCH 10/33] [FIX] Accidental duplication of logs (#105) ## Description Enhanced the Step class and its metaclass StepMetaClass to ensure that the execute method is wrapped only once, even when inherited multiple times. Added tests to verify that the log and wrapper are called only once in such scenarios. Also, included custom metaclass examples and their respective tests. ## Related Issue N/A ## Motivation and Context It was established that (under the right conditions) our code allowed the execute method to be wrapped multiple times, even though we had code in place to prevent that. This reflected outwardly by seeing log duplication. For example like this: #### Wrong output: ```bash [1111] [2024-11-18 14:22:18,355] [INFO] [koheesio.MyCustomGreatGrandChildStep] {__init__.py:_log_start_message:307} - Start running step [1111] [2024-11-18 14:22:18,355] [DEBUG] [koheesio.MyCustomGreatGrandChildStep] {__init__.py:_log_start_message:308} - Step Input: name='MyCustomGreatGrandChildStep' description='MyCustomGreatGrandChildStep' foo='foo' bar='bar' qux='qux' [1111] [2024-11-18 14:22:18,355] [INFO] [koheesio.MyCustomGreatGrandChildStep] {__init__.py:_log_start_message:307} - Start running step [1111] [2024-11-18 14:22:18,355] [DEBUG] [koheesio.MyCustomGreatGrandChildStep] {__init__.py:_log_start_message:308} - Step Input: name='MyCustomGreatGrandChildStep' description='MyCustomGreatGrandChildStep' foo='foo' bar='bar' qux='qux' [1111] [2024-11-18 14:22:18,355] [INFO] [koheesio.MyCustomGreatGrandChildStep] {__init__.py:_log_start_message:307} - Start running step [1111] [2024-11-18 14:22:18,355] [DEBUG] [koheesio.MyCustomGreatGrandChildStep] {__init__.py:_log_start_message:308} - Step Input: name='MyCustomGreatGrandChildStep' description='MyCustomGreatGrandChildStep' foo='foo' bar='bar' qux='qux' [1111] [2024-11-18 14:22:18,356] [DEBUG] [koheesio.MyCustomGreatGrandChildStep] {__init__.py:_log_end_message:329} - Step Output: name='MyCustomGreatGrandChildStep.Output' description='Output for MyCustomGreatGrandChildStep' baz='foobar' [1111] [2024-11-18 14:22:18,356] [INFO] [koheesio.MyCustomGreatGrandChildStep] {__init__.py:_log_end_message:330} - Finished running step [1111] [2024-11-18 14:22:18,356] [DEBUG] [koheesio.MyCustomGreatGrandChildStep] {__init__.py:_log_end_message:329} - Step Output: name='MyCustomGreatGrandChildStep.Output' description='Output for MyCustomGreatGrandChildStep' baz='foobar' [1111] [2024-11-18 14:22:18,356] [INFO] [koheesio.MyCustomGreatGrandChildStep] {__init__.py:_log_end_message:330} - Finished running step [1111] [2024-11-18 14:22:18,356] [DEBUG] [koheesio.MyCustomGreatGrandChildStep] {__init__.py:_log_end_message:329} - Step Output: name='MyCustomGreatGrandChildStep.Output' description='Output for MyCustomGreatGrandChildStep' baz='foobar' [1111] [2024-11-18 14:22:18,356] [INFO] [koheesio.MyCustomGreatGrandChildStep] {__init__.py:_log_end_message:330} - Finished running step ``` #### Expected output: ```bash [1111] [2024-11-18 14:22:18,355] [INFO] [koheesio.MyCustomGreatGrandChildStep] {__init__.py:_log_start_message:307} - Start running step [1111] [2024-11-18 14:22:18,355] [DEBUG] [koheesio.MyCustomGreatGrandChildStep] {__init__.py:_log_start_message:308} - Step Input: name='MyCustomGreatGrandChildStep' description='MyCustomGreatGrandChildStep' foo='foo' bar='bar' qux='qux' [1111] [2024-11-18 14:22:18,356] [DEBUG] [koheesio.MyCustomGreatGrandChildStep] {__init__.py:_log_end_message:329} - Step Output: name='MyCustomGreatGrandChildStep.Output' description='Output for MyCustomGreatGrandChildStep' baz='foobar' [1111] [2024-11-18 14:22:18,356] [INFO] [koheesio.MyCustomGreatGrandChildStep] {__init__.py:_log_end_message:330} - Finished running step ``` ## How Has This Been Tested? Tested by adding unit tests in tests/steps/test_steps.py to verify that the execute method is wrapped only once and that the log messages are called in the correct order. The tests also cover custom metaclass functionality and ensure that the output validation and logging work as expected. ## Screenshots (if appropriate): ... ## Types of changes - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Checklist: - [x] My code follows the code style of this project. - [ ] My change requires a change to the documentation. - [ ] I have updated the documentation accordingly. - [x] I have read the **CONTRIBUTING** document. - [x] I have added tests to cover my changes. - [x] All new and existing tests passed. --------- Co-authored-by: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> --- src/koheesio/asyncio/__init__.py | 4 +- src/koheesio/integrations/box.py | 8 +-- .../integrations/snowflake/test_utils.py | 4 +- .../integrations/spark/tableau/hyper.py | 24 ++----- src/koheesio/models/__init__.py | 8 +-- src/koheesio/spark/readers/file_loader.py | 4 +- src/koheesio/spark/readers/rest_api.py | 8 +-- .../spark/transformations/__init__.py | 4 +- .../spark/transformations/camel_to_snake.py | 4 +- .../transformations/date_time/interval.py | 8 +-- src/koheesio/spark/transformations/lookup.py | 4 +- .../transformations/strings/change_case.py | 12 +--- .../spark/transformations/strings/split.py | 8 +-- .../spark/transformations/strings/trim.py | 8 +-- src/koheesio/spark/transformations/uuid5.py | 4 +- src/koheesio/spark/writers/file_writer.py | 16 ++--- src/koheesio/steps/__init__.py | 70 +++++++++--------- tests/steps/test_steps.py | 72 ++++++++++++++++--- 18 files changed, 131 insertions(+), 139 deletions(-) diff --git a/src/koheesio/asyncio/__init__.py b/src/koheesio/asyncio/__init__.py index 3dc63fea..048ab979 100644 --- a/src/koheesio/asyncio/__init__.py +++ b/src/koheesio/asyncio/__init__.py @@ -67,9 +67,7 @@ def merge(self, other: Union[Dict, StepOutput]) -> "AsyncStepOutput": -------- ```python step_output = StepOutput(foo="bar") - step_output.merge( - {"lorem": "ipsum"} - ) # step_output will now contain {'foo': 'bar', 'lorem': 'ipsum'} + step_output.merge({"lorem": "ipsum"}) # step_output will now contain {'foo': 'bar', 'lorem': 'ipsum'} ``` Functionally similar to adding two dicts together; like running `{**dict_a, **dict_b}`. diff --git a/src/koheesio/integrations/box.py b/src/koheesio/integrations/box.py index 9124b9a7..6d2d055e 100644 --- a/src/koheesio/integrations/box.py +++ b/src/koheesio/integrations/box.py @@ -619,16 +619,12 @@ class BoxFileWriter(BoxFolderBase): from koheesio.steps.integrations.box import BoxFileWriter auth_params = {...} - f1 = BoxFileWriter( - **auth_params, path="/foo/bar", file="path/to/my/file.ext" - ).execute() + f1 = BoxFileWriter(**auth_params, path="/foo/bar", file="path/to/my/file.ext").execute() # or import io b = io.BytesIO(b"my-sample-data") - f2 = BoxFileWriter( - **auth_params, path="/foo/bar", file=b, name="file.ext" - ).execute() + f2 = BoxFileWriter(**auth_params, path="/foo/bar", file=b, name="file.ext").execute() ``` """ diff --git a/src/koheesio/integrations/snowflake/test_utils.py b/src/koheesio/integrations/snowflake/test_utils.py index 8ae9ac3e..8b85e97d 100644 --- a/src/koheesio/integrations/snowflake/test_utils.py +++ b/src/koheesio/integrations/snowflake/test_utils.py @@ -25,9 +25,7 @@ def test_execute(self, mock_query): mock_query.expected_data = [("row1",), ("row2",)] # Act - instance = SnowflakeRunQueryPython( - **COMMON_OPTIONS, query=query, account="42" - ) + instance = SnowflakeRunQueryPython(**COMMON_OPTIONS, query=query, account="42") instance.execute() # Assert diff --git a/src/koheesio/integrations/spark/tableau/hyper.py b/src/koheesio/integrations/spark/tableau/hyper.py index 94230d87..c7df9be0 100644 --- a/src/koheesio/integrations/spark/tableau/hyper.py +++ b/src/koheesio/integrations/spark/tableau/hyper.py @@ -199,15 +199,9 @@ class HyperFileListWriter(HyperFileWriter): table_definition=TableDefinition( table_name=TableName("Extract", "Extract"), columns=[ - TableDefinition.Column( - name="string", type=SqlType.text(), nullability=NOT_NULLABLE - ), - TableDefinition.Column( - name="int", type=SqlType.int(), nullability=NULLABLE - ), - TableDefinition.Column( - name="timestamp", type=SqlType.timestamp(), nullability=NULLABLE - ), + TableDefinition.Column(name="string", type=SqlType.text(), nullability=NOT_NULLABLE), + TableDefinition.Column(name="int", type=SqlType.int(), nullability=NULLABLE), + TableDefinition.Column(name="timestamp", type=SqlType.timestamp(), nullability=NULLABLE), ], ), data=[ @@ -261,15 +255,9 @@ class HyperFileParquetWriter(HyperFileWriter): table_definition=TableDefinition( table_name=TableName("Extract", "Extract"), columns=[ - TableDefinition.Column( - name="string", type=SqlType.text(), nullability=NOT_NULLABLE - ), - TableDefinition.Column( - name="int", type=SqlType.int(), nullability=NULLABLE - ), - TableDefinition.Column( - name="timestamp", type=SqlType.timestamp(), nullability=NULLABLE - ), + TableDefinition.Column(name="string", type=SqlType.text(), nullability=NOT_NULLABLE), + TableDefinition.Column(name="int", type=SqlType.int(), nullability=NULLABLE), + TableDefinition.Column(name="timestamp", type=SqlType.timestamp(), nullability=NULLABLE), ], ), files=[ diff --git a/src/koheesio/models/__init__.py b/src/koheesio/models/__init__.py index a2db492f..89fa7af8 100644 --- a/src/koheesio/models/__init__.py +++ b/src/koheesio/models/__init__.py @@ -407,9 +407,7 @@ def __add__(self, other: Union[Dict, BaseModel]) -> BaseModel: ```python step_output_1 = StepOutput(foo="bar") step_output_2 = StepOutput(lorem="ipsum") - ( - step_output_1 + step_output_2 - ) # step_output_1 will now contain {'foo': 'bar', 'lorem': 'ipsum'} + (step_output_1 + step_output_2) # step_output_1 will now contain {'foo': 'bar', 'lorem': 'ipsum'} ``` Parameters @@ -533,9 +531,7 @@ def merge(self, other: Union[Dict, BaseModel]) -> BaseModel: -------- ```python step_output = StepOutput(foo="bar") - step_output.merge( - {"lorem": "ipsum"} - ) # step_output will now contain {'foo': 'bar', 'lorem': 'ipsum'} + step_output.merge({"lorem": "ipsum"}) # step_output will now contain {'foo': 'bar', 'lorem': 'ipsum'} ``` Parameters diff --git a/src/koheesio/spark/readers/file_loader.py b/src/koheesio/spark/readers/file_loader.py index 2bb3cd8b..6ccd5de8 100644 --- a/src/koheesio/spark/readers/file_loader.py +++ b/src/koheesio/spark/readers/file_loader.py @@ -80,9 +80,7 @@ class FileLoader(Reader, ExtraParamsMixin): Example: ```python - reader = FileLoader( - path="path/to/textfile.txt", format="text", header=True, lineSep="\n" - ) + reader = FileLoader(path="path/to/textfile.txt", format="text", header=True, lineSep="\n") ``` For more information about the available options, see Spark's diff --git a/src/koheesio/spark/readers/rest_api.py b/src/koheesio/spark/readers/rest_api.py index bad9036c..45b5fbbe 100644 --- a/src/koheesio/spark/readers/rest_api.py +++ b/src/koheesio/spark/readers/rest_api.py @@ -70,9 +70,7 @@ class RestApiReader(Reader): pages=3, session=session, ) - task = RestApiReader( - transport=transport, spark_schema="id: int, page:int, value: string" - ) + task = RestApiReader(transport=transport, spark_schema="id: int, page:int, value: string") task.execute() all_data = [row.asDict() for row in task.output.df.collect()] ``` @@ -97,9 +95,7 @@ class RestApiReader(Reader): connector=connector, ) - task = RestApiReader( - transport=transport, spark_schema="id: int, page:int, value: string" - ) + task = RestApiReader(transport=transport, spark_schema="id: int, page:int, value: string") task.execute() all_data = [row.asDict() for row in task.output.df.collect()] ``` diff --git a/src/koheesio/spark/transformations/__init__.py b/src/koheesio/spark/transformations/__init__.py index c44a9981..3f273a85 100644 --- a/src/koheesio/spark/transformations/__init__.py +++ b/src/koheesio/spark/transformations/__init__.py @@ -56,9 +56,7 @@ class Transformation(SparkStep, ABC): class AddOne(Transformation): def execute(self): - self.output.df = self.df.withColumn( - "new_column", f.col("old_column") + 1 - ) + self.output.df = self.df.withColumn("new_column", f.col("old_column") + 1) ``` In the example above, the `execute` method is implemented to add 1 to the values of the `old_column` and store the diff --git a/src/koheesio/spark/transformations/camel_to_snake.py b/src/koheesio/spark/transformations/camel_to_snake.py index f62c822b..e8198639 100644 --- a/src/koheesio/spark/transformations/camel_to_snake.py +++ b/src/koheesio/spark/transformations/camel_to_snake.py @@ -48,9 +48,7 @@ class CamelToSnakeTransformation(ColumnsTransformation): | ... | ... | ```python - output_df = CamelToSnakeTransformation(column="camelCaseColumn").transform( - input_df - ) + output_df = CamelToSnakeTransformation(column="camelCaseColumn").transform(input_df) ``` __output_df:__ diff --git a/src/koheesio/spark/transformations/date_time/interval.py b/src/koheesio/spark/transformations/date_time/interval.py index 1904a187..df7d1f1b 100644 --- a/src/koheesio/spark/transformations/date_time/interval.py +++ b/src/koheesio/spark/transformations/date_time/interval.py @@ -102,14 +102,10 @@ DateTimeAddInterval, ) -input_df = spark.createDataFrame( - [(1, "2022-01-01 00:00:00")], ["id", "my_column"] -) +input_df = spark.createDataFrame([(1, "2022-01-01 00:00:00")], ["id", "my_column"]) # add 1 day to my_column and store the result in a new column called 'one_day_later' -output_df = DateTimeAddInterval( - column="my_column", target_column="one_day_later", interval="1 day" -).transform(input_df) +output_df = DateTimeAddInterval(column="my_column", target_column="one_day_later", interval="1 day").transform(input_df) ``` __output_df__: diff --git a/src/koheesio/spark/transformations/lookup.py b/src/koheesio/spark/transformations/lookup.py index 73292ec8..31b35cb3 100644 --- a/src/koheesio/spark/transformations/lookup.py +++ b/src/koheesio/spark/transformations/lookup.py @@ -103,9 +103,7 @@ class DataframeLookup(Transformation): df=left_df, other=right_df, on=JoinMapping(source_column="id", joined_column="id"), - targets=TargetColumn( - target_column="value", target_column_alias="right_value" - ), + targets=TargetColumn(target_column="value", target_column_alias="right_value"), how=JoinType.LEFT, ) diff --git a/src/koheesio/spark/transformations/strings/change_case.py b/src/koheesio/spark/transformations/strings/change_case.py index 3906b352..20a1ce60 100644 --- a/src/koheesio/spark/transformations/strings/change_case.py +++ b/src/koheesio/spark/transformations/strings/change_case.py @@ -51,9 +51,7 @@ class LowerCase(ColumnsTransformationWithTarget): | Beans| 1600| USA| ```python - output_df = LowerCase( - column="product", target_column="product_lower" - ).transform(df) + output_df = LowerCase(column="product", target_column="product_lower").transform(df) ``` __output_df:__ @@ -109,9 +107,7 @@ class UpperCase(LowerCase): | Beans| 1600| USA| ```python - output_df = UpperCase( - column="product", target_column="product_upper" - ).transform(df) + output_df = UpperCase(column="product", target_column="product_upper").transform(df) ``` __output_df:__ @@ -162,9 +158,7 @@ class TitleCase(LowerCase): | Beans| 1600| USA| ```python - output_df = TitleCase( - column="product", target_column="product_title" - ).transform(df) + output_df = TitleCase(column="product", target_column="product_title").transform(df) ``` __output_df:__ diff --git a/src/koheesio/spark/transformations/strings/split.py b/src/koheesio/spark/transformations/strings/split.py index 0c71d370..6c54586d 100644 --- a/src/koheesio/spark/transformations/strings/split.py +++ b/src/koheesio/spark/transformations/strings/split.py @@ -51,9 +51,7 @@ class SplitAll(ColumnsTransformationWithTarget): | Beans| 1600| USA| ```python - output_df = SplitColumn( - column="product", target_column="split", split_pattern=" " - ).transform(input_df) + output_df = SplitColumn(column="product", target_column="split", split_pattern=" ").transform(input_df) ``` __output_df:__ @@ -109,9 +107,7 @@ class SplitAtFirstMatch(SplitAll): | Beans| 1600| USA| ```python - output_df = SplitColumn( - column="product", target_column="split_first", split_pattern="an" - ).transform(input_df) + output_df = SplitColumn(column="product", target_column="split_first", split_pattern="an").transform(input_df) ``` __output_df:__ diff --git a/src/koheesio/spark/transformations/strings/trim.py b/src/koheesio/spark/transformations/strings/trim.py index 6f72eae8..4a7a8529 100644 --- a/src/koheesio/spark/transformations/strings/trim.py +++ b/src/koheesio/spark/transformations/strings/trim.py @@ -57,9 +57,7 @@ class Trim(ColumnsTransformationWithTarget): ### Trim whitespace from the beginning of a string ```python - output_df = Trim( - column="column", target_column="trimmed_column", direction="left" - ).transform(input_df) + output_df = Trim(column="column", target_column="trimmed_column", direction="left").transform(input_df) ``` __output_df:__ @@ -86,9 +84,7 @@ class Trim(ColumnsTransformationWithTarget): ### Trim whitespace from the end of a string ```python - output_df = Trim( - column="column", target_column="trimmed_column", direction="right" - ).transform(input_df) + output_df = Trim(column="column", target_column="trimmed_column", direction="right").transform(input_df) ``` __output_df:__ diff --git a/src/koheesio/spark/transformations/uuid5.py b/src/koheesio/spark/transformations/uuid5.py index cd709a86..793cb210 100644 --- a/src/koheesio/spark/transformations/uuid5.py +++ b/src/koheesio/spark/transformations/uuid5.py @@ -110,9 +110,7 @@ class HashUUID5(Transformation): In code: ```python - HashUUID5(source_columns=["id", "string"], target_column="uuid5").transform( - input_df - ) + HashUUID5(source_columns=["id", "string"], target_column="uuid5").transform(input_df) ``` In this example, the `id` and `string` columns are concatenated and hashed using the UUID5 algorithm. The result is diff --git a/src/koheesio/spark/writers/file_writer.py b/src/koheesio/spark/writers/file_writer.py index 9e79d562..a3b63ee6 100644 --- a/src/koheesio/spark/writers/file_writer.py +++ b/src/koheesio/spark/writers/file_writer.py @@ -141,9 +141,7 @@ class AvroFileWriter(FileWriter): Examples -------- ```python - writer = AvroFileWriter( - df=df, path="path/to/file.avro", output_mode=BatchOutputMode.APPEND - ) + writer = AvroFileWriter(df=df, path="path/to/file.avro", output_mode=BatchOutputMode.APPEND) ``` """ @@ -160,9 +158,7 @@ class JsonFileWriter(FileWriter): Examples -------- ```python - writer = JsonFileWriter( - df=df, path="path/to/file.json", output_mode=BatchOutputMode.APPEND - ) + writer = JsonFileWriter(df=df, path="path/to/file.json", output_mode=BatchOutputMode.APPEND) ``` """ @@ -179,9 +175,7 @@ class OrcFileWriter(FileWriter): Examples -------- ```python - writer = OrcFileWriter( - df=df, path="path/to/file.orc", output_mode=BatchOutputMode.APPEND - ) + writer = OrcFileWriter(df=df, path="path/to/file.orc", output_mode=BatchOutputMode.APPEND) ``` """ @@ -198,9 +192,7 @@ class TextFileWriter(FileWriter): Examples -------- ```python - writer = TextFileWriter( - df=df, path="path/to/file.txt", output_mode=BatchOutputMode.APPEND - ) + writer = TextFileWriter(df=df, path="path/to/file.txt", output_mode=BatchOutputMode.APPEND) ``` """ diff --git a/src/koheesio/steps/__init__.py b/src/koheesio/steps/__init__.py index 5a1faa7c..d879a6e1 100644 --- a/src/koheesio/steps/__init__.py +++ b/src/koheesio/steps/__init__.py @@ -72,12 +72,13 @@ class StepMetaClass(ModelMetaclass): allowing for the execute method to be auto-decorated with do_execute """ - # Solution to overcome issue with python>=3.11, - # When partialmethod is forgetting that _execute_wrapper - # is a method of wrapper, and it needs to pass that in as the first arg. - # https://github.com/python/cpython/issues/99152 # noinspection PyPep8Naming,PyUnresolvedReferences class _partialmethod_with_self(partialmethod): + """Solution to overcome issue with python>=3.11, when partialmethod is forgetting that _execute_wrapper is a + method of wrapper, and it needs to pass that in as the first arg. + See: https://github.com/python/cpython/issues/99152 + """ + def __get__(self, obj: Any, cls=None): # type: ignore[no-untyped-def] return self._make_unbound_method().__get__(obj, cls) @@ -134,44 +135,43 @@ def __new__( **kwargs, ) - # Extract execute method present in the class - execute_method = getattr(cls, "execute") + # Traverse the MRO to find the first occurrence of the execute method + execute_method = None + for base in cls.__mro__: + if "execute" in base.__dict__: + execute_method = base.__dict__["execute"] + break - # check if function is already wrapped with do_execute - # Here we are trying to get the attribute "_partialmethod__step_execute_wrapper_sentinel" - # from the execute_method function. - # If the function is already wrapped, this attribute should exist. - sentinel = getattr(execute_method, "_partialmethod__step_execute_wrapper_sentinel", None) + if execute_method: + # Check if the execute method is already wrapped + is_already_wrapped = ( + getattr(execute_method, "_step_execute_wrapper_sentinel", None) is cls._step_execute_wrapper_sentinel + ) - # Check if the sentinel is the same as the class's sentinel. If they are the same, - # it means the function is already wrapped. - # noinspection PyUnresolvedReferences - is_already_wrapped = sentinel is cls._step_execute_wrapper_sentinel + # Get the wrap count of the function. If the function is not wrapped yet, the default value is 0. + wrap_count = getattr(execute_method, "_wrap_count", 0) - # Get the wrap count of the function. If the function is not wrapped yet, the default value is 0. - wrap_count = getattr(execute_method, "_partialmethod_wrap_count", 0) + # prevent multiple wrapping + # If the function is not already wrapped, we proceed to wrap it. + if not is_already_wrapped: + # Create a partial method with the execute_method as one of the arguments. + # This is the new function that will be called instead of the original execute_method. - # prevent multiple wrapping - # If the function is not already wrapped, we proceed to wrap it. - if not is_already_wrapped: - # Create a partial method with the execute_method as one of the arguments. - # This is the new function that will be called instead of the original execute_method. + # noinspection PyProtectedMember,PyUnresolvedReferences + wrapper = mcs._partialmethod_impl(cls=cls, execute_method=execute_method) - # noinspection PyProtectedMember,PyUnresolvedReferences - wrapper = mcs._partialmethod_impl(cls=cls, execute_method=execute_method) + # Updating the attributes of the wrapping function to those of the original function. + wraps(execute_method)(wrapper) # type: ignore - # Updating the attributes of the wrapping function to those of the original function. - wraps(execute_method)(wrapper) # type: ignore - - # Set the sentinel attribute to the wrapper. This is done so that we can check - # if the function is already wrapped. - # noinspection PyUnresolvedReferences - setattr(wrapper, "_step_execute_wrapper_sentinel", cls._step_execute_wrapper_sentinel) + # Set the sentinel attribute to the wrapper. This is done so that we can check + # if the function is already wrapped. + # noinspection PyUnresolvedReferences + setattr(wrapper, "_step_execute_wrapper_sentinel", cls._step_execute_wrapper_sentinel) - # Increase the wrap count of the function. This is done to keep track of - # how many times the function has been wrapped. - setattr(wrapper, "_wrap_count", wrap_count + 1) - setattr(cls, "execute", wrapper) + # Increase the wrap count of the function. This is done to keep track of + # how many times the function has been wrapped. + setattr(wrapper, "_wrap_count", wrap_count + 1) + setattr(cls, "execute", wrapper) return cls diff --git a/tests/steps/test_steps.py b/tests/steps/test_steps.py index 484f3b15..7b239efb 100644 --- a/tests/steps/test_steps.py +++ b/tests/steps/test_steps.py @@ -31,7 +31,10 @@ class TestStepOutput: - @pytest.mark.parametrize("output_dict, expected", [(output_dict_1, output_dict_1), (output_dict_2, output_dict_2)]) + @pytest.mark.parametrize( + "output_dict, expected", + [(output_dict_1, output_dict_1), (output_dict_2, output_dict_2)], + ) def test_stepoutput_validate_output(self, output_dict, expected): """Tests that validate_output returns the expected output dict""" test_output = DummyOutput(**output_dict) @@ -59,7 +62,12 @@ def test_stepoutput_unhappy_flow(self, output_dict): ( "foo", 42, - {"a": "foo", "b": 42, "name": "DummyOutput", "description": "Dummy output for testing purposes."}, + { + "a": "foo", + "b": 42, + "name": "DummyOutput", + "description": "Dummy output for testing purposes.", + }, ), # test wrong type assigned ("foo", "invalid type", ValidationError), @@ -82,7 +90,6 @@ def test_stepoutput_lazy(self, a, b, expected): lazy_output.validate_output() else: actual = lazy_output.validate_output().model_dump() - print(f"{actual=}") assert actual == expected @pytest.mark.parametrize("attribute, expected", [("a", True), ("d", False)]) @@ -159,7 +166,9 @@ def execute(self) -> Output: assert step.model_dump() == dict(a="foo", description="SimpleStep", name="SimpleStep") assert step.execute().model_dump() == dict( - b="foo-some-suffix", name="SimpleStep.Output", description="Output for SimpleStep" + b="foo-some-suffix", + name="SimpleStep.Output", + description="Output for SimpleStep", ) # as long as the following doesn't raise an error, we're good @@ -180,9 +189,7 @@ def test_step_execute_and_run(self): b=2, c="foofoo", ) - print(f"{actual_execute=}") - print(f"{actual_run=}") - print(f"{expected=}") + assert actual_execute == actual_run == expected # 3 ways to retrieve output @@ -283,7 +290,11 @@ def execute(self): # Check that do_execute was not called multiple times assert ( - getattr(getattr(obj.execute, "_partialmethod", None), "_step_execute_wrapper_sentinel", None) + getattr( + getattr(obj.execute, "_partialmethod", None), + "_step_execute_wrapper_sentinel", + None, + ) is StepMetaClass._step_execute_wrapper_sentinel ) assert getattr(getattr(obj.execute, "_partialmethod", None), "_wrap_count", 0) == 1 @@ -345,6 +356,51 @@ def test_custom_metaclass_log(self, test_class): assert "It's me from custom meta class" in print_output + def test_log_and_wrapper_duplication(self): + """ + Test that when a step is inherited multiple times, the log and wrapper are only called once when subclasses do + not have explicit execute methods set. + """ + + class MyCustomParentStep(Step): + foo: str = Field(default=..., description="Foo") + bar: str = Field(default=..., description="Bar") + + class Output(Step.Output): + baz: str = Field(default=..., description="Baz") + + def execute(self) -> Output: + self.log.error("This should not be logged") + self.output.baz = self.foo + self.bar + + class MyCustomChildStep(MyCustomParentStep): ... + + class MyCustomGrandChildStep(MyCustomChildStep): + def execute(self) -> MyCustomChildStep.Output: + self.output.baz = self.foo + self.bar + + class MyCustomGreatGrandChildStep(MyCustomGrandChildStep): ... + + with ( + patch.object(MyCustomGreatGrandChildStep, "log", autospec=True) as mock_log, + ): + obj = MyCustomGreatGrandChildStep(foo="foo", bar="bar", qux="qux") + obj.execute() + + name = MyCustomGreatGrandChildStep.__name__ + + # Check that logs were called once (and only once) with the correct messages, and in the correct order + calls = [ + call.info("Start running step"), + call.debug(f"Step Input: name='{name}' description='{name}' foo='foo' bar='bar' qux='qux'"), + call.debug(f"Step Output: name='{name}.Output' description='Output for {name}' baz='foobar'"), + call.info("Finished running step"), + ] + mock_log.assert_has_calls(calls, any_order=False) + + # Check that the execute method is only wrapped once + assert getattr(getattr(obj.execute, "_partialmethod", None), "_wrap_count", 0) == 1 + @pytest.mark.parametrize("test_class", [YourClassWithCustomMeta2]) def test_custom_metaclass_output(self, test_class): with patch.object(test_class, "log", autospec=True) as mock_log: From 93f413ebf0c2c4e00eebedec23c145046cc48c54 Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Mon, 18 Nov 2024 20:19:45 +0100 Subject: [PATCH 11/33] version bump --- src/koheesio/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/koheesio/__about__.py b/src/koheesio/__about__.py index 740b6265..8ebac8d2 100644 --- a/src/koheesio/__about__.py +++ b/src/koheesio/__about__.py @@ -12,7 +12,7 @@ LICENSE_INFO = "Licensed as Apache 2.0" SOURCE = "https://github.com/Nike-Inc/koheesio" -__version__ = "0.9.0rc2" +__version__ = "0.9.0rc3" __logo__ = ( 75, ( From 5146f591f231f15ce690d35f216d0ff802e53ae7 Mon Sep 17 00:00:00 2001 From: Mikita Sakalouski <38785549+mikita-sakalouski@users.noreply.github.com> Date: Mon, 18 Nov 2024 20:33:54 +0100 Subject: [PATCH 12/33] fix: adjust branch fetching (#106) ## Description Fix github action for fetching correct branch in case of fork ## How Has This Been Tested? Github action run ## Types of changes - [x ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Checklist: - [x] My code follows the code style of this project. - [x] My change requires a change to the documentation. - [x] I have updated the documentation accordingly. - [x] I have read the **CONTRIBUTING** document. - [x] I have added tests to cover my changes. - [x] All new and existing tests passed. --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1d6e640b..c48edcb7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -42,7 +42,7 @@ jobs: ref: ${{ github.event.pull_request.head.ref }} repository: ${{ github.event.pull_request.head.repo.full_name }} - name: Fetch target branch - run: git fetch origin ${{ github.event.pull_request.base.ref || 'main'}}:${{ github.event.pull_request.base.ref || 'main'}} + run: git fetch origin ${{ github.event.pull_request.head.ref || 'main'}}:${{ github.event.pull_request.base.ref || 'main'}} - name: Check changes id: check run: | From 9f32fcc8a831694956c353cea796032c7b01e614 Mon Sep 17 00:00:00 2001 From: Danny Meijer Date: Tue, 19 Nov 2024 12:23:43 +0100 Subject: [PATCH 13/33] [FIX] broken import statements and updated hello-world.md (#107) Fix broken import statements and updated `hello-world.md` ## Description This pull request addresses documentation errors in the `hello-world.md` (spark) example as part of fixing issue [#104](https://github.com/Nike-Inc/koheesio/issues/104). - The changes involve updating the hello_world.md file; a paragraph was added explaining the need of setting up a SparkSession prior to invoking - Additionally, any other import errors that were present were addressed ## Related Issue #104 ## Motivation and Context N/A ## How Has This Been Tested? N/A ## Screenshots (if appropriate): N/A ## Types of changes - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Checklist: - [x] My code follows the code style of this project. - [ ] My change requires a change to the documentation. - [x] I have updated the documentation accordingly. - [x] I have read the **CONTRIBUTING** document. - [ ] I have added tests to cover my changes. - [x] All new and existing tests passed. --------- Co-authored-by: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> --- README.md | 12 ++++--- docs/reference/spark/transformations.md | 12 +++---- docs/tutorials/advanced-data-processing.md | 6 ++-- docs/tutorials/hello-world.md | 38 ++++++++++++++------ docs/tutorials/testing-koheesio-steps.md | 32 +++++++++-------- src/koheesio/spark/transformations/lookup.py | 8 +++++ 6 files changed, 70 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 41e7f0a5..75694533 100644 --- a/README.md +++ b/README.md @@ -194,23 +194,27 @@ the `pyproject.toml` entry mentioned above or installing through pip. ### Integrations - __Spark Expectations:__ - Available through the `koheesio.steps.integration.spark.dq.spark_expectations` module; installable through the `se` extra. + Available through the `koheesio.integrations.spark.dq.spark_expectations` module; installable through the `se` extra. - SE Provides Data Quality checks for Spark DataFrames. - For more information, refer to the [Spark Expectations docs](https://engineering.nike.com/spark-expectations). -[//]: # (- **Brickflow:** Available through the `koheesio.steps.integration.workflow` module; installable through the `bf` extra.) +[//]: # (- **Brickflow:** Available through the `koheesio.integrations.workflow` module; installable through the `bf` extra.) [//]: # ( - Brickflow is a workflow orchestration tool that allows you to define and execute workflows in a declarative way.) [//]: # ( - For more information, refer to the [Brickflow docs](https://engineering.nike.com/brickflow)) - __Box__: - Available through the `koheesio.integration.box` module; installable through the `box` extra. + Available through the `koheesio.integrations.box` module; installable through the `box` extra. - [Box](https://www.box.com) is a cloud content management and file sharing service for businesses. - __SFTP__: - Available through the `koheesio.integration.spark.sftp` module; installable through the `sftp` extra. + Available through the `koheesio.integrations.spark.sftp` module; installable through the `sftp` extra. - SFTP is a network protocol used for secure file transfer over a secure shell. - The SFTP integration of Koheesio relies on [paramiko](https://www.paramiko.org/) +- __Snowflake__: + Available through the `koheesio.integrations.snowflake` module; installable through the `snowflake` extra. + - [Snowflake](https://www.snowflake.com) is a cloud-based data warehousing platform. + [//]: # (TODO: add implementations) [//]: # (## Implementations) [//]: # (TODO: add async extra) diff --git a/docs/reference/spark/transformations.md b/docs/reference/spark/transformations.md index 21363a3b..f6eec994 100644 --- a/docs/reference/spark/transformations.md +++ b/docs/reference/spark/transformations.md @@ -76,7 +76,7 @@ Here's an example of a `ColumnsTransformation`: ```python from pyspark.sql import functions as f -from koheesio.steps.transformations import ColumnsTransformation +from koheesio.spark.transformations import ColumnsTransformation class AddOne(ColumnsTransformation): def execute(self): @@ -109,7 +109,7 @@ Here's an example of a `ColumnsTransformationWithTarget`: ```python from pyspark.sql import Column -from koheesio.steps.transformations import ColumnsTransformationWithTarget +from koheesio.spark.transformations import ColumnsTransformationWithTarget class AddOneWithTarget(ColumnsTransformationWithTarget): def func(self, col: Column): @@ -167,7 +167,7 @@ examples: ```python from pyspark.sql import SparkSession - from koheesio.steps.transformations import DataframeLookup, JoinMapping, TargetColumn, JoinType + from koheesio.spark.transformations.lookup import DataframeLookup, JoinMapping, TargetColumn, JoinType spark = SparkSession.builder.getOrCreate() left_df = spark.createDataFrame([(1, "A"), (2, "B")], ["id", "value"]) @@ -191,7 +191,7 @@ examples: ```python from pyspark.sql import SparkSession - from koheesio.steps.transformations import HashUUID5 + from koheesio.spark.transformations.uuid5 import HashUUID5 spark = SparkSession.builder.getOrCreate() df = spark.createDataFrame([(1, "A"), (2, "B")], ["id", "value"]) @@ -245,8 +245,8 @@ how to chain transformations: ```python from pyspark.sql import SparkSession -from koheesio.steps.transformations import HashUUID5 -from koheesio.steps.transformations import DataframeLookup, JoinMapping, TargetColumn, JoinType +from koheesio.spark.transformations.uuid5 import HashUUID5 +from koheesio.spark.transformations.lookup import DataframeLookup, JoinMapping, TargetColumn, JoinType # Create a SparkSession spark = SparkSession.builder.getOrCreate() diff --git a/docs/tutorials/advanced-data-processing.md b/docs/tutorials/advanced-data-processing.md index 99111048..ac0f95b6 100644 --- a/docs/tutorials/advanced-data-processing.md +++ b/docs/tutorials/advanced-data-processing.md @@ -36,8 +36,8 @@ Partitioning is a technique that divides your data into smaller, more manageable allows you to specify the partitioning scheme for your data when writing it to a target. ```python -from koheesio.steps.writers.delta import DeltaTableWriter -from koheesio.tasks.etl_task import EtlTask +from koheesio.spark.writers.delta import DeltaTableWriter +from koheesio.spark.etl_task import EtlTask class MyTask(EtlTask): target = DeltaTableWriter(table="my_table", partitionBy=["column1", "column2"]) @@ -52,7 +52,7 @@ class MyTask(EtlTask): [//]: # () [//]: # (```python) -[//]: # (from koheesio.steps.transformations.cache import CacheTransformation) +[//]: # (from koheesio.spark.transformations.cache import CacheTransformation) [//]: # () [//]: # (class MyTask(EtlTask):) diff --git a/docs/tutorials/hello-world.md b/docs/tutorials/hello-world.md index 596466b4..3b73d987 100644 --- a/docs/tutorials/hello-world.md +++ b/docs/tutorials/hello-world.md @@ -1,5 +1,23 @@ # Simple Examples +## Bring your own SparkSession + +The Koheesio Spark module does not set up a SparkSession for you. You need to create a SparkSession before using +Koheesio spark classes. This is the entry point for any Spark functionality, allowing the step to interact with the +Spark cluster. + +- Every `SparkStep` has a `spark` attribute, which is the active SparkSession. +- Koheesio supports both local and remote (connect) Spark Sessions +- The SparkSession you created can be explicitly passed to the `SparkStep` constructor (this is optional) + +To create a simple SparkSession, you can use the following code: + +```python +from pyspark.sql import SparkSession + +spark = SparkSession.builder.getOrCreate() +``` + ## Creating a Custom Step This example demonstrates how to use the `SparkStep` class from the `koheesio` library to create a custom step named @@ -8,7 +26,7 @@ This example demonstrates how to use the `SparkStep` class from the `koheesio` l ### Code ```python -from koheesio.steps.step import SparkStep +from koheesio.spark import SparkStep class HelloWorldStep(SparkStep): message: str @@ -21,7 +39,7 @@ class HelloWorldStep(SparkStep): ### Usage ```python -hello_world_step = HelloWorldStep(message="Hello, World!") +hello_world_step = HelloWorldStep(message="Hello, World!", spark=spark) # optionally pass the spark session hello_world_step.execute() hello_world_step.output.df.show() @@ -33,16 +51,15 @@ The `HelloWorldStep` class is a `SparkStep` in Koheesio, designed to generate a - `HelloWorldStep` inherits from `SparkStep`, a fundamental building block in Koheesio for creating data processing steps with Apache Spark. - It has a `message` attribute. When creating an instance of `HelloWorldStep`, you can pass a custom message that will be used in the DataFrame. -- `SparkStep` has a `spark` attribute, which is the active SparkSession. This is the entry point for any Spark functionality, allowing the step to interact with the Spark cluster. - `SparkStep` also includes an `Output` class, used to store the output of the step. In this case, `Output` has a `df` attribute to store the output DataFrame. - The `execute` method creates a DataFrame with the custom message and stores it in `output.df`. It doesn't return a value explicitly; instead, the output DataFrame can be accessed via `output.df`. - Koheesio uses pydantic for automatic validation of the step's input and output, ensuring they are correctly defined and of the correct types. +- The `spark` attribute can be optionally passed to the constructor when creating an instance of `HelloWorldStep`. This allows you to use an existing SparkSession or create a new one specifically for the step. +- If no `SparkSession` is passed to a `SparkStep`, Koheesio will use the `SparkSession.getActiveSession()` method to attempt retrieving an active SparkSession. If no active session is found, your code will not work. Note: Pydantic is a data validation library that provides a way to validate that the data (in this case, the input and output of the step) conforms to the expected format. ---- - ## Creating a Custom Task This example demonstrates how to use the `EtlTask` from the `koheesio` library to create a custom task named `MyFavoriteMovieTask`. @@ -51,9 +68,10 @@ This example demonstrates how to use the `EtlTask` from the `koheesio` library t ```python from typing import Any -from pyspark.sql import DataFrame, functions as f -from koheesio.steps.transformations import Transform -from koheesio.tasks.etl_task import EtlTask +from pyspark.sql import functions as f +from koheesio.spark import DataFrame +from koheesio.spark.transformations.transform import Transform +from koheesio.spark.etl_task import EtlTask def add_column(df: DataFrame, target_column: str, value: Any): @@ -104,8 +122,8 @@ source: ```python from pyspark.sql import SparkSession from koheesio.context import Context -from koheesio.steps.readers import DummyReader -from koheesio.steps.writers.dummy import DummyWriter +from koheesio.spark.readers.dummy import DummyReader +from koheesio.spark.writers.dummy import DummyWriter context = Context.from_yaml("sample.yaml") diff --git a/docs/tutorials/testing-koheesio-steps.md b/docs/tutorials/testing-koheesio-steps.md index 015ee554..f55a7c83 100644 --- a/docs/tutorials/testing-koheesio-steps.md +++ b/docs/tutorials/testing-koheesio-steps.md @@ -9,13 +9,15 @@ Unit testing involves testing individual components of the software in isolation Here's an example of how to unit test a Koheesio task: ```python -from koheesio.tasks.etl_task import EtlTask -from koheesio.steps.readers import DummyReader -from koheesio.steps.writers.dummy import DummyWriter -from koheesio.steps.transformations import Transform -from pyspark.sql import SparkSession, DataFrame +from pyspark.sql import SparkSession from pyspark.sql.functions import col +from koheesio.spark import DataFrame +from koheesio.spark.etl_task import EtlTask +from koheesio.spark.readers.dummy import DummyReader +from koheesio.spark.writers.dummy import DummyWriter +from koheesio.spark.transformations.transform import Transform + def filter_age(df: DataFrame) -> DataFrame: return df.filter(col("Age") > 18) @@ -62,12 +64,12 @@ Here's an example of how to write an integration test for this task: ```python # my_module.py -from koheesio.tasks.etl_task import EtlTask -from koheesio.spark.readers.delta import DeltaReader -from koheesio.steps.writers.delta import DeltaWriter -from koheesio.steps.transformations import Transform -from koheesio.context import Context from pyspark.sql.functions import col +from koheesio.spark.etl_task import EtlTask +from koheesio.spark.readers.delta import DeltaTableReader +from koheesio.spark.writers.delta import DeltaTableWriter +from koheesio.spark.transformations.transform import Transform +from koheesio.context import Context def filter_age(df): @@ -84,8 +86,8 @@ context = Context({ }) task = EtlTask( - source=DeltaReader(**context.reader_options), - target=DeltaWriter(**context.writer_options), + source=DeltaTableReader(**context.reader_options), + target=DeltaTableWriter(**context.writer_options), transformations=[ Transform(filter_age) ] @@ -97,11 +99,11 @@ Now, let's create a test for this task. We'll use pytest and unittest.mock to mo ```python # test_my_module.py import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import patch from pyspark.sql import SparkSession from koheesio.context import Context -from koheesio.steps.readers import Reader -from koheesio.steps.writers import Writer +from koheesio.spark.readers import Reader +from koheesio.spark.writers import Writer from my_module import task diff --git a/src/koheesio/spark/transformations/lookup.py b/src/koheesio/spark/transformations/lookup.py index 31b35cb3..6a8c54af 100644 --- a/src/koheesio/spark/transformations/lookup.py +++ b/src/koheesio/spark/transformations/lookup.py @@ -19,6 +19,14 @@ from koheesio.spark import DataFrame from koheesio.spark.transformations import Transformation +__all__ = [ + "JoinMapping", + "TargetColumn", + "JoinType", + "JoinHint", + "DataframeLookup", +] + class JoinMapping(BaseModel): """Mapping for joining two dataframes together""" From 047506a4f33eb0b5c2e8dfd6ac94dfc99080242c Mon Sep 17 00:00:00 2001 From: Mikita Sakalouski <38785549+mikita-sakalouski@users.noreply.github.com> Date: Thu, 21 Nov 2024 11:15:11 +0100 Subject: [PATCH 14/33] fix: test github (#109) ## Description ## Related Issue ## Motivation and Context ## How Has This Been Tested? ## Screenshots (if appropriate): ## Types of changes - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Checklist: - [ ] My code follows the code style of this project. - [ ] My change requires a change to the documentation. - [ ] I have updated the documentation accordingly. - [ ] I have read the **CONTRIBUTING** document. - [ ] I have added tests to cover my changes. - [ ] All new and existing tests passed. --- .github/workflows/test.yml | 2 +- src/koheesio/spark/delta.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c48edcb7..cb19d5c4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -62,7 +62,7 @@ jobs: tests: needs: check_changes - if: needs.check_changes.outputs.python_changed > 0 || needs.check_changes.outputs.toml_changed > 0 || github.event_name == 'workflow_dispatch' + if: needs.check_changes.outputs.python_changed > 0 || needs.check_changes.outputs.toml_changed > 0 || github.event_name == 'workflow_dispatch' || ${{ github.event.pull_request.head.repo.owner.login }} != ${{ github.event.pull_request.base.repo.owner.login }} name: Python ${{ matrix.python-version }} with PySpark ${{ matrix.pyspark-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} runs-on: ${{ matrix.os }} diff --git a/src/koheesio/spark/delta.py b/src/koheesio/spark/delta.py index e4ef31c3..291cd31d 100644 --- a/src/koheesio/spark/delta.py +++ b/src/koheesio/spark/delta.py @@ -20,6 +20,7 @@ class DeltaTableStep(SparkStep): DeltaTable aims to provide a simple interface to create and manage Delta tables. It is a wrapper around the Spark SQL API for Delta tables. + ## Description Introduce private attributes for batch and stream readers to enable adding options to DeltaReader both streaming and writing. ## Related Issue #110 ## Motivation and Context Provide possibility to override readers, e.g. add more options to readers. ## How Has This Been Tested? Current tests ## Types of changes - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Checklist: - [x] My code follows the code style of this project. - [x] My change requires a change to the documentation. - [x] I have updated the documentation accordingly. - [x] I have read the **CONTRIBUTING** document. - [x] I have added tests to cover my changes. - [x] All new and existing tests passed. --- .gitignore | 1 + src/koheesio/spark/readers/delta.py | 22 ++++++++++++++-------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 0184ed10..275084e1 100644 --- a/.gitignore +++ b/.gitignore @@ -139,3 +139,4 @@ out/** # DevContainer .devcontainer +uv.lock diff --git a/src/koheesio/spark/readers/delta.py b/src/koheesio/spark/readers/delta.py index 8983f1aa..e7dde7cc 100644 --- a/src/koheesio/spark/readers/delta.py +++ b/src/koheesio/spark/readers/delta.py @@ -12,6 +12,8 @@ from typing import Any, Dict, Optional, Union +from pydantic import PrivateAttr + from pyspark.sql import DataFrameReader from pyspark.sql import functions as f @@ -163,6 +165,7 @@ class DeltaTableReader(Reader): # private attrs __temp_view_name__: Optional[str] = None + __reader: Optional[Union[DataStreamReader, DataFrameReader]] = PrivateAttr(default=None) @property def temp_view_name(self) -> str: @@ -286,23 +289,26 @@ def normalize(v: Union[str, bool]) -> str: # Any options with `value == None` are filtered out return {k: normalize(v) for k, v in options.items() if v is not None} - @property - def _stream_reader(self) -> DataStreamReader: + def __get_stream_reader(self) -> DataStreamReader: """Returns a basic DataStreamReader (streaming mode)""" return self.spark.readStream.format("delta") - @property - def _batch_reader(self) -> DataFrameReader: + def __get_batch_reader(self) -> DataFrameReader: """Returns a basic DataFrameReader (batch mode)""" return self.spark.read.format("delta") @property def reader(self) -> Union[DataStreamReader, DataFrameReader]: """Return the reader for the DeltaTableReader based on the `streaming` attribute""" - reader = self._stream_reader if self.streaming else self._batch_reader - for key, value in self.get_options().items(): - reader = reader.option(key, value) - return reader + if not self.__reader: + self.__reader = self.__get_stream_reader() if self.streaming else self.__get_batch_reader() + self.__reader = self.__reader.options(**self.get_options()) + + return self.__reader + + @reader.setter + def reader(self, value: Union[DataStreamReader, DataFrameReader]): + self.__reader = value def execute(self) -> Reader.Output: df = self.reader.table(self.table.table_name) From ac95f2d5df9b08406d70151b695e3ae56a009cf0 Mon Sep 17 00:00:00 2001 From: Mikita Sakalouski <38785549+mikita-sakalouski@users.noreply.github.com> Date: Fri, 22 Nov 2024 11:23:39 +0100 Subject: [PATCH 17/33] chore: bump version --- src/koheesio/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/koheesio/__about__.py b/src/koheesio/__about__.py index 8ebac8d2..db20611c 100644 --- a/src/koheesio/__about__.py +++ b/src/koheesio/__about__.py @@ -12,7 +12,7 @@ LICENSE_INFO = "Licensed as Apache 2.0" SOURCE = "https://github.com/Nike-Inc/koheesio" -__version__ = "0.9.0rc3" +__version__ = "0.9.0rc4" __logo__ = ( 75, ( From 602866b7decc79487c39df6cfbf985f243328812 Mon Sep 17 00:00:00 2001 From: Max <43565398+maxim-mityutko@users.noreply.github.com> Date: Fri, 22 Nov 2024 15:04:25 +0100 Subject: [PATCH 18/33] [feature] Add support for HyperProcess parameters (#112) ## Description Refer to #113 ## Related Issue #113 ## Motivation and Context Exposing HyperProcess parameters allow certain customizations to how the Hyper process runs. ## How Has This Been Tested? Unit tests ## Screenshots (if appropriate): ## Types of changes - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Checklist: - [x] My code follows the code style of this project. - [ ] My change requires a change to the documentation. - [ ] I have updated the documentation accordingly. - [x] I have read the **CONTRIBUTING** document. - [x] I have added tests to cover my changes. - [x] All new and existing tests passed. Co-authored-by: Maxim Mityutko --- .../integrations/spark/tableau/hyper.py | 20 +++++++++++++--- .../spark/integrations/tableau/test_hyper.py | 23 +++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/src/koheesio/integrations/spark/tableau/hyper.py b/src/koheesio/integrations/spark/tableau/hyper.py index c7df9be0..dafd310b 100644 --- a/src/koheesio/integrations/spark/tableau/hyper.py +++ b/src/koheesio/integrations/spark/tableau/hyper.py @@ -141,6 +141,10 @@ def execute(self) -> SparkStep.Output: class HyperFileWriter(HyperFile): """ Base class for all HyperFileWriter classes + + Reference + --------- + HyperProcess parameters: https://tableau.github.io/hyper-db/docs/hyper-api/hyper_process/#process-settings """ path: PurePath = Field( @@ -155,6 +159,12 @@ class HyperFileWriter(HyperFile): description="Table definition to write to the Hyper file as described in " "https://tableau.github.io/hyper-db/lang_docs/py/tableauhyperapi.html#tableauhyperapi.TableDefinition", ) + hyper_process_parameters: dict = Field( + # Disable logging by default, if logging is required remove the "log_config" key and refer to the Hyper API docs + default={"log_config": ""}, + description="Set HyperProcess parameters, see Tableau Hyper API documentation for more details: " + "https://tableau.github.io/hyper-db/docs/hyper-api/hyper_process/#process-settings", + ) class Output(StepOutput): """ @@ -218,8 +228,10 @@ class HyperFileListWriter(HyperFileWriter): data: conlist(List[Any], min_length=1) = Field(default=..., description="List of rows to write to the Hyper file") - def execute(self) -> HyperFileWriter.Output: - with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hp: + def execute(self): + with HyperProcess( + telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU, parameters=self.hyper_process_parameters + ) as hp: with Connection( endpoint=hp.endpoint, database=self.hyper_path, create_mode=CreateMode.CREATE_AND_REPLACE ) as connection: @@ -279,7 +291,9 @@ def execute(self) -> HyperFileWriter.Output: _file = [str(f) for f in self.file] array_files = "'" + "','".join(_file) + "'" - with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hp: + with HyperProcess( + telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU, parameters=self.hyper_process_parameters + ) as hp: with Connection( endpoint=hp.endpoint, database=self.hyper_path, create_mode=CreateMode.CREATE_AND_REPLACE ) as connection: diff --git a/tests/spark/integrations/tableau/test_hyper.py b/tests/spark/integrations/tableau/test_hyper.py index d57cd971..c960969b 100644 --- a/tests/spark/integrations/tableau/test_hyper.py +++ b/tests/spark/integrations/tableau/test_hyper.py @@ -115,3 +115,26 @@ def test_hyper_file_dataframe_writer(self, data_path, df_with_all_types): ("timestamp", "timestamp"), ("date", "date"), ] + + @pytest.fixture() + def hyper_file_writer(self): + return HyperFileListWriter( + name="test", + table_definition=TableDefinition( + table_name=TableName("Extract", "Extract"), + columns=[ + TableDefinition.Column(name="string", type=SqlType.text(), nullability=NOT_NULLABLE), + ], + ), + data=[["text_1"]], + ) + + def test_hyper_file_process_custom_log_dir(self, hyper_file_writer): + import os + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + hyper_file_writer.hyper_process_parameters = {"log_dir": temp_dir} + hyper_file_writer.execute() + + assert os.path.exists(f"{temp_dir}/hyperd.log") From c34abbeb88672b048cb1f7caf233132612e83815 Mon Sep 17 00:00:00 2001 From: Mikita Sakalouski <38785549+mikita-sakalouski@users.noreply.github.com> Date: Sun, 24 Nov 2024 13:22:25 +0100 Subject: [PATCH 19/33] [HOTFIX] Remove duplicated implementation (#116) --- src/koheesio/spark/snowflake.py | 1346 +------------------------------ 1 file changed, 20 insertions(+), 1326 deletions(-) diff --git a/src/koheesio/spark/snowflake.py b/src/koheesio/spark/snowflake.py index 4a70b8f5..786420cb 100644 --- a/src/koheesio/spark/snowflake.py +++ b/src/koheesio/spark/snowflake.py @@ -1,76 +1,27 @@ -# noinspection PyUnresolvedReferences """ Snowflake steps and tasks for Koheesio - -Every class in this module is a subclass of `Step` or `Task` and is used to perform operations on Snowflake. - -Notes ------ -Every Step in this module is based on [SnowflakeBaseModel](./snowflake.md#koheesio.spark.snowflake.SnowflakeBaseModel). -The following parameters are available for every Step. - -Parameters ----------- -url : str - Hostname for the Snowflake account, e.g. .snowflakecomputing.com. - Alias for `sfURL`. -user : str - Login name for the Snowflake user. - Alias for `sfUser`. -password : SecretStr - Password for the Snowflake user. - Alias for `sfPassword`. -database : str - The database to use for the session after connecting. - Alias for `sfDatabase`. -sfSchema : str - The schema to use for the session after connecting. - Alias for `schema` ("schema" is a reserved name in Pydantic, so we use `sfSchema` as main name instead). -role : str - The default security role to use for the session after connecting. - Alias for `sfRole`. -warehouse : str - The default virtual warehouse to use for the session after connecting. - Alias for `sfWarehouse`. -authenticator : Optional[str], optional, default=None - Authenticator for the Snowflake user. Example: "okta.com". -options : Optional[Dict[str, Any]], optional, default={"sfCompress": "on", "continue_on_error": "off"} - Extra options to pass to the Snowflake connector. -format : str, optional, default="snowflake" - The default `snowflake` format can be used natively in Databricks, use `net.snowflake.spark.snowflake` in other - environments and make sure to install required JARs. """ -from typing import Any, Callable, Dict, List, Optional, Set, Union -from abc import ABC -from copy import deepcopy -import json -from textwrap import dedent - -from pyspark.sql import Window -from pyspark.sql import functions as f -from pyspark.sql import types as t - -from koheesio import Step, StepOutput -from koheesio.logger import LoggingFactory, warn -from koheesio.models import ( - BaseModel, - ExtraParamsMixin, - Field, - SecretStr, - conlist, - field_validator, - model_validator, -) -from koheesio.spark import DataFrame, SparkStep -from koheesio.spark.delta import DeltaTableStep -from koheesio.spark.readers.delta import DeltaTableReader, DeltaTableStreamReader -from koheesio.spark.readers.jdbc import JdbcReader -from koheesio.spark.transformations import Transformation -from koheesio.spark.writers import BatchOutputMode, Writer -from koheesio.spark.writers.stream import ( - ForEachBatchStreamWriter, - writer_to_foreachbatch, +from koheesio.integrations.spark.snowflake import ( + AddColumn, + CreateOrReplaceTableFromDataFrame, + DbTableQuery, + GetTableSchema, + GrantPrivilegesOnFullyQualifiedObject, + GrantPrivilegesOnObject, + GrantPrivilegesOnTable, + GrantPrivilegesOnView, + Query, + RunQuery, + SnowflakeBaseModel, + SnowflakeReader, + SnowflakeStep, + SnowflakeTableStep, + SnowflakeTransformation, + SnowflakeWriter, + SynchronizeDeltaToSnowflakeTask, + SyncTableAndDataFrameSchema, + TableExists, ) __all__ = [ @@ -94,1260 +45,3 @@ "SynchronizeDeltaToSnowflakeTask", "TableExists", ] - -# pylint: disable=inconsistent-mro, too-many-lines -# Turning off inconsistent-mro because we are using ABCs and Pydantic models and Tasks together in the same class -# Turning off too-many-lines because we are defining a lot of classes in this file - - -class SnowflakeBaseModel(BaseModel, ExtraParamsMixin, ABC): - """ - BaseModel for setting up Snowflake Driver options. - - Notes - ----- - * Snowflake is supported natively in Databricks 4.2 and newer: - https://docs.snowflake.com/en/user-guide/spark-connector-databricks - * Refer to Snowflake docs for the installation instructions for non-Databricks environments: - https://docs.snowflake.com/en/user-guide/spark-connector-install - * Refer to Snowflake docs for connection options: - https://docs.snowflake.com/en/user-guide/spark-connector-use#setting-configuration-options-for-the-connector - - Parameters - ---------- - url : str - Hostname for the Snowflake account, e.g. .snowflakecomputing.com. - Alias for `sfURL`. - user : str - Login name for the Snowflake user. - Alias for `sfUser`. - password : SecretStr - Password for the Snowflake user. - Alias for `sfPassword`. - database : str - The database to use for the session after connecting. - Alias for `sfDatabase`. - sfSchema : str - The schema to use for the session after connecting. - Alias for `schema` ("schema" is a reserved name in Pydantic, so we use `sfSchema` as main name instead). - role : str - The default security role to use for the session after connecting. - Alias for `sfRole`. - warehouse : str - The default virtual warehouse to use for the session after connecting. - Alias for `sfWarehouse`. - authenticator : Optional[str], optional, default=None - Authenticator for the Snowflake user. Example: "okta.com". - options : Optional[Dict[str, Any]], optional, default={"sfCompress": "on", "continue_on_error": "off"} - Extra options to pass to the Snowflake connector. - format : str, optional, default="snowflake" - The default `snowflake` format can be used natively in Databricks, use `net.snowflake.spark.snowflake` in other - environments and make sure to install required JARs. - - """ - - url: str = Field( - default=..., - alias="sfURL", - description="Hostname for the Snowflake account, e.g. .snowflakecomputing.com", - examples=["example.snowflakecomputing.com"], - ) - user: str = Field(default=..., alias="sfUser", description="Login name for the Snowflake user") - password: SecretStr = Field(default=..., alias="sfPassword", description="Password for the Snowflake user") - authenticator: Optional[str] = Field( - default=None, - description="Authenticator for the Snowflake user", - examples=["okta.com"], - ) - database: str = Field( - default=..., alias="sfDatabase", description="The database to use for the session after connecting" - ) - sfSchema: str = Field(default=..., alias="schema", description="The schema to use for the session after connecting") - role: str = Field( - default=..., alias="sfRole", description="The default security role to use for the session after connecting" - ) - warehouse: str = Field( - default=..., - alias="sfWarehouse", - description="The default virtual warehouse to use for the session after connecting", - ) - options: Optional[Dict[str, Any]] = Field( - default={"sfCompress": "on", "continue_on_error": "off"}, - description="Extra options to pass to the Snowflake connector", - ) - format: str = Field( - default="snowflake", - description="The default `snowflake` format can be used natively in Databricks, use " - "`net.snowflake.spark.snowflake` in other environments and make sure to install required JARs.", - ) - - def get_options(self) -> Dict[str, Any]: - """Get the sfOptions as a dictionary.""" - return { - key: value - for key, value in { - "sfURL": self.url, - "sfUser": self.user, - "sfPassword": self.password.get_secret_value(), - "authenticator": self.authenticator, - "sfDatabase": self.database, - "sfSchema": self.sfSchema, - "sfRole": self.role, - "sfWarehouse": self.warehouse, - **self.options, # type: ignore - }.items() - if value is not None - } - - -class SnowflakeStep(SnowflakeBaseModel, SparkStep, ABC): - """Expands the SnowflakeBaseModel so that it can be used as a Step""" - - -class SnowflakeTableStep(SnowflakeStep, ABC): - """Expands the SnowflakeStep, adding a 'table' parameter""" - - table: str = Field(default=..., description="The name of the table") - - def get_options(self) -> Dict[str, Any]: - options = super().get_options() - options["table"] = self.table - return options - - -class SnowflakeReader(SnowflakeBaseModel, JdbcReader): - """ - Wrapper around JdbcReader for Snowflake. - - Example - ------- - ```python - sr = SnowflakeReader( - url="foo.snowflakecomputing.com", - user="YOUR_USERNAME", - password="***", - database="db", - schema="schema", - ) - df = sr.read() - ``` - - Notes - ----- - * Snowflake is supported natively in Databricks 4.2 and newer: - https://docs.snowflake.com/en/user-guide/spark-connector-databricks - * Refer to Snowflake docs for the installation instructions for non-Databricks environments: - https://docs.snowflake.com/en/user-guide/spark-connector-install - * Refer to Snowflake docs for connection options: - https://docs.snowflake.com/en/user-guide/spark-connector-use#setting-configuration-options-for-the-connector - """ - - # overriding `driver` property of JdbcReader, because it is not required by Snowflake - driver: Optional[str] = None # type: ignore - - -class SnowflakeTransformation(SnowflakeBaseModel, Transformation, ABC): - """Adds Snowflake parameters to the Transformation class""" - - -class RunQuery(SnowflakeStep): - """ - Run a query on Snowflake that does not return a result, e.g. create table statement - - This is a wrapper around 'net.snowflake.spark.snowflake.Utils.runQuery' on the JVM - - Example - ------- - ```python - RunQuery( - database="MY_DB", - schema="MY_SCHEMA", - warehouse="MY_WH", - user="account", - password="***", - role="APPLICATION.SNOWFLAKE.ADMIN", - query="CREATE TABLE test (col1 string)", - ).execute() - ``` - """ - - query: str = Field(default=..., description="The query to run", alias="sql") - - @field_validator("query") - def validate_query(cls, query: str) -> str: - """Replace escape characters""" - return query.replace("\\n", "\n").replace("\\t", "\t").strip() - - def get_options(self) -> Dict[str, Any]: - # Executing the RunQuery without `host` option in Databricks throws: - # An error occurred while calling z:net.snowflake.spark.snowflake.Utils.runQuery. - # : java.util.NoSuchElementException: key not found: host - options = super().get_options() - options["host"] = options["sfURL"] - return options - - def execute(self) -> SnowflakeStep.Output: - if not self.query: - self.log.warning("Empty string given as query input, skipping execution") - return - # noinspection PyProtectedMember - self.spark._jvm.net.snowflake.spark.snowflake.Utils.runQuery(self.get_options(), self.query) - - -class Query(SnowflakeReader): - """ - Query data from Snowflake and return the result as a DataFrame - - Example - ------- - ```python - Query( - database="MY_DB", - schema_="MY_SCHEMA", - warehouse="MY_WH", - user="gid.account@nike.com", - password=Secret("super-secret-password"), - role="APPLICATION.SNOWFLAKE.ADMIN", - query="SELECT * FROM MY_TABLE", - ).execute().df - ``` - """ - - query: str = Field(default=..., description="The query to run") - - @field_validator("query") - def validate_query(cls, query: str) -> str: - """Replace escape characters""" - query = query.replace("\\n", "\n").replace("\\t", "\t").strip() - return query - - def get_options(self) -> Dict[str, Any]: - """add query to options""" - options = super().get_options() - options["query"] = self.query - return options - - -class DbTableQuery(SnowflakeReader): - """ - Read table from Snowflake using the `dbtable` option instead of `query` - - Example - ------- - ```python - DbTableQuery( - database="MY_DB", - schema_="MY_SCHEMA", - warehouse="MY_WH", - user="user", - password=Secret("super-secret-password"), - role="APPLICATION.SNOWFLAKE.ADMIN", - table="db.schema.table", - ).execute().df - ``` - """ - - dbtable: str = Field(default=..., alias="table", description="The name of the table") - - -class TableExists(SnowflakeTableStep): - """ - Check if the table exists in Snowflake by using INFORMATION_SCHEMA. - - Example - ------- - ```python - k = TableExists( - url="foo.snowflakecomputing.com", - user="YOUR_USERNAME", - password="***", - database="db", - schema="schema", - table="table", - ) - ``` - """ - - class Output(StepOutput): - """Output class for TableExists""" - - exists: bool = Field(default=..., description="Whether or not the table exists") - - def execute(self) -> Output: - query = ( - dedent( - # Force upper case, due to case-sensitivity of where clause - f""" - SELECT * - FROM INFORMATION_SCHEMA.TABLES - WHERE TABLE_CATALOG = '{self.database}' - AND TABLE_SCHEMA = '{self.sfSchema}' - AND TABLE_TYPE = 'BASE TABLE' - AND upper(TABLE_NAME) = '{self.table.upper()}' - """ # nosec B608: hardcoded_sql_expressions - ) - .upper() - .strip() - ) - - self.log.debug(f"Query that was executed to check if the table exists:\n{query}") - - df = Query(**self.get_options(), query=query).read() - - exists = df.count() > 0 - self.log.info(f"Table {self.table} {'exists' if exists else 'does not exist'}") - self.output.exists = exists - - -def map_spark_type(spark_type: t.DataType) -> str: - """ - Translates Spark DataFrame Schema type to SnowFlake type - - | Basic Types | Snowflake Type | - |-------------------|----------------| - | StringType | STRING | - | NullType | STRING | - | BooleanType | BOOLEAN | - - | Numeric Types | Snowflake Type | - |-------------------|----------------| - | LongType | BIGINT | - | IntegerType | INT | - | ShortType | SMALLINT | - | DoubleType | DOUBLE | - | FloatType | FLOAT | - | NumericType | FLOAT | - | ByteType | BINARY | - - | Date / Time Types | Snowflake Type | - |-------------------|----------------| - | DateType | DATE | - | TimestampType | TIMESTAMP | - - | Advanced Types | Snowflake Type | - |-------------------|----------------| - | DecimalType | DECIMAL | - | MapType | VARIANT | - | ArrayType | VARIANT | - | StructType | VARIANT | - - References - ---------- - - Spark SQL DataTypes: https://spark.apache.org/docs/latest/sql-ref-datatypes.html - - Snowflake DataTypes: https://docs.snowflake.com/en/sql-reference/data-types.html - - Parameters - ---------- - spark_type : pyspark.sql.types.DataType - DataType taken out of the StructField - - Returns - ------- - str - The Snowflake data type - """ - # StructField means that the entire Field was passed, we need to extract just the dataType before continuing - if isinstance(spark_type, t.StructField): - spark_type = spark_type.dataType - - # Check if the type is DayTimeIntervalType - if isinstance(spark_type, t.DayTimeIntervalType): - warn( - "DayTimeIntervalType is being converted to STRING. " - "Consider converting to a more supported date/time/timestamp type in Snowflake." - ) - - # fmt: off - # noinspection PyUnresolvedReferences - data_type_map = { - # Basic Types - t.StringType: "STRING", - t.NullType: "STRING", - t.BooleanType: "BOOLEAN", - - # Numeric Types - t.LongType: "BIGINT", - t.IntegerType: "INT", - t.ShortType: "SMALLINT", - t.DoubleType: "DOUBLE", - t.FloatType: "FLOAT", - t.NumericType: "FLOAT", - t.ByteType: "BINARY", - t.BinaryType: "VARBINARY", - - # Date / Time Types - t.DateType: "DATE", - t.TimestampType: "TIMESTAMP", - t.DayTimeIntervalType: "STRING", - - # Advanced Types - t.DecimalType: - f"DECIMAL({spark_type.precision},{spark_type.scale})" # pylint: disable=no-member - if isinstance(spark_type, t.DecimalType) else "DECIMAL(38,0)", - t.MapType: "VARIANT", - t.ArrayType: "VARIANT", - t.StructType: "VARIANT", - } - return data_type_map.get(type(spark_type), 'STRING') - # fmt: on - - -class CreateOrReplaceTableFromDataFrame(SnowflakeTransformation): - """ - Create (or Replace) a Snowflake table which has the same schema as a Spark DataFrame - - Can be used as any Transformation. The DataFrame is however left unchanged, and only used for determining the - schema of the Snowflake Table that is to be created (or replaced). - - Example - ------- - ```python - CreateOrReplaceTableFromDataFrame( - database="MY_DB", - schema="MY_SCHEMA", - warehouse="MY_WH", - user="gid.account@nike.com", - password="super-secret-password", - role="APPLICATION.SNOWFLAKE.ADMIN", - table="MY_TABLE", - df=df, - ).execute() - ``` - - Or, as a Transformation: - ```python - CreateOrReplaceTableFromDataFrame( - ... - table="MY_TABLE", - ).transform(df) - ``` - - """ - - table: str = Field(default=..., alias="table_name", description="The name of the (new) table") - - class Output(SnowflakeTransformation.Output): - """Output class for CreateOrReplaceTableFromDataFrame""" - - input_schema: t.StructType = Field(default=..., description="The original schema from the input DataFrame") - snowflake_schema: str = Field( - default=..., description="Derived Snowflake table schema based on the input DataFrame" - ) - query: str = Field(default=..., description="Query that was executed to create the table") - - def execute(self) -> Output: - self.output.df = self.df - - input_schema = self.df.schema - self.output.input_schema = input_schema - - snowflake_schema = ", ".join([f"{c.name} {map_spark_type(c.dataType)}" for c in input_schema]) - self.output.snowflake_schema = snowflake_schema - - table_name = f"{self.database}.{self.sfSchema}.{self.table}" - query = f"CREATE OR REPLACE TABLE {table_name} ({snowflake_schema})" - self.output.query = query - - RunQuery(**self.get_options(), query=query).execute() - - -class GrantPrivilegesOnObject(SnowflakeStep): - """ - A wrapper on Snowflake GRANT privileges - - With this Step, you can grant Snowflake privileges to a set of roles on a table, a view, or an object - - See Also - -------- - https://docs.snowflake.com/en/sql-reference/sql/grant-privilege.html - - Parameters - ---------- - warehouse : str - The name of the warehouse. Alias for `sfWarehouse` - user : str - The username. Alias for `sfUser` - password : SecretStr - The password. Alias for `sfPassword` - role : str - The role name - object : str - The name of the object to grant privileges on - type : str - The type of object to grant privileges on, e.g. TABLE, VIEW - privileges : Union[conlist(str, min_length=1), str] - The Privilege/Permission or list of Privileges/Permissions to grant on the given object. - roles : Union[conlist(str, min_length=1), str] - The Role or list of Roles to grant the privileges to - - Example - ------- - ```python - GrantPermissionsOnTable( - object="MY_TABLE", - type="TABLE", - warehouse="MY_WH", - user="gid.account@nike.com", - password=Secret("super-secret-password"), - role="APPLICATION.SNOWFLAKE.ADMIN", - permissions=["SELECT", "INSERT"], - ).execute() - ``` - - In this example, the `APPLICATION.SNOWFLAKE.ADMIN` role will be granted `SELECT` and `INSERT` privileges on - the `MY_TABLE` table using the `MY_WH` warehouse. - """ - - object: str = Field(default=..., description="The name of the object to grant privileges on") - type: str = Field(default=..., description="The type of object to grant privileges on, e.g. TABLE, VIEW") - - privileges: Union[conlist(str, min_length=1), str] = Field( - default=..., - alias="permissions", - description="The Privilege/Permission or list of Privileges/Permissions to grant on the given object. " - "See https://docs.snowflake.com/en/sql-reference/sql/grant-privilege.html", - ) - roles: Union[conlist(str, min_length=1), str] = Field( - default=..., - alias="role", - validation_alias="roles", - description="The Role or list of Roles to grant the privileges to", - ) - - class Output(SnowflakeStep.Output): - """Output class for GrantPrivilegesOnObject""" - - query: conlist(str, min_length=1) = Field( - default=..., description="Query that was executed to grant privileges", validate_default=False - ) - - @model_validator(mode="before") - def set_roles_privileges(cls, values: Dict[str, Any]) -> Dict[str, Any]: - """Coerce roles and privileges to be lists if they are not already.""" - roles_value = values.get("roles") or values.get("role") - privileges_value = values.get("privileges") - - if not (roles_value and privileges_value): - raise ValueError("You have to specify roles AND privileges when using 'GrantPrivilegesOnObject'.") - - # coerce values to be lists - values["roles"] = [roles_value] if isinstance(roles_value, str) else roles_value - values["role"] = values["roles"][0] # hack to keep the validator happy - values["privileges"] = [privileges_value] if isinstance(privileges_value, str) else privileges_value - - return values - - @model_validator(mode="after") - def validate_object_and_object_type(self) -> "GrantPrivilegesOnObject": - """Validate that the object and type are set.""" - object_value = self.object - if not object_value: - raise ValueError("You must provide an `object`, this should be the name of the object. ") - - object_type = self.type - if not object_type: - raise ValueError( - "You must provide a `type`, e.g. TABLE, VIEW, DATABASE. " - "See https://docs.snowflake.com/en/sql-reference/sql/grant-privilege.html" - ) - - return self - - def get_query(self, role: str) -> str: - """Build the GRANT query - - Parameters - ---------- - role: str - The role name - - Returns - ------- - query : str - The Query that performs the grant - """ - query = f"GRANT {','.join(self.privileges)} ON {self.type} {self.object} TO ROLE {role}".upper() # nosec B608: hardcoded_sql_expressions - return query - - def execute(self) -> SnowflakeStep.Output: - self.output.query = [] - roles = self.roles - - for role in roles: - query = self.get_query(role) - self.output.query.append(query) - RunQuery(**self.get_options(), query=query).execute() - - -class GrantPrivilegesOnFullyQualifiedObject(GrantPrivilegesOnObject): - """Grant Snowflake privileges to a set of roles on a fully qualified object, i.e. `database.schema.object_name` - - This class is a subclass of `GrantPrivilegesOnObject` and is used to grant privileges on a fully qualified object. - The advantage of using this class is that it sets the object name to be fully qualified, i.e. - `database.schema.object_name`. - - Meaning, you can set the `database`, `schema` and `object` separately and the object name will be set to be fully - qualified, i.e. `database.schema.object_name`. - - Example - ------- - ```python - GrantPrivilegesOnFullyQualifiedObject( - database="MY_DB", - schema="MY_SCHEMA", - warehouse="MY_WH", - ... - object="MY_TABLE", - type="TABLE", - ... - ) - ``` - - In this example, the object name will be set to be fully qualified, i.e. `MY_DB.MY_SCHEMA.MY_TABLE`. - If you were to use `GrantPrivilegesOnObject` instead, you would have to set the object name to be fully qualified - yourself. - """ - - @model_validator(mode="after") - def set_object_name(self) -> "GrantPrivilegesOnFullyQualifiedObject": - """Set the object name to be fully qualified, i.e. database.schema.object_name""" - # database, schema, obj_name - db = self.database - schema = self.model_dump()["sfSchema"] # since "schema" is a reserved name - obj_name = self.object - - self.object = f"{db}.{schema}.{obj_name}" - - return self - - -class GrantPrivilegesOnTable(GrantPrivilegesOnFullyQualifiedObject): - """Grant Snowflake privileges to a set of roles on a table""" - - type: str = "TABLE" - object: str = Field( - default=..., - alias="table", - description="The name of the Table to grant Privileges on. This should be just the name of the table; so " - "without Database and Schema, use sfDatabase/database and sfSchema/schema to set those instead.", - ) - - -class GrantPrivilegesOnView(GrantPrivilegesOnFullyQualifiedObject): - """Grant Snowflake privileges to a set of roles on a view""" - - type: str = "VIEW" - object: str = Field( - default=..., - alias="view", - description="The name of the View to grant Privileges on. This should be just the name of the view; so " - "without Database and Schema, use sfDatabase/database and sfSchema/schema to set those instead.", - ) - - -class GetTableSchema(SnowflakeStep): - """ - Get the schema from a Snowflake table as a Spark Schema - - Notes - ----- - * This Step will execute a `SELECT * FROM
LIMIT 1` query to get the schema of the table. - * The schema will be stored in the `table_schema` attribute of the output. - * `table_schema` is used as the attribute name to avoid conflicts with the `schema` attribute of Pydantic's - BaseModel. - - Example - ------- - ```python - schema = ( - GetTableSchema( - database="MY_DB", - schema_="MY_SCHEMA", - warehouse="MY_WH", - user="gid.account@nike.com", - password="super-secret-password", - role="APPLICATION.SNOWFLAKE.ADMIN", - table="MY_TABLE", - ) - .execute() - .table_schema - ) - ``` - """ - - table: str = Field(default=..., description="The Snowflake table name") - - class Output(StepOutput): - """Output class for GetTableSchema""" - - table_schema: t.StructType = Field(default=..., serialization_alias="schema", description="The Spark Schema") - - def execute(self) -> Output: - query = f"SELECT * FROM {self.table} LIMIT 1" # nosec B608: hardcoded_sql_expressions - df = Query(**self.get_options(), query=query).execute().df - self.output.table_schema = df.schema - - -class AddColumn(SnowflakeStep): - """ - Add an empty column to a Snowflake table with given name and DataType - - Example - ------- - ```python - AddColumn( - database="MY_DB", - schema_="MY_SCHEMA", - warehouse="MY_WH", - user="gid.account@nike.com", - password=Secret("super-secret-password"), - role="APPLICATION.SNOWFLAKE.ADMIN", - table="MY_TABLE", - col="MY_COL", - dataType=StringType(), - ).execute() - ``` - """ - - table: str = Field(default=..., description="The name of the Snowflake table") - column: str = Field(default=..., description="The name of the new column") - type: t.DataType = Field(default=..., description="The DataType represented as a Spark DataType") - - class Output(SnowflakeStep.Output): - """Output class for AddColumn""" - - query: str = Field(default=..., description="Query that was executed to add the column") - - def execute(self) -> Output: - query = f"ALTER TABLE {self.table} ADD COLUMN {self.column} {map_spark_type(self.type)}".upper() - self.output.query = query - RunQuery(**self.get_options(), query=query).execute() - - -class SyncTableAndDataFrameSchema(SnowflakeStep, SnowflakeTransformation): - """ - Sync the schema's of a Snowflake table and a DataFrame. This will add NULL columns for the columns that are not in - both and perform type casts where needed. - - The Snowflake table will take priority in case of type conflicts. - """ - - df: DataFrame = Field(default=..., description="The Spark DataFrame") - table: str = Field(default=..., description="The table name") - dry_run: Optional[bool] = Field(default=False, description="Only show schema differences, do not apply changes") - - class Output(SparkStep.Output): - """Output class for SyncTableAndDataFrameSchema""" - - original_df_schema: t.StructType = Field(default=..., description="Original DataFrame schema") - original_sf_schema: t.StructType = Field(default=..., description="Original Snowflake schema") - new_df_schema: t.StructType = Field(default=..., description="New DataFrame schema") - new_sf_schema: t.StructType = Field(default=..., description="New Snowflake schema") - sf_table_altered: bool = Field( - default=False, description="Flag to indicate whether Snowflake schema has been altered" - ) - - def execute(self) -> Output: - self.log.warning("Snowflake table will always take a priority in case of data type conflicts!") - - # spark side - df_schema = self.df.schema - self.output.original_df_schema = deepcopy(df_schema) # using deepcopy to avoid storing in place changes - df_cols = [c.name.lower() for c in df_schema] - - # snowflake side - sf_schema = GetTableSchema(**self.get_options(), table=self.table).execute().table_schema - self.output.original_sf_schema = sf_schema - sf_cols = [c.name.lower() for c in sf_schema] - - if self.dry_run: - # Display differences between Spark DataFrame and Snowflake schemas - # and provide dummy values that are expected as class outputs. - self.log.warning(f"Columns to be added to Snowflake table: {set(df_cols) - set(sf_cols)}") - self.log.warning(f"Columns to be added to Spark DataFrame: {set(sf_cols) - set(df_cols)}") - - self.output.new_df_schema = t.StructType() - self.output.new_sf_schema = t.StructType() - self.output.df = self.df - self.output.sf_table_altered = False - - else: - # Add columns to SnowFlake table that exist in DataFrame - for df_column in df_schema: - if df_column.name.lower() not in sf_cols: - AddColumn( - **self.get_options(), - table=self.table, - column=df_column.name, - type=df_column.dataType, - ).execute() - self.output.sf_table_altered = True - - if self.output.sf_table_altered: - sf_schema = GetTableSchema(**self.get_options(), table=self.table).execute().table_schema - sf_cols = [c.name.lower() for c in sf_schema] - - self.output.new_sf_schema = sf_schema - - # Add NULL columns to the DataFrame if they exist in SnowFlake but not in the df - df = self.df - for sf_col in self.output.original_sf_schema: - sf_col_name = sf_col.name.lower() - if sf_col_name not in df_cols: - sf_col_type = sf_col.dataType - df = df.withColumn(sf_col_name, f.lit(None).cast(sf_col_type)) # type: ignore - - # Put DataFrame columns in the same order as the Snowflake table - df = df.select(*sf_cols) - - self.output.df = df - self.output.new_df_schema = df.schema - - -class SnowflakeWriter(SnowflakeBaseModel, Writer): - """Class for writing to Snowflake - - See Also - -------- - - [koheesio.steps.writers.Writer](writers/index.md#koheesio.spark.writers.Writer) - - [koheesio.steps.writers.BatchOutputMode](writers/index.md#koheesio.spark.writers.BatchOutputMode) - - [koheesio.steps.writers.StreamingOutputMode](writers/index.md#koheesio.spark.writers.StreamingOutputMode) - """ - - table: str = Field(default=..., description="Target table name") - insert_type: Optional[BatchOutputMode] = Field( - BatchOutputMode.APPEND, alias="mode", description="The insertion type, append or overwrite" - ) - - def execute(self) -> Writer.Output: - """Write to Snowflake""" - self.log.debug(f"writing to {self.table} with mode {self.insert_type}") - self.df.write.format(self.format).options(**self.get_options()).option("dbtable", self.table).mode( - self.insert_type - ).save() - - -class TagSnowflakeQuery(Step, ExtraParamsMixin): - """ - Provides Snowflake query tag pre-action that can be used to easily find queries through SF history search - and further group them for debugging and cost tracking purposes. - - Takes in query tag attributes as kwargs and additional Snowflake options dict that can optionally contain - other set of pre-actions to be applied to a query, in that case existing pre-action aren't dropped, query tag - pre-action will be added to them. - - Passed Snowflake options dictionary is not modified in-place, instead anew dictionary containing updated pre-actions - is returned. - - Notes - ----- - See this article for explanation: https://select.dev/posts/snowflake-query-tags - - Arbitrary tags can be applied, such as team, dataset names, business capability, etc. - - Example - ------- - ```python - query_tag = ( - AddQueryTag( - options={"preactions": ...}, - task_name="cleanse_task", - pipeline_name="ingestion-pipeline", - etl_date="2022-01-01", - pipeline_execution_time="2022-01-01T00:00:00", - task_execution_time="2022-01-01T01:00:00", - environment="dev", - trace_id="acd4f3f96045", - span_id="546d2d66f6cb", - ) - .execute() - .options - ) - ``` - """ - - options: Dict = Field( - default_factory=dict, description="Additional Snowflake options, optionally containing additional preactions" - ) - - class Output(StepOutput): - """Output class for AddQueryTag""" - - options: Dict = Field(default=..., description="Copy of provided SF options, with added query tag preaction") - - def execute(self) -> Output: - """Add query tag preaction to Snowflake options""" - tag_json = json.dumps(self.extra_params, indent=4, sort_keys=True) - tag_preaction = f"ALTER SESSION SET QUERY_TAG = '{tag_json}';" - preactions = self.options.get("preactions", "") - preactions = f"{preactions}\n{tag_preaction}".strip() - updated_options = dict(self.options) - updated_options["preactions"] = preactions - self.output.options = updated_options - - -class SynchronizeDeltaToSnowflakeTask(SnowflakeStep): - """ - Synchronize a Delta table to a Snowflake table - - * Overwrite - only in batch mode - * Append - supports batch and streaming mode - * Merge - only in streaming mode - - Example - ------- - ```python - SynchronizeDeltaToSnowflakeTask( - url="acme.snowflakecomputing.com", - user="admin", - role="ADMIN", - warehouse="SF_WAREHOUSE", - database="SF_DATABASE", - schema="SF_SCHEMA", - source_table=DeltaTableStep(...), - target_table="my_sf_table", - key_columns=[ - "id", - ], - streaming=False, - ).run() - ``` - """ - - source_table: DeltaTableStep = Field(default=..., description="Source delta table to synchronize") - target_table: str = Field(default=..., description="Target table in snowflake to synchronize to") - synchronisation_mode: BatchOutputMode = Field( - default=BatchOutputMode.MERGE, - description="Determines if synchronisation will 'overwrite' any existing table, 'append' new rows or " - "'merge' with existing rows.", - ) - checkpoint_location: Optional[str] = Field(default=None, description="Checkpoint location to use") - schema_tracking_location: Optional[str] = Field( - default=None, - description="Schema tracking location to use. " - "Info: https://docs.delta.io/latest/delta-streaming.html#-schema-tracking", - ) - staging_table_name: Optional[str] = Field( - default=None, alias="staging_table", description="Optional snowflake staging name", validate_default=False - ) - key_columns: List[str] = Field( - default_factory=list, - description="Key columns on which merge statements will be MERGE statement will be applied.", - ) - streaming: Optional[bool] = Field( - default=False, - description="Should synchronisation happen in streaming or in batch mode. Streaming is supported in 'APPEND' " - "and 'MERGE' mode. Batch is supported in 'OVERWRITE' and 'APPEND' mode.", - ) - persist_staging: Optional[bool] = Field( - default=False, - description="In case of debugging, set `persist_staging` to True to retain the staging table for inspection " - "after synchronization.", - ) - - enable_deletion: Optional[bool] = Field( - default=False, - description="In case of merge synchronisation_mode add deletion statement in merge query.", - ) - - writer_: Optional[Union[ForEachBatchStreamWriter, SnowflakeWriter]] = None - - @field_validator("staging_table_name") - def _validate_staging_table(cls, staging_table_name: str) -> str: - """Validate the staging table name and return it if it's valid.""" - if "." in staging_table_name: - raise ValueError( - "Custom staging table must not contain '.', it is located in the same Schema as the target table." - ) - return staging_table_name - - @model_validator(mode="before") - def _checkpoint_location_check(cls, values: Dict) -> Dict: - """Give a warning if checkpoint location is given but not expected and vice versa""" - streaming = values.get("streaming") - checkpoint_location = values.get("checkpoint_location") - log = LoggingFactory.get_logger(cls.__name__) - - if streaming is False and checkpoint_location is not None: - log.warning("checkpoint_location is provided but will be ignored in batch mode") - if streaming is True and checkpoint_location is None: - log.warning("checkpoint_location is not provided in streaming mode") - return values - - @model_validator(mode="before") - def _synch_mode_check(cls, values: Dict) -> Dict: - """Validate requirements for various synchronisation modes""" - streaming = values.get("streaming") - synchronisation_mode = values.get("synchronisation_mode") - key_columns = values.get("key_columns") - - allowed_output_modes = [BatchOutputMode.OVERWRITE, BatchOutputMode.MERGE, BatchOutputMode.APPEND] - - if synchronisation_mode not in allowed_output_modes: - raise ValueError( - f"Synchronisation mode should be one of {', '.join([m.value for m in allowed_output_modes])}" - ) - if synchronisation_mode == BatchOutputMode.OVERWRITE and streaming is True: - raise ValueError("Synchronisation mode can't be 'OVERWRITE' with streaming enabled") - if synchronisation_mode == BatchOutputMode.MERGE and streaming is False: - raise ValueError("Synchronisation mode can't be 'MERGE' with streaming disabled") - if synchronisation_mode == BatchOutputMode.MERGE and len(key_columns) < 1: # type: ignore - raise ValueError("MERGE synchronisation mode requires a list of PK columns in `key_columns`.") - - return values - - @property - def non_key_columns(self) -> List[str]: - """Columns of source table that aren't part of the (composite) primary key""" - lowercase_key_columns: Set[str] = {c.lower() for c in self.key_columns} # type: ignore - source_table_columns = self.source_table.columns - non_key_columns: List[str] = [c for c in source_table_columns if c.lower() not in lowercase_key_columns] # type: ignore - return non_key_columns - - @property - def staging_table(self) -> str: - """Intermediate table on snowflake where staging results are stored""" - if stg_tbl_name := self.staging_table_name: - return stg_tbl_name - - return f"{self.source_table.table}_stg" - - @property - def reader(self) -> Union[DeltaTableReader, DeltaTableStreamReader]: - """ - DeltaTable reader - - Returns: - -------- - Union[DeltaTableReader, DeltaTableStreamReader] - DeltaTableReader that will yield source delta table - """ - # Wrap in lambda functions to mimic lazy evaluation. - # This ensures the Task doesn't fail if a config isn't provided for a reader/writer that isn't used anyway - map_mode_reader = { - BatchOutputMode.OVERWRITE: lambda: DeltaTableReader( - table=self.source_table, streaming=False, schema_tracking_location=self.schema_tracking_location - ), - BatchOutputMode.APPEND: lambda: DeltaTableReader( - table=self.source_table, - streaming=self.streaming, - schema_tracking_location=self.schema_tracking_location, - ), - BatchOutputMode.MERGE: lambda: DeltaTableStreamReader( - table=self.source_table, read_change_feed=True, schema_tracking_location=self.schema_tracking_location - ), - } - return map_mode_reader[self.synchronisation_mode]() - - def _get_writer(self) -> Union[SnowflakeWriter, ForEachBatchStreamWriter]: - """ - Writer to persist to snowflake - - Depending on configured options, this returns an SnowflakeWriter or ForEachBatchStreamWriter: - - OVERWRITE/APPEND mode yields SnowflakeWriter - - MERGE mode yields ForEachBatchStreamWriter - - Returns - ------- - ForEachBatchStreamWriter | SnowflakeWriter - The right writer for the configured options and mode - """ - # Wrap in lambda functions to mimic lazy evaluation. - # This ensures the Task doesn't fail if a config isn't provided for a reader/writer that isn't used anyway - map_mode_writer = { - (BatchOutputMode.OVERWRITE, False): lambda: SnowflakeWriter( - table=self.target_table, insert_type=BatchOutputMode.OVERWRITE, **self.get_options() - ), - (BatchOutputMode.APPEND, False): lambda: SnowflakeWriter( - table=self.target_table, insert_type=BatchOutputMode.APPEND, **self.get_options() - ), - (BatchOutputMode.APPEND, True): lambda: ForEachBatchStreamWriter( - checkpointLocation=self.checkpoint_location, - batch_function=writer_to_foreachbatch( - SnowflakeWriter(table=self.target_table, insert_type=BatchOutputMode.APPEND, **self.get_options()) - ), - ), - (BatchOutputMode.MERGE, True): lambda: ForEachBatchStreamWriter( - checkpointLocation=self.checkpoint_location, - batch_function=self._merge_batch_write_fn( - key_columns=self.key_columns, # type: ignore - non_key_columns=self.non_key_columns, - staging_table=self.staging_table, - ), - ), - } - return map_mode_writer[(self.synchronisation_mode, self.streaming)]() # type: ignore - - @property - def writer(self) -> Union[ForEachBatchStreamWriter, SnowflakeWriter]: - """ - Writer to persist to snowflake - - Depending on configured options, this returns an SnowflakeWriter or ForEachBatchStreamWriter: - - OVERWRITE/APPEND mode yields SnowflakeWriter - - MERGE mode yields ForEachBatchStreamWriter - - Returns - ------- - Union[ForEachBatchStreamWriter, SnowflakeWriter] - """ - # Cache 'writer' object in memory to ensure same object is used everywhere, this ensures access to underlying - # member objects such as active streaming queries (if any). - if not self.writer_: - self.writer_ = self._get_writer() - return self.writer_ - - def truncate_table(self, snowflake_table: str) -> None: - """Truncate a given snowflake table""" - truncate_query = f"""TRUNCATE TABLE IF EXISTS {snowflake_table}""" # nosec B608: hardcoded_sql_expressions - query_executor = RunQuery( - **self.get_options(), - query=truncate_query, - ) - query_executor.execute() - - def drop_table(self, snowflake_table: str) -> None: - """Drop a given snowflake table""" - self.log.warning(f"Dropping table {snowflake_table} from snowflake") - drop_table_query = f"""DROP TABLE IF EXISTS {snowflake_table}""" # nosec B608: hardcoded_sql_expressions - query_executor = RunQuery(**self.get_options(), query=drop_table_query) - query_executor.execute() - - def _merge_batch_write_fn(self, key_columns: List[str], non_key_columns: List[str], staging_table: str) -> Callable: - """Build a batch write function for merge mode""" - - # pylint: disable=unused-argument - # noinspection PyUnusedLocal,PyPep8Naming - def inner(dataframe: DataFrame, batchId: int) -> None: - self._build_staging_table(dataframe, key_columns, non_key_columns, staging_table) - self._merge_staging_table_into_target() - - # pylint: enable=unused-argument - return inner - - @staticmethod - def _compute_latest_changes_per_pk( - dataframe: DataFrame, key_columns: List[str], non_key_columns: List[str] - ) -> DataFrame: - """Compute the latest changes per primary key""" - window_spec = Window.partitionBy(*key_columns).orderBy(f.col("_commit_version").desc()) - ranked_df = ( - dataframe.filter("_change_type != 'update_preimage'") - .withColumn("rank", f.rank().over(window_spec)) # type: ignore - .filter("rank = 1") - .select(*key_columns, *non_key_columns, "_change_type") # discard unused columns - .distinct() - ) - return ranked_df - - def _build_staging_table( - self, dataframe: DataFrame, key_columns: List[str], non_key_columns: List[str], staging_table: str - ) -> None: - """Build snowflake staging table""" - ranked_df = self._compute_latest_changes_per_pk(dataframe, key_columns, non_key_columns) - batch_writer = SnowflakeWriter( - table=staging_table, df=ranked_df, insert_type=BatchOutputMode.APPEND, **self.get_options() - ) - batch_writer.execute() - - def _merge_staging_table_into_target(self) -> None: - """ - Merge snowflake staging table into final snowflake table - """ - merge_query = self._build_sf_merge_query( - target_table=self.target_table, - stage_table=self.staging_table, - pk_columns=self.key_columns, # type: ignore - non_pk_columns=self.non_key_columns, - enable_deletion=self.enable_deletion, # type: ignore - ) - - query_executor = RunQuery( - **self.get_options(), - query=merge_query, - ) - query_executor.execute() - - @staticmethod - def _build_sf_merge_query( - target_table: str, - stage_table: str, - pk_columns: List[str], - non_pk_columns: List[str], - enable_deletion: bool = False, - ) -> str: - """Build a CDF merge query string - - Parameters - ---------- - target_table: Table - Destination table to merge into - stage_table: Table - Temporary table containing updates to be executed - pk_columns: List[str] - Column names used to uniquely identify each row - non_pk_columns: List[str] - Non-key columns that may need to be inserted/updated - enable_deletion: bool - DELETE actions are synced. If set to False (default) then sync is non-destructive - - Returns - ------- - str - Query to be executed on the target database - """ - all_fields = [*pk_columns, *non_pk_columns] - key_join_string = " AND ".join(f"target.{k} = temp.{k}" for k in pk_columns) - columns_string = ", ".join(all_fields) - assignment_string = ", ".join(f"{k} = temp.{k}" for k in non_pk_columns) - values_string = ", ".join(f"temp.{k}" for k in all_fields) - - query = f""" - MERGE INTO {target_table} target - USING {stage_table} temp ON {key_join_string} - WHEN MATCHED AND temp._change_type = 'update_postimage' THEN UPDATE SET {assignment_string} - WHEN NOT MATCHED AND temp._change_type != 'delete' THEN INSERT ({columns_string}) VALUES ({values_string}) - """ # nosec B608: hardcoded_sql_expressions - if enable_deletion: - query += "WHEN MATCHED AND temp._change_type = 'delete' THEN DELETE" - - return query - - def extract(self) -> DataFrame: - """ - Extract source table - """ - if self.synchronisation_mode == BatchOutputMode.MERGE: - if not self.source_table.is_cdf_active: - raise RuntimeError( - f"Source table {self.source_table.table_name} does not have CDF enabled. " - f"Set TBLPROPERTIES ('delta.enableChangeDataFeed' = true) to enable. " - f"Current properties = {self.source_table.get_persisted_properties()}" - ) - - df = self.reader.read() - self.output.source_df = df - return df - - def load(self, df: DataFrame) -> DataFrame: - """Load source table into snowflake""" - if self.synchronisation_mode == BatchOutputMode.MERGE: - self.log.info(f"Truncating staging table {self.staging_table}") - self.truncate_table(self.staging_table) - self.writer.write(df) - self.output.target_df = df - return df - - def execute(self) -> SnowflakeStep.Output: - # extract - df = self.extract() - self.output.source_df = df - - # synchronize - self.output.target_df = df - self.load(df) - if not self.persist_staging: - # If it's a streaming job, await for termination before dropping staging table - if self.streaming: - self.writer.await_termination() # type: ignore - self.drop_table(self.staging_table) From 7fb1b272e7ed64414587bd20414e07c1678f262c Mon Sep 17 00:00:00 2001 From: Mikita Sakalouski <38785549+mikita-sakalouski@users.noreply.github.com> Date: Sun, 24 Nov 2024 13:22:51 +0100 Subject: [PATCH 20/33] [FEATURE] Populate account from url if not provided in SnowflakeBaseModel (#117) --- .../integrations/snowflake/__init__.py | 13 +++++++- tests/snowflake/test_snowflake.py | 32 +++++++++++++------ 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/src/koheesio/integrations/snowflake/__init__.py b/src/koheesio/integrations/snowflake/__init__.py index d461ffff..11fd8b59 100644 --- a/src/koheesio/integrations/snowflake/__init__.py +++ b/src/koheesio/integrations/snowflake/__init__.py @@ -47,6 +47,7 @@ from abc import ABC from contextlib import contextmanager from types import ModuleType +from urllib.parse import urlparse from koheesio import Step from koheesio.logger import warn @@ -281,7 +282,7 @@ class SnowflakeRunQueryPython(SnowflakeStep): """ query: str = Field(default=..., description="The query to run", alias="sql", serialization_alias="query") - account: str = Field(default=..., description="Snowflake Account Name", alias="account") + account: Optional[str] = Field(default=None, description="Snowflake Account Name", alias="account") # for internal use _snowflake_connector: Optional[ModuleType] = PrivateAttr(default_factory=safe_import_snowflake_connector) @@ -291,6 +292,16 @@ class Output(SnowflakeStep.Output): results: List = Field(default_factory=list, description="The results of the query") + @model_validator(mode="before") + def _validate_account(cls, values: Dict) -> Dict: + """Populate account from URL if not provided""" + if not values.get("account"): + parsed_url = urlparse(values["url"]) + base_url = parsed_url.hostname or parsed_url.path + values["account"] = base_url.split(".")[0] + + return values + @field_validator("query") def validate_query(cls, query: str) -> str: """Replace escape characters, strip whitespace, ensure it is not empty""" diff --git a/tests/snowflake/test_snowflake.py b/tests/snowflake/test_snowflake.py index 8cf2fb4b..c525c5fe 100644 --- a/tests/snowflake/test_snowflake.py +++ b/tests/snowflake/test_snowflake.py @@ -1,9 +1,11 @@ # flake8: noqa: F811 +from copy import deepcopy from unittest import mock -from pydantic_core._pydantic_core import ValidationError import pytest +from pydantic import ValidationError + from koheesio.integrations.snowflake import ( GrantPrivilegesOnObject, GrantPrivilegesOnTable, @@ -15,8 +17,10 @@ ) from koheesio.integrations.snowflake.test_utils import mock_query +mock_query = mock_query + COMMON_OPTIONS = { - "url": "url", + "url": "hostname.com", "user": "user", "password": "password", "database": "db", @@ -120,13 +124,23 @@ def test_get_options(self): "password": "password", "role": "role", "schema": "schema", - "url": "url", + "url": "hostname.com", "user": "user", "warehouse": "warehouse", } assert actual_options == expected_options assert query_in_options["query"] == expected_query, "query should be returned regardless of the input" + def test_account_populated_from_url(self): + kls = SnowflakeRunQueryPython(**COMMON_OPTIONS, sql="SELECT * FROM table") + assert kls.account == "hostname" + + def test_account_populated_from_url2(self): + common_options = deepcopy(COMMON_OPTIONS) + common_options["url"] = "https://host2.host1.snowflakecomputing.com" + kls = SnowflakeRunQueryPython(**common_options, sql="SELECT * FROM table") + assert kls.account == "host2" + def test_execute(self, mock_query): # Arrange query = "SELECT * FROM two_row_table" @@ -161,7 +175,7 @@ class TestSnowflakeBaseModel: def test_get_options_using_alias(self): """Test that the options are correctly generated using alias""" k = SnowflakeBaseModel( - sfURL="url", + sfURL="hostname.com", sfUser="user", sfPassword="password", sfDatabase="database", @@ -170,7 +184,7 @@ def test_get_options_using_alias(self): schema="schema", ) options = k.get_options() # alias should be used by default - assert options["sfURL"] == "url" + assert options["sfURL"] == "hostname.com" assert options["sfUser"] == "user" assert options["sfDatabase"] == "database" assert options["sfRole"] == "role" @@ -180,7 +194,7 @@ def test_get_options_using_alias(self): def test_get_options(self): """Test that the options are correctly generated not using alias""" k = SnowflakeBaseModel( - url="url", + url="hostname.com", user="user", password="password", database="database", @@ -189,7 +203,7 @@ def test_get_options(self): schema="schema", ) options = k.get_options(by_alias=False) - assert options["url"] == "url" + assert options["url"] == "hostname.com" assert options["user"] == "user" assert options["database"] == "database" assert options["role"] == "role" @@ -203,7 +217,7 @@ def test_get_options(self): def test_get_options_include(self): """Test that the options are correctly generated using include""" k = SnowflakeBaseModel( - url="url", + url="hostname.com", user="user", password="password", database="database", @@ -215,7 +229,7 @@ def test_get_options_include(self): options = k.get_options(include={"url", "user", "description", "options"}, by_alias=False) # should be present - assert options["url"] == "url" + assert options["url"] == "hostname.com" assert options["user"] == "user" assert "description" in options From b9c02998d9b097d8f08869debb1b28d5fb6b41cd Mon Sep 17 00:00:00 2001 From: Mikita Sakalouski <38785549+mikita-sakalouski@users.noreply.github.com> Date: Sun, 24 Nov 2024 23:13:21 +0100 Subject: [PATCH 21/33] hotfix: check url alias - sfURL --- src/koheesio/integrations/snowflake/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/koheesio/integrations/snowflake/__init__.py b/src/koheesio/integrations/snowflake/__init__.py index 11fd8b59..dc038a13 100644 --- a/src/koheesio/integrations/snowflake/__init__.py +++ b/src/koheesio/integrations/snowflake/__init__.py @@ -296,7 +296,7 @@ class Output(SnowflakeStep.Output): def _validate_account(cls, values: Dict) -> Dict: """Populate account from URL if not provided""" if not values.get("account"): - parsed_url = urlparse(values["url"]) + parsed_url = urlparse(values.get("url") or values.get("sfURL")) base_url = parsed_url.hostname or parsed_url.path values["account"] = base_url.split(".")[0] From 7fe59208ed1d34bc8ce86706e31702c115cf4407 Mon Sep 17 00:00:00 2001 From: Mikita Sakalouski <38785549+mikita-sakalouski@users.noreply.github.com> Date: Sun, 24 Nov 2024 23:15:47 +0100 Subject: [PATCH 22/33] test: add test for account population from sfURL in SnowflakeRunQueryPython --- tests/snowflake/test_snowflake.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/snowflake/test_snowflake.py b/tests/snowflake/test_snowflake.py index c525c5fe..c2b5420f 100644 --- a/tests/snowflake/test_snowflake.py +++ b/tests/snowflake/test_snowflake.py @@ -141,6 +141,12 @@ def test_account_populated_from_url2(self): kls = SnowflakeRunQueryPython(**common_options, sql="SELECT * FROM table") assert kls.account == "host2" + def test_account_populated_from_sf_url(self): + common_options = deepcopy(COMMON_OPTIONS) + common_options["sfURL"] = common_options.pop("url") + kls = SnowflakeRunQueryPython(**common_options, sql="SELECT * FROM table") + assert kls.account == "hostname" + def test_execute(self, mock_query): # Arrange query = "SELECT * FROM two_row_table" From 7c00d7fd03be9b2743c2092232b429ecc739821f Mon Sep 17 00:00:00 2001 From: Mikita Sakalouski <38785549+mikita-sakalouski@users.noreply.github.com> Date: Sun, 24 Nov 2024 23:34:26 +0100 Subject: [PATCH 23/33] chore: bump version to 0.9.0rc5 --- src/koheesio/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/koheesio/__about__.py b/src/koheesio/__about__.py index db20611c..131e2d15 100644 --- a/src/koheesio/__about__.py +++ b/src/koheesio/__about__.py @@ -12,7 +12,7 @@ LICENSE_INFO = "Licensed as Apache 2.0" SOURCE = "https://github.com/Nike-Inc/koheesio" -__version__ = "0.9.0rc4" +__version__ = "0.9.0rc5" __logo__ = ( 75, ( From 6faa20da03b493a08930d9588eaa46d9e04f8f5c Mon Sep 17 00:00:00 2001 From: Mikita Sakalouski <38785549+mikita-sakalouski@users.noreply.github.com> Date: Mon, 25 Nov 2024 14:11:52 +0100 Subject: [PATCH 24/33] refactor: replace RunQuery with SnowflakeRunQueryPython (#121) ## Description Fix usage of RunQuery ## Related Issue #120 ## Motivation and Context ## How Has This Been Tested? Existing tests ## Screenshots (if appropriate): ## Types of changes - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Checklist: - [x] My code follows the code style of this project. - [ ] My change requires a change to the documentation. - [ ] I have updated the documentation accordingly. - [x] I have read the **CONTRIBUTING** document. - [ ] I have added tests to cover my changes. - [x] All new and existing tests passed. --- src/koheesio/integrations/spark/snowflake.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/koheesio/integrations/spark/snowflake.py b/src/koheesio/integrations/spark/snowflake.py index 8a4ad9a2..a6e6941f 100644 --- a/src/koheesio/integrations/spark/snowflake.py +++ b/src/koheesio/integrations/spark/snowflake.py @@ -926,7 +926,7 @@ def _merge_staging_table_into_target(self) -> None: enable_deletion=self.enable_deletion, ) # type: ignore - query_executor = RunQuery( + query_executor = SnowflakeRunQueryPython( **self.get_options(), query=merge_query, ) From 9496eb55d724f5e84d177488bda7fa60fb351523 Mon Sep 17 00:00:00 2001 From: Mikita Sakalouski <38785549+mikita-sakalouski@users.noreply.github.com> Date: Mon, 25 Nov 2024 22:07:31 +0100 Subject: [PATCH 25/33] hotfix: snowflake python connector default config dir (#125) ## Description Snowflake python connector have a logic to get default[ root directory for configurations](https://github.com/snowflakedb/snowflake-connector-python/blob/9ddb2050cde13819a289fb6982b48431f253a53e/src/snowflake/connector/sf_dirs.py#L47): ```python def _resolve_platform_dirs() -> PlatformDirsProto: """Decide on what PlatformDirs class to use. In case a folder exists (which can be customized with the environmental variable `SNOWFLAKE_HOME`) we use that directory as all platform directories. If this folder does not exist we'll fall back to platformdirs defaults. This helper function was introduced to make this code testable. """ platformdir_kwargs = { "appname": "snowflake", "appauthor": False, } snowflake_home = pathlib.Path( os.environ.get("SNOWFLAKE_HOME", "~/.snowflake/"), ).expanduser() if snowflake_home.exists(): return SFPlatformDirs( str(snowflake_home), **platformdir_kwargs, ) else: # In case SNOWFLAKE_HOME does not exist we fall back to using # platformdirs to determine where system files should be placed. Please # see docs for all the directories defined in the module at # https://platformdirs.readthedocs.io/ return PlatformDirs(**platformdir_kwargs) ``` Currently in databricks jobs execution this one is being set to `'/root/.snowflake'` which is not allowed to be accessed. The fix is to catch the error and provide `tmp` folder instead of `root`. ## Related Issue #124 ## Motivation and Context Be able to use snowflake python in databricks ## How Has This Been Tested? Added new tests ## Types of changes - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Checklist: - [x] My code follows the code style of this project. - [ ] My change requires a change to the documentation. - [ ] I have updated the documentation accordingly. - [ ] I have read the **CONTRIBUTING** document. - [x] I have added tests to cover my changes. - [x] All new and existing tests passed. --- .../integrations/snowflake/__init__.py | 59 ++++++++++++++++++- tests/snowflake/test_snowflake.py | 16 +++++ 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/src/koheesio/integrations/snowflake/__init__.py b/src/koheesio/integrations/snowflake/__init__.py index dc038a13..9e1603d9 100644 --- a/src/koheesio/integrations/snowflake/__init__.py +++ b/src/koheesio/integrations/snowflake/__init__.py @@ -46,6 +46,8 @@ from typing import Any, Dict, Generator, List, Optional, Set, Union from abc import ABC from contextlib import contextmanager +import os +import tempfile from types import ModuleType from urllib.parse import urlparse @@ -61,6 +63,7 @@ field_validator, model_validator, ) +from koheesio.spark.utils.common import on_databricks __all__ = [ "GrantPrivilegesOnFullyQualifiedObject", @@ -79,6 +82,38 @@ # Turning off too-many-lines because we are defining a lot of classes in this file +def __check_access_snowflake_config_dir() -> bool: + """Check if the Snowflake configuration directory is accessible + + Returns + ------- + bool + True if the Snowflake configuration directory is accessible, otherwise False + + Raises + ------ + RuntimeError + If `snowflake-connector-python` is not installed + """ + check_result = False + + try: + from snowflake.connector.sf_dirs import _resolve_platform_dirs # noqa: F401 + + _resolve_platform_dirs().user_config_path + check_result = True + except PermissionError as e: + warn(f"Snowflake configuration directory is not accessible. Please check the permissions.Catched error: {e}") + except (ImportError, ModuleNotFoundError) as e: + raise RuntimeError( + "You need to have the `snowflake-connector-python` package installed to use the Snowflake steps that are" + "based around SnowflakeRunQueryPython. You can install this in Koheesio by adding `koheesio[snowflake]` to " + "your package dependencies.", + ) from e + + return check_result + + def safe_import_snowflake_connector() -> Optional[ModuleType]: """Validate that the Snowflake connector is installed @@ -87,7 +122,17 @@ def safe_import_snowflake_connector() -> Optional[ModuleType]: Optional[ModuleType] The Snowflake connector module if it is installed, otherwise None """ + is_accessable_sf_conf_dir = __check_access_snowflake_config_dir() + + if not is_accessable_sf_conf_dir and on_databricks(): + snowflake_home: str = tempfile.mkdtemp(prefix="snowflake_tmp_", dir="/tmp") # nosec B108:ignore bandit check for CWE-377 + os.environ["SNOWFLAKE_HOME"] = snowflake_home + warn(f"Getting error for snowflake config directory. Going to use temp directory `{snowflake_home}` instead.") + elif not is_accessable_sf_conf_dir: + raise PermissionError("Snowflake configuration directory is not accessible. Please check the permissions.") + try: + # Keep the import here as it is perfroming resolution of snowflake configuration directory from snowflake import connector as snowflake_connector return snowflake_connector @@ -336,8 +381,18 @@ def conn(self) -> Generator: self.log.info(f"Connected to Snowflake account: {sf_options['account']}") try: + from snowflake.connector.connection import logger as snowflake_logger + + _preserve_snowflake_logger = snowflake_logger + snowflake_logger = self.log + snowflake_logger.debug("Replace snowflake logger with Koheesio logger") yield _conn finally: + if _preserve_snowflake_logger: + if snowflake_logger: + snowflake_logger.debug("Restore snowflake logger") + snowflake_logger = _preserve_snowflake_logger + if _conn: _conn.close() @@ -348,7 +403,9 @@ def get_query(self) -> str: def execute(self) -> None: """Execute the query""" with self.conn as conn: - cursors = conn.execute_string(self.get_query()) + cursors = conn.execute_string( + self.get_query(), + ) for cursor in cursors: self.log.debug(f"Cursor executed: {cursor}") self.output.results.extend(cursor.fetchall()) diff --git a/tests/snowflake/test_snowflake.py b/tests/snowflake/test_snowflake.py index c2b5420f..08c074f0 100644 --- a/tests/snowflake/test_snowflake.py +++ b/tests/snowflake/test_snowflake.py @@ -1,5 +1,6 @@ # flake8: noqa: F811 from copy import deepcopy +import os from unittest import mock import pytest @@ -14,6 +15,7 @@ SnowflakeRunQueryPython, SnowflakeStep, SnowflakeTableStep, + safe_import_snowflake_connector, ) from koheesio.integrations.snowflake.test_utils import mock_query @@ -272,3 +274,17 @@ def test_initialization(self): """Test that the table is correctly set""" kls = SnowflakeTableStep(**COMMON_OPTIONS, table="table") assert kls.table == "table" + + +class TestSnowflakeConfigDir: + @mock.patch("koheesio.integrations.snowflake.__check_access_snowflake_config_dir", return_value=False) + @mock.patch("koheesio.integrations.snowflake.on_databricks", return_value=True) + def test_initialization_on_databricks(self, mock_on_databricks, mock_check_access): + """Test that the config dir is correctly set""" + safe_import_snowflake_connector() + assert os.environ["SNOWFLAKE_HOME"].startswith("/tmp/snowflake_tmp_") + + def test_initialization(self): + origin_snowflake_home = os.environ.get("SNOWFLAKE_HOME") + safe_import_snowflake_connector() + assert os.environ.get("SNOWFLAKE_HOME") == origin_snowflake_home From ea2d15e00fae6d654096a24992e779d24145c949 Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Tue, 26 Nov 2024 11:40:33 +0100 Subject: [PATCH 26/33] version bump --- src/koheesio/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/koheesio/__about__.py b/src/koheesio/__about__.py index 131e2d15..350ac2de 100644 --- a/src/koheesio/__about__.py +++ b/src/koheesio/__about__.py @@ -12,7 +12,7 @@ LICENSE_INFO = "Licensed as Apache 2.0" SOURCE = "https://github.com/Nike-Inc/koheesio" -__version__ = "0.9.0rc5" +__version__ = "0.9.0rc6" __logo__ = ( 75, ( From c72f381d00397fdce8b81039ac15ce2c6dd52407 Mon Sep 17 00:00:00 2001 From: Danny Meijer Date: Tue, 26 Nov 2024 11:46:41 +0100 Subject: [PATCH 27/33] Fix/delta merge builder instance check for connect + util fix (#130) ## Description Additonal fixes for #101 and #102 ## Related Issue #101, #102 ## Motivation and Context ## How Has This Been Tested? ## Screenshots (if appropriate): ## Types of changes - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Checklist: - [x] My code follows the code style of this project. - [ ] My change requires a change to the documentation. - [ ] I have updated the documentation accordingly. - [x] I have read the **CONTRIBUTING** document. - [ ] I have added tests to cover my changes. - [x] All new and existing tests passed. --------- Co-authored-by: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> --- src/koheesio/spark/utils/connect.py | 5 ++++- src/koheesio/spark/writers/delta/batch.py | 7 +++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/koheesio/spark/utils/connect.py b/src/koheesio/spark/utils/connect.py index 9cf7f028..d5728902 100644 --- a/src/koheesio/spark/utils/connect.py +++ b/src/koheesio/spark/utils/connect.py @@ -14,6 +14,9 @@ def is_remote_session(spark: Optional[SparkSession] = None) -> bool: result = False if (_spark := spark or get_active_session()) and check_if_pyspark_connect_is_supported(): - result = True if _spark.conf.get("spark.remote", None) else False # type: ignore + # result = True if _spark.conf.get("spark.remote", None) else False # type: ignore + from pyspark.sql.connect.session import SparkSession as ConnectSparkSession + + result = isinstance(_spark, ConnectSparkSession) return result diff --git a/src/koheesio/spark/writers/delta/batch.py b/src/koheesio/spark/writers/delta/batch.py index 5f66df99..2c124cd5 100644 --- a/src/koheesio/spark/writers/delta/batch.py +++ b/src/koheesio/spark/writers/delta/batch.py @@ -323,7 +323,10 @@ def _validate_params(cls, params: dict) -> dict: clause = merge_conf.get("clause") if clause not in valid_clauses: raise ValueError(f"Invalid merge clause '{clause}' provided") - elif not isinstance(merge_builder, DeltaMergeBuilder): + elif ( + not isinstance(merge_builder, DeltaMergeBuilder) + or not type(merge_builder).__name__ == "DeltaMergeBuilder" + ): raise ValueError("merge_builder must be a list or merge clauses or a DeltaMergeBuilder instance") return params @@ -378,7 +381,7 @@ def execute(self) -> Writer.Output: if self.table.create_if_not_exists and not self.table.exists: _writer = _writer.options(**self.table.default_create_properties) - if isinstance(_writer, DeltaMergeBuilder): + if isinstance(_writer, DeltaMergeBuilder) or type(_writer).__name__ == "DeltaMergeBuilder": _writer.execute() else: if options := self.params: From a085947bb99809b38e04fc28f318b7a2c315598b Mon Sep 17 00:00:00 2001 From: Danny Meijer Date: Wed, 27 Nov 2024 15:09:03 +0100 Subject: [PATCH 28/33] Release/0.9 - final version bump and docs (#132) - Final version bump - Documentation update - Ran ruff (2 files changed) --------- Co-authored-by: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> --- README.md | 8 ++++++++ src/koheesio/__about__.py | 2 +- src/koheesio/integrations/box.py | 2 +- src/koheesio/spark/delta.py | 2 +- 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 75694533..31eff830 100644 --- a/README.md +++ b/README.md @@ -198,6 +198,14 @@ the `pyproject.toml` entry mentioned above or installing through pip. - SE Provides Data Quality checks for Spark DataFrames. - For more information, refer to the [Spark Expectations docs](https://engineering.nike.com/spark-expectations). +- __Spark Connect and Delta:__ + Koheesio is ready to be used with Spark Connect. In case you are using Delta package in combination with a remote/connect session, you are getting full support in Databricks and partial support for Delta package in Apache Spark. Full support for Delta in Apache Spark is coming with the release of PySpark 4.0. + - The spark extra can be installed by adding `koheesio[spark]` to the `pyproject.toml` entry mentioned above. + - The spark module is available through the `koheesio.spark` module. + - The delta module is available through the `koheesio.spark.writers.delta` module. + - For more information, refer to the [Databricks documentation](https://docs.databricks.com/). + - For more information on Apache Spark, refer to the [Apache Spark documentation](https://spark.apache.org/docs/latest/). + [//]: # (- **Brickflow:** Available through the `koheesio.integrations.workflow` module; installable through the `bf` extra.) [//]: # ( - Brickflow is a workflow orchestration tool that allows you to define and execute workflows in a declarative way.) [//]: # ( - For more information, refer to the [Brickflow docs](https://engineering.nike.com/brickflow)) diff --git a/src/koheesio/__about__.py b/src/koheesio/__about__.py index 350ac2de..7a8c687d 100644 --- a/src/koheesio/__about__.py +++ b/src/koheesio/__about__.py @@ -12,7 +12,7 @@ LICENSE_INFO = "Licensed as Apache 2.0" SOURCE = "https://github.com/Nike-Inc/koheesio" -__version__ = "0.9.0rc6" +__version__ = "0.9.0" __logo__ = ( 75, ( diff --git a/src/koheesio/integrations/box.py b/src/koheesio/integrations/box.py index 0e61fe87..149d9898 100644 --- a/src/koheesio/integrations/box.py +++ b/src/koheesio/integrations/box.py @@ -10,7 +10,7 @@ * Application is authorized for the enterprise (Developer Portal - MyApp - Authorization) """ -from typing import Any, Dict, Optional, Union, IO +from typing import IO, Any, Dict, Optional, Union from abc import ABC from io import BytesIO, StringIO from pathlib import PurePath diff --git a/src/koheesio/spark/delta.py b/src/koheesio/spark/delta.py index 291cd31d..70a084df 100644 --- a/src/koheesio/spark/delta.py +++ b/src/koheesio/spark/delta.py @@ -20,7 +20,7 @@ class DeltaTableStep(SparkStep): DeltaTable aims to provide a simple interface to create and manage Delta tables. It is a wrapper around the Spark SQL API for Delta tables. - + ## Description This pull request aims to make transformations callable within the Koheesio framework. ## Related Issue N/A ## Motivation and Context ## How Has This Been Tested? Added applicable unit test. ## Screenshots (if appropriate): ## Types of changes - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Checklist: - [x] My code follows the code style of this project. - [x] My change requires a change to the documentation. - [x] I have updated the documentation accordingly. - [x] I have read the **CONTRIBUTING** document. - [x] I have added tests to cover my changes. - [x] All new and existing tests passed. --------- Co-authored-by: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> --- pyproject.toml | 1 + .../spark/transformations/__init__.py | 62 +++++++++++++++++-- tests/spark/transformations/test_transform.py | 28 +++++++++ 3 files changed, 85 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8493bfe3..610bd9af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -467,6 +467,7 @@ exclude = [ [tool.ruff.format] # https://docs.astral.sh/ruff/formatter/#docstring-formatting docstring-code-format = true +docstring-code-line-length = 70 [tool.ruff.lint] select = [ diff --git a/src/koheesio/spark/transformations/__init__.py b/src/koheesio/spark/transformations/__init__.py index 3f273a85..ca9866b6 100644 --- a/src/koheesio/spark/transformations/__init__.py +++ b/src/koheesio/spark/transformations/__init__.py @@ -49,19 +49,25 @@ class Transformation(SparkStep, ABC): Example ------- + ### Implementing a transformation using the Transformation class: ```python from koheesio.steps.transformations import Transformation from pyspark.sql import functions as f class AddOne(Transformation): + target_column: str = "new_column" + def execute(self): - self.output.df = self.df.withColumn("new_column", f.col("old_column") + 1) + self.output.df = self.df.withColumn( + self.target_column, f.col("old_column") + 1 + ) ``` In the example above, the `execute` method is implemented to add 1 to the values of the `old_column` and store the result in a new column called `new_column`. + ### Using the transformation: In order to use this transformation, we can call the `transform` method: ```python @@ -85,6 +91,7 @@ def execute(self): | 2| 3| ... + ### Alternative ways to use the transformation: Alternatively, we can pass the DataFrame to the constructor and call the `execute` or `transform` method without any arguments: @@ -94,9 +101,24 @@ def execute(self): output_df = AddOne(df).execute().output.df ``` - Note: that the transform method was not implemented explicitly in the AddOne class. This is because the `transform` + > Note: that the transform method was not implemented explicitly in the AddOne class. This is because the `transform` method is already implemented in the `Transformation` class. This means that all classes that inherit from the Transformation class will have the `transform` method available. Only the execute method needs to be implemented. + + ### Using the transformation as a function: + The transformation can also be used as a function as part of a DataFrame's `transform` method: + + ```python + input_df = spark.range(3) + + output_df = input_df.transform(AddOne(target_column="foo")).transform( + AddOne(target_column="bar") + ) + ``` + + In the above example, the `AddOne` transformation is applied to the `input_df` DataFrame using the `transform` + method. The `output_df` will now contain the original DataFrame with an additional columns called `foo` and + `bar', each with the values of `id` + 1. """ df: Optional[DataFrame] = Field(default=None, description="The Spark DataFrame") @@ -111,7 +133,9 @@ def execute(self) -> SparkStep.Output: For example: ```python def execute(self): - self.output.df = self.df.withColumn("new_column", f.col("old_column") + 1) + self.output.df = self.df.withColumn( + "new_column", f.col("old_column") + 1 + ) ``` The transform method will call this method and return the output DataFrame. @@ -147,6 +171,26 @@ def transform(self, df: Optional[DataFrame] = None) -> DataFrame: self.execute() return self.output.df + def __call__(self, *args, **kwargs): + """Allow the class to be called as a function. + This is especially useful when using a DataFrame's transform method. + + Example + ------- + ```python + input_df = spark.range(3) + + output_df = input_df.transform(AddOne(target_column="foo")).transform( + AddOne(target_column="bar") + ) + ``` + + In the above example, the `AddOne` transformation is applied to the `input_df` DataFrame using the `transform` + method. The `output_df` will now contain the original DataFrame with an additional columns called `foo` and + `bar', each with the values of `id` + 1. + """ + return self.transform(*args, **kwargs) + class ColumnsTransformation(Transformation, ABC): """Extended Transformation class with a preset validator for handling column(s) data with a standardized input @@ -204,7 +248,9 @@ class ColumnsTransformation(Transformation, ABC): class AddOne(ColumnsTransformation): def execute(self): for column in self.get_columns(): - self.output.df = self.df.withColumn(column, f.col(column) + 1) + self.output.df = self.df.withColumn( + column, f.col(column) + 1 + ) ``` In the above example, the `execute` method is implemented to add 1 to the values of a given column. @@ -460,7 +506,9 @@ class ColumnsTransformationWithTarget(ColumnsTransformation, ABC): ```python from pyspark.sql import Column - from koheesio.steps.transformations import ColumnsTransformationWithTarget + from koheesio.steps.transformations import ( + ColumnsTransformationWithTarget, + ) class AddOneWithTarget(ColumnsTransformationWithTarget): @@ -477,7 +525,9 @@ def func(self, col: Column): # create a DataFrame with 3 rows df = SparkSession.builder.getOrCreate().range(3) - output_df = AddOneWithTarget(column="id", target_column="new_id").transform(df) + output_df = AddOneWithTarget( + column="id", target_column="new_id" + ).transform(df) ``` The `output_df` will now contain the original DataFrame with an additional column called `new_id` with the values of diff --git a/tests/spark/transformations/test_transform.py b/tests/spark/transformations/test_transform.py index 1f92e490..ce46a3d5 100644 --- a/tests/spark/transformations/test_transform.py +++ b/tests/spark/transformations/test_transform.py @@ -6,7 +6,9 @@ from koheesio.logger import LoggingFactory from koheesio.spark import DataFrame +from koheesio.spark.transformations.strings.substring import Substring from koheesio.spark.transformations.transform import Transform +from koheesio.spark.transformations.hash import Sha2Hash pytestmark = pytest.mark.spark @@ -83,3 +85,29 @@ def test_from_func(dummy_df): AddFooColumn = Transform.from_func(dummy_transform_func, target_column="foo") df = AddFooColumn(value="bar").transform(dummy_df) assert transform_output_test(df, {"id": 0, "foo": "bar"}) + + +def test_df_transform_compatibility(dummy_df: DataFrame): + + expected_data = { + "id": 0, + "foo": "bar", + "bar": "baz", + "foo_hash": "fcde2b2edba56bf408601fb721fe9b5c338d10ee429ea04fae5511b68fbf8fb9", + "foo_sub": "fcde2b", + } + + # set up a reusable Transform from a function + add_column = Transform.from_func(dummy_transform_func, value="bar") + + output_df = ( + dummy_df + # test the Transform class with multiple chained transforms + .transform(add_column(target_column="foo")) + .transform(add_column(target_column="bar", value="baz")) + # test that Transformation classes can be called directly by DataFrame.transform + .transform(Sha2Hash(columns="foo", target_column="foo_hash")) + .transform(Substring(column="foo_hash", start=1, length=6, target_column="foo_sub")) + ) + + assert output_df.head().asDict() == expected_data From de56d0033d79fff9bee282a2ed9cd77c1e24e8c9 Mon Sep 17 00:00:00 2001 From: Danny Meijer Date: Thu, 28 Nov 2024 14:06:06 +0100 Subject: [PATCH 30/33] [BUG] small fix for Tableau Server path checking (#134) ## Description Small fix for a tiny bug ## Related Issue ## Motivation and Context ## How Has This Been Tested? ## Screenshots (if appropriate): ## Types of changes - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Checklist: - [x] My code follows the code style of this project. - [ ] My change requires a change to the documentation. - [ ] I have updated the documentation accordingly. - [x] I have read the **CONTRIBUTING** document. - [ ] I have added tests to cover my changes. - [x] All new and existing tests passed. Co-authored-by: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> --- src/koheesio/integrations/spark/tableau/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/koheesio/integrations/spark/tableau/server.py b/src/koheesio/integrations/spark/tableau/server.py index 14d80a6f..d3eddc05 100644 --- a/src/koheesio/integrations/spark/tableau/server.py +++ b/src/koheesio/integrations/spark/tableau/server.py @@ -79,7 +79,7 @@ def validate_project(self) -> "TableauServer": if self.project and self.project_id: raise ValueError("Both 'project' and 'project_id' parameters cannot be provided at the same time.") - if not self.project_id and not self.project_id: + if not self.project and not self.project_id: raise ValueError("Either 'project' or 'project_id' parameters should be provided, none is set") return self From a7d2997e6b5c97f97d6f03e4f4d9590286d5398a Mon Sep 17 00:00:00 2001 From: Mikita Sakalouski <38785549+mikita-sakalouski@users.noreply.github.com> Date: Fri, 29 Nov 2024 09:19:22 +0100 Subject: [PATCH 31/33] [FEATURE] DataBricksSecret for getting secrets from DataBricks scope (#133) ## Description `DataBricksSecret` class can be used to get secrets from DataBricks scopes. ## Related Issue #66 ## Motivation and Context Support secret scope in Databricks ## How Has This Been Tested? Add mocked test ## Screenshots (if appropriate): ## Types of changes - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Checklist: - [x] My code follows the code style of this project. - [ ] My change requires a change to the documentation. - [ ] I have updated the documentation accordingly. - [x] I have read the **CONTRIBUTING** document. - [x] I have added tests to cover my changes. - [ ] All new and existing tests passed. --------- Co-authored-by: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Co-authored-by: Danny Meijer --- docs/tutorials/getting-started.md | 17 ---- src/koheesio/__about__.py | 2 +- .../integrations/snowflake/__init__.py | 2 +- .../integrations/spark/databricks/__init__.py | 0 .../integrations/spark/databricks/secrets.py | 79 +++++++++++++++++++ .../integrations/spark/databricks/utils.py | 16 ++++ src/koheesio/integrations/spark/snowflake.py | 8 +- tests/spark/conftest.py | 18 +++++ .../integrations/databrikcs/test_secrets.py | 48 +++++++++++ tests/spark/test_spark.py | 2 +- 10 files changed, 168 insertions(+), 24 deletions(-) create mode 100644 src/koheesio/integrations/spark/databricks/__init__.py create mode 100644 src/koheesio/integrations/spark/databricks/secrets.py create mode 100644 src/koheesio/integrations/spark/databricks/utils.py create mode 100644 tests/spark/integrations/databrikcs/test_secrets.py diff --git a/docs/tutorials/getting-started.md b/docs/tutorials/getting-started.md index 8fd80ff1..6fe88897 100644 --- a/docs/tutorials/getting-started.md +++ b/docs/tutorials/getting-started.md @@ -19,23 +19,6 @@ ``` -
- poetry - - If you're using Poetry, add the following entry to the `pyproject.toml` file: - - ```toml title="pyproject.toml" - [[tool.poetry.source]] - name = "nike" - url = "https://artifactory.nike.com/artifactory/api/pypi/python-virtual/simple" - secondary = true - ``` - - ```bash - poetry add koheesio - ``` -
-
pip diff --git a/src/koheesio/__about__.py b/src/koheesio/__about__.py index 7a8c687d..20f1a1b9 100644 --- a/src/koheesio/__about__.py +++ b/src/koheesio/__about__.py @@ -12,7 +12,7 @@ LICENSE_INFO = "Licensed as Apache 2.0" SOURCE = "https://github.com/Nike-Inc/koheesio" -__version__ = "0.9.0" +__version__ = "0.9.0rc7" __logo__ = ( 75, ( diff --git a/src/koheesio/integrations/snowflake/__init__.py b/src/koheesio/integrations/snowflake/__init__.py index 9e1603d9..dcabbcb5 100644 --- a/src/koheesio/integrations/snowflake/__init__.py +++ b/src/koheesio/integrations/snowflake/__init__.py @@ -449,7 +449,7 @@ class GrantPrivilegesOnObject(SnowflakeRunQueryPython): object="MY_TABLE", type="TABLE", warehouse="MY_WH", - user="gid.account@nike.com", + user="gid.account@abc.com", password=Secret("super-secret-password"), role="APPLICATION.SNOWFLAKE.ADMIN", permissions=["SELECT", "INSERT"], diff --git a/src/koheesio/integrations/spark/databricks/__init__.py b/src/koheesio/integrations/spark/databricks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/koheesio/integrations/spark/databricks/secrets.py b/src/koheesio/integrations/spark/databricks/secrets.py new file mode 100644 index 00000000..1f04d54d --- /dev/null +++ b/src/koheesio/integrations/spark/databricks/secrets.py @@ -0,0 +1,79 @@ +"""Module for retrieving secrets from DataBricks Scopes. + +Secrets are stored as SecretContext and can be accessed accordingly. + +See DataBricksSecret for more information. +""" + +from typing import Dict, Optional +import re + +from pyspark.sql import SparkSession + +from koheesio.integrations.spark.databricks.utils import get_dbutils +from koheesio.models import Field, model_validator +from koheesio.secrets import Secret + + +class DataBricksSecret(Secret): + """ + Retrieve secrets from DataBricks secret scope and wrap them into Context class for easy access. + All secrets are stored under the "secret" root and "parent". "Parent" either derived from the + secure scope by replacing "/" and "-", or manually provided by the user. + Secrets are wrapped into the pydantic.SecretStr. + + Examples + --------- + + ```python + context = {"secrets": {"parent": {"webhook": SecretStr("**********"), "description": SecretStr("**********")}}} + ``` + + Values can be decoded like this: + ```python + context.secrets.parent.webhook.get_secret_value() + ``` + or if working with dictionary is preferable: + ```python + for key, value in context.get_all().items(): + value.get_secret_value() + ``` + """ + + scope: str = Field(description="Scope") + alias: Optional[Dict[str, str]] = Field(default_factory=dict, description="Alias for secret keys") + + @model_validator(mode="before") + def _set_parent_to_scope(cls, values): + """ + Set default value for `parent` parameter on model initialization when it was not + explicitly set by the user. In this scenario scope will be used: + + 'secret-scope' -> secret_scope + """ + regex = re.compile(r"[/-]") + path = values.get("scope") + + if not values.get("parent"): + values["parent"] = regex.sub("_", path) + + return values + + @property + def _client(self): + """ + Instantiated Databricks client. + """ + + return get_dbutils(SparkSession.getActiveSession()) # type: ignore + + def _get_secrets(self): + """Dictionary of secrets.""" + all_keys = (secret_meta.key for secret_meta in self._client.secrets.list(scope=self.scope)) + secret_data = {} + + for key in all_keys: + key_name = key if not (self.alias and self.alias.get(key)) else self.alias[key] # pylint: disable=E1101 + secret_data[key_name] = self._client.secrets.get(scope=self.scope, key=key) + + return secret_data diff --git a/src/koheesio/integrations/spark/databricks/utils.py b/src/koheesio/integrations/spark/databricks/utils.py new file mode 100644 index 00000000..9a2a8849 --- /dev/null +++ b/src/koheesio/integrations/spark/databricks/utils.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from pyspark.sql import SparkSession + +from koheesio.spark.utils import on_databricks + + +def get_dbutils(spark_session: SparkSession) -> DBUtils: # type: ignore # noqa: F821 + if not on_databricks(): + raise RuntimeError("dbutils not available") + + from pyspark.dbutils import DBUtils # pylint: disable=E0611,E0401 # type: ignore + + dbutils = DBUtils(spark_session) + + return dbutils diff --git a/src/koheesio/integrations/spark/snowflake.py b/src/koheesio/integrations/spark/snowflake.py index a6e6941f..ba292afe 100644 --- a/src/koheesio/integrations/spark/snowflake.py +++ b/src/koheesio/integrations/spark/snowflake.py @@ -302,7 +302,7 @@ class Query(SnowflakeReader): database="MY_DB", schema_="MY_SCHEMA", warehouse="MY_WH", - user="gid.account@nike.com", + user="gid.account@abc.com", password=Secret("super-secret-password"), role="APPLICATION.SNOWFLAKE.ADMIN", query="SELECT * FROM MY_TABLE", @@ -412,7 +412,7 @@ class CreateOrReplaceTableFromDataFrame(SnowflakeTransformation): database="MY_DB", schema="MY_SCHEMA", warehouse="MY_WH", - user="gid.account@nike.com", + user="gid.account@abc.com", password="super-secret-password", role="APPLICATION.SNOWFLAKE.ADMIN", table="MY_TABLE", @@ -477,7 +477,7 @@ class GetTableSchema(SnowflakeStep): database="MY_DB", schema_="MY_SCHEMA", warehouse="MY_WH", - user="gid.account@nike.com", + user="gid.account@abc.com", password="super-secret-password", role="APPLICATION.SNOWFLAKE.ADMIN", table="MY_TABLE", @@ -512,7 +512,7 @@ class AddColumn(SnowflakeStep): database="MY_DB", schema_="MY_SCHEMA", warehouse="MY_WH", - user="gid.account@nike.com", + user="gid.account@abc.com", password=Secret("super-secret-password"), role="APPLICATION.SNOWFLAKE.ADMIN", table="MY_TABLE", diff --git a/tests/spark/conftest.py b/tests/spark/conftest.py index f918ae40..9809b9be 100644 --- a/tests/spark/conftest.py +++ b/tests/spark/conftest.py @@ -1,3 +1,4 @@ +from typing import Any from collections import namedtuple import datetime from decimal import Decimal @@ -347,3 +348,20 @@ def df_with_all_types(spark): data=[[v[0] for v in data.values()]], schema=StructType([StructField(name=v[1], dataType=v[2]) for v in data.values()]), ) + + +class ScopeSecrets: + class SecretMeta: + def __init__(self, key: str): + self.key = key + + def __init__(self, secrets: dict): + self.secrets = secrets + + def get(self, scope: str, key: str) -> Any: + return self.secrets.get(key) + + def list(self, scope: str): + keys = [ScopeSecrets.SecretMeta(key=key) for key in self.secrets.keys()] + + return keys diff --git a/tests/spark/integrations/databrikcs/test_secrets.py b/tests/spark/integrations/databrikcs/test_secrets.py new file mode 100644 index 00000000..11af1170 --- /dev/null +++ b/tests/spark/integrations/databrikcs/test_secrets.py @@ -0,0 +1,48 @@ +from unittest.mock import patch + +from conftest import ScopeSecrets + +from koheesio.integrations.spark.databricks.secrets import DataBricksSecret + + +class TestDatabricksSecret: + def test_set_parent_to_scope(self): + # Test when parent is not provided + secret = DataBricksSecret(scope="secret-scope") + assert secret.parent == "secret_scope" + + # Test when parent is provided + secret = DataBricksSecret(scope="secret-scope", parent="custom_parent") + assert secret.parent == "custom_parent" + + @patch("koheesio.integrations.spark.databricks.secrets.DataBricksSecret._client") + def test_get_secrets_no_alias(self, mock_databricks_client): + with patch("koheesio.integrations.spark.databricks.utils.on_databricks", return_value=True): + dd = { + "key1": "value_of_key1", + "key2": "value_of_key2", + } + databricks = DataBricksSecret(scope="dummy", parent="kafka") + mock_databricks_client.secrets = ScopeSecrets(dd) + secrets = databricks._get_secrets() + + assert secrets["key1"] == "value_of_key1" + assert secrets["key2"] == "value_of_key2" + + @patch("koheesio.integrations.spark.databricks.secrets.DataBricksSecret._client") + def test_get_secrets_alias(self, mock_databricks_client): + with patch("koheesio.integrations.spark.databricks.utils.on_databricks", return_value=True): + dd = { + "key1": "value_of_key1", + "key2": "value_of_key2", + } + alias = { + "key1": "new_name_key1", + "key2": "new_name_key2", + } + databricks = DataBricksSecret(scope="dummy", parent="kafka", alias=alias) + mock_databricks_client.secrets = ScopeSecrets(dd) + secrets = databricks._get_secrets() + + assert secrets["new_name_key1"] == "value_of_key1" + assert secrets["new_name_key2"] == "value_of_key2" diff --git a/tests/spark/test_spark.py b/tests/spark/test_spark.py index e19b3e02..003060b5 100644 --- a/tests/spark/test_spark.py +++ b/tests/spark/test_spark.py @@ -26,7 +26,7 @@ def test_import_error_no_error(self): with mock.patch.dict("sys.modules", {"pyspark": None}): from koheesio.sso.okta import OktaAccessToken - OktaAccessToken(url="https://nike.okta.com", client_id="client_id", client_secret=secret) + OktaAccessToken(url="https://abc.okta.com", client_id="client_id", client_secret=secret) def test_import_error_with_error(self): with mock.patch.dict("sys.modules", {"pyspark.sql": None, "koheesio.steps.spark": None}): From 1e21e37e01bd86738b64d79d381db714b88a0e03 Mon Sep 17 00:00:00 2001 From: Danny Meijer Date: Fri, 29 Nov 2024 15:33:30 +0100 Subject: [PATCH 32/33] [FIX] Remove mention of non-existent class type in docs (#138) ## Description ## Related Issue #135 ## Motivation and Context ## How Has This Been Tested? ## Screenshots (if appropriate): ## Types of changes - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Checklist: - [x] My code follows the code style of this project. - [x] My change requires a change to the documentation. - [x] I have updated the documentation accordingly. - [x] I have read the **CONTRIBUTING** document. - [ ] I have added tests to cover my changes. - [x] All new and existing tests passed. --- src/koheesio/integrations/snowflake/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/koheesio/integrations/snowflake/__init__.py b/src/koheesio/integrations/snowflake/__init__.py index dcabbcb5..acbe9f6f 100644 --- a/src/koheesio/integrations/snowflake/__init__.py +++ b/src/koheesio/integrations/snowflake/__init__.py @@ -2,7 +2,7 @@ """ Snowflake steps and tasks for Koheesio -Every class in this module is a subclass of `Step` or `Task` and is used to perform operations on Snowflake. +Every class in this module is a subclass of `Step` or `BaseModel` and is used to perform operations on Snowflake. Notes ----- From 5298c2b4cb01e3e062061ef997ea1fa2e9280f33 Mon Sep 17 00:00:00 2001 From: Danny Meijer Date: Fri, 29 Nov 2024 15:34:09 +0100 Subject: [PATCH 33/33] [FIX] unused SparkSession being import from pyspark.sql in several tests (#140) ## Description ## Related Issue #139 ## Motivation and Context ## How Has This Been Tested? ## Screenshots (if appropriate): ## Types of changes - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Checklist: - [x] My code follows the code style of this project. - [ ] My change requires a change to the documentation. - [ ] I have updated the documentation accordingly. - [x] I have read the **CONTRIBUTING** document. - [x] I have added tests to cover my changes. - [x] All new and existing tests passed. Co-authored-by: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> --- tests/spark/integrations/dq/test_spark_expectations.py | 1 - tests/spark/readers/test_hana.py | 4 ---- tests/spark/readers/test_jdbc.py | 4 ---- tests/spark/transformations/test_lookup.py | 4 +--- tests/spark/transformations/test_row_number_dedup.py | 9 ++++----- 5 files changed, 5 insertions(+), 17 deletions(-) diff --git a/tests/spark/integrations/dq/test_spark_expectations.py b/tests/spark/integrations/dq/test_spark_expectations.py index e776bd40..3ad43146 100644 --- a/tests/spark/integrations/dq/test_spark_expectations.py +++ b/tests/spark/integrations/dq/test_spark_expectations.py @@ -2,7 +2,6 @@ import pytest -import pyspark from pyspark.sql import SparkSession from koheesio.utils import get_project_root diff --git a/tests/spark/readers/test_hana.py b/tests/spark/readers/test_hana.py index 35c603ad..6049aadb 100644 --- a/tests/spark/readers/test_hana.py +++ b/tests/spark/readers/test_hana.py @@ -1,9 +1,5 @@ -from unittest import mock - import pytest -from pyspark.sql import SparkSession - from koheesio.spark.readers.hana import HanaReader pytestmark = pytest.mark.spark diff --git a/tests/spark/readers/test_jdbc.py b/tests/spark/readers/test_jdbc.py index b75c2ce1..2cecff49 100644 --- a/tests/spark/readers/test_jdbc.py +++ b/tests/spark/readers/test_jdbc.py @@ -1,9 +1,5 @@ -from unittest import mock - import pytest -from pyspark.sql import SparkSession - from koheesio.spark.readers.jdbc import JdbcReader pytestmark = pytest.mark.spark diff --git a/tests/spark/transformations/test_lookup.py b/tests/spark/transformations/test_lookup.py index 8c667375..6df2d785 100644 --- a/tests/spark/transformations/test_lookup.py +++ b/tests/spark/transformations/test_lookup.py @@ -1,7 +1,5 @@ import pytest -from pyspark.sql import SparkSession - from koheesio.spark.transformations.lookup import ( DataframeLookup, JoinHint, @@ -37,7 +35,7 @@ def test_join_hint_values() -> None: @pytest.mark.parametrize("join_hint", [None, JoinHint.BROADCAST]) -def test_dataframe_lookup(spark: SparkSession, join_hint: JoinHint) -> None: +def test_dataframe_lookup(spark, join_hint: JoinHint) -> None: df = spark.createDataFrame( [("1", "a", "a"), ("2", "b", "b")], schema="key string, second_key string, field string", diff --git a/tests/spark/transformations/test_row_number_dedup.py b/tests/spark/transformations/test_row_number_dedup.py index ba521a5d..b2690cfe 100644 --- a/tests/spark/transformations/test_row_number_dedup.py +++ b/tests/spark/transformations/test_row_number_dedup.py @@ -2,7 +2,6 @@ import pytest -from pyspark.sql import SparkSession from pyspark.sql import functions as F from koheesio.spark.transformations.row_number_dedup import RowNumberDedup @@ -11,7 +10,7 @@ @pytest.mark.parametrize("target_column", ["col_row_number"]) -def test_row_number_dedup(spark: SparkSession, target_column: str) -> None: +def test_row_number_dedup(spark, target_column: str) -> None: df = spark.createDataFrame( [ ( @@ -49,7 +48,7 @@ def test_row_number_dedup(spark: SparkSession, target_column: str) -> None: @pytest.mark.parametrize("target_column", ["col_row_number"]) -def test_row_number_dedup_not_list_column(spark: SparkSession, target_column: str) -> None: +def test_row_number_dedup_not_list_column(spark, target_column: str) -> None: df = spark.createDataFrame( [ ( @@ -89,7 +88,7 @@ def test_row_number_dedup_not_list_column(spark: SparkSession, target_column: st @pytest.mark.parametrize("target_column", ["col_row_number"]) -def test_row_number_dedup_with_columns(spark: SparkSession, target_column: str) -> None: +def test_row_number_dedup_with_columns(spark, target_column: str) -> None: df = spark.createDataFrame( [ ( @@ -129,7 +128,7 @@ def test_row_number_dedup_with_columns(spark: SparkSession, target_column: str) @pytest.mark.parametrize("target_column", ["col_row_number"]) -def test_row_number_dedup_with_duplicated_columns(spark: SparkSession, target_column: str) -> None: +def test_row_number_dedup_with_duplicated_columns(spark, target_column: str) -> None: df = spark.createDataFrame( [ (