Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added local oauth #29

Merged
merged 1 commit into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
- kind: default-features
features: default
- kind: full-features
features: cache,macros,metrics,replay,serialize
features: cache,macros,metrics,replay,serialize,local_oauth

steps:
- name: Checkout project
Expand Down Expand Up @@ -100,7 +100,7 @@ jobs:
cargo hack check
--feature-powerset --no-dev-deps
--optional-deps metrics
--group-features default,cache,macros,deny_unknown_fields
--group-features default,cache,macros,local_oauth,deny_unknown_fields

readme:
name: Readme
Expand Down Expand Up @@ -150,4 +150,4 @@ jobs:
--filter-expr 'not binary(requests)'

- name: Run doctests
run: cargo test --doc --all-features
run: cargo test --doc --all-features
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ cache = ["dashmap"]
macros = ["rosu-mods/macros"]
replay = ["osu-db"]
serialize = []
local_oauth = ["tokio/net"]
deny_unknown_fields = []

# --- Dependencies ---
Expand Down
17 changes: 9 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,15 @@ async fn main() {

### Features

| Flag | Description | Dependencies
| ----------- | ---------------------------------------- | ------------
| `default` | Enable the `cache` and `macros` features |
| `cache` | Cache username-user_id pairs so that usernames can be used on all user endpoints instead of only user ids | [`dashmap`]
| `macros` | Re-exports `rosu-mods`'s `mods!` macro to easily create mods for a given mode | [`paste`]
| `serialize` | Implement `serde::Serialize` for most types, allowing for manual serialization |
| `metrics` | Uses the global metrics registry to store response time for each endpoint | [`metrics`]
| `replay` | Enables the method `Osu::replay` to parse a replay. Note that `Osu::replay_raw` is available without this feature but provides raw bytes instead of a parsed replay | [`osu-db`]
| Flag | Description | Dependencies
| ------------- | ---------------------------------------- | ------------
| `default` | Enable the `cache` and `macros` features |
| `cache` | Cache username-user_id pairs so that usernames can be used on all user endpoints instead of only user ids | [`dashmap`]
| `macros` | Re-exports `rosu-mods`'s `mods!` macro to easily create mods for a given mode | [`paste`]
| `serialize` | Implement `serde::Serialize` for most types, allowing for manual serialization |
| `metrics` | Uses the global metrics registry to store response time for each endpoint | [`metrics`]
| `replay` | Enables the method `Osu::replay` to parse a replay. Note that `Osu::replay_raw` is available without this feature but provides raw bytes instead of a parsed replay | [`osu-db`]
| `local_oauth` | Enables the method `OsuBuilder::with_local_authorization` to perform the full OAuth procedure | `tokio/net` feature

[osu!api v2]: https://osu.ppy.sh/docs/index.html
[`rosu`]: https://github.com/MaxOhn/rosu
Expand Down
64 changes: 49 additions & 15 deletions src/client/builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{Authorization, AuthorizationKind, Osu, OsuRef, Token};
use super::{token::AuthorizationBuilder, Authorization, AuthorizationKind, Osu, OsuRef, Token};
use crate::{error::OsuError, OsuResult};

use hyper::client::Builder;
Expand All @@ -17,7 +17,7 @@ use dashmap::DashMap;
/// For more info, check out <https://osu.ppy.sh/docs/index.html#client-credentials-grant>
#[must_use]
pub struct OsuBuilder {
auth_kind: Option<AuthorizationKind>,
auth: Option<AuthorizationBuilder>,
client_id: Option<u64>,
client_secret: Option<String>,
retries: usize,
Expand All @@ -26,10 +26,9 @@ pub struct OsuBuilder {
}

impl Default for OsuBuilder {
#[inline]
fn default() -> Self {
Self {
auth_kind: None,
auth: None,
client_id: None,
client_secret: None,
retries: 2,
Expand All @@ -41,12 +40,11 @@ impl Default for OsuBuilder {

impl OsuBuilder {
/// Create a new [`OsuBuilder`](crate::OsuBuilder)
#[inline]
pub fn new() -> Self {
Self::default()
}

/// Build an [`Osu`](crate::Osu) client.
/// Build an [`Osu`] client.
///
/// To build the client, the client id and secret are being used
/// to acquire a token from the API which expires after a certain time.
Expand All @@ -62,6 +60,17 @@ impl OsuBuilder {
let client_id = self.client_id.ok_or(OsuError::BuilderMissingId)?;
let client_secret = self.client_secret.ok_or(OsuError::BuilderMissingSecret)?;

let auth_kind = match self.auth {
Some(AuthorizationBuilder::Kind(kind)) => kind,
#[cfg(feature = "local_oauth")]
Some(AuthorizationBuilder::LocalOauth { redirect_uri }) => {
AuthorizationBuilder::perform_local_oauth(redirect_uri, client_id)
.await
.map(AuthorizationKind::User)?
}
None => AuthorizationKind::default(),
};

let connector = HttpsConnectorBuilder::new()
.with_native_roots()
.https_or_http()
Expand All @@ -86,7 +95,7 @@ impl OsuBuilder {
http,
ratelimiter,
timeout: self.timeout,
auth_kind: self.auth_kind.unwrap_or_default(),
auth_kind,
token: RwLock::new(Token::default()),
retries: self.retries,
});
Expand Down Expand Up @@ -119,7 +128,6 @@ impl OsuBuilder {
/// Set the client id of the application.
///
/// For more info, check out <https://osu.ppy.sh/docs/index.html#client-credentials-grant>
#[inline]
pub const fn client_id(mut self, client_id: u64) -> Self {
self.client_id = Some(client_id);

Expand All @@ -129,17 +137,45 @@ impl OsuBuilder {
/// Set the client secret of the application.
///
/// For more info, check out <https://osu.ppy.sh/docs/index.html#client-credentials-grant>
#[inline]
pub fn client_secret(mut self, client_secret: impl Into<String>) -> Self {
self.client_secret = Some(client_secret.into());

self
}

/// Upon calling [`build`], `rosu-v2` will print a url to authorize a local
/// osu! profile.
///
/// Be sure that the specified client id matches the OAuth application's
/// redirect uri.
///
/// If the authorization code has already been acquired, use
/// [`with_authorization`] instead.
///
/// For more info, check out
/// <https://osu.ppy.sh/docs/index.html#authorization-code-grant>
///
/// [`build`]: OsuBuilder::build
/// [`with_authorization`]: OsuBuilder::with_authorization
#[cfg(feature = "local_oauth")]
#[cfg_attr(docsrs, doc(cfg(feature = "local_oauth")))]
pub fn with_local_authorization(mut self, redirect_uri: impl Into<String>) -> Self {
self.auth = Some(AuthorizationBuilder::LocalOauth {
redirect_uri: redirect_uri.into(),
});

self
}

/// After acquiring the authorization code from a user through OAuth,
/// use this method to provide the given code, and specified redirect uri.
///
/// For more info, check out <https://osu.ppy.sh/docs/index.html#authorization-code-grant>
/// To perform the full OAuth procedure for a local osu! profile, enable the
/// `local_oauth` feature and use `OsuBuilder::with_local_authorization`
/// instead.
///
/// For more info, check out
/// <https://osu.ppy.sh/docs/index.html#authorization-code-grant>
pub fn with_authorization(
mut self,
code: impl Into<String>,
Expand All @@ -150,21 +186,21 @@ impl OsuBuilder {
redirect_uri: redirect_uri.into(),
};

self.auth_kind = Some(AuthorizationKind::User(authorization));
self.auth = Some(AuthorizationBuilder::Kind(AuthorizationKind::User(
authorization,
)));

self
}

/// In case the request times out, retry up to this many times, defaults to 2.
#[inline]
pub const fn retries(mut self, retries: usize) -> Self {
self.retries = retries;

self
}

/// Set the timeout for requests, defaults to 10 seconds.
#[inline]
pub const fn timeout(mut self, duration: Duration) -> Self {
self.timeout = duration;

Expand All @@ -177,8 +213,6 @@ impl OsuBuilder {
/// Check out the osu!api's [terms of use] for acceptable values.
///
/// [terms of use]: https://osu.ppy.sh/docs/index.html#terms-of-use

#[inline]
pub fn ratelimit(mut self, reqs_per_sec: u32) -> Self {
self.per_second = reqs_per_sec.clamp(1, 20);

Expand Down
124 changes: 124 additions & 0 deletions src/client/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,130 @@ fn adjust_token_expire(expires_in: i64) -> i64 {
expires_in - (expires_in as f64 * 0.05) as i64
}

pub(super) enum AuthorizationBuilder {
Kind(AuthorizationKind),
#[cfg(feature = "local_oauth")]
LocalOauth {
redirect_uri: String,
},
}

impl AuthorizationBuilder {
#[cfg(feature = "local_oauth")]
pub(super) async fn perform_local_oauth(
redirect_uri: String,
client_id: u64,
) -> Result<Authorization, crate::error::OAuthError> {
use std::{
fmt::Write,
io::{Error as IoError, ErrorKind},
str::from_utf8 as str_from_utf8,
};
use tokio::{
io::AsyncWriteExt,
net::{TcpListener, TcpStream},
};

use crate::error::OAuthError;

let port: u16 = redirect_uri
.rsplit_once(':')
.and_then(|(prefix, suffix)| {
suffix
.split('/')
.next()
.filter(|_| prefix.ends_with("localhost"))
})
.map(str::parse)
.and_then(Result::ok)
.ok_or(OAuthError::Url)?;

let listener = TcpListener::bind(("localhost", port))
.await
.map_err(OAuthError::Listener)?;

let mut url = format!(
"https://osu.ppy.sh/oauth/authorize?\
client_id={client_id}\
&redirect_uri={redirect_uri}\
&response_type=code",
);

let mut scopes = [Scope::Identify, Scope::Public].iter();

if let Some(scope) = scopes.next() {
let _ = write!(url, "&scopes=%22{scope}");

for scope in scopes {
let _ = write!(url, "+{scope}");
}

url.push_str("%22");
}

println!("Authorize yourself through the following url:\n{url}");
info!("Awaiting manual authorization...");

let (mut stream, _) = listener.accept().await.map_err(OAuthError::Accept)?;
let mut data = Vec::new();

loop {
stream.readable().await.map_err(OAuthError::Read)?;

match stream.try_read_buf(&mut data) {
Ok(0) => break,
Ok(_) => {
if data.ends_with(b"\n\n") || data.ends_with(b"\r\n\r\n") {
break;
}
}
Err(ref e) if e.kind() == ErrorKind::WouldBlock => continue,
Err(e) => return Err(OAuthError::Read(e)),
}
}

let code = str_from_utf8(&data)
.ok()
.and_then(|data| {
const KEY: &str = "code=";

if let Some(mut start) = data.find(KEY) {
start += KEY.len();

if let Some(end) = data[start..].find(char::is_whitespace) {
return Some(data[start..][..end].to_owned());
}
}

None
})
.ok_or(OAuthError::NoCode { data })?;

info!("Authorization succeeded");

#[allow(clippy::items_after_statements)]
async fn respond(stream: &mut TcpStream) -> Result<(), IoError> {
let response = b"HTTP/1.0 200 OK
Content-Type: text/html

<html><body>
<h2>rosu-v2 authentication succeeded</h2>
You may close this tab
</body></html>";

stream.writable().await?;
stream.write_all(response).await?;
stream.shutdown().await?;

Ok(())
}

respond(&mut stream).await.map_err(OAuthError::Write)?;

Ok(Authorization { code, redirect_uri })
}
}

pub(super) enum AuthorizationKind {
User(Authorization),
Client(Scope),
Expand Down
26 changes: 26 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,24 @@ use serde::Deserialize;
use serde_json::Error as SerdeError;
use std::fmt;

#[cfg(feature = "local_oauth")]
#[cfg_attr(docsrs, doc(cfg(feature = "local_oauth")))]
#[derive(Debug, thiserror::Error)]
pub enum OAuthError {
#[error("failed to accept request")]
Accept(#[source] tokio::io::Error),
#[error("failed to create tcp listener")]
Listener(#[source] tokio::io::Error),
#[error("missing code in request")]
NoCode { data: Vec<u8> },
#[error("failed to read data")]
Read(#[source] tokio::io::Error),
#[error("redirect uri must contain localhost and a port number")]
Url,
#[error("failed to write data")]
Write(#[source] tokio::io::Error),
}

/// The API response was of the form `{ "error": ... }`
#[derive(Debug, Deserialize, thiserror::Error)]
pub struct ApiError {
Expand Down Expand Up @@ -60,6 +78,14 @@ pub enum OsuError {
This should only occur during an extended downtime of the osu!api."
)]
NoToken,
#[cfg(feature = "local_oauth")]
#[cfg_attr(docsrs, doc(cfg(feature = "local_oauth")))]
/// Failed to perform OAuth
#[error("failed to perform oauth")]
OAuth {
#[from]
source: OAuthError,
},
#[cfg(feature = "replay")]
#[cfg_attr(docsrs, doc(cfg(feature = "replay")))]
/// There was an error while trying to use osu-db
Expand Down
Loading