From 10bbb70a9cbcb6a7307f888d6631801ebbb91b49 Mon Sep 17 00:00:00 2001 From: Danny Meijer Date: Sat, 9 Nov 2024 00:07:34 +0100 Subject: [PATCH] additional coverage --- spark_expectations/core/expectations.py | 2 -- tests/core/test_expectations.py | 35 +++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/spark_expectations/core/expectations.py b/spark_expectations/core/expectations.py index a6bc049..1ec3220 100644 --- a/spark_expectations/core/expectations.py +++ b/spark_expectations/core/expectations.py @@ -57,7 +57,6 @@ def check_if_pyspark_connect_is_supported() -> bool: # pylint: disable=ungrouped-imports if check_if_pyspark_connect_is_supported(): - print("PySpark connect is supported") # Import the connect module if the current version of PySpark supports it from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame from pyspark.sql.connect.session import SparkSession as ConnectSparkSession @@ -65,7 +64,6 @@ def check_if_pyspark_connect_is_supported() -> bool: DataFrame = Union[sql.DataFrame, ConnectDataFrame] # type: ignore SparkSession = Union[sql.SparkSession, ConnectSparkSession] # type: ignore else: - print("PySpark connect is not supported") # Otherwise, use the default PySpark classes from pyspark.sql.dataframe import DataFrame # type: ignore from pyspark.sql.session import SparkSession # type: ignore diff --git a/tests/core/test_expectations.py b/tests/core/test_expectations.py index 52e4fb5..e7264d6 100644 --- a/tests/core/test_expectations.py +++ b/tests/core/test_expectations.py @@ -1,7 +1,7 @@ # pylint: disable=too-many-lines import os import datetime -from unittest.mock import Mock, PropertyMock +from unittest.mock import MagicMock, Mock, PropertyMock from unittest.mock import patch import pytest from pyspark.sql import DataFrame, SparkSession @@ -19,6 +19,8 @@ from spark_expectations.core.expectations import ( SparkExpectations, WrappedDataFrameWriter, + check_if_pyspark_connect_is_supported, + get_spark_minor_version, ) from spark_expectations.config.user_config import Constants as user_config from spark_expectations.core import get_spark_session @@ -37,7 +39,7 @@ def fixture_setup_local_kafka_topic(): os.getenv("UNIT_TESTING_ENV") != "spark_expectations_unit_testing_on_github_actions" ): - # remove if docker conatiner is running + # remove if docker container is running os.system( f"sh {current_dir}/../../spark_expectations/examples/docker_scripts/docker_kafka_stop_script.sh" ) @@ -3404,3 +3406,32 @@ def test_delta_bucketby_exception(): match=r"Bucketing is not supported for delta tables yet", ): writer.build() + + +class TestCheckIfPysparkConnectIsSupported: + def test_if_pyspark_connect_is_not_supported(self): + """Test that check_if_pyspark_connect_is_supported returns False when + pyspark connect is not supported.""" + with patch.dict("sys.modules", {"pyspark.sql.connect": None}): + assert check_if_pyspark_connect_is_supported() is False + + def test_check_if_pyspark_connect_is_supported(self): + """Test that check_if_pyspark_connect_is_supported returns True when + pyspark connect is supported.""" + with ( + patch("spark_expectations.core.expectations.SPARK_MINOR_VERSION", 3.5), + patch.dict( + "sys.modules", + { + "pyspark.sql.connect.column": MagicMock(Column=MagicMock()), + "pyspark.sql.connect": MagicMock(), + }, + ), + ): + assert check_if_pyspark_connect_is_supported() is True + + +def test_get_spark_minor_version(): + """Test that get_spark_minor_version returns the correctly formatted version.""" + with patch("spark_expectations.core.expectations.spark_version", "9.9.42"): + assert get_spark_minor_version() == 9.9