Skip to content

Commit

Permalink
bump to v2.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Mon-ius committed May 12, 2024
1 parent c9b8c3f commit 6332e4f
Show file tree
Hide file tree
Showing 16 changed files with 1,738 additions and 40 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[workspace]
members = ["hfd", "hfd-cli"]
# members = ["hfd", "hfd-cli"]
members = ["hfd-cli"]
resolver = "2"

[profile.release]
Expand Down
32 changes: 31 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,34 @@
[![GitHub release (with filter)](https://img.shields.io/github/v/release/AUTOM77/hfd?logo=github)](https://github.com/AUTOM77/hfd/releases)


🎈Rust-based interface for Huggingface 🤗 download
🎈Rust-based interface for Huggingface 🤗 download.

`./hdf "https://huggingface.co/deepseek-ai/DeepSeek-V2"`

For a more convinent user experience, execute:

```bash
cat <<EOF | sudo tee -a /etc/security/limits.conf
root soft nofile 20000000
root hard nofile 20000000
* hard nofile 20000000
* soft nofile 20000000
EOF

cat <<EOF | sudo tee /etc/sysctl.d/bbr.conf
net.core.default_qdisc=fq_codel
net.ipv4.tcp_congestion_control=bbr
net.ipv4.tcp_moderate_rcvbuf = 1
net.ipv4.tcp_mem = '10000000 10000000 10000000'
net.ipv4.tcp_rmem = '1024 4096 16384'
net.ipv4.tcp_wmem = '1024 4096 16384'
net.core.wmem_max = 26214400
net.core.rmem_max = 26214400
fs.file-max = 12000500
fs.nr_open = 20000500
EOF
```

- https://gist.github.com/mustafaturan/47268d8ad6d56cadda357e4c438f51ca
14 changes: 14 additions & 0 deletions example/hyper/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "rddd"
version = "0.1.0"
edition = "2021"

[dependencies]
hyper = { version = "1.3.1", default-features = false, features = ["http2", "client"] }
hyper-util = { version = "0.1.3", features = ["client", "http2", "tokio"] }
tokio = { version = "1.37.0", default-features = false, features = ["rt", "fs"] }
tokio-rustls = { version = "0.26.0", default-features = false, features = ["ring", "tls12"] }
rustls-pki-types = { version = "1.7.0", features = ["alloc"] }
tokio-stream = "0.1.15"
webpki-roots = "0.26"
http-body-util = "0.1"
157 changes: 157 additions & 0 deletions example/hyper/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
use std::net::ToSocketAddrs;
use hyper::body::Bytes;
use hyper::header::{RANGE, HOST, CONTENT_RANGE, LOCATION};

use http_body_util::BodyExt;
use tokio::io::{AsyncWriteExt, AsyncSeekExt, SeekFrom};
use tokio_rustls::rustls;
use tokio_stream::StreamExt;

const ALPN_H2: &str = "h2";
const CHUNK_SIZE: usize = 100_000_000;

const URL: &str = "https://huggingface.co/ByteDance/Hyper-SD/resolve/main/Hyper-SDXL-1step-Unet-Comfyui.fp16.safetensors";
const FILE: &str = "fp16.safetensors";

async fn download_chunk(u: hyper::Uri, s: usize, e: usize) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let host = u.host().expect("no host");
let port = u.port_u16().unwrap_or(443);
let addr = format!("{}:{}", host, port).to_socket_addrs()?.next().unwrap();

let conf = std::sync::Arc::new({
let root_store = rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let mut c = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
c.alpn_protocols.push(ALPN_H2.as_bytes().to_owned());
c
});

let tcp = tokio::net::TcpStream::connect(&addr).await?;
let domain = rustls_pki_types::ServerName::try_from(host)?.to_owned();;
let connector = tokio_rustls::TlsConnector::from(conf);

let stream = connector.connect(domain, tcp).await?;
let _io = hyper_util::rt::TokioIo::new(stream);
let exec = hyper_util::rt::tokio::TokioExecutor::new();

let (mut client, mut h2) = hyper::client::conn::http2::handshake(exec, _io).await?;
tokio::spawn(async move {
if let Err(e) = h2.await {
println!("Error: {:?}", e);
}
});

let range = format!("bytes={s}-{e}");

let req = hyper::Request::builder()
.uri(u)
.header("user-agent", "hyper-client-http2")
.header(RANGE, range)
.version(hyper::Version::HTTP_2)
.body(http_body_util::Empty::<Bytes>::new())?;

let mut response = client.send_request(req).await?;

let mut file = tokio::fs::OpenOptions::new().write(true).open(FILE).await?;
file.seek(SeekFrom::Start(s as u64)).await?;
while let Some(chunk) = response.frame().await {
let chunk = chunk?;
if let Some(c) = chunk.data_ref() {
tokio::io::copy(&mut c.as_ref(), &mut file).await?;
}
}
Ok(())
}

async fn download() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut url: hyper::Uri = URL.parse()?;
let host = url.host().expect("no host");
let port = url.port_u16().unwrap_or(443);
let addr = format!("{}:{}", host, port).to_socket_addrs()?.next().unwrap();

let conf = std::sync::Arc::new({
let root_store = rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let mut c = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
c.alpn_protocols.push(ALPN_H2.as_bytes().to_owned());
c
});

let tcp = tokio::net::TcpStream::connect(&addr).await?;
let domain = rustls_pki_types::ServerName::try_from(host)?.to_owned();;
let connector = tokio_rustls::TlsConnector::from(conf);

let stream = connector.connect(domain, tcp).await?;
let _io = hyper_util::rt::TokioIo::new(stream);
let exec = hyper_util::rt::tokio::TokioExecutor::new();

let (mut client, mut h2) = hyper::client::conn::http2::handshake(exec, _io).await?;
tokio::spawn(async move {
if let Err(e) = h2.await {
println!("Error: {:?}", e);
}
});

let req = hyper::Request::builder()
.uri(url.clone())
.header("user-agent", "hyper-client-http2")
.header(RANGE, "bytes=0-0")
.version(hyper::Version::HTTP_2)
.body(http_body_util::Empty::<Bytes>::new())?;

let mut response = client.send_request(req).await?;
while let Some(location) = response.headers().get(LOCATION) {
let _cdn: hyper::Uri = location.to_str()?.parse()?;
let _req = hyper::Request::builder()
.uri(_cdn.clone())
.header("user-agent", "hyper-client-http2")
.version(hyper::Version::HTTP_2)
.body(http_body_util::Empty::<Bytes>::new())?;
response = client.send_request(_req).await?;
url = _cdn;
}

println!("{:?}", url);
let req = hyper::Request::builder()
.uri(url.clone())
.header("user-agent", "hyper-client-http2")
.header(RANGE, "bytes=0-0")
.version(hyper::Version::HTTP_2)
.body(http_body_util::Empty::<Bytes>::new())?;
let response = client.send_request(req).await?;

println!("{:?}", response);
let length: usize = response
.headers()
.get(CONTENT_RANGE)
.ok_or("Content-Length not found")?
.to_str()?.rsplit('/').next()
.and_then(|s| s.parse().ok())
.ok_or("Failed to parse size")?;

let _ = tokio::fs::File::create(FILE).await?.set_len(length as u64).await?;
let tasks: Vec<_> = (0..length)
.into_iter()
.step_by(CHUNK_SIZE)
.map(|s| {
let _url = url.clone();
let e = std::cmp::min(s + CHUNK_SIZE - 1, length);
tokio::spawn(async move { download_chunk(_url, s, e).await })
})
.collect();

for task in tasks {
let _ = task.await.unwrap();
}
Ok(())
}

