zoe_client/services/
multi_relay_message_manager.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3use std::time::SystemTime;
4
5use async_broadcast::{Receiver, Sender as BroadcastSender};
6use async_trait::async_trait;
7use futures::stream::Stream;
8use tokio::sync::RwLock;
9use tokio::task::JoinHandle;
10
11use eyeball::{AsyncLock, SharedObservable};
12use zoe_client_storage::{MessageStorage, SubscriptionState};
13use zoe_wire_protocol::{
14    CatchUpResponse, Filter, KeyId, MessageFull, PublishResult, StoreKey, StreamMessage,
15};
16
17use crate::error::{ClientError, Result};
18use crate::services::{MessageEvent, MessagesManager, MessagesManagerTrait};
19
20// Constants for catch-up processing
21const CATCH_UP_BATCH_SIZE: usize = 50;
22
23/// Connection state for a relay
24#[derive(Debug, Clone)]
25pub enum ConnectionState {
26    /// Relay is connected and operational
27    Connected,
28    /// Relay is disconnected but may reconnect
29    Disconnected,
30    /// Currently attempting to reconnect
31    Reconnecting,
32    /// Connection failed, will retry later
33    Failed {
34        last_error: String,
35        retry_at: SystemTime,
36    },
37}
38
39/// Represents a connection to a single relay with its associated managers
40pub struct RelayConnection {
41    /// The messages manager for this relay
42    pub manager: Arc<MessagesManager>,
43    /// Connection state tracking
44    pub connection_state: ConnectionState,
45    /// When this relay was last seen as active
46    pub last_seen: SystemTime,
47    /// Relay-specific subscription state
48    pub subscription_state: SubscriptionState,
49}
50
51/// Multi-relay message manager that provides unified messaging across multiple relays
52/// with offline support and automatic failover.
53///
54/// This manager:
55/// - Manages connections to multiple relay servers
56/// - Uses persistent storage for offline message queuing (no in-memory queue)
57/// - Implements automatic failover and load balancing
58/// - Aggregates messages from all connected relays
59/// - Deduplicates messages based on message ID
60/// - Maintains the same interface as a single MessagesManager
61#[derive(Clone)]
62pub struct MultiRelayMessageManager<S: MessageStorage> {
63    /// Map of relay ID to relay connection info
64    relay_connections: Arc<RwLock<BTreeMap<KeyId, RelayConnection>>>,
65    /// Storage for message persistence and offline queuing
66    storage: Arc<S>,
67    /// Global message event broadcaster (aggregates from all relays)
68    global_events_tx: BroadcastSender<MessageEvent>,
69    /// Global message broadcaster (aggregates from all relays)
70    global_messages_tx: BroadcastSender<StreamMessage>,
71    /// Global catch-up response broadcaster
72    global_catchup_tx: BroadcastSender<CatchUpResponse>,
73    /// Map of active catch-up tasks by relay ID
74    catch_up_tasks: Arc<RwLock<BTreeMap<KeyId, JoinHandle<()>>>>,
75}
76
77impl<S: MessageStorage + 'static> MultiRelayMessageManager<S> {
78    /// Create a new multi-relay message manager
79    pub fn new(storage: Arc<S>) -> Self {
80        let (global_events_tx, _) = async_broadcast::broadcast(1000);
81        let (global_messages_tx, _) = async_broadcast::broadcast(1000);
82        let (global_catchup_tx, _) = async_broadcast::broadcast(1000);
83
84        let relay_connections = Arc::new(RwLock::new(BTreeMap::new()));
85        let catch_up_tasks = Arc::new(RwLock::new(BTreeMap::new()));
86
87        Self {
88            relay_connections,
89            storage,
90            global_events_tx,
91            global_messages_tx,
92            global_catchup_tx,
93            catch_up_tasks,
94        }
95    }
96
97    /// Add a relay connection to the manager
98    ///
99    /// # Arguments
100    /// * `relay_id` - The unique identifier for the relay
101    /// * `manager` - The messages manager for this relay
102    /// * `should_catch_up` - Whether to start catching up on historical messages
103    pub async fn add_relay(
104        &self,
105        relay_id: KeyId,
106        manager: Arc<MessagesManager>,
107        should_catch_up: bool,
108    ) -> Result<()> {
109        let connection = RelayConnection {
110            manager: Arc::clone(&manager),
111            connection_state: ConnectionState::Connected,
112            last_seen: SystemTime::now(),
113            subscription_state: SubscriptionState::default(),
114        };
115
116        // Add to our connections map
117        {
118            let mut connections = self.relay_connections.write().await;
119            connections.insert(relay_id, connection);
120        }
121
122        tracing::info!(
123            "Added relay connection: {}",
124            hex::encode(relay_id.as_bytes())
125        );
126
127        // Start catch-up task if requested
128        if should_catch_up {
129            self.start_catch_up_task(relay_id, manager).await?;
130        }
131
132        Ok(())
133    }
134
135    /// Remove a relay connection
136    pub async fn remove_relay(&self, relay_id: &KeyId) -> Option<Arc<MessagesManager>> {
137        // Cancel any active catch-up task for this relay
138        {
139            let mut tasks = self.catch_up_tasks.write().await;
140            if let Some(task) = tasks.remove(relay_id) {
141                task.abort();
142                tracing::debug!(
143                    "Cancelled catch-up task for relay: {}",
144                    hex::encode(relay_id.as_bytes())
145                );
146            }
147        }
148
149        let mut connections = self.relay_connections.write().await;
150        let removed = connections.remove(relay_id);
151
152        if let Some(connection) = &removed {
153            tracing::info!(
154                "Removed relay connection: {}",
155                hex::encode(relay_id.as_bytes())
156            );
157            Some(Arc::clone(&connection.manager))
158        } else {
159            None
160        }
161    }
162
163    /// Get list of currently connected relay IDs
164    pub async fn get_connected_relay_ids(&self) -> Vec<KeyId> {
165        let connections = self.relay_connections.read().await;
166        connections
167            .iter()
168            .filter(|(_, conn)| matches!(conn.connection_state, ConnectionState::Connected))
169            .map(|(id, _)| *id)
170            .collect()
171    }
172
173    /// Get list of all relay IDs (connected and disconnected)
174    pub async fn get_all_relay_ids(&self) -> Vec<KeyId> {
175        let connections = self.relay_connections.read().await;
176        connections.keys().copied().collect()
177    }
178
179    /// Check if any relays are currently connected
180    pub async fn has_connected_relays(&self) -> bool {
181        !self.get_connected_relay_ids().await.is_empty()
182    }
183
184    /// Start a catch-up task for a specific relay
185    async fn start_catch_up_task(
186        &self,
187        relay_id: KeyId,
188        manager: Arc<MessagesManager>,
189    ) -> Result<()> {
190        let storage = Arc::clone(&self.storage);
191        let tasks = Arc::clone(&self.catch_up_tasks);
192
193        let task = tokio::spawn(async move {
194            tracing::info!(
195                "Starting catch-up task for relay: {}",
196                hex::encode(relay_id.as_bytes())
197            );
198
199            match Self::process_all_unsynced_messages_for_relay::<MessagesManager>(
200                &storage,
201                &relay_id,
202                &manager,
203                CATCH_UP_BATCH_SIZE,
204            )
205            .await
206            {
207                Ok(total_processed) => {
208                    tracing::info!(
209                        "Catch-up completed for relay: {} ({} messages processed)",
210                        hex::encode(relay_id.as_bytes()),
211                        total_processed
212                    );
213                }
214                Err(e) => {
215                    tracing::warn!(
216                        "Catch-up failed for relay {}: {}",
217                        hex::encode(relay_id.as_bytes()),
218                        e
219                    );
220                }
221            }
222
223            // Remove this task from the active tasks map when done
224            let mut tasks_guard = tasks.write().await;
225            tasks_guard.remove(&relay_id);
226        });
227
228        // Store the task handle
229        {
230            let mut tasks_guard = self.catch_up_tasks.write().await;
231            tasks_guard.insert(relay_id, task);
232        }
233
234        Ok(())
235    }
236
237    /// Process all unsynced messages for a specific relay in batches
238    /// Returns the total number of messages processed
239    async fn process_all_unsynced_messages_for_relay<M: MessagesManagerTrait>(
240        storage: &Arc<S>,
241        relay_id: &KeyId,
242        manager: &Arc<M>,
243        batch_size: usize,
244    ) -> Result<usize> {
245        let mut total_processed = 0;
246        let mut batch_count = 0;
247
248        loop {
249            tracing::debug!(
250                "Processing batch {} for relay {} (batch size: {})",
251                batch_count,
252                hex::encode(relay_id.as_bytes()),
253                batch_size
254            );
255
256            let batch_result = Self::process_unsynced_messages_batch_for_relay(
257                storage, relay_id, manager, batch_size,
258            )
259            .await?;
260
261            let Some(batch_processed) = batch_result else {
262                // No more messages to process - we're done
263                tracing::debug!(
264                    "No unsynced messages left for relay {}, stopping catch-up after {} batches",
265                    hex::encode(relay_id.as_bytes()),
266                    batch_count
267                );
268                break;
269            };
270
271            total_processed += batch_processed;
272            batch_count += 1;
273
274            tracing::debug!(
275                "Batch {} complete: {} messages processed for relay {}",
276                batch_count,
277                batch_processed,
278                hex::encode(relay_id.as_bytes())
279            );
280        }
281
282        Ok(total_processed)
283    }
284
285    /// Process a single batch of unsynced messages for a specific relay
286    /// Returns:
287    /// - Ok(Some(count)) if messages were processed (count = number successfully processed)
288    /// - Ok(None) if no unsynced messages were found (indicates completion)
289    /// - Err(_) if there was an error
290    async fn process_unsynced_messages_batch_for_relay<M: MessagesManagerTrait>(
291        storage: &Arc<S>,
292        relay_id: &KeyId,
293        manager: &Arc<M>,
294        batch_size: usize,
295    ) -> Result<Option<usize>> {
296        // Convert KeyId to Hash for storage API
297        let relay_key_id = relay_id;
298
299        // Get unsynced messages for this relay
300        let unsynced_messages = storage
301            .get_unsynced_messages_for_relay(relay_key_id, Some(batch_size))
302            .await
303            .map_err(|e| ClientError::Generic(format!("Failed to get unsynced messages: {e}")))?;
304
305        if unsynced_messages.is_empty() {
306            return Ok(None); // No messages to process - indicates completion
307        }
308
309        let batch_message_count = unsynced_messages.len();
310        tracing::debug!(
311            "Processing batch of {} unsynced messages for relay {}",
312            batch_message_count,
313            hex::encode(relay_id.as_bytes())
314        );
315
316        // Filter out expired messages and delete them immediately from storage
317        let mut valid_messages = Vec::new();
318        let mut expired_count = 0;
319
320        let current_time = std::time::SystemTime::now()
321            .duration_since(std::time::UNIX_EPOCH)
322            .unwrap_or_default()
323            .as_secs();
324
325        for message in unsynced_messages {
326            if message.is_expired(current_time) {
327                // Delete expired message immediately
328                if let Err(e) = storage.delete_message(message.id()).await {
329                    tracing::error!(
330                        "Failed to delete expired message {} from storage: {}. Continuing with other messages.",
331                        hex::encode(message.id().as_bytes()),
332                        e
333                    );
334                    // Continue processing other messages even if this deletion failed
335                }
336                expired_count += 1;
337                continue;
338            } else {
339                valid_messages.push(message);
340            }
341        }
342
343        // If all messages were expired, return Some(0) to indicate we processed a batch
344        // but didn't sync any messages (they were expired and removed)
345        if valid_messages.is_empty() {
346            tracing::debug!(
347                "All {} messages in batch were expired and removed for relay {}",
348                expired_count,
349                hex::encode(relay_id.as_bytes())
350            );
351            return Ok(Some(0));
352        }
353
354        // Now check which valid messages already exist on the server
355        let message_ids: Vec<_> = valid_messages.iter().map(|msg| *msg.id()).collect();
356
357        tracing::debug!(
358            "Checking existence of {} valid messages on relay {} ({} expired messages removed)",
359            message_ids.len(),
360            hex::encode(relay_id.as_bytes()),
361            expired_count
362        );
363
364        let existence_results = manager.check_messages(message_ids).await?;
365        let mut processed_count = 0;
366        // Process messages based on existence check results
367        for (message, existence_result) in valid_messages.iter().zip(existence_results.iter()) {
368            // Ensure message is stored locally first (for both existing and new messages)
369            storage.store_message(message).await?;
370
371            if let Some(global_stream_id) = existence_result {
372                // Message already exists on server, just mark as synced
373                if let Err(e) = storage
374                    .mark_message_synced(message.id(), relay_key_id, global_stream_id)
375                    .await
376                {
377                    tracing::error!(
378                        "Failed to mark existing message {} as synced to relay {}: {}",
379                        hex::encode(message.id().as_bytes()),
380                        hex::encode(relay_id.as_bytes()),
381                        e
382                    );
383                    continue;
384                }
385
386                tracing::debug!(
387                    "Message {} already exists on relay {}, marked as synced",
388                    hex::encode(message.id().as_bytes()),
389                    hex::encode(relay_id.as_bytes())
390                );
391                processed_count += 1;
392                continue;
393            }
394
395            // Message doesn't exist, need to send it
396            let global_stream_id = match manager.publish(message.clone()).await {
397                Ok(result) => match result {
398                    PublishResult::StoredNew { global_stream_id } => global_stream_id,
399                    PublishResult::AlreadyExists { global_stream_id } => global_stream_id,
400                    PublishResult::Expired => {
401                        tracing::warn!("Message expired: {}", hex::encode(message.id().as_bytes()));
402                        continue;
403                    }
404                },
405                Err(e) => {
406                    tracing::error!(
407                        "Failed to send message {} to relay {}: {}",
408                        hex::encode(message.id().as_bytes()),
409                        hex::encode(relay_id.as_bytes()),
410                        e
411                    );
412                    continue;
413                }
414            };
415
416            // Mark as synced in storage
417            if let Err(e) = storage
418                .mark_message_synced(message.id(), relay_key_id, &global_stream_id)
419                .await
420            {
421                tracing::error!(
422                    "Failed to mark message {} as synced to relay {}: {}",
423                    hex::encode(message.id().as_bytes()),
424                    hex::encode(relay_id.as_bytes()),
425                    e
426                );
427                continue;
428            }
429            processed_count += 1;
430        }
431
432        tracing::debug!(
433            "Batch complete: {}/{} valid messages successfully processed for relay {} ({} expired messages removed)",
434            processed_count,
435            valid_messages.len(),
436            hex::encode(relay_id.as_bytes()),
437            expired_count
438        );
439
440        Ok(Some(processed_count))
441    }
442}
443
444#[async_trait]
445impl<S: MessageStorage + 'static> MessagesManagerTrait for MultiRelayMessageManager<S> {
446    /// Get a stream of all message events from all relays
447    fn message_events_stream(&self) -> Receiver<MessageEvent> {
448        self.global_events_tx.new_receiver()
449    }
450
451    /// Subscribe to subscription state changes (aggregated from all relays)
452    async fn get_subscription_state_updates(
453        &self,
454    ) -> eyeball::Subscriber<SubscriptionState, AsyncLock> {
455        // For now, return a default state. In a full implementation, we'd aggregate
456        // subscription states from all relays
457        let state = SubscriptionState::default();
458        let observable = SharedObservable::new_async(state);
459        observable.subscribe().await
460    }
461
462    /// Subscribe to messages on all connected relays
463    async fn subscribe(&self) -> Result<()> {
464        let connections = self.relay_connections.read().await;
465        let mut results = Vec::new();
466
467        for (relay_id, connection) in connections.iter() {
468            if matches!(connection.connection_state, ConnectionState::Connected) {
469                match connection.manager.subscribe().await {
470                    Ok(()) => {
471                        tracing::debug!("Subscribed to relay {}", hex::encode(relay_id.as_bytes()));
472                    }
473                    Err(e) => {
474                        tracing::warn!(
475                            "Failed to subscribe to relay {}: {}",
476                            hex::encode(relay_id.as_bytes()),
477                            e
478                        );
479                        results.push(e);
480                    }
481                }
482            }
483        }
484
485        // Return error if all subscriptions failed
486        if !results.is_empty() && results.len() == connections.len() {
487            return Err(results.into_iter().next().unwrap());
488        }
489
490        Ok(())
491    }
492
493    /// Publish a message to available relays or queue for offline delivery
494    async fn publish(&self, message: MessageFull) -> Result<PublishResult> {
495        let connected_relays = {
496            let connections = self.relay_connections.read().await;
497            connections
498                .iter()
499                .filter(|(_, conn)| matches!(conn.connection_state, ConnectionState::Connected))
500                .map(|(id, conn)| (*id, Arc::clone(&conn.manager)))
501                .collect::<Vec<_>>()
502        };
503
504        if connected_relays.is_empty() {
505            // No relays available, store message for offline processing
506            self.storage.store_message(&message).await.map_err(|e| {
507                ClientError::Generic(format!("Failed to store message for offline delivery: {e}"))
508            })?;
509
510            tracing::info!(
511                "No relays available, stored message {} for offline processing",
512                hex::encode(message.id().as_bytes())
513            );
514
515            return Ok(PublishResult::StoredNew {
516                global_stream_id: "queued_offline".to_string(),
517            });
518        }
519
520        // Try to send to the first available relay
521        // In a more sophisticated implementation, we could implement load balancing
522        let (relay_id, manager) = &connected_relays[0];
523
524        // Store message locally first to ensure it exists before any sync operations
525        self.storage.store_message(&message).await.map_err(|e| {
526            ClientError::Generic(format!(
527                "Failed to store message {} locally: {}",
528                hex::encode(message.id().as_bytes()),
529                e
530            ))
531        })?;
532
533        match manager.publish(message.clone()).await {
534            Ok(result) => {
535                // Mark as synced to this relay
536                let relay_key_id = relay_id;
537                let global_stream_id = match &result {
538                    PublishResult::StoredNew { global_stream_id } => global_stream_id,
539                    PublishResult::AlreadyExists { global_stream_id } => global_stream_id,
540                    PublishResult::Expired => {
541                        tracing::warn!("Message expired during publish");
542                        return Ok(result);
543                    }
544                };
545
546                self.storage
547                    .mark_message_synced(message.id(), relay_key_id, global_stream_id)
548                    .await
549                    .map_err(|e| {
550                        ClientError::Generic(format!(
551                            "Failed to mark message {} as synced to relay {}: {}",
552                            hex::encode(message.id().as_bytes()),
553                            hex::encode(relay_id.as_bytes()),
554                            e
555                        ))
556                    })?;
557
558                tracing::debug!(
559                    "Successfully published message to relay {}",
560                    hex::encode(relay_id.as_bytes())
561                );
562                Ok(result)
563            }
564            Err(e) => {
565                tracing::warn!(
566                    "Failed to publish message to relay {}: {}",
567                    hex::encode(relay_id.as_bytes()),
568                    e
569                );
570
571                // Store message for offline processing - background task will retry
572                self.storage
573                    .store_message(&message)
574                    .await
575                    .map_err(|storage_err| {
576                        ClientError::Generic(format!(
577                            "Failed to store message for offline delivery: {storage_err}"
578                        ))
579                    })?;
580
581                tracing::info!(
582                    "Stored message {} for offline processing after publish failure",
583                    hex::encode(message.id().as_bytes())
584                );
585
586                Err(e)
587            }
588        }
589    }
590
591    /// Ensure a filter is included in the subscription on all connected relays
592    async fn ensure_contains_filter(&self, filter: Filter) -> Result<()> {
593        let connections = self.relay_connections.read().await;
594        let mut results = Vec::new();
595
596        for (relay_id, connection) in connections.iter() {
597            if matches!(connection.connection_state, ConnectionState::Connected) {
598                match connection
599                    .manager
600                    .ensure_contains_filter(filter.clone())
601                    .await
602                {
603                    Ok(()) => {
604                        tracing::debug!(
605                            "Added filter to relay {}",
606                            hex::encode(relay_id.as_bytes())
607                        );
608                    }
609                    Err(e) => {
610                        tracing::warn!(
611                            "Failed to add filter to relay {}: {}",
612                            hex::encode(relay_id.as_bytes()),
613                            e
614                        );
615                        results.push(e);
616                    }
617                }
618            }
619        }
620
621        // Return error if all filter additions failed
622        if !results.is_empty() && results.len() == connections.len() {
623            return Err(results.into_iter().next().unwrap());
624        }
625
626        Ok(())
627    }
628
629    /// Get a stream of incoming messages from all relays
630    fn messages_stream(&self) -> Receiver<StreamMessage> {
631        self.global_messages_tx.new_receiver()
632    }
633
634    /// Get a stream of catch-up responses from all relays
635    fn catch_up_stream(&self) -> Receiver<CatchUpResponse> {
636        self.global_catchup_tx.new_receiver()
637    }
638
639    /// Get a filtered stream of messages matching the given filter from all relays
640    fn filtered_messages_stream(
641        &self,
642        _filter: Filter,
643    ) -> std::pin::Pin<Box<dyn Stream<Item = Box<MessageFull>> + Send>> {
644        // TODO: Implement proper filtering
645        // For now, return an empty stream
646        Box::pin(futures::stream::empty())
647    }
648
649    /// Catch up to historical messages and subscribe to new ones for a filter
650    async fn catch_up_and_subscribe(
651        &self,
652        _filter: Filter,
653        _since: Option<String>,
654    ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Box<MessageFull>> + Send>>> {
655        // TODO: Implement proper catch-up across all relays
656        // For now, return an empty stream as a placeholder
657        Ok(Box::pin(futures::stream::empty()))
658    }
659
660    /// Get user data by author and storage key from any available relay
661    async fn user_data(&self, author: KeyId, storage_key: StoreKey) -> Result<Option<MessageFull>> {
662        let connections = self.relay_connections.read().await;
663
664        // Try each connected relay until we find the data
665        for (relay_id, connection) in connections.iter() {
666            if matches!(connection.connection_state, ConnectionState::Connected) {
667                match connection
668                    .manager
669                    .user_data(author, storage_key.clone())
670                    .await
671                {
672                    Ok(Some(data)) => {
673                        tracing::debug!(
674                            "Found user data on relay {}",
675                            hex::encode(relay_id.as_bytes())
676                        );
677                        return Ok(Some(data));
678                    }
679                    Ok(None) => {
680                        // Not found on this relay, try next
681                        continue;
682                    }
683                    Err(e) => {
684                        tracing::warn!(
685                            "Failed to get user data from relay {}: {}",
686                            hex::encode(relay_id.as_bytes()),
687                            e
688                        );
689                        continue;
690                    }
691                }
692            }
693        }
694
695        Ok(None) // Not found on any relay
696    }
697
698    /// Check which messages exist on any connected relay
699    /// Returns the first successful result from any relay
700    async fn check_messages(
701        &self,
702        message_ids: Vec<zoe_wire_protocol::MessageId>,
703    ) -> Result<Vec<Option<String>>> {
704        let connections = self.relay_connections.read().await;
705
706        // Try each connected relay until we get a successful response
707        for (relay_id, connection) in connections.iter() {
708            if matches!(connection.connection_state, ConnectionState::Connected) {
709                match connection.manager.check_messages(message_ids.clone()).await {
710                    Ok(result) => {
711                        tracing::debug!(
712                            "Successfully checked {} messages on relay {}",
713                            message_ids.len(),
714                            hex::encode(relay_id.as_bytes())
715                        );
716                        return Ok(result);
717                    }
718                    Err(e) => {
719                        tracing::warn!(
720                            "Failed to check messages on relay {}: {}",
721                            hex::encode(relay_id.as_bytes()),
722                            e
723                        );
724                        continue;
725                    }
726                }
727            }
728        }
729
730        // If no relay is available, return None for all messages (assume they don't exist)
731        tracing::warn!("No connected relays available for checking messages");
732        Ok(vec![None; message_ids.len()])
733    }
734}
735
736impl<S: MessageStorage> Drop for MultiRelayMessageManager<S> {
737    fn drop(&mut self) {
738        // Cancel all active catch-up tasks
739        if let Ok(mut tasks) = self.catch_up_tasks.try_write() {
740            for (relay_id, task) in tasks.iter() {
741                task.abort();
742                tracing::debug!(
743                    "Cancelled catch-up task for relay during drop: {}",
744                    hex::encode(relay_id.as_bytes())
745                );
746            }
747            tasks.clear();
748        }
749    }
750}
751
752#[cfg(test)]
753mod tests {
754    use super::*;
755    use zoe_client_storage::storage::MockMessageStorage;
756    use zoe_wire_protocol::{Content, KeyPair, Kind, Message, MessageFull};
757
758    fn create_test_message() -> MessageFull {
759        let mut rng = rand::thread_rng();
760        let keypair = KeyPair::generate(&mut rng);
761        let message = Message::new_v0(
762            Content::raw(b"Test message".to_vec()),
763            keypair.public_key(),
764            1234567890u64, // Fixed timestamp for testing
765            Kind::Regular,
766            vec![],
767        );
768        MessageFull::new(message, &keypair).unwrap()
769    }
770
771    #[tokio::test]
772    async fn test_multi_relay_manager_creation() {
773        let mock_storage = MockMessageStorage::new();
774        let manager = MultiRelayMessageManager::new(Arc::new(mock_storage));
775
776        // Should start with no relays
777        assert_eq!(manager.get_all_relay_ids().await.len(), 0);
778        assert!(!manager.has_connected_relays().await);
779    }
780
781    #[tokio::test]
782    async fn test_offline_message_queuing() {
783        let mut mock_storage = MockMessageStorage::new();
784        mock_storage
785            .expect_store_message()
786            .times(1)
787            .returning(|_| Ok(()));
788
789        let manager = MultiRelayMessageManager::new(Arc::new(mock_storage));
790        let test_message = create_test_message();
791
792        // Publishing with no relays should queue the message offline
793        let result = manager.publish(test_message).await;
794        assert!(result.is_ok());
795
796        if let Ok(publish_result) = result {
797            assert_eq!(
798                publish_result,
799                PublishResult::StoredNew {
800                    global_stream_id: "queued_offline".to_string()
801                }
802            );
803        }
804    }
805
806    #[tokio::test]
807    async fn test_basic_functionality() {
808        let mock_storage = MockMessageStorage::new();
809        let manager = MultiRelayMessageManager::new(Arc::new(mock_storage));
810
811        // Should start with no relays
812        assert_eq!(manager.get_all_relay_ids().await.len(), 0);
813        assert!(!manager.has_connected_relays().await);
814
815        // Test getting connected relay IDs when none are connected
816        let connected_ids = manager.get_connected_relay_ids().await;
817        assert!(connected_ids.is_empty());
818    }
819
820    #[tokio::test]
821    async fn test_session_manager_compatibility() {
822        use crate::pqxdh::PqxdhProtocolState;
823        use crate::session_manager::SessionManager;
824        use zoe_client_storage::storage::MockStateStorage;
825        use zoe_wire_protocol::KeyPair;
826
827        let mock_message_storage = MockMessageStorage::new();
828        let multi_relay_manager = Arc::new(MultiRelayMessageManager::new(Arc::new(
829            mock_message_storage,
830        )));
831
832        let mut mock_state_storage = MockStateStorage::new();
833        // Mock the expected calls for loading PQXDH states
834        mock_state_storage
835            .expect_list_namespace_data::<PqxdhProtocolState>()
836            .returning(|_| Ok(vec![]));
837        // Mock the expected calls for group manager initialization
838        mock_state_storage
839            .expect_list_namespace_data::<zoe_state_machine::GroupSession>()
840            .returning(|_| Ok(vec![]));
841
842        let mock_state_storage = Arc::new(mock_state_storage);
843
844        // Create a test keypair
845        let mut rng = rand::thread_rng();
846        let keypair = Arc::new(KeyPair::generate(&mut rng));
847
848        // Test that SessionManager can be created with MultiRelayMessageManager
849        let session_manager_result =
850            SessionManager::builder(mock_state_storage, multi_relay_manager)
851                .client_keypair(keypair)
852                .build()
853                .await;
854
855        // This should compile and work without issues
856        assert!(session_manager_result.is_ok());
857
858        let session_manager = session_manager_result.unwrap();
859
860        // Verify we can access the multi-relay manager through the session manager
861        let messages_manager = session_manager.messages_manager();
862
863        // Test that we can use the session manager's message manager interface
864        let _events_stream = messages_manager.message_events_stream();
865        // Stream creation should succeed (we can't easily test if it's active without subscribing)
866    }
867
868    #[tokio::test]
869    async fn test_pqxdh_handler_compatibility() {
870        use crate::pqxdh::PqxdhProtocolHandler;
871        use zoe_wire_protocol::{KeyPair, PqxdhInboxProtocol};
872
873        let mock_message_storage = MockMessageStorage::new();
874        let multi_relay_manager = Arc::new(MultiRelayMessageManager::new(Arc::new(
875            mock_message_storage,
876        )));
877
878        // Create a test keypair
879        let mut rng = rand::thread_rng();
880        let keypair = Arc::new(KeyPair::generate(&mut rng));
881
882        // Test that PqxdhProtocolHandler can be created with MultiRelayMessageManager
883        let _pqxdh_handler = PqxdhProtocolHandler::new(
884            multi_relay_manager,
885            keypair,
886            PqxdhInboxProtocol::EchoService,
887        );
888
889        // Verify the handler was created successfully by checking it can be used
890        // We can't easily test internal state without more complex setup,
891        // but the fact that it compiles and creates successfully is the main test
892    }
893
894    #[tokio::test]
895    async fn test_catch_up_processing() {
896        use crate::services::messages_manager::MockMessagesManagerTrait;
897
898        // Create a message that will be "unsynced" for a relay
899        let test_message = create_test_message();
900        let relay_id = KeyId::from([1u8; 32]);
901        let relay_key_id = relay_id;
902
903        // Mock storage to return the unsynced message
904        let mut mock_storage = MockMessageStorage::new();
905        mock_storage
906            .expect_get_unsynced_messages_for_relay()
907            .with(
908                mockall::predicate::eq(relay_key_id),
909                mockall::predicate::eq(Some(50)),
910            )
911            .times(1)
912            .returning(move |_, _| Ok(vec![test_message.clone()]));
913
914        // Mock successful message sync marking
915        mock_storage
916            .expect_mark_message_synced()
917            .times(1)
918            .returning(|_, _, _| Ok(()));
919
920        // Mock store_message call (new requirement from our fix)
921        mock_storage
922            .expect_store_message()
923            .times(1)
924            .returning(|_| Ok(()));
925
926        let storage = Arc::new(mock_storage);
927
928        // Create a mock messages manager that will succeed
929        let mut mock_manager = MockMessagesManagerTrait::new();
930
931        // Mock check_messages to return that the message doesn't exist (None)
932        mock_manager
933            .expect_check_messages()
934            .times(1)
935            .returning(|_| Ok(vec![None])); // Message doesn't exist, needs to be sent
936
937        mock_manager.expect_publish().times(1).returning(|_| {
938            Ok(PublishResult::StoredNew {
939                global_stream_id: "test_stream_123".to_string(),
940            })
941        });
942
943        let mock_manager = Arc::new(mock_manager);
944
945        // Test the catch-up processing function directly
946        let result = MultiRelayMessageManager::process_unsynced_messages_batch_for_relay(
947            &storage,
948            &relay_id,
949            &mock_manager,
950            50,
951        )
952        .await;
953
954        assert!(result.is_ok(), "Catch-up processing should succeed");
955
956        // Verify that we processed exactly 1 message
957        assert_eq!(
958            result.unwrap(),
959            Some(1),
960            "Should have processed exactly 1 message"
961        );
962    }
963
964    #[tokio::test]
965    async fn test_batched_catch_up_processing() {
966        use crate::services::messages_manager::MockMessagesManagerTrait;
967
968        // Create multiple test messages
969        let test_messages: Vec<MessageFull> = (0u64..5u64)
970            .map(|i| {
971                let mut rng = rand::thread_rng();
972                let keypair = KeyPair::generate(&mut rng);
973                let message = Message::new_v0(
974                    Content::raw(format!("Test message {i}").as_bytes().to_vec()),
975                    keypair.public_key(),
976                    1234567890u64 + i,
977                    Kind::Regular,
978                    vec![],
979                );
980                MessageFull::new(message, &keypair).unwrap()
981            })
982            .collect();
983
984        let relay_id = KeyId::from([1u8; 32]);
985        let relay_key_id = relay_id;
986
987        // Mock storage to return messages in batches
988        let mut mock_storage = MockMessageStorage::new();
989
990        // First call returns 2 messages (batch size 2)
991        let messages_batch_1 = vec![test_messages[0].clone(), test_messages[1].clone()];
992        mock_storage
993            .expect_get_unsynced_messages_for_relay()
994            .with(
995                mockall::predicate::eq(relay_key_id),
996                mockall::predicate::eq(Some(2)),
997            )
998            .times(1)
999            .returning(move |_, _| Ok(messages_batch_1.clone()));
1000
1001        // Second call returns 2 more messages
1002        let messages_batch_2 = vec![test_messages[2].clone(), test_messages[3].clone()];
1003        mock_storage
1004            .expect_get_unsynced_messages_for_relay()
1005            .with(
1006                mockall::predicate::eq(relay_key_id),
1007                mockall::predicate::eq(Some(2)),
1008            )
1009            .times(1)
1010            .returning(move |_, _| Ok(messages_batch_2.clone()));
1011
1012        // Third call returns 1 message
1013        let messages_batch_3 = vec![test_messages[4].clone()];
1014        mock_storage
1015            .expect_get_unsynced_messages_for_relay()
1016            .with(
1017                mockall::predicate::eq(relay_key_id),
1018                mockall::predicate::eq(Some(2)),
1019            )
1020            .times(1)
1021            .returning(move |_, _| Ok(messages_batch_3.clone()));
1022
1023        // Fourth call returns 0 messages (indicating we're done)
1024        mock_storage
1025            .expect_get_unsynced_messages_for_relay()
1026            .with(
1027                mockall::predicate::eq(relay_key_id),
1028                mockall::predicate::eq(Some(2)),
1029            )
1030            .times(1)
1031            .returning(move |_, _| Ok(vec![]));
1032
1033        // Mock successful message sync marking for all 5 messages
1034        mock_storage
1035            .expect_mark_message_synced()
1036            .times(5)
1037            .returning(|_, _, _| Ok(()));
1038
1039        // Mock store_message calls for all 5 messages (new requirement from our fix)
1040        mock_storage
1041            .expect_store_message()
1042            .times(5)
1043            .returning(|_| Ok(()));
1044
1045        let storage = Arc::new(mock_storage);
1046
1047        // Create a mock messages manager that will succeed for all messages
1048        let mut mock_manager = MockMessagesManagerTrait::new();
1049
1050        // Mock check_messages for each batch:
1051        // Batch 1: 2 messages, both don't exist (need to send)
1052        mock_manager
1053            .expect_check_messages()
1054            .times(1)
1055            .returning(|_| Ok(vec![None, None]));
1056
1057        // Batch 2: 2 messages, both don't exist (need to send)
1058        mock_manager
1059            .expect_check_messages()
1060            .times(1)
1061            .returning(|_| Ok(vec![None, None]));
1062
1063        // Batch 3: 1 message, doesn't exist (need to send)
1064        mock_manager
1065            .expect_check_messages()
1066            .times(1)
1067            .returning(|_| Ok(vec![None]));
1068
1069        mock_manager.expect_publish().times(5).returning(|_| {
1070            Ok(PublishResult::StoredNew {
1071                global_stream_id: "test_stream_123".to_string(),
1072            })
1073        });
1074
1075        let mock_manager = Arc::new(mock_manager);
1076
1077        // Test the full batched catch-up processing
1078        let result = MultiRelayMessageManager::process_all_unsynced_messages_for_relay(
1079            &storage,
1080            &relay_id,
1081            &mock_manager,
1082            2, // batch size of 2
1083        )
1084        .await;
1085
1086        assert!(result.is_ok(), "Batched catch-up processing should succeed");
1087
1088        // Verify that we processed all 5 messages across 4 batches (2+2+1+0)
1089        assert_eq!(
1090            result.unwrap(),
1091            5,
1092            "Should have processed exactly 5 messages total"
1093        );
1094    }
1095
1096    #[tokio::test]
1097    async fn test_efficient_catch_up_with_existing_messages() {
1098        use crate::services::messages_manager::MockMessagesManagerTrait;
1099
1100        // Create 3 test messages
1101        let test_messages: Vec<MessageFull> = (0u64..3u64)
1102            .map(|i| {
1103                let mut rng = rand::thread_rng();
1104                let keypair = KeyPair::generate(&mut rng);
1105                let message = Message::new_v0(
1106                    Content::raw(format!("Test message {i}").as_bytes().to_vec()),
1107                    keypair.public_key(),
1108                    1234567890u64 + i,
1109                    Kind::Regular,
1110                    vec![],
1111                );
1112                MessageFull::new(message, &keypair).unwrap()
1113            })
1114            .collect();
1115
1116        let relay_id = KeyId::from([1u8; 32]);
1117        let relay_key_id = relay_id;
1118
1119        // Extract message IDs before creating closures
1120        let message_ids = vec![
1121            *test_messages[0].id(),
1122            *test_messages[1].id(),
1123            *test_messages[2].id(),
1124        ];
1125
1126        // Mock storage to return 3 unsynced messages
1127        let mut mock_storage = MockMessageStorage::new();
1128        mock_storage
1129            .expect_get_unsynced_messages_for_relay()
1130            .with(
1131                mockall::predicate::eq(relay_key_id),
1132                mockall::predicate::eq(Some(3)),
1133            )
1134            .times(1)
1135            .returning(move |_, _| Ok(test_messages.clone()));
1136
1137        // Second call returns empty (indicating we're done)
1138        mock_storage
1139            .expect_get_unsynced_messages_for_relay()
1140            .with(
1141                mockall::predicate::eq(relay_key_id),
1142                mockall::predicate::eq(Some(3)),
1143            )
1144            .times(1)
1145            .returning(move |_, _| Ok(vec![]));
1146
1147        // Mock successful message sync marking for all 3 messages
1148        mock_storage
1149            .expect_mark_message_synced()
1150            .times(3)
1151            .returning(|_, _, _| Ok(()));
1152
1153        // Mock store_message calls for all 3 messages (new requirement from our fix)
1154        mock_storage
1155            .expect_store_message()
1156            .times(3)
1157            .returning(|_| Ok(()));
1158
1159        let storage = Arc::new(mock_storage);
1160
1161        // Create a mock messages manager
1162        let mut mock_manager = MockMessagesManagerTrait::new();
1163
1164        // Mock check_messages to return:
1165        // - Message 0: already exists (Some("existing_stream_1"))
1166        // - Message 1: doesn't exist (None)
1167        // - Message 2: already exists (Some("existing_stream_2"))
1168        mock_manager
1169            .expect_check_messages()
1170            .with(mockall::predicate::eq(message_ids))
1171            .times(1)
1172            .returning(|_| {
1173                Ok(vec![
1174                    Some("existing_stream_1".to_string()), // Message 0 exists
1175                    None,                                  // Message 1 doesn't exist
1176                    Some("existing_stream_2".to_string()), // Message 2 exists
1177                ])
1178            });
1179
1180        // Mock publish to be called only once (for message 1 that doesn't exist)
1181        mock_manager.expect_publish().times(1).returning(|_| {
1182            Ok(PublishResult::StoredNew {
1183                global_stream_id: "new_stream_123".to_string(),
1184            })
1185        });
1186
1187        let mock_manager = Arc::new(mock_manager);
1188
1189        // Test the efficient batch processing
1190        let result = MultiRelayMessageManager::process_all_unsynced_messages_for_relay(
1191            &storage,
1192            &relay_id,
1193            &mock_manager,
1194            3, // batch size of 3
1195        )
1196        .await;
1197
1198        assert!(
1199            result.is_ok(),
1200            "Efficient catch-up processing should succeed"
1201        );
1202
1203        // Verify that we processed all 3 messages:
1204        // - 2 were marked as synced without sending (already existed)
1205        // - 1 was sent and then marked as synced
1206        assert_eq!(
1207            result.unwrap(),
1208            3,
1209            "Should have processed exactly 3 messages total"
1210        );
1211    }
1212
1213    #[tokio::test]
1214    async fn test_expired_message_handling() {
1215        use crate::services::messages_manager::MockMessagesManagerTrait;
1216        use zoe_wire_protocol::Kind;
1217
1218        // Create test messages - one expired, one valid
1219        let mut rng = rand::thread_rng();
1220        let keypair = KeyPair::generate(&mut rng);
1221
1222        // Create an expired ephemeral message (timeout of 1 second, created 2 seconds ago)
1223        let expired_message = {
1224            let past_timestamp = std::time::SystemTime::now()
1225                .duration_since(std::time::UNIX_EPOCH)
1226                .unwrap()
1227                .as_secs()
1228                .saturating_sub(2); // 2 seconds ago
1229
1230            let message = Message::new_v0(
1231                Content::raw(b"Expired message".to_vec()),
1232                keypair.public_key(),
1233                past_timestamp,
1234                Kind::Ephemeral(1), // 1 second timeout - should be expired
1235                vec![],
1236            );
1237            MessageFull::new(message, &keypair).unwrap()
1238        };
1239
1240        // Create a valid regular message
1241        let valid_message = {
1242            let message = Message::new_v0(
1243                Content::raw(b"Valid message".to_vec()),
1244                keypair.public_key(),
1245                std::time::SystemTime::now()
1246                    .duration_since(std::time::UNIX_EPOCH)
1247                    .unwrap()
1248                    .as_secs(),
1249                Kind::Regular,
1250                vec![],
1251            );
1252            MessageFull::new(message, &keypair).unwrap()
1253        };
1254
1255        let test_messages = vec![expired_message.clone(), valid_message.clone()];
1256        let relay_id = KeyId::from([1u8; 32]);
1257        let relay_key_id = relay_id;
1258
1259        // Mock storage to return both messages initially
1260        let mut mock_storage = MockMessageStorage::new();
1261        mock_storage
1262            .expect_get_unsynced_messages_for_relay()
1263            .with(
1264                mockall::predicate::eq(relay_key_id),
1265                mockall::predicate::eq(Some(2)),
1266            )
1267            .times(1)
1268            .returning(move |_, _| Ok(test_messages.clone()));
1269
1270        // Second call returns empty (indicating we're done)
1271        mock_storage
1272            .expect_get_unsynced_messages_for_relay()
1273            .with(
1274                mockall::predicate::eq(relay_key_id),
1275                mockall::predicate::eq(Some(2)),
1276            )
1277            .times(1)
1278            .returning(move |_, _| Ok(vec![]));
1279
1280        // Mock deletion of expired message
1281        let expired_id = *expired_message.id();
1282        mock_storage
1283            .expect_delete_message()
1284            .with(mockall::predicate::eq(expired_id))
1285            .times(1)
1286            .returning(|_| Ok(true));
1287
1288        // Mock successful message sync marking for the valid message only
1289        mock_storage
1290            .expect_mark_message_synced()
1291            .times(1)
1292            .returning(|_, _, _| Ok(()));
1293
1294        // Mock store_message call for the valid message only (new requirement from our fix)
1295        mock_storage
1296            .expect_store_message()
1297            .times(1)
1298            .returning(|_| Ok(()));
1299
1300        let storage = Arc::new(mock_storage);
1301
1302        // Create a mock messages manager
1303        let mut mock_manager = MockMessagesManagerTrait::new();
1304
1305        // Mock check_messages for the valid message only (expired message filtered out)
1306        let valid_message_id = *valid_message.id();
1307        mock_manager
1308            .expect_check_messages()
1309            .with(mockall::predicate::eq(vec![valid_message_id]))
1310            .times(1)
1311            .returning(|_| Ok(vec![None])); // Valid message doesn't exist, needs to be sent
1312
1313        // Mock publish to be called only for the valid message
1314        mock_manager.expect_publish().times(1).returning(|_| {
1315            Ok(PublishResult::StoredNew {
1316                global_stream_id: "valid_stream_123".to_string(),
1317            })
1318        });
1319
1320        let mock_manager = Arc::new(mock_manager);
1321
1322        // Test the batch processing with expired messages
1323        let result = MultiRelayMessageManager::process_all_unsynced_messages_for_relay(
1324            &storage,
1325            &relay_id,
1326            &mock_manager,
1327            2, // batch size of 2
1328        )
1329        .await;
1330
1331        assert!(
1332            result.is_ok(),
1333            "Processing with expired messages should succeed"
1334        );
1335
1336        // Verify that we processed only 1 message (the valid one)
1337        // The expired message should have been removed from storage
1338        assert_eq!(
1339            result.unwrap(),
1340            1,
1341            "Should have processed exactly 1 valid message (expired message removed)"
1342        );
1343    }
1344
1345    #[tokio::test]
1346    async fn test_all_expired_messages_batch() {
1347        use crate::services::messages_manager::MockMessagesManagerTrait;
1348        use zoe_wire_protocol::Kind;
1349
1350        // Create test messages - all expired
1351        let mut rng = rand::thread_rng();
1352        let keypair = KeyPair::generate(&mut rng);
1353
1354        let expired_messages: Vec<MessageFull> = (0..3)
1355            .map(|i| {
1356                let past_timestamp = std::time::SystemTime::now()
1357                    .duration_since(std::time::UNIX_EPOCH)
1358                    .unwrap()
1359                    .as_secs()
1360                    .saturating_sub(10); // 10 seconds ago
1361
1362                let message = Message::new_v0(
1363                    Content::raw(format!("Expired message {i}").as_bytes().to_vec()),
1364                    keypair.public_key(),
1365                    past_timestamp,
1366                    Kind::Ephemeral(1), // 1 second timeout - all should be expired
1367                    vec![],
1368                );
1369                MessageFull::new(message, &keypair).unwrap()
1370            })
1371            .collect();
1372
1373        let relay_id = KeyId::from([1u8; 32]);
1374        let relay_key_id = relay_id;
1375
1376        // Mock storage to return all expired messages
1377        let mut mock_storage = MockMessageStorage::new();
1378        mock_storage
1379            .expect_get_unsynced_messages_for_relay()
1380            .with(
1381                mockall::predicate::eq(relay_key_id),
1382                mockall::predicate::eq(Some(3)),
1383            )
1384            .times(1)
1385            .returning(move |_, _| Ok(expired_messages.clone()));
1386
1387        // Second call returns empty (indicating we're done)
1388        mock_storage
1389            .expect_get_unsynced_messages_for_relay()
1390            .with(
1391                mockall::predicate::eq(relay_key_id),
1392                mockall::predicate::eq(Some(3)),
1393            )
1394            .times(1)
1395            .returning(move |_, _| Ok(vec![]));
1396
1397        // Mock deletion of all expired messages
1398        mock_storage
1399            .expect_delete_message()
1400            .times(3)
1401            .returning(|_| Ok(true));
1402
1403        let storage = Arc::new(mock_storage);
1404
1405        // Create a mock messages manager - no expectations for check_messages or publish
1406        // since all messages are expired and filtered out
1407        let mock_manager = MockMessagesManagerTrait::new();
1408        let mock_manager = Arc::new(mock_manager);
1409
1410        // Test the batch processing with all expired messages
1411        let result = MultiRelayMessageManager::process_all_unsynced_messages_for_relay(
1412            &storage,
1413            &relay_id,
1414            &mock_manager,
1415            3, // batch size of 3
1416        )
1417        .await;
1418
1419        assert!(
1420            result.is_ok(),
1421            "Processing all expired messages should succeed"
1422        );
1423
1424        // Verify that we processed 0 messages (all were expired and removed)
1425        assert_eq!(
1426            result.unwrap(),
1427            0,
1428            "Should have processed 0 messages (all were expired and removed)"
1429        );
1430    }
1431}