Skip to content

Commit

Permalink
Support map_from_array
Browse files Browse the repository at this point in the history
  • Loading branch information
WangGuangxin committed Aug 10, 2024
1 parent 920cfaf commit c16df96
Show file tree
Hide file tree
Showing 11 changed files with 167 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2074,6 +2074,28 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla
test("Deduplicate sorting keys") {
runQueryAndCompare("select * from lineitem order by l_orderkey, l_orderkey") {
checkGlutenOperatorMatch[SortExecTransformer]
}

test("test map_from_arrays") {
withTempView("t") {
Seq((Seq(1, 2, 1), Seq("a", "b", "c"))).toDF("k", "v").createOrReplaceTempView("t")
withSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key -> SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) {
runQueryAndCompare(
"""
|select map_from_arrays(k, v) from t
|""".stripMargin
) {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
}

withSQLConf(
SQLConf.MAP_KEY_DEDUP_POLICY.key -> SQLConf.MapKeyDedupPolicy.EXCEPTION.toString) {
val msg = intercept[Exception] {
spark.sql("select map_from_arrays(k, v) from t").collect()
}.getMessage
assert(msg.contains("Duplicate map keys (1) are not allowed"))
}
}
}
}
2 changes: 2 additions & 0 deletions cpp/core/config/GlutenConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ const std::string kAllowPrecisionLoss = "spark.sql.decimalOperations.allowPrecis

const std::string kIgnoreMissingFiles = "spark.sql.files.ignoreMissingFiles";

const std::string kMapKeyDedupPolicy = "spark.sql.mapKeyDedupPolicy";

const std::string kDefaultSessionTimezone = "spark.gluten.sql.session.timeZone.default";

const std::string kSparkOffHeapMemory = "spark.gluten.memory.offHeap.size.in.bytes";
Expand Down
3 changes: 3 additions & 0 deletions cpp/velox/compute/WholeStageResultIterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,9 @@ std::unordered_map<std::string, std::string> WholeStageResultIterator::getQueryC

configs[velox::core::QueryConfig::kSparkPartitionId] = std::to_string(taskInfo_.partitionId);

configs[velox::core::QueryConfig::kSparkMapKeyDedupPolicy] =
veloxCfg_->get<std::string>(kMapKeyDedupPolicy, "EXCEPTION");

} catch (const std::invalid_argument& err) {
std::string errDetails = err.what();
throw std::runtime_error("Invalid conf arg: " + errDetails);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ class VeloxTestSettings extends BackendTestSettings {
.exclude("aggregate function - array for non-primitive type")
// Rewrite this test because Velox sorts rows by key for primitive data types, which disrupts the original row sequence.
.exclude("map_zip_with function - map of primitive types")
// Rewrite this test because Velox's exception message is different with vanilla spark.
.exclude("map with arrays")
enableSuite[GlutenDataFrameTungstenSuite]
enableSuite[GlutenDataFrameSetOperationsSuite]
// Result depends on the implementation for nondeterministic expression rand.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenSQLTestsTrait {
import testImplicits._
Expand Down Expand Up @@ -131,4 +132,36 @@ class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenS
oneRowDF.selectExpr("flatten(null)")
}
}

testGluten("map with arrays") {
val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("k", "v")
val expectedType = MapType(IntegerType, StringType, valueContainsNull = true)
val row = df1.select(map_from_arrays($"k", $"v")).first()
assert(row.schema(0).dataType === expectedType)
assert(row.getMap[Int, String](0) === Map(1 -> "a", 2 -> "b"))
checkAnswer(df1.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> "a", 2 -> "b"))))

val df2 = Seq((Seq(1, 2), Seq(null, "b"))).toDF("k", "v")
checkAnswer(df2.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> null, 2 -> "b"))))

val df3 = Seq((null, null)).toDF("k", "v")
checkAnswer(df3.select(map_from_arrays($"k", $"v")), Seq(Row(null)))

val df4 = Seq((1, "a")).toDF("k", "v")
intercept[AnalysisException] {
df4.select(map_from_arrays($"k", $"v"))
}

val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v")
val msg1 = intercept[Exception] {
df5.select(map_from_arrays($"k", $"v")).collect
}.getMessage
assert(msg1.contains("map key cannot be null"))

val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v")
val msg2 = intercept[Exception] {
df6.select(map_from_arrays($"k", $"v")).collect
}.getMessage
assert(msg2.contains("Key and value arrays must be the same length"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,8 @@ class VeloxTestSettings extends BackendTestSettings {
.exclude("aggregate function - array for non-primitive type")
// Rewrite this test because Velox sorts rows by key for primitive data types, which disrupts the original row sequence.
.exclude("map_zip_with function - map of primitive types")
// Rewrite this test because Velox's exception message is different with vanilla spark.
.exclude("map with arrays")
enableSuite[GlutenDataFrameHintSuite]
enableSuite[GlutenDataFrameImplicitsSuite]
enableSuite[GlutenDataFrameJoinSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenSQLTestsTrait {
import testImplicits._
Expand Down Expand Up @@ -131,4 +132,36 @@ class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenS
oneRowDF.selectExpr("flatten(null)")
}
}

testGluten("map with arrays") {
val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("k", "v")
val expectedType = MapType(IntegerType, StringType, valueContainsNull = true)
val row = df1.select(map_from_arrays($"k", $"v")).first()
assert(row.schema(0).dataType === expectedType)
assert(row.getMap[Int, String](0) === Map(1 -> "a", 2 -> "b"))
checkAnswer(df1.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> "a", 2 -> "b"))))

val df2 = Seq((Seq(1, 2), Seq(null, "b"))).toDF("k", "v")
checkAnswer(df2.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> null, 2 -> "b"))))

