zoe_message_store/
service.rs

1use crate::RedisMessageStorage;
2use futures::StreamExt;
3use std::sync::Arc;
4use tokio::sync::{mpsc, RwLock};
5use tracing::{error, info, warn};
6use zoe_wire_protocol::{
7    CatchUpRequest, CatchUpResponse, FilterOperation, FilterUpdateRequest, KeyId, MessageError,
8    MessageFilters, MessageFull, MessageId, MessageService as MessageServiceRpc, PublishResult,
9    StoreKey, StreamMessage, SubscriptionConfig,
10};
11
12#[derive(Clone)]
13pub struct MessagesRpcService {
14    pub store: Arc<RedisMessageStorage>,
15    /// Channel for sending streaming messages back to the relay service
16    pub stream_sender: mpsc::UnboundedSender<StreamMessage>,
17    /// Channel for sending catch-up responses back to the relay service  
18    pub response_sender: mpsc::UnboundedSender<CatchUpResponse>,
19    /// The current subscription config
20    pub subscription: Arc<RwLock<SubscriptionConfig>>,
21    /// the running task handle for the subscription task
22    pub task_handle: Arc<RwLock<Option<tokio::task::AbortHandle>>>,
23}
24
25impl MessagesRpcService {
26    pub fn new(
27        store: Arc<RedisMessageStorage>,
28    ) -> (
29        mpsc::UnboundedReceiver<StreamMessage>,
30        mpsc::UnboundedReceiver<CatchUpResponse>,
31        Self,
32    ) {
33        // Channels for receiving messages from background tasks
34        let (sub_sender, sub_receiver) = mpsc::unbounded_channel::<StreamMessage>();
35        let (response_sender, response_receiver) = mpsc::unbounded_channel::<CatchUpResponse>();
36        (
37            sub_receiver,
38            response_receiver,
39            Self {
40                store,
41                stream_sender: sub_sender,
42                response_sender,
43                subscription: Arc::new(RwLock::new(SubscriptionConfig::default())),
44                task_handle: Arc::new(RwLock::new(None)),
45            },
46        )
47    }
48
49    async fn start_subscription_task(&self) -> Result<(), crate::MessageStoreError> {
50        let config = self.subscription.read().await;
51
52        // Spawn the background subscription task
53        let filters = config.filters.clone();
54        if filters.is_empty() {
55            // if we are empty, we are not starting the task,
56            // just clear the current handle.
57            self.abort_subscription_task().await?;
58            return Ok(());
59        }
60
61        // Check if we have the stream sender channel
62        let stream_sender = self.stream_sender.clone();
63
64        let since = config.since.clone();
65        let store = self.store.clone();
66        let limit = config.limit;
67        let subscription = self.subscription.clone();
68
69        let task_handle = tokio::spawn(async move {
70            if let Err(e) =
71                subscription_task(store, filters, since, limit, subscription, stream_sender).await
72            {
73                error!(error = ?e, "Subscription task failed");
74            }
75        });
76        self.task_handle
77            .write()
78            .await
79            .replace(task_handle.abort_handle());
80        Ok(())
81    }
82
83    async fn abort_subscription_task(&self) -> Result<(), crate::MessageStoreError> {
84        if let Some(task_handle) = self.task_handle.write().await.take() {
85            task_handle.abort();
86        }
87        Ok(())
88    }
89}
90
91// Background task functions for handling subscriptions and catch-up requests
92async fn subscription_task(
93    service: Arc<RedisMessageStorage>,
94    filters: MessageFilters,
95    since: Option<String>,
96    limit: Option<usize>,
97    subscription: Arc<RwLock<SubscriptionConfig>>,
98    sender: mpsc::UnboundedSender<StreamMessage>,
99) -> Result<(), crate::MessageStoreError> {
100    let task_id = format!("{:p}", &sender);
101    info!(
102        "🔄 Starting subscription task {} with filters: {:?}",
103        task_id, filters
104    );
105
106    let stream = service.listen_for_messages(&filters, since, limit).await?;
107    info!("Subscription stream created, starting to listen for messages");
108
109    // Pin the stream so we can use it with .next()
110    tokio::pin!(stream);
111
112    while let Some(result) = stream.next().await {
113        let to_client = match result {
114            Ok((Some(message), height)) => {
115                tracing::debug!(
116                    "📤 Subscription task {} yielding message to client: {}",
117                    task_id,
118                    hex::encode(message.id().as_bytes())
119                );
120                StreamMessage::MessageReceived {
121                    message: Box::new(message),
122                    stream_height: height,
123                }
124            }
125            Ok((None, height)) => {
126                // Empty batch - just a stream height update
127                StreamMessage::StreamHeightUpdate(height)
128            }
129            Err(e) => {
130                error!("Error in subscription stream: {}", e);
131                break;
132            }
133        };
134
135        let new_height = match &to_client {
136            StreamMessage::MessageReceived { stream_height, .. } => stream_height.clone(),
137            StreamMessage::StreamHeightUpdate(height) => height.clone(),
138        };
139
140        // Send response to relay service
141        if let Err(error) = sender.send(to_client) {
142            error!(?error, "Relay service closed, stopping subscription");
143            break;
144        }
145
146        {
147            // also update the internal subscription state height for future restarting.
148            let mut subscription = subscription.write().await;
149            subscription.since = Some(new_height);
150        }
151    }
152
153    info!("Subscription task ended");
154    Ok(())
155}
156
157async fn handle_catch_up_request(
158    service: Arc<RedisMessageStorage>,
159    request: CatchUpRequest,
160    sender: mpsc::UnboundedSender<CatchUpResponse>,
161) -> Result<(), crate::MessageStoreError> {
162    info!("Handling catch-up request: {:?}", request);
163
164    let stream = service.catch_up(&request.filter, request.since).await?;
165
166    tokio::pin!(stream);
167
168    let mut messages = Vec::new();
169    // let max_messages = request.max_messages.unwrap_or(100);
170
171    while let Some(result) = stream.next().await {
172        match result {
173            Ok((message, (_global_height, _local_height))) => {
174                messages.push(message);
175
176                // Send in batches
177                if messages.len() >= 10 {
178                    let response = CatchUpResponse {
179                        request_id: request.request_id,
180                        filter: request.filter.clone(),
181                        messages: messages.clone(),
182                        is_complete: false,
183                        next_since: None, // Could be enhanced for pagination
184                    };
185
186                    if let Err(e) = sender.send(response) {
187                        warn!("Relay service closed, stopping catch-up: {}", e);
188                        return Err(crate::MessageStoreError::Internal(format!(
189                            "Relay service closed, stopping catch-up: {e}",
190                        )));
191                    }
192                    messages.clear();
193                }
194            }
195            Err(e) => {
196                error!("Error in catch-up stream: {}", e);
197                break;
198            }
199        }
200    }
201
202    // Always send a completion response, even if there are no messages
203    let response = CatchUpResponse {
204        request_id: request.request_id,
205        filter: request.filter.clone(),
206        messages,
207        is_complete: true,
208        next_since: None,
209    };
210
211    if let Err(e) = sender.send(response) {
212        warn!("Relay service closed during final catch-up send: {}", e);
213        return Err(crate::MessageStoreError::Internal(format!(
214            "Relay service closed during final catch-up send: {e}"
215        )));
216    }
217
218    info!(
219        "Catch-up request completed for request_id: {}",
220        request.request_id
221    );
222    Ok(())
223}
224
225impl MessageServiceRpc for MessagesRpcService {
226    async fn publish(
227        self,
228        _context: ::tarpc::context::Context,
229        message: MessageFull,
230    ) -> Result<PublishResult, MessageError> {
231        self.store
232            .store_message(&message)
233            .await
234            .map_err(MessageError::from)
235    }
236
237    async fn message(
238        self,
239        _context: ::tarpc::context::Context,
240        id: MessageId,
241    ) -> Result<Option<MessageFull>, MessageError> {
242        self.store
243            .get_message(id.as_bytes())
244            .await
245            .map_err(MessageError::from)
246    }
247
248    async fn user_data(
249        self,
250        _context: ::tarpc::context::Context,
251        author: KeyId,
252        storage_key: StoreKey,
253    ) -> Result<Option<MessageFull>, MessageError> {
254        self.store
255            .get_user_data(author, storage_key)
256            .await
257            .map_err(MessageError::from)
258    }
259
260    async fn check_messages(
261        self,
262        _context: ::tarpc::context::Context,
263        message_ids: Vec<MessageId>,
264    ) -> Result<Vec<Option<String>>, MessageError> {
265        self.store
266            .check_messages(&message_ids)
267            .await
268            .map_err(MessageError::from)
269    }
270
271    async fn subscribe(
272        self,
273        _context: ::tarpc::context::Context,
274        config: SubscriptionConfig,
275    ) -> Result<(), MessageError> {
276        self.abort_subscription_task().await?;
277        {
278            // also update the internal subscription state height for future restarting.
279            let mut subscription = self.subscription.write().await;
280            *subscription = config.clone();
281        }
282        self.start_subscription_task().await?;
283        Ok(())
284    }
285
286    async fn update_filters(
287        self,
288        _context: ::tarpc::context::Context,
289        request: FilterUpdateRequest,
290    ) -> Result<SubscriptionConfig, MessageError> {
291        self.abort_subscription_task().await?;
292        let new_config = {
293            let mut subscription = self.subscription.write().await;
294
295            // Apply filter operations to the current config
296            let updated_filters = &mut subscription.filters;
297            for operation in &request.operations {
298                updated_filters.apply_operation(operation);
299            }
300            subscription.clone()
301        };
302
303        self.start_subscription_task().await?;
304
305        Ok(new_config)
306    }
307
308    async fn catch_up(
309        self,
310        _context: ::tarpc::context::Context,
311        request: CatchUpRequest,
312    ) -> Result<SubscriptionConfig, MessageError> {
313        self.abort_subscription_task().await?;
314        let new_config = {
315            // we are stopping live subscriptions, so we can just return the current subscription state.
316            let mut subscription = self.subscription.write().await; // hold the lock to update the subscription state.
317            let filter = request.filter.clone();
318
319            // Get the response sender channel
320            let response_sender = self.response_sender.clone();
321
322            handle_catch_up_request(self.store.clone(), request, response_sender).await?;
323
324            // we apply the new filters
325            subscription
326                .filters
327                .apply_operation(&FilterOperation::Add(vec![filter]));
328            subscription.clone()
329        };
330        self.start_subscription_task().await?;
331        Ok(new_config)
332    }
333}