zoe_client/
session_manager.rs

1//! Session Manager for automatic PQXDH state synchronization
2//!
3//! This module provides a state synchronization layer that:
4//! - Uses a builder pattern for initialization and state loading
5//! - Manages PQXDH protocol handlers and their state changes
6//! - Automatically persists state changes to storage transparently
7//! - Returns clones of specific types when requested
8//! - Emits state change events for reactive programming
9//!
10//! The SessionManager acts as a central hub for PQXDH state management, providing:
11//! - **Automatic Persistence**: State changes are immediately persisted to storage
12//! - **Event Broadcasting**: State changes are broadcast to subscribers
13//! - **Handler Management**: Registration and lifecycle management of PQXDH handlers
14//! - **State Access**: Fast access to current states via cloning
15//!
16//! ## Usage Pattern
17//!
18//! ```rust,no_run
19//! // Build the session manager (loads all states from storage)
20//! let manager = SessionManager::builder(storage, messages_manager)
21//!     .build()
22//!     .await?;
23//!
24//! // Register PQXDH handlers for automatic state listening
25//! manager.register_pqxdh_handler("handler_1".to_string(), handler).await?;
26//!
27//! // Get current state clones
28//! let current_state = manager.get_pqxdh_state_clone("handler_1").await;
29//!
30//! // Subscribe to state changes
31//! let mut state_changes = manager.subscribe_to_state_changes();
32//! ```
33
34use std::collections::BTreeMap;
35use std::collections::btree_map::Entry;
36use std::sync::Arc;
37
38use zoe_wire_protocol::KeyId;
39
40use tokio::sync::RwLock;
41use tokio::task::JoinHandle;
42
43use zoe_client_storage::{StateNamespace, StateStorage, StorageError};
44use zoe_state_machine::{GroupDataUpdate, GroupManager};
45use zoe_wire_protocol::PqxdhInboxProtocol;
46
47use crate::pqxdh::PqxdhProtocolHandler;
48use crate::services::MessagesManagerTrait;
49
50/// Errors that can occur during session management operations
51#[derive(Debug, thiserror::Error)]
52pub enum SessionManagerError {
53    /// Storage operation failed
54    #[error("Storage error: {0}")]
55    Storage(#[from] StorageError),
56    /// Serialization/deserialization failed
57    #[error("Serialization error: {0}")]
58    Serialization(String),
59    /// Handler registration failed
60    #[error("Handler registration error: {0}")]
61    HandlerRegistration(String),
62    /// Client keypair not found
63    #[error("Client keypair not found")]
64    ClientKeypairNotFound,
65}
66
67/// Result type for session manager operations
68pub type SessionManagerResult<T> = Result<T, SessionManagerError>;
69/// Builder for creating a SessionManager with proper initialization
70pub struct SessionManagerBuilder<S: StateStorage + 'static, M: MessagesManagerTrait + 'static> {
71    storage: Arc<S>,
72    messages_manager: Arc<M>,
73    client_keypair: Option<Arc<zoe_wire_protocol::KeyPair>>,
74}
75
76impl<S: StateStorage + 'static, M: MessagesManagerTrait + 'static> SessionManagerBuilder<S, M> {
77    /// Create a new SessionManager builder
78    pub fn new(storage: Arc<S>, messages_manager: Arc<M>) -> Self {
79        Self {
80            storage,
81            messages_manager,
82            client_keypair: None,
83        }
84    }
85
86    pub fn client_keypair(mut self, client_keypair: Arc<zoe_wire_protocol::KeyPair>) -> Self {
87        self.client_keypair = Some(client_keypair);
88        self
89    }
90
91    /// Build the SessionManager, loading all states from storage and setting up listeners
92    pub async fn build(self) -> SessionManagerResult<SessionManager<S, M>> {
93        let SessionManagerBuilder {
94            storage,
95            messages_manager,
96            client_keypair,
97        } = self;
98
99        let client_keypair = client_keypair.ok_or(SessionManagerError::ClientKeypairNotFound)?;
100
101        let states = SessionManager::load_pqxdh_states(
102            storage.clone(),
103            messages_manager.clone(),
104            client_keypair.clone(),
105            StateNamespace::PqxdhSession(KeyId::from(*client_keypair.id())),
106        )
107        .await?;
108
109        // Create group manager instance and initialize listener
110        let (group_manager, group_manager_task) =
111            SessionManager::<S, M>::init_group_manager(storage.clone(), client_keypair.clone())
112                .await?;
113
114        let manager = SessionManager {
115            storage,
116            messages_manager,
117            pqxdh_handlers: RwLock::new(BTreeMap::from_iter(states)),
118            group_manager,
119            group_manager_task,
120            client_keypair,
121        };
122
123        tracing::info!("SessionManager built and initialized successfully");
124        Ok(manager)
125    }
126}
127
128type PqxdhHandlerState<M> = (Arc<PqxdhProtocolHandler<M>>, JoinHandle<()>);
129type PqxdhHandlerStates<M> = BTreeMap<PqxdhInboxProtocol, PqxdhHandlerState<M>>;
130
131/// State synchronization manager that automatically listens to and persists state changes.
132///
133/// This manager follows the persistent messenger pattern:
134/// - External objects register with the manager
135/// - The manager listens to their state changes
136/// - State changes are automatically persisted to storage
137/// - State change events are broadcast to subscribers
138///
139/// The SessionManager is designed to be long-lived and handles the lifecycle
140/// of multiple PQXDH protocol handlers and group encryption sessions, automatically
141/// managing their state persistence and providing reactive access to state changes.
142pub struct SessionManager<S: StateStorage + 'static, M: MessagesManagerTrait + 'static> {
143    /// Underlying state storage
144    storage: Arc<S>,
145    /// Messages manager for subscription and message handling
146    messages_manager: Arc<M>,
147    /// PQXDH handlers with their background listener tasks  
148    pqxdh_handlers: RwLock<PqxdhHandlerStates<M>>,
149    /// Group manager instance with background listener task
150    group_manager: GroupManager,
151    #[allow(dead_code)]
152    group_manager_task: JoinHandle<()>,
153    client_keypair: Arc<zoe_wire_protocol::KeyPair>,
154}
155
156impl<S: StateStorage + 'static, M: MessagesManagerTrait + 'static> SessionManager<S, M> {
157    /// Create a new SessionManager builder
158    ///
159    /// # Arguments
160    /// * `storage` - The state storage backend
161    /// * `messages_manager` - The messages manager for subscription and message handling
162    ///
163    /// # Returns
164    /// A SessionManagerBuilder for configuring and building the SessionManager
165    pub fn builder(storage: Arc<S>, messages_manager: Arc<M>) -> SessionManagerBuilder<S, M> {
166        SessionManagerBuilder::new(storage, messages_manager)
167    }
168    /// Get a reference to the underlying storage
169    pub fn storage(&self) -> &Arc<S> {
170        &self.storage
171    }
172
173    /// Get a reference to the messages manager
174    pub fn messages_manager(&self) -> &Arc<M> {
175        &self.messages_manager
176    }
177
178    pub async fn pqxdh_handler(
179        &self,
180        protocol: PqxdhInboxProtocol,
181    ) -> SessionManagerResult<Arc<PqxdhProtocolHandler<M>>> {
182        match self.pqxdh_handlers.write().await.entry(protocol.clone()) {
183            Entry::Occupied(occupied) => Ok(occupied.get().0.clone()),
184            Entry::Vacant(vacant) => {
185                let handler = Arc::new(PqxdhProtocolHandler::new(
186                    self.messages_manager.clone(),
187                    self.client_keypair.clone(),
188                    protocol.clone(),
189                ));
190                let (_idx, (handler_arc, listener_task)) = Self::init_pqxdh_handler(
191                    self.storage.clone(),
192                    protocol,
193                    handler.clone(),
194                    StateNamespace::PqxdhSession(KeyId::from(*self.client_keypair.id())),
195                    true,
196                )
197                .await?;
198                vacant.insert((handler_arc.clone(), listener_task));
199                Ok(handler_arc)
200            }
201        }
202    }
203}
204
205/// Group Manager Integration
206impl<S: StateStorage + 'static, M: MessagesManagerTrait + 'static> SessionManager<S, M> {
207    /// Initialize group manager with listener task
208    async fn init_group_manager(
209        storage: Arc<S>,
210        client_keypair: Arc<zoe_wire_protocol::KeyPair>,
211    ) -> SessionManagerResult<(GroupManager, JoinHandle<()>)> {
212        let namespace = StateNamespace::GroupSession(KeyId::from(*client_keypair.id()));
213        let sessions: Vec<(Vec<u8>, zoe_state_machine::GroupSession)> = storage
214            .list_namespace_data(&namespace)
215            .await
216            .map_err(SessionManagerError::Storage)?;
217
218        let group_manager = GroupManager::builder()
219            .with_sessions(sessions.into_iter().map(|(_, session)| session).collect())
220            .build();
221        let mut group_updates = group_manager.subscribe_to_updates();
222
223        let task = tokio::spawn(async move {
224            while let Ok(update) = group_updates.recv().await {
225                tracing::debug!("Received group manager update: {:?}", update);
226                match update {
227                    GroupDataUpdate::GroupAdded(group_session)
228                    | GroupDataUpdate::GroupUpdated(group_session) => {
229                        if let Err(e) = storage
230                            .store(
231                                &namespace,
232                                group_session.state.group_id.as_bytes(),
233                                &group_session,
234                            )
235                            .await
236                        {
237                            tracing::error!(error=?e, "Failed to persist group session");
238                        }
239                    }
240                    GroupDataUpdate::GroupRemoved(group_session) => {
241                        if let Err(e) = storage
242                            .delete(&namespace, group_session.state.group_id.as_bytes())
243                            .await
244                        {
245                            tracing::error!(error=?e, "Failed to delete group session");
246                        }
247                    }
248                }
249            }
250            tracing::info!("Group manager listener task ended");
251        });
252
253        tracing::info!("Group manager listener started");
254        Ok((group_manager, task))
255    }
256
257    /// Get access to the group manager
258    pub fn group_manager(&self) -> GroupManager {
259        self.group_manager.clone()
260    }
261}
262
263/// Live Object Access and State Cloning
264impl<S: StateStorage + 'static, M: MessagesManagerTrait + 'static> SessionManager<S, M> {
265    /// Load PQXDH states from storage
266    async fn load_pqxdh_states(
267        storage: Arc<S>,
268        messages_manager: Arc<M>,
269        client_keypair: Arc<zoe_wire_protocol::KeyPair>,
270        namespace: StateNamespace,
271    ) -> SessionManagerResult<
272        Vec<(
273            PqxdhInboxProtocol,
274            (Arc<PqxdhProtocolHandler<M>>, JoinHandle<()>),
275        )>,
276    > {
277        let pqxdh_data = storage
278            .list_namespace_data(&namespace)
279            .await
280            .map_err(SessionManagerError::Storage)?;
281
282        let mut handlers = Vec::new();
283        for (key, protocol_state) in pqxdh_data {
284            let protocol = PqxdhInboxProtocol::from_bytes(&key)
285                .map_err(|e| SessionManagerError::Serialization(e.to_string()))?;
286            let handler = Arc::new(PqxdhProtocolHandler::from_state(
287                messages_manager.clone(),
288                client_keypair.clone(),
289                protocol_state,
290            ));
291            handlers.push(
292                Self::init_pqxdh_handler(
293                    storage.clone(),
294                    protocol,
295                    handler,
296                    namespace.clone(),
297                    false,
298                )
299                .await?,
300            );
301        }
302
303        tracing::info!("PQXDH state loading completed");
304        Ok(handlers)
305    }
306
307    async fn init_pqxdh_handler(
308        storage: Arc<S>,
309        handler_id: PqxdhInboxProtocol,
310        handler: Arc<PqxdhProtocolHandler<M>>,
311        namespace: StateNamespace,
312        persist_initial_state: bool,
313    ) -> SessionManagerResult<(
314        PqxdhInboxProtocol,
315        (Arc<PqxdhProtocolHandler<M>>, JoinHandle<()>),
316    )> {
317        // Subscribe to state changes for this handler
318        let mut subscriber = handler.subscribe_to_state().await;
319        let handler_id_bytes = handler_id.into_bytes();
320        if persist_initial_state {
321            let initial_state = subscriber.get().await;
322            // Persist initial state
323            storage
324                .store(&namespace, &handler_id_bytes, &initial_state)
325                .await
326                .map_err(SessionManagerError::Storage)?;
327        }
328
329        // Spawn background task to listen to state changes
330        let storage_clone = storage.clone();
331        let handler_id_clone = handler_id.clone();
332
333        let listener_task = tokio::spawn(async move {
334            tracing::info!("Started PQXDH state listener for handler: {handler_id_clone}");
335
336            // Listen for state changes
337            while let Some(new_state) = subscriber.next().await {
338                // Only process if state actually changed (compare serialized forms)
339                let _new_state_bytes = match postcard::to_stdvec(&new_state) {
340                    Ok(bytes) => bytes,
341                    Err(e) => {
342                        tracing::error!(error=?e, "Failed to serialize PQXDH state for {handler_id_clone}");
343                        continue;
344                    }
345                };
346
347                // Persist to storage
348                if let Err(e) = storage_clone
349                    .store(&namespace, &handler_id_bytes, &new_state)
350                    .await
351                {
352                    tracing::error!(error=?e, "Failed to persist PQXDH state for {handler_id_clone}");
353                    continue;
354                }
355            }
356
357            tracing::debug!("PQXDH state listener ended for handler: {handler_id_clone}");
358        });
359
360        tracing::info!("Registered PQXDH handler: {}", handler_id);
361        Ok((handler_id.clone(), (handler, listener_task)))
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    // use std::collections::BTreeMap;
369    use std::sync::Arc;
370    // use tokio::time::{sleep, Duration};
371    use rand::thread_rng;
372
373    use zoe_client_storage::{StateNamespace, storage::MockStateStorage};
374
375    use zoe_wire_protocol::{KeyPair, PqxdhInboxProtocol};
376
377    use crate::pqxdh::PqxdhProtocolState;
378    use crate::services::messages_manager::MockMessagesManagerTrait;
379
380    fn create_test_keypair() -> KeyPair {
381        let mut rng = thread_rng();
382        KeyPair::generate_ed25519(&mut rng)
383    }
384
385    #[allow(dead_code)]
386    fn create_test_namespace(keypair: &KeyPair) -> StateNamespace {
387        StateNamespace::PqxdhSession(KeyId::from(*keypair.public_key().id()))
388    }
389
390    /// Test SessionManagerBuilder creation and configuration
391    #[tokio::test]
392    async fn test_session_manager_builder() {
393        let mut mock_storage = MockStateStorage::new();
394        let mock_messages = MockMessagesManagerTrait::new();
395        let keypair = Arc::new(create_test_keypair());
396
397        // Mock storage to return empty data (no existing sessions)
398        mock_storage
399            .expect_list_namespace_data::<PqxdhProtocolState>()
400            .returning(|_| Ok(Vec::new()));
401
402        let builder = SessionManagerBuilder::new(Arc::new(mock_storage), Arc::new(mock_messages))
403            .client_keypair(keypair.clone());
404
405        assert!(builder.client_keypair.is_some());
406        assert_eq!(
407            builder.client_keypair.unwrap().public_key().id(),
408            keypair.public_key().id()
409        );
410    }
411
412    /// Test SessionManager creation through builder
413    #[tokio::test]
414    async fn test_session_manager_creation() {
415        let mut mock_storage = MockStateStorage::new();
416        let mock_messages = MockMessagesManagerTrait::new();
417        let keypair = Arc::new(create_test_keypair());
418
419        // Mock storage to return empty data (no existing sessions)
420        mock_storage
421            .expect_list_namespace_data::<PqxdhProtocolState>()
422            .returning(|_| Ok(Vec::new()));
423        mock_storage
424            .expect_list_namespace_data::<zoe_state_machine::GroupSession>()
425            .returning(|_| Ok(Vec::new()));
426
427        let manager = SessionManagerBuilder::new(Arc::new(mock_storage), Arc::new(mock_messages))
428            .client_keypair(keypair.clone())
429            .build()
430            .await
431            .expect("Failed to build SessionManager");
432
433        assert_eq!(
434            manager.client_keypair.public_key().id(),
435            keypair.public_key().id()
436        );
437        assert_eq!(manager.pqxdh_handlers.read().await.len(), 0);
438    }
439
440    /// Test SessionManager creation without client keypair fails
441    #[tokio::test]
442    async fn test_session_manager_creation_without_keypair() {
443        let mock_storage = MockStateStorage::new();
444        let mock_messages = MockMessagesManagerTrait::new();
445
446        let result = SessionManagerBuilder::new(Arc::new(mock_storage), Arc::new(mock_messages))
447            .build()
448            .await;
449
450        assert!(matches!(
451            result,
452            Err(SessionManagerError::ClientKeypairNotFound)
453        ));
454    }
455
456    /// Test PQXDH handler creation and registration
457    #[tokio::test]
458    async fn test_pqxdh_handler_registration() {
459        let mut mock_storage = MockStateStorage::new();
460        let mock_messages = MockMessagesManagerTrait::new();
461        let keypair = Arc::new(create_test_keypair());
462        let protocol = PqxdhInboxProtocol::EchoService;
463
464        // Mock storage for initial load (empty)
465        mock_storage
466            .expect_list_namespace_data::<PqxdhProtocolState>()
467            .returning(|_| Ok(Vec::new()));
468        mock_storage
469            .expect_list_namespace_data::<zoe_state_machine::GroupSession>()
470            .returning(|_| Ok(Vec::new()));
471
472        // Mock storage for initial state persistence
473        mock_storage
474            .expect_store::<PqxdhProtocolState>()
475            .returning(|_, _, _| Ok(()));
476
477        let manager = SessionManagerBuilder::new(Arc::new(mock_storage), Arc::new(mock_messages))
478            .client_keypair(keypair.clone())
479            .build()
480            .await
481            .expect("Failed to build SessionManager");
482
483        // Request a PQXDH handler
484        let handler = manager
485            .pqxdh_handler(protocol.clone())
486            .await
487            .expect("Failed to get PQXDH handler");
488
489        // Verify handler was created and registered
490        assert_eq!(manager.pqxdh_handlers.read().await.len(), 1);
491        assert!(manager.pqxdh_handlers.read().await.contains_key(&protocol));
492
493        // Verify the handler has the correct configuration
494        let _state = handler.subscribe_to_state().await.get().await;
495        // Note: protocol field is private, so we can't directly compare it
496        // The fact that we got a handler without errors indicates it was created correctly
497    }
498
499    /// Test PQXDH handler reuse (same protocol returns same handler)
500    #[tokio::test]
501    async fn test_pqxdh_handler_reuse() {
502        let mut mock_storage = MockStateStorage::new();
503        let mock_messages = MockMessagesManagerTrait::new();
504        let keypair = Arc::new(create_test_keypair());
505        let protocol = PqxdhInboxProtocol::EchoService;
506
507        // Mock storage for initial load (empty)
508        mock_storage
509            .expect_list_namespace_data::<PqxdhProtocolState>()
510            .returning(|_| Ok(Vec::new()));
511        mock_storage
512            .expect_list_namespace_data::<zoe_state_machine::GroupSession>()
513            .returning(|_| Ok(Vec::new()));
514
515        // Mock storage for initial state persistence (should only be called once)
516        mock_storage
517            .expect_store::<PqxdhProtocolState>()
518            .times(1)
519            .returning(|_, _, _| Ok(()));
520
521        let manager = SessionManagerBuilder::new(Arc::new(mock_storage), Arc::new(mock_messages))
522            .client_keypair(keypair.clone())
523            .build()
524            .await
525            .expect("Failed to build SessionManager");
526
527        // Request the same handler twice
528        let handler1 = manager
529            .pqxdh_handler(protocol.clone())
530            .await
531            .expect("Failed to get first handler");
532        let handler2 = manager
533            .pqxdh_handler(protocol.clone())
534            .await
535            .expect("Failed to get second handler");
536
537        // Should be the same Arc instance
538        assert!(Arc::ptr_eq(&handler1, &handler2));
539        assert_eq!(manager.pqxdh_handlers.read().await.len(), 1);
540    }
541
542    /// Test loading existing PQXDH states from storage
543    #[tokio::test]
544    async fn test_load_existing_pqxdh_states() {
545        let mut mock_storage = MockStateStorage::new();
546        let mock_messages = MockMessagesManagerTrait::new();
547        let keypair = Arc::new(create_test_keypair());
548        let protocol = PqxdhInboxProtocol::EchoService;
549
550        // Create existing state data
551        let existing_state = PqxdhProtocolState::new(protocol.clone());
552        let protocol_bytes = protocol.clone().into_bytes();
553        let existing_data = vec![(protocol_bytes, existing_state)];
554
555        // Mock storage to return existing PQXDH data
556        mock_storage
557            .expect_list_namespace_data::<PqxdhProtocolState>()
558            .returning(move |_| Ok(existing_data.clone()));
559
560        // Mock storage to return empty group session data
561        mock_storage
562            .expect_list_namespace_data::<zoe_state_machine::GroupSession>()
563            .returning(|_| Ok(Vec::new()));
564
565        let manager = SessionManagerBuilder::new(Arc::new(mock_storage), Arc::new(mock_messages))
566            .client_keypair(keypair.clone())
567            .build()
568            .await
569            .expect("Failed to build SessionManager");
570
571        // Should have loaded the existing handler
572        assert_eq!(manager.pqxdh_handlers.read().await.len(), 1);
573        assert!(manager.pqxdh_handlers.read().await.contains_key(&protocol));
574    }
575
576    /// Test error handling when storage operations fail
577    #[tokio::test]
578    async fn test_storage_error_handling() {
579        let mut mock_storage = MockStateStorage::new();
580        let mock_messages = MockMessagesManagerTrait::new();
581        let keypair = Arc::new(create_test_keypair());
582
583        // Mock storage to fail on list_namespace_data for PQXDH
584        mock_storage
585            .expect_list_namespace_data::<PqxdhProtocolState>()
586            .returning(|_| {
587                Err(zoe_client_storage::StorageError::Serialization(
588                    postcard::Error::DeserializeUnexpectedEnd,
589                ))
590            });
591
592        let result = SessionManagerBuilder::new(Arc::new(mock_storage), Arc::new(mock_messages))
593            .client_keypair(keypair.clone())
594            .build()
595            .await;
596
597        assert!(matches!(result, Err(SessionManagerError::Storage(_))));
598    }
599
600    /// Test multiple different PQXDH protocols
601    #[tokio::test]
602    async fn test_multiple_pqxdh_protocols() {
603        let mut mock_storage = MockStateStorage::new();
604        let mock_messages = MockMessagesManagerTrait::new();
605        let keypair = Arc::new(create_test_keypair());
606
607        // Mock storage for initial load (empty)
608        mock_storage
609            .expect_list_namespace_data::<PqxdhProtocolState>()
610            .returning(|_| Ok(Vec::new()));
611        mock_storage
612            .expect_list_namespace_data::<zoe_state_machine::GroupSession>()
613            .returning(|_| Ok(Vec::new()));
614
615        // Mock storage for initial state persistence (2 protocols)
616        mock_storage
617            .expect_store::<PqxdhProtocolState>()
618            .times(2)
619            .returning(|_, _, _| Ok(()));
620
621        let manager = SessionManagerBuilder::new(Arc::new(mock_storage), Arc::new(mock_messages))
622            .client_keypair(keypair.clone())
623            .build()
624            .await
625            .expect("Failed to build SessionManager");
626
627        // Request handlers for different protocols
628        let handler1 = manager
629            .pqxdh_handler(PqxdhInboxProtocol::EchoService)
630            .await
631            .expect("Failed to get handler1");
632        let handler2 = manager
633            .pqxdh_handler(PqxdhInboxProtocol::CustomProtocol(12345))
634            .await
635            .expect("Failed to get handler2");
636
637        // Should have two different handlers
638        assert!(!Arc::ptr_eq(&handler1, &handler2));
639        assert_eq!(manager.pqxdh_handlers.read().await.len(), 2);
640    }
641}