Skip to content

Commit

Permalink
Improve fuzz testing coverage (#668)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove authored Jul 17, 2024
1 parent ae7ea99 commit a8ebd0b
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 17 deletions.
4 changes: 4 additions & 0 deletions fuzz-testing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ $SPARK_HOME/bin/spark-submit \
data --num-files=2 --num-rows=200 --num-columns=100
```

There is an optional `--exclude-negative-zero` flag for excluding `-0.0` from the generated data, which is
sometimes useful because we already know that we often have different behavior for this edge case due to
differences between Rust and Java handling of this value.

### Generating Queries

Generate random queries that are based on the available test files.
Expand Down
40 changes: 28 additions & 12 deletions fuzz-testing/src/main/scala/org/apache/comet/fuzz/DataGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,16 @@ object DataGen {
spark: SparkSession,
numFiles: Int,
numRows: Int,
numColumns: Int): Unit = {
numColumns: Int,
generateNegativeZero: Boolean): Unit = {
for (i <- 0 until numFiles) {
generateRandomParquetFile(r, spark, s"test$i.parquet", numRows, numColumns)
generateRandomParquetFile(
r,
spark,
s"test$i.parquet",
numRows,
numColumns,
generateNegativeZero)
}
}

Expand All @@ -46,7 +53,8 @@ object DataGen {
spark: SparkSession,
filename: String,
numRows: Int,
numColumns: Int): Unit = {
numColumns: Int,
generateNegativeZero: Boolean): Unit = {

// generate schema using random data types
val fields = Range(0, numColumns)
Expand All @@ -55,7 +63,8 @@ object DataGen {
val schema = StructType(fields)

// generate columnar data
val cols: Seq[Seq[Any]] = fields.map(f => generateColumn(r, f.dataType, numRows))
val cols: Seq[Seq[Any]] =
fields.map(f => generateColumn(r, f.dataType, numRows, generateNegativeZero))

// convert to rows
val rows = Range(0, numRows).map(rowIndex => {
Expand All @@ -66,18 +75,25 @@ object DataGen {
df.write.mode(SaveMode.Overwrite).parquet(filename)
}

def generateColumn(r: Random, dataType: DataType, numRows: Int): Seq[Any] = {
def generateColumn(
r: Random,
dataType: DataType,
numRows: Int,
generateNegativeZero: Boolean): Seq[Any] = {
dataType match {
case DataTypes.BooleanType =>
generateColumn(r, DataTypes.LongType, numRows)
generateColumn(r, DataTypes.LongType, numRows, generateNegativeZero)
.map(_.asInstanceOf[Long].toShort)
.map(s => s % 2 == 0)
case DataTypes.ByteType =>
generateColumn(r, DataTypes.LongType, numRows).map(_.asInstanceOf[Long].toByte)
generateColumn(r, DataTypes.LongType, numRows, generateNegativeZero)
.map(_.asInstanceOf[Long].toByte)
case DataTypes.ShortType =>
generateColumn(r, DataTypes.LongType, numRows).map(_.asInstanceOf[Long].toShort)
generateColumn(r, DataTypes.LongType, numRows, generateNegativeZero)
.map(_.asInstanceOf[Long].toShort)
case DataTypes.IntegerType =>
generateColumn(r, DataTypes.LongType, numRows).map(_.asInstanceOf[Long].toInt)
generateColumn(r, DataTypes.LongType, numRows, generateNegativeZero)
.map(_.asInstanceOf[Long].toInt)
case DataTypes.LongType =>
Range(0, numRows).map(_ => {
r.nextInt(50) match {
Expand All @@ -103,7 +119,7 @@ object DataGen {
case 3 => Float.MinValue
case 4 => Float.MaxValue
case 5 => 0.0f
case 6 => -0.0f
case 6 if generateNegativeZero => -0.0f
case _ => r.nextFloat()
}
})
Expand All @@ -116,7 +132,7 @@ object DataGen {
case 3 => Double.MinValue
case 4 => Double.MaxValue
case 5 => 0.0
case 6 => -0.0
case 6 if generateNegativeZero => -0.0
case _ => r.nextDouble()
}
})
Expand All @@ -134,7 +150,7 @@ object DataGen {
}
})
case DataTypes.BinaryType =>
generateColumn(r, DataTypes.StringType, numRows)
generateColumn(r, DataTypes.StringType, numRows, generateNegativeZero)
.map {
case x: String =>
x.getBytes(Charset.defaultCharset())
Expand Down
4 changes: 3 additions & 1 deletion fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Conf(arguments: Seq[String]) extends ScallopConf(arguments) {
val numFiles: ScallopOption[Int] = opt[Int](required = true)
val numRows: ScallopOption[Int] = opt[Int](required = true)
val numColumns: ScallopOption[Int] = opt[Int](required = true)
val excludeNegativeZero: ScallopOption[Boolean] = opt[Boolean](required = false)
}
addSubcommand(generateData)
object generateQueries extends Subcommand("queries") {
Expand Down Expand Up @@ -64,7 +65,8 @@ object Main {
spark,
numFiles = conf.generateData.numFiles(),
numRows = conf.generateData.numRows(),
numColumns = conf.generateData.numColumns())
numColumns = conf.generateData.numColumns(),
generateNegativeZero = !conf.generateData.excludeNegativeZero())
case Some(conf.generateQueries) =>
QueryGen.generateRandomQueries(
r,
Expand Down
24 changes: 21 additions & 3 deletions fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ object Meta {
Function("trim", 1),
Function("ltrim", 1),
Function("rtrim", 1),
Function("string_space", 1),
Function("rpad", 2),
Function("rpad", 3), // rpad can have 2 or 3 arguments
Function("hex", 1),
Function("unhex", 1),
Function("xxhash64", 1),
Function("sha1", 1),
// Function("sha2", 1), -- needs a second argument for number of bits
Function("substring", 3),
Function("btrim", 1),
Function("concat_ws", 2),
Function("repeat", 2),
Expand Down Expand Up @@ -86,9 +95,16 @@ object Meta {
Function("Sqrt", 1),
Function("Tan", 1),
Function("Ceil", 1),
Function("Floor", 1))
Function("Floor", 1),
Function("bool_and", 1),
Function("bool_or", 1),
Function("bitwise_not", 1))

val miscScalarFunc: Seq[Function] =
Seq(Function("isnan", 1), Function("isnull", 1), Function("isnotnull", 1))

val scalarFunc: Seq[Function] = stringScalarFunc ++ dateScalarFunc ++ mathScalarFunc
val scalarFunc: Seq[Function] = stringScalarFunc ++ dateScalarFunc ++
mathScalarFunc ++ miscScalarFunc

val aggFunc: Seq[Function] = Seq(
Function("min", 1),
Expand All @@ -108,6 +124,8 @@ object Meta {

val unaryArithmeticOps: Seq[String] = Seq("+", "-")

val binaryArithmeticOps: Seq[String] = Seq("+", "-", "*", "/", "%", "&", "|", "^")
val binaryArithmeticOps: Seq[String] = Seq("+", "-", "*", "/", "%", "&", "|", "^", "<<", ">>")

val comparisonOps: Seq[String] = Seq("=", "<=>", ">", ">=", "<", "<=")

}
32 changes: 31 additions & 1 deletion fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ object QueryGen {
val uniqueQueries = mutable.HashSet[String]()

for (_ <- 0 until numQueries) {
val sql = r.nextInt().abs % 6 match {
val sql = r.nextInt().abs % 8 match {
case 0 => generateJoin(r, spark, numFiles)
case 1 => generateAggregate(r, spark, numFiles)
case 2 => generateScalar(r, spark, numFiles)
case 3 => generateCast(r, spark, numFiles)
case 4 => generateUnaryArithmetic(r, spark, numFiles)
case 5 => generateBinaryArithmetic(r, spark, numFiles)
case 6 => generateBinaryComparison(r, spark, numFiles)
case _ => generateConditional(r, spark, numFiles)
}
if (!uniqueQueries.contains(sql)) {
uniqueQueries += sql
Expand Down Expand Up @@ -121,6 +123,34 @@ object QueryGen {
s"ORDER BY $a, $b;"
}

private def generateBinaryComparison(r: Random, spark: SparkSession, numFiles: Int): String = {
val tableName = s"test${r.nextInt(numFiles)}"
val table = spark.table(tableName)

val op = Utils.randomChoice(Meta.comparisonOps, r)
val a = Utils.randomChoice(table.columns, r)
val b = Utils.randomChoice(table.columns, r)

// Example SELECT a, b, a <=> b FROM test0
s"SELECT $a, $b, $a $op $b " +
s"FROM $tableName " +
s"ORDER BY $a, $b;"
}

private def generateConditional(r: Random, spark: SparkSession, numFiles: Int): String = {
val tableName = s"test${r.nextInt(numFiles)}"
val table = spark.table(tableName)

val op = Utils.randomChoice(Meta.comparisonOps, r)
val a = Utils.randomChoice(table.columns, r)
val b = Utils.randomChoice(table.columns, r)

// Example SELECT a, b, IF(a <=> b, 1, 2), CASE WHEN a <=> b THEN 1 ELSE 2 END FROM test0
s"SELECT $a, $b, $a $op $b, IF($a $op $b, 1, 2), CASE WHEN $a $op $b THEN 1 ELSE 2 END " +
s"FROM $tableName " +
s"ORDER BY $a, $b;"
}

private def generateCast(r: Random, spark: SparkSession, numFiles: Int): String = {
val tableName = s"test${r.nextInt(numFiles)}"
val table = spark.table(tableName)
Expand Down

0 comments on commit a8ebd0b

Please sign in to comment.