diff --git a/doc/reference.md b/doc/reference.md index 7c59fae26..5d1de1d6e 100644 --- a/doc/reference.md +++ b/doc/reference.md @@ -86,6 +86,12 @@ may also be used. ("127.0.0.1,192.168.0.1")
Property Name | Default | Description |
---|---|---|
sql.enable.where.clause.optimization |
+ false | +Connector will try to optimize sql query `where`-clause, to increase + number of filters that can be pushed down. Experimental. | +
sql.pushdown.additionalClasses |
diff --git a/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala b/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala index 30fc8ebfc..371bad743 100644 --- a/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala +++ b/spark-cassandra-connector/src/it/scala/org/apache/spark/sql/CassandraPrunedFilteredScanSpec.scala @@ -6,7 +6,14 @@ import com.datastax.spark.connector.embedded.YamlTransformations import com.datastax.spark.connector.rdd.{CassandraTableScanRDD, CqlWhereClause} import com.datastax.spark.connector.util.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.cassandra.CassandraSourceRelation +import org.apache.spark.sql.execution.{ + FilterExec, + RDDScanExec, + RowDataSourceScanExec, + SparkPlan, + WholeStageCodegenExec +} import scala.concurrent.Future @@ -34,6 +41,12 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with s"""CREATE TABLE IF NOT EXISTS $ks.fields |(k INT, a TEXT, b TEXT, c TEXT, d TEXT, e TEXT, PRIMARY KEY (k)) """ .stripMargin) + }, + Future { + session.execute( + s"""CREATE TABLE IF NOT EXISTS $ks.metrics + |(k TEXT, a INT, b INT, c INT, PRIMARY KEY (k, a, b)) """ + .stripMargin) } ) } @@ -41,7 +54,10 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with val colorOptions = Map("keyspace" -> ks, "table" -> "colors") val fieldsOptions = Map("keyspace" -> ks, "table" -> "fields") + val metricsOptions = Map("keyspace" -> ks, "table" -> "metrics") val withPushdown = Map("pushdown" -> "true") + val withWhereClauseOptimizationEnabled = + Map(CassandraSourceRelation.EnableWhereClauseOptimizationParam.name -> "true") val withoutPushdown = Map("pushdown" -> "false") "CassandraPrunedFilteredScan" should "pushdown predicates for clustering keys" in { @@ -76,6 +92,19 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with cts.get.selectedColumnNames should contain theSameElementsAs Seq("b", "c", "d") } + it should "optimize table scan if all filters can be pushed down" in { + val fieldsDF = sparkSession.read + .format(cassandraFormat) + .options(metricsOptions ++ withPushdown ++ withWhereClauseOptimizationEnabled ) + .load() + val df = fieldsDF.filter("a = 5 and (b > 5 or b < 3)") + val executionPlan = df.queryExecution.executedPlan + val cts = findAllCassandraTableScanRDD(executionPlan) + cts.nonEmpty shouldBe true + cts.head.where shouldBe CqlWhereClause(Seq(""""a" = ? AND "b" > ?"""), List(5, 5)) + cts.last.where shouldBe CqlWhereClause(Seq(""""a" = ? AND "b" < ?"""), List(5, 3)) + } + def findCassandraTableScanRDD(sparkPlan: SparkPlan): Option[CassandraTableScanRDD[_]] = { def _findCassandraTableScanRDD(rdd: RDD[_]): Option[CassandraTableScanRDD[_]] = { rdd match { @@ -94,4 +123,22 @@ class CassandraPrunedFilteredScanSpec extends SparkCassandraITFlatSpecBase with } } + def findAllCassandraTableScanRDD(sparkPlan: SparkPlan): List[CassandraTableScanRDD[_]] = { + def _findAllCassandraTableScanRDD(rdd: RDD[_]): List[CassandraTableScanRDD[_]] = { + rdd match { + case ctsrdd: CassandraTableScanRDD[_] => List(ctsrdd) + case other: RDD[_] => other.dependencies.iterator + .flatMap(dep => _findAllCassandraTableScanRDD(dep.rdd)).toList + } + } + + sparkPlan match { + case prdd: RDDScanExec => _findAllCassandraTableScanRDD(prdd.rdd) + case prdd: RowDataSourceScanExec => _findAllCassandraTableScanRDD(prdd.rdd) + case filter: FilterExec => findAllCassandraTableScanRDD(filter.child) + case wsc: WholeStageCodegenExec => findAllCassandraTableScanRDD(wsc.child) + case _ => List.empty + } + } + } \ No newline at end of file diff --git a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala index 1a38e9275..7d71990c5 100644 --- a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala +++ b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala @@ -3,17 +3,7 @@ package org.apache.spark.sql.cassandra import java.net.InetAddress import java.util.UUID -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.cassandra.CassandraSQLRow.CassandraSQLRowReader -import org.apache.spark.sql.cassandra.DataTypeConverter._ -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext, sources} -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.SparkConf - import com.datastax.spark.connector.cql.{CassandraConnector, CassandraConnectorConf, Schema} -import com.datastax.spark.connector.rdd.partitioner.CassandraPartitionGenerator._ import com.datastax.spark.connector.rdd.partitioner.DataSizeEstimates import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.forSystemLocalPartitioner import com.datastax.spark.connector.rdd.{CassandraRDD, ReadConf} @@ -22,6 +12,14 @@ import com.datastax.spark.connector.util.Quote._ import com.datastax.spark.connector.util.{ConfigParameter, Logging, ReflectionUtil} import com.datastax.spark.connector.writer.{SqlRowWriter, WriteConf} import com.datastax.spark.connector.{SomeColumns, _} +import org.apache.spark.SparkConf +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.cassandra.CassandraSQLRow.CassandraSQLRowReader +import org.apache.spark.sql.cassandra.DataTypeConverter._ +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row, SQLContext, sources} +import org.apache.spark.unsafe.types.UTF8String /** * Implements [[BaseRelation]]]], [[InsertableRelation]]]] and [[PrunedFilteredScan]]]] @@ -34,6 +32,7 @@ private[cassandra] class CassandraSourceRelation( userSpecifiedSchema: Option[StructType], filterPushdown: Boolean, tableSizeInBytes: Option[Long], + enableWhereClauseOptimization: Boolean, connector: CassandraConnector, readConf: ReadConf, writeConf: WriteConf, @@ -79,7 +78,7 @@ private[cassandra] class CassandraSourceRelation( def buildScan(): RDD[Row] = baseRdd.asInstanceOf[RDD[Row]] override def unhandledFilters(filters: Array[Filter]): Array[Filter] = filterPushdown match { - case true => predicatePushDown(filters).handledBySpark.toArray + case true => analyzePredicates(filters).head.handledBySpark.toArray case false => filters } @@ -127,13 +126,32 @@ private[cassandra] class CassandraSourceRelation( finalPushdown } + private def analyzePredicates(filters: Array[Filter]): List[AnalyzedPredicates] = { + if (enableWhereClauseOptimization){ + val optimizedFilters = FiltersOptimizer.build(filters) + val partitions = optimizedFilters.map(predicatePushDown) + val allHandledBySparkAreTheSame = partitions.map(_.handledBySpark).sliding(2).forall { tuple => + tuple.head == tuple.last + } + if(allHandledBySparkAreTheSame){ + partitions + } else { + List(predicatePushDown(filters)) + } + } else { + List(predicatePushDown(filters)) + } + } + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val prunedRdd = maybeSelect(baseRdd, requiredColumns) val prunedFilteredRdd = { if(filterPushdown) { - val pushdownFilters = predicatePushDown(filters).handledByCassandra.toArray - val filteredRdd = maybePushdownFilters(prunedRdd, pushdownFilters) - filteredRdd.asInstanceOf[RDD[Row]] + val pushdownFilters = analyzePredicates(filters) + pushdownFilters.map { predicate => + val pushdownFilters = predicate.handledByCassandra.toArray + maybePushdownFilters(prunedRdd, pushdownFilters).asInstanceOf[RDD[Row]] + }.reduce(_ union _) } else { prunedRdd } @@ -231,9 +249,19 @@ object CassandraSourceRelation { """.stripMargin ) + val EnableWhereClauseOptimizationParam = ConfigParameter[Boolean]( + name = "spark.cassandra.sql.enable.where.clause.optimization", + section = ReferenceSection, + default = false, + description = + """Connector will try to optimize sql query `where`-clause, to increase + | number of filters that can be pushed down. Experimental.""".stripMargin + ) + val Properties = Seq( AdditionalCassandraPushDownRulesParam, - TableSizeInBytesParam + TableSizeInBytesParam, + EnableWhereClauseOptimizationParam ) val defaultClusterName = "default" @@ -249,6 +277,9 @@ object CassandraSourceRelation { val conf = consolidateConfs(sparkConf, sqlConf, tableRef, options.cassandraConfs) val tableSizeInBytesString = conf.getOption(TableSizeInBytesParam.name) + val enableWhereClauseOptimization = + conf.getOption(EnableWhereClauseOptimizationParam.name) + .map( _.equalsIgnoreCase("true") ).getOrElse(false) val cassandraConnector = new CassandraConnector(CassandraConnectorConf(conf)) val tableSizeInBytes = tableSizeInBytesString match { @@ -274,6 +305,7 @@ object CassandraSourceRelation { userSpecifiedSchema = schema, filterPushdown = options.pushdown, tableSizeInBytes = tableSizeInBytes, + enableWhereClauseOptimization = enableWhereClauseOptimization, connector = cassandraConnector, readConf = readConf, writeConf = writeConf, diff --git a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/FiltersOptimizer.scala b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/FiltersOptimizer.scala new file mode 100644 index 000000000..0fff11a8f --- /dev/null +++ b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/FiltersOptimizer.scala @@ -0,0 +1,100 @@ +package org.apache.spark.sql.cassandra + +import org.apache.spark.sql.sources._ + +/** + * Optimizer will try to transform pushdown filter into `sum of products`. + * So that the filter like + * '(field1 < 3 OR field1 > 7) AND (field2 = 'val1' OR field2 = 'val2')' + * will become equivalent + * 'field1 < 3 AND field2 = "val1" OR field1 < 3 AND field2 = "val2" OR + * field1 > 7 AND field2 = "val1" OR field1 > 7 AND field2 = "val2"' + * + */ +object FiltersOptimizer{ + + /** + * @param filters Array of logical statements [[org.apache.spark.sql.sources.Filter]] + * that forms `where`-clause with `AND` operator, for example: + * val Array(f1, f2, ... fn) = ... // such that `where f1 AND f2 AND ... AND fn`* + * @return list of filters in disjunctive form + */ + def build(filters: Array[Filter]): List[Array[Filter]] = { + if (filters.nonEmpty) { + val ast = filters.reduce((left, right) => And(left, right)) + (toNNF andThen toDNF andThen traverse andThen groupByAnd).apply(ast) + } else { + List.empty + } + } + + private[cassandra] def dist(predL: Filter, predR: Filter): Filter = (predL, predR) match { + case (Or(l, r), p) => Or(dist(l, p), dist(r, p)) + case (p, Or(l, r)) => Or(dist(p, l), dist(p, r)) + case (l, r) => And(l, r) + } + + /** The 'toNNF' function converts expressions to negation normal form. This + * function is total: it's defined for all expressions, not just those which + * only use negation, conjunction and disjunction, although all expressions in + * negation normal form do in fact only use those connectives. + * + * Then de Morgan's laws are applied to convert negated + * conjunctions and disjunctions into the conjunction or disjunction of the + * negation of their conjuncts: ¬(φ ∧ ψ) is converted to (¬φ ∨ ¬ψ) + * while ¬(φ ∨ ψ) becomes (¬φ ∧ ¬ψ). + */ + private[cassandra] val toNNF: Filter => Filter = { + case a@(EqualTo(_, _) | EqualNullSafe(_, _) | GreaterThan(_, _) | + GreaterThanOrEqual(_, _) | LessThan(_, _) | LessThanOrEqual(_, _) | + In(_, _) | IsNull(_) | IsNotNull(_) | + StringStartsWith(_, _) | StringEndsWith(_, _) | StringContains(_, _)) => a + case a@Not(EqualTo(_, _) | EqualNullSafe(_, _) | In(_, _) | + StringStartsWith(_, _) | StringEndsWith(_, _) | StringContains(_, _)) => a + case Not(GreaterThan(a, v)) => LessThanOrEqual(a, v) + case Not(LessThanOrEqual(a, v)) => GreaterThan(a, v) + case Not(LessThan(a, v)) => GreaterThanOrEqual(a, v) + case Not(GreaterThanOrEqual(a, v)) => LessThan(a, v) + case Not(IsNull(a)) => IsNotNull(a) + case Not(IsNotNull(a)) => IsNull(a) + case Not(Not(p)) => p + case And(l, r) => And(toNNF(l), toNNF(r)) + case Not(And(l, r)) => toNNF(Or(Not(l), Not(r))) + case Or(l, r) => Or(toNNF(l), toNNF(r)) + case Not(Or(l, r)) => toNNF(And(Not(l), Not(r))) + case p => p + } + + /** The 'toDNF' function converts expressions to disjunctive normal form: a + * disjunction of clauses, where a clause is a conjunction of literals + * (variables and negated variables). + * + * The conversion is carried out by first converting the expression into + * negation normal form, and then applying the distributive law. + */ + private[cassandra] val toDNF: Filter => Filter = { + case And(l, r) => dist(toDNF(l), toDNF(r)) + case Or(l, r) => Or(toDNF(l), toDNF(r)) + case p => p + } + + /** + * Traverse over disjunctive clauses of AST + */ + private[cassandra] val traverse: Filter => List[Filter] = { + case Or(l, r) => traverse(l) ++ traverse(r) + case a => a :: Nil + } + + /** + * Group all conjunctive clauses into Array[Filter] + * f1 && f2 && ... && fn => Array(f1, f2, ... fn) + */ + private[cassandra] val andToArray: Filter => Array[Filter] = { + case And(l, r) => andToArray(l) ++ andToArray(r) + case a => Array(a) + } + + private[cassandra] val groupByAnd: List[Filter] => List[Array[Filter]] = _.map(andToArray) + +} diff --git a/spark-cassandra-connector/src/test/scala/org/apache/spark/sql/cassandra/FiltersOptimizerCheck.scala b/spark-cassandra-connector/src/test/scala/org/apache/spark/sql/cassandra/FiltersOptimizerCheck.scala new file mode 100644 index 000000000..1fc293f87 --- /dev/null +++ b/spark-cassandra-connector/src/test/scala/org/apache/spark/sql/cassandra/FiltersOptimizerCheck.scala @@ -0,0 +1,63 @@ +package org.apache.spark.sql.cassandra + +import org.apache.spark.sql.sources._ + +import org.scalacheck._ +import org.scalacheck.Prop.forAll +import org.scalatest.prop.PropertyChecks +import org.scalatest.{FlatSpec, ShouldMatchers} + +class FiltersOptimizerCheck extends FlatSpec with PropertyChecks with ShouldMatchers { + + // For testing purpose + case object True extends Filter + case object False extends Filter + + val genFullTree = for { + size <- Gen.choose(0, 500) + tree <- genTree(size) + } yield tree + + def genTree(maxDepth: Int): Gen[Filter] = + if (maxDepth == 0) leaf else Gen.oneOf(leaf, genAnd(maxDepth), genOr(maxDepth), genNot(maxDepth)) + + def genAnd(maxDepth: Int): Gen[Filter] = for { + depthL <- Gen.choose(0, maxDepth - 1) + depthR <- Gen.choose(0, maxDepth - 1) + left <- genTree(depthL) + right <- genTree(depthR) + } yield And(left, right) + + def genOr(maxDepth: Int): Gen[Filter] = for { + depthL <- Gen.choose(0, maxDepth - 1) + depthR <- Gen.choose(0, maxDepth - 1) + left <- genTree(depthL) + right <- genTree(depthR) + } yield Or(left, right) + + def genNot(maxDepth: Int): Gen[Filter] = for { + depth <- Gen.choose(0, maxDepth - 1) + expr <- genTree(depth) + } yield Not(expr) + + def leaf: Gen[Filter] = Gen.oneOf(True, False) + + /** + * Evaluate logical ADT + **/ + private def eval(clause: Filter): Boolean = clause match { + case And(left, right) => eval(left) && eval(right) + case Or(left, right) => eval(left) || eval(right) + case Not(predicate) => !eval(predicate) + case True => true + case False => false + } + + "FiltersOptimizer" should "generate equivalent disjunction normal form for arbitrary logical statement" in { + forAll(genFullTree){ expr => + val dnf = (FiltersOptimizer.toNNF andThen FiltersOptimizer.toDNF).apply(expr) + assert(eval(dnf) == eval(expr)) + } + } + +} |