Skip to content

Commit

Permalink
Use anyhow for error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
deedy5 committed Aug 13, 2024
1 parent 99ed93b commit 5afc54d
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 78 deletions.
8 changes: 8 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ name = "primp"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.22", features = ["extension-module", "abi3-py38", "indexmap"] }
pyo3 = { version = "0.22", features = ["extension-module", "abi3-py38", "indexmap", "anyhow"] }
anyhow = "1"
rquest = { version = "0.20", default-features = false, features = [
"boring-tls",
"http2",
Expand Down
103 changes: 34 additions & 69 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::sync::{Arc, OnceLock};
use std::time::Duration;

use ahash::RandomState;
use anyhow::{anyhow, Result};
use indexmap::IndexMap;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyString};
use rquest::header::{HeaderMap, HeaderName, HeaderValue, COOKIE};
Expand Down Expand Up @@ -109,11 +109,9 @@ impl Client {
verify: Option<bool>,
http1: Option<bool>,
http2: Option<bool>,
) -> PyResult<Self> {
) -> Result<Self> {
if auth.is_some() && auth_bearer.is_some() {
return Err(PyErr::new::<exceptions::PyValueError, _>(
"Cannot provide both auth and auth_bearer",
));
return Err(anyhow!("Cannot provide both auth and auth_bearer"));
}

// Client builder
Expand All @@ -123,9 +121,7 @@ impl Client {

// Impersonate
if let Some(impersonation_type) = impersonate {
let impersonation = Impersonate::from_str(impersonation_type).map_err(|_| {
PyErr::new::<exceptions::PyValueError, _>("Invalid impersonate param")
})?;
let impersonation = Impersonate::from_str(impersonation_type).map_err(|e| anyhow!(e))?;
client_builder = client_builder.impersonate(impersonation);
}

Expand All @@ -134,12 +130,8 @@ impl Client {
let mut headers_new = HeaderMap::with_capacity(headers.len());
for (key, value) in headers {
headers_new.insert(
HeaderName::from_bytes(key.as_bytes()).map_err(|_| {
PyErr::new::<exceptions::PyValueError, _>("Invalid header name")
})?,
HeaderValue::from_str(&value).map_err(|_| {
PyErr::new::<exceptions::PyValueError, _>("Invalid header value")
})?,
HeaderName::from_bytes(key.as_bytes())?,
HeaderValue::from_str(&value)?,
);
}
client_builder = client_builder.default_headers(headers_new);
Expand All @@ -157,8 +149,7 @@ impl Client {

// Proxy
if let Some(proxy_url) = proxy {
let proxy = rquest::Proxy::all(proxy_url)
.map_err(|_| PyErr::new::<exceptions::PyValueError, _>("Invalid proxy URL"))?;
let proxy = rquest::Proxy::all(proxy_url)?;
client_builder = client_builder.proxy(proxy);
}

Expand All @@ -168,35 +159,29 @@ impl Client {
}

// Redirects
let max_redirects = max_redirects.unwrap_or(20); // Default to 20 if not provided
let max_redirects = max_redirects.unwrap_or(20);
if follow_redirects.unwrap_or(true) {
client_builder = client_builder.redirect(Policy::limited(max_redirects));
} else {
client_builder = client_builder.redirect(Policy::none());
}

// Verify
let verify = verify.unwrap_or(true);
let verify: bool = verify.unwrap_or(true);
if !verify {
client_builder = client_builder.danger_accept_invalid_certs(true);
}

// Http version: http1 || http2
match (http1, http2) {
(Some(true), Some(true)) => {
return Err(PyErr::new::<exceptions::PyValueError, _>(
"Both http1 and http2 cannot be true",
));
}
(Some(true), Some(true)) => return Err(anyhow!("Both http1 and http2 cannot be true")),
(Some(true), _) => client_builder = client_builder.http1_only(),
(_, Some(true)) => client_builder = client_builder.http2_prior_knowledge(),
_ => (),
}

let client =
Arc::new(client_builder.build().map_err(|_| {
PyErr::new::<exceptions::PyValueError, _>("Failed to build client")
})?);
Arc::new(client_builder.build()?);

Ok(Client {
client,
Expand Down Expand Up @@ -249,7 +234,7 @@ impl Client {
auth: Option<(String, Option<String>)>,
auth_bearer: Option<String>,
timeout: Option<f64>,
) -> PyResult<Response> {
) -> Result<Response> {
let client = Arc::clone(&self.client);
let auth = auth.or(self.auth.clone());
let auth_bearer = auth_bearer.or(self.auth_bearer.clone());
Expand Down Expand Up @@ -277,9 +262,7 @@ impl Client {
"PUT" => Ok(Method::PUT),
"PATCH" => Ok(Method::PATCH),
"DELETE" => Ok(Method::DELETE),
&_ => Err(PyErr::new::<exceptions::PyException, _>(
"Unrecognized HTTP method",
)),
&_ => Err(anyhow!("Unrecognized HTTP method")),
}?;

// Create request builder
Expand All @@ -295,12 +278,8 @@ impl Client {
let mut headers_new = HeaderMap::with_capacity(headers.len());
for (key, value) in headers {
headers_new.insert(
HeaderName::from_bytes(key.as_bytes()).map_err(|_| {
PyErr::new::<exceptions::PyValueError, _>("Invalid header name")
})?,
HeaderValue::from_str(&value).map_err(|_| {
PyErr::new::<exceptions::PyValueError, _>("Invalid header value")
})?,
HeaderName::from_bytes(key.as_bytes())?,
HeaderValue::from_str(&value)?,
);
}
request_builder = request_builder.headers(headers_new);
Expand Down Expand Up @@ -355,11 +334,7 @@ impl Client {
(None, Some(token)) => {
request_builder = request_builder.bearer_auth(token);
}
(Some(_), Some(_)) => {
return Err(PyErr::new::<exceptions::PyValueError, _>(
"Cannot provide both auth and auth_bearer",
));
}
(Some(_), Some(_)) => return Err(anyhow!("Cannot provide both auth and auth_bearer")),
_ => {} // No authentication provided
}

Expand All @@ -369,9 +344,7 @@ impl Client {
}

// Send the request and await the response
let resp = request_builder.send().await.map_err(|e| {
PyErr::new::<exceptions::PyException, _>(format!("Error in request: {}", e))
})?;
let resp = request_builder.send().await?;

// Response items
let cookies: IndexMap<String, String, RandomState> = resp
Expand All @@ -385,12 +358,7 @@ impl Client {
.collect();
let status_code = resp.status().as_u16();
let url = resp.url().to_string();
let buf = resp.bytes().await.map_err(|e| {
PyErr::new::<exceptions::PyException, _>(format!(
"Error reading response bytes: {}",
e
))
})?;
let buf = resp.bytes().await?;
let encoding = get_encoding_from_headers(&headers)
.or_else(|| get_encoding_from_content(&buf))
.unwrap_or_else(|| "UTF-8".to_string());
Expand All @@ -400,10 +368,7 @@ impl Client {
// Execute an async future, releasing the Python GIL for concurrency.
// Use Tokio global runtime to block on the future.
let result = py.allow_threads(|| runtime().block_on(future));
let (f_buf, f_cookies, f_encoding, f_headers, f_status_code, f_url) = match result {
Ok(value) => value,
Err(e) => return Err(e),
};
let (f_buf, f_cookies, f_encoding, f_headers, f_status_code, f_url) = result?;

// Response items
let cookies_dict = PyDict::new_bound(py);
Expand Down Expand Up @@ -442,7 +407,7 @@ impl Client {
auth: Option<(String, Option<String>)>,
auth_bearer: Option<String>,
timeout: Option<f64>,
) -> PyResult<Response> {
) -> Result<Response> {
self.request(
py,
"GET",
Expand Down Expand Up @@ -471,7 +436,7 @@ impl Client {
auth: Option<(String, Option<String>)>,
auth_bearer: Option<String>,
timeout: Option<f64>,
) -> PyResult<Response> {
) -> Result<Response> {
self.request(
py,
"HEAD",
Expand Down Expand Up @@ -500,7 +465,7 @@ impl Client {
auth: Option<(String, Option<String>)>,
auth_bearer: Option<String>,
timeout: Option<f64>,
) -> PyResult<Response> {
) -> Result<Response> {
self.request(
py,
"OPTIONS",
Expand Down Expand Up @@ -529,7 +494,7 @@ impl Client {
auth: Option<(String, Option<String>)>,
auth_bearer: Option<String>,
timeout: Option<f64>,
) -> PyResult<Response> {
) -> Result<Response> {
self.request(
py,
"DELETE",
Expand Down Expand Up @@ -563,7 +528,7 @@ impl Client {
auth: Option<(String, Option<String>)>,
auth_bearer: Option<String>,
timeout: Option<f64>,
) -> PyResult<Response> {
) -> Result<Response> {
self.request(
py,
"POST",
Expand Down Expand Up @@ -597,7 +562,7 @@ impl Client {
auth: Option<(String, Option<String>)>,
auth_bearer: Option<String>,
timeout: Option<f64>,
) -> PyResult<Response> {
) -> Result<Response> {
self.request(
py,
"PUT",
Expand Down Expand Up @@ -631,7 +596,7 @@ impl Client {
auth: Option<(String, Option<String>)>,
auth_bearer: Option<String>,
timeout: Option<f64>,
) -> PyResult<Response> {
) -> Result<Response> {
self.request(
py,
"PATCH",
Expand Down Expand Up @@ -670,7 +635,7 @@ fn request(
timeout: Option<f64>,
impersonate: Option<&str>,
verify: Option<bool>,
) -> PyResult<Response> {
) -> Result<Response> {
let client = Client::new(
None,
None,
Expand Down Expand Up @@ -719,7 +684,7 @@ fn get(
timeout: Option<f64>,
impersonate: Option<&str>,
verify: Option<bool>,
) -> PyResult<Response> {
) -> Result<Response> {
let client = Client::new(
None,
None,
Expand Down Expand Up @@ -763,7 +728,7 @@ fn head(
timeout: Option<f64>,
impersonate: Option<&str>,
verify: Option<bool>,
) -> PyResult<Response> {
) -> Result<Response> {
let client = Client::new(
None,
None,
Expand Down Expand Up @@ -807,7 +772,7 @@ fn options(
timeout: Option<f64>,
impersonate: Option<&str>,
verify: Option<bool>,
) -> PyResult<Response> {
) -> Result<Response> {
let client = Client::new(
None,
None,
Expand Down Expand Up @@ -851,7 +816,7 @@ fn delete(
timeout: Option<f64>,
impersonate: Option<&str>,
verify: Option<bool>,
) -> PyResult<Response> {
) -> Result<Response> {
let client = Client::new(
None,
None,
Expand Down Expand Up @@ -899,7 +864,7 @@ fn post(
timeout: Option<f64>,
impersonate: Option<&str>,
verify: Option<bool>,
) -> PyResult<Response> {
) -> Result<Response> {
let client = Client::new(
None,
None,
Expand Down Expand Up @@ -951,7 +916,7 @@ fn put(
timeout: Option<f64>,
impersonate: Option<&str>,
verify: Option<bool>,
) -> PyResult<Response> {
) -> Result<Response> {
let client = Client::new(
None,
None,
Expand Down Expand Up @@ -1003,7 +968,7 @@ fn patch(
timeout: Option<f64>,
impersonate: Option<&str>,
verify: Option<bool>,
) -> PyResult<Response> {
) -> Result<Response> {
let client = Client::new(
None,
None,
Expand Down
Loading

0 comments on commit 5afc54d

Please sign in to comment.