fn main() {
let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
let start_time = std::time::Instant::now();
let _ = rt.block_on(download());
println!("Processing time: {:?}", start_time.elapsed());
}

2 changes: 1 addition & 1 deletion example/mirror.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ async fn main() {

let _filename = api
.model("ByteDance/Hyper-SD".to_string())
.get("Hyper-SDXL-8steps-lora.safetensors")
.get("Hyper-SDXL-1step-Unet-Comfyui.fp16.safetensors")
.await
.unwrap();
}
9 changes: 9 additions & 0 deletions example/rdd/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[package]
name = "rdd"
version = "0.1.0"
edition = "2021"

[dependencies]
reqwest = { version = "0.12.4", default-features = false, features = ["stream", "http2", "rustls-tls"] }
tokio = { version = "1.37.0", default-features = false, features = ["rt", "fs"] }
tokio-stream = "0.1.15"
63 changes: 63 additions & 0 deletions example/rdd/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use tokio;
use reqwest::header::{RANGE, CONTENT_RANGE};
use tokio::io::{AsyncSeekExt, SeekFrom};
use tokio_stream::StreamExt;

const CHUNK_SIZE: usize = 10_000_000;
const URL: &str = "https://huggingface.co/ByteDance/Hyper-SD/resolve/main/Hyper-SDXL-1step-Unet-Comfyui.fp16.safetensors";
const FILE: &str = "fp16.safetensors";

