Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance Improvement] Support for AQE mode for delayed query pushdown for optimum runtime & improved debugging #535

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ import scala.util.Try
* Shim which exposes some JDBC helper functions. Most of this code is copied from Spark SQL, with
* minor modifications for Snowflake-specific features and limitations.
*/
private[snowflake] class JDBCWrapper {
private[snowflake] class JDBCWrapper extends Serializable {

private val log = LoggerFactory.getLogger(getClass)

private val ec: ExecutionContext = {
// Note: marking field `implicit transient lazy` this allows spark to
// recreate upon de-serialization
@transient implicit private lazy val ec: ExecutionContext = {
log.debug("Creating a new ExecutionContext")
val threadFactory: ThreadFactory = new ThreadFactory {
private[this] val count = new AtomicInteger()
Expand Down Expand Up @@ -353,7 +355,7 @@ private[snowflake] class JDBCWrapper {
TelemetryClient.createTelemetry(conn.jdbcConnection)
}

private[snowflake] object DefaultJDBCWrapper extends JDBCWrapper {
private[snowflake] object DefaultJDBCWrapper extends JDBCWrapper with Serializable {

private val LOGGER = LoggerFactory.getLogger(getClass.getName)

Expand Down Expand Up @@ -588,7 +590,7 @@ private[snowflake] object DefaultJDBCWrapper extends JDBCWrapper {
private[snowflake] class SnowflakeSQLStatement(
val numOfVar: Int = 0,
val list: List[StatementElement] = Nil
) {
) extends Serializable {

private val log = LoggerFactory.getLogger(getClass)

Expand Down
6 changes: 5 additions & 1 deletion src/main/scala/net/snowflake/spark/snowflake/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package net.snowflake.spark.snowflake
import java.net.URI
import java.sql.{Connection, ResultSet}
import java.util.{Properties, UUID}

import net.snowflake.client.jdbc.{SnowflakeDriver, SnowflakeResultSet, SnowflakeResultSetSerializable}
import net.snowflake.spark.snowflake.Parameters.MergedParameters
import org.apache.spark.{SPARK_VERSION, SparkContext, SparkEnv}
Expand All @@ -37,6 +36,7 @@ import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.node.Object
import net.snowflake.spark.snowflake.FSType.FSType
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{StructField, StructType}
import org.slf4j.LoggerFactory

Expand Down Expand Up @@ -77,6 +77,10 @@ object Utils {
} else {
""
}
private[snowflake] lazy val lazyMode = SparkSession.active
.conf
.get("spark.snowflakedb.lazyModeForAQE", "true")
.toBoolean
private[snowflake] lazy val scalaVersion =
util.Properties.versionNumberString
private[snowflake] lazy val javaVersion =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package net.snowflake.spark.snowflake.pushdowns

import net.snowflake.spark.snowflake.{SnowflakeRelation, SnowflakeSQLStatement}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
import org.apache.spark.sql.execution.LeafExecNode

import java.util.concurrent.{Callable, ExecutorService, Executors, Future}

/**
* Snowflake Scan Plan for pushing query fragment to snowflake endpoint
*
* @param projection projected columns
* @param snowflakeSQL SQL query that is pushed to snowflake for evaluation
* @param relation Snowflake datasource
*/
case class SnowflakeScanExec(projection: Seq[Attribute],
snowflakeSQL: SnowflakeSQLStatement,
relation: SnowflakeRelation) extends LeafExecNode {
// result holder
@transient implicit private var data: Future[PushDownResult] = _
@transient implicit private val service: ExecutorService = Executors.newCachedThreadPool()

override protected def doPrepare(): Unit = {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is benefit of building RDD in doPrepare instead of doExecute?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doPrepare method allows spark planner in performing initial metadata collection work in async fashion. While do execute is always blocking call. Thus leveraging doPrepare we can move some work such as building sql and creating connection with snowflake in background thread while main planner operates on other nodes in the plan giving some perf gains

logInfo(s"Preparing query to push down - $snowflakeSQL")

val work = new Callable[PushDownResult]() {
override def call(): PushDownResult = {
val result = {
try {
val data = relation.buildScanFromSQL[InternalRow](snowflakeSQL, Some(schema))
PushDownResult(data = Some(data))
} catch {
case e: Exception =>
logError("Failure in query execution", e)
PushDownResult(failure = Some(e))
}
}
result
}
}
data = service.submit(work)
logInfo("submitted query asynchronously")
}

override protected def doExecute(): RDD[InternalRow] = {
if (data.get().failure.nonEmpty) {
// raise original exception
throw data.get().failure.get
}

data.get().data.get.mapPartitions { iter =>
val project = UnsafeProjection.create(schema)
iter.map(project)
}
}

override def cleanupResources(): Unit = {
logDebug(s"shutting down service to clean up resources")
service.shutdown()
}

override def output: Seq[Attribute] = projection
}

/**
* Result holder
*
* @param data RDD that holds the data
* @param failure failure information if we unable to push down
*/
private case class PushDownResult(data: Option[RDD[InternalRow]] = None,
failure: Option[Exception] = None)
extends Serializable
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package net.snowflake.spark.snowflake.pushdowns

import net.snowflake.spark.snowflake.SnowflakeConnectorFeatureNotSupportException
import net.snowflake.spark.snowflake.{SnowflakeConnectorFeatureNotSupportException, Utils}
import net.snowflake.spark.snowflake.pushdowns.querygeneration.QueryBuilder
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -38,9 +38,18 @@ class SnowflakeStrategy extends Strategy {
* @return An Option of Seq[SnowflakePlan] that contains the PhysicalPlan if
* query generation was successful, None if not.
*/
private def buildQueryRDD(plan: LogicalPlan): Option[Seq[SnowflakePlan]] =
QueryBuilder.getRDDFromPlan(plan).map {
case (output: Seq[Attribute], rdd: RDD[InternalRow]) =>
Seq(SnowflakePlan(output, rdd))
private def buildQueryRDD(plan: LogicalPlan): Option[Seq[SparkPlan]] = {
if (Utils.lazyMode) {
logInfo("Using lazy mode for push down")
QueryBuilder.getSnowflakeScanExecPlan(plan).map {
case (projection, snowflakeSQL, relation) =>
Seq(SnowflakeScanExec(projection, snowflakeSQL, relation))
}
} else {
QueryBuilder.getRDDFromPlan(plan).map {
case (output: Seq[Attribute], rdd: RDD[InternalRow]) =>
Seq(SnowflakePlan(output, rdd))
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package net.snowflake.spark.snowflake.pushdowns.querygeneration


import java.io.{PrintWriter, StringWriter}
import java.util.NoSuchElementException

import net.snowflake.spark.snowflake.{
ConnectionCacheKey,
Expand Down Expand Up @@ -307,4 +307,13 @@ private[snowflake] object QueryBuilder {
(executedBuilder.getOutput, executedBuilder.rdd)
}
}

def getSnowflakeScanExecPlan(plan: LogicalPlan):
Option[(Seq[Attribute], SnowflakeSQLStatement, SnowflakeRelation)] = {
val qb = new QueryBuilder(plan)

qb.tryBuild.map { executedBuilder =>
(executedBuilder.getOutput, executedBuilder.statement, executedBuilder.source.relation)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package net.snowflake.spark.snowflake

import net.snowflake.spark.snowflake.pushdowns.SnowflakeScanExec
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.{ExplainMode, FormattedMode}
import org.scalatest.{BeforeAndAfter, FunSuite}

class SparkQuerySuite extends FunSuite with BeforeAndAfter {
private var spark: SparkSession = _

before {
spark = SparkSession
.builder()
.master("local[2]")
.getOrCreate()
}

after {
spark.stop()
}

test("pushdown scan to snowflake") {
spark.sql(
"""
CREATE TABLE student(name string)
USING net.snowflake.spark.snowflake
OPTIONS (dbtable 'default.student',
sfdatabase 'sf-db',
tempdir '/tmp/dir',
sfurl 'accountname.snowflakecomputing.com:443',
sfuser 'alice',
sfpassword 'hello-snowflake')
""").show()

val df = spark.sql(
"""
|SELECT *
| FROM student
|""".stripMargin)
val plan = df.queryExecution.executedPlan

assert(plan.isInstanceOf[SnowflakeScanExec])
val sfPlan = plan.asInstanceOf[SnowflakeScanExec]
assert(sfPlan.snowflakeSQL.toString ==
"""SELECT * FROM ( default.student ) AS "SF_CONNECTOR_QUERY_ALIAS""""
.stripMargin)

// explain plan
val planString = df.queryExecution.explainString(FormattedMode)
val expectedString =
"""== Physical Plan ==
|SnowflakeScan (1)
|
|
|(1) SnowflakeScan
|Output [1]: [name#1]
|Arguments: [name#1], SELECT * FROM ( default.student ) AS "SF_CONNECTOR_QUERY_ALIAS", SnowflakeRelation
""".stripMargin
assert(planString.trim == expectedString.trim)
}

}