diff --git a/src/client.rs b/src/client.rs index d4447e37..5fe5cc20 100644 --- a/src/client.rs +++ b/src/client.rs @@ -19,8 +19,7 @@ use crate::Reference; use crate::errors::{OciDistributionError, Result}; use crate::token_cache::{RegistryOperation, RegistryToken, RegistryTokenType, TokenCache}; use futures_util::future; -use futures_util::stream::{self, StreamExt, TryStreamExt}; -use futures_util::Stream; +use futures_util::stream::{self, Stream, StreamExt, TryStreamExt}; use http::HeaderValue; use http_auth::{parser::ChallengeParser, ChallengeRef}; use olpc_cjson::CanonicalFormatter; @@ -31,6 +30,7 @@ use serde::Serialize; use sha2::Digest; use std::collections::HashMap; use std::convert::TryFrom; +use std::sync::Arc; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tracing::{debug, trace, warn}; @@ -201,8 +201,9 @@ impl TryFrom for ConfigFile { /// /// For true anonymous access, you can skip `auth()`. This is not recommended /// unless you are sure that the remote registry does not require Oauth2. +#[derive(Clone)] pub struct Client { - config: ClientConfig, + config: Arc, tokens: TokenCache, client: reqwest::Client, push_chunk_size: usize, @@ -211,7 +212,7 @@ pub struct Client { impl Default for Client { fn default() -> Self { Self { - config: ClientConfig::default(), + config: Arc::new(ClientConfig::default()), tokens: TokenCache::new(), client: reqwest::Client::new(), push_chunk_size: PUSH_CHUNK_MAX_SIZE, @@ -254,7 +255,7 @@ impl TryFrom for Client { } Ok(Self { - config, + config: Arc::new(config), tokens: TokenCache::new(), client: client_builder.build()?, push_chunk_size: PUSH_CHUNK_MAX_SIZE, @@ -269,7 +270,7 @@ impl Client { warn!("Cannot create OCI client from config: {:?}", err); warn!("Creating client with default configuration"); Self { - config: ClientConfig::default(), + config: Arc::new(ClientConfig::default()), tokens: TokenCache::new(), client: reqwest::Client::new(), push_chunk_size: PUSH_CHUNK_MAX_SIZE, @@ -287,7 +288,7 @@ impl Client { /// The client will check if it's already been authenticated and if /// not will attempt to do. pub async fn list_tags( - &mut self, + &self, image: &Reference, auth: &RegistryAuth, n: Option, @@ -296,7 +297,7 @@ impl Client { let op = RegistryOperation::Pull; let url = self.to_list_tags_url(image); - if !self.tokens.contains_key(image, op) { + if !self.tokens.contains_key(image, op).await { self.auth(image, auth, op).await?; } @@ -316,7 +317,8 @@ impl Client { request_builder: request, }; let res = request - .apply_auth(image, op)? + .apply_auth(image, op) + .await? .into_request_builder() .send() .await?; @@ -333,14 +335,14 @@ impl Client { /// The client will check if it's already been authenticated and if /// not will attempt to do. pub async fn pull( - &mut self, + &self, image: &Reference, auth: &RegistryAuth, accepted_media_types: Vec<&str>, ) -> Result { debug!("Pulling image: {:?}", image); let op = RegistryOperation::Pull; - if !self.tokens.contains_key(image, op) { + if !self.tokens.contains_key(image, op).await { self.auth(image, auth, op).await?; } @@ -351,7 +353,7 @@ impl Client { let layers = stream::iter(&manifest.layers) .map(|layer| { - // This avoids moving `self` which is &mut Self + // This avoids moving `self` which is &Self // into the async block. We only want to capture // as &Self let this = &self; @@ -389,7 +391,7 @@ impl Client { /// /// Returns pullable URL for the image pub async fn push( - &mut self, + &self, image_ref: &Reference, layers: &[ImageLayer], config: Config, @@ -398,7 +400,7 @@ impl Client { ) -> Result { debug!("Pushing image: {:?}", image_ref); let op = RegistryOperation::Push; - if !self.tokens.contains_key(image_ref, op) { + if !self.tokens.contains_key(image_ref, op).await { self.auth(image_ref, auth, op).await?; } @@ -410,7 +412,7 @@ impl Client { // Upload layers stream::iter(layers) .map(|layer| { - // This avoids moving `self` which is &mut Self + // This avoids moving `self` which is &Self // into the async block. We only want to capture // as &Self let this = &self; @@ -489,12 +491,41 @@ impl Client { .await } + /// Pushes a blob to the registry as a stream, chunking it + /// upstream. + pub async fn push_blob_stream( + &self, + image: &Reference, + mut blob_stream: impl Stream> + + std::marker::Unpin, + blob_digest: &str, + ) -> Result { + let mut location = self.begin_push_chunked_session(image).await?; + let mut start: usize = 0; + let mut blob_data = Vec::new(); + let mut done: bool = false; + loop { + if let Some(bytes) = blob_stream.next().await { + blob_data.extend(&bytes?); + } else { + done = true; + } + // gonna break when push chunk finishes + (location, start) = self.push_chunk(&location, image, &blob_data, start).await?; + if done && start >= blob_data.len() { + break; + } + } + self.end_push_chunked_session(&location, image, blob_digest) + .await + } + /// Perform an OAuth v2 auth request if necessary. /// /// This performs authorization and then stores the token internally to be used /// on other requests. pub async fn auth( - &mut self, + &self, image: &Reference, authentication: &RegistryAuth, operation: RegistryOperation, @@ -518,11 +549,13 @@ impl Client { Err(e) => { debug!(error = ?e, "Falling back to HTTP Basic Auth"); if let RegistryAuth::Basic(username, password) = authentication { - self.tokens.insert( - image, - operation, - RegistryTokenType::Basic(username.to_string(), password.to_string()), - ); + self.tokens + .insert( + image, + operation, + RegistryTokenType::Basic(username.to_string(), password.to_string()), + ) + .await; } return Ok(None); } @@ -563,7 +596,8 @@ impl Client { debug!("Successfully authorized for image '{:?}'", image); let oauth_token = token.token().to_string(); self.tokens - .insert(image, operation, RegistryTokenType::Bearer(token)); + .insert(image, operation, RegistryTokenType::Bearer(token)) + .await; Ok(Some(oauth_token)) } _ => { @@ -583,12 +617,12 @@ impl Client { /// HEAD request. If this header is not present, will make a second GET /// request and return the SHA256 of the response body. pub async fn fetch_manifest_digest( - &mut self, + &self, image: &Reference, auth: &RegistryAuth, ) -> Result { let op = RegistryOperation::Pull; - if !self.tokens.contains_key(image, op) { + if !self.tokens.contains_key(image, op).await { self.auth(image, auth, op).await?; } @@ -596,7 +630,8 @@ impl Client { debug!("HEAD image manifest from {}", url); let res = RequestBuilderWrapper::from_client(self, |client| client.head(&url)) .apply_accept(MIME_TYPES_DISTRIBUTION_MANIFEST)? - .apply_auth(image, RegistryOperation::Pull)? + .apply_auth(image, RegistryOperation::Pull) + .await? .into_request_builder() .send() .await?; @@ -606,7 +641,8 @@ impl Client { debug!("GET image manifest from {}", url); let res = RequestBuilderWrapper::from_client(self, |client| client.get(&url)) .apply_accept(MIME_TYPES_DISTRIBUTION_MANIFEST)? - .apply_auth(image, RegistryOperation::Pull)? + .apply_auth(image, RegistryOperation::Pull) + .await? .into_request_builder() .send() .await?; @@ -658,12 +694,12 @@ impl Client { /// If a multi-platform Image Index manifest is encountered, a platform-specific /// Image manifest will be selected using the client's default platform resolution. pub async fn pull_image_manifest( - &mut self, + &self, image: &Reference, auth: &RegistryAuth, ) -> Result<(OciImageManifest, String)> { let op = RegistryOperation::Pull; - if !self.tokens.contains_key(image, op) { + if !self.tokens.contains_key(image, op).await { self.auth(image, auth, op).await?; } @@ -678,12 +714,12 @@ impl Client { /// A Tuple is returned containing the [Manifest](crate::manifest::OciImageManifest) /// and the manifest content digest hash. pub async fn pull_manifest( - &mut self, + &self, image: &Reference, auth: &RegistryAuth, ) -> Result<(OciManifest, String)> { let op = RegistryOperation::Pull; - if !self.tokens.contains_key(image, op) { + if !self.tokens.contains_key(image, op).await { self.auth(image, auth, op).await?; } @@ -747,7 +783,8 @@ impl Client { let res = RequestBuilderWrapper::from_client(self, |client| client.get(&url)) .apply_accept(MIME_TYPES_DISTRIBUTION_MANIFEST)? - .apply_auth(image, RegistryOperation::Pull)? + .apply_auth(image, RegistryOperation::Pull) + .await? .into_request_builder() .send() .await?; @@ -798,12 +835,12 @@ impl Client { /// the manifest content digest hash and the contents of the manifests config layer /// as a String. pub async fn pull_manifest_and_config( - &mut self, + &self, image: &Reference, auth: &RegistryAuth, ) -> Result<(OciImageManifest, String, String)> { let op = RegistryOperation::Pull; - if !self.tokens.contains_key(image, op) { + if !self.tokens.contains_key(image, op).await { self.auth(image, auth, op).await?; } @@ -824,7 +861,7 @@ impl Client { } async fn _pull_manifest_and_config( - &mut self, + &self, image: &Reference, ) -> Result<(OciImageManifest, String, Config)> { let (manifest, digest) = self._pull_image_manifest(image).await?; @@ -842,7 +879,7 @@ impl Client { /// /// This pushes a manifest list to an OCI registry. pub async fn push_manifest_list( - &mut self, + &self, reference: &Reference, auth: &RegistryAuth, manifest: OciImageIndex, @@ -868,7 +905,8 @@ impl Client { let url = self.to_v2_blob_url(image.resolve_registry(), image.repository(), digest); let mut stream = RequestBuilderWrapper::from_client(self, |client| client.get(&url)) .apply_accept(MIME_TYPES_DISTRIBUTION_MANIFEST)? - .apply_auth(image, RegistryOperation::Pull)? + .apply_auth(image, RegistryOperation::Pull) + .await? .into_request_builder() .send() .await? @@ -894,7 +932,8 @@ impl Client { let url = self.to_v2_blob_url(image.resolve_registry(), image.repository(), digest); let stream = RequestBuilderWrapper::from_client(self, |client| client.get(&url)) .apply_accept(MIME_TYPES_DISTRIBUTION_MANIFEST)? - .apply_auth(image, RegistryOperation::Pull)? + .apply_auth(image, RegistryOperation::Pull) + .await? .into_request_builder() .send() .await? @@ -911,7 +950,8 @@ impl Client { let url = &self.to_v2_blob_upload_url(image); debug!(?url, "begin_push_monolithical_session"); let res = RequestBuilderWrapper::from_client(self, |client| client.post(url)) - .apply_auth(image, RegistryOperation::Push)? + .apply_auth(image, RegistryOperation::Push) + .await? .into_request_builder() .send() .await?; @@ -928,7 +968,8 @@ impl Client { let url = &self.to_v2_blob_upload_url(image); debug!(?url, "begin_push_session"); let res = RequestBuilderWrapper::from_client(self, |client| client.post(url)) - .apply_auth(image, RegistryOperation::Push)? + .apply_auth(image, RegistryOperation::Push) + .await? .into_request_builder() .header("Content-Length", 0) .send() @@ -951,7 +992,8 @@ impl Client { let url = Url::parse_with_params(location, &[("digest", digest)]) .map_err(|e| OciDistributionError::GenericError(Some(e.to_string())))?; let res = RequestBuilderWrapper::from_client(self, |client| client.put(url.clone())) - .apply_auth(image, RegistryOperation::Push)? + .apply_auth(image, RegistryOperation::Push) + .await? .into_request_builder() .header("Content-Length", 0) .send() @@ -986,7 +1028,8 @@ impl Client { headers.insert("Content-Type", "application/octet-stream".parse().unwrap()); let res = RequestBuilderWrapper::from_client(self, |client| client.put(&url)) - .apply_auth(image, RegistryOperation::Push)? + .apply_auth(image, RegistryOperation::Push) + .await? .into_request_builder() .headers(headers) .body(layer.to_vec()) @@ -1037,7 +1080,8 @@ impl Client { ); let res = RequestBuilderWrapper::from_client(self, |client| client.patch(location)) - .apply_auth(image, RegistryOperation::Push)? + .apply_auth(image, RegistryOperation::Push) + .await? .into_request_builder() .headers(headers) .body(body) @@ -1067,7 +1111,8 @@ impl Client { .map_err(|e| OciDistributionError::UrlParseError(e.to_string()))?; let res = RequestBuilderWrapper::from_client(self, |client| client.post(url.clone())) - .apply_auth(image, RegistryOperation::Push)? + .apply_auth(image, RegistryOperation::Push) + .await? .into_request_builder() .send() .await?; @@ -1117,7 +1162,8 @@ impl Client { let manifest_hash = sha256_digest(&body); let res = RequestBuilderWrapper::from_client(self, |client| client.put(url.clone())) - .apply_auth(image, RegistryOperation::Push)? + .apply_auth(image, RegistryOperation::Push) + .await? .into_request_builder() .headers(headers) .body(body) @@ -1343,14 +1389,14 @@ impl<'a> RequestBuilderWrapper<'a> { /// Authorization header. It will also set the Accept header, which must /// be set on all OCI Registry requests. If the struct has HTTP Basic Auth /// credentials, these will be configured. - fn apply_auth( + async fn apply_auth( &self, image: &Reference, op: RegistryOperation, ) -> Result { let mut headers = HeaderMap::new(); - if let Some(token) = self.client.tokens.get(image, op) { + if let Some(token) = self.client.tokens.get(image, op).await { match token { RegistryTokenType::Bearer(token) => { debug!("Using bearer token authentication."); @@ -1698,15 +1744,16 @@ mod test { Ok(()) } - #[test] - fn test_apply_auth_no_token() -> anyhow::Result<()> { + #[tokio::test] + async fn test_apply_auth_no_token() -> anyhow::Result<()> { assert!( !RequestBuilderWrapper::from_client(&Client::default(), |client| client .get("https://example.com/some/module.wasm")) .apply_auth( &Reference::try_from(HELLO_IMAGE_TAG)?, RegistryOperation::Pull - )? + ) + .await? .into_request_builder() .build()? .headers() @@ -1716,12 +1763,12 @@ mod test { Ok(()) } - #[test] - fn test_apply_auth_bearer_token() -> anyhow::Result<()> { + #[tokio::test] + async fn test_apply_auth_bearer_token() -> anyhow::Result<()> { use hmac::{Hmac, Mac}; use jwt::SignWithKey; use sha2::Sha256; - let mut client = Client::default(); + let client = Client::default(); let header = jwt::header::Header { algorithm: jwt::algorithm::AlgorithmType::Hs256, key_id: None, @@ -1735,20 +1782,24 @@ mod test { .as_str() .to_string(); - client.tokens.insert( - &Reference::try_from(HELLO_IMAGE_TAG)?, - RegistryOperation::Pull, - RegistryTokenType::Bearer(RegistryToken::Token { - token: token.clone(), - }), - ); + client + .tokens + .insert( + &Reference::try_from(HELLO_IMAGE_TAG)?, + RegistryOperation::Pull, + RegistryTokenType::Bearer(RegistryToken::Token { + token: token.clone(), + }), + ) + .await; assert_eq!( RequestBuilderWrapper::from_client(&client, |client| client .get("https://example.com/some/module.wasm")) .apply_auth( &Reference::try_from(HELLO_IMAGE_TAG)?, RegistryOperation::Pull - )? + ) + .await? .into_request_builder() .build()? .headers()["Authorization"], @@ -2029,7 +2080,7 @@ mod test { async fn test_auth() { for &image in TEST_IMAGES { let reference = Reference::try_from(image).expect("failed to parse reference"); - let mut c = Client::default(); + let c = Client::default(); let token = c .auth( &reference, @@ -2045,6 +2096,7 @@ mod test { let tok = c .tokens .get(&reference, RegistryOperation::Pull) + .await .expect("token is available"); // We test that the token is longer than a minimal hash. if let RegistryTokenType::Bearer(tok) = tok { @@ -2064,7 +2116,7 @@ mod test { let auth = RegistryAuth::Basic(HTPASSWD_USERNAME.to_string(), HTPASSWD_PASSWORD.to_string()); - let mut client = Client::new(ClientConfig { + let client = Client::new(ClientConfig { protocol: ClientProtocol::HttpsExcept(vec![format!("localhost:{}", port)]), ..Default::default() }); @@ -2126,7 +2178,7 @@ mod test { .expect_err("pull manifest should fail"); // But this should pass - let mut c = Client::default(); + let c = Client::default(); c.auth( &reference, &RegistryAuth::Anonymous, @@ -2149,7 +2201,7 @@ mod test { async fn test_pull_manifest_public() { for &image in TEST_IMAGES { let reference = Reference::try_from(image).expect("failed to parse reference"); - let mut c = Client::default(); + let c = Client::default(); let (manifest, _) = c .pull_image_manifest(&reference, &RegistryAuth::Anonymous) .await @@ -2165,7 +2217,7 @@ mod test { async fn pull_manifest_and_config_public() { for &image in TEST_IMAGES { let reference = Reference::try_from(image).expect("failed to parse reference"); - let mut c = Client::default(); + let c = Client::default(); let (manifest, _, config) = c .pull_manifest_and_config(&reference, &RegistryAuth::Anonymous) .await @@ -2180,7 +2232,7 @@ mod test { #[tokio::test] async fn test_fetch_digest() { - let mut c = Client::default(); + let c = Client::default(); for &image in TEST_IMAGES { let reference = Reference::try_from(image).expect("failed to parse reference"); @@ -2190,7 +2242,7 @@ mod test { // This should pass let reference = Reference::try_from(image).expect("failed to parse reference"); - let mut c = Client::default(); + let c = Client::default(); c.auth( &reference, &RegistryAuth::Anonymous, @@ -2212,7 +2264,7 @@ mod test { #[tokio::test] async fn test_pull_blob() { - let mut c = Client::default(); + let c = Client::default(); for &image in TEST_IMAGES { let reference = Reference::try_from(image).expect("failed to parse reference"); @@ -2259,7 +2311,7 @@ mod test { #[tokio::test] async fn test_pull_blob_stream() { - let mut c = Client::default(); + let c = Client::default(); for &image in TEST_IMAGES { let reference = Reference::try_from(image).expect("failed to parse reference"); @@ -2395,7 +2447,7 @@ mod test { let test_container = docker.run(registry_image()); let port = test_container.get_host_port_ipv4(5000); - let mut c = Client::new(ClientConfig { + let c = Client::new(ClientConfig { protocol: ClientProtocol::Http, ..Default::default() }); @@ -2433,6 +2485,39 @@ mod test { assert_eq!(layer_location, format!("http://localhost:{}/v2/hello-wasm/blobs/sha256:6165c4ad43c0803798b6f2e49d6348c915d52c999a5f890846cee77ea65d230b", port)); } + #[tokio::test] + #[cfg(feature = "test-registry")] + async fn can_push_stream() { + let docker = clients::Cli::default(); + let test_container = docker.run(registry_image()); + let port = test_container.get_host_port_ipv4(5000); + + let c = Client::new(ClientConfig { + protocol: ClientProtocol::Http, + ..Default::default() + }); + let url = format!("localhost:{}/hello-wasm:v1", port); + let image: Reference = url.parse().unwrap(); + + c.auth(&image, &RegistryAuth::Anonymous, RegistryOperation::Push) + .await + .expect("result from auth request"); + + let image_data: Vec> = vec![b"iamawebassemblymodule".to_vec()]; + let digest = sha256_digest(&image_data[0]); + let layer_location = c + .push_blob_stream( + &image, + stream::iter(image_data) + .map(|chunk| Ok::<_, std::io::Error>(bytes::Bytes::from(chunk))), + &digest, + ) + .await + .expect("failed to blob stream push"); + + assert_eq!(layer_location, format!("http://localhost:{}/v2/hello-wasm/blobs/sha256:6165c4ad43c0803798b6f2e49d6348c915d52c999a5f890846cee77ea65d230b", port)); + } + #[tokio::test] #[cfg(feature = "test-registry")] async fn can_push_multiple_chunks() { @@ -2510,7 +2595,7 @@ mod test { let _ = tracing_subscriber::fmt::try_init(); let port = test_container.get_host_port_ipv4(5000); - let mut c = Client::new(ClientConfig { + let c = Client::new(ClientConfig { protocol: ClientProtocol::HttpsExcept(vec![format!("localhost:{}", port)]), ..Default::default() }); @@ -2617,7 +2702,7 @@ mod test { async fn test_platform_resolution() { // test that we get an error when we pull a manifest list let reference = Reference::try_from(DOCKER_IO_IMAGE).expect("failed to parse reference"); - let mut c = Client::new(ClientConfig { + let c = Client::new(ClientConfig { platform_resolver: None, ..Default::default() }); @@ -2647,7 +2732,7 @@ mod test { #[tokio::test] async fn test_pull_ghcr_io() { let reference = Reference::try_from(GHCR_IO_IMAGE).expect("failed to parse reference"); - let mut c = Client::default(); + let c = Client::default(); let (manifest, _manifest_str) = c .pull_image_manifest(&reference, &RegistryAuth::Anonymous) .await @@ -2659,7 +2744,7 @@ mod test { #[ignore] async fn test_roundtrip_multiple_layers() { let _ = tracing_subscriber::fmt::try_init(); - let mut c = Client::new(ClientConfig { + let c = Client::new(ClientConfig { protocol: ClientProtocol::HttpsExcept(vec!["oci.registry.local".to_string()]), ..Default::default() }); diff --git a/src/token_cache.rs b/src/token_cache.rs index cd1be726..2b682801 100644 --- a/src/token_cache.rs +++ b/src/token_cache.rs @@ -2,7 +2,9 @@ use crate::reference::Reference; use serde::Deserialize; use std::collections::BTreeMap; use std::fmt; +use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; +use tokio::sync::RwLock; use tracing::{debug, warn}; /// A token granted during the OAuth2-like workflow for OCI registries. @@ -29,7 +31,7 @@ impl fmt::Debug for RegistryToken { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) enum RegistryTokenType { Bearer(RegistryToken), Basic(String, String), @@ -57,21 +59,23 @@ pub enum RegistryOperation { Pull, } -#[derive(Default)] +type CacheType = BTreeMap<(String, String, RegistryOperation), (RegistryTokenType, u64)>; + +#[derive(Default, Clone)] pub(crate) struct TokenCache { // (registry, repository, scope) -> (token, expiration) - tokens: BTreeMap<(String, String, RegistryOperation), (RegistryTokenType, u64)>, + tokens: Arc>, } impl TokenCache { pub(crate) fn new() -> Self { TokenCache { - tokens: BTreeMap::new(), + tokens: Arc::new(RwLock::new(BTreeMap::new())), } } - pub(crate) fn insert( - &mut self, + pub(crate) async fn insert( + &self, reference: &Reference, op: RegistryOperation, token: RegistryTokenType, @@ -116,17 +120,24 @@ impl TokenCache { let repository = reference.repository().to_string(); debug!(%registry, %repository, ?op, %expiration, "Inserting token"); self.tokens + .write() + .await .insert((registry, repository, op), (token, expiration)); } - pub(crate) fn get( + pub(crate) async fn get( &self, reference: &Reference, op: RegistryOperation, - ) -> Option<&RegistryTokenType> { + ) -> Option { let registry = reference.resolve_registry().to_string(); let repository = reference.repository().to_string(); - match self.tokens.get(&(registry.clone(), repository.clone(), op)) { + match self + .tokens + .read() + .await + .get(&(registry.clone(), repository.clone(), op)) + { Some((ref token, expiration)) => { let now = SystemTime::now(); let epoch = now @@ -138,7 +149,7 @@ impl TokenCache { None } else { debug!(%registry, %repository, ?op, %expiration, miss=false, expired=false, "Fetching token"); - Some(token) + Some(token.clone()) } } None => { @@ -148,7 +159,7 @@ impl TokenCache { } } - pub(crate) fn contains_key(&self, reference: &Reference, op: RegistryOperation) -> bool { - self.get(reference, op).is_some() + pub(crate) async fn contains_key(&self, reference: &Reference, op: RegistryOperation) -> bool { + self.get(reference, op).await.is_some() } }