diff --git a/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/NamedParametersPreparedStatement.scala b/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/NamedParametersPreparedStatement.scala index f618d4275..4610d03ec 100644 --- a/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/NamedParametersPreparedStatement.scala +++ b/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/NamedParametersPreparedStatement.scala @@ -51,6 +51,11 @@ import scala.collection.mutable class NamedParametersPreparedStatementException(val errors: List[ErrorMessage]) extends Exception +abstract class NamedParametersPreparedStatementExecutionResult +case class NamedParametersPreparedStatementResultSet(rs: ResultSet) + extends NamedParametersPreparedStatementExecutionResult +case class NamedParametersPreparedStatementUpdate(count: Int) extends NamedParametersPreparedStatementExecutionResult + // A postgres type, described by its JDBC enum type + the regular postgres type name. // The postgres type is a string among the many types that exist in postgres. case class PostgresType(jdbcType: Int, nullable: Boolean, typeName: String) @@ -402,8 +407,14 @@ class NamedParametersPreparedStatement( // The query output type is obtained using JDBC's `metadata` private val queryOutputType: Either[String, PostgresRowType] = { val metadata = stmt.getMetaData // SQLException at that point would be a bug. - if (metadata == null) Left("non-executable code") - else { + if (metadata == null) { + // an UPDATE/INSERT. We'll return a single row with a count column + Right( + PostgresRowType( + Seq(PostgresColumn("update_count", PostgresType(java.sql.Types.INTEGER, nullable = false, "integer"))) + ) + ) + } else { val columns = (1 to metadata.getColumnCount).map { i => val name = metadata.getColumnName(i) val tipe = metadata.getColumnType(i) @@ -452,7 +463,9 @@ class NamedParametersPreparedStatement( def errorPosition(p: Position): ErrorPosition = ErrorPosition(p.line, p.column) ErrorRange(errorPosition(position), errorPosition(position1)) } - def executeWith(parameters: Seq[(String, RawValue)]): Either[String, ResultSet] = { + def executeWith( + parameters: Seq[(String, RawValue)] + ): Either[String, NamedParametersPreparedStatementExecutionResult] = { val mandatoryParameters = { for ( (name, diagnostic) <- declaredTypeInfo @@ -468,7 +481,14 @@ class NamedParametersPreparedStatement( if (mandatoryParameters.nonEmpty) Left(s"no value was specified for ${mandatoryParameters.mkString(", ")}") else try { - Right(stmt.executeQuery()) + val isResultSet = stmt.execute() + if (isResultSet) Right(NamedParametersPreparedStatementResultSet(stmt.getResultSet)) + else { + // successful execution of an UPDATE/INSERT (empty queries also get there) + Right( + NamedParametersPreparedStatementUpdate(stmt.getUpdateCount) + ) + } } catch { // We'd catch here user-visible PSQL runtime errors (e.g. schema not found, table not found, // wrong credentials) hit _at runtime_ because the user FDW schema.table maps to a datasource diff --git a/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/SqlCompilerService.scala b/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/SqlCompilerService.scala index a31de3de4..bed44dbdb 100644 --- a/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/SqlCompilerService.scala +++ b/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/SqlCompilerService.scala @@ -16,7 +16,12 @@ import com.google.common.cache.{CacheBuilder, CacheLoader} import com.rawlabs.compiler._ import com.rawlabs.sql.compiler.antlr4.{ParseProgramResult, SqlIdnNode, SqlParamUseNode, SqlSyntaxAnalyzer} import com.rawlabs.sql.compiler.metadata.UserMetadataCache -import com.rawlabs.sql.compiler.writers.{TypedResultSetCsvWriter, TypedResultSetJsonWriter} +import com.rawlabs.sql.compiler.writers.{ + StatusCsvWriter, + StatusJsonWriter, + TypedResultSetCsvWriter, + TypedResultSetJsonWriter +} import com.rawlabs.utils.core.{RawSettings, RawUtils} import org.bitbucket.inkytonik.kiama.util.Positions @@ -185,7 +190,13 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends case Right(tipe) => val arguments = environment.maybeArguments.getOrElse(Array.empty) pstmt.executeWith(arguments) match { - case Right(r) => render(environment, tipe, r, outputStream, maxRows) + case Right(result) => result match { + case NamedParametersPreparedStatementResultSet(rs) => + resultSetRendering(environment, tipe, rs, outputStream, maxRows) + case NamedParametersPreparedStatementUpdate(count) => + // No ResultSet, it was an update. Return a status in the expected format. + updateResultRendering(environment, outputStream, count, maxRows) + } case Left(error) => ExecutionRuntimeFailure(error) } case Left(errors) => ExecutionRuntimeFailure(errors.mkString(", ")) @@ -209,7 +220,7 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends } } - private def render( + private def resultSetRendering( environment: ProgramEnvironment, tipe: RawType, v: ResultSet, @@ -255,6 +266,44 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends } + private def updateResultRendering( + environment: ProgramEnvironment, + stream: OutputStream, + count: Int, + maybeLong: Option[Long] + ) = { + environment.options + .get("output-format") + .map(_.toLowerCase) match { + case Some("csv") => + val windowsLineEnding = environment.options.get("windows-line-ending") match { + case Some("true") => true + case _ => false //settings.config.getBoolean("raw.compiler.windows-line-ending") + } + val lineSeparator = if (windowsLineEnding) "\r\n" else "\n" + val writer = new StatusCsvWriter(stream, lineSeparator) + try { + writer.write(count) + } catch { + case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) + } finally { + RawUtils.withSuppressNonFatalException(writer.close()) + } + case Some("json") => + val w = new StatusJsonWriter(stream) + try { + w.write(count) + ExecutionSuccess(true) + } catch { + case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) + } finally { + RawUtils.withSuppressNonFatalException(w.close()) + } + case _ => ExecutionRuntimeFailure("unknown output format") + } + ExecutionSuccess(true) + } + override def formatCode( source: String, environment: ProgramEnvironment, diff --git a/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/writers/StatusCsvWriter.scala b/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/writers/StatusCsvWriter.scala new file mode 100644 index 000000000..56dc5590e --- /dev/null +++ b/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/writers/StatusCsvWriter.scala @@ -0,0 +1,57 @@ +/* + * Copyright 2023 RAW Labs S.A. + * + * Use of this software is governed by the Business Source License + * included in the file licenses/BSL.txt. + * + * As of the Change Date specified in that file, in accordance with + * the Business Source License, use of this software will be governed + * by the Apache License, Version 2.0, included in the file + * licenses/APL.txt. + */ + +package com.rawlabs.sql.compiler.writers + +import com.fasterxml.jackson.core.{JsonEncoding, JsonParser} +import com.fasterxml.jackson.dataformat.csv.CsvGenerator.Feature.STRICT_CHECK_FOR_QUOTING +import com.fasterxml.jackson.dataformat.csv.{CsvFactory, CsvSchema} + +import java.io.{Closeable, IOException, OutputStream} + +class StatusCsvWriter(os: OutputStream, lineSeparator: String) extends Closeable { + + final private val gen = + try { + val factory = new CsvFactory + factory.disable(JsonParser.Feature.AUTO_CLOSE_SOURCE) // Don't close file descriptors automatically + factory.createGenerator(os, JsonEncoding.UTF8) + } catch { + case e: IOException => throw new RuntimeException(e) + } + + private val schemaBuilder = CsvSchema.builder() + schemaBuilder.setColumnSeparator(',') + schemaBuilder.setUseHeader(true) + schemaBuilder.setLineSeparator(lineSeparator) + schemaBuilder.setQuoteChar('"') + schemaBuilder.setNullValue("") + schemaBuilder.addColumn("update_count") + gen.setSchema(schemaBuilder.build) + gen.enable(STRICT_CHECK_FOR_QUOTING) + + @throws[IOException] + def write(count: Int): Unit = { + gen.writeStartObject() + gen.writeFieldName("update_count") + gen.writeNumber(count) + gen.writeEndObject() + } + + def flush(): Unit = { + gen.flush() + } + + override def close(): Unit = { + gen.close() + } +} diff --git a/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/writers/StatusJsonWriter.scala b/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/writers/StatusJsonWriter.scala new file mode 100644 index 000000000..4c34641f5 --- /dev/null +++ b/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/writers/StatusJsonWriter.scala @@ -0,0 +1,43 @@ +/* + * Copyright 2024 RAW Labs S.A. + * + * Use of this software is governed by the Business Source License + * included in the file licenses/BSL.txt. + * + * As of the Change Date specified in that file, in accordance with + * the Business Source License, use of this software will be governed + * by the Apache License, Version 2.0, included in the file + * licenses/APL.txt. + */ + +package com.rawlabs.sql.compiler.writers + +import com.fasterxml.jackson.core.{JsonEncoding, JsonFactory, JsonParser} + +import java.io.{IOException, OutputStream} + +class StatusJsonWriter(os: OutputStream) { + + final private val gen = + try { + val factory = new JsonFactory + factory.disable(JsonParser.Feature.AUTO_CLOSE_SOURCE) // Don't close file descriptors automatically + factory.createGenerator(os, JsonEncoding.UTF8) + } catch { + case e: IOException => throw new RuntimeException(e) + } + + @throws[IOException] + def write(count: Int): Unit = { + gen.writeStartArray() + gen.writeStartObject() + gen.writeFieldName("update_count") + gen.writeNumber(count) + gen.writeEndObject() + gen.writeEndArray() + } + + def close(): Unit = { + gen.close() + } +} diff --git a/sql-compiler/src/test/scala/com/rawlabs/sql/compiler/TestNamedParametersStatement.scala b/sql-compiler/src/test/scala/com/rawlabs/sql/compiler/TestNamedParametersStatement.scala index 430600fd2..d416ad22e 100644 --- a/sql-compiler/src/test/scala/com/rawlabs/sql/compiler/TestNamedParametersStatement.scala +++ b/sql-compiler/src/test/scala/com/rawlabs/sql/compiler/TestNamedParametersStatement.scala @@ -60,7 +60,12 @@ class TestNamedParametersStatement var rs: ResultSet = null try { statement = mkPreparedStatement(conn, code) - rs = statement.executeWith(Seq("v1" -> RawString("Hello!"))).right.get + rs = statement + .executeWith(Seq("v1" -> RawString("Hello!"))) + .right + .get + .asInstanceOf[NamedParametersPreparedStatementResultSet] + .rs rs.next() assert(rs.getString("arg") == "Hello!") } finally { @@ -77,7 +82,12 @@ class TestNamedParametersStatement var rs: ResultSet = null try { statement = mkPreparedStatement(conn, code) - rs = statement.executeWith(Seq("v" -> RawString("Hello!"))).right.get + rs = statement + .executeWith(Seq("v" -> RawString("Hello!"))) + .right + .get + .asInstanceOf[NamedParametersPreparedStatementResultSet] + .rs rs.next() assert(rs.getString("greeting") == "Hello!") @@ -98,7 +108,12 @@ class TestNamedParametersStatement val metadata = statement.queryMetadata.right.get assert(metadata.parameters.keys == Set("v1", "v2")) - rs = statement.executeWith(Seq("v1" -> RawString("Lisbon"), "v2" -> RawInt(1))).right.get + rs = statement + .executeWith(Seq("v1" -> RawString("Lisbon"), "v2" -> RawInt(1))) + .right + .get + .asInstanceOf[NamedParametersPreparedStatementResultSet] + .rs rs.next() assert(rs.getString(1) == "Lisbon") assert(rs.getInt(2) == 1) @@ -120,7 +135,12 @@ class TestNamedParametersStatement var rs: ResultSet = null try { statement = mkPreparedStatement(conn, code) - rs = statement.executeWith(Seq("v1" -> RawString("Hello!"))).right.get + rs = statement + .executeWith(Seq("v1" -> RawString("Hello!"))) + .right + .get + .asInstanceOf[NamedParametersPreparedStatementResultSet] + .rs rs.next() assert(rs.getString("arg") == "Hello!") } finally { @@ -139,7 +159,12 @@ class TestNamedParametersStatement statement = mkPreparedStatement(conn, code) val metadata = statement.queryMetadata.right.get assert(metadata.parameters.keys == Set("bar")) - rs = statement.executeWith(Seq("bar" -> RawString("Hello!"))).right.get + rs = statement + .executeWith(Seq("bar" -> RawString("Hello!"))) + .right + .get + .asInstanceOf[NamedParametersPreparedStatementResultSet] + .rs rs.next() assert(rs.getString("v1") == ":foo") @@ -161,7 +186,7 @@ class TestNamedParametersStatement val metadata = statement.queryMetadata assert(metadata.isRight) assert(metadata.right.get.parameters.isEmpty) - rs = statement.executeWith(Seq.empty).right.get + rs = statement.executeWith(Seq.empty).right.get.asInstanceOf[NamedParametersPreparedStatementResultSet].rs rs.next() assert(rs.getString("arg") == """[1, 2, "3", {"a": "Hello"}]""") diff --git a/sql-compiler/src/test/scala/com/rawlabs/sql/compiler/TestSqlCompilerServiceAirports.scala b/sql-compiler/src/test/scala/com/rawlabs/sql/compiler/TestSqlCompilerServiceAirports.scala index 3de0c3e12..f46a6e7aa 100644 --- a/sql-compiler/src/test/scala/com/rawlabs/sql/compiler/TestSqlCompilerServiceAirports.scala +++ b/sql-compiler/src/test/scala/com/rawlabs/sql/compiler/TestSqlCompilerServiceAirports.scala @@ -1137,20 +1137,21 @@ class TestSqlCompilerServiceAirports |-- SELECT :p + 10; |""".stripMargin) { t => val v = compilerService.validate(t.q, asJson()) - assert(v.messages.size == 1) - assert(v.messages(0).message == "non-executable code") - } - - test("""CREATE TABLE Persons ( - | ID int NOT NULL, - | LastName varchar(255) NOT NULL, - | FirstName varchar(255), - | Age int, - | PRIMARY KEY (ID) - |);""".stripMargin) { t => - val v = compilerService.validate(t.q, asJson()) - assert(v.messages.size == 1) - assert(v.messages(0).message == "non-executable code") + assert(v.messages.isEmpty) + val baos = new ByteArrayOutputStream() + assert( + compilerService.execute( + t.q, + asJson(Map("p" -> RawInt(5))), + None, + baos + ) == ExecutionSuccess(true) + ) + // The code does nothing, but we don't get an error when running it in Postgres. + assert( + baos.toString() === + """[{"update_count":0}]""".stripMargin + ) } test("""select @@ -1398,4 +1399,107 @@ class TestSqlCompilerServiceAirports |""".stripMargin) } + + test("INSERT") { _ => + // An SQL INSERT statement adding a fake city to the airports table. + val q = + """INSERT INTO example.airports (airport_id, name, city, country, iata_faa, icao, latitude, longitude, altitude, timezone, dst, tz) + |VALUES (8108, :airport, :city, :country, 'FC', 'FC', 0.0, 0.0, 0.0, 0, 'U', 'UTC') + |""".stripMargin + val baos = new ByteArrayOutputStream() + assert( + compilerService.execute( + q, + asCsv(params = + Map("airport" -> RawString("FAKE"), "city" -> RawString("Fake City"), "country" -> RawString("Fake Country")) + ), + None, + baos + ) == ExecutionSuccess(true) + ) + assert( + baos.toString() === + """update_count + |1 + |""".stripMargin + ) + baos.reset() + assert( + compilerService.execute( + "SELECT city, country FROM example.airports WHERE name = :a", + asCsv(params = Map("a" -> RawString("FAKE"))), + None, + baos + ) == ExecutionSuccess(true) + ) + assert( + baos.toString() === + """city,country + |Fake City,Fake Country + |""".stripMargin + ) + + } + + test("UPDATE (CSV output)") { _ => + val baos = new ByteArrayOutputStream() + assert( + compilerService.execute( + "UPDATE example.airports SET city = :newName WHERE country = :c", + asCsv(params = Map("newName" -> RawString("La Roche sur Foron"), "c" -> RawString("Portugal"))), + None, + baos + ) == ExecutionSuccess(true) + ) + assert( + baos.toString() === + """update_count + |39 + |""".stripMargin + ) + baos.reset() + assert( + compilerService.execute( + "SELECT DISTINCT city FROM example.airports WHERE country = :c", + asCsv(params = Map("c" -> RawString("Portugal"))), + None, + baos + ) == ExecutionSuccess(true) + ) + assert( + baos.toString() === + """city + |La Roche sur Foron + |""".stripMargin + ) + } + + test("UPDATE (Json output)") { _ => + val baos = new ByteArrayOutputStream() + assert( + compilerService.execute( + "UPDATE example.airports SET city = :newName WHERE country = :c", + asJson(params = Map("newName" -> RawString("Lausanne"), "c" -> RawString("Portugal"))), + None, + baos + ) == ExecutionSuccess(true) + ) + assert( + baos.toString() === + """[{"update_count":39}]""".stripMargin + ) + baos.reset() + assert( + compilerService.execute( + "SELECT DISTINCT city FROM example.airports WHERE country = :c", + asJson(params = Map("c" -> RawString("Portugal"))), + None, + baos + ) == ExecutionSuccess(true) + ) + assert( + baos.toString() === + """[{"city":"Lausanne"}]""" + ) + } }