From 30105360331b8b3a14ef0fab5fdd5d52e760ed04 Mon Sep 17 00:00:00 2001 From: Monius Date: Sun, 5 May 2024 08:22:57 +0800 Subject: [PATCH] rm comments --- hfd/src/api/tokio.rs | 162 ++----------------------------------------- hfd/src/lib.rs | 55 +-------------- 2 files changed, 6 insertions(+), 211 deletions(-) diff --git a/hfd/src/api/tokio.rs b/hfd/src/api/tokio.rs index 2e67dfd..e95bd72 100644 --- a/hfd/src/api/tokio.rs +++ b/hfd/src/api/tokio.rs @@ -17,59 +17,44 @@ use thiserror::Error; use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom}; use tokio::sync::{AcquireError, Semaphore, TryAcquireError}; -/// Current version (used in user-agent) const VERSION: &str = env!("CARGO_PKG_VERSION"); -/// Current name (used in user-agent) + const NAME: &str = env!("CARGO_PKG_NAME"); #[derive(Debug, Error)] -/// All errors the API can throw + pub enum ApiError { - /// Api expects certain header to be present in the results to derive some information #[error("Header {0} is missing")] MissingHeader(HeaderName), - /// The header exists, but the value is not conform to what the Api expects. #[error("Header {0} is invalid")] InvalidHeader(HeaderName), - /// The value cannot be used as a header during request header construction #[error("Invalid header value {0}")] InvalidHeaderValue(#[from] InvalidHeaderValue), - /// The header value is not valid utf-8 #[error("header value is not a string")] ToStr(#[from] ToStrError), - /// Error in the request #[error("request error: {0}")] RequestError(#[from] ReqwestError), - /// Error parsing some range value #[error("Cannot parse int")] ParseIntError(#[from] ParseIntError), - /// I/O Error #[error("I/O error {0}")] IoError(#[from] std::io::Error), - /// We tried to download chunk too many times #[error("Too many retries: {0}")] TooManyRetries(Box), - /// Semaphore cannot be acquired #[error("Try acquire: {0}")] TryAcquireError(#[from] TryAcquireError), - /// Semaphore cannot be acquired #[error("Acquire: {0}")] AcquireError(#[from] AcquireError), - // /// Semaphore cannot be acquired - // #[error("Invalid Response: {0:?}")] - // InvalidResponse(Response), } -/// Helper to create [`Api`] with all the options. #[derive(Debug)] pub struct ApiBuilder { endpoint: String, @@ -90,23 +75,11 @@ impl Default for ApiBuilder { } impl ApiBuilder { - /// Default api builder - /// ``` - /// use hf_hub::api::tokio::ApiBuilder; - /// let api = ApiBuilder::new().build().unwrap(); - /// ``` pub fn new() -> Self { let cache = Cache::default(); Self::from_cache(cache) } - /// From a given cache - /// ``` - /// use hf_hub::{api::tokio::ApiBuilder, Cache}; - /// let path = std::path::PathBuf::from("/tmp"); - /// let cache = Cache::new(path); - /// let api = ApiBuilder::from_cache(cache).build().unwrap(); - /// ``` pub fn from_cache(cache: Cache) -> Self { let token = cache.token(); @@ -125,19 +98,16 @@ impl ApiBuilder { } } - /// Wether to show a progressbar pub fn with_progress(mut self, progress: bool) -> Self { self.progress = progress; self } - /// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`. pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { self.cache = Cache::new(cache_dir); self } - /// Sets the token to be used in the API pub fn with_token(mut self, token: Option) -> Self { self.token = token; self @@ -156,27 +126,21 @@ impl ApiBuilder { Ok(headers) } - /// Consumes the builder and builds the final [`Api`] pub fn build(self) -> Result { let headers = self.build_headers()?; let client = Client::builder().default_headers(headers.clone()).build()?; - // Policy: only follow relative redirects - // See: https://github.com/huggingface/huggingface_hub/blob/9c6af39cdce45b570f0b7f8fad2b311c96019804/src/huggingface_hub/file_download.py#L411 let relative_redirect_policy = Policy::custom(|attempt| { - // Follow redirects up to a maximum of 10. if attempt.previous().len() > 10 { return attempt.error("too many redirects"); } if let Some(last) = attempt.previous().last() { - // If the url is not relative if last.make_relative(attempt.url()).is_none() { return attempt.stop(); } } - // Follow redirect attempt.follow() }); @@ -206,9 +170,6 @@ struct Metadata { size: usize, } -/// The actual Api used to interact with the hub. -/// You can inspect repos with [`Api::info`] -/// or download files with [`Api::download`] #[derive(Clone, Debug)] pub struct Api { endpoint: String, @@ -239,9 +200,6 @@ fn make_relative(src: &Path, dst: &Path) -> PathBuf { match (ita.next(), itb.next()) { (Some(a), Some(b)) if a == b => (), (some_a, _) => { - // Ignoring b, because 1 component is the filename - // for which we don't need to go back up for relative - // filename to work. let mut new_path = PathBuf::new(); for _ in itb { new_path.push(Component::ParentDir); @@ -286,13 +244,10 @@ fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize { } impl Api { - /// Creates a default Api, for Api options See [`ApiBuilder`] pub fn new() -> Result { ApiBuilder::new().build() } - /// Get the underlying api client - /// Allows for lower level access pub fn client(&self) -> &Client { &self.client } @@ -316,7 +271,7 @@ impl Api { .get(&header_etag) .ok_or(ApiError::MissingHeader(header_etag))?, }; - // Cleaning extra quotes + let etag = etag.to_str()?.to_string().replace('"', ""); let commit_hash = headers .get(&header_commit) @@ -324,8 +279,6 @@ impl Api { .to_str()? .to_string(); - // The response was redirected o S3 most likely which will - // know about the size of the file let response = if response.status().is_redirection() { self.client .get(headers.get(LOCATION).unwrap().to_str()?.to_string()) @@ -353,47 +306,23 @@ impl Api { }) } - /// Creates a new handle [`ApiRepo`] which contains operations - /// on a particular [`Repo`] pub fn repo(&self, repo: Repo) -> ApiRepo { ApiRepo::new(self.clone(), repo) } - /// Simple wrapper over - /// ``` - /// # use hf_hub::{api::tokio::Api, Repo, RepoType}; - /// # let model_id = "gpt2".to_string(); - /// let api = Api::new().unwrap(); - /// let api = api.repo(Repo::new(model_id, RepoType::Model)); - /// ``` pub fn model(&self, model_id: String) -> ApiRepo { self.repo(Repo::new(model_id, RepoType::Model)) } - /// Simple wrapper over - /// ``` - /// # use hf_hub::{api::tokio::Api, Repo, RepoType}; - /// # let model_id = "gpt2".to_string(); - /// let api = Api::new().unwrap(); - /// let api = api.repo(Repo::new(model_id, RepoType::Dataset)); - /// ``` pub fn dataset(&self, model_id: String) -> ApiRepo { self.repo(Repo::new(model_id, RepoType::Dataset)) } - /// Simple wrapper over - /// ``` - /// # use hf_hub::{api::tokio::Api, Repo, RepoType}; - /// # let model_id = "gpt2".to_string(); - /// let api = Api::new().unwrap(); - /// let api = api.repo(Repo::new(model_id, RepoType::Space)); - /// ``` pub fn space(&self, model_id: String) -> ApiRepo { self.repo(Repo::new(model_id, RepoType::Space)) } } -/// Shorthand for accessing things within a particular repo #[derive(Debug)] pub struct ApiRepo { api: Api, @@ -407,13 +336,6 @@ impl ApiRepo { } impl ApiRepo { - /// Get the fully qualified URL of the remote filename - /// ``` - /// # use hf_hub::api::tokio::Api; - /// let api = Api::new().unwrap(); - /// let url = api.model("gpt2".to_string()).url("model.safetensors"); - /// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors"); - /// ``` pub fn url(&self, filename: &str) -> String { let endpoint = &self.api.endpoint; let revision = &self.repo.url_revision(); @@ -436,7 +358,6 @@ impl ApiRepo { let parallel_failures_semaphore = Arc::new(Semaphore::new(self.api.parallel_failures)); let filename = self.api.cache.temp_path(); - // Create the file and set everything properly tokio::fs::File::create(&filename) .await? .set_len(length as u64) @@ -482,7 +403,6 @@ impl ApiRepo { })); } - // Output the chained result let results: Vec, tokio::task::JoinError>> = futures::future::join_all(handles).await; let results: Result<(), ApiError> = results.into_iter().flatten().collect(); @@ -500,7 +420,6 @@ impl ApiRepo { start: usize, stop: usize, ) -> Result<(), ApiError> { - // Process each socket concurrently. let range = format!("bytes={start}-{stop}"); let mut file = tokio::fs::OpenOptions::new() .write(true) @@ -518,14 +437,6 @@ impl ApiRepo { Ok(()) } - /// This will attempt the fetch the file locally first, then [`Api.download`] - /// if the file is not present. - /// ```no_run - /// # use hf_hub::api::tokio::Api; - /// # tokio_test::block_on(async { - /// let api = Api::new().unwrap(); - /// let local_filename = api.model("gpt2".to_string()).get("model.safetensors").await.unwrap(); - /// # }) pub async fn get(&self, filename: &str) -> Result { if let Some(path) = self.api.cache.repo(self.repo.clone()).get(filename) { Ok(path) @@ -534,17 +445,6 @@ impl ApiRepo { } } - /// Downloads a remote file (if not already present) into the cache directory - /// to be used locally. - /// This functions require internet access to verify if new versions of the file - /// exist, even if a file is already on disk at location. - /// ```no_run - /// # use hf_hub::api::tokio::Api; - /// # tokio_test::block_on(async { - /// let api = Api::new().unwrap(); - /// let local_filename = api.model("gpt2".to_string()).download("model.safetensors").await.unwrap(); - /// # }) - /// ``` pub async fn download(&self, filename: &str) -> Result { let url = self.url(filename); let metadata = self.api.metadata(&url).await?; @@ -559,7 +459,7 @@ impl ApiRepo { ProgressStyle::with_template( "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", ) - .unwrap(), // .progress_chars("━ "), + .unwrap(), ); let maxlength = 30; let message = if filename.len() > maxlength { @@ -589,67 +489,13 @@ impl ApiRepo { Ok(pointer_path) } - /// Get information about the Repo - /// ``` - /// # use hf_hub::api::tokio::Api; - /// # tokio_test::block_on(async { - /// let api = Api::new().unwrap(); - /// api.model("gpt2".to_string()).info(); - /// # }) - /// ``` pub async fn info(&self) -> Result { Ok(self.info_request().send().await?.json().await?) } - /// Get the raw [`reqwest::RequestBuilder`] with the url and method already set - /// ``` - /// # use hf_hub::api::tokio::Api; - /// # tokio_test::block_on(async { - /// let api = Api::new().unwrap(); - /// api.model("gpt2".to_owned()) - /// .info_request() - /// .query(&[("blobs", "true")]) - /// .send() - /// .await; - /// # }) - /// ``` pub fn info_request(&self) -> RequestBuilder { let url = format!("{}/api/{}", self.api.endpoint, self.repo.api_url()); self.api.client.get(url) } } -#[cfg(test)] -mod tests { - use super::*; - use crate::api::Siblings; - use hex_literal::hex; - use rand::distributions::Alphanumeric; - use serde_json::{json, Value}; - use sha2::{Digest, Sha256}; - - struct TempDir { - path: PathBuf, - } - - impl TempDir { - pub fn new() -> Self { - let s: String = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(7) - .map(char::from) - .collect(); - let mut path = std::env::temp_dir(); - path.push(s); - std::fs::create_dir(&path).unwrap(); - Self { path } - } - } - - impl Drop for TempDir { - fn drop(&mut self) { - std::fs::remove_dir_all(&self.path).unwrap(); - } - } - -} \ No newline at end of file diff --git a/hfd/src/lib.rs b/hfd/src/lib.rs index 89264cc..f873469 100644 --- a/hfd/src/lib.rs +++ b/hfd/src/lib.rs @@ -2,49 +2,36 @@ use rand::{distributions::Alphanumeric, Rng}; use std::io::Write; use std::path::PathBuf; -/// The actual Api to interact with the hub. pub mod api; -/// The type of repo to interact with #[derive(Debug, Clone, Copy)] pub enum RepoType { - /// This is a model, usually it consists of weight files and some configuration - /// files Model, - /// This is a dataset, usually contains data within parquet files Dataset, - /// This is a space, usually a demo showcashing a given model or dataset Space, } -/// A local struct used to fetch information from the cache folder. #[derive(Clone, Debug)] pub struct Cache { path: PathBuf, } impl Cache { - /// Creates a new cache object location pub fn new(path: PathBuf) -> Self { Self { path } } - /// Creates a new cache object location pub fn path(&self) -> &PathBuf { &self.path } - /// Returns the location of the token file pub fn token_path(&self) -> PathBuf { let mut path = self.path.clone(); - // Remove `"hub"` path.pop(); path.push("token"); path } - /// Returns the token value if it exists in the cache - /// Use `huggingface-cli login` to set it up. pub fn token(&self) -> Option { let token_filename = self.token_path(); if !token_filename.exists() { @@ -63,41 +50,18 @@ impl Cache { } } - /// Creates a new handle [`CacheRepo`] which contains operations - /// on a particular [`Repo`] pub fn repo(&self, repo: Repo) -> CacheRepo { CacheRepo::new(self.clone(), repo) } - /// Simple wrapper over - /// ``` - /// # use hf_hub::{Cache, Repo, RepoType}; - /// # let model_id = "gpt2".to_string(); - /// let cache = Cache::new("/tmp/".into()); - /// let cache = cache.repo(Repo::new(model_id, RepoType::Model)); - /// ``` pub fn model(&self, model_id: String) -> CacheRepo { self.repo(Repo::new(model_id, RepoType::Model)) } - /// Simple wrapper over - /// ``` - /// # use hf_hub::{Cache, Repo, RepoType}; - /// # let model_id = "gpt2".to_string(); - /// let cache = Cache::new("/tmp/".into()); - /// let cache = cache.repo(Repo::new(model_id, RepoType::Dataset)); - /// ``` pub fn dataset(&self, model_id: String) -> CacheRepo { self.repo(Repo::new(model_id, RepoType::Dataset)) } - /// Simple wrapper over - /// ``` - /// # use hf_hub::{Cache, Repo, RepoType}; - /// # let model_id = "gpt2".to_string(); - /// let cache = Cache::new("/tmp/".into()); - /// let cache = cache.repo(Repo::new(model_id, RepoType::Space)); - /// ``` pub fn space(&self, model_id: String) -> CacheRepo { self.repo(Repo::new(model_id, RepoType::Space)) } @@ -117,7 +81,6 @@ impl Cache { } } -/// Shorthand for accessing things within a particular repo #[derive(Debug)] pub struct CacheRepo { cache: Cache, @@ -128,8 +91,7 @@ impl CacheRepo { fn new(cache: Cache, repo: Repo) -> Self { Self { cache, repo } } - /// This will get the location of the file within the cache for the remote - /// `filename`. Will return `None` if file is not already present in cache. + pub fn get(&self, filename: &str) -> Option { let commit_path = self.ref_path(); let commit_hash = std::fs::read_to_string(commit_path).ok()?; @@ -155,11 +117,9 @@ impl CacheRepo { ref_path } - /// Creates a reference in the cache directory that points branches to the correct - /// commits within the blobs. pub fn create_ref(&self, commit_hash: &str) -> Result<(), std::io::Error> { let ref_path = self.ref_path(); - // Needs to be done like this because revision might contain `/` creating subfolders here. + std::fs::create_dir_all(ref_path.parent().unwrap())?; let mut file = std::fs::OpenOptions::new() .write(true) @@ -201,7 +161,6 @@ impl Default for Cache { } } -/// The representation of a repo on the hub. #[derive(Clone, Debug)] pub struct Repo { repo_id: String, @@ -210,12 +169,10 @@ pub struct Repo { } impl Repo { - /// Repo with the default branch ("main"). pub fn new(repo_id: String, repo_type: RepoType) -> Self { Self::with_revision(repo_id, repo_type, "main".to_string()) } - /// fully qualified Repo pub fn with_revision(repo_id: String, repo_type: RepoType, revision: String) -> Self { Self { repo_id, @@ -224,22 +181,18 @@ impl Repo { } } - /// Shortcut for [`Repo::new`] with [`RepoType::Model`] pub fn model(repo_id: String) -> Self { Self::new(repo_id, RepoType::Model) } - /// Shortcut for [`Repo::new`] with [`RepoType::Dataset`] pub fn dataset(repo_id: String) -> Self { Self::new(repo_id, RepoType::Dataset) } - /// Shortcut for [`Repo::new`] with [`RepoType::Space`] pub fn space(repo_id: String) -> Self { Self::new(repo_id, RepoType::Space) } - /// The normalized folder nameof the repo within the cache directory pub fn folder_name(&self) -> String { let prefix = match self.repo_type { RepoType::Model => "models", @@ -249,12 +202,10 @@ impl Repo { format!("{prefix}--{}", self.repo_id).replace('/', "--") } - /// The revision pub fn revision(&self) -> &str { &self.revision } - /// The actual URL part of the repo pub fn url(&self) -> String { match self.repo_type { RepoType::Model => self.repo_id.to_string(), @@ -267,12 +218,10 @@ impl Repo { } } - /// Revision needs to be url escaped before being used in a URL pub fn url_revision(&self) -> String { self.revision.replace('/', "%2F") } - /// Used to compute the repo's url part when accessing the metadata of the repo pub fn api_url(&self) -> String { let prefix = match self.repo_type { RepoType::Model => "models",