Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LIVY-19. Add Spark SQL support #148

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/src/main/scala/com/cloudera/livy/sessions/Kind.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,17 @@ case class SparkR() extends Kind {
override def toString: String = "sparkr"
}

case class SparkSql() extends Kind {
override def toString: String = "sparksql"
}

object Kind {

def apply(kind: String): Kind = kind match {
case "spark" | "scala" => Spark()
case "pyspark" | "python" => PySpark()
case "sparkr" | "r" => SparkR()
case "sparksql" | "sql" => SparkSql()
case other => throw new IllegalArgumentException(s"Invalid kind: $other")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf)
val interpreter = Kind(livyConf.get(RSCConf.Entry.SESSION_KIND)) match {
case PySpark() => PythonInterpreter(conf)
case Spark() => new SparkInterpreter(conf)
case SparkSql() => new SparkSqlInterpreter(conf)
case SparkR() => SparkRInterpreter(conf)
}

Expand Down
301 changes: 301 additions & 0 deletions repl/src/main/scala/com/cloudera/livy/repl/SparkSqlInterpreter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
/*
* Licensed to Cloudera, Inc. under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Cloudera, Inc. licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.cloudera.livy.repl

import java.io._

import scala.tools.nsc.Settings
import scala.tools.nsc.interpreter.{JPrintWriter, Results}
import scala.util.{Failure, Success, Try}

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.repl.SparkIMain
import org.apache.spark.sql.hive.HiveContext
import org.json4s.{DefaultFormats, Extraction}
import org.json4s.JsonAST._
import org.json4s.JsonDSL._

object SparkSqlInterpreter {
private val MAGIC_REGEX = "^%(\\w+)\\W*(.*)".r
}

/**
* This represents a Spark Sql interpreter. It is not thread safe.
*/
class SparkSqlInterpreter(conf: SparkConf) extends Interpreter {
import SparkSqlInterpreter._

private implicit def formats = DefaultFormats

private val outputStream = new ByteArrayOutputStream()
private var sparkIMain: SparkIMain = _
private var sparkContext: SparkContext = _
private var hiveContext: HiveContext = _

def kind: String = "sparksql"

override def start(): SparkContext = {
require(sparkIMain == null && sparkContext == null && hiveContext == null)

val settings = new Settings()
settings.embeddedDefaults(Thread.currentThread().getContextClassLoader())
settings.usejavacp.value = true

sparkIMain = new SparkIMain(settings, new JPrintWriter(outputStream, true))
sparkIMain.initializeSynchronous()

// Spark 1.6 does not have "classServerUri"; instead, the local directory where class files
// are stored needs to be registered in SparkConf. See comment in
// SparkILoop::createSparkContext().
Try(sparkIMain.getClass().getMethod("classServerUri")) match {
case Success(method) =>
method.setAccessible(true)
conf.set("spark.repl.class.uri", method.invoke(sparkIMain).asInstanceOf[String])

case Failure(_) =>
val outputDir = sparkIMain.getClass().getMethod("getClassOutputDirectory")
outputDir.setAccessible(true)
conf.set("spark.repl.class.outputDir",
outputDir.invoke(sparkIMain).asInstanceOf[File].getAbsolutePath())
}

restoreContextClassLoader {
// Call sparkIMain.setContextClassLoader() to make sure SparkContext and repl are using the
// same ClassLoader. Otherwise if someone defined a new class in interactive shell,
// SparkContext cannot see them and will result in job stage failure.
val setContextClassLoaderMethod = sparkIMain.getClass().getMethod("setContextClassLoader")
setContextClassLoaderMethod.setAccessible(true)
setContextClassLoaderMethod.invoke(sparkIMain)

sparkContext = SparkContext.getOrCreate(conf)
hiveContext = new HiveContext(sparkContext)

Copy link
Contributor

@zjffdu zjffdu Jun 11, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think sometimes just SQLContext is sufficient, especially when there's no hive-site.xml on classpath. Link with #132 which is related. I can see 2 issues of using HiveContext.

  • It is slower to create HiveContext than SQLContext as HiveContext need to initialize many hive related stuff.
  • HiveContext may fail to start when you use the embedded derby store and run multiple sessions in yarn-client mode, although yarn-client mode is not recommended.

sparkIMain.beQuietDuring {
sparkIMain.bind("sc", "org.apache.spark.SparkContext", sparkContext, List("""@transient"""))
val hiveClassName: String = classOf[HiveContext].getCanonicalName
sparkIMain.bind("hiveContext", hiveClassName, hiveContext, List("""@transient"""))
sparkIMain.addImports("hiveContext.implicits._")
sparkIMain.addImports("org.apache.spark.sql.types._")
}
}

sparkContext
}

override def execute(code: String): Interpreter.ExecuteResponse = restoreContextClassLoader {
require(sparkIMain != null && sparkContext != null && hiveContext != null)

executeLines(code.trim.split("\n").toList, Interpreter.ExecuteSuccess(JObject(
(TEXT_PLAIN, JString(""))
)))
}

override def close(): Unit = synchronized {
if (hiveContext != null) {
// clean up and close hive context here
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe set hiveContext as null


if (sparkContext != null) {
sparkContext.stop()
}

if (sparkIMain != null) {
sparkIMain.close()
sparkIMain = null
}
}

private def executeMagic(magic: String, rest: String): Interpreter.ExecuteResponse = {
magic match {
case "json" => executeJsonMagic(rest)
case "table" => executeTableMagic(rest)
case _ =>
Interpreter.ExecuteError("UnknownMagic", f"Unknown magic command $magic")
}
}

private def executeJsonMagic(name: String): Interpreter.ExecuteResponse = {
try {
val value = sparkIMain.valueOfTerm(name) match {
case Some(obj: RDD[_]) => obj.asInstanceOf[RDD[_]].take(10)
case Some(obj) => obj
case None => return Interpreter.ExecuteError("NameError", f"Value $name does not exist")
}

Interpreter.ExecuteSuccess(JObject(
(APPLICATION_JSON, Extraction.decompose(value))
))
} catch {
case _: Throwable =>
Interpreter.ExecuteError("ValueError", "Failed to convert value into a JSON value")
}
}

private class TypesDoNotMatch extends Exception

private def convertTableType(value: JValue): String = {
value match {
case (JNothing | JNull) => "NULL_TYPE"
case JBool(_) => "BOOLEAN_TYPE"
case JString(_) => "STRING_TYPE"
case JInt(_) => "BIGINT_TYPE"
case JDouble(_) => "DOUBLE_TYPE"
case JDecimal(_) => "DECIMAL_TYPE"
case JArray(arr) =>
if (allSameType(arr.iterator)) {
"ARRAY_TYPE"
} else {
throw new TypesDoNotMatch
}
case JObject(obj) =>
if (allSameType(obj.iterator.map(_._2))) {
"MAP_TYPE"
} else {
throw new TypesDoNotMatch
}
}
}

private def allSameType(values: Iterator[JValue]): Boolean = {
if (values.hasNext) {
val type_name = convertTableType(values.next())
values.forall { case value => type_name.equals(convertTableType(value)) }
} else {
true
}
}

private def executeTableMagic(name: String): Interpreter.ExecuteResponse = {
val value = sparkIMain.valueOfTerm(name) match {
case Some(obj: RDD[_]) => obj.asInstanceOf[RDD[_]].take(10)
case Some(obj) => obj
case None => return Interpreter.ExecuteError("NameError", f"Value $name does not exist")
}

extractTableFromJValue(Extraction.decompose(value))
}

private def extractTableFromJValue(value: JValue): Interpreter.ExecuteResponse = {
// Convert the value into JSON and map it to a table.
val rows: List[JValue] = value match {
case JArray(arr) => arr
case _ => List(value)
}

try {
val headers = scala.collection.mutable.Map[String, Map[String, String]]()

val data = rows.map { case row =>
val cols: List[JField] = row match {
case JArray(arr: List[JValue]) =>
arr.zipWithIndex.map { case (v, index) => JField(index.toString, v) }
case JObject(obj) => obj.sortBy(_._1)
case value: JValue => List(JField("0", value))
}

cols.map { case (k, v) =>
val typeName = convertTableType(v)

headers.get(k) match {
case Some(header) =>
if (header.get("type").get != typeName) {
throw new TypesDoNotMatch
}
case None =>
headers.put(k, Map(
"type" -> typeName,
"name" -> k
))
}

v
}
}

Interpreter.ExecuteSuccess(
APPLICATION_LIVY_TABLE_JSON -> (
("headers" -> headers.toSeq.sortBy(_._1).map(_._2)) ~ ("data" -> data)
))
} catch {
case _: TypesDoNotMatch =>
Interpreter.ExecuteError("TypeError", "table rows have different types")
}
}

private def executeLines(
lines: List[String],
result: Interpreter.ExecuteResponse): Interpreter.ExecuteResponse = {
lines match {
case Nil => result
case head :: tail =>
val result = executeLine(head)

result match {
case Interpreter.ExecuteIncomplete() =>
tail match {
case Nil =>
result

case next :: nextTail =>
executeLines(head + "\n" + next :: nextTail, result)
}
case Interpreter.ExecuteError(_, _, _) =>
result

case _ =>
executeLines(tail, result)
}
}
}

private def executeLine(code: String): Interpreter.ExecuteResponse = {
code match {
case MAGIC_REGEX(magic, rest) =>
executeMagic(magic, rest)
case _ =>
scala.Console.withOut(outputStream) {
sparkIMain.interpret(f"""hiveContext.sql("$code")""") match {
case Results.Success =>
Interpreter.ExecuteSuccess(
TEXT_PLAIN -> readStdout()
)
case Results.Incomplete => Interpreter.ExecuteIncomplete()
case Results.Error => Interpreter.ExecuteError("Error", readStdout())
}
}
}
}

private def restoreContextClassLoader[T](fn: => T): T = {
val currentClassLoader = Thread.currentThread().getContextClassLoader()
try {
fn
} finally {
Thread.currentThread().setContextClassLoader(currentClassLoader)
}
}

private def readStdout() = {
val output = outputStream.toString("UTF-8").trim
outputStream.reset()

output
}
}