diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 0eae5ee..f6f7a15 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -37,7 +37,7 @@ jobs: - kind: default-features features: default - kind: full-features - features: cache,macros,metrics,replay,serialize + features: cache,macros,metrics,replay,serialize,local_oauth steps: - name: Checkout project @@ -100,7 +100,7 @@ jobs: cargo hack check --feature-powerset --no-dev-deps --optional-deps metrics - --group-features default,cache,macros,deny_unknown_fields + --group-features default,cache,macros,local_oauth,deny_unknown_fields readme: name: Readme @@ -150,4 +150,4 @@ jobs: --filter-expr 'not binary(requests)' - name: Run doctests - run: cargo test --doc --all-features \ No newline at end of file + run: cargo test --doc --all-features diff --git a/Cargo.toml b/Cargo.toml index cdd7d91..0a39c43 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ cache = ["dashmap"] macros = ["rosu-mods/macros"] replay = ["osu-db"] serialize = [] +local_oauth = ["tokio/net"] deny_unknown_fields = [] # --- Dependencies --- diff --git a/README.md b/README.md index d17f453..71697ab 100644 --- a/README.md +++ b/README.md @@ -104,14 +104,15 @@ async fn main() { ### Features -| Flag | Description | Dependencies -| ----------- | ---------------------------------------- | ------------ -| `default` | Enable the `cache` and `macros` features | -| `cache` | Cache username-user_id pairs so that usernames can be used on all user endpoints instead of only user ids | [`dashmap`] -| `macros` | Re-exports `rosu-mods`'s `mods!` macro to easily create mods for a given mode | [`paste`] -| `serialize` | Implement `serde::Serialize` for most types, allowing for manual serialization | -| `metrics` | Uses the global metrics registry to store response time for each endpoint | [`metrics`] -| `replay` | Enables the method `Osu::replay` to parse a replay. Note that `Osu::replay_raw` is available without this feature but provides raw bytes instead of a parsed replay | [`osu-db`] +| Flag | Description | Dependencies +| ------------- | ---------------------------------------- | ------------ +| `default` | Enable the `cache` and `macros` features | +| `cache` | Cache username-user_id pairs so that usernames can be used on all user endpoints instead of only user ids | [`dashmap`] +| `macros` | Re-exports `rosu-mods`'s `mods!` macro to easily create mods for a given mode | [`paste`] +| `serialize` | Implement `serde::Serialize` for most types, allowing for manual serialization | +| `metrics` | Uses the global metrics registry to store response time for each endpoint | [`metrics`] +| `replay` | Enables the method `Osu::replay` to parse a replay. Note that `Osu::replay_raw` is available without this feature but provides raw bytes instead of a parsed replay | [`osu-db`] +| `local_oauth` | Enables the method `OsuBuilder::with_local_authorization` to perform the full OAuth procedure | `tokio/net` feature [osu!api v2]: https://osu.ppy.sh/docs/index.html [`rosu`]: https://github.com/MaxOhn/rosu diff --git a/src/client/builder.rs b/src/client/builder.rs index 161ede1..6d2abc7 100644 --- a/src/client/builder.rs +++ b/src/client/builder.rs @@ -1,4 +1,4 @@ -use super::{Authorization, AuthorizationKind, Osu, OsuRef, Token}; +use super::{token::AuthorizationBuilder, Authorization, AuthorizationKind, Osu, OsuRef, Token}; use crate::{error::OsuError, OsuResult}; use hyper::client::Builder; @@ -17,7 +17,7 @@ use dashmap::DashMap; /// For more info, check out #[must_use] pub struct OsuBuilder { - auth_kind: Option, + auth: Option, client_id: Option, client_secret: Option, retries: usize, @@ -26,10 +26,9 @@ pub struct OsuBuilder { } impl Default for OsuBuilder { - #[inline] fn default() -> Self { Self { - auth_kind: None, + auth: None, client_id: None, client_secret: None, retries: 2, @@ -41,12 +40,11 @@ impl Default for OsuBuilder { impl OsuBuilder { /// Create a new [`OsuBuilder`](crate::OsuBuilder) - #[inline] pub fn new() -> Self { Self::default() } - /// Build an [`Osu`](crate::Osu) client. + /// Build an [`Osu`] client. /// /// To build the client, the client id and secret are being used /// to acquire a token from the API which expires after a certain time. @@ -62,6 +60,17 @@ impl OsuBuilder { let client_id = self.client_id.ok_or(OsuError::BuilderMissingId)?; let client_secret = self.client_secret.ok_or(OsuError::BuilderMissingSecret)?; + let auth_kind = match self.auth { + Some(AuthorizationBuilder::Kind(kind)) => kind, + #[cfg(feature = "local_oauth")] + Some(AuthorizationBuilder::LocalOauth { redirect_uri }) => { + AuthorizationBuilder::perform_local_oauth(redirect_uri, client_id) + .await + .map(AuthorizationKind::User)? + } + None => AuthorizationKind::default(), + }; + let connector = HttpsConnectorBuilder::new() .with_native_roots() .https_or_http() @@ -86,7 +95,7 @@ impl OsuBuilder { http, ratelimiter, timeout: self.timeout, - auth_kind: self.auth_kind.unwrap_or_default(), + auth_kind, token: RwLock::new(Token::default()), retries: self.retries, }); @@ -119,7 +128,6 @@ impl OsuBuilder { /// Set the client id of the application. /// /// For more info, check out - #[inline] pub const fn client_id(mut self, client_id: u64) -> Self { self.client_id = Some(client_id); @@ -129,17 +137,45 @@ impl OsuBuilder { /// Set the client secret of the application. /// /// For more info, check out - #[inline] pub fn client_secret(mut self, client_secret: impl Into) -> Self { self.client_secret = Some(client_secret.into()); self } + /// Upon calling [`build`], `rosu-v2` will print a url to authorize a local + /// osu! profile. + /// + /// Be sure that the specified client id matches the OAuth application's + /// redirect uri. + /// + /// If the authorization code has already been acquired, use + /// [`with_authorization`] instead. + /// + /// For more info, check out + /// + /// + /// [`build`]: OsuBuilder::build + /// [`with_authorization`]: OsuBuilder::with_authorization + #[cfg(feature = "local_oauth")] + #[cfg_attr(docsrs, doc(cfg(feature = "local_oauth")))] + pub fn with_local_authorization(mut self, redirect_uri: impl Into) -> Self { + self.auth = Some(AuthorizationBuilder::LocalOauth { + redirect_uri: redirect_uri.into(), + }); + + self + } + /// After acquiring the authorization code from a user through OAuth, /// use this method to provide the given code, and specified redirect uri. /// - /// For more info, check out + /// To perform the full OAuth procedure for a local osu! profile, enable the + /// `local_oauth` feature and use `OsuBuilder::with_local_authorization` + /// instead. + /// + /// For more info, check out + /// pub fn with_authorization( mut self, code: impl Into, @@ -150,13 +186,14 @@ impl OsuBuilder { redirect_uri: redirect_uri.into(), }; - self.auth_kind = Some(AuthorizationKind::User(authorization)); + self.auth = Some(AuthorizationBuilder::Kind(AuthorizationKind::User( + authorization, + ))); self } /// In case the request times out, retry up to this many times, defaults to 2. - #[inline] pub const fn retries(mut self, retries: usize) -> Self { self.retries = retries; @@ -164,7 +201,6 @@ impl OsuBuilder { } /// Set the timeout for requests, defaults to 10 seconds. - #[inline] pub const fn timeout(mut self, duration: Duration) -> Self { self.timeout = duration; @@ -177,8 +213,6 @@ impl OsuBuilder { /// Check out the osu!api's [terms of use] for acceptable values. /// /// [terms of use]: https://osu.ppy.sh/docs/index.html#terms-of-use - - #[inline] pub fn ratelimit(mut self, reqs_per_sec: u32) -> Self { self.per_second = reqs_per_sec.clamp(1, 20); diff --git a/src/client/token.rs b/src/client/token.rs index cd7b3bc..63a5ee3 100644 --- a/src/client/token.rs +++ b/src/client/token.rs @@ -107,6 +107,130 @@ fn adjust_token_expire(expires_in: i64) -> i64 { expires_in - (expires_in as f64 * 0.05) as i64 } +pub(super) enum AuthorizationBuilder { + Kind(AuthorizationKind), + #[cfg(feature = "local_oauth")] + LocalOauth { + redirect_uri: String, + }, +} + +impl AuthorizationBuilder { + #[cfg(feature = "local_oauth")] + pub(super) async fn perform_local_oauth( + redirect_uri: String, + client_id: u64, + ) -> Result { + use std::{ + fmt::Write, + io::{Error as IoError, ErrorKind}, + str::from_utf8 as str_from_utf8, + }; + use tokio::{ + io::AsyncWriteExt, + net::{TcpListener, TcpStream}, + }; + + use crate::error::OAuthError; + + let port: u16 = redirect_uri + .rsplit_once(':') + .and_then(|(prefix, suffix)| { + suffix + .split('/') + .next() + .filter(|_| prefix.ends_with("localhost")) + }) + .map(str::parse) + .and_then(Result::ok) + .ok_or(OAuthError::Url)?; + + let listener = TcpListener::bind(("localhost", port)) + .await + .map_err(OAuthError::Listener)?; + + let mut url = format!( + "https://osu.ppy.sh/oauth/authorize?\ + client_id={client_id}\ + &redirect_uri={redirect_uri}\ + &response_type=code", + ); + + let mut scopes = [Scope::Identify, Scope::Public].iter(); + + if let Some(scope) = scopes.next() { + let _ = write!(url, "&scopes=%22{scope}"); + + for scope in scopes { + let _ = write!(url, "+{scope}"); + } + + url.push_str("%22"); + } + + println!("Authorize yourself through the following url:\n{url}"); + info!("Awaiting manual authorization..."); + + let (mut stream, _) = listener.accept().await.map_err(OAuthError::Accept)?; + let mut data = Vec::new(); + + loop { + stream.readable().await.map_err(OAuthError::Read)?; + + match stream.try_read_buf(&mut data) { + Ok(0) => break, + Ok(_) => { + if data.ends_with(b"\n\n") || data.ends_with(b"\r\n\r\n") { + break; + } + } + Err(ref e) if e.kind() == ErrorKind::WouldBlock => continue, + Err(e) => return Err(OAuthError::Read(e)), + } + } + + let code = str_from_utf8(&data) + .ok() + .and_then(|data| { + const KEY: &str = "code="; + + if let Some(mut start) = data.find(KEY) { + start += KEY.len(); + + if let Some(end) = data[start..].find(char::is_whitespace) { + return Some(data[start..][..end].to_owned()); + } + } + + None + }) + .ok_or(OAuthError::NoCode { data })?; + + info!("Authorization succeeded"); + + #[allow(clippy::items_after_statements)] + async fn respond(stream: &mut TcpStream) -> Result<(), IoError> { + let response = b"HTTP/1.0 200 OK +Content-Type: text/html + + +

