Skip to content

Commit

Permalink
feat: added local oauth (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxOhn authored Jul 6, 2024
1 parent 78bb705 commit e86bbe7
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 34 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
- kind: default-features
features: default
- kind: full-features
features: cache,macros,metrics,replay,serialize
features: cache,macros,metrics,replay,serialize,local_oauth

steps:
- name: Checkout project
Expand Down Expand Up @@ -100,7 +100,7 @@ jobs:
cargo hack check
--feature-powerset --no-dev-deps
--optional-deps metrics
--group-features default,cache,macros,deny_unknown_fields
--group-features default,cache,macros,local_oauth,deny_unknown_fields
readme:
name: Readme
Expand Down Expand Up @@ -150,4 +150,4 @@ jobs:
--filter-expr 'not binary(requests)'
- name: Run doctests
run: cargo test --doc --all-features
run: cargo test --doc --all-features
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ cache = ["dashmap"]
macros = ["rosu-mods/macros"]
replay = ["osu-db"]
serialize = []
local_oauth = ["tokio/net"]
deny_unknown_fields = []

# --- Dependencies ---
Expand Down
17 changes: 9 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,15 @@ async fn main() {

### Features

| Flag | Description | Dependencies
| ----------- | ---------------------------------------- | ------------
| `default` | Enable the `cache` and `macros` features |
| `cache` | Cache username-user_id pairs so that usernames can be used on all user endpoints instead of only user ids | [`dashmap`]
| `macros` | Re-exports `rosu-mods`'s `mods!` macro to easily create mods for a given mode | [`paste`]
| `serialize` | Implement `serde::Serialize` for most types, allowing for manual serialization |
| `metrics` | Uses the global metrics registry to store response time for each endpoint | [`metrics`]
| `replay` | Enables the method `Osu::replay` to parse a replay. Note that `Osu::replay_raw` is available without this feature but provides raw bytes instead of a parsed replay | [`osu-db`]
| Flag | Description | Dependencies
| ------------- | ---------------------------------------- | ------------
| `default` | Enable the `cache` and `macros` features |
| `cache` | Cache username-user_id pairs so that usernames can be used on all user endpoints instead of only user ids | [`dashmap`]
| `macros` | Re-exports `rosu-mods`'s `mods!` macro to easily create mods for a given mode | [`paste`]
| `serialize` | Implement `serde::Serialize` for most types, allowing for manual serialization |
| `metrics` | Uses the global metrics registry to store response time for each endpoint | [`metrics`]
| `replay` | Enables the method `Osu::replay` to parse a replay. Note that `Osu::replay_raw` is available without this feature but provides raw bytes instead of a parsed replay | [`osu-db`]
| `local_oauth` | Enables the method `OsuBuilder::with_local_authorization` to perform the full OAuth procedure | `tokio/net` feature

[osu!api v2]: https://osu.ppy.sh/docs/index.html
[`rosu`]: https://github.com/MaxOhn/rosu
Expand Down
64 changes: 49 additions & 15 deletions src/client/builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{Authorization, AuthorizationKind, Osu, OsuRef, Token};
use super::{token::AuthorizationBuilder, Authorization, AuthorizationKind, Osu, OsuRef, Token};
use crate::{error::OsuError, OsuResult};

use hyper::client::Builder;
Expand All @@ -17,7 +17,7 @@ use dashmap::DashMap;
/// For more info, check out <https://osu.ppy.sh/docs/index.html#client-credentials-grant>
#[must_use]
pub struct OsuBuilder {
auth_kind: Option<AuthorizationKind>,
auth: Option<AuthorizationBuilder>,
client_id: Option<u64>,
client_secret: Option<String>,
retries: usize,
Expand All @@ -26,10 +26,9 @@ pub struct OsuBuilder {
}

impl Default for OsuBuilder {
#[inline]
fn default() -> Self {
Self {
auth_kind: None,
auth: None,
client_id: None,
client_secret: None,
retries: 2,
Expand All @@ -41,12 +40,11 @@ impl Default for OsuBuilder {

impl OsuBuilder {
/// Create a new [`OsuBuilder`](crate::OsuBuilder)
#[inline]
pub fn new() -> Self {
Self::default()
}

/// Build an [`Osu`](crate::Osu) client.
/// Build an [`Osu`] client.
///
/// To build the client, the client id and secret are being used
/// to acquire a token from the API which expires after a certain time.
Expand All @@ -62,6 +60,17 @@ impl OsuBuilder {
let client_id = self.client_id.ok_or(OsuError::BuilderMissingId)?;
let client_secret = self.client_secret.ok_or(OsuError::BuilderMissingSecret)?;

let auth_kind = match self.auth {
Some(AuthorizationBuilder::Kind(kind)) => kind,
#[cfg(feature = "local_oauth")]
Some(AuthorizationBuilder::LocalOauth { redirect_uri }) => {
AuthorizationBuilder::perform_local_oauth(redirect_uri, client_id)
.await
.map(AuthorizationKind::User)?
}
None => AuthorizationKind::default(),
};

let connector = HttpsConnectorBuilder::new()
.with_native_roots()
.https_or_http()
Expand All @@ -86,7 +95,7 @@ impl OsuBuilder {
http,
ratelimiter,
timeout: self.timeout,
auth_kind: self.auth_kind.unwrap_or_default(),
auth_kind,
token: RwLock::new(Token::default()),
retries: self.retries,
});
Expand Down Expand Up @@ -119,7 +128,6 @@ impl OsuBuilder {
/// Set the client id of the application.
///
/// For more info, check out <https://osu.ppy.sh/docs/index.html#client-credentials-grant>
#[inline]
pub const fn client_id(mut self, client_id: u64) -> Self {
self.client_id = Some(client_id);

Expand All @@ -129,17 +137,45 @@ impl OsuBuilder {
/// Set the client secret of the application.
///
/// For more info, check out <https://osu.ppy.sh/docs/index.html#client-credentials-grant>
#[inline]
pub fn client_secret(mut self, client_secret: impl Into<String>) -> Self {
self.client_secret = Some(client_secret.into());

self
}

/// Upon calling [`build`], `rosu-v2` will print a url to authorize a local
/// osu! profile.
///
/// Be sure that the specified client id matches the OAuth application's
/// redirect uri.
///
/// If the authorization code has already been acquired, use
/// [`with_authorization`] instead.
///
/// For more info, check out
/// <https://osu.ppy.sh/docs/index.html#authorization-code-grant>
///
/// [`build`]: OsuBuilder::build
/// [`with_authorization`]: OsuBuilder::with_authorization
#[cfg(feature = "local_oauth")]
#[cfg_attr(docsrs, doc(cfg(feature = "local_oauth")))]
pub fn with_local_authorization(mut self, redirect_uri: impl Into<String>) -> Self {
self.auth = Some(AuthorizationBuilder::LocalOauth {
redirect_uri: redirect_uri.into(),
});

self
}

/// After acquiring the authorization code from a user through OAuth,
/// use this method to provide the given code, and specified redirect uri.
///
/// For more info, check out <https://osu.ppy.sh/docs/index.html#authorization-code-grant>
/// To perform the full OAuth procedure for a local osu! profile, enable the
/// `local_oauth` feature and use `OsuBuilder::with_local_authorization`
/// instead.
///
/// For more info, check out
/// <https://osu.ppy.sh/docs/index.html#authorization-code-grant>
pub fn with_authorization(
mut self,
code: impl Into<String>,
Expand All @@ -150,21 +186,21 @@ impl OsuBuilder {
redirect_uri: redirect_uri.into(),
};

self.auth_kind = Some(AuthorizationKind::User(authorization));
self.auth = Some(AuthorizationBuilder::Kind(AuthorizationKind::User(
authorization,
)));

self
}

/// In case the request times out, retry up to this many times, defaults to 2.
#[inline]
pub const fn retries(mut self, retries: usize) -> Self {
self.retries = retries;

self
}

/// Set the timeout for requests, defaults to 10 seconds.
#[inline]
pub const fn timeout(mut self, duration: Duration) -> Self {
self.timeout = duration;

Expand All @@ -177,8 +213,6 @@ impl OsuBuilder {
/// Check out the osu!api's [terms of use] for acceptable values.
///
/// [terms of use]: https://osu.ppy.sh/docs/index.html#terms-of-use
#[inline]
pub fn ratelimit(mut self, reqs_per_sec: u32) -> Self {
self.per_second = reqs_per_sec.clamp(1, 20);

Expand Down
124 changes: 124 additions & 0 deletions src/client/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,130 @@ fn adjust_token_expire(expires_in: i64) -> i64 {
expires_in - (expires_in as f64 * 0.05) as i64
}

pub(super) enum AuthorizationBuilder {
Kind(AuthorizationKind),
#[cfg(feature = "local_oauth")]
LocalOauth {
redirect_uri: String,
},
}

impl AuthorizationBuilder {
#[cfg(feature = "local_oauth")]
pub(super) async fn perform_local_oauth(
redirect_uri: String,
client_id: u64,
) -> Result<Authorization, crate::error::OAuthError> {
use std::{
fmt::Write,
io::{Error as IoError, ErrorKind},
str::from_utf8 as str_from_utf8,
};
use tokio::{
io::AsyncWriteExt,
net::{TcpListener, TcpStream},
};

use crate::error::OAuthError;

let port: u16 = redirect_uri
.rsplit_once(':')
.and_then(|(prefix, suffix)| {
suffix
.split('/')
.next()
.filter(|_| prefix.ends_with("localhost"))
})
.map(str::parse)
.and_then(Result::ok)
.ok_or(OAuthError::Url)?;

let listener = TcpListener::bind(("localhost", port))
.await
.map_err(OAuthError::Listener)?;

let mut url = format!(
"https://osu.ppy.sh/oauth/authorize?\
client_id={client_id}\
&redirect_uri={redirect_uri}\
&response_type=code",
);

let mut scopes = [Scope::Identify, Scope::Public].iter();

if let Some(scope) = scopes.next() {
let _ = write!(url, "&scopes=%22{scope}");

for scope in scopes {
let _ = write!(url, "+{scope}");
}

url.push_str("%22");
}

println!("Authorize yourself through the following url:\n{url}");
info!("Awaiting manual authorization...");

let (mut stream, _) = listener.accept().await.map_err(OAuthError::Accept)?;
let mut data = Vec::new();

loop {
stream.readable().await.map_err(OAuthError::Read)?;

match stream.try_read_buf(&mut data) {
Ok(0) => break,
Ok(_) => {
if data.ends_with(b"\n\n") || data.ends_with(b"\r\n\r\n") {
break;
}
}
Err(ref e) if e.kind() == ErrorKind::WouldBlock => continue,
Err(e) => return Err(OAuthError::Read(e)),
}
}

let code = str_from_utf8(&data)
.ok()
.and_then(|data| {
const KEY: &str = "code=";

if let Some(mut start) = data.find(KEY) {
start += KEY.len();

if let Some(end) = data[start..].find(char::is_whitespace) {
return Some(data[start..][..end].to_owned());
}
}

None
})
.ok_or(OAuthError::NoCode { data })?;

info!("Authorization succeeded");

#[allow(clippy::items_after_statements)]
async fn respond(stream: &mut TcpStream) -> Result<(), IoError> {
let response = b"HTTP/1.0 200 OK
Content-Type: text/html
<html><body>
<h2>rosu-v2 authentication succeeded</h2>
You may close this tab
</body></html>";

stream.writable().await?;
stream.write_all(response).await?;
stream.shutdown().await?;

Ok(())
}

respond(&mut stream).await.map_err(OAuthError::Write)?;

Ok(Authorization { code, redirect_uri })
}
}

pub(super) enum AuthorizationKind {
User(Authorization),
Client(Scope),
Expand Down
26 changes: 26 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,24 @@ use serde::Deserialize;
use serde_json::Error as SerdeError;
use std::fmt;

#[cfg(feature = "local_oauth")]
#[cfg_attr(docsrs, doc(cfg(feature = "local_oauth")))]
#[derive(Debug, thiserror::Error)]
pub enum OAuthError {
#[error("failed to accept request")]
Accept(#[source] tokio::io::Error),
#[error("failed to create tcp listener")]
Listener(#[source] tokio::io::Error),
#[error("missing code in request")]
NoCode { data: Vec<u8> },
#[error("failed to read data")]
Read(#[source] tokio::io::Error),
#[error("redirect uri must contain localhost and a port number")]
Url,
#[error("failed to write data")]
Write(#[source] tokio::io::Error),
}

/// The API response was of the form `{ "error": ... }`
#[derive(Debug, Deserialize, thiserror::Error)]
pub struct ApiError {
Expand Down Expand Up @@ -60,6 +78,14 @@ pub enum OsuError {
This should only occur during an extended downtime of the osu!api."
)]
NoToken,
#[cfg(feature = "local_oauth")]
#[cfg_attr(docsrs, doc(cfg(feature = "local_oauth")))]
/// Failed to perform OAuth
#[error("failed to perform oauth")]
OAuth {
#[from]
source: OAuthError,
},
#[cfg(feature = "replay")]
#[cfg_attr(docsrs, doc(cfg(feature = "replay")))]
/// There was an error while trying to use osu-db
Expand Down
Loading

0 comments on commit e86bbe7

Please sign in to comment.