async fn download_chunk(s: usize, e: usize) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let client = reqwest::Client::builder()
.http2_keep_alive_timeout(tokio::time::Duration::from_secs(15)).build()?;
let range = format!("bytes={s}-{e}");

let response = client.get(URL).header(RANGE, range).send().await?;
let mut stream = response.bytes_stream();

let mut file = tokio::fs::OpenOptions::new().write(true).open(FILE).await?;
file.seek(SeekFrom::Start(s as u64)).await?;
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
tokio::io::copy(&mut chunk.as_ref(), &mut file).await?;
}
Ok(())
}

async fn download() -> Result<(), Box<dyn std::error::Error>> {
let client = reqwest::Client::builder()
.http2_keep_alive_timeout(tokio::time::Duration::from_secs(15)).build()?;

let response = client.get(URL).header(RANGE, "bytes=0-0").send().await?;
let length: usize = response
.headers()
.get(CONTENT_RANGE)
.ok_or("Content-Length not found")?
.to_str()?.rsplit('/').next()
.and_then(|s| s.parse().ok())
.ok_or("Failed to parse size")?;

let _ = tokio::fs::File::create(FILE).await?.set_len(length as u64).await?;

let tasks: Vec<_> = (0..length)
.into_iter()
.step_by(CHUNK_SIZE)
.map(|s| {
let e = std::cmp::min(s + CHUNK_SIZE - 1, length);
tokio::spawn(async move { download_chunk(s, e).await })
})
.collect();

for task in tasks {
let _ = task.await.unwrap();
}
Ok(())
}

fn main() {
let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();

let start_time = std::time::Instant::now();
let _ = rt.block_on(download());
println!("Processing time: {:?}", start_time.elapsed());
}
15 changes: 15 additions & 0 deletions example/simple.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use tokio;

#[tokio::main]
async fn main() {
let start_time = std::time::Instant::now();

let api = libhfd::api::tokio::Api::new().unwrap();

let _filename = api
.model("ByteDance/Hyper-SD".to_string())
.get("Hyper-SDXL-8steps-lora.safetensors")
.await
.unwrap();
println!("Processing time: {:?}", start_time.elapsed());
}
7 changes: 4 additions & 3 deletions hfd-cli/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
[package]
name = "hfd-cli"
version = "0.1.8"
version = "0.2.0"
edition = "2021"

[dependencies]
tokio = { version = "1.37.0", default-features = false, features = ["rt", "rt-multi-thread", "fs"] }
clap = { version= "4.5.4", features=["derive"] }
reqwest = { version = "0.12.4", default-features = false, features = ["stream", "http2", "json", "rustls-tls"] }
tokio = { version = "1.37.0", default-features = false, features = ["rt", "fs"] }
serde_json = { version = "1.0.116" }
tokio-stream = "0.1.15"
libhfd = { path = "../hfd" }

[[bin]]
name = "hfd"
Expand Down
41 changes: 29 additions & 12 deletions hfd-cli/src/cli.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,33 @@
use tokio;
// use tokio_stream::StreamExt;
use clap::{Args, Parser};
use hfd_cli::_rt;

#[tokio::main]
async fn main() {
let start_time = std::time::Instant::now();
#[derive(Args)]
#[group(required = false, multiple = true)]
struct Opts {
#[arg(short = 't', long, name = "TOKEN")]
token: Option<String>,

#[arg(short = 'd', long, name = "DIR", help = "Save it to `$DIR` or `.` ")]
dir: Option<String>,
#[arg(short = 'm', long, name = "MIRROR", help = "Not yet applied")]
mirror: Option<String>,
#[arg(short = 'p', long, name = "PROXY", help = "Not yet applied")]
proxy: Option<String>,
}

#[derive(Parser)]
struct Cli {
url: String,

let api = libhfd::api::tokio::Api::new().unwrap();
#[command(flatten)]
opt: Opts,
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
let start_time = std::time::Instant::now();

let _filename = api
.model("ByteDance/Hyper-SD".to_string())
.get("Hyper-SDXL-8steps-lora.safetensors")
.await
.unwrap();
let cli = Cli::parse();
let _ = _rt(&cli.url, cli.opt.token.as_deref(), cli.opt.dir.as_deref());
println!("Processing time: {:?}", start_time.elapsed());
}
Ok(())
}
Loading

0 comments on commit 6332e4f

Please sign in to comment.