Skip to content

Commit

Permalink
refactor: authorization
Browse files Browse the repository at this point in the history
  • Loading branch information
henry40408 committed Jan 26, 2025
1 parent 9de9658 commit 7f170d1
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 56 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ base64 = "0.22.0"
bcrypt = "0.16.0"
chrono = "0.4.35"
clap = { version = "4.5.2", features = ["derive", "env"] }
http = "1.2.0"
imsz = "0.3.1"
parking_lot = "0.12.1"
rand = "0.8.5"
Expand Down
107 changes: 51 additions & 56 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ pub enum Commands {
pub enum MyError {
#[error("bcrypt error: {0}")]
Bcrypt(#[from] bcrypt::BcryptError),
#[error("decode error: {0}")]
Decode(#[from] base64::DecodeError),
#[error("directory is empty: {0}")]
EmptyDirectory(PathBuf),
#[error("imsz error: {0}")]
Expand All @@ -101,6 +103,10 @@ pub enum MyError {
PasswordMismatched,
#[error("failed to strip prefix")]
StripPrefixError(#[from] path::StripPrefixError),
#[error("failed to convert to string: {0}")]
ToString(#[from] http::header::ToStrError),
#[error("failed to convert to UTF-8: {0}")]
UTF8(#[from] std::string::FromUtf8Error),
}

type MyResult<T> = Result<T, MyError>;
Expand Down Expand Up @@ -292,65 +298,53 @@ enum AuthState {
Failed,
}

fn authenticate(request: &Request) -> AuthState {
let expected = match get_expected_credentials() {
None => return AuthState::Public,
fn authenticate(request: &Request) -> Result<AuthState, MyError> {
let (expected_username, expected_password) = match get_expected_credentials() {
None => return Ok(AuthState::Public),
Some(e) => e,
};

let header_value = request.headers().get("authorization");
if header_value.is_none() {
return AuthState::Request;
let header_value = match request.headers().get(header::AUTHORIZATION) {
None => return Ok(AuthState::Request),
Some(v) => v,
};
let header_str = header_value.to_str()?;
let parts: Vec<&str> = header_str.split_ascii_whitespace().collect();
let digest = match (parts.first().map(|s| s.to_ascii_lowercase()), parts.get(1)) {
(Some(scheme), Some(digest)) if scheme == "basic" => digest,
_ => return Ok(AuthState::Failed),
};
let decoded = BASE64_ENGINE.decode(digest)?;
let decoded_str = String::from_utf8(decoded)?;
let actual: Vec<String> = decoded_str.split(':').map(String::from).collect();
let (username, password) = match (actual.first(), actual.get(1)) {
(Some(u), Some(p)) if &**u == &*expected_username => (u, p),
_ => return Ok(AuthState::Failed),
};
match (
username == &*expected_username,
bcrypt::verify(password, &expected_password),
) {
(true, Ok(true)) => Ok(AuthState::Success),
(true, Ok(false)) => Ok(AuthState::Failed),
(true, Err(err)) => {
error!(?err, "failed to verify password");
Err(MyError::Bcrypt(err))
}
(false, _) => Ok(AuthState::Failed),
}

header_value
.and_then(|v| v.to_str().ok())
.map(|s| s.split_ascii_whitespace().collect::<Vec<&str>>())
.and_then(|splitted| {
match (
&splitted.first().map(|s| s.to_ascii_lowercase()),
splitted.get(1).copied(),
) {
(Some(scheme), Some(digest)) if scheme == "basic" => Some(digest),
_ => None,
}
})
.and_then(|digest| BASE64_ENGINE.decode(digest).ok())
.and_then(|decoded| String::from_utf8(decoded).ok())
.map(|decoded| {
decoded
.split(':')
.map(String::from)
.collect::<Vec<String>>()
})
.map_or_else(
|| AuthState::Failed,
|splitted| match (splitted.first(), splitted.get(1)) {
(Some(u), Some(p)) if &*u == &*expected.0 => bcrypt::verify(p, &expected.1)
.map_err(|err| error!(?err, "failed to verify password"))
.ok()
.map_or_else(
|| AuthState::Failed,
|matched| {
if matched {
AuthState::Success
} else {
AuthState::Failed
}
},
),
_ => AuthState::Failed,
},
)
}

async fn auth_middleware_fn(request: Request, next: Next) -> impl IntoResponse {
match authenticate(&request) {
AuthState::Public | AuthState::Success => next.run(request).await,
AuthState::Failed => StatusCode::UNAUTHORIZED.into_response(),
AuthState::Request => {
Ok(AuthState::Public | AuthState::Success) => next.run(request).await,
Ok(AuthState::Failed) => StatusCode::UNAUTHORIZED.into_response(),
Ok(AuthState::Request) => {
(StatusCode::UNAUTHORIZED, WWW_AUTHENTICATE_HEADER, "").into_response()
}
Err(err) => {
error!(%err, "failed to authenticate");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}

Expand Down Expand Up @@ -508,11 +502,12 @@ pub fn init_route(cli: &Cli, tx: Sender<()>) -> MyResult<Router> {
let data_dir = &cli.data_dir;

let seed = cli.seed.unwrap_or_else(|| {
warn!("no seed provided, use seconds since UNIX epoch as seed");
SystemTime::now()
let seed = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs()
.as_secs();
warn!(%seed, "no seed provided, use seconds since UNIX epoch as seed");
seed
});
let state = AppState {
data_dir: data_dir.clone(),
Expand Down Expand Up @@ -554,14 +549,14 @@ pub fn init_route(cli: &Cli, tx: Sender<()>) -> MyResult<Router> {
}
};

let books = &new_scan.books.len();
let pages = &new_scan.pages_map.len();
let total_books = &new_scan.books.len();
let total_pages = &new_scan.pages_map.len();
let duration = new_scan
.scan_duration
.to_std()
.map(|d| format!("{d:?}"))
.unwrap_or(String::new());
info!(books, pages, %duration, "initial scan finished");
info!(total_books, total_pages, %duration, "initial scan finished");

*state.scan.lock() = Some(new_scan);
}
Expand Down

0 comments on commit 7f170d1

Please sign in to comment.