Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Chong Gao committed Oct 22, 2024
1 parent 35518d8 commit d42d80a
Showing 1 changed file with 39 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.apache.spark.sql.rapids.aggregate

import scala.collection.immutable.Seq

import ai.rapids.cudf
import ai.rapids.cudf.{DType, GroupByAggregation, ReductionAggregation}
import com.nvidia.spark.rapids._
Expand All @@ -25,24 +27,29 @@ import com.nvidia.spark.rapids.jni.HLL
import com.nvidia.spark.rapids.shims.ShimExpression

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.util.{GenericArrayData, HyperLogLogPlusPlusHelper}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch

case class CudfHLL(override val dataType: DataType) extends CudfAggregate {
case class CudfHLL(override val dataType: DataType,
numRegistersPerSketch: Int) extends CudfAggregate {
override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar =
(input: cudf.ColumnVector) => input.reduce(ReductionAggregation.HLL(), DType.LIST)
override lazy val groupByAggregate: GroupByAggregation = GroupByAggregation.HLL(32 * 1024)
(input: cudf.ColumnVector) => input.reduce(
ReductionAggregation.HLL(numRegistersPerSketch), DType.STRUCT)
override lazy val groupByAggregate: GroupByAggregation =
GroupByAggregation.HLL(numRegistersPerSketch)
override val name: String = "CudfHLL"
}

case class CudfMergeHLL(override val dataType: DataType)
case class CudfMergeHLL(override val dataType: DataType,
numRegistersPerSketch: Int)
extends CudfAggregate {
override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar =
(input: cudf.ColumnVector) =>
input.reduce(ReductionAggregation.mergeHLL(), DType.LIST)
input.reduce(ReductionAggregation.mergeHLL(numRegistersPerSketch), DType.STRUCT)

override lazy val groupByAggregate: GroupByAggregation = GroupByAggregation.mergeHLL()
override lazy val groupByAggregate: GroupByAggregation =
GroupByAggregation.mergeHLL(numRegistersPerSketch)
override val name: String = "CudfMergeHLL"
}

Expand All @@ -67,27 +74,45 @@ case class GpuHLLEvaluation(childExpr: Expression, precision: Int)
}
}

case class GpuHLL(childExpr: Expression, precision: Int)
case class GpuHLL(childExpr: Expression, relativeSD: Double)
extends GpuAggregateFunction with Serializable {

// specify the HLL sketch type: list<byte>
private lazy val hllBufferType: DataType = ArrayType(ByteType, containsNull = false)
// Consistent with Spark
private lazy val numRegistersPerSketch: Int =
1 << Math.ceil(2.0d * Math.log(1.106d / relativeSD) / Math.log(2.0d)).toInt

// Consistent with Spark
private lazy val numLongs = numRegistersPerSketch / 10 + 1

// specify the HLL sketch type: struct<long, ..., long>
private lazy val hllBufferType: DataType = StructType.fromAttributes(aggBufferAttributes)

private lazy val hllBufferAttribute: AttributeReference =
AttributeReference("hllAttr", hllBufferType)()

// TODO: should be long array literal
override lazy val initialValues: Seq[Expression] =
Seq(GpuLiteral.create(new GenericArrayData(Array.ofDim[Byte](32 * 1024)), hllBufferType))

override lazy val inputProjection: Seq[Expression] = Seq(childExpr)

override lazy val updateAggregates: Seq[CudfAggregate] = Seq(CudfHLL(hllBufferType))
override lazy val updateAggregates: Seq[CudfAggregate] =
Seq(CudfHLL(hllBufferType, numRegistersPerSketch))

override lazy val mergeAggregates: Seq[CudfAggregate] = Seq(CudfMergeHLL(hllBufferType))
override lazy val mergeAggregates: Seq[CudfAggregate] =
Seq(CudfMergeHLL(hllBufferType, numRegistersPerSketch))

override lazy val evaluateExpression: Expression = GpuHLLEvaluation(hllBufferAttribute, precision)
override lazy val evaluateExpression: Expression =
GpuHLLEvaluation(hllBufferAttribute, numRegistersPerSketch)

override def aggBufferAttributes: Seq[AttributeReference] = hllBufferAttribute :: Nil
private val hllppHelper = new HyperLogLogPlusPlusHelper(relativeSD)

/** Allocate enough words to store all registers. */
override val aggBufferAttributes: Seq[AttributeReference] = {
Seq.tabulate(hllppHelper.numWords) { i =>
AttributeReference(s"MS[$i]", LongType)()
}
}

override def dataType: DataType = hllBufferType

Expand Down

0 comments on commit d42d80a

Please sign in to comment.