Skip to content

Commit

Permalink
avoid accessing global spark session in feb sink
Browse files Browse the repository at this point in the history
  • Loading branch information
jingz-db committed Jan 24, 2025
1 parent 78d1c7d commit 983b476
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 17 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def transformWithStateUDF(
evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
)

# TODO figure out if we need to handle for string type
# TODO add a string struct type test
output_schema: str = (
outputStructType.json()
if isinstance(outputStructType, StructType)
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def handle_pre_init(
# won't be used again on JVM.
statefulProcessor.close()

# return a dummy results, no return value is needed for pre init
# return a dummy result, no return value is needed for pre init
return iter([])

def handle_data_rows(
Expand Down
43 changes: 28 additions & 15 deletions python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,11 @@ def _prepare_input_data_with_3_cols(self, input_path, col1, col2, col3):
for e1, e2, e3 in zip(col1, col2, col3):
fw.write(f"{e1},{e2},{e3}\n")

def end_query_from_feb_sink(self):
raise Exception(f"Ending the query by throw an exception for ProcessingTime mode")
# TODO SPARK-50180 This is a hack to exit the query when all assertions are passed
# for processing time mode
def _check_query_end_exception(self, error):
error_msg = str(error)
return "Checks passed, ending the query for processing time mode" in error_msg

def build_test_df_with_3_cols(self, input_path):
df = self.spark.readStream.format("text").option("maxFilesPerTrigger", 1).load(input_path)
Expand Down Expand Up @@ -285,7 +288,8 @@ def check_results(batch_df, batch_id):
self._test_transform_with_state_in_pandas_basic(
ListStateLargeTTLProcessor(), check_results, True, "processingTime"
)
"""
def test_transform_with_state_in_pandas_map_state(self):
def check_results(batch_df, _):
assert set(batch_df.sort("id").collect()) == {
Expand All @@ -294,6 +298,7 @@ def check_results(batch_df, _):
}

self._test_transform_with_state_in_pandas_basic(MapStateProcessor(), check_results, True)
"""
# test map state with ttl has the same behavior as map state when state doesn't expire.
def test_transform_with_state_in_pandas_map_state_large_ttl(self):
Expand Down Expand Up @@ -1067,6 +1072,7 @@ def check_results(batch_df, batch_id):
checkpoint_path=checkpoint_path,
initial_state=initial_state,
)
"""

# This test covers multiple list state variables and flatten option
def test_transform_with_list_state_metadata(self):
Expand All @@ -1080,7 +1086,7 @@ def check_results(batch_df, batch_id):
}
else:
# check for state metadata source
metadata_df = self.spark.read.format("state-metadata").load(checkpoint_path)
metadata_df = batch_df.sparkSession.read.format("state-metadata").load(checkpoint_path)
operator_properties_json_obj = json.loads(
metadata_df.select("operatorProperties").collect()[0][0]
)
Expand All @@ -1095,7 +1101,7 @@ def check_results(batch_df, batch_id):

# check for state data source and flatten option
list_state_1_df = (
self.spark.read.format("statestore")
batch_df.sparkSession.read.format("statestore")
.option("path", checkpoint_path)
.option("stateVarName", "listState1")
.option("flattenCollectionTypes", True)
Expand All @@ -1118,7 +1124,7 @@ def check_results(batch_df, batch_id):
]

list_state_2_df = (
self.spark.read.format("statestore")
batch_df.sparkSession.read.format("statestore")
.option("path", checkpoint_path)
.option("stateVarName", "listState2")
.option("flattenCollectionTypes", False)
Expand All @@ -1135,16 +1141,23 @@ def check_results(batch_df, batch_id):
Row(groupingKey="1", valueSortedList=[20, 20, 120, 120, 222]),
]

self.end_query_from_feb_sink()
# TODO SPARK-50180 This is a hack to exit the query when all assertions are passed
# for processing time mode
raise Exception("Checks passed, ending the query for processing time mode")

self._test_transform_with_state_in_pandas_basic(
ListStateProcessor(),
check_results,
True,
"processingTime",
checkpoint_path=checkpoint_path,
initial_state=None,
)
try:
self._test_transform_with_state_in_pandas_basic(
ListStateProcessor(),
check_results,
True,
"processingTime",
checkpoint_path=checkpoint_path,
initial_state=None,
)
except Exception as e:
self.assertTrue(self._check_query_end_exception(e))

"""
# This test covers value state variable and read change feed,
# snapshotStartBatchId related options
Expand Down

0 comments on commit 983b476

Please sign in to comment.