Skip to content

Commit

Permalink
[CORE] Use collection interface in method parameter and return type (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 authored Nov 7, 2023
1 parent 7308fdb commit c470b9b
Show file tree
Hide file tree
Showing 52 changed files with 537 additions and 594 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.storage.CHShuffleReadStreamFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -88,8 +87,8 @@ public static long build(

/** create table named struct */
private static NamedStruct toNameStruct(List<Attribute> output) {
ArrayList<TypeNode> typeList = ConverterUtils.collectAttributeTypeNodes(output);
ArrayList<String> nameList = ConverterUtils.collectAttributeNamesWithExprId(output);
List<TypeNode> typeList = ConverterUtils.collectAttributeTypeNodes(output);
List<String> nameList = ConverterUtils.collectAttributeNamesWithExprId(output);
Type.Struct.Builder structBuilder = Type.Struct.newBuilder();
for (TypeNode typeNode : typeList) {
structBuilder.addTypes(typeNode.toProtobuf());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch

import java.lang.{Long => JLong}
import java.net.URI
import java.util
import java.util.{ArrayList => JArrayList}

import scala.collection.JavaConverters._
import scala.collection.mutable
Expand All @@ -68,9 +68,9 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
.makeExtensionTable(p.minParts, p.maxParts, p.database, p.table, p.tablePath),
SoftAffinityUtil.getNativeMergeTreePartitionLocations(p))
case f: FilePartition =>
val paths = new util.ArrayList[String]()
val starts = new util.ArrayList[JLong]()
val lengths = new util.ArrayList[JLong]()
val paths = new JArrayList[String]()
val starts = new JArrayList[JLong]()
val lengths = new JArrayList[JLong]()
val partitionColumns = mutable.ArrayBuffer.empty[Map[String, String]]
f.files.foreach {
file =>
Expand Down Expand Up @@ -122,7 +122,7 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
val resIter: GeneralOutIterator = GlutenTimeMetric.millis(pipelineTime) {
_ =>
val transKernel = new CHNativeExpressionEvaluator()
val inBatchIters = new util.ArrayList[GeneralInIterator](inputIterators.map {
val inBatchIters = new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarNativeIterator(genCloseableColumnBatchIterator(iter).asJava)
}.asJava)
transKernel.createKernelWithBatchIterator(inputPartition.plan, inBatchIters, false)
Expand Down Expand Up @@ -180,7 +180,7 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
_ =>
val transKernel = new CHNativeExpressionEvaluator()
val columnarNativeIterator =
new java.util.ArrayList[GeneralInIterator](inputIterators.map {
new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarNativeIterator(genCloseableColumnBatchIterator(iter).asJava)
}.asJava)
// we need to complete dependency RDD's firstly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}

import java.{lang, util}
import java.lang.{Long => JLong}
import java.util.{List => JList, Map => JMap}

class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil {
override def metricsUpdatingFunction(
child: SparkPlan,
relMap: util.HashMap[lang.Long, util.ArrayList[lang.Long]],
joinParamsMap: util.HashMap[lang.Long, JoinParams],
aggParamsMap: util.HashMap[lang.Long, AggregationParams]): IMetrics => Unit = {
relMap: JMap[JLong, JList[JLong]],
joinParamsMap: JMap[JLong, JoinParams],
aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit = {
MetricsUtil.updateNativeMetrics(child, relMap, joinParamsMap, aggParamsMap)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
import com.google.common.collect.Lists
import org.apache.commons.lang3.ClassUtils

import java.{lang, util}
import java.lang.{Long => JLong}
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}

import scala.collection.mutable.ArrayBuffer

Expand All @@ -64,7 +65,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
/** Transform GetArrayItem to Substrait. */
override def genGetArrayItemExpressionNode(
substraitExprName: String,
functionMap: java.util.HashMap[String, java.lang.Long],
functionMap: JMap[String, JLong],
leftNode: ExpressionNode,
rightNode: ExpressionNode,
original: GetArrayItem): ExpressionNode = {
Expand Down Expand Up @@ -436,9 +437,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
/** Generate window function node */
override def genWindowFunctionsNode(
windowExpression: Seq[NamedExpression],
windowExpressionNodes: util.ArrayList[WindowFunctionNode],
windowExpressionNodes: JList[WindowFunctionNode],
originalInputAttributes: Seq[Attribute],
args: util.HashMap[String, lang.Long]): Unit = {
args: JMap[String, JLong]): Unit = {

windowExpression.map {
windowExpr =>
Expand All @@ -451,7 +452,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame]
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
WindowFunctionsBuilder.create(args, aggWindowFunc).toInt,
new util.ArrayList[ExpressionNode](),
new JArrayList[ExpressionNode](),
columnName,
ConverterUtils.getTypeNode(aggWindowFunc.dataType, aggWindowFunc.nullable),
WindowExecTransformer.getFrameBound(frame.upper),
Expand All @@ -467,7 +468,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
throw new UnsupportedOperationException(s"Not currently supported: $aggregateFunc.")
}

val childrenNodeList = new util.ArrayList[ExpressionNode]()
val childrenNodeList = new JArrayList[ExpressionNode]()
aggregateFunc.children.foreach(
expr =>
childrenNodeList.add(
Expand Down Expand Up @@ -505,7 +506,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
}
}

val childrenNodeList = new util.ArrayList[ExpressionNode]()
val childrenNodeList = new JArrayList[ExpressionNode]()
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetric

import java.lang.{Long => JLong}
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}

import scala.collection.JavaConverters._

object MetricsUtil extends Logging {
Expand Down Expand Up @@ -56,9 +59,9 @@ object MetricsUtil extends Logging {
*/
def updateNativeMetrics(
child: SparkPlan,
relMap: java.util.HashMap[java.lang.Long, java.util.ArrayList[java.lang.Long]],
joinParamsMap: java.util.HashMap[java.lang.Long, JoinParams],
aggParamsMap: java.util.HashMap[java.lang.Long, AggregationParams]): IMetrics => Unit = {
relMap: JMap[JLong, JList[JLong]],
joinParamsMap: JMap[JLong, JoinParams],
aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit = {

val mut: MetricsUpdaterTree = treeifyMetricsUpdaters(child)

Expand Down Expand Up @@ -90,10 +93,10 @@ object MetricsUtil extends Logging {
*/
def updateTransformerMetrics(
mutNode: MetricsUpdaterTree,
relMap: java.util.HashMap[java.lang.Long, java.util.ArrayList[java.lang.Long]],
operatorIdx: java.lang.Long,
joinParamsMap: java.util.HashMap[java.lang.Long, JoinParams],
aggParamsMap: java.util.HashMap[java.lang.Long, AggregationParams]): IMetrics => Unit = {
relMap: JMap[JLong, JList[JLong]],
operatorIdx: JLong,
joinParamsMap: JMap[JLong, JoinParams],
aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit = {
imetrics =>
try {
val metrics = imetrics.asInstanceOf[NativeMetrics]
Expand Down Expand Up @@ -129,13 +132,13 @@ object MetricsUtil extends Logging {
*/
def updateTransformerMetricsInternal(
mutNode: MetricsUpdaterTree,
relMap: java.util.HashMap[java.lang.Long, java.util.ArrayList[java.lang.Long]],
operatorIdx: java.lang.Long,
relMap: JMap[JLong, JList[JLong]],
operatorIdx: JLong,
metrics: NativeMetrics,
metricsIdx: Int,
joinParamsMap: java.util.HashMap[java.lang.Long, JoinParams],
aggParamsMap: java.util.HashMap[java.lang.Long, AggregationParams]): (java.lang.Long, Int) = {
val nodeMetricsList = new java.util.ArrayList[MetricsData]()
joinParamsMap: JMap[JLong, JoinParams],
aggParamsMap: JMap[JLong, AggregationParams]): (JLong, Int) = {
val nodeMetricsList = new JArrayList[MetricsData]()
var curMetricsIdx = metricsIdx
relMap
.get(operatorIdx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ import org.apache.spark.sql.utils.OASPackageBridge.InputMetricsWrapper
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ExecutorManager

import java.lang.{Long => JLong}
import java.net.URLDecoder
import java.nio.charset.StandardCharsets
import java.time.ZoneOffset
import java.util
import java.util.{ArrayList => JArrayList}
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._
Expand All @@ -67,14 +68,14 @@ class IteratorApiImpl extends IteratorApi with Logging {

def constructSplitInfo(schema: StructType, files: Array[PartitionedFile]) = {
val paths = mutable.ArrayBuffer.empty[String]
val starts = mutable.ArrayBuffer.empty[java.lang.Long]
val lengths = mutable.ArrayBuffer.empty[java.lang.Long]
val starts = mutable.ArrayBuffer.empty[JLong]
val lengths = mutable.ArrayBuffer.empty[JLong]
val partitionColumns = mutable.ArrayBuffer.empty[Map[String, String]]
files.foreach {
file =>
paths.append(URLDecoder.decode(file.filePath.toString, StandardCharsets.UTF_8.name()))
starts.append(java.lang.Long.valueOf(file.start))
lengths.append(java.lang.Long.valueOf(file.length))
starts.append(JLong.valueOf(file.start))
lengths.append(JLong.valueOf(file.length))

val partitionColumn = mutable.Map.empty[String, String]
for (i <- 0 until file.partitionValues.numFields) {
Expand All @@ -90,7 +91,7 @@ class IteratorApiImpl extends IteratorApi with Logging {
case _: TimestampType =>
TimestampFormatter
.getFractionFormatter(ZoneOffset.UTC)
.format(pn.asInstanceOf[java.lang.Long])
.format(pn.asInstanceOf[JLong])
case _ => pn.toString
}
}
Expand Down Expand Up @@ -139,7 +140,7 @@ class IteratorApiImpl extends IteratorApi with Logging {
inputIterators: Seq[Iterator[ColumnarBatch]] = Seq()): Iterator[ColumnarBatch] = {
val beforeBuild = System.nanoTime()
val columnarNativeIterators =
new util.ArrayList[GeneralInIterator](inputIterators.map {
new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarBatchInIterator(iter.asJava)
}.asJava)
val transKernel = NativePlanEvaluator.create()
Expand Down Expand Up @@ -183,7 +184,7 @@ class IteratorApiImpl extends IteratorApi with Logging {

val transKernel = NativePlanEvaluator.create()
val columnarNativeIterator =
new util.ArrayList[GeneralInIterator](inputIterators.map {
new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarBatchInIterator(iter.asJava)
}.asJava)
val nativeResultIterator =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}

import java.{lang, util}
import java.lang.{Long => JLong}
import java.util.{List => JList, Map => JMap}

class MetricsApiImpl extends MetricsApi with Logging {
override def metricsUpdatingFunction(
child: SparkPlan,
relMap: util.HashMap[lang.Long, util.ArrayList[lang.Long]],
joinParamsMap: util.HashMap[lang.Long, JoinParams],
aggParamsMap: util.HashMap[lang.Long, AggregationParams]): IMetrics => Unit = {
relMap: JMap[JLong, JList[JLong]],
joinParamsMap: JMap[JLong, JoinParams],
aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit = {
MetricsUtil.updateNativeMetrics(child, relMap, joinParamsMap, aggParamsMap)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ import org.apache.commons.lang3.ClassUtils

import javax.ws.rs.core.UriBuilder

import java.lang.{Long => JLong}
import java.util.{Map => JMap}

import scala.collection.mutable.ArrayBuffer

class SparkPlanExecApiImpl extends SparkPlanExecApi {
Expand All @@ -67,7 +70,7 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
*/
override def genGetArrayItemExpressionNode(
substraitExprName: String,
functionMap: java.util.HashMap[String, java.lang.Long],
functionMap: JMap[String, JLong],
leftNode: ExpressionNode,
rightNode: ExpressionNode,
original: GetArrayItem): ExpressionNode = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDi
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.BitSet

import java.util
import java.util.{Map => JMap}

class TransformerApiImpl extends TransformerApi with Logging {

Expand Down Expand Up @@ -65,7 +65,7 @@ class TransformerApiImpl extends TransformerApi with Logging {
}

override def postProcessNativeConfig(
nativeConfMap: util.Map[String, String],
nativeConfMap: JMap[String, String],
backendPrefix: String): Unit = {
// TODO: IMPLEMENT SPECIAL PROCESS FOR VELOX BACKEND
}
Expand Down
Loading

0 comments on commit c470b9b

Please sign in to comment.