diff --git a/README.md b/README.md index cff000f..2d2e4a2 100644 --- a/README.md +++ b/README.md @@ -5,16 +5,14 @@ [![Maven Central](https://img.shields.io/maven-central/v/io.github.moreirasantos/pgkn)](https://central.sonatype.com/artifact/io.github.moreirasantos/pgkn/) [![Kotlin](https://img.shields.io/badge/kotlin-1.9.0-blue.svg?logo=kotlin)](http://kotlinlang.org) - # pgkn PostgreSQL Kotlin/Native Driver ## Usage ``` -// Show full structure of a kotlin native project implementation("io.github.moreirasantos:pgkn:1.0.0") ``` -``` +```kotlin fun main() { val driver = PostgresDriver( host = "host.docker.internal", @@ -44,3 +42,19 @@ fun main() { } } ``` +## Features +## Named Parameters +```kotlin +driver.execute( + "select name from my_table where name = :one OR email = :other", + mapOf("one" to "your_name", "other" to "your@email.com") +) { it.getString(0) } +``` +Named Parameters provides an alternative to the traditional syntax using `?` to specify parameters. +Under the hood, it substitutes the named parameters to a query placeholder. + +In JDBC, the placeholder would be `?` but with libpq, we will pass `$1`, `$2`, etc as stated here: +[31.3.1. Main Functions - PQexecParams](https://www.postgresql.org/docs/9.5/libpq-exec.html) + +This feature implementation tries to follow Spring's `NamedParameterJdbcTemplate` as close as possible. +[NamedParameterJdbcTemplate](https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.html) diff --git a/build.gradle.kts b/build.gradle.kts index abe203b..f29079a 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -13,7 +13,7 @@ plugins { } group = "io.github.moreirasantos" -version = "1.0.1" +version = "1.0.2" repositories { mavenCentral() diff --git a/src/commonMain/kotlin/io/github/moreirasantos/pgkn/PgKnException.kt b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/PgKnException.kt deleted file mode 100644 index f704cee..0000000 --- a/src/commonMain/kotlin/io/github/moreirasantos/pgkn/PgKnException.kt +++ /dev/null @@ -1,4 +0,0 @@ -@file:Suppress("MatchingDeclarationName") -package io.github.moreirasantos.pgkn - -class SQLException : Exception() diff --git a/src/commonMain/kotlin/io/github/moreirasantos/pgkn/PostgresDriver.kt b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/PostgresDriver.kt index ee2afc1..a293f89 100644 --- a/src/commonMain/kotlin/io/github/moreirasantos/pgkn/PostgresDriver.kt +++ b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/PostgresDriver.kt @@ -1,14 +1,25 @@ package io.github.moreirasantos.pgkn -import kotlinx.cinterop.* -import libpq.* +import io.github.moreirasantos.pgkn.paramsource.MapSqlParameterSource +import io.github.moreirasantos.pgkn.paramsource.SqlParameterSource import io.github.moreirasantos.pgkn.resultset.PostgresResultSet import io.github.moreirasantos.pgkn.resultset.ResultSet +import io.github.moreirasantos.pgkn.sql.buildValueArray +import io.github.moreirasantos.pgkn.sql.parseSql +import io.github.moreirasantos.pgkn.sql.substituteNamedParameters +import kotlinx.cinterop.* +import libpq.* +/** + * Executes given query with given named parameters. + * If you pass a handler, you will receive a list of result data. + * You can pass an [SqlParameterSource] to register your own Postgres types. + */ sealed interface PostgresDriver { - fun execute(sql: String, handler: (ResultSet) -> T): List - - fun execute(sql: String): Long + fun execute(sql: String, namedParameters: Map = emptyMap(), handler: (ResultSet) -> T): List + fun execute(sql: String, paramSource: SqlParameterSource, handler: (ResultSet) -> T): List + fun execute(sql: String, namedParameters: Map = emptyMap()): Long + fun execute(sql: String, paramSource: SqlParameterSource): Long } @OptIn(ExperimentalForeignApi::class) @@ -45,24 +56,60 @@ private class PostgresDriverImpl( pgtty = null ).apply { require(ConnStatusType.CONNECTION_OK == PQstatus(this)) }!! - override fun execute(sql: String, handler: (ResultSet) -> T): List = doExecute(sql).let { - val rs = PostgresResultSet(it) + override fun execute(sql: String, namedParameters: Map, handler: (ResultSet) -> T) = + if (namedParameters.isEmpty()) doExecute(sql).handleResults(handler) + else execute(sql, MapSqlParameterSource(namedParameters), handler) + + override fun execute(sql: String, paramSource: SqlParameterSource, handler: (ResultSet) -> T) = + doExecute(sql, paramSource).handleResults(handler) + + override fun execute(sql: String, namedParameters: Map) = + if (namedParameters.isEmpty()) doExecute(sql).returnCount() + else execute(sql, MapSqlParameterSource(namedParameters)) + + override fun execute(sql: String, paramSource: SqlParameterSource) = + doExecute(sql, paramSource).returnCount() + + private fun CPointer.handleResults(handler: (ResultSet) -> T): List { + val rs = PostgresResultSet(this) val list: MutableList = mutableListOf() while (rs.next()) { list.add(handler(rs)) } - PQclear(it) + PQclear(this) return list } - override fun execute(sql: String): Long = doExecute(sql).let { - val rows = PQcmdTuples(it)!!.toKString() - PQclear(it) + private fun CPointer.returnCount(): Long { + val rows = PQcmdTuples(this)!!.toKString() + PQclear(this) return rows.toLongOrNull() ?: 0 } + private fun doExecute(sql: String, paramSource: SqlParameterSource): CPointer { + val parsedSql = parseSql(sql) + val sqlToUse: String = substituteNamedParameters(parsedSql, paramSource) + val params: Array = buildValueArray(parsedSql, paramSource) + + return memScoped { + PQexecParams( + connection, + command = sqlToUse, + nParams = params.size, + paramValues = createValues(params.size) { + println(params[it]?.toString()?.cstr) + value = params[it]?.toString()?.cstr?.getPointer(this@memScoped) + }, + paramLengths = params.map { it?.toString()?.length ?: 0 }.toIntArray().refTo(0), + paramFormats = IntArray(params.size) { TEXT_RESULT_FORMAT }.refTo(0), + paramTypes = parsedSql.parameterNames.map(paramSource::getSqlType).toUIntArray().refTo(0), + resultFormat = TEXT_RESULT_FORMAT + ) + }.check() + } + private fun doExecute(sql: String) = memScoped { PQexecParams( connection, @@ -74,8 +121,7 @@ private class PostgresDriverImpl( paramTypes = createValues(0) {}, resultFormat = TEXT_RESULT_FORMAT ) - } - .check() + }.check() private fun CPointer?.check(): CPointer { val status = PQresultStatus(this) @@ -90,5 +136,6 @@ private class PostgresDriverImpl( private fun CPointer?.error(): String = PQerrorMessage(this)!!.toKString().also { PQfinish(this) } private const val TEXT_RESULT_FORMAT = 0 + @Suppress("UnusedPrivateProperty") private const val BINARY_RESULT_FORMAT = 1 diff --git a/src/commonMain/kotlin/io/github/moreirasantos/pgkn/exception/PgKnException.kt b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/exception/PgKnException.kt new file mode 100644 index 0000000..252460b --- /dev/null +++ b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/exception/PgKnException.kt @@ -0,0 +1,9 @@ +package io.github.moreirasantos.pgkn.exception + +sealed class SQLException(message: String? = null, cause: Throwable? = null) : Exception(message, cause) + +class InvalidDataAccessApiUsageException(message: String, cause: Throwable? = null) : SQLException(message, cause) + +class AnonymousClassException : SQLException("Class must not be anonymous", null) + +class GetColumnValueException(columnIndex: Int) : SQLException("Error getting column $columnIndex value", null) diff --git a/src/commonMain/kotlin/io/github/moreirasantos/pgkn/paramsource/AbstractSqlParameterSource.kt b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/paramsource/AbstractSqlParameterSource.kt new file mode 100644 index 0000000..30b3cdf --- /dev/null +++ b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/paramsource/AbstractSqlParameterSource.kt @@ -0,0 +1,128 @@ +package io.github.moreirasantos.pgkn.paramsource + +import io.github.moreirasantos.pgkn.exception.AnonymousClassException +import io.github.moreirasantos.pgkn.paramsource.SqlParameterSource.Companion.TYPE_UNKNOWN +import kotlinx.datetime.Instant +import kotlinx.datetime.LocalDate +import kotlinx.datetime.LocalDateTime +import kotlinx.datetime.LocalTime +import kotlin.reflect.KClass + + +/** + * Abstract base class for [SqlParameterSource] implementations. + * Provides registration of SQL types per parameter and a friendly + * [toString] representation. + * Concrete subclasses must implement [hasValue] and [getValue]. + */ +abstract class AbstractSqlParameterSource : SqlParameterSource { + private val sqlTypes: MutableMap = HashMap() + private val typeNames: MutableMap = HashMap() + + /** + * Register an SQL type for the given parameter. + * @param paramName the name of the parameter + * @param sqlType the SQL type of the parameter + */ + fun registerSqlType(paramName: String, sqlType: UInt) { + sqlTypes[paramName] = sqlType + } + + fun registerSqlType(paramName: String, value: Any?) { + registerSqlType( + paramName = paramName, + sqlType = value?.let { oidMap[it::class.simpleName ?: throw AnonymousClassException()] } + ?: TYPE_UNKNOWN + ) + } + + /** + * Register an SQL type for the given parameter. + * @param paramName the name of the parameter + * @param typeName the type name of the parameter + */ + fun registerTypeName(paramName: String, typeName: String) { + typeNames[paramName] = typeName + } + + /** + * Return the SQL type for the given parameter, if registered. + * @param paramName the name of the parameter + * @return the SQL type of the parameter, + * or `TYPE_UNKNOWN` if not registered + */ + override fun getSqlType(paramName: String) = sqlTypes[paramName] ?: TYPE_UNKNOWN + + /** + * Return the type name for the given parameter, if registered. + * @param paramName the name of the parameter + * @return the type name of the parameter, + * or `null` if not registered + */ + fun getTypeName(paramName: String) = typeNames[paramName] + + /** + * Enumerate the parameter names and values with their corresponding SQL type if available, + * or just return the simple `SqlParameterSource` implementation class name otherwise. + * @since 5.2 + * @see .getParameterNames + */ + @Suppress("NestedBlockDepth") + override fun toString(): String { + val parameterNames: Array? = parameterNames + return if (parameterNames != null) { + val array = ArrayList(parameterNames.size) + for (parameterName in parameterNames) { + val value = getValue(parameterName) + /* + if (value is SqlParameterValue) { + value = (value as SqlParameterValue?).getValue() + } + */ + var typeName = getTypeName(parameterName) + if (typeName == null) { + val sqlType = getSqlType(parameterName) + if (sqlType != TYPE_UNKNOWN) { + typeName = sqlTypeNames[sqlType] + if (typeName == null) { + typeName = sqlType.toString() + } + } + } + val entry = StringBuilder() + entry.append(parameterName).append('=').append(value) + if (typeName != null) { + entry.append(" (type:").append(typeName).append(')') + } + array.add(entry.toString()) + } + array.joinToString( + separator = ", ", + prefix = this::class.simpleName + " {", + postfix = "}" + ) + } else { + this::class.simpleName!! + } + } +} + +// Full list here: https://jdbc.postgresql.org/documentation/publicapi/constant-values.html +private val oidMap: Map = hashMapOf( + Boolean::class.namedClassName to 16u, + ByteArray::class.namedClassName to 17u, + Long::class.namedClassName to 20u, + Int::class.namedClassName to 23u, + String::class.namedClassName to 25u, + Double::class.namedClassName to 701u, + LocalDate::class.namedClassName to 1082u, + LocalTime::class.namedClassName to 1083u, + LocalDateTime::class.namedClassName to 1114u, + Instant::class.namedClassName to 1184u, + // intervalOid = 1186u + // uuidOid = 2950u +) + +private val sqlTypeNames: Map = oidMap.entries.associateBy({ it.value }) { it.key } + +private val KClass<*>.namedClassName get() = this.simpleName!! diff --git a/src/commonMain/kotlin/io/github/moreirasantos/pgkn/paramsource/MapSqlParameterSource.kt b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/paramsource/MapSqlParameterSource.kt new file mode 100644 index 0000000..1ab0c82 --- /dev/null +++ b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/paramsource/MapSqlParameterSource.kt @@ -0,0 +1,96 @@ +package io.github.moreirasantos.pgkn.paramsource + + +/** + * [SqlParameterSource] implementation that holds a given Map of parameters. + * + * The [addValue] methods on this class will make adding several values + * easier. The methods return a reference to the [MapSqlParameterSource] + * itself, so you can chain several method calls together within a single statement. + */ +class MapSqlParameterSource : AbstractSqlParameterSource { + private val values: MutableMap = LinkedHashMap() + + /** + * Create a new MapSqlParameterSource based on a Map. + * @param values a Map holding existing parameter values (can be `null`) + */ + constructor(values: Map?) { + addValues(values) + } + + /** + * Add a parameter to this parameter source. + * @param paramName the name of the parameter + * @param value the value of the parameter + * @return a reference to this parameter source, + * so it's possible to chain several calls together + */ + fun addValue(paramName: String, value: Any?): MapSqlParameterSource { + this.values[paramName] = value + registerSqlType(paramName, value) + return this + } + + /** + * Add a parameter to this parameter source. + * @param paramName the name of the parameter + * @param value the value of the parameter + * @param sqlType the SQL type of the parameter + * @return a reference to this parameter source, + * so it's possible to chain several calls together + */ + fun addValue(paramName: String, value: Any?, sqlType: UInt): MapSqlParameterSource { + this.values[paramName] = value + registerSqlType(paramName = paramName, sqlType = sqlType) + return this + } + + /** + * Add a parameter to this parameter source. + * @param paramName the name of the parameter + * @param value the value of the parameter + * @param sqlType the SQL type of the parameter + * @param typeName the type name of the parameter + * @return a reference to this parameter source, + * so it's possible to chain several calls together + */ + fun addValue(paramName: String, value: Any?, sqlType: UInt, typeName: String?): MapSqlParameterSource { + this.values[paramName] = value + registerSqlType(paramName = paramName, sqlType = sqlType) + registerTypeName(paramName, typeName!!) + return this + } + + /** + * Add a Map of parameters to this parameter source. + * @param values a Map holding existing parameter values (can be `null`) + * @return a reference to this parameter source, + * so it's possible to chain several calls together + */ + fun addValues(values: Map?): MapSqlParameterSource { + values?.forEach { (key, value) -> + this.values[key] = value + registerSqlType(paramName = key, value = value) + } + return this + } + + /** + * Expose the current parameter values as read-only Map. + */ + fun getValues(): Map = this.values + + override fun hasValue(paramName: String) = this.values.containsKey(paramName) + + @Suppress("UseRequire") + override fun getValue(paramName: String): Any? { + if (!hasValue(paramName)) { + throw IllegalArgumentException("No value registered for key '$paramName'") + } + return this.values[paramName] + } + + override val parameterNames: Array get() = this.values.keys.toTypedArray() +} + diff --git a/src/commonMain/kotlin/io/github/moreirasantos/pgkn/paramsource/SqlParameterSource.kt b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/paramsource/SqlParameterSource.kt new file mode 100644 index 0000000..6dc3234 --- /dev/null +++ b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/paramsource/SqlParameterSource.kt @@ -0,0 +1,63 @@ +package io.github.moreirasantos.pgkn.paramsource + +/** + * Interface that defines common functionality for objects that can + * offer parameter values for named SQL parameters. + * + * This interface allows for the specification of SQL type in addition + * to parameter values. All parameter values and types are identified by + * specifying the name of the parameter. + * + */ +interface SqlParameterSource { + /** + * Determine whether there is a value for the specified named parameter. + * @param paramName the name of the parameter + * @return whether there is a value defined + */ + fun hasValue(paramName: String): Boolean + + /** + * Return the parameter value for the requested named parameter. + * @param paramName the name of the parameter + * @return the value of the specified parameter + * @throws IllegalArgumentException if there is no value for the requested parameter + */ + @Throws(IllegalArgumentException::class) + fun getValue(paramName: String): Any? + + /** + * Determine the SQL type for the specified named parameter. + * @param paramName the name of the parameter + * @return the SQL type of the specified parameter, + * or `TYPE_UNKNOWN` if not known + * @see .TYPE_UNKNOWN + */ + fun getSqlType(paramName: String) = TYPE_UNKNOWN + + /** + * Determine the type name for the specified named parameter. + * @param paramName the name of the parameter + * @return the type name of the specified parameter, + * or `null` if not known + */ + fun getTypeName(paramName: String?): String? = null + + /** + * Enumerate all available parameter names if possible. + */ + val parameterNames: Array? + get() = null + + companion object { + /** + * Constant that indicates an unknown (or unspecified) SQL type. + * To be returned from `getType` when no specific SQL type known. + * @see getSqlType + * + */ + val TYPE_UNKNOWN: UInt = UInt.MIN_VALUE + // TODO check if libpq has a default param type oid + } +} + diff --git a/src/commonMain/kotlin/io/github/moreirasantos/pgkn/resultset/PostgresResultSet.kt b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/resultset/PostgresResultSet.kt index 4be4d92..f60bab3 100644 --- a/src/commonMain/kotlin/io/github/moreirasantos/pgkn/resultset/PostgresResultSet.kt +++ b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/resultset/PostgresResultSet.kt @@ -1,8 +1,7 @@ package io.github.moreirasantos.pgkn.resultset import io.github.moreirasantos.pgkn.KLogger -import io.github.moreirasantos.pgkn.PgknMarker -import io.github.moreirasantos.pgkn.SQLException +import io.github.moreirasantos.pgkn.exception.GetColumnValueException import kotlinx.cinterop.ByteVar import kotlinx.cinterop.CPointer import kotlinx.cinterop.ExperimentalForeignApi @@ -44,10 +43,10 @@ internal class PostgresResultSet(val internal: CPointer) : ResultSet { private fun getPointer(columnIndex: Int): CPointer? { if (isNull(columnIndex)) return null - return PQgetvalue(res = internal, tup_num = currentRow, field_num = columnIndex) ?: throw SQLException() + return PQgetvalue(res = internal, tup_num = currentRow, field_num = columnIndex) + ?: throw GetColumnValueException(columnIndex) } - /** * Are all non-binary columns returned as text? * https://www.postgresql.org/docs/9.5/libpq-exec.html#LIBPQ-EXEC-SELECT-INFO diff --git a/src/commonMain/kotlin/io/github/moreirasantos/pgkn/resultset/ResultSet.kt b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/resultset/ResultSet.kt index 2c1958d..72f497c 100644 --- a/src/commonMain/kotlin/io/github/moreirasantos/pgkn/resultset/ResultSet.kt +++ b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/resultset/ResultSet.kt @@ -1,10 +1,10 @@ package io.github.moreirasantos.pgkn.resultset +import io.github.moreirasantos.pgkn.exception.SQLException import kotlinx.datetime.Instant import kotlinx.datetime.LocalDate import kotlinx.datetime.LocalDateTime import kotlinx.datetime.LocalTime -import io.github.moreirasantos.pgkn.SQLException @Suppress("TooManyFunctions") sealed interface ResultSet { diff --git a/src/commonMain/kotlin/io/github/moreirasantos/pgkn/sql/ParsedSql.kt b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/sql/ParsedSql.kt new file mode 100644 index 0000000..14605da --- /dev/null +++ b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/sql/ParsedSql.kt @@ -0,0 +1,285 @@ +package io.github.moreirasantos.pgkn.sql + +import io.github.moreirasantos.pgkn.exception.InvalidDataAccessApiUsageException + + +/** + * Heavily Based on: + * https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/jdbc/core/namedparam/NamedParameterUtils.html + */ +@Suppress( + "ComplexCondition", "LoopWithTooManyJumpStatements", + "CyclomaticComplexMethod", "LongMethod", "NestedBlockDepth" +) +internal fun parseSql(sql: String): ParsedSql { + + val namedParameters: MutableSet = HashSet() + val sqlToUse = StringBuilder(sql) + val parameterList: MutableList = ArrayList() + + val statement: CharArray = sql.toCharArray() + var namedParameterCount = 0 + var unnamedParameterCount = 0 + var totalParameterCount = 0 + + var escapes = 0 + var i = 0 + while (i < statement.size) { + var skipToPosition: Int + while (i < statement.size) { + skipToPosition = skipCommentsAndQuotes(statement, i) + i = if (i == skipToPosition) { + break + } else { + skipToPosition + } + } + if (i >= statement.size) { + break + } + val c = statement[i] + if (c == ':' || c == '&') { + var j = i + 1 + if (c == ':' && j < statement.size && statement[j] == ':') { + // Postgres-style "::" casting operator should be skipped + i += 2 + continue + } + var parameter: String? + if (c == ':' && j < statement.size && statement[j] == '{') { + // :{x} style parameter + while (statement[j] != '}') { + j++ + if (j >= statement.size) { + throw InvalidDataAccessApiUsageException( + "Non-terminated named parameter declaration at position $i in statement: $sql" + ) + } + if (statement[j] == ':' || statement[j] == '{') { + throw InvalidDataAccessApiUsageException( + "Parameter name contains invalid character '${statement[j]}' " + + "at position $i in statement: $sql" + ) + } + } + if (j - i > 2) { + parameter = sql.substring(i + 2, j) + namedParameterCount = addNewNamedParameter(namedParameters, namedParameterCount, parameter) + totalParameterCount = addNamedParameter( + parameterList, totalParameterCount, escapes, i, j + 1, parameter + ) + } + j++ + } else { + while (j < statement.size && !isParameterSeparator(statement[j])) { + j++ + } + if (j - i > 1) { + parameter = sql.substring(i + 1, j) + namedParameterCount = addNewNamedParameter(namedParameters, namedParameterCount, parameter) + totalParameterCount = addNamedParameter( + parameterList, totalParameterCount, escapes, i, j, parameter + ) + } + } + i = j - 1 + } else { + if (c == '\\') { + val j = i + 1 + if (j < statement.size && statement[j] == ':') { + // escaped ":" should be skipped + sqlToUse.deleteAt(i - escapes) + escapes++ + i += 2 + continue + } + } + if (c == '?') { + val j = i + 1 + if (j < statement.size && (statement[j] == '?' || statement[j] == '|' || statement[j] == '&')) { + // Postgres-style "??", "?|", "?&" operator should be skipped + i += 2 + continue + } + unnamedParameterCount++ + totalParameterCount++ + } + } + i++ + } + val parsedSql = ParsedSql(sqlToUse.toString()) + for (ph in parameterList) { + parsedSql.addNamedParameter(ph.parameterName, ph.startIndex, ph.endIndex) + } + parsedSql.namedParameterCount = namedParameterCount + parsedSql.unnamedParameterCount = unnamedParameterCount + parsedSql.totalParameterCount = totalParameterCount + return parsedSql +} + + +/** + * Holds information about a parsed SQL statement. + */ +internal class ParsedSql(val originalSql: String) { + + /** + * Return all the parameters (bind variables) in the parsed SQL statement. + * Repeated occurrences of the same parameter name are included here. + */ + val parameterNames: MutableList = ArrayList() + val parameterIndexes: MutableList = ArrayList() + /** + * Return the count of named parameters in the SQL statement. + * Each parameter name counts once; repeated occurrences do not count here. + */ + /** + * Set the count of named parameters in the SQL statement. + * Each parameter name counts once; repeated occurrences do not count here. + */ + var namedParameterCount = 0 + /** + * Return the count of all the unnamed parameters in the SQL statement. + */ + /** + * Set the count of all the unnamed parameters in the SQL statement. + */ + var unnamedParameterCount = 0 + /** + * Return the total count of all the parameters in the SQL statement. + * Repeated occurrences of the same parameter name do count here. + */ + /** + * Set the total count of all the parameters in the SQL statement. + * Repeated occurrences of the same parameter name do count here. + */ + var totalParameterCount = 0 + + /** + * Add a named parameter parsed from this SQL statement. + * @param parameterName the name of the parameter + * @param startIndex the start index in the original SQL String + * @param endIndex the end index in the original SQL String + */ + fun addNamedParameter(parameterName: String, startIndex: Int, endIndex: Int) { + parameterNames.add(parameterName) + parameterIndexes.add(intArrayOf(startIndex, endIndex)) + } + + /** + * Exposes the original SQL String. + */ + override fun toString() = originalSql +} + + +private class ParameterHolder(val parameterName: String, val startIndex: Int, val endIndex: Int) + +/** + * Skip over comments and quoted names present in an SQL statement. + * @param statement character array containing SQL statement + * @param position current position of statement + * @return next position to process after any comments or quotes are skipped + */ +@Suppress("NestedBlockDepth", "ReturnCount") +private fun skipCommentsAndQuotes(statement: CharArray, position: Int): Int { + for (i in START_SKIP.indices) { + if (statement[position] == START_SKIP[i][0]) { + var match = true + for (j in 1 until START_SKIP[i].length) { + if (statement[position + j] != START_SKIP[i][j]) { + match = false + break + } + } + if (match) { + val offset: Int = START_SKIP[i].length + for (m in position + offset until statement.size) { + if (statement[m] == STOP_SKIP[i][0]) { + var endMatch = true + var endPos = m + for (n in 1 until STOP_SKIP[i].length) { + if (m + n >= statement.size) { + // last comment not closed properly + return statement.size + } + if (statement[m + n] != STOP_SKIP[i][n]) { + endMatch = false + break + } + endPos = m + n + } + if (endMatch) { + // found character sequence ending comment or quote + return endPos + 1 + } + } + } + // character sequence ending comment or quote not found + return statement.size + } + } + } + return position +} + +@Suppress("LongParameterList") +private fun addNamedParameter( + parameterList: MutableList, + totalParameterCount: Int, + escapes: Int, + i: Int, + j: Int, + parameter: String +): Int { + var count = totalParameterCount + parameterList.add(ParameterHolder(parameter, i - escapes, j - escapes)) + count++ + return count +} + +private fun addNewNamedParameter( + namedParameters: MutableSet, + namedParameterCount: Int, + parameter: String +): Int { + var count = namedParameterCount + if (!namedParameters.contains(parameter)) { + namedParameters.add(parameter) + count++ + } + return count +} + +/** + * Determine whether a parameter name ends at the current position, + * that is, whether the given character qualifies as a separator. + */ +@Suppress("MagicNumber") +private fun isParameterSeparator(c: Char) = c.code < 128 && separatorIndex[c.code] || c.isWhitespace() + + +/** + * Set of characters that qualify as comment or quotes starting characters. + */ +private val START_SKIP = arrayOf("'", "\"", "--", "/*") + +/** + * Set of characters that at are the corresponding comment or quotes ending characters. + */ +private val STOP_SKIP = arrayOf("'", "\"", "\n", "*/") + +/** + * Set of characters that qualify as parameter separators, + * indicating that a parameter name in an SQL String has ended. + */ +private const val PARAMETER_SEPARATORS = "\"':&,;()|=+-*%/\\<>^" + +/** + * An index with separator flags per character code. + * Technically only needed between 34 and 124 at this point. + */ +@Suppress("MagicNumber") +private val separatorIndex = BooleanArray(128).apply { + PARAMETER_SEPARATORS.toCharArray().forEach { this[it.code] = true } +} diff --git a/src/commonMain/kotlin/io/github/moreirasantos/pgkn/sql/SqlParameter.kt b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/sql/SqlParameter.kt new file mode 100644 index 0000000..8ee46ac --- /dev/null +++ b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/sql/SqlParameter.kt @@ -0,0 +1,34 @@ +package io.github.moreirasantos.pgkn.sql + +import io.github.moreirasantos.pgkn.exception.InvalidDataAccessApiUsageException +import io.github.moreirasantos.pgkn.paramsource.SqlParameterSource + +/** + * Convert a Map of named parameter values to a corresponding array. + * @param parsedSql the parsed SQL statement + * @param paramSource the source for named parameters + * @return the array of values + */ +@Suppress("NestedBlockDepth", "SwallowedException") +internal fun buildValueArray( + parsedSql: ParsedSql, + paramSource: SqlParameterSource +): Array { + if (parsedSql.namedParameterCount > 0 && parsedSql.unnamedParameterCount > 0) { + throw InvalidDataAccessApiUsageException( + "Not allowed to mix named and traditional ? placeholders. You have " + + parsedSql.namedParameterCount + " named parameter(s) and " + + parsedSql.unnamedParameterCount + " traditional placeholder(s) in statement: " + + parsedSql.originalSql + ) + } + val paramArray = arrayOfNulls(parsedSql.totalParameterCount) + parsedSql.parameterNames.forEachIndexed { index, paramName -> + try { + paramArray[index] = paramSource.getValue(paramName) + } catch (ex: IllegalArgumentException) { + throw InvalidDataAccessApiUsageException("No value supplied for the SQL parameter '$paramName'", ex) + } + } + return paramArray +} diff --git a/src/commonMain/kotlin/io/github/moreirasantos/pgkn/sql/SubstituteNamedParameters.kt b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/sql/SubstituteNamedParameters.kt new file mode 100644 index 0000000..4544ac0 --- /dev/null +++ b/src/commonMain/kotlin/io/github/moreirasantos/pgkn/sql/SubstituteNamedParameters.kt @@ -0,0 +1,83 @@ +package io.github.moreirasantos.pgkn.sql + +import io.github.moreirasantos.pgkn.paramsource.SqlParameterSource + +/** + * Heavily Based on: + * https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/jdbc/core/namedparam/NamedParameterUtils.html + * Parse the SQL statement and locate any placeholders or named parameters. Named + * parameters are substituted for a placeholder, and any select list is expanded + * to the required number of placeholders. Select lists may contain an array of + * objects, and in that case the placeholders will be grouped and enclosed with + * parentheses. This allows for the use of "expression lists" in the SQL statement + * like:



+ * `select id, name, state from table where (name, age) in (('John', 35), ('Ann', 50))` + * + * The parameter values passed in are used to determine the number of placeholders to + * be used for a select list. Select lists should be limited to 100 or fewer elements. + * A larger number of elements is not guaranteed to be supported by the database and + * is strictly vendor-dependent. + * @param parsedSql the parsed representation of the SQL statement + * @param paramSource the source for named parameters + * @return the SQL statement with substituted parameters + */ +@Suppress("NestedBlockDepth") +internal fun substituteNamedParameters(parsedSql: ParsedSql, paramSource: SqlParameterSource?): String { + val originalSql: String = parsedSql.originalSql + val paramNames: List = parsedSql.parameterNames + if (paramNames.isEmpty()) { + return originalSql + } + val actualSql = StringBuilder(originalSql.length) + var lastIndex = 0 + var parameterNumber = 1 + for (i in paramNames.indices) { + val paramName = paramNames[i] + val indexes: IntArray = parsedSql.parameterIndexes[i] + val startIndex = indexes[0] + val endIndex = indexes[1] + actualSql.append(originalSql, lastIndex, startIndex) + if (paramSource != null && paramSource.hasValue(paramName)) { + val value: Any = paramSource.getValue(paramName)!! + /* + if (value is SqlParameterValue) { + value = (value as SqlParameterValue).getValue() + } + */ + if (value is Iterable<*>) { + val entryIter = value.iterator() + var k = 0 + while (entryIter.hasNext()) { + if (k > 0) { + actualSql.append(", ") + } + k++ + val entryItem = entryIter.next()!! + if (entryItem is Array<*>) { + actualSql.append('(') + for (m in entryItem.indices) { + if (m > 0) { + actualSql.append(", ") + } + actualSql.append('?') + } + actualSql.append(')') + } else { + actualSql.append('?') + } + } + } else { + // actualSql.append('?') + actualSql.append("\$$parameterNumber") + parameterNumber++ + } + } else { + // actualSql.append('?') + actualSql.append("\$$parameterNumber") + parameterNumber++ + } + lastIndex = endIndex + } + actualSql.append(originalSql, lastIndex, originalSql.length) + return actualSql.toString() +} diff --git a/src/commonTest/kotlin/io/github/moreirasantos/pgkn/NamedParametersTest.kt b/src/commonTest/kotlin/io/github/moreirasantos/pgkn/NamedParametersTest.kt new file mode 100644 index 0000000..a055333 --- /dev/null +++ b/src/commonTest/kotlin/io/github/moreirasantos/pgkn/NamedParametersTest.kt @@ -0,0 +1,75 @@ +package io.github.moreirasantos.pgkn + +import kotlin.test.Test +import kotlin.test.assertEquals + +class NamedParametersTest { + val driver = PostgresDriver( + host = "localhost", + port = 5678, + database = "postgres", + user = "postgres", + password = "postgres", + ) + + private fun createTable(name: String) = """ + create table $name + ( + id integer not null constraint id primary key, + name text, + email text, + int integer default 4 + ) + """.trimIndent() + + @Test + fun `should select with named params`() { + val t = "named_params" + driver.execute("drop table if exists $t") + driver.execute(createTable(t)) + assertEquals(0, driver.execute("select * from $t") {}.size) + + driver.execute("insert into $t(id, name, email, int) values(1, 'john', 'mail@mail.com', 10)") + + assertEquals(listOf("john"), driver.execute( + "select name from $t where name = :one", + mapOf("one" to "john") + ) { it.getString(0) }) + + assertEquals(listOf("john"), driver.execute( + "select name from $t where name = :one OR name = :other", + mapOf("one" to "john", "other" to "another") + ) { it.getString(0) }) + + assertEquals(emptyList(), driver.execute( + "select name from $t where name = :one", + mapOf("one" to "wrong") + ) { it.getString(0) }) + + driver.execute("drop table $t") + } + + @Test + fun `should update with named params`() { + val t = "named_params_update" + driver.execute("drop table if exists $t") + driver.execute(createTable(t)) + assertEquals(0, driver.execute("select * from $t") {}.size) + + driver.execute("insert into $t(id, name, email, int) values(1, 'john', 'mail@mail.com', 10)") + + assertEquals( + 1, driver.execute( + "update $t set int = :number where name = :one", + mapOf("one" to "john", "number" to 15) + ) + ) + + assertEquals(listOf("john"), driver.execute( + "select name from $t where int = :number", + mapOf("number" to 15) + ) { it.getString(0) }) + + driver.execute("drop table $t") + } +} diff --git a/src/commonTest/kotlin/io/github/moreirasantos/pgkn/PostgresDriverTest.kt b/src/commonTest/kotlin/io/github/moreirasantos/pgkn/PostgresDriverTest.kt index 19cdf36..521a27c 100644 --- a/src/commonTest/kotlin/io/github/moreirasantos/pgkn/PostgresDriverTest.kt +++ b/src/commonTest/kotlin/io/github/moreirasantos/pgkn/PostgresDriverTest.kt @@ -8,8 +8,7 @@ import kotlin.test.assertEquals import kotlin.test.assertFailsWith class PostgresDriverTest { - - private val driver = PostgresDriver( + val driver = PostgresDriver( host = "localhost", port = 5678, database = "postgres",