Skip to content

Commit

Permalink
remove global var which stores valid task id
Browse files Browse the repository at this point in the history
  • Loading branch information
pewsheen committed Jun 4, 2024
1 parent bd65717 commit 9fe7c40
Showing 1 changed file with 10 additions and 30 deletions.
40 changes: 10 additions & 30 deletions src/wkwebview/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ mod util;
use cocoa::appkit::{NSView, NSViewHeightSizable, NSViewMinYMargin, NSViewWidthSizable};
use cocoa::{
base::{id, nil, NO, YES},
foundation::{NSDictionary, NSFastEnumeration, NSInteger, NSUInteger},
foundation::{NSDictionary, NSFastEnumeration, NSInteger},
};
use dpi::{LogicalPosition, LogicalSize};
use once_cell::sync::Lazy;
Expand Down Expand Up @@ -83,7 +83,6 @@ const NS_JSON_WRITING_FRAGMENTS_ALLOWED: u64 = 4;

static COUNTER: Counter = Counter::new();
static WEBVIEW_IDS: Lazy<Mutex<HashSet<u32>>> = Lazy::new(Default::default);
static TASK_IDS: Lazy<Mutex<HashSet<NSUInteger>>> = Lazy::new(Default::default);

#[derive(Debug, Default, Copy, Clone)]
pub struct PrintMargin {
Expand Down Expand Up @@ -277,23 +276,18 @@ impl InnerWebView {
match http_request.body(sent_form_body) {
Ok(final_request) => {
let () = msg_send![task, retain];
let task_id: NSUInteger = msg_send![task, hash];

let responder: Box<dyn FnOnce(HttpResponse<Cow<'static, [u8]>>)> = Box::new(
move |sent_response| {
// Best-effort. OS may release task at any moment.
fn check_task_is_valid(webview_id: u32, task_id: u64) -> crate::Result<()> {
if !WEBVIEW_IDS.lock().unwrap().contains(&webview_id)
|| !TASK_IDS.lock().unwrap().contains(&task_id)
{
return Err(crate::Error::CustomProtocolTaskInvalid);
fn check_webview_id_valid(webview_id: u32) -> crate::Result<()> {
match WEBVIEW_IDS.lock().unwrap().contains(&webview_id) {
true => Ok(()),
false => Err(crate::Error::CustomProtocolTaskInvalid),
}
Ok(())
}

unsafe fn response(
task: id,
task_id: NSUInteger,
webview_id: u32,
url: id, /* NSURL */
sent_response: HttpResponse<Cow<'_, [u8]>>,
Expand Down Expand Up @@ -322,11 +316,9 @@ impl InnerWebView {
}

let urlresponse: id = msg_send![class!(NSHTTPURLResponse), alloc];
// url is part of the task, we need to check task is still valid
check_task_is_valid(webview_id, task_id)?;
let response: id = msg_send![urlresponse, initWithURL:url statusCode: wanted_status_code HTTPVersion:NSString::new(&wanted_version) headerFields:headers];

check_task_is_valid(webview_id, task_id)?;
check_webview_id_valid(webview_id)?;
(*task)
.send_message::<(id,), ()>(sel!(didReceiveResponse:), (response,))
.map_err(|_| crate::Error::CustomProtocolTaskInvalid)?;
Expand All @@ -336,36 +328,27 @@ impl InnerWebView {
let data: id = msg_send![class!(NSData), alloc];
let data: id = msg_send![data, initWithBytesNoCopy:bytes length:content.len() freeWhenDone: if content.len() == 0 { NO } else { YES }];

check_task_is_valid(webview_id, task_id)?;
check_webview_id_valid(webview_id)?;
(*task)
.send_message::<(id,), ()>(sel!(didReceiveData:), (data,))
.map_err(|_| crate::Error::CustomProtocolTaskInvalid)?;

// Finish
check_task_is_valid(webview_id, task_id)?;
check_webview_id_valid(webview_id)?;
(*task)
.send_message::<(), ()>(sel!(didFinish), ())
.map_err(|_| crate::Error::CustomProtocolTaskInvalid)?;

Ok(())
}

if check_task_is_valid(webview_id, task_id).is_ok() {
let _ = response(task, task_id, webview_id, url, sent_response);
}
TASK_IDS.lock().unwrap().remove(&task_id);
let _ = response(task, webview_id, url, sent_response);
let () = msg_send![task, release];
},
);

#[cfg(feature = "tracing")]
let _span = tracing::info_span!("wry::custom_protocol::call_handler").entered();

{
let mut task_ids = TASK_IDS.lock().unwrap();
task_ids.insert(task_id);
}

function(final_request, RequestAsyncResponder { responder });
}
Err(_) => respond_with_404(),
Expand All @@ -378,10 +361,7 @@ impl InnerWebView {
}
}
}
extern "C" fn stop_task(_: &Object, _: Sel, _webview: id, task: id) {
let task_id: NSUInteger = unsafe { msg_send![task, hash] };
TASK_IDS.lock().unwrap().remove(&task_id);
}
extern "C" fn stop_task(_: &Object, _: Sel, _webview: id, _task: id) {}

let mut wv_ids = WEBVIEW_IDS.lock().unwrap();
let webview_id = COUNTER.next();
Expand Down

0 comments on commit 9fe7c40

Please sign in to comment.