Skip to content

Commit

Permalink
RD-15131: CRUD support in SqlCompilerService (#537)
Browse files Browse the repository at this point in the history
  • Loading branch information
bgaidioz authored Nov 20, 2024
1 parent f222e30 commit 61a3b47
Show file tree
Hide file tree
Showing 6 changed files with 325 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ import scala.collection.mutable

class NamedParametersPreparedStatementException(val errors: List[ErrorMessage]) extends Exception

abstract class NamedParametersPreparedStatementExecutionResult
case class NamedParametersPreparedStatementResultSet(rs: ResultSet)
extends NamedParametersPreparedStatementExecutionResult
case class NamedParametersPreparedStatementUpdate(count: Int) extends NamedParametersPreparedStatementExecutionResult

// A postgres type, described by its JDBC enum type + the regular postgres type name.
// The postgres type is a string among the many types that exist in postgres.
case class PostgresType(jdbcType: Int, nullable: Boolean, typeName: String)
Expand Down Expand Up @@ -402,8 +407,14 @@ class NamedParametersPreparedStatement(
// The query output type is obtained using JDBC's `metadata`
private val queryOutputType: Either[String, PostgresRowType] = {
val metadata = stmt.getMetaData // SQLException at that point would be a bug.
if (metadata == null) Left("non-executable code")
else {
if (metadata == null) {
// an UPDATE/INSERT. We'll return a single row with a count column
Right(
PostgresRowType(
Seq(PostgresColumn("update_count", PostgresType(java.sql.Types.INTEGER, nullable = false, "integer")))
)
)
} else {
val columns = (1 to metadata.getColumnCount).map { i =>
val name = metadata.getColumnName(i)
val tipe = metadata.getColumnType(i)
Expand Down Expand Up @@ -452,7 +463,9 @@ class NamedParametersPreparedStatement(
def errorPosition(p: Position): ErrorPosition = ErrorPosition(p.line, p.column)
ErrorRange(errorPosition(position), errorPosition(position1))
}
def executeWith(parameters: Seq[(String, RawValue)]): Either[String, ResultSet] = {
def executeWith(
parameters: Seq[(String, RawValue)]
): Either[String, NamedParametersPreparedStatementExecutionResult] = {
val mandatoryParameters = {
for (
(name, diagnostic) <- declaredTypeInfo
Expand All @@ -468,7 +481,14 @@ class NamedParametersPreparedStatement(
if (mandatoryParameters.nonEmpty) Left(s"no value was specified for ${mandatoryParameters.mkString(", ")}")
else
try {
Right(stmt.executeQuery())
val isResultSet = stmt.execute()
if (isResultSet) Right(NamedParametersPreparedStatementResultSet(stmt.getResultSet))
else {
// successful execution of an UPDATE/INSERT (empty queries also get there)
Right(
NamedParametersPreparedStatementUpdate(stmt.getUpdateCount)
)
}
} catch {
// We'd catch here user-visible PSQL runtime errors (e.g. schema not found, table not found,
// wrong credentials) hit _at runtime_ because the user FDW schema.table maps to a datasource
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ import com.google.common.cache.{CacheBuilder, CacheLoader}
import com.rawlabs.compiler._
import com.rawlabs.sql.compiler.antlr4.{ParseProgramResult, SqlIdnNode, SqlParamUseNode, SqlSyntaxAnalyzer}
import com.rawlabs.sql.compiler.metadata.UserMetadataCache
import com.rawlabs.sql.compiler.writers.{TypedResultSetCsvWriter, TypedResultSetJsonWriter}
import com.rawlabs.sql.compiler.writers.{
StatusCsvWriter,
StatusJsonWriter,
TypedResultSetCsvWriter,
TypedResultSetJsonWriter
}
import com.rawlabs.utils.core.{RawSettings, RawUtils}
import org.bitbucket.inkytonik.kiama.util.Positions

Expand Down Expand Up @@ -185,7 +190,13 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
case Right(tipe) =>
val arguments = environment.maybeArguments.getOrElse(Array.empty)
pstmt.executeWith(arguments) match {
case Right(r) => render(environment, tipe, r, outputStream, maxRows)
case Right(result) => result match {
case NamedParametersPreparedStatementResultSet(rs) =>
resultSetRendering(environment, tipe, rs, outputStream, maxRows)
case NamedParametersPreparedStatementUpdate(count) =>
// No ResultSet, it was an update. Return a status in the expected format.
updateResultRendering(environment, outputStream, count, maxRows)
}
case Left(error) => ExecutionRuntimeFailure(error)
}
case Left(errors) => ExecutionRuntimeFailure(errors.mkString(", "))
Expand All @@ -209,7 +220,7 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
}
}

private def render(
private def resultSetRendering(
environment: ProgramEnvironment,
tipe: RawType,
v: ResultSet,
Expand Down Expand Up @@ -255,6 +266,44 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends

}

private def updateResultRendering(
environment: ProgramEnvironment,
stream: OutputStream,
count: Int,
maybeLong: Option[Long]
) = {
environment.options
.get("output-format")
.map(_.toLowerCase) match {
case Some("csv") =>
val windowsLineEnding = environment.options.get("windows-line-ending") match {
case Some("true") => true
case _ => false //settings.config.getBoolean("raw.compiler.windows-line-ending")
}
val lineSeparator = if (windowsLineEnding) "\r\n" else "\n"
val writer = new StatusCsvWriter(stream, lineSeparator)
try {
writer.write(count)
} catch {
case ex: IOException => ExecutionRuntimeFailure(ex.getMessage)
} finally {
RawUtils.withSuppressNonFatalException(writer.close())
}
case Some("json") =>
val w = new StatusJsonWriter(stream)
try {
w.write(count)
ExecutionSuccess(true)
} catch {
case ex: IOException => ExecutionRuntimeFailure(ex.getMessage)
} finally {
RawUtils.withSuppressNonFatalException(w.close())
}
case _ => ExecutionRuntimeFailure("unknown output format")
}
ExecutionSuccess(true)
}

override def formatCode(
source: String,
environment: ProgramEnvironment,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2023 RAW Labs S.A.
*
* Use of this software is governed by the Business Source License
* included in the file licenses/BSL.txt.
*
* As of the Change Date specified in that file, in accordance with
* the Business Source License, use of this software will be governed
* by the Apache License, Version 2.0, included in the file
* licenses/APL.txt.
*/

package com.rawlabs.sql.compiler.writers

import com.fasterxml.jackson.core.{JsonEncoding, JsonParser}
import com.fasterxml.jackson.dataformat.csv.CsvGenerator.Feature.STRICT_CHECK_FOR_QUOTING
import com.fasterxml.jackson.dataformat.csv.{CsvFactory, CsvSchema}

import java.io.{Closeable, IOException, OutputStream}

class StatusCsvWriter(os: OutputStream, lineSeparator: String) extends Closeable {

final private val gen =
try {
val factory = new CsvFactory
factory.disable(JsonParser.Feature.AUTO_CLOSE_SOURCE) // Don't close file descriptors automatically
factory.createGenerator(os, JsonEncoding.UTF8)
} catch {
case e: IOException => throw new RuntimeException(e)
}

private val schemaBuilder = CsvSchema.builder()
schemaBuilder.setColumnSeparator(',')
schemaBuilder.setUseHeader(true)
schemaBuilder.setLineSeparator(lineSeparator)
schemaBuilder.setQuoteChar('"')
schemaBuilder.setNullValue("")
schemaBuilder.addColumn("update_count")
gen.setSchema(schemaBuilder.build)
gen.enable(STRICT_CHECK_FOR_QUOTING)

@throws[IOException]
def write(count: Int): Unit = {
gen.writeStartObject()
gen.writeFieldName("update_count")
gen.writeNumber(count)
gen.writeEndObject()
}

def flush(): Unit = {
gen.flush()
}

override def close(): Unit = {
gen.close()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright 2024 RAW Labs S.A.
*
* Use of this software is governed by the Business Source License
* included in the file licenses/BSL.txt.
*
* As of the Change Date specified in that file, in accordance with
* the Business Source License, use of this software will be governed
* by the Apache License, Version 2.0, included in the file
* licenses/APL.txt.
*/

package com.rawlabs.sql.compiler.writers

import com.fasterxml.jackson.core.{JsonEncoding, JsonFactory, JsonParser}

import java.io.{IOException, OutputStream}

class StatusJsonWriter(os: OutputStream) {

final private val gen =
try {
val factory = new JsonFactory
factory.disable(JsonParser.Feature.AUTO_CLOSE_SOURCE) // Don't close file descriptors automatically
factory.createGenerator(os, JsonEncoding.UTF8)
} catch {
case e: IOException => throw new RuntimeException(e)
}

@throws[IOException]
def write(count: Int): Unit = {
gen.writeStartArray()
gen.writeStartObject()
gen.writeFieldName("update_count")
gen.writeNumber(count)
gen.writeEndObject()
gen.writeEndArray()
}

def close(): Unit = {
gen.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ class TestNamedParametersStatement
var rs: ResultSet = null
try {
statement = mkPreparedStatement(conn, code)
rs = statement.executeWith(Seq("v1" -> RawString("Hello!"))).right.get
rs = statement
.executeWith(Seq("v1" -> RawString("Hello!")))
.right
.get
.asInstanceOf[NamedParametersPreparedStatementResultSet]
.rs
rs.next()
assert(rs.getString("arg") == "Hello!")
} finally {
Expand All @@ -77,7 +82,12 @@ class TestNamedParametersStatement
var rs: ResultSet = null
try {
statement = mkPreparedStatement(conn, code)
rs = statement.executeWith(Seq("v" -> RawString("Hello!"))).right.get
rs = statement
.executeWith(Seq("v" -> RawString("Hello!")))
.right
.get
.asInstanceOf[NamedParametersPreparedStatementResultSet]
.rs

rs.next()
assert(rs.getString("greeting") == "Hello!")
Expand All @@ -98,7 +108,12 @@ class TestNamedParametersStatement
val metadata = statement.queryMetadata.right.get
assert(metadata.parameters.keys == Set("v1", "v2"))

rs = statement.executeWith(Seq("v1" -> RawString("Lisbon"), "v2" -> RawInt(1))).right.get
rs = statement
.executeWith(Seq("v1" -> RawString("Lisbon"), "v2" -> RawInt(1)))
.right
.get
.asInstanceOf[NamedParametersPreparedStatementResultSet]
.rs
rs.next()
assert(rs.getString(1) == "Lisbon")
assert(rs.getInt(2) == 1)
Expand All @@ -120,7 +135,12 @@ class TestNamedParametersStatement
var rs: ResultSet = null
try {
statement = mkPreparedStatement(conn, code)
rs = statement.executeWith(Seq("v1" -> RawString("Hello!"))).right.get
rs = statement
.executeWith(Seq("v1" -> RawString("Hello!")))
.right
.get
.asInstanceOf[NamedParametersPreparedStatementResultSet]
.rs
rs.next()
assert(rs.getString("arg") == "Hello!")
} finally {
Expand All @@ -139,7 +159,12 @@ class TestNamedParametersStatement
statement = mkPreparedStatement(conn, code)
val metadata = statement.queryMetadata.right.get
assert(metadata.parameters.keys == Set("bar"))
rs = statement.executeWith(Seq("bar" -> RawString("Hello!"))).right.get
rs = statement
.executeWith(Seq("bar" -> RawString("Hello!")))
.right
.get
.asInstanceOf[NamedParametersPreparedStatementResultSet]
.rs

rs.next()
assert(rs.getString("v1") == ":foo")
Expand All @@ -161,7 +186,7 @@ class TestNamedParametersStatement
val metadata = statement.queryMetadata
assert(metadata.isRight)
assert(metadata.right.get.parameters.isEmpty)
rs = statement.executeWith(Seq.empty).right.get
rs = statement.executeWith(Seq.empty).right.get.asInstanceOf[NamedParametersPreparedStatementResultSet].rs

rs.next()
assert(rs.getString("arg") == """[1, 2, "3", {"a": "Hello"}]""")
Expand Down
Loading

0 comments on commit 61a3b47

Please sign in to comment.