1use std::collections::BTreeMap;
35use std::collections::btree_map::Entry;
36use std::sync::Arc;
37
38use zoe_wire_protocol::KeyId;
39
40use tokio::sync::RwLock;
41use tokio::task::JoinHandle;
42
43use zoe_client_storage::{StateNamespace, StateStorage, StorageError};
44use zoe_state_machine::{GroupDataUpdate, GroupManager};
45use zoe_wire_protocol::PqxdhInboxProtocol;
46
47use crate::pqxdh::PqxdhProtocolHandler;
48use crate::services::MessagesManagerTrait;
49
50#[derive(Debug, thiserror::Error)]
52pub enum SessionManagerError {
53 #[error("Storage error: {0}")]
55 Storage(#[from] StorageError),
56 #[error("Serialization error: {0}")]
58 Serialization(String),
59 #[error("Handler registration error: {0}")]
61 HandlerRegistration(String),
62 #[error("Client keypair not found")]
64 ClientKeypairNotFound,
65}
66
67pub type SessionManagerResult<T> = Result<T, SessionManagerError>;
69pub struct SessionManagerBuilder<S: StateStorage + 'static, M: MessagesManagerTrait + 'static> {
71 storage: Arc<S>,
72 messages_manager: Arc<M>,
73 client_keypair: Option<Arc<zoe_wire_protocol::KeyPair>>,
74}
75
76impl<S: StateStorage + 'static, M: MessagesManagerTrait + 'static> SessionManagerBuilder<S, M> {
77 pub fn new(storage: Arc<S>, messages_manager: Arc<M>) -> Self {
79 Self {
80 storage,
81 messages_manager,
82 client_keypair: None,
83 }
84 }
85
86 pub fn client_keypair(mut self, client_keypair: Arc<zoe_wire_protocol::KeyPair>) -> Self {
87 self.client_keypair = Some(client_keypair);
88 self
89 }
90
91 pub async fn build(self) -> SessionManagerResult<SessionManager<S, M>> {
93 let SessionManagerBuilder {
94 storage,
95 messages_manager,
96 client_keypair,
97 } = self;
98
99 let client_keypair = client_keypair.ok_or(SessionManagerError::ClientKeypairNotFound)?;
100
101 let states = SessionManager::load_pqxdh_states(
102 storage.clone(),
103 messages_manager.clone(),
104 client_keypair.clone(),
105 StateNamespace::PqxdhSession(KeyId::from(*client_keypair.id())),
106 )
107 .await?;
108
109 let (group_manager, group_manager_task) =
111 SessionManager::<S, M>::init_group_manager(storage.clone(), client_keypair.clone())
112 .await?;
113
114 let manager = SessionManager {
115 storage,
116 messages_manager,
117 pqxdh_handlers: RwLock::new(BTreeMap::from_iter(states)),
118 group_manager,
119 group_manager_task,
120 client_keypair,
121 };
122
123 tracing::info!("SessionManager built and initialized successfully");
124 Ok(manager)
125 }
126}
127
128type PqxdhHandlerState<M> = (Arc<PqxdhProtocolHandler<M>>, JoinHandle<()>);
129type PqxdhHandlerStates<M> = BTreeMap<PqxdhInboxProtocol, PqxdhHandlerState<M>>;
130
131pub struct SessionManager<S: StateStorage + 'static, M: MessagesManagerTrait + 'static> {
143 storage: Arc<S>,
145 messages_manager: Arc<M>,
147 pqxdh_handlers: RwLock<PqxdhHandlerStates<M>>,
149 group_manager: GroupManager,
151 #[allow(dead_code)]
152 group_manager_task: JoinHandle<()>,
153 client_keypair: Arc<zoe_wire_protocol::KeyPair>,
154}
155
156impl<S: StateStorage + 'static, M: MessagesManagerTrait + 'static> SessionManager<S, M> {
157 pub fn builder(storage: Arc<S>, messages_manager: Arc<M>) -> SessionManagerBuilder<S, M> {
166 SessionManagerBuilder::new(storage, messages_manager)
167 }
168 pub fn storage(&self) -> &Arc<S> {
170 &self.storage
171 }
172
173 pub fn messages_manager(&self) -> &Arc<M> {
175 &self.messages_manager
176 }
177
178 pub async fn pqxdh_handler(
179 &self,
180 protocol: PqxdhInboxProtocol,
181 ) -> SessionManagerResult<Arc<PqxdhProtocolHandler<M>>> {
182 match self.pqxdh_handlers.write().await.entry(protocol.clone()) {
183 Entry::Occupied(occupied) => Ok(occupied.get().0.clone()),
184 Entry::Vacant(vacant) => {
185 let handler = Arc::new(PqxdhProtocolHandler::new(
186 self.messages_manager.clone(),
187 self.client_keypair.clone(),
188 protocol.clone(),
189 ));
190 let (_idx, (handler_arc, listener_task)) = Self::init_pqxdh_handler(
191 self.storage.clone(),
192 protocol,
193 handler.clone(),
194 StateNamespace::PqxdhSession(KeyId::from(*self.client_keypair.id())),
195 true,
196 )
197 .await?;
198 vacant.insert((handler_arc.clone(), listener_task));
199 Ok(handler_arc)
200 }
201 }
202 }
203}
204
205impl<S: StateStorage + 'static, M: MessagesManagerTrait + 'static> SessionManager<S, M> {
207 async fn init_group_manager(
209 storage: Arc<S>,
210 client_keypair: Arc<zoe_wire_protocol::KeyPair>,
211 ) -> SessionManagerResult<(GroupManager, JoinHandle<()>)> {
212 let namespace = StateNamespace::GroupSession(KeyId::from(*client_keypair.id()));
213 let sessions: Vec<(Vec<u8>, zoe_state_machine::GroupSession)> = storage
214 .list_namespace_data(&namespace)
215 .await
216 .map_err(SessionManagerError::Storage)?;
217
218 let group_manager = GroupManager::builder()
219 .with_sessions(sessions.into_iter().map(|(_, session)| session).collect())
220 .build();
221 let mut group_updates = group_manager.subscribe_to_updates();
222
223 let task = tokio::spawn(async move {
224 while let Ok(update) = group_updates.recv().await {
225 tracing::debug!("Received group manager update: {:?}", update);
226 match update {
227 GroupDataUpdate::GroupAdded(group_session)
228 | GroupDataUpdate::GroupUpdated(group_session) => {
229 if let Err(e) = storage
230 .store(
231 &namespace,
232 group_session.state.group_id.as_bytes(),
233 &group_session,
234 )
235 .await
236 {
237 tracing::error!(error=?e, "Failed to persist group session");
238 }
239 }
240 GroupDataUpdate::GroupRemoved(group_session) => {
241 if let Err(e) = storage
242 .delete(&namespace, group_session.state.group_id.as_bytes())
243 .await
244 {
245 tracing::error!(error=?e, "Failed to delete group session");
246 }
247 }
248 }
249 }
250 tracing::info!("Group manager listener task ended");
251 });
252
253 tracing::info!("Group manager listener started");
254 Ok((group_manager, task))
255 }
256
257 pub fn group_manager(&self) -> GroupManager {
259 self.group_manager.clone()
260 }
261}
262
263impl<S: StateStorage + 'static, M: MessagesManagerTrait + 'static> SessionManager<S, M> {
265 async fn load_pqxdh_states(
267 storage: Arc<S>,
268 messages_manager: Arc<M>,
269 client_keypair: Arc<zoe_wire_protocol::KeyPair>,
270 namespace: StateNamespace,
271 ) -> SessionManagerResult<
272 Vec<(
273 PqxdhInboxProtocol,
274 (Arc<PqxdhProtocolHandler<M>>, JoinHandle<()>),
275 )>,
276 > {
277 let pqxdh_data = storage
278 .list_namespace_data(&namespace)
279 .await
280 .map_err(SessionManagerError::Storage)?;
281
282 let mut handlers = Vec::new();
283 for (key, protocol_state) in pqxdh_data {
284 let protocol = PqxdhInboxProtocol::from_bytes(&key)
285 .map_err(|e| SessionManagerError::Serialization(e.to_string()))?;
286 let handler = Arc::new(PqxdhProtocolHandler::from_state(
287 messages_manager.clone(),
288 client_keypair.clone(),
289 protocol_state,
290 ));
291 handlers.push(
292 Self::init_pqxdh_handler(
293 storage.clone(),
294 protocol,
295 handler,
296 namespace.clone(),
297 false,
298 )
299 .await?,
300 );
301 }
302
303 tracing::info!("PQXDH state loading completed");
304 Ok(handlers)
305 }
306
307 async fn init_pqxdh_handler(
308 storage: Arc<S>,
309 handler_id: PqxdhInboxProtocol,
310 handler: Arc<PqxdhProtocolHandler<M>>,
311 namespace: StateNamespace,
312 persist_initial_state: bool,
313 ) -> SessionManagerResult<(
314 PqxdhInboxProtocol,
315 (Arc<PqxdhProtocolHandler<M>>, JoinHandle<()>),
316 )> {
317 let mut subscriber = handler.subscribe_to_state().await;
319 let handler_id_bytes = handler_id.into_bytes();
320 if persist_initial_state {
321 let initial_state = subscriber.get().await;
322 storage
324 .store(&namespace, &handler_id_bytes, &initial_state)
325 .await
326 .map_err(SessionManagerError::Storage)?;
327 }
328
329 let storage_clone = storage.clone();
331 let handler_id_clone = handler_id.clone();
332
333 let listener_task = tokio::spawn(async move {
334 tracing::info!("Started PQXDH state listener for handler: {handler_id_clone}");
335
336 while let Some(new_state) = subscriber.next().await {
338 let _new_state_bytes = match postcard::to_stdvec(&new_state) {
340 Ok(bytes) => bytes,
341 Err(e) => {
342 tracing::error!(error=?e, "Failed to serialize PQXDH state for {handler_id_clone}");
343 continue;
344 }
345 };
346
347 if let Err(e) = storage_clone
349 .store(&namespace, &handler_id_bytes, &new_state)
350 .await
351 {
352 tracing::error!(error=?e, "Failed to persist PQXDH state for {handler_id_clone}");
353 continue;
354 }
355 }
356
357 tracing::debug!("PQXDH state listener ended for handler: {handler_id_clone}");
358 });
359
360 tracing::info!("Registered PQXDH handler: {}", handler_id);
361 Ok((handler_id.clone(), (handler, listener_task)))
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use std::sync::Arc;
370 use rand::thread_rng;
372
373 use zoe_client_storage::{StateNamespace, storage::MockStateStorage};
374
375 use zoe_wire_protocol::{KeyPair, PqxdhInboxProtocol};
376
377 use crate::pqxdh::PqxdhProtocolState;
378 use crate::services::messages_manager::MockMessagesManagerTrait;
379
380 fn create_test_keypair() -> KeyPair {
381 let mut rng = thread_rng();
382 KeyPair::generate_ed25519(&mut rng)
383 }
384
385 #[allow(dead_code)]
386 fn create_test_namespace(keypair: &KeyPair) -> StateNamespace {
387 StateNamespace::PqxdhSession(KeyId::from(*keypair.public_key().id()))
388 }
389
390 #[tokio::test]
392 async fn test_session_manager_builder() {
393 let mut mock_storage = MockStateStorage::new();
394 let mock_messages = MockMessagesManagerTrait::new();
395 let keypair = Arc::new(create_test_keypair());
396
397 mock_storage
399 .expect_list_namespace_data::<PqxdhProtocolState>()
400 .returning(|_| Ok(Vec::new()));
401
402 let builder = SessionManagerBuilder::new(Arc::new(mock_storage), Arc::new(mock_messages))
403 .client_keypair(keypair.clone());
404
405 assert!(builder.client_keypair.is_some());
406 assert_eq!(
407 builder.client_keypair.unwrap().public_key().id(),
408 keypair.public_key().id()
409 );
410 }
411
412 #[tokio::test]
414 async fn test_session_manager_creation() {
415 let mut mock_storage = MockStateStorage::new();
416 let mock_messages = MockMessagesManagerTrait::new();
417 let keypair = Arc::new(create_test_keypair());
418
419 mock_storage
421 .expect_list_namespace_data::<PqxdhProtocolState>()
422 .returning(|_| Ok(Vec::new()));
423 mock_storage
424 .expect_list_namespace_data::<zoe_state_machine::GroupSession>()
425 .returning(|_| Ok(Vec::new()));
426
427 let manager = SessionManagerBuilder::new(Arc::new(mock_storage), Arc::new(mock_messages))
428 .client_keypair(keypair.clone())
429 .build()
430 .await
431 .expect("Failed to build SessionManager");
432
433 assert_eq!(
434 manager.client_keypair.public_key().id(),
435 keypair.public_key().id()
436 );
437 assert_eq!(manager.pqxdh_handlers.read().await.len(), 0);
438 }
439
440 #[tokio::test]
442 async fn test_session_manager_creation_without_keypair() {
443 let mock_storage = MockStateStorage::new();
444 let mock_messages = MockMessagesManagerTrait::new();
445
446 let result = SessionManagerBuilder::new(Arc::new(mock_storage), Arc::new(mock_messages))
447 .build()
448 .await;
449
450 assert!(matches!(
451 result,
452 Err(SessionManagerError::ClientKeypairNotFound)
453 ));
454 }
455
456 #[tokio::test]
458 async fn test_pqxdh_handler_registration() {
459 let mut mock_storage = MockStateStorage::new();
460 let mock_messages = MockMessagesManagerTrait::new();
461 let keypair = Arc::new(create_test_keypair());
462 let protocol = PqxdhInboxProtocol::EchoService;
463
464 mock_storage
466 .expect_list_namespace_data::<PqxdhProtocolState>()
467 .returning(|_| Ok(Vec::new()));
468 mock_storage
469 .expect_list_namespace_data::<zoe_state_machine::GroupSession>()
470 .returning(|_| Ok(Vec::new()));
471
472 mock_storage
474 .expect_store::<PqxdhProtocolState>()
475 .returning(|_, _, _| Ok(()));
476
477 let manager = SessionManagerBuilder::new(Arc::new(mock_storage), Arc::new(mock_messages))
478 .client_keypair(keypair.clone())
479 .build()
480 .await
481 .expect("Failed to build SessionManager");
482
483 let handler = manager
485 .pqxdh_handler(protocol.clone())
486 .await
487 .expect("Failed to get PQXDH handler");
488
489 assert_eq!(manager.pqxdh_handlers.read().await.len(), 1);
491 assert!(manager.pqxdh_handlers.read().await.contains_key(&protocol));
492
493 let _state = handler.subscribe_to_state().await.get().await;
495 }
498
499 #[tokio::test]
501 async fn test_pqxdh_handler_reuse() {
502 let mut mock_storage = MockStateStorage::new();
503 let mock_messages = MockMessagesManagerTrait::new();
504 let keypair = Arc::new(create_test_keypair());
505 let protocol = PqxdhInboxProtocol::EchoService;
506
507 mock_storage
509 .expect_list_namespace_data::<PqxdhProtocolState>()
510 .returning(|_| Ok(Vec::new()));
511 mock_storage
512 .expect_list_namespace_data::<zoe_state_machine::GroupSession>()
513 .returning(|_| Ok(Vec::new()));
514
515 mock_storage
517 .expect_store::<PqxdhProtocolState>()
518 .times(1)
519 .returning(|_, _, _| Ok(()));
520
521 let manager = SessionManagerBuilder::new(Arc::new(mock_storage), Arc::new(mock_messages))
522 .client_keypair(keypair.clone())
523 .build()
524 .await
525 .expect("Failed to build SessionManager");
526
527 let handler1 = manager
529 .pqxdh_handler(protocol.clone())
530 .await
531 .expect("Failed to get first handler");
532 let handler2 = manager
533 .pqxdh_handler(protocol.clone())
534 .await
535 .expect("Failed to get second handler");
536
537 assert!(Arc::ptr_eq(&handler1, &handler2));
539 assert_eq!(manager.pqxdh_handlers.read().await.len(), 1);
540 }
541
542 #[tokio::test]
544 async fn test_load_existing_pqxdh_states() {
545 let mut mock_storage = MockStateStorage::new();
546 let mock_messages = MockMessagesManagerTrait::new();
547 let keypair = Arc::new(create_test_keypair());
548 let protocol = PqxdhInboxProtocol::EchoService;
549
550 let existing_state = PqxdhProtocolState::new(protocol.clone());
552 let protocol_bytes = protocol.clone().into_bytes();
553 let existing_data = vec![(protocol_bytes, existing_state)];
554
555 mock_storage
557 .expect_list_namespace_data::<PqxdhProtocolState>()
558 .returning(move |_| Ok(existing_data.clone()));
559
560 mock_storage
562 .expect_list_namespace_data::<zoe_state_machine::GroupSession>()
563 .returning(|_| Ok(Vec::new()));
564
565 let manager = SessionManagerBuilder::new(Arc::new(mock_storage), Arc::new(mock_messages))
566 .client_keypair(keypair.clone())
567 .build()
568 .await
569 .expect("Failed to build SessionManager");
570
571 assert_eq!(manager.pqxdh_handlers.read().await.len(), 1);
573 assert!(manager.pqxdh_handlers.read().await.contains_key(&protocol));
574 }
575
576 #[tokio::test]
578 async fn test_storage_error_handling() {
579 let mut mock_storage = MockStateStorage::new();
580 let mock_messages = MockMessagesManagerTrait::new();
581 let keypair = Arc::new(create_test_keypair());
582
583 mock_storage
585 .expect_list_namespace_data::<PqxdhProtocolState>()
586 .returning(|_| {
587 Err(zoe_client_storage::StorageError::Serialization(
588 postcard::Error::DeserializeUnexpectedEnd,
589 ))
590 });
591
592 let result = SessionManagerBuilder::new(Arc::new(mock_storage), Arc::new(mock_messages))
593 .client_keypair(keypair.clone())
594 .build()
595 .await;
596
597 assert!(matches!(result, Err(SessionManagerError::Storage(_))));
598 }
599
600 #[tokio::test]
602 async fn test_multiple_pqxdh_protocols() {
603 let mut mock_storage = MockStateStorage::new();
604 let mock_messages = MockMessagesManagerTrait::new();
605 let keypair = Arc::new(create_test_keypair());
606
607 mock_storage
609 .expect_list_namespace_data::<PqxdhProtocolState>()
610 .returning(|_| Ok(Vec::new()));
611 mock_storage
612 .expect_list_namespace_data::<zoe_state_machine::GroupSession>()
613 .returning(|_| Ok(Vec::new()));
614
615 mock_storage
617 .expect_store::<PqxdhProtocolState>()
618 .times(2)
619 .returning(|_, _, _| Ok(()));
620
621 let manager = SessionManagerBuilder::new(Arc::new(mock_storage), Arc::new(mock_messages))
622 .client_keypair(keypair.clone())
623 .build()
624 .await
625 .expect("Failed to build SessionManager");
626
627 let handler1 = manager
629 .pqxdh_handler(PqxdhInboxProtocol::EchoService)
630 .await
631 .expect("Failed to get handler1");
632 let handler2 = manager
633 .pqxdh_handler(PqxdhInboxProtocol::CustomProtocol(12345))
634 .await
635 .expect("Failed to get handler2");
636
637 assert!(!Arc::ptr_eq(&handler1, &handler2));
639 assert_eq!(manager.pqxdh_handlers.read().await.len(), 2);
640 }
641}