zoe_client/services/
messages_manager.rs

1use crate::error::{ClientError, Result};
2use async_broadcast::{Receiver, RecvError, Sender};
3use async_trait::async_trait;
4use eyeball::{AsyncLock, ObservableWriteGuard, SharedObservable};
5use futures::{Stream, StreamExt, pin_mut};
6use serde::{Deserialize, Serialize};
7use std::pin::Pin;
8use std::sync::Arc;
9use tokio::{select, task::JoinHandle};
10use tracing::warn;
11use zoe_client_storage::SubscriptionState;
12use zoe_wire_protocol::{
13    CatchUpRequest, CatchUpResponse, Filter, FilterOperation, FilterUpdateRequest, MessageFilters,
14    MessageFull, PublishResult, StreamMessage, SubscriptionConfig,
15};
16
17use super::messages::{CatchUpStream, MessagesService, MessagesStream};
18use async_stream::stream;
19use std::sync::atomic::AtomicU32;
20
21#[cfg(test)]
22use mockall::automock;
23
24/// Trait abstraction for MessagesManager to enable mocking in tests
25#[cfg_attr(test, automock)]
26#[async_trait]
27pub trait MessagesManagerTrait: Send + Sync {
28    /// Get a stream of all message events for persistence and monitoring
29    fn message_events_stream(&self) -> Receiver<MessageEvent>;
30
31    /// Subscribe to subscription state changes for reactive programming
32    ///
33    /// This returns an eyeball::Subscriber that can be used to observe changes to the
34    /// subscription state reactively. The subscriber will be notified whenever:
35    /// - Stream height is updated
36    /// - Filters are added or removed
37    /// - Any other subscription state changes occur
38    ///
39    /// # Example
40    /// ```rust,no_run
41    /// # use zoe_client::services::MessagesManagerTrait;
42    /// # async fn example(manager: &impl MessagesManagerTrait) {
43    /// let subscriber = manager.subscribe_to_subscription_state();
44    /// let current_state = subscriber.get();
45    /// println!("Current stream height: {:?}", current_state.latest_stream_height);
46    /// # }
47    /// ```
48    async fn get_subscription_state_updates(
49        &self,
50    ) -> eyeball::Subscriber<SubscriptionState, AsyncLock>;
51
52    /// Subscribe to messages with current filters
53    async fn subscribe(&self) -> Result<()>;
54
55    /// Publish a message
56    async fn publish(&self, message: MessageFull) -> Result<PublishResult>;
57
58    /// Ensure a filter is included in the subscription
59    async fn ensure_contains_filter(&self, filter: Filter) -> Result<()>;
60
61    /// Get a stream of incoming messages
62    fn messages_stream(&self) -> Receiver<StreamMessage>;
63
64    /// Get a stream of catch-up responses
65    fn catch_up_stream(&self) -> Receiver<CatchUpResponse>;
66
67    /// Get a filtered stream of messages matching the given filter
68    fn filtered_messages_stream(
69        &self,
70        filter: Filter,
71    ) -> Pin<Box<dyn Stream<Item = Box<MessageFull>> + Send>>;
72
73    /// Catch up to historical messages and subscribe to new ones for a filter
74    async fn catch_up_and_subscribe(
75        &self,
76        filter: Filter,
77        since: Option<String>,
78    ) -> Result<Pin<Box<dyn Stream<Item = Box<MessageFull>> + Send>>>;
79
80    /// Get user data by author and storage key (for PQXDH inbox fetching)
81    async fn user_data(
82        &self,
83        author: zoe_wire_protocol::KeyId,
84        storage_key: zoe_wire_protocol::StoreKey,
85    ) -> Result<Option<MessageFull>>;
86
87    /// Check which messages the server already has and return their global stream IDs.
88    /// Returns a vec of `Option<String>` in the same order as the input, where:
89    /// - `Some(stream_id)` means the server has the message with that global stream ID
90    /// - `None` means the server doesn't have this message yet
91    async fn check_messages(
92        &self,
93        message_ids: Vec<zoe_wire_protocol::MessageId>,
94    ) -> Result<Vec<Option<String>>>;
95}
96
97/// Comprehensive message event that covers all message flows for persistence and monitoring.
98///
99/// This enum captures every type of message activity in the MessagesManager,
100/// enabling complete message persistence and audit trails.
101#[derive(Debug, Clone)]
102pub enum MessageEvent {
103    /// Message received from subscription stream
104    MessageReceived {
105        message: MessageFull,
106        stream_height: String,
107    },
108    /// Message sent by this client
109    MessageSent {
110        message: MessageFull,
111        publish_result: PublishResult,
112    },
113    /// Historical message from catch-up
114    CatchUpMessage {
115        message: MessageFull,
116        request_id: u32,
117    },
118    /// Stream height update
119    StreamHeightUpdate { height: String },
120    /// Catch-up completed
121    CatchUpCompleted { request_id: u32 },
122}
123
124/// Configuration for catching up on historical messages
125#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
126pub struct CatchUpConfig {
127    /// How far back to catch up (in messages or time)
128    pub since: Option<u64>,
129    /// Maximum number of messages to catch up
130    pub limit: Option<u32>,
131}
132
133/// Convert SubscriptionState to SubscriptionConfig for wire protocol
134fn subscription_state_to_config(state: &SubscriptionState) -> SubscriptionConfig {
135    SubscriptionConfig {
136        filters: state.current_filters.clone(),
137        since: state.latest_stream_height.clone(),
138        limit: None,
139    }
140}
141
142/// Builder for creating MessagesManager instances with persistent state support.
143///
144/// This builder allows you to configure the MessagesManager with previous state,
145/// buffer sizes, and other options before connecting to the message service.
146///
147/// # Example
148///
149/// ```rust,no_run
150/// # use zoe_client::services::{MessagesManagerBuilder, SubscriptionState};
151/// # async fn example(connection: &quinn::Connection) -> zoe_client::error::Result<()> {
152/// # let saved_bytes: Vec<u8> = vec![];
153/// // Create with previous state
154/// let previous_state = SubscriptionState::new();
155/// let manager = MessagesManagerBuilder::new()
156///     .state(previous_state)
157///     .buffer_size(2000)
158///     .build(connection)
159///     .await?;
160///
161/// // Or create fresh
162/// let manager = MessagesManagerBuilder::new()
163///     .build(connection)
164///     .await?;
165/// # Ok(())
166/// # }
167/// ```
168#[derive(Debug)]
169pub struct MessagesManagerBuilder {
170    /// Previous subscription state to restore
171    state: SubscriptionState,
172    /// Buffer size for the broadcast channel
173    buffer_size: Option<usize>,
174    /// whether to automatically issue the subscribe command at start
175    autosubscribe: bool,
176}
177
178impl Default for MessagesManagerBuilder {
179    fn default() -> Self {
180        Self::new()
181    }
182}
183
184impl MessagesManagerBuilder {
185    /// Create a new builder with default settings
186    pub fn new() -> Self {
187        Self {
188            state: SubscriptionState::new(),
189            buffer_size: None,
190            autosubscribe: false,
191        }
192    }
193
194    /// Set the subscription state to restore from
195    pub fn state(mut self, state: SubscriptionState) -> Self {
196        self.state = state;
197        self
198    }
199
200    pub fn with_filters(mut self, filters: MessageFilters) -> Self {
201        self.state.current_filters = filters;
202        self
203    }
204
205    /// Set the buffer size for the broadcast channel
206    pub fn buffer_size(mut self, size: usize) -> Self {
207        self.buffer_size = Some(size);
208        self
209    }
210
211    pub fn autosubscribe(mut self, autosubscribe: bool) -> Self {
212        self.autosubscribe = autosubscribe;
213        self
214    }
215
216    /// Build the MessagesManager by connecting to the service
217    ///
218    /// This will:
219    /// 1. Connect to the messages service
220    /// 2. Restore previous subscription state if any
221    /// 3. Start the message broadcasting
222    /// 4. Return a fully configured MessagesManager
223    pub async fn build(self, connection: &quinn::Connection) -> Result<MessagesManager> {
224        // Create the messages service and stream
225        let (messages_service, (messages_stream, catch_up_stream)) =
226            MessagesService::connect(connection).await?;
227        let MessagesManagerBuilder {
228            state,
229            buffer_size,
230            autosubscribe,
231        } = self;
232
233        // Create the manager
234        let manager = MessagesManager::new_with_state(
235            messages_service,
236            messages_stream,
237            catch_up_stream,
238            state,
239            buffer_size,
240        );
241
242        if autosubscribe {
243            manager.subscribe().await?;
244        }
245
246        Ok(manager)
247    }
248}
249
250struct AbortOnDrop<T>(JoinHandle<T>);
251
252impl<T> Drop for AbortOnDrop<T> {
253    fn drop(&mut self) {
254        self.0.abort();
255    }
256}
257
258/// High-level messages manager that provides a unified interface for message operations.
259///
260/// The `MessagesManager` combines message broadcasting and subscription management:
261/// - **Message Broadcasting**: Distributes incoming messages to multiple subscribers
262/// - **Subscription Management**: Manages server-side subscriptions with in-flight updates
263/// - **Stream Filtering**: Provides client-side filtering and routing capabilities
264/// - **Lifecycle Management**: Automatic subscription creation, updates, and cleanup
265///
266/// This is the primary interface for interacting with the messaging system.
267///
268#[derive(Clone)]
269pub struct MessagesManager {
270    /// The underlying messages service for RPC operations
271    messages_service: Arc<MessagesService>,
272    /// Broadcast sender for distributing messages to subscribers
273    broadcast_tx: Arc<Sender<StreamMessage>>,
274    /// Broadcast sender for distributing catch-up responses to subscribers
275    catch_up_tx: Arc<Sender<CatchUpResponse>>,
276    /// Broadcast sender for all message events (for persistence and monitoring)
277    message_events_tx: Arc<Sender<MessageEvent>>,
278    /// Keeper receiver to prevent broadcast channel closure (not actively used)
279    _broadcast_keeper: async_broadcast::InactiveReceiver<StreamMessage>,
280    /// Keeper receiver to prevent catch-up channel closure (not actively used)
281    _catch_up_keeper: async_broadcast::InactiveReceiver<CatchUpResponse>,
282    /// Keeper receiver to prevent message events channel closure (not actively used)
283    _message_events_keeper: async_broadcast::InactiveReceiver<MessageEvent>,
284    /// Current subscription state (persistent across reconnections)
285    state: SharedObservable<SubscriptionState, AsyncLock>,
286    /// Background task handle for syncing with the server
287    _sync_handler: Arc<AbortOnDrop<Result<()>>>,
288    /// Catch-up request ID counter
289    catch_up_request_id: Arc<AtomicU32>,
290}
291
292impl MessagesManager {
293    pub fn builder() -> MessagesManagerBuilder {
294        MessagesManagerBuilder::new()
295    }
296
297    /// Helper function to safely broadcast messages using try_broadcast
298    /// Handles TrySendError cases gracefully without panicking
299    fn safe_broadcast<T: Clone>(sender: &Sender<T>, message: T, context: &str) {
300        match sender.try_broadcast(message) {
301            Ok(_msg) => {
302                tracing::trace!("{context}: Successfully broadcast message");
303            }
304            Err(async_broadcast::TrySendError::Inactive(_msg)) => {
305                tracing::debug!("{context}: All receivers inactive, message not sent");
306            }
307            Err(async_broadcast::TrySendError::Full(_msg)) => {
308                tracing::warn!("{context}: Broadcast channel full, message dropped");
309            }
310            Err(async_broadcast::TrySendError::Closed(_msg)) => {
311                tracing::debug!("{context}: Broadcast channel closed");
312            }
313        }
314    }
315
316    /// Create a new MessagesManager with existing subscription state.
317    ///
318    /// This allows restoring a manager to a previous state after reconnection.
319    ///
320    /// # Arguments
321    /// * `messages_service` - The underlying messages service for RPC operations
322    /// * `messages_stream` - The stream of messages from the server
323    /// * `state` - Previous subscription state to restore
324    /// * `buffer_size` - Optional buffer size for the broadcast channel (default: 1000)
325    fn new_with_state(
326        messages_service: MessagesService,
327        messages_stream: MessagesStream,
328        catch_up_stream: CatchUpStream,
329        state: SubscriptionState,
330        buffer_size: Option<usize>,
331    ) -> Self {
332        let buffer_size = buffer_size.unwrap_or(1000);
333        let (broadcast_tx, broadcast_keeper) = async_broadcast::broadcast(buffer_size);
334        let (catch_up_tx, catch_up_keeper) = async_broadcast::broadcast(buffer_size);
335        let (message_events_tx, message_events_keeper) = async_broadcast::broadcast(buffer_size);
336
337        // Create observable state
338        let state = SharedObservable::new_async(state);
339
340        // Start background task to forward messages from stream to broadcast channel
341        let tx_clone = broadcast_tx.clone();
342        let catch_up_tx_clone = catch_up_tx.clone();
343        let message_events_tx_clone = message_events_tx.clone();
344        let state_clone = state.clone();
345        let sync_handler = tokio::spawn(async move {
346            let mut m_stream = messages_stream;
347            let mut c_stream = catch_up_stream;
348            loop {
349                select! {
350                    message = m_stream.recv() => {
351                        let Some(message) = message else {
352                            tracing::debug!("📪 Subscriptions stream ended");
353                            break;
354                        };
355                        match &message {
356                            StreamMessage::StreamHeightUpdate(height) => {
357                                // Update both the internal state and the observable
358                                {
359                                    let mut state = state_clone.write().await;
360                                    ObservableWriteGuard::update(&mut state, |state: &mut SubscriptionState| {
361                                        state.set_stream_height(height.clone());
362                                    });
363                                }
364                                // Emit height update event
365                                let event = MessageEvent::StreamHeightUpdate { height: height.clone() };
366                                Self::safe_broadcast(&message_events_tx_clone, event, "StreamHeightUpdate event");
367                            },
368                            StreamMessage::MessageReceived { message: msg, stream_height } => {
369                                // Update both the internal state and the observable
370                                {
371                                    let mut state = state_clone.write().await;
372                                    ObservableWriteGuard::update(&mut state, |state: &mut SubscriptionState| {
373                                        state.set_stream_height(stream_height.clone());
374                                    });
375                                }
376
377                                // Emit message received event
378                                let event = MessageEvent::MessageReceived {
379                                    message: (**msg).clone(),
380                                    stream_height: stream_height.clone()
381                                };
382                                Self::safe_broadcast(&message_events_tx_clone, event, "MessageReceived event");
383                            }
384                        }
385
386                        // Forward message to all subscribers
387                        // async-broadcast queues messages for receivers even if they're not actively polling
388                        tracing::debug!("MessagesManager forwarding message to broadcast channel: {:?}", message);
389                        Self::safe_broadcast(&tx_clone, message, "StreamMessage");
390                    }
391                    catch_up_response = c_stream.recv() => {
392                        let Some(catch_up_response) = catch_up_response else {
393                            tracing::debug!("📪 Catch-up stream ended");
394                            break;
395                        };
396                        tracing::debug!("📨 MessagesManager received catch-up response: {:?}", catch_up_response);
397
398                        // Emit catch-up message events
399                        for message in &catch_up_response.messages {
400                            let event = MessageEvent::CatchUpMessage {
401                                message: message.clone(),
402                                request_id: catch_up_response.request_id
403                            };
404                            Self::safe_broadcast(&message_events_tx_clone, event, "CatchUpMessage event");
405                        }
406
407                        if catch_up_response.is_complete {
408                            let event = MessageEvent::CatchUpCompleted {
409                                request_id: catch_up_response.request_id
410                            };
411                            Self::safe_broadcast(&message_events_tx_clone, event, "CatchUpCompleted event");
412                        }
413
414                        Self::safe_broadcast(&catch_up_tx_clone, catch_up_response, "CatchUpResponse");
415                    }
416                }
417            }
418
419            Ok(())
420        });
421
422        Self {
423            messages_service: Arc::new(messages_service),
424            broadcast_tx: Arc::new(broadcast_tx),
425            catch_up_tx: Arc::new(catch_up_tx),
426            message_events_tx: Arc::new(message_events_tx),
427            state,
428            catch_up_request_id: Arc::new(AtomicU32::new(0)),
429            _broadcast_keeper: broadcast_keeper.deactivate(),
430            _catch_up_keeper: catch_up_keeper.deactivate(),
431            _message_events_keeper: message_events_keeper.deactivate(),
432            _sync_handler: Arc::new(AbortOnDrop(sync_handler)),
433        }
434    }
435
436    pub async fn subscribe(&self) -> Result<()> {
437        let state = self.state.read().await.clone();
438        self.messages_service
439            .subscribe(subscription_state_to_config(&state))
440            .await
441    }
442
443    pub async fn ensure_contains_filter(&self, filter: Filter) -> Result<()> {
444        let new_filters = self
445            .messages_service
446            .update_filters(FilterUpdateRequest {
447                operations: vec![FilterOperation::Add(vec![filter])],
448            })
449            .await?;
450
451        // Update both the internal state and the observable
452        {
453            let mut state = self.state.write().await;
454            ObservableWriteGuard::update(&mut state, |state: &mut SubscriptionState| {
455                state.current_filters = new_filters.filters;
456            });
457        }
458
459        Ok(())
460    }
461
462    pub async fn catch_up_and_subscribe(
463        self,
464        filter: Filter,
465        since: Option<String>,
466    ) -> Result<impl Stream<Item = Box<MessageFull>>> {
467        // Enure if the underlying service is still alive
468        if self.messages_service.is_closed() {
469            return Err(ClientError::Generic(
470                "Messages service connection is closed".to_string(),
471            ));
472        }
473
474        // First, ensure the filter is added to the server-side subscription
475        // This is crucial so that future messages matching this filter will be delivered
476        self.ensure_contains_filter(filter.clone()).await?;
477
478        let request_id = self
479            .catch_up_request_id
480            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
481        let request = CatchUpRequest {
482            filter: filter.clone(),
483            since,
484            max_messages: None,
485            request_id,
486        };
487
488        // Store request_id before moving the request
489        let request_id_filter = request.request_id;
490
491        // Create and start polling the catch-up receiver immediately to make it "active"
492        let mut catch_up_receiver = self.catch_up_tx.new_receiver();
493
494        // Send the catch-up request to the server
495        self.messages_service.catch_up(request).await?;
496
497        let regular_messages_stream = self.clone().filtered_messages_stream(filter.clone());
498
499        let catch_up_stream = {
500            // we put this into a scope so the broadcaster is dropped when the stream finished
501
502            let catch_up_rcv = async_stream::stream! {
503                loop {
504                    match catch_up_receiver.recv().await {
505                        Ok(CatchUpResponse {
506                            request_id,
507                            messages,
508                            is_complete,
509                            ..
510                        }) => {
511                            if request_id == request_id_filter {
512                                yield (messages, is_complete);
513                            }
514                        }
515                        Err(RecvError::Closed) => break, // we are done processing
516                        Err(RecvError::Overflowed(skipped)) => {
517                            warn!(
518                                "MessagesManager catch-up subscriber lagged, skipped {} responses",
519                                skipped
520                            );
521                            // Continue receiving after overflow
522                        }
523                    }
524                }
525            };
526
527            stream! {
528                pin_mut!(catch_up_rcv);
529                tracing::debug!("🔄 Catch-up stream starting for request_id: {request_id_filter}");
530                while let Some((messages, is_complete)) = catch_up_rcv.next().await {
531                    tracing::debug!("📦 Catch-up received {} messages, is_complete: {}", messages.len(), is_complete);
532                    for message in messages {
533                        yield Box::new(message);
534                    }
535                    if is_complete {
536                        tracing::debug!("✅ Catch-up stream completed for request_id: {request_id_filter}");
537                        break;
538                    }
539                }
540                tracing::debug!("🏁 Catch-up stream ended for request_id: {request_id_filter}");
541            }
542        };
543
544        Ok(Box::pin(catch_up_stream.chain(regular_messages_stream)))
545    }
546
547    pub fn filtered_messages_stream(
548        self,
549        filter: Filter,
550    ) -> Pin<Box<dyn Stream<Item = Box<MessageFull>> + Send>> {
551        Box::pin(self.filtered_fn(move |msg| {
552            let StreamMessage::MessageReceived { message, .. } = msg else {
553                return None;
554            };
555            if filter.matches(&message) {
556                tracing::debug!(
557                    "✅ Message matched filter: {:?}, message_id: {}",
558                    filter,
559                    hex::encode(message.id().as_bytes())
560                );
561                Some(message)
562            } else {
563                tracing::debug!(
564                    "❌ Message did not match filter: {:?}, message_id: {}",
565                    filter,
566                    hex::encode(message.id().as_bytes())
567                );
568                None
569            }
570        }))
571    }
572
573    /// Get a filtered stream of messages.
574    ///
575    /// This creates a client-side filtered stream from the internal broadcast channel.
576    /// The filter function is applied to all messages received by the manager.
577    ///
578    /// # Arguments
579    /// * `filter` - A function that returns true for messages to include
580    ///
581    /// # Returns
582    /// A stream of messages that match the filter
583    pub fn filtered_fn<F, T>(self, filter: F) -> impl Stream<Item = T>
584    where
585        F: Fn(StreamMessage) -> Option<T> + Send + Clone + 'static,
586    {
587        let mut receiver = self.broadcast_tx.new_receiver();
588        tracing::info!(
589            "🔧 Created new broadcast receiver for filtered stream (manager: {:p})",
590            self.broadcast_tx.as_ref()
591        );
592
593        // Convert async-broadcast receiver to stream and apply filter
594        async_stream::stream! {
595            tracing::info!("🎯 Filtered stream started, waiting for messages...");
596            loop {
597                match receiver.recv().await {
598                    Ok(message) => {
599                        if let Some(filtered) = filter(message) {
600                            yield filtered;
601                        }
602                    }
603                    Err(RecvError::Closed) => {
604                        break;
605                    }
606                    Err(RecvError::Overflowed(skipped)) => {
607                        warn!(
608                            "MessagesManager subscriber lagged, skipped {} messages",
609                            skipped
610                        );
611                        // Continue receiving after overflow
612                    }
613                }
614            }
615        }
616    }
617
618    /// Get a stream of all messages.
619    ///
620    /// # Returns
621    /// A stream of all messages received by the manager
622    pub fn all_messages_stream(&self) -> Receiver<StreamMessage> {
623        self.broadcast_tx.new_receiver()
624    }
625
626    pub async fn publish(&self, message: MessageFull) -> Result<PublishResult> {
627        // Publish to network
628        let result = self
629            .messages_service
630            .publish(tarpc::context::current(), message.clone())
631            .await??;
632
633        // Emit the sent message event
634        let event = MessageEvent::MessageSent {
635            message,
636            publish_result: result.clone(),
637        };
638
639        Self::safe_broadcast(&self.message_events_tx, event, "MessageSent event");
640
641        Ok(result)
642    }
643
644    /// Get the current subscription state for persistence.
645    ///
646    /// This state can be saved and later restored using MessagesManagerBuilder.
647    ///
648    /// # Returns
649    /// A clone of the current subscription state
650    pub async fn get_subscription_state(&self) -> SubscriptionState {
651        self.state.read().await.clone()
652    }
653    /// Get the latest stream height received.
654    ///
655    /// This can be used to determine how up-to-date the client is.
656    pub async fn get_latest_stream_height(&self) -> Option<String> {
657        self.state.read().await.latest_stream_height.clone()
658    }
659
660    /// Get the current combined filters.
661    ///
662    /// This shows all the filters that are currently active in the subscription.
663    pub async fn get_current_filters(&self) -> MessageFilters {
664        self.state.read().await.current_filters.clone()
665    }
666
667    /// Get a stream of all message events for persistence and monitoring.
668    ///
669    /// This stream captures every message activity including:
670    /// - Messages received from subscriptions
671    /// - Messages sent by this client
672    /// - Historical messages from catch-up requests
673    /// - Stream height updates
674    /// - Catch-up completion notifications
675    ///
676    /// This is primarily intended for persistence services and audit logging.
677    ///
678    /// # Returns
679    /// A stream of `MessageEvent` that captures all message activities
680    pub fn message_events_stream(&self) -> Receiver<MessageEvent> {
681        self.message_events_tx.new_receiver()
682    }
683}
684
685#[async_trait]
686impl MessagesManagerTrait for MessagesManager {
687    fn message_events_stream(&self) -> Receiver<MessageEvent> {
688        self.message_events_tx.new_receiver()
689    }
690
691    async fn get_subscription_state_updates(
692        &self,
693    ) -> eyeball::Subscriber<SubscriptionState, AsyncLock> {
694        self.state.subscribe().await
695    }
696
697    async fn subscribe(&self) -> Result<()> {
698        MessagesManager::subscribe(self).await
699    }
700
701    async fn publish(&self, message: MessageFull) -> Result<PublishResult> {
702        MessagesManager::publish(self, message).await
703    }
704
705    async fn ensure_contains_filter(&self, filter: Filter) -> Result<()> {
706        MessagesManager::ensure_contains_filter(self, filter).await
707    }
708
709    fn messages_stream(&self) -> Receiver<StreamMessage> {
710        self.all_messages_stream()
711    }
712
713    fn catch_up_stream(&self) -> Receiver<CatchUpResponse> {
714        self.catch_up_tx.new_receiver()
715    }
716
717    fn filtered_messages_stream(
718        &self,
719        filter: Filter,
720    ) -> std::pin::Pin<Box<dyn Stream<Item = Box<MessageFull>> + Send>> {
721        Box::pin(MessagesManager::filtered_messages_stream(
722            self.clone(),
723            filter,
724        ))
725    }
726
727    async fn catch_up_and_subscribe(
728        &self,
729        filter: Filter,
730        since: Option<String>,
731    ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Box<MessageFull>> + Send>>> {
732        let stream = MessagesManager::catch_up_and_subscribe(self.clone(), filter, since).await?;
733        Ok(Box::pin(stream))
734    }
735
736    async fn user_data(
737        &self,
738        author: zoe_wire_protocol::KeyId,
739        storage_key: zoe_wire_protocol::StoreKey,
740    ) -> Result<Option<MessageFull>> {
741        use tarpc::context;
742        let result = self
743            .messages_service
744            .user_data(context::current(), author, storage_key)
745            .await?;
746        Ok(result?)
747    }
748
749    async fn check_messages(
750        &self,
751        message_ids: Vec<zoe_wire_protocol::MessageId>,
752    ) -> Result<Vec<Option<String>>> {
753        use tarpc::context;
754        let result = self
755            .messages_service
756            .check_messages(context::current(), message_ids)
757            .await?;
758        Ok(result?)
759    }
760}
761
762#[cfg(test)]
763mod tests {
764    use super::*;
765    use zoe_wire_protocol::{Filter, KeyId, Tag};
766
767    #[tokio::test]
768    async fn test_filtered_stream_logic() {
769        // Test the filtering logic that would be used in get_filtered_stream
770        let test_messages = [
771            StreamMessage::StreamHeightUpdate("100".to_string()),
772            StreamMessage::StreamHeightUpdate("200".to_string()),
773            StreamMessage::StreamHeightUpdate("150".to_string()),
774        ];
775
776        // Test filter that only allows values > 150
777        let filter = |msg: &StreamMessage| -> bool {
778            match msg {
779                StreamMessage::StreamHeightUpdate(height) => {
780                    height.parse::<i32>().unwrap_or(0) > 150
781                }
782                _ => false,
783            }
784        };
785
786        let filtered: Vec<_> = test_messages.iter().filter(|msg| filter(msg)).collect();
787        assert_eq!(filtered.len(), 1); // Only "200" should pass
788
789        match filtered[0] {
790            StreamMessage::StreamHeightUpdate(height) => {
791                assert_eq!(height, "200");
792            }
793            _ => panic!("Expected StreamHeightUpdate"),
794        }
795    }
796
797    #[tokio::test]
798    async fn test_tag_filtering_logic() {
799        use rand::rngs::OsRng;
800        use zoe_wire_protocol::{
801            Content, KeyPair, Kind, Message, MessageFull, MessageV0, MessageV0Header,
802        };
803
804        // Create test tags
805        let channel_tag = Tag::Channel {
806            id: b"test-channel".to_vec(),
807            relays: vec![],
808        };
809
810        let user_tag = Tag::User {
811            id: KeyId::from_bytes([0u8; 32]),
812            relays: vec![],
813        };
814
815        // Test the tag filtering logic
816        let target_tag = channel_tag.clone();
817        let filter = move |msg: &StreamMessage| -> bool {
818            match msg {
819                StreamMessage::MessageReceived { message, .. } => {
820                    message.tags().contains(&target_tag)
821                }
822                StreamMessage::StreamHeightUpdate(_) => false,
823            }
824        };
825
826        // Create a keypair for signing
827        let keypair = KeyPair::generate(&mut OsRng);
828
829        // Create test message with channel tag
830        let message_v0_with_channel = MessageV0 {
831            header: MessageV0Header {
832                sender: keypair.public_key(),
833                when: 1640995200,
834                kind: Kind::Ephemeral(3600), // 1 hour TTL
835                tags: vec![channel_tag],
836            },
837            content: Content::Raw(b"test message".to_vec()),
838        };
839
840        let message_with_channel = Message::MessageV0(message_v0_with_channel);
841        let full_message_with_channel = MessageFull::new(message_with_channel, &keypair).unwrap();
842
843        let stream_msg_with_channel = StreamMessage::MessageReceived {
844            message: Box::new(full_message_with_channel),
845            stream_height: "100".to_string(),
846        };
847
848        // Create test message with user tag
849        let message_v0_with_user = MessageV0 {
850            header: MessageV0Header {
851                sender: keypair.public_key(),
852                when: 1640995200,
853                kind: Kind::Ephemeral(3600),
854                tags: vec![user_tag],
855            },
856            content: Content::Raw(b"test message".to_vec()),
857        };
858
859        let message_with_user = Message::MessageV0(message_v0_with_user);
860        let full_message_with_user = MessageFull::new(message_with_user, &keypair).unwrap();
861
862        let stream_msg_with_user = StreamMessage::MessageReceived {
863            message: Box::new(full_message_with_user),
864            stream_height: "101".to_string(),
865        };
866
867        let height_update = StreamMessage::StreamHeightUpdate("100".to_string());
868
869        // Test filtering
870        assert!(
871            filter(&stream_msg_with_channel),
872            "Should pass channel tag filter"
873        );
874        assert!(
875            !filter(&stream_msg_with_user),
876            "Should not pass channel tag filter"
877        );
878        assert!(!filter(&height_update), "Should not pass height update");
879    }
880
881    #[tokio::test]
882    async fn test_subscription_state_tracking() {
883        // Test the new subscription state tracking logic
884        let mut state = SubscriptionState::new();
885
886        let tag = Tag::Channel {
887            id: b"test".to_vec(),
888            relays: vec![],
889        };
890
891        let filter: Filter = tag.into();
892
893        // Test adding filters
894        state.add_filters(std::slice::from_ref(&filter));
895        assert!(state.has_active_filters());
896        assert_eq!(state.current_filters.filters.as_ref().unwrap().len(), 1);
897
898        // Test updating stream height
899        state.set_stream_height("123".to_string());
900        assert_eq!(state.latest_stream_height, Some("123".to_string()));
901
902        // Test removing filters
903        state.remove_filters(&[filter]);
904        assert!(!state.has_active_filters());
905
906        // Test serialization
907        let bytes = postcard::to_stdvec(&state).unwrap();
908        let restored: SubscriptionState = postcard::from_bytes(&bytes).unwrap();
909        assert_eq!(state, restored);
910    }
911
912    #[tokio::test]
913    async fn test_subscription_state_observable() {
914        // Create a test subscription state
915        let initial_state = SubscriptionState::new();
916
917        // Create a SharedObservable directly to test the API
918        let state_observable = SharedObservable::new(initial_state.clone());
919
920        // Test subscription
921        let subscriber = state_observable.subscribe();
922        let current_state = subscriber.get();
923        assert_eq!(current_state, initial_state);
924
925        // Test state update
926        let mut updated_state = initial_state.clone();
927        updated_state.set_stream_height("123".to_string());
928
929        state_observable.set_if_not_eq(updated_state.clone());
930
931        // Verify the subscriber sees the update
932        let observed_state = subscriber.get();
933        assert_eq!(observed_state.latest_stream_height, Some("123".to_string()));
934        assert_eq!(observed_state, updated_state);
935    }
936
937    #[tokio::test]
938    async fn test_filter_matching_logic() {
939        use rand::rngs::OsRng;
940        use zoe_wire_protocol::{
941            Content, KeyPair, Kind, Message, MessageFull, MessageV0, MessageV0Header,
942        };
943
944        // Create test keypair
945        let keypair = KeyPair::generate(&mut OsRng);
946
947        // Test Channel filter matching
948        let channel_id = b"test-channel-123".to_vec();
949        let channel_filter = Filter::Channel(channel_id.clone());
950
951        // Create message with matching channel tag
952        let message_with_channel = MessageV0 {
953            header: MessageV0Header {
954                sender: keypair.public_key(),
955                when: 1640995200,
956                kind: Kind::Regular,
957                tags: vec![Tag::Channel {
958                    id: channel_id.clone(),
959                    relays: vec![],
960                }],
961            },
962            content: Content::Raw(b"test message".to_vec()),
963        };
964        let message = Message::MessageV0(message_with_channel);
965        let full_message = MessageFull::new(message, &keypair).unwrap();
966
967        // Test that channel filter matches
968        assert!(
969            channel_filter.matches(&full_message),
970            "Channel filter should match message with same channel tag"
971        );
972
973        // Test with different channel ID
974        let different_channel_filter = Filter::Channel(b"different-channel".to_vec());
975        assert!(
976            !different_channel_filter.matches(&full_message),
977            "Channel filter should not match message with different channel tag"
978        );
979
980        // Create message with Event tag
981        let event_id = *full_message.id();
982        let message_with_event = MessageV0 {
983            header: MessageV0Header {
984                sender: keypair.public_key(),
985                when: 1640995200,
986                kind: Kind::Regular,
987                tags: vec![Tag::Event {
988                    id: event_id,
989                    relays: vec![],
990                }],
991            },
992            content: Content::Raw(b"test message".to_vec()),
993        };
994        let message = Message::MessageV0(message_with_event);
995        let full_message_with_event = MessageFull::new(message, &keypair).unwrap();
996
997        // Test Event filter matching
998        let event_filter = Filter::Event(event_id);
999        assert!(
1000            event_filter.matches(&full_message_with_event),
1001            "Event filter should match message with same event tag"
1002        );
1003
1004        // Test that channel filter doesn't match event message
1005        assert!(
1006            !channel_filter.matches(&full_message_with_event),
1007            "Channel filter should not match message with event tag"
1008        );
1009
1010        // Test Author filter matching
1011        let author_filter = Filter::Author(KeyId::from(*keypair.public_key().id()));
1012        assert!(
1013            author_filter.matches(&full_message),
1014            "Author filter should match message from same author"
1015        );
1016        assert!(
1017            author_filter.matches(&full_message_with_event),
1018            "Author filter should match any message from same author"
1019        );
1020
1021        // Test with different author
1022        let different_keypair = KeyPair::generate(&mut OsRng);
1023        let different_author_filter =
1024            Filter::Author(KeyId::from(*different_keypair.public_key().id()));
1025        assert!(
1026            !different_author_filter.matches(&full_message),
1027            "Author filter should not match message from different author"
1028        );
1029    }
1030}