1use std::collections::BTreeMap;
2use std::sync::Arc;
3use std::time::SystemTime;
4
5use async_broadcast::{Receiver, Sender as BroadcastSender};
6use async_trait::async_trait;
7use futures::stream::Stream;
8use tokio::sync::RwLock;
9use tokio::task::JoinHandle;
10
11use eyeball::{AsyncLock, SharedObservable};
12use zoe_client_storage::{MessageStorage, SubscriptionState};
13use zoe_wire_protocol::{
14 CatchUpResponse, Filter, KeyId, MessageFull, PublishResult, StoreKey, StreamMessage,
15};
16
17use crate::error::{ClientError, Result};
18use crate::services::{MessageEvent, MessagesManager, MessagesManagerTrait};
19
20const CATCH_UP_BATCH_SIZE: usize = 50;
22
23#[derive(Debug, Clone)]
25pub enum ConnectionState {
26 Connected,
28 Disconnected,
30 Reconnecting,
32 Failed {
34 last_error: String,
35 retry_at: SystemTime,
36 },
37}
38
39pub struct RelayConnection {
41 pub manager: Arc<MessagesManager>,
43 pub connection_state: ConnectionState,
45 pub last_seen: SystemTime,
47 pub subscription_state: SubscriptionState,
49}
50
51#[derive(Clone)]
62pub struct MultiRelayMessageManager<S: MessageStorage> {
63 relay_connections: Arc<RwLock<BTreeMap<KeyId, RelayConnection>>>,
65 storage: Arc<S>,
67 global_events_tx: BroadcastSender<MessageEvent>,
69 global_messages_tx: BroadcastSender<StreamMessage>,
71 global_catchup_tx: BroadcastSender<CatchUpResponse>,
73 catch_up_tasks: Arc<RwLock<BTreeMap<KeyId, JoinHandle<()>>>>,
75}
76
77impl<S: MessageStorage + 'static> MultiRelayMessageManager<S> {
78 pub fn new(storage: Arc<S>) -> Self {
80 let (global_events_tx, _) = async_broadcast::broadcast(1000);
81 let (global_messages_tx, _) = async_broadcast::broadcast(1000);
82 let (global_catchup_tx, _) = async_broadcast::broadcast(1000);
83
84 let relay_connections = Arc::new(RwLock::new(BTreeMap::new()));
85 let catch_up_tasks = Arc::new(RwLock::new(BTreeMap::new()));
86
87 Self {
88 relay_connections,
89 storage,
90 global_events_tx,
91 global_messages_tx,
92 global_catchup_tx,
93 catch_up_tasks,
94 }
95 }
96
97 pub async fn add_relay(
104 &self,
105 relay_id: KeyId,
106 manager: Arc<MessagesManager>,
107 should_catch_up: bool,
108 ) -> Result<()> {
109 let connection = RelayConnection {
110 manager: Arc::clone(&manager),
111 connection_state: ConnectionState::Connected,
112 last_seen: SystemTime::now(),
113 subscription_state: SubscriptionState::default(),
114 };
115
116 {
118 let mut connections = self.relay_connections.write().await;
119 connections.insert(relay_id, connection);
120 }
121
122 tracing::info!(
123 "Added relay connection: {}",
124 hex::encode(relay_id.as_bytes())
125 );
126
127 if should_catch_up {
129 self.start_catch_up_task(relay_id, manager).await?;
130 }
131
132 Ok(())
133 }
134
135 pub async fn remove_relay(&self, relay_id: &KeyId) -> Option<Arc<MessagesManager>> {
137 {
139 let mut tasks = self.catch_up_tasks.write().await;
140 if let Some(task) = tasks.remove(relay_id) {
141 task.abort();
142 tracing::debug!(
143 "Cancelled catch-up task for relay: {}",
144 hex::encode(relay_id.as_bytes())
145 );
146 }
147 }
148
149 let mut connections = self.relay_connections.write().await;
150 let removed = connections.remove(relay_id);
151
152 if let Some(connection) = &removed {
153 tracing::info!(
154 "Removed relay connection: {}",
155 hex::encode(relay_id.as_bytes())
156 );
157 Some(Arc::clone(&connection.manager))
158 } else {
159 None
160 }
161 }
162
163 pub async fn get_connected_relay_ids(&self) -> Vec<KeyId> {
165 let connections = self.relay_connections.read().await;
166 connections
167 .iter()
168 .filter(|(_, conn)| matches!(conn.connection_state, ConnectionState::Connected))
169 .map(|(id, _)| *id)
170 .collect()
171 }
172
173 pub async fn get_all_relay_ids(&self) -> Vec<KeyId> {
175 let connections = self.relay_connections.read().await;
176 connections.keys().copied().collect()
177 }
178
179 pub async fn has_connected_relays(&self) -> bool {
181 !self.get_connected_relay_ids().await.is_empty()
182 }
183
184 async fn start_catch_up_task(
186 &self,
187 relay_id: KeyId,
188 manager: Arc<MessagesManager>,
189 ) -> Result<()> {
190 let storage = Arc::clone(&self.storage);
191 let tasks = Arc::clone(&self.catch_up_tasks);
192
193 let task = tokio::spawn(async move {
194 tracing::info!(
195 "Starting catch-up task for relay: {}",
196 hex::encode(relay_id.as_bytes())
197 );
198
199 match Self::process_all_unsynced_messages_for_relay::<MessagesManager>(
200 &storage,
201 &relay_id,
202 &manager,
203 CATCH_UP_BATCH_SIZE,
204 )
205 .await
206 {
207 Ok(total_processed) => {
208 tracing::info!(
209 "Catch-up completed for relay: {} ({} messages processed)",
210 hex::encode(relay_id.as_bytes()),
211 total_processed
212 );
213 }
214 Err(e) => {
215 tracing::warn!(
216 "Catch-up failed for relay {}: {}",
217 hex::encode(relay_id.as_bytes()),
218 e
219 );
220 }
221 }
222
223 let mut tasks_guard = tasks.write().await;
225 tasks_guard.remove(&relay_id);
226 });
227
228 {
230 let mut tasks_guard = self.catch_up_tasks.write().await;
231 tasks_guard.insert(relay_id, task);
232 }
233
234 Ok(())
235 }
236
237 async fn process_all_unsynced_messages_for_relay<M: MessagesManagerTrait>(
240 storage: &Arc<S>,
241 relay_id: &KeyId,
242 manager: &Arc<M>,
243 batch_size: usize,
244 ) -> Result<usize> {
245 let mut total_processed = 0;
246 let mut batch_count = 0;
247
248 loop {
249 tracing::debug!(
250 "Processing batch {} for relay {} (batch size: {})",
251 batch_count,
252 hex::encode(relay_id.as_bytes()),
253 batch_size
254 );
255
256 let batch_result = Self::process_unsynced_messages_batch_for_relay(
257 storage, relay_id, manager, batch_size,
258 )
259 .await?;
260
261 let Some(batch_processed) = batch_result else {
262 tracing::debug!(
264 "No unsynced messages left for relay {}, stopping catch-up after {} batches",
265 hex::encode(relay_id.as_bytes()),
266 batch_count
267 );
268 break;
269 };
270
271 total_processed += batch_processed;
272 batch_count += 1;
273
274 tracing::debug!(
275 "Batch {} complete: {} messages processed for relay {}",
276 batch_count,
277 batch_processed,
278 hex::encode(relay_id.as_bytes())
279 );
280 }
281
282 Ok(total_processed)
283 }
284
285 async fn process_unsynced_messages_batch_for_relay<M: MessagesManagerTrait>(
291 storage: &Arc<S>,
292 relay_id: &KeyId,
293 manager: &Arc<M>,
294 batch_size: usize,
295 ) -> Result<Option<usize>> {
296 let relay_key_id = relay_id;
298
299 let unsynced_messages = storage
301 .get_unsynced_messages_for_relay(relay_key_id, Some(batch_size))
302 .await
303 .map_err(|e| ClientError::Generic(format!("Failed to get unsynced messages: {e}")))?;
304
305 if unsynced_messages.is_empty() {
306 return Ok(None); }
308
309 let batch_message_count = unsynced_messages.len();
310 tracing::debug!(
311 "Processing batch of {} unsynced messages for relay {}",
312 batch_message_count,
313 hex::encode(relay_id.as_bytes())
314 );
315
316 let mut valid_messages = Vec::new();
318 let mut expired_count = 0;
319
320 let current_time = std::time::SystemTime::now()
321 .duration_since(std::time::UNIX_EPOCH)
322 .unwrap_or_default()
323 .as_secs();
324
325 for message in unsynced_messages {
326 if message.is_expired(current_time) {
327 if let Err(e) = storage.delete_message(message.id()).await {
329 tracing::error!(
330 "Failed to delete expired message {} from storage: {}. Continuing with other messages.",
331 hex::encode(message.id().as_bytes()),
332 e
333 );
334 }
336 expired_count += 1;
337 continue;
338 } else {
339 valid_messages.push(message);
340 }
341 }
342
343 if valid_messages.is_empty() {
346 tracing::debug!(
347 "All {} messages in batch were expired and removed for relay {}",
348 expired_count,
349 hex::encode(relay_id.as_bytes())
350 );
351 return Ok(Some(0));
352 }
353
354 let message_ids: Vec<_> = valid_messages.iter().map(|msg| *msg.id()).collect();
356
357 tracing::debug!(
358 "Checking existence of {} valid messages on relay {} ({} expired messages removed)",
359 message_ids.len(),
360 hex::encode(relay_id.as_bytes()),
361 expired_count
362 );
363
364 let existence_results = manager.check_messages(message_ids).await?;
365 let mut processed_count = 0;
366 for (message, existence_result) in valid_messages.iter().zip(existence_results.iter()) {
368 storage.store_message(message).await?;
370
371 if let Some(global_stream_id) = existence_result {
372 if let Err(e) = storage
374 .mark_message_synced(message.id(), relay_key_id, global_stream_id)
375 .await
376 {
377 tracing::error!(
378 "Failed to mark existing message {} as synced to relay {}: {}",
379 hex::encode(message.id().as_bytes()),
380 hex::encode(relay_id.as_bytes()),
381 e
382 );
383 continue;
384 }
385
386 tracing::debug!(
387 "Message {} already exists on relay {}, marked as synced",
388 hex::encode(message.id().as_bytes()),
389 hex::encode(relay_id.as_bytes())
390 );
391 processed_count += 1;
392 continue;
393 }
394
395 let global_stream_id = match manager.publish(message.clone()).await {
397 Ok(result) => match result {
398 PublishResult::StoredNew { global_stream_id } => global_stream_id,
399 PublishResult::AlreadyExists { global_stream_id } => global_stream_id,
400 PublishResult::Expired => {
401 tracing::warn!("Message expired: {}", hex::encode(message.id().as_bytes()));
402 continue;
403 }
404 },
405 Err(e) => {
406 tracing::error!(
407 "Failed to send message {} to relay {}: {}",
408 hex::encode(message.id().as_bytes()),
409 hex::encode(relay_id.as_bytes()),
410 e
411 );
412 continue;
413 }
414 };
415
416 if let Err(e) = storage
418 .mark_message_synced(message.id(), relay_key_id, &global_stream_id)
419 .await
420 {
421 tracing::error!(
422 "Failed to mark message {} as synced to relay {}: {}",
423 hex::encode(message.id().as_bytes()),
424 hex::encode(relay_id.as_bytes()),
425 e
426 );
427 continue;
428 }
429 processed_count += 1;
430 }
431
432 tracing::debug!(
433 "Batch complete: {}/{} valid messages successfully processed for relay {} ({} expired messages removed)",
434 processed_count,
435 valid_messages.len(),
436 hex::encode(relay_id.as_bytes()),
437 expired_count
438 );
439
440 Ok(Some(processed_count))
441 }
442}
443
444#[async_trait]
445impl<S: MessageStorage + 'static> MessagesManagerTrait for MultiRelayMessageManager<S> {
446 fn message_events_stream(&self) -> Receiver<MessageEvent> {
448 self.global_events_tx.new_receiver()
449 }
450
451 async fn get_subscription_state_updates(
453 &self,
454 ) -> eyeball::Subscriber<SubscriptionState, AsyncLock> {
455 let state = SubscriptionState::default();
458 let observable = SharedObservable::new_async(state);
459 observable.subscribe().await
460 }
461
462 async fn subscribe(&self) -> Result<()> {
464 let connections = self.relay_connections.read().await;
465 let mut results = Vec::new();
466
467 for (relay_id, connection) in connections.iter() {
468 if matches!(connection.connection_state, ConnectionState::Connected) {
469 match connection.manager.subscribe().await {
470 Ok(()) => {
471 tracing::debug!("Subscribed to relay {}", hex::encode(relay_id.as_bytes()));
472 }
473 Err(e) => {
474 tracing::warn!(
475 "Failed to subscribe to relay {}: {}",
476 hex::encode(relay_id.as_bytes()),
477 e
478 );
479 results.push(e);
480 }
481 }
482 }
483 }
484
485 if !results.is_empty() && results.len() == connections.len() {
487 return Err(results.into_iter().next().unwrap());
488 }
489
490 Ok(())
491 }
492
493 async fn publish(&self, message: MessageFull) -> Result<PublishResult> {
495 let connected_relays = {
496 let connections = self.relay_connections.read().await;
497 connections
498 .iter()
499 .filter(|(_, conn)| matches!(conn.connection_state, ConnectionState::Connected))
500 .map(|(id, conn)| (*id, Arc::clone(&conn.manager)))
501 .collect::<Vec<_>>()
502 };
503
504 if connected_relays.is_empty() {
505 self.storage.store_message(&message).await.map_err(|e| {
507 ClientError::Generic(format!("Failed to store message for offline delivery: {e}"))
508 })?;
509
510 tracing::info!(
511 "No relays available, stored message {} for offline processing",
512 hex::encode(message.id().as_bytes())
513 );
514
515 return Ok(PublishResult::StoredNew {
516 global_stream_id: "queued_offline".to_string(),
517 });
518 }
519
520 let (relay_id, manager) = &connected_relays[0];
523
524 self.storage.store_message(&message).await.map_err(|e| {
526 ClientError::Generic(format!(
527 "Failed to store message {} locally: {}",
528 hex::encode(message.id().as_bytes()),
529 e
530 ))
531 })?;
532
533 match manager.publish(message.clone()).await {
534 Ok(result) => {
535 let relay_key_id = relay_id;
537 let global_stream_id = match &result {
538 PublishResult::StoredNew { global_stream_id } => global_stream_id,
539 PublishResult::AlreadyExists { global_stream_id } => global_stream_id,
540 PublishResult::Expired => {
541 tracing::warn!("Message expired during publish");
542 return Ok(result);
543 }
544 };
545
546 self.storage
547 .mark_message_synced(message.id(), relay_key_id, global_stream_id)
548 .await
549 .map_err(|e| {
550 ClientError::Generic(format!(
551 "Failed to mark message {} as synced to relay {}: {}",
552 hex::encode(message.id().as_bytes()),
553 hex::encode(relay_id.as_bytes()),
554 e
555 ))
556 })?;
557
558 tracing::debug!(
559 "Successfully published message to relay {}",
560 hex::encode(relay_id.as_bytes())
561 );
562 Ok(result)
563 }
564 Err(e) => {
565 tracing::warn!(
566 "Failed to publish message to relay {}: {}",
567 hex::encode(relay_id.as_bytes()),
568 e
569 );
570
571 self.storage
573 .store_message(&message)
574 .await
575 .map_err(|storage_err| {
576 ClientError::Generic(format!(
577 "Failed to store message for offline delivery: {storage_err}"
578 ))
579 })?;
580
581 tracing::info!(
582 "Stored message {} for offline processing after publish failure",
583 hex::encode(message.id().as_bytes())
584 );
585
586 Err(e)
587 }
588 }
589 }
590
591 async fn ensure_contains_filter(&self, filter: Filter) -> Result<()> {
593 let connections = self.relay_connections.read().await;
594 let mut results = Vec::new();
595
596 for (relay_id, connection) in connections.iter() {
597 if matches!(connection.connection_state, ConnectionState::Connected) {
598 match connection
599 .manager
600 .ensure_contains_filter(filter.clone())
601 .await
602 {
603 Ok(()) => {
604 tracing::debug!(
605 "Added filter to relay {}",
606 hex::encode(relay_id.as_bytes())
607 );
608 }
609 Err(e) => {
610 tracing::warn!(
611 "Failed to add filter to relay {}: {}",
612 hex::encode(relay_id.as_bytes()),
613 e
614 );
615 results.push(e);
616 }
617 }
618 }
619 }
620
621 if !results.is_empty() && results.len() == connections.len() {
623 return Err(results.into_iter().next().unwrap());
624 }
625
626 Ok(())
627 }
628
629 fn messages_stream(&self) -> Receiver<StreamMessage> {
631 self.global_messages_tx.new_receiver()
632 }
633
634 fn catch_up_stream(&self) -> Receiver<CatchUpResponse> {
636 self.global_catchup_tx.new_receiver()
637 }
638
639 fn filtered_messages_stream(
641 &self,
642 _filter: Filter,
643 ) -> std::pin::Pin<Box<dyn Stream<Item = Box<MessageFull>> + Send>> {
644 Box::pin(futures::stream::empty())
647 }
648
649 async fn catch_up_and_subscribe(
651 &self,
652 _filter: Filter,
653 _since: Option<String>,
654 ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Box<MessageFull>> + Send>>> {
655 Ok(Box::pin(futures::stream::empty()))
658 }
659
660 async fn user_data(&self, author: KeyId, storage_key: StoreKey) -> Result<Option<MessageFull>> {
662 let connections = self.relay_connections.read().await;
663
664 for (relay_id, connection) in connections.iter() {
666 if matches!(connection.connection_state, ConnectionState::Connected) {
667 match connection
668 .manager
669 .user_data(author, storage_key.clone())
670 .await
671 {
672 Ok(Some(data)) => {
673 tracing::debug!(
674 "Found user data on relay {}",
675 hex::encode(relay_id.as_bytes())
676 );
677 return Ok(Some(data));
678 }
679 Ok(None) => {
680 continue;
682 }
683 Err(e) => {
684 tracing::warn!(
685 "Failed to get user data from relay {}: {}",
686 hex::encode(relay_id.as_bytes()),
687 e
688 );
689 continue;
690 }
691 }
692 }
693 }
694
695 Ok(None) }
697
698 async fn check_messages(
701 &self,
702 message_ids: Vec<zoe_wire_protocol::MessageId>,
703 ) -> Result<Vec<Option<String>>> {
704 let connections = self.relay_connections.read().await;
705
706 for (relay_id, connection) in connections.iter() {
708 if matches!(connection.connection_state, ConnectionState::Connected) {
709 match connection.manager.check_messages(message_ids.clone()).await {
710 Ok(result) => {
711 tracing::debug!(
712 "Successfully checked {} messages on relay {}",
713 message_ids.len(),
714 hex::encode(relay_id.as_bytes())
715 );
716 return Ok(result);
717 }
718 Err(e) => {
719 tracing::warn!(
720 "Failed to check messages on relay {}: {}",
721 hex::encode(relay_id.as_bytes()),
722 e
723 );
724 continue;
725 }
726 }
727 }
728 }
729
730 tracing::warn!("No connected relays available for checking messages");
732 Ok(vec![None; message_ids.len()])
733 }
734}
735
736impl<S: MessageStorage> Drop for MultiRelayMessageManager<S> {
737 fn drop(&mut self) {
738 if let Ok(mut tasks) = self.catch_up_tasks.try_write() {
740 for (relay_id, task) in tasks.iter() {
741 task.abort();
742 tracing::debug!(
743 "Cancelled catch-up task for relay during drop: {}",
744 hex::encode(relay_id.as_bytes())
745 );
746 }
747 tasks.clear();
748 }
749 }
750}
751
752#[cfg(test)]
753mod tests {
754 use super::*;
755 use zoe_client_storage::storage::MockMessageStorage;
756 use zoe_wire_protocol::{Content, KeyPair, Kind, Message, MessageFull};
757
758 fn create_test_message() -> MessageFull {
759 let mut rng = rand::thread_rng();
760 let keypair = KeyPair::generate(&mut rng);
761 let message = Message::new_v0(
762 Content::raw(b"Test message".to_vec()),
763 keypair.public_key(),
764 1234567890u64, Kind::Regular,
766 vec![],
767 );
768 MessageFull::new(message, &keypair).unwrap()
769 }
770
771 #[tokio::test]
772 async fn test_multi_relay_manager_creation() {
773 let mock_storage = MockMessageStorage::new();
774 let manager = MultiRelayMessageManager::new(Arc::new(mock_storage));
775
776 assert_eq!(manager.get_all_relay_ids().await.len(), 0);
778 assert!(!manager.has_connected_relays().await);
779 }
780
781 #[tokio::test]
782 async fn test_offline_message_queuing() {
783 let mut mock_storage = MockMessageStorage::new();
784 mock_storage
785 .expect_store_message()
786 .times(1)
787 .returning(|_| Ok(()));
788
789 let manager = MultiRelayMessageManager::new(Arc::new(mock_storage));
790 let test_message = create_test_message();
791
792 let result = manager.publish(test_message).await;
794 assert!(result.is_ok());
795
796 if let Ok(publish_result) = result {
797 assert_eq!(
798 publish_result,
799 PublishResult::StoredNew {
800 global_stream_id: "queued_offline".to_string()
801 }
802 );
803 }
804 }
805
806 #[tokio::test]
807 async fn test_basic_functionality() {
808 let mock_storage = MockMessageStorage::new();
809 let manager = MultiRelayMessageManager::new(Arc::new(mock_storage));
810
811 assert_eq!(manager.get_all_relay_ids().await.len(), 0);
813 assert!(!manager.has_connected_relays().await);
814
815 let connected_ids = manager.get_connected_relay_ids().await;
817 assert!(connected_ids.is_empty());
818 }
819
820 #[tokio::test]
821 async fn test_session_manager_compatibility() {
822 use crate::pqxdh::PqxdhProtocolState;
823 use crate::session_manager::SessionManager;
824 use zoe_client_storage::storage::MockStateStorage;
825 use zoe_wire_protocol::KeyPair;
826
827 let mock_message_storage = MockMessageStorage::new();
828 let multi_relay_manager = Arc::new(MultiRelayMessageManager::new(Arc::new(
829 mock_message_storage,
830 )));
831
832 let mut mock_state_storage = MockStateStorage::new();
833 mock_state_storage
835 .expect_list_namespace_data::<PqxdhProtocolState>()
836 .returning(|_| Ok(vec![]));
837 mock_state_storage
839 .expect_list_namespace_data::<zoe_state_machine::GroupSession>()
840 .returning(|_| Ok(vec![]));
841
842 let mock_state_storage = Arc::new(mock_state_storage);
843
844 let mut rng = rand::thread_rng();
846 let keypair = Arc::new(KeyPair::generate(&mut rng));
847
848 let session_manager_result =
850 SessionManager::builder(mock_state_storage, multi_relay_manager)
851 .client_keypair(keypair)
852 .build()
853 .await;
854
855 assert!(session_manager_result.is_ok());
857
858 let session_manager = session_manager_result.unwrap();
859
860 let messages_manager = session_manager.messages_manager();
862
863 let _events_stream = messages_manager.message_events_stream();
865 }
867
868 #[tokio::test]
869 async fn test_pqxdh_handler_compatibility() {
870 use crate::pqxdh::PqxdhProtocolHandler;
871 use zoe_wire_protocol::{KeyPair, PqxdhInboxProtocol};
872
873 let mock_message_storage = MockMessageStorage::new();
874 let multi_relay_manager = Arc::new(MultiRelayMessageManager::new(Arc::new(
875 mock_message_storage,
876 )));
877
878 let mut rng = rand::thread_rng();
880 let keypair = Arc::new(KeyPair::generate(&mut rng));
881
882 let _pqxdh_handler = PqxdhProtocolHandler::new(
884 multi_relay_manager,
885 keypair,
886 PqxdhInboxProtocol::EchoService,
887 );
888
889 }
893
894 #[tokio::test]
895 async fn test_catch_up_processing() {
896 use crate::services::messages_manager::MockMessagesManagerTrait;
897
898 let test_message = create_test_message();
900 let relay_id = KeyId::from([1u8; 32]);
901 let relay_key_id = relay_id;
902
903 let mut mock_storage = MockMessageStorage::new();
905 mock_storage
906 .expect_get_unsynced_messages_for_relay()
907 .with(
908 mockall::predicate::eq(relay_key_id),
909 mockall::predicate::eq(Some(50)),
910 )
911 .times(1)
912 .returning(move |_, _| Ok(vec![test_message.clone()]));
913
914 mock_storage
916 .expect_mark_message_synced()
917 .times(1)
918 .returning(|_, _, _| Ok(()));
919
920 mock_storage
922 .expect_store_message()
923 .times(1)
924 .returning(|_| Ok(()));
925
926 let storage = Arc::new(mock_storage);
927
928 let mut mock_manager = MockMessagesManagerTrait::new();
930
931 mock_manager
933 .expect_check_messages()
934 .times(1)
935 .returning(|_| Ok(vec![None])); mock_manager.expect_publish().times(1).returning(|_| {
938 Ok(PublishResult::StoredNew {
939 global_stream_id: "test_stream_123".to_string(),
940 })
941 });
942
943 let mock_manager = Arc::new(mock_manager);
944
945 let result = MultiRelayMessageManager::process_unsynced_messages_batch_for_relay(
947 &storage,
948 &relay_id,
949 &mock_manager,
950 50,
951 )
952 .await;
953
954 assert!(result.is_ok(), "Catch-up processing should succeed");
955
956 assert_eq!(
958 result.unwrap(),
959 Some(1),
960 "Should have processed exactly 1 message"
961 );
962 }
963
964 #[tokio::test]
965 async fn test_batched_catch_up_processing() {
966 use crate::services::messages_manager::MockMessagesManagerTrait;
967
968 let test_messages: Vec<MessageFull> = (0u64..5u64)
970 .map(|i| {
971 let mut rng = rand::thread_rng();
972 let keypair = KeyPair::generate(&mut rng);
973 let message = Message::new_v0(
974 Content::raw(format!("Test message {i}").as_bytes().to_vec()),
975 keypair.public_key(),
976 1234567890u64 + i,
977 Kind::Regular,
978 vec![],
979 );
980 MessageFull::new(message, &keypair).unwrap()
981 })
982 .collect();
983
984 let relay_id = KeyId::from([1u8; 32]);
985 let relay_key_id = relay_id;
986
987 let mut mock_storage = MockMessageStorage::new();
989
990 let messages_batch_1 = vec![test_messages[0].clone(), test_messages[1].clone()];
992 mock_storage
993 .expect_get_unsynced_messages_for_relay()
994 .with(
995 mockall::predicate::eq(relay_key_id),
996 mockall::predicate::eq(Some(2)),
997 )
998 .times(1)
999 .returning(move |_, _| Ok(messages_batch_1.clone()));
1000
1001 let messages_batch_2 = vec![test_messages[2].clone(), test_messages[3].clone()];
1003 mock_storage
1004 .expect_get_unsynced_messages_for_relay()
1005 .with(
1006 mockall::predicate::eq(relay_key_id),
1007 mockall::predicate::eq(Some(2)),
1008 )
1009 .times(1)
1010 .returning(move |_, _| Ok(messages_batch_2.clone()));
1011
1012 let messages_batch_3 = vec![test_messages[4].clone()];
1014 mock_storage
1015 .expect_get_unsynced_messages_for_relay()
1016 .with(
1017 mockall::predicate::eq(relay_key_id),
1018 mockall::predicate::eq(Some(2)),
1019 )
1020 .times(1)
1021 .returning(move |_, _| Ok(messages_batch_3.clone()));
1022
1023 mock_storage
1025 .expect_get_unsynced_messages_for_relay()
1026 .with(
1027 mockall::predicate::eq(relay_key_id),
1028 mockall::predicate::eq(Some(2)),
1029 )
1030 .times(1)
1031 .returning(move |_, _| Ok(vec![]));
1032
1033 mock_storage
1035 .expect_mark_message_synced()
1036 .times(5)
1037 .returning(|_, _, _| Ok(()));
1038
1039 mock_storage
1041 .expect_store_message()
1042 .times(5)
1043 .returning(|_| Ok(()));
1044
1045 let storage = Arc::new(mock_storage);
1046
1047 let mut mock_manager = MockMessagesManagerTrait::new();
1049
1050 mock_manager
1053 .expect_check_messages()
1054 .times(1)
1055 .returning(|_| Ok(vec![None, None]));
1056
1057 mock_manager
1059 .expect_check_messages()
1060 .times(1)
1061 .returning(|_| Ok(vec![None, None]));
1062
1063 mock_manager
1065 .expect_check_messages()
1066 .times(1)
1067 .returning(|_| Ok(vec![None]));
1068
1069 mock_manager.expect_publish().times(5).returning(|_| {
1070 Ok(PublishResult::StoredNew {
1071 global_stream_id: "test_stream_123".to_string(),
1072 })
1073 });
1074
1075 let mock_manager = Arc::new(mock_manager);
1076
1077 let result = MultiRelayMessageManager::process_all_unsynced_messages_for_relay(
1079 &storage,
1080 &relay_id,
1081 &mock_manager,
1082 2, )
1084 .await;
1085
1086 assert!(result.is_ok(), "Batched catch-up processing should succeed");
1087
1088 assert_eq!(
1090 result.unwrap(),
1091 5,
1092 "Should have processed exactly 5 messages total"
1093 );
1094 }
1095
1096 #[tokio::test]
1097 async fn test_efficient_catch_up_with_existing_messages() {
1098 use crate::services::messages_manager::MockMessagesManagerTrait;
1099
1100 let test_messages: Vec<MessageFull> = (0u64..3u64)
1102 .map(|i| {
1103 let mut rng = rand::thread_rng();
1104 let keypair = KeyPair::generate(&mut rng);
1105 let message = Message::new_v0(
1106 Content::raw(format!("Test message {i}").as_bytes().to_vec()),
1107 keypair.public_key(),
1108 1234567890u64 + i,
1109 Kind::Regular,
1110 vec![],
1111 );
1112 MessageFull::new(message, &keypair).unwrap()
1113 })
1114 .collect();
1115
1116 let relay_id = KeyId::from([1u8; 32]);
1117 let relay_key_id = relay_id;
1118
1119 let message_ids = vec![
1121 *test_messages[0].id(),
1122 *test_messages[1].id(),
1123 *test_messages[2].id(),
1124 ];
1125
1126 let mut mock_storage = MockMessageStorage::new();
1128 mock_storage
1129 .expect_get_unsynced_messages_for_relay()
1130 .with(
1131 mockall::predicate::eq(relay_key_id),
1132 mockall::predicate::eq(Some(3)),
1133 )
1134 .times(1)
1135 .returning(move |_, _| Ok(test_messages.clone()));
1136
1137 mock_storage
1139 .expect_get_unsynced_messages_for_relay()
1140 .with(
1141 mockall::predicate::eq(relay_key_id),
1142 mockall::predicate::eq(Some(3)),
1143 )
1144 .times(1)
1145 .returning(move |_, _| Ok(vec![]));
1146
1147 mock_storage
1149 .expect_mark_message_synced()
1150 .times(3)
1151 .returning(|_, _, _| Ok(()));
1152
1153 mock_storage
1155 .expect_store_message()
1156 .times(3)
1157 .returning(|_| Ok(()));
1158
1159 let storage = Arc::new(mock_storage);
1160
1161 let mut mock_manager = MockMessagesManagerTrait::new();
1163
1164 mock_manager
1169 .expect_check_messages()
1170 .with(mockall::predicate::eq(message_ids))
1171 .times(1)
1172 .returning(|_| {
1173 Ok(vec![
1174 Some("existing_stream_1".to_string()), None, Some("existing_stream_2".to_string()), ])
1178 });
1179
1180 mock_manager.expect_publish().times(1).returning(|_| {
1182 Ok(PublishResult::StoredNew {
1183 global_stream_id: "new_stream_123".to_string(),
1184 })
1185 });
1186
1187 let mock_manager = Arc::new(mock_manager);
1188
1189 let result = MultiRelayMessageManager::process_all_unsynced_messages_for_relay(
1191 &storage,
1192 &relay_id,
1193 &mock_manager,
1194 3, )
1196 .await;
1197
1198 assert!(
1199 result.is_ok(),
1200 "Efficient catch-up processing should succeed"
1201 );
1202
1203 assert_eq!(
1207 result.unwrap(),
1208 3,
1209 "Should have processed exactly 3 messages total"
1210 );
1211 }
1212
1213 #[tokio::test]
1214 async fn test_expired_message_handling() {
1215 use crate::services::messages_manager::MockMessagesManagerTrait;
1216 use zoe_wire_protocol::Kind;
1217
1218 let mut rng = rand::thread_rng();
1220 let keypair = KeyPair::generate(&mut rng);
1221
1222 let expired_message = {
1224 let past_timestamp = std::time::SystemTime::now()
1225 .duration_since(std::time::UNIX_EPOCH)
1226 .unwrap()
1227 .as_secs()
1228 .saturating_sub(2); let message = Message::new_v0(
1231 Content::raw(b"Expired message".to_vec()),
1232 keypair.public_key(),
1233 past_timestamp,
1234 Kind::Ephemeral(1), vec![],
1236 );
1237 MessageFull::new(message, &keypair).unwrap()
1238 };
1239
1240 let valid_message = {
1242 let message = Message::new_v0(
1243 Content::raw(b"Valid message".to_vec()),
1244 keypair.public_key(),
1245 std::time::SystemTime::now()
1246 .duration_since(std::time::UNIX_EPOCH)
1247 .unwrap()
1248 .as_secs(),
1249 Kind::Regular,
1250 vec![],
1251 );
1252 MessageFull::new(message, &keypair).unwrap()
1253 };
1254
1255 let test_messages = vec![expired_message.clone(), valid_message.clone()];
1256 let relay_id = KeyId::from([1u8; 32]);
1257 let relay_key_id = relay_id;
1258
1259 let mut mock_storage = MockMessageStorage::new();
1261 mock_storage
1262 .expect_get_unsynced_messages_for_relay()
1263 .with(
1264 mockall::predicate::eq(relay_key_id),
1265 mockall::predicate::eq(Some(2)),
1266 )
1267 .times(1)
1268 .returning(move |_, _| Ok(test_messages.clone()));
1269
1270 mock_storage
1272 .expect_get_unsynced_messages_for_relay()
1273 .with(
1274 mockall::predicate::eq(relay_key_id),
1275 mockall::predicate::eq(Some(2)),
1276 )
1277 .times(1)
1278 .returning(move |_, _| Ok(vec![]));
1279
1280 let expired_id = *expired_message.id();
1282 mock_storage
1283 .expect_delete_message()
1284 .with(mockall::predicate::eq(expired_id))
1285 .times(1)
1286 .returning(|_| Ok(true));
1287
1288 mock_storage
1290 .expect_mark_message_synced()
1291 .times(1)
1292 .returning(|_, _, _| Ok(()));
1293
1294 mock_storage
1296 .expect_store_message()
1297 .times(1)
1298 .returning(|_| Ok(()));
1299
1300 let storage = Arc::new(mock_storage);
1301
1302 let mut mock_manager = MockMessagesManagerTrait::new();
1304
1305 let valid_message_id = *valid_message.id();
1307 mock_manager
1308 .expect_check_messages()
1309 .with(mockall::predicate::eq(vec![valid_message_id]))
1310 .times(1)
1311 .returning(|_| Ok(vec![None])); mock_manager.expect_publish().times(1).returning(|_| {
1315 Ok(PublishResult::StoredNew {
1316 global_stream_id: "valid_stream_123".to_string(),
1317 })
1318 });
1319
1320 let mock_manager = Arc::new(mock_manager);
1321
1322 let result = MultiRelayMessageManager::process_all_unsynced_messages_for_relay(
1324 &storage,
1325 &relay_id,
1326 &mock_manager,
1327 2, )
1329 .await;
1330
1331 assert!(
1332 result.is_ok(),
1333 "Processing with expired messages should succeed"
1334 );
1335
1336 assert_eq!(
1339 result.unwrap(),
1340 1,
1341 "Should have processed exactly 1 valid message (expired message removed)"
1342 );
1343 }
1344
1345 #[tokio::test]
1346 async fn test_all_expired_messages_batch() {
1347 use crate::services::messages_manager::MockMessagesManagerTrait;
1348 use zoe_wire_protocol::Kind;
1349
1350 let mut rng = rand::thread_rng();
1352 let keypair = KeyPair::generate(&mut rng);
1353
1354 let expired_messages: Vec<MessageFull> = (0..3)
1355 .map(|i| {
1356 let past_timestamp = std::time::SystemTime::now()
1357 .duration_since(std::time::UNIX_EPOCH)
1358 .unwrap()
1359 .as_secs()
1360 .saturating_sub(10); let message = Message::new_v0(
1363 Content::raw(format!("Expired message {i}").as_bytes().to_vec()),
1364 keypair.public_key(),
1365 past_timestamp,
1366 Kind::Ephemeral(1), vec![],
1368 );
1369 MessageFull::new(message, &keypair).unwrap()
1370 })
1371 .collect();
1372
1373 let relay_id = KeyId::from([1u8; 32]);
1374 let relay_key_id = relay_id;
1375
1376 let mut mock_storage = MockMessageStorage::new();
1378 mock_storage
1379 .expect_get_unsynced_messages_for_relay()
1380 .with(
1381 mockall::predicate::eq(relay_key_id),
1382 mockall::predicate::eq(Some(3)),
1383 )
1384 .times(1)
1385 .returning(move |_, _| Ok(expired_messages.clone()));
1386
1387 mock_storage
1389 .expect_get_unsynced_messages_for_relay()
1390 .with(
1391 mockall::predicate::eq(relay_key_id),
1392 mockall::predicate::eq(Some(3)),
1393 )
1394 .times(1)
1395 .returning(move |_, _| Ok(vec![]));
1396
1397 mock_storage
1399 .expect_delete_message()
1400 .times(3)
1401 .returning(|_| Ok(true));
1402
1403 let storage = Arc::new(mock_storage);
1404
1405 let mock_manager = MockMessagesManagerTrait::new();
1408 let mock_manager = Arc::new(mock_manager);
1409
1410 let result = MultiRelayMessageManager::process_all_unsynced_messages_for_relay(
1412 &storage,
1413 &relay_id,
1414 &mock_manager,
1415 3, )
1417 .await;
1418
1419 assert!(
1420 result.is_ok(),
1421 "Processing all expired messages should succeed"
1422 );
1423
1424 assert_eq!(
1426 result.unwrap(),
1427 0,
1428 "Should have processed 0 messages (all were expired and removed)"
1429 );
1430 }
1431}