zoe_client/pqxdh/
message_listener.rs

1use std::{
2    pin::Pin,
3    sync::Arc,
4    task::{Context, Poll},
5};
6
7use eyeball::{AsyncLock, SharedObservable};
8use futures::{Stream, StreamExt};
9use zoe_wire_protocol::{Filter, KeyId, MessageFull, PqxdhEncryptedContent, Tag};
10
11use super::{PqxdhProtocolState, PqxdhSessionId, Result};
12use crate::{pqxdh::PqxdhError, services::MessagesManagerTrait};
13
14pub struct PqxdhMessageListener<U> {
15    inner: Pin<Box<dyn futures::Stream<Item = U> + Send>>,
16}
17
18impl<U> PqxdhMessageListener<U>
19where
20    U: for<'de> serde::Deserialize<'de>,
21{
22    pub(super) async fn new<T: MessagesManagerTrait>(
23        messages_manager: Arc<T>,
24        session_id: PqxdhSessionId,
25        state: SharedObservable<super::PqxdhProtocolState, AsyncLock>,
26        listening_tag: Tag,
27        catch_up: bool,
28    ) -> Result<Self> {
29        // Subscribe to the session channel for responses
30        let messages_stream = if catch_up {
31            messages_manager
32                .catch_up_and_subscribe((&listening_tag).into(), None)
33                .await?
34        } else {
35            messages_manager
36                .ensure_contains_filter(Filter::from(listening_tag.clone()))
37                .await?;
38            messages_manager.filtered_messages_stream(Filter::from(listening_tag))
39        };
40
41        let inner = Box::pin(messages_stream.filter_map(move |message_full| {
42            let state = state.clone();
43            async move {
44                tracing::debug!(
45                    "🔄 PQXDH handler received message: {}",
46                    hex::encode(message_full.id().as_bytes())
47                );
48                Self::on_regular_message(&state, &message_full, &session_id)
49                    .await
50                    .inspect_err(|e| {
51                        tracing::error!(
52                            msg_id = hex::encode(message_full.id().as_bytes()),
53                            "error processing inbox message: {e}"
54                        );
55                    })
56                    .inspect(|_result| {
57                        tracing::debug!(
58                            "✅ PQXDH handler successfully processed message: {}",
59                            hex::encode(message_full.id().as_bytes())
60                        );
61                    })
62                    .ok()
63            }
64        }));
65        Ok(Self { inner })
66    }
67
68    async fn on_regular_message(
69        state: &SharedObservable<PqxdhProtocolState, AsyncLock>,
70        message_full: &MessageFull,
71        session_id: &PqxdhSessionId,
72    ) -> Result<U> {
73        let shared_secret = {
74            let current_state = state.get().await;
75            let Some(session) = current_state.sessions.get(&KeyId::from_bytes(*session_id)) else {
76                return Err(PqxdhError::SessionNotFound);
77            };
78
79            if &session.their_key != message_full.author() {
80                return Err(PqxdhError::InvalidSender);
81            };
82            session.shared_secret.clone()
83        };
84
85        let Some(PqxdhEncryptedContent::Session(pqxdh_content)) =
86            message_full.content().as_pqxdh_encrypted()
87        else {
88            return Err(PqxdhError::NotPqxdhMessage);
89        };
90
91        let decrypted_bytes =
92            zoe_wire_protocol::inbox::pqxdh::pqxdh_crypto::decrypt_pqxdh_session_message(
93                &shared_secret,
94                pqxdh_content,
95            )
96            .map_err(|e| PqxdhError::Crypto(e.to_string()))?;
97        Ok(postcard::from_bytes(&decrypted_bytes)?)
98    }
99}
100
101impl<U> Stream for PqxdhMessageListener<U>
102where
103    U: for<'de> serde::Deserialize<'de>,
104{
105    type Item = U;
106
107    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
108        let this = self.get_mut();
109        this.inner.poll_next_unpin(cx)
110    }
111}