zoe_client/pqxdh/
message_listener.rs1use std::{
2 pin::Pin,
3 sync::Arc,
4 task::{Context, Poll},
5};
6
7use eyeball::{AsyncLock, SharedObservable};
8use futures::{Stream, StreamExt};
9use zoe_wire_protocol::{Filter, KeyId, MessageFull, PqxdhEncryptedContent, Tag};
10
11use super::{PqxdhProtocolState, PqxdhSessionId, Result};
12use crate::{pqxdh::PqxdhError, services::MessagesManagerTrait};
13
14pub struct PqxdhMessageListener<U> {
15 inner: Pin<Box<dyn futures::Stream<Item = U> + Send>>,
16}
17
18impl<U> PqxdhMessageListener<U>
19where
20 U: for<'de> serde::Deserialize<'de>,
21{
22 pub(super) async fn new<T: MessagesManagerTrait>(
23 messages_manager: Arc<T>,
24 session_id: PqxdhSessionId,
25 state: SharedObservable<super::PqxdhProtocolState, AsyncLock>,
26 listening_tag: Tag,
27 catch_up: bool,
28 ) -> Result<Self> {
29 let messages_stream = if catch_up {
31 messages_manager
32 .catch_up_and_subscribe((&listening_tag).into(), None)
33 .await?
34 } else {
35 messages_manager
36 .ensure_contains_filter(Filter::from(listening_tag.clone()))
37 .await?;
38 messages_manager.filtered_messages_stream(Filter::from(listening_tag))
39 };
40
41 let inner = Box::pin(messages_stream.filter_map(move |message_full| {
42 let state = state.clone();
43 async move {
44 tracing::debug!(
45 "🔄 PQXDH handler received message: {}",
46 hex::encode(message_full.id().as_bytes())
47 );
48 Self::on_regular_message(&state, &message_full, &session_id)
49 .await
50 .inspect_err(|e| {
51 tracing::error!(
52 msg_id = hex::encode(message_full.id().as_bytes()),
53 "error processing inbox message: {e}"
54 );
55 })
56 .inspect(|_result| {
57 tracing::debug!(
58 "✅ PQXDH handler successfully processed message: {}",
59 hex::encode(message_full.id().as_bytes())
60 );
61 })
62 .ok()
63 }
64 }));
65 Ok(Self { inner })
66 }
67
68 async fn on_regular_message(
69 state: &SharedObservable<PqxdhProtocolState, AsyncLock>,
70 message_full: &MessageFull,
71 session_id: &PqxdhSessionId,
72 ) -> Result<U> {
73 let shared_secret = {
74 let current_state = state.get().await;
75 let Some(session) = current_state.sessions.get(&KeyId::from_bytes(*session_id)) else {
76 return Err(PqxdhError::SessionNotFound);
77 };
78
79 if &session.their_key != message_full.author() {
80 return Err(PqxdhError::InvalidSender);
81 };
82 session.shared_secret.clone()
83 };
84
85 let Some(PqxdhEncryptedContent::Session(pqxdh_content)) =
86 message_full.content().as_pqxdh_encrypted()
87 else {
88 return Err(PqxdhError::NotPqxdhMessage);
89 };
90
91 let decrypted_bytes =
92 zoe_wire_protocol::inbox::pqxdh::pqxdh_crypto::decrypt_pqxdh_session_message(
93 &shared_secret,
94 pqxdh_content,
95 )
96 .map_err(|e| PqxdhError::Crypto(e.to_string()))?;
97 Ok(postcard::from_bytes(&decrypted_bytes)?)
98 }
99}
100
101impl<U> Stream for PqxdhMessageListener<U>
102where
103 U: for<'de> serde::Deserialize<'de>,
104{
105 type Item = U;
106
107 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
108 let this = self.get_mut();
109 this.inner.poll_next_unpin(cx)
110 }
111}