From ca15c3ef3d4e0c03ba6908a2ec470d6f8067d788 Mon Sep 17 00:00:00 2001
From: Udara Jay <UdaraJay@users.noreply.github.com>
Date: Mon, 16 Dec 2024 12:47:23 -0500
Subject: [PATCH 1/2] Improve custom URL scheme handler robustness and error
 handling

---
 src/wkwebview/class/url_scheme_handler.rs | 135 +++++++++++++---------
 1 file changed, 82 insertions(+), 53 deletions(-)

diff --git a/src/wkwebview/class/url_scheme_handler.rs b/src/wkwebview/class/url_scheme_handler.rs
index 07bb514af..e1741091c 100644
--- a/src/wkwebview/class/url_scheme_handler.rs
+++ b/src/wkwebview/class/url_scheme_handler.rs
@@ -59,8 +59,8 @@ extern "C" fn start_task(
 ) {
   unsafe {
     #[cfg(feature = "tracing")]
-          let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty)
-            .entered();
+    let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty)
+      .entered();
 
     let task_key = task.hash(); // hash by task object address
     let task_uuid = webview.add_custom_task_key(task_key);
@@ -122,7 +122,6 @@ extern "C" fn start_task(
       if let Some(all_headers) = all_headers {
         for current_header in all_headers.allKeys().to_vec() {
           let header_value = all_headers.valueForKey(current_header).unwrap();
-
           // inject the header into the request
           http_request = http_request.header(current_header.to_string(), header_value.to_string());
         }
@@ -145,43 +144,57 @@ extern "C" fn start_task(
         task.didFinish();
       };
 
+      fn check_webview_id_valid(webview_id: &str) -> crate::Result<()> {
+        if !WEBVIEW_IDS.lock().unwrap().contains(webview_id) {
+          return Err(crate::Error::CustomProtocolTaskInvalid);
+        }
+        Ok(())
+      }
+
+      /// Task may not live longer than async custom protocol handler.
+      ///
+      /// There are roughly 2 ways to cause segfault:
+      /// 1. Task has stopped. pointer of the task not valid anymore.
+      /// 2. Task had stopped, but the pointer of the task has allocated to a new task.
+      ///    Outdated custom handler may call to the new task instance and cause segfault.
+      fn check_task_is_valid(
+        webview: &WryWebView,
+        task_key: usize,
+        current_uuid: Retained<NSUUID>,
+      ) -> crate::Result<()> {
+        let latest_task_uuid = webview.get_custom_task_uuid(task_key);
+        if let Some(latest_uuid) = latest_task_uuid {
+          if latest_uuid != current_uuid {
+            return Err(crate::Error::CustomProtocolTaskInvalid);
+          }
+        } else {
+          return Err(crate::Error::CustomProtocolTaskInvalid);
+        }
+        Ok(())
+      }
+
       // send response
       match http_request.body(sent_form_body) {
         Ok(final_request) => {
           let responder: Box<dyn FnOnce(HttpResponse<Cow<'static, [u8]>>)> =
             Box::new(move |sent_response| {
-              fn check_webview_id_valid(webview_id: &str) -> crate::Result<()> {
-                if !WEBVIEW_IDS.lock().unwrap().contains(webview_id) {
-                  return Err(crate::Error::CustomProtocolTaskInvalid);
-                }
-                Ok(())
-              }
-              /// Task may not live longer than async custom protocol handler.
-              ///
-              /// There are roughly 2 ways to cause segfault:
-              /// 1. Task has stopped. pointer of the task not valid anymore.
-              /// 2. Task had stopped, but the pointer of the task has allocated to a new task.
-              ///    Outdated custom handler may call to the new task instance and cause segfault.
-              fn check_task_is_valid(
-                webview: &WryWebView,
-                task_key: usize,
-                current_uuid: Retained<NSUUID>,
-              ) -> crate::Result<()> {
-                let latest_task_uuid = webview.get_custom_task_uuid(task_key);
-                if let Some(latest_uuid) = latest_task_uuid {
-                  if latest_uuid != current_uuid {
-                    return Err(crate::Error::CustomProtocolTaskInvalid);
-                  }
-                } else {
-                  return Err(crate::Error::CustomProtocolTaskInvalid);
-                }
+              // Consolidate checks before calling into `did*` methods.
+              let validate = || -> crate::Result<()> {
+                check_webview_id_valid(webview_id)?;
+                check_task_is_valid(webview, task_key, task_uuid.clone())?;
                 Ok(())
+              };
+
+              // Perform an upfront validation
+              if let Err(e) = validate() {
+                #[cfg(feature = "tracing")]
+                tracing::warn!("Task invalid before sending response: {:?}", e);
+                return; // If invalid, return early without calling task methods.
               }
 
+              // ...
               unsafe fn response(
-                // FIXME: though we give it a static lifetime, it's not guaranteed to be valid.
                 task: &'static ProtocolObject<dyn WKURLSchemeTask>,
-                // FIXME: though we give it a static lifetime, it's not guaranteed to be valid.
                 webview: &'static mut WryWebView,
                 task_key: usize,
                 task_uuid: Retained<NSUUID>,
@@ -189,18 +202,16 @@ extern "C" fn start_task(
                 url: Retained<NSURL>,
                 sent_response: HttpResponse<Cow<'_, [u8]>>,
               ) -> crate::Result<()> {
-                check_task_is_valid(&*webview, task_key, task_uuid.clone())?;
+                // Validate
+                check_webview_id_valid(webview_id)?;
+                check_task_is_valid(webview, task_key, task_uuid.clone())?;
 
                 let content = sent_response.body();
-                // default: application/octet-stream, but should be provided by the client
                 let wanted_mime = sent_response.headers().get(CONTENT_TYPE);
-                // default to 200
                 let wanted_status_code = sent_response.status().as_u16() as i32;
-                // default to HTTP/1.1
                 let wanted_version = format!("{:#?}", sent_response.version());
 
                 let mut headers = NSMutableDictionary::new();
-
                 if let Some(mime) = wanted_mime {
                   headers.insert_id(
                     NSString::from_str(CONTENT_TYPE.as_str()).as_ref(),
@@ -212,7 +223,6 @@ extern "C" fn start_task(
                   NSString::from_str(&content.len().to_string()),
                 );
 
-                // add headers
                 for (name, value) in sent_response.headers().iter() {
                   if let Ok(value) = value.to_str() {
                     headers.insert_id(
@@ -232,40 +242,55 @@ extern "C" fn start_task(
                 )
                 .unwrap();
 
+                // Re-validate before calling didReceiveResponse
                 check_webview_id_valid(webview_id)?;
-                check_task_is_valid(&*webview, task_key, task_uuid.clone())?;
+                check_task_is_valid(webview, task_key, task_uuid.clone())?;
 
+                // Use map_err to convert Option<Retained<Exception>> to crate::Error
                 objc2::exception::catch(AssertUnwindSafe(|| {
                   task.didReceiveResponse(&response);
                 }))
-                .unwrap();
+                .map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;
 
-                // Send data
-                let bytes = content.as_ptr() as *mut c_void;
                 let data = NSData::alloc();
-                // MIGRATE NOTE: we copied the content to the NSData because content will be freed
-                // when out of scope but NSData will also free the content when it's done and cause doube free.
-                let data = NSData::initWithBytes_length(data, bytes, content.len());
+                let data = NSData::initWithBytes_length(
+                  data,
+                  content.as_ptr() as *mut c_void,
+                  content.len(),
+                );
+
+                // Check validity again
                 check_webview_id_valid(webview_id)?;
-                check_task_is_valid(&*webview, task_key, task_uuid.clone())?;
+                check_task_is_valid(webview, task_key, task_uuid.clone())?;
+
                 objc2::exception::catch(AssertUnwindSafe(|| {
                   task.didReceiveData(&data);
                 }))
-                .unwrap();
+                .map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;
 
-                // Finish
                 check_webview_id_valid(webview_id)?;
-                check_task_is_valid(&*webview, task_key, task_uuid.clone())?;
+                check_task_is_valid(webview, task_key, task_uuid.clone())?;
+
                 objc2::exception::catch(AssertUnwindSafe(|| {
                   task.didFinish();
                 }))
-                .unwrap();
-
-                webview.remove_custom_task_key(task_key);
-                Ok(())
+                .map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;
+
+                {
+                  let ids = WEBVIEW_IDS.lock().unwrap();
+                  if ids.contains(webview_id) {
+                    webview.remove_custom_task_key(task_key);
+                    Ok(())
+                  } else {
+                    Err(crate::Error::CustomProtocolTaskInvalid)
+                  }
+                }
               }
 
-              let _ = response(
+              #[cfg(feature = "tracing")]
+              let _span = tracing::info_span!("wry::custom_protocol::call_handler").entered();
+
+              if let Err(e) = response(
                 task,
                 webview,
                 task_key,
@@ -273,7 +298,10 @@ extern "C" fn start_task(
                 webview_id,
                 url.clone(),
                 sent_response,
-              );
+              ) {
+                #[cfg(feature = "tracing")]
+                tracing::error!("Error responding to task: {:?}", e);
+              }
             });
 
           #[cfg(feature = "tracing")]
@@ -294,6 +322,7 @@ extern "C" fn start_task(
     }
   }
 }
+
 extern "C" fn stop_task(
   _this: &ProtocolObject<dyn WKURLSchemeHandler>,
   _sel: objc2::runtime::Sel,

From db6a6eed2efa81f0cf3c26e15537b18553f9db72 Mon Sep 17 00:00:00 2001
From: FabianLars <github@fabianlars.de>
Date: Thu, 23 Jan 2025 17:13:14 +0100
Subject: [PATCH 2/2] re-add comments

---
 src/wkwebview/class/url_scheme_handler.rs | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/src/wkwebview/class/url_scheme_handler.rs b/src/wkwebview/class/url_scheme_handler.rs
index e1741091c..7edb28f52 100644
--- a/src/wkwebview/class/url_scheme_handler.rs
+++ b/src/wkwebview/class/url_scheme_handler.rs
@@ -192,9 +192,10 @@ extern "C" fn start_task(
                 return; // If invalid, return early without calling task methods.
               }
 
-              // ...
               unsafe fn response(
+                // FIXME: though we give it a static lifetime, it's not guaranteed to be valid.
                 task: &'static ProtocolObject<dyn WKURLSchemeTask>,
+                // FIXME: though we give it a static lifetime, it's not guaranteed to be valid.
                 webview: &'static mut WryWebView,
                 task_key: usize,
                 task_uuid: Retained<NSUUID>,
@@ -207,8 +208,11 @@ extern "C" fn start_task(
                 check_task_is_valid(webview, task_key, task_uuid.clone())?;
 
                 let content = sent_response.body();
+                // default: application/octet-stream, but should be provided by the client
                 let wanted_mime = sent_response.headers().get(CONTENT_TYPE);
+                // default to 200
                 let wanted_status_code = sent_response.status().as_u16() as i32;
+                // default to HTTP/1.1
                 let wanted_version = format!("{:#?}", sent_response.version());
 
                 let mut headers = NSMutableDictionary::new();
@@ -223,6 +227,7 @@ extern "C" fn start_task(
                   NSString::from_str(&content.len().to_string()),
                 );
 
+                // add headers
                 for (name, value) in sent_response.headers().iter() {
                   if let Ok(value) = value.to_str() {
                     headers.insert_id(
@@ -252,7 +257,10 @@ extern "C" fn start_task(
                 }))
                 .map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;
 
+                // Send data
                 let data = NSData::alloc();
+                // MIGRATE NOTE: we copied the content to the NSData because content will be freed
+                // when out of scope but NSData will also free the content when it's done and cause doube free.
                 let data = NSData::initWithBytes_length(
                   data,
                   content.as_ptr() as *mut c_void,