Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating Sparksession #31

Merged
merged 2 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions spark_expectations/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import os
from pyspark.sql.session import SparkSession
from delta import configure_spark_with_delta_pip

current_dir = os.path.dirname(os.path.abspath(__file__))


def get_spark_session() -> SparkSession:
builder = SparkSession.builder

if (
os.environ.get("UNIT_TESTING_ENV")
== "spark_expectations_unit_testing_on_github_actions"
) or (os.environ.get("SPARKEXPECTATIONS_ENV") == "local"):
from delta import configure_spark_with_delta_pip

builder = (
builder.config(
SparkSession.builder.config(
"spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension"
)
.config(
Expand All @@ -30,9 +29,7 @@ def get_spark_session() -> SparkSession:
f"{current_dir}/../../jars/commons-pool2-2.8.0.jar,"
f"{current_dir}/../../jars/spark-token-provider-kafka-0-10_2.12-3.0.0.jar",
)
# .config("spark.databricks.delta.checkLatestSchemaOnRead", "false")
)
return configure_spark_with_delta_pip(builder).getOrCreate()

spark = configure_spark_with_delta_pip(builder).getOrCreate()

return spark
return SparkSession.builder.getOrCreate()
14 changes: 13 additions & 1 deletion tests/core/test__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
from unittest import mock
from unittest.mock import patch
from pyspark.sql.session import SparkSession
from spark_expectations.core import get_spark_session
from spark_expectations.core.__init__ import current_dir


@patch('spark_expectations.core.__init__.current_dir', autospec=True, spec_set=True)
def test_get_spark_session(_mock_os):
spark = get_spark_session()
Expand All @@ -11,7 +14,7 @@ def test_get_spark_session(_mock_os):
# Add additional assertions as needed to test the SparkSession configuration
assert "io.delta.sql.DeltaSparkSessionExtension" in spark.sparkContext.getConf().get("spark.sql.extensions")
assert "org.apache.spark.sql.delta.catalog.DeltaCatalog" in spark.sparkContext.getConf().get(
"spark.sql.catalog.spark_catalog")
"spark.sql.catalog.spark_catalog")

# Test that the warehouse and derby directories are properly configured
assert "/tmp/hive/warehouse" in spark.sparkContext.getConf().get("spark.sql.warehouse.dir")
Expand All @@ -23,3 +26,12 @@ def test_get_spark_session(_mock_os):
f"{current_dir}/../../jars/spark-token-provider-kafka-0-10_2.12-3.0.0.jar"

# Add more assertions to test any other desired SparkSession configuration options


@mock.patch.dict(os.environ, {"UNIT_TESTING_ENV": "disable", "SPARKEXPECTATIONS_ENV": "disable"})
def test_get_spark_active_session():
spark = SparkSession.builder.getOrCreate()

# Now try to get the active session as we disabled unittest flags for this test
active = get_spark_session()
assert active == spark