diff --git a/src/client.rs b/src/client.rs index 6bc0a8b8..f3a9347d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -33,6 +33,7 @@ use std::collections::HashMap; use std::convert::TryFrom; use std::sync::Arc; use tokio::io::{AsyncWrite, AsyncWriteExt}; +use tokio::sync::RwLock; use tracing::{debug, trace, warn}; const MIME_TYPES_DISTRIBUTION_MANIFEST: &[&str] = &[ @@ -205,6 +206,8 @@ impl TryFrom for ConfigFile { #[derive(Clone)] pub struct Client { config: Arc, + // Registry -> RegistryAuth + auth_store: Arc>>, tokens: TokenCache, client: reqwest::Client, push_chunk_size: usize, @@ -213,9 +216,10 @@ pub struct Client { impl Default for Client { fn default() -> Self { Self { - config: Arc::new(ClientConfig::default()), - tokens: TokenCache::new(), - client: reqwest::Client::new(), + config: Arc::default(), + auth_store: Arc::default(), + tokens: TokenCache::default(), + client: reqwest::Client::default(), push_chunk_size: PUSH_CHUNK_MAX_SIZE, } } @@ -257,9 +261,9 @@ impl TryFrom for Client { Ok(Self { config: Arc::new(config), - tokens: TokenCache::new(), client: client_builder.build()?, push_chunk_size: PUSH_CHUNK_MAX_SIZE, + ..Default::default() }) } } @@ -271,10 +275,8 @@ impl Client { warn!("Cannot create OCI client from config: {:?}", err); warn!("Creating client with default configuration"); Self { - config: Arc::new(ClientConfig::default()), - tokens: TokenCache::new(), - client: reqwest::Client::new(), push_chunk_size: PUSH_CHUNK_MAX_SIZE, + ..Default::default() } }) } @@ -284,6 +286,41 @@ impl Client { Self::new(config_source.client_config()) } + async fn store_auth(&self, registry: &str, auth: RegistryAuth) { + self.auth_store + .write() + .await + .insert(registry.to_string(), auth); + } + + async fn is_stored_auth(&self, registry: &str) -> bool { + self.auth_store.read().await.contains_key(registry) + } + + async fn store_auth_if_needed(&self, registry: &str, auth: &RegistryAuth) { + if !self.is_stored_auth(registry).await { + self.store_auth(registry, auth.clone()).await; + } + } + + /// Checks if we got a token, if we don't - create it and store it in cache. + async fn get_auth_token( + &self, + reference: &Reference, + op: RegistryOperation, + ) -> Option { + let registry = reference.resolve_registry(); + let auth = self.auth_store.read().await.get(registry)?.clone(); + match self.tokens.get(reference, op).await { + Some(token) => Some(token), + None => { + let token = self._auth(reference, &auth, op).await.ok()??; + self.tokens.insert(reference, op, token.clone()).await; + Some(token) + } + } + } + /// Fetches the available Tags for the given Reference /// /// The client will check if it's already been authenticated and if @@ -298,9 +335,8 @@ impl Client { let op = RegistryOperation::Pull; let url = self.to_list_tags_url(image); - if !self.tokens.contains_key(image, op).await { - self.auth(image, auth, op).await?; - } + self.store_auth_if_needed(image.resolve_registry(), auth) + .await; let request = self.client.get(&url); let request = if let Some(num) = n { @@ -342,10 +378,8 @@ impl Client { accepted_media_types: Vec<&str>, ) -> Result { debug!("Pulling image: {:?}", image); - let op = RegistryOperation::Pull; - if !self.tokens.contains_key(image, op).await { - self.auth(image, auth, op).await?; - } + self.store_auth_if_needed(image.resolve_registry(), auth) + .await; let (manifest, digest, config) = self._pull_manifest_and_config(image).await?; @@ -400,10 +434,8 @@ impl Client { manifest: Option, ) -> Result { debug!("Pushing image: {:?}", image_ref); - let op = RegistryOperation::Push; - if !self.tokens.contains_key(image_ref, op).await { - self.auth(image_ref, auth, op).await?; - } + self.store_auth_if_needed(image_ref.resolve_registry(), auth) + .await; let manifest: OciImageManifest = match manifest { Some(m) => m, @@ -502,6 +534,38 @@ impl Client { authentication: &RegistryAuth, operation: RegistryOperation, ) -> Result> { + self.store_auth_if_needed(image.resolve_registry(), authentication) + .await; + // preserve old caching behavior + match self._auth(image, authentication, operation).await { + Ok(Some(RegistryTokenType::Bearer(token))) => { + self.tokens + .insert(image, operation, RegistryTokenType::Bearer(token.clone())) + .await; + Ok(Some(token.token().to_string())) + } + Ok(Some(RegistryTokenType::Basic(username, password))) => { + self.tokens + .insert( + image, + operation, + RegistryTokenType::Basic(username, password), + ) + .await; + Ok(None) + } + Ok(None) => Ok(None), + Err(e) => Err(e), + } + } + + /// Internal auth that retrieves token. + async fn _auth( + &self, + image: &Reference, + authentication: &RegistryAuth, + operation: RegistryOperation, + ) -> Result> { debug!("Authorizing for image: {:?}", image); // The version request will tell us where to go. let url = format!( @@ -521,13 +585,10 @@ impl Client { Err(e) => { debug!(error = ?e, "Falling back to HTTP Basic Auth"); if let RegistryAuth::Basic(username, password) = authentication { - self.tokens - .insert( - image, - operation, - RegistryTokenType::Basic(username.to_string(), password.to_string()), - ) - .await; + return Ok(Some(RegistryTokenType::Basic( + username.to_string(), + password.to_string(), + ))); } return Ok(None); } @@ -566,11 +627,7 @@ impl Client { let token: RegistryToken = serde_json::from_str(&text) .map_err(|e| OciDistributionError::RegistryTokenDecodeError(e.to_string()))?; debug!("Successfully authorized for image '{:?}'", image); - let oauth_token = token.token().to_string(); - self.tokens - .insert(image, operation, RegistryTokenType::Bearer(token)) - .await; - Ok(Some(oauth_token)) + Ok(Some(RegistryTokenType::Bearer(token))) } _ => { let reason = auth_res.text().await?; @@ -593,10 +650,8 @@ impl Client { image: &Reference, auth: &RegistryAuth, ) -> Result { - let op = RegistryOperation::Pull; - if !self.tokens.contains_key(image, op).await { - self.auth(image, auth, op).await?; - } + self.store_auth_if_needed(image.resolve_registry(), auth) + .await; let url = self.to_v2_manifest_url(image); debug!("HEAD image manifest from {}", url); @@ -670,10 +725,8 @@ impl Client { image: &Reference, auth: &RegistryAuth, ) -> Result<(OciImageManifest, String)> { - let op = RegistryOperation::Pull; - if !self.tokens.contains_key(image, op).await { - self.auth(image, auth, op).await?; - } + self.store_auth_if_needed(image.resolve_registry(), auth) + .await; self._pull_image_manifest(image).await } @@ -690,10 +743,8 @@ impl Client { image: &Reference, auth: &RegistryAuth, ) -> Result<(OciManifest, String)> { - let op = RegistryOperation::Pull; - if !self.tokens.contains_key(image, op).await { - self.auth(image, auth, op).await?; - } + self.store_auth_if_needed(image.resolve_registry(), auth) + .await; self._pull_manifest(image).await } @@ -811,10 +862,8 @@ impl Client { image: &Reference, auth: &RegistryAuth, ) -> Result<(OciImageManifest, String, String)> { - let op = RegistryOperation::Pull; - if !self.tokens.contains_key(image, op).await { - self.auth(image, auth, op).await?; - } + self.store_auth_if_needed(image.resolve_registry(), auth) + .await; self._pull_manifest_and_config(image) .await @@ -855,7 +904,8 @@ impl Client { auth: &RegistryAuth, manifest: OciImageIndex, ) -> Result { - self.auth(reference, auth, RegistryOperation::Push).await?; + self.store_auth_if_needed(reference.resolve_registry(), auth) + .await; self.push_manifest(reference, &OciManifest::ImageIndex(manifest)) .await } @@ -1418,7 +1468,7 @@ impl<'a> RequestBuilderWrapper<'a> { ) -> Result { let mut headers = HeaderMap::new(); - if let Some(token) = self.client.tokens.get(image, op).await { + if let Some(token) = self.client.get_auth_token(image, op).await { match token { RegistryTokenType::Bearer(token) => { debug!("Using bearer token authentication."); @@ -1816,6 +1866,14 @@ mod test { .as_str() .to_string(); + // we have to have it in the stored auth so we'll get to the token cache check. + client + .store_auth( + &Reference::try_from(HELLO_IMAGE_TAG)?.resolve_registry(), + RegistryAuth::Anonymous, + ) + .await; + client .tokens .insert( diff --git a/src/token_cache.rs b/src/token_cache.rs index 2b682801..b4f153d1 100644 --- a/src/token_cache.rs +++ b/src/token_cache.rs @@ -59,21 +59,25 @@ pub enum RegistryOperation { Pull, } -type CacheType = BTreeMap<(String, String, RegistryOperation), (RegistryTokenType, u64)>; +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +struct TokenCacheKey { + registry: String, + repository: String, + operation: RegistryOperation, +} + +struct TokenCacheValue { + token: RegistryTokenType, + expiration: u64, +} #[derive(Default, Clone)] pub(crate) struct TokenCache { // (registry, repository, scope) -> (token, expiration) - tokens: Arc>, + tokens: Arc>>, } impl TokenCache { - pub(crate) fn new() -> Self { - TokenCache { - tokens: Arc::new(RwLock::new(BTreeMap::new())), - } - } - pub(crate) async fn insert( &self, reference: &Reference, @@ -119,10 +123,14 @@ impl TokenCache { let registry = reference.resolve_registry().to_string(); let repository = reference.repository().to_string(); debug!(%registry, %repository, ?op, %expiration, "Inserting token"); - self.tokens - .write() - .await - .insert((registry, repository, op), (token, expiration)); + self.tokens.write().await.insert( + TokenCacheKey { + registry, + repository, + operation: op, + }, + TokenCacheValue { token, expiration }, + ); } pub(crate) async fn get( @@ -132,34 +140,33 @@ impl TokenCache { ) -> Option { let registry = reference.resolve_registry().to_string(); let repository = reference.repository().to_string(); - match self - .tokens - .read() - .await - .get(&(registry.clone(), repository.clone(), op)) - { - Some((ref token, expiration)) => { + let key = TokenCacheKey { + registry, + repository, + operation: op, + }; + match self.tokens.read().await.get(&key) { + Some(TokenCacheValue { + ref token, + expiration, + }) => { let now = SystemTime::now(); let epoch = now .duration_since(UNIX_EPOCH) .expect("Time went backwards") .as_secs(); if epoch > *expiration { - debug!(%registry, %repository, ?op, %expiration, miss=false, expired=true, "Fetching token"); + debug!(%key.registry, %key.repository, ?key.operation, %expiration, miss=false, expired=true, "Fetching token"); None } else { - debug!(%registry, %repository, ?op, %expiration, miss=false, expired=false, "Fetching token"); + debug!(%key.registry, %key.repository, ?key.operation, %expiration, miss=false, expired=false, "Fetching token"); Some(token.clone()) } } None => { - debug!(%registry, %repository, ?op, miss=true, "Fetching token"); + debug!(%key.registry, %key.repository, ?key.operation, miss = true, "Fetching token"); None } } } - - pub(crate) async fn contains_key(&self, reference: &Reference, op: RegistryOperation) -> bool { - self.get(reference, op).await.is_some() - } }