rosu-v2 authentication succeeded

+You may close this tab +"; + + stream.writable().await?; + stream.write_all(response).await?; + stream.shutdown().await?; + + Ok(()) + } + + respond(&mut stream).await.map_err(OAuthError::Write)?; + + Ok(Authorization { code, redirect_uri }) + } +} + pub(super) enum AuthorizationKind { User(Authorization), Client(Scope), diff --git a/src/error.rs b/src/error.rs index 522aa7e..d4ed6e3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -5,6 +5,24 @@ use serde::Deserialize; use serde_json::Error as SerdeError; use std::fmt; +#[cfg(feature = "local_oauth")] +#[cfg_attr(docsrs, doc(cfg(feature = "local_oauth")))] +#[derive(Debug, thiserror::Error)] +pub enum OAuthError { + #[error("failed to accept request")] + Accept(#[source] tokio::io::Error), + #[error("failed to create tcp listener")] + Listener(#[source] tokio::io::Error), + #[error("missing code in request")] + NoCode { data: Vec }, + #[error("failed to read data")] + Read(#[source] tokio::io::Error), + #[error("redirect uri must contain localhost and a port number")] + Url, + #[error("failed to write data")] + Write(#[source] tokio::io::Error), +} + /// The API response was of the form `{ "error": ... }` #[derive(Debug, Deserialize, thiserror::Error)] pub struct ApiError { @@ -60,6 +78,14 @@ pub enum OsuError { This should only occur during an extended downtime of the osu!api." )] NoToken, + #[cfg(feature = "local_oauth")] + #[cfg_attr(docsrs, doc(cfg(feature = "local_oauth")))] + /// Failed to perform OAuth + #[error("failed to perform oauth")] + OAuth { + #[from] + source: OAuthError, + }, #[cfg(feature = "replay")] #[cfg_attr(docsrs, doc(cfg(feature = "replay")))] /// There was an error while trying to use osu-db diff --git a/src/lib.rs b/src/lib.rs index 02b7a84..0b80991 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -103,14 +103,15 @@ //! //! ## Features //! -//! | Flag | Description | Dependencies -//! | ----------- | ---------------------------------------- | ------------ -//! | `default` | Enable the `cache` and `macros` features | -//! | `cache` | Cache username-user_id pairs so that usernames can be used on all user endpoints instead of only user ids | [`dashmap`] -//! | `macros` | Re-exports `rosu-mods`'s `mods!` macro to easily create mods for a given mode | [`paste`] -//! | `serialize` | Implement `serde::Serialize` for most types, allowing for manual serialization | -//! | `metrics` | Uses the global metrics registry to store response time for each endpoint | [`metrics`] -//! | `replay` | Enables the method `Osu::replay` to parse a replay. Note that `Osu::replay_raw` is available without this feature but provides raw bytes instead of a parsed replay | [`osu-db`] +//! | Flag | Description | Dependencies +//! | ------------- | ---------------------------------------- | ------------ +//! | `default` | Enable the `cache` and `macros` features | +//! | `cache` | Cache username-user_id pairs so that usernames can be used on all user endpoints instead of only user ids | [`dashmap`] +//! | `macros` | Re-exports `rosu-mods`'s `mods!` macro to easily create mods for a given mode | [`paste`] +//! | `serialize` | Implement `serde::Serialize` for most types, allowing for manual serialization | +//! | `metrics` | Uses the global metrics registry to store response time for each endpoint | [`metrics`] +//! | `replay` | Enables the method `Osu::replay` to parse a replay. Note that `Osu::replay_raw` is available without this feature but provides raw bytes instead of a parsed replay | [`osu-db`] +//! | `local_oauth` | Enables the method `OsuBuilder::with_local_authorization` to perform the full OAuth procedure | `tokio/net` feature //! //! [osu!api v2]: https://osu.ppy.sh/docs/index.html //! [`rosu`]: https://github.com/MaxOhn/rosu