Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve custom URL scheme handler robustness #1440

Merged
merged 3 commits into from
Jan 24, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 71 additions & 41 deletions src/wkwebview/class/url_scheme_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ extern "C" fn start_task(
) {
unsafe {
#[cfg(feature = "tracing")]
let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty)
.entered();
let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty)
.entered();

let task_key = task.hash(); // hash by task object address
let task_uuid = webview.add_custom_task_key(task_key);
Expand Down Expand Up @@ -122,7 +122,6 @@ extern "C" fn start_task(
if let Some(all_headers) = all_headers {
for current_header in all_headers.allKeys().to_vec() {
let header_value = all_headers.valueForKey(current_header).unwrap();

// inject the header into the request
http_request = http_request.header(current_header.to_string(), header_value.to_string());
}
Expand All @@ -145,37 +144,52 @@ extern "C" fn start_task(
task.didFinish();
};

fn check_webview_id_valid(webview_id: &str) -> crate::Result<()> {
if !WEBVIEW_IDS.lock().unwrap().contains(webview_id) {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
Ok(())
}

/// Task may not live longer than async custom protocol handler.
///
/// There are roughly 2 ways to cause segfault:
/// 1. Task has stopped. pointer of the task not valid anymore.
/// 2. Task had stopped, but the pointer of the task has allocated to a new task.
/// Outdated custom handler may call to the new task instance and cause segfault.
fn check_task_is_valid(
webview: &WryWebView,
task_key: usize,
current_uuid: Retained<NSUUID>,
) -> crate::Result<()> {
let latest_task_uuid = webview.get_custom_task_uuid(task_key);
if let Some(latest_uuid) = latest_task_uuid {
if latest_uuid != current_uuid {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
} else {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
Ok(())
}

// send response
match http_request.body(sent_form_body) {
Ok(final_request) => {
let responder: Box<dyn FnOnce(HttpResponse<Cow<'static, [u8]>>)> =
Box::new(move |sent_response| {
fn check_webview_id_valid(webview_id: &str) -> crate::Result<()> {
if !WEBVIEW_IDS.lock().unwrap().contains(webview_id) {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
Ok(())
}
/// Task may not live longer than async custom protocol handler.
///
/// There are roughly 2 ways to cause segfault:
/// 1. Task has stopped. pointer of the task not valid anymore.
/// 2. Task had stopped, but the pointer of the task has allocated to a new task.
/// Outdated custom handler may call to the new task instance and cause segfault.
fn check_task_is_valid(
webview: &WryWebView,
task_key: usize,
current_uuid: Retained<NSUUID>,
) -> crate::Result<()> {
let latest_task_uuid = webview.get_custom_task_uuid(task_key);
if let Some(latest_uuid) = latest_task_uuid {
if latest_uuid != current_uuid {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
} else {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
// Consolidate checks before calling into `did*` methods.
let validate = || -> crate::Result<()> {
check_webview_id_valid(webview_id)?;
check_task_is_valid(webview, task_key, task_uuid.clone())?;
Ok(())
};

// Perform an upfront validation
if let Err(e) = validate() {
#[cfg(feature = "tracing")]
tracing::warn!("Task invalid before sending response: {:?}", e);
return; // If invalid, return early without calling task methods.
}

unsafe fn response(
Expand All @@ -189,7 +203,9 @@ extern "C" fn start_task(
url: Retained<NSURL>,
sent_response: HttpResponse<Cow<'_, [u8]>>,
) -> crate::Result<()> {
check_task_is_valid(&*webview, task_key, task_uuid.clone())?;
// Validate
check_webview_id_valid(webview_id)?;
check_task_is_valid(webview, task_key, task_uuid.clone())?;

let content = sent_response.body();
// default: application/octet-stream, but should be provided by the client
Expand All @@ -200,7 +216,6 @@ extern "C" fn start_task(
let wanted_version = format!("{:#?}", sent_response.version());

let mut headers = NSMutableDictionary::new();

if let Some(mime) = wanted_mime {
headers.insert_id(
NSString::from_str(CONTENT_TYPE.as_str()).as_ref(),
Expand Down Expand Up @@ -232,34 +247,42 @@ extern "C" fn start_task(
)
.unwrap();

// Re-validate before calling didReceiveResponse
check_webview_id_valid(webview_id)?;
check_task_is_valid(&*webview, task_key, task_uuid.clone())?;
check_task_is_valid(webview, task_key, task_uuid.clone())?;

// Use map_err to convert Option<Retained<Exception>> to crate::Error
objc2::exception::catch(AssertUnwindSafe(|| {
task.didReceiveResponse(&response);
}))
.unwrap();
.map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;

// Send data
let bytes = content.as_ptr() as *mut c_void;
let data = NSData::alloc();
// MIGRATE NOTE: we copied the content to the NSData because content will be freed
// when out of scope but NSData will also free the content when it's done and cause doube free.
let data = NSData::initWithBytes_length(data, bytes, content.len());
let data = NSData::initWithBytes_length(
data,
content.as_ptr() as *mut c_void,
content.len(),
);

// Check validity again
check_webview_id_valid(webview_id)?;
check_task_is_valid(&*webview, task_key, task_uuid.clone())?;
check_task_is_valid(webview, task_key, task_uuid.clone())?;

objc2::exception::catch(AssertUnwindSafe(|| {
task.didReceiveData(&data);
}))
.unwrap();
.map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;

// Finish
check_webview_id_valid(webview_id)?;
check_task_is_valid(&*webview, task_key, task_uuid.clone())?;
check_task_is_valid(webview, task_key, task_uuid.clone())?;

objc2::exception::catch(AssertUnwindSafe(|| {
task.didFinish();
}))
.unwrap();
.map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;

{
let ids = WEBVIEW_IDS.lock().unwrap();
Expand All @@ -272,15 +295,21 @@ extern "C" fn start_task(
}
}

let _ = response(
#[cfg(feature = "tracing")]
let _span = tracing::info_span!("wry::custom_protocol::call_handler").entered();

if let Err(e) = response(
task,
webview,
task_key,
task_uuid,
webview_id,
url.clone(),
sent_response,
);
) {
#[cfg(feature = "tracing")]
tracing::error!("Error responding to task: {:?}", e);
}
});

#[cfg(feature = "tracing")]
Expand All @@ -301,6 +330,7 @@ extern "C" fn start_task(
}
}
}

extern "C" fn stop_task(
_this: &ProtocolObject<dyn WKURLSchemeHandler>,
_sel: objc2::runtime::Sel,
Expand Down
Loading