diff --git a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeJDBCWrapper.scala b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeJDBCWrapper.scala index bdeaeedc..a2fe045d 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeJDBCWrapper.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeJDBCWrapper.scala @@ -39,11 +39,13 @@ import scala.util.Try * Shim which exposes some JDBC helper functions. Most of this code is copied from Spark SQL, with * minor modifications for Snowflake-specific features and limitations. */ -private[snowflake] class JDBCWrapper { +private[snowflake] class JDBCWrapper extends Serializable { private val log = LoggerFactory.getLogger(getClass) - private val ec: ExecutionContext = { + // Note: marking field `implicit transient lazy` this allows spark to + // recreate upon de-serialization + @transient implicit private lazy val ec: ExecutionContext = { log.debug("Creating a new ExecutionContext") val threadFactory: ThreadFactory = new ThreadFactory { private[this] val count = new AtomicInteger() @@ -353,7 +355,7 @@ private[snowflake] class JDBCWrapper { TelemetryClient.createTelemetry(conn.jdbcConnection) } -private[snowflake] object DefaultJDBCWrapper extends JDBCWrapper { +private[snowflake] object DefaultJDBCWrapper extends JDBCWrapper with Serializable { private val LOGGER = LoggerFactory.getLogger(getClass.getName) @@ -588,7 +590,7 @@ private[snowflake] object DefaultJDBCWrapper extends JDBCWrapper { private[snowflake] class SnowflakeSQLStatement( val numOfVar: Int = 0, val list: List[StatementElement] = Nil -) { +) extends Serializable { private val log = LoggerFactory.getLogger(getClass) diff --git a/src/main/scala/net/snowflake/spark/snowflake/Utils.scala b/src/main/scala/net/snowflake/spark/snowflake/Utils.scala index 7eaed10a..00a29b5c 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Utils.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Utils.scala @@ -20,7 +20,6 @@ package net.snowflake.spark.snowflake import java.net.URI import java.sql.{Connection, ResultSet} import java.util.{Properties, UUID} - import net.snowflake.client.jdbc.{SnowflakeDriver, SnowflakeResultSet, SnowflakeResultSetSerializable} import net.snowflake.spark.snowflake.Parameters.MergedParameters import org.apache.spark.{SPARK_VERSION, SparkContext, SparkEnv} @@ -37,6 +36,7 @@ import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.node.Object import net.snowflake.spark.snowflake.FSType.FSType import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.{StructField, StructType} import org.slf4j.LoggerFactory @@ -77,6 +77,10 @@ object Utils { } else { "" } + private[snowflake] lazy val lazyMode = SparkSession.active + .conf + .get("spark.snowflakedb.lazyModeForAQE", "true") + .toBoolean private[snowflake] lazy val scalaVersion = util.Properties.versionNumberString private[snowflake] lazy val javaVersion = diff --git a/src/main/scala/net/snowflake/spark/snowflake/pushdowns/SnowflakeScanExec.scala b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/SnowflakeScanExec.scala new file mode 100644 index 00000000..afee8109 --- /dev/null +++ b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/SnowflakeScanExec.scala @@ -0,0 +1,75 @@ +package net.snowflake.spark.snowflake.pushdowns + +import net.snowflake.spark.snowflake.{SnowflakeRelation, SnowflakeSQLStatement} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.execution.LeafExecNode + +import java.util.concurrent.{Callable, ExecutorService, Executors, Future} + +/** + * Snowflake Scan Plan for pushing query fragment to snowflake endpoint + * + * @param projection projected columns + * @param snowflakeSQL SQL query that is pushed to snowflake for evaluation + * @param relation Snowflake datasource + */ +case class SnowflakeScanExec(projection: Seq[Attribute], + snowflakeSQL: SnowflakeSQLStatement, + relation: SnowflakeRelation) extends LeafExecNode { + // result holder + @transient implicit private var data: Future[PushDownResult] = _ + @transient implicit private val service: ExecutorService = Executors.newCachedThreadPool() + + override protected def doPrepare(): Unit = { + logInfo(s"Preparing query to push down - $snowflakeSQL") + + val work = new Callable[PushDownResult]() { + override def call(): PushDownResult = { + val result = { + try { + val data = relation.buildScanFromSQL[InternalRow](snowflakeSQL, Some(schema)) + PushDownResult(data = Some(data)) + } catch { + case e: Exception => + logError("Failure in query execution", e) + PushDownResult(failure = Some(e)) + } + } + result + } + } + data = service.submit(work) + logInfo("submitted query asynchronously") + } + + override protected def doExecute(): RDD[InternalRow] = { + if (data.get().failure.nonEmpty) { + // raise original exception + throw data.get().failure.get + } + + data.get().data.get.mapPartitions { iter => + val project = UnsafeProjection.create(schema) + iter.map(project) + } + } + + override def cleanupResources(): Unit = { + logDebug(s"shutting down service to clean up resources") + service.shutdown() + } + + override def output: Seq[Attribute] = projection +} + +/** + * Result holder + * + * @param data RDD that holds the data + * @param failure failure information if we unable to push down + */ +private case class PushDownResult(data: Option[RDD[InternalRow]] = None, + failure: Option[Exception] = None) + extends Serializable \ No newline at end of file diff --git a/src/main/scala/net/snowflake/spark/snowflake/pushdowns/SnowflakeStrategy.scala b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/SnowflakeStrategy.scala index ddb0b577..bc2ce849 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/pushdowns/SnowflakeStrategy.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/SnowflakeStrategy.scala @@ -1,6 +1,6 @@ package net.snowflake.spark.snowflake.pushdowns -import net.snowflake.spark.snowflake.SnowflakeConnectorFeatureNotSupportException +import net.snowflake.spark.snowflake.{SnowflakeConnectorFeatureNotSupportException, Utils} import net.snowflake.spark.snowflake.pushdowns.querygeneration.QueryBuilder import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.plans.logical._ @@ -38,9 +38,18 @@ class SnowflakeStrategy extends Strategy { * @return An Option of Seq[SnowflakePlan] that contains the PhysicalPlan if * query generation was successful, None if not. */ - private def buildQueryRDD(plan: LogicalPlan): Option[Seq[SnowflakePlan]] = - QueryBuilder.getRDDFromPlan(plan).map { - case (output: Seq[Attribute], rdd: RDD[InternalRow]) => - Seq(SnowflakePlan(output, rdd)) + private def buildQueryRDD(plan: LogicalPlan): Option[Seq[SparkPlan]] = { + if (Utils.lazyMode) { + logInfo("Using lazy mode for push down") + QueryBuilder.getSnowflakeScanExecPlan(plan).map { + case (projection, snowflakeSQL, relation) => + Seq(SnowflakeScanExec(projection, snowflakeSQL, relation)) + } + } else { + QueryBuilder.getRDDFromPlan(plan).map { + case (output: Seq[Attribute], rdd: RDD[InternalRow]) => + Seq(SnowflakePlan(output, rdd)) + } } + } } diff --git a/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/QueryBuilder.scala b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/QueryBuilder.scala index 1a1b847a..86c473c5 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/QueryBuilder.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/QueryBuilder.scala @@ -1,7 +1,7 @@ package net.snowflake.spark.snowflake.pushdowns.querygeneration + import java.io.{PrintWriter, StringWriter} -import java.util.NoSuchElementException import net.snowflake.spark.snowflake.{ ConnectionCacheKey, @@ -307,4 +307,13 @@ private[snowflake] object QueryBuilder { (executedBuilder.getOutput, executedBuilder.rdd) } } + + def getSnowflakeScanExecPlan(plan: LogicalPlan): + Option[(Seq[Attribute], SnowflakeSQLStatement, SnowflakeRelation)] = { + val qb = new QueryBuilder(plan) + + qb.tryBuild.map { executedBuilder => + (executedBuilder.getOutput, executedBuilder.statement, executedBuilder.source.relation) + } + } } diff --git a/src/test/scala/net/snowflake/spark/snowflake/SparkQuerySuite.scala b/src/test/scala/net/snowflake/spark/snowflake/SparkQuerySuite.scala new file mode 100644 index 00000000..be34c46c --- /dev/null +++ b/src/test/scala/net/snowflake/spark/snowflake/SparkQuerySuite.scala @@ -0,0 +1,62 @@ +package net.snowflake.spark.snowflake + +import net.snowflake.spark.snowflake.pushdowns.SnowflakeScanExec +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.{ExplainMode, FormattedMode} +import org.scalatest.{BeforeAndAfter, FunSuite} + +class SparkQuerySuite extends FunSuite with BeforeAndAfter { + private var spark: SparkSession = _ + + before { + spark = SparkSession + .builder() + .master("local[2]") + .getOrCreate() + } + + after { + spark.stop() + } + + test("pushdown scan to snowflake") { + spark.sql( + """ + CREATE TABLE student(name string) + USING net.snowflake.spark.snowflake + OPTIONS (dbtable 'default.student', + sfdatabase 'sf-db', + tempdir '/tmp/dir', + sfurl 'accountname.snowflakecomputing.com:443', + sfuser 'alice', + sfpassword 'hello-snowflake') + """).show() + + val df = spark.sql( + """ + |SELECT * + | FROM student + |""".stripMargin) + val plan = df.queryExecution.executedPlan + + assert(plan.isInstanceOf[SnowflakeScanExec]) + val sfPlan = plan.asInstanceOf[SnowflakeScanExec] + assert(sfPlan.snowflakeSQL.toString == + """SELECT * FROM ( default.student ) AS "SF_CONNECTOR_QUERY_ALIAS"""" + .stripMargin) + + // explain plan + val planString = df.queryExecution.explainString(FormattedMode) + val expectedString = + """== Physical Plan == + |SnowflakeScan (1) + | + | + |(1) SnowflakeScan + |Output [1]: [name#1] + |Arguments: [name#1], SELECT * FROM ( default.student ) AS "SF_CONNECTOR_QUERY_ALIAS", SnowflakeRelation + """.stripMargin + assert(planString.trim == expectedString.trim) + } + +}