Skip to content

Commit

Permalink
additional coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
dannymeijer committed Nov 8, 2024
1 parent 09959fe commit 10bbb70
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
2 changes: 0 additions & 2 deletions spark_expectations/core/expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,13 @@ def check_if_pyspark_connect_is_supported() -> bool:

# pylint: disable=ungrouped-imports
if check_if_pyspark_connect_is_supported():
print("PySpark connect is supported")
# Import the connect module if the current version of PySpark supports it
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.sql.connect.session import SparkSession as ConnectSparkSession

DataFrame = Union[sql.DataFrame, ConnectDataFrame] # type: ignore
SparkSession = Union[sql.SparkSession, ConnectSparkSession] # type: ignore
else:
print("PySpark connect is not supported")
# Otherwise, use the default PySpark classes
from pyspark.sql.dataframe import DataFrame # type: ignore
from pyspark.sql.session import SparkSession # type: ignore
Expand Down
35 changes: 33 additions & 2 deletions tests/core/test_expectations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=too-many-lines
import os
import datetime
from unittest.mock import Mock, PropertyMock
from unittest.mock import MagicMock, Mock, PropertyMock
from unittest.mock import patch
import pytest
from pyspark.sql import DataFrame, SparkSession
Expand All @@ -19,6 +19,8 @@
from spark_expectations.core.expectations import (
SparkExpectations,
WrappedDataFrameWriter,
check_if_pyspark_connect_is_supported,
get_spark_minor_version,
)
from spark_expectations.config.user_config import Constants as user_config
from spark_expectations.core import get_spark_session
Expand All @@ -37,7 +39,7 @@ def fixture_setup_local_kafka_topic():
os.getenv("UNIT_TESTING_ENV")
!= "spark_expectations_unit_testing_on_github_actions"
):
# remove if docker conatiner is running
# remove if docker container is running
os.system(
f"sh {current_dir}/../../spark_expectations/examples/docker_scripts/docker_kafka_stop_script.sh"
)
Expand Down Expand Up @@ -3404,3 +3406,32 @@ def test_delta_bucketby_exception():
match=r"Bucketing is not supported for delta tables yet",
):
writer.build()


class TestCheckIfPysparkConnectIsSupported:
def test_if_pyspark_connect_is_not_supported(self):
"""Test that check_if_pyspark_connect_is_supported returns False when
pyspark connect is not supported."""
with patch.dict("sys.modules", {"pyspark.sql.connect": None}):
assert check_if_pyspark_connect_is_supported() is False

def test_check_if_pyspark_connect_is_supported(self):
"""Test that check_if_pyspark_connect_is_supported returns True when
pyspark connect is supported."""
with (
patch("spark_expectations.core.expectations.SPARK_MINOR_VERSION", 3.5),
patch.dict(
"sys.modules",
{
"pyspark.sql.connect.column": MagicMock(Column=MagicMock()),
"pyspark.sql.connect": MagicMock(),
},
),
):
assert check_if_pyspark_connect_is_supported() is True


def test_get_spark_minor_version():
"""Test that get_spark_minor_version returns the correctly formatted version."""
with patch("spark_expectations.core.expectations.spark_version", "9.9.42"):
assert get_spark_minor_version() == 9.9

0 comments on commit 10bbb70

Please sign in to comment.