Skip to content

Commit

Permalink
feat: Make streaming the default LightGBM data transfer mode (#2088)
Browse files Browse the repository at this point in the history
* Make streaming the default

* revert to test

* test again

* group id fixes

* nit fix

* nit fix

---------

Co-authored-by: Mark Hamilton <[email protected]>
  • Loading branch information
svotaw and mhamilton723 authored Oct 21, 2023
1 parent b08b33d commit 2f31f74
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.lightgbm

import scala.collection.mutable
import scala.language.existentials

/** Class for converting column values to group ID.
*
* Ints can just be returned, but a map of Long and String values is maintained so that unique and
* consistent values can be returned.
*/
class GroupIdManager {
private val stringGroupIds = mutable.Map[String, Int]()
private val longGroupIds = mutable.Map[Long, Int]()
private[this] val lock = new Object()

/** Convert a group ID into a unique Int.
*
* @param groupValue The original group ID value
*/
def getUniqueIdForGroup(groupValue: Any): Int = {
groupValue match {
case iVal: Int =>
iVal // If it's already an Int, just return
case lVal: Long =>
lock.synchronized {
if (!longGroupIds.contains(lVal)) {
longGroupIds(lVal) = longGroupIds.size
}
longGroupIds(lVal)
}
case sVal: String =>
lock.synchronized {
if (!stringGroupIds.contains(sVal)) {
stringGroupIds(sVal) = longGroupIds.size
}
stringGroupIds(sVal)
}
case _ =>
throw new IllegalArgumentException(s"Unsupported group column type: '${groupValue.getClass.getName}'")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class SharedState(trainParams: BaseTrainParams) {
val datasetState: SharedDatasetState = new SharedDatasetState(trainParams, isForValidation = false)
val validationDatasetState: SharedDatasetState = new SharedDatasetState(trainParams, isForValidation = true)

lazy val groupIdManager: GroupIdManager = new GroupIdManager()

@volatile var isSparse: Option[Boolean] = None
@volatile var mainExecutorWorker: Option[Long] = None
@volatile var validationDatasetWorker: Option[Long] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,10 @@ class StreamingPartitionTask extends BasePartitionTask {
private def loadOneMetadataRow(state: StreamingState, row: Row, index: Int): Unit = {
state.labelBuffer.setItem(index, row.getDouble(state.labelIndex).toFloat)
if (state.hasWeights) state.weightBuffer.setItem(index, row.getDouble(state.weightIndex).toFloat)
if (state.hasGroups) state.groupBuffer.setItem(index, row.getAs[Int](state.groupIndex))
if (state.hasGroups) {
val groupIdManager = state.ctx.sharedState.groupIdManager
state.groupBuffer.setItem(index, groupIdManager.getUniqueIdForGroup(row.getAs[Any](state.groupIndex)))
}

// Initial scores are passed in column-based format, where the score for each class is contiguous
if (state.hasInitialScores) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ trait LightGBMExecutionParams extends Wrappable {
val dataTransferMode = new Param[String](this, "dataTransferMode",
"Specify how SynapseML transfers data from Spark to LightGBM. " +
"Values can be streaming, bulk. Default is bulk, which is the legacy mode.")
setDefault(dataTransferMode -> LightGBMConstants.BulkDataTransferMode)
setDefault(dataTransferMode -> LightGBMConstants.StreamingDataTransferMode)
def getDataTransferMode: String = $(dataTransferMode)
def setDataTransferMode(value: String): this.type = set(dataTransferMode, value)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ abstract class LightGBMRankerTestData extends Benchmarks with EstimatorFuzzing[L
.setGroupCol(queryCol)
.setDefaultListenPort(getAndIncrementPort())
.setRepartitionByGroupingColumn(false)
.setDataTransferMode(dataTransferMode)
.setNumLeaves(5)
.setNumIterations(10)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ abstract class LightGBMRegressorTestData extends Benchmarks
.setLabelCol(labelCol)
.setFeaturesCol(featuresCol)
.setDefaultListenPort(getAndIncrementPort())
.setDataTransferMode(dataTransferMode)
.setNumLeaves(5)
.setNumIterations(10)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

package com.microsoft.azure.synapse.ml.lightgbm.split2

import com.microsoft.azure.synapse.ml.lightgbm.LightGBMConstants
import com.microsoft.azure.synapse.ml.lightgbm.dataset.DatasetUtils.countCardinality
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vectors
Expand All @@ -14,6 +15,7 @@ import scala.language.postfixOps
//scalastyle:off magic.number
/** Tests to validate the functionality of LightGBM Ranker module in streaming mode. */
class VerifyLightGBMRankerStream extends LightGBMRankerTestData {
override val dataTransferMode: String = LightGBMConstants.StreamingDataTransferMode

import spark.implicits._

Expand Down Expand Up @@ -76,6 +78,11 @@ class VerifyLightGBMRankerStream extends LightGBMRankerTestData {
assert(counts === Seq(2, 3, 1))
}

test("verify cardinality counts: long" + executionModeSuffix) {
val counts = countCardinality(Seq(1L, 1L, 2L, 2L, 2L, 3L))
assert(counts === Seq(2, 3, 1))
}

test("verify cardinality counts: string" + executionModeSuffix) {
val counts = countCardinality(Seq("a", "a", "b", "b", "b", "c"))
assert(counts === Seq(2, 3, 1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

package com.microsoft.azure.synapse.ml.lightgbm.split2

import com.microsoft.azure.synapse.ml.lightgbm.LightGBMRegressionModel
import com.microsoft.azure.synapse.ml.lightgbm.{LightGBMConstants, LightGBMRegressionModel}
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder, TrainValidationSplit}
Expand All @@ -15,6 +15,8 @@ import org.apache.spark.sql.{DataFrame, Row}
/** Tests to validate the functionality of LightGBM module in streaming mode.
*/
class VerifyLightGBMRegressorStream extends LightGBMRegressorTestData {
override val dataTransferMode: String = LightGBMConstants.StreamingDataTransferMode

test(verifyLearnerTitleTemplate.format(energyEffFile, dataTransferMode)) {
verifyLearnerOnRegressionCsvFile(energyEffFile, "Y1", 0,
Some(Seq("X1", "X2", "X3", "X4", "X5", "X6", "X7", "X8", "Y2")))
Expand Down

0 comments on commit 2f31f74

Please sign in to comment.