Skip to content

Commit

Permalink
Add user collation support
Browse files Browse the repository at this point in the history
Add support to create and delete [custom
collations](https://www.sqlite.org/c3ref/create_collation.html).

The functionality is similair to user functions, but there's a fixed arity (two
strings) and a fixed return value (type) so the code is quite simple compared to the
implementation of user functions.

Automated test is included to ensure that the functionality actually works as expected.
  • Loading branch information
Gusted committed Jan 20, 2025
1 parent 40dc885 commit 29d38f6
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 1 deletion.
1 change: 0 additions & 1 deletion TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
- Add the following wrappers:

- Better SQL Function support (aggregating ones)
- Collation support
- Authorization support
- Trace/profile support
- progress handler support
Expand Down
8 changes: 8 additions & 0 deletions lib/sqlite3.ml
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,14 @@ module Aggregate = struct
value final
end

(* Collation registration *)

external create_collation : db -> string -> (string -> string -> int) -> unit
= "caml_sqlite3_create_collation"

external delete_collation : db -> string -> unit
= "caml_sqlite3_delete_collation"

module Backup = struct
module Raw = struct
type t
Expand Down
14 changes: 14 additions & 0 deletions lib/sqlite3.mli
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,20 @@ module Aggregate : sig
@raise SqliteError if an invalid database handle is passed. *)
end

val create_collation : db -> string -> (string -> string -> int) -> unit
(** [create_collation db name func] creates a collation with [name] in database
handle [db]. [func] is called when the collation is needed, it must return
an integer that is negative, zero, or positive if the first string is less
than, equal to, or greater than the second, respectively
@raise SqliteError if an invalid database handle is passed. *)

val delete_collation : db -> string -> unit
(** [delete_collation db name] deletes collation with name [name] from database
handle [db].
@raise SqliteError if an invalid database handle is passed. *)

module Backup : sig
type t
(** Type of a backup between two databases *)
Expand Down
94 changes: 94 additions & 0 deletions lib/sqlite3_stubs.c
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,17 @@ typedef struct user_function {
struct user_function *next;
} user_function;

typedef struct user_collation {
value v_fun;
struct user_collation *next;
} user_collation;

typedef struct db_wrap {
sqlite3 *db;
int rc;
int ref_count;
user_function *user_functions;
user_collation *user_collations;
} db_wrap;

typedef struct stmt_wrap {
Expand Down Expand Up @@ -360,6 +366,13 @@ static inline void ref_count_finalize_dbw(db_wrap *dbw) {
caml_stat_free(link);
}
dbw->user_functions = NULL;
user_collation *link_c, *next_c;
for (link_c = dbw->user_collations; link_c != NULL; link_c = next_c) {
caml_remove_generational_global_root(&link_c->v_fun);
next_c = link_c->next;
caml_stat_free(link);
}
dbw->user_collations = NULL;
my_sqlite3_close(dbw->db);
caml_stat_free(dbw);
}
Expand Down Expand Up @@ -514,6 +527,7 @@ CAMLprim value caml_sqlite3_open(value v_mode, value v_uri, value v_memory,
dbw->rc = SQLITE_OK;
dbw->ref_count = 1;
dbw->user_functions = NULL;
dbw->user_collations = NULL;
Sqlite3_val(v_res) = dbw;
return v_res;
}
Expand Down Expand Up @@ -1525,6 +1539,86 @@ CAMLprim value caml_sqlite3_delete_function(value v_db, value v_name) {
return Val_unit;
}

/* User defined collations */

static inline void unregister_user_collation(db_wrap *db_data, value v_name) {
user_collation *prev = NULL, *link = db_data->user_collations;
const char *name = String_val(v_name);

while (link != NULL) {
if (strcmp(String_val(Field(link->v_fun, 0)), name) == 0) {
if (prev == NULL)
db_data->user_collations = link->next;
else
prev->next = link->next;
caml_remove_generational_global_root(&link->v_fun);
caml_stat_free(link);
break;
}
prev = link;
link = link->next;
}
}

static inline user_collation *
register_user_collation(db_wrap *db_data, value v_name, value v_fun) {
user_collation *link;
value v_cell = caml_alloc_small(2, 0);
Field(v_cell, 0) = v_name;
Field(v_cell, 1) = v_fun;

/* Assume parameters are already protected */
link = caml_stat_alloc(sizeof *link);
link->v_fun = v_cell;
link->next = db_data->user_collations;
caml_register_generational_global_root(&link->v_fun);
db_data->user_collations = link;
return link;
}