val df3 = Seq((null, null)).toDF("k", "v")
checkAnswer(df3.select(map_from_arrays($"k", $"v")), Seq(Row(null)))

val df4 = Seq((1, "a")).toDF("k", "v")
intercept[AnalysisException] {
df4.select(map_from_arrays($"k", $"v"))
}

val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v")
val msg1 = intercept[Exception] {
df5.select(map_from_arrays($"k", $"v")).collect
}.getMessage
assert(msg1.contains("map key cannot be null"))

val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v")
val msg2 = intercept[Exception] {
df6.select(map_from_arrays($"k", $"v")).collect
}.getMessage
assert(msg2.contains("Key and value arrays must be the same length"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,8 @@ class VeloxTestSettings extends BackendTestSettings {
.exclude("aggregate function - array for non-primitive type")
// Rewrite this test because Velox sorts rows by key for primitive data types, which disrupts the original row sequence.
.exclude("map_zip_with function - map of primitive types")
// Rewrite this test because Velox's exception message is different with vanilla spark.
.exclude("map with arrays")
enableSuite[GlutenDataFrameHintSuite]
enableSuite[GlutenDataFrameImplicitsSuite]
enableSuite[GlutenDataFrameJoinSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenSQLTestsTrait {
import testImplicits._
Expand Down Expand Up @@ -49,4 +50,36 @@ class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenS
false
)
}

testGluten("map with arrays") {
val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("k", "v")
val expectedType = MapType(IntegerType, StringType, valueContainsNull = true)
val row = df1.select(map_from_arrays($"k", $"v")).first()
assert(row.schema(0).dataType === expectedType)
assert(row.getMap[Int, String](0) === Map(1 -> "a", 2 -> "b"))
checkAnswer(df1.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> "a", 2 -> "b"))))

val df2 = Seq((Seq(1, 2), Seq(null, "b"))).toDF("k", "v")
checkAnswer(df2.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> null, 2 -> "b"))))

val df3 = Seq((null, null)).toDF("k", "v")
checkAnswer(df3.select(map_from_arrays($"k", $"v")), Seq(Row(null)))

val df4 = Seq((1, "a")).toDF("k", "v")
intercept[AnalysisException] {
df4.select(map_from_arrays($"k", $"v"))
}

val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v")
val msg1 = intercept[Exception] {
df5.select(map_from_arrays($"k", $"v")).collect
}.getMessage
assert(msg1.contains("map key cannot be null"))

val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v")
val msg2 = intercept[Exception] {
df6.select(map_from_arrays($"k", $"v")).collect
}.getMessage
assert(msg2.contains("Key and value arrays must be the same length"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,8 @@ class VeloxTestSettings extends BackendTestSettings {
.exclude("aggregate function - array for non-primitive type")
// Rewrite this test because Velox sorts rows by key for primitive data types, which disrupts the original row sequence.
.exclude("map_zip_with function - map of primitive types")
// Rewrite this test because Velox's exception message is different with vanilla spark.
.exclude("map with arrays")
enableSuite[GlutenDataFrameHintSuite]
enableSuite[GlutenDataFrameImplicitsSuite]
enableSuite[GlutenDataFrameJoinSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenSQLTestsTrait {
import testImplicits._
Expand Down Expand Up @@ -49,4 +50,36 @@ class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenS
false
)
}

testGluten("map with arrays") {
val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("k", "v")
val expectedType = MapType(IntegerType, StringType, valueContainsNull = true)
val row = df1.select(map_from_arrays($"k", $"v")).first()
assert(row.schema(0).dataType === expectedType)
assert(row.getMap[Int, String](0) === Map(1 -> "a", 2 -> "b"))
checkAnswer(df1.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> "a", 2 -> "b"))))

val df2 = Seq((Seq(1, 2), Seq(null, "b"))).toDF("k", "v")
checkAnswer(df2.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> null, 2 -> "b"))))

val df3 = Seq((null, null)).toDF("k", "v")
checkAnswer(df3.select(map_from_arrays($"k", $"v")), Seq(Row(null)))

val df4 = Seq((1, "a")).toDF("k", "v")
intercept[AnalysisException] {
df4.select(map_from_arrays($"k", $"v"))
}

val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v")
val msg1 = intercept[Exception] {
df5.select(map_from_arrays($"k", $"v")).collect
}.getMessage
assert(msg1.contains("map key cannot be null"))

val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v")
val msg2 = intercept[Exception] {
df6.select(map_from_arrays($"k", $"v")).collect
}.getMessage
assert(msg2.contains("Key and value arrays must be the same length"))
}
}

0 comments on commit c16df96

Please sign in to comment.