Skip to content

Commit

Permalink
Add API for prepared queries (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlindenkreuz authored Jan 7, 2024
1 parent 3e6ebdc commit 317cf10
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 8 deletions.
47 changes: 39 additions & 8 deletions lib/adbc_connection.ex
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,36 @@ defmodule Adbc.Connection do
@doc """
Runs the given `query` with `params`.
"""
@spec query(t(), binary, [term]) :: {:ok, result_set} | {:error, Exception.t()}
def query(conn, query, params \\ []) when is_binary(query) and is_list(params) do
@spec query(t(), binary | reference, [term]) :: {:ok, result_set} | {:error, Exception.t()}
def query(conn, query, params \\ []) when (is_binary(query) or is_reference(query)) and is_list(params) do
stream_lock(conn, {:query, query, params}, &stream_results/2)
end

@doc """
Same as `query/3` but raises an exception on error.
"""
@spec query!(t(), binary, [term]) :: result_set
def query!(conn, query, params \\ []) when is_binary(query) and is_list(params) do
@spec query!(t(), binary | reference, [term]) :: result_set
def query!(conn, query, params \\ []) when (is_binary(query) or is_reference(query)) and is_list(params) do
case query(conn, query, params) do
{:ok, result} -> result
{:error, reason} -> raise reason
end
end

@doc """
Prepares the given `query`.
"""
@spec prepare(t(), binary) :: {:ok, reference} | {:error, Exception.t()}
def prepare(conn, query) when is_binary(query) do
with stmt <- create_statement(conn, query),
:ok <- Adbc.Nif.adbc_statement_prepare(stmt) do
{:ok, stmt}
else
{:error, reason} ->
{:error, error_to_exception(reason)}
end
end

@doc """
Runs the given `query` with `params` and
pass the ArrowStream pointer to the given function.
Expand All @@ -92,7 +106,7 @@ defmodule Adbc.Connection do
native code that consumes the ArrowStream accordingly.
"""
def query_pointer(conn, query, params \\ [], fun)
when is_binary(query) and is_list(params) and is_function(fun) do
when (is_binary(query) or is_reference(query)) and is_list(params) and is_function(fun) do
stream_lock(conn, {:query, query, params}, fn stream_ref, rows_affected ->
{:ok, fun.(Adbc.Nif.adbc_arrow_array_stream_get_pointer(stream_ref), rows_affected)}
end)
Expand Down Expand Up @@ -365,20 +379,37 @@ defmodule Adbc.Connection do

defp maybe_dequeue(state), do: state

defp handle_command({:query, query, params}, conn) do
with {:ok, stmt} <- Adbc.Nif.adbc_statement_new(conn),
:ok <- Adbc.Nif.adbc_statement_set_sql_query(stmt, query),
defp handle_command({:query, query, params}, conn) when is_binary(query) do
with stmt <- create_statement(conn, query),
:ok <- maybe_bind(stmt, params) do
Adbc.Nif.adbc_statement_execute_query(stmt)
end
end

defp handle_command({:query, stmt, params}, _conn) when is_reference(stmt) do
with :ok <- maybe_bind(stmt, params) do
Adbc.Nif.adbc_statement_execute_query(stmt)
end
end

defp handle_command({name, args}, conn) do
with {:ok, stream_ref} <- apply(Adbc.Nif, name, [conn | args]) do
{:ok, stream_ref, -1}
end
end

defp create_statement(conn, query) when is_pid(conn) do
%{conn: conn} = :sys.get_state(conn)
create_statement(conn, query)
end

defp create_statement(conn, query) when is_reference(conn) do
with {:ok, stmt} <- Adbc.Nif.adbc_statement_new(conn),
:ok <- Adbc.Nif.adbc_statement_set_sql_query(stmt, query) do
stmt
end
end

defp maybe_bind(_stmt, []), do: :ok
defp maybe_bind(stmt, params), do: Adbc.Nif.adbc_statement_bind(stmt, params)
end
37 changes: 37 additions & 0 deletions test/adbc_connection_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,23 @@ defmodule Adbc.Connection.Test do
assert {:error, %Adbc.Error{} = error} = Connection.query(conn, "NOT VALID SQL")
assert Exception.message(error) =~ "[SQLite] Failed to prepare query"
end

test "select with prepared query", %{db: db} do
conn = start_supervised!({Connection, database: db})
assert {:ok, ref} = Connection.prepare(conn, "SELECT 123 + ? as num")
assert {:ok, %Adbc.Result{data: %{"num" => [579]}}} =
Connection.query(conn, ref, [456])
end

test "select with multiple prepared queries", %{db: db} do
conn = start_supervised!({Connection, database: db})
assert {:ok, ref_a} = Connection.prepare(conn, "SELECT 123 + ? as num")
assert {:ok, ref_b} = Connection.prepare(conn, "SELECT 1000 + ? as num")
assert {:ok, %Adbc.Result{data: %{"num" => [579]}}} =
Connection.query(conn, ref_a, [456])
assert {:ok, %Adbc.Result{data: %{"num" => [1456]}}} =
Connection.query(conn, ref_b, [456])
end
end

describe "query!" do
Expand Down Expand Up @@ -216,6 +233,14 @@ defmodule Adbc.Connection.Test do
end
end

describe "prepared queries" do
test "prepare", %{db: db} do
conn = start_supervised!({Connection, database: db})
assert {:ok, ref} = Connection.prepare(conn, "SELECT 123 + ? as num")
assert is_reference(ref)
end
end

describe "query_pointer" do
test "select", %{db: db} do
conn = start_supervised!({Connection, database: db})
Expand All @@ -226,6 +251,17 @@ defmodule Adbc.Connection.Test do
:from_pointer
end)
end

test "prepared query", %{db: db} do
conn = start_supervised!({Connection, database: db})

{:ok, ref} = Connection.prepare(conn, "SELECT 123 + ? as num")
assert {:ok, :from_pointer} =
Connection.query_pointer(conn, ref, [456], fn
pointer, nil when is_integer(pointer) ->
:from_pointer
end)
end
end

describe "lock" do
Expand Down Expand Up @@ -270,6 +306,7 @@ defmodule Adbc.Connection.Test do
test "commands that error do not lock", %{db: db} do
conn = start_supervised!({Connection, database: db})
{:error, %Adbc.Error{}} = Connection.query(conn, "NOT VALID SQL")
{:error, %Adbc.Error{}} = Connection.prepare(conn, "NOT VALID SQL")
run_anything(conn)
end

Expand Down

0 comments on commit 317cf10

Please sign in to comment.