diff --git a/.zetch.lock b/.zetch.lock index cefb38af..2eb69c2f 100644 --- a/.zetch.lock +++ b/.zetch.lock @@ -18,6 +18,7 @@ "py_rust/LICENSE.zetch.md": "d2c12e539d357957b950a54a5477c3a9f87bd2b3ee707be7a4db7adaf5aacc2b", "py_rust/README.zetch.md": "8a874afc6629819156bb8871a329b9f2b670d6610261d60e5de994d13e9cc30f", "rust/LICENSE.zetch.md": "d2c12e539d357957b950a54a5477c3a9f87bd2b3ee707be7a4db7adaf5aacc2b", - "rust/README.zetch.md": "8a874afc6629819156bb8871a329b9f2b670d6610261d60e5de994d13e9cc30f" + "rust/README.zetch.md": "8a874afc6629819156bb8871a329b9f2b670d6610261d60e5de994d13e9cc30f", + "rust/pkg/LICENSE.zetch.md": "d2c12e539d357957b950a54a5477c3a9f87bd2b3ee707be7a4db7adaf5aacc2b" } } \ No newline at end of file diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 7c367d41..6f8bfc8d 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -17,6 +17,42 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "aes-gcm-siv" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae0784134ba9375416d469ec31e7c5f9fa94405049cf08c5ce5b4698be673e0d" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "polyval", + "subtle", + "zeroize", +] + [[package]] name = "ahash" version = "0.8.11" @@ -71,6 +107,24 @@ version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d301b3b94cb4b2f23d7917810addbbaff90738e0ca2be692bd027e70d7e0330c" +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + +[[package]] +name = "argon2" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072" +dependencies = [ + "base64ct", + "blake2", + "cpufeatures", + "password-hash", +] + [[package]] name = "async-recursion" version = "1.1.1" @@ -293,6 +347,12 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + [[package]] name = "bincode" version = "1.3.3" @@ -306,8 +366,12 @@ dependencies = [ name = "bitbazaar" version = "0.1.3" dependencies = [ + "aes-gcm-siv", + "arc-swap", + "argon2", "async-semaphore", "axum-extra", + "bincode", "chrono", "chrono-humanize", "colored", @@ -321,6 +385,7 @@ dependencies = [ "homedir", "hostname", "http 1.1.0", + "indexmap 2.2.6", "itertools 0.12.1", "leptos", "leptos_axum", @@ -332,6 +397,7 @@ dependencies = [ "opentelemetry-semantic-conventions", "opentelemetry_sdk", "parking_lot", + "paste", "portpicker", "rand", "rayon", @@ -372,6 +438,15 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -500,6 +575,16 @@ dependencies = [ "half 2.4.1", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clap" version = "2.34.0" @@ -748,6 +833,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", + "rand_core", "typenum", ] @@ -772,6 +858,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + [[package]] name = "darling" version = "0.14.4" @@ -878,6 +973,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", ] [[package]] @@ -1726,6 +1822,15 @@ dependencies = [ "hashbrown 0.14.5", ] +[[package]] +name = "inout" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +dependencies = [ + "generic-array", +] + [[package]] name = "instant" version = "0.1.13" @@ -2289,6 +2394,12 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +[[package]] +name = "opaque-debug" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" + [[package]] name = "opentelemetry" version = "0.21.0" @@ -2444,6 +2555,17 @@ dependencies = [ "windows-targets 0.52.5", ] +[[package]] +name = "password-hash" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" +dependencies = [ + "base64ct", + "rand_core", + "subtle", +] + [[package]] name = "paste" version = "1.0.15" @@ -2532,6 +2654,18 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "polyval" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" +dependencies = [ + "cfg-if", + "cpufeatures", + "opaque-debug", + "universal-hash", +] + [[package]] name = "portpicker" version = "0.1.1" @@ -3278,6 +3412,12 @@ dependencies = [ "syn 2.0.66", ] +[[package]] +name = "subtle" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d0208408ba0c3df17ed26eb06992cb1a1268d41b2c0e12e65203fbe3972cee5" + [[package]] name = "syn" version = "1.0.109" @@ -3801,6 +3941,16 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + [[package]] name = "url" version = "2.5.1" @@ -4347,6 +4497,12 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" + [[package]] name = "zerovec" version = "0.10.2" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index ca5a5500..805fc2da 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -28,16 +28,20 @@ colored = '2' once_cell = '1' tracing-subscriber = { version = "0.3", features = ["fmt", "std", "time"] } serde = { version = "1", features = ["derive", "rc"] } +serde_json = { version = "1" } time = { version = "0.3", features = ["local-offset"] } futures = { version = "0.3", features = [] } async-semaphore = "1.2" gloo-timers = { version = "0.3", features = ["futures"] } itertools = "0.12" tracing-core = "0.1" +chrono = { version = '0.4', features = ["serde"] } +chrono-humanize = { version = "0.2" } +arc-swap = "1" +paste = "1" # Not in default, but randomly useful in features: strum = { version = "0.25", features = ["derive"], optional = true } -serde_json = { version = "1.0", optional = true } rand = { version = "0.8", optional = true } uuid = { version = "1.6", features = ["v4"], optional = true } axum-extra = { version = "0.9", features = [], optional = true } @@ -46,12 +50,14 @@ leptos_axum = { version = "0.6", optional = true } http = { version = "1", optional = true } portpicker = { version = '0.1', optional = true } -# FEAT: hash: -sha2 = { version = "0.10", optional = true } +# FEAT: indexmap: +indexmap = { version = "2", optional = true } -# FEAT: chrono: (but also sometimes enabled by other features) -chrono = { version = '0.4', optional = true } -chrono-humanize = { version = "0.2", optional = true } +# FEAT: crypto: +sha2 = { version = "0.10", optional = true } +aes-gcm-siv = { version = "0.11", optional = true } +argon2 = { version = "0.5", optional = true } +bincode = { version = "1", optional = true } # FEAT: log-filter: regex = { version = '1', optional = true } @@ -115,20 +121,18 @@ tokio = { version = '1', features = ["full"] } # harness = false [features] +indexmap = ["dep:indexmap"] log-filter = ["dep:regex"] -hash = ['dep:sha2'] -chrono = ['dep:chrono', 'dep:chrono-humanize'] -timing = ['dep:comfy-table', 'chrono'] -cli = ['dep:normpath', 'dep:conch-parser', 'dep:homedir', 'chrono', 'dep:strum'] +crypto = ['dep:sha2', 'dep:aes-gcm-siv', 'dep:argon2', 'dep:bincode'] +timing = ['dep:comfy-table'] +cli = ['dep:normpath', 'dep:conch-parser', 'dep:homedir', 'dep:strum'] system = ['dep:sysinfo'] redis = [ 'dep:deadpool-redis', 'dep:redis', "dep:redis-macros", 'dep:sha1_smol', - 'dep:serde_json', 'dep:rand', - 'chrono', 'dep:uuid', 'dep:portpicker', ] @@ -159,16 +163,8 @@ opentelemetry-http = [ # In general there's no point with this currently, made f rayon = ['dep:rayon'] # Cookie deps depending on wasm or not: -cookies_ssr = [ - 'chrono', - 'dep:http', - 'dep:serde_json', - 'dep:axum-extra', - 'axum-extra/cookie', - 'dep:leptos', - 'dep:leptos_axum', -] -cookies_wasm = ['chrono', 'dep:http', 'dep:serde_json', 'dep:wasm-cookies'] +cookies_ssr = ['dep:http', 'dep:axum-extra', 'axum-extra/cookie', 'dep:leptos', 'dep:leptos_axum'] +cookies_wasm = ['dep:http', 'dep:wasm-cookies'] [profile.release] strip = "debuginfo" # Note: true or "symbols" seems to break static c linking e.g. with ffmpeg. diff --git a/rust/bitbazaar/crypto/encrypt/encrypt.rs b/rust/bitbazaar/crypto/encrypt/encrypt.rs new file mode 100644 index 00000000..3d3a9a6d --- /dev/null +++ b/rust/bitbazaar/crypto/encrypt/encrypt.rs @@ -0,0 +1,135 @@ +use aes_gcm_siv::{ + aead::{generic_array::GenericArray, rand_core::RngCore, Aead, OsRng}, + Aes256GcmSiv, KeyInit, Nonce, +}; + +use crate::prelude::*; + +#[derive(serde::Serialize, serde::Deserialize)] +struct PrecryptorFile { + data: Vec, + nonce: [u8; 12], + salt: [u8; 32], +} + +/// Encrypts some data using a password, internally also a nonce and salt. +/// Uses a secure AES256-GCM-SIV algorithm (very safe in 2024). +/// +/// # Examples +/// +/// ```no_run +/// use bitbazaar::encrypt; +/// +/// let encrypted_data = encrypt::encrypt_aes256(b"example text", b"example password").expect("Failed to encrypt"); +/// // and now you can write it to a file: +/// // fs::write("encrypted_text.txt", encrypted_data).expect("Failed to write to file"); +/// ``` +/// +pub fn encrypt_aes256(data: &[u8], password: &[u8]) -> RResult, AnyErr> { + // Generating salt: + let mut salt = [0u8; 32]; + OsRng.fill_bytes(&mut salt); + + // Generating key: + // https://docs.rs/argon2/0.5.3/argon2/#key-derivation + let mut output_key_material = [0u8; 32]; // Can be any desired size + argon2::Argon2::default() + .hash_password_into(password, &salt, &mut output_key_material) + .map_err(|e| anyerr!("Failed to hash password: {:?}", e))?; + + let key = GenericArray::from_slice(&output_key_material); + let cipher = Aes256GcmSiv::new(key); + + // Generating nonce: + let mut nonce_rand = [0u8; 12]; + OsRng.fill_bytes(&mut nonce_rand); + let nonce = Nonce::from_slice(&nonce_rand); + + // Encrypting: + let ciphertext = match cipher.encrypt(nonce, data.as_ref()) { + Ok(ciphertext) => ciphertext, + Err(_) => return Err(anyerr!("Failed to encrypt data -> invalid password")), + }; + + let file = PrecryptorFile { + data: ciphertext, + nonce: nonce_rand, + salt, + }; + + // Encoding: + let encoded: Vec = bincode::serialize(&file).change_context(AnyErr)?; + + Ok(encoded) +} + +/// Decrypts some data and returns the result +/// +/// # Examples +/// +/// ```no_run +/// use bitbazaar::encrypt; +/// +/// let encrypted_data = encrypt::encrypt_aes256(b"example text", b"example password").expect("Failed to encrypt"); +/// +/// let data = encrypt::decrypt_aes256(&encrypted_data, b"example password").expect("Failed to decrypt"); +/// // and now you can print it to stdout: +/// // println!("data: {}", String::from_utf8(data.clone()).expect("Data is not a utf8 string")); +/// // or you can write it to a file: +/// // fs::write("text.txt", data).expect("Failed to write to file"); +/// ``` +/// +pub fn decrypt_aes256(data: &[u8], password: &[u8]) -> RResult, AnyErr> { + // Decoding: + let decoded: PrecryptorFile = bincode::deserialize(data).change_context(AnyErr)?; + + // Generating key: + // https://docs.rs/argon2/0.5.3/argon2/#key-derivation + let mut output_key_material = [0u8; 32]; // Can be any desired size + argon2::Argon2::default() + .hash_password_into(password, &decoded.salt, &mut output_key_material) + .map_err(|e| anyerr!("Failed to hash password: {:?}", e))?; + + let key = GenericArray::from_slice(&output_key_material); + let cipher = Aes256GcmSiv::new(key); + let nonce = Nonce::from_slice(&decoded.nonce); + + // Decrypting: + let text = match cipher.decrypt(nonce, decoded.data.as_ref()) { + Ok(ciphertext) => ciphertext, + Err(_) => return Err(anyerr!("Failed to encrypt data -> invalid password")), + }; + + Ok(text) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_encrypt_decrypt_aes256() -> RResult<(), AnyErr> { + let data = b"example text"; + let password_1 = b"example password"; + let password_2 = b"example password 2"; + + // 1. Simple encrypt/decrypt works. + let encrypted_data = encrypt_aes256(data, password_1)?; + let decrypted_data = decrypt_aes256(&encrypted_data, password_1)?; + assert_eq!(data, decrypted_data.as_slice()); + + // 2. Wrong password fails to decrypt. + let decrypted_data = decrypt_aes256(&encrypted_data, password_2); + assert!(decrypted_data.is_err()); + + // 3. Different passwords using same content leads to different encrypted data. + let encrypted_data_2 = encrypt_aes256(data, password_2)?; + assert_ne!(encrypted_data, encrypted_data_2); + assert_eq!( + data, + decrypt_aes256(&encrypted_data_2, password_2)?.as_slice() + ); + + Ok(()) + } +} diff --git a/rust/bitbazaar/crypto/encrypt/mod.rs b/rust/bitbazaar/crypto/encrypt/mod.rs new file mode 100644 index 00000000..1d815f6c --- /dev/null +++ b/rust/bitbazaar/crypto/encrypt/mod.rs @@ -0,0 +1,3 @@ +mod encrypt; + +pub use encrypt::*; diff --git a/rust/bitbazaar/hash/fnv1a.rs b/rust/bitbazaar/crypto/hash/fnv1a.rs similarity index 100% rename from rust/bitbazaar/hash/fnv1a.rs rename to rust/bitbazaar/crypto/hash/fnv1a.rs diff --git a/rust/bitbazaar/hash/mod.rs b/rust/bitbazaar/crypto/hash/mod.rs similarity index 89% rename from rust/bitbazaar/hash/mod.rs rename to rust/bitbazaar/crypto/hash/mod.rs index 6b3ca4b1..d6c685fd 100644 --- a/rust/bitbazaar/hash/mod.rs +++ b/rust/bitbazaar/crypto/hash/mod.rs @@ -1,6 +1,8 @@ mod fnv1a; +mod password; -pub use fnv1a::fnv1a; +pub use fnv1a::*; +pub use password::*; /// SHA256 hash function. /// diff --git a/rust/bitbazaar/crypto/hash/password.rs b/rust/bitbazaar/crypto/hash/password.rs new file mode 100644 index 00000000..a2e90ee2 --- /dev/null +++ b/rust/bitbazaar/crypto/hash/password.rs @@ -0,0 +1,76 @@ +use argon2::password_hash::{PasswordHasher, PasswordVerifier}; + +use crate::prelude::*; + +/// Hash a password to a "PHC string" for intermediary password storage. +/// +/// Uses argon2, should be very secure. +pub fn password_hash_argon2id_19(password: &str) -> RResult { + // https://docs.rs/argon2/0.5.3/argon2/#password-hashing + let salt = + argon2::password_hash::SaltString::generate(&mut argon2::password_hash::rand_core::OsRng); + let argon2 = argon2id_19_config(); + Ok(argon2 + .hash_password(password.as_bytes(), &salt) + .map_err(|e| anyerr!("Failed to hash password: {:?}", e))? + .to_string()) +} + +/// Verify an entered password matches a stored "PHC string" password hash. +/// +/// Uses argon2, should be very secure. +pub fn password_verify_argon2id_19( + entered_pswd: &str, + real_pswd_hash: &str, +) -> RResult { + // https://docs.rs/argon2/0.5.3/argon2/#password-hashing + let parsed_hash = argon2::password_hash::PasswordHash::new(real_pswd_hash) + .map_err(|e| anyerr!("Failed to parse password hash: {:?}", e))?; + match argon2id_19_config().verify_password(entered_pswd.as_bytes(), &parsed_hash) { + Ok(_noop) => Ok(true), + Err(e) => match e { + argon2::password_hash::Error::Password => Ok(false), + _ => Err(anyerr!("Failed to verify password: {:?}", e)), + }, + } +} + +fn argon2id_19_config() -> argon2::Argon2<'static> { + // These are the defaults currently, just future proofing if they change. + argon2::Argon2::new( + argon2::Algorithm::Argon2id, + // This is v19 in hex. + argon2::Version::V0x13, + // Keeping default as don't think this will break hashes. + argon2::Params::default(), + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_password_hash_argon2id_19() -> RResult<(), AnyErr> { + let pswd = "password"; + let hash = password_hash_argon2id_19(pswd)?; + assert_ne!(pswd, &hash); + assert!(password_verify_argon2id_19(pswd, &hash)?); + assert!(!password_verify_argon2id_19("wrong", &hash)?); + + let pswd_2 = "password2"; + let hash_2 = password_hash_argon2id_19(pswd_2)?; + assert_ne!(pswd, &hash_2); + assert_ne!(hash, hash_2); + assert!(password_verify_argon2id_19(pswd_2, &hash_2)?); + assert!(!password_verify_argon2id_19("wrong", &hash_2)?); + + // Hashing the same password twice should produce 2 different hashes because of the in-built salt: + let hash_3 = password_hash_argon2id_19(pswd)?; + assert_ne!(hash, hash_3); + assert!(password_verify_argon2id_19(pswd, &hash)?); + assert!(password_verify_argon2id_19(pswd, &hash_3)?); + + Ok(()) + } +} diff --git a/rust/bitbazaar/crypto/mod.rs b/rust/bitbazaar/crypto/mod.rs new file mode 100644 index 00000000..28af4bea --- /dev/null +++ b/rust/bitbazaar/crypto/mod.rs @@ -0,0 +1,5 @@ +/// Encryption utilities. +pub mod encrypt; + +/// Hashing utilities. +pub mod hash; diff --git a/rust/bitbazaar/lib.rs b/rust/bitbazaar/lib.rs index 68102245..17fa1466 100644 --- a/rust/bitbazaar/lib.rs +++ b/rust/bitbazaar/lib.rs @@ -11,17 +11,16 @@ mod prelude; /// Command line interface utilities. pub mod cli; -#[cfg(feature = "chrono")] /// Chrono utilities pub mod chrono; #[cfg(any(feature = "cookies_ssr", feature = "cookies_wasm"))] /// Setting/getting cookies in wasm or ssr. pub mod cookies; +#[cfg(feature = "crypto")] +/// Hashing & encryption utilities. +pub mod crypto; /// Error handling utilities. pub mod errors; -#[cfg(feature = "hash")] -/// Hashing utilities. -pub mod hash; /// Logging utilities pub mod log; /// Completely miscellaneous utilities diff --git a/rust/bitbazaar/misc/binary_search.rs b/rust/bitbazaar/misc/binary_search.rs index 219333db..ec947cb4 100644 --- a/rust/bitbazaar/misc/binary_search.rs +++ b/rust/bitbazaar/misc/binary_search.rs @@ -1,24 +1,127 @@ use std::cmp::Ordering; +/// A trait allowing arbitrary types to be binary searchable. +pub trait BinarySearchable { + /// The type of the items in the array. + type Item; + + /// Get the length of the array. + fn len(&self) -> usize; + + /// Get the item at the given index. + fn get(&self, index: usize) -> &Self::Item; + + /// Check if the array is empty. + fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl BinarySearchable for Vec { + type Item = T; + + fn len(&self) -> usize { + self.len() + } + + fn get(&self, index: usize) -> &T { + &self[index] + } +} + +impl BinarySearchable for [T; N] { + type Item = T; + + fn len(&self) -> usize { + N + } + + fn get(&self, index: usize) -> &T { + &self[index] + } +} + +impl BinarySearchable for [T] { + type Item = T; + + fn len(&self) -> usize { + self.len() + } + + fn get(&self, index: usize) -> &T { + &self[index] + } +} + +#[cfg(feature = "indexmap")] +impl BinarySearchable for indexmap::IndexMap { + type Item = V; + + fn len(&self) -> usize { + self.len() + } + + fn get(&self, index: usize) -> &V { + &self[index] + } +} + /// Binary search for the desired item in the array. /// Returns None when not exactly found. /// /// Arguments: -/// - arr: The array to search. -/// - comparer: Fn(&Item) that should return `$target.cmp($item_val)`. (i.e. "where to go from here") +/// - arr_like: The array-like structure to search. +/// - comparer: `Fn(&Item)` that should return `$target.cmp($item_val)`. (i.e. "where to go from here") /// /// Returns: /// - Some(index) the index of the item in the array. /// - None if the item is not found. pub fn binary_search_exact( - arr: &[Item], + arr_like: &impl BinarySearchable, comparer: impl Fn(&Item) -> Ordering, ) -> Option { let mut low = 0; - let mut high = arr.len() - 1; + let mut high = arr_like.len() - 1; while low <= high { let mid = (low + high) / 2; - match comparer(&arr[mid]) { + match comparer(arr_like.get(mid)) { + Ordering::Greater => low = mid + 1, + Ordering::Less => high = mid - 1, + Ordering::Equal => return Some(mid), + } + } + None +} + +/// Binary search for the desired item in the array, but also passes the prior and next items into the comparison function, if they exist. +/// Returns None when not exactly found. +/// +/// Arguments: +/// - arr_like: The array-like structure to search. +/// - comparer: `Fn(Option<&Item>, &Item, Option<&Item>)` (prior, target, next) that should return `$target.cmp($item_val)`. (i.e. "where to go from here") +/// +/// Returns: +/// - Some(index) the index of the item in the array. +/// - None if the item is not found. +pub fn binary_search_exact_with_siblings( + arr_like: &impl BinarySearchable, + comparer: impl Fn(Option<&Item>, &Item, Option<&Item>) -> Ordering, +) -> Option { + let mut low = 0; + let mut high = arr_like.len() - 1; + while low <= high { + let mid = (low + high) / 2; + let prior = if mid == 0 { + None + } else { + Some(arr_like.get(mid - 1)) + }; + let next = if mid == arr_like.len() - 1 { + None + } else { + Some(arr_like.get(mid + 1)) + }; + match comparer(prior, arr_like.get(mid), next) { Ordering::Greater => low = mid + 1, Ordering::Less => high = mid - 1, Ordering::Equal => return Some(mid), @@ -31,25 +134,25 @@ pub fn binary_search_exact( /// Only returns None when the array is empty. /// /// Arguments: -/// - arr: The array to search. -/// - comparer: Fn(&Item) that should return `$target.cmp($item_val)`. (i.e. "where to go from here") +/// - arr_like: The array-like structure to search. +/// - comparer: `Fn(&Item)` that should return `$target.cmp($item_val)`. (i.e. "where to go from here") /// /// Returns: /// - Some(index) the index of the item in the array, OR THE CLOSEST. /// - None if the item is not found. pub fn binary_search_soft( - arr: &[Item], + arr_like: &impl BinarySearchable, comparer: impl Fn(&Item) -> Ordering, ) -> Option { - if arr.is_empty() { + if arr_like.is_empty() { None } else { let mut low = 0; - let mut high = arr.len() - 1; + let mut high = arr_like.len() - 1; let mut mid = 0; while low <= high { mid = (low + high) / 2; - match comparer(&arr[mid]) { + match comparer(arr_like.get(mid)) { Ordering::Greater => low = mid + 1, Ordering::Less => high = mid - 1, Ordering::Equal => return Some(mid), @@ -72,6 +175,32 @@ mod tests { assert_eq!(binary_search_exact(&arr, |x| 10.cmp(x)), Some(8)); assert_eq!(binary_search_exact(&arr, |x| 7.cmp(x)), None); + // With siblings just check they're being passed around correctly: + assert_eq!( + binary_search_exact_with_siblings(&arr, |prior, x, next| { + if *x == 1 { + assert_eq!(prior, None); + assert_eq!(next, Some(&2)); + } else if *x == 10 { + assert_eq!(prior, Some(&9)); + assert_eq!(next, None); + } else if *x == 9 { + assert_eq!(prior, Some(&8)); + assert_eq!(next, Some(&10)); + } else if *x == 8 { + assert_eq!(prior, Some(&6)); + assert_eq!(next, Some(&9)); + } else { + // These are now all prior the gap so generic mapping: + assert_eq!(prior, Some(&arr[x - 2])); + assert_eq!(next, Some(&arr[*x])); + } + 5.cmp(x) + }), + Some(4) + ); + + // Nonexact: assert_eq!(binary_search_soft(&arr, |x| 5.cmp(x)), Some(4)); assert_eq!(binary_search_soft(&arr, |x| 1.cmp(x)), Some(0)); assert_eq!(binary_search_soft(&arr, |x| 10.cmp(x)), Some(8)); diff --git a/rust/bitbazaar/misc/mod.rs b/rust/bitbazaar/misc/mod.rs index fd49f8b7..f15bb677 100644 --- a/rust/bitbazaar/misc/mod.rs +++ b/rust/bitbazaar/misc/mod.rs @@ -6,7 +6,10 @@ mod flexi_logger; mod in_ci; mod is_tcp_port_listening; mod periodic_updater; +#[cfg(feature = "redis")] +mod refreshable; mod retry_backoff; +mod serde_migratable; mod sleep_compat; pub use binary_search::*; @@ -14,5 +17,8 @@ pub use flexi_logger::*; pub use in_ci::in_ci; pub use is_tcp_port_listening::is_tcp_port_listening; pub use periodic_updater::*; +#[cfg(feature = "redis")] +pub use refreshable::*; pub use retry_backoff::*; +pub use serde_migratable::*; pub use sleep_compat::*; diff --git a/rust/bitbazaar/misc/refreshable.rs b/rust/bitbazaar/misc/refreshable.rs new file mode 100644 index 00000000..4af497fe --- /dev/null +++ b/rust/bitbazaar/misc/refreshable.rs @@ -0,0 +1,219 @@ +use std::sync::{atomic::AtomicU64, Arc, OnceLock}; + +use arc_swap::ArcSwap; +use futures::{future::BoxFuture, Future, FutureExt}; + +pub use arc_swap::Guard as RefreshableGuard; + +use crate::{ + prelude::*, + redis::{Redis, RedisBatchFire, RedisBatchReturningOps, RedisConn}, +}; + +/// A data wrapper that automatically updates the data given out when deemed stale. +/// The data is set to refresh at a certain interval (triggered on access), or can be forcefully refreshed. +pub struct Refreshable { + redis: Redis, + redis_namespace: String, + redis_key: String, + redis_mutate_key: String, + mutate_id: AtomicU64, + // Don't want to hold lock when giving out data, so opposite to normal pattern: + data: OnceLock>, + // To prevent stopping Send/Sync working for this struct, these need to be both: + getter: Arc BoxFuture<'static, RResult> + Sync + Send>>, + setter: Arc BoxFuture<'static, RResult<(), AnyErr>> + Sync + Send>>, + on_mutate: Option RResult<(), AnyErr> + 'static + Send + Sync>>>, + last_updated_utc_ms: AtomicU64, + force_refresh_every_ms: u64, +} + +impl Refreshable { + /// Creates a new refreshable data wrapper. + /// This will only call the getter on first access, not on instanstiation. + /// + /// Arguments: + /// - `redis`: The redis wrapper itself needed for dlocking. + /// - `redis_namespace`: The namespace to use for the redis key when locking during setting. + /// - `redis_key`: The key to use for the redis key when locking during setting. + /// - `force_refresh_every`: The interval for forceful data should be refreshed. For when something other than a `Refreshable` container updates, but still good as a backup. + /// - `getter`: A function that returns a future that resolves to the data. + /// - `setter`: A function that updates the source with new data. + pub fn new< + FutGet: Future> + Send + 'static, + FutSet: Future> + Send + 'static, + >( + redis: &Redis, + redis_namespace: impl Into, + redis_key: impl Into, + force_refresh_every: std::time::Duration, + getter: impl Fn() -> FutGet + 'static + Sync + Send, + setter: impl Fn(T) -> FutSet + 'static + Sync + Send, + ) -> RResult { + let redis_key = redis_key.into(); + Ok(Self { + redis: redis.clone(), + redis_namespace: redis_namespace.into(), + redis_mutate_key: format!("{}_mutater", redis_key), + redis_key, + mutate_id: AtomicU64::new(0), + data: OnceLock::new(), + getter: Arc::new(Box::new(move || getter().boxed())), + setter: Arc::new(Box::new(move |data| setter(data).boxed())), + on_mutate: None, + last_updated_utc_ms: AtomicU64::new(utc_now_ms()), + force_refresh_every_ms: force_refresh_every.as_millis() as u64, + }) + } + + /// Do something whenever a mutation happens, useful to hook in other mutations while a mut self is available. + pub fn on_mutate( + mut self, + on_mutate: impl Fn(&mut T) -> RResult<(), AnyErr> + 'static + Send + Sync, + ) -> Self { + self.on_mutate = Some(Arc::new(Box::new(on_mutate))); + self + } + + /// Update stored data: + async fn set_data(&self, new_data: T) -> RResult<(), AnyErr> { + self.last_updated_utc_ms + .store(utc_now_ms(), std::sync::atomic::Ordering::Relaxed); + self.data().await?.store(Arc::new(new_data)); + Ok(()) + } + + /// Internal of `sync`, doesn't set the new data so can be used in mutator. + async fn sync_no_set(&self, conn: &mut RedisConn<'_>) -> RResult, AnyErr> { + let mutate_id_changed = { + if let Some(Some(current_mutate_id)) = conn + .batch() + .get::(&self.redis_namespace, &self.redis_mutate_key) + .fire() + .await + { + // Check if different, simultaneously setting the new value: + current_mutate_id + != self + .mutate_id + .swap(current_mutate_id, std::sync::atomic::Ordering::Relaxed) + } else { + false + } + }; + if mutate_id_changed + || utc_now_ms() + - self + .last_updated_utc_ms + .load(std::sync::atomic::Ordering::Relaxed) + > self.force_refresh_every_ms + { + Ok(Some((self.getter)().await?)) + } else { + Ok(None) + } + } + + /// Resyncs the data with the source, when it becomes stale (hard refresh), or redis mutate id changes, meaning a different node has updated the data. + async fn sync(&self, conn: &mut RedisConn<'_>) -> RResult<(), AnyErr> { + if let Some(new_data) = self.sync_no_set(conn).await? { + self.set_data(new_data).await?; + } + Ok(()) + } + + /// Access the currently stored data, initializing the `OnceLock` if empty. + async fn data(&self) -> RResult<&ArcSwap, AnyErr> { + if let Some(val) = self.data.get() { + Ok(val) + } else { + let new_data = (self.getter)().await?; + let _ = self.data.set(ArcSwap::from(Arc::new(new_data))); + Ok(self + .data + .get() + .ok_or_else(|| anyerr!("Failed to set data"))?) + } + } + + /// Update the data in the refreshable with key features: + /// - Locks the source using a redis dlock for the duration of the update. + /// - Refreshes the data before the update inside the locked section, + /// to make sure you're doing the update on the latest data and not overwriting changes. + /// - Updates the setter, thanks to above guaranteed no sibling node overwrites etc. + /// + /// NOTE: returns a double result to allow custom internal error types to be passed out. + pub async fn mutate( + &self, + conn: &mut RedisConn<'_>, + mutator: impl FnOnce(&mut T) -> Result<(), E>, + ) -> RResult, AnyErr> { + self.redis + .dlock_for_fut( + &self.redis_namespace, + &self.redis_key, + // Really don't want to miss updates: + Some(std::time::Duration::from_secs(30)), + async move { + // Make sure working with up-to-date data: + let mut data = if let Some(data) = self.sync_no_set(conn).await? { + data + } else { + (**self.data().await?.load()).clone() + }; + // Mutate the up-to-date data: + match mutator(&mut data) { + Ok(_) => { + // Run the on_mutate hook if there: + if let Some(on_mutate) = &self.on_mutate { + on_mutate(&mut data)?; + } + + // Update the source with the new data: + (self.setter)(data.clone()).await?; + // Update the mutate id to signal to other nodes that the data has changed: + let new_mutate_id = rand::random(); + conn.batch() + .set( + &self.redis_namespace, + &self.redis_mutate_key, + new_mutate_id, + None, + ) + .fire() + .await; + self.mutate_id + .store(new_mutate_id, std::sync::atomic::Ordering::Relaxed); + self.set_data(data).await?; + Ok::<_, Report>(Ok(())) + } + Err(e) => Ok(Err(e)), + } + }, + ) + .await + .change_context(AnyErr) + } + + /// Force a refresh of the data. + pub async fn refresh(&self) -> RResult<(), AnyErr> { + let new_data = (self.getter)().await?; + self.set_data(new_data).await?; + Ok(()) + } + + /// Get the underlying data for use. + /// If the data is stale, it will be refreshed before returning. + /// + /// NOTE: the implementation of the guards means not too many should be alive at once, and keeping across await points should be discouraged. + /// If you need long access to the underlying data, consider cloning it. + pub async fn get(&self, conn: &mut RedisConn<'_>) -> RResult>, AnyErr> { + // Refresh if stale or mutate id in redis changes: + self.sync(conn).await?; + Ok(self.data().await?.load()) + } +} + +fn utc_now_ms() -> u64 { + chrono::Utc::now().timestamp_millis() as u64 +} diff --git a/rust/bitbazaar/misc/serde_migratable.rs b/rust/bitbazaar/misc/serde_migratable.rs new file mode 100644 index 00000000..e51873ab --- /dev/null +++ b/rust/bitbazaar/misc/serde_migratable.rs @@ -0,0 +1,181 @@ +use crate::prelude::*; + +macro_rules! try_fallbacks { + ($e1:expr, $value:expr) => { + match Self::from_last(&$value) { + Ok(deserialized) => Ok(deserialized), + Err(e2) => match Self::from_next(&$value) { + Ok(deserialized) => Ok(deserialized), + Err(e3) => Err(anyerr!( + "Failed to deserialize to target directly or through either migratable type." + ) + .attach_printable(format!("direct error: {:?}", $e1)) + .attach_printable(format!("from_last error: {:?}", e2)) + .attach_printable(format!("from_next error: {:?}", e3))), + }, + } + }; +} + +/// A trait to help with migrating data structures when they change. +pub trait SerdeMigratable: Sized + serde::de::DeserializeOwned { + /// How to convert from the last version to the current version. + fn from_last(last: &serde_json::Value) -> RResult; + + /// Optional, how to convert from the next back to this, can be useful if rolling back. + fn from_next(_next: &serde_json::Value) -> RResult { + Err(anyerr!( + "Not implemented! Need to implement 'from_next' method on the 'SerdeMigratable' trait for rollback to work." + )) + } + + /// Deserialize from a string, trying to convert from legacy types if needed. + fn from_str(src: &str) -> RResult { + match serde_json::from_str::(src).change_context(AnyErr) { + Ok(deserialized) => Ok(deserialized), + Err(e1) => { + let value: serde_json::Value = serde_json::from_str(src).change_context(AnyErr)?; + try_fallbacks!(e1, value) + } + } + } + + /// Deserialize from a slice, trying to convert from legacy types if needed. + fn from_slice(src: &[u8]) -> RResult { + match serde_json::from_slice::(src).change_context(AnyErr) { + Ok(deserialized) => Ok(deserialized), + Err(e1) => { + let value: serde_json::Value = + serde_json::from_slice(src).change_context(AnyErr)?; + try_fallbacks!(e1, value) + } + } + } + + /// Deserialize from a value, trying to convert from legacy types if needed. + fn from_value(src: serde_json::Value) -> RResult { + match serde_json::from_value::(src.clone()).change_context(AnyErr) { + Ok(deserialized) => Ok(deserialized), + Err(e1) => try_fallbacks!(e1, src), + } + } + + /// Deserialize from a reader, trying to convert from legacy types if needed. + fn from_reader(mut src: R) -> RResult { + let mut buffer = Vec::new(); + src.read_to_end(&mut buffer).change_context(AnyErr)?; + match serde_json::from_slice::(&buffer).change_context(AnyErr) { + Ok(deserialized) => Ok(deserialized), + Err(e1) => { + let value: serde_json::Value = + serde_json::from_slice(&buffer).change_context(AnyErr)?; + try_fallbacks!(e1, value) + } + } + } +} + +#[cfg(test)] +mod tests { + use serde::{Deserialize, Serialize}; + + use super::*; + use crate::testing::prelude::*; + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct V1 { + v1: String, + } + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct V2 { + v2: String, + } + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct V3 { + v3: String, + } + + impl SerdeMigratable for V2 { + fn from_last(last: &serde_json::Value) -> RResult { + let v1: V1 = serde_json::from_value(last.clone()).change_context(AnyErr)?; + Ok(Self { v2: v1.v1 }) + } + + fn from_next(next: &serde_json::Value) -> RResult { + let v3: V3 = serde_json::from_value(next.clone()).change_context(AnyErr)?; + Ok(Self { v2: v3.v3 }) + } + } + + fn setup() -> RResult<(V2, String, String, String), AnyErr> { + let v1 = V1 { + v1: "hello".to_string(), + }; + let v1_str = serde_json::to_string(&v1).change_context(AnyErr)?; + let v2 = V2 { + v2: "hello".to_string(), + }; + let v2_str = serde_json::to_string(&v2).change_context(AnyErr)?; + let v3 = V3 { + v3: "hello".to_string(), + }; + let v3_str = serde_json::to_string(&v3).change_context(AnyErr)?; + Ok((v2, v1_str, v2_str, v3_str)) + } + + #[rstest] + fn test_serde_migratable_from_str() -> RResult<(), AnyErr> { + let (v2, v1_str, v2_str, v3_str) = setup()?; + assert_eq!(v2, V2::from_str(&v2_str)?); + assert_eq!(v2, V2::from_str(&v1_str)?); + assert_eq!(v2, V2::from_str(&v3_str)?); + Ok(()) + } + + #[rstest] + fn test_serde_migratable_from_slice() -> RResult<(), AnyErr> { + let (v2, v1_str, v2_str, v3_str) = setup()?; + assert_eq!(v2, V2::from_slice(v2_str.as_bytes())?); + assert_eq!(v2, V2::from_slice(v1_str.as_bytes())?); + assert_eq!(v2, V2::from_slice(v3_str.as_bytes())?); + Ok(()) + } + + #[rstest] + fn test_serde_migratable_from_value() -> RResult<(), AnyErr> { + let (v2, v1_str, v2_str, v3_str) = setup()?; + assert_eq!( + v2, + V2::from_value(serde_json::from_str(&v2_str).change_context(AnyErr)?)? + ); + assert_eq!( + v2, + V2::from_value(serde_json::from_str(&v1_str).change_context(AnyErr)?)? + ); + assert_eq!( + v2, + V2::from_value(serde_json::from_str(&v3_str).change_context(AnyErr)?)? + ); + Ok(()) + } + + #[rstest] + fn test_serde_migratable_from_reader() -> RResult<(), AnyErr> { + let (v2, v1_str, v2_str, v3_str) = setup()?; + assert_eq!( + v2, + V2::from_reader(std::io::Cursor::new(v2_str.as_bytes()))? + ); + assert_eq!( + v2, + V2::from_reader(std::io::Cursor::new(v1_str.as_bytes()))? + ); + assert_eq!( + v2, + V2::from_reader(std::io::Cursor::new(v3_str.as_bytes()))? + ); + Ok(()) + } +} diff --git a/rust/bitbazaar/redis/batch.rs b/rust/bitbazaar/redis/batch.rs index aaee0290..c4888176 100644 --- a/rust/bitbazaar/redis/batch.rs +++ b/rust/bitbazaar/redis/batch.rs @@ -106,11 +106,11 @@ impl<'a, 'b, 'c, ReturnType> RedisBatch<'a, 'b, 'c, ReturnType> { /// Expire an existing key with a new/updated ttl. /// /// https://redis.io/commands/pexpire/ - pub fn expire(mut self, namespace: &str, key: &str, ttl: std::time::Duration) -> Self { + pub fn expire(mut self, namespace: &str, key: &str, ttl: chrono::Duration) -> Self { self.pipe .pexpire( self.redis_conn.final_key(namespace, key.into()), - ttl.as_millis() as i64, + ttl.num_milliseconds(), ) // Ignoring so it doesn't take up a space in the tuple response. .ignore(); @@ -136,7 +136,7 @@ impl<'a, 'b, 'c, ReturnType> RedisBatch<'a, 'b, 'c, ReturnType> { mut self, set_namespace: &str, set_key: &str, - set_ttl: Option, + set_ttl: Option, score: i64, value: impl ToRedisArgs, ) -> Self { @@ -205,7 +205,7 @@ impl<'a, 'b, 'c, ReturnType> RedisBatch<'a, 'b, 'c, ReturnType> { mut self, set_namespace: &str, set_key: &str, - set_ttl: Option, + set_ttl: Option, items: &[(i64, impl ToRedisArgs)], ) -> Self { self.pipe @@ -261,16 +261,16 @@ impl<'a, 'b, 'c, ReturnType> RedisBatch<'a, 'b, 'c, ReturnType> { namespace: &str, key: &str, value: T, - expiry: Option, + expiry: Option, ) -> Self { let final_key = self.redis_conn.final_key(namespace, key.into()); if let Some(expiry) = expiry { - // If expiry is weirdly 0 don't send to prevent redis error: - if (expiry) > std::time::Duration::from_millis(0) { + // If expiry is 0 or negative don't send to prevent redis error: + if expiry > chrono::Duration::zero() { // Ignoring so it doesn't take up a space in the tuple response. self.pipe - .pset_ex(final_key, value, expiry.as_millis() as u64) + .pset_ex(final_key, value, expiry.num_milliseconds() as u64) .ignore(); } } else { @@ -293,7 +293,7 @@ impl<'a, 'b, 'c, ReturnType> RedisBatch<'a, 'b, 'c, ReturnType> { mut self, namespace: &str, pairs: impl IntoIterator, Value)>, - expiry: Option, + expiry: Option, ) -> Self { let final_pairs = pairs .into_iter() @@ -307,10 +307,10 @@ impl<'a, 'b, 'c, ReturnType> RedisBatch<'a, 'b, 'c, ReturnType> { if let Some(expiry) = expiry { // If expiry is weirdly 0 don't send to prevent redis error: - if (expiry) > std::time::Duration::from_millis(0) { + if (expiry) > chrono::Duration::milliseconds(0) { let mut invoker = MSET_WITH_EXPIRY_SCRIPT .invoker() - .arg(expiry.as_millis() as u64); + .arg(expiry.num_milliseconds() as u64); for (key, value) in final_pairs { invoker = invoker.key(key).arg(value); } diff --git a/rust/bitbazaar/redis/conn.rs b/rust/bitbazaar/redis/conn.rs index 7f21ed6d..b3455cd4 100644 --- a/rust/bitbazaar/redis/conn.rs +++ b/rust/bitbazaar/redis/conn.rs @@ -1,9 +1,10 @@ use std::{borrow::Cow, future::Future}; use deadpool_redis::redis::{FromRedisValue, ToRedisArgs}; +use once_cell::sync::Lazy; use super::batch::{RedisBatch, RedisBatchFire, RedisBatchReturningOps}; -use crate::errors::prelude::*; +use crate::{errors::prelude::*, redis::RedisScript}; /// Wrapper around a lazy redis connection. pub struct RedisConn<'a> { @@ -52,6 +53,57 @@ impl<'a> RedisConn<'a> { } } + /// A simple ratelimit/backoff helper. + /// Can be used to protect against repeated attempts in quick succession. + /// Once `start_delaying_after_attempt` is hit, the operation delay will multiplied by the multiplier each time. + /// Only once no call is made for the duration of the current delay (so current delay doubled) will the attempt number reset to zero. + /// + /// Arguments: + /// - `namespace`: A unique identifier for the endpoint, e.g. user-login. + /// - `caller_identifier`: A unique identifier for the caller, e.g. a user id. + /// - `start_delaying_after_attempt`: The number of attempts before the delays start being imposed. + /// - `initial_delay`: The initial delay to impose. + /// - `multiplier`: The multiplier to apply, `(attempt-start_delaying_after_attempt) * multiplier * initial_delay = delay`. + /// + /// Returns: + /// - `None`: Continue with the operation. + /// - `Some`: Retry after the duration. + pub async fn backoff_protector( + &mut self, + namespace: &str, + caller_identifier: &str, + start_delaying_after_attempt: usize, + initial_delay: chrono::Duration, + multiplier: f64, + ) -> Option { + static LUA_BACKOFF_SCRIPT: Lazy = + Lazy::new(|| RedisScript::new(include_str!("lua_scripts/backoff_protector.lua"))); + + let final_key = self.final_key(namespace, caller_identifier.into()); + let result = self + .batch() + .script::( + LUA_BACKOFF_SCRIPT + .invoker() + .key(final_key) + .arg(start_delaying_after_attempt) + .arg(initial_delay.num_milliseconds()) + .arg(multiplier), + ) + .fire() + .await; + + if let Some(result) = result { + if result > 0 { + Some(chrono::Duration::milliseconds(result)) + } else { + None + } + } else { + None + } + } + /// Get a new [`RedisBatch`] for this connection that commands can be piped together with. pub fn batch<'ref_lt>(&'ref_lt mut self) -> RedisBatch<'ref_lt, 'a, '_, ()> { RedisBatch::new(self) @@ -81,7 +133,7 @@ impl<'a> RedisConn<'a> { &mut self, namespace: &str, key: K, - expiry: Option, + expiry: Option, cb: impl FnOnce() -> Fut, ) -> RResult where diff --git a/rust/bitbazaar/redis/dlock.rs b/rust/bitbazaar/redis/dlock.rs index e98405cd..f43d3c30 100644 --- a/rust/bitbazaar/redis/dlock.rs +++ b/rust/bitbazaar/redis/dlock.rs @@ -86,7 +86,7 @@ impl<'a> RedisLock<'a> { /// Creates a new lock, use [`super::Redis::dlock`] instead. pub(crate) async fn new( redis: &'a super::Redis, - namespace: &'static str, + namespace: &str, lock_key: &str, ttl: Duration, wait_up_to: Option, @@ -389,7 +389,7 @@ pub async fn redis_dlock_tests(r: &super::Redis) -> RResult<(), AnyErr> { macro_rules! assert_td_in_range { ($td:expr, $range:expr) => { assert!( - $range.contains(&$td), + $td >= $range.start && $td <= $range.end, "Expected '{}' to be in range '{}' - '{}'.", chrono_format_td($td, true), chrono_format_td($range.start, true), diff --git a/rust/bitbazaar/redis/lua_scripts/backoff_protector.lua b/rust/bitbazaar/redis/lua_scripts/backoff_protector.lua new file mode 100644 index 00000000..c42d5701 --- /dev/null +++ b/rust/bitbazaar/redis/lua_scripts/backoff_protector.lua @@ -0,0 +1,43 @@ +local start_delaying_after_attempt = tonumber(ARGV[1]) +local initial_delay_ms = tonumber(ARGV[2]) +local multiplier = tonumber(ARGV[3]) +local key = KEYS[1] + +-- Attempt to get the value of the key +local value = redis.call("GET", key) +local past_attempts +local last_attempt_at_utc_ms +if value then + local tuple = cjson.decode(value) + past_attempts = tuple[1] + last_attempt_at_utc_ms = tuple[2] +else + past_attempts = 0 + last_attempt_at_utc_ms = nil +end + +local active_delay_ms +local next_delay_ms +if past_attempts >= start_delaying_after_attempt then + local applicable_attempts = past_attempts - start_delaying_after_attempt + active_delay_ms = initial_delay_ms * math.pow(multiplier, applicable_attempts) + next_delay_ms = initial_delay_ms * math.pow(multiplier, applicable_attempts + 1) +else + active_delay_ms = 0 + next_delay_ms = initial_delay_ms +end + +local now_utc_ms = redis.call("TIME")[1] * 1000 + redis.call("TIME")[2] / 1000 +if last_attempt_at_utc_ms then + if now_utc_ms - last_attempt_at_utc_ms < active_delay_ms then + return active_delay_ms - (now_utc_ms - last_attempt_at_utc_ms) + else + -- Expiry delay * 2 as this is when the attempt count resets completely. + redis.call("SET", key, cjson.encode({past_attempts + 1, now_utc_ms}), "PX", next_delay_ms * 2) + return 0 + end +else -- First attempt + -- Expiry delay * 2 as this is when the attempt count resets completely. + redis.call("SET", key, cjson.encode({1, now_utc_ms}), "PX", next_delay_ms * 2) + return 0 +end \ No newline at end of file diff --git a/rust/bitbazaar/redis/mod.rs b/rust/bitbazaar/redis/mod.rs index 91bbb692..37a186d5 100644 --- a/rust/bitbazaar/redis/mod.rs +++ b/rust/bitbazaar/redis/mod.rs @@ -25,15 +25,13 @@ pub use wrapper::Redis; #[cfg(test)] mod tests { - use std::{ - sync::{atomic::AtomicU8, Arc}, - time::Duration, - }; + use std::sync::{atomic::AtomicU8, Arc}; use rstest::*; use super::*; use crate::{ + chrono::chrono_format_td, errors::prelude::*, log::GlobalLog, redis::{dlock::redis_dlock_tests, temp_list::redis_temp_list_tests}, @@ -304,7 +302,7 @@ mod tests { .cached_fn( "my_fn_ex_group", "foo", - Some(Duration::from_millis(15)), + Some(chrono::Duration::milliseconds(15)), || async { // Add one to the call count, should only be called once: called.fetch_add(1, std::sync::atomic::Ordering::SeqCst); @@ -327,17 +325,17 @@ mod tests { // <--- set/mset with expiry: work_conn .batch() - .set("e1", "foo", "foo", Some(Duration::from_millis(15))) - .set("e1", "bar", "bar", Some(Duration::from_millis(30))) + .set("e1", "foo", "foo", Some(chrono::Duration::milliseconds(15))) + .set("e1", "bar", "bar", Some(chrono::Duration::milliseconds(30))) .mset( "e2", [("foo", "foo"), ("bar", "bar")], - Some(Duration::from_millis(15)), + Some(chrono::Duration::milliseconds(15)), ) .mset( "e2", [("baz", "baz"), ("qux", "qux")], - Some(Duration::from_millis(30)), + Some(chrono::Duration::milliseconds(30)), ) .fire() .await; @@ -383,7 +381,13 @@ mod tests { .batch() .zadd("z1", "myset", None, 3, "foo") // By setting an expiry time here, the set itself will now expire after 30ms: - .zadd("z1", "myset", Some(Duration::from_millis(30)), 1, "bar") + .zadd( + "z1", + "myset", + Some(chrono::Duration::milliseconds(30)), + 1, + "bar" + ) .zadd_multi( "z1", "myset", @@ -464,4 +468,74 @@ mod tests { Ok(()) } + + #[rstest] + #[tokio::test] + async fn test_redis_backoff(#[allow(unused_variables)] logging: ()) -> RResult<(), AnyErr> { + // Redis can't be run on windows, skip if so: + if cfg!(windows) { + return Ok(()); + } + + // TODO using now in here and dlock, should be some test utils we can use cross crate. + macro_rules! assert_td_in_range { + ($td:expr, $range:expr) => { + assert!( + $td >= $range.start && $td <= $range.end, + "Expected '{}' to be in range '{}' - '{}'.", + chrono_format_td($td, true), + chrono_format_td($range.start, true), + chrono_format_td($range.end, true), + ); + }; + } + + let rs = RedisStandalone::new().await?; + + let r = rs.instance()?; + let mut rconn = r.conn(); + + macro_rules! call { + () => { + rconn + .backoff_protector("n1", "caller1", 2, chrono::Duration::milliseconds(100), 1.5) + .await + }; + } + assert_eq!(call!(), None); + assert_eq!(call!(), None); + assert_td_in_range!( + call!().unwrap(), + chrono::Duration::milliseconds(90)..chrono::Duration::milliseconds(100) + ); + // Wait to allow call again, but will x1.5 wait for next time: + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + assert_eq!(call!(), None); + assert_td_in_range!( + call!().unwrap(), + chrono::Duration::milliseconds(130)..chrono::Duration::milliseconds(150) + ); + // Just check double call too: + assert_td_in_range!( + call!().unwrap(), + chrono::Duration::milliseconds(125)..chrono::Duration::milliseconds(150) + ); + tokio::time::sleep(std::time::Duration::from_millis(150)).await; + assert_eq!(call!(), None); + // Should now 1.5x again: + assert_td_in_range!( + call!().unwrap(), + chrono::Duration::milliseconds(190)..chrono::Duration::milliseconds(225) + ); + // By waiting over 2x the current delay, should all reset: + tokio::time::sleep(std::time::Duration::from_millis(450)).await; + assert_eq!(call!(), None); + assert_eq!(call!(), None); + assert_td_in_range!( + call!().unwrap(), + chrono::Duration::milliseconds(90)..chrono::Duration::milliseconds(100) + ); + + Ok(()) + } } diff --git a/rust/bitbazaar/redis/temp_list.rs b/rust/bitbazaar/redis/temp_list.rs index c8a6cd43..3d1d2476 100644 --- a/rust/bitbazaar/redis/temp_list.rs +++ b/rust/bitbazaar/redis/temp_list.rs @@ -1,7 +1,6 @@ use std::{ borrow::Cow, sync::{atomic::AtomicI64, Arc}, - time::Duration, }; use futures::{future::BoxFuture, FutureExt}; @@ -262,31 +261,31 @@ pub struct RedisTempList { pub key: String, /// If the list hasn't been read or written to in this time, it will be expired. - pub list_inactive_ttl: Duration, + pub list_inactive_ttl: std::time::Duration, /// If an item hasn't been read or written to in this time, it will be expired. - pub item_inactive_ttl: Duration, + pub item_inactive_ttl: std::time::Duration, /// Used to prevent overlap between push() and extend() calls using the same ts by accident. #[serde(skip)] last_extension_ts_millis: Arc, } -/// A managed list entry in redis that will: -/// - Auto expire on inactivity -/// - Return items newest to oldest -/// - Items are set with their own ttl, so the members themselves expire separate from the full list -/// - Optionally prevents duplicate values. When enabled, if a duplicate is added, the item will be bumped to the front & old discarded. impl RedisTempList { - pub(crate) fn new( + /// Create/connect to a managed list entry in redis that will: + /// - Auto expire on inactivity + /// - Return items newest to oldest + /// - Items are set with their own ttl, so the members themselves expire separate from the full list + /// - Optionally prevents duplicate values. When enabled, if a duplicate is added, the item will be bumped to the front & old discarded. + pub fn new( namespace: &'static str, - key: String, - list_inactive_ttl: Duration, - item_inactive_ttl: Duration, + key: impl Into, + list_inactive_ttl: std::time::Duration, + item_inactive_ttl: std::time::Duration, ) -> Arc { Arc::new(Self { namespace: Cow::Borrowed(namespace), - key, + key: key.into(), list_inactive_ttl, item_inactive_ttl, last_extension_ts_millis: Arc::new(AtomicI64::new(0)), @@ -341,7 +340,7 @@ impl RedisTempList { .zadd_multi( &self.namespace, &self.key, - Some(self.list_inactive_ttl), // This will auto reset the expire time of the list as a whole + Some(self.list_inactive_ttl_chrono()), // This will auto reset the expire time of the list as a whole items_with_uids .iter() .map(|(uid, _)| (score, uid)) @@ -354,7 +353,7 @@ impl RedisTempList { items_with_uids .into_iter() .map(|(uid, item)| (uid, RedisJsonBorrowed(item))), - Some(self.item_inactive_ttl), + Some(self.item_inactive_ttl_chrono()), ) // Cleanup old members that have now expired: // (set member expiry is a logical process, not currently part of redis but could be soon) @@ -494,7 +493,7 @@ impl RedisTempList { &item_info.iter().map(|(uid, _)| uid).collect::>(), ) // Unlike our zadd during setting, need to manually refresh the expire time of the list here: - .expire(&self.namespace, &self.key, self.list_inactive_ttl) + .expire(&self.namespace, &self.key, self.list_inactive_ttl_chrono()) .fire() .await; @@ -570,7 +569,7 @@ impl RedisTempList { ) .get::>(&self.namespace, uid) // Unlike our zadd during setting, need to manually refresh the expire time of the list here: - .expire(&self.namespace, &self.key, self.list_inactive_ttl) + .expire(&self.namespace, &self.key, self.list_inactive_ttl_chrono()) .fire() .await; @@ -621,7 +620,7 @@ impl RedisTempList { .as_slice(), ) // Unlike our zadd during setting, need to manually refresh the expire time of the list here: - .expire(&self.namespace, &self.key, self.list_inactive_ttl) + .expire(&self.namespace, &self.key, self.list_inactive_ttl_chrono()) .fire() .await; } @@ -646,7 +645,7 @@ impl RedisTempList { .zadd( &self.namespace, &self.key, - Some(self.list_inactive_ttl), + Some(self.list_inactive_ttl_chrono()), new_score, uid, ) @@ -655,7 +654,7 @@ impl RedisTempList { &self.namespace, uid, RedisJsonBorrowed(item), - Some(self.item_inactive_ttl), + Some(self.item_inactive_ttl_chrono()), ) // Cleanup old members that have now expired: // (member expiry is a logical process, not currently part of redis but could be soon) @@ -679,6 +678,14 @@ impl RedisTempList { .fire() .await; } + + fn list_inactive_ttl_chrono(&self) -> chrono::Duration { + chrono::Duration::milliseconds(self.list_inactive_ttl.as_millis() as i64) + } + + fn item_inactive_ttl_chrono(&self) -> chrono::Duration { + chrono::Duration::milliseconds(self.item_inactive_ttl.as_millis() as i64) + } } #[cfg(test)] @@ -699,20 +706,20 @@ pub async fn redis_temp_list_tests(r: &super::Redis) -> Result<(), AnyErr> { static NS: &str = "templist_tests"; - let li1 = r.templist( + let li1 = RedisTempList::new( NS, "t1", - Duration::from_millis(100), - Duration::from_millis(60), + std::time::Duration::from_millis(100), + std::time::Duration::from_millis(60), ); li1.extend( &mut conn, vec!["i1".to_string(), "i2".to_string(), "i3".to_string()], ) .await; - tokio::time::sleep(Duration::from_millis(20)).await; + tokio::time::sleep(std::time::Duration::from_millis(20)).await; li1.push(&mut conn, "i4".to_string()).await; - tokio::time::sleep(Duration::from_millis(20)).await; + tokio::time::sleep(std::time::Duration::from_millis(20)).await; li1.push(&mut conn, "i5".to_string()).await; // Keys are ordered from recent to old: assert_eq!( @@ -724,20 +731,20 @@ pub async fn redis_temp_list_tests(r: &super::Redis) -> Result<(), AnyErr> { RedisTempListItem::vec_items(li1.read_multi::(&mut conn, Some(3)).await), vec!["i5", "i4", "i3"] ); - tokio::time::sleep(Duration::from_millis(20)).await; + tokio::time::sleep(std::time::Duration::from_millis(20)).await; // First batch should have expired now (3 20ms waits) // i4 should be expired, it had a ttl of 20ms: assert_eq!( RedisTempListItem::vec_items(li1.read_multi::(&mut conn, None).await), vec!["i5", "i4"] ); - tokio::time::sleep(Duration::from_millis(20)).await; + tokio::time::sleep(std::time::Duration::from_millis(20)).await; // i4 should be gone now: assert_eq!( RedisTempListItem::vec_items(li1.read_multi::(&mut conn, None).await), vec!["i5"] ); - tokio::time::sleep(Duration::from_millis(20)).await; + tokio::time::sleep(std::time::Duration::from_millis(20)).await; // i5 should be gone now: assert_eq!( RedisTempListItem::vec_items(li1.read_multi::(&mut conn, None).await), @@ -745,34 +752,34 @@ pub async fn redis_temp_list_tests(r: &super::Redis) -> Result<(), AnyErr> { ); // Let's create a strange list, one with a short list lifetime but effectively infinite item lifetime: - let li2 = r.templist( + let li2 = RedisTempList::new( NS, "t2", - Duration::from_millis(50), - Duration::from_millis(1000), + std::time::Duration::from_millis(50), + std::time::Duration::from_millis(1000), ); - tokio::time::sleep(Duration::from_millis(20)).await; + tokio::time::sleep(std::time::Duration::from_millis(20)).await; let i1 = li2.push(&mut conn, "i1".to_string()).await; // Li should still be there after another 40ms, as the last push should have updated the list's ttl: - tokio::time::sleep(Duration::from_millis(40)).await; + tokio::time::sleep(std::time::Duration::from_millis(40)).await; assert_eq!( RedisTempListItem::vec_items(li2.read_multi::(&mut conn, None).await), vec!["i1"] ); - tokio::time::sleep(Duration::from_millis(40)).await; + tokio::time::sleep(std::time::Duration::from_millis(40)).await; li2.extend( &mut conn, vec!["i2".to_string(), "i3".to_string(), "i4".to_string()], ) .await; // Li should still be there after another 40ms, as the last push should have updated the list's ttl: - tokio::time::sleep(Duration::from_millis(40)).await; + tokio::time::sleep(std::time::Duration::from_millis(40)).await; assert_eq!( RedisTempListItem::vec_items(li2.read_multi::(&mut conn, None).await), vec!["i4", "i3", "i2", "i1"] ); // Above read_multi should have also updated the list's ttl, so should be there after another 40ms: - tokio::time::sleep(Duration::from_millis(40)).await; + tokio::time::sleep(std::time::Duration::from_millis(40)).await; assert_eq!( li2.read::(&mut conn, i1.uid().unwrap()) .await @@ -780,7 +787,7 @@ pub async fn redis_temp_list_tests(r: &super::Redis) -> Result<(), AnyErr> { Some("i1".to_string()) ); // Above direct read should have also updated the list's ttl, so should be there after another 40ms: - tokio::time::sleep(Duration::from_millis(40)).await; + tokio::time::sleep(std::time::Duration::from_millis(40)).await; assert_eq!( li2.read::(&mut conn, i1.uid().unwrap()) .await @@ -788,19 +795,19 @@ pub async fn redis_temp_list_tests(r: &super::Redis) -> Result<(), AnyErr> { Some("i1".to_string()) ); let i5 = li2.push(&mut conn, "i5".to_string()).await; - tokio::time::sleep(Duration::from_millis(40)).await; + tokio::time::sleep(std::time::Duration::from_millis(40)).await; li2.delete(&mut conn, i1.uid().unwrap()).await; // Above delete should have updated the list's ttl, so should be there after another 40ms: - tokio::time::sleep(Duration::from_millis(40)).await; + tokio::time::sleep(std::time::Duration::from_millis(40)).await; assert_eq!( RedisTempListItem::vec_items(li2.read_multi::(&mut conn, None).await), vec!["i5", "i4", "i3", "i2"] ); - tokio::time::sleep(Duration::from_millis(40)).await; + tokio::time::sleep(std::time::Duration::from_millis(40)).await; li2.update(&mut conn, i5.uid().unwrap(), &"i5-updated") .await; // Above update should have updated the list's ttl, so should be there after another 40ms: - tokio::time::sleep(Duration::from_millis(40)).await; + tokio::time::sleep(std::time::Duration::from_millis(40)).await; assert_eq!( li2.read::(&mut conn, i5.uid().unwrap()) .await @@ -808,18 +815,18 @@ pub async fn redis_temp_list_tests(r: &super::Redis) -> Result<(), AnyErr> { Some("i5-updated".to_string()) ); // When no reads or writes, the list should expire after 50ms: - tokio::time::sleep(Duration::from_millis(60)).await; + tokio::time::sleep(std::time::Duration::from_millis(60)).await; assert_eq!( RedisTempListItem::vec_items(li2.read_multi::(&mut conn, None).await), Vec::::new() ); // Make sure manual clear() works on a new list too: - let li3 = r.templist( + let li3 = RedisTempList::new( NS, "t3", - Duration::from_millis(100), - Duration::from_millis(30), + std::time::Duration::from_millis(100), + std::time::Duration::from_millis(30), ); li3.extend( &mut conn, @@ -837,11 +844,11 @@ pub async fn redis_temp_list_tests(r: &super::Redis) -> Result<(), AnyErr> { ); // Try with arb value, e.g. vec of (i32, String): - let li4 = r.templist( + let li4 = RedisTempList::new( NS, "t4", - Duration::from_millis(100), - Duration::from_millis(30), + std::time::Duration::from_millis(100), + std::time::Duration::from_millis(30), ); li4.push(&mut conn, (1, "a".to_string())).await; li4.push(&mut conn, (2, "b".to_string())).await; @@ -861,11 +868,11 @@ pub async fn redis_temp_list_tests(r: &super::Redis) -> Result<(), AnyErr> { ); // Try with json value: - let li5 = r.templist( + let li5 = RedisTempList::new( NS, "t5", - Duration::from_millis(100), - Duration::from_millis(30), + std::time::Duration::from_millis(100), + std::time::Duration::from_millis(30), ); li5.extend( &mut conn, @@ -896,11 +903,11 @@ pub async fn redis_temp_list_tests(r: &super::Redis) -> Result<(), AnyErr> { ); // Make sure duplicate values don't break the list and are still kept: - let li6 = r.templist( + let li6 = RedisTempList::new( NS, "t6", - Duration::from_millis(100), - Duration::from_millis(30), + std::time::Duration::from_millis(100), + std::time::Duration::from_millis(30), ); li6.extend( &mut conn, diff --git a/rust/bitbazaar/redis/wrapper.rs b/rust/bitbazaar/redis/wrapper.rs index 718628c6..bef6f812 100644 --- a/rust/bitbazaar/redis/wrapper.rs +++ b/rust/bitbazaar/redis/wrapper.rs @@ -1,9 +1,9 @@ -use std::{sync::Arc, time::Duration}; +use std::time::Duration; use deadpool_redis::{Config, Runtime}; use futures::Future; -use super::{RedisConn, RedisLock, RedisLockErr, RedisTempList}; +use super::{RedisConn, RedisLock, RedisLockErr}; use crate::errors::prelude::*; /// A wrapper around redis to make it more concise to use and not need redis in the downstream Cargo.toml. @@ -68,7 +68,7 @@ impl Redis { /// - `wait_up_to`: if the lock is busy elsewhere, wait this long trying to get it, before giving up and returning [`RedisLockErr::Unavailable`]. pub async fn dlock_for_fut>>( &self, - namespace: &'static str, + namespace: &str, lock_key: &str, wait_up_to: Option, fut: Fut, @@ -89,24 +89,6 @@ impl Redis { result } - /// Connect up to a magic redis list that: - /// - Has an expiry on the list itself, resetting on each read or write. (each change lives again for `expire_after` time) - /// - Each item in the list has it's own expiry, so the list is always clean of old items. - /// - Each item has a generated unique key, this key can be used to update or delete specific items directly. - /// - Returned items are returned newest/last-updated to oldest - /// This makes this distributed data structure perfect for stuff like: - /// - recent/temporary logs/events of any sort. - /// - pending actions, that can be updated in-place by the creator, but read as part of a list by a viewer etc. - pub fn templist( - &self, - namespace: &'static str, - key: impl Into, - list_inactive_ttl: Duration, - item_inactive_ttl: Duration, - ) -> Arc { - RedisTempList::new(namespace, key.into(), list_inactive_ttl, item_inactive_ttl) - } - /// Escape hatch, access the inner deadpool_redis pool. pub fn get_inner_pool(&self) -> &deadpool_redis::Pool { &self.pool