zoe_client/services/
messages.rs

1use std::ops::Deref;
2
3use crate::error::{ClientError, Result};
4use futures::{SinkExt, StreamExt};
5use quinn::Connection;
6use tokio::{
7    io::{AsyncReadExt, AsyncWriteExt},
8    select,
9    sync::mpsc::{UnboundedReceiver, unbounded_channel},
10    task::JoinHandle,
11};
12use zoe_wire_protocol::{
13    CatchUpRequest, CatchUpResponse, FilterUpdateRequest, MessageServiceClient,
14    MessageServiceResponseWrap, MessagesServiceRequestWrap, StreamMessage, SubscriptionConfig,
15    ZoeServices, stream_pair::create_postcard_streams,
16};
17
18pub type MessagesStream = UnboundedReceiver<StreamMessage>;
19pub type CatchUpStream = UnboundedReceiver<CatchUpResponse>;
20
21pub struct MessagesService {
22    rpc_client: MessageServiceClient,
23    handle: JoinHandle<Result<()>>,
24}
25
26impl Deref for MessagesService {
27    type Target = MessageServiceClient;
28    fn deref(&self) -> &Self::Target {
29        &self.rpc_client
30    }
31}
32
33impl MessagesService {
34    pub async fn connect(
35        connection: &Connection,
36    ) -> Result<(Self, (MessagesStream, CatchUpStream))> {
37        // Open bidirectional stream
38        let (mut send, mut recv) = connection.open_bi().await?;
39
40        // Send service ID
41        send.write_u8(ZoeServices::Messages as u8).await?;
42
43        let service_ok = recv.read_u8().await?;
44        if service_ok != 1 {
45            return Err(ClientError::Generic(
46                "Service ID not acknowledged".to_string(),
47            ));
48        }
49
50        let (mut stream, mut sink) = create_postcard_streams::<
51            MessageServiceResponseWrap,
52            MessagesServiceRequestWrap,
53        >(recv, send);
54        let (incoming_tx, incoming_rx) = unbounded_channel::<StreamMessage>();
55        let (catch_up_tx, catch_up_rx) = unbounded_channel::<CatchUpResponse>();
56
57        let (client_transport, mut server_transport) = tarpc::transport::channel::unbounded();
58        let rpc_client = MessageServiceClient::new(Default::default(), client_transport).spawn();
59
60        let handle = tokio::spawn(async move {
61            loop {
62                select! {
63                    // Receive messages from server and forward to client
64                    message_result = stream.next() => {
65                        let Some(incoming_message) = message_result else {
66                            tracing::info!("Stream ended - server closed connection");
67                            break;
68                        };
69                        let inner = match incoming_message {
70                            Ok(msg) => msg,
71                            Err(e) => {
72                                tracing::warn!("Stream error (connection may be closing): {e}");
73                                // Don't return error immediately - let the loop continue to handle graceful shutdown
74                                continue;
75                            }
76                        };
77                        match inner {
78                            MessageServiceResponseWrap::StreamMessage(message) => {
79                                if let Err(e) = incoming_tx.send(message) {
80                                    tracing::warn!("Stream message send failed (receiver may be dropped): {e}");
81                                    // Don't return error - continue processing other messages
82                                }
83                            }
84                            MessageServiceResponseWrap::RpcResponse(response) => {
85                                if let Err(e) = server_transport.send(*response).await {
86                                    tracing::warn!("RPC response send failed (connection may be closing): {e}");
87                                    break; // Break on RPC transport errors as they indicate connection issues
88                                }
89                            }
90                            MessageServiceResponseWrap::CatchUpResponse(catch_up_response) => {
91                                if let Err(e) = catch_up_tx.send(catch_up_response) {
92                                    tracing::warn!("Catch-up response send failed (receiver may be dropped): {e}");
93                                    // Don't return error - continue processing other messages
94                                }
95                            }
96                        }
97                    }
98                    // Poll for messages from rpc client
99                    rpc_message = server_transport.next() => {
100                        let Some(rpc_message) = rpc_message else {
101                            tracing::trace!("RPC client closed");
102                            break;
103                        };
104                        let rpc_message = match rpc_message {
105                            Ok(msg) => msg,
106                            Err(e) => {
107                                tracing::warn!("RPC client error: {e}");
108                                continue;
109                            }
110                        };
111                        // Send RPC message directly since MessagesServiceRequestWrap is now just ClientMessage
112                        if let Err(e) = sink.send(rpc_message).await {
113                            tracing::warn!("Failed to send RPC message (connection may be closing): {e}");
114                            break;
115                        }
116                    }
117                }
118            }
119            Ok(())
120        });
121
122        Ok((Self { rpc_client, handle }, (incoming_rx, catch_up_rx)))
123    }
124
125    pub async fn subscribe(&self, filters: SubscriptionConfig) -> Result<()> {
126        // Use RPC client directly for subscription - returns subscription ID
127        self.rpc_client
128            .subscribe(tarpc::context::current(), filters)
129            .await
130            .map_err(|e| ClientError::Generic(format!("Subscription failed: {e}")))?
131            .map_err(|e| ClientError::Generic(format!("Subscription error: {e}")))
132    }
133
134    pub async fn update_filters(&self, request: FilterUpdateRequest) -> Result<SubscriptionConfig> {
135        self.rpc_client
136            .update_filters(tarpc::context::current(), request)
137            .await
138            .map_err(|e| ClientError::Generic(format!("Update filters failed: {e}")))?
139            .map_err(|e| ClientError::Generic(format!("Update filters error: {e}")))
140    }
141
142    pub async fn catch_up(&self, request: CatchUpRequest) -> Result<SubscriptionConfig> {
143        self.rpc_client
144            .catch_up(tarpc::context::current(), request)
145            .await
146            .map_err(|e| ClientError::Generic(format!("Catch up failed: {e}")))?
147            .map_err(|e| ClientError::Generic(format!("Catch up error: {e}")))
148    }
149
150    /// Check if the service is closed
151    pub fn is_closed(&self) -> bool {
152        self.handle.is_finished()
153    }
154}