zoe_client/pqxdh/
handler.rs

1use eyeball::{AsyncLock, ObservableWriteGuard, SharedObservable};
2use futures::StreamExt;
3use rand::RngCore;
4use serde::{Deserialize, Serialize};
5
6use std::sync::Arc;
7use std::time::{SystemTime, UNIX_EPOCH};
8
9use tracing::{error, warn};
10use zoe_wire_protocol::{Content, Filter};
11use zoe_wire_protocol::{
12    KeyId, Kind, Message, MessageFull, PqxdhInboxProtocol, StoreKey, Tag, VerifyingKey,
13    inbox::pqxdh::{
14        InboxType, PqxdhInbox, PqxdhInitialPayload, generate_pqxdh_prekeys, pqxdh_initiate,
15    },
16};
17
18use crate::pqxdh::PqxdhMessageListener;
19
20use super::{
21    PqxdhError, PqxdhProtocolState, PqxdhSession, PqxdhSessionId, PqxdhTarpcTransport, Result,
22};
23
24/// A complete PQXDH protocol handler that encapsulates all session management,
25/// key observation, subscription handling, and message routing logic.
26///
27/// This provides a high-level abstraction over the entire PQXDH workflow with
28/// automatic state management and persistence support. It can operate in two modes:
29///
30/// ## Service Provider Mode
31/// - Publishes a PQXDH inbox for clients to discover
32/// - Listens for incoming client connections
33/// - Manages multiple concurrent client sessions
34/// - Handles initial message decryption and session establishment
35///
36/// ## Client Mode  
37/// - Discovers and connects to service provider inboxes
38/// - Establishes secure sessions with service providers
39/// - Sends messages over established sessions
40/// - Manages session state and channel subscriptions
41///
42/// ## Key Features
43/// - **State Persistence**: All state can be serialized and restored across restarts
44/// - **State Observation**: Observable state changes via broadcast channels for reactive programming
45/// - **Automatic Subscriptions**: Handles message routing and channel management
46/// - **Session Management**: Tracks multiple concurrent sessions by user ID
47/// - **Privacy Preserving**: Uses randomized channel IDs and derived tags
48/// - **Type Safety**: Generic over message payload types with compile-time safety
49pub struct PqxdhProtocolHandler<T: crate::services::MessagesManagerTrait> {
50    messages_manager: Arc<T>,
51    client_keypair: Arc<zoe_wire_protocol::KeyPair>,
52    /// Observable state that can be subscribed to for reactive programming
53    pub(crate) state: SharedObservable<PqxdhProtocolState, AsyncLock>,
54}
55
56impl<T: crate::services::MessagesManagerTrait> Clone for PqxdhProtocolHandler<T> {
57    fn clone(&self) -> Self {
58        Self {
59            messages_manager: self.messages_manager.clone(),
60            client_keypair: self.client_keypair.clone(),
61            state: self.state.clone(),
62        }
63    }
64}
65
66impl<T: crate::services::MessagesManagerTrait + 'static> PqxdhProtocolHandler<T> {
67    /// Creates a new protocol handler for a specific PQXDH protocol
68    ///
69    /// This creates a fresh handler with empty state that can be used either as:
70    /// - **Service Provider**: Call `publish_service()` then use `inbox_stream()`
71    /// - **Client**: Call `connect_to_service()` then `send_message()`
72    ///
73    /// # Arguments
74    /// * `messages_manager` - The messages manager to use for message operations
75    /// * `client_keypair` - The client's keypair for signing and encryption
76    /// * `protocol` - The specific PQXDH protocol variant to use
77    ///
78    /// # Example
79    /// ```rust,no_run
80    /// # use zoe_client::pqxdh::*;
81    /// # use zoe_wire_protocol::*;
82    /// # async fn example() -> Result<()> {
83    /// # let messages_manager = todo!();
84    /// # let keypair = todo!();
85    /// let handler = PqxdhProtocolHandler::new(
86    ///     &messages_manager,
87    ///     &keypair,
88    ///     PqxdhInboxProtocol::EchoService
89    /// );
90    /// # Ok(())
91    /// # }
92    /// ```
93    pub fn new(
94        messages_manager: Arc<T>,
95        client_keypair: Arc<zoe_wire_protocol::KeyPair>,
96        protocol: PqxdhInboxProtocol,
97    ) -> Self {
98        let initial_state = PqxdhProtocolState::new(protocol);
99        let state = SharedObservable::new_async(initial_state);
100
101        Self {
102            messages_manager,
103            client_keypair,
104            state,
105        }
106    }
107
108    /// Creates a protocol handler from existing serialized state
109    ///
110    /// This allows restoring a handler across application restarts by loading
111    /// previously serialized state from a database or file. All sessions and
112    /// cryptographic state will be restored to their previous state.
113    ///
114    /// # Arguments
115    /// * `messages_manager` - The messages manager to use for message operations
116    /// * `client_keypair` - The client's keypair for signing and encryption
117    /// * `state` - Previously serialized protocol state
118    ///
119    /// # Example
120    /// ```rust,no_run
121    /// # use zoe_client::pqxdh::*;
122    /// # use zoe_wire_protocol::*;
123    /// # async fn example() -> Result<()> {
124    /// # let messages_manager = todo!();
125    /// # let keypair = todo!();
126    /// // Load state from storage
127    /// let saved_state: PqxdhProtocolState = load_from_database()?;
128    ///
129    /// // Restore handler with previous state
130    /// let handler = PqxdhProtocolHandler::from_state(
131    ///     &messages_manager,
132    ///     &keypair,
133    ///     saved_state
134    /// );
135    /// # Ok(())
136    /// # }
137    /// # fn load_from_database() -> Result<PqxdhProtocolState> { todo!() }
138    /// ```
139    pub fn from_state(
140        messages_manager: Arc<T>,
141        client_keypair: Arc<zoe_wire_protocol::KeyPair>,
142        state: PqxdhProtocolState,
143    ) -> Self {
144        let state = SharedObservable::new_async(state);
145
146        Self {
147            messages_manager,
148            client_keypair,
149            state,
150        }
151    }
152
153    /// Subscribe to state changes for reactive programming
154    ///
155    /// This method returns a broadcast receiver that can be used to observe changes to the
156    /// protocol handler's internal state. This is useful for building reactive UIs
157    /// or implementing state-dependent logic that needs to respond to changes in
158    /// session state, inbox status, or other protocol state.
159    ///
160    /// # Returns
161    /// Returns a `broadcast::Receiver<PqxdhProtocolState>` that receives state updates
162    ///
163    /// # Example
164    /// ```rust,no_run
165    /// # use zoe_client::pqxdh::*;
166    /// # use futures::StreamExt;
167    /// # async fn example() -> Result<()> {
168    /// # let handler: PqxdhProtocolHandler = todo!();
169    /// // Subscribe to state changes
170    /// let mut state_receiver = handler.subscribe_to_state();
171    ///
172    /// // Get current state
173    /// let current_state = handler.current_state();
174    /// println!("Current sessions: {}", current_state.sessions.len());
175    ///
176    /// // Watch for state changes
177    /// let mut state_stream = state_receiver.subscribe();
178    /// while let Some(new_state) = state_stream.next().await {
179    ///     println!("State updated! Sessions: {}", new_state.sessions.len());
180    /// }
181    /// # Ok(())
182    /// # }
183    /// ```
184    pub async fn subscribe_to_state(&self) -> eyeball::Subscriber<PqxdhProtocolState, AsyncLock> {
185        self.state.subscribe().await
186    }
187
188    /// Publishes a service inbox for this protocol (SERVICE PROVIDERS ONLY)
189    ///
190    /// This makes the current client discoverable as a service provider for the given protocol.
191    /// It generates fresh prekey bundles, publishes the inbox to the message store, and sets up
192    /// the necessary subscriptions for receiving client connections.
193    ///
194    /// Only call this if you want to provide a service that others can connect to.
195    /// After calling this, use `inbox_stream()` to listen for incoming client messages.
196    ///
197    /// # Arguments
198    /// * `force_overwrite` - If true, overwrites any existing published inbox
199    ///
200    /// # Returns
201    /// Returns the `Tag` of the published inbox
202    ///
203    /// # Errors
204    /// Returns an error if an inbox is already published and `force_overwrite` is false,
205    /// or if the publishing process fails.
206    ///
207    /// # Example
208    /// ```rust,no_run
209    /// # use zoe_client::pqxdh::*;
210    /// # use futures::StreamExt;
211    /// # async fn example() -> Result<()> {
212    /// # let mut handler: PqxdhProtocolHandler = todo!();
213    /// // Publish service for clients to discover
214    /// let inbox_tag = handler.publish_service(false).await?;
215    /// println!("Service published with tag: {:?}", inbox_tag);
216    ///
217    /// // Now listen for client connections
218    /// let mut inbox_stream = Box::pin(handler.inbox_stream::<String>().await?);
219    /// # Ok(())
220    /// # }
221    /// ```
222    pub async fn publish_service(&self, force_overwrite: bool) -> Result<Tag> {
223        let (inbox_tag, protocol) = {
224            let current_state = self.state.get().await;
225            (
226                current_state.inbox_tag.clone(),
227                current_state.protocol.clone(),
228            )
229        };
230
231        if inbox_tag.is_some() && !force_overwrite {
232            return Err(PqxdhError::InboxAlreadyPublished);
233        }
234
235        // Generate prekey bundle with private keys
236        let (prekey_bundle, private_keys) =
237            create_pqxdh_prekey_bundle_with_private_keys(&self.client_keypair, 5)?;
238
239        // Create inbox
240        let inbox = PqxdhInbox::new(
241            InboxType::Public,
242            prekey_bundle,
243            Some(1024), // Max message size
244            None,       // No expiration
245        );
246
247        let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
248
249        // Create storage message for the PQXDH inbox
250        let inbox_message = Message::new_v0(
251            Content::Raw(postcard::to_stdvec(&inbox)?),
252            self.client_keypair.public_key(),
253            timestamp,
254            Kind::Store(StoreKey::PqxdhInbox(protocol)),
255            vec![],
256        );
257
258        let inbox_message_full =
259            MessageFull::new(inbox_message, &self.client_keypair).map_err(|e| {
260                PqxdhError::MessageCreation(format!("Failed to create MessageFull for inbox: {e}"))
261            })?;
262
263        let target_tag = Tag::from(&inbox_message_full);
264
265        // Publish the inbox
266        self.messages_manager.publish(inbox_message_full).await?;
267
268        // Update state and notify observers
269        {
270            let mut state = self.state.write().await;
271            ObservableWriteGuard::update(&mut state, |state| {
272                state.private_keys = Some(private_keys);
273                state.inbox_tag = Some(target_tag.clone());
274                state.inbox = Some(inbox);
275            });
276        }
277
278        Ok(target_tag)
279    }
280
281    /// Connects to a service provider's inbox (CLIENTS ONLY)
282    ///
283    /// This discovers the target service provider's inbox and establishes a secure PQXDH session.
284    /// It performs the full PQXDH key exchange protocol and sets up the necessary subscriptions
285    /// for receiving responses from the service.
286    ///
287    /// Use this when you want to connect to a service as a client.
288    /// After calling this, use `send_message()` to send additional messages to the service.
289    ///
290    /// # Arguments
291    /// * `target_service_key` - The public key of the service provider to connect to
292    /// * `initial_message` - The first message to send as part of the connection
293    ///
294    /// # Returns
295    /// Returns a stream of messages from the service provider
296    ///
297    /// # Example
298    /// ```rust,no_run
299    /// # use zoe_client::pqxdh::*;
300    /// # use futures::StreamExt;
301    /// # async fn example() -> Result<()> {
302    /// # let mut handler: PqxdhProtocolHandler = todo!();
303    /// # let service_key = todo!();
304    /// // Connect to service and send initial message
305    /// let mut response_stream = Box::pin(handler.connect_to_service::<String, String>(
306    ///     &service_key,
307    ///     &"Hello, service!".to_string()
308    /// ).await?);
309    ///
310    /// // Listen for responses
311    /// while let Some(message) = response_stream.next().await {
312    ///     // Handle service responses
313    /// }
314    /// # Ok(())
315    /// # }
316    /// ```
317    pub async fn connect_to_service<O, I>(
318        &self,
319        target_service_key: &VerifyingKey,
320        initial_message: &O,
321    ) -> Result<(PqxdhSessionId, PqxdhMessageListener<I>)>
322    where
323        O: serde::Serialize + for<'de> serde::Deserialize<'de> + Clone,
324        I: for<'de> serde::Deserialize<'de> + Clone + 'static,
325    {
326        let protocol = self.state.get().await.protocol.clone();
327        // Discover inbox
328        let (inbox, inbox_tag) = self.fetch_pqxdh_inbox(target_service_key, protocol).await?;
329
330        // Establish session
331        let session = self
332            .initiate_session(
333                target_service_key.clone(),
334                &inbox,
335                vec![inbox_tag],
336                initial_message,
337            )
338            .await?;
339
340        let my_session_id = session.my_session_channel_id;
341
342        // Update state and notify observers
343        {
344            let mut state = self.state.write().await;
345            ObservableWriteGuard::update(&mut state, |state| {
346                if state
347                    .sessions
348                    .insert(KeyId::from_bytes(my_session_id), session)
349                    .is_some()
350                {
351                    warn!("overwriting existing pqxdh session. Shouldn't happen");
352                }
353            });
354        }
355
356        Ok((
357            my_session_id,
358            self.listen_for_messages(my_session_id, true).await?,
359        ))
360    }
361
362    pub async fn listen_for_messages<I>(
363        &self,
364        my_session_id: PqxdhSessionId,
365        catch_up: bool,
366    ) -> Result<PqxdhMessageListener<I>>
367    where
368        I: for<'de> serde::Deserialize<'de>,
369    {
370        let listening_tag = {
371            let state = self.state.get().await;
372            let Some(session) = state.sessions.get(&KeyId::from_bytes(my_session_id)) else {
373                return Err(PqxdhError::SessionNotFound);
374            };
375            session.listening_channel_tag()
376        };
377
378        let messages_manager = self.messages_manager.clone();
379        let state = self.state.clone();
380        PqxdhMessageListener::new(
381            messages_manager,
382            my_session_id,
383            state,
384            listening_tag,
385            catch_up,
386        )
387        .await
388    }
389
390    pub async fn tarpc_transport<Req, Resp>(
391        &self,
392        session_id: PqxdhSessionId,
393    ) -> Result<PqxdhTarpcTransport<Req, Resp>>
394    where
395        Req: for<'de> serde::Deserialize<'de> + Send + 'static,
396        Resp: serde::Serialize + Send + Sync + 'static,
397    {
398        // Create the incoming stream first
399        let stream = self
400            .clone()
401            .listen_for_messages::<Req>(session_id, false)
402            .await?;
403
404        let me: PqxdhProtocolHandler<T> = self.clone();
405
406        // Create the transport with the stream and client
407        Ok(PqxdhTarpcTransport::new(session_id, Box::pin(stream), me))
408    }
409
410    /// Sends a message to an established session (CLIENTS ONLY)
411    ///
412    /// Use this to send additional messages after calling `connect_to_service()`.
413    /// The message will be encrypted and sent over the established secure PQXDH session
414    /// using the session's private channel.
415    ///
416    /// # Arguments
417    /// * `session_id` - The session ID of the established PQXDH session
418    /// * `message` - The message payload to encrypt and send
419    ///
420    /// # Errors
421    /// Returns an error if no active session exists with the given session ID
422    ///
423    /// # Example
424    /// ```rust,no_run
425    /// # use zoe_client::pqxdh::*;
426    /// # async fn example() -> Result<()> {
427    /// # let mut handler: PqxdhProtocolHandler = todo!();
428    /// # let service_key = todo!();
429    /// # let session_id = [0u8; 32]; // Session ID from established connection
430    /// // First establish connection
431    /// let _stream = handler.connect_to_service::<String, String>(&service_key, &"initial".to_string()).await?;
432    ///
433    /// // Send follow-up messages using session ID
434    /// handler.send_message(&session_id, &"follow-up message".to_string()).await?;
435    /// handler.send_message(&session_id, &"another message".to_string()).await?;
436    /// # Ok(())
437    /// # }
438    /// ```
439    pub async fn send_message<U>(&self, session_id: &PqxdhSessionId, message: &U) -> Result<()>
440    where
441        U: serde::Serialize + for<'de> serde::Deserialize<'de> + Clone,
442    {
443        self.send_message_inner(session_id, message, Kind::Regular)
444            .await
445    }
446
447    pub async fn send_ephemeral_message<U>(
448        &self,
449        session_id: &PqxdhSessionId,
450        message: &U,
451        timeout: u32,
452    ) -> Result<()>
453    where
454        U: serde::Serialize + for<'de> serde::Deserialize<'de> + Clone,
455    {
456        self.send_message_inner(session_id, message, Kind::Ephemeral(timeout))
457            .await
458    }
459
460    /// Creates a stream of messages that arrive to our inbox (SERVICE PROVIDERS ONLY)
461    ///
462    /// This method returns a stream of incoming messages from clients who are connecting
463    /// to or communicating with this service. The stream includes both initial PQXDH
464    /// messages (new client connections) and session messages (ongoing communication).
465    ///
466    /// # Returns
467    /// Returns a stream of `(PqxdhSessionId, T)` tuples where:
468    /// - `PqxdhSessionId` is the session ID for the client connection
469    /// - `T` is the deserialized message payload from the client
470    ///
471    /// # Errors
472    /// Returns an error if `publish_service()` has not been called first
473    ///
474    /// # Example
475    /// ```rust,no_run
476    /// # use zoe_client::pqxdh::*;
477    /// # use futures::StreamExt;
478    /// # async fn example() -> Result<()> {
479    /// # let mut handler: PqxdhProtocolHandler = todo!();
480    /// // First publish the service
481    /// handler.publish_service(false).await?;
482    ///
483    /// // Then listen for client messages
484    /// let mut inbox_stream = Box::pin(handler.inbox_stream::<String>().await?);
485    /// while let Some((session_id, message)) = inbox_stream.next().await {
486    ///     println!("Received message from session {:?}: {}", session_id, message);
487    ///     
488    ///     // Echo the message back to the client
489    ///     handler.send_message(&session_id, &format!("Echo: {}", message)).await?;
490    /// }
491    /// # Ok(())
492    /// # }
493    /// ```
494    pub async fn inbox_stream<U>(&self) -> Result<impl futures::Stream<Item = (PqxdhSessionId, U)>>
495    where
496        U: for<'de> serde::Deserialize<'de>,
497    {
498        let inbox_tag = {
499            let current_state = self.state.get().await;
500            if current_state.private_keys.is_none() {
501                return Err(PqxdhError::ServiceNotPublished);
502            };
503
504            let Some(inbox_tag) = &current_state.inbox_tag else {
505                return Err(PqxdhError::NoInboxSubscription);
506            };
507            inbox_tag.clone()
508        };
509
510        self.messages_manager
511            .ensure_contains_filter(Filter::from(inbox_tag.clone()))
512            .await?;
513
514        let stream = self
515            .messages_manager
516            .filtered_messages_stream(Filter::from(inbox_tag));
517
518        let state = self.state.clone();
519        let my_public_key = self.client_keypair.public_key().clone();
520
521        let stream = stream.filter_map(move |message_full| {
522            let state = state.clone();
523            let my_public_key = my_public_key.clone();
524            async move {
525                Self::on_incoming_inbox_message::<U>(&state, &my_public_key, &message_full)
526                    .await
527                    .inspect_err(|e| {
528                        error!(
529                            msg_id = hex::encode(message_full.id().as_bytes()),
530                            "error processing inbox message: {e}"
531                        );
532                    })
533                    .ok()
534            }
535        });
536
537        Ok(stream)
538    }
539}
540
541#[async_trait::async_trait]
542impl<Resp, T: crate::services::MessagesManagerTrait> super::PqxdhTarpcTransportSender<Resp>
543    for PqxdhProtocolHandler<T>
544where
545    Resp: serde::Serialize + Send + Sync,
546{
547    async fn send_response(&self, session_id: &PqxdhSessionId, resp: &Resp) -> Result<()> {
548        self.send_message_inner(session_id, &resp, Kind::Ephemeral(10))
549            .await
550    }
551}
552
553// Internal functions
554
555impl<T: crate::services::MessagesManagerTrait> PqxdhProtocolHandler<T> {
556    /// Fetch a PQXDH inbox using the trait method
557    async fn fetch_pqxdh_inbox<U: for<'de> Deserialize<'de>>(
558        &self,
559        provider_key: &VerifyingKey,
560        protocol: PqxdhInboxProtocol,
561    ) -> Result<(U, Tag)> {
562        let provider_user_id = *provider_key.id();
563        let store_key = StoreKey::PqxdhInbox(protocol);
564
565        let Some(message_full) = self
566            .messages_manager
567            .user_data(KeyId::from(*provider_user_id.as_bytes()), store_key)
568            .await?
569        else {
570            return Err(PqxdhError::InboxNotFound);
571        };
572
573        let Some(content_bytes) = message_full.raw_content() else {
574            return Err(PqxdhError::NoContent);
575        };
576
577        let inbox_data: U = postcard::from_bytes(content_bytes)?;
578
579        Ok((inbox_data, Tag::from(&message_full)))
580    }
581
582    /// Initiates a PQXDH session with a target user using an already loaded inbox
583    async fn initiate_session<U: Serialize>(
584        &self,
585        target_public_key: VerifyingKey,
586        inbox: &PqxdhInbox,
587        target_tags: Vec<Tag>,
588        initial_payload: &U,
589    ) -> Result<PqxdhSession> {
590        // Extract the prekey bundle from the inbox
591        let prekey_bundle = &inbox.pqxdh_prekeys;
592
593        // Generate randomized channel ID for session messages
594        let mut rng = rand::thread_rng();
595        let mut session_channel_id_prefix = PqxdhSessionId::default();
596        rng.fill_bytes(&mut session_channel_id_prefix);
597
598        let their_session_channel_id = {
599            let mut hasher = blake3::Hasher::new();
600            hasher.update(&session_channel_id_prefix);
601            hasher.update(target_public_key.id().as_ref());
602            hasher.finalize()
603        };
604
605        let my_session_channel_id = {
606            let mut hasher = blake3::Hasher::new();
607            hasher.update(&session_channel_id_prefix);
608            hasher.update(self.client_keypair.public_key().id().as_ref());
609            hasher.finalize()
610        };
611
612        // Create the combined initial payload with channel ID
613        let initial_payload_struct = PqxdhInitialPayload {
614            user_payload: initial_payload,
615            session_channel_id_prefix,
616        };
617
618        let combined_payload_bytes = postcard::to_stdvec(&initial_payload_struct)?;
619
620        // Initiate PQXDH
621        let (initial_message, shared_secret) = pqxdh_initiate(
622            &self.client_keypair,
623            prekey_bundle,
624            &combined_payload_bytes,
625            &mut rng,
626        )
627        .map_err(|e| PqxdhError::Crypto(e.to_string()))?;
628
629        let pqxdh_content = zoe_wire_protocol::PqxdhEncryptedContent::Initial(initial_message);
630
631        let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
632        let message = Message::new_v0(
633            Content::PqxdhEncrypted(pqxdh_content),
634            self.client_keypair.public_key(),
635            timestamp,
636            Kind::Ephemeral(0),
637            target_tags,
638        );
639
640        let message_full = MessageFull::new(message, &self.client_keypair).map_err(|e| {
641            PqxdhError::MessageCreation(format!("Failed to create initial message: {e}"))
642        })?;
643
644        // The initial message will be routed using its derived tag (Tag::Event with message ID)
645        // This is automatically created by Tag::from(&MessageFull)
646
647        self.messages_manager.publish(message_full).await?;
648
649        Ok(PqxdhSession {
650            shared_secret,
651            sequence_number: 1,
652            my_session_channel_id: my_session_channel_id.into(),
653            their_session_channel_id: their_session_channel_id.into(),
654            their_key: target_public_key,
655        })
656    }
657
658    async fn on_incoming_inbox_message<U>(
659        state: &SharedObservable<PqxdhProtocolState, AsyncLock>,
660        my_public_key: &VerifyingKey,
661        message_full: &MessageFull,
662    ) -> Result<(PqxdhSessionId, U)>
663    where
664        U: for<'de> serde::Deserialize<'de>,
665    {
666        let msg_id_hex = hex::encode(message_full.id().as_bytes());
667        let Some(pqxdh_content) = message_full.content().as_pqxdh_encrypted() else {
668            warn!(
669                msg_id = msg_id_hex,
670                "not the proper content type on message"
671            );
672            return Err(PqxdhError::InvalidContentType);
673        };
674
675        let (private_keys, prekey_bundle) = {
676            let current_state = state.get().await;
677            let Some(private_keys) = current_state.private_keys else {
678                error!(msg_id = msg_id_hex, "no private keys");
679                return Err(PqxdhError::NoPrivateKeys);
680            };
681            let Some(ref inbox) = current_state.inbox else {
682                error!(msg_id = msg_id_hex, "no inbox");
683                return Err(PqxdhError::NoInbox);
684            };
685            (private_keys.clone(), inbox.pqxdh_prekeys.clone())
686        };
687
688        let initial_msg = match pqxdh_content {
689            zoe_wire_protocol::PqxdhEncryptedContent::Initial(initial_msg) => initial_msg,
690            _ => {
691                warn!(
692                    msg_id = msg_id_hex,
693                    "not an inbox initial message. ignoring incoming message."
694                );
695                return Err(PqxdhError::NotInitialMessage);
696            }
697        };
698
699        let (decrypted_payload, shared_secret) =
700            zoe_wire_protocol::inbox::pqxdh::pqxdh_crypto::pqxdh_respond(
701                initial_msg,
702                &private_keys,
703                &prekey_bundle,
704            )
705            .map_err(|e| PqxdhError::Crypto(e.to_string()))?;
706
707        // Deserialize the PqxdhInitialPayload structure with the user payload type
708        let initial_payload: PqxdhInitialPayload<U> = postcard::from_bytes(&decrypted_payload)?;
709
710        let PqxdhInitialPayload {
711            user_payload: user_message,
712            session_channel_id_prefix,
713        } = initial_payload;
714
715        let my_session_channel_id = {
716            let mut hasher = blake3::Hasher::new();
717            hasher.update(&session_channel_id_prefix);
718            hasher.update(my_public_key.id().as_ref());
719            hasher.finalize().into()
720        };
721
722        let their_session_channel_id = {
723            let mut hasher = blake3::Hasher::new();
724            hasher.update(&session_channel_id_prefix);
725            hasher.update(message_full.author().id().as_ref());
726            hasher.finalize().into()
727        };
728
729        // Create session
730        let session = PqxdhSession::from_shared_secret(
731            shared_secret,
732            my_session_channel_id,
733            their_session_channel_id,
734            message_full.author().clone(),
735        );
736
737        // Update state
738        let mut current_state = state.write().await;
739        ObservableWriteGuard::update(&mut current_state, |state| {
740            if state
741                .sessions
742                .insert(KeyId::from_bytes(my_session_channel_id), session)
743                .is_some()
744            {
745                error!("overwriting existing pqxdh session. Shouldn't happen");
746            }
747        });
748
749        Ok((my_session_channel_id, user_message))
750    }
751
752    async fn send_message_inner<U>(
753        &self,
754        session_id: &PqxdhSessionId,
755        message: &U,
756        kind: Kind,
757    ) -> Result<()>
758    where
759        U: serde::Serialize,
760    {
761        let full_msg = {
762            let mut current_state = self.state.write().await;
763            let Some(mut session) = current_state
764                .sessions
765                .get(&KeyId::from_bytes(*session_id))
766                .cloned()
767            else {
768                return Err(PqxdhError::SessionNotFound);
769            };
770
771            let msg = session.gen_next_message(&self.client_keypair, message, kind)?;
772
773            ObservableWriteGuard::update(&mut current_state, |state: &mut PqxdhProtocolState| {
774                state
775                    .sessions
776                    .insert(KeyId::from_bytes(*session_id), session); // re-add the changed session
777            });
778
779            msg
780        };
781
782        self.messages_manager.publish(full_msg).await?;
783
784        Ok(())
785    }
786}
787
788pub(crate) fn create_pqxdh_prekey_bundle_with_private_keys(
789    identity_keypair: &zoe_wire_protocol::KeyPair,
790    num_one_time_keys: usize,
791) -> Result<(
792    zoe_wire_protocol::inbox::pqxdh::PqxdhPrekeyBundle,
793    zoe_wire_protocol::inbox::pqxdh::PqxdhPrivateKeys,
794)> {
795    let mut rng = rand::thread_rng();
796    generate_pqxdh_prekeys(identity_keypair, num_one_time_keys, &mut rng)
797        .map_err(|e| PqxdhError::KeyGeneration(format!("Failed to generate PQXDH prekeys: {e}")))
798}