diff --git a/gateway/build.sbt b/gateway/build.sbt index 7a816fd..d175988 100644 --- a/gateway/build.sbt +++ b/gateway/build.sbt @@ -20,8 +20,14 @@ libraryDependencies ++= Seq( "net.databinder.dispatch" %% "dispatch-core" % "0.11.2", // parsing of program arguments "com.github.scopt" %% "scopt" % "3.2.0", + // apache curator + "org.apache.curator" % "apache-curator" % "2.8.0", + "org.apache.curator" % "curator-framework" % "2.8.0", + "org.apache.curator" % "curator-recipes" % "2.8.0", // testing - "org.scalatest" %% "scalatest" % "2.2.1" + "org.scalatest" %% "scalatest" % "2.2.1", + "org.apache.curator" % "curator-test" % "2.8.0", + "org.scala-lang.modules" %% "scala-async" % "0.9.2" ) // disable parallel test execution to avoid BindException when mocking diff --git a/gateway/src/main/resources/log4j.xml b/gateway/src/main/resources/log4j.xml index 8e69dd3..5a623f5 100644 --- a/gateway/src/main/resources/log4j.xml +++ b/gateway/src/main/resources/log4j.xml @@ -5,7 +5,7 @@ - + @@ -16,7 +16,12 @@ - + + + + + + diff --git a/gateway/src/main/scala/us/jubat/jubaql_server/gateway/GatewayPlan.scala b/gateway/src/main/scala/us/jubat/jubaql_server/gateway/GatewayPlan.scala index 7f30aec..b0565f2 100644 --- a/gateway/src/main/scala/us/jubat/jubaql_server/gateway/GatewayPlan.scala +++ b/gateway/src/main/scala/us/jubat/jubaql_server/gateway/GatewayPlan.scala @@ -20,17 +20,20 @@ import org.jboss.netty.handler.execution.MemoryAwareThreadPoolExecutor import unfiltered.response._ import unfiltered.request._ import unfiltered.netty.{cycle, ServerErrorResponse} -import us.jubat.jubaql_server.gateway.json.{Unregister, Register, Query, SessionId, QueryToProcessor} +import us.jubat.jubaql_server.gateway.json._ import scala.collection.mutable import org.json4s.DefaultFormats import org.json4s.native.Serialization.write -import scala.util.Random import scala.util.{Try, Success, Failure} import java.io._ import dispatch._ import dispatch.Defaults._ import java.util.jar.JarFile import java.nio.file.{StandardCopyOption, Files} +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.collection.mutable.{ArrayBuffer, LinkedHashMap} +import scala.collection.Map // A Netty plan using async IO // cf. . @@ -40,7 +43,8 @@ import java.nio.file.{StandardCopyOption, Files} class GatewayPlan(ipAddress: String, port: Int, envpForProcessor: Array[String], runMode: RunMode, sparkDistribution: String, fatjar: String, - checkpointDir: String) + checkpointDir: String, gatewayId: String, persistSession: Boolean, + threads: Int, channelMemory: Long, totalMemory: Long) extends cycle.Plan /* With cycle.SynchronousExecution, there is a group of N (16?) threads (named "nioEventLoopGroup-5-*") that will process N requests in @@ -62,12 +66,23 @@ class GatewayPlan(ipAddress: String, port: Int, // for error handling with ServerErrorResponse with LazyLogging { - lazy val underlying = new MemoryAwareThreadPoolExecutor(16, 65536, 1048576) + lazy val underlying = new MemoryAwareThreadPoolExecutor(threads, channelMemory, totalMemory) - // holds session ids mapping to keys and host:port locations, respectively - val session2key: mutable.Map[String, String] = new mutable.HashMap() - val key2session: mutable.Map[String, String] = new mutable.HashMap() - val session2loc: mutable.Map[String, (String, Int)] = new mutable.HashMap() + val Lineseparator = System.lineSeparator() + val StateRunStr = "state: RUNNING" + + // SessionManager: holds session ids mapping to keys and host:port locations, respectively + val sessionManager = try { + persistSession match { + case true => new SessionManager(gatewayId, new ZookeeperStore) + case false => new SessionManager(gatewayId) + } + } catch { + case e: Throwable => { + logger.error("failed to start session manager", e) + throw e + } + } /* When starting the processor using spark-submit, we rely on a certain * logging behavior. It seems like the log4j.xml file bundled with @@ -93,113 +108,229 @@ class GatewayPlan(ipAddress: String, port: Int, val errorMsgContentType = ContentType("text/plain; charset=utf-8") + val DefaultTimeout: Long = 60000 + val submitTimeout = try { + val timeoutString: String = System.getProperty("jubaql.gateway.submitTimeout", DefaultTimeout.toString) + val timeout = timeoutString.toLong + if (timeout < 1) { + throw new Exception(s"""jubaql.gateway.submitTimeout value must be "1 <= n <= ${Long.MaxValue}"""") + } + timeout + } catch { + case e: Exception => + logger.warn(s"failed get jubaql.gateway.submitTimeout property. Use default Timeout : ${DefaultTimeout}", e) + DefaultTimeout + } + logger.debug(s"set spark-submit startingTimeout = $submitTimeout ms") + implicit val formats = DefaultFormats + private val statusLock = new AnyRef + var queryTransferCount: Long = 0 + var queryReceivedCount: Long = 0 + val startTime = System.currentTimeMillis() + def intent = { case req@POST(Path("/login")) => - var sessionId = "" - var key = "" + val body = readAllFromReader(req.reader) val reqSource = req.remoteAddr - logger.info(f"received HTTP request at /login from $reqSource%s") - session2key.synchronized { - do { - sessionId = Alphanumeric.generate(20) // TODO: generate in a more sophisticated way. - } while (session2key.get(sessionId) != None) - do { - key = Alphanumeric.generate(20) // TODO: generate in a more sophisticated way. - } while (key2session.get(key) != None) - session2key += (sessionId -> key) - key2session += (key -> sessionId) - } - val callbackUrl = composeCallbackUrl(ipAddress, port, key) - - val runtime = Runtime.getRuntime - val cmd = mutable.ArrayBuffer(f"$sparkDistribution%s/bin/spark-submit", - "--class", "us.jubat.jubaql_server.processor.JubaQLProcessor", - "--master", "", // set later - "--conf", "", // set later - "--conf", s"log4j.configuration=file:$tmpLog4jPath", - fatjar, - callbackUrl) - logger.info(f"starting Spark in run mode $runMode%s (session_id: $sessionId%s)") - val divide = runMode match { - case RunMode.Production(zookeeper, numExecutors, coresPerExecutor, sparkJar) => - cmd.update(4, "yarn-cluster") // --master - // When we run the processor on YARN, any options passed in with run.mode - // will be passed to the SparkSubmit class, not the the Spark driver. To - // get the run.mode passed one step further, we use the extraJavaOptions - // variable. It is important to NOT ADD ANY QUOTES HERE or they will be - // double-escaped on their way to the Spark driver and probably never end - // up there. - cmd.update(6, "spark.driver.extraJavaOptions=-Drun.mode=production " + - s"-Djubaql.zookeeper=$zookeeper " + - s"-Djubaql.checkpointdir=$checkpointDir") // --conf - // also specify the location of the Spark jar file, if given - val sparkJarParams = sparkJar match { - case Some(url) => "--conf" :: s"spark.yarn.jar=$url" :: Nil - case _ => Nil - } - cmd.insertAll(9, "--num-executors" :: numExecutors.toString :: - "--executor-cores" :: coresPerExecutor.toString :: sparkJarParams) - logger.debug("executing: " + cmd.mkString(" ")) - - Try { - val maybeProcess = Try(runtime.exec(cmd.toArray, envpForProcessor)) - - maybeProcess.flatMap { process => - // NB. which stream we have to use and whether the message we are - // waiting for actually appears, depends on the log4j.xml file - // bundled in the application jar... - val is: InputStream = process.getInputStream - val isr = new InputStreamReader(is) - val br = new BufferedReader(isr) - var line: String = br.readLine() - while (line != null && !line.trim.contains("state: RUNNING")) { - if (line.contains("Exception")) { - logger.error(line) - throw new RuntimeException("could not start spark-submit") - } - line = br.readLine() + logger.debug(f"received HTTP request at /login from $reqSource%s with body: $body%s") + + val maybeJson = org.json4s.native.JsonMethods.parseOpt(body) + val maybeSessionId = maybeJson.flatMap(_.extractOpt[SessionId]) + maybeSessionId match { + case Some(sessionId) => + // connect existing session + val session = sessionManager.getSession(sessionId.session_id) + session match { + case Failure(t) => + InternalServerError ~> errorMsgContentType ~> ResponseString("Failed to get session") + case Success(sessionInfo) => + sessionInfo match { + case SessionState.NotFound => + logger.warn("received a query JSON without a usable session_id") + Unauthorized ~> errorMsgContentType ~> ResponseString("Unknown session_id") + case SessionState.Registering(key) => + logger.warn(s"processor for session $key has not registered yet") + ServiceUnavailable ~> errorMsgContentType ~> ResponseString("This session has not been registered. Wait a second.") + case SessionState.Ready(host, port, key) => + logger.info(s"received login request for existing session (${sessionId}) from ${reqSource}") + val sessionIdJson = write(SessionId(sessionId.session_id)) + Ok ~> errorMsgContentType ~> ResponseString(sessionIdJson) } - process.destroy() - // TODO: consider to check line is not null here - Success(1) - } } - case RunMode.Development(numThreads) => - cmd.update(4, s"local[$numThreads]") // --master - cmd.update(6, "run.mode=development") // --conf - cmd.insertAll(7, Seq("--conf", s"jubaql.checkpointdir=$checkpointDir")) - logger.debug("executing: " + cmd.mkString(" ")) - - Try { - val maybeProcess = Try(runtime.exec(cmd.toArray)) - - maybeProcess.flatMap { process => - handleSubProcessOutput(process.getInputStream, System.out) - handleSubProcessOutput(process.getErrorStream, System.err) - Success(1) - } + case None => + // connect new session + val session = sessionManager.createNewSession() + session match { + case Failure(t) => + InternalServerError ~> errorMsgContentType ~> ResponseString("Failed to create session") + case Success((sessionId, key)) => + val callbackUrl = composeCallbackUrl(ipAddress, port, key) + val gatewayAddress = s"${ipAddress}:${port}" + + val runtime = Runtime.getRuntime + val cmd = mutable.ArrayBuffer(f"$sparkDistribution%s/bin/spark-submit", + "--class", "us.jubat.jubaql_server.processor.JubaQLProcessor", + "--master", "", // set later + "--conf", "", // set later + "--conf", s"log4j.configuration=file:$tmpLog4jPath", + "--name", s"JubaQLProcessor:$gatewayAddress:$sessionId", + fatjar, + callbackUrl) + logger.info(f"starting Spark in run mode $runMode%s (session_id: $sessionId%s)") + val divide = runMode match { + case RunMode.Production(zookeeper, numExecutors, coresPerExecutor, sparkJar, sparkDriverMemory, sparkExecutorMemory) => + cmd.update(4, "yarn-cluster") // --master + // When we run the processor on YARN, any options passed in with run.mode + // will be passed to the SparkSubmit class, not the the Spark driver. To + // get the run.mode passed one step further, we use the extraJavaOptions + // variable. It is important to NOT ADD ANY QUOTES HERE or they will be + // double-escaped on their way to the Spark driver and probably never end + // up there. + cmd.update(6, "spark.driver.extraJavaOptions=-Drun.mode=production " + + s"-Djubaql.zookeeper=$zookeeper " + + s"-Djubaql.checkpointdir=$checkpointDir "+ + s"-Djubaql.gateway.address=$gatewayAddress " + + s"-Djubaql.processor.sessionId=$sessionId " + + s"-XX:MaxPermSize=128m") // --conf + // also specify the location of the Spark jar file, if given + val sparkJarParams = sparkJar match { + case Some(url) => "--conf" :: s"spark.yarn.jar=$url" :: Nil + case _ => Nil + } + cmd.insertAll(9, "--num-executors" :: numExecutors.toString :: + "--executor-cores" :: coresPerExecutor.toString :: sparkJarParams) + if (sparkDriverMemory != None) { + cmd ++= mutable.ArrayBuffer("--driver-memory", sparkDriverMemory.get) + } + if (sparkExecutorMemory != None) { + cmd ++= mutable.ArrayBuffer("--executor-memory", sparkExecutorMemory.get) + } + logger.debug("executing: " + cmd.mkString(" ")) + Try { + val maybeProcess = Try(runtime.exec(cmd.toArray, envpForProcessor)) + maybeProcess match { + case Success(_) => + maybeProcess.flatMap { process => + // cache spark-submit.log + val logBuffer = new scala.collection.mutable.StringBuilder() + + val sparkProcessFuture = Future { + // NB. which stream we have to use and whether the message we are + // waiting for actually appears, depends on the log4j.xml file + // bundled in the application jar... + val is: InputStream = process.getInputStream + val isr = new InputStreamReader(is) + val br = new BufferedReader(isr) + var line: String = br.readLine() + while (line != null && !line.trim.contains(StateRunStr)) { + logBuffer.append("\t" + line + Lineseparator) + line = br.readLine() + } + val isStateRun = (line != null && line.trim.contains(StateRunStr)) + Try(process.exitValue) match { + case Failure(t) if t.isInstanceOf[IllegalThreadStateException] => + if (isStateRun) { + logger.debug("Succeed to spark-submit") + } else { + val returnCode = process.waitFor + throw new RuntimeException(s"Failed to finish process. returnCode: ${returnCode}") + } + case Failure(t) => + throw new RuntimeException("Failed to spark-submit", t) + case Success(returnCode) => + // process finished => abnormal + throw new RuntimeException(s"Failed to finish process. returnCode: ${returnCode}") + } + } + Try(Await.result(sparkProcessFuture, Duration(submitTimeout, MILLISECONDS))) match { + case Success(_) => + //watch spark-submit starting process + handleSubProcess(process, sessionId) + Success(1) + case Failure(t) => + //get standard error + val is: InputStream = process.getErrorStream + val isr = new InputStreamReader(is) + val br = new BufferedReader(isr) + try { + var line = "" + while ({ line = br.readLine(); line ne null }) { + logBuffer.append("\t" + line + Lineseparator) + } + } finally { + br.close() + } + if (t.isInstanceOf[java.util.concurrent.TimeoutException]) { + logger.error(s"processor did not start within timeout period (${submitTimeout / 1000} seconds)${Lineseparator}spark-submit log : ${Lineseparator}" + logBuffer.toString()) + } else { + logger.error(s"${t.getMessage}${Lineseparator}spark-submit log : ${Lineseparator}" + logBuffer.toString()) + } + process.destroy() + Failure(t) + } + } + case Failure(e) => + Failure(e) + } + } + case RunMode.Development(numThreads) => + cmd.update(4, s"local[$numThreads]") // --master + cmd.update(6, "run.mode=development") // --conf + cmd.insertAll(7, Seq("--conf", s"jubaql.checkpointdir=$checkpointDir")) + logger.debug("executing: " + cmd.mkString(" ")) + + Try { + val maybeProcess = Try(runtime.exec(cmd.toArray)) + maybeProcess match { + case Success(_) => + maybeProcess.flatMap { process => + handleSubProcessOutput(process.getInputStream, System.out) + handleSubProcessOutput(process.getErrorStream, System.err) + Success(1) + } + case Failure(t) => + Failure(t) + } + } + case RunMode.Test => + // do nothing in test mode. + Success(1) + } + divide match { + case Success(result) => + result match { + case Failure(e) => + logger.error(e.getMessage, e) + InternalServerError ~> errorMsgContentType ~> ResponseString("Failed to start Spark\n") + case Success(_) => + logger.info(f"started Spark with callback URL $callbackUrl%s") + val sessionIdJson = write(SessionId(sessionId)) + Ok ~> errorMsgContentType ~> ResponseString(sessionIdJson) + case 1 => + // test mode result + logger.debug(f"started Spark with callback URL $callbackUrl%s in run.mode=test") + val sessionIdJson = write(SessionId(sessionId)) + Ok ~> errorMsgContentType ~> ResponseString(sessionIdJson) + } + case Failure(e) => + logger.error(e.getMessage, e) + InternalServerError ~> errorMsgContentType ~> ResponseString("Failed to start Spark\n") + } } - case RunMode.Test => - // do nothing in test mode. - Success(1) - } - divide match { - case Success(_) => - logger.info(f"started Spark with callback URL $callbackUrl%s") - val sessionIdJson = write(SessionId(sessionId)) - Ok ~> errorMsgContentType ~> ResponseString(sessionIdJson) - case Failure(e) => - logger.error(e.getMessage) - InternalServerError ~> errorMsgContentType ~> ResponseString("Failed to start Spark\n") + } case req@POST(Path("/jubaql")) => // TODO: treat very long input + statusLock.synchronized { + queryReceivedCount += 1 + logger.debug(s"queryReceivedCount: $queryReceivedCount") + } val body = readAllFromReader(req.reader) val reqSource = req.remoteAddr - logger.info(f"received HTTP request at /jubaql from $reqSource%s with body: $body%s") + logger.debug(f"received HTTP request at /jubaql from $reqSource%s with body: $body%s") val maybeJson = org.json4s.native.JsonMethods.parseOpt(body) val maybeQuery = maybeJson.flatMap(_.extractOpt[Query]) maybeQuery match { @@ -210,43 +341,40 @@ class GatewayPlan(ipAddress: String, port: Int, logger.warn("received an unacceptable JSON query") BadRequest ~> errorMsgContentType ~> ResponseString("Unacceptable JSON") case Some(query) => - var maybeKey: Option[String] = None - var maybeLoc: Option[(String, Int)] = None - session2key.synchronized { - maybeKey = session2key.get(query.session_id) - maybeLoc = session2loc.get(query.session_id) - } - (maybeKey, maybeLoc) match { - case (None, None) => - logger.warn("received a query JSON without a usable session_id") - Unauthorized ~> errorMsgContentType ~> ResponseString("Unknown session_id") - case (None, Some(loc)) => - logger.error("inconsistent data in this gateway server") - InternalServerError ~> errorMsgContentType ~> ResponseString("Inconsistent data") - case (Some(key), None) => - logger.warn(s"processor for session $key has not registered yet") - ServiceUnavailable ~> errorMsgContentType ~> - ResponseString("This session has not been registered. Wait a second.") - case (Some(key), Some(loc)) => - // TODO: check forward query - val (host, port) = loc - - val queryJson = write(QueryToProcessor(query.query)).toString - - val url = :/(host, port) / "jubaql" - val req = Http((url.POST << queryJson) > (x => x)) - - logger.debug(f"forward query to processor ($host%s:$port%d)") - req.either.apply() match { - case Left(error) => - logger.error("failed to send request to processor [" + error.getMessage + "]") - BadGateway ~> errorMsgContentType ~> ResponseString("Bad gateway") - case Right(result) => - val statusCode = result.getStatusCode - val responseBody = result.getResponseBody - val contentType = Option(result.getContentType).getOrElse("text/plain; charset=utf-8") - logger.debug(f"got result from processor [$statusCode%d: $responseBody%s]") - Status(statusCode) ~> ContentType(contentType) ~> ResponseString(responseBody) + val session = sessionManager.getSession(query.session_id) + session match { + case Failure(t) => + InternalServerError ~> errorMsgContentType ~> ResponseString("Failed to get session") + case Success(sessionInfo) => + sessionInfo match { + case SessionState.NotFound => + logger.warn("received a query JSON without a usable session_id") + Unauthorized ~> errorMsgContentType ~> ResponseString("Unknown session_id") + case SessionState.Registering(key) => + logger.warn(s"processor for session $key has not registered yet") + ServiceUnavailable ~> errorMsgContentType ~> ResponseString("This session has not been registered. Wait a second.") + case SessionState.Ready(host, port, key) => + val queryJson = write(QueryToProcessor(query.query)).toString + + val url = :/(host, port) / "jubaql" + val req = Http((url.POST << queryJson) > (x => x)) + + logger.debug(f"forward query to processor ($host%s:$port%d)") + statusLock.synchronized { + queryTransferCount += 1 + logger.debug(s"queryTransferCount: $queryTransferCount") + } + req.either.apply() match { + case Left(error) => + logger.error("failed to send request to processor [" + error.getMessage + "]") + BadGateway ~> errorMsgContentType ~> ResponseString("Bad gateway") + case Right(result) => + val statusCode = result.getStatusCode + val responseBody = result.getResponseBody + val contentType = Option(result.getContentType).getOrElse("text/plain; charset=utf-8") + logger.debug(f"got result from processor [$statusCode%d: $responseBody%s]") + Status(statusCode) ~> ContentType(contentType) ~> ResponseString(responseBody) + } } } } @@ -260,12 +388,13 @@ class GatewayPlan(ipAddress: String, port: Int, val maybeUnregister = maybeJson.flatMap(_.extractOpt[Unregister]). filter(_.action == "unregister") - if (!maybeRegister.isEmpty) + if (!maybeRegister.isEmpty) { logger.info(f"start registration (key: $key%s)") - else if (!maybeUnregister.isEmpty) + } else if (!maybeUnregister.isEmpty) { logger.info(f"start unregistration (key: $key%s)") - else + } else { logger.info(f"start registration or unregistration (key: $key%s)") + } if (maybeJson.isEmpty) { logger.warn("received query not in JSON format") @@ -274,34 +403,43 @@ class GatewayPlan(ipAddress: String, port: Int, logger.warn("received unacceptable JSON query") BadRequest ~> errorMsgContentType ~> ResponseString("Unacceptable JSON") } else { - session2key.synchronized { - val maybeSessionId = key2session.get(key) - if (!maybeRegister.isEmpty) { // register - val register = maybeRegister.get - val (ip, port) = (register.ip, register.port) - logger.debug(f"registering $ip%s:$port%d") - maybeSessionId match { - case None => - logger.error("attempted to register unknown key") - Unauthorized ~> errorMsgContentType ~> ResponseString("Unknown key") - case Some(sessionId) => - session2loc += (sessionId -> (ip, port)) - Ok ~> errorMsgContentType ~> ResponseString("Successfully registered") - } - } else { // unregister - logger.debug("unregistering") - maybeSessionId match { - case Some(sessionId) => // unregistering an existent key - session2key -= sessionId - key2session -= key - session2loc -= sessionId - case _ => // unregistering a nonexistent key - () - } - Ok ~> errorMsgContentType ~> ResponseString("Successfully unregistered") + if (!maybeRegister.isEmpty) { // register + val register = maybeRegister.get + val (ip, port) = (register.ip, register.port) + logger.debug(f"registering $ip%s:$port%d") + val result = sessionManager.attachProcessorToSession(ip, port, key) + result match { + case Failure(t) => + InternalServerError ~> errorMsgContentType ~> ResponseString(s"Failed to register key : ${key}") + case Success(sessionId) => + logger.info(s"registered session. sessionId : $sessionId") + Ok ~> errorMsgContentType ~> ResponseString("Successfully registered") + } + } else { // unregister + logger.debug("unregistering") + val result = sessionManager.deleteSessionByKey(key) + result match { + case Failure(t) => + InternalServerError ~> errorMsgContentType ~> ResponseString(s"Failed to unregister key : ${key}") + case Success((sessionId, key)) => + if (sessionId != null) { + logger.info(s"unregistered session. sessionId: ${sessionId}") + } else { + logger.info(s"already delete session. key: ${key}") + } + Ok ~> errorMsgContentType ~> ResponseString("Successfully unregistered") } } } + + case req@POST(Path("/status")) => + val reqSource = req.remoteAddr + logger.debug(s"received HTTP request at /status from $reqSource") + val stsMap = getGatewayStatus() + val strStatus: String = write(GatewayStatus(stsMap)) + logger.debug(s"Response: $strStatus") + Ok ~> errorMsgContentType ~> ResponseString(strStatus) + } private def composeCallbackUrl(ip: String, port: Int, key: String): String = { @@ -315,6 +453,12 @@ class GatewayPlan(ipAddress: String, port: Int, thread.start() } + private def handleSubProcess(process: Process, sessionId: String): Unit = { + val thread = new SubProcessHandlerThread(process, sessionId, this, logger) + thread.setDaemon(true) + thread.start() + } + private def readAllFromReader(reader: java.io.Reader):String = { val sb = new StringBuffer() val buffer = Array[Char](1024) @@ -325,27 +469,40 @@ class GatewayPlan(ipAddress: String, port: Int, } sb.toString } -} -// An alphanumeric string generator. -object Alphanumeric { - val random = new Random() - val chars = "0123456789abcdefghijklmnopqrstuvwxyz" + def close():Unit = { + sessionManager.close() + } - def generate(length: Int): String = { - val ret = new Array[Char](length) - this.synchronized { - for (i <- 0 until ret.length) { - ret(i) = chars(random.nextInt(chars.length)) - } - } - new String(ret) + private def getGatewayStatus(): Map[String, Any] = { + val curTime = System.currentTimeMillis() + val opTime = curTime - startTime + val runtime = Runtime.getRuntime() + val usedMemory = runtime.totalMemory() - runtime.freeMemory() + + var stsMap: LinkedHashMap[String, Any] = new LinkedHashMap() + stsMap.put("ipAddress", ipAddress) + stsMap.put("port", port) + stsMap.put("user", System.getProperty("user.name")) + stsMap.put("pid", java.lang.management.ManagementFactory.getRuntimeMXBean().getName().split("@")(0)) + stsMap.put("sparkDistribution", sparkDistribution) + stsMap.put("runMode", runMode.name) + stsMap.put("zookeeper", scala.util.Properties.propOrElse("jubaql.zookeeper", "")) + stsMap.put("sessionIds", sessionManager.session2key.values) + stsMap.put("startTime", startTime) + stsMap.put("currentTime", curTime) + stsMap.put("oparatingTime", opTime) + stsMap.put("maxMemory", runtime.maxMemory()) + stsMap.put("usedMemory", usedMemory) + stsMap.put("queryTransferCount", queryTransferCount) + stsMap.put("queryReceivedCount", queryReceivedCount) + stsMap } } private class SubProcessOutputHandlerThread(in: InputStream, out: PrintStream, - logger: com.typesafe.scalalogging.Logger) extends Thread { + logger: com.typesafe.scalalogging.slf4j.Logger) extends Thread { override def run(): Unit = { val reader = new BufferedReader(new InputStreamReader(in)) try { @@ -356,18 +513,63 @@ private class SubProcessOutputHandlerThread(in: InputStream, } } catch { case e: IOException => - logger.warn("caught IOException in subprocess handler") + logger.warn("caught IOException in subprocess handler", e) () } // Never close out here. } } -sealed trait RunMode +private class SubProcessHandlerThread(process: Process, sessionId: String, parent: GatewayPlan, logger: com.typesafe.scalalogging.slf4j.Logger) extends Thread { + override def run(): Unit = { + try { + process.waitFor() + Try(process.exitValue) match { + case Success(returnCode) => + if (returnCode == 0) { + logger.info("Finished spark-submit") + } else { + logger.error(f"Failed to spark-submit. returnCode: $returnCode%s") + } + case Failure(e) => + logger.error("Failed to spark-submit", e) + } + } catch { + case e: Exception => + logger.error("caught Exception in spark-submit error handler", e) + } finally { + try { + parent.sessionManager.getSession(sessionId) match { + case Success(state) => + state match { + case SessionState.Ready(_, _, _) | SessionState.Registering(_) => + parent.sessionManager.deleteSessionById(sessionId) match { + case Success((sessionId, key)) => + logger.debug(s"Finished terminate process. sessionId: ${sessionId}") + case Failure(t) => + throw t + } + case _ => + logger.debug(s"Terminate process is not required. sessionId: ${sessionId}") + } + case Failure(t) => + throw t + } + } catch { + case e: Exception => + logger.error(s"Failed to terminate process. sessionId: ${sessionId}", e) + } + process.destroy() + } + } +} + +sealed abstract class RunMode(val name: String) object RunMode { case class Production(zookeeper: String, numExecutors: Int = 3, coresPerExecutor: Int = 2, - sparkJar: Option[String] = None) extends RunMode - case class Development(numThreads: Int = 3) extends RunMode - case object Test extends RunMode + sparkJar: Option[String] = None, sparkDriverMemory: Option[String] = None, + sparkExecutorMemory: Option[String] = None) extends RunMode("Production") + case class Development(numThreads: Int = 3) extends RunMode("Development") + case object Test extends RunMode("Test") } diff --git a/gateway/src/main/scala/us/jubat/jubaql_server/gateway/JubaQLGateway.scala b/gateway/src/main/scala/us/jubat/jubaql_server/gateway/JubaQLGateway.scala index ff4be71..307f7c2 100644 --- a/gateway/src/main/scala/us/jubat/jubaql_server/gateway/JubaQLGateway.scala +++ b/gateway/src/main/scala/us/jubat/jubaql_server/gateway/JubaQLGateway.scala @@ -20,23 +20,37 @@ import scopt.OptionParser object JubaQLGateway extends LazyLogging { val defaultPort = 9877 + val defaultThreads = 16 + val defaultChannelMemory: Long = 65536 + val defalutTotalMemory: Long = 1048576 /** Main function to start the JubaQL gateway application. */ def main(args: Array[String]) { val maybeParsedOptions: Option[CommandlineOptions] = parseCommandlineOption(args) - if (maybeParsedOptions.isEmpty) + if (maybeParsedOptions.isEmpty) { System.exit(1) + } val parsedOptions = maybeParsedOptions.get val ipAddress: String = parsedOptions.ip val port: Int = parsedOptions.port + val gatewayId = parsedOptions.gatewayId match { + case "" => s"${ipAddress}_${port}" + case id => id + } + + val persist: Boolean = parsedOptions.persist + var envp: Array[String] = Array() var runMode: RunMode = RunMode.Development() val runModeProperty: String = System.getProperty("run.mode") val sparkJar = Option(System.getProperty("spark.yarn.jar")) val zookeeperString = scala.util.Properties.propOrElse("jubaql.zookeeper", "") + val sparkDriverMemory = Option(System.getProperty("jubaql.processor.driverMemory")) + val sparkExecutorMemory = Option(System.getProperty("jubaql.processor.executorMemory")) + val devModeRe = "development:([0-9]+)".r val prodModeRe = "production:([0-9]+):([0-9]+)".r runModeProperty match { @@ -47,11 +61,12 @@ object JubaQLGateway extends LazyLogging { runMode = RunMode.Development(numThreadsString.toInt) case "production" => - runMode = RunMode.Production(zookeeperString, sparkJar = sparkJar) + runMode = RunMode.Production(zookeeperString, sparkJar = sparkJar, + sparkDriverMemory = sparkDriverMemory, sparkExecutorMemory = sparkExecutorMemory) case prodModeRe(numExecutorsString, coresPerExecutorString) => runMode = RunMode.Production(zookeeperString, numExecutorsString.toInt, - coresPerExecutorString.toInt, sparkJar = sparkJar) + coresPerExecutorString.toInt, sparkJar = sparkJar, sparkDriverMemory = sparkDriverMemory, sparkExecutorMemory = sparkExecutorMemory) case _ => System.err.println("Bad run.mode property") @@ -75,8 +90,17 @@ object JubaQLGateway extends LazyLogging { "in production mode (comma-separated host:port list)") System.exit(1) } + case p: RunMode.Development => + // When persist was configured, Set system property jubaql.zookeeper. + if (persist && zookeeperString.trim.isEmpty) { + logger.error("system property jubaql.zookeeper must be given " + + "with set persist flag (--persist)") + System.exit(1) + } else if (!persist && !zookeeperString.trim.isEmpty) { + logger.warn("persist flag is not specified; jubaql.zookeeper is ignored") + } case _ => - // don't set environment in dev mode + // don't set environment in other mode } logger.info("Starting in run mode %s".format(runMode)) @@ -86,10 +110,12 @@ object JubaQLGateway extends LazyLogging { val plan = new GatewayPlan(ipAddress, port, envp, runMode, sparkDistribution = sparkDistribution, fatjar = fatjar, - checkpointDir = checkpointDir) + checkpointDir = checkpointDir, gatewayId, persist, + parsedOptions.threads, parsedOptions.channelMemory, parsedOptions.totalMemory) val nettyServer = unfiltered.netty.Server.http(port).plan(plan) logger.info("JubaQLGateway starting") nettyServer.run() + plan.close() logger.info("JubaQLGateway shut down successfully") } @@ -106,6 +132,35 @@ object JubaQLGateway extends LazyLogging { x => if (x >= 1 && x <= 65535) success else failure("bad port number; port number n must be \"1 <= n <= 65535\"") } text (f"port (default: $defaultPort%d)") + opt[String]('g', "gatewayID") optional() valueName ("") action { + (x, o) => + o.copy(gatewayId = x) + } text ("Gateway ID (default: ip_port)") + opt[Unit]("persist") optional() valueName ("") action { + (x, o) => + o.copy(persist = true) + } text ("session persist") + opt[Int]("threads") optional() valueName ("") action { + (x, o) => + o.copy(threads = x) + } validate { + x => + if (x >= 1 && x <= Int.MaxValue) success else failure(s"invalid threads: specified in 1 or more and ${Int.MaxValue} or less") + } text (s"threads (default: $defaultThreads)") + opt[Long]("channel_memory") optional() valueName ("") action { + (x, o) => + o.copy(channelMemory = x) + } validate { + x => + if (x >= 0 && x <= Long.MaxValue) success else failure(s"invalid channelMemory: specified in 0 or more and ${Long.MaxValue} or less") + } text (s"channelMemory (default: $defaultChannelMemory)") + opt[Long]("total_memory") optional() valueName ("") action { + (x, o) => + o.copy(totalMemory = x) + } validate { + x => + if (x >= 0 && x <= Long.MaxValue) success else failure(s"invalid totalMemory: specified in 0 or more and ${Long.MaxValue} or less") + } text (s"totalMemory (default: $defalutTotalMemory)") } parser.parse(args, CommandlineOptions()) @@ -124,7 +179,7 @@ object JubaQLGateway extends LazyLogging { val dir = scala.util.Properties.propOrElse("jubaql.checkpointdir", "") if (dir.trim.isEmpty) { runMode match { - case RunMode.Production(_, _, _, _) => + case RunMode.Production(_, _, _, _, _, _) => "hdfs:///tmp/spark" case RunMode.Development(_) => "file:///tmp/spark" @@ -135,4 +190,5 @@ object JubaQLGateway extends LazyLogging { } } -case class CommandlineOptions(ip: String = "", port: Int = JubaQLGateway.defaultPort) +case class CommandlineOptions(ip: String = "", port: Int = JubaQLGateway.defaultPort, gatewayId: String = "", persist: Boolean = false, + threads: Int = JubaQLGateway.defaultThreads, channelMemory: Long = JubaQLGateway.defaultChannelMemory, totalMemory: Long = JubaQLGateway.defalutTotalMemory) diff --git a/gateway/src/main/scala/us/jubat/jubaql_server/gateway/SessionManager.scala b/gateway/src/main/scala/us/jubat/jubaql_server/gateway/SessionManager.scala new file mode 100644 index 0000000..a38ce0a --- /dev/null +++ b/gateway/src/main/scala/us/jubat/jubaql_server/gateway/SessionManager.scala @@ -0,0 +1,492 @@ +// Jubatus: Online machine learning framework for distributed environment +// Copyright (C) 2015 Preferred Networks and Nippon Telegraph and Telephone Corporation. +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License version 2.1 as published by the Free Software Foundation. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +package us.jubat.jubaql_server.gateway + +import scala.collection.mutable +import scala.collection.mutable._ +import scala.util.Random +import org.json4s._ +import org.json4s.native.JsonMethods._ +import com.typesafe.scalalogging.slf4j.LazyLogging +import scala.util.{Try, Success, Failure} + +/** + * Session Management Class + */ +class SessionManager(gatewayId: String, sessionStore: SessionStore = new NonPersistentStore) extends LazyLogging { + implicit val formats = DefaultFormats + + logger.info(s"Use SessionStore: ${sessionStore.getClass}") + + // key = sessionId, value = keyInfo + val session2key: mutable.Map[String, String] = new mutable.HashMap() + // key = keyInfo, value = sessionId + // key: identification required for registration / unregistration + val key2session: mutable.Map[String, String] = new mutable.HashMap() + // key = sessionId, vale = Processor's connectionInfo(host, port) + val session2loc: mutable.Map[String, (String, Int)] = new mutable.HashMap() + + // initialize session store(specific processing) + sessionStore.registerGateway(gatewayId) + + // initialize cache from session store + initCache() + + // --- session controller method --- + /** + * initialize cache from session store + */ + def initCache(): Unit = { + val sessionMap = sessionStore.getAllSessions(gatewayId) + session2key.synchronized { + for ((sessionId, sessionInfo) <- sessionMap) { + sessionInfo match { + case completeSesion: SessionState.Ready => + session2key += sessionId -> completeSesion.key + key2session += completeSesion.key -> sessionId + session2loc += sessionId -> (completeSesion.host, completeSesion.port) + sessionStore.addDeleteListener(gatewayId, sessionId, deleteFunction) + case registeringSession: SessionState.Registering => + logger.debug(s"Registering session not add to cache. session: ${sessionId}") + case _ => + logger.debug(s"Invalid session. exclude session: ${sessionId}") + } + } + } + } + + /** + * create new session + * @return resultCreateSession(session ID and registration key) + */ + def createNewSession(): Try[(String, String)] = { + var lockObj: SessionLock = null + val result = Try { + lockObj = lock() + val (sessionId, key) = createId() + sessionStore.preregisterSession(gatewayId, sessionId, key) + session2key.synchronized { + session2key += (sessionId -> key) + key2session += (key -> sessionId) + } + logger.debug(s"created session. sessionId: ${sessionId}") + (sessionId, key) + } + unlock(lockObj) + result + } + + /** + * get Session + * if not found in cache, contact to session store + * @param sessionId: sessionId + * @return SessionInformation (processor's host, processor's port, regstrationKey) + */ + def getSession(sessionId: String): Try[SessionState] = { + Try { + val cacheSession: SessionState = getSessionFromCache(sessionId) + cacheSession match { + case completeSession: SessionState.Ready => completeSession + case registeringSession: SessionState.Registering => registeringSession + case SessionState.NotFound => + logger.debug(s"session not found from cache. sessionId: ${sessionId}") + // session not found in cache, contact session store. + val storeSession = sessionStore.getSession(gatewayId, sessionId) + storeSession match { + case SessionState.NotFound => storeSession + case SessionState.Inconsistent => + throw new Exception(s"Inconsistent data. sessionId: ${sessionId}") + case completeSession: SessionState.Ready => + val host = completeSession.host + val port = completeSession.port + val key = completeSession.key + sessionStore.addDeleteListener(gatewayId, sessionId, deleteFunction) + session2key.synchronized { + session2key += sessionId -> key + key2session += key -> sessionId + session2loc += sessionId -> (host, port) + } + completeSession + case registeringSession: SessionState.Registering => + registeringSession + } + } + } + } + + /** + * get session from cache + * @param sessionId session ID + * @return session information + */ + def getSessionFromCache(sessionId: String): SessionState = { + session2key.get(sessionId) match { + case Some(key) => + session2loc.get(sessionId) match { + case Some(loc) => + if (key != null) { + SessionState.Ready(loc._1, loc._2, key) + } else { + throw new Exception(s"Inconsistent data. sessionId: ${sessionId}") + } + case None => + // not yet registered sessionId. + SessionState.Registering(key) + } + case None => + SessionState.NotFound + } + } + + /** + * attach Processor Information to session + * @param host Processor's host + * @param port Processor's port + * @param key registration key + * @return attachResult + */ + def attachProcessorToSession(host: String, port: Int, key: String): Try[String] = { + var lockObj: SessionLock = null + val result = Try { + val sessionId = key2session.get(key) match { + case Some(sessionId) => sessionId + case None => throw new Exception(s"non exist sessionId. key: ${key}") + } + lockObj = lock() + sessionStore.registerSession(gatewayId, sessionId, host, port, key) + sessionStore.addDeleteListener(gatewayId, sessionId, deleteFunction) + session2key.synchronized { + session2loc += (sessionId -> (host, port)) + } + logger.debug(s"attached session. sessionId: ${sessionId}") + sessionId + } + unlock(lockObj) + result + } + + /** + * delete Session by registration key + * @param key registration key + * @return deleteResult(sessionId, key) + */ + def deleteSessionByKey(key: String): Try[(String, String)] = { + Try { + key2session.get(key) match { + case Some(sessionId) => + deleteSessionById(sessionId) match { + case Success((deleteSessionId, deleteKey)) => + (deleteSessionId, deleteKey) + case Failure(t) => + throw t + } + case None => + logger.debug(s"non exist sessionId. key: ${key}") + (null, key) + } + } + } + + /** + * delete Session by sessionId + * @param sessionId session ID + * @return deleteResult(sessionId, key) + */ + def deleteSessionById(sessionId: String): Try[(String, String)] = { + var lockObj: SessionLock = null + val result = Try { + lockObj = lock() + sessionStore.deleteSession(gatewayId, sessionId) + session2key.synchronized { + val key = session2key.get(sessionId) match { + case Some(key) => + key2session -= key + key + case None => null + } + session2key -= sessionId + session2loc -= sessionId + logger.debug(s"deleted session. sessionId: ${sessionId}") + (sessionId, key) + } + + } + unlock(lockObj) + result + } + + /** + * gateway session lock + * @return Lock Object + */ + def lock(): SessionLock = { + val lock = sessionStore.lock(gatewayId) + logger.debug(s"locked: $gatewayId") + lock + } + + /** + * gateway session unlock + * @param lock Lock Object + */ + def unlock(lock: SessionLock): Unit = { + if (lock != null) { + try { + sessionStore.unlock(lock) + logger.debug(s"unlocked: gatewayId = $gatewayId") + } catch { + case e: Exception => + logger.error("failed to unlock.", e) + } + } else { + logger.debug(s"not lock. gatewayId = $gatewayId") + } + } + + /** + * generate sessionId and key + * @return (sessionId, registrationKey) + */ + def createId(): (String, String) = { + session2key.synchronized { + var sessionId = "" + var key = "" + + do { + sessionId = Alphanumeric.generate(20) + } while (session2key.get(sessionId) != None) + do { + key = Alphanumeric.generate(20) + } while (key2session.get(key) != None) + (sessionId, key) + } + } + + /** + * delete session in cache. called by delete listeners. + * @param sessionId session ID + */ + def deleteFunction(sessionId: String): Unit = { + logger.debug(s"delete session from cache. sessionId: ${sessionId}") + session2key.synchronized { + val keyOpt = session2key.get(sessionId) + keyOpt match { + case Some(key) => key2session -= key + case None => //Nothing + } + session2key -= sessionId + session2loc -= sessionId + } + } + + /** + * get session from store + * @param sessionId session ID + * @return session information + */ + def getSessionFromStore(sessionId: String): SessionState = { + sessionStore.getSession(gatewayId, sessionId) + } + + /** + * close + */ + def close(): Unit = { + sessionStore.close() + } +} + +/** + * An alphanumeric string generator. + */ +object Alphanumeric { + val random = new Random(new java.security.SecureRandom()) + val chars = "0123456789abcdefghijklmnopqrstuvwxyz" + + /** + * generate ID + * @param length generated ID length + */ + def generate(length: Int): String = { + val ret = new Array[Char](length) + this.synchronized { + for (i <- 0 until ret.length) { + ret(i) = chars(random.nextInt(chars.length)) + } + } + + new String(ret) + } +} + +/** + * Session Store Interface + */ +trait SessionStore { + implicit val formats = DefaultFormats + + /** + * get All Session by gatewayId + * @param gatewayId gateway ID + * @return SessionInformation Map(key = sessionId, value = SessionState) + */ + def getAllSessions(gatewayId: String): Map[String, SessionState] + + /** + * get Session + * if not found in cache, contact to session store + * @param gatewayId gateway ID + * @param sessionId session ID + * @return SessionState + */ + def getSession(gatewayId: String, sessionId: String): SessionState + + /** + * register SessionId to session store + * @param gatewayId gateway ID + * @param sessionId session ID + * @param key registration key + */ + def preregisterSession(gatewayId: String, sessionId: String, key: String): Unit + + /** + * register Session to session store + * @param gatewayId gateway ID + * @param sessionId session ID + * @param host Processor's host + * @param port Processor's port + * @param key registration key + */ + def registerSession(gatewayId: String, sessionId: String, host: String, port: Int, key: String): Unit + + /** + * delete Session + * @param gatewayId gateway ID + * @param sessionId session Id + */ + def deleteSession(gatewayId: String, sessionId: String): Unit + + /** + * gateway session lock + * @param gatewayId gateway ID + * @return Lock Object + */ + def lock(gatewayId: String): SessionLock + + /** + * gateway session unlock + * @param lock Lock Object + */ + def unlock(lock: SessionLock): Unit + + /** + * register delete listener + * @param gatewayId gateway Id + * @param sessionId session Id + * @param deleteFunction delete-event trigger, call function + */ + def addDeleteListener(gatewayId: String, sessionId: String, deleteFunction: Function1[String, Unit] = null): Unit + + /** + * initialize session store + * @param gatewayId gateway Id + */ + def registerGateway(gatewayId: String): Unit + + /** + * close + */ + def close(): Unit + + // --- utility method --- + /** + * extract Session Information for jsonString + * @param jsonString + * @return SessionInfo + */ + def extractSessionInfo(jsonString: String): SessionState = { + try { + val jsonData = parse(jsonString) + if (jsonData != JNothing && jsonData.children.length > 0) { + val host = (jsonData \ "host") match { + case JNothing => null + case value: JValue => value.extract[String] + } + val port = (jsonData \ "port") match { + case JNothing => null + case value: JValue => value.extract[String] + } + val key = (jsonData \ "key") match { + case JNothing => null + case value: JValue => value.extract[String] + } + if (host == null && port == null && key == null) { + SessionState.NotFound + } else if (host == null && port == null && key != null) { + SessionState.Registering(key) + } else if (host != null && port != null && key != null) { + SessionState.Ready(host, port.toInt, key) + } else { + SessionState.Inconsistent + } + } else { + SessionState.NotFound + } + } catch { + case e: Throwable => + throw e + } + } +} + +/** + * Session Lock Object + */ +class SessionLock(lock: Any) { + val lockObject = lock +} + +/** + * SessionStore for non-persist + */ +class NonPersistentStore extends SessionStore { + override def getAllSessions(gatewayId: String): Map[String, SessionState] = Map.empty[String, SessionState] + override def getSession(gatewayId: String, sessionId: String): SessionState = SessionState.NotFound + override def preregisterSession(gatewayId: String, sessionId: String, key: String): Unit = {} + override def registerSession(gatewayId: String, sessionId: String, host: String, port: Int, key: String): Unit = {} + override def deleteSession(gatewayId: String, sessionId: String): Unit = {} + override def lock(gatewayId: String): SessionLock = { null } + override def unlock(lock: SessionLock): Unit = {} + override def addDeleteListener(gatewayId: String, sessionId: String, deleteFunction: Function1[String, Unit] = null): Unit = {} + override def registerGateway(gatewayId: String): Unit = {} + override def close(): Unit = {} +} + +/** + * SessionInfo + */ +trait SessionState + +object SessionState { + // Session ready + case class Ready(host: String, port: Int, key: String) extends SessionState + // Registering yet + case class Registering(key: String) extends SessionState + // Session Inconsistent + case object Inconsistent extends SessionState + // Unknown session + case object NotFound extends SessionState +} diff --git a/gateway/src/main/scala/us/jubat/jubaql_server/gateway/ZookeeperStore.scala b/gateway/src/main/scala/us/jubat/jubaql_server/gateway/ZookeeperStore.scala new file mode 100644 index 0000000..5f9f8f3 --- /dev/null +++ b/gateway/src/main/scala/us/jubat/jubaql_server/gateway/ZookeeperStore.scala @@ -0,0 +1,257 @@ +// Jubatus: Online machine learning framework for distributed environment +// Copyright (C) 2015 Preferred Networks and Nippon Telegraph and Telephone Corporation. +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License version 2.1 as published by the Free Software Foundation. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +package us.jubat.jubaql_server.gateway + +import com.typesafe.scalalogging.slf4j.LazyLogging +import scala.collection.mutable +import scala.collection.mutable._ +import org.apache.curator._ +import org.apache.curator.retry._ +import org.apache.curator.framework._ +import org.apache.curator.framework.recipes.locks._ +import collection.JavaConversions._ +import org.apache.zookeeper.KeeperException +import org.apache.zookeeper.KeeperException.NoNodeException +import org.apache.curator.framework.api.CuratorListener +import org.apache.curator.framework.api.CuratorEvent +import org.apache.zookeeper.Watcher +import org.apache.curator.framework.listen.ListenerContainer +import java.util.concurrent.TimeUnit + +/** + * ZooKeeper Store Class + */ +class ZookeeperStore extends SessionStore with LazyLogging { + + val zkJubaQLPath: String = "/jubaql" + val zkSessionPath: String = "/jubaql/session" + + val lockNodeName = "locks" + val leasesNodeName = "leases" + //lock keep time(unit: seconds) + val lockTimeout: Long = 300 + // retry sleep time(unit: ms) + val retrySleepTimeMs: Int = 1000 + val retryCount: Int = 1 + + val zookeeperString: String = scala.util.Properties.propOrElse("jubaql.zookeeper", "") + val retryPolicy: RetryPolicy = new RetryNTimes(retryCount, retrySleepTimeMs) + + logger.info(s"connecting to ZooKeeper : ${zookeeperString}") + val zookeeperClient: CuratorFramework = CuratorFrameworkFactory.newClient(zookeeperString, retryPolicy) + + zookeeperClient.start() + if (!zookeeperClient.getZookeeperClient.blockUntilConnectedOrTimedOut) { + zookeeperClient.close() + logger.error(s"zookeeper connection timeout. zookeeper: ${zookeeperString}") + throw new Exception("failed to connected zookeeper") + } + logger.info(s"connected to ZooKeeper : ${zookeeperString}") + + override def getAllSessions(gatewayId: String): Map[String, SessionState] = { + var result = Map.empty[String, SessionState] + try { + val isExist = zookeeperClient.checkExists().forPath(s"${zkSessionPath}/${gatewayId}") + if (isExist != null) { + val sessionIdList = zookeeperClient.getChildren.forPath(s"${zkSessionPath}/${gatewayId}") + for (sessionId <- sessionIdList) { + val sessionInfo = getSession(gatewayId, sessionId) + result += sessionId -> sessionInfo + } + } else { + logger.debug(s"non exist node. gatewayId: ${gatewayId}") + } + } catch { + case e: Exception => + val errorMessage = "failed to get all session." + logger.error(errorMessage, e) + throw new Exception(errorMessage, e) + } + return result + } + + override def getSession(gatewayId: String, sessionId: String): SessionState = { + try { + val sessionByteArray = zookeeperClient.getData().forPath(s"${zkSessionPath}/${gatewayId}/${sessionId}") + val sessionJsonString = new String(sessionByteArray, "UTF-8") + extractSessionInfo(sessionJsonString) + } catch { + case e: NoNodeException => + logger.debug(s"not found session. sesionId: ${sessionId}") + SessionState.NotFound + case e: Exception => + val errorMessage = s"failed to get session. sessionId: ${sessionId}" + logger.error(errorMessage, e) + throw new Exception(errorMessage, e) + } + } + + override def preregisterSession(gatewayId: String, sessionId: String, key: String): Unit = { + try { + val isExists = zookeeperClient.checkExists().forPath(s"${zkSessionPath}/${gatewayId}/${sessionId}") + if (isExists == null) { + zookeeperClient.create().forPath(s"${zkSessionPath}/${gatewayId}/${sessionId}", s"""{"key":"$key"}""".getBytes("UTF-8")) + } else { + throw new IllegalStateException(s"already exists session. sessionId: ${sessionId}") + } + } catch { + case e: Exception => + val errorMessage = s"failed to pre-register session. sessionId: ${sessionId}" + logger.error(errorMessage, e) + throw new Exception(errorMessage, e) + } + } + + override def registerSession(gatewayId: String, sessionId: String, host: String, port: Int, key: String): Unit = { + try { + val isExists = zookeeperClient.checkExists().forPath(s"${zkSessionPath}/${gatewayId}/${sessionId}") + if (isExists != null) { + val currentSession = getSession(gatewayId, sessionId) + if (currentSession.isInstanceOf[SessionState.Registering]) { + zookeeperClient.setData().forPath(s"${zkSessionPath}/${gatewayId}/${sessionId}", s"""{"host":"$host","port":$port,"key":"$key"}""".getBytes("UTF-8")) + } else { + throw new IllegalStateException(s"illegal session state. sessionId: ${sessionId}, state: ${currentSession}") + } + } else { + throw new IllegalStateException(s"non exists session. sessionId: ${sessionId}") + } + } catch { + case e: Exception => + val errorMessage = s"failed to register session. sessionId: ${sessionId}" + logger.error(errorMessage, e) + throw new Exception(errorMessage, e) + } + } + + override def deleteSession(gatewayId: String, sessionId: String): Unit = { + try { + zookeeperClient.delete().forPath(s"${zkSessionPath}/${gatewayId}/${sessionId}") + } catch { + case e: NoNodeException => logger.warn(s"No exist Node. sessionId : $sessionId") + case e: Exception => + val errorMessage = s"failed to delete session. sessionId: ${sessionId}" + logger.error(errorMessage, e) + throw new Exception(errorMessage, e) + } + } + + override def lock(gatewayId: String): SessionLock = { + val mutex: InterProcessSemaphoreMutex = try { + val mutex = new InterProcessSemaphoreMutex(zookeeperClient, s"${zkSessionPath}/${gatewayId}") + mutex.acquire(lockTimeout, TimeUnit.SECONDS) + mutex + } catch { + case e: Exception => + val errorMessage = s"failed to create lock object. gatewayId: ${gatewayId}" + logger.error(errorMessage, e) + throw new Exception(errorMessage, e) + } + new SessionLock(mutex) + } + + override def unlock(lock: SessionLock): Unit = { + if (lock != null) { + val mutex = lock.lockObject + if (mutex != null && mutex.isInstanceOf[InterProcessSemaphoreMutex]) { + try { + mutex.asInstanceOf[InterProcessSemaphoreMutex].release() + } catch { + case e: Exception => + val errorMessage = "failed to unlock" + logger.error(errorMessage, e) + throw new Exception(errorMessage, e) + } + } else { + val errorMessage = if (mutex != null) { + s"failed to unlock. illegal lock object: ${mutex.getClass()}" + } else { + "failed to unlock. lock object: null" + } + logger.error(errorMessage) + throw new Exception(errorMessage) + } + } else { + val errorMessage = "failed to unlock. session lock: null" + logger.error(errorMessage) + throw new Exception(errorMessage) + } + } + + override def addDeleteListener(gatewayId: String, sessionId: String, deleteFunction: Function1[String, Unit]): Unit = { + try { + zookeeperClient.synchronized { + val listenerSize = if (zookeeperClient.getCuratorListenable.isInstanceOf[ListenerContainer[CuratorListener]]) { + val container = zookeeperClient.getCuratorListenable.asInstanceOf[ListenerContainer[CuratorListener]] + container.size() + } else { + throw new Exception(s"invalid listener class. class: ${zookeeperClient.getCuratorListenable.getClass}") + } + if (sessionId != lockNodeName && sessionId != leasesNodeName) { + if (listenerSize == 0) { + val listener = new CuratorListener() { + override def eventReceived(client: CuratorFramework, event: CuratorEvent): Unit = { + if (event.getWatchedEvent != null && event.getWatchedEvent.getType == Watcher.Event.EventType.NodeDeleted) { + val deletedNodePath = event.getPath + val deletedSessionID = deletedNodePath.substring(deletedNodePath.lastIndexOf("/") + 1, deletedNodePath.length) + deleteFunction(deletedSessionID) + } + } + } + zookeeperClient.getCuratorListenable.addListener(listener) + } + zookeeperClient.getChildren.watched().forPath(s"${zkSessionPath}/${gatewayId}/${sessionId}") + } else { + logger.debug(s"not add listener. exclude sessionId: ${sessionId}") + } + } + } catch { + case e: Exception => + val errorMessage = s"failed to add delete listener. sessionId: ${sessionId}" + logger.error(errorMessage, e) + throw new Exception(errorMessage, e) + } + } + + override def registerGateway(gatewayId: String): Unit = { + try { + if (zookeeperClient.checkExists().forPath(zkJubaQLPath) == null) { + zookeeperClient.create().forPath(zkJubaQLPath, new Array[Byte](0)) + } + if (zookeeperClient.checkExists().forPath(zkSessionPath) == null) { + zookeeperClient.create().forPath(zkSessionPath, new Array[Byte](0)) + } + if (zookeeperClient.checkExists().forPath(s"${zkSessionPath}/${gatewayId}") == null) { + zookeeperClient.create().forPath(s"${zkSessionPath}/${gatewayId}", new Array[Byte](0)) + } + } catch { + case e: Exception => + val errorMessage = s"failed to registerGateway. gatewayId: ${gatewayId}" + logger.error(errorMessage, e) + throw new Exception(errorMessage, e) + } + } + + override def close(): Unit = { + try { + zookeeperClient.close() + logger.info("zookeeperClient closed") + } catch { + case e: Exception => + val errorMessage = s"failed to close." + logger.error(errorMessage, e) + } + } +} \ No newline at end of file diff --git a/gateway/src/main/scala/us/jubat/jubaql_server/gateway/json/GatewayStatus.scala b/gateway/src/main/scala/us/jubat/jubaql_server/gateway/json/GatewayStatus.scala new file mode 100644 index 0000000..7878c95 --- /dev/null +++ b/gateway/src/main/scala/us/jubat/jubaql_server/gateway/json/GatewayStatus.scala @@ -0,0 +1,20 @@ +// Jubatus: Online machine learning framework for distributed environment +// Copyright (C) 2015 Preferred Networks and Nippon Telegraph and Telephone Corporation. +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License version 2.1 as published by the Free Software Foundation. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +package us.jubat.jubaql_server.gateway.json + +import scala.collection.Map + +case class GatewayStatus(gateway: Map[String, Any]) diff --git a/gateway/src/test/java/DummySparkSubmit.java b/gateway/src/test/java/DummySparkSubmit.java new file mode 100644 index 0000000..2f5746e --- /dev/null +++ b/gateway/src/test/java/DummySparkSubmit.java @@ -0,0 +1,133 @@ +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.net.HttpURLConnection; +import java.net.URL; +import java.nio.charset.StandardCharsets; + +public class DummySparkSubmit { + public static void main(String[] args) throws Exception { + String url = ""; + boolean isTimeout = false; + boolean isReturn = false; + boolean isAfter = false; + boolean isAfterFailed = false; + boolean isException = false; + + for (String arg : args) { + System.out.println(arg); + if(arg.startsWith("http://")) { + url = arg; + } else if (arg.contains("jubaql.checkpointdir")) { + if (arg.contains("timeout")) { + isTimeout = true; + } else if (arg.contains("return")) { + isReturn = true; + } else if (arg.contains("after")) { + isAfter = true; + } else if (arg.contains("afterFailed")) { + isAfterFailed = true; + }else if (arg.contains("exception")) { + isException = true; + } + } + } + if (isTimeout) { + long startTime = System.currentTimeMillis(); + while (System.currentTimeMillis() - startTime < 30000) { + try { + Thread.sleep(1000); + } catch (Exception e) { + e.printStackTrace(); + } + System.out.println("state: ACCEPTED"); + } + } else if (isReturn) { + long startTime = System.currentTimeMillis(); + while (System.currentTimeMillis() - startTime < 5000) { + try { + Thread.sleep(1000); + } catch (Exception e) { + e.printStackTrace(); + } + System.err.println("Standard Error Message"); + System.out.println("state: ACCEPTED"); + } + } else if (isAfter) { + Thread.sleep(1000); + System.out.println("state: ACCEPTED"); + Thread.sleep(1000); + System.out.println("state: RUNNING"); + sendRegistRequest(url); + Thread.sleep(5000); + System.out.println("[Error] Exception: xxxxxxxxxxxx"); + } else if (isAfterFailed) { + Thread.sleep(1000); + System.out.println("state: ACCEPTED"); + Thread.sleep(1000); + System.out.println("state: RUNNING"); + sendRegistRequest(url); + Thread.sleep(5000); + System.out.println("[Error] Exception: xxxxxxxxxxxx"); + System.exit(10); + } else if (isException) { + Thread.sleep(1000); + System.out.println("state: ACCEPTED"); + Thread.sleep(1000); + System.out.println("state: RUNNING"); + sendRegistRequest(url); + Thread.sleep(5000); + throw new Exception("Runnning After Exception"); + } else { + try { + Thread.sleep(1000); + System.out.println("state: ACCEPTED"); + Thread.sleep(1000); + System.out.println("state: RUNNING"); + sendRegistRequest(url); + Thread.sleep(5000); + long startTime = System.currentTimeMillis(); + while (System.currentTimeMillis() - startTime < 30000) { + Thread.sleep(1000); + System.out.println("state: RUNNNING"); + } + } catch (Exception e) { + e.printStackTrace(); + throw e; + } + } + System.exit(0); + } + + private static void sendRegistRequest(String urlString) { + HttpURLConnection connection = null; + try { + URL url = new URL(urlString); + connection = (HttpURLConnection) url.openConnection(); + connection.setDoOutput(true); + connection.setRequestMethod("POST"); + BufferedWriter writer = new BufferedWriter( + new OutputStreamWriter(connection.getOutputStream(), StandardCharsets.UTF_8)); + writer.write("{\"action\": \"register\", \"ip\": \"localhost\",\"port\": 12345}"); + writer.flush(); + + if (connection.getResponseCode() == HttpURLConnection.HTTP_OK) { + try (InputStreamReader isr = new InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8); + BufferedReader reader = new BufferedReader(isr)) { + String line; + while ((line = reader.readLine()) != null) { + System.out.println(line); + } + } + } + } catch (Exception e) { + e.printStackTrace(); + } finally { + if (connection != null) { + connection.disconnect(); + } + } + + } +} diff --git a/gateway/src/test/resources/dummy/bin/spark-submit b/gateway/src/test/resources/dummy/bin/spark-submit new file mode 100644 index 0000000..16a2e6f --- /dev/null +++ b/gateway/src/test/resources/dummy/bin/spark-submit @@ -0,0 +1,5 @@ +#!/bin/sh + +java -cp target/scala-2.10/test-classes DummySparkSubmit $@ + +exit $? \ No newline at end of file diff --git a/gateway/src/test/resources/spark.xml.dist b/gateway/src/test/resources/spark.xml.dist new file mode 100644 index 0000000..adbd44b --- /dev/null +++ b/gateway/src/test/resources/spark.xml.dist @@ -0,0 +1,7 @@ + + + + +[spark home path] +src/test/resources/dummy + diff --git a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/DummyStore.scala b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/DummyStore.scala new file mode 100644 index 0000000..d1e2fbd --- /dev/null +++ b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/DummyStore.scala @@ -0,0 +1,107 @@ +// Jubatus: Online machine learning framework for distributed environment +// Copyright (C) 2015 Preferred Networks and Nippon Telegraph and Telephone Corporation. +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License version 2.1 as published by the Free Software Foundation. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +package us.jubat.jubaql_server.gateway + +import scala.collection.mutable._ + +class DummyStore extends SessionStore { + + override def getAllSessions(gatewayId: String): Map[String, SessionState] = { + println(s"call getAllSessions gatewayId= ${gatewayId}") + if (gatewayId.contains("getAllSessionsFailed")) { + throw new Exception("getAllSessionsFailed") + } + val result = Map.empty[String, SessionState] + result += "dummySession1" -> SessionState.Ready("dummyHost1",11111,"dummyKey1") + result += "dummySession2" -> SessionState.Ready("dummyHost2",11112,"dummyKey2") + result += "dummySession3" -> SessionState.Registering("dummyKey3") + result += "dummySession4" -> SessionState.Inconsistent + result += "dummySession5" -> SessionState.NotFound + result += "dummySession6" -> null + } + + override def getSession(gatewayId: String, sessionId: String): SessionState = { + println(s"call getSession : gatewayId= ${gatewayId}, sessionId= ${sessionId}") + if (gatewayId.contains("getSessionFailed")) { + throw new Exception("getSessionFailed") + } + val notfoundRe = """(.*Notfound.*)""".r + val inconsistentRe = """(.*Inconsistent.*)""".r + val registeringRe = """(.*Registering.*)""".r + val readyRe = """(.*Ready.*)""".r + + sessionId match { + case notfoundRe(id) => SessionState.NotFound + case inconsistentRe(id) => SessionState.Inconsistent + case registeringRe(id) => SessionState.Registering("registeringKey") + case readyRe(id) => SessionState.Ready("readyHost",12345,"readyKey") + } + } + override def preregisterSession(gatewayId: String, sessionId: String, key: String): Unit = { + println(s"call preregisterSession: gatewayId= ${gatewayId}, sessionId= ${sessionId}, key= ${key}") + if (gatewayId.contains("preregisterSessionFailed")) { + throw new Exception("preregisterSessionFailed") + } + } + override def registerSession(gatewayId: String, sessionId: String, host: String, port: Int, key: String): Unit = { + println(s"call registerSession: gatewayId= ${gatewayId}, sessionId= ${sessionId}, host= ${host}, port= ${port}, key= ${key}") + if (gatewayId.contains("registerSessionFailed")) { + throw new Exception("registerSessionFailed") + } + } + override def deleteSession(gatewayId: String, sessionId: String): Unit = { + println(s"call deleteSession: gatewayId= ${gatewayId}, sessionId= ${sessionId}") + if (gatewayId.contains("deleteSessionFailed")) { + throw new Exception("deleteSessionFailed") + } else if (gatewayId.contains("deleteSessionByIdFailed")) { + throw new Exception("deleteSessionByIdFailed") + } + } + override def lock(gatewayId: String): SessionLock = { + println(s"call lock: gatewayId= ${gatewayId}") + if (gatewayId.contains("lockFailed")) { + throw new Exception("lockFailed") + } + new SessionLock(gatewayId) + } + override def unlock(lock: SessionLock): Unit = { + if (lock != null ){ + println(s"call unlock: lockObject: ${lock.lockObject}") + } else { + println("call unlock: lockObject: null") + } + if (lock != null && lock.lockObject != null && lock.lockObject.toString.contains("unLockFailed")) { + throw new Exception("unLockFailed") + } + } + override def addDeleteListener(gatewayId: String, sessionId: String, deleteFunction: Function1[String, Unit] = null): Unit = { + println(s"call addDeleteListener: gatewayId= ${gatewayId}, sessionId= ${sessionId}") + if (gatewayId.contains("addDeleteListenerFailed")) { + throw new Exception("addDeleteListenerFailed") + } + } + + override def registerGateway(gatewayId: String): Unit = { + println(s"call registerGateway: gatewayId= ${gatewayId}") + if (gatewayId.contains("registerGatewayFailed")) { + throw new Exception("registerGatewayFailed") + } + } + + override def close(): Unit = { + println(s"call close") + } +} \ No newline at end of file diff --git a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/GatewayPlanSpec.scala b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/GatewayPlanSpec.scala new file mode 100644 index 0000000..6b0f88e --- /dev/null +++ b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/GatewayPlanSpec.scala @@ -0,0 +1,65 @@ +// Jubatus: Online machine learning framework for distributed environment +// Copyright (C) 2015 Preferred Networks and Nippon Telegraph and Telephone Corporation. +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License version 2.1 as published by the Free Software Foundation. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +package us.jubat.jubaql_server.gateway + +import org.scalatest._ +import dispatch._ +import dispatch.Defaults._ +import scala.util.Success +import unfiltered.netty.Server + +class GatewayPlanSpec extends FlatSpec with Matchers with HasSpark { + + "startingTimeout propery non exist" should "reflect default value" in { + // System.setProperty("", arg1) + val plan = new GatewayPlan("example.com", 1234, + Array(), RunMode.Test, + sparkDistribution = "", + fatjar = "src/test/resources/processor-logfile.jar", + checkpointDir = "file:///tmp/spark", "localhost:9877", false, 16, 0,0 ) + plan.submitTimeout shouldBe (60000) + } + + "startingTimeout propery exist" should "reflect value" in { + System.setProperty("jubaql.gateway.submitTimeout", "12345") + val plan = new GatewayPlan("example.com", 1234, + Array(), RunMode.Test, + sparkDistribution = "", + fatjar = "src/test/resources/processor-logfile.jar", + checkpointDir = "file:///tmp/spark", "localhost:9877", false, 16, 0,0 ) + plan.submitTimeout shouldBe (12345) + } + + "startingTimeout propery illegal value" should "reflect default value" in { + System.setProperty("jubaql.gateway.submitTimeout", "fail") + val plan = new GatewayPlan("example.com", 1234, + Array(), RunMode.Test, + sparkDistribution = "", + fatjar = "src/test/resources/processor-logfile.jar", + checkpointDir = "file:///tmp/spark", "localhost:9877", false, 16, 0,0 ) + plan.submitTimeout shouldBe (60000) + } + + "startingTimeout propery negative value" should "reflect default value" in { + System.setProperty("jubaql.gateway.submitTimeout", "-30000") + val plan = new GatewayPlan("example.com", 1234, + Array(), RunMode.Test, + sparkDistribution = "", + fatjar = "src/test/resources/processor-logfile.jar", + checkpointDir = "file:///tmp/spark", "localhost:9877", false, 16, 0,0 ) + plan.submitTimeout shouldBe (60000) + } +} \ No newline at end of file diff --git a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/GatewayServer.scala b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/GatewayServer.scala index d1d531a..6971b1f 100644 --- a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/GatewayServer.scala +++ b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/GatewayServer.scala @@ -16,22 +16,64 @@ package us.jubat.jubaql_server.gateway import org.scalatest.{Suite, BeforeAndAfterAll} +import scala.sys.process.ProcessLogger +import org.apache.curator.test.TestingServer +import unfiltered.netty.Server -trait GatewayServer extends BeforeAndAfterAll { +trait GatewayServer extends BeforeAndAfterAll with HasSpark { this: Suite => + val zkServer = new TestingServer(2181,true) + protected val plan = new GatewayPlan("example.com", 1234, Array(), RunMode.Test, sparkDistribution = "", fatjar = "src/test/resources/processor-logfile.jar", - checkpointDir = "file:///tmp/spark") + checkpointDir = "file:///tmp/spark", "localhost:9877", false, 16, 0, 0) protected val server = unfiltered.netty.Server.http(9877).plan(plan) + val replan = new GatewayPlan("example.com", 1234, + Array(), RunMode.Test, + sparkDistribution = "", + fatjar = "src/test/resources/processor-logfile.jar", + checkpointDir = "file:///tmp/spark", "localhost:9877", false, 16, 0, 0) + protected val reserver = unfiltered.netty.Server.http(9877).plan(replan) + + //run.mode=production指定サーバ + protected val pro_plan = new GatewayPlan("example.com", 1235, + Array(), RunMode.Production("localhost"), + sparkDistribution = sparkPath, + fatjar = "src/test/resources/processor-logfile.jar", + checkpointDir = "file:///tmp/spark", "localhost:9878", false, 16, 0, 0) + protected val pro_server = unfiltered.netty.Server.http(9878).plan(pro_plan) + + //run.mode=development指定サーバ + protected val dev_plan = new GatewayPlan("example.com", 1236, + Array(), RunMode.Development(1), + sparkDistribution = sparkPath, + fatjar = "src/test/resources/processor-logfile.jar", + checkpointDir = "file:///tmp/spark", "localhost:9879", false, 16, 0, 0) + protected val dev_server = unfiltered.netty.Server.http(9879).plan(dev_plan) + + protected val persist_plan = new GatewayPlan("example.com", 1237, + Array(), RunMode.Test, + sparkDistribution = "", + fatjar = "src/test/resources/processor-logfile.jar", + checkpointDir = "file:///tmp/spark", "localhost:9880", true, 16, 0, 0) + protected val persist_server = unfiltered.netty.Server.http(9880).plan(persist_plan) + override protected def beforeAll() = { server.start() + pro_server.start() + dev_server.start() + persist_server.start() } override protected def afterAll() = { server.stop() + pro_server.stop() + dev_server.stop() + persist_server.stop() + zkServer.stop() } } diff --git a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/HasSpark.scala b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/HasSpark.scala new file mode 100644 index 0000000..c28f21c --- /dev/null +++ b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/HasSpark.scala @@ -0,0 +1,49 @@ +// Jubatus: Online machine learning framework for distributed environment +// Copyright (C) 2014-2015 Preferred Networks and Nippon Telegraph and Telephone Corporation. +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License version 2.1 as published by the Free Software Foundation. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +package us.jubat.jubaql_server.gateway + +import java.io.{FileInputStream, FileNotFoundException} +import java.util.Properties + +import org.scalatest._ + +trait HasSpark extends ShouldMatchers { + lazy val sparkPath: String = { + val properties = loadProperties() + properties.getProperty("spark_home_path") + } + + lazy val dummySparkPath: String = { + val properties = loadProperties() + properties.getProperty("dummy_spark_home_path") + } + + private def loadProperties(): Properties = { + val sparkConfig = "src/test/resources/spark.xml" + + val is = try { + Some(new FileInputStream(sparkConfig)) + } catch { + case _: FileNotFoundException => + None + } + is shouldBe a[Some[_]] + + val properties = new Properties() + properties.loadFromXML(is.get) + properties + } +} diff --git a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/JubaQLGatewaySpec.scala b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/JubaQLGatewaySpec.scala new file mode 100644 index 0000000..7bf54d4 --- /dev/null +++ b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/JubaQLGatewaySpec.scala @@ -0,0 +1,140 @@ +// Jubatus: Online machine learning framework for distributed environment +// Copyright (C) 2015 Preferred Networks and Nippon Telegraph and Telephone Corporation. +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License version 2.1 as published by the Free Software Foundation. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +package us.jubat.jubaql_server.gateway + +import org.scalatest._ + +class JubaQLGatewaySpec extends FlatSpec with Matchers { + + "Parameter gatewyId" should "return gatewayid" in { + // 省略時。パラメータから取得するgatewayIDは空。後で[host_port]の形式で生成する。 + var result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost")) + result.get should not be None + result.get.gatewayId shouldBe ("") + // 1文字オプションによるgatewayID指定 + result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "-g", "gatewayid")) + result.get should not be None + result.get.gatewayId shouldBe ("gatewayid") + // ロングオプションによるgatewayID指定 + result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--gatewayID", "gatewayid2")) + result.get should not be None + result.get.gatewayId shouldBe ("gatewayid2") + } + "Parameter persist" should "return persist" in { + // 省略時 + var result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost")) + result.get should not be None + result.get.persist shouldBe (false) + // オプション指定 + result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--persist")) + result.get should not be None + result.get.persist shouldBe (true) + } + "Illegal parameter" should "stderr usage" in { + // 不正オプション時のusgageチェック + val result = JubaQLGateway.parseCommandlineOption(Array("-i")) + result shouldBe (None) + // Usageは目視確認 + } + + "Parameter threads" should "return threads" in { + // 省略時 + var result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost")) + result.get should not be None + result.get.threads shouldBe (16) + // オプション指定 + result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--threads", "32")) + result.get should not be None + result.get.threads shouldBe (32) + // オプション指定 最小値 + result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--threads", "1")) + result.get should not be None + result.get.threads shouldBe (1) + // オプション指定 最大値 + result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--threads", s"${Int.MaxValue}")) + result.get should not be None + result.get.threads shouldBe (Int.MaxValue) + } + "Illegal parameter threads" should "out of range" in { + // オプション指定 範囲外 + var result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--threads", "0")) + result shouldBe (None) + // オプション指定 範囲外 + result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--threads", s"${Int.MaxValue + 1}")) + result shouldBe (None) + // オプション指定 指定なし + result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--threads", "")) + result shouldBe (None) + } + "Parameter channel_memory" should "return channelMemory" in { + // 省略時 + var result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost")) + result.get should not be None + result.get.channelMemory shouldBe (65536) + // オプション指定 + result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--channel_memory", "256")) + result.get should not be None + result.get.channelMemory shouldBe (256) + // オプション指定 最小値 + result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--channel_memory", "0")) + result.get should not be None + result.get.channelMemory shouldBe (0) + // オプション指定 最大値 + result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--channel_memory", s"${Long.MaxValue}")) + result.get should not be None + result.get.channelMemory shouldBe (Long.MaxValue) + } + "Illegal parameter channel_memory" should "out of range" in { + // オプション指定 範囲外 + var result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--channel_memory", "-1")) + result shouldBe (None) + // オプション指定 範囲外 + result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--channel_memory", s"${Long.MaxValue + 1}")) + result shouldBe (None) + // オプション指定 指定なし + result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--channel_memory", "")) + result shouldBe (None) + } + "Parameter total_memory" should "return total_memory" in { + // 省略時 + var result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost")) + result.get should not be None + result.get.totalMemory shouldBe (1048576) + // オプション指定 + result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--total_memory", "256")) + result.get should not be None + result.get.totalMemory shouldBe (256) + // オプション指定 最小値 + result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--total_memory", "0")) + result.get should not be None + result.get.totalMemory shouldBe (0) + // オプション指定 最大値 + result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--total_memory", s"${Long.MaxValue}")) + result.get should not be None + result.get.totalMemory shouldBe (Long.MaxValue) + } + "Illegal parameter total_memory" should "out of range" in { + // オプション指定 範囲外 + var result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--total_memory", "-1")) + result shouldBe (None) + // オプション指定 範囲外 + result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--total_memory", s"${Long.MaxValue + 1}")) + result shouldBe (None) + // オプション指定 指定なし + result = JubaQLGateway.parseCommandlineOption(Array("-i", "localhost", "--total_memory", "")) + result shouldBe (None) + } +} \ No newline at end of file diff --git a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/JubaQLSpec.scala b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/JubaQLSpec.scala index 789ab94..c3d8e78 100644 --- a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/JubaQLSpec.scala +++ b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/JubaQLSpec.scala @@ -15,23 +15,31 @@ // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA package us.jubat.jubaql_server.gateway +import us.jubat.jubaql_server.gateway.json.SessionId import us.jubat.jubaql_server.gateway.json.Query import org.scalatest._ import EitherValues._ import dispatch._ import dispatch.Defaults._ -import org.json4s.DefaultFormats -import org.json4s.native.Serialization.write +import org.json4s._ +import org.json4s.Formats._ +import org.json4s.native.Serialization.{read, write} +import org.json4s.native.JsonMethods._ +import org.json4s.JsonDSL._ // We use mock processor in this test, so queries are dummies. class JubaQLSpec extends FlatSpec with Matchers with ProcessorAndGatewayServer { val jubaQlUrl = :/("localhost", 9877) / "jubaql" + val persist_jubaQlUrl = :/("localhost", 9880) / "jubaql" + val register_persist_jubaQlUrl = :/("localhost", 9880) / "registration" + val persist_loginUrl = :/("localhost", 9880) / "login" + implicit val formats = DefaultFormats - def requestAsJson() = { - val request = (jubaQlUrl).POST + def requestAsJson(url : Req = jubaQlUrl) = { + val request = (url).POST request.setContentType("application/json", "UTF-8") } @@ -56,6 +64,17 @@ class JubaQLSpec extends FlatSpec with Matchers with ProcessorAndGatewayServer { result.right.value.getContentType should include("charset=utf-8") } + "Posting jubaql with inconsistent data" should "fail" in { + plan.sessionManager.session2key += ("sessionId" -> null) + plan.sessionManager.key2session += ("sessionKey" -> "persistSessionId") + plan.sessionManager.session2loc += ("sessionId" -> ("localhost", 9876)) + + val request = requestAsJson(jubaQlUrl) << write(Query("sessionId", "query")).toString + val result = Http(request > (x => x)).either.apply() + result.right.value.getStatusCode shouldBe 500 + result.right.value.getContentType should include("charset=utf-8") + } + "Posting jubaql without query" should "fail" in { val request = requestAsJson() << f"""{"session_id": "$session%s"}""" val result = Http(request > (x => x)).either.apply() @@ -76,4 +95,105 @@ class JubaQLSpec extends FlatSpec with Matchers with ProcessorAndGatewayServer { result.right.value.getStatusCode shouldBe 200 result.right.value.getContentType should include("charset=utf-8") } + + // persist + // sesssion Idなし + "Posting jubaql with unknown session id in persist" should "fail" in { + println(session) + val request = requestAsJson(persist_jubaQlUrl) << write(Query(session, "query")).toString + val result = Http(request > (x => x)).either.apply() + result.right.value.getStatusCode shouldBe 401 + result.right.value.getContentType should include("charset=utf-8") + } + + // register yet + "Posting jubaql with register yet in persist" should "fail" in { + val connectedSessionId = { + val req = Http(persist_loginUrl.POST OK as.String) + req.option.apply.get + } + println("first connect : " + connectedSessionId) + // Register + val json = parseOpt(connectedSessionId) + val sessionId = json.get.extractOpt[SessionId].get.session_id + + val request = requestAsJson(persist_jubaQlUrl) << write(Query(sessionId, "query")).toString + val result = Http(request > (x => x)).either.apply() + result.right.value.getStatusCode shouldBe 503 + result.right.value.getContentType should include("charset=utf-8") + } + + // キャッシュからセッション取得 + "Posting jubaql with exist session cache in persist" should "succeed" in { + persist_plan.sessionManager.session2key += ("persistSessionId" -> "persistSessionKey") + persist_plan.sessionManager.session2loc += ("persistSessionId" -> ("localhost", 9876)) + val request = requestAsJson(persist_jubaQlUrl) << write(Query("persistSessionId", "query")).toString + val result = Http(request > (x => x)).either.apply() + result.right.value.getStatusCode shouldBe 200 + result.right.value.getContentType should include("charset=utf-8") + } + + // Zookeeperにセッションあり(キャッシュはなし) + "Posting jubaql with exist session zookeeper in persist" should "succeed" in { + val connectedSessionId = { + val req = Http(persist_loginUrl.POST OK as.String) + req.option.apply.get + } + println("first connect : " + connectedSessionId) + //Register + val json = parseOpt(connectedSessionId) + val sessionId = json.get.extractOpt[SessionId].get.session_id + val key = persist_plan.sessionManager.session2key.get(sessionId).get + val registJsonString = """{ "action": "register", "ip": "localhost", "port": 9876 }""" + Http(register_persist_jubaQlUrl./(key).POST << registJsonString).either.apply() match { + case Right(resultJson) => + println("register success: " + resultJson.getResponseBody) + case Left(t) => + t.printStackTrace() + fail(t) + } + // キャッシュを削除 + persist_plan.sessionManager.session2key -= sessionId + persist_plan.sessionManager.key2session -= key + persist_plan.sessionManager.session2loc -= sessionId + + val request = requestAsJson(persist_jubaQlUrl) << write(Query(sessionId, "query")).toString + val result = Http(request > (x => x)).either.apply() + result.right.value.getStatusCode shouldBe 200 + result.right.value.getContentType should include("charset=utf-8") + } + + // Zookeeperへの接続失敗 + "Posting jubaql with failed connect zookeeper in persist" should "fail" in { + val connectedSessionId = { + val req = Http(persist_loginUrl.POST OK as.String) + req.option.apply.get + } + println("first connect : " + connectedSessionId) + // Register + val json = parseOpt(connectedSessionId) + val sessionId = json.get.extractOpt[SessionId].get.session_id + val key = persist_plan.sessionManager.session2key.get(sessionId).get + val registJsonString = """{ "action": "register", "ip": "localhost", "port": 9876 }""" + Http(register_persist_jubaQlUrl./(key).POST << registJsonString).either.apply() match { + case Right(resultJson) => + println("register success: " + resultJson.getResponseBody) + case Left(t) => + t.printStackTrace() + fail(t) + } + // 試験用にキャッシュ情報を削除 + persist_plan.sessionManager.session2key -= sessionId + persist_plan.sessionManager.key2session -= key + persist_plan.sessionManager.session2loc -= sessionId + + // Zookeeper停止 + zkServer.stop() + + val request = requestAsJson(persist_jubaQlUrl) << write(Query(sessionId, "query")).toString + val result = Http(request > (x => x)).either.apply() + println(result.right.value.getResponseBody) + result.right.value.getStatusCode shouldBe 500 + result.right.value.getContentType should include("charset=utf-8") + } } diff --git a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/LoginSpec.scala b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/LoginSpec.scala index b804266..490e574 100644 --- a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/LoginSpec.scala +++ b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/LoginSpec.scala @@ -23,12 +23,28 @@ import org.json4s._ import org.json4s.Formats._ import org.json4s.native.Serialization.{read, write} import org.json4s.native.JsonMethods._ +import org.json4s.JsonDSL._ +import scala.concurrent.duration.Duration +import scala.concurrent.Await +import scala.async.Async.{async, await} +import scala.concurrent.ExecutionContext +import unfiltered.netty.Server +import scala.util.Success -class LoginSpec extends FlatSpec with Matchers with GatewayServer { +class LoginSpec extends FlatSpec with Matchers with GatewayServer with BeforeAndAfter { implicit val formats = DefaultFormats + before { + zkServer.restart + } + val url = :/("localhost", 9877) / "login" + val regist_url = :/("localhost", 9877) / "registration" + val pro_url = :/("localhost", 9878) / "login" + val dev_url = :/("localhost", 9879) / "login" + val persist_url = :/("localhost", 9880) / "login" + val persist_regist_url = :/("localhost", 9880) / "registration" "POST to /login" should "return something" in { val req = Http(url.POST OK as.String) @@ -52,5 +68,553 @@ class LoginSpec extends FlatSpec with Matchers with GatewayServer { val maybeSessionId = maybeJson.get.extractOpt[SessionId] maybeSessionId should not be None maybeSessionId.get.session_id.length should be > 0 + + } + + "POST to /login reconnect" should "not exist sessionId. return Error Response(unknown sessioId)" in { + // + val sessionId = "testSessionId" + val payloadData = ("session_id" -> sessionId) + val json: String = compact(render(payloadData)) + + val req = Http(url.POST << json)// OK as.String) + req.either.apply() match { + case Right(response) => + response.getStatusCode shouldEqual(401) + response.getResponseBody shouldBe("Unknown session_id") + case Left(t) => + t.printStackTrace() + fail(t) + } + } + + "POST to /login reconnect" should "session inconsistency. return Error Response(inconsist data)" in { + // 初回接続 + val connectedSessionId = connection(url) + + println("first connect : " + connectedSessionId) + //session情報を不整合状態に書き換え + plan.sessionManager.session2key -= connectedSessionId + plan.sessionManager.session2key += connectedSessionId -> null + plan.sessionManager.session2loc += connectedSessionId -> ("dummyHost", 9999) + + //接続済sessionIdを利用した接続 + val payloadData = ("session_id" -> connectedSessionId) + val json: String = compact(render(payloadData)) + + val req = Http(url.POST << json) + req.either.apply() match { + case Right(response) => + println(response.getResponseBody) + response.getStatusCode shouldEqual(500) + response.getResponseBody shouldBe("Failed to get session") + case Left(t) => + t.printStackTrace() + fail(t) + } + } + + "POST to /login reconnect" should "exist sessionId, but registered yet. return Error Response(registered yet)" in { + // 初回接続 + val connectedSessionId = connection(url) + + println("first connect : " + connectedSessionId) + //接続済sessionIdを利用した接続 + val payloadData = ("session_id" -> connectedSessionId) + val json: String = compact(render(payloadData)) + + val req = Http(url.POST << json) + req.either.apply() match { + case Right(response) => + response.getStatusCode shouldEqual(503) + response.getResponseBody shouldBe("This session has not been registered. Wait a second.") + case Left(t) => + t.printStackTrace() + fail(t) + } + } + + // RunMode.testで実施(非永続化) + "POST to /login reconnect" should "success reconnect" in { + // 初回接続 + val connectedSessionId = connection(url) + + println("first connect : " + connectedSessionId) + + //疑似Registration + registration(regist_url,plan,connectedSessionId) + + //接続済sessionIdを利用した接続 + reconnection(url,connectedSessionId) + } + + // 非永続化における再接続(接続失敗ケース) + "POST to /login reconnect" should "reboot gateaway" in { + // 初回接続 + val connectedSessionId = connection(url) + + println("first connect : " + connectedSessionId) + + //疑似Registration + registration(regist_url,plan,connectedSessionId) + + // Gateway再起動 + server.stop() + reserver.start() + + // 接続済sessionIdを利用した接続 + val payloadData = ("session_id" -> connectedSessionId) + val json: String = compact(render(payloadData)) + + val req = Http(url.POST << json) + + // 永続化していないため、エラーレスポンス返却 + req.option.apply.get.getStatusCode shouldBe(401) + req.option.apply.get.getResponseBody shouldBe("Unknown session_id") + reserver.stop() + } + + // 永続化確認 + "reconnect in persist" should "save session to zookeeper" in { + // 初回接続 + val connectedSessionId = connection(persist_url) + + println("first connect : " + connectedSessionId) + + //疑似Registration + registration(persist_regist_url, persist_plan, connectedSessionId) + + //接続済sessionIdを利用した接続 + reconnection(persist_url, connectedSessionId) + + // キャッシュへのセッション情報格納を確認 + val cacheKey = persist_plan.sessionManager.session2key.get(connectedSessionId) + val cacheSessionId = persist_plan.sessionManager.key2session.get(cacheKey.get) + val cacheLocation = persist_plan.sessionManager.session2loc.get(connectedSessionId) + cacheKey.get.length() should be > 0 + cacheSessionId.get shouldBe (connectedSessionId) + cacheLocation.get shouldBe (("dummyhost", 1111)) + + // zkへのセッション情報格納を確認 + val storeSession: SessionState = persist_plan.sessionManager.getSessionFromStore(connectedSessionId) + storeSession shouldBe SessionState.Ready("dummyhost", 1111, cacheKey.get) + val completeSession = storeSession.asInstanceOf[SessionState.Ready] + completeSession.host shouldBe "dummyhost" + completeSession.port shouldBe 1111 + completeSession.key shouldBe cacheKey.get + } + + // Gatewayサーバ再起動によるセッション接続 + "reconnect in persist" should "reboot gateway" in { + // 初回接続 + val firstplan = new GatewayPlan("example.com", 1880, + Array(), RunMode.Test, + sparkDistribution = "", + fatjar = "src/test/resources/processor-logfile.jar", + checkpointDir = "file:///tmp/spark", "localhost:8880", true, 16, 0, 0) + val firstserver = unfiltered.netty.Server.http(8880).plan(firstplan) + firstserver.start() + + val url1 = :/("localhost", 8880) / "login" + val regist_url1 = :/("localhost", 8880) / "registration" + val connectedSessionId = connection(url1) + + println("first connect : " + connectedSessionId) + + // 疑似Registration + registration(regist_url1, firstplan, connectedSessionId) + + // Gateway停止 + firstserver.stop() + + // Gateway再起動 + val secondplan = new GatewayPlan("example.com", 1880, + Array(), RunMode.Test, + sparkDistribution = "", + fatjar = "src/test/resources/processor-logfile.jar", + checkpointDir = "file:///tmp/spark", "localhost:8880", true, 16, 0, 0) + val secondserver = unfiltered.netty.Server.http(8880).plan(secondplan) + secondserver.start() + + // 接続済sessionIdを利用した接続 + reconnection(url1, connectedSessionId) + + // キャッシュへのセッション情報格納を確認 + val cacheKey = secondplan.sessionManager.session2key.get(connectedSessionId) + val cacheSessionId = secondplan.sessionManager.key2session.get(cacheKey.get) + val cacheLocation = secondplan.sessionManager.session2loc.get(connectedSessionId) + cacheKey.get.length() should be > 0 + cacheSessionId.get shouldBe (connectedSessionId) + cacheLocation.get shouldBe (("dummyhost", 1111)) + + // zookeeperへのセッション情報格納を確認 + val storeSession: SessionState = secondplan.sessionManager.getSessionFromStore(connectedSessionId) + storeSession.isInstanceOf[SessionState.Ready].shouldBe(true) + val completeSession = storeSession.asInstanceOf[SessionState.Ready] + completeSession.host shouldBe "dummyhost" + completeSession.port shouldBe 1111 + completeSession.key shouldBe cacheKey.get + secondserver.stop() + firstplan.close() + secondplan.close() + } + + "zookeeper connection failed" should "throw Exception" in { + zkServer.stop() + try { + new GatewayPlan("example.com", 1237, + Array(), RunMode.Test, + sparkDistribution = "", + fatjar = "src/test/resources/processor-logfile.jar", + checkpointDir = "file:///tmp/spark", "localhost:9877", true, 16, 0, 0) + fail() + } catch { + case e: Exception => + e.getMessage().shouldBe("failed to connected zookeeper") + } + } + + "POST to /login stop zookeeper" should "failed login" in { + zkServer.stop() + val req = Http(persist_url.POST) + req.option.apply.get.getStatusCode.shouldBe(500) + req.option.apply.get.getResponseBody.shouldBe("Failed to create session") + } + + "POST to /login stop zookeeper(reconnect)" should "failed login" in { + zkServer.stop() + val payloadData = ("session_id" -> "testSessionId") + val json: String = compact(render(payloadData)) + + val req = Http(persist_url.POST << json) + req.option.apply.get.getStatusCode.shouldBe(500) + req.option.apply.get.getResponseBody.shouldBe("Failed to get session") + } + + "POST to /login clustering" should "succeed" in { + val url1 = :/("localhost", 9890) / "login" + val regist_url1 = :/("localhost", 9890) / "registration" + val url2 = :/("localhost", 9891) / "login" + val regist_url2 = :/("localhost", 9891) / "registration" + val clustername = "gwCluster" + + val node1_plan = new GatewayPlan("example.com", 1290, + Array(), RunMode.Test, + sparkDistribution = "", + fatjar = "src/test/resources/processor-logfile.jar", + checkpointDir = "file:///tmp/spark", clustername, true, 16, 0, 0) + val server1 = unfiltered.netty.Server.http(9890).plan(node1_plan) + val node2_plan = new GatewayPlan("example.com", 1291, + Array(), RunMode.Test, + sparkDistribution = "", + fatjar = "src/test/resources/processor-logfile.jar", + checkpointDir = "file:///tmp/spark", clustername, true, 16, 0, 0) + val server2 = unfiltered.netty.Server.http(9891).plan(node1_plan) + server1.start() + server2.start() + + val connectedSessionId = connection(url1) + registration(regist_url1, node1_plan, connectedSessionId) + reconnection(url2, connectedSessionId) + val registeringId = connection(url2) + reconnectionNJ(url1, registeringId) shouldBe "This session has not been registered. Wait a second." + } + + "spark-submit command succeed in developmentMode" should "succeed" in { + val url = :/("localhost", 1877) / "login" + val (server, plan) = startServer(RunMode.Development(), dummySparkPath) + val req = Http(url.POST OK as.String) + req.option.apply() should not be None + server.stop() + } + + "spark-submit command failed(fail dummySparkPath) in developmentMode" should "fail" in { + val url = :/("localhost", 1877) / "login" + val (server, plan) = startServer(RunMode.Development(), "failpath") + val req = Http(url.POST) + req.either.apply() match { + case Right(res) => + server.stop() + res.getStatusCode shouldBe(500) + res.getResponseBody should startWith("Failed to start Spark") + case Left(t) => + t.printStackTrace() + server.stop() + fail() + } + } + + "spark-submit command succeed in productionMode" should "succeed" in { + val url = :/("localhost", 1877) / "login" + val (server, plan) = startServer(RunMode.Production("localhost:2181"), dummySparkPath) + val req = Http(url.POST) + req.either.apply() match { + case Right(res) => + server.stop() + res.getStatusCode shouldBe(200) + res.getResponseBody should not be None + case Left(t) => + t.printStackTrace() + server.stop() + fail() + } + + } + + "spark-submit command failed(fail sparkpath) in productionMode" should "fail" in { + val url = :/("localhost", 1877) / "login" + val (server, plan) = startServer(RunMode.Production("localhost:2181"), "failpath") + val req = Http(url.POST) + req.either.apply() match { + case Right(res) => + server.stop() + res.getStatusCode shouldBe(500) + res.getResponseBody should startWith("Failed to start Spark") + case Left(t) => + t.printStackTrace() + fail() + } + } + + "spark-submit command failed(timeout) in productionMode" should "fail" in { + val url = :/("localhost", 1877) / "login" + System.setProperty("jubaql.gateway.submitTimeout", "10000") + val (server, plan) = startServer(RunMode.Production("localhost:2181"), dummySparkPath, "file:///timeout") + val req = Http(url.POST) + req.either.apply() match { + case Right(res) => + server.stop() + res.getStatusCode shouldBe(500) + res.getResponseBody should startWith("Failed to start Spark") + case Left(t) => + t.printStackTrace() + server.stop() + fail() + } + System.clearProperty("jubaql.gateway.submitTimeout") + + } + + "spark-submit command failed(return code: xx) in productionMode" should "fail" in { + val url = :/("localhost", 1877) / "login" + val (server, plan) = startServer(RunMode.Production("localhost:2181"), dummySparkPath, "file:///return") + val req = Http(url.POST) + req.either.apply() match { + case Right(res) => + server.stop() + res.getStatusCode shouldBe(500) + res.getResponseBody should startWith("Failed to start Spark") + case Left(t) => + t.printStackTrace() + server.stop() + fail() + } + } + + "spark-submit command failed(after runnnig returun code:0) in productionMode" should "succeed" in { + val url = :/("localhost", 1877) / "login" + val (server, plan) = startServer(RunMode.Production("localhost:2181"), dummySparkPath, "file:///after") + val req = Http(url.POST) + val sessionId = req.either.apply() match { + case Right(res) => + res.getStatusCode shouldBe(200) + val json = parseOpt(res.getResponseBody) + json.get.extractOpt[SessionId].get.session_id + case Left(t) => + t.printStackTrace() + server.stop() + fail() + } + Thread.sleep(10000) + val state = plan.sessionManager.getSession(sessionId) + state.isSuccess shouldBe true + state.get shouldBe(SessionState.NotFound) + server.stop() + } + + "spark-submit command failed(after runnnig returun code:10) in productionMode" should "succeed" in { + val url = :/("localhost", 1877) / "login" + val (server, plan) = startServer(RunMode.Production("localhost:2181"), dummySparkPath, "file:///afterFailed") + val req = Http(url.POST) + val sessionId = req.either.apply() match { + case Right(res) => + res.getStatusCode shouldBe(200) + val json = parseOpt(res.getResponseBody) + json.get.extractOpt[SessionId].get.session_id + case Left(t) => + t.printStackTrace() + server.stop() + fail() + } + Thread.sleep(10000) + val state = plan.sessionManager.getSession(sessionId) + state.isSuccess shouldBe true + state.get shouldBe(SessionState.NotFound) + server.stop() + } + + "spark-submit command failed(after runnnig throw Exception) in productionMode" should "succeed" in { + val url = :/("localhost", 1877) / "login" + val (server, plan) = startServer(RunMode.Production("localhost:2181"), dummySparkPath, "file:///exception") + val req = Http(url.POST) + val sessionId = req.either.apply() match { + case Right(res) => + res.getStatusCode shouldBe(200) + val json = parseOpt(res.getResponseBody) + json.get.extractOpt[SessionId].get.session_id + case Left(t) => + t.printStackTrace() + server.stop() + fail() + } + Thread.sleep(10000) + val state = plan.sessionManager.getSession(sessionId) + state.isSuccess shouldBe true + state.get shouldBe(SessionState.NotFound) + server.stop() + } + + "resource for gateway(non DriverMemory/ExecutorMemory)" should "succeed" in { + val url = :/("localhost", 1877) / "login" + val plan = new GatewayPlan("localhost", 1877, + Array(), RunMode.Production("localhost:2181", sparkDriverMemory=None, sparkExecutorMemory=None), + sparkDistribution = dummySparkPath, + fatjar = "src/test/resources/processor-logfile.jar", + "file:///tmp/spark", "localhost:1877", false, 16, 0, 0) + val server = unfiltered.netty.Server.http(1877).plan(plan) + server.start() + + val req = Http(url.POST) + req.either.apply() match { + case Right(res) => + server.stop() + res.getStatusCode shouldBe(200) + + plan.underlying.getMaxChannelMemorySize() shouldBe 0 + plan.underlying.getMaxTotalMemorySize() shouldBe 0 + plan.underlying.getCorePoolSize shouldBe 16 + + // sparkDriverMemory, sparkExecutorMemory は目視確認 + // executing: --driver-memoryなし, --executor-memoryなし + + case Left(t) => + t.printStackTrace() + server.stop() + fail() + } + } + + "resource for gateway(specified DriverMemory/ExecutorMemory)" should "succeed" in { + val url = :/("localhost", 1877) / "login" + val plan = new GatewayPlan("localhost", 1877, + Array(), RunMode.Production("localhost:2181", sparkDriverMemory=Some("256M"), sparkExecutorMemory=Some("2G")), + sparkDistribution = dummySparkPath, + fatjar = "src/test/resources/processor-logfile.jar", + "file:///tmp/spark", "localhost:1877", false, 32, 65536, 1048576) + val server = unfiltered.netty.Server.http(1877).plan(plan) + server.start() + + val req = Http(url.POST) + req.either.apply() match { + case Right(res) => + server.stop() + res.getStatusCode shouldBe(200) + res.getResponseBody should not be None + println(s"${res.getResponseBody}") + + plan.underlying.getMaxChannelMemorySize() shouldBe 65536 + plan.underlying.getMaxTotalMemorySize() shouldBe 1048576 + plan.underlying.getCorePoolSize shouldBe 32 + + // sparkDriverMemory, sparkExecutorMemory は目視確認 + // executing: --driver-memory 256M, --executor-memory 2G + + case Left(t) => + t.printStackTrace() + server.stop() + fail() + } + } + +// --------------------- + private def connection(url: Req): String = { + val req = Http(url.POST OK as.String) + req.option.apply() should not be None + val returnedString = req.option.apply.get + val maybeJson = parseOpt(returnedString) + maybeJson should not be None + val maybeSessionId = maybeJson.get.extractOpt[SessionId] + maybeSessionId should not be None + //セッションIDの返却チェック + maybeSessionId.get.session_id.length should be > 0 + maybeSessionId.get.session_id + } + + private def connectionNJ(url: Req): String = { + val req = Http(url.POST OK as.String) + req.option.apply() should not be None + val returnedString = req.option.apply.get + val maybeJson = parseOpt(returnedString) + if (maybeJson == None) { + "" + } else { + val maybeSessionId = maybeJson.get.extractOpt[SessionId] + if (maybeSessionId == None) { + "" + } else { + maybeSessionId.get.session_id + } + } + } + + private def reconnectionNJ(url: Req, sessionId: String):String = { + val payloadData = ("session_id" -> sessionId) + val json: String = compact(render(payloadData)) + val req = Http(url.POST << json) + req.option.apply match { + case Some(res) => + println(res.getResponseBody) + res.getResponseBody + case None => + "" + } + } + + private def reconnection(url:Req, sessionId: String) = { + val payloadData = ("session_id" -> sessionId) + val json: String = compact(render(payloadData)) + + val req = Http(url.POST << json) + + val returnedString = req.option.apply.get.getResponseBody + val maybeJson = parseOpt(returnedString) + maybeJson should not be None + val maybeSessionId = maybeJson.get.extractOpt[SessionId] + maybeSessionId should not be None + // セッションIDの返却チェック + maybeSessionId.get.session_id shouldBe (sessionId) + } + + private def registration(url: Req, plan: GatewayPlan, sessionId: String): Unit = { + val key = plan.sessionManager.session2key(sessionId) + val registJsonString = """{ "action": "register", "ip": "dummyhost", "port": 1111 }""" + Http(url./(key).POST << registJsonString).either.apply() match { + case Right(resultJson) => + println(resultJson.getResponseBody) + case Left(t) => + t.printStackTrace() + fail(t) + } + } + + private def startServer(runMode:RunMode, sparkPath:String, checkpointDir: String = "file:///tmp/spark"): (Server, GatewayPlan) = { + val plan = new GatewayPlan("localhost", 1877, + Array(), runMode, + sparkDistribution = sparkPath, + fatjar = "src/test/resources/processor-logfile.jar", + checkpointDir, "localhost:1877", false, 16, 0, 0) + val server = unfiltered.netty.Server.http(1877).plan(plan) + server.start() + (server, plan) } } diff --git a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/ProcessorAndGatewayServer.scala b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/ProcessorAndGatewayServer.scala index 50d4f59..6019127 100644 --- a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/ProcessorAndGatewayServer.scala +++ b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/ProcessorAndGatewayServer.scala @@ -39,9 +39,9 @@ trait ProcessorAndGatewayServer extends GatewayServer { ) override def beforeAll(): Unit = { - plan.session2key += (session -> key_) - plan.key2session += (key_ -> session) - plan.session2loc += (session -> loc) + plan.sessionManager.session2key += (session -> key_) + plan.sessionManager.key2session += (key_ -> session) + plan.sessionManager.session2loc += (session -> loc) super.beforeAll() processorMock.start() diff --git a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/RegisterSpec.scala b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/RegisterSpec.scala index c2cb8f7..eb168ff 100644 --- a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/RegisterSpec.scala +++ b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/RegisterSpec.scala @@ -24,19 +24,26 @@ import org.json4s.Formats._ import org.json4s.native.Serialization.{read, write} import org.json4s.native.JsonMethods._ import EitherValues._ +import us.jubat.jubaql_server.gateway.json.SessionId -class RegisterSpec extends FlatSpec with Matchers with GatewayServer { +class RegisterSpec extends FlatSpec with Matchers with GatewayServer with BeforeAndAfter { implicit val formats = DefaultFormats val loginUrl = :/("localhost", 9877) / "login" val registrationUrl = :/("localhost", 9877) / "registration" + val persist_loginUrl = :/("localhost", 9880) / "login" + val persist_registrationUrl = :/("localhost", 9880) / "registration" - def requestAsJson(key: String) = { - val request = (registrationUrl / key).POST + def requestAsJson(key: String, url: Req = registrationUrl) = { + val request = (url / key).POST request.setContentType("application/json", "UTF-8") } + before { + zkServer.restart + } + override def beforeAll(): Unit = { super.beforeAll() // login twice @@ -49,8 +56,8 @@ class RegisterSpec extends FlatSpec with Matchers with GatewayServer { def keys = { var keys = List.empty[String] // This lock is required because the server is in the same process. - plan.session2key.synchronized { - for (key <- plan.key2session.keys) + plan.sessionManager.session2key.synchronized { + for (key <- plan.sessionManager.key2session.keys) keys = key :: keys } keys @@ -78,7 +85,7 @@ class RegisterSpec extends FlatSpec with Matchers with GatewayServer { val registerJson = write(Register("register", "8.8.8.8", 30)) val requestWithBody = requestAsJson(garbageKey) << registerJson.toString val result = Http(requestWithBody OK as.String).either.apply() - result.left.value.getMessage shouldBe "Unexpected response status: 401" + result.left.value.getMessage shouldBe "Unexpected response status: 500" } } @@ -87,10 +94,10 @@ class RegisterSpec extends FlatSpec with Matchers with GatewayServer { val registerJson = write(Register("register", "8.8.8.8", 30)) val requestWithBody = requestAsJson(key) << registerJson.toString Http(requestWithBody OK as.String).option.apply() should not be None - plan.session2key.synchronized { - val session = plan.key2session.get(key) + plan.sessionManager.session2key.synchronized { + val session = plan.sessionManager.key2session.get(key) session should not be None - val maybeLoc = plan.session2loc.get(session.get) + val maybeLoc = plan.sessionManager.session2loc.get(session.get) maybeLoc should not be None val loc = maybeLoc.get loc shouldBe ("8.8.8.8", 30) @@ -106,10 +113,10 @@ class RegisterSpec extends FlatSpec with Matchers with GatewayServer { val registerJson = write(Register("register", ip, port)) val requestWithBody = requestAsJson(key) << registerJson.toString Http(requestWithBody OK as.String).option.apply() should not be None - plan.session2key.synchronized { - val session = plan.key2session.get(key) + plan.sessionManager.session2key.synchronized { + val session = plan.sessionManager.key2session.get(key) session should not be None - val maybeLoc = plan.session2loc.get(session.get) + val maybeLoc = plan.sessionManager.session2loc.get(session.get) maybeLoc should not be None val loc = maybeLoc.get loc shouldBe (ip, port) @@ -117,4 +124,94 @@ class RegisterSpec extends FlatSpec with Matchers with GatewayServer { } } } + +// persist + "Registering an existing key in persist" should "succeed" in { + List(0,1).par.foreach(key => { + val req = Http(persist_loginUrl.POST OK as.String) + req.option.apply() + }) + + var keys = List.empty[String] + // This lock is required because the server is in the same process. + persist_plan.sessionManager.session2key.synchronized { + for (key <- persist_plan.sessionManager.key2session.keys) + keys = key :: keys + } + keys.par.foreach(key => { + val registerJson = write(Register("register", s"8.8.8.8", 30)) + val requestWithBody = requestAsJson(key, persist_registrationUrl) << registerJson.toString + Http(requestWithBody OK as.String).option.apply() should not be None + persist_plan.sessionManager.session2key.synchronized { + val session = persist_plan.sessionManager.key2session.get(key) + session should not be None + val maybeLoc = persist_plan.sessionManager.session2loc.get(session.get) + maybeLoc should not be None + val loc = maybeLoc.get + loc shouldBe (s"8.8.8.8", 30) + } + }) + } + + "Registering an Unknown key in persist" should "fail" in { + val req = Http(persist_loginUrl.POST OK as.String) + var sessionId = req.option.apply() match { + case Some(res) => + parseOpt(res).get.extractOpt[SessionId].get.session_id + case None => fail() + } + + var keys = List.empty[String] + // This lock is required because the server is in the same process. + persist_plan.sessionManager.session2key.synchronized { + for (key <- persist_plan.sessionManager.key2session.keys) + keys = key :: keys + } + //キー削除 + val key = persist_plan.sessionManager.session2key.get(sessionId).get + persist_plan.sessionManager.key2session -= key + + val registerJson = write(Register("register", s"8.8.8.8", 30)) + val requestWithBody = requestAsJson(key, persist_registrationUrl) << registerJson.toString + val res = Http(requestWithBody).option.apply() + res match { + case Some(res) => + res.getStatusCode.shouldBe(500) + res.getResponseBody.shouldBe(s"Failed to register key : ${key}") + case None => + fail() + } + + } + + "Registering connect zookeeper failed in persist" should "fail" in { + val req = Http(persist_loginUrl.POST OK as.String) + val sessionId = req.option.apply() match { + case Some(res) => + parseOpt(res).get.extractOpt[SessionId].get.session_id + case None => fail() + } + + var keys = List.empty[String] + // This lock is required because the server is in the same process. + persist_plan.sessionManager.session2key.synchronized { + for (key <- persist_plan.sessionManager.key2session.keys) + keys = key :: keys + } + val key = persist_plan.sessionManager.session2key.get(sessionId).get + zkServer.stop() + + val registerJson = write(Register("register", s"8.8.8.8", 30)) + val requestWithBody = requestAsJson(key, persist_registrationUrl) << registerJson.toString + val res = Http(requestWithBody).option.apply() + res match { + case Some(res) => + println(res.getStatusCode) + println(res.getResponseBody) + res.getStatusCode.shouldBe(500) + res.getResponseBody.shouldBe(s"Failed to register key : ${key}") + case None => + fail() + } + } } diff --git a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/SessionManagerSpec.scala b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/SessionManagerSpec.scala new file mode 100644 index 0000000..10e3f19 --- /dev/null +++ b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/SessionManagerSpec.scala @@ -0,0 +1,375 @@ +// Jubatus: Online machine learning framework for distributed environment +// Copyright (C) 2015 Preferred Networks and Nippon Telegraph and Telephone Corporation. +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License version 2.1 as published by the Free Software Foundation. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +package us.jubat.jubaql_server.gateway + +import org.scalatest._ +import scala.util._ + + +class SessionManagerSpec extends FlatSpec with Matchers with BeforeAndAfterAll { + val sessionManager = new SessionManager("gatewayID", new DummyStore) + + "initCache()" should "succeed" in { + // コンストラクタで起動済みのセッション情報を利用 + sessionManager.session2key.get("dummySession1").get shouldBe("dummyKey1") + sessionManager.session2key.get("dummySession2").get shouldBe("dummyKey2") + sessionManager.key2session.get("dummyKey1").get shouldBe("dummySession1") + sessionManager.key2session.get("dummyKey2").get shouldBe("dummySession2") + sessionManager.session2loc.get("dummySession1").get shouldBe(("dummyHost1",11111)) + sessionManager.session2loc.get("dummySession2").get shouldBe(("dummyHost2",11112)) + } + + it should "fail" in { + try { + new SessionManager("addDeleteListenerFailed", new DummyStore) + fail() + } catch { + case e:Exception => e.getMessage shouldBe("addDeleteListenerFailed") + } + } + + "createNewSession()" should "succeed" in { + //parallel execute + var resultMap = Map.empty[String, String] + List(1,2,3,4,5).par.foreach { x => + val result = sessionManager.createNewSession() + result match { + case Success((id, key)) => + id.length should be > 0 + key.length should be > 0 + resultMap += id -> key + case Failure(t) => + t.printStackTrace() + fail() + } + } + for ((key,value) <- resultMap) { + sessionManager.session2key.get(key).get shouldBe value + sessionManager.key2session.get(value).get shouldBe key + } + } + it should "fail" in { + val sessionManager = new SessionManager("preregisterSessionFailed", new DummyStore) + val result = sessionManager.createNewSession() + result match { + case Success((id, key)) => + fail() + case Failure(t) => + t.getMessage shouldBe("preregisterSessionFailed") + sessionManager.session2key.size shouldBe 2 //dummysession + sessionManager.key2session.size shouldBe 2 //dummysession + sessionManager.session2loc.size shouldBe 2 //dummysession + } + } + + "getSession() find readyState in cache" should "succeed" in { + val readyResult = sessionManager.getSession("dummySession1") + readyResult match { + case Success(state) if state.isInstanceOf[SessionState.Ready] => + state.asInstanceOf[SessionState.Ready].host shouldBe("dummyHost1") + state.asInstanceOf[SessionState.Ready].port shouldBe(11111) + state.asInstanceOf[SessionState.Ready].key shouldBe("dummyKey1") + case Success(state) => + println(state) + fail() + case Failure(t) => + t.printStackTrace() + fail() + } + } + "getSession() find registeringState in cache" should "succeed" in { + //register registeringState + val (id,key) = sessionManager.createNewSession() match { + case Success((id, key)) => (id, key) + case Failure(t) => + t.printStackTrace() + fail() + } + println("--- get registeringState ---") + val registeringResult = sessionManager.getSession(id) + registeringResult match { + case Success(state) if state.isInstanceOf[SessionState.Registering] => + state.asInstanceOf[SessionState.Registering].key shouldBe(key) + case Success(state) => + println(state) + fail() + case Failure(t) => + t.printStackTrace() + fail() + } + } + "getSession() find inconsistantData in cache" should "succeed" in { + //register registeringState + val (id,key) = sessionManager.createNewSession() match { + case Success((id, key)) => (id, key) + case Failure(t) => + t.printStackTrace() + fail() + } + //不整合データ生成 + sessionManager.session2key -= id + sessionManager.session2key += id -> null + sessionManager.session2loc += id -> ("inconsistHost", 9999) + + println("--- get inconsitantData ---") + sessionManager.getSession(id) match { + case Success(state) => + fail() + case Failure(t) => + t.getMessage shouldBe(s"Inconsistent data. sessionId: ${id}") + } + } + + "getSession() find readyState in store" should "succeed" in { + val readyResult = sessionManager.getSession("ReadySession") + readyResult match { + case Success(state) if state.isInstanceOf[SessionState.Ready] => + state.asInstanceOf[SessionState.Ready].host shouldBe("readyHost") + state.asInstanceOf[SessionState.Ready].port shouldBe(12345) + state.asInstanceOf[SessionState.Ready].key shouldBe("readyKey") + sessionManager.session2key.get("ReadySession").get shouldBe("readyKey") + sessionManager.key2session.get("readyKey").get shouldBe("ReadySession") + sessionManager.session2loc.get("ReadySession").get shouldBe(("readyHost",12345)) + case Success(state) => + println(state) + fail() + case Failure(t) => + t.printStackTrace() + fail() + } + } + "getSession() find registeringState in store" should "succeed" in { + val registeringResult = sessionManager.getSession("RegisteringSession") + registeringResult match { + case Success(state) if state.isInstanceOf[SessionState.Registering] => + state.asInstanceOf[SessionState.Registering].key shouldBe("registeringKey") + sessionManager.session2key.get("RegisteringSession") shouldBe None + case Success(state) => + println(state) + fail() + case Failure(t) => + t.printStackTrace() + fail() + } + } + "getSession() find notfoundState in store" should "succeed" in { + val notfoundResult = sessionManager.getSession("NotfoundSession") + notfoundResult match { + case Success(state) => + state shouldBe SessionState.NotFound + sessionManager.session2key.get("NotfoundSession") shouldBe None + case Failure(t) => + t.printStackTrace() + fail() + } + } + "getSession() find inconsistantState in store" should "succeed" in { + val inconsistantResult = sessionManager.getSession("InconsistentSession") + inconsistantResult match { + case Success(state) => + fail() + case Failure(t) => + t.getMessage shouldBe("Inconsistent data. sessionId: InconsistentSession") + sessionManager.session2key.get("InconsistentSession") shouldBe None + } + } + "getSession() sessionStore throw Exception" should "fail" in { + val sessionManager = new SessionManager("getSessionFailed", new DummyStore) + val result = sessionManager.getSession("NonSession") + result match { + case Success(state) => + fail() + case Failure(t) => + t.getMessage shouldBe("getSessionFailed") + sessionManager.session2key.size shouldBe 2 //dummysession + sessionManager.key2session.size shouldBe 2 //dummysession + sessionManager.session2loc.size shouldBe 2 //dummysession + } + } + + "attachProcessorToSession() exist sessionId" should "succeed" in { + var resultMap = Map.empty[String, String] + List(1,2,3,4,5).par.foreach { x => + val (id, key) = sessionManager.createNewSession() match { + case Success((id, key)) => (id, key) + case Failure(t) => + t.printStackTrace() + fail() + } + val result = sessionManager.attachProcessorToSession("dummyHost" + x, x, key) + result match { + case Success(sessionId) => + sessionId shouldBe id + sessionManager.session2loc.get(id).get shouldBe ("dummyHost" + x, x) + case Failure(t) => + t.printStackTrace() + fail() + } + } + } + "attachProcessorToSession() non exist sessionId" should "fail" in { + val result = sessionManager.attachProcessorToSession("dummyHost", 9999, "NonKey") + result match { + case Success(sessionId) => + fail() + case Failure(t) => + t.getMessage shouldBe(s"non exist sessionId. key: NonKey") + sessionManager.key2session.get("NonKey") shouldBe None + } + } + "attachProcessorToSession() sessionStore throw Exception" should "fail" in { + val sessionManager = new SessionManager("registerSessionFailed", new DummyStore) + val (id, key) = sessionManager.createNewSession() match { + case Success((id, key)) => (id, key) + case Failure(t) => + t.printStackTrace() + fail() + } + val result = sessionManager.attachProcessorToSession("dummyHost", 9999, key) + result match { + case Success(sessionId) => + fail() + case Failure(t) => + t.getMessage shouldBe("registerSessionFailed") + sessionManager.session2loc.get(id) shouldBe None //dummysession + } + } + + "deleteSession() exist session" should "succeed" in { + val sessionManager = new SessionManager("test", new DummyStore) + val result = sessionManager.deleteSessionByKey("dummyKey1") + result match { + case Success((id, key)) => + id shouldBe("dummySession1") + key shouldBe("dummyKey1") + sessionManager.session2key.get("dummySession1") shouldBe None + sessionManager.key2session.get("dummyKey1") shouldBe None + sessionManager.session2loc.get("dummySession1") shouldBe None + case Failure(t) => + t.printStackTrace() + fail() + } + } + "deleteSession() non exist session" should "succeed" in { + val result = sessionManager.deleteSessionByKey("dummyKey3") + result match { + case Success((id, key)) => + id shouldBe(null) + key shouldBe("dummyKey3") + case Failure(t) => + t.printStackTrace() + fail() + } + } + "deleteSession() sessionStore throw Exception" should "fail" in { + val sessionManager = new SessionManager("deleteSessionFailed", new DummyStore) + val result = sessionManager.deleteSessionByKey("dummyKey1") + result match { + case Success((id, key)) => + fail() + case Failure(t) => + t.getMessage shouldBe("deleteSessionFailed") + sessionManager.session2key.get("dummySession1") should not be None //don't delete session cache + sessionManager.key2session.get("dummyKey1") should not be None //don't delete session cache + sessionManager.session2loc.get("dummySession1") should not be None //don't delete session cache + } + } + + "deleteSessionById() exist session" should "succeed" in { + val sessionManager = new SessionManager("test", new DummyStore) + val result = sessionManager.deleteSessionById("dummySession1") + result match { + case Success((id, key)) => + id shouldBe("dummySession1") + key shouldBe("dummyKey1") + sessionManager.session2key.get("dummySession1") shouldBe None + sessionManager.key2session.get("dummyKey1") shouldBe None + sessionManager.session2loc.get("dummySession1") shouldBe None + case Failure(t) => + t.printStackTrace() + fail() + } + } + "deleteSessionById() non exist session" should "succeed" in { + val result = sessionManager.deleteSessionById("dummySession3") + result match { + case Success((id, key)) => + id shouldBe("dummySession3") + key shouldBe(null) + case Failure(t) => + t.printStackTrace() + fail() + } + } + "deleteSessionById() sessionStore throw Exception" should "fail" in { + val sessionManager = new SessionManager("deleteSessionByIdFailed", new DummyStore) + val result = sessionManager.deleteSessionById("dummyKey1") + result match { + case Success((id, key)) => + fail() + case Failure(t) => + t.getMessage shouldBe("deleteSessionByIdFailed") + sessionManager.session2key.get("dummySession1") should not be None //don't delete session cache + sessionManager.key2session.get("dummyKey1") should not be None //don't delete session cache + sessionManager.session2loc.get("dummySession1") should not be None //don't delete session cache + } + } + + "deleteFunction() exist session" should "succeed" in { + val sessionManager = new SessionManager("test", new DummyStore) + sessionManager.deleteFunction("dummySession1") + + sessionManager.session2key.get("dummySession1") shouldBe None + sessionManager.key2session.get("dummyKey1") shouldBe None + sessionManager.session2loc.get("dummySession1") shouldBe None + } + "deleteFunction() non exist session" should "succeed" in { + val sessionManager = new SessionManager("test", new DummyStore) + sessionManager.deleteFunction("NonSession") + //check don't delete + sessionManager.session2key.size shouldBe 2 + sessionManager.key2session.size shouldBe 2 + sessionManager.session2loc.size shouldBe 2 + } + + "lock()" should "succeed" in { + val lock = sessionManager.lock() + lock.lockObject should not be None + sessionManager.unlock(lock) + } + it should "fail" in { + val sessionManager = new SessionManager("lockFailed", new DummyStore) + try { + sessionManager.lock() + fail() + } catch { + case e:Exception => + e.getMessage shouldBe("lockFailed") + } + } + + "unlock() sessionStore throw Exception" should "succeed" in { + val sessionManager = new SessionManager("unLockFailed", new DummyStore) + try { + val lock = sessionManager.lock() + sessionManager.unlock(lock) + } catch { + case e: Exception => + fail() + } + } + +} \ No newline at end of file diff --git a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/UnregisterSpec.scala b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/UnregisterSpec.scala index fb0b895..f9e2fe8 100644 --- a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/UnregisterSpec.scala +++ b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/UnregisterSpec.scala @@ -33,12 +33,17 @@ class UnregisterSpec extends FlatSpec with Matchers with GatewayServer with Befo val registerJson = write(Register("register", "8.8.8.8", 30)).toString val unregisterJson = """{"action": "unregister"}""" - def requestAsJson(key: String) = { - val request = (registrationUrl / key).POST + val persist_loginUrl = :/("localhost", 9880) / "login" + val persist_registrationUrl = :/("localhost", 9880) / "registration" + + def requestAsJson(key: String, url: Req = registrationUrl) = { + val request = (url / key).POST request.setContentType("application/json", "UTF-8") } before { + zkServer.restart + // login twice for (i <- 0 until 2) { val req = Http(loginUrl.POST OK as.String) @@ -50,6 +55,19 @@ class UnregisterSpec extends FlatSpec with Matchers with GatewayServer with Befo val requestWithBody = requestAsJson(key) << registerJson Http(requestWithBody OK as.String).option.apply() } + + //persist gateway + // login twice + for (i <- 0 until 2) { + val req = Http(persist_loginUrl.POST OK as.String) + req.option.apply() + } + + // register all keys + for (key <- persist_keys) { + val requestWithBody = requestAsJson(key, persist_registrationUrl) << registerJson + Http(requestWithBody OK as.String).option.apply() + } } after { @@ -58,13 +76,28 @@ class UnregisterSpec extends FlatSpec with Matchers with GatewayServer with Befo val requestWithBody = requestAsJson(key) << unregisterJson Http(requestWithBody OK as.String).option.apply() } + + for (key <- persist_keys) { + val requestWithBody = requestAsJson(key, persist_registrationUrl) << unregisterJson + Http(requestWithBody OK as.String).option.apply() + } } def keys = { var keys = List.empty[String] // This lock is required because the server is in the same process. - plan.session2key.synchronized { - for (key <- plan.key2session.keys) + plan.sessionManager.session2key.synchronized { + for (key <- plan.sessionManager.key2session.keys) + keys = key :: keys + } + keys + } + + def persist_keys = { + var keys = List.empty[String] + // This lock is required because the server is in the same process. + persist_plan.sessionManager.session2key.synchronized { + for (key <- persist_plan.sessionManager.key2session.keys) keys = key :: keys } keys @@ -98,16 +131,179 @@ class UnregisterSpec extends FlatSpec with Matchers with GatewayServer with Befo "Unregistering an existing key" should "succeed" in { for (key <- keys) { var sessionId = "" - plan.session2key.synchronized { - sessionId = plan.key2session.get(key).get + plan.sessionManager.session2key.synchronized { + sessionId = plan.sessionManager.key2session.get(key).get } val requestWithBody = requestAsJson(key) << unregisterJson Http(requestWithBody OK as.String).option.apply() should not be None - plan.session2key.synchronized { - plan.session2key.get(sessionId) shouldBe None - plan.key2session.get(key) shouldBe None - plan.session2loc.get(sessionId) shouldBe None + plan.sessionManager.session2key.synchronized { + plan.sessionManager.session2key.get(sessionId) shouldBe None + plan.sessionManager.key2session.get(key) shouldBe None + plan.sessionManager.session2loc.get(sessionId) shouldBe None } } } + + // persist + "Unregistering an existing key in persist" should "succeed" in { + persist_keys.par.foreach(key => { + var sessionId = "" + persist_plan.sessionManager.session2key.synchronized { + sessionId = persist_plan.sessionManager.key2session.get(key).get + } + val requestWithBody = requestAsJson(key, persist_registrationUrl) << unregisterJson + Http(requestWithBody OK as.String).option.apply() should not be None + persist_plan.sessionManager.session2key.synchronized { + persist_plan.sessionManager.session2key.get(sessionId) shouldBe None + persist_plan.sessionManager.key2session.get(key) shouldBe None + persist_plan.sessionManager.session2loc.get(sessionId) shouldBe None + persist_plan.sessionManager.getSessionFromStore(sessionId) shouldBe SessionState.NotFound + } + }) + } + + "Unregistering an already delete key in persist" should "succeed" in { + persist_keys.par.foreach(key => { + var sessionId = "" + persist_plan.sessionManager.session2key.synchronized { + sessionId = persist_plan.sessionManager.key2session.get(key).get + } + println("sessionId : " + sessionId+ ", key : " + key) + val requestWithBody = requestAsJson(key, persist_registrationUrl) << unregisterJson + val result = Http(requestWithBody OK as.String).option.apply() + result should not be None + result.get shouldBe "Successfully unregistered" + persist_plan.sessionManager.session2key.synchronized { + persist_plan.sessionManager.session2key.get(sessionId) shouldBe None + persist_plan.sessionManager.key2session.get(key) shouldBe None + persist_plan.sessionManager.session2loc.get(sessionId) shouldBe None + persist_plan.sessionManager.getSessionFromStore(sessionId) shouldBe SessionState.NotFound + } + + val alredyDeleteresult = Http(requestWithBody OK as.String).option.apply() + alredyDeleteresult should not be None + alredyDeleteresult.get shouldBe "Successfully unregistered" + persist_plan.sessionManager.session2key.synchronized { + persist_plan.sessionManager.session2key.get(sessionId) shouldBe None + persist_plan.sessionManager.key2session.get(key) shouldBe None + persist_plan.sessionManager.session2loc.get(sessionId) shouldBe None + persist_plan.sessionManager.getSessionFromStore(sessionId) shouldBe SessionState.NotFound + } + + }) + } + + "Unregistering connect zookeeper failed in persist" should "fail" in { + zkServer.stop + persist_keys.par.foreach(key => { + var sessionId = "" + persist_plan.sessionManager.session2key.synchronized { + sessionId = persist_plan.sessionManager.key2session.get(key).get + } + val requestWithBody = requestAsJson(key, persist_registrationUrl) << unregisterJson + val result = Http(requestWithBody).either.apply() + result match { + case Left(t) => + println("hogehoge---------------------------: " + t.getMessage) + fail() + case Right(res) => + res.getResponseBody shouldBe s"Failed to unregister key : ${key}" + persist_plan.sessionManager.session2key.synchronized { + persist_plan.sessionManager.session2key.get(sessionId) should not be None + persist_plan.sessionManager.key2session.get(key) should not be None + persist_plan.sessionManager.session2loc.get(sessionId) should not be None + } + } + }) + zkServer.restart + } + + "Unregistering an existing key in JubaQLGateway Cluster" should "succeed" in { + + println("--- gateway cluster test setting ---") + val gateway1_loginUrl = :/("localhost", 9980) / "login" + val gateway2_loginUrl = :/("localhost", 9981) / "login" + val gateway1_registrationUrl = :/("localhost", 9980) / "registration" + val gateway2_registrationUrl = :/("localhost", 9981) / "registration" + + //gateway startup + val gateway1_plan = new GatewayPlan("example.com", 2345, + Array(), RunMode.Test, + sparkDistribution = "", + fatjar = "src/test/resources/processor-logfile.jar", + checkpointDir = "file:///tmp/spark", "gateway_cluster", true, 16, 0, 0) + val gateway1_server = unfiltered.netty.Server.http(9980).plan(gateway1_plan) + gateway1_server.start + + // login + for (i <- 0 until 2) { + val req = Http(gateway1_loginUrl.POST OK as.String) + req.option.apply() + } + // get key + var keys = List.empty[String] + gateway1_plan.sessionManager.session2key.synchronized { + for (key <- gateway1_plan.sessionManager.key2session.keys) + keys = key :: keys + } + // register all keys + for (key <- keys) { + val requestWithBody = requestAsJson(key, gateway1_registrationUrl) << registerJson + val res = Http(requestWithBody OK as.String).option.apply() + } + + //gateway1でregistrationが完了した時点で同クラスタに所属するgateway2を起動 + val gateway2_plan = new GatewayPlan("example.com", 2346, + Array(), RunMode.Test, + sparkDistribution = "", + fatjar = "src/test/resources/processor-logfile.jar", + checkpointDir = "file:///tmp/spark", "gateway_cluster", true, 16, 0, 0) + val gateway2_server = unfiltered.netty.Server.http(9981).plan(gateway2_plan) + gateway2_server.start + + keys.par.foreach(key => { + var sessionId = "" + gateway1_plan.sessionManager.session2key.synchronized { + sessionId = gateway1_plan.sessionManager.key2session.get(key).get + } + gateway1_plan.sessionManager.session2key.synchronized { + gateway1_plan.sessionManager.session2key.get(sessionId) should not be None + gateway1_plan.sessionManager.key2session.get(key) should not be None + gateway1_plan.sessionManager.session2loc.get(sessionId) should not be None + gateway1_plan.sessionManager.getSessionFromStore(sessionId).isInstanceOf[SessionState.Ready] shouldBe true + } + gateway2_plan.sessionManager.session2key.synchronized { + gateway2_plan.sessionManager.session2key.get(sessionId) should not be None + gateway2_plan.sessionManager.key2session.get(key) should not be None + gateway2_plan.sessionManager.session2loc.get(sessionId) should not be None + gateway2_plan.sessionManager.getSessionFromStore(sessionId).isInstanceOf[SessionState.Ready] shouldBe true + } + }) + println("--- unregistring ---") + keys.par.foreach(key => { + var sessionId = "" + gateway1_plan.sessionManager.session2key.synchronized { + sessionId = gateway1_plan.sessionManager.key2session.get(key).get + } + // セッション情報を登録したgateway以外でunregeister実行 + val requestWithBody = requestAsJson(key, gateway2_registrationUrl) << unregisterJson + Http(requestWithBody OK as.String).option.apply() should not be None + // 同クラスタのすべてのgatewayから該当セッション情報が削除されることを確認 + gateway1_plan.sessionManager.session2key.synchronized { + gateway1_plan.sessionManager.session2key.get(sessionId) shouldBe None + gateway1_plan.sessionManager.key2session.get(key) shouldBe None + gateway1_plan.sessionManager.session2loc.get(sessionId) shouldBe None + gateway1_plan.sessionManager.getSessionFromStore(sessionId) shouldBe SessionState.NotFound + } + gateway2_plan.sessionManager.session2key.synchronized { + gateway2_plan.sessionManager.session2key.get(sessionId) shouldBe None + gateway2_plan.sessionManager.key2session.get(key) shouldBe None + gateway2_plan.sessionManager.session2loc.get(sessionId) shouldBe None + gateway2_plan.sessionManager.getSessionFromStore(sessionId) shouldBe SessionState.NotFound + } + }) + println("--- unregistred ---") + gateway1_server.stop + gateway2_server.stop + } } diff --git a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/ZookeeperServer.scala b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/ZookeeperServer.scala new file mode 100644 index 0000000..4398af4 --- /dev/null +++ b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/ZookeeperServer.scala @@ -0,0 +1,36 @@ +// Jubatus: Online machine learning framework for distributed environment +// Copyright (C) 2015 Preferred Networks and Nippon Telegraph and Telephone Corporation. +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License version 2.1 as published by the Free Software Foundation. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +package us.jubat.jubaql_server.gateway + +import org.scalatest.{Suite, BeforeAndAfterAll, BeforeAndAfter} +import org.apache.curator.test.TestingServer + +trait ZookeeperServer extends BeforeAndAfter with BeforeAndAfterAll { + this: Suite => + + val zkServer = new TestingServer(2181, true) + + before { + zkServer.restart() + } + after { + zkServer.stop() + } + + override protected def afterAll() { + zkServer.close() + } +} \ No newline at end of file diff --git a/gateway/src/test/scala/us/jubat/jubaql_server/gateway/ZookeeperStoreSpec.scala b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/ZookeeperStoreSpec.scala new file mode 100644 index 0000000..3f1066c --- /dev/null +++ b/gateway/src/test/scala/us/jubat/jubaql_server/gateway/ZookeeperStoreSpec.scala @@ -0,0 +1,468 @@ +// Jubatus: Online machine learning framework for distributed environment +// Copyright (C) 2015 Preferred Networks and Nippon Telegraph and Telephone Corporation. +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License version 2.1 as published by the Free Software Foundation. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +package us.jubat.jubaql_server.gateway + +import org.scalatest._ +import scala.util._ +import java.util.concurrent.locks.ReentrantLock + +class ZookeeperStoreSpec extends FlatSpec with Matchers with ZookeeperServer { + val zkStore = new ZookeeperStore() + + var res: String = "" + var res2: String = "" + + "getAllSessions() NoNode" should "succeed" in { + val result = zkStore.getAllSessions("NoNode") + result.size shouldBe (0) + } + + "getAllSessions() Empty" should "succeed" in { + zkStore.registerGateway("getAllSessions") + val result = zkStore.getAllSessions("getAllSessions") + result.size shouldBe (0) + } + + "getAllSessions() Size=2" should "succeed" in { + zkStore.registerGateway("getAllSessions") + // set ReadySession + zkStore.preregisterSession("getAllSessions", "sessionID1", "DummyKey1") + zkStore.registerSession("getAllSessions", "sessionID1", "DummyHost1", 9999, "DummyKey1") + // set RegisteringSession + zkStore.preregisterSession("getAllSessions", "sessionID2", "DummyKey2") + + val result = zkStore.getAllSessions("getAllSessions") + result.size shouldBe (2) + result.get("sessionID1") should not be None + result.get("sessionID1").get.isInstanceOf[SessionState.Ready] shouldBe (true) + val getSession1 = result.get("sessionID1").get.asInstanceOf[SessionState.Ready] + getSession1.host shouldBe "DummyHost1" + getSession1.port shouldBe 9999 + getSession1.key shouldBe "DummyKey1" + result.get("sessionID2") should not be None + result.get("sessionID2").get.isInstanceOf[SessionState.Registering] shouldBe (true) + val getSession2 = result.get("sessionID2").get.asInstanceOf[SessionState.Registering] + getSession2.key shouldBe "DummyKey2" + } + + "getAllSessions() stop zookeeper" should "fail" in { + zkStore.registerGateway("getAllSessions") + zkServer.stop() + try { + zkStore.getAllSessions("getAllSessions") + fail + } catch { + case e: Exception => + e.getMessage shouldBe ("failed to get all session.") + } + } + + "getSession() non exist SessionNode" should "succeed" in { + zkStore.registerGateway("getSession") + val result = zkStore.getSession("getSession", "NoSession") + result shouldBe (SessionState.NotFound) + } + + "getSession() exist SessionNode" should "succeed" in { + zkStore.registerGateway("getSession") + // set ReadySession + zkStore.preregisterSession("getSession", "sessionID1", "DummyKey1") + zkStore.registerSession("getSession", "sessionID1", "DummyHost1", 9999, "DummyKey1") + // set RegisteringSession + zkStore.preregisterSession("getSession", "sessionID2", "DummyKey2") + // set InconsistantSession + zkStore.zookeeperClient.create().forPath(s"${zkStore.zkSessionPath}/getSession/sessionID3", s"""{"host":"DummyHost3", "port": 9999}""".getBytes("UTF-8")) + + val result1 = zkStore.getSession("getSession", "sessionID1") + val getSession1 = result1.asInstanceOf[SessionState.Ready] + getSession1.host shouldBe "DummyHost1" + getSession1.port shouldBe 9999 + getSession1.key shouldBe "DummyKey1" + val result2 = zkStore.getSession("getSession", "sessionID2") + val getSession2 = result2.asInstanceOf[SessionState.Registering] + getSession2.key shouldBe "DummyKey2" + val result3 = zkStore.getSession("getSession", "sessionID3") + result3 shouldBe (SessionState.Inconsistent) + val result4 = zkStore.getSession("getSession", "sessionID4") + result4 shouldBe (SessionState.NotFound) + } + + "getSession() illegal data" should "fail" in { + zkStore.registerGateway("getSession") + zkStore.zookeeperClient.create().forPath(s"${zkStore.zkSessionPath}/getSession/IllegalData", s"""{"host":"DummyHost", "port": "number", "key": "DummyKey"}""".getBytes("UTF-8")) + + try { + zkStore.getSession("getSession", "IllegalData") + fail + } catch { + case e: Exception => + e.getMessage shouldBe ("failed to get session. sessionId: IllegalData") + } + } + + "preregisterSession() non exist SessionNode" should "succeed" in { + zkStore.registerGateway("preregisterSession") + zkStore.preregisterSession("preregisterSession", "sessionID1", "DummyKey") + val result = zkStore.getSession("preregisterSession", "sessionID1") + result.isInstanceOf[SessionState.Registering] shouldBe (true) + val session = result.asInstanceOf[SessionState.Registering] + session.key shouldBe ("DummyKey") + } + + "preregisterSession() already exist SessionNode" should "fail" in { + zkStore.registerGateway("preregisterSession") + zkStore.preregisterSession("preregisterSession", "sessionID2", "DummyKey") + try { + zkStore.preregisterSession("preregisterSession", "sessionID2", "DummyKey2") + } catch { + case e: Exception => + e.getMessage shouldBe ("failed to pre-register session. sessionId: sessionID2") + // 先に登録したデータに影響がないことを確認 + val result = zkStore.getSession("preregisterSession", "sessionID2") + result.isInstanceOf[SessionState.Registering] shouldBe (true) + val session = result.asInstanceOf[SessionState.Registering] + session.key shouldBe ("DummyKey") + } + } + + "preregisterSession() stop zookeeper" should "fail" in { + zkStore.registerGateway("preregisterSession") + zkServer.stop() + try { + zkStore.preregisterSession("preregisterSession", "sessionID3", "DummyKey") + fail + } catch { + case e: Exception => + e.getMessage shouldBe ("failed to pre-register session. sessionId: sessionID3") + } + } + + "registerSession() non exist SessionNode" should "fail" in { + zkStore.registerGateway("registerSession") + try { + zkStore.registerSession("registerSession", "sessionID1", "DummyHost", 9999, "DummyKey") + } catch { + case e: Exception => + e.getMessage shouldBe ("failed to register session. sessionId: sessionID1") + val result = zkStore.getSession("registerSession", "sessionID1") + result shouldBe SessionState.NotFound + } + } + + "registerSession() exist SessionNode" should "fail" in { + zkStore.registerGateway("registerSession") + zkStore.preregisterSession("registerSession", "sessionID2", "DummyKey") + zkStore.registerSession("registerSession", "sessionID2", "DummyHost2", 9999, "DummyKey2") + val result = zkStore.getSession("registerSession", "sessionID2") + val getSession = result.asInstanceOf[SessionState.Ready] + getSession.host shouldBe "DummyHost2" + getSession.port shouldBe 9999 + getSession.key shouldBe "DummyKey2" + } + + "registerSession() alreadyRegistered SessionNode" should "fail" in { + zkStore.registerGateway("registerSession") + zkStore.preregisterSession("registerSession", "sessionID3", "DummyKey") + zkStore.registerSession("registerSession", "sessionID3", "DummyHost3", 9999, "DummyKey3") + + try { + zkStore.registerSession("registerSession", "sessionID3", "DummyHost4", 9999, "DummyKey4") + } catch { + case e: Exception => + e.getMessage shouldBe ("failed to register session. sessionId: sessionID3") + val result = zkStore.getSession("registerSession", "sessionID3") + val getSession = result.asInstanceOf[SessionState.Ready] + getSession.host shouldBe "DummyHost3" + getSession.port shouldBe 9999 + getSession.key shouldBe "DummyKey3" + } + } + + "registerSession() stop zookeeper" should "fail" in { + zkStore.registerGateway("registerSession") + zkStore.preregisterSession("registerSession", "sessionID4", "DummyKey") + + zkServer.stop + try { + zkStore.registerSession("registerSession", "sessionID4", "DummyHost4", 9999, "DummyKey4") + } catch { + case e: Exception => + e.getMessage shouldBe ("failed to register session. sessionId: sessionID4") + } + } + + "deleteSession() non exist SessionNode" should "succeed" in { + zkStore.registerGateway("deleteSession") + try { + zkStore.deleteSession("deleteSession", "sessionID1") + } catch { + case e: Exception => + e.printStackTrace() + fail + } + } + + "deleteSession() exist Registering SessionNode" should "succeed" in { + zkStore.registerGateway("deleteSession") + zkStore.preregisterSession("deleteSession", "sessionID1", "DummyKey1") + + zkStore.deleteSession("deleteSession", "sessionID1") + val result = zkStore.getSession("deleteSession", "sessionID1") + result shouldBe (SessionState.NotFound) + } + + "deleteSession() exist Registered SessionNode" should "succeed" in { + zkStore.registerGateway("deleteSession") + zkStore.preregisterSession("deleteSession", "sessionID2", "DummyKey2") + zkStore.registerSession("deleteSession", "sessionID2", "DummyHost2", 9999, "DummyKey2") + + zkStore.deleteSession("deleteSession", "sessionID2") + val result = zkStore.getSession("deleteSession", "sessionID2") + result shouldBe (SessionState.NotFound) + } + + "deleteSession() stop zookeeper" should "fail" in { + zkStore.registerGateway("deleteSession") + zkStore.preregisterSession("deleteSession", "sessionID3", "DummyKey3") + zkServer.stop + + try { + zkStore.deleteSession("deleteSession", "sessionID3") + fail + } catch { + case e: Exception => + e.getMessage shouldBe ("failed to delete session. sessionId: sessionID3") + } + } + + "lock() non exist SessionNode" should "succeed" in { + val result = zkStore.lock("lock") + result should not be null + zkStore.unlock(result) + } + + "lock() exist SessionNode" should "succeed" in { + zkStore.registerGateway("lock") + val result = zkStore.lock("lock") + result should not be null + zkStore.unlock(result) + } + + // "lock() mulitlock timeout" should "succeed" in { + // // 多重ロック+ロックタイムアウトの確認 + // // lockTimeout時間待ち合わせるため、実行する場合は、定数を変更する。 + // zkStore.registerGateway("lock") + // val result = zkStore.lock("lock") + // result should not be null + // val result2 = zkStore.lock("lock") + // result2 should not be null + // zkStore.unlock(result2) + // } + + "lock() stop zookeeper" should "fail" in { + zkStore.registerGateway("lock") + zkServer.stop + try { + zkStore.lock("lock") + fail + } catch { + case e: Exception => + e.getMessage shouldBe ("failed to create lock object. gatewayId: lock") + } + } + + "unlock() sessionLock null" should "fail" in { + zkStore.registerGateway("unlock") + try { + zkStore.unlock(null) + } catch { + case e: Exception => + e.getMessage shouldBe ("failed to unlock. session lock: null") + } + } + + "unlock() lockObject null" should "fail" in { + zkStore.registerGateway("unlock") + val sessionLock = new SessionLock(null) + try { + zkStore.unlock(sessionLock) + } catch { + case e: Exception => + e.getMessage shouldBe ("failed to unlock. lock object: null") + } + } + "unlock() illegal parameter" should "fail" in { + zkStore.registerGateway("unlock") + val illegalLockObj = new ReentrantLock() + val sessionLock = new SessionLock(illegalLockObj) + try { + zkStore.unlock(sessionLock) + } catch { + case e: Exception => + e.getMessage shouldBe ("failed to unlock. illegal lock object: class java.util.concurrent.locks.ReentrantLock") + } + } + + "unlock() stop zookeeper" should "fail" in { + zkStore.registerGateway("unlock") + val lockObj = zkStore.lock("unlock") + zkServer.stop + try { + zkStore.unlock(lockObj) + fail + } catch { + case e: Exception => + e.getMessage shouldBe ("failed to unlock") + } + } + + "addDeleteListener() add Session" should "succeed" in { + res = "" + zkStore.registerGateway("addDeleteListener") + zkStore.preregisterSession("addDeleteListener", "sessionID1", "dummyKey1") + + zkStore.addDeleteListener("addDeleteListener", "sessionID1", (str: String) => + { + res = str + println(s"call Dummy Delete Function. sessionId: ${str}") + fail() + }) + zkStore.deleteSession("addDeleteListener", "sessionID1") + //削除イベント処理待ち + Thread.sleep(1000) + res shouldBe ("sessionID1") + } + + "addDeleteListener() deleted Session" should "succeed" in { + res = "" + zkStore.registerGateway("addDeleteListener") + zkStore.preregisterSession("addDeleteListener", "sessionID1", "dummyKey1") + zkStore.addDeleteListener("addDeleteListener", "sessionID1", (str: String) => + { + res = str + println(s"call Dummy Delete Function. sessionId: ${str}") + }) + zkStore.deleteSession("addDeleteListener", "sessionID1") + //削除イベント処理待ち + Thread.sleep(1000) + res shouldBe ("sessionID1") + zkStore.deleteSession("addDeleteListener", "sessionID1") + } + + "addDeleteListener() add Multi Session" should "succeed" in { + res = "" + res2 = "" + zkStore.registerGateway("addDeleteListener") + zkStore.preregisterSession("addDeleteListener", "sessionID1", "dummyKey1") + zkStore.preregisterSession("addDeleteListener", "sessionID2", "dummyKey2") + zkStore.registerSession("addDeleteListener", "sessionID2", "dummyHost2", 9999, "dummyKey2") + + zkStore.addDeleteListener("addDeleteListener", "sessionID1", (str: String) => + { + res = str + println(s"call Dummy Delete Function1. sessionId: ${str}") + }) + zkStore.addDeleteListener("addDeleteListener", "sessionID2", (str: String) => + { + res2 = str + println(s"call Dummy Delete Function2. sessionId: ${str}") + }) + zkStore.deleteSession("addDeleteListener", "sessionID1") + //削除イベント処理待ち + Thread.sleep(1000) + res shouldBe ("sessionID1") + res2 shouldBe ("") + } + + "addDeleteListener() Re add Session" should "succeed" in { + zkStore.registerGateway("addDeleteListener") + zkStore.preregisterSession("addDeleteListener", "sessionID1", "dummyKey1") + zkStore.addDeleteListener("addDeleteListener", "sessionID1", (str: String) => + { + res = str + println(s"call Dummy Delete Function. sessionId: ${str}") + }) + zkStore.deleteSession("addDeleteListener", "sessionID1") + //削除イベント処理待ち + Thread.sleep(1000) + res shouldBe ("sessionID1") + + //Re: add deleteListener + zkStore.preregisterSession("addDeleteListener", "sessionID1", "dummyKey1") + res = "" + zkStore.addDeleteListener("addDeleteListener", "sessionID1", (str: String) => + { + res = str + println(s"call Dummy Delete Function. sessionId: ${str}") + }) + zkStore.deleteSession("addDeleteListener", "sessionID1") + //削除イベント処理待ち + Thread.sleep(1000) + res shouldBe ("sessionID1") + } + + "addDeleteListener() exclude SessionId" should "succeed" in { + res = "" + zkStore.registerGateway("addDeleteListener") + zkStore.zookeeperClient.create().forPath(s"${zkStore.zkSessionPath}/addDeleteListener/locks", """""".getBytes("UTF-8")) + zkStore.addDeleteListener("addDeleteListener", "locks", (str: String) => + { + res = str + println(s"call Dummy Delete Function. sessionId: ${str}") + }) + zkStore.deleteSession("addDeleteListener", "locks") + //削除イベント処理待ち + Thread.sleep(1000) + res shouldBe ("") + + zkStore.zookeeperClient.create().forPath(s"${zkStore.zkSessionPath}/addDeleteListener/leases", """""".getBytes("UTF-8")) + res = "" + zkStore.addDeleteListener("addDeleteListener", "leases", (str: String) => + { + res = str + println(s"call Dummy Delete Function. sessionId: ${str}") + }) + zkStore.deleteSession("addDeleteListener", "leases") + //削除イベント処理待ち + Thread.sleep(1000) + res shouldBe ("") + } + + "addDeleteListener() stop zookeeper" should "fail" in { + zkStore.registerGateway("addDeleteListener") + zkServer.stop + try { + val str = "hoge" + zkStore.addDeleteListener("addDeleteListener", "sessionID1", (str: String) => { println(s"Dummy Delete Function:${str}") }) + fail + } catch { + case e: Exception => + e.getMessage shouldBe ("failed to add delete listener. sessionId: sessionID1") + } + } + + "registerGateway() stop zookeeper" should "fail" in { + zkServer.stop + try { + zkStore.registerGateway("registerGateway") + fail + } catch { + case e: Exception => + e.getMessage shouldBe ("failed to registerGateway. gatewayId: registerGateway") + } + } + +} \ No newline at end of file diff --git a/processor/build.sbt b/processor/build.sbt index ce48bcc..498c185 100644 --- a/processor/build.sbt +++ b/processor/build.sbt @@ -31,7 +31,7 @@ libraryDependencies ++= Seq( "org.slf4j" % "slf4j-api" % "1.6.4", "org.slf4j" % "slf4j-log4j12" % "1.6.4", // Jubatus - "us.jubat" % "jubatus" % "0.7.1" + "us.jubat" % "jubatus" % "0.8.0" exclude("org.jboss.netty", "netty"), // jubatusonyarn "us.jubat" %% "jubatus-on-yarn-client" % "1.1" diff --git a/processor/src/main/resources/log4j.xml b/processor/src/main/resources/log4j.xml index 4aaa374..089d202 100644 --- a/processor/src/main/resources/log4j.xml +++ b/processor/src/main/resources/log4j.xml @@ -5,7 +5,7 @@ - + diff --git a/processor/src/main/scala/us/jubat/jubaql_server/processor/HybridProcessor.scala b/processor/src/main/scala/us/jubat/jubaql_server/processor/HybridProcessor.scala index eb2b97b..3e78bfe 100644 --- a/processor/src/main/scala/us/jubat/jubaql_server/processor/HybridProcessor.scala +++ b/processor/src/main/scala/us/jubat/jubaql_server/processor/HybridProcessor.scala @@ -33,7 +33,7 @@ import org.apache.spark.streaming.kafka.KafkaUtils import org.apache.spark.storage.StorageLevel import org.apache.spark.SparkContext._ import kafka.serializer.StringDecoder -import scala.collection.mutable.Queue +import scala.collection.mutable.{Queue, LinkedHashMap} import org.apache.spark.sql.{SchemaRDD, SQLContext} import org.apache.spark.sql.catalyst.types.StructType import org.json4s.JValue @@ -51,6 +51,17 @@ case object Running extends ProcessorState case object Finished extends ProcessorState +// an object describing the phase of the processing +sealed abstract class ProcessingPhase(val name: String) { + def getPhaseName(): String = { name } +} + +case object StopPhase extends ProcessingPhase("Stop") + +case object StoragePhase extends ProcessingPhase("Storage") + +case object StreamPhase extends ProcessingPhase("Stream") + class HybridProcessor(sc: SparkContext, sqlc: SQLContext, storageLocation: String, @@ -157,6 +168,26 @@ class HybridProcessor(sc: SparkContext, _state } + // phase of the processing + protected var _phase: ProcessingPhase = StopPhase + + protected def setPhase(newPhase: ProcessingPhase) = synchronized { + _phase = newPhase + } + + def phase: ProcessingPhase = synchronized { + _phase + } + + // latest time-stamp + var latestStaticTimestamp = sc.accumulator[Option[IdType]](None)(new MaxOptionAccumulatorParam[IdType]) + var latestStreamTimestamp = sc.accumulator[Option[IdType]](None)(new MaxOptionAccumulatorParam[IdType]) + + var storageCount = sc.accumulator(0L) + var streamCount = sc.accumulator(0L) + var staticStartTime = 0L + var streamStartTime = 0L + /** * Start hybrid processing using the given RDD[JValue] operation. * @@ -356,9 +387,10 @@ class HybridProcessor(sc: SparkContext, // keep track of the maximal ID seen during processing val maxStaticId = sc.accumulator[Option[IdType]](None)(new MaxOptionAccumulatorParam[IdType]) - val countStatic = sc.accumulator(0L) + latestStaticTimestamp = maxStaticId + storageCount = sc.accumulator(0L) val maxStreamId = sc.accumulator[Option[IdType]](None)(new MaxOptionAccumulatorParam[IdType]) - val countStream = sc.accumulator(0L) + streamCount = sc.accumulator(0L) // processing of static data val repartitionedData = if (_master == "yarn-cluster" && totalNumCores > 0) { @@ -379,7 +411,7 @@ class HybridProcessor(sc: SparkContext, }).foreachRDD(rdd => { val count = rdd.count() // we count the number of total input rows (on the driver) - countStatic += count + storageCount += count // stop processing of static data if there are no new files if (count == 0) { logger.info(s"processed $count (static) lines, looks like done") @@ -396,10 +428,11 @@ class HybridProcessor(sc: SparkContext, // start first StreamingContext logger.info("starting static data processing") - val staticStartTime = System.currentTimeMillis() + staticStartTime = System.currentTimeMillis() var staticRunTime = 0L - var streamStartTime = -1L var streamRunTime = 0L + setPhase(StoragePhase) + ssc_.start() val staticStreamingContext = ssc_ @@ -445,7 +478,7 @@ class HybridProcessor(sc: SparkContext, // NB. This is a separate thread. In functions that will be serialized, // you cannot necessarily use variables from outside this thread. val localExtractId = extractId - val localCountStream = countStream + val localCountStream = streamCount val localMaxStreamId = maxStreamId logger.debug("hello from thread to start stream processing") staticStreamingContext.awaitTermination() @@ -454,9 +487,11 @@ class HybridProcessor(sc: SparkContext, // to continue with real stream processing only if the static processing // was completed successfully. val largestStaticItemId = maxStaticId.value + latestStreamTimestamp = maxStreamId + staticRunTime = System.currentTimeMillis() - staticStartTime logger.debug("static processing ended after %d items and %s ms, largest seen ID: %s".format( - countStatic.value, staticRunTime, largestStaticItemId)) + storageCount.value, staticRunTime, largestStaticItemId)) logger.debug("sleeping a bit to allow Spark to settle") runMode match { case Development => @@ -535,31 +570,37 @@ class HybridProcessor(sc: SparkContext, } else { logger.info("starting stream processing") streamStartTime = System.currentTimeMillis() + setPhase(StreamPhase) ssc_.start() } } case Nil => logger.info("not starting stream processing " + "(no stream source given)") + setPhase(StopPhase) setState(Finished) case _ => logger.error("not starting stream processing " + "(multiple streams not implemented)") + setPhase(StopPhase) setState(Finished) } } else if (staticProcessingComplete && userStoppedProcessing) { logger.info("static processing was stopped by user, " + "not setting up stream") + setPhase(StopPhase) setState(Finished) } else { logger.warn("static processing did not complete successfully, " + "not setting up stream") + setPhase(StopPhase) setState(Finished) } logger.debug("bye from thread to start stream processing") } onFailure { case error: Throwable => logger.error("Error while setting up stream processing", error) + setPhase(StopPhase) setState(Finished) } @@ -578,11 +619,12 @@ class HybridProcessor(sc: SparkContext, streamRunTime = System.currentTimeMillis() - streamStartTime } logger.info(("processed %s items in %s ms (static) and %s items in " + - "%s ms (stream)").format(countStatic.value, staticRunTime, - countStream.value, streamRunTime)) + "%s ms (stream)").format(storageCount.value, staticRunTime, + streamCount.value, streamRunTime)) + setPhase(StopPhase) setState(Finished) - (ProcessingInformation(countStatic.value, staticRunTime, maxStaticId.value), - ProcessingInformation(countStream.value, streamRunTime, maxStreamId.value)) + (ProcessingInformation(storageCount.value, staticRunTime, maxStaticId.value), + ProcessingInformation(streamCount.value, streamRunTime, maxStreamId.value)) }, () => maxStaticId.value) } @@ -607,4 +649,34 @@ class HybridProcessor(sc: SparkContext, throw e } } + + def getStatus(): LinkedHashMap[String, Any] = { + var storageMap: LinkedHashMap[String, Any] = new LinkedHashMap() + storageMap.put("path", storageLocation) + storageMap.put("storage_start", staticStartTime) + storageMap.put("storage_count", storageCount.value) + + var streamMap: LinkedHashMap[String, Any] = new LinkedHashMap() + streamMap.put("path", streamLocations) + streamMap.put("stream_start", streamStartTime) + streamMap.put("stream_count", streamCount.value) + + var stsMap: LinkedHashMap[String, Any] = new LinkedHashMap() + stsMap.put("state", _state.toString()) + stsMap.put("process_phase", _phase.getPhaseName()) + + val timestamp = latestStreamTimestamp.value match { + case Some(value) => value + case None => + latestStaticTimestamp.value match { + case Some(value) => value + case None => "" + } + } + + stsMap.put("process_timestamp", timestamp) + stsMap.put("storage", storageMap) + stsMap.put("stream", streamMap) + stsMap + } } diff --git a/processor/src/main/scala/us/jubat/jubaql_server/processor/JavaScriptUDFManager.scala b/processor/src/main/scala/us/jubat/jubaql_server/processor/JavaScriptUDFManager.scala index 1b2af8b..a5f016f 100644 --- a/processor/src/main/scala/us/jubat/jubaql_server/processor/JavaScriptUDFManager.scala +++ b/processor/src/main/scala/us/jubat/jubaql_server/processor/JavaScriptUDFManager.scala @@ -20,8 +20,9 @@ import scala.collection.JavaConversions import javax.script.{ScriptEngine, ScriptEngineManager, Invocable} import scala.util.{Failure, Success, Try} +import com.typesafe.scalalogging.slf4j.LazyLogging -class JavaScriptUDFManager { +class JavaScriptUDFManager extends LazyLogging { // The null is required. // See: http://stackoverflow.com/questions/20168226/sbt-0-13-scriptengine-is-null-for-getenginebyname-javascript private val scriptEngineManager = new ScriptEngineManager(null) @@ -60,10 +61,17 @@ class JavaScriptUDFManager { private def invoke(funcName: String, args: AnyRef*): AnyRef = { val inv = getInvocableEngine() - inv.invokeFunction(funcName, args: _*) + try { + inv.invokeFunction(funcName, args: _*) + } catch { + case e: Exception => + val errMsg = s"Failed to invoke function. functionName: ${funcName}, args: ${args.toString()}" + logger.error(errMsg, e) + throw new Exception(errMsg, e) + } } - def call[T](funcName: String, args: AnyRef*): Option[T] = { + def optionCall[T](funcName: String, args: AnyRef*): Option[T] = { Try { invoke(funcName, args:_*).asInstanceOf[T] } match { @@ -76,9 +84,19 @@ class JavaScriptUDFManager { invoke(funcName, args:_*).asInstanceOf[T] } - def registerAndCall[T](funcName: String, nargs: Int, funcBody: String, args: AnyRef*): Option[T] = { + def call[T](funcName: String, args: AnyRef*): T = { + Try { + invoke(funcName, args:_*).asInstanceOf[T] + } match { + case Success(value) => value + case Failure(err) => + throw err + } + } + + def registerAndOptionCall[T](funcName: String, nargs: Int, funcBody: String, args: AnyRef*): Option[T] = { register(funcName, nargs, funcBody) - call[T](funcName, args:_*) + optionCall[T](funcName, args:_*) } def registerAndTryCall[T](funcName: String, nargs: Int, funcBody: String, args: AnyRef*): Try[T] = { @@ -86,6 +104,11 @@ class JavaScriptUDFManager { tryCall[T](funcName, args:_*) } + def registerAndCall[T](funcName: String, nargs: Int, funcBody: String, args: AnyRef*): T = { + register(funcName, nargs, funcBody) + call[T](funcName, args:_*) + } + def getNumberOfArgsByFunctionName(fname: String): Option[Int] = funcs.synchronized { funcs.get(fname).map(_.nargs) } diff --git a/processor/src/main/scala/us/jubat/jubaql_server/processor/JubaQLAST.scala b/processor/src/main/scala/us/jubat/jubaql_server/processor/JubaQLAST.scala index 5bf3690..24782a3 100644 --- a/processor/src/main/scala/us/jubat/jubaql_server/processor/JubaQLAST.scala +++ b/processor/src/main/scala/us/jubat/jubaql_server/processor/JubaQLAST.scala @@ -31,19 +31,37 @@ CreateModel(algorithm: String, modelName: String, labelOrId: Option[(String, String)], featureExtraction: List[(FeatureFunctionParameters, String)], - configJson: String) extends JubaQLAST { - override def toString: String = "CreateModel(%s,%s,%s,%s,%s)".format( + configJson: String, + resConfigJson: Option[String] = None, + serverConfigJson: Option[String] = None, + proxyConfigJson: Option[String] = None) extends JubaQLAST { + override def toString: String = "CreateModel(%s,%s,%s,%s,%s,%s,%s,%s)".format( algorithm, modelName, labelOrId, featureExtraction, - if (configJson.size > 13) configJson.take(5) + "..." + configJson.takeRight(5) - else configJson + shorten(configJson), + resConfigJson match { + case Some(res) => shorten(res) + case None => resConfigJson + }, + serverConfigJson match { + case Some(server) => shorten(server) + case None => serverConfigJson + }, + proxyConfigJson match { + case Some(proxy) => shorten(proxy) + case None => proxyConfigJson + } ) + + def shorten(s: String): String = if (s.size < 13) s else (s.take(5) + "..." + s.takeRight(5)) } case class Update(modelName: String, rpcName: String, source: String) extends JubaQLAST +case class UpdateWith(modelName: String, rpcName: String, learningData: String) extends JubaQLAST + case class CreateStreamFromSelect(streamName: String, selectPlan: LogicalPlan) extends JubaQLAST case class CreateStreamFromAnalyze(streamName: String, analyze: Analyze, newColumn: Option[String]) extends JubaQLAST @@ -75,3 +93,7 @@ case class CreateFeatureFunction(funcName: String, args: List[(String, String)], case class CreateTriggerFunction(funcName: String, args: List[(String, String)], lang: String, body: String) extends JubaQLAST + +case class SaveModel(modelName: String, modelPath: String, modelId: String) extends JubaQLAST + +case class LoadModel(modelName: String, modelPath: String, modelId: String) extends JubaQLAST diff --git a/processor/src/main/scala/us/jubat/jubaql_server/processor/JubaQLParser.scala b/processor/src/main/scala/us/jubat/jubaql_server/processor/JubaQLParser.scala index 6156c1c..48d444b 100644 --- a/processor/src/main/scala/us/jubat/jubaql_server/processor/JubaQLParser.scala +++ b/processor/src/main/scala/us/jubat/jubaql_server/processor/JubaQLParser.scala @@ -111,6 +111,11 @@ class JubaQLParser extends SqlParser with LazyLogging { protected lazy val TIME = Keyword("TIME") protected lazy val TUPLES = Keyword("TUPLES") protected lazy val OVER = Keyword("OVER") + protected lazy val SAVE = Keyword("SAVE") + protected lazy val LOAD = Keyword("LOAD") + protected lazy val RESOURCE = Keyword("RESOURCE") + protected lazy val SERVER = Keyword("SERVER") + protected lazy val PROXY = Keyword("PROXY") override val lexical = new JubaQLLexical(reservedWords) @@ -222,9 +227,10 @@ class JubaQLParser extends SqlParser with LazyLogging { } CREATE ~> jubatusAlgorithm ~ MODEL ~ modelIdent ~ opt(labelOrId) ~ AS ~ - rep1sep(paramsAndFunction, ",") ~ CONFIG ~ stringLit ^^ { - case algorithm ~ _ ~ modelName ~ maybeLabelOrId ~ _ ~ l ~ _ ~ config => - CreateModel(algorithm, modelName, maybeLabelOrId, l, config) + rep1sep(paramsAndFunction, ",") ~ CONFIG ~ stringLit ~ opt(RESOURCE ~ CONFIG ~> stringLit) ~ + opt(SERVER ~ CONFIG ~> stringLit) ~ opt(PROXY ~ CONFIG ~> stringLit) ^^ { + case algorithm ~ _ ~ modelName ~ maybeLabelOrId ~ _ ~ l ~ _ ~ config ~ resourceConfig ~ serverConfig ~ proxyConfig => + CreateModel(algorithm, modelName, maybeLabelOrId, l, config, resourceConfig, serverConfig, proxyConfig) } } @@ -290,9 +296,11 @@ class JubaQLParser extends SqlParser with LazyLogging { } protected lazy val update: Parser[JubaQLAST] = { - UPDATE ~ MODEL ~> modelIdent ~ USING ~ funcIdent ~ FROM ~ streamIdent ^^ { - case modelName ~ _ ~ rpcName ~ _ ~ source => + UPDATE ~ MODEL ~> modelIdent ~ USING ~ funcIdent ~ (FROM ~ streamIdent | WITH ~ stringLit) ^^ { + case modelName ~ _ ~ rpcName ~ (fromOrWith ~ source) if fromOrWith.compareToIgnoreCase("FROM") == 0 => Update(modelName, rpcName, source) + case modelName ~ _ ~ rpcName ~ (fromOrWith ~ learningData) if fromOrWith.compareToIgnoreCase("WITH") == 0 => + UpdateWith(modelName, rpcName, learningData) } } @@ -369,6 +377,30 @@ class JubaQLParser extends SqlParser with LazyLogging { } } + protected lazy val saveModel: Parser[JubaQLAST] = { + SAVE ~ MODEL ~> modelIdent ~ USING ~ stringLit ~ AS ~ ident ^^ { + case modelName ~ _ ~ modelPath ~ _ ~ modelId => + modelPath match { + case "" => + null + case _ => + SaveModel(modelName, modelPath, modelId) + } + } + } + + protected lazy val loadModel: Parser[JubaQLAST] = { + LOAD ~ MODEL ~> modelIdent ~ USING ~ stringLit ~ AS ~ ident ^^ { + case modelName ~ _ ~ modelPath ~ _ ~ modelId => + modelPath match { + case "" => + null + case _ => + LoadModel(modelName, modelPath, modelId) + } + } + } + protected lazy val jubaQLQuery: Parser[JubaQLAST] = { createDatasource | createModel | @@ -385,7 +417,9 @@ class JubaQLParser extends SqlParser with LazyLogging { stopProcessing | createFunction | createFeatureFunction | - createTriggerFunction + createTriggerFunction | + saveModel | + loadModel } // note: apply cannot override incompatible type with parent class diff --git a/processor/src/main/scala/us/jubat/jubaql_server/processor/JubaQLService.scala b/processor/src/main/scala/us/jubat/jubaql_server/processor/JubaQLService.scala index ed80a19..fbbca09 100644 --- a/processor/src/main/scala/us/jubat/jubaql_server/processor/JubaQLService.scala +++ b/processor/src/main/scala/us/jubat/jubaql_server/processor/JubaQLService.scala @@ -19,6 +19,7 @@ import java.net.InetAddress import java.text.SimpleDateFormat import java.util.Date import java.util.concurrent.ConcurrentHashMap +import collection.JavaConversions._ import com.twitter.finagle.Service import com.twitter.util.{Future => TwFuture, Promise => TwPromise} @@ -27,6 +28,7 @@ import io.netty.util.CharsetUtil import RunMode.{Production, Development} import us.jubat.jubaql_server.processor.json._ import us.jubat.jubaql_server.processor.updater._ +import org.apache.hadoop.conf.Configuration import org.apache.spark.{SparkFiles, SparkContext} import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD @@ -44,16 +46,18 @@ import org.jboss.netty.handler.codec.http._ import org.json4s._ import org.json4s.native.{JsonMethods, Serialization} import org.json4s.JsonDSL._ +import org.apache.commons.io._ import sun.misc.Signal import us.jubat.anomaly.AnomalyClient -import us.jubat.classifier.ClassifierClient -import us.jubat.common.Datum +import us.jubat.classifier.{ClassifierClient, LabeledDatum} +import us.jubat.common.{Datum, ClientBase} import us.jubat.recommender.RecommenderClient -import us.jubat.yarn.client.{JubatusYarnApplication, JubatusYarnApplicationStatus, Resource} -import us.jubat.yarn.common.{LearningMachineType, Location} +import us.jubat.yarn.client.{JubatusYarnApplication, JubatusYarnApplicationStatus, Resource, JubatusClusterConfiguration} +import us.jubat.yarn.common._ import scala.collection._ import scala.collection.convert.decorateAsScala._ +import scala.collection.mutable.{LinkedHashMap, HashMap, ArrayBuffer} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.concurrent.{Await => ScAwait, Future => ScFuture, Promise => ScPromise, SyncVar} @@ -91,6 +95,9 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) val knownStreamNames: concurrent.Map[String, String] = new ConcurrentHashMap[String, String]().asScala + val streamStates: concurrent.Map[String, StreamState] = + new ConcurrentHashMap[String, StreamState]().asScala + // hold feature functions written in JavaScript. val featureFunctions: concurrent.Map[String, String] = new ConcurrentHashMap[String, String]().asScala @@ -258,6 +265,248 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) } } + protected def complementResource(resourceJsonString: Option[String]): Either[(Int, String), Resource] = { + + resourceJsonString match { + case Some(strResource) => + JsonMethods.parseOpt(strResource) match { + case Some(obj: JObject) => + val masterMemory = checkConfigByInt(obj, "applicationmaster_memory", Resource.defaultMasterMemory, 1) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + val proxyMemory = checkConfigByInt(obj, "jubatus_proxy_memory", Resource.defaultJubatusProxyMemory, 1) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + val masterCores = checkConfigByInt(obj, "applicationmaster_cores", Resource.defaultMasterCores, 1) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + val containerPriority = checkConfigByInt(obj, "container_priority", Resource.defaultPriority, 0) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + val containerMemory = checkConfigByInt(obj, "container_memory", Resource.defaultContainerMemory, 1) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + val serverMemory = checkConfigByInt(obj, "jubatus_server_memory", Resource.defaultJubatusServerMemory, 1) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + val containerCores = checkConfigByInt(obj, "container_cores", Resource.defaultContainerCores, 1) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + val containerNodes = checkConfigByStringList(obj, "container_nodes") match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + val containerRacks = checkConfigByStringList(obj, "container_racks") match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + Right(Resource(containerPriority, serverMemory, containerCores, masterMemory, proxyMemory, masterCores, containerMemory, containerNodes, containerRacks)) + + case None => + Left((400, "Resource config is not a JSON")) + } + + case None => + Right(Resource()) + } + } + + protected def complementServerConfig(serverJsonString: Option[String]): Either[(Int, String), ServerConfig] = { + + serverJsonString match { + case Some(strServer) => + JsonMethods.parseOpt(strServer) match { + case Some(obj: JObject) => + val diffSet = obj.values.keySet diff Set("thread", "timeout", "mixer", "interval_sec", "interval_count", "zookeeper_timeout", "interconnect_timeout") + if (diffSet.size != 0) { + return Left(400, s"invalid server config elements (${diffSet.mkString(",")})") + } + + val thread = checkConfigByInt(obj, "thread", ServerConfig.defaultThread, 1) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + val timeout = checkConfigByInt(obj, "timeout", ServerConfig.defaultTimeout, 0) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + val mixer = checkConfigByString(obj, "mixer", ServerConfig.defaultMixer.name) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + try { + Mixer.valueOf(value) + } catch { + case e: Throwable => + return Left(400, s"invalid mixer specify 'linear_mixer' or 'random_mixer' or 'broadcast_mixer' or 'skip_mixer'") + } + } + + val intervalSec = checkConfigByInt(obj, "interval_sec", ServerConfig.defaultIntervalSec, 0) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + val intervalCount = checkConfigByInt(obj, "interval_count", ServerConfig.defaultIntervalCount, 0) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + val zookeeperTimeout = checkConfigByInt(obj, "zookeeper_timeout", ServerConfig.defaultZookeeperTimeout, 1) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + val interconnectTimeout = checkConfigByInt(obj, "interconnect_timeout", ServerConfig.defaultInterconnectTimeout, 1) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + Right(ServerConfig(thread, timeout, mixer, intervalSec, intervalCount, zookeeperTimeout, interconnectTimeout)) + + case None => + Left((400, "Server config is not a JSON")) + } + + case None => + Right(ServerConfig()) + } + } + + protected def complementProxyConfig(proxyJsonString: Option[String]): Either[(Int, String), ProxyConfig] = { + + proxyJsonString match { + case Some(strProxy) => + JsonMethods.parseOpt(strProxy) match { + case Some(obj: JObject) => + val diffSet = obj.values.keySet diff Set("thread", "timeout", "zookeeper_timeout", "interconnect_timeout", "pool_expire", "pool_size") + if (diffSet.size != 0) { + return Left(400, s"invalid proxy config elements (${diffSet.mkString(",")})") + } + + val thread = checkConfigByInt(obj, "thread", ProxyConfig.defaultThread, 1) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + val timeout = checkConfigByInt(obj, "timeout", ProxyConfig.defaultTimeout, 0) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + val zookeeperTimeout = checkConfigByInt(obj, "zookeeper_timeout", ProxyConfig.defaultZookeeperTimeout, 1) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + val interconnectTimeout = checkConfigByInt(obj, "interconnect_timeout", ProxyConfig.defaultInterconnectTimeout, 1) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + val poolExpire = checkConfigByInt(obj, "pool_expire", ProxyConfig.defaultPoolExpire, 0) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + val poolSize = checkConfigByInt(obj, "pool_size", ProxyConfig.defaultPoolSize, 0) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + + case Right(value) => + value + } + + Right(ProxyConfig(thread, timeout, zookeeperTimeout, interconnectTimeout, poolExpire, poolSize)) + + case None => + Left((400, "Proxy config is not a JSON")) + } + + case None => + Right(ProxyConfig()) + } + } + protected def takeAction(ast: JubaQLAST): Either[(Int, String), JubaQLResponse] = { ast match { case anything if isAcceptingQueries.get == false => @@ -324,15 +573,44 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) compact(render(config)) } // TODO: location, resource - val resource = Resource(priority = 0, memory = 256, virtualCores = 1) + val resource = complementResource(cm.resConfigJson) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + case Right(value) => + value + } + + val serverConfig = complementServerConfig(cm.serverConfigJson) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + case Right(value) => + value + } + + val proxyConfig = complementProxyConfig(cm.proxyConfigJson) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + case Right(value) => + value + } + + var message: String = "" + val gatewayAddress = scala.util.Properties.propOrElse("jubaql.gateway.address","") + val sessionId = scala.util.Properties.propOrElse("jubaql.processor.sessionId","") + val applicationName = s"JubatusOnYarn:$gatewayAddress:$sessionId:${jubaType.name}:${cm.modelName}" + val juba: ScFuture[JubatusYarnApplication] = runMode match { case RunMode.Production(zookeeper) => val location = zookeeper.map { case (host, port) => Location(InetAddress.getByName(host), port) } - JubatusYarnApplication.start(cm.modelName, jubaType, location, configJsonStr, resource, 2) + val jubaClusterConfig = JubatusClusterConfiguration(cm.modelName, jubaType, location, configJsonStr, null, resource, 2, applicationName, serverConfig, proxyConfig) + JubatusYarnApplication.start(jubaClusterConfig) case RunMode.Development => - LocalJubatusApplication.start(cm.modelName, jubaType, configJsonStr) + if (cm.proxyConfigJson.isDefined) { + message = "(proxy setting has been ignored in Development mode)" + } + LocalJubatusApplication.start(cm.modelName, jubaType, configJsonStr, serverConfig) } // we keep a reference to the started instance so we can always check its status @@ -349,7 +627,7 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) t.printStackTrace() startedInstance.completeWith(juba) } - Right(StatementProcessed("CREATE MODEL (started)")) + Right(StatementProcessed(s"CREATE MODEL (started) $message")) case CreateStreamFromSelect(streamName, selectPlan) => if (knownStreamNames.contains(streamName)) { @@ -361,6 +639,7 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) withStreams(refStreams)(mainDataSource => { // register this stream internally knownStreamNames += ((streamName, mainDataSource)) + streamStates += ((streamName, new StreamState(sc, refStreams.toList))) preparedStatements.enqueue((mainDataSource, PreparedCreateStreamFromSelect(streamName, selectPlan, refStreams.toList))) Right(StatementProcessed("CREATE STREAM")) @@ -425,6 +704,7 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) withStreams(refStreams)(mainDataSource => { // register this stream internally knownStreamNames += ((streamName, mainDataSource)) + streamStates += ((streamName, new StreamState(sc, refStreams.toList))) val flattenedFuncs = checkedFuncSpecs.collect{ case Right(x) => x } // build the schema that will result from this statement // (add one additional column with the window timestamp if the @@ -490,6 +770,7 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) case Right((modelFut, analyzerFut)) => // register this stream internally knownStreamNames += ((cs.streamName, mainDataSource)) + streamStates += ((cs.streamName, new StreamState(sc, List(cs.analyze.data)))) // put the UPDATE statement in the statement queue preparedStatements.enqueue((mainDataSource, PreparedCreateStreamFromAnalyze(cs.streamName, cs.analyze.modelName, modelFut, @@ -556,6 +837,14 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) } }) + case updateWith: UpdateWith => + queryUpdateWith(updateWith) match { + case Left(msgWithErrCode) => + Left(msgWithErrCode) + case Right(updateResult) => + Right(StatementProcessed(s"UPDATE MODEL ($updateResult)")) + } + case StartProcessing(sourceName) => sources.get(sourceName) match { case None => @@ -574,6 +863,7 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) logger.info(s"setting up processing pipeline for data source '$sourceName' " + s"with given schema $maybeSchema") + val readyStreamList = ArrayBuffer.empty[String] val rddOperations: mutable.Queue[Either[(Int, String), StreamingContext => Unit]] = preparedStatements.filter(_._1 == sourceName).map(_._2).map(stmt => { logger.debug(s"deal with $stmt") @@ -584,8 +874,16 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) logger.info(s"adding 'CREATE STREAM $streamName FROM SELECT ...' to pipeline") Right((ssc: StreamingContext) => { logger.debug(s"executing 'CREATE STREAM $streamName FROM SELECT ...'") - SchemaDStream.fromSQL(ssc, sqlc, - selectPlan, Some(streamName)) + val selectedStream = SchemaDStream.fromSQL(ssc, sqlc, selectPlan, Some(streamName)) + streamStates.get(streamName) match { + case Some(streamState) => + readyStreamList += streamName + selectedStream.foreachRDD({ rdd => + streamState.outputCount += rdd.count + }) + case None => + logger.warn(s"Stream(${streamName}) that counts the number of processing not found.") + } () }) @@ -654,8 +952,17 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) schemaRdd.where(postCond) }) }).getOrElse(outRowStream) - SchemaDStream(sqlc, filteredOutRowStream, outSchemaStream) - .registerStreamAsTable(streamName) + val filteredOutRowWithSchemaStream = SchemaDStream(sqlc, filteredOutRowStream, outSchemaStream) + streamStates.get(streamName) match { + case Some(streamState) => + readyStreamList += streamName + filteredOutRowWithSchemaStream.foreachRDD({ rdd => + streamState.outputCount += rdd.count + }) + case None => + logger.warn(s"Stream(${streamName}) that counts the number of processing not found.") + } + filteredOutRowWithSchemaStream.registerStreamAsTable(streamName) () } Right(fun) @@ -694,24 +1001,32 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) logger.info(s"adding 'CREATE STREAM $streamName FROM ANALYZE ...' to pipeline") Right((ssc: StreamingContext) => { logger.debug(s"executing 'CREATE STREAM $streamName FROM ANALYZE ...'") - SchemaDStream.fromRDDTransformation(ssc, sqlc, dataSourceName, tmpRdd => { - val rddSchema: StructType = tmpRdd.schema - val analyzeFun = UpdaterAnalyzeWrapper(rddSchema, statusUrl, - updater, rpcName) - val newSchema = StructType(rddSchema.fields :+ - StructField(newColumn.getOrElse(rpcName), - analyzeFun.dataType, nullable = false)) - val newRdd = sqlc.applySchema(tmpRdd.mapPartitionsWithIndex((idx, iter) => { - val formatter = new SimpleDateFormat("HH:mm:ss.SSS") - val hostname = InetAddress.getLocalHost().getHostName() - println("%s @ %s [%s] DEBUG analyzing model from partition %d".format( - formatter.format(new Date), hostname, Thread.currentThread().getName, idx - )) - iter - }).mapPartitions(analyzeFun.apply(_)), - newSchema) - newRdd - }, Some(streamName)) + val analyzedStream = SchemaDStream.fromRDDTransformation(ssc, sqlc, dataSourceName, tmpRdd => { + val rddSchema: StructType = tmpRdd.schema + val analyzeFun = UpdaterAnalyzeWrapper(rddSchema, statusUrl, + updater, rpcName) + val newSchema = StructType(rddSchema.fields :+ + StructField(newColumn.getOrElse(rpcName), + analyzeFun.dataType, nullable = false)) + val newRdd = sqlc.applySchema(tmpRdd.mapPartitionsWithIndex((idx, iter) => { + val formatter = new SimpleDateFormat("HH:mm:ss.SSS") + val hostname = InetAddress.getLocalHost().getHostName() + println("%s @ %s [%s] DEBUG analyzing model from partition %d".format( + formatter.format(new Date), hostname, Thread.currentThread().getName, idx)) + iter + }).mapPartitions(analyzeFun.apply(_)), + newSchema) + newRdd + }, Some(streamName)) + streamStates.get(streamName) match { + case Some(streamState) => + readyStreamList += streamName + analyzedStream.foreachRDD({ rdd => + streamState.outputCount += rdd.count() + }) + case None => + logger.warn(s"Stream(${streamName}) that counts the number of processing not found.") + } () }) } @@ -832,6 +1147,15 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) val stopFun = processor.startTableProcessingGeneral(transform, maybeSchema, sourceName)._1 stopUpdateFunc = Some(() => stopFun()) + // set start time to streamState + readyStreamList.foreach { startedStreamName => + streamStates.get(startedStreamName) match { + case Some(state) => + state.startTime = System.currentTimeMillis() + case None => + logger.warn(s"${startedStreamName} is undefined") + } + } Right(StatementProcessed("START PROCESSING")) } } @@ -845,12 +1169,15 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) } case s: Status => - val dsStatus = sources.mapValues(_._1.state.toString) - val jubaStatus = models.mapValues(_._1 match { - case dummy: LocalJubatusApplication => "OK" - case real => real.status.toString - }) - Right(StatusResponse("STATUS", dsStatus.toMap, jubaStatus.toMap)) + val dsStatus = getSourcesStatus() + val jubaStatus = getModelsStatus() + val proStatus = getProcessorStatus() + val streamStatus = getStreamStatus() + logger.debug(s"dataSourcesStatus: $dsStatus") + logger.debug(s"modelsStatus: $jubaStatus") + logger.debug(s"processorStatus: $proStatus") + logger.debug(s"streamStatus: $streamStatus") + Right(StatusResponse("STATUS", dsStatus, jubaStatus, proStatus, streamStatus)) case s: Shutdown => // first set a flag to stop further query processing @@ -934,20 +1261,20 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) // (0 until nParams).map(n => s"x$n: AnyRef").mkString(", ") // } // - // def caseTypeString(sqlType: String, scalaType: String, defaultValue: String, nArgs: Int): String = { + // def caseTypeString(sqlType: String, scalaType: String, nArgs: Int): String = { // val args = nArgsString(nArgs) // val params = nParamsString(nArgs) // s"""case "$sqlType" => // | sqlc.registerFunction(funcName, ($params) => { // | JavaScriptUDFManager.registerAndCall[$scalaType](funcName, - // | $nArgs, funcBody, $args).getOrElse($defaultValue) + // | $nArgs, funcBody, $args) // | })""".stripMargin // } // // def caseNArgs(nArgs: Int): String = { - // val numericCase = caseTypeString("numeric", "Double", "0.0", nArgs).split("\n").map(" " + _).mkString("\n") - // val stringCase = caseTypeString("string", "String", "\"\"", nArgs).split("\n").map(" " + _).mkString("\n") - // val booleanCase = caseTypeString("boolean", "Boolean", "false", nArgs).split("\n").map(" " + _).mkString("\n") + // val numericCase = caseTypeString("numeric", "Double", nArgs).split("\n").map(" " + _).mkString("\n") + // val stringCase = caseTypeString("string", "String", nArgs).split("\n").map(" " + _).mkString("\n") + // val booleanCase = caseTypeString("boolean", "Boolean", nArgs).split("\n").map(" " + _).mkString("\n") // s"""case $nArgs => // | returnType match { // |$numericCase @@ -964,17 +1291,17 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) case "numeric" => sqlc.registerFunction(funcName, (x0: AnyRef) => { JavaScriptUDFManager.registerAndCall[Double](funcName, - 1, funcBody, x0).getOrElse(0.0) + 1, funcBody, x0) }) case "string" => sqlc.registerFunction(funcName, (x0: AnyRef) => { JavaScriptUDFManager.registerAndCall[String](funcName, - 1, funcBody, x0).getOrElse("") + 1, funcBody, x0) }) case "boolean" => sqlc.registerFunction(funcName, (x0: AnyRef) => { JavaScriptUDFManager.registerAndCall[Boolean](funcName, - 1, funcBody, x0).getOrElse(false) + 1, funcBody, x0) }) } Right(StatementProcessed("CREATE FUNCTION")) @@ -984,17 +1311,17 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) case "numeric" => sqlc.registerFunction(funcName, (x0: AnyRef, x1: AnyRef) => { JavaScriptUDFManager.registerAndCall[Double](funcName, - 2, funcBody, x0, x1).getOrElse(0.0) + 2, funcBody, x0, x1) }) case "string" => sqlc.registerFunction(funcName, (x0: AnyRef, x1: AnyRef) => { JavaScriptUDFManager.registerAndCall[String](funcName, - 2, funcBody, x0, x1).getOrElse("") + 2, funcBody, x0, x1) }) case "boolean" => sqlc.registerFunction(funcName, (x0: AnyRef, x1: AnyRef) => { JavaScriptUDFManager.registerAndCall[Boolean](funcName, - 2, funcBody, x0, x1).getOrElse(false) + 2, funcBody, x0, x1) }) } Right(StatementProcessed("CREATE FUNCTION")) @@ -1004,17 +1331,17 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) case "numeric" => sqlc.registerFunction(funcName, (x0: AnyRef, x1: AnyRef, x2: AnyRef) => { JavaScriptUDFManager.registerAndCall[Double](funcName, - 3, funcBody, x0, x1, x2).getOrElse(0.0) + 3, funcBody, x0, x1, x2) }) case "string" => sqlc.registerFunction(funcName, (x0: AnyRef, x1: AnyRef, x2: AnyRef) => { JavaScriptUDFManager.registerAndCall[String](funcName, - 3, funcBody, x0, x1, x2).getOrElse("") + 3, funcBody, x0, x1, x2) }) case "boolean" => sqlc.registerFunction(funcName, (x0: AnyRef, x1: AnyRef, x2: AnyRef) => { JavaScriptUDFManager.registerAndCall[Boolean](funcName, - 3, funcBody, x0, x1, x2).getOrElse(false) + 3, funcBody, x0, x1, x2) }) } Right(StatementProcessed("CREATE FUNCTION")) @@ -1024,17 +1351,17 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) case "numeric" => sqlc.registerFunction(funcName, (x0: AnyRef, x1: AnyRef, x2: AnyRef, x3: AnyRef) => { JavaScriptUDFManager.registerAndCall[Double](funcName, - 4, funcBody, x0, x1, x2, x3).getOrElse(0.0) + 4, funcBody, x0, x1, x2, x3) }) case "string" => sqlc.registerFunction(funcName, (x0: AnyRef, x1: AnyRef, x2: AnyRef, x3: AnyRef) => { JavaScriptUDFManager.registerAndCall[String](funcName, - 4, funcBody, x0, x1, x2, x3).getOrElse("") + 4, funcBody, x0, x1, x2, x3) }) case "boolean" => sqlc.registerFunction(funcName, (x0: AnyRef, x1: AnyRef, x2: AnyRef, x3: AnyRef) => { JavaScriptUDFManager.registerAndCall[Boolean](funcName, - 4, funcBody, x0, x1, x2, x3).getOrElse(false) + 4, funcBody, x0, x1, x2, x3) }) } Right(StatementProcessed("CREATE FUNCTION")) @@ -1044,17 +1371,17 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) case "numeric" => sqlc.registerFunction(funcName, (x0: AnyRef, x1: AnyRef, x2: AnyRef, x3: AnyRef, x4: AnyRef) => { JavaScriptUDFManager.registerAndCall[Double](funcName, - 5, funcBody, x0, x1, x2, x3, x4).getOrElse(0.0) + 5, funcBody, x0, x1, x2, x3, x4) }) case "string" => sqlc.registerFunction(funcName, (x0: AnyRef, x1: AnyRef, x2: AnyRef, x3: AnyRef, x4: AnyRef) => { JavaScriptUDFManager.registerAndCall[String](funcName, - 5, funcBody, x0, x1, x2, x3, x4).getOrElse("") + 5, funcBody, x0, x1, x2, x3, x4) }) case "boolean" => sqlc.registerFunction(funcName, (x0: AnyRef, x1: AnyRef, x2: AnyRef, x3: AnyRef, x4: AnyRef) => { JavaScriptUDFManager.registerAndCall[Boolean](funcName, - 5, funcBody, x0, x1, x2, x3, x4).getOrElse(false) + 5, funcBody, x0, x1, x2, x3, x4) }) } Right(StatementProcessed("CREATE FUNCTION")) @@ -1121,7 +1448,7 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) // Returns an Int value because registerFunction does not accept a function which returns Unit. // The Int value is not used. sqlc.registerFunction(funcName, (x0: AnyRef) => { - JavaScriptUDFManager.registerAndCall[Int](funcName, + JavaScriptUDFManager.registerAndOptionCall[Int](funcName, 1, funcBody, x0).getOrElse(0) }) Right(StatementProcessed("CREATE TRIGGER FUNCTION")) @@ -1129,7 +1456,7 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) case 2 => // Returns Int for the above reason. sqlc.registerFunction(funcName, (x0: AnyRef, x1: AnyRef) => { - JavaScriptUDFManager.registerAndCall[Int](funcName, + JavaScriptUDFManager.registerAndOptionCall[Int](funcName, 2, funcBody, x0, x1).getOrElse(0) }) Right(StatementProcessed("CREATE TRIGGER FUNCTION")) @@ -1137,7 +1464,7 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) case 3 => // Returns Int for the above reason. sqlc.registerFunction(funcName, (x0: AnyRef, x1: AnyRef, x2: AnyRef) => { - JavaScriptUDFManager.registerAndCall[Int](funcName, + JavaScriptUDFManager.registerAndOptionCall[Int](funcName, 3, funcBody, x0, x1, x2).getOrElse(0) }) Right(StatementProcessed("CREATE TRIGGER FUNCTION")) @@ -1145,7 +1472,7 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) case 4 => // Returns Int for the above reason. sqlc.registerFunction(funcName, (x0: AnyRef, x1: AnyRef, x2: AnyRef, x3: AnyRef) => { - JavaScriptUDFManager.registerAndCall[Int](funcName, + JavaScriptUDFManager.registerAndOptionCall[Int](funcName, 4, funcBody, x0, x1, x2, x3).getOrElse(0) }) Right(StatementProcessed("CREATE TRIGGER FUNCTION")) @@ -1153,7 +1480,7 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) case 5 => // Returns Int for the above reason. sqlc.registerFunction(funcName, (x0: AnyRef, x1: AnyRef, x2: AnyRef, x3: AnyRef, x4: AnyRef) => { - JavaScriptUDFManager.registerAndCall[Int](funcName, + JavaScriptUDFManager.registerAndOptionCall[Int](funcName, 5, funcBody, x0, x1, x2, x3, x4).getOrElse(0) }) Right(StatementProcessed("CREATE TRIGGER FUNCTION")) @@ -1164,8 +1491,74 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) Left((400, msg)) } + case SaveModel(modelName, modelPath, modelId) => + models.get(modelName) match { + case Some((jubaApp, createModelStmt, machineType)) => + val chkResult = runMode match { + case RunMode.Production(zookeeper) => + modelPath.startsWith("hdfs://") + case RunMode.Development => + modelPath.startsWith("file://") + } + + if (chkResult) { + val juba = jubaApp.saveModel(new org.apache.hadoop.fs.Path(modelPath), modelId) + juba match { + case Failure(t) => + val msg = s"SAVE MODEL failed: ${t.getMessage}" + logger.error(msg, t) + Left((500, msg)) + + case _ => + Right(StatementProcessed("SAVE MODEL")) + } + } else { + val msg = s"invalid model path ($modelPath)" + logger.warn(msg) + Left((400, msg)) + } + + case None => + val msg = s"model '$modelName' does not exist" + logger.warn(msg) + Left((400, msg)) + } + + case LoadModel(modelName, modelPath, modelId) => + models.get(modelName) match { + case Some((jubaApp, createModelStmt, machineType)) => + val chkResult = runMode match { + case RunMode.Production(zookeeper) => + modelPath.startsWith("hdfs://") + case RunMode.Development => + modelPath.startsWith("file://") + } + + if (chkResult) { + val juba = jubaApp.loadModel(new org.apache.hadoop.fs.Path(modelPath), modelId) + juba match { + case Failure(t) => + val msg = s"LOAD MODEL failed: ${t.getMessage}" + logger.error(msg, t) + Left((500, msg)) + + case _ => + Right(StatementProcessed("LOAD MODEL")) + } + } else { + val msg = s"invalid model path ($modelPath)" + logger.warn(msg) + Left((400, msg)) + } + + case None => + val msg = s"model '$modelName' does not exist" + logger.warn(msg) + Left((400, msg)) + } + case other => - val msg = "no handler for " + other + val msg = s"no handler for $other" logger.error(msg) Left((500, msg)) } @@ -1458,6 +1851,274 @@ class JubaQLService(sc: SparkContext, runMode: RunMode, checkpointDir: String) Left((400, msg)) } } + + protected def checkConfigByInt(resObj: JObject, strKey: String, defValue: Int, minValue: Int = 0): Either[(Int, String), Int] = { + resObj.values.get(strKey) match { + case Some(value) => + try { + val numValue = value.asInstanceOf[Number] + val intValue = numValue.intValue() + if (intValue >= minValue && intValue <= Int.MaxValue) { + Right(intValue) + } else { + Left((400, s"invalid ${strKey} specified in ${minValue} or more and ${Int.MaxValue} or less")) + } + } catch { + case e: Throwable => + logger.error(e.getMessage(), e) + Left(400, s"invalid config (${strKey})") + } + + case None => + Right(defValue) + } + } + + protected def checkConfigByString(resObj: JObject, strKey: String, defValue: String): Either[(Int, String), String] = { + resObj.values.get(strKey) match { + case Some(value) => + try { + Right(value.asInstanceOf[String]) + } catch { + case e: Throwable => + logger.error(e.getMessage(), e) + Left(400, s"invalid config (${strKey})") + } + + case None => + Right(defValue) + } + } + + protected def checkConfigByStringList(resObj: JObject, strKey: String): Either[(Int, String), List[String]] = { + resObj.values.get(strKey) match { + case Some(value) => + try { + Right(value.asInstanceOf[List[String]]) + } catch { + case e: Throwable => + logger.error(e.getMessage(), e) + Left(400, s"invalid config (${strKey})") + } + + case None => + Right(null) + } + } + + protected def getSourcesStatus(): Map[String, Any] = { + var sourceMap: LinkedHashMap[String, Any] = new LinkedHashMap() + sources.foreach { + case (sourceName, (hybridProcessor, schema)) => + sourceMap.put(sourceName, hybridProcessor.getStatus()) + } + sourceMap + } + + protected def getModelsStatus(): Map[String, Any] = { + var jubaStatus: LinkedHashMap[String, Any] = new LinkedHashMap() + models.foreach { + case (modelName, (jubaApp, createModel, jubaType)) => + var configMap: LinkedHashMap[String, Any] = new LinkedHashMap() + configMap.put("jubatusConfig", createModel.configJson) + configMap.put("resourceConfig", createModel.resConfigJson.getOrElse("")) + configMap.put("serverConfig", createModel.serverConfigJson.getOrElse("")) + configMap.put("proxyConfig", createModel.proxyConfigJson.getOrElse("")) + + var jubatusAppStatusMap: LinkedHashMap[String, Any] = new LinkedHashMap() + val jubatusAppStatus = jubaApp.status + + var proxyStatus: Map[String, Map[String, String]] = new HashMap() + if (jubatusAppStatus.jubatusProxy != null) { + proxyStatus = jubatusAppStatus.jubatusProxy.asScala.mapValues(map => map.asScala) + } + var serversStatus: Map[String, Map[String, String]] = new HashMap() + if (jubatusAppStatus.jubatusServers != null) { + serversStatus = jubatusAppStatus.jubatusServers.asScala.mapValues(map => map.asScala) + } + val yarnAppStatus: LinkedHashMap[String, Any] = new LinkedHashMap() + if (jubatusAppStatus.yarnApplication != null) { + jubatusAppStatus.yarnApplication.foreach { + case (key, value) => + if (key == "applicationReport") { + yarnAppStatus.put(key, value.toString()) + } else { + yarnAppStatus.put(key, value) + } + } + } + + jubatusAppStatusMap.put("jubatusProxy", proxyStatus) + jubatusAppStatusMap.put("jubatusServers", serversStatus) + jubatusAppStatusMap.put("jubatusOnYarn", yarnAppStatus) + + var modelMap: LinkedHashMap[String, Any] = new LinkedHashMap() + modelMap.put("learningMachineType", jubaType.name) + modelMap.put("config", configMap) + modelMap.put("jubatusYarnApplicationStatus", jubatusAppStatusMap) + jubaStatus.put(modelName, modelMap) + } + jubaStatus + } + + protected def getProcessorStatus(): Map[String, Any] = { + val curTime = System.currentTimeMillis() + val opTime = curTime - sc.startTime + val runtime = Runtime.getRuntime() + val usedMemory = runtime.totalMemory() - runtime.freeMemory() + + var proStatusMap: LinkedHashMap[String, Any] = new LinkedHashMap() + proStatusMap.put("applicationId", sc.applicationId) + proStatusMap.put("startTime", sc.startTime) + proStatusMap.put("currentTime", curTime) + proStatusMap.put("opratingTime", opTime) + proStatusMap.put("virtualMemory", runtime.totalMemory()) + proStatusMap.put("usedMemory", usedMemory) + + proStatusMap + } + + def getStreamStatus(): Map[String, Map[String, Any]] = { + // Map{key = streamName, value = Map{key = statusName, value = statusValue}} + var streamStatusMap: Map[String, Map[String, Any]] = Map.empty[String, Map[String, Any]] + streamStates.foreach(streamState => { + val stateValue = streamState._2 + var totalInputCount = 0L + // calculate input count of each stream + stateValue.inputStreamList.foreach { inputStreamName => + // add output count of datasource to totalInputCount + sources.get(inputStreamName) match { + case Some(source) => + totalInputCount += source._1.storageCount.value + totalInputCount += source._1.streamCount.value + case None => + logger.warn(s"input datasource(${inputStreamName}) of stream(${streamState._1}) was not found.") + } + // add output count of user define stream to totalInputCount + streamStates.get(inputStreamName) match { + case Some(inputStreamState) => + totalInputCount += inputStreamState.outputCount.value + case None => + logger.warn(s"input stream(${inputStreamName}) of stream(${streamState._1}) was not found.") + } + } + stateValue.inputCount = totalInputCount + val stateMap: scala.collection.mutable.Map[String, Any] = new scala.collection.mutable.LinkedHashMap[String, Any] + stateMap.put("stream_start", stateValue.startTime) + stateMap.put("input_count", totalInputCount) + stateMap.put("output_count", stateValue.outputCount.value) + streamStatusMap += streamState._1 -> stateMap + }) + streamStatusMap + } + + protected def queryUpdateWith(updateWith: UpdateWith): Either[(Int, String), String] = { + models.get(updateWith.modelName) match { + case Some((jubaApp, createModelStmt, machineType)) => + val host = jubaApp.jubatusProxy.hostAddress + val port = jubaApp.jubatusProxy.port + + machineType match { + case LearningMachineType.Anomaly if updateWith.rpcName == "add" => + val datum = DatumExtractor.extract(createModelStmt, updateWith.learningData, featureFunctions, logger) + val anomaly = new AnomalyClient(host, port, updateWith.modelName, 5) + try { + val result = anomaly.add(datum) + logger.info(s"anomaly.add result: ${result.toString()}") + Right(result.toString()) + } finally { + anomaly.getClient.close() + } + + case LearningMachineType.Classifier if updateWith.rpcName == "train" => + val datum = DatumExtractor.extract(createModelStmt, updateWith.learningData, featureFunctions, logger) + val label = createModelStmt.labelOrId match { + case Some(("label", value)) => + value + case _ => + val msg = s"no label for datum specified" + logger.warn(msg) + return Left((400, msg)) + } + + val labelValue = getItemValue(updateWith.learningData, label) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + case Right(value) => + value + } + + val labelDatum = new LabeledDatum(labelValue, datum) + val datumList = new java.util.LinkedList[LabeledDatum]() + datumList.add(labelDatum) + val classifier = new ClassifierClient(host, port, updateWith.modelName, 5) + try { + val result = classifier.train(datumList) + logger.info(s"classifier.train result: ${result}") + Right(result.toString()) + } finally { + classifier.getClient.close() + } + + case LearningMachineType.Recommender if updateWith.rpcName == "update_row" => + val datum = DatumExtractor.extract(createModelStmt, updateWith.learningData, featureFunctions, logger) + val idName = createModelStmt.labelOrId match { + case Some(("id", value)) => + value + case _ => + val msg = s"no id for datum specified" + logger.warn(msg) + return Left((400, msg)) + } + + val idValue = getItemValue(updateWith.learningData, idName) match { + case Left((errCode, errMsg)) => + return Left((errCode, errMsg)) + case Right(value) => + value + } + + val recommender = new RecommenderClient(host, port, updateWith.modelName, 5) + try { + val result = recommender.updateRow(idValue, datum) + logger.info(s"recommender.update_row result: ${result}") + Right(result.toString()) + } finally { + recommender.getClient.close() + } + + case _ => + val msg = s"cannot use model '${updateWith.modelName}' with method '${updateWith.rpcName}'" + logger.warn(msg) + Left((400, msg)) + } + + case None => + val msg = s"model '${updateWith.modelName}' does not exist" + logger.warn(msg) + Left((400, msg)) + } + } + + protected def getItemValue(dataJsonString: String, itemName: String): Either[(Int, String), String] = { + JsonMethods.parseOpt(dataJsonString) match { + case Some(obj: JObject) => + obj.values.get(itemName) match { + case Some(value) => + Right(value.toString()) + + case None => + val msg = s"the given schema ${dataJsonString} does not contain a column named '${itemName}'" + logger.warn(msg) + Left((400, msg)) + } + + case _ => + val msg = "data is not JSON." + logger.warn(msg) + Left((400, msg)) + } + } } sealed trait RunMode @@ -1468,12 +2129,14 @@ object RunMode { case object Development extends RunMode + case object Test extends RunMode } object LocalJubatusApplication extends LazyLogging { def start(aLearningMachineName: String, aLearningMachineType: LearningMachineType, - aConfigString: String): scala.concurrent.Future[us.jubat.yarn.client.JubatusYarnApplication] = { + aConfigString: String, + aServerConfig: ServerConfig = ServerConfig()): scala.concurrent.Future[us.jubat.yarn.client.JubatusYarnApplication] = { scala.concurrent.Future { val jubaCmdName = aLearningMachineType match { case LearningMachineType.Anomaly => @@ -1501,7 +2164,19 @@ object LocalJubatusApplication extends LazyLogging { val namedPipe = new java.io.File(namedPipePath) try { val rpcPort = findAvailablePort() - val jubatusProcess = runtime.exec(s"$jubaCmdName -p $rpcPort -f $namedPipePath") + val command = new StringBuilder + command.append(s"$jubaCmdName") + command.append(s" -p $rpcPort") + command.append(s" -f $namedPipePath") + command.append(s" -c ${aServerConfig.thread}") + command.append(s" -t ${aServerConfig.timeout}") + command.append(s" -x ${aServerConfig.mixer.name}") + command.append(s" -s ${aServerConfig.intervalSec}") + command.append(s" -i ${aServerConfig.intervalCount}") + command.append(s" -Z ${aServerConfig.zookeeperTimeout}") + command.append(s" -I ${aServerConfig.interconnectTimeout}") + logger.debug(s"command: ${command.result()}") + val jubatusProcess = runtime.exec(command.result()) handleSubProcessOutput(jubatusProcess.getInputStream, System.out, jubaCmdName) handleSubProcessOutput(jubatusProcess.getErrorStream, System.err, jubaCmdName) val namedPipeWriter = new java.io.PrintWriter(namedPipe) @@ -1511,8 +2186,8 @@ object LocalJubatusApplication extends LazyLogging { namedPipeWriter.close() } - new LocalJubatusApplication(jubatusProcess, aLearningMachineName, jubaCmdName, - rpcPort) + new LocalJubatusApplication(jubatusProcess, aLearningMachineName, aLearningMachineType, + jubaCmdName, rpcPort) } finally { namedPipe.delete() } @@ -1566,11 +2241,30 @@ object LocalJubatusApplication extends LazyLogging { } // LocalJubatusApplication is not a JubatusYarnApplication, but extends JubatusYarnApplication for implementation. -class LocalJubatusApplication(jubatus: Process, name: String, jubaCmdName: String, port: Int = 9199) +class LocalJubatusApplication(jubatus: Process, name: String, aLearningMachineType: LearningMachineType, jubaCmdName: String, port: Int = 9199) extends JubatusYarnApplication(Location(InetAddress.getLocalHost, port), List(), null) { + private val timeoutCount: Int = 180 + private val fileRe = """file://(.+)""".r + override def status: JubatusYarnApplicationStatus = { - throw new NotImplementedError("status is not implemented") + logger.info("status LocalJubatusApplication") + + val strHost = jubatusProxy.hostAddress + val strPort = jubatusProxy.port + val client: ClientBase = aLearningMachineType match { + case LearningMachineType.Anomaly => + new AnomalyClient(strHost, strPort, name, timeoutCount) + + case LearningMachineType.Classifier => + new ClassifierClient(strHost, strPort, name, timeoutCount) + + case LearningMachineType.Recommender => + new RecommenderClient(strHost, strPort, name, timeoutCount) + } + + val stsMap: java.util.Map[String, java.util.Map[String, String]] = client.getStatus() + JubatusYarnApplicationStatus(null, stsMap, null) } override def stop(): scala.concurrent.Future[Unit] = scala.concurrent.Future { @@ -1585,10 +2279,174 @@ class LocalJubatusApplication(jubatus: Process, name: String, jubaCmdName: Strin } override def loadModel(aModelPathPrefix: org.apache.hadoop.fs.Path, aModelId: String): Try[JubatusYarnApplication] = Try { - throw new NotImplementedError("loadModel is not implemented") + logger.info(s"loadModel path: $aModelPathPrefix, modelId: $aModelId") + + val strHost = jubatusProxy.hostAddress + val strPort = jubatusProxy.port + + val srcDir = aModelPathPrefix.toUri().toString() match { + case fileRe(filepath) => + val realpath = if (filepath.startsWith("/")) { + filepath + } else { + (new java.io.File(".")).getAbsolutePath + "/" + filepath + } + "file://" + realpath + } + logger.debug(s"convert srcDir: $srcDir") + + val localFileSystem = org.apache.hadoop.fs.FileSystem.getLocal(new Configuration()) + val srcDirectory = localFileSystem.pathToFile(new org.apache.hadoop.fs.Path(srcDir)) + val srcPath = new java.io.File(srcDirectory, aModelId) + if (!srcPath.exists()) { + val msg = s"model path does not exist ($srcPath)" + logger.error(msg) + throw new RuntimeException(msg) + } + + val srcFile = new java.io.File(srcPath, "0.jubatus") + if (!srcFile.exists()) { + val msg = s"model file does not exist ($srcFile)" + logger.error(msg) + throw new RuntimeException(msg) + } + + val client: ClientBase = aLearningMachineType match { + case LearningMachineType.Anomaly => + new AnomalyClient(strHost, strPort, name, timeoutCount) + + case LearningMachineType.Classifier => + new ClassifierClient(strHost, strPort, name, timeoutCount) + + case LearningMachineType.Recommender => + new RecommenderClient(strHost, strPort, name, timeoutCount) + } + + val stsMap: java.util.Map[String, java.util.Map[String, String]] = client.getStatus() + logger.debug(s"getStatus method result: $stsMap") + if (stsMap.size != 1) { + val msg = s"getStatus RPC failed (got ${stsMap.size} results)" + logger.error(msg) + throw new RuntimeException(msg) + } + + val strHostPort = stsMap.keys.head + logger.debug(s"key[Host_Port]: $strHostPort") + + val baseDir = localFileSystem.pathToFile(new org.apache.hadoop.fs.Path(stsMap.get(strHostPort).get("datadir"))) + val mType = stsMap.get(strHostPort).get("type") + val dstFile = new java.io.File(baseDir, s"${strHostPort}_${mType}_${aModelId}.jubatus") + + logger.debug(s"srcFile: $srcFile") + logger.debug(s"dstFile: $dstFile") + + FileUtils.copyFile(srcFile, dstFile, false) + + val ret = client.load(aModelId) + if (!ret) { + val msg = "load RPC failed" + logger.error(msg) + throw new RuntimeException(msg) + } + this } override def saveModel(aModelPathPrefix: org.apache.hadoop.fs.Path, aModelId: String): Try[JubatusYarnApplication] = Try { - throw new NotImplementedError("saveModel is not implemented") + logger.info(s"saveModel path: $aModelPathPrefix, modelId: $aModelId") + + val strHost = jubatusProxy.hostAddress + val strPort = jubatusProxy.port + + val strId = Math.abs(new Random().nextInt()).toString() + + val result: java.util.Map[String, String] = aLearningMachineType match { + case LearningMachineType.Anomaly => + val anomaly = new AnomalyClient(strHost, strPort, name, timeoutCount) + anomaly.save(strId) + + case LearningMachineType.Classifier => + val classifier = new ClassifierClient(strHost, strPort, name, timeoutCount) + classifier.save(strId) + + case LearningMachineType.Recommender => + val recommender = new RecommenderClient(strHost, strPort, name, timeoutCount) + recommender.save(strId) + } + + logger.debug(s"save method result: $result") + if (result.size != 1) { + val msg = s"save RPC failed (got ${result.size} results)" + logger.error(msg) + throw new RuntimeException(msg) + } + + val strSavePath = result.values.head + logger.debug(s"srcFile: $strSavePath") + + val dstDir = aModelPathPrefix.toUri().toString() match { + case fileRe(filepath) => + val realpath = if (filepath.startsWith("/")) { + filepath + } else { + s"${(new java.io.File(".")).getAbsolutePath}/$filepath" + } + s"file://$realpath" + } + logger.debug(s"convert dstDir: $dstDir") + + val localFileSystem = org.apache.hadoop.fs.FileSystem.getLocal(new Configuration()) + val dstDirectory = localFileSystem.pathToFile(new org.apache.hadoop.fs.Path(dstDir)) + val dstPath = new java.io.File(dstDirectory, aModelId) + val dstFile = new java.io.File(dstPath, "0.jubatus") + logger.debug(s"dstFile: $dstFile") + + if (!dstPath.exists()) { + dstPath.mkdirs() + } else { + if (dstFile.exists()) { + dstFile.delete() + } + } + + FileUtils.moveFile(new java.io.File(strSavePath), dstFile) + this + } +} + +object TestJubatusApplication extends LazyLogging { + def start(aLearningMachineName: String, + aLearningMachineType: LearningMachineType): scala.concurrent.Future[us.jubat.yarn.client.JubatusYarnApplication] = { + scala.concurrent.Future { + new TestJubatusApplication(aLearningMachineName, aLearningMachineType) + } + } +} + +class TestJubatusApplication(name: String, aLearningMachineType: LearningMachineType) + extends JubatusYarnApplication(null, List(), null) { + + override def status: JubatusYarnApplicationStatus = { + logger.info("status TestJubatusApplication") + + val dmyProxy: java.util.Map[String, java.util.Map[String, String]] = new java.util.HashMap() + val dmyProxySub: java.util.Map[String, String] = new java.util.HashMap() + dmyProxySub.put("PROGNAME", "jubaclassifier_proxy") + dmyProxy.put("dummyProxy", dmyProxySub) + val dmyServer: java.util.Map[String, java.util.Map[String, String]] = new java.util.HashMap() + val dmyServerSub: java.util.Map[String, String] = new java.util.HashMap() + dmyServerSub.put("PROGNAME", "jubaclassifier") + dmyServer.put("dummyServer", dmyServerSub) + val dmyApp: java.util.Map[String, Any] = new java.util.HashMap() + dmyApp.put("applicationReport", "applicationId{ id: 1 cluster_timestamp: 99999999}") + dmyApp.put("currentTime", System.currentTimeMillis()) + dmyApp.put("oparatingTime", 1000) + JubatusYarnApplicationStatus(dmyProxy, dmyServer, dmyApp) } } + +class StreamState(sc: SparkContext, inputStreams: List[String]) { + var inputCount = 0L + val outputCount = sc.accumulator(0L) + val inputStreamList = inputStreams + var startTime = 0L +} \ No newline at end of file diff --git a/processor/src/main/scala/us/jubat/jubaql_server/processor/json/JubaQLResponse.scala b/processor/src/main/scala/us/jubat/jubaql_server/processor/json/JubaQLResponse.scala index 5d5bdec..7bbfb9a 100644 --- a/processor/src/main/scala/us/jubat/jubaql_server/processor/json/JubaQLResponse.scala +++ b/processor/src/main/scala/us/jubat/jubaql_server/processor/json/JubaQLResponse.scala @@ -15,6 +15,8 @@ // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA package us.jubat.jubaql_server.processor.json +import scala.collection.Map + // We use a sealed trait to make sure we have all possible // response types in *this* file. sealed trait JubaQLResponse @@ -26,6 +28,9 @@ case class AnalyzeResultWrapper(result: AnalyzeResult) extends JubaQLResponse case class StatusResponse(result: String, - sources: Map[String, String], - models: Map[String, String]) + sources: Map[String, Any], + models: Map[String, Any], + processor: Map[String, Any], + streams: Map[String, Map[String, Any]]) + extends JubaQLResponse diff --git a/processor/src/main/scala/us/jubat/jubaql_server/processor/updater/Anomaly.scala b/processor/src/main/scala/us/jubat/jubaql_server/processor/updater/Anomaly.scala index fead403..2b0b179 100644 --- a/processor/src/main/scala/us/jubat/jubaql_server/processor/updater/Anomaly.scala +++ b/processor/src/main/scala/us/jubat/jubaql_server/processor/updater/Anomaly.scala @@ -39,23 +39,31 @@ class Anomaly(val jubaHost: String, jubaPort: Int, cm: CreateModel, featureFunct logger.debug("driver status is 'stopped', skip processing") } var batchStartTime = System.currentTimeMillis() - iter.takeWhile(_ => !stopped_?).zipWithIndex.foreach { case (row, idx) => { - // create datum and send to Jubatus - val d = DatumExtractor.extract(cm, rowSchema, row, featureFunctions, logger) - retry(2, logger)(client.add(d)) + var idx: Int = 0 + while (iter.hasNext && !stopped_?) { + try { + val row = iter.next + // create datum and send to Jubatus + val d = DatumExtractor.extract(cm, rowSchema, row, featureFunctions, logger) + retry(2, logger)(client.add(d)) - // every 1000 items, check if the Spark driver is still running - if ((idx + 1) % 1000 == 0) { - val duration = System.currentTimeMillis() - batchStartTime - logger.debug(s"processed 1000 items using 'add' method in $duration ms") - stopped_? = HttpClientPerJvm.stopped - if (stopped_?) { - logger.debug("driver status is 'stopped', end processing") + // every 1000 items, check if the Spark driver is still running + if ((idx + 1) % 1000 == 0) { + val duration = System.currentTimeMillis() - batchStartTime + logger.debug(s"processed 1000 items using 'add' method in $duration ms") + stopped_? = HttpClientPerJvm.stopped + if (stopped_?) { + logger.debug("driver status is 'stopped', end processing") + } + batchStartTime = System.currentTimeMillis() } - batchStartTime = System.currentTimeMillis() + } catch { + case e: Exception => + logger.error(s"Failed to add row.", e) + } finally { + idx += 1 } } - } } override def analyzeMethod(rpcName: String) = { @@ -79,35 +87,46 @@ class Anomaly(val jubaHost: String, jubaPort: Int, cm: CreateModel, featureFunct logger.debug("driver status is 'stopped', skip processing") } var batchStartTime = System.currentTimeMillis() - iter.takeWhile(_ => !stopped_?).zipWithIndex.map { case (row, idx) => { - // convert to datum and compute score via Jubatus - val datum = DatumExtractor.extract(cm, rowSchema, row, featureFunctions, logger) - // if we return a Float here, this will result in casting exceptions - // during processing, so we convert to double - val score = retry(2, logger)(client.calcScore(datum).toDouble) - // we may get an Infinity result if this row is identical to too many - // other items, cf. . - // we assume 1.0 instead to avoid weird behavior in the future if the - // infinity value appeared in the row. - val adjustedScore = if (score.isInfinite) { - 1.0 - } else { - score - } + var resultSeq: Seq[Row] = List.empty[Row] + var idx: Int = 0 + while(iter.hasNext && !stopped_?) { + try { + val row = iter.next + // convert to datum and compute score via Jubatus + val datum = DatumExtractor.extract(cm, rowSchema, row, featureFunctions, logger) + // if we return a Float here, this will result in casting exceptions + // during processing, so we convert to double + val score = retry(2, logger)(client.calcScore(datum).toDouble) + // we may get an Infinity result if this row is identical to too many + // other items, cf. . + // we assume 1.0 instead to avoid weird behavior in the future if the + // infinity value appeared in the row. + val adjustedScore = if (score.isInfinite) { + 1.0 + } else { + score + } - // every 1000 items, check if the Spark driver is still running - if ((idx + 1) % 1000 == 0) { - val duration = System.currentTimeMillis() - batchStartTime - logger.debug(s"processed 1000 items using 'calc_score' method in $duration ms") - stopped_? = HttpClientPerJvm.stopped - if (stopped_?) { - logger.debug("driver status is 'stopped', end processing") + // every 1000 items, check if the Spark driver is still running + if ((idx + 1) % 1000 == 0) { + val duration = System.currentTimeMillis() - batchStartTime + logger.debug(s"processed 1000 items using 'calc_score' method in $duration ms") + stopped_? = HttpClientPerJvm.stopped + if (stopped_?) { + logger.debug("driver status is 'stopped', end processing") + } + batchStartTime = System.currentTimeMillis() } - batchStartTime = System.currentTimeMillis() - } - Row.fromSeq(row :+ adjustedScore) - } + val scoreRow = Row.fromSeq(row :+ adjustedScore) + resultSeq = resultSeq :+ scoreRow + } catch { + case e: Exception => + logger.error(s"Failed to calcScore row.", e) + } finally { + idx += 1 + } } + resultSeq.iterator } } diff --git a/processor/src/main/scala/us/jubat/jubaql_server/processor/updater/Classifier.scala b/processor/src/main/scala/us/jubat/jubaql_server/processor/updater/Classifier.scala index 88290ac..c586a61 100644 --- a/processor/src/main/scala/us/jubat/jubaql_server/processor/updater/Classifier.scala +++ b/processor/src/main/scala/us/jubat/jubaql_server/processor/updater/Classifier.scala @@ -55,32 +55,40 @@ class Classifier(val jubaHost: String, jubaPort: Int, cm: CreateModel, logger.debug("driver status is 'stopped', skip processing") } var batchStartTime = System.currentTimeMillis() + var idx: Int = 0 // TODO we can make this more efficient using batch training - iter.takeWhile(_ => !stopped_?).zipWithIndex.foreach { case (row, idx) => { - if (!row.isNullAt(labelIdx)) { - // create datum and send to Jubatus - val labelValue = row.getString(labelIdx) - val datum = DatumExtractor.extract(cm, rowSchema, row, featureFunctions, logger) - val labelDatum = new LabeledDatum(labelValue, datum) - val datumList = new java.util.LinkedList[LabeledDatum]() - datumList.add(labelDatum) - retry(2, logger)(client.train(datumList)) + while (iter.hasNext && !stopped_?) { + try { + val row = iter.next + if (!row.isNullAt(labelIdx)) { + // create datum and send to Jubatus + val labelValue = row.getString(labelIdx) + val datum = DatumExtractor.extract(cm, rowSchema, row, featureFunctions, logger) + val labelDatum = new LabeledDatum(labelValue, datum) + val datumList = new java.util.LinkedList[LabeledDatum]() + datumList.add(labelDatum) + retry(2, logger)(client.train(datumList)) - // every 1000 items, check if the Spark driver is still running - if ((idx + 1) % 1000 == 0) { - val duration = System.currentTimeMillis() - batchStartTime - logger.debug(s"processed 1000 items using 'train' method in $duration ms") - stopped_? = HttpClientPerJvm.stopped - if (stopped_?) { - logger.debug("driver status is 'stopped', end processing") + // every 1000 items, check if the Spark driver is still running + if ((idx + 1) % 1000 == 0) { + val duration = System.currentTimeMillis() - batchStartTime + logger.debug(s"processed 1000 items using 'train' method in $duration ms") + stopped_? = HttpClientPerJvm.stopped + if (stopped_?) { + logger.debug("driver status is 'stopped', end processing") + } + batchStartTime = System.currentTimeMillis() } - batchStartTime = System.currentTimeMillis() + } else { + logger.warn("row %s has a NULL label".format(row)) } - } else { - logger.warn("row %s has a NULL label".format(row)) + } catch { + case e: Exception => + logger.error(s"Failed to train row.", e) + } finally { + idx += 1 } } - } } } @@ -110,45 +118,60 @@ class Classifier(val jubaHost: String, jubaPort: Int, cm: CreateModel, logger.debug("driver status is 'stopped', skip processing") } var batchStartTime = System.currentTimeMillis() - iter.takeWhile(_ => !stopped_?).zipWithIndex.flatMap { case (row, idx) => { - // TODO we can make this more efficient using batch training - // convert to datum - val datum = DatumExtractor.extract(cm, rowSchema, row, featureFunctions, logger) - val datumList = new java.util.LinkedList[Datum]() - datumList.add(datum) - // classify - val maybeClassifierResult = retry(2, logger)(client.classify(datumList).toList) match { - case labeledDatumList :: rest => - if (!rest.isEmpty) { - logger.warn("received more than one result from classifier, " + - "ignoring all but the first") - } - if (labeledDatumList.isEmpty) { - logger.warn("got an empty classification list for datum") - } - Some(labeledDatumList.map(labeledDatum => { - ClassifierPrediction(labeledDatum.label, labeledDatum.score) - }).toList) - case Nil => - logger.error("received no result from classifier") - None - } - // every 1000 items, check if the Spark driver is still running - if ((idx + 1) % 1000 == 0) { - val duration = System.currentTimeMillis() - batchStartTime - logger.debug(s"processed 1000 items using 'classify' method in $duration ms") - stopped_? = HttpClientPerJvm.stopped - if (stopped_?) { - logger.debug("driver status is 'stopped', end processing") + var resultSeq: Seq[Row] = List.empty[Row] + var idx: Int = 0 + while(iter.hasNext && !stopped_?) { + try { + val row = iter.next + // TODO we can make this more efficient using batch training + // convert to datum + val datum = DatumExtractor.extract(cm, rowSchema, row, featureFunctions, logger) + val datumList = new java.util.LinkedList[Datum]() + datumList.add(datum) + // classify + val maybeClassifierResult = retry(2, logger)(client.classify(datumList).toList) match { + case labeledDatumList :: rest => + if (!rest.isEmpty) { + logger.warn("received more than one result from classifier, " + + "ignoring all but the first") + } + if (labeledDatumList.isEmpty) { + logger.warn("got an empty classification list for datum") + } + Some(labeledDatumList.map(labeledDatum => { + ClassifierPrediction(labeledDatum.label, labeledDatum.score) + }).toList) + case Nil => + logger.error("received no result from classifier") + None + } + + // every 1000 items, check if the Spark driver is still running + if ((idx + 1) % 1000 == 0) { + val duration = System.currentTimeMillis() - batchStartTime + logger.debug(s"processed 1000 items using 'classify' method in $duration ms") + stopped_? = HttpClientPerJvm.stopped + if (stopped_?) { + logger.debug("driver status is 'stopped', end processing") + } + batchStartTime = System.currentTimeMillis() + } + maybeClassifierResult.map(classifierResult => { + Row.fromSeq(row :+ classifierResult.map(r => + Row.fromSeq(r.productIterator.toSeq))) + }) match { + case Some(classifier) => + resultSeq = resultSeq :+ classifier + case None => //nothing } - batchStartTime = System.currentTimeMillis() + } catch { + case e: Exception => + logger.error(s"Failed to classify row.", e) + } finally { + idx += 1 } - maybeClassifierResult.map(classifierResult => { - Row.fromSeq(row :+ classifierResult.map(r => - Row.fromSeq(r.productIterator.toSeq))) - }) - } } + resultSeq.iterator } } diff --git a/processor/src/main/scala/us/jubat/jubaql_server/processor/updater/Recommender.scala b/processor/src/main/scala/us/jubat/jubaql_server/processor/updater/Recommender.scala index b398df8..cf13f1d 100644 --- a/processor/src/main/scala/us/jubat/jubaql_server/processor/updater/Recommender.scala +++ b/processor/src/main/scala/us/jubat/jubaql_server/processor/updater/Recommender.scala @@ -54,28 +54,36 @@ class Recommender(val jubaHost: String, jubaPort: Int, cm: CreateModel, logger.debug("driver status is 'stopped', skip processing") } var batchStartTime = System.currentTimeMillis() - iter.takeWhile(_ => !stopped_?).zipWithIndex.foreach { case (row, idx) => { - if (!row.isNullAt(idIdx)) { - // create datum and send to Jubatus - val idValue = row.getString(idIdx) - val datum = DatumExtractor.extract(cm, rowSchema, row, featureFunctions, logger) - retry(2, logger)(client.updateRow(idValue, datum)) - - // every 1000 items, check if the Spark driver is still running - if ((idx + 1) % 1000 == 0) { - val duration = System.currentTimeMillis() - batchStartTime - logger.debug(s"processed 1000 items using 'updateRow' method in $duration ms") - stopped_? = HttpClientPerJvm.stopped - if (stopped_?) { - logger.debug("driver status is 'stopped', end processing") + var idx: Int = 0 + while (iter.hasNext && !stopped_?) { + try { + val row = iter.next + if (!row.isNullAt(idIdx)) { + // create datum and send to Jubatus + val idValue = row.getString(idIdx) + val datum = DatumExtractor.extract(cm, rowSchema, row, featureFunctions, logger) + retry(2, logger)(client.updateRow(idValue, datum)) + + // every 1000 items, check if the Spark driver is still running + if ((idx + 1) % 1000 == 0) { + val duration = System.currentTimeMillis() - batchStartTime + logger.debug(s"processed 1000 items using 'updateRow' method in $duration ms") + stopped_? = HttpClientPerJvm.stopped + if (stopped_?) { + logger.debug("driver status is 'stopped', end processing") + } + batchStartTime = System.currentTimeMillis() } - batchStartTime = System.currentTimeMillis() + } else { + logger.warn("row %s has a NULL id".format(row)) } - } else { - logger.warn("row %s has a NULL id".format(row)) + } catch { + case e: Exception => + logger.error(s"Failed to add row.", e) + } finally { + idx += 1 } } - } } } @@ -109,27 +117,38 @@ class Recommender(val jubaHost: String, jubaPort: Int, cm: CreateModel, logger.debug("driver status is 'stopped', skip processing") } var batchStartTime = System.currentTimeMillis() - iter.takeWhile(_ => !stopped_?).zipWithIndex.map { case (row, idx) => { - - val id = extractIdOrLabel(idColumnName, rowSchema, row, logger) - val fullDatum = retry(2, logger)(client.completeRowFromId(id)) - val wrappedFullDatum = datumToJson(fullDatum) - // every 1000 items, check if the Spark driver is still running - if ((idx + 1) % 1000 == 0) { - val duration = System.currentTimeMillis() - batchStartTime - logger.debug(s"processed 1000 items using 'complete_row_from_id' method in $duration ms") - stopped_? = HttpClientPerJvm.stopped - if (stopped_?) { - logger.debug("driver status is 'stopped', end processing") + var resultSeq: Seq[Row] = List.empty[Row] + var idx: Int = 0 + while (iter.hasNext && !stopped_?) { + try { + val row = iter.next + val id = extractIdOrLabel(idColumnName, rowSchema, row, logger) + val fullDatum = retry(2, logger)(client.completeRowFromId(id)) + val wrappedFullDatum = datumToJson(fullDatum) + + // every 1000 items, check if the Spark driver is still running + if ((idx + 1) % 1000 == 0) { + val duration = System.currentTimeMillis() - batchStartTime + logger.debug(s"processed 1000 items using 'complete_row_from_id' method in $duration ms") + stopped_? = HttpClientPerJvm.stopped + if (stopped_?) { + logger.debug("driver status is 'stopped', end processing") + } + batchStartTime = System.currentTimeMillis() } - batchStartTime = System.currentTimeMillis() - } - // we must add a nested row (not case class) to allow for nested queries - Row.fromSeq(row :+ Row.fromSeq(wrappedFullDatum.productIterator.toSeq)) - } + // we must add a nested row (not case class) to allow for nested queries + val fromIdRow = Row.fromSeq(row :+ Row.fromSeq(wrappedFullDatum.productIterator.toSeq)) + resultSeq = resultSeq :+ fromIdRow + } catch { + case e: Exception => + logger.error(s"Failed to completeRowFromId row.", e) + } finally { + idx += 1 + } } + resultSeq.iterator } def completeRowFromDatum(rowSchema: Map[String, (Int, DataType)], @@ -145,28 +164,39 @@ class Recommender(val jubaHost: String, jubaPort: Int, cm: CreateModel, logger.debug("driver status is 'stopped', skip processing") } var batchStartTime = System.currentTimeMillis() - iter.takeWhile(_ => !stopped_?).zipWithIndex.map { case (row, idx) => { - - // convert to datum - val datum = DatumExtractor.extract(cm, rowSchema, row, featureFunctions, logger) - val fullDatum = retry(2, logger)(client.completeRowFromDatum(datum)) - val wrappedFullDatum = datumToJson(fullDatum) - - // every 1000 items, check if the Spark driver is still running - if ((idx + 1) % 1000 == 0) { - val duration = System.currentTimeMillis() - batchStartTime - logger.debug(s"processed 1000 items using 'complete_row_from_datum' method in $duration ms") - stopped_? = HttpClientPerJvm.stopped - if (stopped_?) { - logger.debug("driver status is 'stopped', end processing") + + var resultSeq: Seq[Row] = List.empty[Row] + var idx: Int = 0 + while (iter.hasNext && !stopped_?) { + try { + val row = iter.next + // convert to datum + val datum = DatumExtractor.extract(cm, rowSchema, row, featureFunctions, logger) + val fullDatum = retry(2, logger)(client.completeRowFromDatum(datum)) + val wrappedFullDatum = datumToJson(fullDatum) + + // every 1000 items, check if the Spark driver is still running + if ((idx + 1) % 1000 == 0) { + val duration = System.currentTimeMillis() - batchStartTime + logger.debug(s"processed 1000 items using 'complete_row_from_datum' method in $duration ms") + stopped_? = HttpClientPerJvm.stopped + if (stopped_?) { + logger.debug("driver status is 'stopped', end processing") + } + batchStartTime = System.currentTimeMillis() } - batchStartTime = System.currentTimeMillis() - } - // we must add a nested row (not case class) to allow for nested queries - Row.fromSeq(row :+ Row.fromSeq(wrappedFullDatum.productIterator.toSeq)) - } + // we must add a nested row (not case class) to allow for nested queries + val fromDatumRow = Row.fromSeq(row :+ Row.fromSeq(wrappedFullDatum.productIterator.toSeq)) + resultSeq = resultSeq :+ fromDatumRow + } catch { + case e: Exception => + logger.error(s"Failed to completeRowFromDatum row.", e) + } finally { + idx += 1 + } } + resultSeq.iterator } protected def datumToJson(datum: Datum): DatumResult = { diff --git a/processor/src/test/resources/data_1.json b/processor/src/test/resources/data_1.json new file mode 100644 index 0000000..c261800 --- /dev/null +++ b/processor/src/test/resources/data_1.json @@ -0,0 +1,4 @@ +{"label":"ashikaga","name":"takauji","jubaql_timestamp":"2010-11-11T11:11:11"} +{"label":"ashikaga","name":"yoshiakira","jubaql_timestamp":"2010-11-11T11:11:12"} +{"label":"ashikaga","name":"yoshimitsu","jubaql_timestamp":"2010-11-11T11:11:13"} +{"label":"ashikaga","name":"yoshinori","jubaql_timestamp":"2010-11-11T11:11:14"} \ No newline at end of file diff --git a/processor/src/test/resources/data_2.json b/processor/src/test/resources/data_2.json new file mode 100644 index 0000000..b26bd0f --- /dev/null +++ b/processor/src/test/resources/data_2.json @@ -0,0 +1,2 @@ +{"label":"hojo","name":"tokimasa","jubaql_timestamp":"2016-11-11T11:11:11"} +{"label":"hojo","name":"munetoki","jubaql_timestamp":"2016-11-11T11:11:12"} \ No newline at end of file diff --git a/processor/src/test/resources/kafka.xml.dist b/processor/src/test/resources/kafka.xml.dist index 634f690..cc0dc7b 100644 --- a/processor/src/test/resources/kafka.xml.dist +++ b/processor/src/test/resources/kafka.xml.dist @@ -3,4 +3,6 @@ [kafka path] + +localhost:9092 diff --git a/processor/src/test/resources/shogun_1.json b/processor/src/test/resources/shogun_1.json new file mode 100644 index 0000000..08682a4 --- /dev/null +++ b/processor/src/test/resources/shogun_1.json @@ -0,0 +1 @@ +{"label":"徳川","name":"家康"} \ No newline at end of file diff --git a/processor/src/test/resources/shogun_2.json b/processor/src/test/resources/shogun_2.json new file mode 100644 index 0000000..a69e830 --- /dev/null +++ b/processor/src/test/resources/shogun_2.json @@ -0,0 +1,2 @@ +{"label":"徳川","name":"家康"} +{"label":"足利","name":"義満"} \ No newline at end of file diff --git a/processor/src/test/resources/test_data.json b/processor/src/test/resources/test_data.json new file mode 100644 index 0000000..f8fe339 --- /dev/null +++ b/processor/src/test/resources/test_data.json @@ -0,0 +1,3 @@ +{"label":"tokugawa","name":"ieyasu","jubaql_timestamp":"2015-11-11T11:11:11"} +{"label":"tokugawa","name":"hidetada","jubaql_timestamp":"2015-11-11T11:11:12"} +{"label":"tokugawa","name":"iemitsu","jubaql_timestamp":"2015-11-11T11:11:13"} \ No newline at end of file diff --git a/processor/src/test/scala/us/jubat/jubaql_server/processor/HasKafkaPath.scala b/processor/src/test/scala/us/jubat/jubaql_server/processor/HasKafkaPath.scala index 2d222bf..80fbc4d 100644 --- a/processor/src/test/scala/us/jubat/jubaql_server/processor/HasKafkaPath.scala +++ b/processor/src/test/scala/us/jubat/jubaql_server/processor/HasKafkaPath.scala @@ -37,4 +37,21 @@ trait HasKafkaPath extends ShouldMatchers { properties.getProperty("path") } + + lazy val kafkaServerAddress: String = { + val kafkaXmlPath = "src/test/resources/kafka.xml" + + val is = try { + Some(new FileInputStream(kafkaXmlPath)) + } catch { + case _: FileNotFoundException => + None + } + is shouldBe a[Some[_]] + + val properties = new Properties() + properties.loadFromXML(is.get) + + properties.getProperty("server_address") + } } diff --git a/processor/src/test/scala/us/jubat/jubaql_server/processor/HybridProcessorSpec.scala b/processor/src/test/scala/us/jubat/jubaql_server/processor/HybridProcessorSpec.scala index 8064621..cda5a81 100644 --- a/processor/src/test/scala/us/jubat/jubaql_server/processor/HybridProcessorSpec.scala +++ b/processor/src/test/scala/us/jubat/jubaql_server/processor/HybridProcessorSpec.scala @@ -15,14 +15,17 @@ // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA package us.jubat.jubaql_server.processor -import java.io.{FileNotFoundException, FileInputStream} import java.util.Properties - import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.{SparkContext, SparkException} import org.scalatest._ +import scala.collection.mutable.LinkedHashMap +import kafka.producer.ProducerConfig +import kafka.producer.Producer +import kafka.producer.KeyedMessage + class HybridProcessorSpec extends FlatSpec @@ -100,6 +103,39 @@ class HybridProcessorSpec } } + "getStatus()" should "return datasource status for local files" taggedAs (LocalTest) in { + val inPath = "file://src/test/resources/dummydata" + val processor = new HybridProcessor(sc, sqlc, inPath, Nil) + var status = processor.getStatus() + + // Initialized + status.get("state") shouldBe Some("Initialized") + + status.get("storage") match { + case Some(storage) => + val storageMap: LinkedHashMap[String, Any] = storage.asInstanceOf[LinkedHashMap[String, Any]] + storageMap.get("path") shouldBe Some(inPath) + case _ => fail() + } + + status.get("stream") match { + case Some(stream) => + val streamMap: LinkedHashMap[String, Any] = stream.asInstanceOf[LinkedHashMap[String, Any]] + streamMap.get("path") shouldBe Some(List()) + case _ => fail() + } + + // Running + processor.startJValueProcessing(rdd => rdd.count) + status = processor.getStatus() + status.get("state") shouldBe Some("Running") + + // Finished + processor.awaitTermination() + status = processor.getStatus() + status.get("state") shouldBe Some("Finished") + } + override def afterAll = { sc.stop() } @@ -187,6 +223,39 @@ class HDFSStreamSpec streamInfo.maxId shouldBe empty } + "getStatus()" should "return datasource status for hdfs files" taggedAs (HDFSTest) in { + val inPath = "hdfs:///user/empty" + val processor = new HybridProcessor(sc, sqlc, inPath, Nil) + var status = processor.getStatus() + + // Initialized + status.get("state") shouldBe Some("Initialized") + + status.get("storage") match { + case Some(storage) => + val storageMap: LinkedHashMap[String, Any] = storage.asInstanceOf[LinkedHashMap[String, Any]] + storageMap.get("path") shouldBe Some(inPath) + case _ => fail() + } + + status.get("stream") match { + case Some(stream) => + val streamMap: LinkedHashMap[String, Any] = stream.asInstanceOf[LinkedHashMap[String, Any]] + streamMap.get("path") shouldBe Some(List()) + case _ => fail() + } + + // Running + processor.startJValueProcessing(rdd => rdd.count) + status = processor.getStatus() + status.get("state") shouldBe Some("Running") + + // Finished + processor.awaitTermination() + status = processor.getStatus() + status.get("state") shouldBe Some("Finished") + } + override def afterAll = { sc.stop() } @@ -306,6 +375,42 @@ class KafkaStreamSpec streamInfo.maxId shouldBe empty } + "getStatus()" should "return datasource status for stream" taggedAs (KafkaTest) in { + val inPath = s"kafka://$kafkaPath/dummy/1" + val processor = new HybridProcessor(sc, sqlc, "empty", inPath :: Nil) + var status = processor.getStatus() + + // Initialized + status.get("state") shouldBe Some("Initialized") + + status.get("storage") match { + case Some(storage) => + val storageMap: LinkedHashMap[String, Any] = storage.asInstanceOf[LinkedHashMap[String, Any]] + storageMap.get("path") shouldBe Some("empty") + case _ => fail() + } + + status.get("stoream") match { + case Some(stream) => + val streamMap: LinkedHashMap[String, Any] = stream.asInstanceOf[LinkedHashMap[String, Any]] + streamMap.get("path") shouldBe Some(List(inPath)) + case _ => fail() + } + + // Running + val stopFun = processor.startJValueProcessing(rdd => rdd.count)._1 + status = processor.getStatus() + status.get("state") shouldBe Some("Running") + + Thread.sleep(1700) // if we stop during the first batch, something goes wrong + val (staticInfo, streamInfo) = stopFun() + + // Finished + processor.awaitTermination() + status = processor.getStatus() + status.get("state") shouldBe Some("Finished") + } + override def afterAll = { sc.stop() } @@ -316,7 +421,7 @@ class HDFSKafkaStreamSpec with ShouldMatchers with HasKafkaPath with BeforeAndAfterAll { - val sc = new SparkContext("local[3]", "KafkaStreamSpec") + val sc = new SparkContext("local[3]", "HDFSKafkaStreamSpec") val sqlc = new SQLContext(sc) "HDFS+Kafka processing" should "change processing smoothly" taggedAs (HDFSTest, KafkaTest) in { @@ -385,6 +490,157 @@ class HDFSKafkaStreamSpec } } +class FileKafkaStreamSpec + extends FlatSpec + with ShouldMatchers + with HasKafkaPath + with BeforeAndAfterAll { + val sc = new SparkContext("local[3]", "FileKafkaStreamSpec") + val sqlc = new SQLContext(sc) + + // テスト実行時の事前準備 + // ・kafka serverを起動する + // ・dummy1,dummy_1,dummy_2のtopicが存在しない場合作成する + "ProcessingStatus(phase/starTime/count/timeStamp)" should "check ProcessingStatus value" taggedAs (HDFSTest, KafkaTest) in { + val filePath = "file://src/test/resources/test_data.json" + val kafkaURI = s"kafka://$kafkaPath/dummy1/1" + val processor = new HybridProcessor(sc, sqlc, filePath, kafkaURI :: Nil) + + val testStartTime = System.currentTimeMillis() + + processor.getStatus.get("process_phase").get shouldBe "Stop" + var storageMap = processor.getStatus.get("storage").get.asInstanceOf[LinkedHashMap[String,Any]] + storageMap.get("storage_start").get shouldBe 0L + storageMap.get("storage_count").get shouldBe 0L + var streamMap = processor.getStatus.get("stream").get.asInstanceOf[LinkedHashMap[String,Any]] + streamMap.get("stream_start").get shouldBe 0L + streamMap.get("stream_count").get shouldBe 0L + processor.getStatus.get("process_timestamp").get shouldBe "" + + val stopFun = processor.startJValueProcessing(rdd => rdd.count)._1 + + sendKafkaMessage(s"$kafkaServerAddress", "dummy1", Array("""{"label":"tokugawa", "name":"test1", "jubaql_timestamp": "2016-11-11T11:11:11"}""", """{"label":"tokugawa", "name":"test2", "jubaql_timestamp": "2014-11-11T11:11:15"}""","""{"label":"tokugawa", "name":"test3", "jubaql_timestamp": "2016-11-11T11:11:10"}""")) + + processor.getStatus.get("process_phase").get shouldBe "Storage" + + while (processor.phase == StoragePhase) { + Thread.sleep(1000) + } + processor.getStatus.get("process_phase").get shouldBe "Stream" + storageMap = processor.getStatus.get("storage").get.asInstanceOf[LinkedHashMap[String,Any]] + storageMap.get("storage_start").get.asInstanceOf[Long] should be > testStartTime + storageMap.get("storage_count").get shouldBe 3L + streamMap = processor.getStatus.get("stream").get.asInstanceOf[LinkedHashMap[String,Any]] + streamMap.get("stream_start").get.asInstanceOf[Long] should be > storageMap.get("storage_start").get.asInstanceOf[Long] + streamMap.get("stream_count").get shouldBe 0L + processor.getStatus.get("process_timestamp").get shouldBe "2015-11-11T11:11:13" + + Thread.sleep(5000) + val (staticInfo, streamInfo) = stopFun() + + processor.getStatus.get("process_phase").get shouldBe "Stop" + storageMap = processor.getStatus.get("storage").get.asInstanceOf[LinkedHashMap[String,Any]] + storageMap.get("storage_start").get.asInstanceOf[Long] should be > testStartTime + storageMap.get("storage_count").get shouldBe 3L + streamMap = processor.getStatus.get("stream").get.asInstanceOf[LinkedHashMap[String,Any]] + streamMap.get("stream_start").get.asInstanceOf[Long] should be > storageMap.get("storage_start").get.asInstanceOf[Long] + streamMap.get("stream_count").get shouldBe 2L + processor.getStatus.get("process_timestamp").get shouldBe "2016-11-11T11:11:11" + } + + it should "Interference confirmation of ProcessingStatus values" taggedAs (HDFSTest, KafkaTest) in { + val filePathA = "file://src/test/resources/data_1.json" + val kafkaURIA = s"kafka://$kafkaPath/dummy_1/1" + val processorA = new HybridProcessor(sc, sqlc, filePathA, kafkaURIA :: Nil) + val filePathB = "file://src/test/resources/data_2.json" + val kafkaURIB = s"kafka://$kafkaPath/dummy_2/1" + val processorB = new HybridProcessor(sc, sqlc, filePathB, kafkaURIB :: Nil) + + sendKafkaMessage(s"$kafkaServerAddress", "dummy_1", Array("""{"label":"tokugawa", "name":"test1", "jubaql_timestamp": "2015-11-11T11:11:11"}""", """{"label":"tokugawa", "name":"test2", "jubaql_timestamp": "2015-11-11T11:11:12"}""")) + Thread.sleep(3000) + sendKafkaMessage(s"$kafkaServerAddress", "dummy_2", Array("""{"label":"tokugawa", "name":"test1", "jubaql_timestamp": "2016-11-11T11:11:11"}""", """{"label":"tokugawa", "name":"test1", "jubaql_timestamp": "2016-11-11T11:11:13"}""")) + + val stopFunA = processorA.startJValueProcessing(rdd => rdd.count)._1 + while (processorA.phase == StoragePhase) { + Thread.sleep(1000) + } + Thread.sleep(10000) + processorA.getStatus.get("process_phase").get shouldBe "Stream" + var storageMapA = processorA.getStatus.get("storage").get.asInstanceOf[LinkedHashMap[String,Any]] + storageMapA.get("storage_start").get.asInstanceOf[Long] should be > 0L + storageMapA.get("storage_count").get shouldBe 4L + var streamMapA = processorA.getStatus.get("stream").get.asInstanceOf[LinkedHashMap[String,Any]] + streamMapA.get("stream_start").get.asInstanceOf[Long] should be > 0L + streamMapA.get("stream_count").get.asInstanceOf[Long] shouldBe 2L + processorA.getStatus.get("process_timestamp").get shouldBe "2015-11-11T11:11:12" + + processorB.getStatus.get("process_phase").get shouldBe "Stop" + var storageMapB = processorB.getStatus.get("storage").get.asInstanceOf[LinkedHashMap[String,Any]] + storageMapB.get("storage_start").get.asInstanceOf[Long] shouldBe 0L + storageMapB.get("storage_count").get shouldBe 0L + var streamMapB = processorB.getStatus.get("stream").get.asInstanceOf[LinkedHashMap[String,Any]] + streamMapB.get("stream_start").get.asInstanceOf[Long] shouldBe 0L + streamMapB.get("stream_count").get.asInstanceOf[Long] shouldBe 0L + processorB.getStatus.get("process_timestamp").get shouldBe "" + + val (staticInfoA, streamInfoA) = stopFunA() + + val stopFunB = processorB.startJValueProcessing(rdd => rdd.count)._1 + while (processorB.phase == StoragePhase) { + Thread.sleep(1000) + } + Thread.sleep(10000) + processorA.getStatus.get("process_phase").get shouldBe "Stop" + storageMapA = processorA.getStatus.get("storage").get.asInstanceOf[LinkedHashMap[String,Any]] + storageMapA.get("storage_start").get.asInstanceOf[Long] should be > 0L + storageMapA.get("storage_count").get shouldBe 4L + streamMapA = processorA.getStatus.get("stream").get.asInstanceOf[LinkedHashMap[String,Any]] + streamMapA.get("stream_start").get.asInstanceOf[Long] should be > 0L + streamMapA.get("stream_count").get.asInstanceOf[Long] shouldBe 2L + processorA.getStatus.get("process_timestamp").get shouldBe "2015-11-11T11:11:12" + + processorB.getStatus.get("process_phase").get shouldBe "Stream" + storageMapB = processorB.getStatus.get("storage").get.asInstanceOf[LinkedHashMap[String,Any]] + storageMapB.get("storage_start").get.asInstanceOf[Long] should be > 0L + storageMapB.get("storage_count").get shouldBe 2L + streamMapB = processorB.getStatus.get("stream").get.asInstanceOf[LinkedHashMap[String,Any]] + streamMapB.get("stream_start").get.asInstanceOf[Long] should be > 0L + streamMapB.get("stream_count").get.asInstanceOf[Long] shouldBe 1L + processorB.getStatus.get("process_timestamp").get shouldBe "2016-11-11T11:11:13" + + val (staticInfoB, streamInfoB) = stopFunB() + + processorA.getStatus.get("process_phase").get shouldBe "Stop" + processorB.getStatus.get("process_phase").get shouldBe "Stop" + + // + storageMapA = processorA.getStatus.get("storage").get.asInstanceOf[LinkedHashMap[String,Any]] + storageMapB = processorB.getStatus.get("storage").get.asInstanceOf[LinkedHashMap[String,Any]] + streamMapA = processorA.getStatus.get("stream").get.asInstanceOf[LinkedHashMap[String,Any]] + streamMapB = processorB.getStatus.get("stream").get.asInstanceOf[LinkedHashMap[String,Any]] + + storageMapA.get("storage_start").get.asInstanceOf[Long] should not be storageMapB.get("storage_start").get.asInstanceOf[Long] + streamMapA.get("stream_start").get.asInstanceOf[Long] should not be streamMapB.get("stream_start").get.asInstanceOf[Long] + } + + override def afterAll = { + sc.stop() + } + + def sendKafkaMessage(address:String, topic: String, message: Array[String]):Unit = { + val prop = new Properties() + prop.put("metadata.broker.list", address) + prop.put("serializer.class","kafka.serializer.StringEncoder") + val producerConfig = new ProducerConfig(prop) + val producer = new Producer[String, String](producerConfig) + message.foreach({ line => + val message = new KeyedMessage[String, String](topic, line) + producer.send(message) + }) + producer.close() + } +} + class SQLSpec extends FeatureSpec with GivenWhenThen diff --git a/processor/src/test/scala/us/jubat/jubaql_server/processor/JavaScriptSpec.scala b/processor/src/test/scala/us/jubat/jubaql_server/processor/JavaScriptSpec.scala index b61d98e..92d709d 100644 --- a/processor/src/test/scala/us/jubat/jubaql_server/processor/JavaScriptSpec.scala +++ b/processor/src/test/scala/us/jubat/jubaql_server/processor/JavaScriptSpec.scala @@ -25,6 +25,9 @@ import unfiltered.util.RunnableServer import scala.collection.JavaConversions._ import scala.util.Success +import org.scalatest.exceptions.TestFailedException +import javax.script.ScriptException +import scala.util.Failure class JavaScriptSpec extends FlatSpec with ShouldMatchers with MockServer { protected var wiser: Wiser = null @@ -105,10 +108,12 @@ class JavaScriptSpec extends FlatSpec with ShouldMatchers with MockServer { | return result.get(); """.stripMargin val funcBody = funcBodyTmpl.format(body) + val cores = Runtime.getRuntime().availableProcessors(); // up to 8 requests are processed in parallel, the 9th is // executed later (seems like 8 is the thread pool limit for // either dispatch or unfiltered) - val loop = (1 to 8).toList.par + // modify: get number of cores + val loop = (1 to cores).toList.par val startTime = System.currentTimeMillis() val resultOpts = loop.map(_ => { JavaScriptUDFManager.registerAndTryCall[String]("test", 0, funcBody) @@ -207,6 +212,89 @@ class JavaScriptSpec extends FlatSpec with ShouldMatchers with MockServer { mime.getContent.toString should include("よろしく") } + it should "registerAndCall: allow simple functions" taggedAs (LocalTest) in { + val body = "return x;" + val funcBody = funcBodyTmpl.format(body) + val result = JavaScriptUDFManager.registerAndCall[Double]("test", + 1, funcBody, Double.box(17.0)) + result shouldBe 17.0 + } + + it should "registerAndTryCall: allow simple functions" taggedAs (LocalTest) in { + val body = "return x;" + val funcBody = funcBodyTmpl.format(body) + val resultTry = JavaScriptUDFManager.registerAndTryCall[Double]("test", + 1, funcBody, Double.box(17.0)) + resultTry shouldBe a[Success[_]] + resultTry.foreach(result => { + result shouldBe 17.0 + }) + } + + it should "registerAndOptionCall: allow simple functions" taggedAs (LocalTest) in { + val body = "return x;" + val funcBody = funcBodyTmpl.format(body) + val resultOpt = JavaScriptUDFManager.registerAndOptionCall[Double]("test", + 1, funcBody, Double.box(17.0)) + resultOpt shouldBe a[Some[_]] + resultOpt.foreach(result => { + result shouldBe 17.0 + }) + } + + "JavaScript throws Exception" should "registerAndCall(args = 1) throw Exception" taggedAs (LocalTest) in { + val body = "throw new Error('error Message');" + val funcBody = funcBodyTmpl.format(body) + try { + val result = JavaScriptUDFManager.registerAndCall[Double]("test", + 1, funcBody, Double.box(17.0)) + fail() + } catch { + case e: TestFailedException => + e.printStackTrace() + fail() + case e: Exception => + // invoke methodの出力メッセージ確認 + e.getMessage should startWith("Failed to invoke function. functionName: test, args: WrappedArray(17.0)") + } + } + + it should "registerAndCall(args = 0) throw Exception" taggedAs (LocalTest) in { + val body = "throw new Error('error Message');" + val funcBody = funcBodyTmpl.format(body) + try { + val result = JavaScriptUDFManager.registerAndCall[Double]("test", + 0, funcBody) + fail() + } catch { + case e: TestFailedException => + e.printStackTrace() + fail() + case e: Exception => + // invoke methodの出力メッセージ確認(パラメータなし) + e.getMessage should startWith("Failed to invoke function. functionName: test, args: WrappedArray()") + } + } + + it should "registerAndTryCall return Failure" taggedAs (LocalTest) in { + val body = "throw new Error('error Message');" + val funcBody = funcBodyTmpl.format(body) + val resultTry = JavaScriptUDFManager.registerAndTryCall[Double]("test", + 1, funcBody, Double.box(17.0)) + resultTry shouldBe a[Failure[_]] + resultTry.foreach(result => { + result shouldBe "Failed to invoke function. functionName: test, args: WrappedArray(17.0)" + }) + } + + it should "registerAndOptionCall return None" taggedAs (LocalTest) in { + val body = "throw new Error('error Message');" + val funcBody = funcBodyTmpl.format(body) + val resultTry = JavaScriptUDFManager.registerAndOptionCall[Double]("test", + 1, funcBody, Double.box(17.0)) + resultTry shouldBe None + } + // this server mocks the gateway protected val server: RunnableServer = { unfiltered.netty.Server.http(12345).plan( diff --git a/processor/src/test/scala/us/jubat/jubaql_server/processor/JubaQLParserSpec.scala b/processor/src/test/scala/us/jubat/jubaql_server/processor/JubaQLParserSpec.scala index e08e70b..8ec740d 100644 --- a/processor/src/test/scala/us/jubat/jubaql_server/processor/JubaQLParserSpec.scala +++ b/processor/src/test/scala/us/jubat/jubaql_server/processor/JubaQLParserSpec.scala @@ -103,7 +103,7 @@ class JubaQLParserSpec extends FlatSpec { // use single quotation val result: Option[JubaQLAST] = parser.parse( """ - CREATE CLASSIFIER MODEL test1 (label: l) AS * WITH id CONFIG '{"test": 123}' + CREATE CLASSIFIER MODEL test1 (label: l) AS * WITH fex CONFIG '{"test": 123}' """.stripMargin ) @@ -112,7 +112,7 @@ class JubaQLParserSpec extends FlatSpec { create.algorithm shouldBe "CLASSIFIER" create.modelName shouldBe "test1" create.labelOrId shouldBe Some(("label", "l")) - create.featureExtraction shouldBe List((WildcardAnyParameter, "id")) + create.featureExtraction shouldBe List((WildcardAnyParameter, "fex")) create.configJson shouldBe """{"test": 123}""" //create.specifier shouldBe List(("id", List("id")), ("datum", List("a", "b"))) } @@ -122,7 +122,7 @@ class JubaQLParserSpec extends FlatSpec { // use single quotation val result: Option[JubaQLAST] = parser.parse( """ - |CREATE CLASSIFIER MODEL test1 (label: l) AS * WITH id CONFIG '{"test": + |CREATE CLASSIFIER MODEL test1 (label: l) AS * WITH fex CONFIG '{"test": |123}' """.stripMargin ) @@ -132,7 +132,7 @@ class JubaQLParserSpec extends FlatSpec { create.algorithm shouldBe "CLASSIFIER" create.modelName shouldBe "test1" create.labelOrId shouldBe Some(("label", "l")) - create.featureExtraction shouldBe List((WildcardAnyParameter, "id")) + create.featureExtraction shouldBe List((WildcardAnyParameter, "fex")) create.configJson shouldBe "{\"test\":\n123}" //create.specifier shouldBe List(("id", List("id")), ("datum", List("a", "b"))) } @@ -175,6 +175,141 @@ class JubaQLParserSpec extends FlatSpec { //create.specifier shouldBe List(("id", List("id")), ("datum", List("a", "b"))) } + it should "recognize CREATE MODEL without recource config" taggedAs (LocalTest) in { + val parser = new JubaQLParser + // use single quotation + val result: Option[JubaQLAST] = parser.parse( + """ + CREATE CLASSIFIER MODEL test1 (label: l) AS * WITH fex CONFIG '{"test": 123}' + """.stripMargin) + + result shouldNot be(None) + val create = result.get.asInstanceOf[CreateModel] + create.resConfigJson shouldBe None + } + + it should "recognize CREATE MODEL for recource config" taggedAs (LocalTest) in { + val parser = new JubaQLParser + // use single quotation + val result: Option[JubaQLAST] = parser.parse( + """ + CREATE CLASSIFIER MODEL test1 (label: l) AS * WITH fex CONFIG '{"test": 123}' RESOURCE CONFIG '{"applicationmaster_memory": 256}' + """.stripMargin) + result shouldNot be(None) + val create = result.get.asInstanceOf[CreateModel] + create.resConfigJson shouldBe Some("""{"applicationmaster_memory": 256}""") + } + + it should "not recognize CREATE MODEL for recource config without value" taggedAs (LocalTest) in { + val parser = new JubaQLParser + // use single quotation + var result: Option[JubaQLAST] = parser.parse( + """ + CREATE CLASSIFIER MODEL test1 (label: l) AS * WITH fex CONFIG '{"test": 123}' RESOURCE CONFIG + """.stripMargin) + result shouldBe (None) + } + + it should "not recognize CREATE MODEL for recource config without 'CONFIG'" taggedAs (LocalTest) in { + val parser = new JubaQLParser + // use single quotation + var result: Option[JubaQLAST] = parser.parse( + """ + CREATE CLASSIFIER MODEL test1 (label: l) AS * WITH fex CONFIG '{"test": 123}' RESOURCE '{"applicationmaster_memory": 256}' + """.stripMargin) + result shouldBe (None) + } + + it should "recognize CREATE MODEL without server config" taggedAs (LocalTest) in { + val parser = new JubaQLParser + // use single quotation + val result: Option[JubaQLAST] = parser.parse( + """ + CREATE CLASSIFIER MODEL test1 (label: l) AS * WITH fex CONFIG '{"test": 123}' + """.stripMargin) + + result shouldNot be(None) + val create = result.get.asInstanceOf[CreateModel] + create.serverConfigJson shouldBe None + } + + it should "recognize CREATE MODEL for server config" taggedAs (LocalTest) in { + val parser = new JubaQLParser + // use single quotation + val result: Option[JubaQLAST] = parser.parse( + """ + CREATE CLASSIFIER MODEL test1 (label: l) AS * WITH fex CONFIG '{"test": 123}' SERVER CONFIG '{"thread": 3}' + """.stripMargin) + result shouldNot be(None) + val create = result.get.asInstanceOf[CreateModel] + create.serverConfigJson shouldBe Some("""{"thread": 3}""") + } + + it should "not recognize CREATE MODEL for server config without value" taggedAs (LocalTest) in { + val parser = new JubaQLParser + // use single quotation + var result: Option[JubaQLAST] = parser.parse( + """ + CREATE CLASSIFIER MODEL test1 (label: l) AS * WITH fex CONFIG '{"test": 123}' SERVER CONFIG + """.stripMargin) + result shouldBe (None) + } + + it should "not recognize CREATE MODEL for server config without 'CONFIG'" taggedAs (LocalTest) in { + val parser = new JubaQLParser + // use single quotation + var result: Option[JubaQLAST] = parser.parse( + """ + CREATE CLASSIFIER MODEL test1 (label: l) AS * WITH fex CONFIG '{"test": 123}' SERVER '{"thread": 3}' + """.stripMargin) + result shouldBe (None) + } + + it should "recognize CREATE MODEL without proxy config" taggedAs (LocalTest) in { + val parser = new JubaQLParser + // use single quotation + val result: Option[JubaQLAST] = parser.parse( + """ + CREATE CLASSIFIER MODEL test1 (label: l) AS * WITH fex CONFIG '{"test": 123}' + """.stripMargin) + + result shouldNot be(None) + val create = result.get.asInstanceOf[CreateModel] + create.proxyConfigJson shouldBe None + } + + it should "recognize CREATE MODEL for proxy config" taggedAs (LocalTest) in { + val parser = new JubaQLParser + // use single quotation + val result: Option[JubaQLAST] = parser.parse( + """ + CREATE CLASSIFIER MODEL test1 (label: l) AS * WITH fex CONFIG '{"test": 123}' PROXY CONFIG '{"thread": 3}' + """.stripMargin) + result shouldNot be(None) + val create = result.get.asInstanceOf[CreateModel] + create.proxyConfigJson shouldBe Some("""{"thread": 3}""") + } + + it should "not recognize CREATE MODEL for proxy config without value" taggedAs (LocalTest) in { + val parser = new JubaQLParser + // use single quotation + var result: Option[JubaQLAST] = parser.parse( + """ + CREATE CLASSIFIER MODEL test1 (label: l) AS * WITH fex CONFIG '{"test": 123}' PROXY CONFIG + """.stripMargin) + result shouldBe (None) + } + + it should "not recognize CREATE MODEL for proxy config without 'CONFIG'" taggedAs (LocalTest) in { + val parser = new JubaQLParser + // use single quotation + var result: Option[JubaQLAST] = parser.parse( + """ + CREATE CLASSIFIER MODEL test1 (label: l) AS * WITH fex CONFIG '{"test": 123}' PROXY '{"thread": 3}' + """.stripMargin) + result shouldBe (None) + } + it should "recognize CREATE STREAM FROM SELECT" taggedAs (LocalTest) in { val parser = new JubaQLParser val result: Option[JubaQLAST] = parser.parse( @@ -224,7 +359,7 @@ class JubaQLParserSpec extends FlatSpec { | SLIDING WINDOW (SIZE 4 ADVANCE 3 TUPLES) | OVER source | WITH fourier(some_col) AS fourier_coeffs - | WHERE id % 2 = 0 + | WHERE fid % 2 = 0 | HAVING fourier_coeffs = 2 """.stripMargin) result shouldNot be(None) @@ -337,6 +472,68 @@ class JubaQLParserSpec extends FlatSpec { update.source shouldBe "source" } + it should "recognize UPDATE WITH for uppercase" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + UPDATE MODEL juba_model USING train WITH '{"test": 123}' + """.stripMargin) + + result shouldNot be(None) + val updateWith = result.get.asInstanceOf[UpdateWith] + updateWith.modelName shouldBe "juba_model" + updateWith.rpcName shouldBe "train" + updateWith.learningData shouldBe """{"test": 123}""" + } + + it should "recognize UPDATE WITH for lowercase" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + UPDATE MODEL juba_model USING train with '{"test": 123}' + """.stripMargin) + + result shouldNot be(None) + val updateWith = result.get.asInstanceOf[UpdateWith] + updateWith.modelName shouldBe "juba_model" + updateWith.rpcName shouldBe "train" + updateWith.learningData shouldBe """{"test": 123}""" + } + + it should "recognize UPDATE WITH for mixedcase" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + UPDATE MODEL juba_model USING train With '{"test": 123}' + """.stripMargin) + + result shouldNot be(None) + val updateWith = result.get.asInstanceOf[UpdateWith] + updateWith.modelName shouldBe "juba_model" + updateWith.rpcName shouldBe "train" + updateWith.learningData shouldBe """{"test": 123}""" + } + + it should "not recognize UPDATE WITH without with" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + UPDATE MODEL juba_model USING train '{"test": 123}' + """.stripMargin) + + result should be(None) + } + + it should "not recognize UPDATE WITH without learningData" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + UPDATE MODEL juba_model USING train WITH + """.stripMargin) + + result should be(None) + } + // TODO write more ANALYZE tests it should "recognize ANALYZE" taggedAs (LocalTest) in { @@ -457,4 +654,184 @@ class JubaQLParserSpec extends FlatSpec { cf.lang shouldBe "JavaScript" cf.body shouldBe " var n = 1; return n; " } + + // TODO write more SAVE MODEL tests + + it should "recognize SAVE MODEL for Development Mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + SAVE MODEL test USING "file:///home/data/models" AS test001 + """.stripMargin) + + result shouldNot be(None) + val sm = result.get.asInstanceOf[SaveModel] + sm.modelName shouldBe "test" + sm.modelPath shouldBe """file:///home/data/models""" + sm.modelId shouldBe "test001" + } + + it should "recognize SAVE MODEL for Production Mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + SAVE MODEL test USING "hdfs:///data/models" AS id + """.stripMargin) + + result shouldNot be(None) + val sm = result.get.asInstanceOf[SaveModel] + sm.modelName shouldBe "test" + sm.modelPath shouldBe """hdfs:///data/models""" + sm.modelId shouldBe "id" + } + + it should "not recognize SAVE MODEL without ModelName" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + SAVE MODEL USING "hdfs:///data/models" AS test001 + """.stripMargin) + + result should be(None) + } + + it should "not recognize SAVE MODEL ModelName is Empty" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + SAVE MODEL "" USING "hdfs:///data/models" AS test001 + """.stripMargin) + + result should be(None) + } + + it should "not recognize SAVE MODEL without ModelPath" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + SAVE MODEL test USING AS test001 + """.stripMargin) + + result should be(None) + } + + it should "not recognize SAVE MODEL ModelPath is Empty" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + SAVE MODEL test USING "" AS test001 + """.stripMargin) + + result should be(None) + } + + it should "not recognize SAVE MODEL without ModelId" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + SAVE MODEL test USING "hdfs:///data/models" AS + """.stripMargin) + + result should be(None) + } + + it should "not recognize SAVE MODEL ModelId is Empty" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + SAVE MODEL test USING "hdfs:///data/models" AS "" + """.stripMargin) + + result should be(None) + } + + // TODO write more LOAD MODEL tests + + it should "recognize LOAD MODEL Development Mode/file:scheme" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + LOAD MODEL test USING "file:///home/data/models" AS test001 + """.stripMargin) + + result shouldNot be(None) + val sm = result.get.asInstanceOf[LoadModel] + sm.modelName shouldBe "test" + sm.modelPath shouldBe """file:///home/data/models""" + sm.modelId shouldBe "test001" + } + + it should "recognize LOAD MODEL Production Mode/hdfs:scheme" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + LOAD MODEL test USING "hdfs:///data/models" AS id + """.stripMargin) + + result shouldNot be(None) + val sm = result.get.asInstanceOf[LoadModel] + sm.modelName shouldBe "test" + sm.modelPath shouldBe """hdfs:///data/models""" + sm.modelId shouldBe "id" + } + + it should "not recognize LOAD MODEL without ModelName" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + LOAD MODEL USING "hdfs:///data/models" AS test001 + """.stripMargin) + + result should be(None) + } + + it should "not recognize LOAD MODEL ModelName is Empty" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + LOAD MODEL "" USING "hdfs:///data/models" AS test001 + """.stripMargin) + + result should be(None) + } + + it should "not recognize LOAD MODEL without ModelPath" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + LOAD MODEL test USING AS test001 + """.stripMargin) + + result should be(None) + } + + it should "not recognize LOAD MODEL ModelPath is Empty" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + LOAD MODEL test USING "" AS test001 + """.stripMargin) + + result should be(None) + } + + it should "not recognize LOAD MODEL without ModelId" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + LOAD MODEL test USING "hdfs:///data/models" AS + """.stripMargin) + + result should be(None) + } + + it should "not recognize LOAD MODEL ModelId is Empty" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val result: Option[JubaQLAST] = parser.parse( + """ + LOAD MODEL test USING "hdfs:///data/models" AS "" + """.stripMargin) + + result should be(None) + } } diff --git a/processor/src/test/scala/us/jubat/jubaql_server/processor/JubaQLServiceHelperSpec.scala b/processor/src/test/scala/us/jubat/jubaql_server/processor/JubaQLServiceHelperSpec.scala index a59309e..b67d6cf 100644 --- a/processor/src/test/scala/us/jubat/jubaql_server/processor/JubaQLServiceHelperSpec.scala +++ b/processor/src/test/scala/us/jubat/jubaql_server/processor/JubaQLServiceHelperSpec.scala @@ -17,7 +17,19 @@ package us.jubat.jubaql_server.processor import org.scalatest.{ShouldMatchers, BeforeAndAfterAll, FlatSpec} import org.scalatest.EitherValues._ +import org.scalatest.PrivateMethodTester._ import org.apache.spark.SparkContext +import org.apache.commons.io.FileExistsException +import us.jubat.jubaql_server.processor.json.{JubaQLResponse, StatementProcessed, StatusResponse} +import us.jubat.yarn.common.{LearningMachineType, ServerConfig, ProxyConfig, Mixer} +import us.jubat.yarn.client.{JubatusYarnApplication, Resource, JubatusYarnApplicationStatus} +import scala.collection.mutable.LinkedHashMap +import scala.collection.Map +import scala.concurrent._ +import scala.concurrent.duration.Duration + +import scala.util.{Try, Success} + /* This test case tests only the state-independent (helper) functions of * JubaQLService (such as `parseJson()` or `extractDatum()`). It does @@ -26,14 +38,71 @@ import org.apache.spark.SparkContext * (The reason being that we need to kill the JVM that is running * the JubaQLProcessor to exit cleanly.) */ +object JubaQLServiceHelperSpec { + val throwExceptionName = "throwExceptionName" +} class JubaQLServiceHelperSpec extends FlatSpec with ShouldMatchers with BeforeAndAfterAll { private var sc: SparkContext = null private var service: JubaQLServiceTester = null + private var proService: JubaQLServiceProductionTester = null // create a subclass to test the protected methods class JubaQLServiceTester(sc: SparkContext) extends JubaQLService(sc, RunMode.Development, "file:///tmp/spark") { override def parseJson(in: String): Either[(Int, String), JubaQLAST] = super.parseJson(in) + + override def complementResource(resourceJsonString: Option[String]): Either[(Int, String), Resource] = + super.complementResource(resourceJsonString) + + override def complementServerConfig(serverJsonString: Option[String]): Either[(Int, String), ServerConfig] = + super.complementServerConfig(serverJsonString) + + override def complementProxyConfig(proxyJsonString: Option[String]): Either[(Int, String), ProxyConfig] = + super.complementProxyConfig(proxyJsonString) + + override def takeAction(ast: JubaQLAST): Either[(Int, String), JubaQLResponse] = + super.takeAction(ast) + + override def getSourcesStatus(): Map[String, Any] = + super.getSourcesStatus() + + override def getModelsStatus(): Map[String, Any] = + super.getModelsStatus() + + override def getProcessorStatus(): Map[String, Any] = + super.getProcessorStatus() + + override def queryUpdateWith(updateWith: UpdateWith): Either[(Int, String), String] = + super.queryUpdateWith(updateWith) + } + + // create a subclass to test the protected methods for Production Mode + class JubaQLServiceProductionTester(sc: SparkContext, runMode: RunMode) extends JubaQLService(sc, runMode, "file:///tmp/spark") { + override def takeAction(ast: JubaQLAST): Either[(Int, String), JubaQLResponse] = + super.takeAction(ast) + } + + // create a subclass to test the protected methods + class LocalJubatusApplicationTester(name: String) extends LocalJubatusApplication(null, name, LearningMachineType.Classifier, "jubaclassifier") { + override def saveModel(aModelPathPrefix: org.apache.hadoop.fs.Path, aModelId: String): Try[JubatusYarnApplication] = Try { + name match { + case JubaQLServiceHelperSpec.throwExceptionName => + throw new FileExistsException("exception for test") + + case _ => + this + } + } + + override def loadModel(aModelPathPrefix: org.apache.hadoop.fs.Path, aModelId: String): Try[JubatusYarnApplication] = Try { + name match { + case JubaQLServiceHelperSpec.throwExceptionName => + throw new FileExistsException("exception for test") + + case _ => + this + } + } } "parseJson()" should "be able to parse JSON" taggedAs (LocalTest) in { @@ -80,9 +149,2525 @@ class JubaQLServiceHelperSpec extends FlatSpec with ShouldMatchers with BeforeAn result.left.value._1 shouldBe 400 } - override protected def beforeAll(): Unit = { - sc = new SparkContext("local[3]", "JubaQL Processor Test") - service = new JubaQLServiceTester(sc) + // SaveModel test + "takeAction():SaveModel" should "return an success for Development mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new SaveModel("test", "file:///home/data/models", "test001") + val cm = new CreateModel("CLASSIFIER", "test", None, null, "") + val juba = new LocalJubatusApplicationTester("test") + + service.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val result: Either[(Int, String), JubaQLResponse] = service.takeAction(ast) + service.models.remove("test") + + val sp = result.right.value.asInstanceOf[StatementProcessed] + sp.result shouldBe "SAVE MODEL" + } + + it should "return error of non model for Development mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new SaveModel("test", "file:///home/data/models", "test001") + val juba = new LocalJubatusApplicationTester("test") + + val result: Either[(Int, String), JubaQLResponse] = service.takeAction(ast) + + result.left.value._1 shouldBe 400 + } + + it should "return error of invalid model path for Development mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new SaveModel("test", "hdfs:///home/data/models", "test001") + val cm = new CreateModel("CLASSIFIER", "test", None, null, "") + val juba = new LocalJubatusApplicationTester("test") + + service.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val result: Either[(Int, String), JubaQLResponse] = service.takeAction(ast) + service.models.remove("test") + + result.left.value._1 shouldBe 400 + } + + it should "return error invalid model path2 for Development mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new SaveModel("test", "file:/tmp/data/models", "test001") + val cm = new CreateModel("CLASSIFIER", "test", None, null, "") + val juba = new LocalJubatusApplicationTester("test") + + service.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val result: Either[(Int, String), JubaQLResponse] = service.takeAction(ast) + service.models.remove("test") + + result.left.value._1 shouldBe 400 + } + + it should "return error of saveModel method for Development mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new SaveModel("test", "file:///home/data/models", "test001") + val cm = new CreateModel("CLASSIFIER", "test", None, null, "") + val juba = new LocalJubatusApplicationTester(JubaQLServiceHelperSpec.throwExceptionName) + + service.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val result: Either[(Int, String), JubaQLResponse] = service.takeAction(ast) + service.models.remove("test") + + result.left.value._1 shouldBe 500 + } + + it should "return success for Production mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new SaveModel("test", "hdfs:///home/data/models", "test001") + val cm = new CreateModel("CLASSIFIER", "test", None, null, "") + val juba = new LocalJubatusApplicationTester("test") + + proService.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val result: Either[(Int, String), JubaQLResponse] = proService.takeAction(ast) + proService.models.remove("test") + + val sp = result.right.value.asInstanceOf[StatementProcessed] + sp.result shouldBe "SAVE MODEL" + } + + it should "return error of non model for Production mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new SaveModel("test", "hdfs:///home/data/models", "test001") + val juba = new LocalJubatusApplicationTester("test") + + val result: Either[(Int, String), JubaQLResponse] = proService.takeAction(ast) + + result.left.value._1 shouldBe 400 + } + + it should "return error of invalid model path for Production mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new SaveModel("test", "file:///home/data/models", "test001") + val cm = new CreateModel("CLASSIFIER", "test", None, null, "") + val juba = new LocalJubatusApplicationTester("test") + + proService.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val result: Either[(Int, String), JubaQLResponse] = proService.takeAction(ast) + proService.models.remove("test") + + result.left.value._1 shouldBe 400 + } + + it should "return error of invalid model path2 for Production mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new SaveModel("test", "hdfs:/home/data/models", "test001") + val cm = new CreateModel("CLASSIFIER", "test", None, null, "") + val juba = new LocalJubatusApplicationTester("test") + + proService.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val result: Either[(Int, String), JubaQLResponse] = proService.takeAction(ast) + proService.models.remove("test") + + result.left.value._1 shouldBe 400 + } + + it should "return error of saveModel method for Production mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new SaveModel("test", "hdfs:///home/data/models", "test001") + val cm = new CreateModel("CLASSIFIER", "test", None, null, "") + val juba = new LocalJubatusApplicationTester(JubaQLServiceHelperSpec.throwExceptionName) + + proService.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val result: Either[(Int, String), JubaQLResponse] = proService.takeAction(ast) + proService.models.remove("test") + + result.left.value._1 shouldBe 500 + } + + // LoadModel test + "takeAction():LoadModel" should "return an success for Development mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new LoadModel("test", "file:///home/data/models", "test001") + val cm = new CreateModel("CLASSIFIER", "test", None, null, "") + val juba = new LocalJubatusApplicationTester("test") + + service.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val result: Either[(Int, String), JubaQLResponse] = service.takeAction(ast) + service.models.remove("test") + + val sp = result.right.value.asInstanceOf[StatementProcessed] + sp.result shouldBe "LOAD MODEL" + } + + it should "return error of non model for Development mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new LoadModel("test", "file:///home/data/models", "test001") + val juba = new LocalJubatusApplicationTester("test") + + val result: Either[(Int, String), JubaQLResponse] = service.takeAction(ast) + + result.left.value._1 shouldBe 400 + } + + it should "return error of invalid model path for Development mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new LoadModel("test", "hdfs:///home/data/models", "test001") + val cm = new CreateModel("CLASSIFIER", "test", None, null, "") + val juba = new LocalJubatusApplicationTester("test") + + service.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val result: Either[(Int, String), JubaQLResponse] = service.takeAction(ast) + service.models.remove("test") + + result.left.value._1 shouldBe 400 + } + + it should "return error invalid model path2 for Development mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new LoadModel("test", "file:/tmp/data/models", "test001") + val cm = new CreateModel("CLASSIFIER", "test", None, null, "") + val juba = new LocalJubatusApplicationTester("test") + + service.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val result: Either[(Int, String), JubaQLResponse] = service.takeAction(ast) + service.models.remove("test") + + result.left.value._1 shouldBe 400 + } + + it should "return error of loadModel method for Development mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new LoadModel("test", "file:///home/data/models", "test001") + val cm = new CreateModel("CLASSIFIER", "test", None, null, "") + val juba = new LocalJubatusApplicationTester(JubaQLServiceHelperSpec.throwExceptionName) + + service.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val result: Either[(Int, String), JubaQLResponse] = service.takeAction(ast) + service.models.remove("test") + + result.left.value._1 shouldBe 500 + } + + it should "return success for Production mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new LoadModel("test", "hdfs:///home/data/models", "test001") + val cm = new CreateModel("CLASSIFIER", "test", None, null, "") + val juba = new LocalJubatusApplicationTester("test") + + proService.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val result: Either[(Int, String), JubaQLResponse] = proService.takeAction(ast) + proService.models.remove("test") + + val sp = result.right.value.asInstanceOf[StatementProcessed] + sp.result shouldBe "LOAD MODEL" + } + + it should "return error of non model for Production mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new LoadModel("test", "hdfs:///home/data/models", "test001") + val juba = new LocalJubatusApplicationTester("test") + + val result: Either[(Int, String), JubaQLResponse] = proService.takeAction(ast) + + result.left.value._1 shouldBe 400 + } + + it should "return error of invalid model path for Production mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new LoadModel("test", "file:///home/data/models", "test001") + val cm = new CreateModel("CLASSIFIER", "test", None, null, "") + val juba = new LocalJubatusApplicationTester("test") + + proService.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val result: Either[(Int, String), JubaQLResponse] = proService.takeAction(ast) + proService.models.remove("test") + + result.left.value._1 shouldBe 400 + } + + it should "return error of invalid model path2 for Production mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new LoadModel("test", "hdfs:/home/data/models", "test001") + val cm = new CreateModel("CLASSIFIER", "test", None, null, "") + val juba = new LocalJubatusApplicationTester("test") + + proService.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val result: Either[(Int, String), JubaQLResponse] = proService.takeAction(ast) + proService.models.remove("test") + + result.left.value._1 shouldBe 400 + } + + it should "return error of loadModel method for Production mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + val ast: JubaQLAST = new LoadModel("test", "hdfs:///home/data/models", "test001") + val cm = new CreateModel("CLASSIFIER", "test", None, null, "") + val juba = new LocalJubatusApplicationTester(JubaQLServiceHelperSpec.throwExceptionName) + + proService.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val result: Either[(Int, String), JubaQLResponse] = proService.takeAction(ast) + proService.models.remove("test") + + result.left.value._1 shouldBe 500 + } + + "complementResource()" should "success recource config" taggedAs (LocalTest) in { + // 指定なし + var result = service.complementResource(None) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.masterMemory shouldBe 128 + value.proxyMemory shouldBe 32 + value.masterCores shouldBe 1 + value.priority shouldBe 0 + value.containerMemory shouldBe 128 + value.memory shouldBe 256 + value.virtualCores shouldBe 1 + value.containerNodes shouldBe null + value.containerRacks shouldBe null + + case _ => + fail() + } + // 必要なキーなし + var resConfig = """{"test1": 256, "test2": 128}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.masterMemory shouldBe 128 + value.proxyMemory shouldBe 32 + value.masterCores shouldBe 1 + value.priority shouldBe 0 + value.containerMemory shouldBe 128 + value.memory shouldBe 256 + value.virtualCores shouldBe 1 + value.containerNodes shouldBe null + value.containerRacks shouldBe null + + case _ => + fail() + } + // 必要なキーなし + resConfig = """{}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.masterMemory shouldBe 128 + value.proxyMemory shouldBe 32 + value.masterCores shouldBe 1 + value.priority shouldBe 0 + value.containerMemory shouldBe 128 + value.memory shouldBe 256 + value.virtualCores shouldBe 1 + value.containerNodes shouldBe null + value.containerRacks shouldBe null + + case _ => + fail() + } + } + + it should "error recource config for invalid format" taggedAs (LocalTest) in { + var resConfig = """""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + resConfig = """{"applicationmaster_memory", "jubatus_proxy_memory"}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + resConfig = """{applicationmaster_memory: 256, jubatus_proxy_memory: 128}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + resConfig = """{"applicationmaster_memory":256, 1}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success recource config for applicationmaster_memory" taggedAs (LocalTest) in { + var resConfig = """{"applicationmaster_memory": 256}""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.masterMemory shouldBe 256 + + case _ => + fail() + } + + // 最小値 + resConfig = """{"applicationmaster_memory": 1}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.masterMemory shouldBe 1 + + case _ => + fail() + } + + // 最大値 + resConfig = s"""{"applicationmaster_memory": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.masterMemory shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error recource config for applicationmaster_memory" taggedAs (LocalTest) in { + // 範囲外 + var resConfig = """{"applicationmaster_memory": 0}""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + resConfig = s"""{"applicationmaster_memory": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + resConfig = """{"applicationmaster_memory": "256"}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success recource config for jubatus_proxy_memory" taggedAs (LocalTest) in { + var resConfig = """{"jubatus_proxy_memory": 16}""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.proxyMemory shouldBe 16 + + case _ => + fail() + } + + // 最小値 + resConfig = """{"jubatus_proxy_memory": 1}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.proxyMemory shouldBe 1 + + case _ => + fail() + } + + // 最大値 + resConfig = s"""{"jubatus_proxy_memory": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.proxyMemory shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error recource config for jubatus_proxy_memory" taggedAs (LocalTest) in { + // 範囲外 + var resConfig = """{"jubatus_proxy_memory": 0}""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + resConfig = s"""{"jubatus_proxy_memory": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + resConfig = """{"jubatus_proxy_memory": "256"}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success recource config for applicationmaster_cores" taggedAs (LocalTest) in { + var resConfig = """{"applicationmaster_cores": 2}""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.masterCores shouldBe 2 + + case _ => + fail() + } + + // 最小値 + resConfig = """{"applicationmaster_cores": 1}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.masterCores shouldBe 1 + + case _ => + fail() + } + + // 最大値 + resConfig = s"""{"applicationmaster_cores": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.masterCores shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error recource config for applicationmaster_cores" taggedAs (LocalTest) in { + // 範囲外 + var resConfig = """{"applicationmaster_cores": 0}""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + resConfig = s"""{"applicationmaster_cores": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + resConfig = """{"applicationmaster_cores": "256"}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success recource config for container_priority" taggedAs (LocalTest) in { + var resConfig = """{"container_priority": 2}""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.priority shouldBe 2 + + case _ => + fail() + } + + // 最小値 + resConfig = """{"container_priority": 0}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.priority shouldBe 0 + + case _ => + fail() + } + + // 最大値 + resConfig = s"""{"container_priority": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.priority shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error recource config for container_priority" taggedAs (LocalTest) in { + // 範囲外 + var resConfig = """{"container_priority": -1}""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + resConfig = s"""{"container_priority": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + resConfig = """{"container_priority": "1"}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success recource config for container_memory" taggedAs (LocalTest) in { + var resConfig = """{"container_memory": 256}""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.containerMemory shouldBe 256 + + case _ => + fail() + } + + // 最小値 + resConfig = """{"container_memory": 1}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.containerMemory shouldBe 1 + + case _ => + fail() + } + + // 最大値 + resConfig = s"""{"container_memory": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.containerMemory shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error recource config for container_memory" taggedAs (LocalTest) in { + // 範囲外 + var resConfig = """{"container_memory": 0}""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + resConfig = s"""{"container_memory": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + resConfig = """{"container_memory": "256"}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success recource config for jubatus_server_memory" taggedAs (LocalTest) in { + var resConfig = """{"jubatus_server_memory": 512}""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.memory shouldBe 512 + + case _ => + fail() + } + + // 最小値 + resConfig = """{"jubatus_server_memory": 1}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.memory shouldBe 1 + + case _ => + fail() + } + + // 最大値 + resConfig = s"""{"jubatus_server_memory": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.memory shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error recource config for jubatus_server_memory" taggedAs (LocalTest) in { + // 範囲外 + var resConfig = """{"jubatus_server_memory": 0}""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + resConfig = s"""{"jubatus_server_memory": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + resConfig = """{"jubatus_server_memory": "256"}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success recource config for container_cores" taggedAs (LocalTest) in { + var resConfig = """{"container_cores": 2}""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.virtualCores shouldBe 2 + + case _ => + fail() + } + + // 最小値 + resConfig = """{"container_cores": 1}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.virtualCores shouldBe 1 + + case _ => + fail() + } + + // 最大値 + resConfig = s"""{"container_cores": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.virtualCores shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error recource config for container_cores" taggedAs (LocalTest) in { + // 範囲外 + var resConfig = """{"container_cores": 0}""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + resConfig = s"""{"container_cores": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + resConfig = """{"container_cores": "256"}""".stripMargin.trim + result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success recource config for container_nodes" taggedAs (LocalTest) in { + var resConfig = """{"container_nodes": ["1", "2"]}""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.containerNodes shouldBe List("1", "2") + + case _ => + fail() + } + } + + it should "error recource config for container_nodes" taggedAs (LocalTest) in { + // 型違い + var resConfig = """{"container_nodes": 0}""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success recource config for container_racks" taggedAs (LocalTest) in { + var resConfig = """{"container_racks": ["1", "2"]}""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + value.isInstanceOf[Resource] shouldBe true + value.containerRacks shouldBe List("1", "2") + + case _ => + fail() + } + } + + it should "error recource config for container_racks" taggedAs (LocalTest) in { + // 型違い + var resConfig = """{"container_racks": 0}""".stripMargin.trim + var result = service.complementResource(Option(resConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + "getSourceStatus()" should "success source status" taggedAs (LocalTest) in { + // データなし + var stsMap = service.getSourcesStatus() + stsMap.isEmpty shouldBe true + + // データ1件 + val processor = new HybridProcessor(sc, service.sqlc, "file://data/shogun_data.json", List("dummy"), + RunMode.Development, "") + service.sources.put("testData", (processor, None)) + stsMap = service.getSourcesStatus() + println(s"sourceStatus: $stsMap") + stsMap.size shouldBe 1 + stsMap.get("testData") match { + case Some(data) => checkSourceStatus(data.asInstanceOf[LinkedHashMap[String, Any]]) + case _ => fail() + } + + // データ2件 + service.sources.putIfAbsent("testData2", (processor, None)) + stsMap = service.getSourcesStatus() + println(s"sourceStatus: $stsMap") + stsMap.size shouldBe 2 + stsMap.get("testData") match { + case Some(data) => checkSourceStatus(data.asInstanceOf[LinkedHashMap[String, Any]]) + case _ => fail() + } + stsMap.get("testData2") match { + case Some(data) => checkSourceStatus(data.asInstanceOf[LinkedHashMap[String, Any]]) + case _ => fail() + } + + service.sources.clear() + } + + "getModelsStatus()" should "success models status" taggedAs (LocalTest) in { + val testService = new JubaQLService(sc, RunMode.Test, "file:///tmp/spark") + val method = PrivateMethod[Map[String, Any]]('getModelsStatus) + + // モデルなし + var stsMap = testService invokePrivate method() + stsMap.isEmpty shouldBe true + + // モデル1件(resourceConfig/serverConfig/proxyConfigなし) + val cm = new CreateModel("CLASSIFIER", "test", None, null, "config", None, None, None) + val juba = new TestJubatusApplication("Test", LearningMachineType.Classifier) + testService.models.put("test", (juba, cm, LearningMachineType.Classifier)) + stsMap = testService invokePrivate method() + println(s"modelStatus: $stsMap") + stsMap.size shouldBe 1 + stsMap.get("test") match { + case Some(model) => checkModelStatus(model.asInstanceOf[LinkedHashMap[String, Any]]) + case _ => fail() + } + + // モデル2件(resourceConfig/serverConfig/proxyConfigあり) + val cm2 = new CreateModel("CLASSIFIER", "test2", None, null, "config", + Option("resourceConfig"), Option("serverConfig"), Option("proxyConfig")) + testService.models.put("test2", (juba, cm2, LearningMachineType.Classifier)) + stsMap = testService invokePrivate method() + println(s"modelStatus: $stsMap") + stsMap.size shouldBe 2 + stsMap.get("test") match { + case Some(model) => checkModelStatus(model.asInstanceOf[LinkedHashMap[String, Any]]) + case _ => fail() + } + stsMap.get("test2") match { + case Some(model) => checkModelStatus(model.asInstanceOf[LinkedHashMap[String, Any]]) + case _ => fail() + } + } + + "getProcessorStatus()" should "success processor status" taggedAs (LocalTest) in { + val stsMap = service.getProcessorStatus() + stsMap.isEmpty shouldBe false + println(s"processorStatus: $stsMap") + stsMap.get("applicationId") should not be None + stsMap.get("startTime") should not be None + stsMap.get("currentTime") should not be None + stsMap.get("opratingTime") should not be None + stsMap.get("virtualMemory") should not be None + stsMap.get("usedMemory") should not be None + } + + private def checkSourceStatus(status: LinkedHashMap[String, Any]): Unit = { + status.get("state") should not be None + + status.get("storage") match { + case Some(storage) => + val storageMap = storage.asInstanceOf[LinkedHashMap[String, Any]] + storageMap.get("path") should not be None + case _ => fail() + } + + status.get("stream") match { + case Some(storage) => + val storageMap = storage.asInstanceOf[LinkedHashMap[String, Any]] + storageMap.get("path") should not be None + case _ => fail() + } + } + + private def checkModelStatus(status: LinkedHashMap[String, Any]): Unit = { + status.get("learningMachineType") should not be None + + status.get("config") match { + case Some(config) => + val configMap = config.asInstanceOf[LinkedHashMap[String, Any]] + configMap.get("jubatusConfig") should not be None + configMap.get("resourceConfig") should not be None + configMap.get("serverConfig") should not be None + configMap.get("proxyConfig") should not be None + case _ => fail() + } + + status.get("jubatusYarnApplicationStatus") match { + case Some(appStatus) => + val appMap = appStatus.asInstanceOf[LinkedHashMap[String, Any]] + appMap.get("jubatusProxy") should not be None + appMap.get("jubatusServers") should not be None + appMap.get("jubatusOnYarn") should not be None + case _ => fail() + } + } + + "takeAction():Status" should "return StatusResponse" taggedAs (LocalTest) in { + val testService = new JubaQLService(sc, RunMode.Test, "file:///tmp/spark") + val method = PrivateMethod[Either[(Int, String), JubaQLResponse]]('takeAction) + + val processor = new HybridProcessor(sc, testService.sqlc, "file://data/shogun_data.json", List("dummy"), + RunMode.Test, "") + testService.sources.put("testData", (processor, None)) + + val cm = new CreateModel("CLASSIFIER", "test", None, null, "config", None) + val juba = new TestJubatusApplication("Test", LearningMachineType.Classifier) + testService.models.put("test", (juba, cm, LearningMachineType.Classifier)) + + val ast: JubaQLAST = new Status() + val result = testService invokePrivate method(ast) + + val sp = result.right.value.asInstanceOf[StatusResponse] + println(s"result: $sp") + sp.result shouldBe "STATUS" + sp.sources.size shouldBe 1 + sp.models.size shouldBe 1 + sp.processor.isEmpty shouldBe false + } + + "complementServerConfig()" should "success server config" taggedAs (LocalTest) in { + // 指定なし + var result = service.complementServerConfig(None) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.thread shouldBe 2 + value.timeout shouldBe 10 + value.mixer shouldBe Mixer.Linear + value.intervalSec shouldBe 16 + value.intervalCount shouldBe 512 + value.zookeeperTimeout shouldBe 10 + value.interconnectTimeout shouldBe 10 + + case _ => + fail() + } + // 必要なキーなし + var serverConfig = """{}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.thread shouldBe 2 + value.timeout shouldBe 10 + value.mixer shouldBe Mixer.Linear + value.intervalSec shouldBe 16 + value.intervalCount shouldBe 512 + value.zookeeperTimeout shouldBe 10 + value.interconnectTimeout shouldBe 10 + + case _ => + fail() + } + } + + it should "error server config for invalid format" taggedAs (LocalTest) in { + var serverConfig = """""".stripMargin.trim + var result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + serverConfig = """{"thread", "timeout"}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + serverConfig = """{thread: 3, timeout: 0}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + serverConfig = """{"thread": 3, 0}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + // 不正キー + serverConfig = """{"thread": 3, "test":0}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success server config for thread" taggedAs (LocalTest) in { + var serverConfig = """{"thread": 3}""".stripMargin.trim + var result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.thread shouldBe 3 + + case _ => + fail() + } + + // 最小値 + serverConfig = """{"thread": 1}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.thread shouldBe 1 + + case _ => + fail() + } + + // 最大値 + serverConfig = s"""{"thread": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.thread shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error server config for thread" taggedAs (LocalTest) in { + // 範囲外 + var serverConfig = """{"thread": 0}""".stripMargin.trim + var result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + serverConfig = s"""{"thread": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + serverConfig = """{"thread": "3"}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success server config for timeout" taggedAs (LocalTest) in { + var serverConfig = """{"timeout": 30}""".stripMargin.trim + var result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.timeout shouldBe 30 + + case _ => + fail() + } + + // 最小値 + serverConfig = """{"timeout": 0}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.timeout shouldBe 0 + + case _ => + fail() + } + + // 最大値 + serverConfig = s"""{"timeout": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.timeout shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error server config for timeout" taggedAs (LocalTest) in { + // 範囲外 + var serverConfig = """{"timeout": -1}""".stripMargin.trim + var result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + serverConfig = s"""{"timeout": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + serverConfig = """{"timeout": "30"}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success server config for mixer" taggedAs (LocalTest) in { + var serverConfig = """{"mixer": "linear_mixer"}""".stripMargin.trim + var result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.mixer shouldBe Mixer.Linear + + case _ => + fail() + } + + serverConfig = """{"mixer": "random_mixer"}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.mixer shouldBe Mixer.Random + + case _ => + fail() + } + + serverConfig = """{"mixer": "broadcast_mixer"}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.mixer shouldBe Mixer.Broadcast + + case _ => + fail() + } + + serverConfig = """{"mixer": "skip_mixer"}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.mixer shouldBe Mixer.Skip + + case _ => + fail() + } + } + + it should "error server config for mixer" taggedAs (LocalTest) in { + // 範囲外 + var serverConfig = s"""{"mixer": "test"}""".stripMargin.trim + var result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + serverConfig = """{"mixer": random_mixer}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success server config for interval_sec" taggedAs (LocalTest) in { + var serverConfig = """{"interval_sec": 10}""".stripMargin.trim + var result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.intervalSec shouldBe 10 + + case _ => + fail() + } + + // 最小値 + serverConfig = """{"interval_sec": 0}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.intervalSec shouldBe 0 + + case _ => + fail() + } + + // 最大値 + serverConfig = s"""{"interval_sec": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.intervalSec shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error server config for interval_sec" taggedAs (LocalTest) in { + // 範囲外 + var serverConfig = """{"interval_sec": -1}""".stripMargin.trim + var result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + serverConfig = s"""{"interval_sec": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + serverConfig = """{"interval_sec": "10"}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success server config for interval_count" taggedAs (LocalTest) in { + var serverConfig = """{"interval_count": 1024}""".stripMargin.trim + var result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.intervalCount shouldBe 1024 + + case _ => + fail() + } + + // 最小値 + serverConfig = """{"interval_count": 0}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.intervalCount shouldBe 0 + + case _ => + fail() + } + + // 最大値 + serverConfig = s"""{"interval_count": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.intervalCount shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error server config for interval_count" taggedAs (LocalTest) in { + // 範囲外 + var serverConfig = """{"interval_count": -1}""".stripMargin.trim + var result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + serverConfig = s"""{"interval_count": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + serverConfig = """{"interval_count": "1024"}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success server config for zookeeper_timeout" taggedAs (LocalTest) in { + var serverConfig = """{"zookeeper_timeout": 30}""".stripMargin.trim + var result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.zookeeperTimeout shouldBe 30 + + case _ => + fail() + } + + // 最小値 + serverConfig = """{"zookeeper_timeout": 1}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.zookeeperTimeout shouldBe 1 + + case _ => + fail() + } + + // 最大値 + serverConfig = s"""{"zookeeper_timeout": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.zookeeperTimeout shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error server config for zookeeper_timeout" taggedAs (LocalTest) in { + // 範囲外 + var serverConfig = """{"zookeeper_timeout": 0}""".stripMargin.trim + var result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + serverConfig = s"""{"zookeeper_timeout": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + serverConfig = """{"zookeeper_timeout": "30"}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success server config for interconnect_timeout" taggedAs (LocalTest) in { + var serverConfig = """{"interconnect_timeout": 30}""".stripMargin.trim + var result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.interconnectTimeout shouldBe 30 + + case _ => + fail() + } + + // 最小値 + serverConfig = """{"interconnect_timeout": 1}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.interconnectTimeout shouldBe 1 + + case _ => + fail() + } + + // 最大値 + serverConfig = s"""{"interconnect_timeout": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + value.isInstanceOf[ServerConfig] shouldBe true + value.interconnectTimeout shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error server config for interconnect_timeout" taggedAs (LocalTest) in { + // 範囲外 + var serverConfig = """{"interconnect_timeout": 0}""".stripMargin.trim + var result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + serverConfig = s"""{"interconnect_timeout": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + serverConfig = """{"interconnect_timeout": "30"}""".stripMargin.trim + result = service.complementServerConfig(Option(serverConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + "complementProxyConfig()" should "success proxy config" taggedAs (LocalTest) in { + // 指定なし + var result = service.complementProxyConfig(None) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.thread shouldBe 4 + value.timeout shouldBe 10 + value.zookeeperTimeout shouldBe 10 + value.interconnectTimeout shouldBe 10 + value.poolExpire shouldBe 60 + value.poolSize shouldBe 0 + + case _ => + fail() + } + // 必要なキーなし + var proxyConfig = """{}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.thread shouldBe 4 + value.timeout shouldBe 10 + value.zookeeperTimeout shouldBe 10 + value.interconnectTimeout shouldBe 10 + value.poolExpire shouldBe 60 + value.poolSize shouldBe 0 + + case _ => + fail() + } + } + + it should "error proxy config for invalid format" taggedAs (LocalTest) in { + var proxyConfig = """""".stripMargin.trim + var result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + proxyConfig = """{"thread", "timeout"}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + proxyConfig = """{thread: 3, timeout: 0}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + proxyConfig = """{"thread": 3, 0}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + // 不正キー + proxyConfig = """{"thread": 3, "test":0}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success proxy config for thread" taggedAs (LocalTest) in { + var proxyConfig = """{"thread": 3}""".stripMargin.trim + var result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.thread shouldBe 3 + + case _ => + fail() + } + + // 最小値 + proxyConfig = """{"thread": 1}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.thread shouldBe 1 + + case _ => + fail() + } + + // 最大値 + proxyConfig = s"""{"thread": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.thread shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error proxy config for thread" taggedAs (LocalTest) in { + // 範囲外 + var proxyConfig = """{"thread": 0}""".stripMargin.trim + var result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + proxyConfig = s"""{"thread": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + proxyConfig = """{"thread": "3"}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success proxy config for timeout" taggedAs (LocalTest) in { + var proxyConfig = """{"timeout": 30}""".stripMargin.trim + var result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.timeout shouldBe 30 + + case _ => + fail() + } + + // 最小値 + proxyConfig = """{"timeout": 0}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.timeout shouldBe 0 + + case _ => + fail() + } + + // 最大値 + proxyConfig = s"""{"timeout": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.timeout shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error proxy config for timeout" taggedAs (LocalTest) in { + // 範囲外 + var proxyConfig = """{"timeout": -1}""".stripMargin.trim + var result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + proxyConfig = s"""{"timeout": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + proxyConfig = """{"timeout": "30"}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success proxy config for zookeeper_timeout" taggedAs (LocalTest) in { + var proxyConfig = """{"zookeeper_timeout": 30}""".stripMargin.trim + var result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.zookeeperTimeout shouldBe 30 + + case _ => + fail() + } + + // 最小値 + proxyConfig = """{"zookeeper_timeout": 1}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.zookeeperTimeout shouldBe 1 + + case _ => + fail() + } + + // 最大値 + proxyConfig = s"""{"zookeeper_timeout": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.zookeeperTimeout shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error proxy config for zookeeper_timeout" taggedAs (LocalTest) in { + // 範囲外 + var proxyConfig = """{"zookeeper_timeout": 0}""".stripMargin.trim + var result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + proxyConfig = s"""{"zookeeper_timeout": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + proxyConfig = """{"zookeeper_timeout": "30"}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success proxy config for interconnect_timeout" taggedAs (LocalTest) in { + var proxyConfig = """{"interconnect_timeout": 30}""".stripMargin.trim + var result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.interconnectTimeout shouldBe 30 + + case _ => + fail() + } + + // 最小値 + proxyConfig = """{"interconnect_timeout": 1}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.interconnectTimeout shouldBe 1 + + case _ => + fail() + } + + // 最大値 + proxyConfig = s"""{"interconnect_timeout": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.interconnectTimeout shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error proxy config for interconnect_timeout" taggedAs (LocalTest) in { + // 範囲外 + var proxyConfig = """{"interconnect_timeout": 0}""".stripMargin.trim + var result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + proxyConfig = s"""{"interconnect_timeout": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + proxyConfig = """{"interconnect_timeout": "30"}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success proxy config for pool_expire" taggedAs (LocalTest) in { + var proxyConfig = """{"pool_expire": 30}""".stripMargin.trim + var result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.poolExpire shouldBe 30 + + case _ => + fail() + } + + // 最小値 + proxyConfig = """{"pool_expire": 0}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.poolExpire shouldBe 0 + + case _ => + fail() + } + + // 最大値 + proxyConfig = s"""{"pool_expire": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.poolExpire shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error proxy config for pool_expire" taggedAs (LocalTest) in { + // 範囲外 + var proxyConfig = """{"pool_expire": -1}""".stripMargin.trim + var result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + proxyConfig = s"""{"pool_expire": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + proxyConfig = """{"pool_expire": "30"}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + it should "success proxy config for pool_size" taggedAs (LocalTest) in { + var proxyConfig = """{"pool_size": 10}""".stripMargin.trim + var result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.poolSize shouldBe 10 + + case _ => + fail() + } + + // 最小値 + proxyConfig = """{"pool_size": 0}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.poolSize shouldBe 0 + + case _ => + fail() + } + + // 最大値 + proxyConfig = s"""{"pool_size": ${Int.MaxValue}}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + value.isInstanceOf[ProxyConfig] shouldBe true + value.poolSize shouldBe Int.MaxValue + + case _ => + fail() + } + } + + it should "error proxy config for pool_size" taggedAs (LocalTest) in { + // 範囲外 + var proxyConfig = """{"pool_size": -1}""".stripMargin.trim + var result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 範囲外 + proxyConfig = s"""{"pool_size": ${Int.MaxValue + 1}}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + + // 型違い + proxyConfig = """{"pool_size": "10"}""".stripMargin.trim + result = service.complementProxyConfig(Option(proxyConfig)) + result match { + case Right(value) => + fail() + + case Left((errCode, errMsg)) => + println(s"$errMsg") + errCode shouldBe 400 + } + } + + "takeAction():CreateModel" should "return an success for Development mode" taggedAs (LocalTest) in { + val parser = new JubaQLParser + + // プロキシ設定なし + var ast: Option[JubaQLAST] = parser.parse( + """ + CREATE CLASSIFIER MODEL test1 (label: l) AS * WITH unigram CONFIG '{"method": "AROW", "parameter": {"regularization_weight" : 1.0}}' + SERVER CONFIG '{"thread": 2}' + """.stripMargin) + var cm = ast.get.asInstanceOf[CreateModel] + var result: Either[(Int, String), JubaQLResponse] = service.takeAction(ast.get) + result match { + case Right(value) => + val sp = value.asInstanceOf[StatementProcessed] + sp.result shouldBe "CREATE MODEL (started) " + service.startedJubatusInstances.get("test1") match { + case Some((jubaFut, _, _)) => + Await.ready(jubaFut, Duration.Inf) + jubaFut.value match { + case Some(Success(j)) => + Await.ready(j.stop(), Duration.Inf) + } + } + case _ => + fail() + } + + // プロキシ設定あり + ast = parser.parse( + """ + CREATE CLASSIFIER MODEL test1 (label: l) AS * WITH unigram CONFIG '{"method": "AROW", "parameter": {"regularization_weight" : 1.0}}' + SERVER CONFIG '{"thread": 2}' PROXY CONFIG '{"thread": 2}' + """.stripMargin) + cm = ast.get.asInstanceOf[CreateModel] + result = service.takeAction(ast.get) + result match { + case Right(value) => + val sp = value.asInstanceOf[StatementProcessed] + sp.result shouldBe "CREATE MODEL (started) (proxy setting has been ignored in Development mode)" + service.startedJubatusInstances.get("test1") match { + case Some((jubaFut, _, _)) => + Await.ready(jubaFut, Duration.Inf) + jubaFut.value match { + case Some(Success(j)) => + Await.ready(j.stop(), Duration.Inf) + } + } + case _ => + fail() + } + } + + // UpdateWith test + "takeAction():UpdateWith" should "return an success for Anomaly" taggedAs (LocalTest) in { + val parser = new JubaQLParser + + val cmAst: Option[JubaQLAST] = parser.parse( + """ + CREATE ANOMALY MODEL test1 AS * CONFIG '{"method": "lof", "parameter": {"nearest_neighbor_num" : 10, + "reverse_nearest_neighbor_num": 30, "method": "euclid_lsh", "parameter": {"hash_num": 64, + "table_num": 4, "probe_num": 64, "bin_width": 100, "seed": 1091, "retain_projection": false}}}' + """.stripMargin) + + val cmResult: Either[(Int, String), JubaQLResponse] = service.takeAction(cmAst.get) + cmResult match { + case Right(value) => + service.startedJubatusInstances.get("test1") match { + case Some((jubaFut, _, _)) => + Await.ready(jubaFut, Duration.Inf) + } + case _ => + fail() + } + + val upAst: Option[JubaQLAST] = parser.parse( + """ + UPDATE MODEL test1 USING add WITH '{"test1": 0, "test2": "aaaa", "test3": 1}' + """.stripMargin) + + val upResult: Either[(Int, String), JubaQLResponse] = service.takeAction(upAst.get) + upResult match { + case Right(value) => + val sp = value.asInstanceOf[StatementProcessed] + sp.result shouldBe "UPDATE MODEL (id_with_score{id: 0, score: Infinity})" + + service.startedJubatusInstances.get("test1") match { + case Some((jubaFut, _, _)) => + jubaFut.value match { + case Some(Success(j)) => + Await.ready(j.stop(), Duration.Inf) + } + } + case _ => + fail() + } + } + + it should "return an success for Classifier" taggedAs (LocalTest) in { + val parser = new JubaQLParser + + val cmAst: Option[JubaQLAST] = parser.parse( + """ + CREATE CLASSIFIER MODEL test1 (label: label) AS name WITH unigram + CONFIG '{"method": "AROW", "parameter": {"regularization_weight" : 1.0}}' + """.stripMargin) + + val cmResult: Either[(Int, String), JubaQLResponse] = service.takeAction(cmAst.get) + cmResult match { + case Right(value) => + service.startedJubatusInstances.get("test1") match { + case Some((jubaFut, _, _)) => + Await.ready(jubaFut, Duration.Inf) + } + case _ => + fail() + } + + val upAst: Option[JubaQLAST] = parser.parse( + """ + UPDATE MODEL test1 USING train WITH '{"label": "label1", "name": "name1"}' + """.stripMargin) + + val upResult: Either[(Int, String), JubaQLResponse] = service.takeAction(upAst.get) + upResult match { + case Right(value) => + val sp = value.asInstanceOf[StatementProcessed] + sp.result shouldBe "UPDATE MODEL (1)" + + service.startedJubatusInstances.get("test1") match { + case Some((jubaFut, _, _)) => + jubaFut.value match { + case Some(Success(j)) => + Await.ready(j.stop(), Duration.Inf) + } + } + case _ => + fail() + } + } + + it should "return an success for Recommender" taggedAs (LocalTest) in { + val parser = new JubaQLParser + + val cmAst: Option[JubaQLAST] = parser.parse( + """ + CREATE RECOMMENDER MODEL test1 (id: pname) AS * CONFIG '{"method": "inverted_index", "parameter": {}}' + """.stripMargin) + + val cmResult: Either[(Int, String), JubaQLResponse] = service.takeAction(cmAst.get) + cmResult match { + case Right(value) => + service.startedJubatusInstances.get("test1") match { + case Some((jubaFut, _, _)) => + Await.ready(jubaFut, Duration.Inf) + } + case _ => + fail() + } + + val upAst: Option[JubaQLAST] = parser.parse( + """ + UPDATE MODEL test1 USING update_row WITH '{"pname": "name1", "team": "aaa", "test": 1}' + """.stripMargin) + + val upResult: Either[(Int, String), JubaQLResponse] = service.takeAction(upAst.get) + upResult match { + case Right(value) => + val sp = value.asInstanceOf[StatementProcessed] + sp.result shouldBe "UPDATE MODEL (true)" + + service.startedJubatusInstances.get("test1") match { + case Some((jubaFut, _, _)) => + jubaFut.value match { + case Some(Success(j)) => + Await.ready(j.stop(), Duration.Inf) + } + } + case _ => + fail() + } + } + + "queryUpdateWith()" should "error without model" taggedAs (LocalTest) in { + val updateWith = new UpdateWith("test", "train", """{"label": "label1", "name": "namme1"}""") + val result = service.queryUpdateWith(updateWith) + result.left.value._1 shouldBe 400 + } + + it should "error model and method mismatch for Anomaly" taggedAs (LocalTest) in { + val cm = new CreateModel("ANOMALY", "test", None, List(), "") + val juba = new LocalJubatusApplicationTester("test") + + service.models.put("test", (juba, cm, LearningMachineType.Anomaly)) + val updateWith = new UpdateWith("test", "train", """{"label": "label1", "name": "namme1"}""") + val result = service.queryUpdateWith(updateWith) + service.models.remove("test") + + result.left.value._1 shouldBe 400 + } + + it should "error model and method mismatch for Classifier" taggedAs (LocalTest) in { + val cm = new CreateModel("CLASSIFIER", "test", None, List(), "") + val juba = new LocalJubatusApplicationTester("test") + + service.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val updateWith = new UpdateWith("test", "add", """{"label": "label1", "name": "namme1"}""") + val result = service.queryUpdateWith(updateWith) + service.models.remove("test") + + result.left.value._1 shouldBe 400 + } + + it should "error model and method mismatch for Recommender" taggedAs (LocalTest) in { + val cm = new CreateModel("RECOMMENDER", "test", None, List(), "") + val juba = new LocalJubatusApplicationTester("test") + + service.models.put("test", (juba, cm, LearningMachineType.Recommender)) + val updateWith = new UpdateWith("test", "add", """{"label": "label1", "name": "namme1"}""") + val result = service.queryUpdateWith(updateWith) + service.models.remove("test") + + result.left.value._1 shouldBe 400 + } + + it should "error no 'label' in CreateModel for Classifier" taggedAs (LocalTest) in { + val cm = new CreateModel("CLASSIFIER", "test", None, List(), "") + val juba = new LocalJubatusApplicationTester("test") + + service.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val updateWith = new UpdateWith("test", "train", """{"label": "label1", "name": "namme1"}""") + val result = service.queryUpdateWith(updateWith) + service.models.remove("test") + + result.left.value._1 shouldBe 400 + } + + it should "error no 'id' in CreateModel for Recommender" taggedAs (LocalTest) in { + val cm = new CreateModel("RECOMMENDER", "test", None, List(), "") + val juba = new LocalJubatusApplicationTester("test") + + service.models.put("test", (juba, cm, LearningMachineType.Recommender)) + val updateWith = new UpdateWith("test", "update_row", """{"pname": "name1", "team": "aaa", "test": 1}""") + val result = service.queryUpdateWith(updateWith) + service.models.remove("test") + + result.left.value._1 shouldBe 400 + } + + it should "error no 'label' in learningData for Classifier" taggedAs (LocalTest) in { + val cm = new CreateModel("CLASSIFIER", "test", Some(("label", "label")), List(), "") + val juba = new LocalJubatusApplicationTester("test") + + service.models.put("test", (juba, cm, LearningMachineType.Classifier)) + val updateWith = new UpdateWith("test", "train", """{"name": "name1"}""") + val result = service.queryUpdateWith(updateWith) + service.models.remove("test") + + result.left.value._1 shouldBe 400 + } + + it should "error no 'id' in learningData for Recommender" taggedAs (LocalTest) in { + val cm = new CreateModel("RECOMMENDER", "test", Some(("id", "pname")), List(), "") + val juba = new LocalJubatusApplicationTester("test") + + service.models.put("test", (juba, cm, LearningMachineType.Recommender)) + val updateWith = new UpdateWith("test", "update_row", """{"team": "aaa", "test": 1}""") + val result = service.queryUpdateWith(updateWith) + service.models.remove("test") + + result.left.value._1 shouldBe 400 + } + + override protected def beforeAll(): Unit = { + sc = new SparkContext("local[3]", "JubaQL Processor Test") + service = new JubaQLServiceTester(sc) + + val hosts: List[(String, Int)] = List(("localhost", 1111)) + proService = new JubaQLServiceProductionTester(sc, RunMode.Production(hosts)) } override protected def afterAll(): Unit = { diff --git a/processor/src/test/scala/us/jubat/jubaql_server/processor/LocalJubatusApplicationSpec.scala b/processor/src/test/scala/us/jubat/jubaql_server/processor/LocalJubatusApplicationSpec.scala index cb1f76a..e84c79c 100644 --- a/processor/src/test/scala/us/jubat/jubaql_server/processor/LocalJubatusApplicationSpec.scala +++ b/processor/src/test/scala/us/jubat/jubaql_server/processor/LocalJubatusApplicationSpec.scala @@ -16,13 +16,19 @@ package us.jubat.jubaql_server.processor import org.scalatest._ -import us.jubat.yarn.common.LearningMachineType +import us.jubat.yarn.common.{LearningMachineType, ServerConfig, Mixer} import scala.concurrent._ import ExecutionContext.Implicits.global import scala.concurrent.duration.Duration -import scala.util.Success +import scala.util.{Success, Failure} +import org.apache.hadoop.fs.Path +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import java.net.InetSocketAddress -class LocalJubatusApplicationSpec extends FlatSpec with ShouldMatchers { +class LocalJubatusApplicationSpec extends FlatSpec with ShouldMatchers with BeforeAndAfterAll { + + private var dummyJubaServer: DummyJubatusServer = null val anomalyConfig = """{ "method" : "lof", @@ -179,4 +185,515 @@ class LocalJubatusApplicationSpec extends FlatSpec with ShouldMatchers { case _ => } } + + "start" should "return success for default server config" taggedAs (LocalTest, JubatusTest) in { + // デフォルト + val serverConfig = ServerConfig() + val f = LocalJubatusApplication.start("test", LearningMachineType.Classifier, classifierConfig, serverConfig) + Await.ready(f, Duration.Inf) + val result = f.value.get + result shouldBe a[Success[_]] + result match { + case Success(app) => + Await.ready(app.stop(), Duration.Inf) + case _ => + } + // コマンドのパラメータを目視確認 + // -c 2 -t 10 -x linear_mixer -s 16 -i 512 -Z 10 -I 10 + } + + it should "return success for customized server config" taggedAs (LocalTest, JubatusTest) in { + // 値指定あり + val serverConfig = ServerConfig(3, 30, Mixer.Broadcast, 0, 1024, 30, 40) + val f = LocalJubatusApplication.start("test", LearningMachineType.Classifier, classifierConfig, serverConfig) + Await.ready(f, Duration.Inf) + val result = f.value.get + result shouldBe a[Success[_]] + result match { + case Success(app) => + Await.ready(app.stop(), Duration.Inf) + case _ => + } + // コマンドのパラメータを目視確認 + // -c 3 -t 30 -x broadcast_mixer -s 0 -i 1024 -Z 30 -I 40 + } + + "saveModel" should "saveModel for classifier" taggedAs (LocalTest, JubatusTest) in { + val juba = new LocalJubatusApplication(null, "Test001", LearningMachineType.Classifier, "jubaclassifier", 9300) + + val dstFile = new java.io.File("/tmp/t1/data/classifier/test001/0.jubatus") + if (dstFile.exists()) { + dstFile.delete() + } + val dstPath = new java.io.File("/tmp/t1/data/classifier/test001") + if (dstPath.exists()) { + dstPath.delete() + } + + val modelPath = new Path("file:///tmp/t1/data/classifier") + val retApp = juba.saveModel(modelPath, "test001") + + retApp shouldBe a[Success[_]] + dstFile.exists() shouldBe true + } + + it should "saveModel a file doesn't exist for classifier" taggedAs (LocalTest, JubatusTest) in { + val juba = new LocalJubatusApplication(null, "Test002", LearningMachineType.Classifier, "jubaclassifier", 9300) + + val dstFile = new java.io.File("/tmp/t1/data/classifier/test002/0.jubatus") + if (dstFile.exists()) { + dstFile.delete() + } + val dstPath = new java.io.File("/tmp/t1/data/classifier/test002") + if (!dstPath.exists()) { + dstPath.mkdirs() + } + + val modelPath = new Path("file:///tmp/t1/data/classifier") + val retApp = juba.saveModel(modelPath, "test002") + + retApp shouldBe a[Success[_]] + dstFile.exists() shouldBe true + } + + it should "saveModel a file exists for classifier" taggedAs (LocalTest, JubatusTest) in { + val juba = new LocalJubatusApplication(null, "Test003", LearningMachineType.Classifier, "jubaclassifier", 9300) + + val dstPath = new java.io.File("/tmp/t1/data/classifier/test003") + if (!dstPath.exists()) { + dstPath.mkdirs() + } + + val dstFile = new java.io.File("/tmp/t1/data/classifier/test003/0.jubatus") + if (!dstFile.exists()) { + dstFile.createNewFile() + } + + val writer = new java.io.FileWriter(dstFile, true) + writer.write("test") + writer.close() + val beforLen = dstFile.length() + + val modelPath = new Path("file:///tmp/t1/data/classifier") + val retApp = juba.saveModel(modelPath, "test003") + + retApp shouldBe a[Success[_]] + dstFile.exists() shouldBe true + dstFile.length() should not be beforLen + } + + it should "saveModel no write permission for classifier" taggedAs (LocalTest, JubatusTest) in { + val juba = new LocalJubatusApplication(null, "Test004", LearningMachineType.Classifier, "jubaclassifier", 9300) + + val dstPath = new java.io.File("/tmp/t1/data/classifier/test004") + if (!dstPath.exists()) { + dstPath.mkdirs() + } + val dstFile = new java.io.File("/tmp/t1/data/classifier/test004/0.jubatus") + if (dstFile.exists()) { + dstFile.delete() + } + dstPath.setReadOnly() + + val modelPath = new Path("file:///t1/data/classifier") + val retApp = juba.saveModel(modelPath, "test004") + + retApp should not be a[Success[_]] + dstFile.exists() shouldBe false + } + + it should "saveModel relative path for classifier" taggedAs (LocalTest, JubatusTest) in { + saveModelForRelativePath("tmp/data/classifier", "test005") + } + + it should "saveModel relative path2 for classifier" taggedAs (LocalTest, JubatusTest) in { + saveModelForRelativePath("./tmp/data/classifier", "test006") + } + + it should "saveModel relative path3 for classifier" taggedAs (LocalTest, JubatusTest) in { + saveModelForRelativePath("../tmp/data/classifier", "test007") + } + + it should "saveModel save RPC result.size is 0" taggedAs (LocalTest, JubatusTest) in { + val juba = new LocalJubatusApplication(null, dummyJubaServer.JubaServer.resultSize0, LearningMachineType.Classifier, "jubaclassifier", 9300) + val modelPath = new Path("file:///tmp/data/classifier") + val retApp = juba.saveModel(modelPath, "test001") + + retApp shouldBe a[Failure[_]] + + retApp match { + case Failure(t) => + t.printStackTrace() + + case Success(_) => + printf("save model success") + } + } + + it should "saveModel save RPC result.size is 2" taggedAs (LocalTest, JubatusTest) in { + val juba = new LocalJubatusApplication(null, dummyJubaServer.JubaServer.resultSize2, LearningMachineType.Classifier, "jubaclassifier", 9300) + val modelPath = new Path("file:///tmp/data/classifier") + val retApp = juba.saveModel(modelPath, "test001") + + retApp shouldBe a[Failure[_]] + + retApp match { + case Failure(t) => + t.printStackTrace() + + case Success(_) => + printf("save model success") + } + } + + private def saveModelForRelativePath(path: String, id: String) { + val juba = new LocalJubatusApplication(null, id, LearningMachineType.Classifier, "jubaclassifier", 9300) + + val dstFile = new java.io.File(s"$path/$id/0.jubatus") + if (dstFile.exists()) { + dstFile.delete() + } + + val df = new java.io.File(s"/tmp/data/classifier/$id/0.jubatus") + if (df.exists()) { + df.delete() + } + + val modelPath = new Path(s"file://$path") + val retApp = juba.saveModel(modelPath, id) + + retApp shouldBe a[Success[_]] + dstFile.exists() shouldBe true + df.exists() shouldBe false + } + + "loadModel" should "loadModel for classifier" taggedAs (LocalTest, JubatusTest) in { + val juba = new LocalJubatusApplication(null, "Test001", LearningMachineType.Classifier, "jubaclassifier", 9300) + + val srcPath = new java.io.File("/tmp/t1/data/classifier/test001") + if (!srcPath.exists()) { + srcPath.mkdirs() + } + val srcFile = new java.io.File("/tmp/t1/data/classifier/test001/0.jubatus") + if (!srcFile.exists()) { + srcFile.createNewFile() + } + val dstFile = new java.io.File("/tmp/dummyHost_port_classifier_test001.jubatus") + if (dstFile.exists()) { + dstFile.delete() + } + + val modelPath = new Path("file:///tmp/t1/data/classifier") + val retApp = juba.loadModel(modelPath, "test001") + + retApp shouldBe a[Success[_]] + srcFile.exists() shouldBe true + dstFile.exists() shouldBe true + } + + it should "loadModel a dstFile exists for classifier" taggedAs (LocalTest, JubatusTest) in { + val juba = new LocalJubatusApplication(null, "Test002", LearningMachineType.Classifier, "jubaclassifier", 9300) + + val srcPath = new java.io.File("/tmp/t1/data/classifier/test002") + if (!srcPath.exists()) { + srcPath.mkdirs() + } + val srcFile = new java.io.File("/tmp/t1/data/classifier/test002/0.jubatus") + if (!srcFile.exists()) { + srcFile.createNewFile() + } + val dstFile = new java.io.File("/tmp/dummyHost_port_classifier_test002.jubatus") + if (!dstFile.exists()) { + dstFile.createNewFile() + } + val writer = new java.io.FileWriter(dstFile, true) + writer.write("test") + writer.close() + val beforLen = dstFile.length() + + val modelPath = new Path("file:///tmp/t1/data/classifier") + val retApp = juba.loadModel(modelPath, "test002") + + retApp shouldBe a[Success[_]] + srcFile.exists() shouldBe true + dstFile.exists() shouldBe true + dstFile.length() should not be beforLen + } + + it should "loadModel a srcFolder doesn't exists for classifier" taggedAs (LocalTest, JubatusTest) in { + val juba = new LocalJubatusApplication(null, "Test003", LearningMachineType.Classifier, "jubaclassifier", 9300) + + val srcPath = new java.io.File("/tmp/t1/data/classifier/test003") + if (srcPath.exists()) { + FileUtils.cleanDirectory(srcPath) + srcPath.delete() + } + + val modelPath = new Path("file:///tmp/t1/data/classifier") + val retApp = juba.loadModel(modelPath, "test003") + + retApp shouldBe a[Failure[_]] + retApp match { + case Failure(t) => + t.printStackTrace() + case Success(_) => + printf("load model success") + } + } + + it should "loadModel a srcFile doesn't exists for classifier" taggedAs (LocalTest, JubatusTest) in { + val juba = new LocalJubatusApplication(null, "Test004", LearningMachineType.Classifier, "jubaclassifier", 9300) + + val srcPath = new java.io.File("/tmp/t1/data/classifier/test004") + if (!srcPath.exists()) { + srcPath.mkdirs() + } + val srcFile = new java.io.File("/tmp/t1/data/classifier/test004/0.jubatus") + if (srcFile.exists()) { + srcFile.delete() + } + + val modelPath = new Path("file:///tmp/t1/data/classifier") + val retApp = juba.loadModel(modelPath, "test004") + + retApp shouldBe a[Failure[_]] + retApp match { + case Failure(t) => + t.printStackTrace() + case Success(_) => + printf("load model success") + } + } + + it should "loadModel relative path for classifier" taggedAs (LocalTest, JubatusTest) in { + loadModelForRelativePath("tmp/t1/data/classifier", "test005") + } + + it should "loadModel relative path2 for classifier" taggedAs (LocalTest, JubatusTest) in { + loadModelForRelativePath("./tmp/t1/data/classifier", "test006") + } + + it should "loadModel relative path3 for classifier" taggedAs (LocalTest, JubatusTest) in { + loadModelForRelativePath("../tmp/t1/data/classifier", "test007") + } + + it should "loadModel getStatus RPC result.size is 0" taggedAs (LocalTest, JubatusTest) in { + val juba = new LocalJubatusApplication(null, "errTest001", LearningMachineType.Classifier, "jubaclassifier", 9300) + dummyJubaServer.statusType = dummyJubaServer.JubaServer.resultSize0 + + val srcPath = new java.io.File("/tmp/t1/data/classifier/test008") + if (!srcPath.exists()) { + srcPath.mkdirs() + } + val srcFile = new java.io.File("/tmp/t1/data/classifier/test008/0.jubatus") + if (!srcFile.exists()) { + srcFile.createNewFile() + } + val dstFile = new java.io.File("/tmp/dummyHost_port_classifier_test008.jubatus") + if (dstFile.exists()) { + dstFile.delete() + } + + val modelPath = new Path("file:///tmp/t1/data/classifier") + val retApp = juba.loadModel(modelPath, "test008") + + retApp shouldBe a[Failure[_]] + retApp match { + case Failure(t) => + t.printStackTrace() + case Success(_) => + printf("load model success") + } + } + + it should "loadModel getStatus RPC result.size is 2" taggedAs (LocalTest, JubatusTest) in { + val juba = new LocalJubatusApplication(null, "errTest001", LearningMachineType.Classifier, "jubaclassifier", 9300) + dummyJubaServer.statusType = dummyJubaServer.JubaServer.resultSize2 + + val srcPath = new java.io.File("/tmp/t1/data/classifier/test008") + if (!srcPath.exists()) { + srcPath.mkdirs() + } + val srcFile = new java.io.File("/tmp/t1/data/classifier/test008/0.jubatus") + if (!srcFile.exists()) { + srcFile.createNewFile() + } + val dstFile = new java.io.File("/tmp/dummyHost_port_classifier_test008.jubatus") + if (dstFile.exists()) { + dstFile.delete() + } + + val modelPath = new Path("file:///tmp/t1/data/classifier") + val retApp = juba.loadModel(modelPath, "test008") + + retApp shouldBe a[Failure[_]] + retApp match { + case Failure(t) => + t.printStackTrace() + case Success(_) => + printf("load model success") + } + } + + it should "loadModel load RPC result error" taggedAs (LocalTest, JubatusTest) in { + val juba = new LocalJubatusApplication(null, "errTest001", LearningMachineType.Classifier, "jubaclassifier", 9300) + dummyJubaServer.statusType = dummyJubaServer.JubaServer.resultSize1 + + val srcPath = new java.io.File("/tmp/t1/data/classifier/test008") + if (!srcPath.exists()) { + srcPath.mkdirs() + } + val srcFile = new java.io.File("/tmp/t1/data/classifier/test008/0.jubatus") + if (!srcFile.exists()) { + srcFile.createNewFile() + } + val dstFile = new java.io.File("/tmp/dummyHost_port_classifier_test008.jubatus") + if (dstFile.exists()) { + dstFile.delete() + } + + val modelPath = new Path("file:///tmp/t1/data/classifier") + val retApp = juba.loadModel(modelPath, "test008") + + retApp shouldBe a[Failure[_]] + retApp match { + case Failure(t) => + t.printStackTrace() + case Success(_) => + printf("load model success") + } + } + + private def loadModelForRelativePath(path: String, id: String) { + val juba = new LocalJubatusApplication(null, id, LearningMachineType.Classifier, "jubaclassifier", 9300) + + val localFileSystem = org.apache.hadoop.fs.FileSystem.getLocal(new Configuration()) + val srcDirectory = localFileSystem.pathToFile(new org.apache.hadoop.fs.Path(path)) + + val srcPath = new java.io.File(srcDirectory, id) + if (!srcPath.exists()) { + srcPath.mkdirs() + } + val srcFile = new java.io.File(srcPath, "0.jubatus") + if (!srcFile.exists()) { + srcFile.createNewFile() + } + val dstFile = new java.io.File(s"/tmp/dummyHost_port_classifier_$id.jubatus") + if (dstFile.exists()) { + dstFile.delete() + } + val df = new java.io.File(s"/tmp/t1/data/classifier/$id/0.jubatus") + if (df.exists()) { + df.delete() + } + + val modelPath = new Path(s"file://$path") + val retApp = juba.loadModel(modelPath, id) + + retApp shouldBe a[Success[_]] + srcFile.exists() shouldBe true + dstFile.exists() shouldBe true + df.exists() shouldBe false + } + + "status" should "status for classifier" taggedAs (LocalTest, JubatusTest) in { + val juba = new LocalJubatusApplication(null, "Test001", LearningMachineType.Classifier, "jubaclassifier", 9300) + dummyJubaServer.statusType = dummyJubaServer.JubaServer.resultSize1 + + val retStatus = juba.status + retStatus.jubatusProxy shouldBe null + retStatus.jubatusServers should not be null + retStatus.yarnApplication shouldBe null + + retStatus.jubatusServers.size() shouldBe 1 + } + + override protected def beforeAll(): Unit = { + dummyJubaServer = new DummyJubatusServer + dummyJubaServer.start(9300) + } + + override protected def afterAll(): Unit = { + dummyJubaServer.stop() + } + +} + +class DummyJubatusServer { + var server: org.msgpack.rpc.Server = null + var statusType: String = JubaServer.resultSize1 + + object JubaServer { + val resultSize0: String = "resultSize0" + val resultSize1: String = "resultSize1" + val resultSize2: String = "resultSize2" + } + class JubaServer { + def save(strId: String): java.util.Map[String, String] = { + var ret: java.util.Map[String, String] = new java.util.HashMap() + + strId match { + case JubaServer.resultSize0 => // return 0 + ret + + case JubaServer.resultSize2 => // return 2 + ret.put("key1", "value1") + ret.put("key2", "value2") + ret + + case _ => + val file = new java.io.File("/tmp/test.jubatus") + if (!file.exists()) { + file.createNewFile() + } + ret.put("key1", "/tmp/test.jubatus") + ret + } + } + + def load(strId: String): Boolean = { + + strId match { + case "errTest001" => + false + + case _ => + true + } + } + + def get_status(): java.util.Map[String, java.util.Map[String, String]] = { + var ret: java.util.Map[String, java.util.Map[String, String]] = new java.util.HashMap() + var ret2: java.util.Map[String, String] = new java.util.HashMap() + statusType match { + case JubaServer.resultSize0 => + ret + + case JubaServer.resultSize2 => + ret2.put("datadir", "file:///tmp") + ret2.put("type", "classifier") + ret.put("key1", ret2) + ret.put("key2", ret2) + ret + + case _ => + ret2.put("datadir", "file:///tmp") + ret2.put("type", "classifier") + ret.put("dummyHost_port", ret2) + ret + } + } + } + + def start(id: Int) { + server = new org.msgpack.rpc.Server() + server.serve(new JubaServer()) + server.listen(new InetSocketAddress(id)) + println("*** DummyJubatusServer start ***") + } + + def stop() { + server.close() + println("*** DummyJubatusServer stop ***") + } } diff --git a/processor/src/test/scala/us/jubat/jubaql_server/processor/ProcessUtil.scala b/processor/src/test/scala/us/jubat/jubaql_server/processor/ProcessUtil.scala index 57567b0..b54d684 100644 --- a/processor/src/test/scala/us/jubat/jubaql_server/processor/ProcessUtil.scala +++ b/processor/src/test/scala/us/jubat/jubaql_server/processor/ProcessUtil.scala @@ -3,10 +3,8 @@ package us.jubat.jubaql_server.processor import scala.sys.process.{Process, ProcessBuilder} object ProcessUtil { - /** - * Returns a ProcessBuilder with an environment variable for checkpointDir. - */ - def commandToProcessBuilder(command: Seq[String]): ProcessBuilder = { - Process(command, None, "JAVA_OPTS" -> "-Djubaql.checkpointdir=file:///tmp/spark") + + def commandToProcessBuilder(command: Seq[String], env: String = "-Djubaql.checkpointdir=file:///tmp/spark"): ProcessBuilder = { + Process(command, None, "JAVA_OPTS" -> env) } } diff --git a/processor/src/test/scala/us/jubat/jubaql_server/processor/integration/JubaQLProcessorSpec.scala b/processor/src/test/scala/us/jubat/jubaql_server/processor/integration/JubaQLProcessorSpec.scala index 01017fd..070e143 100644 --- a/processor/src/test/scala/us/jubat/jubaql_server/processor/integration/JubaQLProcessorSpec.scala +++ b/processor/src/test/scala/us/jubat/jubaql_server/processor/integration/JubaQLProcessorSpec.scala @@ -26,7 +26,7 @@ import org.json4s.JsonDSL._ import org.json4s.native.JsonMethods._ import scala.util.{Success, Failure, Try} import us.jubat.jubaql_server.processor._ -import us.jubat.jubaql_server.processor.json.ClassifierResult +import us.jubat.jubaql_server.processor.json.{ClassifierResult, AnomalyScore} /** Tests the correct behavior as viewed from the outside. */ @@ -82,7 +82,7 @@ trait ProcessorTestManager while (now - start < waitMax && state != "Finished") { sendJubaQL("STATUS") match { case Success(x) => - x._2 \ "sources" \ dsName match { + x._2 \ "sources" \ dsName \ "state" match { case JString(currentState) => state = currentState case other => @@ -109,6 +109,16 @@ trait ProcessorTestManager (process, stdoutBuffer, sendJubaQLTo(port)) } + protected def startProcessor(env: String): (Process, + StringBuffer, String => Try[(Int, JValue)]) = { + val command = Seq("./start-script/run") + val pb = commandToProcessBuilder(command, env) + val (logger, stdoutBuffer, stderrBuffer) = getProcessLogger() + val process = pb run logger + val port = getServerPort(stdoutBuffer) + (process, stdoutBuffer, sendJubaQLTo(port)) + } + protected def getProcessLogger(): (ProcessLogger, StringBuffer, StringBuffer) = { val stdoutBuffer = new StringBuffer() val stderrBuffer = new StringBuffer() @@ -150,6 +160,30 @@ trait ProcessorTestManager } } } + + case class StreamStatus(inputCount: Long, outputCount: Long, startTime: Long) + + protected def getStreamStatus(streamStatusJValue: JValue): StreamStatus = { + val startTime = streamStatusJValue \ "stream_start" match { + case JInt(value) => + value.toLong + case other => + throw new Exception("failed json parse:" + other.toString) + } + val inputCount = streamStatusJValue \ "input_count" match { + case JInt(value) => + value.toLong + case other => + throw new Exception("failed json parse:" + other.toString) + } + val outputCount = streamStatusJValue \ "output_count" match { + case JInt(value) => + value.toLong + case other => + throw new Exception("failed json parse:" + other.toString) + } + StreamStatus(inputCount, outputCount, startTime) + } } class CreateDataSourceSpec @@ -267,6 +301,55 @@ class CreateModelSpec val exitValue = process.exitValue() exitValue shouldBe 0 } + + "CREATE MODEL(Production Mode)" should "return HTTP 200 on correct syntax and application-name of the as expected" taggedAs (JubatusTest) in { + // override before() processing + if (process != null) process.destroy() + // start production-mode processor + val startResult = startProcessor("-Drun.mode=production -Djubaql.zookeeper=127.0.0.1:2181 -Djubaql.checkpointdir=file:///tmp/spark -Djubaql.gateway.address=testAddress:1234 -Djubaql.processor.sessionId=1234567890abcdeABCDE") + process = startResult._1 + stdout = startResult._2 + sendJubaQL = startResult._3 + + val cmResult = sendJubaQL(goodCmStmt) + cmResult shouldBe a[Success[_]] + cmResult.get._1 shouldBe 200 + cmResult.get._2 \ "result" shouldBe JString("CREATE MODEL (started)") + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + + // check application-name + stdout.toString should include("starting JubatusOnYarn:testAddress:1234:1234567890abcdeABCDE:classifier:test1") + } + + "CREATE MODEL(Production Mode) without SystemProperty" should "return HTTP 200 on correct syntax and application-name of the as expected" taggedAs (JubatusTest) in { + + // override before() processing + if (process != null) process.destroy() + // start production-mode processor (without System Property) + val startResult = startProcessor("-Drun.mode=production -Djubaql.zookeeper=127.0.0.1:2181 -Djubaql.checkpointdir=file:///tmp/spark") + process = startResult._1 + stdout = startResult._2 + sendJubaQL = startResult._3 + + val cmResult = sendJubaQL(goodCmStmt) + cmResult shouldBe a[Success[_]] + cmResult.get._1 shouldBe 200 + cmResult.get._2 \ "result" shouldBe JString("CREATE MODEL (started)") + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + + // check application-name + stdout.toString should include("starting JubatusOnYarn:::classifier:test1") + } } class CreateStreamFromSelectSpec @@ -525,6 +608,159 @@ class CreateStreamFromSelectSpec val exitValue = process.exitValue() exitValue shouldBe 0 } + + it should "StreamStatus(starTime/outputCount/inputCount) from datasource" taggedAs (LocalTest) in { + val cdResult = sendJubaQL(goodCdStmt) + cdResult shouldBe a[Success[_]] + val csResult = sendJubaQL("""CREATE STREAM test FROM SELECT label FROM ds1 WHERE label = '徳川'""") + csResult shouldBe a[Success[_]] + csResult.get._1 shouldBe 200 + csResult.get._2 \ "result" shouldBe JString("CREATE STREAM") + + val beforeStartTime = System.currentTimeMillis() + + //before start + val stBefore = sendJubaQL("STATUS") + stBefore shouldBe a[Success[_]] + stBefore.get._1 shouldBe 200 + val beforeStreamStatus = stBefore.get._2 \ "streams" \ "test" + + val before = getStreamStatus(beforeStreamStatus) + before.inputCount shouldBe 0L + before.outputCount shouldBe 0L + before.startTime shouldBe 0L + + val sp1Result = sendJubaQL("START PROCESSING ds1") + sp1Result shouldBe a[Success[_]] + sp1Result.get._1 shouldBe 200 + sp1Result.get._2 \ "result" shouldBe JString("START PROCESSING") + waitUntilDone("ds1", 6000) + + // get status(startTime) + val stResult = sendJubaQL("STATUS") + stResult shouldBe a[Success[_]] + stResult.get._1 shouldBe 200 + val streamStatus = stResult.get._2 \ "streams" \ "test" + + val result = getStreamStatus(streamStatus) + result.inputCount shouldBe 44L + result.outputCount shouldBe 14L + result.startTime should be > beforeStartTime + + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + // streamStatusの目視確認用デバッグ出力 + println(streamStatus) + } + + it should "StreamStatus(starTime/outputCount/inputCount) from stream" taggedAs (LocalTest) in { + val cdResult = sendJubaQL(goodCdStmt) + cdResult shouldBe a[Success[_]] + val cs1Result = sendJubaQL("""CREATE STREAM ds2 FROM SELECT label, name FROM ds1 WHERE label = '徳川'""") + cs1Result shouldBe a[Success[_]] + + val csResult = sendJubaQL("""CREATE STREAM test FROM SELECT label, name FROM ds2 WHERE name like '家%'""") + csResult shouldBe a[Success[_]] + csResult.get._1 shouldBe 200 + csResult.get._2 \ "result" shouldBe JString("CREATE STREAM") + + val beforeStartTime = System.currentTimeMillis() + + //before start + val stBefore = sendJubaQL("STATUS") + stBefore shouldBe a[Success[_]] + stBefore.get._1 shouldBe 200 + val beforeStreamStatus = stBefore.get._2 \ "streams" \ "test" + + val before = getStreamStatus(beforeStreamStatus) + before.inputCount shouldBe 0L + before.outputCount shouldBe 0L + before.startTime shouldBe 0L + + val sp1Result = sendJubaQL("START PROCESSING ds1") + sp1Result shouldBe a[Success[_]] + sp1Result.get._1 shouldBe 200 + sp1Result.get._2 \ "result" shouldBe JString("START PROCESSING") + waitUntilDone("ds1", 6000) + + // get status(startTime) + val stResult = sendJubaQL("STATUS") + stResult shouldBe a[Success[_]] + stResult.get._1 shouldBe 200 + val streamStatus = stResult.get._2 \ "streams" \ "test" + + val result = getStreamStatus(streamStatus) + result.inputCount shouldBe 14L + result.outputCount shouldBe 11L + result.startTime should be > beforeStartTime + + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + // streamStatusの目視確認用デバッグ出力 + println(streamStatus) + } + + it should "Interference confirmation of StreamStatus values" taggedAs (LocalTest) in { + val cdResult = sendJubaQL(goodCdStmt) + cdResult shouldBe a[Success[_]] + + val cs1Result = sendJubaQL("""CREATE STREAM ds2 FROM SELECT label, name FROM ds1 WHERE label = '徳川'""") + cs1Result shouldBe a[Success[_]] + + val csResult1 = sendJubaQL("""CREATE STREAM test FROM SELECT label, name FROM ds2 WHERE name like '家%'""") + csResult1 shouldBe a[Success[_]] + csResult1.get._1 shouldBe 200 + csResult1.get._2 \ "result" shouldBe JString("CREATE STREAM") + + val csResult2 = sendJubaQL("""CREATE STREAM test2 FROM SELECT label, name FROM ds1 WHERE label = '北条'""") + csResult2 shouldBe a[Success[_]] + csResult2.get._1 shouldBe 200 + csResult2.get._2 \ "result" shouldBe JString("CREATE STREAM") + + val beforeStartTime = System.currentTimeMillis() + + val sp1Result = sendJubaQL("START PROCESSING ds1") + sp1Result shouldBe a[Success[_]] + sp1Result.get._1 shouldBe 200 + sp1Result.get._2 \ "result" shouldBe JString("START PROCESSING") + waitUntilDone("ds1", 6000) + + // get status(startTime) + val stResult = sendJubaQL("STATUS") + stResult shouldBe a[Success[_]] + stResult.get._1 shouldBe 200 + val streamStatus1 = stResult.get._2 \ "streams" \ "test" + val streamStatus2 = stResult.get._2 \ "streams" \ "test2" + + val result1 = getStreamStatus(streamStatus1) + result1.inputCount shouldBe 14L + result1.outputCount shouldBe 11L + result1.startTime should be > beforeStartTime + + val result2 = getStreamStatus(streamStatus2) + result2.inputCount shouldBe 44L + result2.outputCount shouldBe 15L + result2.startTime should be > beforeStartTime + + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + // streamStatus1の目視確認用デバッグ出力 + println(streamStatus1) + // streamStatus2の目視確認用デバッグ出力 + println(streamStatus2) + } } class AggregatesInSlidingWindowSpec @@ -1025,6 +1261,108 @@ class CreateStreamFromSlidingWindowSpec val outputString = correctValues.map(l => l._1 + " | " + l._2).mkString("\n") stdout.toString should include(headerRow + "\n" + outputString) } + + it should "StreamStatus(starTime/outputCount/inputCount) from datasource" taggedAs (LocalTest) in { + val cdResult = sendJubaQL(cdStmt) + cdResult shouldBe a[Success[_]] + + val csStmt = """CREATE STREAM ds2 FROM SLIDING WINDOW (SIZE 4 ADVANCE 1 TUPLES) """ + + """OVER ds1 WITH avg(age) AS avg_age, maxelem(gender)""" + val csResult = sendJubaQL(csStmt) + csResult shouldBe a[Success[_]] + if (csResult.get._1 != 200) + println(stdout.toString) + csResult.get._1 shouldBe 200 + csResult.get._2 \ "result" shouldBe JString("CREATE STREAM") + + val beforeStartTime = System.currentTimeMillis() + + //before start + val stBefore = sendJubaQL("STATUS") + stBefore shouldBe a[Success[_]] + stBefore.get._1 shouldBe 200 + val beforeStreamStatus = stBefore.get._2 \ "streams" \ "ds2" + + val before = getStreamStatus(beforeStreamStatus) + before.inputCount shouldBe 0L + before.outputCount shouldBe 0L + before.startTime shouldBe 0L + + sendJubaQL("START PROCESSING ds1") shouldBe a[Success[_]] + waitUntilDone("ds1", 30000) + + val stResult = sendJubaQL("STATUS") + stResult shouldBe a[Success[_]] + stResult.get._1 shouldBe 200 + val streamStatus = stResult.get._2 \ "streams" \ "ds2" + + val result = getStreamStatus(streamStatus) + result.inputCount shouldBe 12L + result.outputCount shouldBe 8L + result.startTime should be > beforeStartTime + + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + // streamStatusの目視確認用デバッグ出力 + println(streamStatus) + } + + it should "StreamStatus(starTime/outputCount/inputCount) from stream" taggedAs (LocalTest) in { + val cdResult = sendJubaQL(cdStmt) + cdResult shouldBe a[Success[_]] + + val cs1Result = sendJubaQL("CREATE STREAM stream1 FROM SELECT gender, age, jubaql_timestamp FROM ds1 WHERE age > 20") + cs1Result shouldBe a[Success[_]] + cs1Result.get._1 shouldBe 200 + + val csStmt = """CREATE STREAM ds2 FROM SLIDING WINDOW (SIZE 4 ADVANCE 1 TUPLES) """ + + """OVER stream1 WITH avg(age) AS avg_age, maxelem(gender)""" + val csResult = sendJubaQL(csStmt) + csResult shouldBe a[Success[_]] + if (csResult.get._1 != 200) + println(stdout.toString) + csResult.get._1 shouldBe 200 + csResult.get._2 \ "result" shouldBe JString("CREATE STREAM") + + val beforeStartTime = System.currentTimeMillis() + + //before start + val stBefore = sendJubaQL("STATUS") + stBefore shouldBe a[Success[_]] + stBefore.get._1 shouldBe 200 + val beforeStreamStatus = stBefore.get._2 \ "streams" \ "ds2" + + val before = getStreamStatus(beforeStreamStatus) + before.inputCount shouldBe 0L + before.outputCount shouldBe 0L + before.startTime shouldBe 0L + + sendJubaQL("START PROCESSING ds1") shouldBe a[Success[_]] + waitUntilDone("ds1", 30000) + + val stResult = sendJubaQL("STATUS") + stResult shouldBe a[Success[_]] + stResult.get._1 shouldBe 200 + val streamStatus = stResult.get._2 \ "streams" \ "ds2" + + val result = getStreamStatus(streamStatus) + result.inputCount shouldBe 10L + result.outputCount shouldBe 6L + result.startTime should be > beforeStartTime + + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + // streamStatusの目視確認用デバッグ出力 + println(streamStatus) + } } @@ -1174,24 +1512,43 @@ class CreateStreamFromAnalyzeSpec exitValue shouldBe 0 } - it should "work correctly with ANOMALY" taggedAs (LocalTest, JubatusTest) in { + it should "work correctly with CLASSIFIER and use error function" taggedAs (LocalTest, JubatusTest) in { + implicit val formats = DefaultFormats + + val cfResult = sendJubaQL( + """CREATE FUNCTION addABC(label string) RETURNS string LANGUAGE JavaScript AS $$ + |if (label == '徳川') { + | return label + "ABC"; + |} else { + | throw new Error("Error Message"); + |} + |$$ + """.stripMargin) + cfResult shouldBe a[Success[_]] + cfResult.get._1 shouldBe 200 + cfResult.get._2 \ "result" shouldBe JString("CREATE FUNCTION") + val cmStmt = """CREATE DATASOURCE ds (label string, name string) FROM (STORAGE: "file://src/test/resources/shogun_data.json")""" val cmResult = sendJubaQL(cmStmt) cmResult shouldBe a[Success[_]] - val config = Source.fromFile("src/test/resources/lof.json").getLines().mkString("") - val cdStmt = s"""CREATE ANOMALY MODEL test AS name WITH unigram CONFIG '$config'""" + val config = Source.fromFile("src/test/resources/shogun.json").getLines().mkString("") + val cdStmt = s"""CREATE CLASSIFIER MODEL test (label: label) AS name WITH unigram CONFIG '$config'""" val cdResult = sendJubaQL(cdStmt) cdResult shouldBe a[Success[_]] - val csfaStmt = """CREATE STREAM output FROM ANALYZE ds BY MODEL test USING calc_score AS newcol""" + val csfsStmt = """CREATE STREAM ds2 FROM SELECT addABC(label) AS lable, name FROM ds""" + val csfsResult = sendJubaQL(csfsStmt) + csfsResult shouldBe a[Success[_]] + + val csfaStmt = """CREATE STREAM output FROM ANALYZE ds2 BY MODEL test USING classify AS newcol""" val csfaResult = sendJubaQL(csfaStmt) csfaResult shouldBe a[Success[_]] // executed before UPDATE sendJubaQL("LOG STREAM output") shouldBe a[Success[_]] - val umStmt = """UPDATE MODEL test USING add FROM ds""" + val umStmt = """UPDATE MODEL test USING train FROM ds""" val umResult = sendJubaQL(umStmt) umResult shouldBe a[Success[_]] @@ -1200,11 +1557,11 @@ class CreateStreamFromAnalyzeSpec val spResult = sendJubaQL("START PROCESSING ds") spResult shouldBe a[Success[_]] - waitUntilDone("ds", 6000) + waitUntilDone("ds", 30000) // before the first update: - stdout.toString should include("徳川 | 家康 | 1.0") + stdout.toString should include("徳川ABC | 家康 | List()") // after the first update: - stdout.toString should include("徳川 | 家康 | 0.9990") + stdout.toString should include regex("徳川ABC \\| 家康 \\| .*徳川,0.93333333") // shut down val sdResult = sendJubaQL("SHUTDOWN") @@ -1214,24 +1571,24 @@ class CreateStreamFromAnalyzeSpec exitValue shouldBe 0 } - it should "work correctly with RECOMMENDER/from_id" taggedAs (LocalTest, JubatusTest) in { - val cmStmt = """CREATE DATASOURCE ds FROM (STORAGE: "file://src/test/resources/npb_similar_player_data.json")""" + it should "work correctly with ANOMALY" taggedAs (LocalTest, JubatusTest) in { + val cmStmt = """CREATE DATASOURCE ds (label string, name string) FROM (STORAGE: "file://src/test/resources/shogun_data.json")""" val cmResult = sendJubaQL(cmStmt) cmResult shouldBe a[Success[_]] - val config = Source.fromFile("src/test/resources/npb_similar_player.json").getLines().mkString("") - val cdStmt = s"""CREATE RECOMMENDER MODEL test (id: id) AS team WITH unigram, * WITH id CONFIG '$config'""" + val config = Source.fromFile("src/test/resources/lof.json").getLines().mkString("") + val cdStmt = s"""CREATE ANOMALY MODEL test AS name WITH unigram CONFIG '$config'""" val cdResult = sendJubaQL(cdStmt) cdResult shouldBe a[Success[_]] - val csfaStmt = """CREATE STREAM output FROM ANALYZE ds BY MODEL test USING complete_row_from_id AS newcol""" + val csfaStmt = """CREATE STREAM output FROM ANALYZE ds BY MODEL test USING calc_score AS newcol""" val csfaResult = sendJubaQL(csfaStmt) csfaResult shouldBe a[Success[_]] // executed before UPDATE sendJubaQL("LOG STREAM output") shouldBe a[Success[_]] - val umStmt = """UPDATE MODEL test USING update_row FROM ds""" + val umStmt = """UPDATE MODEL test USING add FROM ds""" val umResult = sendJubaQL(umStmt) umResult shouldBe a[Success[_]] @@ -1242,9 +1599,9 @@ class CreateStreamFromAnalyzeSpec spResult shouldBe a[Success[_]] waitUntilDone("ds", 6000) // before the first update: - stdout.toString should include regex("長野久義 .+Map\\(\\),Map\\(\\)") + stdout.toString should include("徳川 | 家康 | 1.0") // after the first update: - stdout.toString should include regex("長野久義 .+Map\\(\\),Map\\(.*OPS -> 0.6804.*\\)") + stdout.toString should include("徳川 | 家康 | 0.9990") // shut down val sdResult = sendJubaQL("SHUTDOWN") @@ -1254,24 +1611,41 @@ class CreateStreamFromAnalyzeSpec exitValue shouldBe 0 } - it should "work correctly with RECOMMENDER/from_data" taggedAs (LocalTest, JubatusTest) in { - val cmStmt = """CREATE DATASOURCE ds FROM (STORAGE: "file://src/test/resources/npb_similar_player_data.json")""" - val cmResult = sendJubaQL(cmStmt) - cmResult shouldBe a[Success[_]] + it should "work correctly with ANOMALY and use error function" taggedAs (LocalTest, JubatusTest) in { + val cfResult = sendJubaQL( + """CREATE FUNCTION addABC(label string) RETURNS string LANGUAGE JavaScript AS $$ + |if (label == '徳川') { + | return label + "ABC"; + |} else { + | throw new Error("Error Message"); + |} + |$$ + """.stripMargin) + cfResult shouldBe a[Success[_]] + cfResult.get._1 shouldBe 200 + cfResult.get._2 \ "result" shouldBe JString("CREATE FUNCTION") - val config = Source.fromFile("src/test/resources/npb_similar_player.json").getLines().mkString("") - val cdStmt = s"""CREATE RECOMMENDER MODEL test (id: id) AS team WITH unigram, * WITH id CONFIG '$config'""" + val cdStmt = """CREATE DATASOURCE ds (label string, name string) FROM (STORAGE: "file://src/test/resources/shogun_data.json")""" val cdResult = sendJubaQL(cdStmt) cdResult shouldBe a[Success[_]] - val aStmt2 = """CREATE STREAM output FROM ANALYZE ds BY MODEL test USING complete_row_from_datum AS newcol""" - val aResult2 = sendJubaQL(aStmt2) - aResult2 shouldBe a[Success[_]] + val config = Source.fromFile("src/test/resources/lof.json").getLines().mkString("") + val cmStmt = s"""CREATE ANOMALY MODEL test AS name WITH unigram CONFIG '$config'""" + val cmResult = sendJubaQL(cmStmt) + cmResult shouldBe a[Success[_]] - // executed before UPDATE - sendJubaQL("LOG STREAM output") shouldBe a[Success[_]] + val csfsStmt = """CREATE STREAM ds2 FROM SELECT addABC(label) AS lable, name FROM ds""" + val csfsResult = sendJubaQL(csfsStmt) + csfsResult shouldBe a[Success[_]] - val umStmt = """UPDATE MODEL test USING update_row FROM ds""" + val csfaStmt = """CREATE STREAM output FROM ANALYZE ds2 BY MODEL test USING calc_score AS newcol""" + val csfaResult = sendJubaQL(csfaStmt) + csfaResult shouldBe a[Success[_]] + + // executed before UPDATE + sendJubaQL("LOG STREAM output") shouldBe a[Success[_]] + + val umStmt = """UPDATE MODEL test USING add FROM ds""" val umResult = sendJubaQL(umStmt) umResult shouldBe a[Success[_]] @@ -1280,7 +1654,144 @@ class CreateStreamFromAnalyzeSpec val spResult = sendJubaQL("START PROCESSING ds") spResult shouldBe a[Success[_]] - waitUntilDone("ds", 6000) + waitUntilDone("ds", 30000) + // before the first update: + stdout.toString should include("徳川ABC | 家康 | 1.0") + // after the first update: + stdout.toString should include("徳川ABC | 家康 | 0.9990") + + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + } + + it should "work correctly with RECOMMENDER/from_id" taggedAs (LocalTest, JubatusTest) in { + val cdStmt = """CREATE DATASOURCE ds FROM (STORAGE: "file://src/test/resources/npb_similar_player_data.json")""" + val cdResult = sendJubaQL(cdStmt) + cdResult shouldBe a[Success[_]] + + val config = Source.fromFile("src/test/resources/npb_similar_player.json").getLines().mkString("") + val cmStmt = s"""CREATE RECOMMENDER MODEL test (id: id) AS team WITH unigram, * WITH id CONFIG '$config'""" + val cmResult = sendJubaQL(cmStmt) + cmResult shouldBe a[Success[_]] + + val csfaStmt = """CREATE STREAM output FROM ANALYZE ds BY MODEL test USING complete_row_from_id AS newcol""" + val csfaResult = sendJubaQL(csfaStmt) + csfaResult shouldBe a[Success[_]] + + // executed before UPDATE + sendJubaQL("LOG STREAM output") shouldBe a[Success[_]] + + val umStmt = """UPDATE MODEL test USING update_row FROM ds""" + val umResult = sendJubaQL(umStmt) + umResult shouldBe a[Success[_]] + + // executed after UPDATE + sendJubaQL("LOG STREAM output") shouldBe a[Success[_]] + + val spResult = sendJubaQL("START PROCESSING ds") + spResult shouldBe a[Success[_]] + waitUntilDone("ds", 30000) + // before the first update: + stdout.toString should include regex("長野久義 .+Map\\(\\),Map\\(\\)") + // after the first update: + stdout.toString should include regex("長野久義 .+Map\\(\\),Map\\(.*OPS -> 0.6804.*\\)") + + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + } + + it should "work correctly with RECOMMENDER/from_id and use error function" taggedAs (LocalTest, JubatusTest) in { + val cfResult = sendJubaQL( + """CREATE FUNCTION addABC(team string) RETURNS string LANGUAGE JavaScript AS $$ + |if (team == '巨人') { + | return team + "ABC"; + |} else { + | throw new Error("Error Message"); + |} + |$$ + """.stripMargin) + cfResult shouldBe a[Success[_]] + cfResult.get._1 shouldBe 200 + cfResult.get._2 \ "result" shouldBe JString("CREATE FUNCTION") + + val cdStmt = """CREATE DATASOURCE ds (id string, team string, 打率 numeric, 試合数 numeric, 打席 numeric, 打数 numeric, 安打 numeric, 本塁打 numeric, 打点 numeric, 盗塁 numeric, 四球 numeric, 死球 numeric, 三振 numeric, 犠打 numeric, 併殺打 numeric, 長打率 numeric, 出塁率 numeric, OPS numeric, RC27 numeric, XR27 numeric) FROM (STORAGE: "file://src/test/resources/npb_similar_player_data.json")""" + val cdResult = sendJubaQL(cdStmt) + cdResult shouldBe a[Success[_]] + + val config = Source.fromFile("src/test/resources/npb_similar_player.json").getLines().mkString("") + val cmStmt = s"""CREATE RECOMMENDER MODEL test (id: id) AS team WITH unigram, * WITH id CONFIG '$config'""" + val cmResult = sendJubaQL(cmStmt) + cmResult shouldBe a[Success[_]] + + val csfsStmt = """CREATE STREAM ds2 FROM SELECT id, addABC(team) AS team, 打率, 試合数, 打席, 打数, 安打, 本塁打, 打点, 盗塁, 四球, 死球, 三振, 犠打, 併殺打, 長打率, 出塁率, OPS, RC27, XR27 FROM ds""" + val csfsResult = sendJubaQL(csfsStmt) + csfsResult shouldBe a[Success[_]] + + val csfaStmt = """CREATE STREAM output FROM ANALYZE ds2 BY MODEL test USING complete_row_from_id AS newcol""" + val csfaResult = sendJubaQL(csfaStmt) + csfaResult shouldBe a[Success[_]] + + // executed before UPDATE + sendJubaQL("LOG STREAM output") shouldBe a[Success[_]] + + val umStmt = """UPDATE MODEL test USING update_row FROM ds""" + val umResult = sendJubaQL(umStmt) + umResult shouldBe a[Success[_]] + + // executed after UPDATE + sendJubaQL("LOG STREAM output") shouldBe a[Success[_]] + + val spResult = sendJubaQL("START PROCESSING ds") + spResult shouldBe a[Success[_]] + waitUntilDone("ds", 30000) + // before the first update: + stdout.toString should include regex("長野久義 \\| 巨人ABC .+Map\\(\\),Map\\(\\)") + // after the first update: + stdout.toString should include regex("長野久義 \\| 巨人ABC .+Map\\(\\),Map\\(.*OPS -> 0.6804.*\\)") + + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + } + + it should "work correctly with RECOMMENDER/from_data" taggedAs (LocalTest, JubatusTest) in { + val cmStmt = """CREATE DATASOURCE ds FROM (STORAGE: "file://src/test/resources/npb_similar_player_data.json")""" + val cmResult = sendJubaQL(cmStmt) + cmResult shouldBe a[Success[_]] + + val config = Source.fromFile("src/test/resources/npb_similar_player.json").getLines().mkString("") + val cdStmt = s"""CREATE RECOMMENDER MODEL test (id: id) AS team WITH unigram, * WITH id CONFIG '$config'""" + val cdResult = sendJubaQL(cdStmt) + cdResult shouldBe a[Success[_]] + + val aStmt2 = """CREATE STREAM output FROM ANALYZE ds BY MODEL test USING complete_row_from_datum AS newcol""" + val aResult2 = sendJubaQL(aStmt2) + aResult2 shouldBe a[Success[_]] + + // executed before UPDATE + sendJubaQL("LOG STREAM output") shouldBe a[Success[_]] + + val umStmt = """UPDATE MODEL test USING update_row FROM ds""" + val umResult = sendJubaQL(umStmt) + umResult shouldBe a[Success[_]] + + // executed after UPDATE + sendJubaQL("LOG STREAM output") shouldBe a[Success[_]] + + val spResult = sendJubaQL("START PROCESSING ds") + spResult shouldBe a[Success[_]] + waitUntilDone("ds", 30000) // before the first update: stdout.toString should include regex("長野久義 .+Map\\(\\),Map\\(\\)") // after the first update: @@ -1294,6 +1805,63 @@ class CreateStreamFromAnalyzeSpec exitValue shouldBe 0 } + it should "work correctly with RECOMMENDER/from_data and use error function" taggedAs (LocalTest, JubatusTest) in { + val cfResult = sendJubaQL( + """CREATE FUNCTION addABC(name string, team string) RETURNS string LANGUAGE JavaScript AS $$ + |if (team == '巨人') { + | return name + "ABC"; + |} else { + | throw new Error("Error Message"); + |} + |$$ + """.stripMargin) + cfResult shouldBe a[Success[_]] + cfResult.get._1 shouldBe 200 + cfResult.get._2 \ "result" shouldBe JString("CREATE FUNCTION") + + val cmStmt = """CREATE DATASOURCE ds (id string, team string, 打率 numeric, 試合数 numeric, 打席 numeric, 打数 numeric, 安打 numeric, 本塁打 numeric, 打点 numeric, 盗塁 numeric, 四球 numeric, 死球 numeric, 三振 numeric, 犠打 numeric, 併殺打 numeric, 長打率 numeric, 出塁率 numeric, OPS numeric, RC27 numeric, XR27 numeric) FROM (STORAGE: "file://src/test/resources/npb_similar_player_data.json")""" + val cmResult = sendJubaQL(cmStmt) + cmResult shouldBe a[Success[_]] + + val config = Source.fromFile("src/test/resources/npb_similar_player.json").getLines().mkString("") + val cdStmt = s"""CREATE RECOMMENDER MODEL test (id: id) AS team WITH unigram, * WITH id CONFIG '$config'""" + val cdResult = sendJubaQL(cdStmt) + cdResult shouldBe a[Success[_]] + + val aStmt1 = """CREATE STREAM ds2 FROM SELECT addABC(id, team) AS id, team, 打率, 試合数, 打席, 打数, 安打, 本塁打, 打点, 盗塁, 四球, 死球, 三振, 犠打, 併殺打, 長打率, 出塁率, OPS, RC27, XR27 FROM ds""" + val aResult1 = sendJubaQL(aStmt1) + aResult1 shouldBe a[Success[_]] + + val aStmt2 = """CREATE STREAM output FROM ANALYZE ds2 BY MODEL test USING complete_row_from_datum AS newcol""" + val aResult2 = sendJubaQL(aStmt2) + aResult2 shouldBe a[Success[_]] + + // executed before UPDATE + sendJubaQL("LOG STREAM output") shouldBe a[Success[_]] + + val umStmt = """UPDATE MODEL test USING update_row FROM ds""" + val umResult = sendJubaQL(umStmt) + umResult shouldBe a[Success[_]] + + // executed after UPDATE + sendJubaQL("LOG STREAM output") shouldBe a[Success[_]] + + val spResult = sendJubaQL("START PROCESSING ds") + spResult shouldBe a[Success[_]] + waitUntilDone("ds", 30000) + // before the first update: + stdout.toString should include regex("長野久義ABC .+Map\\(\\),Map\\(\\)") + // after the first update: + stdout.toString should include regex("長野久義ABC .+Map\\(\\),Map\\(.*OPS -> .*\\)") + + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + } + it should "return HTTP 400 if referenced data source does not exist" taggedAs (LocalTest, JubatusTest) in { val cmResult = sendJubaQL(goodCmStmt) cmResult shouldBe a[Success[_]] @@ -1403,6 +1971,124 @@ class CreateStreamFromAnalyzeSpec val exitValue = process.exitValue() exitValue shouldBe 0 } + + it should "StreamStatus(starTime/outputCount/inputCount) from datasource with CLASSIFIER" taggedAs (LocalTest, JubatusTest) in { + implicit val formats = DefaultFormats + + val cmStmt = """CREATE DATASOURCE ds (label string, name string) FROM (STORAGE: "file://src/test/resources/shogun_data.json")""" + val cmResult = sendJubaQL(cmStmt) + cmResult shouldBe a[Success[_]] + + val config = Source.fromFile("src/test/resources/shogun.json").getLines().mkString("") + val cdStmt = s"""CREATE CLASSIFIER MODEL test (label: label) AS name WITH unigram CONFIG '$config'""" + val cdResult = sendJubaQL(cdStmt) + cdResult shouldBe a[Success[_]] + + val csfaStmt = """CREATE STREAM output FROM ANALYZE ds BY MODEL test USING classify AS newcol""" + val csfaResult = sendJubaQL(csfaStmt) + csfaResult shouldBe a[Success[_]] + + val umStmt = """UPDATE MODEL test USING train FROM ds""" + val umResult = sendJubaQL(umStmt) + umResult shouldBe a[Success[_]] + + val beforeStartTime = System.currentTimeMillis() + + //before start + val stBefore = sendJubaQL("STATUS") + stBefore shouldBe a[Success[_]] + stBefore.get._1 shouldBe 200 + val beforeStreamStatus = stBefore.get._2 \ "streams" \ "output" + + val before = getStreamStatus(beforeStreamStatus) + before.inputCount shouldBe 0L + before.outputCount shouldBe 0L + before.startTime shouldBe 0L + + val spResult = sendJubaQL("START PROCESSING ds") + spResult shouldBe a[Success[_]] + waitUntilDone("ds", 6000) + + val stResult = sendJubaQL("STATUS") + stResult shouldBe a[Success[_]] + stResult.get._1 shouldBe 200 + val streamStatus = stResult.get._2 \ "streams" \ "output" + + val result = getStreamStatus(streamStatus) + result.inputCount shouldBe 44L + result.outputCount shouldBe 44L + result.startTime should be > beforeStartTime + + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + // streamStatusの目視確認用デバッグ出力 + println(streamStatus) + } + + it should "StreamStatus(starTime/outputCount/inputCount) from stream with CLASSIFIER" taggedAs (LocalTest, JubatusTest) in { + implicit val formats = DefaultFormats + + val cmStmt = """CREATE DATASOURCE ds (label string, name string) FROM (STORAGE: "file://src/test/resources/shogun_data.json")""" + val cmResult = sendJubaQL(cmStmt) + cmResult shouldBe a[Success[_]] + + val csStmt = """CREATE STREAM stream1 FROM SELECT * FROM ds WHERE label = '徳川'""" + val csResult = sendJubaQL(csStmt) + csResult shouldBe a[Success[_]] + + val config = Source.fromFile("src/test/resources/shogun.json").getLines().mkString("") + val cdStmt = s"""CREATE CLASSIFIER MODEL test (label: label) AS name WITH unigram CONFIG '$config'""" + val cdResult = sendJubaQL(cdStmt) + cdResult shouldBe a[Success[_]] + + val csfaStmt = """CREATE STREAM output FROM ANALYZE stream1 BY MODEL test USING classify AS newcol""" + val csfaResult = sendJubaQL(csfaStmt) + csfaResult shouldBe a[Success[_]] + + val umStmt = """UPDATE MODEL test USING train FROM stream1""" + val umResult = sendJubaQL(umStmt) + umResult shouldBe a[Success[_]] + + val beforeStartTime = System.currentTimeMillis() + + //before start + val stBefore = sendJubaQL("STATUS") + stBefore shouldBe a[Success[_]] + stBefore.get._1 shouldBe 200 + val beforeStreamStatus = stBefore.get._2 \ "streams" \ "output" + + val before = getStreamStatus(beforeStreamStatus) + before.inputCount shouldBe 0L + before.outputCount shouldBe 0L + before.startTime shouldBe 0L + + val spResult = sendJubaQL("START PROCESSING ds") + spResult shouldBe a[Success[_]] + waitUntilDone("ds", 6000) + + val stResult = sendJubaQL("STATUS") + stResult shouldBe a[Success[_]] + stResult.get._1 shouldBe 200 + val streamStatus = stResult.get._2 \ "streams" \ "output" + + val result = getStreamStatus(streamStatus) + result.inputCount shouldBe 14L + result.outputCount shouldBe 14L + result.startTime should be > beforeStartTime + + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + // streamStatusの目視確認用デバッグ出力 + println(streamStatus) + } } class UpdateModelSpec @@ -1493,10 +2179,11 @@ class UpdateModelSpec waitUntilDone("ds", 6000) stdout.toString should include("column named 'doesnotexist' not found") - stdout.toString should - include("HybridProcessor - Error while waiting for static processing end") // we log once - stdout.toString should - include("HybridProcessor - Error while setting up stream processing") // ... and again +// エラー終了しないように改修したため、以下の判定は削除 +// stdout.toString should +// include("HybridProcessor - Error while waiting for static processing end") // we log once +// stdout.toString should +// include("HybridProcessor - Error while setting up stream processing") // ... and again // shut down val sdResult = sendJubaQL("SHUTDOWN") @@ -2602,14 +3289,334 @@ class CreateFunctionSpec implicit val formats = DefaultFormats val cfResult = sendJubaQL( - """CREATE FUNCTION addABC(arg string) RETURNS string LANGUAGE JavaScript AS $$ - |return arg + "ABC"; + """CREATE FUNCTION addABC(arg string) RETURNS string LANGUAGE JavaScript AS $$ + |return arg + "ABC"; + |$$ + """.stripMargin) + cfResult shouldBe a[Success[_]] + cfResult.get._1 shouldBe 200 + cfResult.get._2 \ "result" shouldBe JString("CREATE FUNCTION") + + val cmStmt = """CREATE DATASOURCE ds1 (label string, name string) FROM (STORAGE: "file://src/test/resources/shogun_data.json")""" + val cmResult = sendJubaQL(cmStmt) + cmResult shouldBe a[Success[_]] + + val config = Source.fromFile("src/test/resources/shogun.json").getLines().mkString("") + val cdStmt = s"""CREATE CLASSIFIER MODEL test (label: label) AS name WITH unigram CONFIG '$config'""" + val cdResult = sendJubaQL(cdStmt) + cdResult shouldBe a[Success[_]] + + val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT addABC(label) AS label, name FROM ds1""") + csResult shouldBe a[Success[_]] + csResult.get._1 shouldBe 200 + csResult.get._2 \ "result" shouldBe JString("CREATE STREAM") + + val umStmt = """UPDATE MODEL test USING train FROM ds2""" + val umResult = sendJubaQL(umStmt) + umResult shouldBe a[Success[_]] + + val spResult = sendJubaQL("START PROCESSING ds1") + spResult shouldBe a[Success[_]] + waitUntilDone("ds1", 6000) + + // analyze + val aStmt = """ANALYZE '{"name": "慶喜"}' BY MODEL test USING classify""" + val aResult = sendJubaQL(aStmt) + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // now check the result + aResult shouldBe a[Success[_]] + if (aResult.get._1 != 200) + println(stdout.toString) + aResult.get._1 shouldBe 200 + (aResult.get._2 \ "result").extractOpt[ClassifierResult] match { + case Some(pred) => + val scores = pred.predictions.map(res => (res.label, res.score)).toMap + // the order of entries differs per machine/OS, so we use this + // slightly complicated way of checking equality + scores.keys.toList should contain only("徳川ABC", "足利ABC", "北条ABC") + Math.abs(scores("徳川ABC") - 0.07692306488752365) should be < 0.00001 + scores("足利ABC") shouldBe 0.0 + scores("北条ABC") shouldBe 0.0 + case None => + fail("Failed to parse returned content as a classifier result") + } + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + } + + it should "make callable a function which takes two string arguments" taggedAs (LocalTest, JubatusTest) in { + implicit val formats = DefaultFormats + + val cfResult = sendJubaQL( + """CREATE FUNCTION concat(arg1 string, arg2 string) RETURNS string LANGUAGE JavaScript AS $$ + |return arg1 + arg2; + |$$ + """.stripMargin) + cfResult shouldBe a[Success[_]] + cfResult.get._1 shouldBe 200 + cfResult.get._2 \ "result" shouldBe JString("CREATE FUNCTION") + + val cmStmt = """CREATE DATASOURCE ds1 (label string, name string) FROM (STORAGE: "file://src/test/resources/shogun_data.json")""" + val cmResult = sendJubaQL(cmStmt) + cmResult shouldBe a[Success[_]] + + val config = Source.fromFile("src/test/resources/shogun.json").getLines().mkString("") + val cdStmt = s"""CREATE CLASSIFIER MODEL test (label: label) AS name WITH unigram CONFIG '$config'""" + val cdResult = sendJubaQL(cdStmt) + cdResult shouldBe a[Success[_]] + + val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT concat(label, "ABC") AS label, name FROM ds1""") + csResult shouldBe a[Success[_]] + csResult.get._1 shouldBe 200 + csResult.get._2 \ "result" shouldBe JString("CREATE STREAM") + + val umStmt = """UPDATE MODEL test USING train FROM ds2""" + val umResult = sendJubaQL(umStmt) + umResult shouldBe a[Success[_]] + + val spResult = sendJubaQL("START PROCESSING ds1") + spResult shouldBe a[Success[_]] + waitUntilDone("ds1", 6000) + + // analyze + val aStmt = """ANALYZE '{"name": "慶喜"}' BY MODEL test USING classify""" + val aResult = sendJubaQL(aStmt) + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // now check the result + aResult shouldBe a[Success[_]] + if (aResult.get._1 != 200) + println(stdout.toString) + aResult.get._1 shouldBe 200 + (aResult.get._2 \ "result").extractOpt[ClassifierResult] match { + case Some(pred) => + val scores = pred.predictions.map(res => (res.label, res.score)).toMap + // the order of entries differs per machine/OS, so we use this + // slightly complicated way of checking equality + scores.keys.toList should contain only("徳川ABC", "足利ABC", "北条ABC") + Math.abs(scores("徳川ABC") - 0.07692306488752365) should be < 0.00001 + scores("足利ABC") shouldBe 0.0 + scores("北条ABC") shouldBe 0.0 + case None => + fail("Failed to parse returned content as a classifier result") + } + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + } + + // TODO: generate tests which take many arguments + it should "make callable a function which takes three string arguments" taggedAs (LocalTest, JubatusTest) in { + implicit val formats = DefaultFormats + + val cfResult = sendJubaQL( + """CREATE FUNCTION concat3(arg1 string, arg2 string, arg3 string) RETURNS string LANGUAGE JavaScript AS $$ + |return arg1 + arg2 + arg3; + |$$ + """.stripMargin) + cfResult shouldBe a[Success[_]] + cfResult.get._1 shouldBe 200 + cfResult.get._2 \ "result" shouldBe JString("CREATE FUNCTION") + + val cmStmt = """CREATE DATASOURCE ds1 (label string, name string) FROM (STORAGE: "file://src/test/resources/shogun_data.json")""" + val cmResult = sendJubaQL(cmStmt) + cmResult shouldBe a[Success[_]] + + val config = Source.fromFile("src/test/resources/shogun.json").getLines().mkString("") + val cdStmt = s"""CREATE CLASSIFIER MODEL test (label: label) AS name WITH unigram CONFIG '$config'""" + val cdResult = sendJubaQL(cdStmt) + cdResult shouldBe a[Success[_]] + + val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT concat3(label, "AB", "C") AS label, name FROM ds1""") + csResult shouldBe a[Success[_]] + csResult.get._1 shouldBe 200 + csResult.get._2 \ "result" shouldBe JString("CREATE STREAM") + + val umStmt = """UPDATE MODEL test USING train FROM ds2""" + val umResult = sendJubaQL(umStmt) + umResult shouldBe a[Success[_]] + + val spResult = sendJubaQL("START PROCESSING ds1") + spResult shouldBe a[Success[_]] + waitUntilDone("ds1", 6000) + + // analyze + val aStmt = """ANALYZE '{"name": "慶喜"}' BY MODEL test USING classify""" + val aResult = sendJubaQL(aStmt) + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // now check the result + aResult shouldBe a[Success[_]] + if (aResult.get._1 != 200) + println(stdout.toString) + aResult.get._1 shouldBe 200 + (aResult.get._2 \ "result").extractOpt[ClassifierResult] match { + case Some(pred) => + val scores = pred.predictions.map(res => (res.label, res.score)).toMap + // the order of entries differs per machine/OS, so we use this + // slightly complicated way of checking equality + scores.keys.toList should contain only("徳川ABC", "足利ABC", "北条ABC") + Math.abs(scores("徳川ABC") - 0.07692306488752365) should be < 0.00001 + scores("足利ABC") shouldBe 0.0 + scores("北条ABC") shouldBe 0.0 + case None => + fail("Failed to parse returned content as a classifier result") + } + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + } + + it should "make callable a function which takes four string arguments" taggedAs (LocalTest, JubatusTest) in { + implicit val formats = DefaultFormats + + val cfResult = sendJubaQL( + """CREATE FUNCTION concat4(arg1 string, arg2 string, arg3 string, arg4 string) RETURNS string LANGUAGE JavaScript AS $$ + |return arg1 + arg2 + arg3 + arg4; + |$$ + """.stripMargin) + cfResult shouldBe a[Success[_]] + cfResult.get._1 shouldBe 200 + cfResult.get._2 \ "result" shouldBe JString("CREATE FUNCTION") + + val cmStmt = """CREATE DATASOURCE ds1 (label string, name string) FROM (STORAGE: "file://src/test/resources/shogun_data.json")""" + val cmResult = sendJubaQL(cmStmt) + cmResult shouldBe a[Success[_]] + + val config = Source.fromFile("src/test/resources/shogun.json").getLines().mkString("") + val cdStmt = s"""CREATE CLASSIFIER MODEL test (label: label) AS name WITH unigram CONFIG '$config'""" + val cdResult = sendJubaQL(cdStmt) + cdResult shouldBe a[Success[_]] + + val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT concat4(label, "A", "B", "C") AS label, name FROM ds1""") + csResult shouldBe a[Success[_]] + csResult.get._1 shouldBe 200 + csResult.get._2 \ "result" shouldBe JString("CREATE STREAM") + + val umStmt = """UPDATE MODEL test USING train FROM ds2""" + val umResult = sendJubaQL(umStmt) + umResult shouldBe a[Success[_]] + + val spResult = sendJubaQL("START PROCESSING ds1") + spResult shouldBe a[Success[_]] + waitUntilDone("ds1", 6000) + + // analyze + val aStmt = """ANALYZE '{"name": "慶喜"}' BY MODEL test USING classify""" + val aResult = sendJubaQL(aStmt) + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // now check the result + aResult shouldBe a[Success[_]] + if (aResult.get._1 != 200) + println(stdout.toString) + aResult.get._1 shouldBe 200 + (aResult.get._2 \ "result").extractOpt[ClassifierResult] match { + case Some(pred) => + val scores = pred.predictions.map(res => (res.label, res.score)).toMap + // the order of entries differs per machine/OS, so we use this + // slightly complicated way of checking equality + scores.keys.toList should contain only("徳川ABC", "足利ABC", "北条ABC") + Math.abs(scores("徳川ABC") - 0.07692306488752365) should be < 0.00001 + scores("足利ABC") shouldBe 0.0 + scores("北条ABC") shouldBe 0.0 + case None => + fail("Failed to parse returned content as a classifier result") + } + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + } + + it should "make callable a function which takes five string arguments" taggedAs (LocalTest, JubatusTest) in { + implicit val formats = DefaultFormats + + val cfResult = sendJubaQL( + """CREATE FUNCTION concat5(arg1 string, arg2 string, arg3 string, arg4 string, arg5 string) RETURNS string LANGUAGE JavaScript AS $$ + |return arg1 + arg2 + arg3 + arg4 + arg5; + |$$ + """.stripMargin) + cfResult shouldBe a[Success[_]] + cfResult.get._1 shouldBe 200 + cfResult.get._2 \ "result" shouldBe JString("CREATE FUNCTION") + + val cmStmt = """CREATE DATASOURCE ds1 (label string, name string) FROM (STORAGE: "file://src/test/resources/shogun_data.json")""" + val cmResult = sendJubaQL(cmStmt) + cmResult shouldBe a[Success[_]] + + val config = Source.fromFile("src/test/resources/shogun.json").getLines().mkString("") + val cdStmt = s"""CREATE CLASSIFIER MODEL test (label: label) AS name WITH unigram CONFIG '$config'""" + val cdResult = sendJubaQL(cdStmt) + cdResult shouldBe a[Success[_]] + + val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT concat5(label, "A", "B", "C", "D") AS label, name FROM ds1""") + csResult shouldBe a[Success[_]] + csResult.get._1 shouldBe 200 + csResult.get._2 \ "result" shouldBe JString("CREATE STREAM") + + val umStmt = """UPDATE MODEL test USING train FROM ds2""" + val umResult = sendJubaQL(umStmt) + umResult shouldBe a[Success[_]] + + val spResult = sendJubaQL("START PROCESSING ds1") + spResult shouldBe a[Success[_]] + waitUntilDone("ds1", 6000) + + // analyze + val aStmt = """ANALYZE '{"name": "慶喜"}' BY MODEL test USING classify""" + val aResult = sendJubaQL(aStmt) + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // now check the result + aResult shouldBe a[Success[_]] + if (aResult.get._1 != 200) + println(stdout.toString) + aResult.get._1 shouldBe 200 + (aResult.get._2 \ "result").extractOpt[ClassifierResult] match { + case Some(pred) => + val scores = pred.predictions.map(res => (res.label, res.score)).toMap + // the order of entries differs per machine/OS, so we use this + // slightly complicated way of checking equality + scores.keys.toList should contain only("徳川ABCD", "足利ABCD", "北条ABCD") + Math.abs(scores("徳川ABCD") - 0.07692306488752365) should be < 0.00001 + scores("足利ABCD") shouldBe 0.0 + scores("北条ABCD") shouldBe 0.0 + case None => + fail("Failed to parse returned content as a classifier result") + } + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + } + + it should "make callable a function which takes arguments of different type" taggedAs (LocalTest, JubatusTest) in { + implicit val formats = DefaultFormats + + val cfResult = sendJubaQL( + """CREATE FUNCTION multiply(n numeric, s string) RETURNS string LANGUAGE JavaScript AS $$ + |return Array(n + 1).join(s); |$$ """.stripMargin) cfResult shouldBe a[Success[_]] cfResult.get._1 shouldBe 200 cfResult.get._2 \ "result" shouldBe JString("CREATE FUNCTION") + val cfResult2 = sendJubaQL( + """CREATE FUNCTION concat(arg1 string, arg2 string) RETURNS string LANGUAGE JavaScript AS $$ + |return arg1 + arg2; + |$$ + """.stripMargin) + cfResult2 shouldBe a[Success[_]] + cfResult2.get._1 shouldBe 200 + cfResult2.get._2 \ "result" shouldBe JString("CREATE FUNCTION") + val cmStmt = """CREATE DATASOURCE ds1 (label string, name string) FROM (STORAGE: "file://src/test/resources/shogun_data.json")""" val cmResult = sendJubaQL(cmStmt) cmResult shouldBe a[Success[_]] @@ -2619,7 +3626,7 @@ class CreateFunctionSpec val cdResult = sendJubaQL(cdStmt) cdResult shouldBe a[Success[_]] - val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT addABC(label) AS label, name FROM ds1""") + val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT concat(multiply(3, label), "ABC") AS label, name FROM ds1""") csResult shouldBe a[Success[_]] csResult.get._1 shouldBe 200 csResult.get._2 \ "result" shouldBe JString("CREATE STREAM") @@ -2648,10 +3655,10 @@ class CreateFunctionSpec val scores = pred.predictions.map(res => (res.label, res.score)).toMap // the order of entries differs per machine/OS, so we use this // slightly complicated way of checking equality - scores.keys.toList should contain only("徳川ABC", "足利ABC", "北条ABC") - Math.abs(scores("徳川ABC") - 0.07692306488752365) should be < 0.00001 - scores("足利ABC") shouldBe 0.0 - scores("北条ABC") shouldBe 0.0 + scores.keys.toList should contain only("徳川徳川徳川ABC", "足利足利足利ABC", "北条北条北条ABC") + Math.abs(scores("徳川徳川徳川ABC") - 0.07692306488752365) should be < 0.00001 + scores("足利足利足利ABC") shouldBe 0.0 + scores("北条北条北条ABC") shouldBe 0.0 case None => fail("Failed to parse returned content as a classifier result") } @@ -2660,12 +3667,16 @@ class CreateFunctionSpec exitValue shouldBe 0 } - it should "make callable a function which takes two string arguments" taggedAs (LocalTest, JubatusTest) in { + it should "make callable a throw error function which takes a string argument" taggedAs (LocalTest, JubatusTest) in { implicit val formats = DefaultFormats val cfResult = sendJubaQL( - """CREATE FUNCTION concat(arg1 string, arg2 string) RETURNS string LANGUAGE JavaScript AS $$ - |return arg1 + arg2; + """CREATE FUNCTION addABC(arg string) RETURNS string LANGUAGE JavaScript AS $$ + |if (arg == '徳川') { + | return arg + "ABC"; + |} else { + | throw new Error('error Message'); + |} |$$ """.stripMargin) cfResult shouldBe a[Success[_]] @@ -2681,7 +3692,7 @@ class CreateFunctionSpec val cdResult = sendJubaQL(cdStmt) cdResult shouldBe a[Success[_]] - val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT concat(label, "ABC") AS label, name FROM ds1""") + val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT addABC(label) AS label, name FROM ds1""") csResult shouldBe a[Success[_]] csResult.get._1 shouldBe 200 csResult.get._2 \ "result" shouldBe JString("CREATE STREAM") @@ -2710,10 +3721,9 @@ class CreateFunctionSpec val scores = pred.predictions.map(res => (res.label, res.score)).toMap // the order of entries differs per machine/OS, so we use this // slightly complicated way of checking equality - scores.keys.toList should contain only("徳川ABC", "足利ABC", "北条ABC") + scores.keys.size shouldBe 1 + scores.keys.toList should contain only("徳川ABC") Math.abs(scores("徳川ABC") - 0.07692306488752365) should be < 0.00001 - scores("足利ABC") shouldBe 0.0 - scores("北条ABC") shouldBe 0.0 case None => fail("Failed to parse returned content as a classifier result") } @@ -2722,19 +3732,35 @@ class CreateFunctionSpec exitValue shouldBe 0 } - // TODO: generate tests which take many arguments - it should "make callable a function which takes three string arguments" taggedAs (LocalTest, JubatusTest) in { + it should "make callable a throw error function which takes two string argument returns numeric" taggedAs (LocalTest, JubatusTest) in { implicit val formats = DefaultFormats val cfResult = sendJubaQL( - """CREATE FUNCTION concat3(arg1 string, arg2 string, arg3 string) RETURNS string LANGUAGE JavaScript AS $$ - |return arg1 + arg2 + arg3; + """CREATE FUNCTION addABC(arg string, arg2 numeric) RETURNS string LANGUAGE JavaScript AS $$ + |if (arg == '徳川') { + | return arg + "ABC : " + arg2; + |} else { + | throw new Error('error Message'); + |} |$$ """.stripMargin) cfResult shouldBe a[Success[_]] cfResult.get._1 shouldBe 200 cfResult.get._2 \ "result" shouldBe JString("CREATE FUNCTION") + val cfResultNum = sendJubaQL( + """CREATE FUNCTION getLength(arg string) RETURNS numeric LANGUAGE JavaScript AS $$ + |if (arg == '徳川') { + | return parseFloat(arg.length); + |} else { + | throw new Error('error Message'); + |} + |$$ + """.stripMargin) + cfResultNum shouldBe a[Success[_]] + cfResultNum.get._1 shouldBe 200 + cfResultNum.get._2 \ "result" shouldBe JString("CREATE FUNCTION") + val cmStmt = """CREATE DATASOURCE ds1 (label string, name string) FROM (STORAGE: "file://src/test/resources/shogun_data.json")""" val cmResult = sendJubaQL(cmStmt) cmResult shouldBe a[Success[_]] @@ -2744,12 +3770,17 @@ class CreateFunctionSpec val cdResult = sendJubaQL(cdStmt) cdResult shouldBe a[Success[_]] - val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT concat3(label, "AB", "C") AS label, name FROM ds1""") + val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT label, getLength(label) AS labelLength, name FROM ds1""") csResult shouldBe a[Success[_]] csResult.get._1 shouldBe 200 csResult.get._2 \ "result" shouldBe JString("CREATE STREAM") - val umStmt = """UPDATE MODEL test USING train FROM ds2""" + val csResult2 = sendJubaQL("""CREATE STREAM ds3 FROM SELECT addABC(label, labelLength) AS label, name FROM ds2""") + csResult2 shouldBe a[Success[_]] + csResult2.get._1 shouldBe 200 + csResult2.get._2 \ "result" shouldBe JString("CREATE STREAM") + + val umStmt = """UPDATE MODEL test USING train FROM ds3""" val umResult = sendJubaQL(umStmt) umResult shouldBe a[Success[_]] @@ -2773,10 +3804,9 @@ class CreateFunctionSpec val scores = pred.predictions.map(res => (res.label, res.score)).toMap // the order of entries differs per machine/OS, so we use this // slightly complicated way of checking equality - scores.keys.toList should contain only("徳川ABC", "足利ABC", "北条ABC") - Math.abs(scores("徳川ABC") - 0.07692306488752365) should be < 0.00001 - scores("足利ABC") shouldBe 0.0 - scores("北条ABC") shouldBe 0.0 + scores.keys.size shouldBe 1 + scores.keys.toList should contain only("徳川ABC : 2") + Math.abs(scores("徳川ABC : 2") - 0.07692306488752365) should be < 0.00001 case None => fail("Failed to parse returned content as a classifier result") } @@ -2785,18 +3815,35 @@ class CreateFunctionSpec exitValue shouldBe 0 } - it should "make callable a function which takes four string arguments" taggedAs (LocalTest, JubatusTest) in { + it should "make callable a throw error function which takes two string argument returns boolean" taggedAs (LocalTest, JubatusTest) in { implicit val formats = DefaultFormats val cfResult = sendJubaQL( - """CREATE FUNCTION concat4(arg1 string, arg2 string, arg3 string, arg4 string) RETURNS string LANGUAGE JavaScript AS $$ - |return arg1 + arg2 + arg3 + arg4; + """CREATE FUNCTION addABC(arg string, arg2 boolean) RETURNS string LANGUAGE JavaScript AS $$ + |if (arg == '徳川') { + | return arg + "ABC : " + arg2; + |} else { + | throw new Error('error Message'); + |} |$$ """.stripMargin) cfResult shouldBe a[Success[_]] cfResult.get._1 shouldBe 200 cfResult.get._2 \ "result" shouldBe JString("CREATE FUNCTION") + val cfResultNum = sendJubaQL( + """CREATE FUNCTION isTokugawa(arg string) RETURNS boolean LANGUAGE JavaScript AS $$ + |if (arg == '徳川') { + | return true; + |} else { + | throw new Error('error Message'); + |} + |$$ + """.stripMargin) + cfResultNum shouldBe a[Success[_]] + cfResultNum.get._1 shouldBe 200 + cfResultNum.get._2 \ "result" shouldBe JString("CREATE FUNCTION") + val cmStmt = """CREATE DATASOURCE ds1 (label string, name string) FROM (STORAGE: "file://src/test/resources/shogun_data.json")""" val cmResult = sendJubaQL(cmStmt) cmResult shouldBe a[Success[_]] @@ -2806,12 +3853,17 @@ class CreateFunctionSpec val cdResult = sendJubaQL(cdStmt) cdResult shouldBe a[Success[_]] - val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT concat4(label, "A", "B", "C") AS label, name FROM ds1""") + val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT label, isTokugawa(label) AS tokugawa, name FROM ds1""") csResult shouldBe a[Success[_]] csResult.get._1 shouldBe 200 csResult.get._2 \ "result" shouldBe JString("CREATE STREAM") - val umStmt = """UPDATE MODEL test USING train FROM ds2""" + val csResult2 = sendJubaQL("""CREATE STREAM ds3 FROM SELECT addABC(label, tokugawa) AS label, name FROM ds2""") + csResult2 shouldBe a[Success[_]] + csResult2.get._1 shouldBe 200 + csResult2.get._2 \ "result" shouldBe JString("CREATE STREAM") + + val umStmt = """UPDATE MODEL test USING train FROM ds3""" val umResult = sendJubaQL(umStmt) umResult shouldBe a[Success[_]] @@ -2835,10 +3887,9 @@ class CreateFunctionSpec val scores = pred.predictions.map(res => (res.label, res.score)).toMap // the order of entries differs per machine/OS, so we use this // slightly complicated way of checking equality - scores.keys.toList should contain only("徳川ABC", "足利ABC", "北条ABC") - Math.abs(scores("徳川ABC") - 0.07692306488752365) should be < 0.00001 - scores("足利ABC") shouldBe 0.0 - scores("北条ABC") shouldBe 0.0 + scores.keys.size shouldBe 1 + scores.keys.toList should contain only("徳川ABC : true") + Math.abs(scores("徳川ABC : true") - 0.07692306488752365) should be < 0.00001 case None => fail("Failed to parse returned content as a classifier result") } @@ -2847,33 +3898,31 @@ class CreateFunctionSpec exitValue shouldBe 0 } - it should "make callable a function which takes five string arguments" taggedAs (LocalTest, JubatusTest) in { - implicit val formats = DefaultFormats - + it should "work correctly with ANOMALY use function" taggedAs (LocalTest, JubatusTest) in { val cfResult = sendJubaQL( - """CREATE FUNCTION concat5(arg1 string, arg2 string, arg3 string, arg4 string, arg5 string) RETURNS string LANGUAGE JavaScript AS $$ - |return arg1 + arg2 + arg3 + arg4 + arg5; + """CREATE FUNCTION addABC(arg string) RETURNS string LANGUAGE JavaScript AS $$ + |return arg + "ABC"; |$$ """.stripMargin) cfResult shouldBe a[Success[_]] cfResult.get._1 shouldBe 200 cfResult.get._2 \ "result" shouldBe JString("CREATE FUNCTION") - val cmStmt = """CREATE DATASOURCE ds1 (label string, name string) FROM (STORAGE: "file://src/test/resources/shogun_data.json")""" + val cmStmt = """CREATE DATASOURCE ds1 (label string, name string) FROM (STORAGE: "file://src/test/resources/shogun_1.json")""" val cmResult = sendJubaQL(cmStmt) cmResult shouldBe a[Success[_]] - val config = Source.fromFile("src/test/resources/shogun.json").getLines().mkString("") - val cdStmt = s"""CREATE CLASSIFIER MODEL test (label: label) AS name WITH unigram CONFIG '$config'""" + val config = Source.fromFile("src/test/resources/lof.json").getLines().mkString("") + val cdStmt = s"""CREATE ANOMALY MODEL test AS name WITH unigram CONFIG '$config'""" val cdResult = sendJubaQL(cdStmt) cdResult shouldBe a[Success[_]] - val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT concat5(label, "A", "B", "C", "D") AS label, name FROM ds1""") + val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT label, addABC(name) AS name FROM ds1""") csResult shouldBe a[Success[_]] csResult.get._1 shouldBe 200 csResult.get._2 \ "result" shouldBe JString("CREATE STREAM") - val umStmt = """UPDATE MODEL test USING train FROM ds2""" + val umStmt = """UPDATE MODEL test USING add FROM ds2""" val umResult = sendJubaQL(umStmt) umResult shouldBe a[Success[_]] @@ -2881,9 +3930,10 @@ class CreateFunctionSpec spResult shouldBe a[Success[_]] waitUntilDone("ds1", 6000) - // analyze - val aStmt = """ANALYZE '{"name": "慶喜"}' BY MODEL test USING classify""" + val aStmt = """ANALYZE '{"label": "徳川","name": "家康ABC"}' BY MODEL test USING calc_score""" val aResult = sendJubaQL(aStmt) + val aStmt2 = """ANALYZE '{"label": "徳川","name": "ABC"}' BY MODEL test USING calc_score""" + val aResult2 = sendJubaQL(aStmt2) // shut down val sdResult = sendJubaQL("SHUTDOWN") sdResult shouldBe a[Success[_]] @@ -2892,59 +3942,194 @@ class CreateFunctionSpec if (aResult.get._1 != 200) println(stdout.toString) aResult.get._1 shouldBe 200 - (aResult.get._2 \ "result").extractOpt[ClassifierResult] match { - case Some(pred) => - val scores = pred.predictions.map(res => (res.label, res.score)).toMap + (aResult.get._2 \ "result").extractOpt[AnomalyScore] match { + case Some(scoreRes) => + val scores = scoreRes.score // the order of entries differs per machine/OS, so we use this // slightly complicated way of checking equality - scores.keys.toList should contain only("徳川ABCD", "足利ABCD", "北条ABCD") - Math.abs(scores("徳川ABCD") - 0.07692306488752365) should be < 0.00001 - scores("足利ABCD") shouldBe 0.0 - scores("北条ABCD") shouldBe 0.0 + scores shouldBe 1.0F case None => - fail("Failed to parse returned content as a classifier result") + fail("Failed to parse returned content as a calc_score result") } + // now check the result + + // TODO 結果がInfinityのため、レスポンスはパースエラーとなる + aResult2 shouldBe a[Failure[_]] + // wait until shutdown val exitValue = process.exitValue() exitValue shouldBe 0 } - it should "make callable a function which takes arguments of different type" taggedAs (LocalTest, JubatusTest) in { - implicit val formats = DefaultFormats + it should "work correctly with ANOMALY use error function" taggedAs (LocalTest, JubatusTest) in { + val cfResult = sendJubaQL( + """CREATE FUNCTION addABC(arg string) RETURNS string LANGUAGE JavaScript AS $$ + |if (arg == '徳川') { + | return arg + "ABC"; + |} else { + | throw new Error('error Message'); + |}$$ + """.stripMargin) + cfResult shouldBe a[Success[_]] + cfResult.get._1 shouldBe 200 + cfResult.get._2 \ "result" shouldBe JString("CREATE FUNCTION") + + val cmStmt = """CREATE DATASOURCE ds1 (label string, name string) FROM (STORAGE: "file://src/test/resources/shogun_2.json")""" + val cmResult = sendJubaQL(cmStmt) + cmResult shouldBe a[Success[_]] + + val config = Source.fromFile("src/test/resources/lof.json").getLines().mkString("") + val cdStmt = s"""CREATE ANOMALY MODEL test AS name WITH unigram CONFIG '$config'""" + val cdResult = sendJubaQL(cdStmt) + cdResult shouldBe a[Success[_]] + + val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT addABC(label) AS label, name FROM ds1""") + csResult shouldBe a[Success[_]] + csResult.get._1 shouldBe 200 + csResult.get._2 \ "result" shouldBe JString("CREATE STREAM") + + val umStmt = """UPDATE MODEL test USING add FROM ds2""" + val umResult = sendJubaQL(umStmt) + umResult shouldBe a[Success[_]] + + val spResult = sendJubaQL("START PROCESSING ds1") + spResult shouldBe a[Success[_]] + waitUntilDone("ds1", 6000) + + val aStmt = """ANALYZE '{"label": "徳川ABC","name": "家康"}' BY MODEL test USING calc_score""" + val aResult = sendJubaQL(aStmt) + + val aStmt2 = """ANALYZE '{"label": "足利","name": "義満"}' BY MODEL test USING calc_score""" + val aResult2 = sendJubaQL(aStmt2) + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + // now check the result + aResult shouldBe a[Success[_]] + if (aResult.get._1 != 200) + println(stdout.toString) + aResult.get._1 shouldBe 200 + (aResult.get._2 \ "result").extractOpt[AnomalyScore] match { + case Some(scoreRes) => + val scores = scoreRes.score + // the order of entries differs per machine/OS, so we use this + // slightly complicated way of checking equality + scores shouldBe 1.0F + case None => + fail("Failed to parse returned content as a calc_score result") + } + // now check the result + + // TODO 結果がInfinityのため、レスポンスはパースエラーとなる + aResult2 shouldBe a[Failure[_]] + + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + } + it should "recommender use function" taggedAs (LocalTest, JubatusTest) in { val cfResult = sendJubaQL( - """CREATE FUNCTION multiply(n numeric, s string) RETURNS string LANGUAGE JavaScript AS $$ - |return Array(n + 1).join(s); + """CREATE FUNCTION addABC(arg string) RETURNS string LANGUAGE JavaScript AS $$ + |return arg + "ABC"; |$$ """.stripMargin) cfResult shouldBe a[Success[_]] cfResult.get._1 shouldBe 200 cfResult.get._2 \ "result" shouldBe JString("CREATE FUNCTION") - val cfResult2 = sendJubaQL( - """CREATE FUNCTION concat(arg1 string, arg2 string) RETURNS string LANGUAGE JavaScript AS $$ - |return arg1 + arg2; + val cmStmt = """CREATE DATASOURCE ds1 (id string, team string, 打率 numeric, 試合数 numeric, 打席 numeric, 打数 numeric, 安打 numeric, 本塁打 numeric, 打点 numeric, 盗塁 numeric, 四球 numeric, 死球 numeric, 三振 numeric, 犠 打 numeric, 併殺打 numeric, 長打率 numeric, 出塁率 numeric, OPS numeric, RC27 numeric, XR27 numeric) FROM (STORAGE: "file://src/test/resources/npb_similar_player_data.json")""" + val cmResult = sendJubaQL(cmStmt) + cmResult shouldBe a[Success[_]] + + val config = Source.fromFile("src/test/resources/npb_similar_player.json").getLines().mkString("") + val cdStmt = s"""CREATE RECOMMENDER MODEL test (id: id) AS team WITH unigram, * WITH id CONFIG '$config'""" + val cdResult = sendJubaQL(cdStmt) + cdResult shouldBe a[Success[_]] + + val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT addABC(id) AS id, team, 打率, 試合数, 打席, 打数, 安打, 本塁打, 打点, 盗塁, 四球, 死球, 三振, 犠打, 併殺打, 長打率, 出塁率, OPS, RC27, XR27 FROM ds1""") + csResult shouldBe a[Success[_]] + csResult.get._1 shouldBe 200 + csResult.get._2 \ "result" shouldBe JString("CREATE STREAM") + + val umStmt = """UPDATE MODEL test USING update_row FROM ds2""" + val umResult = sendJubaQL(umStmt) + umResult shouldBe a[Success[_]] + + val spResult = sendJubaQL("START PROCESSING ds1") + spResult shouldBe a[Success[_]] + waitUntilDone("ds1", 6000) + + val aStmt = """ANALYZE '荻野貴司ABC' BY MODEL test USING complete_row_from_id""" + val aResult = sendJubaQL(aStmt) + val aStmt2 = """ANALYZE '荻野貴司' BY MODEL test USING complete_row_from_id""" + val aResult2 = sendJubaQL(aStmt2) + // shut down + val sdResult = sendJubaQL("SHUTDOWN") + sdResult shouldBe a[Success[_]] + + // now check the result + aResult shouldBe a[Success[_]] + if (aResult.get._1 != 200) + println(stdout.toString) + aResult.get._1 shouldBe 200 + aResult.get._2 \ "result" \ "num_values" match { + case JObject(list) => + val vals = list.collect({ + case (s, JDouble(j)) => (s, j) + }).toMap + vals.size shouldBe 18 + case _ => + fail("there was no 'num_values' key") + } + aResult2 shouldBe a[Success[_]] + if (aResult2.get._1 != 200) + println(stdout.toString) + aResult2.get._1 shouldBe 200 + aResult2.get._2 \ "result" \ "num_values" match { + case JObject(list) => + val vals = list.collect({ + case (s, JDouble(j)) => (s, j) + }).toMap + + vals.size shouldBe 0 + case _ => + fail("there was no 'num_values' key") + } + // wait until shutdown + val exitValue = process.exitValue() + exitValue shouldBe 0 + } + + it should "recommender use error function" taggedAs (LocalTest, JubatusTest) in { + val cfResult = sendJubaQL( + """CREATE FUNCTION addABC(name string, team string) RETURNS string LANGUAGE JavaScript AS $$ + |if (team == '巨人') { + | return name + "ABC"; + |} else { + | throw new Error("Error Message"); + |} |$$ """.stripMargin) - cfResult2 shouldBe a[Success[_]] - cfResult2.get._1 shouldBe 200 - cfResult2.get._2 \ "result" shouldBe JString("CREATE FUNCTION") + cfResult shouldBe a[Success[_]] + cfResult.get._1 shouldBe 200 + cfResult.get._2 \ "result" shouldBe JString("CREATE FUNCTION") - val cmStmt = """CREATE DATASOURCE ds1 (label string, name string) FROM (STORAGE: "file://src/test/resources/shogun_data.json")""" + val cmStmt = """CREATE DATASOURCE ds1 (id string, team string, 打率 numeric, 試合数 numeric, 打席 numeric, 打数 numeric, 安打 numeric, 本塁打 numeric, 打点 numeric, 盗塁 numeric, 四球 numeric, 死球 numeric, 三振 numeric, 犠 打 numeric, 併殺打 numeric, 長打率 numeric, 出塁率 numeric, OPS numeric, RC27 numeric, XR27 numeric) FROM (STORAGE: "file://src/test/resources/npb_similar_player_data.json")""" val cmResult = sendJubaQL(cmStmt) cmResult shouldBe a[Success[_]] - val config = Source.fromFile("src/test/resources/shogun.json").getLines().mkString("") - val cdStmt = s"""CREATE CLASSIFIER MODEL test (label: label) AS name WITH unigram CONFIG '$config'""" + val config = Source.fromFile("src/test/resources/npb_similar_player.json").getLines().mkString("") + val cdStmt = s"""CREATE RECOMMENDER MODEL test (id: id) AS team WITH unigram, * WITH id CONFIG '$config'""" val cdResult = sendJubaQL(cdStmt) cdResult shouldBe a[Success[_]] - val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT concat(multiply(3, label), "ABC") AS label, name FROM ds1""") + val csResult = sendJubaQL("""CREATE STREAM ds2 FROM SELECT addABC(id, team) AS id, team, 打率, 試合数, 打席, 打数, 安打, 本塁打, 打点, 盗塁, 四球, 死球, 三振, 犠打, 併殺打, 長打率, 出塁率, OPS, RC27, XR27 FROM ds1""") csResult shouldBe a[Success[_]] csResult.get._1 shouldBe 200 csResult.get._2 \ "result" shouldBe JString("CREATE STREAM") - val umStmt = """UPDATE MODEL test USING train FROM ds2""" + val umStmt = """UPDATE MODEL test USING update_row FROM ds2""" val umResult = sendJubaQL(umStmt) umResult shouldBe a[Success[_]] @@ -2952,28 +4137,41 @@ class CreateFunctionSpec spResult shouldBe a[Success[_]] waitUntilDone("ds1", 6000) - // analyze - val aStmt = """ANALYZE '{"name": "慶喜"}' BY MODEL test USING classify""" + val aStmt = """ANALYZE '阿部慎之助ABC' BY MODEL test USING complete_row_from_id""" val aResult = sendJubaQL(aStmt) + val aStmt2 = """ANALYZE '内川聖一ABC' BY MODEL test USING complete_row_from_id""" + val aResult2 = sendJubaQL(aStmt2) // shut down val sdResult = sendJubaQL("SHUTDOWN") sdResult shouldBe a[Success[_]] + // now check the result aResult shouldBe a[Success[_]] if (aResult.get._1 != 200) println(stdout.toString) aResult.get._1 shouldBe 200 - (aResult.get._2 \ "result").extractOpt[ClassifierResult] match { - case Some(pred) => - val scores = pred.predictions.map(res => (res.label, res.score)).toMap - // the order of entries differs per machine/OS, so we use this - // slightly complicated way of checking equality - scores.keys.toList should contain only("徳川徳川徳川ABC", "足利足利足利ABC", "北条北条北条ABC") - Math.abs(scores("徳川徳川徳川ABC") - 0.07692306488752365) should be < 0.00001 - scores("足利足利足利ABC") shouldBe 0.0 - scores("北条北条北条ABC") shouldBe 0.0 - case None => - fail("Failed to parse returned content as a classifier result") + aResult.get._2 \ "result" \ "num_values" match { + case JObject(list) => + val vals = list.collect({ + case (s, JDouble(j)) => (s, j) + }).toMap + vals.size shouldBe 18 + case _ => + fail("there was no 'num_values' key") + } + aResult2 shouldBe a[Success[_]] + if (aResult2.get._1 != 200) + println(stdout.toString) + aResult2.get._1 shouldBe 200 + aResult2.get._2 \ "result" \ "num_values" match { + case JObject(list) => + val vals = list.collect({ + case (s, JDouble(j)) => (s, j) + }).toMap + + vals.size shouldBe 0 + case _ => + fail("there was no 'num_values' key") } // wait until shutdown val exitValue = process.exitValue() @@ -3836,4 +5034,4 @@ class CreateFeatureFunctionSpec val exitValue = process.exitValue() exitValue shouldBe 0 } -} +} \ No newline at end of file