1use crate::RedisMessageStorage;
2use futures::StreamExt;
3use std::sync::Arc;
4use tokio::sync::{mpsc, RwLock};
5use tracing::{error, info, warn};
6use zoe_wire_protocol::{
7 CatchUpRequest, CatchUpResponse, FilterOperation, FilterUpdateRequest, KeyId, MessageError,
8 MessageFilters, MessageFull, MessageId, MessageService as MessageServiceRpc, PublishResult,
9 StoreKey, StreamMessage, SubscriptionConfig,
10};
11
12#[derive(Clone)]
13pub struct MessagesRpcService {
14 pub store: Arc<RedisMessageStorage>,
15 pub stream_sender: mpsc::UnboundedSender<StreamMessage>,
17 pub response_sender: mpsc::UnboundedSender<CatchUpResponse>,
19 pub subscription: Arc<RwLock<SubscriptionConfig>>,
21 pub task_handle: Arc<RwLock<Option<tokio::task::AbortHandle>>>,
23}
24
25impl MessagesRpcService {
26 pub fn new(
27 store: Arc<RedisMessageStorage>,
28 ) -> (
29 mpsc::UnboundedReceiver<StreamMessage>,
30 mpsc::UnboundedReceiver<CatchUpResponse>,
31 Self,
32 ) {
33 let (sub_sender, sub_receiver) = mpsc::unbounded_channel::<StreamMessage>();
35 let (response_sender, response_receiver) = mpsc::unbounded_channel::<CatchUpResponse>();
36 (
37 sub_receiver,
38 response_receiver,
39 Self {
40 store,
41 stream_sender: sub_sender,
42 response_sender,
43 subscription: Arc::new(RwLock::new(SubscriptionConfig::default())),
44 task_handle: Arc::new(RwLock::new(None)),
45 },
46 )
47 }
48
49 async fn start_subscription_task(&self) -> Result<(), crate::MessageStoreError> {
50 let config = self.subscription.read().await;
51
52 let filters = config.filters.clone();
54 if filters.is_empty() {
55 self.abort_subscription_task().await?;
58 return Ok(());
59 }
60
61 let stream_sender = self.stream_sender.clone();
63
64 let since = config.since.clone();
65 let store = self.store.clone();
66 let limit = config.limit;
67 let subscription = self.subscription.clone();
68
69 let task_handle = tokio::spawn(async move {
70 if let Err(e) =
71 subscription_task(store, filters, since, limit, subscription, stream_sender).await
72 {
73 error!(error = ?e, "Subscription task failed");
74 }
75 });
76 self.task_handle
77 .write()
78 .await
79 .replace(task_handle.abort_handle());
80 Ok(())
81 }
82
83 async fn abort_subscription_task(&self) -> Result<(), crate::MessageStoreError> {
84 if let Some(task_handle) = self.task_handle.write().await.take() {
85 task_handle.abort();
86 }
87 Ok(())
88 }
89}
90
91async fn subscription_task(
93 service: Arc<RedisMessageStorage>,
94 filters: MessageFilters,
95 since: Option<String>,
96 limit: Option<usize>,
97 subscription: Arc<RwLock<SubscriptionConfig>>,
98 sender: mpsc::UnboundedSender<StreamMessage>,
99) -> Result<(), crate::MessageStoreError> {
100 let task_id = format!("{:p}", &sender);
101 info!(
102 "🔄 Starting subscription task {} with filters: {:?}",
103 task_id, filters
104 );
105
106 let stream = service.listen_for_messages(&filters, since, limit).await?;
107 info!("Subscription stream created, starting to listen for messages");
108
109 tokio::pin!(stream);
111
112 while let Some(result) = stream.next().await {
113 let to_client = match result {
114 Ok((Some(message), height)) => {
115 tracing::debug!(
116 "📤 Subscription task {} yielding message to client: {}",
117 task_id,
118 hex::encode(message.id().as_bytes())
119 );
120 StreamMessage::MessageReceived {
121 message: Box::new(message),
122 stream_height: height,
123 }
124 }
125 Ok((None, height)) => {
126 StreamMessage::StreamHeightUpdate(height)
128 }
129 Err(e) => {
130 error!("Error in subscription stream: {}", e);
131 break;
132 }
133 };
134
135 let new_height = match &to_client {
136 StreamMessage::MessageReceived { stream_height, .. } => stream_height.clone(),
137 StreamMessage::StreamHeightUpdate(height) => height.clone(),
138 };
139
140 if let Err(error) = sender.send(to_client) {
142 error!(?error, "Relay service closed, stopping subscription");
143 break;
144 }
145
146 {
147 let mut subscription = subscription.write().await;
149 subscription.since = Some(new_height);
150 }
151 }
152
153 info!("Subscription task ended");
154 Ok(())
155}
156
157async fn handle_catch_up_request(
158 service: Arc<RedisMessageStorage>,
159 request: CatchUpRequest,
160 sender: mpsc::UnboundedSender<CatchUpResponse>,
161) -> Result<(), crate::MessageStoreError> {
162 info!("Handling catch-up request: {:?}", request);
163
164 let stream = service.catch_up(&request.filter, request.since).await?;
165
166 tokio::pin!(stream);
167
168 let mut messages = Vec::new();
169 while let Some(result) = stream.next().await {
172 match result {
173 Ok((message, (_global_height, _local_height))) => {
174 messages.push(message);
175
176 if messages.len() >= 10 {
178 let response = CatchUpResponse {
179 request_id: request.request_id,
180 filter: request.filter.clone(),
181 messages: messages.clone(),
182 is_complete: false,
183 next_since: None, };
185
186 if let Err(e) = sender.send(response) {
187 warn!("Relay service closed, stopping catch-up: {}", e);
188 return Err(crate::MessageStoreError::Internal(format!(
189 "Relay service closed, stopping catch-up: {e}",
190 )));
191 }
192 messages.clear();
193 }
194 }
195 Err(e) => {
196 error!("Error in catch-up stream: {}", e);
197 break;
198 }
199 }
200 }
201
202 let response = CatchUpResponse {
204 request_id: request.request_id,
205 filter: request.filter.clone(),
206 messages,
207 is_complete: true,
208 next_since: None,
209 };
210
211 if let Err(e) = sender.send(response) {
212 warn!("Relay service closed during final catch-up send: {}", e);
213 return Err(crate::MessageStoreError::Internal(format!(
214 "Relay service closed during final catch-up send: {e}"
215 )));
216 }
217
218 info!(
219 "Catch-up request completed for request_id: {}",
220 request.request_id
221 );
222 Ok(())
223}
224
225impl MessageServiceRpc for MessagesRpcService {
226 async fn publish(
227 self,
228 _context: ::tarpc::context::Context,
229 message: MessageFull,
230 ) -> Result<PublishResult, MessageError> {
231 self.store
232 .store_message(&message)
233 .await
234 .map_err(MessageError::from)
235 }
236
237 async fn message(
238 self,
239 _context: ::tarpc::context::Context,
240 id: MessageId,
241 ) -> Result<Option<MessageFull>, MessageError> {
242 self.store
243 .get_message(id.as_bytes())
244 .await
245 .map_err(MessageError::from)
246 }
247
248 async fn user_data(
249 self,
250 _context: ::tarpc::context::Context,
251 author: KeyId,
252 storage_key: StoreKey,
253 ) -> Result<Option<MessageFull>, MessageError> {
254 self.store
255 .get_user_data(author, storage_key)
256 .await
257 .map_err(MessageError::from)
258 }
259
260 async fn check_messages(
261 self,
262 _context: ::tarpc::context::Context,
263 message_ids: Vec<MessageId>,
264 ) -> Result<Vec<Option<String>>, MessageError> {
265 self.store
266 .check_messages(&message_ids)
267 .await
268 .map_err(MessageError::from)
269 }
270
271 async fn subscribe(
272 self,
273 _context: ::tarpc::context::Context,
274 config: SubscriptionConfig,
275 ) -> Result<(), MessageError> {
276 self.abort_subscription_task().await?;
277 {
278 let mut subscription = self.subscription.write().await;
280 *subscription = config.clone();
281 }
282 self.start_subscription_task().await?;
283 Ok(())
284 }
285
286 async fn update_filters(
287 self,
288 _context: ::tarpc::context::Context,
289 request: FilterUpdateRequest,
290 ) -> Result<SubscriptionConfig, MessageError> {
291 self.abort_subscription_task().await?;
292 let new_config = {
293 let mut subscription = self.subscription.write().await;
294
295 let updated_filters = &mut subscription.filters;
297 for operation in &request.operations {
298 updated_filters.apply_operation(operation);
299 }
300 subscription.clone()
301 };
302
303 self.start_subscription_task().await?;
304
305 Ok(new_config)
306 }
307
308 async fn catch_up(
309 self,
310 _context: ::tarpc::context::Context,
311 request: CatchUpRequest,
312 ) -> Result<SubscriptionConfig, MessageError> {
313 self.abort_subscription_task().await?;
314 let new_config = {
315 let mut subscription = self.subscription.write().await; let filter = request.filter.clone();
318
319 let response_sender = self.response_sender.clone();
321
322 handle_catch_up_request(self.store.clone(), request, response_sender).await?;
323
324 subscription
326 .filters
327 .apply_operation(&FilterOperation::Add(vec![filter]));
328 subscription.clone()
329 };
330 self.start_subscription_task().await?;
331 Ok(new_config)
332 }
333}