diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/GroupIdManager.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/GroupIdManager.scala new file mode 100644 index 0000000000..7d061315b4 --- /dev/null +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/GroupIdManager.scala @@ -0,0 +1,45 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.lightgbm + +import scala.collection.mutable +import scala.language.existentials + +/** Class for converting column values to group ID. + * + * Ints can just be returned, but a map of Long and String values is maintained so that unique and + * consistent values can be returned. + */ +class GroupIdManager { + private val stringGroupIds = mutable.Map[String, Int]() + private val longGroupIds = mutable.Map[Long, Int]() + private[this] val lock = new Object() + + /** Convert a group ID into a unique Int. + * + * @param groupValue The original group ID value + */ + def getUniqueIdForGroup(groupValue: Any): Int = { + groupValue match { + case iVal: Int => + iVal // If it's already an Int, just return + case lVal: Long => + lock.synchronized { + if (!longGroupIds.contains(lVal)) { + longGroupIds(lVal) = longGroupIds.size + } + longGroupIds(lVal) + } + case sVal: String => + lock.synchronized { + if (!stringGroupIds.contains(sVal)) { + stringGroupIds(sVal) = longGroupIds.size + } + stringGroupIds(sVal) + } + case _ => + throw new IllegalArgumentException(s"Unsupported group column type: '${groupValue.getClass.getName}'") + } + } +} diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/SharedState.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/SharedState.scala index c9f97923a2..2ade1ebfaa 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/SharedState.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/SharedState.scala @@ -36,6 +36,8 @@ class SharedState(trainParams: BaseTrainParams) { val datasetState: SharedDatasetState = new SharedDatasetState(trainParams, isForValidation = false) val validationDatasetState: SharedDatasetState = new SharedDatasetState(trainParams, isForValidation = true) + lazy val groupIdManager: GroupIdManager = new GroupIdManager() + @volatile var isSparse: Option[Boolean] = None @volatile var mainExecutorWorker: Option[Long] = None @volatile var validationDatasetWorker: Option[Long] = None diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/StreamingPartitionTask.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/StreamingPartitionTask.scala index 98cac95f51..0d39328b77 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/StreamingPartitionTask.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/StreamingPartitionTask.scala @@ -331,7 +331,10 @@ class StreamingPartitionTask extends BasePartitionTask { private def loadOneMetadataRow(state: StreamingState, row: Row, index: Int): Unit = { state.labelBuffer.setItem(index, row.getDouble(state.labelIndex).toFloat) if (state.hasWeights) state.weightBuffer.setItem(index, row.getDouble(state.weightIndex).toFloat) - if (state.hasGroups) state.groupBuffer.setItem(index, row.getAs[Int](state.groupIndex)) + if (state.hasGroups) { + val groupIdManager = state.ctx.sharedState.groupIdManager + state.groupBuffer.setItem(index, groupIdManager.getUniqueIdForGroup(row.getAs[Any](state.groupIndex))) + } // Initial scores are passed in column-based format, where the score for each class is contiguous if (state.hasInitialScores) { diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/params/LightGBMParams.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/params/LightGBMParams.scala index 5e5e9a77a4..423ea6300a 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/params/LightGBMParams.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/params/LightGBMParams.scala @@ -92,7 +92,7 @@ trait LightGBMExecutionParams extends Wrappable { val dataTransferMode = new Param[String](this, "dataTransferMode", "Specify how SynapseML transfers data from Spark to LightGBM. " + "Values can be streaming, bulk. Default is bulk, which is the legacy mode.") - setDefault(dataTransferMode -> LightGBMConstants.BulkDataTransferMode) + setDefault(dataTransferMode -> LightGBMConstants.StreamingDataTransferMode) def getDataTransferMode: String = $(dataTransferMode) def setDataTransferMode(value: String): this.type = set(dataTransferMode, value) diff --git a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/LightGBMRankerTestData.scala b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/LightGBMRankerTestData.scala index 1e9ebfc91b..e66d8b3f8c 100644 --- a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/LightGBMRankerTestData.scala +++ b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/LightGBMRankerTestData.scala @@ -51,6 +51,7 @@ abstract class LightGBMRankerTestData extends Benchmarks with EstimatorFuzzing[L .setGroupCol(queryCol) .setDefaultListenPort(getAndIncrementPort()) .setRepartitionByGroupingColumn(false) + .setDataTransferMode(dataTransferMode) .setNumLeaves(5) .setNumIterations(10) } diff --git a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/LightGBMRegressorTestData.scala b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/LightGBMRegressorTestData.scala index e6e58b0dc6..be28beccf0 100644 --- a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/LightGBMRegressorTestData.scala +++ b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/LightGBMRegressorTestData.scala @@ -29,6 +29,7 @@ abstract class LightGBMRegressorTestData extends Benchmarks .setLabelCol(labelCol) .setFeaturesCol(featuresCol) .setDefaultListenPort(getAndIncrementPort()) + .setDataTransferMode(dataTransferMode) .setNumLeaves(5) .setNumIterations(10) } diff --git a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/VerifyLightGBMRankerStream.scala b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/VerifyLightGBMRankerStream.scala index 526fae8aa4..c649bb7763 100644 --- a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/VerifyLightGBMRankerStream.scala +++ b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/VerifyLightGBMRankerStream.scala @@ -3,6 +3,7 @@ package com.microsoft.azure.synapse.ml.lightgbm.split2 +import com.microsoft.azure.synapse.ml.lightgbm.LightGBMConstants import com.microsoft.azure.synapse.ml.lightgbm.dataset.DatasetUtils.countCardinality import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.linalg.Vectors @@ -14,6 +15,7 @@ import scala.language.postfixOps //scalastyle:off magic.number /** Tests to validate the functionality of LightGBM Ranker module in streaming mode. */ class VerifyLightGBMRankerStream extends LightGBMRankerTestData { + override val dataTransferMode: String = LightGBMConstants.StreamingDataTransferMode import spark.implicits._ @@ -76,6 +78,11 @@ class VerifyLightGBMRankerStream extends LightGBMRankerTestData { assert(counts === Seq(2, 3, 1)) } + test("verify cardinality counts: long" + executionModeSuffix) { + val counts = countCardinality(Seq(1L, 1L, 2L, 2L, 2L, 3L)) + assert(counts === Seq(2, 3, 1)) + } + test("verify cardinality counts: string" + executionModeSuffix) { val counts = countCardinality(Seq("a", "a", "b", "b", "b", "c")) assert(counts === Seq(2, 3, 1)) diff --git a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/VerifyLightGBMRegressorStream.scala b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/VerifyLightGBMRegressorStream.scala index 5645bc194a..15286f4499 100644 --- a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/VerifyLightGBMRegressorStream.scala +++ b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/split2/VerifyLightGBMRegressorStream.scala @@ -3,7 +3,7 @@ package com.microsoft.azure.synapse.ml.lightgbm.split2 -import com.microsoft.azure.synapse.ml.lightgbm.LightGBMRegressionModel +import com.microsoft.azure.synapse.ml.lightgbm.{LightGBMConstants, LightGBMRegressionModel} import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder, TrainValidationSplit} @@ -15,6 +15,8 @@ import org.apache.spark.sql.{DataFrame, Row} /** Tests to validate the functionality of LightGBM module in streaming mode. */ class VerifyLightGBMRegressorStream extends LightGBMRegressorTestData { + override val dataTransferMode: String = LightGBMConstants.StreamingDataTransferMode + test(verifyLearnerTitleTemplate.format(energyEffFile, dataTransferMode)) { verifyLearnerOnRegressionCsvFile(energyEffFile, "Y1", 0, Some(Seq("X1", "X2", "X3", "X4", "X5", "X6", "X7", "X8", "Y2")))