zoe_wire_protocol/
streaming.rs

1use serde::{Deserialize, Serialize};
2use std::hash::Hash as StdHash;
3use tarpc::{ClientMessage, Response};
4
5use crate::{KeyId, MessageFull, MessageId, StoreKey, Tag};
6
7/// Unified filter type for different kinds of message filtering
8#[derive(Clone, PartialEq, Eq, StdHash, Serialize, Deserialize)]
9pub enum Filter {
10    /// Filter by message author
11    Author(KeyId),
12    /// Filter by channel ID
13    Channel(Vec<u8>),
14    /// Filter by event ID
15    Event(MessageId),
16    /// Filter by user key (for user-targeted messages)
17    User(KeyId),
18}
19
20impl std::fmt::Debug for Filter {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        match self {
23            Filter::Author(id) => write!(f, "Author(#{})", hex::encode(id)),
24            Filter::Channel(id) => write!(f, "Channel(#{})", hex::encode(id)),
25            Filter::Event(id) => write!(f, "Event(#{})", hex::encode(id.as_bytes())),
26            Filter::User(id) => write!(f, "User(#{})", hex::encode(id)),
27        }
28    }
29}
30
31impl Filter {
32    pub fn matches(&self, message: &MessageFull) -> bool {
33        if let Filter::Author(author) = self {
34            return message.author().id() == *author;
35        }
36        for t in message.tags() {
37            match (t, &self) {
38                (Tag::Channel { id, .. }, Filter::Channel(channel)) => {
39                    if id == channel {
40                        return true;
41                    }
42                }
43                (Tag::Event { id, .. }, Filter::Event(event)) => {
44                    if id == event {
45                        return true;
46                    }
47                }
48                (Tag::User { id, .. }, Filter::User(user)) => {
49                    if id == user {
50                        return true;
51                    }
52                }
53                _ => {}
54            }
55        }
56        false
57    }
58}
59
60impl From<&Tag> for Filter {
61    fn from(tag: &Tag) -> Self {
62        match tag {
63            Tag::Channel { id, .. } => Filter::Channel(id.clone()),
64            Tag::Event { id, .. } => Filter::Event(*id),
65            Tag::User { id, .. } => Filter::User(*id),
66            Tag::Protected => {
67                unreachable!("There is no filtering for protected tags. Programmer Error.")
68            }
69        }
70    }
71}
72
73impl From<Tag> for Filter {
74    fn from(tag: Tag) -> Self {
75        Filter::from(&tag)
76    }
77}
78
79/// Message filtering criteria for querying stored messages
80#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
81pub struct MessageFilters {
82    pub filters: Option<Vec<Filter>>,
83}
84
85/// Type-safe filter operations using unified Filter type
86#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
87pub enum FilterOperation {
88    /// Add filters to the active set
89    Add(Vec<Filter>),
90    /// Remove specific filters from the active set
91    Remove(Vec<Filter>),
92    /// Replace all filters (forces restart - use sparingly)
93    ReplaceAll(Vec<Filter>),
94    /// Clear all filters
95    Clear,
96}
97
98impl FilterOperation {
99    /// Add channels to the filter
100    pub fn add_channels(channels: Vec<Vec<u8>>) -> Self {
101        Self::Add(channels.into_iter().map(Filter::Channel).collect())
102    }
103
104    /// Remove channels from the filter
105    pub fn remove_channels(channels: Vec<Vec<u8>>) -> Self {
106        Self::Remove(channels.into_iter().map(Filter::Channel).collect())
107    }
108
109    /// Add authors to the filter
110    pub fn add_authors(authors: Vec<KeyId>) -> Self {
111        Self::Add(authors.into_iter().map(Filter::Author).collect())
112    }
113
114    /// Remove authors from the filter
115    pub fn remove_authors(authors: Vec<KeyId>) -> Self {
116        Self::Remove(authors.into_iter().map(Filter::Author).collect())
117    }
118
119    /// Add events to the filter
120    pub fn add_events(events: Vec<MessageId>) -> Self {
121        Self::Add(events.into_iter().map(Filter::Event).collect())
122    }
123
124    /// Remove events from the filter
125    pub fn remove_events(events: Vec<MessageId>) -> Self {
126        Self::Remove(events.into_iter().map(Filter::Event).collect())
127    }
128
129    /// Add users to the filter
130    pub fn add_users(users: Vec<KeyId>) -> Self {
131        Self::Add(users.into_iter().map(Filter::User).collect())
132    }
133
134    /// Remove users from the filter
135    pub fn remove_users(users: Vec<KeyId>) -> Self {
136        Self::Remove(users.into_iter().map(Filter::User).collect())
137    }
138
139    /// Replace all filters
140    pub fn replace_all(filters: Vec<Filter>) -> Self {
141        Self::ReplaceAll(filters)
142    }
143
144    /// Clear all filters
145    pub fn clear() -> Self {
146        Self::Clear
147    }
148}
149
150impl MessageFilters {
151    pub fn is_empty(&self) -> bool {
152        self.filters.as_ref().is_none_or(|f| f.is_empty())
153    }
154
155    /// Apply a type-safe filter operation to this filter set
156    pub fn apply_operation(&mut self, operation: &FilterOperation) {
157        match operation {
158            FilterOperation::Add(new_filters) => {
159                let filter_vec = self.filters.get_or_insert_with(Vec::new);
160                for filter in new_filters {
161                    if !filter_vec.contains(filter) {
162                        filter_vec.push(filter.clone());
163                    }
164                }
165            }
166            FilterOperation::Remove(filters_to_remove) => {
167                if let Some(filter_vec) = self.filters.as_mut() {
168                    filter_vec.retain(|existing| !filters_to_remove.contains(existing));
169                    if filter_vec.is_empty() {
170                        self.filters = None;
171                    }
172                }
173            }
174            FilterOperation::ReplaceAll(new_filters) => {
175                if new_filters.is_empty() {
176                    self.filters = None;
177                } else {
178                    self.filters = Some(new_filters.clone());
179                }
180            }
181            FilterOperation::Clear => {
182                self.filters = None;
183            }
184        }
185    }
186}
187
188/// Messages sent over the streaming protocol
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub enum StreamMessage {
191    /// A new message received that matches our filter
192    MessageReceived {
193        /// Blake3 hash of the message
194        message: Box<MessageFull>,
195        /// Redis stream position
196        stream_height: String,
197    },
198    /// We have just received a stream height update
199    /// but our filter didn't apply here
200    /// Indicator that we are live now and we have
201    /// received all messages up to this point this
202    /// server knows about
203    StreamHeightUpdate(String),
204}
205
206#[derive(Debug, Clone, Default, Serialize, Deserialize)]
207pub struct SubscriptionConfig {
208    pub filters: MessageFilters,
209    pub since: Option<String>,
210    pub limit: Option<usize>,
211}
212
213/// Message store service for message interaction operations
214#[tarpc::service]
215pub trait MessageService {
216    // Core message operations
217    async fn publish(message: MessageFull) -> Result<PublishResult, MessageError>;
218
219    /// Retrieve a specific message by its ID
220    async fn message(id: MessageId) -> Result<Option<MessageFull>, MessageError>;
221
222    /// Retrieve a specific user's data by their key and storage key
223    async fn user_data(
224        author: KeyId,
225        storage_key: StoreKey,
226    ) -> Result<Option<MessageFull>, MessageError>;
227
228    // Bulk operations for sync
229    /// Check which messages the server already has and return their global stream IDs.
230    /// Returns a vec of `Option<String>` in the same order as the input, where:
231    /// - `Some(stream_id)` means the server has the message with that global stream ID
232    /// - `None` means the server doesn't have this message yet
233    async fn check_messages(
234        message_ids: Vec<MessageId>,
235    ) -> Result<Vec<Option<String>>, MessageError>;
236
237    /// Start the subscription
238    async fn subscribe(config: SubscriptionConfig) -> Result<(), MessageError>; // Returns nothing
239
240    /// Update the running subscription filters with the actions. Returns the now final subscription config.
241    async fn update_filters(
242        request: FilterUpdateRequest,
243    ) -> Result<SubscriptionConfig, MessageError>;
244
245    /// Update the internal subscription and catch up to the latest stream height for the given filter
246    async fn catch_up(request: CatchUpRequest) -> Result<SubscriptionConfig, MessageError>; // Returns catch_up_id for tracking
247}
248
249/// Result type for message operations
250pub type MessageResult<T> = Result<T, MessageError>;
251
252/// Result of publishing a message to the relay
253#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
254pub enum PublishResult {
255    /// Message was newly stored with this global stream ID
256    StoredNew { global_stream_id: String },
257    /// Message already existed at this global stream ID  
258    AlreadyExists { global_stream_id: String },
259    /// Message was expired and not stored
260    Expired,
261}
262
263impl PublishResult {
264    /// Get the global stream ID if available (None for expired messages)
265    pub fn global_stream_id(&self) -> Option<&str> {
266        match self {
267            PublishResult::StoredNew { global_stream_id } => Some(global_stream_id),
268            PublishResult::AlreadyExists { global_stream_id } => Some(global_stream_id),
269            PublishResult::Expired => None,
270        }
271    }
272
273    /// Check if the message was stored (either new or already existed)
274    pub fn was_stored(&self) -> bool {
275        !matches!(self, PublishResult::Expired)
276    }
277}
278
279/// Error types for message operations  
280#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, thiserror::Error)]
281pub enum MessageError {
282    #[error("Message not found: {hash}")]
283    NotFound { hash: String },
284
285    #[error("Invalid message hash: {hash}")]
286    InvalidHash { hash: String },
287
288    #[error("Storage error: {message}")]
289    StorageError { message: String },
290
291    #[error("Serialization error: {message}")]
292    SerializationError { message: String },
293
294    #[error("IO error: {message}")]
295    IoError { message: String },
296
297    #[error("Internal server error: {message}")]
298    InternalError { message: String },
299}
300
301/// Generic filter update request
302#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct FilterUpdateRequest {
304    pub operations: Vec<FilterOperation>,
305}
306
307/// Type-safe catch-up request for historical messages
308#[derive(Debug, Clone, Serialize, Deserialize)]
309pub struct CatchUpRequest {
310    pub filter: Filter,
311    pub since: Option<String>,
312    pub max_messages: Option<usize>,
313    pub request_id: u32,
314}
315
316impl CatchUpRequest {
317    /// Convenience constructor for channel catch-up
318    pub fn for_channel(
319        channel_id: Vec<u8>,
320        since: Option<String>,
321        max_messages: Option<usize>,
322        request_id: u32,
323    ) -> Self {
324        Self {
325            filter: Filter::Channel(channel_id),
326            since,
327            max_messages,
328            request_id,
329        }
330    }
331
332    /// Convenience constructor for author catch-up
333    pub fn for_author(
334        author_key: KeyId,
335        since: Option<String>,
336        max_messages: Option<usize>,
337        request_id: u32,
338    ) -> Self {
339        Self {
340            filter: Filter::Author(author_key),
341            since,
342            max_messages,
343            request_id,
344        }
345    }
346
347    /// Convenience constructor for event catch-up
348    pub fn for_event(
349        event_id: MessageId,
350        since: Option<String>,
351        max_messages: Option<usize>,
352        request_id: u32,
353    ) -> Self {
354        Self {
355            filter: Filter::Event(event_id),
356            since,
357            max_messages,
358            request_id,
359        }
360    }
361
362    /// Convenience constructor for user catch-up
363    pub fn for_user(
364        user_key: KeyId,
365        since: Option<String>,
366        max_messages: Option<usize>,
367        request_id: u32,
368    ) -> Self {
369        Self {
370            filter: Filter::User(user_key),
371            since,
372            max_messages,
373            request_id,
374        }
375    }
376}
377
378/// Catch-up response with historical messages
379#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct CatchUpResponse {
381    pub request_id: u32,
382    pub filter: Filter, // What filter was requested
383    pub messages: Vec<MessageFull>,
384    pub is_complete: bool,          // False if more batches coming
385    pub next_since: Option<String>, // For pagination
386}
387
388/// Simplified request wrapper - now just RPC requests
389/// All subscription, filter updates, and catch-up requests are now handled as RPC calls
390pub type MessagesServiceRequestWrap = ClientMessage<MessageServiceRequest>;
391
392#[derive(Debug, Serialize, Deserialize)]
393pub enum MessageServiceResponseWrap {
394    /// Streaming messages from background listener tasks
395    StreamMessage(StreamMessage),
396
397    /// Catch-up response with batched historical messages from background catch-up tasks
398    CatchUpResponse(CatchUpResponse),
399
400    /// RPC response (includes subscription/filter update acknowledgments)
401    RpcResponse(Box<Response<MessageServiceResponse>>),
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    // Helper function to create test KeyIds from byte arrays
409    fn create_test_verifying_key_id(bytes: &[u8]) -> KeyId {
410        // Generate a proper Ed25519 keypair and use the public key
411        use rand::SeedableRng;
412
413        // Create a simple hash from the input bytes for deterministic generation
414        let mut seed = [0u8; 32];
415        let len = std::cmp::min(bytes.len(), 32);
416        seed[..len].copy_from_slice(&bytes[..len]);
417
418        let mut seed_rng = rand_chacha::ChaCha20Rng::from_seed(seed);
419        let signing_key = ed25519_dalek::SigningKey::generate(&mut seed_rng);
420        let verifying_key = signing_key.verifying_key();
421
422        crate::keys::VerifyingKey::Ed25519(Box::new(verifying_key)).id()
423    }
424
425    #[test]
426    fn test_filter_enum() {
427        // Test that Filter variants work correctly
428        let author = Filter::Author(create_test_verifying_key_id(b"alice"));
429        let channel = Filter::Channel(b"general".to_vec());
430        let event = Filter::Event(MessageId::from_content(b"important"));
431        let user = Filter::User(create_test_verifying_key_id(b"bob"));
432
433        // Test Debug formatting works
434        assert!(format!("{author:?}").contains("Author"));
435        assert!(format!("{channel:?}").contains("Channel"));
436        assert!(format!("{event:?}").contains("Event"));
437        assert!(format!("{user:?}").contains("User"));
438
439        // Test PartialEq
440        assert_eq!(
441            author,
442            Filter::Author(create_test_verifying_key_id(b"alice"))
443        );
444        assert_ne!(author, channel);
445    }
446
447    #[test]
448    fn test_message_filters_default() {
449        let filters = MessageFilters::default();
450        assert!(filters.is_empty());
451        assert!(filters.filters.is_none());
452    }
453
454    #[test]
455    fn test_message_filters_is_empty() {
456        let mut filters = MessageFilters::default();
457        assert!(filters.is_empty());
458
459        // Add some filters
460        filters.filters = Some(vec![Filter::Channel(b"general".to_vec())]);
461        assert!(!filters.is_empty());
462
463        // Clear filters but add authors
464        filters.filters = Some(vec![Filter::Author(create_test_verifying_key_id(b"alice"))]);
465        assert!(!filters.is_empty());
466
467        // Clear all
468        filters.filters = None;
469        assert!(filters.is_empty());
470    }
471
472    #[test]
473    fn test_filter_operations() {
474        // Test convenience constructors
475        let channels = vec![b"general".to_vec(), b"tech".to_vec()];
476        let authors = vec![create_test_verifying_key_id(b"alice")];
477
478        // Test add operations
479        let add_channels = FilterOperation::add_channels(channels.clone());
480        match add_channels {
481            FilterOperation::Add(filters) => {
482                assert_eq!(filters.len(), 2);
483                assert!(filters.contains(&Filter::Channel(b"general".to_vec())));
484                assert!(filters.contains(&Filter::Channel(b"tech".to_vec())));
485            }
486            _ => panic!("Expected Add operation"),
487        }
488
489        let add_authors = FilterOperation::add_authors(authors.clone());
490        match add_authors {
491            FilterOperation::Add(filters) => {
492                assert_eq!(filters.len(), 1);
493                assert!(filters.contains(&Filter::Author(create_test_verifying_key_id(b"alice"))));
494            }
495            _ => panic!("Expected Add operation"),
496        }
497
498        // Test remove operations
499        let remove_channels = FilterOperation::remove_channels(channels.clone());
500        match remove_channels {
501            FilterOperation::Remove(filters) => {
502                assert_eq!(filters.len(), 2);
503                assert!(filters.contains(&Filter::Channel(b"general".to_vec())));
504                assert!(filters.contains(&Filter::Channel(b"tech".to_vec())));
505            }
506            _ => panic!("Expected Remove operation"),
507        }
508
509        // Test clear operation
510        let clear_op = FilterOperation::clear();
511        assert_eq!(clear_op, FilterOperation::Clear);
512    }
513
514    #[test]
515    fn test_filter_operation_apply() {
516        let mut filters = MessageFilters::default();
517        assert!(filters.is_empty());
518
519        // Add some filters
520        let add_op = FilterOperation::add_channels(vec![b"general".to_vec()]);
521        filters.apply_operation(&add_op);
522        assert!(!filters.is_empty());
523        assert_eq!(filters.filters.as_ref().unwrap().len(), 1);
524
525        // Add more filters
526        let add_author_op =
527            FilterOperation::add_authors(vec![create_test_verifying_key_id(b"alice")]);
528        filters.apply_operation(&add_author_op);
529        assert_eq!(filters.filters.as_ref().unwrap().len(), 2);
530
531        // Remove a filter
532        let remove_op = FilterOperation::remove_channels(vec![b"general".to_vec()]);
533        filters.apply_operation(&remove_op);
534        assert_eq!(filters.filters.as_ref().unwrap().len(), 1);
535
536        // Clear all
537        let clear_op = FilterOperation::clear();
538        filters.apply_operation(&clear_op);
539        assert!(filters.is_empty());
540    }
541
542    #[test]
543    fn test_catchup_request_constructors() {
544        // Test channel catch-up request
545        let channel_request = CatchUpRequest::for_channel(
546            b"general".to_vec(),
547            Some("0-0".to_string()),
548            Some(100),
549            123,
550        );
551        assert_eq!(channel_request.filter, Filter::Channel(b"general".to_vec()));
552        assert_eq!(channel_request.request_id, 123);
553
554        // Test author catch-up request
555        let author_key = create_test_verifying_key_id(b"alice");
556        let author_request = CatchUpRequest::for_author(author_key, None, None, 456);
557        assert_eq!(author_request.filter, Filter::Author(author_key));
558        assert_eq!(author_request.request_id, 456);
559    }
560}