diff --git a/src/wkwebview/class/url_scheme_handler.rs b/src/wkwebview/class/url_scheme_handler.rs index 07bb514af..e1741091c 100644 --- a/src/wkwebview/class/url_scheme_handler.rs +++ b/src/wkwebview/class/url_scheme_handler.rs @@ -59,8 +59,8 @@ extern "C" fn start_task( ) { unsafe { #[cfg(feature = "tracing")] - let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty) - .entered(); + let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty) + .entered(); let task_key = task.hash(); // hash by task object address let task_uuid = webview.add_custom_task_key(task_key); @@ -122,7 +122,6 @@ extern "C" fn start_task( if let Some(all_headers) = all_headers { for current_header in all_headers.allKeys().to_vec() { let header_value = all_headers.valueForKey(current_header).unwrap(); - // inject the header into the request http_request = http_request.header(current_header.to_string(), header_value.to_string()); } @@ -145,43 +144,57 @@ extern "C" fn start_task( task.didFinish(); }; + fn check_webview_id_valid(webview_id: &str) -> crate::Result<()> { + if !WEBVIEW_IDS.lock().unwrap().contains(webview_id) { + return Err(crate::Error::CustomProtocolTaskInvalid); + } + Ok(()) + } + + /// Task may not live longer than async custom protocol handler. + /// + /// There are roughly 2 ways to cause segfault: + /// 1. Task has stopped. pointer of the task not valid anymore. + /// 2. Task had stopped, but the pointer of the task has allocated to a new task. + /// Outdated custom handler may call to the new task instance and cause segfault. + fn check_task_is_valid( + webview: &WryWebView, + task_key: usize, + current_uuid: Retained, + ) -> crate::Result<()> { + let latest_task_uuid = webview.get_custom_task_uuid(task_key); + if let Some(latest_uuid) = latest_task_uuid { + if latest_uuid != current_uuid { + return Err(crate::Error::CustomProtocolTaskInvalid); + } + } else { + return Err(crate::Error::CustomProtocolTaskInvalid); + } + Ok(()) + } + // send response match http_request.body(sent_form_body) { Ok(final_request) => { let responder: Box>)> = Box::new(move |sent_response| { - fn check_webview_id_valid(webview_id: &str) -> crate::Result<()> { - if !WEBVIEW_IDS.lock().unwrap().contains(webview_id) { - return Err(crate::Error::CustomProtocolTaskInvalid); - } - Ok(()) - } - /// Task may not live longer than async custom protocol handler. - /// - /// There are roughly 2 ways to cause segfault: - /// 1. Task has stopped. pointer of the task not valid anymore. - /// 2. Task had stopped, but the pointer of the task has allocated to a new task. - /// Outdated custom handler may call to the new task instance and cause segfault. - fn check_task_is_valid( - webview: &WryWebView, - task_key: usize, - current_uuid: Retained, - ) -> crate::Result<()> { - let latest_task_uuid = webview.get_custom_task_uuid(task_key); - if let Some(latest_uuid) = latest_task_uuid { - if latest_uuid != current_uuid { - return Err(crate::Error::CustomProtocolTaskInvalid); - } - } else { - return Err(crate::Error::CustomProtocolTaskInvalid); - } + // Consolidate checks before calling into `did*` methods. + let validate = || -> crate::Result<()> { + check_webview_id_valid(webview_id)?; + check_task_is_valid(webview, task_key, task_uuid.clone())?; Ok(()) + }; + + // Perform an upfront validation + if let Err(e) = validate() { + #[cfg(feature = "tracing")] + tracing::warn!("Task invalid before sending response: {:?}", e); + return; // If invalid, return early without calling task methods. } + // ... unsafe fn response( - // FIXME: though we give it a static lifetime, it's not guaranteed to be valid. task: &'static ProtocolObject, - // FIXME: though we give it a static lifetime, it's not guaranteed to be valid. webview: &'static mut WryWebView, task_key: usize, task_uuid: Retained, @@ -189,18 +202,16 @@ extern "C" fn start_task( url: Retained, sent_response: HttpResponse>, ) -> crate::Result<()> { - check_task_is_valid(&*webview, task_key, task_uuid.clone())?; + // Validate + check_webview_id_valid(webview_id)?; + check_task_is_valid(webview, task_key, task_uuid.clone())?; let content = sent_response.body(); - // default: application/octet-stream, but should be provided by the client let wanted_mime = sent_response.headers().get(CONTENT_TYPE); - // default to 200 let wanted_status_code = sent_response.status().as_u16() as i32; - // default to HTTP/1.1 let wanted_version = format!("{:#?}", sent_response.version()); let mut headers = NSMutableDictionary::new(); - if let Some(mime) = wanted_mime { headers.insert_id( NSString::from_str(CONTENT_TYPE.as_str()).as_ref(), @@ -212,7 +223,6 @@ extern "C" fn start_task( NSString::from_str(&content.len().to_string()), ); - // add headers for (name, value) in sent_response.headers().iter() { if let Ok(value) = value.to_str() { headers.insert_id( @@ -232,40 +242,55 @@ extern "C" fn start_task( ) .unwrap(); + // Re-validate before calling didReceiveResponse check_webview_id_valid(webview_id)?; - check_task_is_valid(&*webview, task_key, task_uuid.clone())?; + check_task_is_valid(webview, task_key, task_uuid.clone())?; + // Use map_err to convert Option> to crate::Error objc2::exception::catch(AssertUnwindSafe(|| { task.didReceiveResponse(&response); })) - .unwrap(); + .map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?; - // Send data - let bytes = content.as_ptr() as *mut c_void; let data = NSData::alloc(); - // MIGRATE NOTE: we copied the content to the NSData because content will be freed - // when out of scope but NSData will also free the content when it's done and cause doube free. - let data = NSData::initWithBytes_length(data, bytes, content.len()); + let data = NSData::initWithBytes_length( + data, + content.as_ptr() as *mut c_void, + content.len(), + ); + + // Check validity again check_webview_id_valid(webview_id)?; - check_task_is_valid(&*webview, task_key, task_uuid.clone())?; + check_task_is_valid(webview, task_key, task_uuid.clone())?; + objc2::exception::catch(AssertUnwindSafe(|| { task.didReceiveData(&data); })) - .unwrap(); + .map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?; - // Finish check_webview_id_valid(webview_id)?; - check_task_is_valid(&*webview, task_key, task_uuid.clone())?; + check_task_is_valid(webview, task_key, task_uuid.clone())?; + objc2::exception::catch(AssertUnwindSafe(|| { task.didFinish(); })) - .unwrap(); - - webview.remove_custom_task_key(task_key); - Ok(()) + .map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?; + + { + let ids = WEBVIEW_IDS.lock().unwrap(); + if ids.contains(webview_id) { + webview.remove_custom_task_key(task_key); + Ok(()) + } else { + Err(crate::Error::CustomProtocolTaskInvalid) + } + } } - let _ = response( + #[cfg(feature = "tracing")] + let _span = tracing::info_span!("wry::custom_protocol::call_handler").entered(); + + if let Err(e) = response( task, webview, task_key, @@ -273,7 +298,10 @@ extern "C" fn start_task( webview_id, url.clone(), sent_response, - ); + ) { + #[cfg(feature = "tracing")] + tracing::error!("Error responding to task: {:?}", e); + } }); #[cfg(feature = "tracing")] @@ -294,6 +322,7 @@ extern "C" fn start_task( } } } + extern "C" fn stop_task( _this: &ProtocolObject, _sel: objc2::runtime::Sel,