diff --git a/backend/database_pools.ml b/backend/database_pools.ml index 1bcd59c..ddb836e 100644 --- a/backend/database_pools.ml +++ b/backend/database_pools.ml @@ -1,3 +1,5 @@ +open CCFun.Infix + exception Exception of string module type Sig = Database_pools_sig.Sig @@ -18,251 +20,247 @@ module LogTag = struct ;; end -type connection_type = - | SinglePool of string - | MultiPools of (string * string) list - -let with_log ?tags ?(log_level = Logs.Error) ?(msg_prefix = "Error") err = - let msg = Caqti_error.show err in - Logs.msg ~src log_level (fun m -> m ?tags "%s: %s" msg_prefix msg); - msg -;; - -let get_or_raise ?ctx ?tags ?log_level ?msg_prefix () = - let tags = CCOption.or_ ~else_:(LogTag.ctx_opt ?ctx ()) tags in - function - | Ok result -> result - | Error error -> failwith (with_log ?tags ?log_level ?msg_prefix error) -;; - -let map_or_raise ?ctx ?tags ?log_level ?msg_prefix fcn result = - result - |> CCResult.map fcn - |> get_or_raise ?ctx ?tags ?log_level ?msg_prefix () -;; - module type ConfigSig = sig - val database : connection_type + val database : string * string val database_pool_size : int + val expected_databases : int end module DefaultConfig : ConfigSig = struct - let database = SinglePool "mariadb://root@database:3306/test" + let database = "main", "mariadb://root@database:3306/test" let database_pool_size = 5 + let expected_databases = 1 end module Make (Config : ConfigSig) = struct - let main_pool_ref - : (Caqti_lwt.connection, Caqti_error.t) Caqti_lwt_unix.Pool.t option ref - = - ref None - ;; + module Config = Config - let pools - : ( string - , (Caqti_lwt.connection, Caqti_error.t) Caqti_lwt_unix.Pool.t ) - Hashtbl.t - = - let spare_for_pools = 5 in - Hashtbl.create - (match Config.database with - | SinglePool _ -> 1 - | MultiPools pools -> CCList.length pools + spare_for_pools) - ;; + type connection = + | Close + | Open of (Caqti_lwt.connection, Caqti_error.t) Caqti_lwt_unix.Pool.t + | Fail of Caqti_error.t + + module Pool = struct + type t = + { database_label : string + ; database_url : string + ; required : bool + ; connection : connection [@opaque] + ; n_retries : int + } + [@@deriving fields] + + let create ?(required = false) database_label database_url = + { database_label + ; database_url + ; required + ; connection = Close + ; n_retries = 0 + } + ;; + + let reset_retry pool = { pool with n_retries = 0 } + let increment_retry pool = { pool with n_retries = pool.n_retries + 1 } - let print_pool_usage ?tags pool = - let n_connections = Caqti_lwt_unix.Pool.size pool in - let max_connections = Config.database_pool_size in - Logs.debug ~src (fun m -> - m ?tags "Pool usage: %i/%i" n_connections max_connections) + let connect_pool = + Uri.of_string + %> Caqti_lwt_unix.connect_pool + ~pool_config: + (Caqti_pool_config.create ~max_size:Config.database_pool_size ()) + ;; + + let connect + ?(retries = 2) + ({ database_label; database_url; required; _ } as pool) + = + let tags = database_label |> LogTag.create in + CCResult.retry retries (fun () -> database_url |> connect_pool) + |> (function + | Error [] -> raise (Exception "Failed to connect: empty error") + | Error (err :: _) when required -> raise (Caqti_error.Exn err) + | Error (err :: _ as errors) -> + Logs.warn ~src (fun m -> + m + ~tags + "Failed to connect: %s (%s)" + database_label + ([%show: Caqti_error.t list] errors)); + Fail err + | Ok con -> Open con) + |> fun connection -> { pool with connection } + ;; + end + + module Cache = struct + module Hashtbl = CCHashtbl.Make (CCString) + + let pools : Pool.t Hashtbl.t = + Hashtbl.create (max 1 Config.expected_databases) + ;; + + let add = Hashtbl.add pools + let remove = Hashtbl.remove pools + let find_opt = Hashtbl.find_opt pools + let replace pool = Hashtbl.replace pools (Pool.database_label pool) pool + end + + let print_pool_usage ?tags = + Pool.connection + %> function + | Open pool -> + let n_connections = Caqti_lwt_unix.Pool.size pool in + Logs.debug ~src (fun m -> + m ?tags "Pool usage: %i/%i" n_connections Config.database_pool_size) + | Close | Fail _ -> + Logs.debug ~src (fun m -> m ?tags "Pool usage: No connection found") ;; - let connect_or_failwith - ?(pool_size = Config.database_pool_size) - ok_fun - database_url - = - database_url - |> Uri.of_string - |> Caqti_lwt_unix.connect_pool - ~pool_config:(Caqti_pool_config.create ~max_size:pool_size ()) - |> map_or_raise ~msg_prefix:"Failed to connect to DB pool" ok_fun + let drain_opt = + Pool.connection + %> function + | Open pool -> Caqti_lwt_unix.Pool.drain pool + | Close | Fail _ -> Lwt.return_unit ;; - let add_pool ?pool_size name database_url = - let tags = LogTag.create name in - match Config.database, Hashtbl.find_opt pools name with - | SinglePool _, _ -> - failwith "SinglePool is selected: Switch to 'MultiPools' first" - | MultiPools _, Some _ -> + let add_pool ?required database_label database_url = + match Cache.find_opt database_label with + | Some _ -> let msg = - Format.asprintf - "Failed to create pool: Connection pool with name '%s' exists already" - name + [%string "Failed to add pool: Pool already exists %{database_label}"] in - Logs.err ~src (fun m -> m ~tags "%s" msg); + Logs.err ~src (fun m -> + m ~tags:(database_label |> LogTag.create) "%s" msg); failwith msg - | MultiPools _, None -> - database_url |> connect_or_failwith ?pool_size (Hashtbl.add pools name) + | None -> + Pool.create ?required database_label database_url + |> Cache.add database_label ;; let drop_pool name = - let tags = LogTag.create name in - match Config.database, Hashtbl.find_opt pools name with - | SinglePool _, _ -> - failwith "SinglePool is selected: Switch to 'MultiPools' first" - | MultiPools _, None -> + match Cache.find_opt name with + | None -> let msg = - Format.asprintf - "Failed to drop pool: connection to '%s' doesn't exist" - name + [%string "Failed to drop pool: connection to '%{name}' doesn't exist"] in - Logs.warn ~src (fun m -> m ~tags "%s" msg); - Lwt.return_unit - | MultiPools _, Some connection -> - let () = Hashtbl.remove pools name in - let%lwt () = Caqti_lwt_unix.Pool.drain connection in + Logs.info ~src (fun m -> m ~tags:(LogTag.create name) "%s" msg); Lwt.return_unit + | Some pool -> + let%lwt () = drain_opt pool in + Cache.remove name |> Lwt.return ;; - let initialize () = - match Config.database with - | SinglePool database_url when CCOption.is_none !main_pool_ref -> - database_url - |> connect_or_failwith (fun pool -> - main_pool_ref := Some pool; - ()) - | SinglePool _ -> () - | MultiPools pools' -> - pools' - |> CCList.filter (fun (name, _) -> - CCOption.is_none (Hashtbl.find_opt pools name)) - |> CCList.iter (fun (name, url) -> - url |> connect_or_failwith (Hashtbl.add pools name)) + let initialize ?(additinal_pools : (string * string) list = []) () : unit = + Config.database :: additinal_pools + |> CCList.filter (fst %> Cache.find_opt %> CCOption.is_none) + |> CCList.iter (CCFun.uncurry (Pool.create ~required:true) %> Cache.replace) ;; - let fetch_pool ?(ctx = []) () = - let open CCOption in - let () = initialize () in - match Config.database with - | SinglePool _ -> - !main_pool_ref |> get_exn_or "Initialization missed: run 'initialize'" - | MultiPools _ -> - find_pool_name ctx - >>= Hashtbl.find_opt pools - |> (function - | Some pool -> pool - | None -> failwith "Unknown Pool: Please 'add_pool' first!") + let connect = + Cache.find_opt + %> function + | Some pool -> + let rec connect pool = + match pool.Pool.connection with + | Fail err -> Error (Caqti_error.show err) + | Close -> Pool.connect pool |> connect + | Open _ -> Ok () + in + connect pool + | None -> Error "Database not found" ;; - let transaction ?ctx f = - let open Lwt.Infix in - let pool = fetch_pool ?ctx () in - print_pool_usage pool; - Caqti_lwt_unix.Pool.use - (fun connection -> - Logs.debug ~src (fun m -> - m ?tags:(LogTag.ctx_opt ?ctx ()) "Fetched connection from pool"); - let (module Connection : Caqti_lwt.CONNECTION) = connection in - let open Caqti_error in - match%lwt Connection.start () with - | Error msg -> - Logs.debug ~src (fun m -> - m - ?tags:(LogTag.ctx_opt ?ctx ()) - "Failed to start transaction: %s" - (show msg)); - Lwt.return_error msg - | Ok () -> - Logs.debug ~src (fun m -> - m ?tags:(LogTag.ctx_opt ?ctx ()) "Started transaction"); - Lwt.catch - (fun () -> - match%lwt Connection.commit () with - | Ok () -> - Logs.debug ~src (fun m -> - m - ?tags:(LogTag.ctx_opt ?ctx ()) - "Successfully committed transaction"); - f connection |> Lwt_result.return - | Error error -> - Exception - (with_log - ?tags:(LogTag.ctx_opt ?ctx ()) - ~msg_prefix:"Failed to commit transaction" - error) - |> Lwt.fail) - (fun e -> - match%lwt Connection.rollback () with - | Ok () -> - Logs.debug ~src (fun m -> - m - ?tags:(LogTag.ctx_opt ?ctx ()) - "Successfully rolled back transaction"); - Lwt.fail e - | Error error -> - Exception - (with_log - ?tags:(LogTag.ctx_opt ?ctx ()) - ~msg_prefix:"Failed to rollback transaction" - error) - |> Lwt.fail)) - pool - >|= get_or_raise ?ctx () + let disconnect ?error = + Cache.find_opt + %> function + | Some pool -> + let%lwt () = drain_opt pool in + Cache.replace + { pool with + Pool.connection = + CCOption.map_or ~default:Close (fun err -> Fail err) error + } + |> Lwt.return + | None -> Lwt.return_unit ;; - let transaction' ?ctx f = transaction ?ctx f |> Lwt.map (get_or_raise ?ctx ()) + let raise_caqti_error label input = + let open Caqti_error in + match%lwt input with + | Ok resp -> Lwt.return resp + | Error `Unsupported -> raise (Exception "Caqti error") + | Error (#load_or_connect as err) -> + let%lwt () = disconnect ~error:err label in + raise (Exn err) + | Error (#t as err) -> raise (Exn err) + ;; - let exec_with_connection - (request : ('a, unit, [< `Zero ]) Caqti_request.t) - (input : 'a) - (connection : (module Caqti_lwt.CONNECTION)) - : unit Lwt.t - = - let open CCFun in - let (module Connection : Caqti_lwt.CONNECTION) = connection in - Connection.exec request input - |> Lwt.map CCResult.(map_err Caqti_error.show %> get_or_failwith) + let rec fetch_pool ?(ctx = []) ?(retries = 2) () = + match ctx |> find_pool_name |> CCFun.flip CCOption.bind Cache.find_opt with + | Some pool -> + (match Pool.connection pool with + | Fail err when pool.Pool.n_retries >= retries -> + raise_caqti_error + (Pool.database_label pool) + (Error err |> Lwt_result.lift) + | Fail _ -> + let () = Pool.connect pool |> Pool.increment_retry |> Cache.replace in + fetch_pool ~ctx ~retries () + | Close -> + let () = Pool.connect pool |> Cache.replace in + fetch_pool ~ctx ~retries () + | Open connection when pool.Pool.n_retries > 0 -> + let () = Pool.reset_retry pool |> Cache.replace in + print_pool_usage ?tags:(LogTag.ctx_opt ~ctx ()) pool; + Lwt.return connection + | Open connection -> + print_pool_usage ?tags:(LogTag.ctx_opt ~ctx ()) pool; + Lwt.return connection) + | None -> + Exception + (Format.asprintf + "Unknown Pool: Please 'add_pool' first! (%s)" + CCOption.(find_pool_name ctx |> value ~default:"-")) + |> raise ;; - let query ?ctx f = - let open Lwt.Infix in - let pool = fetch_pool ?ctx () in - print_pool_usage pool; - Caqti_lwt_unix.Pool.use - (fun connection -> f connection >|= CCResult.return) - pool - >|= get_or_raise ?ctx () + let map_fetched ?ctx ?retries (fcn : 'a -> ('b, 'e) Lwt_result.t) = + let label = + CCOption.(bind ctx find_pool_name |> get_exn_or "Unknown pool") + in + let%lwt connection = fetch_pool ?ctx ?retries () in + fcn connection |> raise_caqti_error label ;; - let query' ?ctx f = query ?ctx f |> Lwt.map (get_or_raise ?ctx ()) + let query ?ctx f = + Caqti_lwt_unix.Pool.use (fun connection -> f connection) |> map_fetched ?ctx + ;; let find_opt ?ctx request input = - query' ?ctx (fun connection -> + query ?ctx (fun connection -> let module Connection = (val connection : Caqti_lwt.CONNECTION) in Connection.find_opt request input) ;; let find ?ctx request input = - query' ?ctx (fun connection -> + query ?ctx (fun connection -> let module Connection = (val connection : Caqti_lwt.CONNECTION) in Connection.find request input) ;; let collect ?ctx request input = - query' ?ctx (fun connection -> + query ?ctx (fun connection -> let module Connection = (val connection : Caqti_lwt.CONNECTION) in Connection.collect_list request input) ;; let exec ?ctx request input = - query' ?ctx (fun connection -> + query ?ctx (fun connection -> let module Connection = (val connection : Caqti_lwt.CONNECTION) in Connection.exec request input) ;; let populate ?ctx table columns request input = - query' ?ctx (fun connection -> + query ?ctx (fun connection -> let module Connection = (val connection : Caqti_lwt.CONNECTION) in Connection.populate ~table @@ -271,4 +269,65 @@ module Make (Config : ConfigSig) = struct (Caqti_lwt.Stream.of_list input) |> Lwt.map Caqti_error.uncongested) ;; + + let exec_each connection = + let open CCFun.Infix in + Lwt_list.map_s (fun request -> request connection) + %> Lwt.map CCResult.flatten_l + %> Lwt_result.map (fun (_ : unit list) -> ()) + ;; + + let rollback ?ctx connection error = + let (module Connection : Caqti_lwt.CONNECTION) = connection in + let label = + CCOption.(bind ctx find_pool_name |> get_exn_or "Unknown pool") + in + let%lwt () = + Connection.rollback () + |> Lwt_result.map + (CCFun.tap (fun _ -> + Logs.debug (fun m -> m "Successfully rolled back transaction"))) + |> raise_caqti_error label + in + Lwt.fail error + ;; + + let transaction + ?ctx + ?(setup : (Caqti_lwt.connection -> (unit, Caqti_error.t) Lwt_result.t) list = + []) + ?(cleanup : + (Caqti_lwt.connection -> (unit, Caqti_error.t) Lwt_result.t) list = + []) + (f : Caqti_lwt.connection -> ('a, Caqti_error.t) Lwt_result.t) + : 'a Lwt.t + = + let open Lwt_result.Syntax in + Caqti_lwt_unix.Pool.use (fun connection -> + let (module Connection : Caqti_lwt.CONNECTION) = connection in + let* () = Connection.start () in + Lwt.catch + (fun () -> + let* () = exec_each connection setup in + let* result = f connection in + let* () = exec_each connection cleanup in + match%lwt Connection.commit () with + | Ok () -> Lwt.return_ok result + | Error error -> Lwt.return_error error) + (rollback ?ctx connection)) + |> map_fetched ?ctx + ;; + + let transaction_iter ?ctx queries = + let open Lwt_result.Syntax in + Caqti_lwt_unix.Pool.use (fun connection -> + let (module Connection : Caqti_lwt.CONNECTION) = connection in + let* () = Connection.start () in + Lwt.catch + (fun () -> + let* () = exec_each connection queries in + Connection.commit ()) + (rollback ?ctx connection)) + |> map_fetched ?ctx + ;; end diff --git a/backend/database_pools.mli b/backend/database_pools.mli index 7ae5c47..6306a81 100644 --- a/backend/database_pools.mli +++ b/backend/database_pools.mli @@ -1,21 +1,9 @@ exception Exception of string -val get_or_raise - : ?ctx:(string * string) list - -> ?tags:Logs.Tag.set - -> ?log_level:Logs.level - -> ?msg_prefix:string - -> unit - -> ('a, [< Caqti_error.t ]) result - -> 'a - -type connection_type = - | SinglePool of string - | MultiPools of (string * string) list - module type ConfigSig = sig - val database : connection_type + val database : string * string val database_pool_size : int + val expected_databases : int end module DefaultConfig : ConfigSig @@ -25,19 +13,9 @@ module type Sig = Database_pools_sig.Sig module Make : functor (Config : ConfigSig) -> sig include Sig - val transaction' - : ?ctx:(string * string) list - -> (Caqti_lwt.connection -> ('a, [< Caqti_error.t ]) result) - -> 'a Lwt.t - val query : ?ctx:(string * string) list - -> (Caqti_lwt.connection -> 'a Lwt.t) - -> 'a Lwt.t - - val query' - : ?ctx:(string * string) list - -> (Caqti_lwt.connection -> ('a, [< Caqti_error.t ]) result Lwt.t) + -> (Caqti_lwt.connection -> ('a, Caqti_error.t) result Lwt.t) -> 'a Lwt.t end [@@warning "-67"] diff --git a/backend/database_pools_sig.ml b/backend/database_pools_sig.ml index ae22f9f..0c64d50 100644 --- a/backend/database_pools_sig.ml +++ b/backend/database_pools_sig.ml @@ -1,13 +1,16 @@ module type Sig = sig - val initialize : unit -> unit + val initialize : ?additinal_pools:(string * string) list -> unit -> unit val fetch_pool : ?ctx:(string * string) list + -> ?retries:int -> unit - -> (Caqti_lwt.connection, Caqti_error.t) Caqti_lwt_unix.Pool.t + -> (Caqti_lwt.connection, Caqti_error.t) Caqti_lwt_unix.Pool.t Lwt.t - val add_pool : ?pool_size:int -> string -> string -> unit + val add_pool : ?required:bool -> string -> string -> unit val drop_pool : string -> unit Lwt.t + val connect : string -> (unit, string) result + val disconnect : ?error:Caqti_error.t -> string -> unit Lwt.t val find : ?ctx:(string * string) list @@ -33,17 +36,6 @@ module type Sig = sig -> 'a -> unit Lwt.t - val transaction - : ?ctx:(string * string) list - -> (Caqti_lwt.connection -> 'a) - -> 'a Lwt.t - - val exec_with_connection - : ('a, unit, [< `Zero ]) Caqti_request.t - -> 'a - -> (module Caqti_lwt.CONNECTION) - -> unit Lwt.t - val populate : ?ctx:(string * string) list -> string @@ -51,4 +43,17 @@ module type Sig = sig -> 'a Caqti_type.t -> 'a list -> unit Lwt.t + + val transaction + : ?ctx:(string * string) list + -> ?setup:(Caqti_lwt.connection -> (unit, Caqti_error.t) Lwt_result.t) list + -> ?cleanup: + (Caqti_lwt.connection -> (unit, Caqti_error.t) Lwt_result.t) list + -> (Caqti_lwt.connection -> ('a, Caqti_error.t) Lwt_result.t) + -> 'a Lwt.t + + val transaction_iter + : ?ctx:(string * string) list + -> (Caqti_lwt.connection -> (unit, Caqti_error.t) result Lwt.t) list + -> unit Lwt.t end diff --git a/backend/dune b/backend/dune index 80f25ea..71aa5e4 100644 --- a/backend/dune +++ b/backend/dune @@ -2,6 +2,7 @@ (public_name guardian.backend) (name Guardian_backend) (libraries + base caqti caqti-driver-mariadb caqti-lwt @@ -20,4 +21,5 @@ ppx_deriving.make ppx_deriving.ord ppx_deriving.show + ppx_fields_conv ppx_string))) diff --git a/backend/mariadb_backend.ml b/backend/mariadb_backend.ml index 6292ac8..40cb99f 100644 --- a/backend/mariadb_backend.ml +++ b/backend/mariadb_backend.ml @@ -981,14 +981,15 @@ struct ;; let delete ?ctx ?comment role = - let exec = Database.exec_with_connection in - (fun conn -> - let%lwt () = - exec delete_add_history_request (role, comment) conn - in - exec delete_remove_request role conn) - |> Database.transaction ?ctx - |> combine_lwt + let with_connection request input connection = + let (module Connection : Caqti_lwt.CONNECTION) = connection in + Connection.exec request input + in + Database.transaction_iter + ?ctx + [ with_connection delete_add_history_request (role, comment) + ; with_connection delete_remove_request role + ] ;; end diff --git a/test/main.ml b/test/main.ml index c91e87e..fa38125 100644 --- a/test/main.ml +++ b/test/main.ml @@ -821,20 +821,20 @@ let () = include DefaultConfig let database = - MultiPools - [ ( test_database - , Sys.getenv_opt "DATABASE_URL" - |> CCOption.get_or ~default:"mariadb://root@database:3306/test" ) - ] + ( test_database + , Sys.getenv_opt "DATABASE_URL" + |> CCOption.get_or ~default:"mariadb://root@database:3306/test" ) ;; end in + let module Database = Make (MariaConfig) in let module Maria = Guardian_backend.MariaDb.Make (Role.Actor) (Role.Role) (Role.Target) - (Make (MariaConfig)) + (Database) in Lwt_main.run @@ + let () = Database.initialize () in let%lwt () = Maria.delete ~ctx () in let%lwt () = Maria.migrate ~ctx () in let%lwt () = Maria.clean ~ctx () in