int caml_sqlite3_user_collation(void *ctx, int nLeft, const void *zLeft,
int nRight, const void *zRight) {
user_collation *data = ctx;
value v_res, v_left, v_right;
int v_return;
caml_leave_blocking_section();
v_left = caml_alloc_initialized_string(nLeft, zLeft);
v_right = caml_alloc_initialized_string(nRight, zRight);
v_res = caml_callback2_exn(Field(data->v_fun, 1), v_left, v_right);
v_return = Int_val(v_res);
caml_enter_blocking_section();
return v_return;
}

CAMLprim value caml_sqlite3_create_collation(value v_db, value v_name,
value v_fun) {
CAMLparam3(v_db, v_name, v_fun);
user_collation *param;
int rc;
db_wrap *dbw = Sqlite3_val(v_db);
check_db(dbw, "create_collation");
param = register_user_collation(dbw, v_name, v_fun);
rc = sqlite3_create_collation(dbw->db, String_val(v_name), SQLITE_UTF8, param,
caml_sqlite3_user_collation);
if (rc != SQLITE_OK) {
unregister_user_collation(dbw, v_name);
raise_sqlite3_current(dbw->db, "create_collation");
}
CAMLreturn(Val_unit);
}

CAMLprim value caml_sqlite3_delete_collation(value v_db, value v_name) {
int rc;
db_wrap *dbw = Sqlite3_val(v_db);
check_db(dbw, "delete_collation");
rc = sqlite3_create_collation(dbw->db, String_val(v_name), SQLITE_UTF8, NULL,
NULL);
if (rc != SQLITE_OK)
raise_sqlite3_current(dbw->db, "delete_collation");
unregister_user_collation(dbw, v_name);
return Val_unit;
}

CAMLprim value caml_sqlite3_busy_timeout(value v_db, intnat ms) {
int rc;
db_wrap *dbw = Sqlite3_val(v_db);
Expand Down
58 changes: 58 additions & 0 deletions test/test_collation.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
open Sqlite3

let assert_ok rc = assert (rc = Rc.OK)
let assert_error rc = assert (rc = Rc.ERROR)

let assert_rows_equal expected_rows db sql =
let actual_rows = ref [] in
let _ =
assert_ok
(exec db sql ~cb:(fun row _ ->
match row.(0) with
| Some a -> actual_rows := a :: !actual_rows
| _ -> ()))
in
let actual_rows = List.sort compare (List.rev !actual_rows) in
assert (actual_rows = expected_rows)

let%test "test_collation" =
let db = db_open "t_collation" in
create_collation db "FIRST_CHAR" (fun left right ->
compare (String.get left 0) (String.get right 0));

let found_first_char = ref false in
let _ =
assert_ok
(exec db "PRAGMA collation_list" ~cb:(fun row _ ->
match row.(1) with
| Some a -> found_first_char := !found_first_char || a = "FIRST_CHAR"
| _ -> ()))
in
assert !found_first_char;

assert_ok (exec db "DROP TABLE IF EXISTS tbl");
assert_ok (exec db "CREATE TABLE tbl (a varchar(10) COLLATE FIRST_CHAR)");
assert_ok (exec db "INSERT INTO tbl VALUES ('pippo')");
assert_ok (exec db "INSERT INTO tbl VALUES ('pippo2')");
assert_ok (exec db "INSERT INTO tbl VALUES ('atypical')");
assert_rows_equal [ "pippo"; "pippo2" ] db
"SELECT * FROM tbl WHERE a = 'pippo'";
assert_rows_equal [ "pippo"; "pippo2" ] db
"SELECT * FROM tbl WHERE a = 'papa'";
assert_rows_equal [ "atypical" ] db
"SELECT * FROM tbl WHERE a = 'asymmetrical'";
assert_rows_equal [ "atypical" ] db "SELECT * FROM tbl WHERE a = 'atypical'";
assert_rows_equal [] db "SELECT * FROM tbl WHERE a = 'border'";

assert_ok (exec db "DROP TABLE IF EXISTS tbl");
assert_ok (exec db "CREATE TABLE tbl (a varchar(10))");
assert_ok (exec db "INSERT INTO tbl VALUES ('pippo')");
assert_ok (exec db "INSERT INTO tbl VALUES ('pippo2')");
assert_rows_equal [ "pippo" ] db "SELECT * FROM tbl WHERE a = 'pippo'";
assert_rows_equal [ "pippo"; "pippo2" ] db
"SELECT * FROM tbl WHERE a = 'pippo' COLLATE FIRST_CHAR";

delete_collation db "FIRST_CHAR";
assert_error
(exec db "SELECT * FROM tbl WHERE a = 'pippo' COLLATE FIRST_CHAR");
true

0 comments on commit 29d38f6

Please sign in to comment.