From d9d6ea9107353261b17ce84a47e66ee2b623e373 Mon Sep 17 00:00:00 2001 From: kozabrada123 <59031733+kozabrada123@users.noreply.github.com> Date: Sun, 24 Nov 2024 07:53:16 +0100 Subject: [PATCH] Tracking for Release v0.18.0 (#576) Release tracker for v0.18.0 of chorus, set to release on November 24th, 2024. ## Public API changes - #570: Various entity public api changes - 644d3beb90bc1aedd2dd222cab199457d4dffddd, 85e922bc5039bd2af091b4a95d4b7f7d82ade7e2: Add type `OneOrMoreSnowflakes`, allow `GatewayRequestGuildMembers` to request multiple guild and user ids - f65b9c1: Differentiate `PresenceUpdate` and `GatewayPresenceUpdate` - 0e5fd86: Temporarily fix `PresenceUpdate` for Spacebar Client by making `user` optional - 61ac7d1465f2509bed01ddefc31228c31d893704: Updated `LazyRequest` (op 14) to use the `Snowflake` type for ids instead of just `String` ## Additions - #564: MFA implementation, by @xystrive and @kozabrada123 - 4ed68ce7a506d0354730b4b4f0b7d4267c5e2a50: Added [Last Messages request](https://docs.discord.sex/topics/gateway-events#request-last-messages) and [response](https://docs.discord.sex/topics/gateway-events#last-messages) - b23fb68: Add `ReadState` to `GatewayReady` - #571: Gateway Opcode enum - #573: Gateway Disconnect Opcode enums ## Bugfixes - #565: Fix sqlx En-/Decoding of `PremiumType` - 7460d3f: Fix `GatewayIdentifyConnectionProps` for Spacebar Client by deriving default on all fields, since the client does not send it - 3d9460f: Derive Default for `MessageReferenceType`, assume default reference_type if none is provided - 4baecf978415b680d5ef13061c24e5e363d19988: Fixed a deserialization error related to `presences` in `GuildMembersChunk` being an array, not a single value - 1b201020e9dec23e2dfcd467e49c58b151d88e0e: Fixed a deserialization error with deserializing `activities` in `PresenceUpdate` as an empty array when they are sent as `null` - 7feb57186a8bb3f873cc4d5935280de0c8687511: Fixed a deserialization error on discord.com related to experiments (they are not implemented yet, see #578) - fb94afa390e41caeec3883d5feaf8aa6dc9b8985: Fixed a deserialization error on discord.com related to `last_viewed` in `ReadState` being a version / counter, not a `DateTime` ## Internal changes - 40754c5: bump sqlx-pg-uint to v0.8.0 - #575: Refactor of gateway close code handling - 4ed68ce7a506d0354730b4b4f0b7d4267c5e2a50: Refactored the gateway to fully use the `Opcode` enum instead of constants - #579 --------- Co-authored-by: bitfl0wer Co-authored-by: Flori <39242991+bitfl0wer@users.noreply.github.com> Co-authored-by: xystrive --- .github/workflows/build_and_test.yml | 16 +- Cargo.lock | 102 +++- Cargo.toml | 16 +- README.md | 8 +- src/api/auth/login.rs | 86 ++- src/api/auth/mod.rs | 10 +- src/api/auth/register.rs | 12 +- src/api/channels/channels.rs | 6 - src/api/channels/messages.rs | 11 - src/api/channels/permissions.rs | 1 - src/api/channels/reactions.rs | 6 - src/api/guilds/guilds.rs | 30 +- src/api/guilds/roles.rs | 1 - src/api/users/mfa.rs | 372 +++++++++++++ src/api/users/mod.rs | 2 + src/api/users/users.rs | 26 +- src/errors.rs | 61 ++- src/gateway/backends/tungstenite.rs | 33 +- src/gateway/backends/wasm.rs | 11 +- src/gateway/events.rs | 9 + src/gateway/gateway.rs | 217 +++++--- src/gateway/handle.rs | 50 +- src/gateway/heartbeat.rs | 12 +- src/gateway/message.rs | 55 +- src/gateway/mod.rs | 49 +- src/instance.rs | 74 ++- src/ratelimiter.rs | 14 +- src/types/entities/channel.rs | 134 ++++- src/types/entities/message.rs | 5 +- src/types/entities/mfa_token.rs | 61 +++ src/types/entities/mod.rs | 10 + src/types/entities/relationship.rs | 10 +- src/types/entities/user.rs | 11 +- src/types/entities/user_settings.rs | 4 +- src/types/events/call.rs | 62 ++- src/types/events/guild.rs | 2 +- src/types/events/identify.rs | 13 +- src/types/events/lazy_request.rs | 3 +- src/types/events/message.rs | 25 + src/types/events/mfa.rs | 31 ++ src/types/events/mod.rs | 2 + src/types/events/presence.rs | 30 +- src/types/events/ready.rs | 60 ++- src/types/events/request_members.rs | 40 +- src/types/schema/auth.rs | 22 +- src/types/schema/channel.rs | 1 + src/types/schema/mfa.rs | 426 +++++++++++++++ src/types/schema/mod.rs | 2 + src/types/utils/mod.rs | 7 +- src/types/utils/opcode.rs | 308 +++++++++++ src/types/utils/snowflake.rs | 92 ++++ src/voice/gateway/backends/mod.rs | 1 - src/voice/gateway/backends/tungstenite.rs | 28 +- src/voice/gateway/backends/wasm.rs | 3 +- src/voice/gateway/gateway.rs | 140 ++++- src/voice/gateway/message.rs | 48 +- tests/auth.rs | 606 +++++++++++++++++++++- tests/channels.rs | 12 +- tests/common/mod.rs | 256 ++++++++- tests/gateway.rs | 177 ++++--- tests/relationships.rs | 41 +- 61 files changed, 3429 insertions(+), 534 deletions(-) create mode 100644 src/api/users/mfa.rs create mode 100644 src/types/entities/mfa_token.rs create mode 100644 src/types/events/mfa.rs create mode 100644 src/types/schema/mfa.rs create mode 100644 src/types/utils/opcode.rs diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index b9cebea5..b7a125a5 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -93,6 +93,12 @@ jobs: npm run setup npm run start & working-directory: ./server + # Note: see + # https://github.com/polyphony-chat/chorus/pull/579, + # https://github.com/rustwasm/wasm-bindgen/issues/4274 + # https://github.com/rustwasm/wasm-bindgen/issues/4274#issuecomment-2493497388 + - name: Install rust toolchain 1.81.0 + uses: dtolnay/rust-toolchain@1.81.0 - uses: Swatinem/rust-cache@v2 with: cache-all-crates: "true" @@ -102,7 +108,7 @@ jobs: run: | rustup target add wasm32-unknown-unknown curl -L --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/cargo-bins/cargo-binstall/main/install-from-binstall-release.sh | bash - cargo binstall --no-confirm wasm-bindgen-cli --version "0.2.93" --force + cargo binstall --no-confirm wasm-bindgen-cli --version "0.2.95" --force GECKODRIVER=$(which geckodriver) cargo test --target wasm32-unknown-unknown --no-default-features --features="client, rt, voice_gateway" wasm-chrome: runs-on: ubuntu-latest @@ -123,6 +129,12 @@ jobs: npm run setup npm run start & working-directory: ./server + # Note: see + # https://github.com/polyphony-chat/chorus/pull/579, + # https://github.com/rustwasm/wasm-bindgen/issues/4274 + # https://github.com/rustwasm/wasm-bindgen/issues/4274#issuecomment-2493497388 + - name: Install rust toolchain 1.81.0 + uses: dtolnay/rust-toolchain@1.81.0 - uses: Swatinem/rust-cache@v2 with: cache-all-crates: "true" @@ -132,5 +144,5 @@ jobs: run: | rustup target add wasm32-unknown-unknown curl -L --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/cargo-bins/cargo-binstall/main/install-from-binstall-release.sh | bash - cargo binstall --no-confirm wasm-bindgen-cli --version "0.2.93" --force + cargo binstall --no-confirm wasm-bindgen-cli --version "0.2.95" --force CHROMEDRIVER=$(which chromedriver) cargo test --target wasm32-unknown-unknown --no-default-features --features="client, rt, voice_gateway" diff --git a/Cargo.lock b/Cargo.lock index 30d7b2a3..826762a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -189,6 +189,17 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bstr" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" +dependencies = [ + "memchr", + "regex-automata", + "serde", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -230,7 +241,7 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chorus" -version = "0.17.0" +version = "0.18.0" dependencies = [ "async-trait", "base64 0.21.7", @@ -245,9 +256,11 @@ dependencies = [ "getrandom", "hostname", "http 0.2.12", + "httptest", "jsonwebtoken", "lazy_static", "log", + "pharos", "poem", "pubserve", "rand", @@ -383,6 +396,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-queue" version = "0.3.11" @@ -959,6 +981,30 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "httptest" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae0fc8d140f1f0f3e7f821c8eff55cd13966db7a3370b2d9a7b08e9ec8ee8786" +dependencies = [ + "bstr", + "bytes", + "crossbeam-channel", + "form_urlencoded", + "futures", + "http 1.1.0", + "http-body-util", + "hyper 1.4.1", + "hyper-util", + "log", + "once_cell", + "regex", + "serde", + "serde_json", + "serde_urlencoded", + "tokio", +] + [[package]] name = "hyper" version = "0.14.30" @@ -1125,9 +1171,9 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "js-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ "wasm-bindgen", ] @@ -2340,9 +2386,9 @@ dependencies = [ [[package]] name = "sqlx-pg-uint" -version = "0.7.2" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60e5ec2fd2d274ebf9ad6b44b3986f9bcdbb554bb162c4b1ac4af05a439c66f2" +checksum = "0da63696a67d81d916818ff62b9581f72cde00fa95b224944bab607a27594751" dependencies = [ "bigdecimal", "serde", @@ -2353,9 +2399,9 @@ dependencies = [ [[package]] name = "sqlx-pg-uint-macros" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e527060e9f43479e5b386e4237ab320a36fce39394f6ed73c8870f4637f2e5f" +checksum = "087085ea4ec76bb36e7737c7f8c73ea63ea31243301ed00860ba0758f277dcf2" dependencies = [ "quote", "syn 2.0.79", @@ -2885,9 +2931,9 @@ checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" [[package]] name = "wasm-bindgen" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ "cfg-if", "once_cell", @@ -2896,9 +2942,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ "bumpalo", "log", @@ -2911,9 +2957,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.43" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" +checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" dependencies = [ "cfg-if", "js-sys", @@ -2923,9 +2969,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2933,9 +2979,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", @@ -2946,15 +2992,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] name = "wasm-bindgen-test" -version = "0.3.43" +version = "0.3.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68497a05fb21143a08a7d24fc81763384a3072ee43c44e86aad1744d6adef9d9" +checksum = "d381749acb0943d357dcbd8f0b100640679883fcdeeef04def49daf8d33a5426" dependencies = [ "console_error_panic_hook", "js-sys", @@ -2967,9 +3013,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-test-macro" -version = "0.3.43" +version = "0.3.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b8220be1fa9e4c889b30fd207d4906657e7e90b12e0e6b0c8b8d8709f5de021" +checksum = "c97b2ef2c8d627381e51c071c2ab328eac606d3f69dd82bcbca20a9e389d95f0" dependencies = [ "proc-macro2", "quote", @@ -2978,9 +3024,9 @@ dependencies = [ [[package]] name = "wasmtimer" -version = "0.2.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f656cd8858a5164932d8a90f936700860976ec21eb00e0fe2aa8cab13f6b4cf" +checksum = "bb4f099acbc1043cc752b91615b24b02d7f6fcd975bd781fed9f50b3c3e15bf7" dependencies = [ "futures", "js-sys", @@ -2992,9 +3038,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" dependencies = [ "js-sys", "wasm-bindgen", @@ -3053,7 +3099,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.48.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 675a4662..eced124e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "chorus" description = "A library for interacting with multiple Spacebar-compatible Instances at once." -version = "0.17.0" +version = "0.18.0" license = "MPL-2.0" edition = "2021" repository = "https://github.com/polyphony-chat/chorus" @@ -67,7 +67,7 @@ rand = "0.8.5" flate2 = { version = "1.0.33", optional = true } webpki-roots = "0.26.3" pubserve = { version = "1.1.0", features = ["async", "send"] } -sqlx-pg-uint = { version = "0.7.2", features = ["serde"], optional = true } +sqlx-pg-uint = { version = "0.8.0", features = ["serde"], optional = true } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] rustls = "0.21.12" @@ -80,14 +80,22 @@ getrandom = { version = "0.2.15" } [target.'cfg(target_arch = "wasm32")'.dependencies] getrandom = { version = "0.2.15", features = ["js"] } ws_stream_wasm = "0.7.4" +pharos = "*" # This is a dependency of ws_stream_wasm, we are including it to interface with that library wasm-bindgen-futures = "0.4.43" -wasmtimer = "0.2.0" +wasmtimer = "0.4.0" [dev-dependencies] lazy_static = "1.5.0" wasm-bindgen-test = "0.3.43" -wasm-bindgen = "0.2.93" +wasm-bindgen = "0.2.95" simple_logger = { version = "5.0.0", default-features = false } +[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] +httptest = "0.16.1" + [lints.rust] unexpected_cfgs = { level = "allow", check-cfg = ['cfg(tarpaulin_include)'] } + +# See https://github.com/whizsid/wasmtimer-rs/issues/18#issuecomment-2420144039 +[package.metadata.wasm-pack.profile.dev.wasm-bindgen] +debug-js-glue = false diff --git a/README.md b/README.md index cdac65d4..af5eb2ff 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ To get started with Chorus, import it into your project by adding the following ```toml [dependencies] -chorus = "0.17.0" +chorus = "0.18.0" ``` ### Establishing a Connection @@ -143,7 +143,11 @@ to run the tests for wasm. ## Versioning -This crate uses Semantic Versioning 2.0.0 as its versioning scheme. You can read the specification [here](https://semver.org/spec/v2.0.0.html). +Like other cargo crates, this crate uses Semantic Versioning 2.0.0 as its versioning scheme. +You can read the specification [here](https://semver.org/spec/v2.0.0.html). + +Code gated behind the `backend` feature is not considered part of the public API and can change without +affecting semver compatibility. The `backend` feature is explicitly meant for use in [`symfonia`](https://github.com/polyphony-chat/symfonia) ## Contributing diff --git a/src/api/auth/login.rs b/src/api/auth/login.rs index ab78a995..749ed239 100644 --- a/src/api/auth/login.rs +++ b/src/api/auth/login.rs @@ -11,7 +11,10 @@ use crate::errors::ChorusResult; use crate::gateway::Gateway; use crate::instance::{ChorusUser, Instance}; use crate::ratelimiter::ChorusRequest; -use crate::types::{GatewayIdentifyPayload, LimitType, LoginResult, LoginSchema, User}; +use crate::types::{ + MfaAuthenticationType, GatewayIdentifyPayload, LimitType, LoginResult, LoginSchema, + SendMfaSmsResponse, SendMfaSmsSchema, User, VerifyMFALoginResponse, VerifyMFALoginSchema, +}; impl Instance { /// Logs into an existing account on the spacebar server. @@ -27,6 +30,7 @@ impl Instance { .header("Content-Type", "application/json"), limit_type: LimitType::AuthLogin, }; + // We do not have a user yet, and the UserRateLimits will not be affected by a login // request (since login is an instance wide limit), which is why we are just cloning the // instances' limits to pass them on as user_rate_limits later. @@ -35,16 +39,82 @@ impl Instance { let login_result = chorus_request .deserialize_response::(&mut user) .await?; - user.set_token(&login_result.token); - user.settings = login_result.settings; - let object = User::get_current(&mut user).await?; - *user.object.write().unwrap() = object; + user.update_with_login_data(login_result.token, Some(login_result.settings)) + .await?; + + Ok(user) + } + + /// Verifies a multi-factor authentication login + /// + /// # Reference + /// See + pub async fn verify_mfa_login( + &mut self, + authenticator: MfaAuthenticationType, + schema: VerifyMFALoginSchema, + ) -> ChorusResult { + let endpoint_url = self.urls.api.clone() + "/auth/mfa/" + &authenticator.to_string(); + + let chorus_request = ChorusRequest { + request: Client::new() + .post(endpoint_url) + .header("Content-Type", "application/json") + .json(&schema), + limit_type: LimitType::AuthLogin, + }; + + let mut user = ChorusUser::shell(Arc::new(RwLock::new(self.clone())), "None").await; - let mut identify = GatewayIdentifyPayload::common(); - identify.token = user.token(); - user.gateway.send_identify(identify).await; + match chorus_request + .deserialize_response::(&mut user) + .await? + { + VerifyMFALoginResponse::Success { + token, + user_settings, + } => { + user.update_with_login_data(token, Some(user_settings)) + .await?; + } + VerifyMFALoginResponse::UserSuspended { + suspended_user_token, + } => { + return Err(crate::errors::ChorusError::SuspendUser { + token: suspended_user_token, + }) + } + } Ok(user) } + + /// Sends a multi-factor authentication code to the user's phone number + /// + /// # Reference + /// See + // FIXME: This uses ChorusUser::shell, when it *really* shoudln't, but + // there is no other way to send a ratelimited request + pub async fn send_mfa_sms( + &mut self, + schema: SendMfaSmsSchema, + ) -> ChorusResult { + let endpoint_url = self.urls.api.clone() + "/auth/mfa/sms/send"; + let chorus_request = ChorusRequest { + request: Client::new() + .post(endpoint_url) + .header("Content-Type", "application/json") + .json(&schema), + limit_type: LimitType::Ip, + }; + + let mut chorus_user = ChorusUser::shell(Arc::new(RwLock::new(self.clone())), "None").await; + + let send_mfa_sms_response = chorus_request + .deserialize_response::(&mut chorus_user) + .await?; + + Ok(send_mfa_sms_response) + } } diff --git a/src/api/auth/mod.rs b/src/api/auth/mod.rs index 96491351..91d6d73e 100644 --- a/src/api/auth/mod.rs +++ b/src/api/auth/mod.rs @@ -25,15 +25,7 @@ impl Instance { pub async fn login_with_token(&mut self, token: &str) -> ChorusResult { let mut user = ChorusUser::shell(Arc::new(RwLock::new(self.clone())), token).await; - let object = User::get_current(&mut user).await?; - let settings = User::get_settings(&mut user).await?; - - *user.object.write().unwrap() = object; - *user.settings.write().unwrap() = settings; - - let mut identify = GatewayIdentifyPayload::common(); - identify.token = user.token(); - user.gateway.send_identify(identify).await; + user.update_with_login_data(token.to_string(), None).await?; Ok(user) } diff --git a/src/api/auth/register.rs b/src/api/auth/register.rs index d978e0ff..db230f22 100644 --- a/src/api/auth/register.rs +++ b/src/api/auth/register.rs @@ -43,18 +43,8 @@ impl Instance { .deserialize_response::(&mut user) .await? .token; - - user.set_token(&token); - let object = User::get_current(&mut user).await?; - let settings = User::get_settings(&mut user).await?; - - *user.object.write().unwrap() = object; - *user.settings.write().unwrap() = settings; - - let mut identify = GatewayIdentifyPayload::common(); - identify.token = user.token(); - user.gateway.send_identify(identify).await; + user.update_with_login_data(token, None).await?; Ok(user) } diff --git a/src/api/channels/channels.rs b/src/api/channels/channels.rs index 6c415763..7cb9ee3f 100644 --- a/src/api/channels/channels.rs +++ b/src/api/channels/channels.rs @@ -30,7 +30,6 @@ impl Channel { ), None, None, - None, Some(user), LimitType::Channel(channel_id), ); @@ -61,7 +60,6 @@ impl Channel { &url, None, audit_log_reason.as_deref(), - None, Some(user), LimitType::Channel(self.id), ); @@ -101,7 +99,6 @@ impl Channel { &url, Some(to_string(&modify_data).unwrap()), audit_log_reason.as_deref(), - None, Some(user), LimitType::Channel(channel_id), ); @@ -134,7 +131,6 @@ impl Channel { &url, None, None, - None, Some(user), Default::default(), ); @@ -196,7 +192,6 @@ impl Channel { &url, None, None, - None, Some(user), LimitType::Channel(self.id), ); @@ -225,7 +220,6 @@ impl Channel { &url, Some(to_string(&schema).unwrap()), None, - None, Some(user), LimitType::Guild(guild_id), ); diff --git a/src/api/channels/messages.rs b/src/api/channels/messages.rs index a682c21c..53fe64e6 100644 --- a/src/api/channels/messages.rs +++ b/src/api/channels/messages.rs @@ -152,7 +152,6 @@ impl Message { .as_str(), None, None, - None, Some(user), LimitType::Channel(channel_id), ); @@ -183,7 +182,6 @@ impl Message { .as_str(), None, audit_log_reason, - None, Some(user), LimitType::Channel(channel_id), ); @@ -210,7 +208,6 @@ impl Message { .as_str(), None, audit_log_reason, - None, Some(user), LimitType::Channel(channel_id), ); @@ -259,7 +256,6 @@ impl Message { .as_str(), Some(to_string(&schema).unwrap()), None, - None, Some(user), LimitType::Channel(channel_id), ); @@ -293,7 +289,6 @@ impl Message { .as_str(), Some(to_string(&schema).unwrap()), None, - None, Some(user), LimitType::Channel(channel_id), ); @@ -322,7 +317,6 @@ impl Message { .as_str(), None, None, - None, Some(user), LimitType::Channel(channel_id), ); @@ -349,7 +343,6 @@ impl Message { &url, None, None, - None, Some(user), LimitType::Channel(channel_id), ); @@ -383,7 +376,6 @@ impl Message { &url, Some(to_string(&schema).unwrap()), None, - None, Some(user), LimitType::Channel(channel_id), ); @@ -410,7 +402,6 @@ impl Message { &url, None, audit_log_reason.as_deref(), - None, Some(user), LimitType::Channel(channel_id), ); @@ -448,7 +439,6 @@ impl Message { .as_str(), Some(to_string(&messages).unwrap()), audit_log_reason.as_deref(), - None, Some(user), LimitType::Channel(channel_id), ); @@ -473,7 +463,6 @@ impl Message { .as_str(), None, None, - None, Some(user), LimitType::Channel(channel_id), ); diff --git a/src/api/channels/permissions.rs b/src/api/channels/permissions.rs index 03465b80..19492cd6 100644 --- a/src/api/channels/permissions.rs +++ b/src/api/channels/permissions.rs @@ -83,7 +83,6 @@ impl types::Channel { &url, None, None, - None, Some(user), LimitType::Channel(channel_id), ); diff --git a/src/api/channels/reactions.rs b/src/api/channels/reactions.rs index f2de33d7..e9bec801 100644 --- a/src/api/channels/reactions.rs +++ b/src/api/channels/reactions.rs @@ -36,7 +36,6 @@ impl ReactionMeta { &url, None, None, - None, Some(user), LimitType::Channel(self.channel_id), ); @@ -65,7 +64,6 @@ impl ReactionMeta { &url, None, None, - None, Some(user), LimitType::Channel(self.channel_id), ); @@ -96,7 +94,6 @@ impl ReactionMeta { &url, None, None, - None, Some(user), LimitType::Channel(self.channel_id), ); @@ -130,7 +127,6 @@ impl ReactionMeta { &url, None, None, - None, Some(user), LimitType::Channel(self.channel_id), ); @@ -159,7 +155,6 @@ impl ReactionMeta { &url, None, None, - None, Some(user), LimitType::Channel(self.channel_id), ); @@ -196,7 +191,6 @@ impl ReactionMeta { &url, None, None, - None, Some(user), LimitType::Channel(self.channel_id), ); diff --git a/src/api/guilds/guilds.rs b/src/api/guilds/guilds.rs index e2ff9bad..6469834d 100644 --- a/src/api/guilds/guilds.rs +++ b/src/api/guilds/guilds.rs @@ -63,6 +63,9 @@ impl Guild { /// /// Returns the updated guild. /// + /// # Notes + /// This route requires MFA. + /// /// # Reference /// pub async fn modify( @@ -81,7 +84,9 @@ impl Guild { .header("Content-Type", "application/json") .body(to_string(&schema).unwrap()), limit_type: LimitType::Guild(guild_id), - }; + } + .with_maybe_mfa(&user.mfa_token); + let response = chorus_request.deserialize_response::(user).await?; Ok(response) } @@ -90,12 +95,14 @@ impl Guild { /// /// User must be the owner. /// + /// # Notes + /// This route requires MFA. + /// /// # Example /// /// ```rs - /// let mut user = User::new(); - /// let mut instance = Instance::new(); - /// let guild_id = String::from("1234567890"); + /// let mut user: ChorusUser; + /// let guild_id = Snowflake::from(1234567890); /// /// match Guild::delete(&mut user, guild_id) { /// Err(e) => println!("Error deleting guild: {:?}", e), @@ -111,13 +118,16 @@ impl Guild { user.belongs_to.read().unwrap().urls.api, guild_id ); + let chorus_request = ChorusRequest { request: Client::new() .post(url.clone()) .header("Authorization", user.token.clone()) .header("Content-Type", "application/json"), limit_type: LimitType::Global, - }; + } + .with_maybe_mfa(&user.mfa_token); + chorus_request.handle_request_as_result(user).await } @@ -220,7 +230,6 @@ impl Guild { .as_str(), None, None, - None, Some(user), LimitType::Guild(guild_id), ); @@ -246,7 +255,6 @@ impl Guild { .as_str(), None, None, - None, Some(user), LimitType::Guild(guild_id), ); @@ -279,7 +287,6 @@ impl Guild { .as_str(), None, audit_log_reason.as_deref(), - None, Some(user), LimitType::Guild(guild_id), ); @@ -309,7 +316,6 @@ impl Guild { .as_str(), Some(to_string(&schema).unwrap()), audit_log_reason.as_deref(), - None, Some(user), LimitType::Guild(guild_id), ); @@ -336,7 +342,6 @@ impl Guild { .as_str(), Some(to_string(&schema).unwrap()), audit_log_reason.as_deref(), - None, Some(user), LimitType::Guild(guild_id), ); @@ -362,7 +367,6 @@ impl Guild { .as_str(), Some(to_string(&schema).unwrap()), None, - None, Some(user), LimitType::Guild(guild_id), ); @@ -393,7 +397,6 @@ impl Guild { &url, None, None, - None, Some(user), LimitType::Guild(guild_id), ); @@ -426,7 +429,6 @@ impl Guild { &url, None, None, - None, Some(user), LimitType::Guild(guild_id), ); @@ -456,7 +458,6 @@ impl Guild { .as_str(), Some(to_string(&schema).unwrap()), audit_log_reason.as_deref(), - None, Some(user), LimitType::Guild(guild_id), ); @@ -487,7 +488,6 @@ impl Guild { &url, None, audit_log_reason.as_deref(), - None, Some(user), LimitType::Guild(guild_id), ); diff --git a/src/api/guilds/roles.rs b/src/api/guilds/roles.rs index 6100a48d..7e760103 100644 --- a/src/api/guilds/roles.rs +++ b/src/api/guilds/roles.rs @@ -188,7 +188,6 @@ impl types::RoleObject { &url, None, audit_log_reason.as_deref(), - None, Some(user), LimitType::Guild(guild_id), ); diff --git a/src/api/users/mfa.rs b/src/api/users/mfa.rs new file mode 100644 index 00000000..621f0692 --- /dev/null +++ b/src/api/users/mfa.rs @@ -0,0 +1,372 @@ +use reqwest::Client; + +use crate::{ + errors::ChorusResult, + instance::{ChorusUser, Token}, + ratelimiter::ChorusRequest, + types::{ + BeginWebAuthnAuthenticatorCreationReturn, EnableTotpMfaResponse, EnableTotpMfaSchema, + FinishWebAuthnAuthenticatorCreationReturn, FinishWebAuthnAuthenticatorCreationSchema, + GetBackupCodesSchema, LimitType, MfaAuthenticator, MfaBackupCode, + ModifyWebAuthnAuthenticatorSchema, SendBackupCodesChallengeReturn, + SendBackupCodesChallengeSchema, SmsMfaRouteSchema, Snowflake, + }, +}; + +impl ChorusUser { + /// Enables TOTP based multi-factor authentication for the current user. + /// + /// # Notes + /// Fires a [`UserUpdate`](crate::types::UserUpdate) gateway event. + /// + /// Updates the authorization token for the current session. + /// + /// # Reference + /// See + pub async fn enable_totp_mfa( + &mut self, + schema: EnableTotpMfaSchema, + ) -> ChorusResult { + let request = Client::new() + .post(format!( + "{}/users/@me/mfa/totp/enable", + self.belongs_to.read().unwrap().urls.api + )) + .header("Authorization", self.token()) + .json(&schema); + + let chorus_request = ChorusRequest { + request, + limit_type: LimitType::default(), + }; + + let response: EnableTotpMfaResponse = chorus_request.deserialize_response(self).await?; + + self.token = response.token.clone(); + + Ok(response) + } + + /// Disables TOTP based multi-factor authentication for the current user. + /// + /// Updates the authorization token for the current session and returns the new token. + /// + /// # Notes + /// Requires MFA. + /// + /// MFA cannot be disabled for administrators of guilds with published creator monetization listings. + /// + /// Fires a [`UserUpdate`](crate::types::UserUpdate) gateway event. + /// + /// # Reference + /// See + pub async fn disable_totp_mfa(&mut self) -> ChorusResult { + let request = Client::new() + .post(format!( + "{}/users/@me/mfa/totp/disable", + self.belongs_to.read().unwrap().urls.api + )) + .header("Authorization", self.token()); + + let chorus_request = ChorusRequest { + request, + limit_type: LimitType::default(), + } + .with_maybe_mfa(&self.mfa_token); + + let response: Token = chorus_request.deserialize_response(self).await?; + + self.token = response.token.clone(); + + Ok(response) + } + + /// Enables SMS based multi-factor authentication for the current user. + /// + /// Requires that TOTP based MFA is already enabled and the user has a verified phone number. + /// + /// # Notes + /// Requires MFA. + /// + /// Fires a [`UserUpdate`](crate::types::UserUpdate) gateway event. + /// + /// # Reference + /// See + pub async fn enable_sms_mfa(&mut self, schema: SmsMfaRouteSchema) -> ChorusResult<()> { + let request = Client::new() + .post(format!( + "{}/users/@me/mfa/sms/enable", + self.belongs_to.read().unwrap().urls.api + )) + .header("Authorization", self.token()) + .json(&schema); + + let chorus_request = ChorusRequest { + request, + limit_type: LimitType::default(), + } + .with_maybe_mfa(&self.mfa_token); + + chorus_request.handle_request_as_result(self).await + } + + /// Disables SMS based multi-factor authentication for the current user. + /// + /// # Notes + /// Requires MFA. + /// + /// Fires a [`UserUpdate`](crate::types::UserUpdate) gateway event. + /// + /// # Reference + /// See + pub async fn disable_sms_mfa(&mut self, schema: SmsMfaRouteSchema) -> ChorusResult<()> { + let request = Client::new() + .post(format!( + "{}/users/@me/mfa/sms/disable", + self.belongs_to.read().unwrap().urls.api + )) + .header("Authorization", self.token()) + .json(&schema); + + let chorus_request = ChorusRequest { + request, + limit_type: LimitType::default(), + } + .with_maybe_mfa(&self.mfa_token); + + chorus_request.handle_request_as_result(self).await + } + + /// Fetches a list of [WebAuthn](crate::types::MfaAuthenticatorType::WebAuthn) + /// [MfaAuthenticator]s for the current user. + /// + /// # Reference + /// See + pub async fn get_webauthn_authenticators(&mut self) -> ChorusResult> { + let request = Client::new() + .get(format!( + "{}/users/@me/mfa/webauthn/credentials", + self.belongs_to.read().unwrap().urls.api + )) + .header("Authorization", self.token()); + + let chorus_request = ChorusRequest { + request, + limit_type: LimitType::default(), + }; + + chorus_request.deserialize_response(self).await + } + + /// Begins creation of a [WebAuthn](crate::types::MfaAuthenticatorType::WebAuthn) + /// [MfaAuthenticator] for the current user. + /// + /// Returns [BeginWebAuthnAuthenticatorCreationReturn], which includes the MFA ticket + /// and a stringified JSON object of the public key credential challenge. + /// + /// Once you have obtained the credential from the user, you should call + /// [ChorusUser::finish_webauthn_authenticator_creation] + /// + /// # Notes + /// Requires MFA. + /// + /// # Reference + /// See + /// + /// Note: for an easier to use API, we've split this one route into two methods + pub async fn begin_webauthn_authenticator_creation( + &mut self, + ) -> ChorusResult { + let request = Client::new() + .post(format!( + "{}/users/@me/mfa/webauthn/credentials", + self.belongs_to.read().unwrap().urls.api + )) + .header("Authorization", self.token()); + + let chorus_request = ChorusRequest { + request, + limit_type: LimitType::default(), + } + .with_maybe_mfa(&self.mfa_token); + + chorus_request.deserialize_response(self).await + } + + /// Finishes creation of a [WebAuthn](crate::types::MfaAuthenticatorType::WebAuthn) + /// [MfaAuthenticator] for the current user. + /// + /// Returns [FinishWebAuthnAuthenticatorCreationReturn], which includes the created + /// authenticator and a list of backup codes. + /// + /// To create a Webauthn authenticator from start to finish, call + /// [ChorusUser::begin_webauthn_authenticator_creation] first. + /// + /// # Notes + /// Requires MFA. + /// + /// Fires [AuthenticatorCreate](crate::types::AuthenticatorCreate) and + /// [UserUpdate](crate::types::UserUpdate) events. + /// + /// # Reference + /// See + /// + /// Note: for an easier to use API, we've split this one route into two methods + pub async fn finish_webauthn_authenticator_creation( + &mut self, + schema: FinishWebAuthnAuthenticatorCreationSchema, + ) -> ChorusResult { + let request = Client::new() + .post(format!( + "{}/users/@me/mfa/webauthn/credentials", + self.belongs_to.read().unwrap().urls.api + )) + .header("Authorization", self.token()) + .json(&schema); + + let chorus_request = ChorusRequest { + request, + limit_type: LimitType::default(), + } + .with_maybe_mfa(&self.mfa_token); + + chorus_request.deserialize_response(self).await + } + + /// Modifies a [WebAuthn](crate::types::MfaAuthenticatorType::WebAuthn) + /// [MfaAuthenticator] (currently just renames) for the current user. + /// + /// Returns the updated authenticator. + /// + /// # Notes + /// Requires MFA. + /// + /// Fires an [AuthenticatorUpdate](crate::types::AuthenticatorUpdate) event. + /// + /// # Reference + /// See + pub async fn modify_webauthn_authenticator( + &mut self, + authenticator_id: Snowflake, + schema: ModifyWebAuthnAuthenticatorSchema, + ) -> ChorusResult { + let request = Client::new() + .patch(format!( + "{}/users/@me/mfa/webauthn/credentials/{}", + self.belongs_to.read().unwrap().urls.api, + authenticator_id + )) + .header("Authorization", self.token()) + .json(&schema); + + let chorus_request = ChorusRequest { + request, + limit_type: LimitType::default(), + } + .with_maybe_mfa(&self.mfa_token); + + chorus_request.deserialize_response(self).await + } + + /// Deletes a [WebAuthn](crate::types::MfaAuthenticatorType::WebAuthn) + /// [MfaAuthenticator] for the current user. + /// + /// # Notes + /// Requires MFA. + /// + /// Fires [AuthenticatorDelete](crate::types::AuthenticatorDelete) and + /// [UserUpdate](crate::types::UserUpdate) events. + /// + /// If this is the last remaining authenticator, this disables MFA for the current user. + /// + /// MFA cannot be disabled for administrators of guilds with published creator monetization listings. + /// + /// # Reference + /// See + pub async fn delete_webauthn_authenticator( + &mut self, + authenticator_id: Snowflake, + ) -> ChorusResult<()> { + let request = Client::new() + .delete(format!( + "{}/users/@me/mfa/webauthn/credentials/{}", + self.belongs_to.read().unwrap().urls.api, + authenticator_id + )) + .header("Authorization", self.token()); + + let chorus_request = ChorusRequest { + request, + limit_type: LimitType::default(), + } + .with_maybe_mfa(&self.mfa_token); + + chorus_request.handle_request_as_result(self).await + } + + /// Sends an email to the current user with a verification code + /// that allows them to view or regenerate their backup codes. + /// + /// For the request to actually view the backup codes, see [ChorusUser::get_backup_codes]. + /// + /// # Notes + /// The two returned nonces can only be used once and expire after 30 minutes. + /// + /// # Reference + /// See + pub async fn send_backup_codes_challenge( + &mut self, + schema: SendBackupCodesChallengeSchema, + ) -> ChorusResult { + let request = Client::new() + .post(format!( + "{}/auth/verify/view-backup-codes-challenge", + self.belongs_to.read().unwrap().urls.api, + )) + .header("Authorization", self.token()) + .json(&schema); + + let chorus_request = ChorusRequest { + request, + limit_type: LimitType::default(), + }; + + chorus_request.deserialize_response(self).await + } + + /// Fetches the user's [MfaBackupCode]s. + /// + /// Before using this endpoint, you must use [ChorusUser::send_backup_codes_challenge] and + /// obtain a key from the user's email. + /// + /// # Notes + /// The nonces in the schema are returned by the [ChorusUser::send_backup_codes_challenge] + /// endpoint. + /// + /// If regenerate is set to true, the nonce in the schema must be the regenerate_nonce. + /// Otherwise it should be the view_nonce. + /// + /// Each nonce can only be used once and expires after 30 minutes. + /// + /// # Reference + /// See + pub async fn get_backup_codes( + &mut self, + schema: GetBackupCodesSchema, + ) -> ChorusResult> { + let request = Client::new() + .post(format!( + "{}/users/@me/mfa/codes-verification", + self.belongs_to.read().unwrap().urls.api, + )) + .header("Authorization", self.token()) + .json(&schema); + + let chorus_request = ChorusRequest { + request, + limit_type: LimitType::default(), + }; + + chorus_request.deserialize_response(self).await + } +} diff --git a/src/api/users/mod.rs b/src/api/users/mod.rs index 702233cb..b6b6ac02 100644 --- a/src/api/users/mod.rs +++ b/src/api/users/mod.rs @@ -6,11 +6,13 @@ pub use channels::*; pub use connections::*; pub use guilds::*; +pub use mfa::*; pub use relationships::*; pub use users::*; pub mod channels; pub mod connections; pub mod guilds; +pub mod mfa; pub mod relationships; pub mod users; diff --git a/src/api/users/users.rs b/src/api/users/users.rs index 483a85c1..a430d545 100644 --- a/src/api/users/users.rs +++ b/src/api/users/users.rs @@ -87,6 +87,9 @@ impl ChorusUser { /// Modifies the current user's representation. (See [`User`]) /// + /// # Notes + /// This route requires MFA. + /// /// # Reference /// See pub async fn modify(&mut self, modify_schema: UserModifySchema) -> ChorusResult { @@ -109,10 +112,13 @@ impl ChorusUser { .body(to_string(&modify_schema).unwrap()) .header("Authorization", self.token()) .header("Content-Type", "application/json"); + let chorus_request = ChorusRequest { request, limit_type: LimitType::default(), - }; + } + .with_maybe_mfa(&self.mfa_token); + chorus_request.deserialize_response::(self).await } @@ -123,7 +129,7 @@ impl ChorusUser { /// Requires the user's current password (if any) /// /// # Notes - /// Requires MFA + /// This route requires MFA. /// /// # Reference /// See @@ -135,10 +141,13 @@ impl ChorusUser { )) .header("Authorization", self.token()) .json(&schema); + let chorus_request = ChorusRequest { request, limit_type: LimitType::default(), - }; + } + .with_maybe_mfa(&self.mfa_token); + chorus_request.handle_request_as_result(self).await } @@ -147,7 +156,7 @@ impl ChorusUser { /// Requires the user's current password (if any) /// /// # Notes - /// Requires MFA + /// This route requires MFA. /// /// # Reference /// See @@ -159,10 +168,13 @@ impl ChorusUser { )) .header("Authorization", self.token()) .json(&schema); + let chorus_request = ChorusRequest { request, limit_type: LimitType::default(), - }; + } + .with_maybe_mfa(&self.mfa_token); + chorus_request.handle_request_as_result(self).await } @@ -524,9 +536,9 @@ impl ChorusUser { /// Returns a mapping of user IDs ([Snowflake]s) to notes ([String]s) for the current user. /// - /// # Notes + /// # Notes /// As of 2024/08/21, Spacebar does not yet implement this endpoint. - /// + /// /// # Reference /// See pub async fn get_user_notes(&mut self) -> ChorusResult> { diff --git a/src/errors.rs b/src/errors.rs index 0d130dd0..84b3f1a6 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -5,7 +5,7 @@ //! Contains all the errors that can be returned by the library. use custom_error::custom_error; -use crate::types::WebSocketEvent; +use crate::types::{CloseCode, MfaRequiredSchema, VoiceCloseCode, WebSocketEvent}; use chorus_macros::WebSocketEvent; custom_error! { @@ -46,7 +46,16 @@ custom_error! { /// Malformed or unexpected response. InvalidResponse{error: String} = "The response is malformed and cannot be processed. Error: {error}", /// Invalid, insufficient or too many arguments provided. - InvalidArguments{error: String} = "Invalid arguments were provided. Error: {error}" + InvalidArguments{error: String} = "Invalid arguments were provided. Error: {error}", + /// The request requires MFA verification. + /// + /// This error type contains an [crate::types::MfaChallenge], which can be completed + /// with [crate::instance::ChorusUser::complete_mfa_challenge]. + /// + /// After verifying, the same request can be retried. + MfaRequired {error: MfaRequiredSchema} = "Mfa verification is required to perform this action", + /// The user's account is suspended + SuspendUser { token: String } = "Your account has been suspended" } impl From for ChorusError { @@ -100,6 +109,32 @@ custom_error! { UnexpectedOpcodeReceived{opcode: u8} = "Received an opcode we weren't expecting to receive: {opcode}", } +impl From for GatewayError { + fn from(value: CloseCode) -> Self { + match value { + CloseCode::UnknownError => GatewayError::Unknown, + CloseCode::UnknownOpcode => GatewayError::UnknownOpcode, + CloseCode::DecodeError => GatewayError::Decode, + CloseCode::NotAuthenticated => GatewayError::NotAuthenticated, + CloseCode::AuthenticationFailed => GatewayError::AuthenticationFailed, + CloseCode::AlreadyAuthenticated => GatewayError::AlreadyAuthenticated, + CloseCode::InvalidSeq => GatewayError::InvalidSequenceNumber, + CloseCode::RateLimited => GatewayError::RateLimited, + CloseCode::SessionTimeout => GatewayError::SessionTimedOut, + // Note: this case is + // deprecated, it + // should never actually + // be received anymore + CloseCode::SessionNoLongerValid => GatewayError::SessionTimedOut, + CloseCode::InvalidShard => GatewayError::InvalidShard, + CloseCode::ShardingRequired => GatewayError::ShardingRequired, + CloseCode::InvalidApiVersion => GatewayError::InvalidAPIVersion, + CloseCode::InvalidIntents => GatewayError::InvalidIntents, + CloseCode::DisallowedIntents => GatewayError::DisallowedIntents, + } + } +} + custom_error! { /// Voice Gateway errors /// @@ -116,7 +151,7 @@ custom_error! { AuthenticationFailed = "The token you sent in your identify payload is incorrect", AlreadyAuthenticated = "You sent more than one identify payload", SessionNoLongerValid = "Your session is no longer valid", - SessionTimeout = "Your session has timed out", + SessionTimedOut = "Your session has timed out", ServerNotFound = "We can't find the server you're trying to connect to", UnknownProtocol = "We didn't recognize the protocol you sent", Disconnected = "Channel was deleted, you were kicked, voice server changed, or the main gateway session was dropped. Should not reconnect.", @@ -131,6 +166,25 @@ custom_error! { UnexpectedOpcodeReceived{opcode: u8} = "Received an opcode we weren't expecting to receive: {opcode}", } +impl From for VoiceGatewayError { + fn from(value: VoiceCloseCode) -> Self { + match value { + VoiceCloseCode::UnknownOpcode => VoiceGatewayError::UnknownOpcode, + VoiceCloseCode::FailedToDecodePayload => VoiceGatewayError::FailedToDecodePayload, + VoiceCloseCode::NotAuthenticated => VoiceGatewayError::NotAuthenticated, + VoiceCloseCode::AuthenticationFailed => VoiceGatewayError::AuthenticationFailed, + VoiceCloseCode::AlreadyAuthenticated => VoiceGatewayError::AlreadyAuthenticated, + VoiceCloseCode::SessionTimeout => VoiceGatewayError::SessionTimedOut, + VoiceCloseCode::SessionNoLongerValid => VoiceGatewayError::SessionNoLongerValid, + VoiceCloseCode::ServerNotFound => VoiceGatewayError::ServerNotFound, + VoiceCloseCode::UnknownProtocol => VoiceGatewayError::UnknownProtocol, + VoiceCloseCode::DisconnectedChannelDeletedOrKicked => VoiceGatewayError::Disconnected, + VoiceCloseCode::VoiceServerCrashed => VoiceGatewayError::VoiceServerCrashed, + VoiceCloseCode::UnknownEncryptionMode => VoiceGatewayError::UnknownEncryptionMode, + } + } +} + custom_error! { /// Voice UDP errors. #[derive(Clone, PartialEq, Eq, WebSocketEvent)] @@ -151,4 +205,3 @@ custom_error! { CannotBind{error: String} = "Cannot bind socket due to a UDP error: {error}", CannotConnect{error: String} = "Cannot connect due to a UDP error: {error}", } - diff --git a/src/gateway/backends/tungstenite.rs b/src/gateway/backends/tungstenite.rs index 4464f8dd..40d30827 100644 --- a/src/gateway/backends/tungstenite.rs +++ b/src/gateway/backends/tungstenite.rs @@ -9,12 +9,16 @@ use futures_util::{ }; use tokio::net::TcpStream; use tokio_tungstenite::{ - connect_async_tls_with_config, connect_async_with_config, tungstenite, Connector, - MaybeTlsStream, WebSocketStream, + connect_async_tls_with_config, connect_async_with_config, + tungstenite::self, + Connector, MaybeTlsStream, WebSocketStream, }; use url::Url; -use crate::gateway::{GatewayMessage, RawGatewayMessage}; +use crate::{ + gateway::{GatewayCommunication, GatewayMessage, RawGatewayMessage}, + types::CloseCode, +}; #[derive(Debug, Clone, Copy)] pub struct TungsteniteBackend; @@ -112,12 +116,27 @@ impl From for tungstenite::Message { } } -impl From for RawGatewayMessage { +impl From for GatewayCommunication { fn from(value: tungstenite::Message) -> Self { match value { - tungstenite::Message::Binary(bytes) => RawGatewayMessage::Bytes(bytes), - tungstenite::Message::Text(text) => RawGatewayMessage::Text(text), - _ => RawGatewayMessage::Text(value.to_string()), + tungstenite::Message::Binary(bytes) => { + GatewayCommunication::Message(RawGatewayMessage::Bytes(bytes)) + } + tungstenite::Message::Text(text) => { + GatewayCommunication::Message(RawGatewayMessage::Text(text)) + } + tungstenite::Message::Close(close_frame) => { + if close_frame.is_none() { + return GatewayCommunication::Error(CloseCode::UnknownError); + } + + let close_code = u16::from(close_frame.unwrap().code); + + GatewayCommunication::Error( + CloseCode::try_from(close_code).unwrap_or(CloseCode::UnknownError), + ) + } + _ => GatewayCommunication::Error(CloseCode::UnknownError), } } } diff --git a/src/gateway/backends/wasm.rs b/src/gateway/backends/wasm.rs index e0fd9c68..53fb87be 100644 --- a/src/gateway/backends/wasm.rs +++ b/src/gateway/backends/wasm.rs @@ -11,20 +11,23 @@ use ws_stream_wasm::*; use crate::gateway::{GatewayMessage, RawGatewayMessage}; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct WasmBackend; // These could be made into inherent associated types when that's stabilized pub type WasmSink = SplitSink; -pub type WasmStream = SplitStream; +/// Note: this includes WsMeta so we can subscribe to events as well +pub type WasmStream = (SplitStream, WsMeta); impl WasmBackend { pub async fn connect( websocket_url: &str, ) -> Result<(WasmSink, WasmStream), ws_stream_wasm::WsErr> { - let (_, websocket_stream) = WsMeta::connect(websocket_url, None).await?; + let (meta, websocket_stream) = WsMeta::connect(websocket_url, None).await?; - Ok(websocket_stream.split()) + let (sink, stream) = websocket_stream.split(); + + Ok((sink, (stream, meta))) } } diff --git a/src/gateway/events.rs b/src/gateway/events.rs index 4663fe1a..41780503 100644 --- a/src/gateway/events.rs +++ b/src/gateway/events.rs @@ -25,6 +25,7 @@ pub struct Events { pub call: Call, pub voice: Voice, pub webhooks: Webhooks, + pub mfa: Mfa, pub gateway_identify_payload: Publisher, pub gateway_resume: Publisher, pub error: Publisher, @@ -71,6 +72,7 @@ pub struct Message { pub reaction_remove_emoji: Publisher, pub recent_mention_delete: Publisher, pub ack: Publisher, + pub last_messages: Publisher, } #[derive(Default, Debug)] @@ -169,3 +171,10 @@ pub struct Voice { pub struct Webhooks { pub update: Publisher, } + +#[derive(Default, Debug)] +pub struct Mfa { + pub authenticator_create: Publisher, + pub authenticator_update: Publisher, + pub authenticator_delete: Publisher, +} diff --git a/src/gateway/gateway.rs b/src/gateway/gateway.rs index 20f86407..98ac9cea 100644 --- a/src/gateway/gateway.rs +++ b/src/gateway/gateway.rs @@ -16,11 +16,15 @@ use super::*; use super::{Sink, Stream}; use crate::types::{ self, AutoModerationRule, AutoModerationRuleUpdate, Channel, ChannelCreate, ChannelDelete, - ChannelUpdate, GatewayInvalidSession, GatewayReconnect, Guild, GuildRoleCreate, - GuildRoleUpdate, JsonField, RoleObject, SourceUrlField, ThreadUpdate, UpdateMessage, + ChannelUpdate, CloseCode, GatewayInvalidSession, GatewayReconnect, Guild, GuildRoleCreate, + GuildRoleUpdate, JsonField, Opcode, RoleObject, SourceUrlField, ThreadUpdate, UpdateMessage, WebSocketEvent, }; +// Needed to observe close codes +#[cfg(target_arch = "wasm32")] +use pharos::Observable; + /// Tells us we have received enough of the buffer to decompress it const ZLIB_SUFFIX: [u8; 4] = [0, 0, 255, 255]; @@ -72,9 +76,21 @@ impl Gateway { // Wait for the first hello and then spawn both tasks so we avoid nested tasks // This automatically spawns the heartbeat task, but from the main thread #[cfg(not(target_arch = "wasm32"))] - let received: RawGatewayMessage = websocket_receive.next().await.unwrap().unwrap().into(); + let received: RawGatewayMessage = { + // Note: The tungstenite backend handles close codes as messages, while the ws_stream_wasm one handles them differently. + // + // Hence why wasm receives straight RawGatewayMessages, and tungstenite receives + // GatewayCommunications. + let communication: GatewayCommunication = + websocket_receive.next().await.unwrap().unwrap().into(); + + match communication { + GatewayCommunication::Message(message) => message, + GatewayCommunication::Error(error) => return Err(error.into()), + } + }; #[cfg(target_arch = "wasm32")] - let received: RawGatewayMessage = websocket_receive.next().await.unwrap().into(); + let received: RawGatewayMessage = websocket_receive.0.next().await.unwrap().into(); let message: GatewayMessage; @@ -101,13 +117,14 @@ impl Gateway { let gateway_payload: types::GatewayReceivePayload = serde_json::from_str(&message.0).unwrap(); - if gateway_payload.op_code != GATEWAY_HELLO { + if gateway_payload.op_code != (Opcode::Hello as u8) { + warn!("GW: Received a non-hello opcode ({}) on gateway init", gateway_payload.op_code); return Err(GatewayError::NonHelloOnInitiate { opcode: gateway_payload.op_code, }); } - info!("GW: Received Hello"); + debug!("GW: Received Hello"); let gateway_hello: types::HelloData = serde_json::from_str(gateway_payload.event_data.unwrap().get()).unwrap(); @@ -138,11 +155,11 @@ impl Gateway { // Now we can continuously check for messages in a different task, since we aren't going to receive another hello #[cfg(not(target_arch = "wasm32"))] task::spawn(async move { - gateway.gateway_listen_task().await; + gateway.gateway_listen_task_tungstenite().await; }); #[cfg(target_arch = "wasm32")] wasm_bindgen_futures::spawn_local(async move { - gateway.gateway_listen_task().await; + gateway.gateway_listen_task_wasm().await; }); Ok(GatewayHandle { @@ -154,8 +171,9 @@ impl Gateway { }) } - /// The main gateway listener task; - async fn gateway_listen_task(&mut self) { + /// The main gateway listener task for a tungstenite based gateway; + #[cfg(not(target_arch = "wasm32"))] + async fn gateway_listen_task_tungstenite(&mut self) { loop { let msg; @@ -169,13 +187,75 @@ impl Gateway { } } - // PRETTYFYME: Remove inline conditional compiling - #[cfg(not(target_arch = "wasm32"))] + // Note: The tungstenite backend handles close codes as messages, while the ws_stream_wasm one handles them differently. + // + // Hence why wasm receives straight RawGatewayMessages, and tungstenite receives + // GatewayCommunications. if let Some(Ok(message)) = msg { - self.handle_raw_message(message.into()).await; + let communication: GatewayCommunication = message.into(); + + match communication { + GatewayCommunication::Message(raw_message) => { + self.handle_raw_message(raw_message).await + } + GatewayCommunication::Error(close_code) => { + self.handle_close_code(close_code).await + } + } + continue; } - #[cfg(target_arch = "wasm32")] + + // We couldn't receive the next message or it was an error, something is wrong with the websocket, close + warn!("GW: Websocket is broken, stopping gateway"); + break; + } + } + + /// The main gateway listener task for a wasm based gateway; + /// + /// Wasm handles close codes and events differently, and so we must change the listener logic a + /// bit + #[cfg(target_arch = "wasm32")] + async fn gateway_listen_task_wasm(&mut self) { + // Initiate the close event listener + let mut close_events = self + .websocket_receive + .1 + .observe(pharos::Filter::Pointer(ws_stream_wasm::WsEvent::is_closed).into()) + .await + .unwrap(); + + loop { + let msg; + + tokio::select! { + Ok(_) = self.kill_receive.recv() => { + log::trace!("GW: Closing listener task"); + break; + } + message = self.websocket_receive.0.next() => { + msg = message; + } + maybe_event = close_events.next() => { + if let Some(event) = maybe_event { + match event { + ws_stream_wasm::WsEvent::Closed(closed_event) => { + let close_code = CloseCode::try_from(closed_event.code).unwrap_or(CloseCode::UnknownError); + self.handle_close_code(close_code).await; + break; + } + _ => unreachable!() // Should be impossible, we filtered close events + } + } + continue; + } + } + + // Note: The tungstenite backend handles close codes as messages, while the ws_stream_wasm one handles them as a seperate receiver. + // + // Hence why wasm receives RawGatewayMessages, and tungstenite receives + // GatewayCommunications. if let Some(message) = msg { self.handle_raw_message(message.into()).await; continue; @@ -193,6 +273,17 @@ impl Gateway { self.websocket_send.lock().await.close().await.unwrap(); } + /// Handles receiving a [CloseCode]. + /// + /// Closes the connection and publishes an error event. + async fn handle_close_code(&mut self, code: CloseCode) { + let error = GatewayError::from(code); + + warn!("GW: Received error {:?}, connection will close..", error); + self.close().await; + self.events.lock().await.error.publish(error).await; + } + /// Deserializes and updates a dispatched event, when we already know its type; /// (Called for every event in handle_message) #[allow(dead_code)] // TODO: Remove this allow annotation @@ -211,7 +302,7 @@ impl Gateway { } /// Takes a [RawGatewayMessage], converts it to [GatewayMessage] based - /// of connection options and calls handle_message + /// of connection options and calls [Self::handle_message] async fn handle_raw_message(&mut self, raw_message: RawGatewayMessage) { let message; @@ -251,29 +342,32 @@ impl Gateway { } let Ok(gateway_payload) = msg.payload() else { - if let Some(error) = msg.error() { - warn!("GW: Received error {:?}, connection will close..", error); - self.close().await; - self.events.lock().await.error.publish(error).await; - } else { - warn!( - "Message unrecognised: {:?}, please open an issue on the chorus github", - msg.0 - ); - } + warn!( + "GW: Message unrecognised: {:?}, please open an issue on the chorus github", + msg.0 + ); return; }; - // See https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-opcodes - match gateway_payload.op_code { + let op_code_res = Opcode::try_from(gateway_payload.op_code); + + if op_code_res.is_err() { + warn!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code); + trace!("Event data: {:?}", gateway_payload); + return; + } + + let op_code = op_code_res.unwrap(); + + match op_code { // An event was dispatched, we need to look at the gateway event name t - GATEWAY_DISPATCH => { + Opcode::Dispatch => { let Some(event_name) = gateway_payload.event_name else { - warn!("Gateway dispatch op without event_name"); + warn!("GW: Received dispatch without event_name"); return; }; - trace!("Gateway: Received {event_name}"); + trace!("GW: Received {event_name}"); macro_rules! handle { ($($name:literal => $($path:ident).+ $( $message_type:ty: $update_type:ty)?),*) => { @@ -354,6 +448,9 @@ impl Gateway { "AUTO_MODERATION_RULE_UPDATE" =>auto_moderation.rule_update AutoModerationRuleUpdate: AutoModerationRule, "AUTO_MODERATION_RULE_DELETE" => auto_moderation.rule_delete, "AUTO_MODERATION_ACTION_EXECUTION" => auto_moderation.action_execution, + "AUTHENTICATOR_CREATE" => mfa.authenticator_create, // TODO + "AUTHENTICATOR_UPDATE" => mfa.authenticator_update, // TODO + "AUTHENTICATOR_DELETE" => mfa.authenticator_delete, // TODO "CHANNEL_CREATE" => channel.create ChannelCreate: Guild, "CHANNEL_UPDATE" => channel.update ChannelUpdate: Channel, "CHANNEL_UNREAD_UPDATE" => channel.unread_update, @@ -396,6 +493,7 @@ impl Gateway { "INTERACTION_CREATE" => interaction.create, // TODO "INVITE_CREATE" => invite.create, // TODO "INVITE_DELETE" => invite.delete, // TODO + "LAST_MESSAGES" => message.last_messages, "MESSAGE_CREATE" => message.create, "MESSAGE_UPDATE" => message.update, // TODO "MESSAGE_DELETE" => message.delete, @@ -424,14 +522,28 @@ impl Gateway { } // We received a heartbeat from the server // "Discord may send the app a Heartbeat (opcode 1) event, in which case the app should send a Heartbeat event immediately." - GATEWAY_HEARTBEAT => { + Opcode::Heartbeat => { trace!("GW: Received Heartbeat // Heartbeat Request"); // Tell the heartbeat handler it should send a heartbeat right away + let heartbeat_communication = HeartbeatThreadCommunication { + sequence_number: gateway_payload.sequence_number, + op_code: Some(Opcode::Heartbeat), + }; + self.heartbeat_handler + .send + .send(heartbeat_communication) + .await + .unwrap(); + } + Opcode::HeartbeatAck => { + trace!("GW: Received Heartbeat ACK"); + + // Tell the heartbeat handler we received an ack let heartbeat_communication = HeartbeatThreadCommunication { sequence_number: gateway_payload.sequence_number, - op_code: Some(GATEWAY_HEARTBEAT), + op_code: Some(Opcode::HeartbeatAck), }; self.heartbeat_handler @@ -440,7 +552,7 @@ impl Gateway { .await .unwrap(); } - GATEWAY_RECONNECT => { + Opcode::Reconnect => { trace!("GW: Received Reconnect"); let reconnect = GatewayReconnect {}; @@ -453,7 +565,7 @@ impl Gateway { .publish(reconnect) .await; } - GATEWAY_INVALID_SESSION => { + Opcode::InvalidSession => { trace!("GW: Received Invalid Session"); let mut resumable: bool = false; @@ -479,44 +591,19 @@ impl Gateway { .await; } // Starts our heartbeat - // We should have already handled this in gateway init - GATEWAY_HELLO => { + // We should have already handled this + Opcode::Hello => { warn!("Received hello when it was unexpected"); } - GATEWAY_HEARTBEAT_ACK => { - trace!("GW: Received Heartbeat ACK"); - - // Tell the heartbeat handler we received an ack - - let heartbeat_communication = HeartbeatThreadCommunication { - sequence_number: gateway_payload.sequence_number, - op_code: Some(GATEWAY_HEARTBEAT_ACK), - }; - - self.heartbeat_handler - .send - .send(heartbeat_communication) - .await - .unwrap(); - } - GATEWAY_IDENTIFY - | GATEWAY_UPDATE_PRESENCE - | GATEWAY_UPDATE_VOICE_STATE - | GATEWAY_RESUME - | GATEWAY_REQUEST_GUILD_MEMBERS - | GATEWAY_CALL_SYNC - | GATEWAY_LAZY_REQUEST => { - info!( - "Received unexpected opcode ({}) for current state. This might be due to a faulty server implementation and is likely not the fault of chorus.", + _ => { + warn!( + "Received unexpected opcode ({}) for current state. This might be due to a faulty server implementation, but you can open an issue on the chorus github anyway", gateway_payload.op_code ); } - _ => { - warn!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code); - } } - // If we we received a seq number we should let it know + // If we we received a sequence number we should let the heartbeat thread know if let Some(seq_num) = gateway_payload.sequence_number { let heartbeat_communication = HeartbeatThreadCommunication { sequence_number: Some(seq_num), diff --git a/src/gateway/handle.rs b/src/gateway/handle.rs index 98a77314..f165d40b 100644 --- a/src/gateway/handle.rs +++ b/src/gateway/handle.rs @@ -8,7 +8,7 @@ use log::*; use std::fmt::Debug; use super::{events::Events, *}; -use crate::types::{self, Composite, Shared}; +use crate::types::{self, Composite, Opcode, Shared}; /// Represents a handle to a Gateway connection. /// @@ -105,70 +105,90 @@ impl GatewayHandle { object } - /// Sends an identify event to the gateway + /// Sends an identify event ([types::GatewayIdentifyPayload]) to the gateway pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) { let to_send_value = serde_json::to_value(&to_send).unwrap(); trace!("GW: Sending Identify.."); - self.send_json_event(GATEWAY_IDENTIFY, to_send_value).await; + self.send_json_event(Opcode::Identify as u8, to_send_value).await; } - /// Sends a resume event to the gateway + /// Sends a resume event ([types::GatewayResume]) to the gateway pub async fn send_resume(&self, to_send: types::GatewayResume) { let to_send_value = serde_json::to_value(&to_send).unwrap(); trace!("GW: Sending Resume.."); - self.send_json_event(GATEWAY_RESUME, to_send_value).await; + self.send_json_event(Opcode::Resume as u8, to_send_value).await; } - /// Sends an update presence event to the gateway + /// Sends an update presence event ([types::UpdatePresence]) to the gateway pub async fn send_update_presence(&self, to_send: types::UpdatePresence) { let to_send_value = serde_json::to_value(&to_send).unwrap(); trace!("GW: Sending Update Presence.."); - self.send_json_event(GATEWAY_UPDATE_PRESENCE, to_send_value) + self.send_json_event(Opcode::PresenceUpdate as u8, to_send_value) .await; } - /// Sends a request guild members to the server + /// Sends a request guild members ([types::GatewayRequestGuildMembers]) to the server pub async fn send_request_guild_members(&self, to_send: types::GatewayRequestGuildMembers) { let to_send_value = serde_json::to_value(&to_send).unwrap(); trace!("GW: Sending Request Guild Members.."); - self.send_json_event(GATEWAY_REQUEST_GUILD_MEMBERS, to_send_value) + self.send_json_event(Opcode::RequestGuildMembers as u8, to_send_value) .await; } - /// Sends an update voice state to the server + /// Sends an update voice state ([types::UpdateVoiceState]) to the server pub async fn send_update_voice_state(&self, to_send: types::UpdateVoiceState) { let to_send_value = serde_json::to_value(to_send).unwrap(); trace!("GW: Sending Update Voice State.."); - self.send_json_event(GATEWAY_UPDATE_VOICE_STATE, to_send_value) + self.send_json_event(Opcode::VoiceStateUpdate as u8, to_send_value) .await; } - /// Sends a call sync to the server + /// Sends a call sync ([types::CallSync]) to the server pub async fn send_call_sync(&self, to_send: types::CallSync) { let to_send_value = serde_json::to_value(to_send).unwrap(); trace!("GW: Sending Call Sync.."); - self.send_json_event(GATEWAY_CALL_SYNC, to_send_value).await; + self.send_json_event(Opcode::CallConnect as u8, to_send_value).await; } - /// Sends a Lazy Request + /// Sends a request call connect event (aka [types::CallSync]) to the server + /// + /// # Notes + /// Alias of [Self::send_call_sync] + pub async fn send_request_call_connect(&self, to_send: types::CallSync) { + self.send_call_sync(to_send).await + } + + /// Sends a Lazy Request ([types::LazyRequest]) to the server pub async fn send_lazy_request(&self, to_send: types::LazyRequest) { let to_send_value = serde_json::to_value(&to_send).unwrap(); trace!("GW: Sending Lazy Request.."); - self.send_json_event(GATEWAY_LAZY_REQUEST, to_send_value) + self.send_json_event(Opcode::GuildSubscriptions as u8, to_send_value) + .await; + } + + /// Sends a Request Last Messages ([types::RequestLastMessages]) to the server + /// + /// The server should respond with a [types::LastMessages] event + pub async fn send_request_last_messages(&self, to_send: types::RequestLastMessages) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Request Last Messages.."); + + self.send_json_event(Opcode::RequestLastMessages as u8, to_send_value) .await; } diff --git a/src/gateway/heartbeat.rs b/src/gateway/heartbeat.rs index 5dcc98db..2a0a6f97 100644 --- a/src/gateway/heartbeat.rs +++ b/src/gateway/heartbeat.rs @@ -23,13 +23,13 @@ use tokio::sync::mpsc::{Receiver, Sender}; use tokio::task; use super::*; -use crate::types; +use crate::types::{self, Opcode}; /// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms pub const HEARTBEAT_ACK_TIMEOUT: u64 = 2000; /// Handles sending heartbeats to the gateway in another thread -#[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used +#[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is "used" #[derive(Debug)] pub(super) struct HeartbeatHandler { /// How ofter heartbeats need to be sent at a minimum @@ -98,11 +98,11 @@ impl HeartbeatHandler { if let Some(op_code) = communication.op_code { match op_code { - GATEWAY_HEARTBEAT => { + Opcode::Heartbeat => { // As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately should_send = true; } - GATEWAY_HEARTBEAT_ACK => { + Opcode::HeartbeatAck => { // The server received our heartbeat last_heartbeat_acknowledged = true; } @@ -120,7 +120,7 @@ impl HeartbeatHandler { trace!("GW: Sending Heartbeat.."); let heartbeat = types::GatewayHeartbeat { - op: GATEWAY_HEARTBEAT, + op: (Opcode::Heartbeat as u8), d: last_seq_number, }; @@ -147,7 +147,7 @@ impl HeartbeatHandler { #[derive(Clone, Copy, Debug)] pub(super) struct HeartbeatThreadCommunication { /// The opcode for the communication we received, if relevant - pub(super) op_code: Option, + pub(super) op_code: Option, /// The sequence number we got from discord, if any pub(super) sequence_number: Option, } diff --git a/src/gateway/message.rs b/src/gateway/message.rs index 7d44af6b..cae45af3 100644 --- a/src/gateway/message.rs +++ b/src/gateway/message.rs @@ -4,10 +4,27 @@ use std::string::FromUtf8Error; -use crate::types; +use crate::types::{CloseCode, GatewayReceivePayload}; use super::*; +#[derive(Clone, Debug, PartialEq, Eq)] +/// Defines a communication received from the gateway, being either an optionally compressed +/// [RawGatewayMessage] or a [CloseCode]. +/// +/// Used only for a tungstenite gateway, since our underlying wasm backend handles close codes +/// differently. +pub(crate) enum GatewayCommunication { + Message(RawGatewayMessage), + Error(CloseCode), +} + +impl From for GatewayCommunication { + fn from(value: RawGatewayMessage) -> Self { + Self::Message(value) + } +} + /// Defines a raw gateway message, being either string json or bytes /// /// This is used as an intermediary type between types from different websocket implementations @@ -36,42 +53,17 @@ impl RawGatewayMessage { } /// Represents a json message received from the gateway. -/// This will be either a [types::GatewayReceivePayload], containing events, or a [GatewayError]. +/// +/// This will usually be a [GatewayReceivePayload]. +/// /// This struct is used internally when handling messages. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct GatewayMessage(pub String); impl GatewayMessage { - /// Parses the message as an error; - /// Returns the error if successfully parsed, None if the message isn't an error - pub fn error(&self) -> Option { - // Some error strings have dots on the end, which we don't care about - let processed_content = self.0.to_lowercase().replace('.', ""); - - match processed_content.as_str() { - "unknown error" | "4000" => Some(GatewayError::Unknown), - "unknown opcode" | "4001" => Some(GatewayError::UnknownOpcode), - "decode error" | "error while decoding payload" | "4002" => Some(GatewayError::Decode), - "not authenticated" | "4003" => Some(GatewayError::NotAuthenticated), - "authentication failed" | "4004" => Some(GatewayError::AuthenticationFailed), - "already authenticated" | "4005" => Some(GatewayError::AlreadyAuthenticated), - "invalid seq" | "4007" => Some(GatewayError::InvalidSequenceNumber), - "rate limited" | "4008" => Some(GatewayError::RateLimited), - "session timed out" | "4009" => Some(GatewayError::SessionTimedOut), - "invalid shard" | "4010" => Some(GatewayError::InvalidShard), - "sharding required" | "4011" => Some(GatewayError::ShardingRequired), - "invalid api version" | "4012" => Some(GatewayError::InvalidAPIVersion), - "invalid intent(s)" | "invalid intent" | "4013" => Some(GatewayError::InvalidIntents), - "disallowed intent(s)" | "disallowed intents" | "4014" => { - Some(GatewayError::DisallowedIntents) - } - _ => None, - } - } - /// Parses the message as a payload; /// Returns a result of deserializing - pub fn payload(&self) -> Result { + pub fn payload(&self) -> Result { serde_json::from_str(&self.0) } @@ -90,7 +82,6 @@ impl GatewayMessage { bytes: &[u8], inflate: &mut flate2::Decompress, ) -> Result { - // Note: is there a better way to handle the size of this output buffer? // // This used to be 10, I measured it at 11.5, so a safe bet feels like 20 diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 6ee06788..98d34a9b 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -2,6 +2,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +#![allow(deprecated)] // Since Opcode variants marked as deprecated are being used here, we need to suppress the warnings about them being deprecated pub mod backends; pub mod events; @@ -27,54 +28,6 @@ use std::sync::{Arc, RwLock}; use tokio::sync::Mutex; -// Gateway opcodes -/// Opcode received when the server dispatches a [crate::types::WebSocketEvent] -const GATEWAY_DISPATCH: u8 = 0; -/// Opcode sent when sending a heartbeat -const GATEWAY_HEARTBEAT: u8 = 1; -/// Opcode sent to initiate a session -/// -/// See [types::GatewayIdentifyPayload] -const GATEWAY_IDENTIFY: u8 = 2; -/// Opcode sent to update our presence -/// -/// See [types::GatewayUpdatePresence] -const GATEWAY_UPDATE_PRESENCE: u8 = 3; -/// Opcode sent to update our state in vc -/// -/// Like muting, deafening, leaving, joining.. -/// -/// See [types::UpdateVoiceState] -const GATEWAY_UPDATE_VOICE_STATE: u8 = 4; -/// Opcode sent to resume a session -/// -/// See [types::GatewayResume] -const GATEWAY_RESUME: u8 = 6; -/// Opcode received to tell the client to reconnect -const GATEWAY_RECONNECT: u8 = 7; -/// Opcode sent to request guild member data -/// -/// See [types::GatewayRequestGuildMembers] -const GATEWAY_REQUEST_GUILD_MEMBERS: u8 = 8; -/// Opcode received to tell the client their token / session is invalid -const GATEWAY_INVALID_SESSION: u8 = 9; -/// Opcode received when initially connecting to the gateway, starts our heartbeat -/// -/// See [types::HelloData] -const GATEWAY_HELLO: u8 = 10; -/// Opcode received to acknowledge a heartbeat -const GATEWAY_HEARTBEAT_ACK: u8 = 11; -/// Opcode sent to get the voice state of users in a given DM/group channel -/// -/// See [types::CallSync] -const GATEWAY_CALL_SYNC: u8 = 13; -/// Opcode sent to get data for a server (Lazy Loading request) -/// -/// Sent by the official client when switching to a server -/// -/// See [types::LazyRequest] -const GATEWAY_LAZY_REQUEST: u8 = 14; - pub type ObservableObject = dyn Send + Sync + Any; /// Note: this is a reexport of [pubserve::Subscriber], diff --git a/src/instance.rs b/src/instance.rs index 3956dd06..aa891ede 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -8,7 +8,9 @@ use std::collections::HashMap; use std::fmt; use std::sync::{Arc, RwLock}; +use std::time::Duration; +use chrono::Utc; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -17,7 +19,8 @@ use crate::gateway::{Gateway, GatewayHandle, GatewayOptions}; use crate::ratelimiter::ChorusRequest; use crate::types::types::subconfigs::limits::rates::RateLimits; use crate::types::{ - GeneralConfiguration, Limit, LimitType, LimitsConfiguration, Shared, User, UserSettings, + GatewayIdentifyPayload, GeneralConfiguration, Limit, LimitType, LimitsConfiguration, MfaToken, + MfaTokenSchema, MfaVerifySchema, Shared, User, UserSettings, }; use crate::UrlBundle; @@ -258,6 +261,7 @@ impl fmt::Display for Token { pub struct ChorusUser { pub belongs_to: Shared, pub token: String, + pub mfa_token: Option, pub limits: Option>, pub settings: Shared, pub object: Shared, @@ -289,6 +293,7 @@ impl ChorusUser { ChorusUser { belongs_to, token, + mfa_token: None, limits, settings, object, @@ -296,6 +301,34 @@ impl ChorusUser { } } + /// Updates a shell user after the login process. + /// + /// Fetches all the other required data from the api. + /// + /// If the received_settings can be None, since not all login methods + /// return user settings. If this is the case, we'll fetch them via an api route. + pub(crate) async fn update_with_login_data( + &mut self, + token: String, + received_settings: Option>, + ) -> ChorusResult<()> { + self.token = token.clone(); + + let mut identify = GatewayIdentifyPayload::common(); + identify.token = token; + self.gateway.send_identify(identify).await; + + *self.object.write().unwrap() = self.get_current_user().await?; + + if let Some(passed_settings) = received_settings { + self.settings = passed_settings; + } else { + *self.settings.write().unwrap() = self.get_settings().await?; + } + + Ok(()) + } + /// Creates a new 'shell' of a user. The user does not exist as an object, and exists so that you have /// a ChorusUser object to make Rate Limited requests with. This is useful in scenarios like /// registering or logging in to the Instance, where you do not yet have a User object, but still @@ -304,12 +337,15 @@ impl ChorusUser { pub(crate) async fn shell(instance: Shared, token: &str) -> ChorusUser { let settings = Arc::new(RwLock::new(UserSettings::default())); let object = Arc::new(RwLock::new(User::default())); + let wss_url = &instance.read().unwrap().urls.wss.clone(); let gateway_options = instance.read().unwrap().gateway_options; + // Dummy gateway object let gateway = Gateway::spawn(wss_url, gateway_options).await.unwrap(); ChorusUser { token: token.to_string(), + mfa_token: None, belongs_to: instance.clone(), limits: instance .read() @@ -322,4 +358,40 @@ impl ChorusUser { gateway, } } + + /// Sends a request to complete an MFA challenge. + /// + /// If successful, the MFA verification JWT returned is set on the current [ChorusUser] executing the + /// request. + /// + /// The JWT token expires after 5 minutes. + /// + /// This route is usually used in response to [ChorusError::MfaRequired](crate::ChorusError::MfaRequired). + /// + /// # Reference + /// See + pub async fn complete_mfa_challenge( + &mut self, + mfa_verify_schema: MfaVerifySchema, + ) -> ChorusResult<()> { + let endpoint_url = self.belongs_to.read().unwrap().urls.api.clone() + "/mfa/finish"; + let chorus_request = ChorusRequest { + request: Client::new() + .post(endpoint_url) + .header("Authorization", self.token()) + .json(&mfa_verify_schema), + limit_type: LimitType::Global, + }; + + let mfa_token_schema = chorus_request + .deserialize_response::(self) + .await?; + + self.mfa_token = Some(MfaToken { + token: mfa_token_schema.token, + expires_at: Utc::now() + Duration::from_secs(60 * 5), + }); + + Ok(()) + } } diff --git a/src/ratelimiter.rs b/src/ratelimiter.rs index 5ffcca66..785c57db 100644 --- a/src/ratelimiter.rs +++ b/src/ratelimiter.rs @@ -13,7 +13,7 @@ use serde_json::from_str; use crate::{ errors::{ChorusError, ChorusResult}, instance::ChorusUser, - types::{types::subconfigs::limits::rates::RateLimits, Limit, LimitType, LimitsConfiguration}, + types::{types::subconfigs::limits::rates::RateLimits, Limit, LimitType, LimitsConfiguration, MfaRequiredSchema}, }; /// Chorus' request struct. This struct is used to send rate-limited requests to the Spacebar server. @@ -34,13 +34,12 @@ impl ChorusRequest { /// * [`http::Method::DELETE`] /// * [`http::Method::PATCH`] /// * [`http::Method::HEAD`] - #[allow(unused_variables)] // TODO: Add mfa_token to request, once we figure out *how* to do so correctly + #[allow(unused_variables)] pub fn new( method: http::Method, url: &str, body: Option, audit_log_reason: Option<&str>, - mfa_token: Option<&str>, chorus_user: Option<&mut ChorusUser>, limit_type: LimitType, ) -> ChorusRequest { @@ -266,7 +265,14 @@ impl ChorusRequest { async fn interpret_error(response: reqwest::Response) -> ChorusError { match response.status().as_u16() { - 401..=403 | 407 => ChorusError::NoPermission, + 401 => { + let response = response.text().await.unwrap(); + match serde_json::from_str::(&response) { + Ok(response) => ChorusError::MfaRequired { error: response }, + Err(_) => ChorusError::NoPermission, + } + } + 402..=403 | 407 => ChorusError::NoPermission, 404 => ChorusError::NotFound { error: response.text().await.unwrap(), }, diff --git a/src/types/entities/channel.rs b/src/types/entities/channel.rs index df301533..02cdbb3f 100644 --- a/src/types/entities/channel.rs +++ b/src/types/entities/channel.rs @@ -8,6 +8,7 @@ use serde_repr::{Deserialize_repr, Serialize_repr}; use std::fmt::{Debug, Formatter}; use std::str::FromStr; +use crate::errors::ChorusError; use crate::types::{ entities::{GuildMember, User}, utils::Snowflake, @@ -31,7 +32,7 @@ use serde::de::{Error, Visitor}; #[cfg(feature = "sqlx")] use sqlx::types::Json; -use super::{option_arc_rwlock_ptr_eq, option_vec_arc_rwlock_ptr_eq}; +use super::{option_arc_rwlock_ptr_eq, option_vec_arc_rwlock_ptr_eq, Emoji}; #[derive(Default, Debug, Serialize, Deserialize, Clone)] #[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] @@ -42,17 +43,23 @@ use super::{option_arc_rwlock_ptr_eq, option_vec_arc_rwlock_ptr_eq}; /// See pub struct Channel { pub application_id: Option, - pub applied_tags: Option>, + #[cfg(not(feature = "sqlx"))] + pub applied_tags: Option>, + #[cfg(feature = "sqlx")] + pub applied_tags: Option>>, + #[cfg(not(feature = "sqlx"))] pub available_tags: Option>, + #[cfg(feature = "sqlx")] + pub available_tags: Option>>, pub bitrate: Option, #[serde(rename = "type")] + #[cfg_attr(feature = "sqlx", sqlx(rename = "type"))] pub channel_type: ChannelType, pub created_at: Option>, pub default_auto_archive_duration: Option, - pub default_forum_layout: Option, - // DefaultReaction could be stored in a separate table. However, there are a lot of default emojis. How would we handle that? + pub default_forum_layout: Option, pub default_reaction_emoji: Option, - pub default_sort_order: Option, + pub default_sort_order: Option, pub default_thread_rate_limit_per_user: Option, pub flags: Option, pub guild_id: Option, @@ -325,6 +332,15 @@ pub struct DefaultReaction { pub emoji_name: Option, } +impl From for DefaultReaction { + fn from(value: Emoji) -> Self { + Self { + emoji_id: Some(value.id), + emoji_name: value.name, + } + } +} + #[derive( Default, Clone, @@ -401,3 +417,111 @@ pub struct FollowedChannel { pub channel_id: Snowflake, pub webhook_id: Snowflake, } + +#[derive( + Debug, Deserialize, Serialize, Clone, PartialEq, Eq, Copy, Hash, PartialOrd, Ord, Default, +)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +#[repr(u8)] +pub enum DefaultForumLayout { + #[default] + Default = 0, + List = 1, + Grid = 2, +} + +impl TryFrom for DefaultForumLayout { + type Error = ChorusError; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(DefaultForumLayout::Default), + 1 => Ok(DefaultForumLayout::List), + 2 => Ok(DefaultForumLayout::Grid), + _ => Err(ChorusError::InvalidArguments { + error: "Value is not a valid DefaultForumLayout".to_string(), + }), + } + } +} + +#[cfg(feature = "sqlx")] +impl sqlx::Type for DefaultForumLayout { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +#[cfg(feature = "sqlx")] +impl<'q> sqlx::Encode<'q, sqlx::Postgres> for DefaultForumLayout { + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { + let sqlx_pg_uint = sqlx_pg_uint::PgU8::from(*self as u8); + sqlx_pg_uint.encode_by_ref(buf) + } +} + +#[cfg(feature = "sqlx")] +impl<'r> sqlx::Decode<'r, sqlx::Postgres> for DefaultForumLayout { + fn decode( + value: ::ValueRef<'r>, + ) -> Result { + let sqlx_pg_uint = sqlx_pg_uint::PgU8::decode(value)?; + DefaultForumLayout::try_from(sqlx_pg_uint.to_uint()).map_err(|e| e.into()) + } +} + +#[derive( + Debug, Deserialize, Serialize, Clone, PartialEq, Eq, Copy, Hash, PartialOrd, Ord, Default, +)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +#[repr(u8)] +pub enum DefaultSortOrder { + #[default] + LatestActivity = 0, + CreationTime = 1, +} + +impl TryFrom for DefaultSortOrder { + type Error = ChorusError; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(DefaultSortOrder::LatestActivity), + 1 => Ok(DefaultSortOrder::CreationTime), + _ => Err(ChorusError::InvalidArguments { + error: "Value is not a valid DefaultSearchOrder".to_string(), + }), + } + } +} + +#[cfg(feature = "sqlx")] +impl sqlx::Type for DefaultSortOrder { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +#[cfg(feature = "sqlx")] +impl<'q> sqlx::Encode<'q, sqlx::Postgres> for DefaultSortOrder { + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { + let sqlx_pg_uint = sqlx_pg_uint::PgU8::from(*self as u8); + sqlx_pg_uint.encode_by_ref(buf) + } +} + +#[cfg(feature = "sqlx")] +impl<'r> sqlx::Decode<'r, sqlx::Postgres> for DefaultSortOrder { + fn decode( + value: ::ValueRef<'r>, + ) -> Result { + let sqlx_pg_uint = sqlx_pg_uint::PgU8::decode(value)?; + DefaultSortOrder::try_from(sqlx_pg_uint.to_uint()).map_err(|e| e.into()) + } +} diff --git a/src/types/entities/message.rs b/src/types/entities/message.rs index b63eb3bf..9d93e138 100644 --- a/src/types/entities/message.rs +++ b/src/types/entities/message.rs @@ -132,6 +132,7 @@ impl PartialEq for Message { /// See pub struct MessageReference { #[serde(rename = "type")] + #[serde(default)] pub reference_type: MessageReferenceType, pub message_id: Snowflake, pub channel_id: Snowflake, @@ -139,9 +140,11 @@ pub struct MessageReference { pub fail_if_not_exists: Option, } -#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, Eq, Ord, PartialOrd, Copy)] +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, Eq, Ord, PartialOrd, Copy, Default)] +#[repr(u8)] pub enum MessageReferenceType { /// A standard reference used by replies and system messages + #[default] Default = 0, /// A reference used to point to a message at a point in time Forward = 1, diff --git a/src/types/entities/mfa_token.rs b/src/types/entities/mfa_token.rs new file mode 100644 index 00000000..4aee344c --- /dev/null +++ b/src/types/entities/mfa_token.rs @@ -0,0 +1,61 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +use chrono::{DateTime, Utc}; +use reqwest::RequestBuilder; + +use crate::ratelimiter::ChorusRequest; + +#[derive(Debug, Clone)] +/// A Token used to bypass mfa for five minutes. +pub struct MfaToken { + pub token: String, + pub expires_at: DateTime, +} + +impl MfaToken { + /// Add the MFA bypass token to a reqwest request builder. + /// + /// This is used to provide the token in requests that require MFA. + pub fn add_to_request_builder(&self, request: RequestBuilder) -> RequestBuilder { + request.header("X-Discord-MFA-Authorization", &self.token) + } + + /// Add the MFA bypass token to a [ChorusRequest]. + /// + /// This is used to provide the token in requests that require MFA. + pub fn add_to_request(&self, request: ChorusRequest) -> ChorusRequest { + let mut request = request; + + let request_builder = request.request; + + request.request = self.add_to_request_builder(request_builder); + request + } + + /// Returns whether or not the token is still valid + pub fn is_valid(&self) -> bool { + Utc::now() < self.expires_at + } +} + +impl ChorusRequest { + /// Adds an [MfaToken] to the request. + /// + /// Used for requests that need MFA. + pub fn with_mfa(self, token: &MfaToken) -> ChorusRequest { + token.add_to_request(self) + } + + /// Adds an [MfaToken] to the request, if the token is [Some]. + /// + /// Used for requests that need MFA, when we might or might not have a token already + pub fn with_maybe_mfa(self, token: &Option) -> ChorusRequest { + if let Some(mfa_token) = token { + return mfa_token.add_to_request(self); + } + + self + } +} diff --git a/src/types/entities/mod.rs b/src/types/entities/mod.rs index 969d731a..c3fe572f 100644 --- a/src/types/entities/mod.rs +++ b/src/types/entities/mod.rs @@ -29,6 +29,9 @@ pub use user_settings::*; pub use voice_state::*; pub use webhook::*; +#[cfg(feature = "client")] +pub use mfa_token::*; + use crate::types::Shared; #[cfg(feature = "client")] use std::sync::{Arc, RwLock}; @@ -72,6 +75,13 @@ mod user_settings; mod voice_state; mod webhook; +// Note: this is a purely client side version of the mfa token. +// +// For the server, you'd likely only store when it expires somewhere, +// and give the JWT to the client to store +#[cfg(feature = "client")] +mod mfa_token; + #[cfg(feature = "client")] #[async_trait(?Send)] pub trait Composite { diff --git a/src/types/entities/relationship.rs b/src/types/entities/relationship.rs index da2a9dc7..cc143405 100644 --- a/src/types/entities/relationship.rs +++ b/src/types/entities/relationship.rs @@ -9,7 +9,7 @@ use serde_repr::{Deserialize_repr, Serialize_repr}; use crate::errors::ChorusError; use crate::types::{Shared, Snowflake}; -use super::{arc_rwlock_ptr_eq, PublicUser}; +use super::{option_arc_rwlock_ptr_eq, PublicUser}; #[derive(Debug, Deserialize, Serialize, Clone, Default)] #[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] @@ -25,7 +25,11 @@ pub struct Relationship { pub nickname: Option, #[cfg_attr(feature = "sqlx", sqlx(skip))] // Can be derived from the user id /// The target user - pub user: Shared, + /// + /// Note: on Discord.com, this is not sent in select places, such as READY payload. + /// + /// In such a case, you should refer to the id field and seperately fetch the user's data + pub user: Option>, /// When the user requested a relationship pub since: Option>, } @@ -36,7 +40,7 @@ impl PartialEq for Relationship { self.id == other.id && self.relationship_type == other.relationship_type && self.nickname == other.nickname - && arc_rwlock_ptr_eq(&self.user, &other.user) + && option_arc_rwlock_ptr_eq(&self.user, &other.user) && self.since == other.since } } diff --git a/src/types/entities/user.rs b/src/types/entities/user.rs index 5e761307..8491e7c6 100644 --- a/src/types/entities/user.rs +++ b/src/types/entities/user.rs @@ -4,7 +4,7 @@ use crate::errors::ChorusError; use crate::types::utils::Snowflake; -use crate::{UInt32, UInt8}; +use crate::UInt32; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use serde_aux::prelude::{deserialize_default_from_null, deserialize_option_number_from_string}; @@ -826,11 +826,14 @@ pub struct MutualGuild { #[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] pub struct UserNote { /// Actual note contents; max 256 characters - pub note: String, + #[serde(rename = "note")] + pub content: String, /// The ID of the user the note is on - pub note_user_id: Snowflake, + #[serde(rename = "note_user_id")] + pub target_id: Snowflake, /// The ID of the user who created the note (always the current user) - pub user_id: Snowflake, + #[serde(rename = "user_id")] + pub author_id: Snowflake, } /// Structure which defines an affinity the local user has with another user. diff --git a/src/types/entities/user_settings.rs b/src/types/entities/user_settings.rs index 1f4c1769..2d3f6baf 100644 --- a/src/types/entities/user_settings.rs +++ b/src/types/entities/user_settings.rs @@ -133,7 +133,9 @@ pub struct CustomStatus { pub text: Option, } -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Copy, PartialOrd, Ord, Hash)] +#[derive( + Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Copy, PartialOrd, Ord, Hash, +)] #[cfg_attr(feature = "sqlx", derive(sqlx::FromRow, sqlx::Type))] pub struct FriendSourceFlags { pub all: bool, diff --git a/src/types/events/call.rs b/src/types/events/call.rs index 7efdce10..b3f69a21 100644 --- a/src/types/events/call.rs +++ b/src/types/events/call.rs @@ -11,32 +11,42 @@ use chorus_macros::WebSocketEvent; /// Officially Undocumented; /// Is sent to a client by the server to signify a new call being created; /// -/// Ex: {"t":"CALL_CREATE","s":2,"op":0,"d":{"voice_states":[],"ringing":[],"region":"milan","message_id":"1107187514906775613","embedded_activities":[],"channel_id":"837609115475771392"}} +/// # Reference +/// See pub struct CallCreate { - pub voice_states: Vec, - /// Seems like a vec of channel ids - pub ringing: Vec, - pub region: String, - // milan - pub message_id: Snowflake, - /// What is this? - pub embedded_activities: Vec, + /// Id of the private channel this call is in pub channel_id: Snowflake, + /// Id of the messsage which created the call + pub message_id: Snowflake, + + /// The IDs of users that are being rung to join the call + pub ringing: Vec, + + // milan + pub region: String, + + /// The voice states of the users already in the call + pub voice_states: Vec, + // What is this? + //pub embedded_activities: Vec, } #[derive(Debug, Deserialize, Serialize, Default, Clone, PartialEq, Eq, WebSocketEvent)] -/// Officially Undocumented; -/// Updates the client on which calls are ringing, along with a specific call?; +/// Updates the client when metadata about a call changes. /// -/// Ex: {"t":"CALL_UPDATE","s":5,"op":0,"d":{"ringing":["837606544539254834"],"region":"milan","message_id":"1107191540234846308","guild_id":null,"channel_id":"837609115475771392"}} +/// # Reference +/// See pub struct CallUpdate { - /// Seems like a vec of channel ids + /// Id of the private channel this call is in + pub channel_id: Snowflake, + /// Id of the messsage which created the call + pub message_id: Snowflake, + + /// The IDs of users that are being rung to join the call pub ringing: Vec, - pub region: String, + // milan - pub message_id: Snowflake, - pub guild_id: Option, - pub channel_id: Snowflake, + pub region: String, } #[derive( @@ -52,11 +62,14 @@ pub struct CallUpdate { PartialOrd, Ord, )] -/// Officially Undocumented; -/// Deletes a ringing call; -/// Ex: {"t":"CALL_DELETE","s":8,"op":0,"d":{"channel_id":"837609115475771392"}} +/// Sent when a call is deleted, or becomes unavailable due to an outage. +/// +/// # Reference +/// See pub struct CallDelete { pub channel_id: Snowflake, + /// Whether the call is unavailable due to an outage + pub unavailable: Option, } #[derive( @@ -72,10 +85,13 @@ pub struct CallDelete { PartialOrd, Ord, )] -/// Officially Undocumented; -/// See ; +/// Used to request a private channel's pre-existing call data, +/// created before the connection was established. +/// +/// Fires a [CallCreate] event if a call is found. /// -/// Ex: {"op":13,"d":{"channel_id":"837609115475771392"}} +/// # Reference +/// See ; pub struct CallSync { pub channel_id: Snowflake, } diff --git a/src/types/events/guild.rs b/src/types/events/guild.rs index f599d1fa..595cdf2a 100644 --- a/src/types/events/guild.rs +++ b/src/types/events/guild.rs @@ -244,7 +244,7 @@ pub struct GuildMembersChunk { pub chunk_index: u16, pub chunk_count: u16, pub not_found: Option>, - pub presences: Option, + pub presences: Option>, pub nonce: Option, } diff --git a/src/types/events/identify.rs b/src/types/events/identify.rs index 28297066..f6d0e071 100644 --- a/src/types/events/identify.rs +++ b/src/types/events/identify.rs @@ -2,10 +2,12 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -use crate::types::events::{PresenceUpdate, WebSocketEvent}; +use crate::types::events::WebSocketEvent; use serde::{Deserialize, Serialize}; use serde_with::serde_as; +use super::GatewayIdentifyPresenceUpdate; + #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, WebSocketEvent)] pub struct GatewayIdentifyPayload { pub token: String, @@ -18,7 +20,7 @@ pub struct GatewayIdentifyPayload { #[serde(skip_serializing_if = "Option::is_none")] pub shard: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub presence: Option, + pub presence: Option, // What is the difference between these two? // Intents is documented, capabilities is used in users // I wonder if these are interchangeable... @@ -78,10 +80,12 @@ pub struct GatewayIdentifyConnectionProps { /// ex: "Linux", "Windows", "Mac OS X" /// /// ex (mobile): "Windows Mobile", "iOS", "Android", "BlackBerry" + #[serde(default)] pub os: String, /// Almost always sent /// /// ex: "Firefox", "Chrome", "Opera Mini", "Opera", "Blackberry", "Facebook Mobile", "Chrome iOS", "Mobile Safari", "Safari", "Android Chrome", "Android Mobile", "Edge", "Konqueror", "Internet Explorer", "Mozilla", "Discord Client" + #[serde(default)] pub browser: String, /// Sometimes not sent, acceptable to be "" /// @@ -94,14 +98,17 @@ pub struct GatewayIdentifyConnectionProps { /// Almost always sent, most commonly en-US /// /// ex: "en-US" + #[serde(default)] pub system_locale: String, /// Almost always sent /// /// ex: any user agent, most common is "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36" + #[serde(default)] pub browser_user_agent: String, /// Almost always sent /// /// ex: "113.0.0.0" + #[serde(default)] pub browser_version: String, /// Sometimes not sent, acceptable to be "" /// @@ -118,8 +125,10 @@ pub struct GatewayIdentifyConnectionProps { #[serde_as(as = "NoneAsEmptyString")] pub referrer_current: Option, /// Almost always sent, most commonly "stable" + #[serde(default)] pub release_channel: String, /// Almost always sent, identifiable if default is 0, should be around 199933 + #[serde(default)] pub client_build_number: u64, //pub client_event_source: Option } diff --git a/src/types/events/lazy_request.rs b/src/types/events/lazy_request.rs index 6e17b8ed..ccf2d191 100644 --- a/src/types/events/lazy_request.rs +++ b/src/types/events/lazy_request.rs @@ -20,6 +20,7 @@ use super::WebSocketEvent; /// /// {"op":14,"d":{"guild_id":"848582562217590824","typing":true,"activities":true,"threads":true}} pub struct LazyRequest { + /// The guild id to request pub guild_id: Snowflake, pub typing: bool, pub activities: bool, @@ -27,6 +28,6 @@ pub struct LazyRequest { #[serde(skip_serializing_if = "Option::is_none")] pub members: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub channels: Option>>>, + pub channels: Option>>>, } diff --git a/src/types/events/message.rs b/src/types/events/message.rs index d2ca3082..cffda72a 100644 --- a/src/types/events/message.rs +++ b/src/types/events/message.rs @@ -178,3 +178,28 @@ pub struct MessageACK { pub flags: Option, pub channel_id: Snowflake, } + +#[derive(Debug, Deserialize, Serialize, Default, Clone, WebSocketEvent)] +/// Used to request the last messages from channels. +/// +/// Fires a [LastMessages] events with up to 100 messages that match the request. +/// +/// # Reference +/// See +pub struct RequestLastMessages { + /// The ID of the guild the channels are in + pub guild_id: Snowflake, + /// The IDs of the channels to request last messages for (max 100) + pub channel_ids: Vec +} + +#[derive(Debug, Deserialize, Serialize, Default, Clone, WebSocketEvent)] +/// Sent as a response to [RequestLastMessages]. +/// +/// # Reference +/// See +pub struct LastMessages { + /// The ID of the guild the channels are in + pub guild_id: Snowflake, + pub messages: Vec +} diff --git a/src/types/events/mfa.rs b/src/types/events/mfa.rs new file mode 100644 index 00000000..00dee345 --- /dev/null +++ b/src/types/events/mfa.rs @@ -0,0 +1,31 @@ +use crate::types::{MfaAuthenticator, MfaAuthenticatorType, Snowflake, WebSocketEvent}; +use chorus_macros::WebSocketEvent; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq, Eq, WebSocketEvent)] +/// See ; +/// +/// Sent when an [MfaAuthenticator] is created. +pub struct AuthenticatorCreate { + #[serde(flatten)] + pub authenticator: MfaAuthenticator, +} + +#[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq, Eq, WebSocketEvent)] +/// See ; +/// +/// Sent when an [MfaAuthenticator] is modified. +pub struct AuthenticatorUpdate { + #[serde(flatten)] + pub authenticator: MfaAuthenticator, +} + +#[derive(Debug, Default, Deserialize, Serialize, Clone, Copy, PartialEq, Eq, WebSocketEvent)] +/// See ; +/// +/// Sent when an [MfaAuthenticator] is deleted. +pub struct AuthenticatorDelete { + pub id: Snowflake, + #[serde(rename = "type")] + pub authenticator_type: MfaAuthenticatorType, +} diff --git a/src/types/events/mod.rs b/src/types/events/mod.rs index 280ebb80..68e4aa1b 100644 --- a/src/types/events/mod.rs +++ b/src/types/events/mod.rs @@ -18,6 +18,7 @@ pub use invalid_session::*; pub use invite::*; pub use lazy_request::*; pub use message::*; +pub use mfa::*; pub use passive_update::*; pub use presence::*; pub use ready::*; @@ -68,6 +69,7 @@ mod invalid_session; mod invite; mod lazy_request; mod message; +mod mfa; mod passive_update; mod presence; mod ready; diff --git a/src/types/events/presence.rs b/src/types/events/presence.rs index d96d984f..5783cdb4 100644 --- a/src/types/events/presence.rs +++ b/src/types/events/presence.rs @@ -5,12 +5,14 @@ use crate::types::{events::WebSocketEvent, UserStatus}; use crate::types::{Activity, ClientStatusObject, PublicUser, Snowflake}; use serde::{Deserialize, Serialize}; +use serde_with::{serde_as, DefaultOnNull}; #[derive(Debug, Deserialize, Serialize, Default, Clone, WebSocketEvent)] /// Sent by the client to update its status and presence; /// See pub struct UpdatePresence { - /// Unix time of when the client went idle, or none if client is not idle. + /// Unix time of when the client went idle, or none + /// if client is not idle. pub since: Option, /// the client's status (online, invisible, offline, dnd, idle..) pub status: UserStatus, @@ -18,8 +20,10 @@ pub struct UpdatePresence { pub afk: bool, } +#[serde_as] #[derive(Debug, Deserialize, Serialize, Default, Clone, PartialEq, WebSocketEvent)] -/// Received to tell the client that a user updated their presence / status +/// Received to tell the client that a user updated their presence / status. If you are looking for +/// the PresenceUpdate used in the IDENTIFY gateway event, see /// /// See /// (Same structure as ) @@ -28,8 +32,28 @@ pub struct PresenceUpdate { #[serde(default)] pub guild_id: Option, pub status: UserStatus, - #[serde(default)] + // This will just result in an empty array, I guess we could also use option + #[serde_as(deserialize_as = "DefaultOnNull")] pub activities: Vec, pub client_status: ClientStatusObject, } +#[derive(Debug, Deserialize, Serialize, Default, Clone, PartialEq, WebSocketEvent)] +/// Sent to the gateway as part of [GatewayIdentifyPayload](crate::types::GatewayIdentifyPayload) +pub struct GatewayIdentifyPresenceUpdate { + #[serde(default)] + pub guild_id: Option, + pub status: UserStatus, + #[serde(default)] + pub activities: Vec, +} + +impl From for GatewayIdentifyPresenceUpdate { + fn from(value: PresenceUpdate) -> Self { + Self { + guild_id: value.guild_id, + status: value.status, + activities: value.activities, + } + } +} diff --git a/src/types/events/ready.rs b/src/types/events/ready.rs index d8c11de1..db28725b 100644 --- a/src/types/events/ready.rs +++ b/src/types/events/ready.rs @@ -4,13 +4,14 @@ use std::collections::HashMap; +use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use crate::types::entities::{Guild, User}; use crate::types::events::{Session, WebSocketEvent}; use crate::types::{ - Activity, Channel, ClientStatusObject, GuildMember, PresenceUpdate, Relationship, Snowflake, - UserSettings, VoiceState, + Activity, Channel, ClientStatusObject, GuildMember, MfaAuthenticatorType, PresenceUpdate, + Relationship, Snowflake, UserSettings, VoiceState, }; use crate::{UInt32, UInt64, UInt8}; @@ -38,6 +39,8 @@ pub struct GatewayReady { pub guilds: Vec, /// The presences of the user's non-offline friends and implicit relationships (depending on the `NO_AFFINE_USER_IDS` Gateway capability). pub presences: Option>, + /// Undocumented. Seems to be a list of sessions the user is currently connected with. + /// On Discord.com, this includes the current session. pub sessions: Option>, /// Unique session ID, used for resuming connections pub session_id: String, @@ -64,6 +67,9 @@ pub struct GatewayReady { pub notes: HashMap, /// The presences of the user's non-offline friends and implicit relationships (depending on the `NO_AFFINE_USER_IDS` Gateway capability), and any guild presences sent at startup pub merged_presences: Option, + /// The members of the user's guilds, in the same order as the `guilds` array + #[serde(default)] + pub merged_members: Option>>, #[serde(default)] /// The deduped users across all objects in the event pub users: Vec, @@ -71,7 +77,7 @@ pub struct GatewayReady { pub auth_token: Option, #[serde(default)] /// The types of multi-factor authenticators the user has enabled - pub authenticator_types: Vec, + pub authenticator_types: Vec, /// The action a user is required to take before continuing to use Discord pub required_action: Option, #[serde(default)] @@ -84,12 +90,22 @@ pub struct GatewayReady { pub api_code_version: UInt8, #[serde(default)] /// User experiment rollouts for the user + /// /// TODO: Make User Experiments into own struct - pub experiments: Vec, + // Note: this is a pain to parse! We need a way to parse arrays into structs via the index of + // their feilds + // + // ex: [4130837190, 0, 10, -1, 0, 1932, 0, 0] + // needs to be parsed into a struct with fields corresponding to the first, second.. value in + // the array + pub experiments: Vec, #[serde(default)] /// Guild experiment rollouts for the user + /// /// TODO: Make Guild Experiments into own struct - pub guild_experiments: Vec, + // Note: this is a pain to parse! See the above TODO + pub guild_experiments: Vec, + pub read_state: ReadState, } #[derive(Debug, Deserialize, Serialize, Default, Clone, WebSocketEvent)] @@ -125,7 +141,7 @@ pub struct GatewayReadyBot { pub users: Vec, #[serde(default)] /// The types of multi-factor authenticators the user has enabled - pub authenticator_types: Vec, + pub authenticator_types: Vec, #[serde(default)] /// A geo-ordered list of RTC regions that can be used when when setting a voice channel's `rtc_region` or updating the client's voice state pub geo_ordered_rtc_regions: Vec, @@ -160,14 +176,6 @@ impl GatewayReady { self.into() } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)] -#[cfg_attr(not(feature = "sqlx"), repr(u8))] -#[cfg_attr(feature = "sqlx", repr(i16))] -pub enum AuthenticatorType { - WebAuthn = 1, - Totp = 2, - Sms = 3, -} #[derive(Debug, Deserialize, Serialize, Default, Clone, WebSocketEvent)] /// Officially Undocumented; @@ -227,3 +235,27 @@ pub struct SupplementalGuild { /// Field not documented even unofficially pub embedded_activities: Vec, } + +#[derive(Debug, Deserialize, Serialize, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +/// Not documented even unofficially. Information about this type is likely to be partially incorrect. +pub struct ReadState { + pub entries: Vec, + pub partial: bool, + pub version: u32, +} + +#[derive( + Debug, Deserialize, Serialize, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Copy, +)] +/// Not documented even unofficially. Information about this type is likely to be partially incorrect. +pub struct ReadStateEntry { + /// Spacebar servers do not have flags in this entity at all (??) + pub flags: Option, + pub id: Snowflake, + pub last_message_id: Option, + pub last_pin_timestamp: Option>, + /// A value that is incremented each time the read state is read + pub last_viewed: Option, + // Temporary adding Option to fix Spacebar servers, they have mention count as a nullable + pub mention_count: Option, +} diff --git a/src/types/events/request_members.rs b/src/types/events/request_members.rs index 5228170d..1c0beb02 100644 --- a/src/types/events/request_members.rs +++ b/src/types/events/request_members.rs @@ -2,17 +2,43 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -use crate::types::{events::WebSocketEvent, Snowflake}; +use crate::types::{events::WebSocketEvent, OneOrMoreSnowflakes}; use serde::{Deserialize, Serialize}; -#[derive(Debug, Deserialize, Serialize, Default, WebSocketEvent, Clone)] -/// See +#[derive(Debug, Deserialize, Default, Serialize, WebSocketEvent, Clone)] +/// Used to request members for a guild or a list of guilds. +/// +/// Fires multiple [crate::types::events::GuildMembersChunk] events (each with up to 1000 members) +/// until all members that match the request have been sent. +/// +/// # Notes +/// One of `query` or `user_ids` is required. +/// +/// If `query` is set, `limit` is required (if requesting all members, set `limit` to 0) +/// +/// # Reference +/// See pub struct GatewayRequestGuildMembers { - pub guild_id: Snowflake, + /// Id(s) of the guild(s) to get members for + pub guild_id: OneOrMoreSnowflakes, + + /// The user id(s) to request (0 - 100) + pub user_ids: Option, + + /// String that the username / nickname starts with, or an empty string for all members pub query: Option, - pub limit: u64, + + /// Maximum number of members to send matching the query (0 - 100) + /// + /// Must be 0 with an empty query + pub limit: u8, + + /// Whether to return the [Presence](crate::types::events::PresenceUpdate) of the matched + /// members pub presences: Option, - // TODO: allow array - pub user_ids: Option, + + /// Unique string to identify the received event for this specific request. + /// + /// Up to 32 bytes. If you send a longer nonce, it will be ignored pub nonce: Option, } diff --git a/src/types/schema/auth.rs b/src/types/schema/auth.rs index 83c88dc0..808ff7c7 100644 --- a/src/types/schema/auth.rs +++ b/src/types/schema/auth.rs @@ -5,6 +5,8 @@ use chrono::NaiveDate; use serde::{Deserialize, Serialize}; +use crate::types::{Shared, UserSettings}; + #[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub struct RegisterSchema { @@ -36,11 +38,19 @@ pub struct LoginSchema { pub gift_code_sku_id: Option, } -#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] -pub struct TotpSchema { - code: String, - ticket: String, - gift_code_sku_id: Option, - login_source: Option, +pub struct VerifyMFALoginSchema { + pub ticket: String, + pub code: String, + pub login_source: Option, + pub gift_code_sku_id: Option, } + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum VerifyMFALoginResponse { + Success { token: String, user_settings: Shared }, + UserSuspended { suspended_user_token: String } +} + diff --git a/src/types/schema/channel.rs b/src/types/schema/channel.rs index 62b3b53a..abc57ff3 100644 --- a/src/types/schema/channel.rs +++ b/src/types/schema/channel.rs @@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize}; use crate::types::{entities::PermissionOverwrite, ChannelType, DefaultReaction, Snowflake}; +// TODO: Needs updating #[derive(Debug, Deserialize, Serialize, Default, PartialEq, PartialOrd)] #[serde(rename_all = "snake_case")] pub struct ChannelCreateSchema { diff --git a/src/types/schema/mfa.rs b/src/types/schema/mfa.rs new file mode 100644 index 00000000..460bec3c --- /dev/null +++ b/src/types/schema/mfa.rs @@ -0,0 +1,426 @@ +use std::fmt::Display; + +use serde::{Deserialize, Serialize}; +use serde_repr::{Deserialize_repr, Serialize_repr}; + +use crate::{types::Snowflake, errors::ChorusError}; + +#[cfg(feature = "client")] +use crate::{instance::ChorusUser, ChorusResult}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +#[serde(rename_all = "snake_case")] +/// Error received when mfa is required +pub struct MfaRequiredSchema { + pub message: String, + pub code: i32, + #[serde(rename = "mfa")] + pub mfa_challenge: MfaChallenge, +} + +impl Display for MfaRequiredSchema { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MfaRequired") + .field("message", &self.message) + .field("code", &self.code) + .field("mfa", &self.mfa_challenge) + .finish() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +#[serde(rename_all = "snake_case")] +/// A challenge to verify the local user's identity with mfa. +/// +/// (Normally returned in [MfaRequiredSchema] as [ChorusError::MfaRequired]) +/// +/// To complete the challenge, see [ChorusUser::complete_mfa_challenge]. +pub struct MfaChallenge { + /// A unique ticket which identifies this challenge + pub ticket: String, + /// The ways we can verify the user's identity + pub methods: Vec, +} + +#[cfg(feature = "client")] +impl MfaChallenge { + /// Attempts to complete the [MfaChallenge] with authentication data from the user. + /// + /// If successful, the MFA verification JWT returned is set on the provided [ChorusUser]. + /// + /// The JWT token expires after 5 minutes. + /// + /// # Arguments + /// `authentication_type` is the way the user has chosen to authenticate. + /// + /// It must be the type of one of the provided `methods` in the challenge. + /// + /// `data` is specific to the `authentication_type`. + /// + /// For example, a totp authenticator uses a 6 digit code as the `data`. + /// + /// # Notes + /// Alias of [ChorusUser::complete_mfa_challenge] + pub async fn complete( + self, + user: &mut ChorusUser, + authentication_type: MfaAuthenticationType, + data: String, + ) -> ChorusResult<()> { + let schema = + MfaVerifySchema::from_challenge_and_verification_data(self, authentication_type, data); + + user.complete_mfa_challenge(schema).await + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +#[serde(rename_all = "snake_case")] +/// A way we can verify the user's identity, found in [MfaChallenge] +pub struct MfaMethod { + /// The type of authentication we can perform + #[serde(rename = "type")] + pub kind: MfaAuthenticationType, + + /// A challenge string unique to the authentication type, [None] if the type does not need a challenge string + #[serde(skip_serializing_if = "Option::is_none")] + pub challenge: Option, + + /// Whether or not we can use a backup code for this authentication type + #[serde(skip_serializing_if = "Option::is_none")] + pub backup_codes_allowed: Option, +} + +#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +#[serde(rename_all = "snake_case")] +/// A multi-factor authentication authenticator. +/// +/// # Reference +/// See +pub struct MfaAuthenticator { + pub id: Snowflake, + #[serde(rename = "type")] + pub authenticator_type: MfaAuthenticatorType, + pub name: String, +} + +#[derive(Serialize_repr, Deserialize_repr, Debug, Default, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u8)] +#[serde(rename_all = "lowercase")] +/// Types of [MfaAuthenticator]s. +/// +/// Not to be confused with [MfaAuthenticationType], which covers other cases of authentication as well. (Such as backup codes or a password) +pub enum MfaAuthenticatorType { + #[default] + WebAuthn = 1, + TOTP = 2, + SMS = 3, +} + +impl TryFrom for MfaAuthenticatorType { + type Error = ChorusError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::WebAuthn), + 2 => Ok(Self::TOTP), + 3 => Ok(Self::SMS), + _ => Err(ChorusError::InvalidArguments { + error: "Value is not a valid MfaAuthenticatorType".to_string(), + }), + } + } +} + +#[cfg(feature = "sqlx")] +impl sqlx::Type for MfaAuthenticatorType { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +#[cfg(feature = "sqlx")] +impl<'q> sqlx::Encode<'q, sqlx::Postgres> for MfaAuthenticatorType { + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { + let sqlx_pg_uint = sqlx_pg_uint::PgU8::from(*self as u8); + sqlx_pg_uint.encode_by_ref(buf) + } +} + +#[cfg(feature = "sqlx")] +impl<'r> sqlx::Decode<'r, sqlx::Postgres> for MfaAuthenticatorType { + fn decode( + value: ::ValueRef<'r>, + ) -> Result { + let sqlx_pg_uint = sqlx_pg_uint::PgU8::decode(value)?; + MfaAuthenticatorType::try_from(sqlx_pg_uint.to_uint()).map_err(|e| e.into()) + } +} + +impl MfaAuthenticatorType { + /// Converts self into [MfaAuthenticationType] + pub fn into_authentication_type(self) -> MfaAuthenticationType { + match self { + Self::WebAuthn => MfaAuthenticationType::WebAuthn, + Self::TOTP => MfaAuthenticationType::TOTP, + Self::SMS => MfaAuthenticationType::SMS, + } + } +} + +impl From for MfaAuthenticationType { + fn from(value: MfaAuthenticatorType) -> Self { + value.into_authentication_type() + } +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] +#[serde(rename_all = "lowercase")] +/// Ways to perform multi factor authentication. +pub enum MfaAuthenticationType { + WebAuthn, + TOTP, + SMS, + Backup, + Password, +} + +impl Display for MfaAuthenticationType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + MfaAuthenticationType::TOTP => "totp", + MfaAuthenticationType::SMS => "sms", + MfaAuthenticationType::Backup => "backup", + MfaAuthenticationType::WebAuthn => "webauthn", + MfaAuthenticationType::Password => "password", + } + ) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] +/// An mfa backup code. +/// +/// # Reference +/// See +pub struct MfaBackupCode { + pub user_id: Snowflake, + pub code: String, + /// Whether or not the backup code has been used + pub consumed: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +/// A schema used for the [ChorusUser::complete_mfa_challenge] route. +pub struct MfaVerifySchema { + /// Usually obtained from [MfaChallenge] + pub ticket: String, + /// The way the user has chosen to authenticate + /// + /// Must be one of the available methods in from the [MfaChallenge] + pub mfa_type: MfaAuthenticationType, + /// Data unique to the authentication type (ex. a 6 digit totp code for totp, a password) + pub data: String, +} + +#[cfg(feature = "client")] +impl MfaVerifySchema { + /// Creates the verify schema from an [MfaChallenge] and data needed to complete it. + /// + /// Shorthand for initializing [Self] with mfa_type, data and ticket = challenge.ticket + pub fn from_challenge_and_verification_data( + challenge: MfaChallenge, + mfa_type: MfaAuthenticationType, + data: String, + ) -> Self { + Self { + ticket: challenge.ticket, + mfa_type, + data, + } + } + + /// Uses the verification schema to attempt to complete an [MfaChallenge]. + /// + /// If successful, the MFA verification JWT returned is set on the provided [ChorusUser]. + /// + /// The JWT token expires after 5 minutes. + /// + /// # Notes + /// Alias of [ChorusUser::complete_mfa_challenge] + pub async fn verify_mfa(self, user: &mut ChorusUser) -> ChorusResult<()> { + user.complete_mfa_challenge(self).await + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +/// An MFA token generated by the server after completing and mfa challenge ([crate::instance::ChorusUser::complete_mfa_challenge]) +pub struct MfaTokenSchema { + pub token: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +/// Schema for the Send Mfa SMS route ([crate::instance::Instance::send_mfa_sms]) +pub struct SendMfaSmsSchema { + pub ticket: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +/// Return type for the Send Mfa SMS route ([crate::instance::Instance::send_mfa_sms]) +pub struct SendMfaSmsResponse { + pub phone: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +/// Json schema for the Enable TOTP MFA route +/// +/// # Notes +/// Secret and code are optional so that clients +/// may first verify the password is correct before +/// letting the user save the secrets. +/// +/// If the password is valid, the request will fail with a 60005 +/// json error code. However note that JSON error codes are not yet +/// implemented in chorus. () +/// To implement this kind of check, you would need to manually deserialize into +/// the json error code object. +// TODO: Json error codes +/// +/// # Reference +/// See +pub struct EnableTotpMfaSchema { + pub password: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub secret: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +/// Response type for the Enable TOTP MFA route +/// +/// # Reference +/// See +pub struct EnableTotpMfaResponse { + pub token: String, + pub backup_codes: Vec, +} + +#[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq, Eq)] +/// A schema for SMS MFA Enable and Disable routes. +/// +/// # Reference +/// See and +/// +pub struct SmsMfaRouteSchema { + /// The user's current password + pub password: String, +} + +#[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq, Eq)] +/// A return type for the [ChorusUser::begin_webauthn_authenticator_creation] route (Create WebAuthn Authenticator with no arguments). +/// +/// Includes the MFA ticket and a stringified JSON object of the public key credential challenge. +/// +/// # Reference +/// See +pub struct BeginWebAuthnAuthenticatorCreationReturn { + pub ticket: String, + /// Stringified JSON public key credential request options challenge + pub challenge: String, +} + +#[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq, Eq)] +/// A schema for the [ChorusUser::finish_webauthn_authenticator_creation] route (Create WebAuthn Authenticator). +/// +/// # Reference +/// See +pub struct FinishWebAuthnAuthenticatorCreationSchema { + /// Name of the authenticator to create (1 - 32 characters) + pub name: String, + /// The MFA ticket returned by the (begin creation)[ChorusUser::begin_webauthn_authenticator_creation] endpoint + pub ticket: String, + /// A stringified JSON object of the public key credential response. + pub credential: String, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +/// A return type for the [ChorusUser::finish_webauthn_authenticator_creation] route (Create WebAuthn Authenticator). +/// +/// Includes the MFA ticket and a stringified JSON object of the public key credential challenge. +/// +/// # Reference +/// See +pub struct FinishWebAuthnAuthenticatorCreationReturn { + #[serde(flatten)] + /// The created authenticator object + pub authenticator: MfaAuthenticator, + /// A list of MFA backup codes + pub backup_codes: Vec, +} + +#[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq, Eq)] +/// A schema for the Modify WebAuthn Authenticator route. +/// +/// # Reference +/// See +pub struct ModifyWebAuthnAuthenticatorSchema { + #[serde(skip_serializing_if = "Option::is_none")] + /// New name of the authenticator (1 - 32 characters) + pub name: Option, +} + +#[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq, Eq)] +/// A schema for the Send Backup Codes Challenge route. +/// +/// # Reference +/// See +pub struct SendBackupCodesChallengeSchema { + /// The user's current password + pub password: String, +} + +#[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq, Eq)] +/// A return type for the Send Backup Codes Challenge route. +/// +/// # Reference +/// See +pub struct SendBackupCodesChallengeReturn { + /// A one-time verification nonce used to view the backup codes + /// + /// Send this in the [ChorusUser::get_backup_codes] endpoint as the nonce if you want to view + /// the existing codes + #[serde(rename = "nonce")] + pub view_nonce: String, + /// A one-time verification nonce used to regenerate the backup codes + /// + /// Send this in the [ChorusUser::get_backup_codes] endpoint as the nonce if you want to + /// regenerate the backup codes + pub regenerate_nonce: String, +} + +#[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq, Eq)] +/// A schema for the Get Backup Codes route. +/// +/// # Reference +/// See +pub struct GetBackupCodesSchema { + /// The one-time verification nonce used to view or regenerate the backup codes. + /// + /// Obtained from the [ChorusUser::send_backup_codes_challenge] route. + pub nonce: String, + /// The backup verification key received in the email + pub key: String, + /// Whether or not to regenerate the backup codes + /// + /// If set to true, nonce should be the regenerate_nonce + /// otherwise it should be the view_nonce + pub regenerate: bool, +} diff --git a/src/types/schema/mod.rs b/src/types/schema/mod.rs index 2888046e..fdebb582 100644 --- a/src/types/schema/mod.rs +++ b/src/types/schema/mod.rs @@ -5,6 +5,7 @@ pub use apierror::*; pub use audit_log::*; pub use auth::*; +pub use mfa::*; pub use channel::*; pub use guild::*; pub use message::*; @@ -18,6 +19,7 @@ pub use instance::*; mod apierror; mod audit_log; mod auth; +mod mfa; mod channel; mod guild; mod message; diff --git a/src/types/utils/mod.rs b/src/types/utils/mod.rs index 5608fe74..ff6b20ec 100644 --- a/src/types/utils/mod.rs +++ b/src/types/utils/mod.rs @@ -3,13 +3,14 @@ // file, You can obtain one at http://mozilla.org/MPL/2.0/. #![allow(unused_imports)] +pub use opcode::*; pub use regexes::*; pub use rights::Rights; -pub use snowflake::Snowflake; +pub use snowflake::{Snowflake, OneOrMoreSnowflakes}; pub mod jwt; +pub mod opcode; mod regexes; mod rights; -mod snowflake; pub mod serde; - +mod snowflake; diff --git a/src/types/utils/opcode.rs b/src/types/utils/opcode.rs new file mode 100644 index 00000000..3c533c3e --- /dev/null +++ b/src/types/utils/opcode.rs @@ -0,0 +1,308 @@ +#![allow(deprecated)] // Required to suppress warnings about deprecated opcodes + +use serde::{Deserialize, Serialize}; +#[cfg(not(target_arch = "wasm32"))] +use tokio_tungstenite::tungstenite::protocol::CloseFrame; + +use crate::errors::ChorusError; + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Deserialize, Serialize)] +#[non_exhaustive] +#[repr(u8)] +/// Gateway opcodes used in the Spacebar Gateway Protocol. +pub enum Opcode { + /// An event was dispatched. + Dispatch = 0, + /// Keep the WebSocket connection alive. + Heartbeat = 1, + /// Start a new session during the initial handshake. + Identify = 2, + /// Update the client's presence. + PresenceUpdate = 3, + /// Join/leave or move between voice channels and calls. + VoiceStateUpdate = 4, + /// Ping the Discord voice servers. + VoiceServerPing = 5, + /// Resume a previous session that was disconnected. + Resume = 6, + /// You should attempt to reconnect and resume immediately. + Reconnect = 7, + /// Request information about guild members. + RequestGuildMembers = 8, + /// The session has been invalidated. You should reconnect and identify/resume accordingly. + InvalidSession = 9, + /// Sent immediately after connecting, contains the heartbeat_interval to use. + Hello = 10, + /// Acknowledge a received heartbeat. + HeartbeatAck = 11, + /// Request all members and presences for guilds. + #[deprecated] + GuildSync = 12, + /// Request a private channel's pre-existing call data. + CallConnect = 13, + /// Update subscriptions for a guild. + GuildSubscriptions = 14, + /// Join a lobby. + LobbyConnect = 15, + /// Leave a lobby. + LobbyDisconnect = 16, + /// Update the client's voice state in a lobby. + LobbyVoiceStates = 17, + /// Create a stream for the client. + StreamCreate = 18, + /// End a client stream. + StreamDelete = 19, + /// Watch a user's stream. + StreamWatch = 20, + /// Ping a user stream's voice server. + StreamPing = 21, + /// Pause/resume a client stream. + StreamSetPaused = 22, + /// Update subscriptions for an LFG lobby. + #[deprecated] + LfgSubscriptions = 23, + /// Request guild application commands. + #[deprecated] + RequestGuildApplicationCommands = 24, + /// Launch an embedded activity in a voice channel or call. + EmbeddedActivityCreate = 25, + /// Stop an embedded activity. + EmbeddedActivityDelete = 26, + /// Update an embedded activity. + EmbeddedActivityUpdate = 27, + /// Request forum channel unread counts. + RequestForumUnreads = 28, + /// Send a remote command to an embedded (Xbox, PlayStation) voice session. + RemoteCommand = 29, + /// Request deleted entity IDs not matching a given hash for a guild. + RequestDeletedEntityIDs = 30, + /// Request soundboard sounds for guilds. + RequestSoundboardSounds = 31, + /// Create a voice speed test. + SpeedTestCreate = 32, + /// Delete a voice speed test. + SpeedTestDelete = 33, + /// Request last messages for a guild's channels. + RequestLastMessages = 34, + /// Request information about recently-joined guild members. + SearchRecentMembers = 35, + /// Request voice channel statuses for a guild. + RequestChannelStatuses = 36, +} + +impl TryFrom for Opcode { + type Error = ChorusError; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(Self::Dispatch), + 1 => Ok(Self::Heartbeat), + 2 => Ok(Self::Identify), + 3 => Ok(Self::PresenceUpdate), + 4 => Ok(Self::VoiceStateUpdate), + 5 => Ok(Self::VoiceServerPing), + 6 => Ok(Self::Resume), + 7 => Ok(Self::Reconnect), + 8 => Ok(Self::RequestGuildMembers), + 9 => Ok(Self::InvalidSession), + 10 => Ok(Self::Hello), + 11 => Ok(Self::HeartbeatAck), + 12 => Ok(Self::GuildSync), + 13 => Ok(Self::CallConnect), + 14 => Ok(Self::GuildSubscriptions), + 15 => Ok(Self::LobbyConnect), + 16 => Ok(Self::LobbyDisconnect), + 17 => Ok(Self::LobbyVoiceStates), + 18 => Ok(Self::StreamCreate), + 19 => Ok(Self::StreamDelete), + 20 => Ok(Self::StreamWatch), + 21 => Ok(Self::StreamPing), + 22 => Ok(Self::StreamSetPaused), + 23 => Ok(Self::LfgSubscriptions), + 24 => Ok(Self::RequestGuildApplicationCommands), + 25 => Ok(Self::EmbeddedActivityCreate), + 26 => Ok(Self::EmbeddedActivityDelete), + 27 => Ok(Self::EmbeddedActivityUpdate), + 28 => Ok(Self::RequestForumUnreads), + 29 => Ok(Self::RemoteCommand), + 30 => Ok(Self::RequestDeletedEntityIDs), + 31 => Ok(Self::RequestSoundboardSounds), + 32 => Ok(Self::SpeedTestCreate), + 33 => Ok(Self::SpeedTestDelete), + 34 => Ok(Self::RequestLastMessages), + 35 => Ok(Self::SearchRecentMembers), + 36 => Ok(Self::RequestChannelStatuses), + e => Err(ChorusError::InvalidArguments { + error: format!("Provided value {e} is not a valid opcode"), + }), + } + } +} + +#[repr(u16)] +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)] +/// When the gateway server closes your connection, it tells you what happened throught a close code. +/// +/// # Reference +/// See +pub enum CloseCode { + UnknownError = 4000, + UnknownOpcode = 4001, + DecodeError = 4002, + NotAuthenticated = 4003, + AuthenticationFailed = 4004, + AlreadyAuthenticated = 4005, + SessionNoLongerValid = 4006, + InvalidSeq = 4007, + RateLimited = 4008, + SessionTimeout = 4009, + InvalidShard = 4010, + ShardingRequired = 4011, + InvalidApiVersion = 4012, + InvalidIntents = 4013, + DisallowedIntents = 4014, +} + +#[cfg(not(target_arch = "wasm32"))] +impl CloseCode { + /// Convert `&self` to a `tokio_tungstenite` [CloseFrame]. + pub fn as_tungstenite_close_frame<'a>(&'a self, reason: &'a str) -> CloseFrame { + CloseFrame { + code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Library( + *self as u16, + ), + reason: reason.into(), + } + } +} + +#[cfg(not(target_arch = "wasm32"))] +impl TryFrom for CloseCode { + type Error = ChorusError; + + fn try_from(value: tokio_tungstenite::tungstenite::Message) -> Result { + match value { + tokio_tungstenite::tungstenite::Message::Close(close_frame) => { + if close_frame.is_none() { + return Err(ChorusError::InvalidArguments { + error: "No close_frame provided".to_string(), + }); + } + let close_frame = close_frame.unwrap(); + let close_code = u16::from(close_frame.code); + CloseCode::try_from(close_code) + } + _ => Err(ChorusError::InvalidArguments { + error: "value is not a valid CloseCode".to_string(), + }), + } + } +} + +impl TryFrom for CloseCode { + type Error = ChorusError; + + fn try_from(value: u16) -> Result { + match value { + 4000 => Ok(CloseCode::UnknownError), + 4001 => Ok(CloseCode::UnknownOpcode), + 4002 => Ok(CloseCode::DecodeError), + 4003 => Ok(CloseCode::NotAuthenticated), + 4004 => Ok(CloseCode::AuthenticationFailed), + 4005 => Ok(CloseCode::AlreadyAuthenticated), + 4006 => Ok(CloseCode::SessionNoLongerValid), + 4007 => Ok(CloseCode::InvalidSeq), + 4008 => Ok(CloseCode::RateLimited), + 4009 => Ok(CloseCode::SessionTimeout), + 4010 => Ok(CloseCode::InvalidShard), + 4011 => Ok(CloseCode::ShardingRequired), + 4012 => Ok(CloseCode::InvalidApiVersion), + 4013 => Ok(CloseCode::InvalidIntents), + 4014 => Ok(CloseCode::DisallowedIntents), + e => Err(ChorusError::InvalidArguments { + error: format!("{e} is not a valid CloseCode"), + }), + } + } +} + +#[repr(u16)] +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)] +/// When the voice gateway server closes your connection, it tells you what happened throught a close code. +/// +/// # Reference +/// See +pub enum VoiceCloseCode { + UnknownOpcode = 4001, + FailedToDecodePayload = 4002, + NotAuthenticated = 4003, + AuthenticationFailed = 4004, + AlreadyAuthenticated = 4005, + SessionNoLongerValid = 4006, + SessionTimeout = 4009, + ServerNotFound = 4011, + UnknownProtocol = 4012, + DisconnectedChannelDeletedOrKicked = 4014, + VoiceServerCrashed = 4015, + UnknownEncryptionMode = 4016, +} + +#[cfg(not(target_arch = "wasm32"))] +impl VoiceCloseCode { + /// Convert `&self` to a `tokio_tungstenite` [CloseFrame]. + pub fn as_tungstenite_close_frame<'a>(&'a self, reason: &'a str) -> CloseFrame { + CloseFrame { + code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Library( + *self as u16, + ), + reason: reason.into(), + } + } +} + +#[cfg(not(target_arch = "wasm32"))] +impl TryFrom for VoiceCloseCode { + type Error = ChorusError; + + fn try_from(value: tokio_tungstenite::tungstenite::Message) -> Result { + match value { + tokio_tungstenite::tungstenite::Message::Close(close_frame) => { + if close_frame.is_none() { + return Err(ChorusError::InvalidArguments { + error: "No close_frame provided".to_string(), + }); + } + let close_frame = close_frame.unwrap(); + let close_code = u16::from(close_frame.code); + VoiceCloseCode::try_from(close_code) + } + _ => Err(ChorusError::InvalidArguments { + error: "value is not a valid VoiceCloseCode".to_string(), + }), + } + } +} + +impl TryFrom for VoiceCloseCode { + type Error = ChorusError; + + fn try_from(value: u16) -> Result { + match value { + 4001 => Ok(VoiceCloseCode::UnknownOpcode), + 4002 => Ok(VoiceCloseCode::FailedToDecodePayload), + 4003 => Ok(VoiceCloseCode::NotAuthenticated), + 4004 => Ok(VoiceCloseCode::AuthenticationFailed), + 4005 => Ok(VoiceCloseCode::AlreadyAuthenticated), + 4006 => Ok(VoiceCloseCode::SessionNoLongerValid), + 4009 => Ok(VoiceCloseCode::SessionTimeout), + 4011 => Ok(VoiceCloseCode::ServerNotFound), + 4012 => Ok(VoiceCloseCode::UnknownProtocol), + 4014 => Ok(VoiceCloseCode::DisconnectedChannelDeletedOrKicked), + 4015 => Ok(VoiceCloseCode::VoiceServerCrashed), + 4016 => Ok(VoiceCloseCode::UnknownEncryptionMode), + e => Err(ChorusError::InvalidArguments { + error: format!("{e} is not a valid VoiceCloseCode"), + }), + } + } +} diff --git a/src/types/utils/snowflake.rs b/src/types/utils/snowflake.rs index e19f7667..ab6f338c 100644 --- a/src/types/utils/snowflake.rs +++ b/src/types/utils/snowflake.rs @@ -7,6 +7,8 @@ use std::{ sync::atomic::{AtomicUsize, Ordering}, }; +use serde::{Serialize, Deserialize}; + use chrono::{DateTime, TimeZone, Utc}; /// 2015-01-01 @@ -138,10 +140,75 @@ impl<'d> sqlx::Decode<'d, sqlx::Postgres> for Snowflake { } } +/// A type representing either a single [Snowflake] or a [Vec] of [Snowflake]s. +/// +/// Useful for e.g. [RequestGuildMembers](crate::types::events::GatewayRequestGuildMembers), to +/// select either one specific user or multiple users. +/// +/// Should (de)serialize either as a single [Snowflake] or as an array. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[serde(untagged)] +pub enum OneOrMoreSnowflakes { + One(Snowflake), + More(Vec) +} + +// Note: allows us to have Default on the events +// that use this type +impl Default for OneOrMoreSnowflakes { + fn default() -> Self { + Snowflake::default().into() + } +} + +impl Display for OneOrMoreSnowflakes { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + OneOrMoreSnowflakes::One(snowflake) => write!(f, "{}", snowflake.0), + // Display as you would debug a vec of u64s + OneOrMoreSnowflakes::More(snowflake_vec) => write!(f, "{:?}", snowflake_vec.iter().map(|x| x.0)), + } + } +} + +impl From for OneOrMoreSnowflakes { + fn from(item: Snowflake) -> Self { + Self::One(item) + } +} + +impl From> for OneOrMoreSnowflakes { + fn from(item: Vec) -> Self { + if item.len() == 1 { + return Self::One(item[0]); + } + + Self::More(item) + } +} + +impl From for OneOrMoreSnowflakes { + fn from(item: u64) -> Self { + Self::One(item.into()) + } +} + +impl From> for OneOrMoreSnowflakes { + fn from(item: Vec) -> Self { + if item.len() == 1 { + return Self::One(item[0].into()); + } + + Self::More(item.into_iter().map(|x| x.into()).collect()) + } +} + #[cfg(test)] mod test { use chrono::{DateTime, Utc}; + use crate::types::utils::snowflake::OneOrMoreSnowflakes; + use super::Snowflake; #[test] @@ -157,4 +224,29 @@ mod test { let timestamp = "2016-04-30 11:18:25.796Z".parse::>().unwrap(); assert_eq!(snow.timestamp(), timestamp); } + + #[test] + fn serialize() { + let snowflake = Snowflake(1303390110099968072_u64); + let serialized = serde_json::to_string(&snowflake).unwrap(); + + assert_eq!(serialized, "\"1303390110099968072\"".to_string()); + } + + #[test] + fn serialize_one_or_more() { + let snowflake = Snowflake(1303390110099968072_u64); + let one_snowflake: OneOrMoreSnowflakes = snowflake.into(); + + let serialized = serde_json::to_string(&one_snowflake).unwrap(); + + assert_eq!(serialized, "\"1303390110099968072\"".to_string()); + + let more_snowflakes: OneOrMoreSnowflakes = vec![snowflake, snowflake, snowflake].into(); + + let serialized = serde_json::to_string(&more_snowflakes).unwrap(); + + assert_eq!(serialized, "[\"1303390110099968072\",\"1303390110099968072\",\"1303390110099968072\"]".to_string()); + + } } diff --git a/src/voice/gateway/backends/mod.rs b/src/voice/gateway/backends/mod.rs index 23f2767d..3a488859 100644 --- a/src/voice/gateway/backends/mod.rs +++ b/src/voice/gateway/backends/mod.rs @@ -7,4 +7,3 @@ pub mod tungstenite; #[cfg(all(target_arch = "wasm32", feature = "voice_gateway"))] pub mod wasm; - diff --git a/src/voice/gateway/backends/tungstenite.rs b/src/voice/gateway/backends/tungstenite.rs index 599274d1..8989deb6 100644 --- a/src/voice/gateway/backends/tungstenite.rs +++ b/src/voice/gateway/backends/tungstenite.rs @@ -2,7 +2,8 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -use crate::voice::gateway::VoiceGatewayMessage; +use crate::types::VoiceCloseCode; +use crate::voice::gateway::{VoiceGatewayCommunication, VoiceGatewayMessage}; impl From for tokio_tungstenite::tungstenite::Message { fn from(message: VoiceGatewayMessage) -> Self { @@ -15,3 +16,28 @@ impl From for VoiceGatewayMessage { Self(value.to_string()) } } + +impl From for VoiceGatewayCommunication { + fn from(value: tokio_tungstenite::tungstenite::Message) -> Self { + match value { + tokio_tungstenite::tungstenite::Message::Text(text) => { + VoiceGatewayCommunication::Message(VoiceGatewayMessage(text)) + } + tokio_tungstenite::tungstenite::Message::Close(close_frame) => { + if close_frame.is_none() { + // Note: there is no unknown error. This case shouldn't happen, so I'm just + // going to delegate it to this error + return VoiceGatewayCommunication::Error(VoiceCloseCode::FailedToDecodePayload); + } + + let close_code = u16::from(close_frame.unwrap().code); + + VoiceGatewayCommunication::Error( + VoiceCloseCode::try_from(close_code) + .unwrap_or(VoiceCloseCode::FailedToDecodePayload), + ) + } + _ => VoiceGatewayCommunication::Error(VoiceCloseCode::FailedToDecodePayload), + } + } +} diff --git a/src/voice/gateway/backends/wasm.rs b/src/voice/gateway/backends/wasm.rs index 7b069c60..18265c82 100644 --- a/src/voice/gateway/backends/wasm.rs +++ b/src/voice/gateway/backends/wasm.rs @@ -2,8 +2,8 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -use ws_stream_wasm::WsMessage; use crate::voice::gateway::VoiceGatewayMessage; +use ws_stream_wasm::WsMessage; impl From for WsMessage { fn from(message: VoiceGatewayMessage) -> Self { @@ -23,4 +23,3 @@ impl From for VoiceGatewayMessage { } } } - diff --git a/src/voice/gateway/gateway.rs b/src/voice/gateway/gateway.rs index 1b5981cb..aba34776 100644 --- a/src/voice/gateway/gateway.rs +++ b/src/voice/gateway/gateway.rs @@ -18,17 +18,24 @@ use crate::gateway::WebSocketBackend; use crate::{ errors::VoiceGatewayError, types::{ - VoiceGatewayReceivePayload, VoiceHelloData, WebSocketEvent, VOICE_BACKEND_VERSION, - VOICE_CLIENT_CONNECT_FLAGS, VOICE_CLIENT_CONNECT_PLATFORM, VOICE_CLIENT_DISCONNECT, - VOICE_HEARTBEAT, VOICE_HEARTBEAT_ACK, VOICE_HELLO, VOICE_IDENTIFY, VOICE_MEDIA_SINK_WANTS, - VOICE_READY, VOICE_RESUME, VOICE_SELECT_PROTOCOL, VOICE_SESSION_DESCRIPTION, - VOICE_SESSION_UPDATE, VOICE_SPEAKING, VOICE_SSRC_DEFINITION, + VoiceCloseCode, VoiceGatewayReceivePayload, VoiceHelloData, WebSocketEvent, + VOICE_BACKEND_VERSION, VOICE_CLIENT_CONNECT_FLAGS, VOICE_CLIENT_CONNECT_PLATFORM, + VOICE_CLIENT_DISCONNECT, VOICE_HEARTBEAT, VOICE_HEARTBEAT_ACK, VOICE_HELLO, VOICE_IDENTIFY, + VOICE_MEDIA_SINK_WANTS, VOICE_READY, VOICE_RESUME, VOICE_SELECT_PROTOCOL, + VOICE_SESSION_DESCRIPTION, VOICE_SESSION_UPDATE, VOICE_SPEAKING, VOICE_SSRC_DEFINITION, + }, + voice::gateway::{ + heartbeat::VoiceHeartbeatThreadCommunication, VoiceGatewayCommunication, + VoiceGatewayMessage, }, - voice::gateway::{heartbeat::VoiceHeartbeatThreadCommunication, VoiceGatewayMessage}, }; use super::{events::VoiceEvents, heartbeat::VoiceHeartbeatHandler, VoiceGatewayHandle}; +// Needed to observe close codes +#[cfg(target_arch = "wasm32")] +use pharos::Observable; + #[derive(Debug)] pub struct VoiceGateway { events: Arc>, @@ -64,9 +71,22 @@ impl VoiceGateway { // Wait for the first hello and then spawn both tasks so we avoid nested tasks // This automatically spawns the heartbeat task, but from the main thread #[cfg(not(target_arch = "wasm32"))] - let msg: VoiceGatewayMessage = websocket_receive.next().await.unwrap().unwrap().into(); + let msg: VoiceGatewayMessage = { + // Note: The tungstenite backend handles close codes as messages, while the ws_stream_wasm one handles them differently. + // + // Hence why wasm receives straight VoiceGatewayMessages, and tungstenite receives + // VoiceGatewayCommunications. + let communication: VoiceGatewayCommunication = + websocket_receive.next().await.unwrap().unwrap().into(); + + match communication { + VoiceGatewayCommunication::Message(message) => message, + VoiceGatewayCommunication::Error(error) => return Err(error.into()), + } + }; + #[cfg(target_arch = "wasm32")] - let msg: VoiceGatewayMessage = websocket_receive.next().await.unwrap().into(); + let msg: VoiceGatewayMessage = websocket_receive.0.next().await.unwrap().into(); let gateway_payload: VoiceGatewayReceivePayload = serde_json::from_str(&msg.0).unwrap(); if gateway_payload.op_code != VOICE_HELLO { @@ -102,11 +122,11 @@ impl VoiceGateway { // Now we can continuously check for messages in a different task, since we aren't going to receive another hello #[cfg(not(target_arch = "wasm32"))] tokio::task::spawn(async move { - gateway.gateway_listen_task().await; + gateway.gateway_listen_task_tungstenite().await; }); #[cfg(target_arch = "wasm32")] wasm_bindgen_futures::spawn_local(async move { - gateway.gateway_listen_task().await; + gateway.gateway_listen_task_wasm().await; }); Ok(VoiceGatewayHandle { @@ -117,8 +137,9 @@ impl VoiceGateway { }) } - /// The main gateway listener task; - pub async fn gateway_listen_task(&mut self) { + /// The main gateway listener task for a tungstenite based gateway; + #[cfg(not(target_arch = "wasm32"))] + async fn gateway_listen_task_tungstenite(&mut self) { loop { let msg; @@ -132,13 +153,75 @@ impl VoiceGateway { } } - // PRETTYFYME: Remove inline conditional compiling - #[cfg(not(target_arch = "wasm32"))] + // Note: The tungstenite backend handles close codes as messages, while the ws_stream_wasm one handles them differently. + // + // Hence why wasm receives straight RawGatewayMessages, and tungstenite receives + // GatewayCommunications. if let Some(Ok(message)) = msg { - self.handle_message(message.into()).await; + let communication: VoiceGatewayCommunication = message.into(); + + match communication { + VoiceGatewayCommunication::Message(message) => { + self.handle_message(message).await + } + VoiceGatewayCommunication::Error(close_code) => { + self.handle_close_code(close_code).await + } + } + continue; } - #[cfg(target_arch = "wasm32")] + + // We couldn't receive the next message or it was an error, something is wrong with the websocket, close + warn!("VGW: Websocket is broken, stopping gateway"); + break; + } + } + + /// The main gateway listener task for a wasm based gateway; + /// + /// Wasm handles close codes and events differently, and so we must change the listener logic a + /// bit + #[cfg(target_arch = "wasm32")] + async fn gateway_listen_task_wasm(&mut self) { + // Initiate the close event listener + let mut close_events = self + .websocket_receive + .1 + .observe(pharos::Filter::Pointer(ws_stream_wasm::WsEvent::is_closed).into()) + .await + .unwrap(); + + loop { + let msg; + + tokio::select! { + Ok(_) = self.kill_receive.recv() => { + log::trace!("VGW: Closing listener task"); + break; + } + message = self.websocket_receive.0.next() => { + msg = message; + } + maybe_event = close_events.next() => { + if let Some(event) = maybe_event { + match event { + ws_stream_wasm::WsEvent::Closed(closed_event) => { + let close_code = VoiceCloseCode::try_from(closed_event.code).unwrap_or(VoiceCloseCode::FailedToDecodePayload); + self.handle_close_code(close_code).await; + break; + } + _ => unreachable!() // Should be impossible, we filtered close events + } + } + continue; + } + } + + // Note: The tungstenite backend handles close codes as messages, while the ws_stream_wasm one handles them as a seperate receiver. + // + // Hence why wasm receives VoiceGatewayMessages, and tungstenite receives + // VoiceGatewayCommunications. if let Some(message) = msg { self.handle_message(message.into()).await; continue; @@ -156,6 +239,17 @@ impl VoiceGateway { self.websocket_send.lock().await.close().await.unwrap(); } + /// Handles receiving a [VoiceCloseCode]. + /// + /// Closes the connection and publishes an error event. + async fn handle_close_code(&mut self, code: VoiceCloseCode) { + let error = VoiceGatewayError::from(code); + + warn!("VGW: Received error {:?}, connection will close..", error); + self.close().await; + self.events.lock().await.error.publish(error).await; + } + /// Deserializes and updates a dispatched event, when we already know its type; /// (Called for every event in handle_message) async fn handle_event<'a, T: WebSocketEvent + serde::Deserialize<'a>>( @@ -179,16 +273,10 @@ impl VoiceGateway { } let Ok(gateway_payload) = msg.payload() else { - if let Some(error) = msg.error() { - warn!("GW: Received error {:?}, connection will close..", error); - self.close().await; - self.events.lock().await.error.publish(error).await; - } else { - warn!( - "Message unrecognised: {:?}, please open an issue on the chorus github", - msg.0 - ); - } + warn!( + "VGW: Message unrecognised: {:?}, please open an issue on the chorus github", + msg.0 + ); return; }; diff --git a/src/voice/gateway/message.rs b/src/voice/gateway/message.rs index 4b40f35c..79e9b33a 100644 --- a/src/voice/gateway/message.rs +++ b/src/voice/gateway/message.rs @@ -2,42 +2,34 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -use crate::{errors::VoiceGatewayError, types::VoiceGatewayReceivePayload}; +use crate::types::{VoiceGatewayReceivePayload, VoiceCloseCode}; + +#[derive(Clone, Debug, PartialEq, Eq)] +/// Defines a communication received from the gateway, being either an optionally compressed +/// [RawGatewayMessage] or a [CloseCode]. +/// +/// Used only for a tungstenite gateway, since our underlying wasm backend handles close codes +/// differently. +pub(crate) enum VoiceGatewayCommunication { + Message(VoiceGatewayMessage), + Error(VoiceCloseCode), +} + +impl From for VoiceGatewayCommunication { + fn from(value: VoiceGatewayMessage) -> Self { + Self::Message(value) + } +} /// Represents a message received from the voice websocket connection. /// -/// This will be either a [VoiceGatewayReceivePayload], containing voice gateway events, or a [VoiceGatewayError]. +/// This should be a [VoiceGatewayReceivePayload]. /// /// This struct is used internally when handling messages. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct VoiceGatewayMessage(pub String); impl VoiceGatewayMessage { - /// Parses the message as an error; - /// Returns the error if successfully parsed, None if the message isn't an error - pub fn error(&self) -> Option { - // Some error strings have dots on the end, which we don't care about - let processed_content = self.0.to_lowercase().replace('.', ""); - - match processed_content.as_str() { - "unknown opcode" | "4001" => Some(VoiceGatewayError::UnknownOpcode), - "decode error" | "failed to decode payload" | "4002" => { - Some(VoiceGatewayError::FailedToDecodePayload) - } - "not authenticated" | "4003" => Some(VoiceGatewayError::NotAuthenticated), - "authentication failed" | "4004" => Some(VoiceGatewayError::AuthenticationFailed), - "already authenticated" | "4005" => Some(VoiceGatewayError::AlreadyAuthenticated), - "session is no longer valid" | "4006" => Some(VoiceGatewayError::SessionNoLongerValid), - "session timeout" | "4009" => Some(VoiceGatewayError::SessionTimeout), - "server not found" | "4011" => Some(VoiceGatewayError::ServerNotFound), - "unknown protocol" | "4012" => Some(VoiceGatewayError::UnknownProtocol), - "disconnected" | "4014" => Some(VoiceGatewayError::Disconnected), - "voice server crashed" | "4015" => Some(VoiceGatewayError::VoiceServerCrashed), - "unknown encryption mode" | "4016" => Some(VoiceGatewayError::UnknownEncryptionMode), - _ => None, - } - } - /// Parses the message as a payload; /// Returns a result of deserializing pub fn payload(&self) -> Result { diff --git a/tests/auth.rs b/tests/auth.rs index 9f0cbfdb..b05c917b 100644 --- a/tests/auth.rs +++ b/tests/auth.rs @@ -4,7 +4,18 @@ use std::str::FromStr; -use chorus::types::{LoginSchema, RegisterSchema}; +use chorus::types::{ + LoginSchema, MfaAuthenticationType, MfaVerifySchema, RegisterSchema, SendMfaSmsSchema, +}; + +#[cfg(not(target_arch = "wasm32"))] +use httptest::{ + matchers::{all_of, contains, eq, json_decoded, request}, + responders::json_encoded, + Expectation, +}; + +use serde_json::json; #[cfg(target_arch = "wasm32")] use wasm_bindgen_test::*; #[cfg(target_arch = "wasm32")] @@ -101,3 +112,596 @@ async fn test_login_with_invalid_token() { common::teardown(bundle).await; } + +#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] +#[cfg(not(target_arch = "wasm32"))] +async fn test_complete_mfa_challenge_totp() { + let server = common::create_mock_server(); + let mut bundle = common::setup_with_mock_server(&server).await; + + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/api/mfa/finish"), + request::body(json_decoded(eq( + json!({"ticket": "testticket", "mfa_type": "totp", "data": "testdata"}) + ))), + request::headers(contains(("authorization", "faketoken"))) + ]) + .respond_with(json_encoded(json!({"token": "testtoken"}))), + ); + + let schema = MfaVerifySchema { + ticket: "testticket".to_string(), + mfa_type: MfaAuthenticationType::TOTP, + data: "testdata".to_string(), + }; + + let result = bundle.user.complete_mfa_challenge(schema).await; + + assert!(result.is_ok()); + assert_eq!( + bundle.user.mfa_token.unwrap().token, + "testtoken".to_string() + ); +} + +#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] +#[cfg(not(target_arch = "wasm32"))] +async fn test_complete_mfa_challenge_sms() { + let server = common::create_mock_server(); + let mut bundle = common::setup_with_mock_server(&server).await; + + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/api/mfa/finish"), + request::body(json_decoded(eq( + json!({"ticket": "testticket", "mfa_type": "sms", "data": "testdata"}) + ))), + request::headers(contains(("authorization", "faketoken"))) + ]) + .respond_with(json_encoded(json!({"token": "testtoken"}))), + ); + + let schema = MfaVerifySchema { + ticket: "testticket".to_string(), + mfa_type: MfaAuthenticationType::SMS, + data: "testdata".to_string(), + }; + + let result = bundle.user.complete_mfa_challenge(schema).await; + + assert!(result.is_ok()); + assert_eq!( + bundle.user.mfa_token.unwrap().token, + "testtoken".to_string() + ); +} + +#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] +#[cfg(not(target_arch = "wasm32"))] +async fn test_verify_mfa_login_webauthn() { + let server = common::create_mock_server(); + let mut bundle = common::setup_with_mock_server(&server).await; + + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/api/mfa/finish"), + request::body(json_decoded(eq( + json!({"ticket": "testticket", "mfa_type": "webauthn", "data": "testdata"}) + ))), + request::headers(contains(("authorization", "faketoken"))) + ]) + .respond_with(json_encoded(json!({"token": "testtoken"}))), + ); + + let schema = MfaVerifySchema { + ticket: "testticket".to_string(), + mfa_type: MfaAuthenticationType::WebAuthn, + data: "testdata".to_string(), + }; + + let result = bundle.user.complete_mfa_challenge(schema).await; + + assert!(result.is_ok()); + assert_eq!( + bundle.user.mfa_token.unwrap().token, + "testtoken".to_string() + ); +} + +#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] +#[cfg(not(target_arch = "wasm32"))] +async fn test_complete_mfa_challenge_backup() { + let server = common::create_mock_server(); + let mut bundle = common::setup_with_mock_server(&server).await; + + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/api/mfa/finish"), + request::body(json_decoded(eq( + json!({"ticket": "testticket", "mfa_type": "backup", "data": "testdata"}) + ))), + request::headers(contains(("authorization", "faketoken"))) + ]) + .respond_with(json_encoded(json!({"token": "testtoken"}))), + ); + + let schema = MfaVerifySchema { + ticket: "testticket".to_string(), + mfa_type: MfaAuthenticationType::Backup, + data: "testdata".to_string(), + }; + + let result = bundle.user.complete_mfa_challenge(schema).await; + + assert!(result.is_ok()); + assert_eq!( + bundle.user.mfa_token.unwrap().token, + "testtoken".to_string() + ); +} + +#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] +#[cfg(not(target_arch = "wasm32"))] +async fn test_complete_mfa_challenge_password() { + let server = common::create_mock_server(); + let mut bundle = common::setup_with_mock_server(&server).await; + + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/api/mfa/finish"), + request::body(json_decoded(eq( + json!({"ticket": "testticket", "mfa_type": "password", "data": "testdata"}) + ))), + request::headers(contains(("authorization", "faketoken"))) + ]) + .respond_with(json_encoded(json!({"token": "testtoken"}))), + ); + + let schema = MfaVerifySchema { + ticket: "testticket".to_string(), + mfa_type: MfaAuthenticationType::Password, + data: "testdata".to_string(), + }; + + let result = bundle.user.complete_mfa_challenge(schema).await; + + assert!(result.is_ok()) +} + +#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] +#[cfg(not(target_arch = "wasm32"))] +async fn test_send_mfa_sms() { + let server = common::create_mock_server(); + let mut bundle = common::setup_with_mock_server(&server).await; + + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/api/auth/mfa/sms/send"), + request::body(json_decoded(eq(json!({"ticket": "testticket"})))), + ]) + .respond_with(json_encoded(json!({"phone": "+*******0085"}))), + ); + + let schema = SendMfaSmsSchema { + ticket: "testticket".to_string(), + }; + + let result = bundle.instance.send_mfa_sms(schema).await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap().phone, "+*******0085".to_string()); +} + +// Note: user mfa routes are also here, because the other mfa routes were already here +// TODO: Test also not having an mfa token and trying to make a request that needs mfa +#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] +#[cfg(not(target_arch = "wasm32"))] +async fn test_enable_totp_mfa() { + use chorus::types::{EnableTotpMfaSchema, MfaBackupCode, Snowflake}; + + let server = common::create_mock_server(); + let mut bundle = common::setup_with_mock_server(&server).await; + + // TODO: Once json response codes are implemented, add the case where we can validate a user's + // password + + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/api/users/@me/mfa/totp/enable"), + request::body(json_decoded(eq(json!({"password": "test_password", "secret":"testsecret", "code":"testcode"})))), + request::headers(contains(("authorization", "faketoken"))), + ]) + .respond_with(json_encoded(json!({"token": "testtoken", "backup_codes": [{"user_id": "852892297661906993", "code": "zqs8oqxk", "consumed": false}]}))), + ); + + let schema = EnableTotpMfaSchema { + code: Some("testcode".to_string()), + password: "test_password".to_string(), + secret: Some("testsecret".to_string()), + }; + + let result = bundle.user.enable_totp_mfa(schema).await; + + assert!(result.is_ok()); + assert_eq!( + result.unwrap().backup_codes, + vec![MfaBackupCode { + user_id: Snowflake(852892297661906993), + code: "zqs8oqxk".to_string(), + consumed: false + }] + ); +} + +#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] +#[cfg(not(target_arch = "wasm32"))] +async fn test_disable_totp_mfa() { + use chrono::{TimeDelta, Utc}; + + let server = common::create_mock_server(); + let mut bundle = common::setup_with_mock_server(&server).await; + + bundle.user.mfa_token = Some(chorus::types::MfaToken { + token: "fakemfatoken".to_string(), + expires_at: Utc::now() + TimeDelta::minutes(5), + }); + + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/api/users/@me/mfa/totp/disable"), + request::headers(contains(("x-discord-mfa-authorization", "fakemfatoken"))), + request::headers(contains(("authorization", "faketoken"))), + ]) + .respond_with(json_encoded(json!({"token": "testmfatoken"}))), + ); + + let result = bundle.user.disable_totp_mfa().await; + + assert!(result.is_ok()); +} + +#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] +#[cfg(not(target_arch = "wasm32"))] +async fn test_enable_sms_mfa() { + use chrono::{TimeDelta, Utc}; + use httptest::responders::status_code; + + let server = common::create_mock_server(); + let mut bundle = common::setup_with_mock_server(&server).await; + + bundle.user.mfa_token = Some(chorus::types::MfaToken { + token: "fakemfatoken".to_string(), + expires_at: Utc::now() + TimeDelta::minutes(5), + }); + + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/api/users/@me/mfa/sms/enable"), + request::headers(contains(("x-discord-mfa-authorization", "fakemfatoken"))), + request::headers(contains(("authorization", "faketoken"))), + request::body(json_decoded(eq(json!({"password": "test_password"})))), + ]) + .respond_with(status_code(204)), + ); + + let schema = chorus::types::SmsMfaRouteSchema { + password: "test_password".to_string(), + }; + + let result = bundle.user.enable_sms_mfa(schema).await; + + assert!(result.is_ok()); +} + +#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] +#[cfg(not(target_arch = "wasm32"))] +async fn test_disable_sms_mfa() { + use chrono::{TimeDelta, Utc}; + use httptest::responders::status_code; + + let server = common::create_mock_server(); + let mut bundle = common::setup_with_mock_server(&server).await; + + bundle.user.mfa_token = Some(chorus::types::MfaToken { + token: "fakemfatoken".to_string(), + expires_at: Utc::now() + TimeDelta::minutes(5), + }); + + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/api/users/@me/mfa/sms/disable"), + request::headers(contains(("x-discord-mfa-authorization", "fakemfatoken"))), + request::headers(contains(("authorization", "faketoken"))), + request::body(json_decoded(eq(json!({"password": "test_password"})))), + ]) + .respond_with(status_code(204)), + ); + + let schema = chorus::types::SmsMfaRouteSchema { + password: "test_password".to_string(), + }; + + let result = bundle.user.disable_sms_mfa(schema).await; + + assert!(result.is_ok()); +} + +#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] +#[cfg(not(target_arch = "wasm32"))] +async fn test_get_mfa_webauthn_authenticators() { + use chorus::types::{MfaAuthenticator, Snowflake}; + use chrono::{TimeDelta, Utc}; + + let server = common::create_mock_server(); + let mut bundle = common::setup_with_mock_server(&server).await; + + bundle.user.mfa_token = Some(chorus::types::MfaToken { + token: "fakemfatoken".to_string(), + expires_at: Utc::now() + TimeDelta::minutes(5), + }); + + server.expect( + Expectation::matching(all_of![ + request::method("GET"), + request::path("/api/users/@me/mfa/webauthn/credentials"), + request::headers(contains(("authorization", "faketoken"))), + ]) + .respond_with(json_encoded( + json!([{"id": "1219430671865610261", "type": 1, "name": "Alienkey"}]), + )), + ); + + let result = bundle.user.get_webauthn_authenticators().await; + + assert_eq!( + result.unwrap(), + vec![MfaAuthenticator { + id: Snowflake(1219430671865610261), + name: "Alienkey".to_string(), + authenticator_type: chorus::types::MfaAuthenticatorType::WebAuthn + }] + ); +} + +#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] +#[cfg(not(target_arch = "wasm32"))] +async fn test_create_mfa_webauthn_authenticator() { + use chorus::types::{ + FinishWebAuthnAuthenticatorCreationReturn, FinishWebAuthnAuthenticatorCreationSchema, + MfaAuthenticator, MfaBackupCode, Snowflake, + }; + use chrono::{TimeDelta, Utc}; + + let server = common::create_mock_server(); + let mut bundle = common::setup_with_mock_server(&server).await; + + bundle.user.mfa_token = Some(chorus::types::MfaToken { + token: "fakemfatoken".to_string(), + expires_at: Utc::now() + TimeDelta::minutes(5), + }); + + // Begin creation + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/api/users/@me/mfa/webauthn/credentials"), + request::headers(contains(("authorization", "faketoken"))), + request::headers(contains(("x-discord-mfa-authorization", "fakemfatoken"))), + ]) + .respond_with(json_encoded(json!({"ticket": "ODUyODkyMjk3NjYxOTA2OTkz.WrhGhYEhM3lHUPN61xF6JcQKwVutk8fBvcoHjo", "challenge": "{\"publicKey\":{\"challenge\":\"a8a1cHP7_zYheggFG68zKUkl8DwnEqfKvPE-GOMvhss\",\"timeout\":60000,\"rpId\":\"discord.com\",\"allowCredentials\":[{\"type\":\"public-key\",\"id\":\"izrvF80ogrfg9dC3RmWWwW1VxBVBG0TzJVXKOJl__6FvMa555dH4Trt2Ub8AdHxNLkQsc0unAGcn4-hrJHDKSO\"}],\"userVerification\":\"preferred\"}}"}))), + ); + + // Finish creation + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/api/users/@me/mfa/webauthn/credentials"), + request::headers(contains(("authorization", "faketoken"))), + request::headers(contains(("x-discord-mfa-authorization", "fakemfatoken"))), + request::body(json_decoded(eq(json!({"name": "AlienKey", "ticket": "ODUyODkyMjk3NjYxOTA2OTkz.WrhGhYEhM3lHUPN61xF6JcQKwVutk8fBvcoHjo", "credential": "{\"test\": \"lest\"}"})))), + ]) + .respond_with(json_encoded(json!({ "id": "1219430671865610261", + "type": 1, + "name": "AlienKey", + "backup_codes": [ + { + "user_id": "852892297661906993", + "code": "zqs8oqxk", + "consumed": false + } + ]}))), + ); + + let result = bundle + .user + .begin_webauthn_authenticator_creation() + .await + .unwrap(); + + let schema = FinishWebAuthnAuthenticatorCreationSchema { + name: "AlienKey".to_string(), + ticket: result.ticket, + credential: "{\"test\": \"lest\"}".to_string(), + }; + + let result = bundle + .user + .finish_webauthn_authenticator_creation(schema) + .await; + + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + FinishWebAuthnAuthenticatorCreationReturn { + backup_codes: vec![MfaBackupCode { + user_id: Snowflake(852892297661906993), + code: "zqs8oqxk".to_string(), + consumed: false + }], + authenticator: MfaAuthenticator { + name: "AlienKey".to_string(), + id: Snowflake(1219430671865610261), + authenticator_type: chorus::types::MfaAuthenticatorType::WebAuthn + } + } + ); +} + +#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] +#[cfg(not(target_arch = "wasm32"))] +async fn test_modify_mfa_webauthn_authenticator() { + use chorus::types::{MfaAuthenticator, ModifyWebAuthnAuthenticatorSchema, Snowflake}; + use chrono::{TimeDelta, Utc}; + + let server = common::create_mock_server(); + let mut bundle = common::setup_with_mock_server(&server).await; + + bundle.user.mfa_token = Some(chorus::types::MfaToken { + token: "fakemfatoken".to_string(), + expires_at: Utc::now() + TimeDelta::minutes(5), + }); + + server.expect( + Expectation::matching(all_of![ + request::method("PATCH"), + request::path("/api/users/@me/mfa/webauthn/credentials/1219430671865610261"), + request::headers(contains(("authorization", "faketoken"))), + request::headers(contains(("x-discord-mfa-authorization", "fakemfatoken"))), + request::body(json_decoded(eq(json!({"name": "Alienkey Pro Ultra SE+"})))), + ]) + .respond_with(json_encoded( + json!({ "id": "1219430671865610261", "type": 1, "name": "Alienkey Pro Ultra SE+" }), + )), + ); + + let schema = ModifyWebAuthnAuthenticatorSchema { + name: Some("Alienkey Pro Ultra SE+".to_string()), + }; + + let result = bundle + .user + .modify_webauthn_authenticator(Snowflake(1219430671865610261), schema) + .await; + + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + MfaAuthenticator { + name: "Alienkey Pro Ultra SE+".to_string(), + id: Snowflake(1219430671865610261), + authenticator_type: chorus::types::MfaAuthenticatorType::WebAuthn + } + ); +} + +#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] +#[cfg(not(target_arch = "wasm32"))] +async fn test_delete_mfa_webauthn_authenticator() { + use chorus::types::Snowflake; + use chrono::{TimeDelta, Utc}; + use httptest::responders::status_code; + + let server = common::create_mock_server(); + let mut bundle = common::setup_with_mock_server(&server).await; + + bundle.user.mfa_token = Some(chorus::types::MfaToken { + token: "fakemfatoken".to_string(), + expires_at: Utc::now() + TimeDelta::minutes(5), + }); + + server.expect( + Expectation::matching(all_of![ + request::method("DELETE"), + request::path("/api/users/@me/mfa/webauthn/credentials/1219430671865610261"), + request::headers(contains(("authorization", "faketoken"))), + request::headers(contains(("x-discord-mfa-authorization", "fakemfatoken"))), + ]) + .respond_with(status_code(204)), + ); + + let result = bundle + .user + .delete_webauthn_authenticator(Snowflake(1219430671865610261)) + .await; + + assert!(result.is_ok()); +} + +#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] +#[cfg(not(target_arch = "wasm32"))] +// Tests the send backup codes challenge and get backup codes endpoints +async fn test_send_mfa_backup_codes() { + use chorus::types::{MfaBackupCode, SendBackupCodesChallengeReturn, Snowflake}; + + let server = common::create_mock_server(); + let mut bundle = common::setup_with_mock_server(&server).await; + + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/api/auth/verify/view-backup-codes-challenge"), + request::headers(contains(("authorization", "faketoken"))), + request::body(json_decoded(eq(json!({"password": "test_password"})))), + ]) + .times(1) + .respond_with(json_encoded(json!({"nonce": "test_view_nonce", "regenerate_nonce": "test_regenerate_nonce"}))), + ); + + let schema = chorus::types::SendBackupCodesChallengeSchema { password: "test_password".to_string() }; + + let result = bundle + .user + .send_backup_codes_challenge(schema) + .await.unwrap(); + + assert_eq!(result, SendBackupCodesChallengeReturn {view_nonce: "test_view_nonce".to_string(), regenerate_nonce: "test_regenerate_nonce".to_string() }); + + // View routes, assume we got an email key of "test_key" + // View nonce, regenerate = false + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/api/users/@me/mfa/codes-verification"), + request::headers(contains(("authorization", "faketoken"))), + request::body(json_decoded(eq(json!({"key": "test_key", "nonce": "test_view_nonce", "regenerate": false})))), + ]) + .times(1) + .respond_with(json_encoded(json!([{"user_id": "852892297661906993", "code": "zqs8oqxk", "consumed": false}]))), + ); + + // Regenerate nonce, regenerate = true + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/api/users/@me/mfa/codes-verification"), + request::headers(contains(("authorization", "faketoken"))), + request::body(json_decoded(eq(json!({"key": "test_key", "nonce": "test_regenerate_nonce", "regenerate": true})))), + ]) + .times(1) + .respond_with(json_encoded(json!([{"user_id": "852892297661906993", "code": "oqxk8zqs", "consumed": false}]))), + ); + + let schema_view = chorus::types::GetBackupCodesSchema { nonce: result.view_nonce, key: "test_key".to_string(), regenerate: false }; + + let schema_regenerate = chorus::types::GetBackupCodesSchema { nonce: result.regenerate_nonce, key: "test_key".to_string(), regenerate: true }; + + let result_view = bundle.user.get_backup_codes(schema_view).await.unwrap(); + + assert_eq!(result_view, vec![MfaBackupCode {user_id: Snowflake(852892297661906993), code: "zqs8oqxk".to_string(), consumed: false}]); + + let result_regenerate = bundle.user.get_backup_codes(schema_regenerate).await.unwrap(); + + assert_ne!(result_view, result_regenerate); + assert_eq!(result_regenerate, vec![MfaBackupCode {user_id: Snowflake(852892297661906993), code: "oqxk8zqs".to_string(), consumed: false}]); +} diff --git a/tests/channels.rs b/tests/channels.rs index eb1c1200..091e3a32 100644 --- a/tests/channels.rs +++ b/tests/channels.rs @@ -2,7 +2,11 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -use chorus::types::{self, Channel, GetChannelMessagesSchema, MessageSendSchema, PermissionFlags, PermissionOverwrite, PermissionOverwriteType, PrivateChannelCreateSchema, RelationshipType, Snowflake}; +use chorus::types::{ + self, Channel, GetChannelMessagesSchema, MessageSendSchema, PermissionFlags, + PermissionOverwrite, PermissionOverwriteType, PrivateChannelCreateSchema, RelationshipType, + Snowflake, +}; mod common; @@ -168,7 +172,8 @@ async fn create_dm() { dm_channel .recipients .as_ref() - .unwrap().first() + .unwrap() + .first() .unwrap() .read() .unwrap() @@ -241,7 +246,8 @@ async fn remove_add_person_from_to_dm() { dm_channel .recipients .as_ref() - .unwrap().first() + .unwrap() + .first() .unwrap() .read() .unwrap() diff --git a/tests/common/mod.rs b/tests/common/mod.rs index ac85b859..4de7c68d 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -5,7 +5,7 @@ use std::str::FromStr; use chorus::gateway::{Gateway, GatewayOptions}; -use chorus::types::{DeleteDisableUserSchema, IntoShared, PermissionFlags}; +use chorus::types::{DeleteDisableUserSchema, IntoShared, PermissionFlags, Snowflake}; use chorus::{ instance::{ChorusUser, Instance}, types::{ @@ -17,6 +17,13 @@ use chorus::{ use chrono::NaiveDate; +#[cfg(not(target_arch = "wasm32"))] +use httptest::{ + matchers::{all_of, contains, request}, + responders::{json_encoded, status_code}, + Expectation, +}; + #[allow(dead_code)] #[derive(Debug)] pub(crate) struct TestBundle { @@ -47,6 +54,7 @@ impl TestBundle { ChorusUser { belongs_to: self.user.belongs_to.clone(), token: self.user.token.clone(), + mfa_token: None, limits: self.user.limits.clone(), settings: self.user.settings.clone(), object: self.user.object.clone(), @@ -57,7 +65,9 @@ impl TestBundle { } } -// Set up a test by creating an Instance and a User. Reduces Test boilerplate. +/// Set up a test by creating an [Instance] and a User for a real, +/// running server at localhost:3001. Reduces Test boilerplate. +#[allow(dead_code)] pub(crate) async fn setup() -> TestBundle { // So we can get logs when tests fail let _ = simple_logger::SimpleLogger::with_level( @@ -141,6 +151,74 @@ pub(crate) async fn setup() -> TestBundle { } } +/// Set up a test by creating an [Instance] and a User for a mocked +/// server with httptest. Reduces Test boilerplate. +/// +/// Note: httptest does not work on wasm! +/// +/// This test server will always provide snowflake ids as 123456789101112131 +/// and auth tokens as "faketoken" +#[allow(dead_code)] +#[cfg(not(target_arch = "wasm32"))] +pub(crate) async fn setup_with_mock_server(server: &httptest::Server) -> TestBundle { + // So we can get logs when tests fail + let _ = simple_logger::SimpleLogger::with_level( + simple_logger::SimpleLogger::new(), + log::LevelFilter::Debug, + ) + .init(); + + let instance = Instance::new(server.url_str("/api").as_str(), None) + .await + .unwrap(); + + // Requires the existence of the below user. + let reg = RegisterSchema { + username: "integrationtestuser".into(), + consent: true, + date_of_birth: Some(NaiveDate::from_str("2000-01-01").unwrap()), + ..Default::default() + }; + let user = instance.clone().register_account(reg).await.unwrap(); + + let guild = Guild { + id: Snowflake(123456789101112131), + name: Some("Test-Guild!".to_string()), + ..Default::default() + }; + + let channel = Channel { + id: Snowflake(123456789101112131), + name: Some("testchannel".to_string()), + channel_type: chorus::types::ChannelType::GuildText, + nsfw: Some(false), + flags: Some(0), + default_thread_rate_limit_per_user: Some(0), + ..Default::default() + }; + + let role = chorus::types::RoleObject { + id: Snowflake(123456789101112131), + name: "Bundle role".to_string(), + permissions: PermissionFlags::from_bits(8).unwrap(), + hoist: true, + unicode_emoji: Some(String::new()), + mentionable: true, + ..Default::default() + }; + + let urls = instance.urls.clone(); + + TestBundle { + urls, + user, + instance, + guild: guild.into_shared(), + role: role.into_shared(), + channel: channel.into_shared(), + } +} + // Teardown method to clean up after a test. #[allow(dead_code)] pub(crate) async fn teardown(mut bundle: TestBundle) { @@ -152,3 +230,177 @@ pub(crate) async fn teardown(mut bundle: TestBundle) { .await .unwrap() } + +/// Creates a mock http server at localhost:3001 with the basic routes +/// needed to run TestBundle setup and teardown +/// +/// Note: httptest does not work on wasm! +/// +/// This test server will always provide snowflake ids as 123456789101112131 +/// and auth tokens as "faketoken" +#[allow(dead_code)] +#[cfg(not(target_arch = "wasm32"))] +pub(crate) fn create_mock_server() -> httptest::Server { + let server = httptest::Server::run(); + + let api_url = server.url("/api"); + let cdn_url = server.url("/cdn"); + + // Just redirect it to the one we're running for spacebar tests + // We're using this just for the api anyway, so it can break after identifying + let gateway_url = "ws://localhost:3001"; + + // Mock the instance/domains endpoint, so we can create from a single url + server.expect( + Expectation::matching(all_of![ + request::method("GET"), + request::path("/api/policies/instance/domains") + ]) + .times(0..100) + .respond_with(json_encoded( + chorus::types::types::domains_configuration::Domains { + api_endpoint: api_url.to_string(), + cdn: cdn_url.to_string(), + gateway: gateway_url.to_string(), + default_api_version: "v9".to_string(), + }, + )), + ); + + // The following routes are mocked so that login and register work: + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/api/auth/register") + ]) + .times(0..100) + .respond_with(json_encoded(chorus::instance::Token { + token: "faketoken".to_string(), + })), + ); + + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/api/auth/login") + ]) + .times(0..100) + .respond_with(json_encoded(chorus::types::LoginResult { + token: "faketoken".to_string(), + settings: chorus::types::UserSettings { + ..Default::default() + } + .into_shared(), + })), + ); + + server.expect( + Expectation::matching(all_of![ + request::method("GET"), + request::path("/api/users/@me"), + request::headers(contains(("authorization", "faketoken"))) + ]) + .times(0..100) + .respond_with(json_encoded(chorus::types::User { + id: chorus::types::Snowflake(123456789101112131), + username: "integrationtestuser".to_string(), + discriminator: "1234".to_string(), + mfa_enabled: Some(true), + locale: Some(String::from("en-us")), + disabled: Some(false), + ..Default::default() + })), + ); + + server.expect( + Expectation::matching(all_of![ + request::method("GET"), + request::path("/api/users/@me/settings"), + request::headers(contains(("authorization", "faketoken"))) + ]) + .times(0..100) + .respond_with(json_encoded(chorus::types::UserSettings { + status: chorus::types::UserStatus::Online.into_shared(), + ..Default::default() + })), + ); + + // The folowing routes are mocked so that teardown works: + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + // Can we have wildcards here? + request::path("/api/guilds/123456789101112131/delete"), + request::headers(contains(("authorization", "faketoken"))) + ]) + .times(0..100) + .respond_with(status_code(200)), + ); + + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/api/users/@me/delete"), + request::headers(contains(("authorization", "faketoken"))) + ]) + .times(0..100) + .respond_with(status_code(200)), + ); + + // The following should just return a 404, and it's normal that we're getting them + server.expect( + Expectation::matching(all_of![ + request::method("GET"), + request::path("/api/.well-known/spacebar") + ]) + .times(0..100) + .respond_with(status_code(404)), + ); + + server.expect( + Expectation::matching(all_of![ + request::method("GET"), + request::path("/api/api/policies/instance/domains") + ]) + .times(0..100) + .respond_with(status_code(404)), + ); + + server.expect( + Expectation::matching(all_of![ + request::method("GET"), + request::path("/api/policies/instance/limits") + ]) + .times(0..100) + .respond_with(status_code(404)), + ); + + server.expect( + Expectation::matching(all_of![ + request::method("GET"), + request::path("/api/policies/instance/") + ]) + .times(0..100) + .respond_with(status_code(404)), + ); + + server.expect( + Expectation::matching(all_of![ + request::method("GET"), + request::path("/api/version") + ]) + .times(0..100) + .respond_with(status_code(404)), + ); + + server.expect( + Expectation::matching(all_of![ + request::method("GET"), + request::path("/api/ping") + ]) + .times(0..100) + .respond_with(status_code(404)), + ); + + server +} diff --git a/tests/gateway.rs b/tests/gateway.rs index ccd96eff..a48dc9db 100644 --- a/tests/gateway.rs +++ b/tests/gateway.rs @@ -49,6 +49,18 @@ impl Subscriber for GatewayReadyObserver { } } +#[derive(Debug)] +struct GatewayErrorObserver { + channel: tokio::sync::mpsc::Sender, +} + +#[async_trait] +impl Subscriber for GatewayErrorObserver { + async fn update(&self, data: &GatewayError) { + self.channel.send(data.clone()).await.unwrap(); + } +} + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] /// Tests establishing a connection and authenticating @@ -91,6 +103,82 @@ async fn test_gateway_authenticate() { common::teardown(bundle).await } +#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] +#[cfg_attr(not(target_arch = "wasm32"), tokio::test)] +/// Tests establishing a connection and receiving errors +async fn test_gateway_errors() { + let bundle = common::setup().await; + + // FIXME: Without this, this test does not work + // + // Can WASM handle multiple gateway existing simulatinously? + // This reminds of when my very old laptop couldn't handle multiple connections + // with the ammount of tasks + // + // Anyway, if you have a free weekend to spend debugging wasm, you're welcome to have a crack + // at this + bundle.user.gateway.close().await; + + let gateway: GatewayHandle = Gateway::spawn(&bundle.urls.wss, GatewayOptions::default()) + .await + .unwrap(); + + // First we'll authenticate, wait for ready, and then authenticate again to get AlreadyAuthenticated + let (ready_send, mut ready_receive) = tokio::sync::mpsc::channel(1); + + let observer = Arc::new(GatewayReadyObserver { + channel: ready_send, + }); + + gateway + .events + .lock() + .await + .session + .ready + .subscribe(observer); + + let (error_send, mut error_receive) = tokio::sync::mpsc::channel(1); + + let observer = Arc::new(GatewayErrorObserver { + channel: error_send, + }); + + gateway.events.lock().await.error.subscribe(observer); + + let mut identify = types::GatewayIdentifyPayload::common(); + identify.token = bundle.user.token.clone(); + + // Identify and wait to receive ready + gateway.send_identify(identify.clone()).await; + + tokio::select! { + // Fail, we timed out waiting for it + () = sleep(Duration::from_secs(20)) => { + println!("Timed out waiting for ready, failing.."); + assert!(false); + } + // Success, we have received it + Some(_) = ready_receive.recv() => {} + } + + // Identify again, so we should receive already authenticated + gateway.send_identify(identify).await; + + tokio::select! { + // Fail, we timed out waiting for it + () = sleep(Duration::from_secs(20)) => { + assert!(false); + } + // Success, we have received it + Some(error) = error_receive.recv() => { + assert_eq!(error, GatewayError::AlreadyAuthenticated); + } + } + + common::teardown(bundle).await; +} + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] async fn test_self_updating_structs() { @@ -210,92 +298,3 @@ async fn test_recursive_self_updating_structs() { assert_eq!(guild_role_inner.name, "yippieee".to_string()); common::teardown(bundle).await; } - -#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] -#[cfg_attr(not(target_arch = "wasm32"), test)] -fn test_error() { - let error = GatewayMessage("4000".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::Unknown); - let error = GatewayMessage("4001".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::UnknownOpcode); - let error = GatewayMessage("4002".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::Decode); - let error = GatewayMessage("4003".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::NotAuthenticated); - let error = GatewayMessage("4004".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::AuthenticationFailed); - let error = GatewayMessage("4005".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::AlreadyAuthenticated); - let error = GatewayMessage("4007".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::InvalidSequenceNumber); - let error = GatewayMessage("4008".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::RateLimited); - let error = GatewayMessage("4009".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::SessionTimedOut); - let error = GatewayMessage("4010".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::InvalidShard); - let error = GatewayMessage("4011".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::ShardingRequired); - let error = GatewayMessage("4012".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::InvalidAPIVersion); - let error = GatewayMessage("4013".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::InvalidIntents); - let error = GatewayMessage("4014".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::DisallowedIntents); -} - -#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] -#[cfg_attr(not(target_arch = "wasm32"), test)] -fn test_error_message() { - let error = GatewayMessage("Unknown Error".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::Unknown); - let error = GatewayMessage("Unknown Opcode".to_string()) - .error() - .unwrap(); - assert_eq!(error, GatewayError::UnknownOpcode); - let error = GatewayMessage("Decode Error".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::Decode); - let error = GatewayMessage("Not Authenticated".to_string()) - .error() - .unwrap(); - assert_eq!(error, GatewayError::NotAuthenticated); - let error = GatewayMessage("Authentication Failed".to_string()) - .error() - .unwrap(); - assert_eq!(error, GatewayError::AuthenticationFailed); - let error = GatewayMessage("Already Authenticated".to_string()) - .error() - .unwrap(); - assert_eq!(error, GatewayError::AlreadyAuthenticated); - let error = GatewayMessage("Invalid Seq".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::InvalidSequenceNumber); - let error = GatewayMessage("Rate Limited".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::RateLimited); - let error = GatewayMessage("Session Timed Out".to_string()) - .error() - .unwrap(); - assert_eq!(error, GatewayError::SessionTimedOut); - let error = GatewayMessage("Invalid Shard".to_string()).error().unwrap(); - assert_eq!(error, GatewayError::InvalidShard); - let error = GatewayMessage("Sharding Required".to_string()) - .error() - .unwrap(); - assert_eq!(error, GatewayError::ShardingRequired); - let error = GatewayMessage("Invalid API Version".to_string()) - .error() - .unwrap(); - assert_eq!(error, GatewayError::InvalidAPIVersion); - let error = GatewayMessage("Invalid Intent(s)".to_string()) - .error() - .unwrap(); - assert_eq!(error, GatewayError::InvalidIntents); - let error = GatewayMessage("Disallowed Intent(s)".to_string()) - .error() - .unwrap(); - assert_eq!(error, GatewayError::DisallowedIntents); - // Also test the dot thing - let error = GatewayMessage("Invalid Intent(s).".to_string()) - .error() - .unwrap(); - assert_eq!(error, GatewayError::InvalidIntents); -} diff --git a/tests/relationships.rs b/tests/relationships.rs index 4e57b22e..618f433c 100644 --- a/tests/relationships.rs +++ b/tests/relationships.rs @@ -16,8 +16,19 @@ async fn test_get_mutual_relationships() { let mut bundle = common::setup().await; let mut other_user = bundle.create_user("integrationtestuser2").await; let user = &mut bundle.user; - let username = user.object.read().unwrap().username.clone(); - let discriminator = user.object.read().unwrap().discriminator.clone(); + + let username = user + .object + .read() + .unwrap() + .username + .clone(); + let discriminator = user + .object + .read() + .unwrap() + .discriminator + .clone(); let other_user_id: types::Snowflake = other_user.object.read().unwrap().id; let friend_request_schema = types::FriendRequestSendSchema { username, @@ -38,8 +49,18 @@ async fn test_get_relationships() { let mut bundle = common::setup().await; let mut other_user = bundle.create_user("integrationtestuser2").await; let user = &mut bundle.user; - let username = user.object.read().unwrap().username.clone(); - let discriminator = user.object.read().unwrap().discriminator.clone(); + let username = user + .object + .read() + .unwrap() + .username + .clone(); + let discriminator = user + .object + .read() + .unwrap() + .discriminator + .clone(); let friend_request_schema = types::FriendRequestSendSchema { username, discriminator: Some(discriminator), @@ -62,8 +83,8 @@ async fn test_modify_relationship_friends() { let mut bundle = common::setup().await; let mut other_user = bundle.create_user("integrationtestuser2").await; let user = &mut bundle.user; - let user_id: types::Snowflake = user.object.read().unwrap().id; - let other_user_id: types::Snowflake = other_user.object.read().unwrap().id; + let user_id: types::Snowflake = user.object.as_ref().read().unwrap().id; + let other_user_id: types::Snowflake = other_user.object.as_ref().read().unwrap().id; other_user .modify_user_relationship(user_id, types::RelationshipType::Friends) @@ -72,7 +93,7 @@ async fn test_modify_relationship_friends() { let relationships = user.get_relationships().await.unwrap(); assert_eq!( relationships.first().unwrap().id, - other_user.object.read().unwrap().id + other_user.object.as_ref().read().unwrap().id ); assert_eq!( relationships.first().unwrap().relationship_type, @@ -81,7 +102,7 @@ async fn test_modify_relationship_friends() { let relationships = other_user.get_relationships().await.unwrap(); assert_eq!( relationships.first().unwrap().id, - user.object.read().unwrap().id + user.object.as_ref().read().unwrap().id ); assert_eq!( relationships.first().unwrap().relationship_type, @@ -114,7 +135,7 @@ async fn test_modify_relationship_block() { let mut bundle = common::setup().await; let mut other_user = bundle.create_user("integrationtestuser2").await; let user = &mut bundle.user; - let user_id: types::Snowflake = user.object.read().unwrap().id; + let user_id: types::Snowflake = user.object.as_ref().read().unwrap().id; other_user .modify_user_relationship(user_id, types::RelationshipType::Blocked) @@ -125,7 +146,7 @@ async fn test_modify_relationship_block() { let relationships = other_user.get_relationships().await.unwrap(); assert_eq!( relationships.first().unwrap().id, - user.object.read().unwrap().id + user.object.as_ref().read().unwrap().id ); assert_eq!( relationships.first().unwrap().relationship_type,