diff --git a/Cargo.toml b/Cargo.toml index c301146..1f37fcf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,6 @@ [workspace] -members = ["hfd", "hfd-cli"] +# members = ["hfd", "hfd-cli"] +members = ["hfd-cli"] resolver = "2" [profile.release] diff --git a/README.md b/README.md index 8846b17..e95e080 100644 --- a/README.md +++ b/README.md @@ -11,4 +11,34 @@ [![GitHub release (with filter)](https://img.shields.io/github/v/release/AUTOM77/hfd?logo=github)](https://github.com/AUTOM77/hfd/releases) -🎈Rust-based interface for Huggingface 🤗 download +🎈Rust-based interface for Huggingface 🤗 download. + +`./hdf "https://huggingface.co/deepseek-ai/DeepSeek-V2"` + +For a more convinent user experience, execute: + +```bash +cat < Result<(), Box> { + let host = u.host().expect("no host"); + let port = u.port_u16().unwrap_or(443); + let addr = format!("{}:{}", host, port).to_socket_addrs()?.next().unwrap(); + + let conf = std::sync::Arc::new({ + let root_store = rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let mut c = rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + c.alpn_protocols.push(ALPN_H2.as_bytes().to_owned()); + c + }); + + let tcp = tokio::net::TcpStream::connect(&addr).await?; + let domain = rustls_pki_types::ServerName::try_from(host)?.to_owned();; + let connector = tokio_rustls::TlsConnector::from(conf); + + let stream = connector.connect(domain, tcp).await?; + let _io = hyper_util::rt::TokioIo::new(stream); + let exec = hyper_util::rt::tokio::TokioExecutor::new(); + + let (mut client, mut h2) = hyper::client::conn::http2::handshake(exec, _io).await?; + tokio::spawn(async move { + if let Err(e) = h2.await { + println!("Error: {:?}", e); + } + }); + + let range = format!("bytes={s}-{e}"); + + let req = hyper::Request::builder() + .uri(u) + .header("user-agent", "hyper-client-http2") + .header(RANGE, range) + .version(hyper::Version::HTTP_2) + .body(http_body_util::Empty::::new())?; + + let mut response = client.send_request(req).await?; + + let mut file = tokio::fs::OpenOptions::new().write(true).open(FILE).await?; + file.seek(SeekFrom::Start(s as u64)).await?; + while let Some(chunk) = response.frame().await { + let chunk = chunk?; + if let Some(c) = chunk.data_ref() { + tokio::io::copy(&mut c.as_ref(), &mut file).await?; + } + } + Ok(()) +} + +async fn download() -> Result<(), Box> { + let mut url: hyper::Uri = URL.parse()?; + let host = url.host().expect("no host"); + let port = url.port_u16().unwrap_or(443); + let addr = format!("{}:{}", host, port).to_socket_addrs()?.next().unwrap(); + + let conf = std::sync::Arc::new({ + let root_store = rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let mut c = rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + c.alpn_protocols.push(ALPN_H2.as_bytes().to_owned()); + c + }); + + let tcp = tokio::net::TcpStream::connect(&addr).await?; + let domain = rustls_pki_types::ServerName::try_from(host)?.to_owned();; + let connector = tokio_rustls::TlsConnector::from(conf); + + let stream = connector.connect(domain, tcp).await?; + let _io = hyper_util::rt::TokioIo::new(stream); + let exec = hyper_util::rt::tokio::TokioExecutor::new(); + + let (mut client, mut h2) = hyper::client::conn::http2::handshake(exec, _io).await?; + tokio::spawn(async move { + if let Err(e) = h2.await { + println!("Error: {:?}", e); + } + }); + + let req = hyper::Request::builder() + .uri(url.clone()) + .header("user-agent", "hyper-client-http2") + .header(RANGE, "bytes=0-0") + .version(hyper::Version::HTTP_2) + .body(http_body_util::Empty::::new())?; + + let mut response = client.send_request(req).await?; + while let Some(location) = response.headers().get(LOCATION) { + let _cdn: hyper::Uri = location.to_str()?.parse()?; + let _req = hyper::Request::builder() + .uri(_cdn.clone()) + .header("user-agent", "hyper-client-http2") + .version(hyper::Version::HTTP_2) + .body(http_body_util::Empty::::new())?; + response = client.send_request(_req).await?; + url = _cdn; + } + + println!("{:?}", url); + let req = hyper::Request::builder() + .uri(url.clone()) + .header("user-agent", "hyper-client-http2") + .header(RANGE, "bytes=0-0") + .version(hyper::Version::HTTP_2) + .body(http_body_util::Empty::::new())?; + let response = client.send_request(req).await?; + + println!("{:?}", response); + let length: usize = response + .headers() + .get(CONTENT_RANGE) + .ok_or("Content-Length not found")? + .to_str()?.rsplit('/').next() + .and_then(|s| s.parse().ok()) + .ok_or("Failed to parse size")?; + + let _ = tokio::fs::File::create(FILE).await?.set_len(length as u64).await?; + let tasks: Vec<_> = (0..length) + .into_iter() + .step_by(CHUNK_SIZE) + .map(|s| { + let _url = url.clone(); + let e = std::cmp::min(s + CHUNK_SIZE - 1, length); + tokio::spawn(async move { download_chunk(_url, s, e).await }) + }) + .collect(); + + for task in tasks { + let _ = task.await.unwrap(); + } + Ok(()) +} + +fn main() { + let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap(); + let start_time = std::time::Instant::now(); + let _ = rt.block_on(download()); + println!("Processing time: {:?}", start_time.elapsed()); +} + diff --git a/example/mirror.rs b/example/mirror.rs index 83fb2e4..ab9637d 100644 --- a/example/mirror.rs +++ b/example/mirror.rs @@ -9,7 +9,7 @@ async fn main() { let _filename = api .model("ByteDance/Hyper-SD".to_string()) - .get("Hyper-SDXL-8steps-lora.safetensors") + .get("Hyper-SDXL-1step-Unet-Comfyui.fp16.safetensors") .await .unwrap(); } \ No newline at end of file diff --git a/example/rdd/Cargo.toml b/example/rdd/Cargo.toml new file mode 100644 index 0000000..d5fb21d --- /dev/null +++ b/example/rdd/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "rdd" +version = "0.1.0" +edition = "2021" + +[dependencies] +reqwest = { version = "0.12.4", default-features = false, features = ["stream", "http2", "rustls-tls"] } +tokio = { version = "1.37.0", default-features = false, features = ["rt", "fs"] } +tokio-stream = "0.1.15" diff --git a/example/rdd/src/main.rs b/example/rdd/src/main.rs new file mode 100644 index 0000000..3d746dd --- /dev/null +++ b/example/rdd/src/main.rs @@ -0,0 +1,63 @@ +use tokio; +use reqwest::header::{RANGE, CONTENT_RANGE}; +use tokio::io::{AsyncSeekExt, SeekFrom}; +use tokio_stream::StreamExt; + +const CHUNK_SIZE: usize = 10_000_000; +const URL: &str = "https://huggingface.co/ByteDance/Hyper-SD/resolve/main/Hyper-SDXL-1step-Unet-Comfyui.fp16.safetensors"; +const FILE: &str = "fp16.safetensors"; + +async fn download_chunk(s: usize, e: usize) -> Result<(), Box> { + let client = reqwest::Client::builder() + .http2_keep_alive_timeout(tokio::time::Duration::from_secs(15)).build()?; + let range = format!("bytes={s}-{e}"); + + let response = client.get(URL).header(RANGE, range).send().await?; + let mut stream = response.bytes_stream(); + + let mut file = tokio::fs::OpenOptions::new().write(true).open(FILE).await?; + file.seek(SeekFrom::Start(s as u64)).await?; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + tokio::io::copy(&mut chunk.as_ref(), &mut file).await?; + } + Ok(()) +} + +async fn download() -> Result<(), Box> { + let client = reqwest::Client::builder() + .http2_keep_alive_timeout(tokio::time::Duration::from_secs(15)).build()?; + + let response = client.get(URL).header(RANGE, "bytes=0-0").send().await?; + let length: usize = response + .headers() + .get(CONTENT_RANGE) + .ok_or("Content-Length not found")? + .to_str()?.rsplit('/').next() + .and_then(|s| s.parse().ok()) + .ok_or("Failed to parse size")?; + + let _ = tokio::fs::File::create(FILE).await?.set_len(length as u64).await?; + + let tasks: Vec<_> = (0..length) + .into_iter() + .step_by(CHUNK_SIZE) + .map(|s| { + let e = std::cmp::min(s + CHUNK_SIZE - 1, length); + tokio::spawn(async move { download_chunk(s, e).await }) + }) + .collect(); + + for task in tasks { + let _ = task.await.unwrap(); + } + Ok(()) +} + +fn main() { + let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap(); + + let start_time = std::time::Instant::now(); + let _ = rt.block_on(download()); + println!("Processing time: {:?}", start_time.elapsed()); +} diff --git a/example/simple.rs b/example/simple.rs new file mode 100644 index 0000000..3fa63a8 --- /dev/null +++ b/example/simple.rs @@ -0,0 +1,15 @@ +use tokio; + +#[tokio::main] +async fn main() { + let start_time = std::time::Instant::now(); + + let api = libhfd::api::tokio::Api::new().unwrap(); + + let _filename = api + .model("ByteDance/Hyper-SD".to_string()) + .get("Hyper-SDXL-8steps-lora.safetensors") + .await + .unwrap(); + println!("Processing time: {:?}", start_time.elapsed()); +} \ No newline at end of file diff --git a/hfd-cli/Cargo.toml b/hfd-cli/Cargo.toml index 94a1e2b..ccbdd1e 100644 --- a/hfd-cli/Cargo.toml +++ b/hfd-cli/Cargo.toml @@ -1,13 +1,14 @@ [package] name = "hfd-cli" -version = "0.1.8" +version = "0.2.0" edition = "2021" [dependencies] -tokio = { version = "1.37.0", default-features = false, features = ["rt", "rt-multi-thread", "fs"] } clap = { version= "4.5.4", features=["derive"] } +reqwest = { version = "0.12.4", default-features = false, features = ["stream", "http2", "json", "rustls-tls"] } +tokio = { version = "1.37.0", default-features = false, features = ["rt", "fs"] } +serde_json = { version = "1.0.116" } tokio-stream = "0.1.15" -libhfd = { path = "../hfd" } [[bin]] name = "hfd" diff --git a/hfd-cli/src/cli.rs b/hfd-cli/src/cli.rs index 6fd2d12..6ec8c7f 100644 --- a/hfd-cli/src/cli.rs +++ b/hfd-cli/src/cli.rs @@ -1,16 +1,33 @@ -use tokio; -// use tokio_stream::StreamExt; +use clap::{Args, Parser}; +use hfd_cli::_rt; -#[tokio::main] -async fn main() { - let start_time = std::time::Instant::now(); +#[derive(Args)] +#[group(required = false, multiple = true)] +struct Opts { + #[arg(short = 't', long, name = "TOKEN")] + token: Option, + + #[arg(short = 'd', long, name = "DIR", help = "Save it to `$DIR` or `.` ")] + dir: Option, + #[arg(short = 'm', long, name = "MIRROR", help = "Not yet applied")] + mirror: Option, + #[arg(short = 'p', long, name = "PROXY", help = "Not yet applied")] + proxy: Option, +} + +#[derive(Parser)] +struct Cli { + url: String, - let api = libhfd::api::tokio::Api::new().unwrap(); + #[command(flatten)] + opt: Opts, +} + +fn main() -> Result<(), Box> { + let start_time = std::time::Instant::now(); - let _filename = api - .model("ByteDance/Hyper-SD".to_string()) - .get("Hyper-SDXL-8steps-lora.safetensors") - .await - .unwrap(); + let cli = Cli::parse(); + let _ = _rt(&cli.url, cli.opt.token.as_deref(), cli.opt.dir.as_deref()); println!("Processing time: {:?}", start_time.elapsed()); -} \ No newline at end of file + Ok(()) +} diff --git a/hfd-cli/src/lib.rs b/hfd-cli/src/lib.rs new file mode 100644 index 0000000..3afa3c7 --- /dev/null +++ b/hfd-cli/src/lib.rs @@ -0,0 +1,248 @@ +use std::str::FromStr; +use std::path::PathBuf; +use reqwest::header::{HeaderMap, AUTHORIZATION, CONTENT_RANGE, RANGE, USER_AGENT}; +use tokio::time::Duration; +use tokio::io::{AsyncSeekExt, SeekFrom}; +use tokio_stream::StreamExt; + +const CHUNK_SIZE: usize = 10_000_000; + +#[derive(Debug)] +pub struct HfURL { + endpoint: String, + repo_type: Option, + repo_id: String, +} + +impl HfURL { + pub fn new(endpoint: String, repo_type: Option, repo_id: String) -> Self { + Self { endpoint, repo_type, repo_id } + } + + pub fn with_endpoint(mut self, endpoint: &str) -> Self { + self.endpoint = endpoint.to_string(); + self + } + + pub fn api(&self) -> String { + let repo_path = match &self.repo_type { + Some(repo_type) => repo_type.clone(), + _ => "models".to_string(), + }; + format!("https://{}/api/{}/{}", self.endpoint, repo_path, self.repo_id) + } + + pub fn path(&self, fname: &str) -> String { + let repo_path = match &self.repo_type { + Some(repo_type) => format!("{}/{}", repo_type, self.repo_id), + _ => self.repo_id.clone(), + }; + format!("https://{}/{}/resolve/main/{}", self.endpoint, repo_path, fname) + } +} + +impl FromStr for HfURL { + type Err = &'static str; + fn from_str(s: &str) -> Result { + let mut parts = s.split('/').skip(2); + let endpoint = match parts.next() { + Some(ep) => ep.to_string(), + None => return Err("Missing endpoint"), + }; + + let mut repo_type = None; + + if let Some(next_part) = parts.clone().next() { + repo_type = match next_part { + "datasets" | "spaces" => Some(next_part.to_string()), + _ => None, + }; + + if repo_type.is_some() { + parts.next(); + } + } + + let owner = parts.next().ok_or("Missing owner")?; + let repo = parts.next().ok_or("Missing repo")?; + let repo_id = format!("{}/{}", owner, repo); + + Ok(HfURL::new(endpoint, repo_type, repo_id)) + } +} + +#[derive(Debug)] +pub struct HfClient { + headers: HeaderMap, + hf_url: HfURL, + + root: PathBuf, +} + +async fn download_chunk( + headers: &HeaderMap, + url: &str, + path: &PathBuf, + s: usize, + e: usize + ) -> Result<(), Box> { + let client= reqwest::Client::builder() + .default_headers(headers.clone()) + .pool_idle_timeout(Duration::from_millis(50)) + .pool_max_idle_per_host(2) + .timeout(Duration::from_secs(30)) + .build()?; + let range = format!("bytes={s}-{e}"); + + let response = client.get(url).header(RANGE, range).send().await?; + let mut stream = response.bytes_stream(); + + let mut file = tokio::fs::OpenOptions::new().write(true).open(path).await?; + file.seek(SeekFrom::Start(s as u64)).await?; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + tokio::io::copy(&mut chunk.as_ref(), &mut file).await?; + } + Ok(()) +} + +async fn download( + headers: HeaderMap, + url: String, + path: PathBuf, + chunk_size: usize + ) -> Result<(), Box> { + // let url = self.hf_url.path(&file); + // let path = self.root.join(&file); + let client = reqwest::Client::builder() + .default_headers(headers.clone()) + .http2_keep_alive_timeout(Duration::from_secs(15)).build()?; + + let response = client.get(&url).header(RANGE, "bytes=0-0").send().await?; + let length: usize = response + .headers() + .get(CONTENT_RANGE) + .ok_or("Content-Length not found")? + .to_str()?.rsplit('/').next() + .and_then(|s| s.parse().ok()) + .ok_or("Failed to parse size")?; + + let _ = tokio::fs::File::create(&path).await?.set_len(length as u64).await?; + + let tasks: Vec<_> = (0..length) + .into_iter() + .step_by(chunk_size) + .map(|s| { + let _url = url.clone(); + let _path = path.clone(); + let headers = headers.clone(); + let e = std::cmp::min(s + chunk_size - 1, length); + tokio::spawn(async move { download_chunk(&headers, &_url, &_path, s, e).await }) + }) + .collect(); + + for task in tasks { + let _ = task.await.unwrap(); + } + Ok(()) +} + +impl HfClient { + pub fn new(headers: HeaderMap, hf_url: HfURL) -> Self { + let default = match std::env::var("HF_HOME") { + Ok(home) => home, + Err(_) => ".".to_string() + }; + + let root = PathBuf::from(default).join(hf_url.repo_id.clone()); + Self { headers, hf_url, root } + } + + pub fn build(url: &str) -> Result>{ + let hf_url = url.parse()?; + let mut headers = HeaderMap::new(); + headers.insert(USER_AGENT, "hyper-client-http2".parse()?); + Ok(Self::new(headers, hf_url)) + } + + pub fn apply_token(mut self, _token: Option<&str>) -> Self{ + if let Some(token) = _token { + self.headers.insert(AUTHORIZATION, format!("Bearer {token}").parse().unwrap()); + } + self + } + + pub fn apply_root(mut self, _root: Option<&str>) -> Self{ + if let Some(root) = _root { + self.root = PathBuf::from(root).join(self.hf_url.repo_id.clone()); + } + self + } + + async fn list_files(&self) -> Result, Box> { + let client = reqwest::Client::builder() + .http2_keep_alive_timeout(Duration::from_secs(15)).build()?; + let api = self.hf_url.api(); + let response = client.get(api) + .headers(self.headers.clone()) + .send().await? + .json::() + .await?; + + let mut files: Vec = Vec::new(); + + if let Some(siblings) = response["siblings"].as_array() { + let mut _files: Vec = siblings.into_iter() + .map(|f|f.get("rfilename").expect("filename").as_str()) + .flatten() + .map(|x| x.into()) + .collect(); + files.append(&mut _files); + } + Ok(files) + } + + fn create_dir_all(&self, files: Vec) -> Result<(), Box> { + for file in files { + if let Some(parent) = self.root.join(file).parent() { + let _ = std::fs::create_dir_all(parent)?; + } + } + Ok(()) + } + + pub async fn download_all(&self) -> Result<(), Box> { + let files = self.list_files().await?; + let _ = self.create_dir_all(files.clone()); + let file_chunks: Vec<_> = files + .chunks(30) + .map(|chunk| chunk.to_owned()) + .collect(); + + for fc in file_chunks{ + let tasks: Vec<_> = fc.into_iter() + .map(|f| { + let url = self.hf_url.path(&f); + let path = self.root.join(&f); + let headers = self.headers.clone(); + tokio::spawn(async move {let _ = download(headers, url, path, CHUNK_SIZE).await; }) + }) + .collect(); + + for task in tasks { + task.await.unwrap(); + } + } + Ok(()) + } +} + +pub fn _rt(_url: &str, _token: Option<&str>, _dir: Option<&str>) -> Result<(), Box> { + let rt = tokio::runtime::Builder::new_current_thread().enable_all().build()?; + let hfc = HfClient::build(_url)? + .apply_token(_token) + .apply_root(_dir); + + let _ = rt.block_on(hfc.download_all()); + Ok(()) +} \ No newline at end of file diff --git a/hfd/Cargo.toml b/hfd/Cargo.toml index e73219c..fb44500 100644 --- a/hfd/Cargo.toml +++ b/hfd/Cargo.toml @@ -4,15 +4,15 @@ version = "0.1.0" edition = "2021" [dependencies] -tokio = { version = "1.37.0", features = ["full"] } -reqwest = { version = "0.12.4", default-features = false, features = ["json", "charset", "http2", "macos-system-configuration", "rustls-tls"] } +tokio = { version = "1.37.0", default-features = false, features = ["rt", "fs"] } +reqwest = { version = "0.12.4", default-features = false, features = ["stream", "json", "charset", "http2", "macos-system-configuration", "rustls-tls"] } indicatif = { version = "0.17.8" } serde = { version = "1.0.200", features = ["derive"] } serde_json = { version = "1.0.116" } +tokio-stream = "0.1.15" futures = "0.3.30" thiserror = "1.0.59" -dirs = "5.0.1" rand = "0.8.5" num_cpus = "1.16.0" diff --git a/hfd/src/api/mod.rs b/hfd/src/api/mod.rs index cca0cd2..f8a22b4 100644 --- a/hfd/src/api/mod.rs +++ b/hfd/src/api/mod.rs @@ -1,15 +1 @@ -use serde::Deserialize; - -pub mod tokio; - -#[derive(Debug, Clone, Deserialize, PartialEq)] -pub struct Siblings { - pub rfilename: String, -} - -#[derive(Debug, Clone, Deserialize, PartialEq)] -pub struct RepoInfo { - pub siblings: Vec, - - pub sha: String, -} \ No newline at end of file +pub mod tokio; \ No newline at end of file diff --git a/hfd/src/api/tokio.rs b/hfd/src/api/tokio.rs index 8b7e21d..e93788a 100644 --- a/hfd/src/api/tokio.rs +++ b/hfd/src/api/tokio.rs @@ -16,6 +16,7 @@ use std::sync::Arc; use thiserror::Error; use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom}; use tokio::sync::{AcquireError, Semaphore, TryAcquireError}; +use tokio_stream::StreamExt; const VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -426,6 +427,7 @@ impl ApiRepo { stop: usize, ) -> Result<(), ApiError> { let range = format!("bytes={start}-{stop}"); + let mut file = tokio::fs::OpenOptions::new() .write(true) .open(filename) @@ -435,10 +437,14 @@ impl ApiRepo { .get(url) .header(RANGE, range) .send() - .await? - .error_for_status()?; - let content = response.bytes().await?; - file.write_all(&content).await?; + .await?; + + let mut stream = response.bytes_stream(); + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + tokio::io::copy(&mut chunk.as_ref(), &mut file).await?; + } + Ok(()) } diff --git a/hfd/src/api/tokio_scope.rs b/hfd/src/api/tokio_scope.rs new file mode 100644 index 0000000..f33eb8b --- /dev/null +++ b/hfd/src/api/tokio_scope.rs @@ -0,0 +1,504 @@ +use super::RepoInfo; +use crate::{Cache, Repo, RepoType}; +use indicatif::{ProgressBar, ProgressStyle}; +use rand::Rng; +use reqwest::{ + header::{ + HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue, ToStrError, AUTHORIZATION, + CONTENT_RANGE, LOCATION, RANGE, USER_AGENT, + }, + redirect::Policy, + Client, Error as ReqwestError, RequestBuilder, +}; +use std::num::ParseIntError; +use std::path::{Component, Path, PathBuf}; +use std::sync::Arc; +use thiserror::Error; +use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom}; +use tokio::sync::{AcquireError, Semaphore, TryAcquireError}; + +const VERSION: &str = env!("CARGO_PKG_VERSION"); + +const NAME: &str = env!("CARGO_PKG_NAME"); + +#[derive(Debug, Error)] + +pub enum ApiError { + #[error("Header {0} is missing")] + MissingHeader(HeaderName), + + #[error("Header {0} is invalid")] + InvalidHeader(HeaderName), + + #[error("Invalid header value {0}")] + InvalidHeaderValue(#[from] InvalidHeaderValue), + + #[error("header value is not a string")] + ToStr(#[from] ToStrError), + + #[error("request error: {0}")] + RequestError(#[from] ReqwestError), + + #[error("Cannot parse int")] + ParseIntError(#[from] ParseIntError), + + #[error("I/O error {0}")] + IoError(#[from] std::io::Error), + + #[error("Too many retries: {0}")] + TooManyRetries(Box), + + #[error("Try acquire: {0}")] + TryAcquireError(#[from] TryAcquireError), + + #[error("Acquire: {0}")] + AcquireError(#[from] AcquireError), +} + +#[derive(Debug)] +pub struct ApiBuilder { + endpoint: String, + cache: Cache, + url_template: String, + token: Option, + max_files: usize, + chunk_size: usize, + parallel_failures: usize, + max_retries: usize, + progress: bool, +} + +impl ApiBuilder { + pub fn new() -> Self { + let cache = Cache::default(); + Self::from_cache(cache) + } + + pub fn from_cache(cache: Cache) -> Self { + let token = cache.token(); + + let progress = true; + + Self { + endpoint: "https://huggingface.co".to_string(), + url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(), + cache, + token, + max_files: num_cpus::get(), + chunk_size: 10_000_000, + parallel_failures: 0, + max_retries: 0, + progress, + } + } + + pub fn with_endpoint(mut self, endpoint: &str) -> Self { + self.endpoint = endpoint.to_string(); + self + } + + pub fn with_progress(mut self, progress: bool) -> Self { + self.progress = progress; + self + } + + pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { + self.cache = Cache::new(cache_dir); + self + } + + pub fn with_token(mut self, token: &str) -> Self { + self.token = Some(token.to_string()); + self + } + + fn build_headers(&self) -> Result { + let mut headers = HeaderMap::new(); + let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown"); + headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?); + if let Some(token) = &self.token { + headers.insert( + AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {token}"))?, + ); + } + Ok(headers) + } + + pub fn build(self) -> Result { + let headers = self.build_headers()?; + let client = Client::builder().default_headers(headers.clone()).build()?; + + let relative_redirect_policy = Policy::custom(|attempt| { + if attempt.previous().len() > 10 { + return attempt.error("too many redirects"); + } + + if let Some(last) = attempt.previous().last() { + if last.make_relative(attempt.url()).is_none() { + return attempt.stop(); + } + } + + attempt.follow() + }); + + let relative_redirect_client = Client::builder() + .redirect(relative_redirect_policy) + .default_headers(headers.clone()) + .build()?; + Ok(Api { + endpoint: self.endpoint, + url_template: self.url_template, + cache: self.cache, + client, + headers, + relative_redirect_client, + max_files: self.max_files, + chunk_size: self.chunk_size, + parallel_failures: self.parallel_failures, + max_retries: self.max_retries, + progress: self.progress, + }) + } +} + +#[derive(Debug)] +struct Metadata { + commit_hash: String, + etag: String, + size: usize, +} + +#[derive(Clone, Debug)] +pub struct Api { + endpoint: String, + url_template: String, + cache: Cache, + client: Client, + headers: HeaderMap, + relative_redirect_client: Client, + max_files: usize, + chunk_size: usize, + parallel_failures: usize, + max_retries: usize, + progress: bool, +} + +fn make_relative(src: &Path, dst: &Path) -> PathBuf { + let path = src; + let base = dst; + + assert_eq!( + path.is_absolute(), + base.is_absolute(), + "This function is made to look at absolute paths only" + ); + let mut ita = path.components(); + let mut itb = base.components(); + + loop { + match (ita.next(), itb.next()) { + (Some(a), Some(b)) if a == b => (), + (some_a, _) => { + let mut new_path = PathBuf::new(); + for _ in itb { + new_path.push(Component::ParentDir); + } + if let Some(a) = some_a { + new_path.push(a); + for comp in ita { + new_path.push(comp); + } + } + return new_path; + } + } + } +} + +fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> { + if dst.exists() { + return Ok(()); + } + + let rel_src = make_relative(src, dst); + #[cfg(target_os = "windows")] + { + if std::os::windows::fs::symlink_file(rel_src, dst).is_err() { + std::fs::rename(src, dst)?; + } + } + + #[cfg(target_family = "unix")] + std::os::unix::fs::symlink(rel_src, dst)?; + + Ok(()) +} + +fn jitter() -> usize { + rand::thread_rng().gen_range(0..=500) +} + +fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize { + (base_wait_time + n.pow(2) + jitter()).min(max) +} + +impl Api { + pub fn new() -> Result { + ApiBuilder::new().build() + } + + pub fn client(&self) -> &Client { + &self.client + } + + async fn metadata(&self, url: &str) -> Result { + let response = self + .relative_redirect_client + .get(url) + .header(RANGE, "bytes=0-0") + .send() + .await?; + let response = response.error_for_status()?; + let headers = response.headers(); + let header_commit = HeaderName::from_static("x-repo-commit"); + let header_linked_etag = HeaderName::from_static("x-linked-etag"); + let header_etag = HeaderName::from_static("etag"); + + let etag = match headers.get(&header_linked_etag) { + Some(etag) => etag, + None => headers + .get(&header_etag) + .ok_or(ApiError::MissingHeader(header_etag))?, + }; + + let etag = etag.to_str()?.to_string().replace('"', ""); + let commit_hash = headers + .get(&header_commit) + .ok_or(ApiError::MissingHeader(header_commit))? + .to_str()? + .to_string(); + + let response = if response.status().is_redirection() { + self.client + .get(headers.get(LOCATION).unwrap().to_str()?.to_string()) + .header(RANGE, "bytes=0-0") + .send() + .await? + } else { + response + }; + let headers = response.headers(); + let content_range = headers + .get(CONTENT_RANGE) + .ok_or(ApiError::MissingHeader(CONTENT_RANGE))? + .to_str()?; + + let size = content_range + .split('/') + .last() + .ok_or(ApiError::InvalidHeader(CONTENT_RANGE))? + .parse()?; + Ok(Metadata { + commit_hash, + etag, + size, + }) + } + + pub fn repo(&self, repo: Repo) -> ApiRepo { + ApiRepo::new(self.clone(), repo) + } + + pub fn model(&self, model_id: String) -> ApiRepo { + self.repo(Repo::new(model_id, RepoType::Model)) + } + + pub fn dataset(&self, model_id: String) -> ApiRepo { + self.repo(Repo::new(model_id, RepoType::Dataset)) + } + + pub fn space(&self, model_id: String) -> ApiRepo { + self.repo(Repo::new(model_id, RepoType::Space)) + } +} + +#[derive(Debug)] +pub struct ApiRepo { + api: Api, + repo: Repo, +} + +impl ApiRepo { + fn new(api: Api, repo: Repo) -> Self { + Self { api, repo } + } +} + +impl ApiRepo { + pub fn url(&self, filename: &str) -> String { + let endpoint = &self.api.endpoint; + let revision = &self.repo.url_revision(); + self.api + .url_template + .replace("{endpoint}", endpoint) + .replace("{repo_id}", &self.repo.url()) + .replace("{revision}", revision) + .replace("{filename}", filename) + } + + async fn download_tempfile( + &self, + url: &str, + length: usize, + progressbar: Option, + ) -> Result { + let mut handles = vec![]; + let semaphore = Arc::new(Semaphore::new(self.api.max_files)); + let parallel_failures_semaphore = Arc::new(Semaphore::new(self.api.parallel_failures)); + let filename = self.api.cache.temp_path(); + + tokio::fs::File::create(&filename) + .await? + .set_len(length as u64) + .await?; + + let chunk_size = self.api.chunk_size; + for start in (0..length).step_by(chunk_size) { + let url = url.to_string(); + let filename = filename.clone(); + let headers = self.api.headers.clone(); + // let client = self.api.client.clone(); + + let stop = std::cmp::min(start + chunk_size - 1, length); + let permit = semaphore.clone().acquire_owned().await?; + let parallel_failures = self.api.parallel_failures; + let max_retries = self.api.max_retries; + let parallel_failures_semaphore = parallel_failures_semaphore.clone(); + let progress = progressbar.clone(); + handles.push(tokio::spawn(async move { + let mut chunk = Self::download_chunk(&headers, &url, &filename, start, stop).await; + let mut i = 0; + if parallel_failures > 0 { + while let Err(dlerr) = chunk { + let parallel_failure_permit = + parallel_failures_semaphore.clone().try_acquire_owned()?; + + let wait_time = exponential_backoff(300, i, 10_000); + tokio::time::sleep(tokio::time::Duration::from_millis(wait_time as u64)) + .await; + + chunk = Self::download_chunk(&headers, &url, &filename, start, stop).await; + i += 1; + if i > max_retries { + return Err(ApiError::TooManyRetries(dlerr.into())); + } + drop(parallel_failure_permit); + } + } + drop(permit); + if let Some(p) = progress { + p.inc((stop - start) as u64); + } + chunk + })); + } + + let results: Vec, tokio::task::JoinError>> = + futures::future::join_all(handles).await; + let results: Result<(), ApiError> = results.into_iter().flatten().collect(); + results?; + if let Some(p) = progressbar { + p.finish(); + } + Ok(filename) + } + + async fn download_chunk( + headers: &HeaderMap, + url: &str, + filename: &PathBuf, + start: usize, + stop: usize, + ) -> Result<(), ApiError> { + let client = Client::builder().default_headers(headers.clone()).build()?; + let range = format!("bytes={start}-{stop}"); + let mut file = tokio::fs::OpenOptions::new() + .write(true) + .open(filename) + .await?; + file.seek(SeekFrom::Start(start as u64)).await?; + let response = client + .get(url) + .header(RANGE, range) + .send() + .await? + .error_for_status()?; + let content = response.bytes().await?; + file.write_all(&content).await?; + Ok(()) + } + + pub async fn get(&self, filename: &str) -> Result { + if let Some(path) = self.api.cache.repo(self.repo.clone()).get(filename) { + Ok(path) + } else { + self.download(filename).await + } + } + + pub async fn download(&self, filename: &str) -> Result { + let url = self.url(filename); + let metadata = self.api.metadata(&url).await?; + let cache = self.api.cache.repo(self.repo.clone()); + + let blob_path = cache.blob_path(&metadata.etag); + std::fs::create_dir_all(blob_path.parent().unwrap())?; + + let progressbar = if self.api.progress { + let progress = ProgressBar::new(metadata.size as u64); + progress.set_style( + ProgressStyle::with_template( + "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", + ) + .unwrap(), + ); + let maxlength = 30; + let message = if filename.len() > maxlength { + format!("..{}", &filename[filename.len() - maxlength..]) + } else { + filename.to_string() + }; + progress.set_message(message); + Some(progress) + } else { + None + }; + + let tmp_filename = self + .download_tempfile(&url, metadata.size, progressbar) + .await?; + + tokio::fs::rename(&tmp_filename, &blob_path).await?; + + let mut pointer_path = cache.pointer_path(&metadata.commit_hash); + pointer_path.push(filename); + std::fs::create_dir_all(pointer_path.parent().unwrap()).ok(); + + symlink_or_rename(&blob_path, &pointer_path)?; + cache.create_ref(&metadata.commit_hash)?; + + Ok(pointer_path) + } + + pub async fn info(&self) -> Result { + Ok(self.info_request().send().await?.json().await?) + } + + pub fn info_request(&self) -> RequestBuilder { + let url = format!("{}/api/{}", self.api.endpoint, self.repo.api_url()); + self.api.client.get(url) + } +} + diff --git a/sample.json b/sample.json new file mode 100644 index 0000000..e1fae71 --- /dev/null +++ b/sample.json @@ -0,0 +1,647 @@ +{ + "siblings": [ + { + "rfilename": ".gitattributes" + }, + { + "rfilename": "README.md" + }, + { + "rfilename": "app.py" + }, + { + "rfilename": "banned_ids.txt" + }, + { + "rfilename": "data/000000.parquet" + }, + { + "rfilename": "data/000001.parquet" + }, + { + "rfilename": "data/000002.parquet" + }, + { + "rfilename": "data/000003.parquet" + }, + { + "rfilename": "data/000004.parquet" + }, + { + "rfilename": "data/000005.parquet" + }, + { + "rfilename": "data/000006.parquet" + }, + { + "rfilename": "data/000007.parquet" + }, + { + "rfilename": "data/000008.parquet" + }, + { + "rfilename": "data/000009.parquet" + }, + { + "rfilename": "data/000010.parquet" + }, + { + "rfilename": "data/000011.parquet" + }, + { + "rfilename": "data/000012.parquet" + }, + { + "rfilename": "data/000013.parquet" + }, + { + "rfilename": "data/000014.parquet" + }, + { + "rfilename": "data/000015.parquet" + }, + { + "rfilename": "data/000016.parquet" + }, + { + "rfilename": "data/000017.parquet" + }, + { + "rfilename": "data/000018.parquet" + }, + { + "rfilename": "data/000019.parquet" + }, + { + "rfilename": "data/000020.parquet" + }, + { + "rfilename": "data/000021.parquet" + }, + { + "rfilename": "data/000022.parquet" + }, + { + "rfilename": "data/000023.parquet" + }, + { + "rfilename": "data/000024.parquet" + }, + { + "rfilename": "data/000025.parquet" + }, + { + "rfilename": "data/000026.parquet" + }, + { + "rfilename": "data/000027.parquet" + }, + { + "rfilename": "data/000028.parquet" + }, + { + "rfilename": "data/000029.parquet" + }, + { + "rfilename": "data/000030.parquet" + }, + { + "rfilename": "data/000031.parquet" + }, + { + "rfilename": "data/000032.parquet" + }, + { + "rfilename": "data/000033.parquet" + }, + { + "rfilename": "data/000034.parquet" + }, + { + "rfilename": "data/000035.parquet" + }, + { + "rfilename": "data/000036.parquet" + }, + { + "rfilename": "data/000037.parquet" + }, + { + "rfilename": "data/000038.parquet" + }, + { + "rfilename": "data/000039.parquet" + }, + { + "rfilename": "data/000040.parquet" + }, + { + "rfilename": "data/000041.parquet" + }, + { + "rfilename": "data/000042.parquet" + }, + { + "rfilename": "data/000043.parquet" + }, + { + "rfilename": "data/000044.parquet" + }, + { + "rfilename": "data/000045.parquet" + }, + { + "rfilename": "data/000046.parquet" + }, + { + "rfilename": "data/000047.parquet" + }, + { + "rfilename": "data/000048.parquet" + }, + { + "rfilename": "data/000049.parquet" + }, + { + "rfilename": "data/000050.parquet" + }, + { + "rfilename": "data/000051.parquet" + }, + { + "rfilename": "data/000052.parquet" + }, + { + "rfilename": "data/000053.parquet" + }, + { + "rfilename": "data/000054.parquet" + }, + { + "rfilename": "data/000055.parquet" + }, + { + "rfilename": "data/000056.parquet" + }, + { + "rfilename": "data/000057.parquet" + }, + { + "rfilename": "data/000058.parquet" + }, + { + "rfilename": "data/000059.parquet" + }, + { + "rfilename": "data/000060.parquet" + }, + { + "rfilename": "data/000061.parquet" + }, + { + "rfilename": "data/000062.parquet" + }, + { + "rfilename": "data/000063.parquet" + }, + { + "rfilename": "data/000064.parquet" + }, + { + "rfilename": "data/000065.parquet" + }, + { + "rfilename": "data/000066.parquet" + }, + { + "rfilename": "data/000067.parquet" + }, + { + "rfilename": "data/000068.parquet" + }, + { + "rfilename": "data/000069.parquet" + }, + { + "rfilename": "data/000070.parquet" + }, + { + "rfilename": "data/000071.parquet" + }, + { + "rfilename": "data/000072.parquet" + }, + { + "rfilename": "data/000073.parquet" + }, + { + "rfilename": "data/000074.parquet" + }, + { + "rfilename": "data/000075.parquet" + }, + { + "rfilename": "data/000076.parquet" + }, + { + "rfilename": "data/000077.parquet" + }, + { + "rfilename": "data/000078.parquet" + }, + { + "rfilename": "data/000079.parquet" + }, + { + "rfilename": "data/000080.parquet" + }, + { + "rfilename": "data/000081.parquet" + }, + { + "rfilename": "data/000082.parquet" + }, + { + "rfilename": "data/000083.parquet" + }, + { + "rfilename": "data/000084.parquet" + }, + { + "rfilename": "data/000085.parquet" + }, + { + "rfilename": "data/000086.parquet" + }, + { + "rfilename": "data/000087.parquet" + }, + { + "rfilename": "data/000088.parquet" + }, + { + "rfilename": "data/000089.parquet" + }, + { + "rfilename": "data/000090.parquet" + }, + { + "rfilename": "data/000091.parquet" + }, + { + "rfilename": "data/000092.parquet" + }, + { + "rfilename": "data/000093.parquet" + }, + { + "rfilename": "data/000094.parquet" + }, + { + "rfilename": "data/000095.parquet" + }, + { + "rfilename": "data/000096.parquet" + }, + { + "rfilename": "data/000097.parquet" + }, + { + "rfilename": "data/000098.parquet" + }, + { + "rfilename": "data/000099.parquet" + }, + { + "rfilename": "data/000100.parquet" + }, + { + "rfilename": "data/000101.parquet" + }, + { + "rfilename": "data/000102.parquet" + }, + { + "rfilename": "data/000103.parquet" + }, + { + "rfilename": "data/000104.parquet" + }, + { + "rfilename": "data/000105.parquet" + }, + { + "rfilename": "data/000106.parquet" + }, + { + "rfilename": "data/000107.parquet" + }, + { + "rfilename": "data/000108.parquet" + }, + { + "rfilename": "data/000109.parquet" + }, + { + "rfilename": "data/000110.parquet" + }, + { + "rfilename": "data/000111.parquet" + }, + { + "rfilename": "data/000112.parquet" + }, + { + "rfilename": "data/000113.parquet" + }, + { + "rfilename": "data/000114.parquet" + }, + { + "rfilename": "data/000115.parquet" + }, + { + "rfilename": "data/000116.parquet" + }, + { + "rfilename": "data/000117.parquet" + }, + { + "rfilename": "data/000118.parquet" + }, + { + "rfilename": "data/000119.parquet" + }, + { + "rfilename": "data/000120.parquet" + }, + { + "rfilename": "data/000121.parquet" + }, + { + "rfilename": "data/000122.parquet" + }, + { + "rfilename": "data/000123.parquet" + }, + { + "rfilename": "data/000124.parquet" + }, + { + "rfilename": "data/000125.parquet" + }, + { + "rfilename": "data/000126.parquet" + }, + { + "rfilename": "data/000127.parquet" + }, + { + "rfilename": "data/000128.parquet" + }, + { + "rfilename": "data/000129.parquet" + }, + { + "rfilename": "data/000130.parquet" + }, + { + "rfilename": "data/000131.parquet" + }, + { + "rfilename": "data/000132.parquet" + }, + { + "rfilename": "data/000133.parquet" + }, + { + "rfilename": "data/000134.parquet" + }, + { + "rfilename": "data/000135.parquet" + }, + { + "rfilename": "data/000136.parquet" + }, + { + "rfilename": "data/000137.parquet" + }, + { + "rfilename": "data/000138.parquet" + }, + { + "rfilename": "data/000139.parquet" + }, + { + "rfilename": "data/000140.parquet" + }, + { + "rfilename": "data/000141.parquet" + }, + { + "rfilename": "data/000142.parquet" + }, + { + "rfilename": "data/000143.parquet" + }, + { + "rfilename": "data/000144.parquet" + }, + { + "rfilename": "data/000145.parquet" + }, + { + "rfilename": "data/000146.parquet" + }, + { + "rfilename": "data/000147.parquet" + }, + { + "rfilename": "data/000148.parquet" + }, + { + "rfilename": "data/000149.parquet" + }, + { + "rfilename": "data/000150.parquet" + }, + { + "rfilename": "data/000151.parquet" + }, + { + "rfilename": "data/000152.parquet" + }, + { + "rfilename": "data/000153.parquet" + }, + { + "rfilename": "data/000154.parquet" + }, + { + "rfilename": "data/000155.parquet" + }, + { + "rfilename": "data/000156.parquet" + }, + { + "rfilename": "data/000157.parquet" + }, + { + "rfilename": "data/000158.parquet" + }, + { + "rfilename": "data/000159.parquet" + }, + { + "rfilename": "data/000160.parquet" + }, + { + "rfilename": "data/000161.parquet" + }, + { + "rfilename": "data/000162.parquet" + }, + { + "rfilename": "data/000163.parquet" + }, + { + "rfilename": "data/000164.parquet" + }, + { + "rfilename": "data/000165.parquet" + }, + { + "rfilename": "data/000166.parquet" + }, + { + "rfilename": "data/000167.parquet" + }, + { + "rfilename": "data/000168.parquet" + }, + { + "rfilename": "data/000169.parquet" + }, + { + "rfilename": "data/000170.parquet" + }, + { + "rfilename": "data/000171.parquet" + }, + { + "rfilename": "data/000172.parquet" + }, + { + "rfilename": "data/000173.parquet" + }, + { + "rfilename": "data/000174.parquet" + }, + { + "rfilename": "data/000175.parquet" + }, + { + "rfilename": "data/000176.parquet" + }, + { + "rfilename": "data/000177.parquet" + }, + { + "rfilename": "data/000178.parquet" + }, + { + "rfilename": "data/000179.parquet" + }, + { + "rfilename": "data/000180.parquet" + }, + { + "rfilename": "data/000181.parquet" + }, + { + "rfilename": "data/000182.parquet" + }, + { + "rfilename": "data/000183.parquet" + }, + { + "rfilename": "data/000184.parquet" + }, + { + "rfilename": "data/000185.parquet" + }, + { + "rfilename": "data/000186.parquet" + }, + { + "rfilename": "data/000187.parquet" + }, + { + "rfilename": "data/000188.parquet" + }, + { + "rfilename": "data/000189.parquet" + }, + { + "rfilename": "data/000190.parquet" + }, + { + "rfilename": "data/000191.parquet" + }, + { + "rfilename": "data/000192.parquet" + }, + { + "rfilename": "data/000193.parquet" + }, + { + "rfilename": "data/000194.parquet" + }, + { + "rfilename": "data/000195.parquet" + }, + { + "rfilename": "data/000196.parquet" + }, + { + "rfilename": "data/000197.parquet" + }, + { + "rfilename": "data/000198.parquet" + }, + { + "rfilename": "data/000199.parquet" + }, + { + "rfilename": "data/000200.parquet" + }, + { + "rfilename": "data/000201.parquet" + }, + { + "rfilename": "data/000202.parquet" + }, + { + "rfilename": "data/000203.parquet" + }, + { + "rfilename": "data/000204.parquet" + }, + { + "rfilename": "scene_list_adaptive.jsonl" + }, + { + "rfilename": "scene_list_all.jsonl" + }, + { + "rfilename": "scene_list_content.jsonl" + }, + { + "rfilename": "scene_list_threshold.jsonl" + }, + { + "rfilename": "test.bin" + } + ], + "createdAt": "2024-03-18T06:30:44.000Z" + } \ No newline at end of file