Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50967][SS] Add option to skip emitting initial state keys within the FMGWS operator #49632

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2292,6 +2292,15 @@ object SQLConf {
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
.createWithDefault(2)

val FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS =
buildConf("spark.sql.streaming.flatMapGroupsWithState.skipEmittingInitialStateKeys")
.internal()
.doc("When true, the flatMapGroupsWithState operation in a streaming query will not emit " +
"results for the initial state keys of each group.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val CHECKPOINT_LOCATION = buildConf("spark.sql.streaming.checkpointLocation")
.doc("The default location for storing checkpoint data for streaming queries.")
.version("2.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -736,11 +736,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _,
timeout, hasInitialState, stateGroupAttr, sda, sDeser, initialState, child) =>
val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
val skipEmittingInitialStateKeys =
conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS)
val execPlan = FlatMapGroupsWithStateExec(
func, keyDeser, valueDeser, sDeser, groupAttr, stateGroupAttr, dataAttr, sda, outputAttr,
None, stateEnc, stateVersion, outputMode, timeout, batchTimestampMs = None,
eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None,
planLater(initialState), hasInitialState, planLater(child)
planLater(initialState), hasInitialState, skipEmittingInitialStateKeys, planLater(child)
)
execPlan :: Nil
case _ =>
Expand Down Expand Up @@ -828,7 +830,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val execPlan = python.FlatMapGroupsInPandasWithStateExec(
func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout,
batchTimestampMs = None, eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None, planLater(child)
eventTimeWatermarkForEviction = None,
skipEmittingInitialStateKeys = false,
anishshri-db marked this conversation as resolved.
Show resolved Hide resolved
planLater(child)
)
execPlan :: Nil
case _ =>
Expand Down Expand Up @@ -953,10 +957,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
f, keyDeserializer, valueDeserializer, grouping, data, output, stateEncoder, outputMode,
isFlatMapGroupsWithState, timeout, hasInitialState, initialStateGroupAttrs,
initialStateDataAttrs, initialStateDeserializer, initialState, child) =>
val skipEmittingInitialStateKeys =
conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS)
FlatMapGroupsWithStateExec.generateSparkPlanForBatchQueries(
f, keyDeserializer, valueDeserializer, initialStateDeserializer, grouping,
initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout,
hasInitialState, planLater(initialState), planLater(child)
hasInitialState, skipEmittingInitialStateKeys, planLater(initialState), planLater(child)
) :: Nil
case logical.TransformWithState(keyDeserializer, valueDeserializer, groupingAttributes,
dataAttributes, statefulProcessor, timeMode, outputMode, keyEncoder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import org.apache.spark.util.CompletionIterator
* @param batchTimestampMs processing timestamp of the current batch.
* @param eventTimeWatermarkForLateEvents event time watermark for filtering late events
* @param eventTimeWatermarkForEviction event time watermark for state eviction
* @param skipEmittingInitialStateKeys whether to skip emitting initial state df keys
* @param child logical plan of the underlying data
*/
case class FlatMapGroupsInPandasWithStateExec(
Expand All @@ -64,6 +65,7 @@ case class FlatMapGroupsInPandasWithStateExec(
batchTimestampMs: Option[Long],
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
skipEmittingInitialStateKeys: Boolean,
child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase {

// TODO(SPARK-40444): Add the support of initial state.
Expand Down Expand Up @@ -137,7 +139,8 @@ case class FlatMapGroupsInPandasWithStateExec(

override def processNewDataWithInitialState(
childDataIter: Iterator[InternalRow],
initStateIter: Iterator[InternalRow]): Iterator[InternalRow] = {
initStateIter: Iterator[InternalRow],
skipEmittingInitialStateKeys: Boolean): Iterator[InternalRow] = {
throw SparkUnsupportedOperationException()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ trait FlatMapGroupsWithStateExecBase
protected val initialStateDataAttrs: Seq[Attribute]
protected val initialState: SparkPlan
protected val hasInitialState: Boolean
protected val skipEmittingInitialStateKeys: Boolean

val stateInfo: Option[StatefulOperatorStateInfo]
protected val stateEncoder: ExpressionEncoder[Any]
Expand Down Expand Up @@ -145,7 +146,8 @@ trait FlatMapGroupsWithStateExecBase

val processedOutputIterator = initialStateIterOption match {
case Some(initStateIter) if initStateIter.hasNext =>
processor.processNewDataWithInitialState(filteredIter, initStateIter)
processor.processNewDataWithInitialState(filteredIter, initStateIter,
skipEmittingInitialStateKeys)
case _ => processor.processNewData(filteredIter)
}

Expand Down Expand Up @@ -301,7 +303,8 @@ trait FlatMapGroupsWithStateExecBase
*/
def processNewDataWithInitialState(
childDataIter: Iterator[InternalRow],
initStateIter: Iterator[InternalRow]
initStateIter: Iterator[InternalRow],
skipEmittingInitialStateKeys: Boolean
): Iterator[InternalRow] = {

if (!childDataIter.hasNext && !initStateIter.hasNext) return Iterator.empty
Expand All @@ -312,7 +315,8 @@ trait FlatMapGroupsWithStateExecBase
val groupedInitialStateIter =
GroupedIterator(initStateIter, initialStateGroupAttrs, initialState.output)

// Create a CoGroupedIterator that will group the two iterators together for every key group.
// Create a CoGroupedIterator that will group the two iterators together for every
// key group.
new CoGroupedIterator(
groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap {
case (keyRow, valueRowIter, initialStateRowIter) =>
Expand All @@ -326,12 +330,17 @@ trait FlatMapGroupsWithStateExecBase
val initStateObj = getStateObj.get(initialStateRow)
stateManager.putState(store, keyUnsafeRow, initStateObj, NO_TIMESTAMP)
}
// We apply the values for the key after applying the initial state.
callFunctionAndUpdateState(
stateManager.getState(store, keyUnsafeRow),

if (skipEmittingInitialStateKeys && valueRowIter.isEmpty) {
// If the user has specified to skip emitting the keys that only have initial state
// and no data, then we should not call the function for such keys.
Iterator.empty
} else {
callFunctionAndUpdateState(
stateManager.getState(store, keyUnsafeRow),
valueRowIter,
hasTimedOut = false
)
hasTimedOut = false)
}
}
}

Expand Down Expand Up @@ -388,6 +397,7 @@ trait FlatMapGroupsWithStateExecBase
* @param eventTimeWatermarkForEviction event time watermark for state eviction
* @param initialState the user specified initial state
* @param hasInitialState indicates whether the initial state is provided or not
* @param skipEmittingInitialStateKeys whether to skip emitting initial state df keys
* @param child the physical plan for the underlying data
*/
case class FlatMapGroupsWithStateExec(
Expand All @@ -410,6 +420,7 @@ case class FlatMapGroupsWithStateExec(
eventTimeWatermarkForEviction: Option[Long],
initialState: SparkPlan,
hasInitialState: Boolean,
skipEmittingInitialStateKeys: Boolean,
child: SparkPlan)
extends FlatMapGroupsWithStateExecBase with BinaryExecNode with ObjectProducerExec {
import GroupStateImpl._
Expand Down Expand Up @@ -533,6 +544,7 @@ object FlatMapGroupsWithStateExec {
outputObjAttr: Attribute,
timeoutConf: GroupStateTimeout,
hasInitialState: Boolean,
skipEmittingInitialStateKeys: Boolean,
initialState: SparkPlan,
child: SparkPlan): SparkPlan = {
if (hasInitialState) {
Expand All @@ -541,27 +553,31 @@ object FlatMapGroupsWithStateExec {
case _ => false
}
val func = (keyRow: Any, values: Iterator[Any], states: Iterator[Any]) => {
// Check if there is only one state for every key.
var foundInitialStateForKey = false
val optionalStates = states.map { stateValue =>
if (foundInitialStateForKey) {
foundDuplicateInitialKeyException()
}
foundInitialStateForKey = true
stateValue
}.toArray

// Create group state object
val groupState = GroupStateImpl.createForStreaming(
optionalStates.headOption,
System.currentTimeMillis,
GroupStateImpl.NO_TIMESTAMP,
timeoutConf,
hasTimedOut = false,
watermarkPresent)

// Call user function with the state and values for this key
userFunc(keyRow, values, groupState)
if (skipEmittingInitialStateKeys && values.isEmpty) {
Iterator.empty
} else {
// Check if there is only one state for every key.
var foundInitialStateForKey = false
val optionalStates = states.map { stateValue =>
if (foundInitialStateForKey) {
foundDuplicateInitialKeyException()
}
foundInitialStateForKey = true
stateValue
}.toArray

// Create group state object
val groupState = GroupStateImpl.createForStreaming(
optionalStates.headOption,
System.currentTimeMillis,
GroupStateImpl.NO_TIMESTAMP,
timeoutConf,
hasTimedOut = false,
watermarkPresent)

// Call user function with the state and values for this key
userFunc(keyRow, values, groupState)
}
}
CoGroupExec(
func, keyDeserializer, valueDeserializer, initialStateDeserializer, groupingAttributes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
Some(currentBatchTimestamp), Some(0), Some(currentBatchWatermark),
RDDScanExec(g, emptyRdd, "rdd"),
hasInitialState,
false,
RDDScanExec(g, emptyRdd, "rdd"))
}.get
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,135 @@ class FlatMapGroupsWithStateWithInitialStateSuite extends StateStoreMetricsTest
)
}

// if the keys part of initial state df are different than the keys in the input data, then
// they will not be emitted as part of the result with skipEmittingInitialStateKeys set to true
testWithAllStateVersions("flatMapGroupsWithState - initial state - " +
s"skipEmittingInitialStateKeys=true") {
withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key -> "true") {
val initialState = Seq(
("apple", 1L),
("orange", 2L),
("mango", 5L)).toDS().groupByKey(_._1).mapValues(_._2)

val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => {
val count = state.getOption.map( x => x).getOrElse(0L) + values.size
state.update(count)
Iterator.single((key, count))
}

val inputData = MemoryStream[String]
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc)
testStream(result, Update)(
AddData(inputData, "apple"),
AddData(inputData, "banana"),
CheckNewAnswer(("apple", 2), ("banana", 1)),
AddData(inputData, "orange"),
CheckNewAnswer(("orange", 3)),
StopStream
)
}
}

// if the keys part of initial state df are different than the keys in the input data, then
// they will be emitted as part of the result with skipEmittedInitialStateKeys set to false
testWithAllStateVersions("flatMapGroupsWithState - initial state - " +
s"skipEmittingInitialStateKeys=false") {
withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key -> "false") {
val initialState = Seq(
("apple", 1L),
("orange", 2L),
("mango", 5L)).toDS().groupByKey(_._1).mapValues(_._2)

val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => {
val count = state.getOption.map( x => x).getOrElse(0L) + values.size
state.update(count)
Iterator.single((key, count))
}

val inputData = MemoryStream[String]
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc)
testStream(result, Update)(
AddData(inputData, "apple"),
AddData(inputData, "banana"),
CheckNewAnswer(("apple", 2), ("banana", 1), ("orange", 2), ("mango", 5)),
AddData(inputData, "orange"),
CheckNewAnswer(("orange", 3)),
StopStream
)
}
}

// if the keys part of the initial state and the first batch are the same, then the result
// is the same irrespective of the value of skipEmittingInitialStateKeys
Seq(true, false).foreach { skipEmittingInitialStateKeys =>
testWithAllStateVersions("flatMapGroupsWithState - initial state and initial batch " +
s"have same keys and skipEmittingInitialStateKeys=$skipEmittingInitialStateKeys") {
withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key ->
skipEmittingInitialStateKeys.toString) {
val initialState = Seq(
("apple", 1L),
("orange", 2L)).toDS().groupByKey(_._1).mapValues(_._2)

val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => {
val count = state.getOption.map(x => x).getOrElse(0L) + values.size
state.update(count)
Iterator.single((key, count))
}

val inputData = MemoryStream[String]
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc)
testStream(result, Update)(
AddData(inputData, "apple"),
AddData(inputData, "apple"),
AddData(inputData, "orange"),
CheckNewAnswer(("apple", 3), ("orange", 3)),
AddData(inputData, "orange"),
CheckNewAnswer(("orange", 4)),
StopStream
)
}
}
}

Seq(true, false).foreach { skipEmittingInitialStateKeys =>
testWithAllStateVersions("flatMapGroupsWithState - batch query and " +
s"skipEmittingInitialStateKeys=$skipEmittingInitialStateKeys") {
withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key ->
skipEmittingInitialStateKeys.toString) {
val initialState = Seq(
("apple", 1L),
("orange", 2L)).toDS().groupByKey(_._1).mapValues(_._2)

val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => {
val count = state.getOption.map(x => x).getOrElse(0L) + values.size
state.update(count)
Iterator.single((key, count))
}

val inputData = Seq("orange", "mango")
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc)
val df = result.toDF()
if (skipEmittingInitialStateKeys) {
checkAnswer(df, Seq(("orange", 3), ("mango", 1)).toDF())
} else {
checkAnswer(df, Seq(("apple", 1), ("orange", 3), ("mango", 1)).toDF())
}
}
}
}

def testWithAllStateVersions(name: String)(func: => Unit): Unit = {
for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) {
test(s"$name - state format version $version") {
Expand Down