1use std::sync::Arc;
2
3use futures_util::Stream;
4use redis::{
5 aio::{ConnectionManager, ConnectionManagerConfig},
6 AsyncCommands, SetOptions,
7};
8use tracing::{debug, error, info, trace, warn};
9use zoe_wire_protocol::{
10 Filter, KeyId, MessageFilters, MessageFull, MessageId, PublishResult, StoreKey, Tag,
11};
12
13use crate::error::{MessageStoreError, Result};
14
15const GLOBAL_MESSAGES_STREAM_NAME: &str = "message_stream";
17const MESSAGE_TO_STREAM_ID_PREFIX: &str = "msg_stream_id:";
18const ID_KEY: &str = "id";
19const EXPIRATION_KEY: &str = "exp";
20const EVENT_KEY: &str = "event";
21const AUTHOR_KEY: &str = "author";
22const USER_KEY: &str = "user";
23const CHANNEL_KEY: &str = "channel";
24const STREAM_HEIGHT_KEY: &str = "stream_height";
25
26const STORE_MESSAGE_SCRIPT: &str = r#"
28local message_key = KEYS[1]
29local stream_id_key = KEYS[2]
30local global_stream = KEYS[3]
31
32local message_data = ARGV[1]
33local message_id_bytes = ARGV[2]
34local author_bytes = ARGV[3]
35local expiration_time = ARGV[4] -- empty string if no expiration
36local timeout = ARGV[5] -- empty string if no timeout
37
38-- Try to store message with NX (only if not exists)
39local set_result
40if expiration_time ~= "" and timeout ~= "" then
41 set_result = redis.call('SET', message_key, message_data, 'EX', timeout, 'NX')
42else
43 set_result = redis.call('SET', message_key, message_data, 'NX')
44end
45
46-- If message already exists, return existing stream ID
47if not set_result then
48 local existing_stream_id = redis.call('GET', stream_id_key)
49 if existing_stream_id then
50 return {'EXISTS', existing_stream_id}
51 else
52 return {'ERROR', 'Message exists but no stream ID mapping found'}
53 end
54end
55
56-- Message is new - add to global stream
57local xadd_args = {global_stream, '*', 'id', message_id_bytes, 'author', author_bytes}
58
59-- Add expiration if provided
60if expiration_time ~= "" then
61 table.insert(xadd_args, 'exp')
62 -- Decode hex-encoded expiration time back to bytes
63 local exp_bytes = {}
64 for i = 1, #expiration_time, 2 do
65 local byte = expiration_time:sub(i, i+1)
66 table.insert(exp_bytes, string.char(tonumber(byte, 16)))
67 end
68 table.insert(xadd_args, table.concat(exp_bytes))
69end
70
71-- Add tags from remaining ARGV (starting at index 6)
72for i = 6, #ARGV, 2 do
73 if ARGV[i] and ARGV[i+1] then
74 table.insert(xadd_args, ARGV[i]) -- tag key
75 table.insert(xadd_args, ARGV[i+1]) -- tag value
76 end
77end
78
79local stream_id = redis.call('XADD', unpack(xadd_args))
80
81-- Store the mapping from message-id to stream-id
82redis.call('SET', stream_id_key, stream_id)
83
84return {'STORED', stream_id}
85"#;
86
87pub type GlobalStreamHeight = String;
88pub type LocalStreamHeight = String;
89pub type CatchUpItem = (MessageFull, (GlobalStreamHeight, LocalStreamHeight));
90pub type GlobalStreamItem = (Option<MessageFull>, GlobalStreamHeight);
91
92#[derive(Clone)]
94pub struct RedisMessageStorage {
95 pub conn: Arc<tokio::sync::Mutex<ConnectionManager>>,
96 pub client: redis::Client,
97}
98
99impl RedisMessageStorage {
101 async fn get_inner<R: redis::FromRedisValue>(&self, id: &str) -> Result<Option<R>> {
102 info!("Reading: {id}");
103 let mut conn = self.conn.lock().await;
104
105 return conn.get(id).await.map_err(MessageStoreError::Redis);
106 }
107 async fn get_inner_raw(&self, id: &str) -> Result<Option<Vec<u8>>> {
109 self.get_inner::<Vec<u8>>(id).await
110 }
111 async fn get_inner_full(&self, id: &str) -> Result<Option<MessageFull>> {
113 let mut conn = self.conn.lock().await.clone();
114 Self::get_full(&mut conn, id).await
115 }
116
117 async fn get_message_full(
118 conn: &mut ConnectionManager,
119 id: &[u8],
120 ) -> Result<Option<MessageFull>> {
121 let message_id = hex::encode(id);
122 Self::get_full(conn, &message_id).await
123 }
124
125 async fn get_full(conn: &mut ConnectionManager, id: &str) -> Result<Option<MessageFull>> {
126 let Some(value): Option<Vec<u8>> = conn.get(id).await? else {
127 return Ok(None);
128 };
129
130 match MessageFull::from_storage_value(&value) {
133 Ok(message) => Ok(Some(message)),
134 Err(e) => {
135 tracing::warn!("Failed to deserialize message {}: {}. Skipping corrupted/incompatible message.", id, e);
136 Ok(None)
137 }
138 }
139 }
140
141 async fn add_to_index_stream(
142 conn: &mut ConnectionManager,
143 stream_name: &str,
144 message_id: &[u8],
145 stream_height: &str,
146 expiration_time: Option<u64>,
147 ) -> Result<String> {
148 let mut channel_xadd = redis::cmd("XADD");
150 channel_xadd
151 .arg(stream_name)
152 .arg("*") .arg(ID_KEY)
154 .arg(message_id)
155 .arg(STREAM_HEIGHT_KEY)
156 .arg(stream_height); if let Some(expiration_time) = expiration_time {
159 channel_xadd
160 .arg(EXPIRATION_KEY)
161 .arg(expiration_time.to_le_bytes().to_vec());
162 }
163
164 let tags_stream_id: String = channel_xadd
166 .query_async(conn)
167 .await
168 .map_err(MessageStoreError::Redis)?;
169
170 debug!(
171 "Added message {} to stream {}",
172 hex::encode(message_id),
173 stream_name
174 );
175
176 Ok(tags_stream_id)
177 }
178}
179
180type RedisStreamResult = Vec<(String, Vec<(String, Vec<(Vec<u8>, Vec<u8>)>)>)>;
181
182impl RedisMessageStorage {
183 fn is_expired_from_timestamp(expiration_time: u64, current_time: u64) -> bool {
187 expiration_time < current_time
188 }
189
190 pub async fn new(redis_url: String) -> Result<Self> {
192 debug!("Connecting to Redis at {}", redis_url);
193 let client = redis::Client::open(redis_url).map_err(MessageStoreError::Redis)?;
194 trace!("Starting connection manager");
195
196 let mut conn_manager = ConnectionManager::new_with_config(
197 client.clone(),
198 ConnectionManagerConfig::default()
199 .set_connection_timeout(std::time::Duration::from_secs(5)),
200 )
201 .await
202 .map_err(MessageStoreError::Redis)?;
203
204 conn_manager.ping::<()>().await?;
206
207 trace!("Connection manager started");
208
209 Ok(Self {
210 conn: Arc::new(tokio::sync::Mutex::new(conn_manager)),
211 client,
212 })
213 }
214
215 pub async fn get_message_raw(&self, id: &[u8]) -> Result<Option<Vec<u8>>> {
217 let message_id = hex::encode(id);
218 self.get_inner_raw(&message_id).await
219 }
220 pub async fn store_message(&self, message: &MessageFull) -> Result<PublishResult> {
228 let mut conn = { self.conn.lock().await.clone() };
229
230 let current_time = std::time::SystemTime::now()
232 .duration_since(std::time::UNIX_EPOCH)?
233 .as_secs();
234
235 if message.is_expired(current_time) {
236 debug!("Message is expired, ignoring to store");
237 return Ok(PublishResult::Expired);
238 }
239
240 let (ex_time, timeout_str) = if let Some(timeout) = message.storage_timeout() {
242 if timeout > 0 {
243 let expiration_time = message.when().saturating_add(timeout);
244 (Some(expiration_time), timeout.to_string())
245 } else {
246 (None, String::new())
247 }
248 } else {
249 (None, String::new())
250 };
251
252 let storage_value = message
254 .storage_value()
255 .map_err(|e| MessageStoreError::Serialization(e.to_string()))?;
256
257 let msg_id_bytes = message.id().as_bytes();
258 let message_id = hex::encode(msg_id_bytes);
259
260 let stream_id_key = format!("{MESSAGE_TO_STREAM_ID_PREFIX}{message_id}");
262
263 let mut script_args = vec![
265 storage_value.to_vec(), msg_id_bytes.to_vec(), message.author().id().as_bytes().to_vec(), ];
269 script_args.push(
270 ex_time
271 .map_or(String::new(), |t| hex::encode(t.to_le_bytes()))
272 .into_bytes(),
273 ); script_args.push(timeout_str.into_bytes()); for tag in message.tags() {
278 match tag {
279 Tag::Event { id: event_id, .. } => {
280 script_args.push(EVENT_KEY.as_bytes().to_vec());
281 script_args.push(event_id.as_bytes().to_vec());
282 }
283 Tag::User { id: user_id, .. } => {
284 script_args.push(USER_KEY.as_bytes().to_vec());
285 script_args.push(user_id.as_bytes().to_vec());
286 }
287 Tag::Channel { id: channel_id, .. } => {
288 script_args.push(CHANNEL_KEY.as_bytes().to_vec());
289 script_args.push(channel_id.clone());
290 }
291 Tag::Protected => {
292 }
294 }
295 }
296
297 let script_result: Vec<String> = redis::Script::new(STORE_MESSAGE_SCRIPT)
299 .key(&message_id) .key(&stream_id_key) .key(GLOBAL_MESSAGES_STREAM_NAME) .arg(script_args)
303 .invoke_async(&mut conn)
304 .await
305 .map_err(MessageStoreError::Redis)?;
306
307 let (result_type, stream_id) = match script_result.as_slice() {
308 [result_type, stream_id] => (result_type, stream_id),
309 _ => {
310 return Err(MessageStoreError::Internal(
311 "Invalid response from store_message script".to_string(),
312 ))
313 }
314 };
315
316 let publish_result = match result_type.as_str() {
317 "EXISTS" => PublishResult::AlreadyExists {
318 global_stream_id: stream_id.clone(),
319 },
320 "STORED" => PublishResult::StoredNew {
321 global_stream_id: stream_id.clone(),
322 },
323 "ERROR" => {
324 error!("Script error: {}", stream_id);
325 return Err(MessageStoreError::Internal(stream_id.clone()));
326 }
327 _ => {
328 return Err(MessageStoreError::Internal(format!(
329 "Unknown script result type: {result_type}"
330 )))
331 }
332 };
333
334 let PublishResult::StoredNew {
335 ref global_stream_id,
336 } = publish_result
337 else {
338 return Ok(publish_result);
339 };
340 Self::add_to_index_stream(
345 &mut conn,
346 &format!("author:{}:stream", hex::encode(message.author().id())),
347 msg_id_bytes,
348 global_stream_id,
349 ex_time,
350 )
351 .await?;
352
353 for tag in message.tags() {
355 let tags_stream = match tag {
356 Tag::Channel { id: channel_id, .. } => {
357 format!("channel:{}:stream", hex::encode(channel_id))
358 }
359 Tag::Event { id, .. } => {
360 format!("event:{}:stream", hex::encode(id.as_bytes()))
361 }
362 Tag::User { id, .. } => {
363 format!("user:{}:stream", hex::encode(id))
364 }
365 _ => continue, };
367
368 Self::add_to_index_stream(
369 &mut conn,
370 &tags_stream,
371 msg_id_bytes,
372 global_stream_id,
373 ex_time,
374 )
375 .await?;
376 }
377
378 if let Some(storage_key) = message.store_key() {
380 let author_id = hex::encode(message.author().id());
381 let storage_key_enc: u32 = storage_key.into();
382 let storage_id = format!("{author_id}:{storage_key_enc}");
383
384 info!(
385 redis_key = storage_id,
386 message_id = message_id,
387 "storing for key"
388 );
389
390 if let Some(previous_id) = conn
391 .set_options(&storage_id, &message_id, SetOptions::default().get(true))
392 .await?
393 {
394 let mut previous_id: String = previous_id;
396 'retry: loop {
397 info!(redis_key = previous_id, "checking previous message");
398 let Some(previous_message) = Self::get_full(&mut conn, &previous_id).await?
399 else {
400 info!(
401 redis_key = storage_id,
402 "No previous message found, all good"
403 );
404 break 'retry;
405 };
406 info!("previous message found. comparing timestamps");
407 let prev_when = previous_message.when();
408 let msg_when = message.when();
409 if msg_when > prev_when {
410 info!(redis_key = previous_id, "We are newer, ignore");
412 break 'retry;
413 } else if prev_when == msg_when {
414 if previous_message.signature() < message.signature() {
416 info!(redis_key = previous_id, "We are older, ignore");
418 break 'retry;
419 }
420 }
421
422 info!(
423 redis_key = previous_id,
424 "The previous message needs to be restored"
425 );
426
427 let Some(new_previous_id): Option<String> = conn
429 .set_options(&storage_id, &previous_id, SetOptions::default().get(true))
430 .await?
431 else {
432 warn!("Restored without it being set. curious...");
434 break 'retry;
435 };
436
437 if new_previous_id == previous_id || new_previous_id == message_id {
438 break 'retry;
440 } else {
441 previous_id = new_previous_id;
442 }
443 }
444 }
445 }
446
447 Ok(publish_result)
448 }
449
450 pub async fn check_messages(&self, message_ids: &[MessageId]) -> Result<Vec<Option<String>>> {
455 if message_ids.is_empty() {
456 return Ok(vec![]);
457 }
458
459 let mut conn = { self.conn.lock().await.clone() };
460
461 let mut pipe = redis::pipe();
462 let stream_id_keys: Vec<String> = message_ids
463 .iter()
464 .map(|id| {
465 format!(
466 "{MESSAGE_TO_STREAM_ID_PREFIX}{}",
467 hex::encode(id.as_bytes())
468 )
469 })
470 .collect();
471
472 for stream_id_key in &stream_id_keys {
474 pipe.get(stream_id_key);
475 }
476
477 let pipeline_results: Vec<Option<String>> = pipe
479 .query_async(&mut conn)
480 .await
481 .map_err(MessageStoreError::Redis)?;
482
483 Ok(pipeline_results)
484 }
485
486 pub async fn get_message(&self, id: &[u8]) -> Result<Option<MessageFull>> {
488 let mut conn = { self.conn.lock().await.clone() };
489 Self::get_message_full(&mut conn, id).await
490 }
491
492 pub async fn catch_up<'a>(
494 &'a self,
495 filter: &Filter,
496 since: Option<String>,
497 ) -> Result<impl Stream<Item = Result<CatchUpItem>> + 'a> {
498 let channel_stream = match filter {
499 Filter::Channel(channel_id) => format!("channel:{}:stream", hex::encode(channel_id)),
500 Filter::Event(event_id) => format!("event:{}:stream", hex::encode(event_id.as_bytes())),
501 Filter::User(user_id) => format!("user:{}:stream", hex::encode(user_id)),
502 Filter::Author(author_id) => format!("author:{}:stream", hex::encode(author_id)),
503 };
504
505 let mut conn = {
506 self.client
510 .get_connection_manager()
511 .await
512 .map_err(MessageStoreError::Redis)?
513 };
514 let mut fetch_con = {
515 self.client
517 .get_connection_manager()
518 .await
519 .map_err(MessageStoreError::Redis)?
520 };
521 let mut last_seen_height = since.unwrap_or_else(|| "0-0".to_string());
522
523 Ok(async_stream::stream! {
524 loop {
525 let mut read = redis::cmd("XREAD");
526
527 read.arg("STREAMS")
528 .arg(&channel_stream)
529 .arg(&last_seen_height);
530
531 let stream_result = match read.query_async(&mut conn).await {
532 Ok(stream_result) => stream_result,
533 Err(e) => {
534 error!(error=?e, "Error reading messages at catch up");
535 yield Err(MessageStoreError::Redis(e));
536 break;
537 }
538 };
539
540 let rows: RedisStreamResult = match redis::from_redis_value(&stream_result) {
542 Ok(rows) => rows,
543 Err(e) => {
544 error!(error=?e, "Error parsing messages at catch up");
545 yield Err(MessageStoreError::Redis(e));
546 break;
547 }
548 };
549
550 if rows.is_empty() {
551 break;
553 }
554
555 for (_, entries) in rows {
556 'messages: for (height, meta) in entries {
557 let mut id = None;
558 last_seen_height = height.clone();
559 let mut stream_height = None;
560
561 'meta: for (key, value) in meta {
562 let key_str = String::from_utf8_lossy(&key);
564
565 match key_str.as_ref() {
567 ID_KEY => {
568 id = Some(value);
569 }
570
571 EXPIRATION_KEY => {
574 let expiration_time = match value.try_into().map(u64::from_le_bytes) {
575 Ok(expiration_time) => expiration_time,
576 Err(e) => {
577 error!(error=?e, "Message has a bad expiration time");
578 continue 'meta;
579 }
580 };
581 let current_time = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs();
582 if Self::is_expired_from_timestamp(expiration_time, current_time) {
583 debug!("Message is expired, ignoring to yield in catch up");
585 continue 'messages;
586 }
587
588 }
589 STREAM_HEIGHT_KEY => {
591 stream_height = Some(String::from_utf8_lossy(&value).to_string());
592 }
593 _ => {
594 }
596 }
597 }
598
599 let Some(msg_id) = id else {
601 error!("Message ID not found in stream info at catch up");
602 continue 'messages;
603 };
604
605
606 let Some(msg_full) = Self::get_message_full(&mut fetch_con, &msg_id).await? else {
607 error!("Message not found in storage at catch up. odd...");
609 continue 'messages;
610 };
611 yield Ok((msg_full, (stream_height.clone().unwrap_or_else(|| "0-0".to_string()), height.clone())));
612 }
613 }
614 }
615 })
616 }
617
618 pub async fn listen_for_messages<'a>(
620 &'a self,
621 filters: &'a MessageFilters,
622 since: Option<String>,
623 limit: Option<usize>,
624 ) -> Result<impl Stream<Item = Result<GlobalStreamItem>> + 'a> {
625 if filters.is_empty() {
626 return Err(MessageStoreError::EmptyFilters);
627 }
628
629 let mut conn = {
630 self.client
634 .get_connection_manager()
635 .await
636 .map_err(MessageStoreError::Redis)?
637 };
638 let mut fetch_con = {
639 self.client
641 .get_connection_manager()
642 .await
643 .map_err(MessageStoreError::Redis)?
644 };
645 let mut since = since;
646 let mut block = false;
647
648 Ok(async_stream::stream! {
649 loop {
650 let mut read = redis::cmd("XREAD");
651
652 if block {
653 read.arg("BLOCK").arg(10000);
654 } else {
655 match &limit {
656 Some(l) if *l > 0 => {
657 read.arg("COUNT").arg(l);
658 }
659 _ => {}
660 }
661 }
662 read.arg("STREAMS").arg(GLOBAL_MESSAGES_STREAM_NAME);
663 if let Some(since) = &since {
664 read.arg(since);
665 } else {
666 read.arg("0-0"); }
668
669 debug!("redis listening for messages with filters: {:?}", filters);
670
671 let stream_result = match read.query_async(&mut conn).await {
672 Ok(stream_result) => stream_result,
673 Err(e) => {
674 error!("Error reading messages: {:?}", e);
675 yield Err(MessageStoreError::Redis(e));
676 break;
677 }
678 };
679
680 let rows: RedisStreamResult = match redis::from_redis_value(&stream_result) {
682 Ok(rows) => rows,
683 Err(e) => {
684 error!("Error parsing messages: {:?}", e);
685 yield Err(MessageStoreError::Redis(e));
686 break;
687 }
688 };
689
690 if rows.is_empty() {
691 if !block {
693 block = true;
694 info!("Switching to blocking mode");
695 yield Ok((None, since.clone().unwrap_or_else(|| "0-0".to_string())));
697 }
698 continue;
699 }
700
701 let mut did_yield = false;
705 let mut last_seen_height = since.clone();
706
707 for (_, entries) in rows {
708 'messages: for (height, meta) in entries {
709 let mut should_yield = false;
710 let mut id = None;
711 last_seen_height = Some(height.clone());
712
713 'meta: for (key, value) in meta {
714 let key_str = String::from_utf8_lossy(&key);
716 since = Some(height.clone());
717
718 match key_str.as_ref() {
720 ID_KEY => {
721 id = Some(value);
722 }
724
725 EXPIRATION_KEY => {
728 let expiration_time = match value.try_into().map(u64::from_le_bytes) {
729 Ok(expiration_time) => expiration_time,
730 Err(e) => {
731 error!("Message has a bad expiration time: {:?}", e);
732 continue 'meta;
733 }
734 };
735 let current_time = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs();
736 if Self::is_expired_from_timestamp(expiration_time, current_time) {
737 debug!("Message is expired, ignoring to yield in regular listen");
739 continue 'messages;
740 }
741
742 }
743
744 EVENT_KEY => {
746 if let Some(filter_list) = &filters.filters {
747 for filter in filter_list {
748 if let Filter::Event(event_id) = filter {
749 if value == event_id.as_bytes() {
750 should_yield = true;
751 break 'meta;
752 }
753 }
754 }
755 }
756 }
757 AUTHOR_KEY => {
758 if let Some(filter_list) = &filters.filters {
759 for filter in filter_list {
760 if let Filter::Author(author_id) = filter {
761 if value == author_id.as_bytes() {
762 should_yield = true;
763 break 'meta;
764 }
765 }
766 }
767 }
768 }
769 USER_KEY => {
770 if let Some(filter_list) = &filters.filters {
771 for filter in filter_list {
772 if let Filter::User(user_id) = filter {
773 if value == user_id.as_bytes() {
774 should_yield = true;
775 break 'meta;
776 }
777 }
778 }
779 }
780 }
781 CHANNEL_KEY => {
782 if let Some(filter_list) = &filters.filters {
783 for filter in filter_list {
784 if let Filter::Channel(channel_id) = filter {
785 if value == channel_id.as_slice() {
786 should_yield = true;
787 break 'meta;
788 }
789 }
790 }
791 }
792 }
793 _ => {
794 }
796 }
797 }
798
799 if should_yield {
801 let Some(msg_id) = id else {
802 error!("Message ID not found in stream info");
803 continue 'messages;
804 };
805 info!("Message ID found in stream info: {}", hex::encode(&msg_id));
806 let Some(msg_full) = Self::get_message_full(&mut fetch_con, &msg_id).await? else {
807 tracing::debug!("Message {} not found or corrupted, skipping", hex::encode(&msg_id));
809 continue 'messages;
810 };
811 yield Ok((Some(msg_full), height.clone()));
812 did_yield = true;
813 }
814 }
815 }
816
817 if !did_yield {
818 info!("No messages matched filters, yielding empty");
819 yield Ok((None, last_seen_height.clone().unwrap_or_else(|| "0-0".to_string())));
820 }
821 }
822 })
823 }
824
825 pub async fn get_user_data(
826 &self,
827 user_id: KeyId,
828 key: StoreKey,
829 ) -> Result<Option<MessageFull>> {
830 let message_id = hex::encode(user_id);
831 let storage_key: u32 = key.into();
832 let target_key = format!("{message_id}:{storage_key}");
833 let Some(message_id) = self.get_inner::<String>(&target_key).await? else {
834 return Ok(None);
835 };
836 self.get_inner_full(&message_id).await
837 }
838}