diff --git a/c_src/adbc_nif.cpp b/c_src/adbc_nif.cpp index 63da108..7fadb3a 100644 --- a/c_src/adbc_nif.cpp +++ b/c_src/adbc_nif.cpp @@ -96,7 +96,7 @@ template static ERL_NIF_TERM strings_from_buffer( return enif_make_list_from_array(env, values.data(), (unsigned)values.size()); } -static int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level, std::vector &out_terms, ERL_NIF_TERM &error); +static int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level, std::vector &out_terms, ERL_NIF_TERM &error, bool *end_of_series = nullptr); static int get_arrow_array_children_as_list(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level, std::vector &children, ERL_NIF_TERM &error) { ERL_NIF_TERM children_term{}; @@ -113,9 +113,9 @@ static int get_arrow_array_children_as_list(ErlNifEnv *env, struct ArrowSchema * return 1; } - children.resize(schema->n_children); - if (schema->n_children > 0) { - for (int64_t child_i = 0; child_i < schema->n_children; child_i++) { + children.resize(values->n_children); + if (values->n_children > 0) { + for (int64_t child_i = 0; child_i < values->n_children; child_i++) { struct ArrowSchema * child_schema = schema->children[child_i]; struct ArrowArray * child_values = values->children[child_i]; std::vector childrens; @@ -277,7 +277,7 @@ static ERL_NIF_TERM get_arrow_array_list_children(ErlNifEnv *env, struct ArrowSc return enif_make_list_from_array(env, children.data(), (unsigned)items_values->n_children); } -int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level, std::vector &out_terms, ERL_NIF_TERM &error) { +int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level, std::vector &out_terms, ERL_NIF_TERM &error, bool *end_of_series) { if (schema == nullptr) { error = erlang::nif::error(env, "invalid ArrowSchema (nullptr) when invoking next"); return 1; @@ -430,10 +430,18 @@ int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct if (strncmp("+s", format, 2) == 0) { // only handle and return children if this is a struct is_struct = true; - if (get_arrow_array_children_as_list(env, schema, values, level, children, error) == 1) { - return 1; + + if (schema->n_children == values->n_children) { + if (get_arrow_array_children_as_list(env, schema, values, level, children, error) == 1) { + return 1; + } + children_term = enif_make_list_from_array(env, children.data(), (unsigned)schema->n_children); + } else { + if (end_of_series) { + *end_of_series = true; + } + children_term = erlang::nif::atom(env, "end_of_series"); } - children_term = enif_make_list_from_array(env, children.data(), (unsigned)schema->n_children); } else if (strncmp("+m", format, 2) == 0) { children_term = get_arrow_array_map_children(env, schema, values, level); } else if (strncmp("+l", format, 2) == 0 || strncmp("+L", format, 2) == 0) { @@ -1071,28 +1079,38 @@ static ERL_NIF_TERM adbc_arrow_array_stream_next(ErlNifEnv *env, int argc, const return erlang::nif::error(env, reason ? reason : "unknown error"); } - if (res->private_data == nullptr) { - res->private_data = enif_alloc(sizeof(struct ArrowSchema)); - memset(res->private_data, 0, sizeof(struct ArrowSchema)); - int code = res->val.get_schema(&res->val, (struct ArrowSchema *)res->private_data); - if (code != 0) { - const char * reason = res->val.get_last_error(&res->val); - enif_free(res->private_data); - res->private_data = nullptr; - return erlang::nif::error(env, reason ? reason : "unknown error"); - } + if (res->private_data != nullptr) { + enif_free(res->private_data); + res->private_data = nullptr; + } + + res->private_data = enif_alloc(sizeof(struct ArrowSchema)); + memset(res->private_data, 0, sizeof(struct ArrowSchema)); + code = res->val.get_schema(&res->val, (struct ArrowSchema *)res->private_data); + if (code != 0) { + const char * reason = res->val.get_last_error(&res->val); + enif_free(res->private_data); + res->private_data = nullptr; + return erlang::nif::error(env, reason ? reason : "unknown error"); } std::vector out_terms; auto schema = (struct ArrowSchema*)res->private_data; - if (arrow_array_to_nif_term(env, schema, &out, 0, out_terms, error) == 1) { + bool end_of_series = false; + if (arrow_array_to_nif_term(env, schema, &out, 0, out_terms, error, &end_of_series) == 1) { if (out.release) out.release(&out); return error; } if (out_terms.size() == 1) { ret = out_terms[0]; + if (end_of_series) { + if (out.release) { + out.release(&out); + } + return ret; + } } else { ret = enif_make_tuple2(env, out_terms[0], out_terms[1]); } diff --git a/lib/adbc_connection.ex b/lib/adbc_connection.ex index e740303..7ab01e5 100644 --- a/lib/adbc_connection.ex +++ b/lib/adbc_connection.ex @@ -286,13 +286,12 @@ defmodule Adbc.Connection do defp stream_results(reference, acc, num_rows) do case Adbc.Nif.adbc_arrow_array_stream_next(reference) do - {:ok, results, done} -> + {:ok, results, _done} -> acc = Map.merge(acc, Map.new(results), fn _k, v1, v2 -> v1 ++ v2 end) + stream_results(reference, acc, num_rows) - case done do - 0 -> stream_results(reference, acc, num_rows) - 1 -> {:ok, %Adbc.Result{data: acc, num_rows: num_rows}} - end + :end_of_series -> + {:ok, %Adbc.Result{data: acc, num_rows: num_rows}} {:error, reason} -> {:error, error_to_exception(reason)} diff --git a/test/adbc_test.exs b/test/adbc_test.exs index 2758db6..dd1a1a0 100644 --- a/test/adbc_test.exs +++ b/test/adbc_test.exs @@ -35,6 +35,20 @@ defmodule AdbcTest do Connection.query(conn, "SELECT 123 as num") end + test "getting all chunks", %{conn: conn} do + query = """ + SELECT * FROM generate_series('2000-03-01 00:00'::timestamp, '2100-03-04 12:00'::timestamp, '15 minutes') + """ + + %Adbc.Result{ + data: %{ + "generate_series" => generate_series + } + } = Connection.query!(conn, query) + + assert Enum.count(generate_series) == 3_506_641 + end + test "select with temporal types", %{conn: conn} do query = """ select