diff --git a/.changes/fix-macos-async-command-panic.md b/.changes/fix-macos-async-command-panic.md new file mode 100644 index 000000000..ea79b27c4 --- /dev/null +++ b/.changes/fix-macos-async-command-panic.md @@ -0,0 +1,5 @@ +--- +"wry": patch +--- + +On macOS, fix an issue that could cause a panic when running an async command. diff --git a/src/error.rs b/src/error.rs index afe0547e7..a0b2518ed 100644 --- a/src/error.rs +++ b/src/error.rs @@ -57,4 +57,6 @@ pub enum Error { #[cfg(target_os = "android")] #[error(transparent)] CrossBeamRecvError(#[from] crossbeam_channel::RecvError), + #[error("Custom protocol task is invalid.")] + CustomProtocolTaskInvalid, } diff --git a/src/wkwebview/mod.rs b/src/wkwebview/mod.rs index b61500de9..0da98f4c5 100644 --- a/src/wkwebview/mod.rs +++ b/src/wkwebview/mod.rs @@ -40,6 +40,7 @@ use core_graphics::{ use objc::{ declare::ClassDecl, runtime::{Class, Object, Sel, BOOL}, + Message, }; use objc_id::Id; @@ -193,7 +194,7 @@ impl InnerWebView { } // Task handler for custom protocol - extern "C" fn start_task(this: &Object, _: Sel, _webview: id, task: id) { + extern "C" fn start_task(this: &Object, _: Sel, _webview: id, task: *mut Object) { unsafe { #[cfg(feature = "tracing")] let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty) @@ -274,53 +275,75 @@ impl InnerWebView { // send response match http_request.body(sent_form_body) { Ok(final_request) => { + let () = msg_send![task, retain]; + let responder: Box>)> = Box::new( move |sent_response| { - let content = sent_response.body(); - // default: application/octet-stream, but should be provided by the client - let wanted_mime = sent_response.headers().get(CONTENT_TYPE); - // default to 200 - let wanted_status_code = sent_response.status().as_u16() as i32; - // default to HTTP/1.1 - let wanted_version = format!("{:#?}", sent_response.version()); - - let dictionary: id = msg_send![class!(NSMutableDictionary), alloc]; - let headers: id = msg_send![dictionary, initWithCapacity:1]; - if let Some(mime) = wanted_mime { - let () = msg_send![headers, setObject:NSString::new(mime.to_str().unwrap()) forKey: NSString::new(CONTENT_TYPE.as_str())]; + fn check_webview_id_valid(webview_id: u32) -> crate::Result<()> { + match WEBVIEW_IDS.lock().unwrap().contains(&webview_id) { + true => Ok(()), + false => Err(crate::Error::CustomProtocolTaskInvalid), + } } - let () = msg_send![headers, setObject:NSString::new(&content.len().to_string()) forKey: NSString::new(CONTENT_LENGTH.as_str())]; - // add headers - for (name, value) in sent_response.headers().iter() { - let header_key = name.as_str(); - if let Ok(value) = value.to_str() { - let () = msg_send![headers, setObject:NSString::new(value) forKey: NSString::new(header_key)]; + unsafe fn response( + task: id, + webview_id: u32, + url: id, /* NSURL */ + sent_response: HttpResponse>, + ) -> crate::Result<()> { + let content = sent_response.body(); + // default: application/octet-stream, but should be provided by the client + let wanted_mime = sent_response.headers().get(CONTENT_TYPE); + // default to 200 + let wanted_status_code = sent_response.status().as_u16() as i32; + // default to HTTP/1.1 + let wanted_version = format!("{:#?}", sent_response.version()); + + let dictionary: id = msg_send![class!(NSMutableDictionary), alloc]; + let headers: id = msg_send![dictionary, initWithCapacity:1]; + if let Some(mime) = wanted_mime { + let () = msg_send![headers, setObject:NSString::new(mime.to_str().unwrap()) forKey: NSString::new(CONTENT_TYPE.as_str())]; + } + let () = msg_send![headers, setObject:NSString::new(&content.len().to_string()) forKey: NSString::new(CONTENT_LENGTH.as_str())]; + + // add headers + for (name, value) in sent_response.headers().iter() { + let header_key = name.as_str(); + if let Ok(value) = value.to_str() { + let () = msg_send![headers, setObject:NSString::new(value) forKey: NSString::new(header_key)]; + } } - } - let urlresponse: id = msg_send![class!(NSHTTPURLResponse), alloc]; - let response: id = msg_send![urlresponse, initWithURL:url statusCode: wanted_status_code HTTPVersion:NSString::new(&wanted_version) headerFields:headers]; - if !WEBVIEW_IDS.lock().unwrap().contains(&webview_id) { - return; - } - let () = msg_send![task, didReceiveResponse: response]; + let urlresponse: id = msg_send![class!(NSHTTPURLResponse), alloc]; + let response: id = msg_send![urlresponse, initWithURL:url statusCode: wanted_status_code HTTPVersion:NSString::new(&wanted_version) headerFields:headers]; - // Send data - let bytes = content.as_ptr() as *mut c_void; - let data: id = msg_send![class!(NSData), alloc]; - let data: id = msg_send![data, initWithBytesNoCopy:bytes length:content.len() freeWhenDone: if content.len() == 0 { NO } else { YES }]; + check_webview_id_valid(webview_id)?; + (*task) + .send_message::<(id,), ()>(sel!(didReceiveResponse:), (response,)) + .map_err(|_| crate::Error::CustomProtocolTaskInvalid)?; - if !WEBVIEW_IDS.lock().unwrap().contains(&webview_id) { - return; - } - let () = msg_send![task, didReceiveData: data]; + // Send data + let bytes = content.as_ptr() as *mut c_void; + let data: id = msg_send![class!(NSData), alloc]; + let data: id = msg_send![data, initWithBytesNoCopy:bytes length:content.len() freeWhenDone: if content.len() == 0 { NO } else { YES }]; + + check_webview_id_valid(webview_id)?; + (*task) + .send_message::<(id,), ()>(sel!(didReceiveData:), (data,)) + .map_err(|_| crate::Error::CustomProtocolTaskInvalid)?; - // Finish - if !WEBVIEW_IDS.lock().unwrap().contains(&webview_id) { - return; + // Finish + check_webview_id_valid(webview_id)?; + (*task) + .send_message::<(), ()>(sel!(didFinish), ()) + .map_err(|_| crate::Error::CustomProtocolTaskInvalid)?; + + Ok(()) } - let () = msg_send![task, didFinish]; + + let _ = response(task, webview_id, url, sent_response); + let () = msg_send![task, release]; }, );