From cea85745a449aa0f2c32312460e8246db97fc2ba Mon Sep 17 00:00:00 2001 From: Ponkin Date: Fri, 12 Aug 2016 04:23:33 +0300 Subject: [PATCH 1/6] SPARKC-312: Implementing FilterOptimizer --- .../sql/CassandraPrunedFilteredScanSpec.scala | 35 ++++++++ .../cassandra/CassandraSourceRelation.scala | 19 +++- .../sql/cassandra/FiltersOptimizer.scala | 88 +++++++++++++++++++ 3 files changed, 139 insertions(+), 3 deletions(-) create mode 100644 spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/FiltersOptimizer.scala diff --git a/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala b/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala index 2bf1adc22..e01913cb9 100644 --- a/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala +++ b/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala @@ -32,6 +32,12 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with s"""CREATE TABLE IF NOT EXISTS $ks.fields |(k INT, a TEXT, b TEXT, c TEXT, d TEXT, e TEXT, PRIMARY KEY (k)) """ .stripMargin) + }, + Future { + session.execute( + s"""CREATE TABLE IF NOT EXISTS $ks.metrics + |(k TEXT, a INT, b INT, c INT, PRIMARY KEY (k, a, b)) """ + .stripMargin) } ) } @@ -39,6 +45,7 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with val colorOptions = Map("keyspace" -> ks, "table" -> "colors") val fieldsOptions = Map("keyspace" -> ks, "table" -> "fields") + val metricsOptions = Map("keyspace" -> ks, "table" -> "metrics") val withPushdown = Map("pushdown" -> "true") val withoutPushdown = Map("pushdown" -> "false") @@ -74,6 +81,16 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with cts.get.selectedColumnNames should contain theSameElementsAs Seq("b", "c", "d") } + it should "optimize table scan if all filters can be pushed down" in { + val fieldsDF = sqlContext.read.format(cassandraFormat).options(metricsOptions ++ withPushdown).load() + val df = fieldsDF.filter("a = 5 and (b > 5 or b < 3)") + val executionPlan = df.queryExecution.executedPlan + val cts = findAllCassandraTableScanRDD(executionPlan) + cts.nonEmpty shouldBe true + cts.head.where shouldBe CqlWhereClause(Seq(""""a" = ? AND "b" > ?"""), List(5, 5)) + cts.last.where shouldBe CqlWhereClause(Seq(""""a" = ? AND "b" < ?"""), List(5, 3)) + } + def findCassandraTableScanRDD(sparkPlan: SparkPlan): Option[CassandraTableScanRDD[_]] = { def _findCassandraTableScanRDD(rdd: RDD[_]): Option[CassandraTableScanRDD[_]] = { rdd match { @@ -92,4 +109,22 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with } } + def findAllCassandraTableScanRDD(sparkPlan: SparkPlan): List[CassandraTableScanRDD[_]] = { + def _findAllCassandraTableScanRDD(rdd: RDD[_]): List[CassandraTableScanRDD[_]] = { + rdd match { + case ctsrdd: CassandraTableScanRDD[_] => List(ctsrdd) + case other: RDD[_] => other.dependencies.iterator + .flatMap(dep => _findAllCassandraTableScanRDD(dep.rdd)).toList + } + } + + sparkPlan match { + case prdd: RDDScanExec => _findAllCassandraTableScanRDD(prdd.rdd) + case prdd: RowDataSourceScanExec => _findAllCassandraTableScanRDD(prdd.rdd) + case filter: FilterExec => findAllCassandraTableScanRDD(filter.child) + case wsc: WholeStageCodegenExec => findAllCassandraTableScanRDD(wsc.child) + case _ => List.empty + } + } + } \ No newline at end of file diff --git a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala index 6b220ead6..96a70822f 100644 --- a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala +++ b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala @@ -11,7 +11,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row, SQLContext, sources} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.SparkConf - import com.datastax.spark.connector.cql.{CassandraConnector, CassandraConnectorConf, Schema} import com.datastax.spark.connector.rdd.partitioner.CassandraPartitionGenerator._ import com.datastax.spark.connector.rdd.partitioner.DataSizeEstimates @@ -129,8 +128,22 @@ private[cassandra] class CassandraSourceRelation( val prunedRdd = maybeSelect(baseRdd, requiredColumns) val prunedFilteredRdd = { if(filterPushdown) { - val pushdownFilters = predicatePushDown(filters).handledByCassandra.toArray - val filteredRdd = maybePushdownFilters(prunedRdd, pushdownFilters) + val optimizedFilters = new FiltersOptimizer(filters).build() + val optimizationCanBeApplied = + optimizedFilters.size > 1 && + optimizedFilters.sliding(2).foldLeft(true){ (acc, sets) => + // check whether all non-pushed down filters are equals for each separate rdd + predicatePushDown(sets.head).handledBySpark == predicatePushDown(sets.last).handledBySpark + } + val filteredRdd = if(optimizationCanBeApplied) { + optimizedFilters.map { predicate => + val pushdownFilters = predicatePushDown(predicate).handledByCassandra.toArray + maybePushdownFilters(prunedRdd, pushdownFilters).asInstanceOf[RDD[Row]] + }.reduce(_ union _) + } else { + val pushdownFilters = predicatePushDown(filters).handledByCassandra.toArray + maybePushdownFilters(prunedRdd, pushdownFilters) + } filteredRdd.asInstanceOf[RDD[Row]] } else { prunedRdd diff --git a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/FiltersOptimizer.scala b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/FiltersOptimizer.scala new file mode 100644 index 000000000..d3bb55e59 --- /dev/null +++ b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/FiltersOptimizer.scala @@ -0,0 +1,88 @@ +package org.apache.spark.sql.cassandra + +import org.apache.spark.sql.sources._ + + +/** + * Optimizer will try to transform pushdown filter into `sum of products`. + * So that the filter like + * '(field1 < 3 OR field1 > 7) AND (field2 = 'val1' OR field2 = 'val2')' + * will become equivalent + * 'field1 < 3 AND field2 = "val1" OR field1 < 3 AND field2 = "val2" OR + * field1 > 7 AND field2 = "val1" OR field1 > 7 AND field2 = "val2"' + */ +class FiltersOptimizer(filters: Array[Filter]) { + + private val fullFilterAst = + if (filters.nonEmpty) Some(filters.reduce((left, right) => And(left, right))) else None + + private def dist(predL: Filter, predR: Filter): Filter = (predL, predR) match { + case (Or(l, r), p) => Or(dist(l, p), dist(r, p)) + case (p, Or(l, r)) => Or(dist(p, l), dist(p, r)) + case (l, r) => And(l, r) + } + + /** The 'toNNF' function converts expressions to negation normal form. This + * function is total: it's defined for all expressions, not just those which + * only use negation, conjunction and disjunction, although all expressions in + * negation normal form do in fact only use those connectives. + * + * Then de Morgan's laws are applied to convert negated + * conjunctions and disjunctions into the conjunction or disjunction of the + * negation of their conjuncts: ¬(φ ∧ ψ) is converted to (¬φ ∨ ¬ψ) + * while ¬(φ ∨ ψ) becomes (¬φ ∧ ¬ψ). + */ + private val toNNF: Filter => Filter = { + case a@(EqualTo(_, _) | EqualNullSafe(_, _) | GreaterThan(_, _) | + GreaterThanOrEqual(_, _) | LessThan(_, _) | LessThanOrEqual(_, _) | + In(_, _) | IsNull(_) | IsNotNull(_) | + StringStartsWith(_, _) | StringEndsWith(_, _) | StringContains(_, _)) => a + case a@Not(EqualTo(_, _) | EqualNullSafe(_, _) | In(_, _) | + StringStartsWith(_, _) | StringEndsWith(_, _) | StringContains(_, _)) => a + case Not(GreaterThan(a, v)) => LessThanOrEqual(a, v) + case Not(LessThanOrEqual(a, v)) => GreaterThan(a, v) + case Not(LessThan(a, v)) => GreaterThanOrEqual(a, v) + case Not(GreaterThanOrEqual(a, v)) => LessThan(a, v) + case Not(IsNull(a)) => IsNotNull(a) + case Not(IsNotNull(a)) => IsNull(a) + case Not(Not(p)) => p + case And(l, r) => And(toNNF(l), toNNF(r)) + case Not(And(l, r)) => toNNF(Or(Not(l), Not(r))) + case Or(l, r) => Or(toNNF(l), toNNF(r)) + case Not(Or(l, r)) => toNNF(And(Not(l), Not(r))) + } + + /** The 'toDNF' function converts expressions to disjunctive normal form: a + * disjunction of clauses, where a clause is a conjunction of literals + * (variables and negated variables). + * + * The conversion is carried out by first converting the expression into + * negation normal form, and then applying the distributive law. + */ + private val toDNF: Filter => Filter = { + case And(l, r) => dist(toDNF(l), toDNF(r)) + case Or(l, r) => Or(toDNF(l), toDNF(r)) + case p => p + } + + /** + * Traverse over disjunctive parts of AST + */ + private val traverse: Filter => List[Filter] = { + case Or(l, r) => traverse(l) ++ traverse(r) + case a => a :: Nil + } + + private val andToArray: Filter => Array[Filter] = { + case And(l, r) => andToArray(l) ++ andToArray(r) + case a => Array(a) + } + + private val groupAnd: List[Filter] => List[Array[Filter]] = _.map(andToArray) + + def build(): List[Array[Filter]] = fullFilterAst match { + case Some(ast) => (toNNF andThen toDNF andThen traverse andThen groupAnd).apply(ast) + case None => List.empty + } + +} From 3441215695cd4e2c9ac1d79c68bf300e53dc4d96 Mon Sep 17 00:00:00 2001 From: Ponkin Date: Fri, 12 Aug 2016 04:52:35 +0300 Subject: [PATCH 2/6] SPARKC-312: Fixing incorrect optimization checking --- .../cassandra/CassandraSourceRelation.scala | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala index 96a70822f..9a128f81d 100644 --- a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala +++ b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala @@ -3,16 +3,7 @@ package org.apache.spark.sql.cassandra import java.net.InetAddress import java.util.UUID -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.cassandra.CassandraSQLRow.CassandraSQLRowReader -import org.apache.spark.sql.cassandra.DataTypeConverter._ -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext, sources} -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.SparkConf import com.datastax.spark.connector.cql.{CassandraConnector, CassandraConnectorConf, Schema} -import com.datastax.spark.connector.rdd.partitioner.CassandraPartitionGenerator._ import com.datastax.spark.connector.rdd.partitioner.DataSizeEstimates import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.forSystemLocalPartitioner import com.datastax.spark.connector.rdd.{CassandraRDD, ReadConf} @@ -21,6 +12,14 @@ import com.datastax.spark.connector.util.Quote._ import com.datastax.spark.connector.util.{ConfigParameter, Logging, ReflectionUtil} import com.datastax.spark.connector.writer.{SqlRowWriter, WriteConf} import com.datastax.spark.connector.{SomeColumns, _} +import org.apache.spark.SparkConf +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.cassandra.CassandraSQLRow.CassandraSQLRowReader +import org.apache.spark.sql.cassandra.DataTypeConverter._ +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row, SQLContext, sources} +import org.apache.spark.unsafe.types.UTF8String /** * Implements [[BaseRelation]]]], [[InsertableRelation]]]] and [[PrunedFilteredScan]]]] @@ -131,9 +130,9 @@ private[cassandra] class CassandraSourceRelation( val optimizedFilters = new FiltersOptimizer(filters).build() val optimizationCanBeApplied = optimizedFilters.size > 1 && - optimizedFilters.sliding(2).foldLeft(true){ (acc, sets) => + optimizedFilters.sliding(2).forall{ set => // check whether all non-pushed down filters are equals for each separate rdd - predicatePushDown(sets.head).handledBySpark == predicatePushDown(sets.last).handledBySpark + predicatePushDown(set.head).handledBySpark == predicatePushDown(set.last).handledBySpark } val filteredRdd = if(optimizationCanBeApplied) { optimizedFilters.map { predicate => From 5d593f27e5ff8aa24ea65943db13518ef6e571e8 Mon Sep 17 00:00:00 2001 From: Alexey Ponkin Date: Fri, 12 Aug 2016 17:39:03 +0300 Subject: [PATCH 3/6] SPARKC-312: Fixing bug in unhandledFilters method when optimization is available --- .../cassandra/CassandraSourceRelation.scala | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala index 9a128f81d..6cf13b88a 100644 --- a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala +++ b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala @@ -77,7 +77,15 @@ private[cassandra] class CassandraSourceRelation( def buildScan(): RDD[Row] = baseRdd.asInstanceOf[RDD[Row]] override def unhandledFilters(filters: Array[Filter]): Array[Filter] = filterPushdown match { - case true => predicatePushDown(filters).handledBySpark.toArray + case true => + val optimizedFilters = new FiltersOptimizer(filters).build() + val optimizationCanBeApplied = isOptimizationAvailable(optimizedFilters) + if(optimizationCanBeApplied) { + // all such filters are the same, take first one + predicatePushDown(optimizedFilters.head).handledBySpark.toArray + } else { + predicatePushDown(filters).handledBySpark.toArray + } case false => filters } @@ -123,17 +131,20 @@ private[cassandra] class CassandraSourceRelation( finalPushdown } + private def isOptimizationAvailable(optimizedFilters: List[Array[Filter]]): Boolean ={ + optimizedFilters.size > 1 && + optimizedFilters.sliding(2).forall{ set => + // check whether all non-pushed down filters are equals for each separate rdd + predicatePushDown(set.head).handledBySpark == predicatePushDown(set.last).handledBySpark + } + } + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val prunedRdd = maybeSelect(baseRdd, requiredColumns) val prunedFilteredRdd = { if(filterPushdown) { val optimizedFilters = new FiltersOptimizer(filters).build() - val optimizationCanBeApplied = - optimizedFilters.size > 1 && - optimizedFilters.sliding(2).forall{ set => - // check whether all non-pushed down filters are equals for each separate rdd - predicatePushDown(set.head).handledBySpark == predicatePushDown(set.last).handledBySpark - } + val optimizationCanBeApplied = isOptimizationAvailable(optimizedFilters) val filteredRdd = if(optimizationCanBeApplied) { optimizedFilters.map { predicate => val pushdownFilters = predicatePushDown(predicate).handledByCassandra.toArray From 1ec147d441b875526d5c8f5264a2f6ec7f027dc1 Mon Sep 17 00:00:00 2001 From: Alexey Ponkin Date: Tue, 16 Aug 2016 16:54:02 +0300 Subject: [PATCH 4/6] SPARKC-312: Introducing option spark.cassandra.sql.enable.where.clause.optimization + some more tests --- doc/reference.md | 6 ++ .../sql/CassandraPrunedFilteredScanSpec.scala | 5 +- .../cassandra/CassandraSourceRelation.scala | 25 ++++++-- .../sql/cassandra/FiltersOptimizer.scala | 42 +++++++++---- .../sql/cassandra/FiltersOptimizerCheck.scala | 63 +++++++++++++++++++ 5 files changed, 122 insertions(+), 19 deletions(-) create mode 100644 spark-cassandra-connector/src/test/scala/org/apache/spark/sql/cassandra/FiltersOptimizerCheck.scala diff --git a/doc/reference.md b/doc/reference.md index 77c51cf99..ea7e2e061 100644 --- a/doc/reference.md +++ b/doc/reference.md @@ -92,6 +92,12 @@ may also be used. ("127.0.0.1,192.168.0.1") + + + + + diff --git a/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala b/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala index e01913cb9..8f579d200 100644 --- a/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala +++ b/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala @@ -47,6 +47,7 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with val fieldsOptions = Map("keyspace" -> ks, "table" -> "fields") val metricsOptions = Map("keyspace" -> ks, "table" -> "metrics") val withPushdown = Map("pushdown" -> "true") + val withWhereClauseOptimizationEnabled = Map("spark.cassandra.sql.enable.where.clause.optimization" -> "true") val withoutPushdown = Map("pushdown" -> "false") "CassandraPrunedFilteredScan" should "pushdown predicates for clustering keys" in { @@ -82,7 +83,7 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with } it should "optimize table scan if all filters can be pushed down" in { - val fieldsDF = sqlContext.read.format(cassandraFormat).options(metricsOptions ++ withPushdown).load() + val fieldsDF = sqlContext.read.format(cassandraFormat).options(metricsOptions ++ withPushdown ++ withWhereClauseOptimizationEnabled ).load() val df = fieldsDF.filter("a = 5 and (b > 5 or b < 3)") val executionPlan = df.queryExecution.executedPlan val cts = findAllCassandraTableScanRDD(executionPlan) @@ -127,4 +128,4 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with } } -} \ No newline at end of file +} diff --git a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala index 6cf13b88a..b19965449 100644 --- a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala +++ b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala @@ -32,6 +32,7 @@ private[cassandra] class CassandraSourceRelation( userSpecifiedSchema: Option[StructType], filterPushdown: Boolean, tableSizeInBytes: Option[Long], + enableWhereClauseOptimization: Boolean, connector: CassandraConnector, readConf: ReadConf, writeConf: WriteConf, @@ -78,7 +79,7 @@ private[cassandra] class CassandraSourceRelation( override def unhandledFilters(filters: Array[Filter]): Array[Filter] = filterPushdown match { case true => - val optimizedFilters = new FiltersOptimizer(filters).build() + val optimizedFilters = FiltersOptimizer(filters).build() val optimizationCanBeApplied = isOptimizationAvailable(optimizedFilters) if(optimizationCanBeApplied) { // all such filters are the same, take first one @@ -131,13 +132,13 @@ private[cassandra] class CassandraSourceRelation( finalPushdown } - private def isOptimizationAvailable(optimizedFilters: List[Array[Filter]]): Boolean ={ - optimizedFilters.size > 1 && + private def isOptimizationAvailable(optimizedFilters: List[Array[Filter]]): Boolean = + enableWhereClauseOptimization && optimizedFilters.size > 1 && optimizedFilters.sliding(2).forall{ set => // check whether all non-pushed down filters are equals for each separate rdd predicatePushDown(set.head).handledBySpark == predicatePushDown(set.last).handledBySpark } - } + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val prunedRdd = maybeSelect(baseRdd, requiredColumns) @@ -252,9 +253,19 @@ object CassandraSourceRelation { """.stripMargin ) + val EnableWhereClauseOptimizationParam = ConfigParameter[Boolean]( + name = "spark.cassandra.sql.enable.where.clause.optimization", + section = ReferenceSection, + default = false, + description = + """Connector will try to optimize sql query `where`-clause, to increase + | number of filters that can be pushed down. Experimental.""".stripMargin + ) + val Properties = Seq( AdditionalCassandraPushDownRulesParam, - TableSizeInBytesParam + TableSizeInBytesParam, + EnableWhereClauseOptimizationParam ) val defaultClusterName = "default" @@ -270,6 +281,9 @@ object CassandraSourceRelation { val conf = consolidateConfs(sparkConf, sqlConf, tableRef, options.cassandraConfs) val tableSizeInBytesString = conf.getOption(TableSizeInBytesParam.name) + val enableWhereClauseOptimization = + conf.getOption(EnableWhereClauseOptimizationParam.name) + .map( _.equalsIgnoreCase("true") ).getOrElse(false) val cassandraConnector = new CassandraConnector(CassandraConnectorConf(conf)) val tableSizeInBytes = tableSizeInBytesString match { @@ -295,6 +309,7 @@ object CassandraSourceRelation { userSpecifiedSchema = schema, filterPushdown = options.pushdown, tableSizeInBytes = tableSizeInBytes, + enableWhereClauseOptimization = enableWhereClauseOptimization, connector = cassandraConnector, readConf = readConf, writeConf = writeConf, diff --git a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/FiltersOptimizer.scala b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/FiltersOptimizer.scala index d3bb55e59..a57d23dda 100644 --- a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/FiltersOptimizer.scala +++ b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/FiltersOptimizer.scala @@ -10,13 +10,31 @@ import org.apache.spark.sql.sources._ * will become equivalent * 'field1 < 3 AND field2 = "val1" OR field1 < 3 AND field2 = "val2" OR * field1 > 7 AND field2 = "val1" OR field1 > 7 AND field2 = "val2"' + * + * @param filters Array of logical statements [[org.apache.spark.sql.sources.Filter]] + * that forms `where`-clause with `AND` operator, for example: + * val Array(f1, f2, ... fn) = ... // such that `where f1 AND f2 AND ... AND fn` + * */ class FiltersOptimizer(filters: Array[Filter]) { private val fullFilterAst = if (filters.nonEmpty) Some(filters.reduce((left, right) => And(left, right))) else None - private def dist(predL: Filter, predR: Filter): Filter = (predL, predR) match { + import FiltersOptimizer._ + + def build(): List[Array[Filter]] = fullFilterAst match { + case Some(ast) => (toNNF andThen toDNF andThen traverse andThen groupByAnd).apply(ast) + case None => List.empty + } + +} + +object FiltersOptimizer{ + + def apply(filters: Array[Filter]): FiltersOptimizer = new FiltersOptimizer(filters) + + private[cassandra] def dist(predL: Filter, predR: Filter): Filter = (predL, predR) match { case (Or(l, r), p) => Or(dist(l, p), dist(r, p)) case (p, Or(l, r)) => Or(dist(p, l), dist(p, r)) case (l, r) => And(l, r) @@ -32,7 +50,7 @@ class FiltersOptimizer(filters: Array[Filter]) { * negation of their conjuncts: ¬(φ ∧ ψ) is converted to (¬φ ∨ ¬ψ) * while ¬(φ ∨ ψ) becomes (¬φ ∧ ¬ψ). */ - private val toNNF: Filter => Filter = { + private[cassandra] val toNNF: Filter => Filter = { case a@(EqualTo(_, _) | EqualNullSafe(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) | LessThan(_, _) | LessThanOrEqual(_, _) | In(_, _) | IsNull(_) | IsNotNull(_) | @@ -50,6 +68,7 @@ class FiltersOptimizer(filters: Array[Filter]) { case Not(And(l, r)) => toNNF(Or(Not(l), Not(r))) case Or(l, r) => Or(toNNF(l), toNNF(r)) case Not(Or(l, r)) => toNNF(And(Not(l), Not(r))) + case p => p } /** The 'toDNF' function converts expressions to disjunctive normal form: a @@ -59,30 +78,29 @@ class FiltersOptimizer(filters: Array[Filter]) { * The conversion is carried out by first converting the expression into * negation normal form, and then applying the distributive law. */ - private val toDNF: Filter => Filter = { + private[cassandra] val toDNF: Filter => Filter = { case And(l, r) => dist(toDNF(l), toDNF(r)) case Or(l, r) => Or(toDNF(l), toDNF(r)) case p => p } /** - * Traverse over disjunctive parts of AST + * Traverse over disjunctive clauses of AST */ - private val traverse: Filter => List[Filter] = { + private[cassandra] val traverse: Filter => List[Filter] = { case Or(l, r) => traverse(l) ++ traverse(r) case a => a :: Nil } - private val andToArray: Filter => Array[Filter] = { + /** + * Group all conjunctive clauses into Array[Filter] + * f1 && f2 && ... && fn => Array(f1, f2, ... fn) + */ + private[cassandra] val andToArray: Filter => Array[Filter] = { case And(l, r) => andToArray(l) ++ andToArray(r) case a => Array(a) } - private val groupAnd: List[Filter] => List[Array[Filter]] = _.map(andToArray) - - def build(): List[Array[Filter]] = fullFilterAst match { - case Some(ast) => (toNNF andThen toDNF andThen traverse andThen groupAnd).apply(ast) - case None => List.empty - } + private[cassandra] val groupByAnd: List[Filter] => List[Array[Filter]] = _.map(andToArray) } diff --git a/spark-cassandra-connector/src/test/scala/org/apache/spark/sql/cassandra/FiltersOptimizerCheck.scala b/spark-cassandra-connector/src/test/scala/org/apache/spark/sql/cassandra/FiltersOptimizerCheck.scala new file mode 100644 index 000000000..1fc293f87 --- /dev/null +++ b/spark-cassandra-connector/src/test/scala/org/apache/spark/sql/cassandra/FiltersOptimizerCheck.scala @@ -0,0 +1,63 @@ +package org.apache.spark.sql.cassandra + +import org.apache.spark.sql.sources._ + +import org.scalacheck._ +import org.scalacheck.Prop.forAll +import org.scalatest.prop.PropertyChecks +import org.scalatest.{FlatSpec, ShouldMatchers} + +class FiltersOptimizerCheck extends FlatSpec with PropertyChecks with ShouldMatchers { + + // For testing purpose + case object True extends Filter + case object False extends Filter + + val genFullTree = for { + size <- Gen.choose(0, 500) + tree <- genTree(size) + } yield tree + + def genTree(maxDepth: Int): Gen[Filter] = + if (maxDepth == 0) leaf else Gen.oneOf(leaf, genAnd(maxDepth), genOr(maxDepth), genNot(maxDepth)) + + def genAnd(maxDepth: Int): Gen[Filter] = for { + depthL <- Gen.choose(0, maxDepth - 1) + depthR <- Gen.choose(0, maxDepth - 1) + left <- genTree(depthL) + right <- genTree(depthR) + } yield And(left, right) + + def genOr(maxDepth: Int): Gen[Filter] = for { + depthL <- Gen.choose(0, maxDepth - 1) + depthR <- Gen.choose(0, maxDepth - 1) + left <- genTree(depthL) + right <- genTree(depthR) + } yield Or(left, right) + + def genNot(maxDepth: Int): Gen[Filter] = for { + depth <- Gen.choose(0, maxDepth - 1) + expr <- genTree(depth) + } yield Not(expr) + + def leaf: Gen[Filter] = Gen.oneOf(True, False) + + /** + * Evaluate logical ADT + **/ + private def eval(clause: Filter): Boolean = clause match { + case And(left, right) => eval(left) && eval(right) + case Or(left, right) => eval(left) || eval(right) + case Not(predicate) => !eval(predicate) + case True => true + case False => false + } + + "FiltersOptimizer" should "generate equivalent disjunction normal form for arbitrary logical statement" in { + forAll(genFullTree){ expr => + val dnf = (FiltersOptimizer.toNNF andThen FiltersOptimizer.toDNF).apply(expr) + assert(eval(dnf) == eval(expr)) + } + } + +} From 2bbb34fce9be9d639344424a3f81c4d5d68d8bde Mon Sep 17 00:00:00 2001 From: Alexey Ponkin Date: Sun, 20 Nov 2016 01:16:23 +0300 Subject: [PATCH 5/6] SPARKC-312: Fixing review comments --- .../sql/CassandraPrunedFilteredScanSpec.scala | 4 +- .../cassandra/CassandraSourceRelation.scala | 50 ++++++++----------- .../sql/cassandra/FiltersOptimizer.scala | 34 ++++++------- 3 files changed, 39 insertions(+), 49 deletions(-) diff --git a/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala b/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala index 8f579d200..bcbc8c791 100644 --- a/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala +++ b/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala @@ -5,7 +5,9 @@ import com.datastax.spark.connector.cql.CassandraConnector import com.datastax.spark.connector.rdd.{CassandraTableScanRDD, CqlWhereClause} import com.datastax.spark.connector.util.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.cassandra.CassandraSourceRelation import org.apache.spark.sql.execution.{FilterExec, RDDScanExec, RowDataSourceScanExec, SparkPlan, WholeStageCodegenExec} + import scala.concurrent.Future class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with Logging { @@ -47,7 +49,7 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with val fieldsOptions = Map("keyspace" -> ks, "table" -> "fields") val metricsOptions = Map("keyspace" -> ks, "table" -> "metrics") val withPushdown = Map("pushdown" -> "true") - val withWhereClauseOptimizationEnabled = Map("spark.cassandra.sql.enable.where.clause.optimization" -> "true") + val withWhereClauseOptimizationEnabled = Map(CassandraSourceRelation.EnableWhereClauseOptimizationParam.name -> "true") val withoutPushdown = Map("pushdown" -> "false") "CassandraPrunedFilteredScan" should "pushdown predicates for clustering keys" in { diff --git a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala index b19965449..ffb277962 100644 --- a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala +++ b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala @@ -78,15 +78,7 @@ private[cassandra] class CassandraSourceRelation( def buildScan(): RDD[Row] = baseRdd.asInstanceOf[RDD[Row]] override def unhandledFilters(filters: Array[Filter]): Array[Filter] = filterPushdown match { - case true => - val optimizedFilters = FiltersOptimizer(filters).build() - val optimizationCanBeApplied = isOptimizationAvailable(optimizedFilters) - if(optimizationCanBeApplied) { - // all such filters are the same, take first one - predicatePushDown(optimizedFilters.head).handledBySpark.toArray - } else { - predicatePushDown(filters).handledBySpark.toArray - } + case true => analyzePredicates(filters).head.handledBySpark.toArray case false => filters } @@ -132,30 +124,32 @@ private[cassandra] class CassandraSourceRelation( finalPushdown } - private def isOptimizationAvailable(optimizedFilters: List[Array[Filter]]): Boolean = - enableWhereClauseOptimization && optimizedFilters.size > 1 && - optimizedFilters.sliding(2).forall{ set => - // check whether all non-pushed down filters are equals for each separate rdd - predicatePushDown(set.head).handledBySpark == predicatePushDown(set.last).handledBySpark - } - + private def analyzePredicates(filters: Array[Filter]): List[AnalyzedPredicates] = { + if (enableWhereClauseOptimization){ + val optimizedFilters = FiltersOptimizer.build(filters) + val partitions = optimizedFilters.map(predicatePushDown) + val allHandledBySparkAreTheSame = partitions.map(_.handledBySpark).sliding(2).forall { tuple => + tuple.head == tuple.last + } + if(allHandledBySparkAreTheSame){ + partitions + } else { + List(predicatePushDown(filters)) + } + } else { + List(predicatePushDown(filters)) + } + } override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val prunedRdd = maybeSelect(baseRdd, requiredColumns) val prunedFilteredRdd = { if(filterPushdown) { - val optimizedFilters = new FiltersOptimizer(filters).build() - val optimizationCanBeApplied = isOptimizationAvailable(optimizedFilters) - val filteredRdd = if(optimizationCanBeApplied) { - optimizedFilters.map { predicate => - val pushdownFilters = predicatePushDown(predicate).handledByCassandra.toArray - maybePushdownFilters(prunedRdd, pushdownFilters).asInstanceOf[RDD[Row]] - }.reduce(_ union _) - } else { - val pushdownFilters = predicatePushDown(filters).handledByCassandra.toArray - maybePushdownFilters(prunedRdd, pushdownFilters) - } - filteredRdd.asInstanceOf[RDD[Row]] + val pushdownFilters = analyzePredicates(filters) + pushdownFilters.map { predicate => + val pushdownFilters = predicate.handledByCassandra.toArray + maybePushdownFilters(prunedRdd, pushdownFilters).asInstanceOf[RDD[Row]] + }.reduce(_ union _) } else { prunedRdd } diff --git a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/FiltersOptimizer.scala b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/FiltersOptimizer.scala index a57d23dda..0fff11a8f 100644 --- a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/FiltersOptimizer.scala +++ b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/FiltersOptimizer.scala @@ -2,7 +2,6 @@ package org.apache.spark.sql.cassandra import org.apache.spark.sql.sources._ - /** * Optimizer will try to transform pushdown filter into `sum of products`. * So that the filter like @@ -11,28 +10,23 @@ import org.apache.spark.sql.sources._ * 'field1 < 3 AND field2 = "val1" OR field1 < 3 AND field2 = "val2" OR * field1 > 7 AND field2 = "val1" OR field1 > 7 AND field2 = "val2"' * - * @param filters Array of logical statements [[org.apache.spark.sql.sources.Filter]] - * that forms `where`-clause with `AND` operator, for example: - * val Array(f1, f2, ... fn) = ... // such that `where f1 AND f2 AND ... AND fn` - * */ -class FiltersOptimizer(filters: Array[Filter]) { - - private val fullFilterAst = - if (filters.nonEmpty) Some(filters.reduce((left, right) => And(left, right))) else None - - import FiltersOptimizer._ - - def build(): List[Array[Filter]] = fullFilterAst match { - case Some(ast) => (toNNF andThen toDNF andThen traverse andThen groupByAnd).apply(ast) - case None => List.empty - } - -} - object FiltersOptimizer{ - def apply(filters: Array[Filter]): FiltersOptimizer = new FiltersOptimizer(filters) + /** + * @param filters Array of logical statements [[org.apache.spark.sql.sources.Filter]] + * that forms `where`-clause with `AND` operator, for example: + * val Array(f1, f2, ... fn) = ... // such that `where f1 AND f2 AND ... AND fn`* + * @return list of filters in disjunctive form + */ + def build(filters: Array[Filter]): List[Array[Filter]] = { + if (filters.nonEmpty) { + val ast = filters.reduce((left, right) => And(left, right)) + (toNNF andThen toDNF andThen traverse andThen groupByAnd).apply(ast) + } else { + List.empty + } + } private[cassandra] def dist(predL: Filter, predR: Filter): Filter = (predL, predR) match { case (Or(l, r), p) => Or(dist(l, p), dist(r, p)) From 3eafd8714612bfb5ca14ba0cdf514e021ab2eb6e Mon Sep 17 00:00:00 2001 From: Alexey Ponkin Date: Mon, 21 Nov 2016 21:55:02 +0300 Subject: [PATCH 6/6] Fixing test: SparkSession instead of sqlContext --- .../sql/CassandraPrunedFilteredScanSpec.scala | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala b/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala index 38a99aaeb..371bad743 100644 --- a/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala +++ b/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala @@ -7,9 +7,13 @@ import com.datastax.spark.connector.rdd.{CassandraTableScanRDD, CqlWhereClause} import com.datastax.spark.connector.util.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.cassandra.CassandraSourceRelation -import org.apache.spark.sql.execution.{FilterExec, RDDScanExec, RowDataSourceScanExec, SparkPlan, WholeStageCodegenExec} - -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.{ + FilterExec, + RDDScanExec, + RowDataSourceScanExec, + SparkPlan, + WholeStageCodegenExec +} import scala.concurrent.Future @@ -52,7 +56,8 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with val fieldsOptions = Map("keyspace" -> ks, "table" -> "fields") val metricsOptions = Map("keyspace" -> ks, "table" -> "metrics") val withPushdown = Map("pushdown" -> "true") - val withWhereClauseOptimizationEnabled = Map(CassandraSourceRelation.EnableWhereClauseOptimizationParam.name -> "true") + val withWhereClauseOptimizationEnabled = + Map(CassandraSourceRelation.EnableWhereClauseOptimizationParam.name -> "true") val withoutPushdown = Map("pushdown" -> "false") "CassandraPrunedFilteredScan" should "pushdown predicates for clustering keys" in { @@ -88,7 +93,10 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with } it should "optimize table scan if all filters can be pushed down" in { - val fieldsDF = sqlContext.read.format(cassandraFormat).options(metricsOptions ++ withPushdown ++ withWhereClauseOptimizationEnabled ).load() + val fieldsDF = sparkSession.read + .format(cassandraFormat) + .options(metricsOptions ++ withPushdown ++ withWhereClauseOptimizationEnabled ) + .load() val df = fieldsDF.filter("a = 5 and (b > 5 or b < 3)") val executionPlan = df.queryExecution.executedPlan val cts = findAllCassandraTableScanRDD(executionPlan)
Property NameDefaultDescription
sql.enable.where.clause.optimizationfalseConnector will try to optimize sql query `where`-clause, to increase + number of filters that can be pushed down. Experimental.
sql.pushdown.additionalClasses