Skip to content

Commit

Permalink
Fix RD-14971: Make the inferrer thread pool dynamic + fix inferrer ca…
Browse files Browse the repository at this point in the history
…che reuse
  • Loading branch information
miguelbranco80 authored Oct 31, 2024
1 parent 153f13f commit 4863b20
Show file tree
Hide file tree
Showing 12 changed files with 56 additions and 115 deletions.
2 changes: 1 addition & 1 deletion project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ object Dependencies {

val utilsCore = "com.raw-labs" %% "utils-core" % "0.50.0"

val utilsSources = "com.raw-labs" %% "utils-sources" % "0.50.0"
val utilsSources = "com.raw-labs" %% "utils-sources" % "0.51.0"

val protocolRaw = "com.raw-labs" %% "protocol-raw" % "0.50.0"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class SnapiCompilerService(engineDefinition: (Engine, Boolean))(implicit protect

private val (engine, initedEngine) = engineDefinition

private val inferrer = InferrerServiceProvider()

// The default constructor allows an Engine to be specified, plus a flag to indicate whether it was created here
// or externally. That's necessary for the test framework.
// This is actually the "default constructor" which obtains a new engine or reuses an existing one.
Expand All @@ -125,18 +127,10 @@ class SnapiCompilerService(engineDefinition: (Engine, Boolean))(implicit protect

private def getCompilerContext(user: RawUid): CompilerContext = {
compilerContextCachesLock.synchronized {
compilerContextCaches.getOrElseUpdate(user, createCompilerContext(user, "snapi"))
compilerContextCaches.getOrElseUpdate(user, new CompilerContext(user, inferrer))
}
}

private def createCompilerContext(user: RawUid, language: String): CompilerContext = {
// Initialize inferrer
val inferrer = InferrerServiceProvider()

// Initialize compiler context
new CompilerContext(language, user, inferrer)
}

private def getProgramContext(user: RawUid, environment: ProgramEnvironment): ProgramContext = {
val compilerContext = getCompilerContext(user)
new ProgramContext(environment, compilerContext)
Expand Down Expand Up @@ -646,7 +640,11 @@ class SnapiCompilerService(engineDefinition: (Engine, Boolean))(implicit protect
}

override def doStop(): Unit = {
compilerContextCaches.values.foreach(compilerContext => compilerContext.inferrer.stop())
compilerContextCachesLock.synchronized {
compilerContextCaches.values.foreach(_.stop())
compilerContextCaches.clear()
}
inferrer.stop()
if (initedEngine) {
CompilerService.releaseEngine
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ class SnowflakePackageTest extends SnapiTestContext {
|)""".stripMargin
) { it =>
it should runErrorAs(
"""IO error connecting to does-not-exist: JDBC driver encountered communication error. Message: HTTP status=513.""".stripMargin
"""IO error connecting to does-not-exist: JDBC driver encountered communication error. Message: HTTP status=404.""".stripMargin
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,6 @@ package com.rawlabs.snapi.compiler.tests.regressions
import com.rawlabs.snapi.compiler.tests.SnapiTestContext

class RD5775Test extends SnapiTestContext {
test("""let data = Http.Get("https://jira.atlassian.com/rest/api/latest/search", args=[{"jql", "fixVersion=9.0.0"}]),
| r = Json.InferAndRead(data),
| issues = Collection.Transform(r.issues, i -> {
| i.key, i.fields.summary, i.fields.status.name, i.fields.created, i.fields.resolutiondate
| })
|in Collection.OrderBy(issues, i -> i.resolutiondate, "DESC")
|""".stripMargin)(_ should run)

test("Collection")(_ should (tipe and runErrorAs("unsupported type")))
test("Collection.GroupBy")(_ should (tipe and runErrorAs("unsupported type")))
test("""Http.Get("http://www.raw-labs.com/not-found")""")(_ should (tipe and runErrorAs("unsupported type")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ import com.rawlabs.snapi.compiler.tests.SnapiTestContext

class RD9485Test extends SnapiTestContext {

property("raw.snapi.frontend.inferrer.local.xml.sample-size", "2")

val xmlList = tempFile("""<?xml version="1.0" encoding="UTF-8"?>
|<top>
| <a>1</a>
Expand All @@ -26,7 +24,7 @@ class RD9485Test extends SnapiTestContext {
| <a/>
|</top>""".stripMargin)

test(s"""Xml.InferAndRead("file://$xmlList")""") { it =>
test(s"""Xml.InferAndRead("file://$xmlList", sampleSize = 2)""") { it =>
it should evaluateTo(s"""Xml.Read("file://$xmlList", type record(a: collection(double)))""")
it should evaluateTo("""{
| a: [
Expand All @@ -39,7 +37,7 @@ class RD9485Test extends SnapiTestContext {
|""".stripMargin)
}

test(s"""Xml.InferAndRead("file://$xmlList", preferNulls=false)""") { it =>
test(s"""Xml.InferAndRead("file://$xmlList", sampleSize = 2, preferNulls=false)""") { it =>
it should evaluateTo("""{
| a: [
| 1.0,
Expand Down Expand Up @@ -68,7 +66,7 @@ class RD9485Test extends SnapiTestContext {
| <a arg1="where is the other one?"></a>
|</top>""".stripMargin)

test(s"""Xml.InferAndRead("file://$xmlRecords")""") { it =>
test(s"""Xml.InferAndRead("file://$xmlRecords", sampleSize = 2)""") { it =>
it should evaluateTo("""{
| a: [
| {
Expand Down Expand Up @@ -100,7 +98,7 @@ class RD9485Test extends SnapiTestContext {
|""".stripMargin)
}

test(s"""Xml.InferAndRead("file://$xmlRecords", preferNulls=false)""") { it =>
test(s"""Xml.InferAndRead("file://$xmlRecords", sampleSize = 2, preferNulls=false)""") { it =>
it should evaluateTo("""{
| a: [
| {
Expand Down
2 changes: 0 additions & 2 deletions snapi-frontend/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ raw.snapi.frontend {
check-syntax-analyzers = false

inferrer {
# Number of inferrer threads.
thread-pool-size = 8
# Number of entries in the inferrer cache.
cache-size = 100
# Timeout for an inference request.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,22 @@

package com.rawlabs.snapi.frontend.base

import com.rawlabs.utils.core.{RawSettings, RawUid}
import com.rawlabs.utils.core.{RawService, RawSettings, RawUid}
import com.typesafe.scalalogging.StrictLogging
import com.rawlabs.snapi.frontend.inferrer.api.{InferrerInput, InferrerOutput, InferrerService}
import com.rawlabs.utils.core._

/**
* Contains state that is shared between different programs.
*/
class CompilerContext(
val language: String,
val user: RawUid,
val inferrer: InferrerService
)(
class CompilerContext(val user: RawUid, val inferrer: InferrerService)(
implicit val settings: RawSettings
) extends StrictLogging {
) extends RawService
with StrictLogging {

def infer(properties: InferrerInput): Either[String, InferrerOutput] = {
inferrer.inferWithExpiry(properties)
def infer(inferrerInput: InferrerInput): Either[String, InferrerOutput] = {
inferrer.inferWithCache(inferrerInput)
}

override protected def doStop(): Unit = {}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ object InferrerService {
private val INFERRER_EXPIRY = "raw.snapi.frontend.inferrer.expiry"

private val INFERRER_CACHE_SIZE = "raw.snapi.frontend.inferrer.cache-size"

private val INFERRER_THREAD_POOL_SIZE = "raw.snapi.frontend.inferrer.thread-pool-size"
}

abstract class InferrerService(implicit settings: RawSettings) extends RawService {
Expand All @@ -36,32 +34,17 @@ abstract class InferrerService(implicit settings: RawSettings) extends RawServic

private val inferrerCacheSize = settings.getInt(INFERRER_CACHE_SIZE)

private val inferrerThreadPoolSize = settings.getInt(INFERRER_THREAD_POOL_SIZE)
private val inferrerThreadPool =
Executors.newFixedThreadPool(inferrerThreadPoolSize, RawUtils.newThreadFactory("inferrer-service"))

// The main entrypoint for the inferrer.
// Using an exception for inference is reasonable because we often want inference to exit early.
@throws[RawException]
def infer(properties: InferrerInput): InferrerOutput

// Inferrer that uses internal cache and expiry.
// Instead of an exception, it returns an Either (since the timeout error is returned as a Left).
// @param timeout How long the inference can take before aborting with an exception.
// @param expiry How long an old result of the inference is still accepted as valid.
final def inferWithExpiry(properties: InferrerInput): Either[String, InferrerOutput] = {
inferCache.get(properties)
}
private val inferrerThreadPool = Executors.newCachedThreadPool(RawUtils.newThreadFactory("inferrer-thread"))

private val inferCache: LoadingCache[InferrerInput, Either[String, InferrerOutput]] = CacheBuilder
.newBuilder()
.maximumSize(inferrerCacheSize)
.expireAfterAccess(inferrerExpirySeconds, TimeUnit.SECONDS)
.build(new CacheLoader[InferrerInput, Either[String, InferrerOutput]] {
def load(properties: InferrerInput): Either[String, InferrerOutput] = {
def load(inferrerInput: InferrerInput): Either[String, InferrerOutput] = {
val inferrerFuture = inferrerThreadPool.submit(() => {
try {
Right(infer(properties))
Right(infer(inferrerInput))
} catch {
case ex: InferrerException => Left(ex.getMessage)
}
Expand All @@ -84,6 +67,22 @@ abstract class InferrerService(implicit settings: RawSettings) extends RawServic
}
})

def inferWithCache(inferrerInput: InferrerInput): Either[String, InferrerOutput] = {
inferCache.get(inferrerInput)
}

/**
* Infers the schema of a data source.
* It throws an exception if the inference fails.
* (We prefer in this case to use exceptions instead of Option or Try because we often want to exit early.)
*
* @param properties The properties of the data source.
* @throws RawException if the inference fails.
* @return The inferred schema.
*/
@throws[RawException]
def infer(properties: InferrerInput): InferrerOutput

final override def doStop(): Unit = {
RawUtils.withSuppressNonFatalException {
inferrerThreadPool.shutdownNow()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,8 @@ class ProgramContext(

private val stageCompilerCache = new mutable.HashMap[SnapiProgram, Either[ErrorCompilerMessage, SnapiValue]]

def infer(
inferrerProperties: InferrerInput
): Either[String, InferrerOutput] = {
inferCache.getOrElseUpdate(
inferrerProperties,
compilerContext.infer(inferrerProperties)
)
def infer(inferrerProperties: InferrerInput): Either[String, InferrerOutput] = {
inferCache.getOrElseUpdate(inferrerProperties, compilerContext.infer(inferrerProperties))
}

def getPackage(name: String): Option[PackageExtension] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ class RD3852Test extends RawTestSuite with SettingsTestContext with StrictLoggin
val p = RawUtils.getResource("data/students/students.csv")
val l1 = new LocalPath(p)
try {

val TextInputStreamInferrerOutput(_, _, format) = inferrer.infer(AutoInferrerInput(l1, None))
assert(format.isInstanceOf[CsvFormatDescriptor])

} finally {
RawUtils.withSuppressNonFatalException(inferrer.stop())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,55 +18,31 @@
import com.rawlabs.snapi.frontend.inferrer.api.InferrerServiceProvider;
import com.rawlabs.utils.core.RawSettings;
import com.rawlabs.utils.core.RawUid;
import com.rawlabs.utils.core.RawUtils;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import scala.runtime.BoxedUnit;

public class SnapiLanguageCache {

private final Object activeContextsLock = new Object();
private final Set<SnapiContext> activeContexts = new HashSet<SnapiContext>();

private final ConcurrentHashMap<RawUid, Value> map = new ConcurrentHashMap<>();
private final ConcurrentHashMap<RawUid, InferrerService> inferrerServiceCache =
new ConcurrentHashMap<>();

private static class Value {
private final CompilerContext compilerContext;
private final InferrerService inferrer;

Value(CompilerContext compilerContext, InferrerService inferrer) {
this.compilerContext = compilerContext;
this.inferrer = inferrer;
}

public CompilerContext getCompilerContext() {
return compilerContext;
}

public InferrerService getInferrer() {
return inferrer;
}
}
private final ConcurrentHashMap<RawUid, CompilerContext> compilerContextsCache =
new ConcurrentHashMap<>();

@CompilerDirectives.TruffleBoundary
private Value get(RawUid user, RawSettings rawSettings) {
return map.computeIfAbsent(
user,
k -> {
InferrerService inferrer = InferrerServiceProvider.apply(rawSettings);
CompilerContext compilerContext =
new CompilerContext("snapi", user, inferrer, rawSettings);
return new Value(compilerContext, inferrer);
});
public InferrerService getInferrer(RawUid user, RawSettings rawSettings) {
return inferrerServiceCache.computeIfAbsent(
user, k -> InferrerServiceProvider.apply(rawSettings));
}

@CompilerDirectives.TruffleBoundary
public CompilerContext getCompilerContext(RawUid user, RawSettings rawSettings) {
return get(user, rawSettings).getCompilerContext();
}

public InferrerService getInferrer(RawUid user, RawSettings rawSettings) {
return get(user, rawSettings).getInferrer();
return compilerContextsCache.computeIfAbsent(
user, k -> new CompilerContext(user, getInferrer(user, rawSettings), rawSettings));
}

@CompilerDirectives.TruffleBoundary
Expand All @@ -81,18 +57,9 @@ public void releaseContext(SnapiContext context) {
synchronized (activeContextsLock) {
activeContexts.remove(context);
if (activeContexts.isEmpty()) {
// Close all inferrer services.
map.values()
.forEach(
v -> {
RawUtils.withSuppressNonFatalException(
() -> {
v.getInferrer().stop();
return BoxedUnit.UNIT;
},
true);
});
map.clear();
// Close all compiler contexts.
compilerContextsCache.values().forEach(CompilerContext::stop);
compilerContextsCache.clear();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,6 @@ protected Object doDescribe(
return record;
} catch (RawException ex) {
return new ErrorObject(ex.getMessage());
} finally {
inferrer.stop();
}
}
}

0 comments on commit 4863b20

Please sign in to comment.