From c96f87cbde738b99b0bdaf00339f1ba780d2a751 Mon Sep 17 00:00:00 2001 From: deedy5 <65482418+deedy5@users.noreply.github.com> Date: Wed, 1 May 2024 18:24:50 +0300 Subject: [PATCH] Migrate to async reqwest_impersonate::Client, use rayon ThreadPool and Tokio multi-thread runtime --- Cargo.lock | 64 ++++++++--- Cargo.toml | 3 +- src/lib.rs | 300 ++++++++++++++++++++++++++++-------------------- src/response.rs | 2 +- 4 files changed, 231 insertions(+), 138 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d314f0c..4a58358 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -201,9 +201,9 @@ checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" [[package]] name = "cc" -version = "1.0.95" +version = "1.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d32a725bc159af97c3e629873bb9f88fb8cf8a4867175f76dc987815ea07c83b" +checksum = "065a29261d53ba54260972629f9ca6bffa69bac13cd1fed61420f7fa68b9f8bd" [[package]] name = "cexpr" @@ -293,6 +293,31 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" + [[package]] name = "deranged" version = "0.3.11" @@ -390,12 +415,6 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" -[[package]] -name = "futures-io" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" - [[package]] name = "futures-sink" version = "0.3.30" @@ -415,12 +434,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ "futures-core", - "futures-io", "futures-task", - "memchr", "pin-project-lite", "pin-utils", - "slab", ] [[package]] @@ -918,8 +934,10 @@ dependencies = [ "encoding_rs", "pyo3", "pythonize", + "rayon", "reqwest-impersonate", "serde_json", + "tokio", ] [[package]] @@ -941,6 +959,26 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.1" @@ -981,9 +1019,9 @@ checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" [[package]] name = "reqwest-impersonate" -version = "0.11.72" +version = "0.11.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b725314a794f526bac22292a10e2bf581cbea5ebd527fa3032e2cfd829068be" +checksum = "78ba0d953bf10f7aa5d454bc8d9136efd94c3cfb4d6db4f02323b54957f7310f" dependencies = [ "async-compression", "base64", diff --git a/Cargo.toml b/Cargo.toml index 582bb61..8a63da5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ crate-type = ["cdylib"] pyo3 = { version = "0.21", features = ["extension-module", "abi3-py38"] } reqwest-impersonate = { version = "0.11", default-features = false, features = [ "cookies", - "blocking", "boring-tls", "impersonate", "json", @@ -28,6 +27,8 @@ reqwest-impersonate = { version = "0.11", default-features = false, features = [ encoding_rs = "0.8" pythonize = "0.21" serde_json = "1" +tokio = { version = "1", features = ["fs", "rt-multi-thread"] } +rayon = "1" [profile.release] codegen-units = 1 diff --git a/src/lib.rs b/src/lib.rs index e6929b6..0c9d77c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,24 +1,43 @@ use std::collections::HashMap; use std::str::FromStr; +use std::sync::{mpsc, Arc, OnceLock}; use std::time::Duration; use pyo3::exceptions; use pyo3::prelude::*; -use pyo3::types::PyBytes; -use pyo3::types::{PyDict, PyString}; -use reqwest_impersonate::blocking::multipart; +use pyo3::types::{PyBytes, PyDict, PyString}; +use rayon::{ThreadPool, ThreadPoolBuilder}; use reqwest_impersonate::header::{HeaderMap, HeaderName, HeaderValue}; use reqwest_impersonate::impersonate::Impersonate; +use reqwest_impersonate::multipart; use reqwest_impersonate::redirect::Policy; use reqwest_impersonate::Method; +use tokio::runtime::{self, Runtime}; mod response; use response::Response; +// Rayon global thread pool +fn cpu_pool() -> &'static ThreadPool { + static CPU_POOL: OnceLock = OnceLock::new(); + CPU_POOL.get_or_init(|| ThreadPoolBuilder::new().num_threads(32).build().unwrap()) +} + +// Tokio global multi-thread runtime +fn runtime() -> &'static Runtime { + static RUNTIME: OnceLock = OnceLock::new(); + RUNTIME.get_or_init(|| { + runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() + }) +} + #[pyclass] /// HTTP client that can impersonate web browsers. -struct Client { - client: reqwest_impersonate::blocking::Client, +pub struct Client { + client: Arc, auth: Option<(String, Option)>, auth_bearer: Option, params: Option>, @@ -94,10 +113,10 @@ impl Client { )); } - let mut client_builder = reqwest_impersonate::blocking::Client::builder() + // Client builder + let mut client_builder = reqwest_impersonate::Client::builder() .enable_ech_grease(true) - .permute_extensions(true) - .timeout(timeout.map(Duration::from_secs_f64)); + .permute_extensions(true); // Headers if let Some(headers) = headers { @@ -132,6 +151,11 @@ impl Client { client_builder = client_builder.proxy(proxy); } + // Timeout + if let Some(seconds) = timeout { + client_builder = client_builder.timeout(Duration::from_secs_f64(seconds)); + } + // Impersonate if let Some(impersonation_type) = impersonate { let impersonation = Impersonate::from_str(impersonation_type).map_err(|_| { @@ -166,9 +190,10 @@ impl Client { _ => (), } - let client = client_builder - .build() - .map_err(|_| PyErr::new::("Failed to build client"))?; + let client = + Arc::new(client_builder.build().map_err(|_| { + PyErr::new::("Failed to build client") + })?); Ok(Client { client, @@ -215,115 +240,115 @@ impl Client { auth_bearer: Option, timeout: Option, ) -> PyResult { - // Check if method is POST || PUT || PATCH - let is_post_put_patch = method == "POST" || method == "PUT" || method == "PATCH"; - - // Method - let method = match method { - "GET" => Ok(Method::GET), - "POST" => Ok(Method::POST), - "HEAD" => Ok(Method::HEAD), - "OPTIONS" => Ok(Method::OPTIONS), - "PUT" => Ok(Method::PUT), - "PATCH" => Ok(Method::PATCH), - "DELETE" => Ok(Method::DELETE), - &_ => Err(PyErr::new::( - "Unrecognized HTTP method", - )), - }; - let method = method?; - - // Create request builder - let mut request_builder = self.client.request(method, url); - - // Params (use the provided `params` if available; otherwise, fall back to `self.params`) - let params_to_use = params.or(self.params.clone()).unwrap_or_default(); - if !params_to_use.is_empty() { - request_builder = request_builder.query(¶ms_to_use); - } - - // Headers - if let Some(headers) = headers { - let mut headers_new = HeaderMap::new(); - for (key, value) in headers { - headers_new.insert( - HeaderName::from_bytes(key.as_bytes()).map_err(|_| { - PyErr::new::("Invalid header name") - })?, - HeaderValue::from_str(&value).map_err(|_| { - PyErr::new::("Invalid header value") - })?, - ); + let client = Arc::clone(&self.client); + let method = method.to_owned(); + let url = url.to_owned(); + let auth = auth.or(self.auth.clone()); + let auth_bearer = auth_bearer.or(self.auth_bearer.clone()); + let params = params.or(self.params.clone()); + + let future = async move { + // Check if method is POST || PUT || PATCH + let is_post_put_patch = method == "POST" || method == "PUT" || method == "PATCH"; + + // Method + let method = match method.as_str() { + "GET" => Ok(Method::GET), + "POST" => Ok(Method::POST), + "HEAD" => Ok(Method::HEAD), + "OPTIONS" => Ok(Method::OPTIONS), + "PUT" => Ok(Method::PUT), + "PATCH" => Ok(Method::PATCH), + "DELETE" => Ok(Method::DELETE), + &_ => Err(PyErr::new::( + "Unrecognized HTTP method", + )), + }?; + + // Create request builder + let mut request_builder = client.request(method, url); + + // Params + if let Some(params) = params { + request_builder = request_builder.query(¶ms); } - request_builder = request_builder.headers(headers_new); - } - // Only if method POST || PUT || PATCH - if is_post_put_patch { - // Content - if let Some(content) = content { - request_builder = request_builder.body(content); - } - // Data - if let Some(data) = data { - request_builder = request_builder.form(&data); - } - // Files - if let Some(files) = files { - let mut form = multipart::Form::new(); - for (field, path) in files { - form = form.file(field, path)?; + // Headers + if let Some(headers) = headers { + let mut headers_new = HeaderMap::new(); + for (key, value) in headers { + headers_new.insert( + HeaderName::from_bytes(key.as_bytes()).map_err(|_| { + PyErr::new::("Invalid header name") + })?, + HeaderValue::from_str(&value).map_err(|_| { + PyErr::new::("Invalid header value") + })?, + ); } - request_builder = request_builder.multipart(form); + request_builder = request_builder.headers(headers_new); } - } - // Auth - match (auth, auth_bearer, &self.auth, &self.auth_bearer) { - (Some((username, password)), None, _, _) => { - request_builder = request_builder.basic_auth(username, password.as_deref()); - } - (None, Some(token), _, _) => { - request_builder = request_builder.bearer_auth(token); - } - (None, None, Some((username, password)), None) => { - request_builder = request_builder.basic_auth(username, password.as_deref()); - } - (None, None, None, Some(token)) => { - request_builder = request_builder.bearer_auth(token); + // Only if method POST || PUT || PATCH + if is_post_put_patch { + // Content + if let Some(content) = content { + request_builder = request_builder.body(content); + } + // Data + if let Some(data) = data { + request_builder = request_builder.form(&data); + } + // Files + if let Some(files) = files { + let mut form = multipart::Form::new(); + for (field, path) in files { + let file_content = tokio::fs::read(&path).await.map_err(|e| { + PyErr::new::(format!( + "Error reading file {}: {}", + path, e + )) + })?; + let part = multipart::Part::bytes(file_content); + form = form.part(field, part); + } + request_builder = request_builder.multipart(form); + } } - (Some(_), Some(_), None, None) | (None, None, Some(_), Some(_)) => { - return Err(PyErr::new::( - "Cannot provide both auth and auth_bearer", - )); + + // Auth + match (auth, auth_bearer) { + (Some((username, password)), None) => { + request_builder = request_builder.basic_auth(username, password.as_deref()); + } + (None, Some(token)) => { + request_builder = request_builder.bearer_auth(token); + } + (Some(_), Some(_)) => { + return Err(PyErr::new::( + "Cannot provide both auth and auth_bearer", + )); + } + _ => {} // No authentication provided } - _ => {} // No authentication provided - } - // Timeout - if let Some(seconds) = timeout { - request_builder = request_builder.timeout(Duration::from_secs_f64(seconds)); - } + // Timeout + if let Some(seconds) = timeout { + request_builder = request_builder.timeout(Duration::from_secs_f64(seconds)); + } - // Send request | release GIL - let resp = py.allow_threads(|| { - request_builder.send().map_err(|e| { + // Send the request and await the response + let resp = request_builder.send().await.map_err(|e| { PyErr::new::(format!("Error in request: {}", e)) - }) - })?; - - // Response items - let cookies_dict = PyDict::new_bound(py); - for cookie in resp.cookies() { - let key = cookie.name().to_string(); - let value = cookie.value().to_string(); - cookies_dict.set_item(key, value)?; - } - let cookies = cookies_dict.unbind(); + })?; - // Encoding from "Content-Type" header or "UTF-8" - let encoding = { - let encoding_str = resp + // Response items + let cookies: HashMap = resp + .cookies() + .map(|cookie| (cookie.name().to_string(), cookie.value().to_string())) + .collect(); + // Encoding from "Content-Type" header or "UTF-8" + let encoding = resp .headers() .get("Content-Type") .and_then(|ct| ct.to_str().ok()) @@ -340,25 +365,54 @@ impl Client { }) }) .unwrap_or("UTF-8".to_string()); - PyString::new_bound(py, &encoding_str).unbind() + let headers: HashMap = resp + .headers() + .iter() + .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string())) + .collect(); + let status_code = resp.status().as_u16(); + let url = resp.url().to_string(); + let buf = resp.bytes().await.map_err(|e| { + PyErr::new::(format!( + "Error reading response bytes: {}", + e + )) + })?; + Ok((buf, cookies, encoding, headers, status_code, url)) }; + // Execute an async future in Python, releasing the GIL for concurrency. + // Uses Rayon's global thread pool and Tokio global runtime to block on the future. + let (tx, rx) = mpsc::sync_channel(1); + py.allow_threads(|| { + cpu_pool().install(|| { + let result = runtime().block_on(future); + _ = tx.send(result); + }); + }); + let result = rx.recv().map_err(|e| { + PyErr::new::(format!("Error executing future: {}", e)) + })?; + let (f_buf, f_cookies, f_encoding, f_headers, f_status_code, f_url) = match result { + Ok(value) => value, + Err(e) => return Err(e), + }; + + // Response items + let cookies_dict = PyDict::new_bound(py); + for (key, value) in f_cookies { + cookies_dict.set_item(key, value)?; + } + let cookies = cookies_dict.unbind(); + let encoding = PyString::new_bound(py, f_encoding.as_str()).unbind(); let headers_dict = PyDict::new_bound(py); - for (key, value) in resp.headers().iter() { - let key_str = key.as_str(); - let value_str = value.to_str().unwrap_or(""); - headers_dict.set_item(key_str, value_str)?; + for (key, value) in f_headers { + headers_dict.set_item(key, value)?; } let headers = headers_dict.unbind(); - - let status_code = resp.status().as_u16().into_py(py); - - let url = PyString::new_bound(py, resp.url().as_ref()).into(); - - let buf = resp.bytes().map_err(|e| { - PyErr::new::(format!("Error reading response bytes: {}", e)) - })?; - let content = PyBytes::new_bound(py, &buf).unbind(); + let status_code = f_status_code.into_py(py); + let url = PyString::new_bound(py, &f_url).unbind(); + let content = PyBytes::new_bound(py, &f_buf).unbind(); Ok(Response { content, diff --git a/src/response.rs b/src/response.rs index 5d3948b..ac8e9d9 100644 --- a/src/response.rs +++ b/src/response.rs @@ -22,7 +22,7 @@ pub struct Response { #[pyo3(get)] pub status_code: Py, #[pyo3(get)] - pub url: Py, + pub url: Py, } #[pymethods]