diff --git a/bindings/matrix-sdk-ffi/src/widget.rs b/bindings/matrix-sdk-ffi/src/widget.rs index 9cdd5856033..5c41d9aa5e3 100644 --- a/bindings/matrix-sdk-ffi/src/widget.rs +++ b/bindings/matrix-sdk-ffi/src/widget.rs @@ -3,7 +3,7 @@ use std::sync::{Arc, Mutex}; use language_tags::LanguageTag; use matrix_sdk::{ async_trait, - widget::{MessageLikeEventFilter, StateEventFilter}, + widget::{MessageLikeEventFilter, StateEventFilter, ToDeviceEventFilter}, }; use tracing::error; @@ -427,6 +427,8 @@ pub enum WidgetEventFilter { StateWithType { event_type: String }, /// Matches state events with the given `type` and `state_key`. StateWithTypeAndStateKey { event_type: String, state_key: String }, + /// Matches ToDevice events with the given `type`. + ToDeviceWithType { event_type: String }, } impl From for matrix_sdk::widget::EventFilter { @@ -444,6 +446,9 @@ impl From for matrix_sdk::widget::EventFilter { WidgetEventFilter::StateWithTypeAndStateKey { event_type, state_key } => { Self::State(StateEventFilter::WithTypeAndStateKey(event_type.into(), state_key)) } + WidgetEventFilter::ToDeviceWithType { event_type } => { + Self::ToDevice(ToDeviceEventFilter(event_type.into())) + } } } } @@ -465,6 +470,9 @@ impl From for WidgetEventFilter { F::State(StateEventFilter::WithTypeAndStateKey(event_type, state_key)) => { Self::StateWithTypeAndStateKey { event_type: event_type.to_string(), state_key } } + F::ToDevice(to_device_event_filter) => { + Self::ToDeviceWithType { event_type: to_device_event_filter.0.to_string() } + } } } } diff --git a/crates/matrix-sdk/src/widget/capabilities.rs b/crates/matrix-sdk/src/widget/capabilities.rs index b2492dc9eac..4830752935d 100644 --- a/crates/matrix-sdk/src/widget/capabilities.rs +++ b/crates/matrix-sdk/src/widget/capabilities.rs @@ -18,12 +18,15 @@ use std::fmt; use async_trait::async_trait; -use ruma::{events::AnyTimelineEvent, serde::Raw}; use serde::{ser::SerializeSeq, Deserialize, Deserializer, Serialize, Serializer}; use tracing::{debug, error}; +use crate::widget::filter::ToDeviceEventFilter; + use super::{ - filter::MatrixEventFilterInput, EventFilter, MessageLikeEventFilter, StateEventFilter, + filter::{EventAndRequestFilter, MatrixEventFilterInput, MatrixEventFilterInputData}, + machine::MatrixEvent, + EventFilter, MessageLikeEventFilter, StateEventFilter, }; /// Must be implemented by a component that provides functionality of deciding @@ -59,23 +62,38 @@ pub struct Capabilities { impl Capabilities { /// Tells if a given raw event matches the read filter. - pub fn raw_event_matches_read_filter(&self, raw: &Raw) -> bool { - let filter_in = match raw.deserialize_as::() { - Ok(filter) => filter, - Err(err) => { - error!("Failed to deserialize raw event as MatrixEventFilterInput: {err}"); - return false; + pub fn raw_event_matches_read_filter(&self, raw: &MatrixEvent) -> bool { + let filter_in = match raw { + MatrixEvent::Timeline(raw) => { + match raw.deserialize_as::() { + Ok(filter) => MatrixEventFilterInput::Timeline(filter), + Err(err) => { + error!("Failed to deserialize raw event as MatrixEventFilterInput: {err}"); + return false; + } + } + } + MatrixEvent::ToDevice(raw) => { + match raw.deserialize_as::() { + Ok(filter) => MatrixEventFilterInput::ToDevice(filter), + Err(err) => { + error!("Failed to deserialize raw event as MatrixEventFilterInput: {err}"); + return false; + } + } } }; - self.read.iter().any(|f| f.matches(&filter_in)) + self.read.iter().any(|f| f.matches_event(&filter_in)) } } const SEND_EVENT: &str = "org.matrix.msc2762.send.event"; -const READ_EVENT: &str = "org.matrix.msc2762.receive.event"; +const RECEIVE_EVENT: &str = "org.matrix.msc2762.receive.event"; const SEND_STATE: &str = "org.matrix.msc2762.send.state_event"; -const READ_STATE: &str = "org.matrix.msc2762.receive.state_event"; +const RECEIVE_STATE: &str = "org.matrix.msc2762.receive.state_event"; +const SEND_TODEVICE: &str = "org.matrix.msc3819.send.to_device"; +const RECEIVE_TODEVICE: &str = "org.matrix.msc3819.receive.to_device"; const REQUIRES_CLIENT: &str = "io.element.requires_client"; pub(super) const SEND_DELAYED_EVENT: &str = "org.matrix.msc4157.send.delayed_event"; pub(super) const UPDATE_DELAYED_EVENT: &str = "org.matrix.msc4157.update_delayed_event"; @@ -91,6 +109,7 @@ impl Serialize for Capabilities { match self.0 { EventFilter::MessageLike(filter) => PrintMessageLikeEventFilter(filter).fmt(f), EventFilter::State(filter) => PrintStateEventFilter(filter).fmt(f), + EventFilter::ToDevice(filter) => filter.fmt(f), } } } @@ -136,8 +155,9 @@ impl Serialize for Capabilities { } for filter in &self.read { let name = match filter { - EventFilter::MessageLike(_) => READ_EVENT, - EventFilter::State(_) => READ_STATE, + EventFilter::MessageLike(_) => RECEIVE_EVENT, + EventFilter::State(_) => RECEIVE_STATE, + EventFilter::ToDevice(_) => RECEIVE_TODEVICE, }; seq.serialize_element(&format!("{name}:{}", PrintEventFilter(filter)))?; } @@ -145,6 +165,7 @@ impl Serialize for Capabilities { let name = match filter { EventFilter::MessageLike(_) => SEND_EVENT, EventFilter::State(_) => SEND_STATE, + EventFilter::ToDevice(_) => SEND_TODEVICE, }; seq.serialize_element(&format!("{name}:{}", PrintEventFilter(filter)))?; } @@ -184,18 +205,24 @@ impl<'de> Deserialize<'de> for Capabilities { } match s.split_once(':') { - Some((READ_EVENT, filter_s)) => Ok(Permission::Read(EventFilter::MessageLike( - parse_message_event_filter(filter_s), - ))), + Some((RECEIVE_EVENT, filter_s)) => Ok(Permission::Read( + EventFilter::MessageLike(parse_message_event_filter(filter_s)), + )), Some((SEND_EVENT, filter_s)) => Ok(Permission::Send(EventFilter::MessageLike( parse_message_event_filter(filter_s), ))), - Some((READ_STATE, filter_s)) => { + Some((RECEIVE_STATE, filter_s)) => { Ok(Permission::Read(EventFilter::State(parse_state_event_filter(filter_s)))) } Some((SEND_STATE, filter_s)) => { Ok(Permission::Send(EventFilter::State(parse_state_event_filter(filter_s)))) } + Some((RECEIVE_TODEVICE, filter_s)) => Ok(Permission::Read( + EventFilter::ToDevice(parse_to_device_event_filter(filter_s)), + )), + Some((SEND_TODEVICE, filter_s)) => Ok(Permission::Send(EventFilter::ToDevice( + parse_to_device_event_filter(filter_s), + ))), _ => { debug!("Unknown capability `{s}`"); Ok(Self::Unknown) @@ -222,6 +249,10 @@ impl<'de> Deserialize<'de> for Capabilities { } } + fn parse_to_device_event_filter(s: &str) -> ToDeviceEventFilter { + ToDeviceEventFilter(s.into()) + } + let mut capabilities = Capabilities::default(); for capability in Vec::::deserialize(deserializer)? { match capability { @@ -263,8 +294,10 @@ mod tests { "org.matrix.msc2762.receive.event:org.matrix.rageshake_request", "org.matrix.msc2762.receive.state_event:m.room.member", "org.matrix.msc2762.receive.state_event:org.matrix.msc3401.call.member", + "org.matrix.msc3819.receive.to_device:io.element.call.encryption_keys", "org.matrix.msc2762.send.event:org.matrix.rageshake_request", "org.matrix.msc2762.send.state_event:org.matrix.msc3401.call.member#@user:matrix.server", + "org.matrix.msc3819.send.to_device:io.element.call.encryption_keys", "org.matrix.msc4157.send.delayed_event", "org.matrix.msc4157.update_delayed_event" ]"#; @@ -279,6 +312,9 @@ mod tests { EventFilter::State(StateEventFilter::WithType( "org.matrix.msc3401.call.member".into(), )), + EventFilter::ToDevice(ToDeviceEventFilter( + "io.element.call.encryption_keys".into(), + )), ], send: vec![ EventFilter::MessageLike(MessageLikeEventFilter::WithType( @@ -288,6 +324,9 @@ mod tests { "org.matrix.msc3401.call.member".into(), "@user:matrix.server".into(), )), + EventFilter::ToDevice(ToDeviceEventFilter( + "io.element.call.encryption_keys".into(), + )), ], requires_client: true, update_delayed_event: true, @@ -309,6 +348,9 @@ mod tests { "org.matrix.msc3401.call.member".into(), "@user:matrix.server".into(), )), + EventFilter::ToDevice(ToDeviceEventFilter( + "io.element.call.encryption_keys".into(), + )), ], send: vec![ EventFilter::MessageLike(MessageLikeEventFilter::WithType( @@ -318,6 +360,7 @@ mod tests { "org.matrix.msc3401.call.member".into(), "@user:matrix.server".into(), )), + EventFilter::ToDevice(ToDeviceEventFilter("my.org.other.to_device_event".into())), ], requires_client: true, update_delayed_event: false, diff --git a/crates/matrix-sdk/src/widget/filter.rs b/crates/matrix-sdk/src/widget/filter.rs index c3daf2ce062..6c1cb4f2c16 100644 --- a/crates/matrix-sdk/src/widget/filter.rs +++ b/crates/matrix-sdk/src/widget/filter.rs @@ -14,10 +14,16 @@ #![allow(dead_code)] // temporary -use ruma::events::{MessageLikeEventType, StateEventType, TimelineEventType}; -use serde::Deserialize; +use std::fmt; + +use ruma::events::{MessageLikeEventType, StateEventType, TimelineEventType, ToDeviceEventType}; +use serde::{Deserialize, Serialize}; /// Different kinds of filters for timeline events. + +// Refactor this to only have two methods: matches_request(RequestFilterInput) and matches_event(MatrixEventFilterInput). +// or only one matches(FilterInput), enum FilterInput{Event(MatrixEventFilterInput), Request(RequestFilterInput)} and from impls +// for FilterInput... #[derive(Clone, Debug)] #[cfg_attr(test, derive(PartialEq))] pub enum EventFilter { @@ -25,35 +31,33 @@ pub enum EventFilter { MessageLike(MessageLikeEventFilter), /// Filter for state events. State(StateEventFilter), + /// Filter for to device events. + ToDevice(ToDeviceEventFilter), } -impl EventFilter { - pub(super) fn matches(&self, matrix_event: &MatrixEventFilterInput) -> bool { - match self { - EventFilter::MessageLike(message_filter) => message_filter.matches(matrix_event), - EventFilter::State(state_filter) => state_filter.matches(matrix_event), - } +impl EventAndRequestFilter for EventFilter { + fn matches_event(&self, matrix_event: &MatrixEventFilterInput) -> bool { + let filter: &dyn EventAndRequestFilter = match self { + EventFilter::MessageLike(filter) => filter, + EventFilter::State(filter) => filter, + EventFilter::ToDevice(filter) => filter, + }; + filter.matches_event(matrix_event) } - pub(super) fn matches_state_event_with_any_state_key( - &self, - event_type: &StateEventType, - ) -> bool { - matches!( - self, - Self::State(filter) if filter.matches_state_event_with_any_state_key(event_type) - ) + fn matches_request(&self, event_type: &FilterEventType) -> bool { + let filter: &dyn EventAndRequestFilter = match self { + EventFilter::MessageLike(filter) => filter, + EventFilter::State(filter) => filter, + EventFilter::ToDevice(filter) => filter, + }; + filter.matches_request(event_type) } +} - pub(super) fn matches_message_like_event_type( - &self, - event_type: &MessageLikeEventType, - ) -> bool { - matches!( - self, - Self::MessageLike(filter) if filter.matches_message_like_event_type(event_type) - ) - } +pub trait EventAndRequestFilter { + fn matches_event(&self, matrix_event: &MatrixEventFilterInput) -> bool; + fn matches_request(&self, event_type: &FilterEventType) -> bool; } /// Filter for message-like events. @@ -66,29 +70,41 @@ pub enum MessageLikeEventFilter { RoomMessageWithMsgtype(String), } -impl MessageLikeEventFilter { - fn matches(&self, matrix_event: &MatrixEventFilterInput) -> bool { +impl EventAndRequestFilter for MessageLikeEventFilter { + fn matches_event(&self, matrix_event: &MatrixEventFilterInput) -> bool { + let MatrixEventFilterInput::Timeline(matrix_event) = matrix_event else { + return false; + }; if matrix_event.state_key.is_some() { // State event doesn't match a message-like event filter. return false; } - - match self { - MessageLikeEventFilter::WithType(event_type) => { - matrix_event.event_type == TimelineEventType::from(event_type.clone()) - } - MessageLikeEventFilter::RoomMessageWithMsgtype(msgtype) => { - matrix_event.event_type == TimelineEventType::RoomMessage - && matrix_event.content.msgtype.as_ref() == Some(msgtype) + if let FilterEventType::Timeline(event_type) = &matrix_event.event_type { + match self { + MessageLikeEventFilter::WithType(filter_event_type) => { + *event_type == TimelineEventType::from(filter_event_type.clone()) + } + MessageLikeEventFilter::RoomMessageWithMsgtype(msgtype) => { + *event_type == TimelineEventType::RoomMessage + && matrix_event.content.msgtype.as_ref() == Some(msgtype) + } } + } else { + false } } - fn matches_message_like_event_type(&self, event_type: &MessageLikeEventType) -> bool { + fn matches_request(&self, event_type: &FilterEventType) -> bool { + let FilterEventType::Timeline(event_type) = event_type else { + return false; + }; + match self { - MessageLikeEventFilter::WithType(filter_event_type) => filter_event_type == event_type, + MessageLikeEventFilter::WithType(filter_event_type) => { + TimelineEventType::from(filter_event_type.clone()) == *event_type + } MessageLikeEventFilter::RoomMessageWithMsgtype(_) => { - event_type == &MessageLikeEventType::RoomMessage + &TimelineEventType::RoomMessage == event_type } } } @@ -104,8 +120,11 @@ pub enum StateEventFilter { WithTypeAndStateKey(StateEventType, String), } -impl StateEventFilter { - fn matches(&self, matrix_event: &MatrixEventFilterInput) -> bool { +impl EventAndRequestFilter for StateEventFilter { + fn matches_event(&self, matrix_event: &MatrixEventFilterInput) -> bool { + let MatrixEventFilterInput::Timeline(matrix_event) = matrix_event else { + return false; + }; let Some(state_key) = &matrix_event.state_key else { // Message-like event doesn't match a state event filter. return false; @@ -113,28 +132,96 @@ impl StateEventFilter { match self { StateEventFilter::WithType(event_type) => { - matrix_event.event_type == TimelineEventType::from(event_type.clone()) + matrix_event.event_type + == FilterEventType::Timeline(TimelineEventType::from(event_type.clone())) } StateEventFilter::WithTypeAndStateKey(event_type, filter_state_key) => { - matrix_event.event_type == TimelineEventType::from(event_type.clone()) + matrix_event.event_type + == FilterEventType::Timeline(TimelineEventType::from(event_type.clone())) && state_key == filter_state_key } } } - fn matches_state_event_with_any_state_key(&self, event_type: &StateEventType) -> bool { - matches!(self, Self::WithType(ty) if ty == event_type) + fn matches_request(&self, event_type: &FilterEventType) -> bool { + matches!(self, Self::WithType(ty) if FilterEventType::Timeline(TimelineEventType::from(ty.clone())) == *event_type) + } +} + +/// Filter for to-device events. +#[derive(Clone, Debug)] +#[cfg_attr(test, derive(PartialEq))] + +pub struct ToDeviceEventFilter(pub ToDeviceEventType); + +impl EventAndRequestFilter for ToDeviceEventFilter { + fn matches_event(&self, matrix_event: &MatrixEventFilterInput) -> bool { + let MatrixEventFilterInput::ToDevice(matrix_event) = matrix_event else { + return false; + }; + match self { + ToDeviceEventFilter(event_type) => { + matrix_event.event_type == FilterEventType::ToDevice(event_type.clone()) + } + } + } + + fn matches_request(&self, _: &FilterEventType) -> bool { + // There is no way to request events. We will only need to run checks on sending and receiving already existing + // events. + false + } +} + +impl fmt::Display for ToDeviceEventFilter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{0}", self.0) + } +} + +#[derive(Debug, Deserialize, PartialEq, Serialize)] +#[serde(untagged)] +pub enum FilterEventType { + Timeline(TimelineEventType), + ToDevice(ToDeviceEventType), +} + +impl From for FilterEventType { + fn from(value: TimelineEventType) -> Self { + Self::Timeline(value) + } +} + +impl From for FilterEventType { + fn from(value: ToDeviceEventType) -> Self { + Self::ToDevice(value) + } +} +impl From for FilterEventType { + fn from(value: StateEventType) -> Self { + TimelineEventType::from(value).into() + } +} +impl From for FilterEventType { + fn from(value: MessageLikeEventType) -> Self { + TimelineEventType::from(value).into() } } #[derive(Debug, Deserialize)] -pub(super) struct MatrixEventFilterInput { +pub(super) struct MatrixEventFilterInputData { #[serde(rename = "type")] - pub(super) event_type: TimelineEventType, + pub(super) event_type: FilterEventType, pub(super) state_key: Option, pub(super) content: MatrixEventContent, } +#[derive(Debug)] +pub(crate) enum MatrixEventFilterInput { + Timeline(MatrixEventFilterInputData), + ToDevice(MatrixEventFilterInputData), +} + #[derive(Debug, Default, Deserialize)] pub(super) struct MatrixEventContent { pub(super) msgtype: Option, @@ -144,32 +231,38 @@ pub(super) struct MatrixEventContent { mod tests { use ruma::events::{MessageLikeEventType, StateEventType, TimelineEventType}; + use crate::widget::filter::EventAndRequestFilter; + use super::{ - EventFilter, MatrixEventContent, MatrixEventFilterInput, MessageLikeEventFilter, - StateEventFilter, + EventFilter, MatrixEventContent, MatrixEventFilterInput, MatrixEventFilterInputData, + MessageLikeEventFilter, StateEventFilter, }; fn message_event(event_type: TimelineEventType) -> MatrixEventFilterInput { - MatrixEventFilterInput { event_type, state_key: None, content: Default::default() } + MatrixEventFilterInput::Timeline(MatrixEventFilterInputData { + event_type: event_type.into(), + state_key: None, + content: Default::default(), + }) } fn message_event_with_msgtype( event_type: TimelineEventType, msgtype: String, ) -> MatrixEventFilterInput { - MatrixEventFilterInput { - event_type, + MatrixEventFilterInput::Timeline(MatrixEventFilterInputData { + event_type: event_type.into(), state_key: None, content: MatrixEventContent { msgtype: Some(msgtype) }, - } + }) } fn state_event(event_type: TimelineEventType, state_key: String) -> MatrixEventFilterInput { - MatrixEventFilterInput { - event_type, + MatrixEventFilterInput::Timeline(MatrixEventFilterInputData { + event_type: event_type.into(), state_key: Some(state_key), content: Default::default(), - } + }) } // Tests against an `m.room.message` filter with `msgtype = m.text` @@ -181,7 +274,7 @@ mod tests { #[test] fn text_event_filter_matches_text_event() { - assert!(room_message_text_event_filter().matches(&message_event_with_msgtype( + assert!(room_message_text_event_filter().matches_event(&message_event_with_msgtype( TimelineEventType::RoomMessage, "m.text".to_owned() ))); @@ -189,7 +282,7 @@ mod tests { #[test] fn text_event_filter_does_not_match_image_event() { - assert!(!room_message_text_event_filter().matches(&message_event_with_msgtype( + assert!(!room_message_text_event_filter().matches_event(&message_event_with_msgtype( TimelineEventType::RoomMessage, "m.image".to_owned() ))); @@ -197,7 +290,7 @@ mod tests { #[test] fn text_event_filter_does_not_match_custom_event_with_msgtype() { - assert!(!room_message_text_event_filter().matches(&message_event_with_msgtype( + assert!(!room_message_text_event_filter().matches_event(&message_event_with_msgtype( "io.element.message".into(), "m.text".to_owned() ))); @@ -210,12 +303,12 @@ mod tests { #[test] fn reaction_event_filter_matches_reaction() { - assert!(reaction_event_filter().matches(&message_event(TimelineEventType::Reaction))); + assert!(reaction_event_filter().matches_event(&message_event(TimelineEventType::Reaction))); } #[test] fn reaction_event_filter_does_not_match_room_message() { - assert!(!reaction_event_filter().matches(&message_event_with_msgtype( + assert!(!reaction_event_filter().matches_event(&message_event_with_msgtype( TimelineEventType::RoomMessage, "m.text".to_owned() ))); @@ -223,7 +316,7 @@ mod tests { #[test] fn reaction_event_filter_does_not_match_state_event() { - assert!(!reaction_event_filter().matches(&state_event( + assert!(!reaction_event_filter().matches_event(&state_event( // Use the `m.reaction` event type to make sure the event would pass // the filter without state event checks, even though in practice // that event type won't be used for a state event. @@ -235,7 +328,7 @@ mod tests { #[test] fn reaction_event_filter_does_not_match_state_event_any_key() { assert!( - !reaction_event_filter().matches_state_event_with_any_state_key(&"m.reaction".into()) + !reaction_event_filter().matches_request(&StateEventType::from("m.reaction").into()) ); } @@ -249,13 +342,15 @@ mod tests { #[test] fn self_member_event_filter_matches_self_member_event() { - assert!(self_member_event_filter() - .matches(&state_event(TimelineEventType::RoomMember, "@self:example.me".to_owned()))); + assert!(self_member_event_filter().matches_event(&state_event( + TimelineEventType::RoomMember, + "@self:example.me".to_owned() + ))); } #[test] fn self_member_event_filter_does_not_match_somebody_elses_member_event() { - assert!(!self_member_event_filter().matches(&state_event( + assert!(!self_member_event_filter().matches_event(&state_event( TimelineEventType::RoomMember, "@somebody_else.example.me".to_owned() ))); @@ -263,7 +358,7 @@ mod tests { #[test] fn self_member_event_filter_does_not_match_unrelated_state_event_with_same_state_key() { - assert!(!self_member_event_filter().matches(&state_event( + assert!(!self_member_event_filter().matches_event(&state_event( TimelineEventType::from("io.element.test_state_event"), "@self.example.me".to_owned() ))); @@ -271,13 +366,14 @@ mod tests { #[test] fn self_member_event_filter_does_not_match_reaction_event() { - assert!(!self_member_event_filter().matches(&message_event(TimelineEventType::Reaction))); + assert!( + !self_member_event_filter().matches_event(&message_event(TimelineEventType::Reaction)) + ); } #[test] fn self_member_event_filter_only_matches_specific_state_key() { - assert!(!self_member_event_filter() - .matches_state_event_with_any_state_key(&StateEventType::RoomMember)); + assert!(!self_member_event_filter().matches_request(&StateEventType::RoomMember.into())); } // Tests against an `m.room.member` filter with any `state_key`. @@ -288,24 +384,23 @@ mod tests { #[test] fn member_event_filter_matches_some_member_event() { assert!(member_event_filter() - .matches(&state_event(TimelineEventType::RoomMember, "@foo.bar.baz".to_owned()))); + .matches_event(&state_event(TimelineEventType::RoomMember, "@foo.bar.baz".to_owned()))); } #[test] fn member_event_filter_does_not_match_room_name_event() { assert!(!member_event_filter() - .matches(&state_event(TimelineEventType::RoomName, "".to_owned()))); + .matches_event(&state_event(TimelineEventType::RoomName, "".to_owned()))); } #[test] fn member_event_filter_does_not_match_reaction_event() { - assert!(!member_event_filter().matches(&message_event(TimelineEventType::Reaction))); + assert!(!member_event_filter().matches_event(&message_event(TimelineEventType::Reaction))); } #[test] fn member_event_filter_matches_any_state_key() { - assert!(member_event_filter() - .matches_state_event_with_any_state_key(&StateEventType::RoomMember)); + assert!(member_event_filter().matches_request(&StateEventType::RoomMember.into())); } // Tests against an `m.room.topic` filter with `state_key = ""` @@ -318,8 +413,7 @@ mod tests { #[test] fn topic_event_filter_does_not_match_any_state_key() { - assert!(!topic_event_filter() - .matches_state_event_with_any_state_key(&StateEventType::RoomTopic)); + assert!(!topic_event_filter().matches_request(&StateEventType::RoomTopic.into())); } // Tests against an `m.room.message` filter with `msgtype = m.custom` @@ -339,37 +433,34 @@ mod tests { #[test] fn room_message_event_type_matches_room_message_text_event_filter() { assert!(room_message_text_event_filter() - .matches_message_like_event_type(&MessageLikeEventType::RoomMessage)); + .matches_request(&MessageLikeEventType::RoomMessage.into())); } #[test] fn reaction_event_type_does_not_match_room_message_text_event_filter() { assert!(!room_message_text_event_filter() - .matches_message_like_event_type(&MessageLikeEventType::Reaction)); + .matches_request(&MessageLikeEventType::Reaction.into())); } #[test] fn room_message_event_type_matches_room_message_custom_event_filter() { assert!(room_message_custom_event_filter() - .matches_message_like_event_type(&MessageLikeEventType::RoomMessage)); + .matches_request(&MessageLikeEventType::RoomMessage.into())); } #[test] fn reaction_event_type_does_not_match_room_message_custom_event_filter() { assert!(!room_message_custom_event_filter() - .matches_message_like_event_type(&MessageLikeEventType::Reaction)); + .matches_request(&MessageLikeEventType::Reaction.into())); } #[test] fn room_message_event_type_matches_room_message_event_filter() { - assert!(room_message_filter() - .matches_message_like_event_type(&MessageLikeEventType::RoomMessage)); + assert!(room_message_filter().matches_request(&MessageLikeEventType::RoomMessage.into())); } #[test] fn reaction_event_type_does_not_match_room_message_event_filter() { - assert!( - !room_message_filter().matches_message_like_event_type(&MessageLikeEventType::Reaction) - ); + assert!(!room_message_filter().matches_request(&MessageLikeEventType::Reaction.into())); } } diff --git a/crates/matrix-sdk/src/widget/machine/driver_req.rs b/crates/matrix-sdk/src/widget/machine/driver_req.rs index feeba69888b..5cdda61db88 100644 --- a/crates/matrix-sdk/src/widget/machine/driver_req.rs +++ b/crates/matrix-sdk/src/widget/machine/driver_req.rs @@ -14,12 +14,20 @@ //! A high-level API for requests that we send to the matrix driver. -use std::marker::PhantomData; +use std::{collections::BTreeMap, marker::PhantomData}; use ruma::{ - api::client::{account::request_openid_token, delayed_events::update_delayed_event}, - events::{AnyTimelineEvent, MessageLikeEventType, StateEventType, TimelineEventType}, + api::client::{ + account::request_openid_token, delayed_events::update_delayed_event, + to_device::send_event_to_device, + }, + events::{ + AnyTimelineEvent, AnyToDeviceEventContent, MessageLikeEventType, StateEventType, + TimelineEventType, ToDeviceEventType, + }, serde::Raw, + to_device::DeviceIdOrAllDevices, + OwnedUserId, }; use serde::Deserialize; use serde_json::value::RawValue as RawJsonValue; @@ -52,6 +60,9 @@ pub(crate) enum MatrixDriverRequestData { /// Send matrix event that corresponds to the given description. SendMatrixEvent(SendEventRequest), + /// Send matrix event that corresponds to the given description. + SendToDeviceEvent(SendToDeviceRequest), + /// Data for sending a UpdateDelayedEvent client server api request. UpdateDelayedEvent(UpdateDelayedEventRequest), } @@ -251,6 +262,43 @@ impl FromMatrixDriverResponse for SendEventResponse { } } +/// Ask the client to send matrix event that corresponds to the given +/// description and returns an event ID (or a delay ID, +/// see [MSC4140](https://github.com/matrix-org/matrix-spec-proposals/pull/4140)) as a response. +#[derive(Clone, Debug, Deserialize)] +pub(crate) struct SendToDeviceRequest { + /// The type of the event. + #[serde(rename = "type")] + pub(crate) event_type: ToDeviceEventType, + // If the to_device message should be encrypted or not. + pub(crate) encrypted: bool, + /// The messages body of the to + pub(crate) messages: + BTreeMap>>, +} + +impl From for MatrixDriverRequestData { + fn from(value: SendToDeviceRequest) -> Self { + MatrixDriverRequestData::SendToDeviceEvent(value) + } +} + +impl MatrixDriverRequest for SendToDeviceRequest { + type Response = send_event_to_device::v3::Response; +} + +impl FromMatrixDriverResponse for send_event_to_device::v3::Response { + fn from_response(ev: MatrixDriverResponse) -> Option { + match ev { + MatrixDriverResponse::MatrixToDeviceSent(response) => Some(response), + _ => { + error!("bug in MatrixDriver, received wrong event response"); + None + } + } + } +} + /// Ask the client to send a UpdateDelayedEventRequest with the given `delay_id` /// and `action`. Defined by [MSC4157](https://github.com/matrix-org/matrix-spec-proposals/pull/4157) #[derive(Deserialize, Debug, Clone)] diff --git a/crates/matrix-sdk/src/widget/machine/from_widget.rs b/crates/matrix-sdk/src/widget/machine/from_widget.rs index 7dcb3c86ad7..52c276259b0 100644 --- a/crates/matrix-sdk/src/widget/machine/from_widget.rs +++ b/crates/matrix-sdk/src/widget/machine/from_widget.rs @@ -15,8 +15,9 @@ use std::fmt; use ruma::{ - api::client::delayed_events::{ - delayed_message_event, delayed_state_event, update_delayed_event, + api::client::{ + delayed_events::{delayed_message_event, delayed_state_event, update_delayed_event}, + to_device::send_event_to_device, }, events::{AnyTimelineEvent, MessageLikeEventType, StateEventType}, serde::Raw, @@ -24,7 +25,7 @@ use ruma::{ }; use serde::{Deserialize, Serialize}; -use super::{SendEventRequest, UpdateDelayedEventRequest}; +use super::{driver_req::SendToDeviceRequest, SendEventRequest, UpdateDelayedEventRequest}; use crate::widget::StateKeySelector; #[derive(Deserialize, Debug)] @@ -37,6 +38,8 @@ pub(super) enum FromWidgetRequest { #[serde(rename = "org.matrix.msc2876.read_events")] ReadEvent(ReadEventRequest), SendEvent(SendEventRequest), + #[serde(rename = "org.matrix.msc3819.send_to_device")] + SendToDevice(SendToDeviceRequest), #[serde(rename = "org.matrix.msc4157.update_delayed_event")] DelayedEventUpdate(UpdateDelayedEventRequest), } @@ -175,9 +178,15 @@ impl From for SendEventResponse { /// which derives Serialize. (The response struct from Ruma does not derive /// serialize) #[derive(Serialize, Debug)] -pub(crate) struct UpdateDelayedEventResponse {} -impl From for UpdateDelayedEventResponse { +pub(crate) struct EmptySerializableEvenResponse {} +impl From for EmptySerializableEvenResponse { fn from(_: update_delayed_event::unstable::Response) -> Self { Self {} } } + +impl From for EmptySerializableEvenResponse { + fn from(_: send_event_to_device::v3::Response) -> Self { + Self {} + } +} diff --git a/crates/matrix-sdk/src/widget/machine/incoming.rs b/crates/matrix-sdk/src/widget/machine/incoming.rs index e0ca2c1b968..95f5bc5ff6d 100644 --- a/crates/matrix-sdk/src/widget/machine/incoming.rs +++ b/crates/matrix-sdk/src/widget/machine/incoming.rs @@ -13,8 +13,8 @@ // limitations under the License. use ruma::{ - api::client::{account::request_openid_token, delayed_events}, - events::AnyTimelineEvent, + api::client::{account::request_openid_token, delayed_events, to_device::send_event_to_device}, + events::{AnyTimelineEvent, AnyToDeviceEvent}, serde::Raw, }; use serde::{de, Deserialize, Deserializer}; @@ -45,7 +45,13 @@ pub(crate) enum IncomingMessage { /// /// This means that the machine previously subscribed to some events /// (`Action::Subscribe` request). - MatrixEventReceived(Raw), + MatrixEventReceived(MatrixEvent), +} + +#[derive(Debug)] +pub enum MatrixEvent { + Timeline(Raw), + ToDevice(Raw), } pub(crate) enum MatrixDriverResponse { @@ -62,6 +68,7 @@ pub(crate) enum MatrixDriverResponse { /// Client sent some matrix event. The response contains the event ID. /// A response to an `Action::SendMatrixEvent` command. MatrixEventSent(SendEventResponse), + MatrixToDeviceSent(send_event_to_device::v3::Response), MatrixDelayedEventUpdate(delayed_events::update_delayed_event::unstable::Response), } diff --git a/crates/matrix-sdk/src/widget/machine/mod.rs b/crates/matrix-sdk/src/widget/machine/mod.rs index 31e96fc7da9..bbf0af7358b 100644 --- a/crates/matrix-sdk/src/widget/machine/mod.rs +++ b/crates/matrix-sdk/src/widget/machine/mod.rs @@ -18,8 +18,8 @@ use std::{fmt, iter, time::Duration}; -use driver_req::UpdateDelayedEventRequest; -use from_widget::UpdateDelayedEventResponse; +use driver_req::{SendToDeviceRequest, UpdateDelayedEventRequest}; +use from_widget::EmptySerializableEvenResponse; use indexmap::IndexMap; use ruma::{ serde::{JsonObject, Raw}, @@ -27,6 +27,7 @@ use ruma::{ }; use serde::Serialize; use serde_json::value::RawValue as RawJsonValue; +use to_widget::NotifyNewToDeviceEvent; use tracing::{debug, error, info, instrument, warn}; use uuid::Uuid; @@ -51,10 +52,12 @@ use self::{ use super::WidgetDriver; use super::{ capabilities, - filter::{MatrixEventContent, MatrixEventFilterInput}, + filter::{ + FilterEventType, MatrixEventContent, MatrixEventFilterInput, MatrixEventFilterInputData, + }, Capabilities, StateKeySelector, }; -use crate::widget::EventFilter; +use crate::widget::{filter::EventAndRequestFilter, EventFilter}; mod driver_req; mod from_widget; @@ -68,7 +71,7 @@ mod to_widget; pub(crate) use self::{ driver_req::{MatrixDriverRequestData, ReadStateEventRequest, SendEventRequest}, from_widget::SendEventResponse, - incoming::{IncomingMessage, MatrixDriverResponse}, + incoming::{IncomingMessage, MatrixDriverResponse, MatrixEvent}, }; /// Action (a command) that client (driver) must perform. @@ -95,12 +98,20 @@ pub(crate) enum Action { /// Subscribe to the events in the *current* room, i.e. a room which this /// widget is instantiated with. The client is aware of the room. #[allow(dead_code)] - Subscribe, + SubscribeTimeline, /// Unsuscribe from the events in the *current* room. Symmetrical to /// `Subscribe`. #[allow(dead_code)] - Unsubscribe, + UnsubscribeTimeline, + + /// Subscribe to to-events events, this widget has access to. + #[allow(dead_code)] + SubscribeToDevice, + + //// Unsubscribe from to-events events. + #[allow(dead_code)] + UnsubscribeToDevice, } /// No I/O state machine. @@ -147,7 +158,7 @@ impl WidgetMachine { self.pending_to_widget_requests.remove_expired(); self.pending_matrix_driver_requests.remove_expired(); - match event { + let vec = match event { IncomingMessage::WidgetMessage(raw) => self.process_widget_message(&raw), IncomingMessage::MatrixDriverResponse { request_id, response } => { self.process_matrix_driver_response(request_id, response) @@ -157,16 +168,23 @@ impl WidgetMachine { error!("Received matrix event before capabilities negotiation"); return Vec::new(); }; - capabilities .raw_event_matches_read_filter(&event) .then(|| { - let action = self.send_to_widget_request(NotifyNewMatrixEvent(event)).1; + let action = match event { + MatrixEvent::Timeline(event) => { + self.send_to_widget_request(NotifyNewMatrixEvent(event)).1 + } + MatrixEvent::ToDevice(event) => { + self.send_to_widget_request(NotifyNewToDeviceEvent(event)).1 + } + }; action.map(|a| vec![a]).unwrap_or_default() }) .unwrap_or_default() } - } + }; + vec } fn process_widget_message(&mut self, raw: &str) -> Vec { @@ -230,6 +248,11 @@ impl WidgetMachine { .map(|a| vec![a]) .unwrap_or_default(), + FromWidgetRequest::SendToDevice(req) => self + .process_to_device_request(req, raw_request) + .map(|a| vec![a]) + .unwrap_or_default(), + FromWidgetRequest::GetOpenId {} => { let (request, request_action) = self.send_matrix_driver_request(RequestOpenId); request.then(|res, machine| { @@ -273,7 +296,7 @@ impl WidgetMachine { raw_request, // This is mapped to another type because the update_delay_event::Response // does not impl Serialize - res.map(Into::::into), + res.map(Into::::into), )] }); request_action.map(|a| vec![a]).unwrap_or_default() @@ -293,7 +316,7 @@ impl WidgetMachine { match request { ReadEventRequest::ReadMessageLikeEvent { event_type, limit } => { - let filter_fn = |f: &EventFilter| f.matches_message_like_event_type(&event_type); + let filter_fn = |f: &EventFilter| f.matches_request(&event_type.clone().into()); if !capabilities.read.iter().any(filter_fn) { return Some(self.send_from_widget_error_response(raw_request, "Not allowed")); } @@ -310,7 +333,10 @@ impl WidgetMachine { return Err(err.into()); }; - events.retain(|e| capabilities.raw_event_matches_read_filter(e)); + events.retain(|e| { + capabilities + .raw_event_matches_read_filter(&MatrixEvent::Timeline(e.clone())) + }); Ok(ReadEventResponse { events }) }); vec![machine.send_from_widget_result_response(raw_request, response)] @@ -318,21 +344,23 @@ impl WidgetMachine { action } ReadEventRequest::ReadStateEvent { event_type, state_key } => { + let event_type_filter: FilterEventType = event_type.clone().into(); let allowed = match &state_key { StateKeySelector::Any => capabilities .read .iter() - .any(|filter| filter.matches_state_event_with_any_state_key(&event_type)), + .any(|filter| filter.matches_request(&event_type_filter)), StateKeySelector::Key(state_key) => { - let filter_in = MatrixEventFilterInput { - event_type: event_type.to_string().into(), - state_key: Some(state_key.clone()), - // content doesn't matter for state events - content: MatrixEventContent::default(), - }; - - capabilities.read.iter().any(|filter| filter.matches(&filter_in)) + let filter_in = + MatrixEventFilterInput::Timeline(MatrixEventFilterInputData { + event_type: event_type_filter, + state_key: Some(state_key.clone()), + // content doesn't matter for state events + content: MatrixEventContent::default(), + }); + + capabilities.read.iter().any(|filter| filter.matches_event(&filter_in)) } }; @@ -361,8 +389,8 @@ impl WidgetMachine { return None; }; - let filter_in = MatrixEventFilterInput { - event_type: request.event_type.clone(), + let filter_in = MatrixEventFilterInput::Timeline(MatrixEventFilterInputData { + event_type: request.event_type.clone().into(), state_key: request.state_key.clone(), content: serde_json::from_str(request.content.get()).unwrap_or_else(|e| { debug!("Failed to deserialize event content for filter: {e}"); @@ -370,7 +398,7 @@ impl WidgetMachine { // that matches with it when it otherwise wouldn't. Default::default() }), - }; + }); if !capabilities.send_delayed_event && request.delay.is_some() { return Some(self.send_from_widget_error_response( raw_request, @@ -380,7 +408,7 @@ impl WidgetMachine { ), )); } - if !capabilities.send.iter().any(|filter| filter.matches(&filter_in)) { + if !capabilities.send.iter().any(|filter| filter.matches_event(&filter_in)) { return Some(self.send_from_widget_error_response(raw_request, "Not allowed")); } @@ -394,6 +422,41 @@ impl WidgetMachine { action } + fn process_to_device_request( + &mut self, + request: SendToDeviceRequest, + raw_request: Raw, + ) -> Option { + let CapabilitiesState::Negotiated(capabilities) = &self.capabilities else { + error!("Received send event request before capabilities negotiation"); + return None; + }; + + let filter_in = MatrixEventFilterInput::ToDevice(MatrixEventFilterInputData { + event_type: request.event_type.clone().into(), + state_key: None, + content: MatrixEventContent { msgtype: None }, + }); + + if !capabilities.send.iter().any(|filter| filter.matches_event(&filter_in)) { + return Some(self.send_from_widget_error_response( + raw_request, + format!("Not allowed to send to-device message of type: {}", request.event_type), + )); + } + + let (request, action) = self.send_matrix_driver_request(request); + request.then(|result, machine| { + vec![machine.send_from_widget_result_response( + raw_request, + // This is mapped to another type because the update_delay_event::Response + // does not impl Serialize + result.map(Into::::into), + )] + }); + action + } + #[instrument(skip_all, fields(?request_id))] fn process_to_widget_response( &mut self, @@ -553,13 +616,21 @@ impl WidgetMachine { let update = NotifyCapabilitiesChanged { approved, requested }; let (_request, action) = machine.send_to_widget_request(update); - (subscribe_required).then(|| Action::Subscribe).into_iter().chain(action).collect() + (subscribe_required) + .then(|| Action::SubscribeTimeline) + .into_iter() + .chain(action) + .collect() }); action.map(|a| vec![a]).unwrap_or_default() }); - unsubscribe_required.then(|| Action::Unsubscribe).into_iter().chain(action).collect() + unsubscribe_required + .then(|| Action::UnsubscribeTimeline) + .into_iter() + .chain(action) + .collect() } } diff --git a/crates/matrix-sdk/src/widget/machine/tests/capabilities.rs b/crates/matrix-sdk/src/widget/machine/tests/capabilities.rs index 3c06da002ae..d1ed5c8680b 100644 --- a/crates/matrix-sdk/src/widget/machine/tests/capabilities.rs +++ b/crates/matrix-sdk/src/widget/machine/tests/capabilities.rs @@ -196,7 +196,7 @@ pub(super) fn assert_capabilities_dance( .any(|c| capability.starts_with(c)) { let action = actions.remove(0); - assert_matches!(action, Action::Subscribe); + assert_matches!(action, Action::SubscribeTimeline); } // Inform the widget about the acquired capabilities. diff --git a/crates/matrix-sdk/src/widget/machine/to_widget.rs b/crates/matrix-sdk/src/widget/machine/to_widget.rs index 5315ffb5850..08ae59f2e9f 100644 --- a/crates/matrix-sdk/src/widget/machine/to_widget.rs +++ b/crates/matrix-sdk/src/widget/machine/to_widget.rs @@ -14,7 +14,10 @@ use std::marker::PhantomData; -use ruma::{events::AnyTimelineEvent, serde::Raw}; +use ruma::{ + events::{AnyTimelineEvent, AnyToDeviceEvent}, + serde::Raw, +}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::value::RawValue as RawJsonValue; use tracing::error; @@ -130,3 +133,14 @@ impl ToWidgetRequest for NotifyNewMatrixEvent { #[derive(Deserialize)] pub(crate) struct Empty {} + +/// Notify the widget that we received a new matrix event. +/// This is a "response" to the widget subscribing to the events in the room. +#[derive(Serialize)] +#[serde(transparent)] +pub(crate) struct NotifyNewToDeviceEvent(pub(crate) Raw); + +impl ToWidgetRequest for NotifyNewToDeviceEvent { + const ACTION: &'static str = "send_to_device"; + type ResponseData = Empty; +} diff --git a/crates/matrix-sdk/src/widget/matrix.rs b/crates/matrix-sdk/src/widget/matrix.rs index b0d09b18ab7..02b5a67a067 100644 --- a/crates/matrix-sdk/src/widget/matrix.rs +++ b/crates/matrix-sdk/src/widget/matrix.rs @@ -17,28 +17,36 @@ use std::collections::BTreeMap; -use matrix_sdk_base::deserialized_responses::RawAnySyncOrStrippedState; +use futures_util::future::join_all; +use matrix_sdk_base::deserialized_responses::{EncryptionInfo, RawAnySyncOrStrippedState}; use ruma::{ api::client::{ account::request_openid_token::v3::{Request as OpenIdRequest, Response as OpenIdResponse}, delayed_events::{self, update_delayed_event::unstable::UpdateAction}, filter::RoomEventFilter, + to_device::send_event_to_device::{self, v3::Request as RumaToDeviceRequest}, }, assign, events::{ AnyMessageLikeEventContent, AnyStateEventContent, AnySyncTimelineEvent, AnyTimelineEvent, - MessageLikeEventType, StateEventType, TimelineEventType, + AnyToDeviceEvent, AnyToDeviceEventContent, MessageLikeEventType, StateEventType, + TimelineEventType, ToDeviceEventType, }, serde::Raw, - RoomId, TransactionId, + to_device::DeviceIdOrAllDevices, + OwnedUserId, RoomId, TransactionId, UserId, }; -use serde_json::value::RawValue as RawJsonValue; -use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver}; -use tracing::error; +use serde_json::{json, value::RawValue as RawJsonValue, Value}; +use tokio::{ + spawn, + sync::mpsc::{unbounded_channel, UnboundedReceiver}, +}; +use tracing::{error, info}; use super::{machine::SendEventResponse, StateKeySelector}; use crate::{ - event_handler::EventHandlerDropGuard, room::MessagesOptions, HttpResult, Result, Room, + encryption::identities::Device, event_handler::EventHandlerDropGuard, room::MessagesOptions, + Client, Error, HttpResult, Result, Room, }; /// Thin wrapper around a [`Room`] that provides functionality relevant for @@ -47,6 +55,14 @@ pub(crate) struct MatrixDriver { room: Room, } +// pub enum SendToDeviceResult { +// HttpError(HttpResult), +// } +// impl From> for SendEventResponse { +// fn from(value: HttpResult) -> Self { +// SendToDeviceResult::HttpError(value) +// } +// } impl MatrixDriver { /// Creates a new `MatrixDriver` for a given `room`. pub(crate) fn new(room: Room) -> Self { @@ -161,7 +177,7 @@ impl MatrixDriver { /// Starts forwarding new room events. Once the returned `EventReceiver` /// is dropped, forwarding will be stopped. - pub(crate) fn events(&self) -> EventReceiver { + pub(crate) fn events(&self) -> EventReceiver { let (tx, rx) = unbounded_channel(); let room_id = self.room.room_id().to_owned(); let handle = self.room.add_event_handler(move |raw: Raw| { @@ -172,17 +188,167 @@ impl MatrixDriver { let drop_guard = self.room.client().event_handler_drop_guard(handle); EventReceiver { rx, _drop_guard: drop_guard } } + + /// Starts forwarding new room events. Once the returned `EventReceiver` + /// is dropped, forwarding will be stopped. + pub(crate) fn to_device_events(&self) -> EventReceiver { + let (tx, rx) = unbounded_channel(); + + let to_device_handle = self.room.client().add_event_handler( + move |raw: Raw, encryption_info: Option| { + // Deserialize the Raw to a mutable structure + let mut event_with_encryption_flag: Value = + raw.deserialize_as().expect("Invalid event JSON"); + + if let Some(content) = event_with_encryption_flag.get_mut("content") { + content["encrypted"] = json!(encryption_info.is_some()); + } + let ev_for_widget = match Raw::::from_json_string( + event_with_encryption_flag.to_string(), + ) { + Ok(ev) => ev, + Err(_) => raw, + }; + + let _ = tx.send(ev_for_widget); + async {} + }, + ); + + let drop_guard = self.room.client().event_handler_drop_guard(to_device_handle); + EventReceiver { rx, _drop_guard: drop_guard } + } + + pub(crate) async fn send_to_device( + &self, + event_type: ToDeviceEventType, + encrypted: bool, + messages: BTreeMap< + OwnedUserId, + BTreeMap>, + >, + ) -> Result { + let client = self.room.client(); + // This encrypts the content for a device. + // A device_id can be "*" in this case the function also computes all devices for a user and + // returns an iterator of devices. + async fn encrypted_content_for_device_or_all_devices( + client: &Client, + unencrypted: &Raw, + device_or_all_id: DeviceIdOrAllDevices, + user_id: &UserId, + event_type: &ToDeviceEventType, + ) -> Result)>> + { + let user_devices = client.encryption().get_user_devices(&user_id).await?; + + let devices: Vec = match device_or_all_id { + DeviceIdOrAllDevices::DeviceId(device_id) => { + vec![user_devices.get(&device_id)].into_iter().flatten().collect() + } + DeviceIdOrAllDevices::AllDevices => user_devices.devices().collect(), + }; + + let content: Value = unencrypted.deserialize_as().map_err(Into::::into)?; + let event_type = event_type.clone(); + let device_content_tasks = devices.into_iter().map(|device| spawn({ + let value = event_type.clone(); + let content = content.clone(); + async move { + let a =match device + .inner + .encrypt_event_raw(&value.to_string(), &content) + .await{ + Ok(encrypted) => Some((device.device_id().to_owned().into(), encrypted.cast())), + Err(e) =>{ info!("Failed to encrypt to_device event from widget for device: {} because, {}", device.device_id(), e); None}, + }; + a + } + })); + let t = join_all(device_content_tasks).await.into_iter().flatten().flatten(); + Ok(t) + } + + // Here we convert the device content map for one user into the same content map with encrypted content + // This needs to flatten the vectors we get from `encrypted_content_for_device_or_all_devices` + // since one DeviceIdOrAllDevices id can be multiple devices. + async fn encrypted_device_content_map( + client: &Client, + user_id: &UserId, + event_type: &ToDeviceEventType, + device_content_map: BTreeMap>, + ) -> BTreeMap> { + let device_map_futures = + device_content_map.into_iter().map(|(device_or_all_id, content)| spawn({let client = client.clone();let user_id = user_id.to_owned();let event_type = event_type.clone();async move { + encrypted_content_for_device_or_all_devices( + &client, + &content, + device_or_all_id, + &user_id, + &event_type, + ) + .await + .map_err(|e| info!("could not encrypt content for to device widget event content: {}. because, {}", content.json(), e)) + .ok() + }})); + // The first flatten takes the iterator of Option's iterators and converts it to just a iterator over Option<(Device, data)>'s. + // The second flatten takes the the iterator over Option<(Device,data)>'s and converts it to just a iterator over Device + join_all(device_map_futures).await.into_iter().flatten().flatten().flatten().collect() + } + + // We first want to get all missing session before we start any to device sending! + client.claim_one_time_keys(messages.iter().map(|(u, _)| u.as_ref())).await?; + + let request = match encrypted { + true => { + let encrypted_content: BTreeMap< + OwnedUserId, + BTreeMap>, + > = join_all(messages.into_iter().map(|(user_id, device_content_map)| { + let event_type = event_type.clone(); + async move { + ( + user_id.clone(), + encrypted_device_content_map( + &self.room.client(), + &user_id, + &event_type, + device_content_map, + ) + .await, + ) + } + })) + .await + .into_iter() + .collect(); + + RumaToDeviceRequest::new_raw( + event_type.clone(), + TransactionId::new(), + encrypted_content, + ) + } + false => RumaToDeviceRequest::new_raw(event_type, TransactionId::new(), messages), + }; + + let response = client.send(request, None).await; + // if let Ok(res){ + // client.mark_request_as_sent(request.request_id(), &res).await; + // }; + response.map_err(Into::into) + } } /// A simple entity that wraps an `UnboundedReceiver` /// along with the drop guard for the room event handler. -pub(crate) struct EventReceiver { - rx: UnboundedReceiver>, +pub(crate) struct EventReceiver { + rx: UnboundedReceiver>, _drop_guard: EventHandlerDropGuard, } -impl EventReceiver { - pub(crate) async fn recv(&mut self) -> Option> { +impl EventReceiver { + pub(crate) async fn recv(&mut self) -> Option> { self.rx.recv().await } } diff --git a/crates/matrix-sdk/src/widget/mod.rs b/crates/matrix-sdk/src/widget/mod.rs index d5f7109e4ff..048d39f41b8 100644 --- a/crates/matrix-sdk/src/widget/mod.rs +++ b/crates/matrix-sdk/src/widget/mod.rs @@ -17,6 +17,7 @@ use std::{fmt, time::Duration}; use async_channel::{Receiver, Sender}; +use machine::MatrixEvent; use ruma::api::client::delayed_events::DelayParameters; use serde::de::{self, Deserialize, Deserializer, Visitor}; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; @@ -39,7 +40,7 @@ mod settings; pub use self::{ capabilities::{Capabilities, CapabilitiesProvider}, - filter::{EventFilter, MessageLikeEventFilter, StateEventFilter}, + filter::{EventFilter, MessageLikeEventFilter, StateEventFilter, ToDeviceEventFilter}, settings::{ ClientProperties, EncryptionSystem, VirtualElementCallWidgetOptions, WidgetSettings, }, @@ -151,6 +152,7 @@ impl WidgetDriver { widget_machine: client_api, matrix_driver: MatrixDriver::new(room.clone()), event_forwarding_guard: None, + to_device_event_forwarding_guard: None, to_widget_tx: self.to_widget_tx, events_tx, capabilities_provider, @@ -175,6 +177,7 @@ struct ProcessingContext { widget_machine: WidgetMachine, matrix_driver: MatrixDriver, event_forwarding_guard: Option, + to_device_event_forwarding_guard: Option, to_widget_tx: Sender, events_tx: UnboundedSender, capabilities_provider: T, @@ -247,13 +250,24 @@ impl ProcessingContext { .await .map(MatrixDriverResponse::MatrixDelayedEventUpdate) .map_err(|e: HttpError| e.to_string()), + + MatrixDriverRequestData::SendToDeviceEvent(send_to_device_request) => self + .matrix_driver + .send_to_device( + send_to_device_request.event_type, + send_to_device_request.encrypted, + send_to_device_request.messages, + ) + .await + .map(MatrixDriverResponse::MatrixToDeviceSent) + .map_err(|e| e.to_string()), }; self.events_tx .send(IncomingMessage::MatrixDriverResponse { request_id, response }) .map_err(|_| ())?; } - Action::Subscribe => { + Action::SubscribeTimeline => { // Only subscribe if we are not already subscribed. if self.event_forwarding_guard.is_none() { let (stop_forwarding, guard) = { @@ -269,16 +283,44 @@ impl ProcessingContext { tokio::select! { _ = stop_forwarding.cancelled() => { return } Some(event) = matrix.recv() => { - let _ = events_tx.send(IncomingMessage::MatrixEventReceived(event)); + let ev = MatrixEvent::Timeline(event); + let _ = events_tx.send(IncomingMessage::MatrixEventReceived(ev)); } } } }); } } - Action::Unsubscribe => { + Action::UnsubscribeTimeline => { self.event_forwarding_guard = None; } + Action::SubscribeToDevice => { + // Only subscribe if we are not already subscribed. + if self.to_device_event_forwarding_guard.is_none() { + let (stop_forwarding, guard) = { + let token = CancellationToken::new(); + (token.child_token(), token.drop_guard()) + }; + + self.to_device_event_forwarding_guard = Some(guard); + let (mut matrix, events_tx) = + (self.matrix_driver.to_device_events(), self.events_tx.clone()); + tokio::spawn(async move { + loop { + tokio::select! { + _ = stop_forwarding.cancelled() => { return } + Some(event) = matrix.recv() => { + let ev = MatrixEvent::ToDevice(event); + let _ = events_tx.send(IncomingMessage::MatrixEventReceived(ev)); + } + } + } + }); + } + } + Action::UnsubscribeToDevice => { + self.to_device_event_forwarding_guard = None; + } } Ok(()) diff --git a/crates/matrix-sdk/tests/integration/widget.rs b/crates/matrix-sdk/tests/integration/widget.rs index 72d29fcbd04..f950bd70814 100644 --- a/crates/matrix-sdk/tests/integration/widget.rs +++ b/crates/matrix-sdk/tests/integration/widget.rs @@ -42,7 +42,7 @@ use ruma::{ user_id, OwnedRoomId, }; use serde::Serialize; -use serde_json::{json, Value as JsonValue}; +use serde_json::{json, Map, Value as JsonValue}; use tracing::error; use wiremock::{ matchers::{header, method, path_regex, query_param}, @@ -111,15 +111,15 @@ async fn send_request( action: &str, data: impl Serialize, ) { - let sent = driver_handle - .send(json_string!({ - "api": "fromWidget", - "widgetId": WIDGET_ID, - "requestId": request_id, - "action": action, - "data": data, - })) - .await; + let json_string = json_string!({ + "api": "fromWidget", + "widgetId": WIDGET_ID, + "requestId": request_id, + "action": action, + "data": data, + }); + println!("Json string sent from the widget {}", json_string); + let sent = driver_handle.send(json_string).await; assert!(sent); } @@ -488,6 +488,8 @@ async fn test_receive_live_events() { )), ); + sync_builder.add_to_device_event(json!({"some":"Event"})); + mock_sync(&mock_server, sync_builder.build_json_sync_response(), None).await; let _response = client.sync_once(SyncSettings::new().timeout(Duration::from_millis(3000))).await.unwrap(); @@ -827,6 +829,108 @@ async fn test_try_update_delayed_event_without_permission_negotiate() { } } +async fn send_to_device_test_helper( + event_type: &str, + data: JsonValue, + expected_response: JsonValue, +) -> JsonValue { + let (_, mock_server, driver_handle) = run_test_driver(false).await; + + negotiate_capabilities( + &driver_handle, + json!([ + "org.matrix.msc3819.send.to_device:my.custom.to_device_type", + "org.matrix.msc3819.send.to_device:my.other_type" + ]), + ) + .await; + + Mock::given(method("PUT")) + .and(path_regex(format!(r"^/_matrix/client/r0/sendToDevice/{}/.*", event_type))) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({}))) + .expect(1) + .mount(&mock_server) + .await; + + send_request(&driver_handle, event_type, "org.matrix.msc3819.send_to_device", data).await; + + // Receive the response + let msg = recv_message(&driver_handle).await; + assert_eq!(msg["api"], "fromWidget"); + assert_eq!(msg["action"], "org.matrix.msc3819.send_to_device"); + let response = msg["response"].clone(); + assert_eq!(&response, &expected_response); + + // Make sure the event-sending endpoint was hit exactly once + mock_server.verify().await; + + response +} + +#[async_test] +async fn test_send_to_device_event() { + let r = send_to_device_test_helper( + "my.custom.to_device_type", + json!({ + "type": "my.custom.to_device_type", + "encrypted": false, + "messages":{ + "@username:test.org": { + "DEVICEID": { + "param1":"test", + }, + }, + } + }), + json! {{}}, + ) + .await; + assert_eq!(r.as_object(), Some(Map::new()).as_ref()); +} + +#[async_test] +async fn test_error_to_device_event_no_permission() { + let r = send_to_device_test_helper( + "my.custom.to_device_type", + json!({ + "type": "my.unallowed_type", + "encrypted": false, + "messages":{ + "@username:test.org": { + "DEVICEID": { + "param1":"test", + }, + }, + } + }), + // this means the server did not get the correct event type + json! {{"error": {"message": "Not allowed to send to-device message of type: my.unallowed_type"}}}, + ) + .await; + assert_eq!(r.as_object(), Some(Map::new()).as_ref()); +} + +#[async_test] +async fn test_send_encrypted_to_device_event() { + let r = send_to_device_test_helper( + "my.custom.to_device_type", + json!({ + "type": "my.custom.to_device_type", + "encrypted": true, + "messages":{ + "@username:test.org": { + "DEVICEID": { + "param1":"test", + }, + }, + } + }), + json! {{}}, + ) + .await; + assert_eq!(r.as_object(), Some(Map::new()).as_ref()); +} + async fn negotiate_capabilities(driver_handle: &WidgetDriverHandle, caps: JsonValue) { { // Receive toWidget capabilities request diff --git a/testing/matrix-sdk-test/src/sync_builder/mod.rs b/testing/matrix-sdk-test/src/sync_builder/mod.rs index ea6d05bd8cb..e9bc291343a 100644 --- a/testing/matrix-sdk-test/src/sync_builder/mod.rs +++ b/testing/matrix-sdk-test/src/sync_builder/mod.rs @@ -8,7 +8,7 @@ use ruma::{ }, IncomingResponse, }, - events::{presence::PresenceEvent, AnyGlobalAccountDataEvent}, + events::{presence::PresenceEvent, AnyGlobalAccountDataEvent, AnyToDeviceEvent}, serde::Raw, OwnedRoomId, OwnedUserId, UserId, }; @@ -54,6 +54,7 @@ pub struct SyncResponseBuilder { batch_counter: i64, /// The device lists of the user. changed_device_lists: Vec, + to_device_events: Vec>, } impl SyncResponseBuilder { @@ -143,6 +144,12 @@ impl SyncResponseBuilder { self } + /// Add a presence event. + pub fn add_to_device_event(&mut self, event: JsonValue) -> &mut Self { + self.to_device_events.push(from_json_value(event).unwrap()); + self + } + /// Builds a sync response as a JSON Value containing the events we queued /// so far. /// @@ -171,7 +178,7 @@ impl SyncResponseBuilder { "leave": self.left_rooms, }, "to_device": { - "events": [] + "events": self.to_device_events, }, "presence": { "events": self.presence,