From 2e7e0da64117a2159db96258f32d3070e77c56e8 Mon Sep 17 00:00:00 2001 From: bogao007 Date: Thu, 23 Jan 2025 18:34:12 -0800 Subject: [PATCH 1/2] Update API naming --- python/pyspark/sql/pandas/group_ops.py | 4 +- .../sql/streaming/stateful_processor.py | 182 +++++++++--------- .../test_pandas_transform_with_state.py | 92 ++++----- 3 files changed, 139 insertions(+), 139 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 343a68bf010bf..e0108da34f0c2 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -566,8 +566,8 @@ def handle_expired_timers( statefulProcessorApiClient.set_implicit_key(key_obj) for pd in statefulProcessor.handleExpiredTimer( key=key_obj, - timer_values=TimerValues(batch_timestamp, watermark_timestamp), - expired_timer_info=ExpiredTimerInfo(expiry_timestamp), + timerValues=TimerValues(batch_timestamp, watermark_timestamp), + expiredTimerInfo=ExpiredTimerInfo(expiry_timestamp), ): yield pd statefulProcessorApiClient.delete_timer(expiry_timestamp) diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py index b04bb955488ab..9885fe57b45d8 100644 --- a/python/pyspark/sql/streaming/stateful_processor.py +++ b/python/pyspark/sql/streaming/stateful_processor.py @@ -45,33 +45,33 @@ class ValueState: .. versionadded:: 4.0.0 """ - def __init__(self, value_state_client: ValueStateClient, state_name: str) -> None: - self._value_state_client = value_state_client - self._state_name = state_name + def __init__(self, valueStateClient: ValueStateClient, stateName: str) -> None: + self._valueStateClient = valueStateClient + self._stateName = stateName def exists(self) -> bool: """ Whether state exists or not. """ - return self._value_state_client.exists(self._state_name) + return self._valueStateClient.exists(self._stateName) def get(self) -> Optional[Tuple]: """ Get the state value if it exists. Returns None if the state variable does not have a value. """ - return self._value_state_client.get(self._state_name) + return self._valueStateClient.get(self._stateName) - def update(self, new_value: Tuple) -> None: + def update(self, newValue: Tuple) -> None: """ Update the value of the state. """ - self._value_state_client.update(self._state_name, new_value) + self._valueStateClient.update(self._stateName, newValue) def clear(self) -> None: """ Remove this state. """ - self._value_state_client.clear(self._state_name) + self._valueStateClient.clear(self._stateName) class TimerValues: @@ -82,22 +82,22 @@ class TimerValues: """ def __init__( - self, current_processing_time_in_ms: int = -1, current_watermark_in_ms: int = -1 + self, currentProcessingTimeInMs: int = -1, currentWatermarkInMs: int = -1 ) -> None: - self._current_processing_time_in_ms = current_processing_time_in_ms - self._current_watermark_in_ms = current_watermark_in_ms + self._currentProcessingTimeInMs = currentProcessingTimeInMs + self._currentWatermarkInMs = currentWatermarkInMs - def get_current_processing_time_in_ms(self) -> int: + def getCurrentProcessingTimeInMs(self) -> int: """ Get processing time for current batch, return timestamp in millisecond. """ - return self._current_processing_time_in_ms + return self._currentProcessingTimeInMs - def get_current_watermark_in_ms(self) -> int: + def getCurrentWatermarkInMs(self) -> int: """ Get watermark for current batch, return timestamp in millisecond. """ - return self._current_watermark_in_ms + return self._currentWatermarkInMs class ExpiredTimerInfo: @@ -106,14 +106,14 @@ class ExpiredTimerInfo: .. versionadded:: 4.0.0 """ - def __init__(self, expiry_time_in_ms: int = -1) -> None: - self._expiry_time_in_ms = expiry_time_in_ms + def __init__(self, expiryTimeInMs: int = -1) -> None: + self._expiryTimeInMs = expiryTimeInMs - def get_expiry_time_in_ms(self) -> int: + def getExpiryTimeInMs(self) -> int: """ Get the timestamp for expired timer, return timestamp in millisecond. """ - return self._expiry_time_in_ms + return self._expiryTimeInMs class ListState: @@ -124,45 +124,45 @@ class ListState: .. versionadded:: 4.0.0 """ - def __init__(self, list_state_client: ListStateClient, state_name: str) -> None: - self._list_state_client = list_state_client - self._state_name = state_name + def __init__(self, listStateClient: ListStateClient, stateName: str) -> None: + self._listStateClient = listStateClient + self._stateName = stateName def exists(self) -> bool: """ Whether list state exists or not. """ - return self._list_state_client.exists(self._state_name) + return self._listStateClient.exists(self._stateName) def get(self) -> Iterator[Tuple]: """ Get list state with an iterator. """ - return ListStateIterator(self._list_state_client, self._state_name) + return ListStateIterator(self._listStateClient, self._stateName) - def put(self, new_state: List[Tuple]) -> None: + def put(self, newState: List[Tuple]) -> None: """ Update the values of the list state. """ - self._list_state_client.put(self._state_name, new_state) + self._listStateClient.put(self._stateName, newState) - def append_value(self, new_state: Tuple) -> None: + def appendValue(self, newState: Tuple) -> None: """ Append a new value to the list state. """ - self._list_state_client.append_value(self._state_name, new_state) + self._listStateClient.append_value(self._stateName, newState) - def append_list(self, new_state: List[Tuple]) -> None: + def appendList(self, newState: List[Tuple]) -> None: """ Append a list of new values to the list state. """ - self._list_state_client.append_list(self._state_name, new_state) + self._listStateClient.append_list(self._stateName, newState) def clear(self) -> None: """ Remove this state. """ - self._list_state_client.clear(self._state_name) + self._listStateClient.clear(self._stateName) class MapState: @@ -175,65 +175,65 @@ class MapState: def __init__( self, - map_state_client: MapStateClient, - state_name: str, + MapStateClient: MapStateClient, + stateName: str, ) -> None: - self._map_state_client = map_state_client - self._state_name = state_name + self._mapStateClient = MapStateClient + self._stateName = stateName def exists(self) -> bool: """ Whether state exists or not. """ - return self._map_state_client.exists(self._state_name) + return self._mapStateClient.exists(self._stateName) - def get_value(self, key: Tuple) -> Optional[Tuple]: + def getValue(self, key: Tuple) -> Optional[Tuple]: """ Get the state value for given user key if it exists. """ - return self._map_state_client.get_value(self._state_name, key) + return self._mapStateClient.get_value(self._stateName, key) - def contains_key(self, key: Tuple) -> bool: + def containsKey(self, key: Tuple) -> bool: """ Check if the user key is contained in the map. """ - return self._map_state_client.contains_key(self._state_name, key) + return self._mapStateClient.contains_key(self._stateName, key) - def update_value(self, key: Tuple, value: Tuple) -> None: + def updateValue(self, key: Tuple, value: Tuple) -> None: """ Update value for given user key. """ - return self._map_state_client.update_value(self._state_name, key, value) + return self._mapStateClient.update_value(self._stateName, key, value) def iterator(self) -> Iterator[Tuple[Tuple, Tuple]]: """ Get the map associated with grouping key. """ - return MapStateKeyValuePairIterator(self._map_state_client, self._state_name) + return MapStateKeyValuePairIterator(self._mapStateClient, self._stateName) def keys(self) -> Iterator[Tuple]: """ Get the list of keys present in map associated with grouping key. """ - return MapStateIterator(self._map_state_client, self._state_name, True) + return MapStateIterator(self._mapStateClient, self._stateName, True) def values(self) -> Iterator[Tuple]: """ Get the list of values present in map associated with grouping key. """ - return MapStateIterator(self._map_state_client, self._state_name, False) + return MapStateIterator(self._mapStateClient, self._stateName, False) - def remove_key(self, key: Tuple) -> None: + def removeKey(self, key: Tuple) -> None: """ Remove user key from map state. """ - return self._map_state_client.remove_key(self._state_name, key) + return self._mapStateClient.remove_key(self._stateName, key) def clear(self) -> None: """ Remove this state. """ - self._map_state_client.clear(self._state_name) + self._mapStateClient.clear(self._stateName) class StatefulProcessorHandle: @@ -244,11 +244,11 @@ class StatefulProcessorHandle: .. versionadded:: 4.0.0 """ - def __init__(self, stateful_processor_api_client: StatefulProcessorApiClient) -> None: - self.stateful_processor_api_client = stateful_processor_api_client + def __init__(self, statefulProcessorApiClient: StatefulProcessorApiClient) -> None: + self._statefulProcessorApiClient = statefulProcessorApiClient def getValueState( - self, state_name: str, schema: Union[StructType, str], ttl_duration_ms: Optional[int] = None + self, stateName: str, schema: Union[StructType, str], ttlDurationMs: Optional[int] = None ) -> ValueState: """ Function to create new or return existing single value state variable of given type. @@ -257,7 +257,7 @@ def getValueState( Parameters ---------- - state_name : str + stateName : str name of the state variable schema : :class:`pyspark.sql.types.DataType` or str The schema of the state variable. The value can be either a @@ -268,11 +268,11 @@ def getValueState( resets the expiration time to current processing time plus ttlDuration. If ttl is not specified the state will never expire. """ - self.stateful_processor_api_client.get_value_state(state_name, schema, ttl_duration_ms) - return ValueState(ValueStateClient(self.stateful_processor_api_client, schema), state_name) + self._statefulProcessorApiClient.get_value_state(stateName, schema, ttlDurationMs) + return ValueState(ValueStateClient(self._statefulProcessorApiClient, schema), stateName) def getListState( - self, state_name: str, schema: Union[StructType, str], ttl_duration_ms: Optional[int] = None + self, stateName: str, schema: Union[StructType, str], ttlDurationMs: Optional[int] = None ) -> ListState: """ Function to create new or return existing single value state variable of given type. @@ -281,7 +281,7 @@ def getListState( Parameters ---------- - state_name : str + stateName : str name of the state variable schema : :class:`pyspark.sql.types.DataType` or str The schema of the state variable. The value can be either a @@ -292,15 +292,15 @@ def getListState( resets the expiration time to current processing time plus ttlDuration. If ttl is not specified the state will never expire. """ - self.stateful_processor_api_client.get_list_state(state_name, schema, ttl_duration_ms) - return ListState(ListStateClient(self.stateful_processor_api_client, schema), state_name) + self._statefulProcessorApiClient.get_list_state(stateName, schema, ttlDurationMs) + return ListState(ListStateClient(self._statefulProcessorApiClient, schema), stateName) def getMapState( self, - state_name: str, - user_key_schema: Union[StructType, str], - value_schema: Union[StructType, str], - ttl_duration_ms: Optional[int] = None, + stateName: str, + userKeySchema: Union[StructType, str], + valueSchema: Union[StructType, str], + ttlDurationMs: Optional[int] = None, ) -> MapState: """ Function to create new or return existing single map state variable of given type. @@ -309,51 +309,51 @@ def getMapState( Parameters ---------- - state_name : str + stateName : str name of the state variable - user_key_schema : :class:`pyspark.sql.types.DataType` or str + userKeySchema : :class:`pyspark.sql.types.DataType` or str The schema of the key of map state. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. - value_schema : :class:`pyspark.sql.types.DataType` or str + valueSchema : :class:`pyspark.sql.types.DataType` or str The schema of the value of map state The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. - ttl_duration_ms: int + ttlDurationMs: int Time to live duration of the state in milliseconds. State values will not be returned past ttlDuration and will be eventually removed from the state store. Any state update resets the expiration time to current processing time plus ttlDuration. If ttl is not specified the state will never expire. """ - self.stateful_processor_api_client.get_map_state( - state_name, user_key_schema, value_schema, ttl_duration_ms + self._statefulProcessorApiClient.get_map_state( + stateName, userKeySchema, valueSchema, ttlDurationMs ) return MapState( - MapStateClient(self.stateful_processor_api_client, user_key_schema, value_schema), - state_name, + MapStateClient(self._statefulProcessorApiClient, userKeySchema, valueSchema), + stateName, ) - def registerTimer(self, expiry_time_stamp_ms: int) -> None: + def registerTimer(self, expiryTimestampMs: int) -> None: """ Register a timer for a given expiry timestamp in milliseconds for the grouping key. """ - self.stateful_processor_api_client.register_timer(expiry_time_stamp_ms) + self._statefulProcessorApiClient.register_timer(expiryTimestampMs) - def deleteTimer(self, expiry_time_stamp_ms: int) -> None: + def deleteTimer(self, expiryTimestampMs: int) -> None: """ Delete a timer for a given expiry timestamp in milliseconds for the grouping key. """ - self.stateful_processor_api_client.delete_timer(expiry_time_stamp_ms) + self._statefulProcessorApiClient.delete_timer(expiryTimestampMs) def listTimers(self) -> Iterator[int]: """ List all timers of their expiry timestamps in milliseconds for the grouping key. """ - return ListTimerIterator(self.stateful_processor_api_client) + return ListTimerIterator(self._statefulProcessorApiClient) - def deleteIfExists(self, state_name: str) -> None: + def deleteIfExists(self, stateName: str) -> None: """ Function to delete and purge state variable if defined previously """ - self.stateful_processor_api_client.delete_if_exists(state_name) + self._statefulProcessorApiClient.delete_if_exists(stateName) class StatefulProcessor(ABC): @@ -383,7 +383,7 @@ def handleInputRows( self, key: Any, rows: Iterator["PandasDataFrameLike"], - timer_values: TimerValues, + timerValues: TimerValues, ) -> Iterator["PandasDataFrameLike"]: """ Function that will allow users to interact with input data rows along with the grouping key. @@ -402,14 +402,14 @@ def handleInputRows( grouping key. rows : iterable of :class:`pandas.DataFrame` iterator of input rows associated with grouping key - timer_values: TimerValues - Timer value for the current batch that process the input rows. - Users can get the processing or event time timestamp from TimerValues. + timerValues: TimerValues + Timer value for the current batch that process the input rows. + Users can get the processing or event time timestamp from TimerValues. """ ... def handleExpiredTimer( - self, key: Any, timer_values: TimerValues, expired_timer_info: ExpiredTimerInfo + self, key: Any, timerValues: TimerValues, expiredTimerInfo: ExpiredTimerInfo ) -> Iterator["PandasDataFrameLike"]: """ Optional to implement. Will act return an empty iterator if not defined. @@ -420,11 +420,11 @@ def handleExpiredTimer( ---------- key : Any grouping key. - timer_values: TimerValues - Timer value for the current batch that process the input rows. - Users can get the processing or event time timestamp from TimerValues. - expired_timer_info: ExpiredTimerInfo - Instance of ExpiredTimerInfo that provides access to expired timer. + timerValues: TimerValues + Timer value for the current batch that process the input rows. + Users can get the processing or event time timestamp from TimerValues. + expiredTimerInfo: ExpiredTimerInfo + Instance of ExpiredTimerInfo that provides access to expired timer. """ return iter([]) @@ -437,7 +437,7 @@ def close(self) -> None: ... def handleInitialState( - self, key: Any, initialState: "PandasDataFrameLike", timer_values: TimerValues + self, key: Any, initialState: "PandasDataFrameLike", timerValues: TimerValues ) -> None: """ Optional to implement. Will act as no-op if not defined or no initial state input. @@ -449,8 +449,8 @@ def handleInitialState( grouping key. initialState: :class:`pandas.DataFrame` One dataframe in the initial state associated with the key. - timer_values: TimerValues - Timer value for the current batch that process the input rows. - Users can get the processing or event time timestamp from TimerValues. + timerValues: TimerValues + Timer value for the current batch that process the input rows. + Users can get the processing or event time timestamp from TimerValues. """ pass diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index fec2e5d0caa2e..202d4b5799c4d 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -1481,7 +1481,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.value_state = handle.getValueState("value_state", state_schema) self.handle = handle - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: exists = self.value_state.exists() if exists: value_row = self.value_state.get() @@ -1504,7 +1504,7 @@ def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: else: yield pd.DataFrame({"id": key, "value": str(accumulated_value)}) - def handleInitialState(self, key, initialState, timer_values) -> None: + def handleInitialState(self, key, initialState, timerValues) -> None: init_val = initialState.at[0, "initVal"] self.value_state.update((init_val,)) if len(key) == 1: @@ -1515,16 +1515,16 @@ def close(self) -> None: class StatefulProcessorWithInitialStateTimers(SimpleStatefulProcessorWithInitialState): - def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[pd.DataFrame]: - self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) + def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]: + self.handle.deleteTimer(expiredTimerInfo.getExpiryTimeInMs()) str_key = f"{str(key[0])}-expired" yield pd.DataFrame( - {"id": (str_key,), "value": str(expired_timer_info.get_expiry_time_in_ms())} + {"id": (str_key,), "value": str(expiredTimerInfo.getExpiryTimeInMs())} ) - def handleInitialState(self, key, initialState, timer_values) -> None: - super().handleInitialState(key, initialState, timer_values) - self.handle.registerTimer(timer_values.get_current_processing_time_in_ms() - 1) + def handleInitialState(self, key, initialState, timerValues) -> None: + super().handleInitialState(key, initialState, timerValues) + self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() - 1) class StatefulProcessorWithListStateInitialState(SimpleStatefulProcessorWithInitialState): @@ -1533,9 +1533,9 @@ def init(self, handle: StatefulProcessorHandle) -> None: list_ele_schema = StructType([StructField("value", IntegerType(), True)]) self.list_state = handle.getListState("list_state", list_ele_schema) - def handleInitialState(self, key, initialState, timer_values) -> None: + def handleInitialState(self, key, initialState, timerValues) -> None: for val in initialState["initVal"].tolist(): - self.list_state.append_value((val,)) + self.list_state.appendValue((val,)) # A stateful processor that output the max event time it has seen. Register timer for @@ -1546,15 +1546,15 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.handle = handle self.max_state = handle.getValueState("max_state", state_schema) - def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[pd.DataFrame]: + def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]: self.max_state.clear() - self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) + self.handle.deleteTimer(expiredTimerInfo.getExpiryTimeInMs()) str_key = f"{str(key[0])}-expired" yield pd.DataFrame( - {"id": (str_key,), "timestamp": str(expired_timer_info.get_expiry_time_in_ms())} + {"id": (str_key,), "timestamp": str(expiredTimerInfo.getExpiryTimeInMs())} ) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: timestamp_list = [] for pdf in rows: # int64 will represent timestamp in nanosecond, restore to second @@ -1567,7 +1567,7 @@ def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: max_event_time = str(max(cur_max, max(timestamp_list))) self.max_state.update((max_event_time,)) - self.handle.registerTimer(timer_values.get_current_watermark_in_ms()) + self.handle.registerTimer(timerValues.getCurrentWatermarkInMs()) yield pd.DataFrame({"id": key, "timestamp": max_event_time}) @@ -1583,7 +1583,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.handle = handle self.count_state = handle.getValueState("count_state", state_schema) - def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[pd.DataFrame]: + def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]: # reset count state each time the timer is expired timer_list_1 = [e for e in self.handle.listTimers()] timer_list_2 = [] @@ -1597,23 +1597,23 @@ def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[ if len(timer_list_1) > 0: assert len(timer_list_1) == 2 self.count_state.clear() - self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms()) + self.handle.deleteTimer(expiredTimerInfo.getExpiryTimeInMs()) yield pd.DataFrame( { "id": key, "countAsString": str("-1"), - "timeValues": str(expired_timer_info.get_expiry_time_in_ms()), + "timeValues": str(expiredTimerInfo.getExpiryTimeInMs()), } ) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: if not self.count_state.exists(): count = 0 else: count = int(self.count_state.get()[0]) if key == ("0",): - self.handle.registerTimer(timer_values.get_current_processing_time_in_ms() + 1) + self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 1) rows_count = 0 for pdf in rows: @@ -1623,7 +1623,7 @@ def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: count = count + rows_count self.count_state.update((str(count),)) - timestamp = str(timer_values.get_current_processing_time_in_ms()) + timestamp = str(timerValues.getCurrentProcessingTimeInMs()) yield pd.DataFrame({"id": key, "countAsString": str(count), "timeValues": timestamp}) @@ -1642,7 +1642,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.temp_state = handle.getValueState("tempState", state_schema) handle.deleteIfExists("tempState") - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: with self.assertRaisesRegex(PySparkRuntimeError, "Error checking value state exists"): self.temp_state.exists() new_violations = 0 @@ -1674,7 +1674,7 @@ class StatefulProcessorChainingOps(StatefulProcessor): def init(self, handle: StatefulProcessorHandle) -> None: pass - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: for pdf in rows: timestamp_list = pdf["eventTime"].tolist() yield pd.DataFrame({"id": key, "outputTimestamp": timestamp_list[0]}) @@ -1704,7 +1704,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: "ttl-map-state", user_key_schema, state_schema, 10000 ) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: count = 0 ttl_count = 0 ttl_list_state_count = 0 @@ -1719,7 +1719,7 @@ def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: for s in iter: ttl_list_state_count += s[0] if self.ttl_map_state.exists(): - ttl_map_state_count = self.ttl_map_state.get_value(key)[0] + ttl_map_state_count = self.ttl_map_state.getValue(key)[0] for pdf in rows: pdf_count = pdf.count().get("temperature") count += pdf_count @@ -1732,7 +1732,7 @@ def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: if not (ttl_count == 2 and id == "0"): self.ttl_count_state.update((ttl_count,)) self.ttl_list_state.put([(ttl_list_state_count,), (ttl_list_state_count,)]) - self.ttl_map_state.update_value(key, (ttl_map_state_count,)) + self.ttl_map_state.updateValue(key, (ttl_map_state_count,)) yield pd.DataFrame( { "id": [ @@ -1754,7 +1754,7 @@ def init(self, handle: StatefulProcessorHandle) -> None: state_schema = StructType([StructField("value", IntegerType(), True)]) self.num_violations_state = handle.getValueState("numViolations", state_schema) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: count = 0 exists = self.num_violations_state.exists() assert not exists @@ -1778,16 +1778,16 @@ def init(self, handle: StatefulProcessorHandle) -> None: self.list_state1 = handle.getListState("listState1", state_schema) self.list_state2 = handle.getListState("listState2", state_schema) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: count = 0 for pdf in rows: list_state_rows = [(120,), (20,)] self.list_state1.put(list_state_rows) self.list_state2.put(list_state_rows) - self.list_state1.append_value((111,)) - self.list_state2.append_value((222,)) - self.list_state1.append_list(list_state_rows) - self.list_state2.append_list(list_state_rows) + self.list_state1.appendValue((111,)) + self.list_state2.appendValue((222,)) + self.list_state1.appendList(list_state_rows) + self.list_state2.appendList(list_state_rows) pdf_count = pdf.count() count += pdf_count.get("temperature") iter1 = self.list_state1.get() @@ -1832,7 +1832,7 @@ def init(self, handle: StatefulProcessorHandle): # Test string type schemas self.map_state = handle.getMapState("mapState", "name string", "count int") - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: count = 0 key1 = ("key1",) key2 = ("key2",) @@ -1842,12 +1842,12 @@ def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: value1 = count value2 = count if self.map_state.exists(): - if self.map_state.contains_key(key1): - value1 += self.map_state.get_value(key1)[0] - if self.map_state.contains_key(key2): - value2 += self.map_state.get_value(key2)[0] - self.map_state.update_value(key1, (value1,)) - self.map_state.update_value(key2, (value2,)) + if self.map_state.containsKey(key1): + value1 += self.map_state.getValue(key1)[0] + if self.map_state.containsKey(key2): + value2 += self.map_state.getValue(key2)[0] + self.map_state.updateValue(key1, (value1,)) + self.map_state.updateValue(key2, (value2,)) key_iter = self.map_state.keys() assert next(key_iter)[0] == "key1" assert next(key_iter)[0] == "key2" @@ -1857,8 +1857,8 @@ def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: map_iter = self.map_state.iterator() assert next(map_iter)[0] == key1 assert next(map_iter)[1] == (value2,) - self.map_state.remove_key(key1) - assert not self.map_state.contains_key(key1) + self.map_state.removeKey(key1) + assert not self.map_state.containsKey(key1) yield pd.DataFrame({"id": key, "countAsString": str(count)}) def close(self) -> None: @@ -1884,7 +1884,7 @@ class BasicProcessor(StatefulProcessor): def init(self, handle): self.state = handle.getValueState("state", self.state_schema) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: for pdf in rows: pass id_val = int(key[0]) @@ -1910,7 +1910,7 @@ class AddFieldsProcessor(StatefulProcessor): def init(self, handle): self.state = handle.getValueState("state", self.state_schema) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: for pdf in rows: pass id_val = int(key[0]) @@ -1958,7 +1958,7 @@ class RemoveFieldsProcessor(StatefulProcessor): def init(self, handle): self.state = handle.getValueState("state", self.state_schema) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: for pdf in rows: pass id_val = int(key[0]) @@ -1986,7 +1986,7 @@ class ReorderedFieldsProcessor(StatefulProcessor): def init(self, handle): self.state = handle.getValueState("state", self.state_schema) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: for pdf in rows: pass id_val = int(key[0]) @@ -2035,7 +2035,7 @@ class UpcastProcessor(StatefulProcessor): def init(self, handle): self.state = handle.getValueState("state", self.state_schema) - def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: for pdf in rows: pass id_val = int(key[0]) From f6332be3ec621088c4ef9a3d27514d8f9f02ca93 Mon Sep 17 00:00:00 2001 From: bogao007 Date: Thu, 23 Jan 2025 21:40:22 -0800 Subject: [PATCH 2/2] lint --- python/pyspark/sql/streaming/stateful_processor.py | 4 +--- .../sql/tests/pandas/test_pandas_transform_with_state.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py index 9885fe57b45d8..ba2707ccfb892 100644 --- a/python/pyspark/sql/streaming/stateful_processor.py +++ b/python/pyspark/sql/streaming/stateful_processor.py @@ -81,9 +81,7 @@ class TimerValues: .. versionadded:: 4.0.0 """ - def __init__( - self, currentProcessingTimeInMs: int = -1, currentWatermarkInMs: int = -1 - ) -> None: + def __init__(self, currentProcessingTimeInMs: int = -1, currentWatermarkInMs: int = -1) -> None: self._currentProcessingTimeInMs = currentProcessingTimeInMs self._currentWatermarkInMs = currentWatermarkInMs diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 202d4b5799c4d..64924ded824b9 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -1518,9 +1518,7 @@ class StatefulProcessorWithInitialStateTimers(SimpleStatefulProcessorWithInitial def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]: self.handle.deleteTimer(expiredTimerInfo.getExpiryTimeInMs()) str_key = f"{str(key[0])}-expired" - yield pd.DataFrame( - {"id": (str_key,), "value": str(expiredTimerInfo.getExpiryTimeInMs())} - ) + yield pd.DataFrame({"id": (str_key,), "value": str(expiredTimerInfo.getExpiryTimeInMs())}) def handleInitialState(self, key, initialState, timerValues) -> None: super().handleInitialState(key, initialState, timerValues)