zoe_client/
rpc_transport.rs

1//! Simple RPC Message Listener over X25519 Encrypted Messages
2//!
3//! This module provides a simple listener that detects ephemeral RPC messages
4//! targeted at this client and decrypts their content, deserializing via postcard.
5
6use crate::error::{ClientError, Result};
7use crate::services::messages::{MessagesService, MessagesStream};
8use ed25519_dalek::{SigningKey, VerifyingKey};
9use futures::{SinkExt, Stream, StreamExt};
10use std::collections::HashMap;
11use std::marker::PhantomData;
12use std::ops::Deref;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15// use std::time::{SystemTime, UNIX_EPOCH}; // Temporarily disabled
16use tarpc::transport::channel::UnboundedChannel;
17use tarpc::{ClientMessage, Response};
18use tokio::select;
19use tokio::task::JoinHandle;
20use tracing::{debug, error};
21use zoe_wire_protocol::{Kind, Message, MessageFull, MessageV0Header, StreamMessage, Tag};
22
23/// RPC message containing both header metadata and deserialized content
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct RpcMessage<T> {
26    /// Message header with sender, timestamp, kind, and tags
27    pub header: MessageV0Header,
28    /// Deserialized message content
29    pub content: T,
30}
31
32/// Simple RPC message listener that detects and decrypts RPC messages
33/// Now specifically for tarpc wrapper types
34pub struct RpcMessageListener<TarpcMsg> {
35    signing_key: SigningKey,
36    messages_stream: MessagesStream,
37    _phantom: PhantomData<TarpcMsg>,
38}
39
40impl<TarpcMsg> RpcMessageListener<TarpcMsg> {
41    pub fn new(signing_key: SigningKey, messages_stream: MessagesStream) -> Self {
42        Self {
43            signing_key,
44            messages_stream,
45            _phantom: PhantomData,
46        }
47    }
48
49    /// Check if this message is an ephemeral RPC message targeted at us
50    fn is_rpc_message_for_us(&self, message: &MessageFull) -> bool {
51        // Check if it's an ephemeral message
52        if !matches!(message.kind(), Kind::Ephemeral(_)) {
53            tracing::debug!("Message is not ephemeral: {:?}", message.kind());
54            return false;
55        }
56
57        // Check if it's targeted at our public key
58        let our_public_key = *self.signing_key.verifying_key().as_bytes();
59        tracing::debug!("Our public key: {}", hex::encode(our_public_key));
60
61        let is_targeted = message.tags().iter().any(|tag| {
62            if let Tag::User { id, .. } = tag {
63                tracing::debug!("Checking tag user ID: {}", hex::encode(id));
64                id.as_bytes() == &our_public_key
65            } else {
66                false
67            }
68        });
69
70        tracing::debug!("Message is targeted at us: {}", is_targeted);
71        is_targeted
72    }
73
74    /// Try to decrypt ephemeral ECDH encrypted content and deserialize as `RpcMessage<TarpcMsg>`
75    fn try_decrypt_and_deserialize_message(
76        &self,
77        message: &MessageFull,
78    ) -> Option<RpcMessage<TarpcMsg>>
79    where
80        TarpcMsg: serde::de::DeserializeOwned,
81    {
82        // Try to decrypt ephemeral ECDH encrypted content
83        let ecdh_content = message.content().as_ephemeral_ecdh()?;
84        tracing::debug!(
85            "Got ephemeral ECDH encrypted content with {} bytes ciphertext",
86            ecdh_content.ciphertext.len()
87        );
88
89        // Decrypt using our private key (ephemeral X25519 public key is stored in the content)
90        tracing::debug!(
91            "Sender ML-DSA public key: {}",
92            hex::encode(message.author().encode())
93        );
94        let decrypted_data = match ecdh_content.decrypt(&self.signing_key) {
95            Ok(data) => {
96                tracing::debug!("Successfully decrypted {} bytes", data.len());
97                data
98            }
99            Err(e) => {
100                tracing::debug!("Failed to decrypt X25519 content: {}", e);
101                return None;
102            }
103        };
104
105        // Try to deserialize the decrypted data using postcard
106        let content = match postcard::from_bytes::<TarpcMsg>(&decrypted_data) {
107            Ok(content) => {
108                tracing::debug!("Successfully deserialized RPC content");
109                content
110            }
111            Err(e) => {
112                tracing::debug!("Failed to deserialize content: {}", e);
113                return None;
114            }
115        };
116
117        // Extract header information from the message
118        let header = match message.message() {
119            Message::MessageV0(msg) => msg.header.clone(),
120        };
121
122        Some(RpcMessage { header, content })
123    }
124}
125
126impl<TarpcMsg> Stream for RpcMessageListener<TarpcMsg>
127where
128    TarpcMsg: serde::de::DeserializeOwned + Unpin,
129{
130    type Item = RpcMessage<TarpcMsg>;
131
132    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
133        loop {
134            // Poll the underlying messages stream
135            let this = self.as_mut().get_mut();
136            match this.messages_stream.poll_recv(cx) {
137                Poll::Ready(Some(stream_message)) => {
138                    match stream_message {
139                        StreamMessage::MessageReceived { message, .. } => {
140                            tracing::debug!(
141                                "RpcMessageListener received message from {}, kind: {:?}, tags: {:?}",
142                                hex::encode(message.author().encode()),
143                                message.kind(),
144                                message.tags()
145                            );
146
147                            // Check if this is an ephemeral message targeted at us
148                            if this.is_rpc_message_for_us(&message) {
149                                tracing::debug!(
150                                    "Message is targeted at us, attempting to decrypt..."
151                                );
152                                // Try to decrypt and deserialize the content
153                                if let Some(deserialized_message) =
154                                    this.try_decrypt_and_deserialize_message(&message)
155                                {
156                                    tracing::debug!(
157                                        "Successfully decrypted and deserialized RPC message"
158                                    );
159                                    return Poll::Ready(Some(deserialized_message));
160                                } else {
161                                    tracing::debug!("Failed to decrypt or deserialize message");
162                                }
163                            } else {
164                                tracing::debug!("Message is not for us or not ephemeral");
165                            }
166                        }
167                        other => {
168                            tracing::debug!(
169                                "RpcMessageListener received non-message stream event: {:?}",
170                                other
171                            );
172                        }
173                    }
174                    // Continue polling if this wasn't a valid RPC message for us
175                    continue;
176                }
177                Poll::Ready(None) => {
178                    // Stream ended
179                    return Poll::Ready(None);
180                }
181                Poll::Pending => {
182                    // No messages available right now
183                    return Poll::Pending;
184                }
185            }
186        }
187    }
188}
189
190type ServiceMaker<Req, Resp> =
191    fn(UnboundedChannel<ClientMessage<Req>, Response<Resp>>) -> JoinHandle<Result<()>>;
192
193pub struct TarpcOverMessagesServer {
194    // Bridge task handle
195    handle: JoinHandle<Result<()>>,
196    rpc_spawn: JoinHandle<Result<()>>,
197}
198
199impl TarpcOverMessagesServer {
200    pub fn new<S, Req, Resp>(
201        mut request_listener: S,
202        _signing_key: SigningKey,
203        _messages_service: MessagesService,
204        service_maker: ServiceMaker<Req, Resp>,
205    ) -> Self
206    where
207        Req: serde::de::DeserializeOwned + Unpin + Send + Sync + 'static,
208        Resp: serde::Serialize + Unpin + Send + Sync + 'static,
209        S: Stream<Item = RpcMessage<ClientMessage<Req>>> + Unpin + Send + Sync + 'static,
210    {
211        // Create tarpc transport channel - just like messages.rs
212        let (mut client_transport, server_transport) = tarpc::transport::channel::unbounded();
213        let rpc_spawn = service_maker(server_transport);
214
215        let mut target_public_keys = HashMap::new();
216
217        // Bridge task: bidirectional bridge like messages.rs select! loop
218        let handle = tokio::spawn(async move {
219            loop {
220                tokio::select! {
221                    // Incoming: RpcMessageListener -> tarpc server
222                    rpc_request = request_listener.next() => {
223                        match rpc_request {
224                            Some(RpcMessage { header, content }) => {
225                                debug!("📨 Bridge forwarding request to tarpc server from {}",
226                                       hex::encode(header.sender.encode()));
227                                match &content {
228                                    ClientMessage::Request(request) => {
229                                        target_public_keys.insert(request.id, header.sender);
230                                    }
231                                    ClientMessage::Cancel { request_id, .. } => {
232                                        target_public_keys.remove(request_id);
233                                    }
234                                    _ => {
235                                        error!("Unexpected request type");
236                                    }
237                                }
238
239                                if let Err(e) = client_transport.send(content).await {
240                                    error!("Failed to forward request to tarpc server: {e}");
241                                    break;
242                                }
243                            }
244                            None => {
245                                debug!("Request listener stream ended");
246                                break;
247                            }
248                        }
249                    }
250
251                    // Outgoing: tarpc server -> send_rpc_response
252                    tarpc_response = client_transport.next() => {
253                        match tarpc_response {
254                            Some(Ok(response)) => {
255                                debug!("📤 Bridge sending tarpc response via RPC message");
256                                let Some(_target_public_key) = target_public_keys.remove(&response.request_id) else {
257                                    // This should never happen, but just in case
258                                    error!("Target public key not found for response ID: {}", response.request_id);
259                                    continue;
260                                };
261                                // TODO: Temporarily disabled due to Ed25519/ML-DSA key type mismatch
262                                error!("RPC-over-messages temporarily disabled during ML-DSA migration");
263                                break;
264                            }
265                            Some(Err(e)) => {
266                                error!("tarpc transport error: {e}");
267                                break;
268                            }
269                            None => {
270                                debug!("tarpc server transport closed");
271                                break;
272                            }
273                        }
274                    }
275                }
276            }
277            debug!("RPC bridge server task ending");
278            Ok(())
279        });
280
281        Self { handle, rpc_spawn }
282    }
283
284    /// Check if the bridge is still running
285    pub fn is_running(&self) -> bool {
286        !self.handle.is_finished()
287    }
288
289    pub fn abort(&self) {
290        self.handle.abort();
291        self.rpc_spawn.abort();
292    }
293}
294
295type ClientMaker<C, Req, Resp> = fn(UnboundedChannel<Response<Resp>, ClientMessage<Req>>) -> C;
296
297pub struct TarpcOverMessagesClient<C> {
298    // Bridge task handle
299    handle: JoinHandle<Result<()>>,
300    client: C,
301}
302
303impl<C> Deref for TarpcOverMessagesClient<C> {
304    type Target = C;
305    fn deref(&self) -> &Self::Target {
306        &self.client
307    }
308}
309
310impl<C> TarpcOverMessagesClient<C> {
311    pub fn new<S, Req, Resp>(
312        request_listener: S,
313        signing_key: SigningKey,
314        messages_service: MessagesService,
315        target_public_key: VerifyingKey,
316        client_maker: ClientMaker<C, Req, Resp>,
317    ) -> Self
318    where
319        S: Stream<Item = RpcMessage<Response<Resp>>> + Unpin + Send + Sync + 'static,
320        Req: serde::Serialize + Unpin + Send + Sync + 'static,
321        Resp: serde::de::DeserializeOwned + Unpin + Send + Sync + 'static,
322    {
323        Self::new_with_mapper(
324            request_listener,
325            signing_key,
326            messages_service,
327            target_public_key,
328            client_maker,
329            |rpc_message| rpc_message,
330        )
331    }
332
333    pub fn new_with_mapper<S, Req, Resp, M, T>(
334        mut request_listener: S,
335        signing_key: SigningKey,
336        messages_service: MessagesService,
337        target_public_key: VerifyingKey,
338        client_maker: ClientMaker<C, Req, Resp>,
339        mapper: M,
340    ) -> Self
341    where
342        S: Stream<Item = RpcMessage<Response<Resp>>> + Unpin + Send + Sync + 'static,
343        T: serde::Serialize + Unpin + Send + Sync + 'static,
344        Req: serde::Serialize + Unpin + Send + Sync + 'static,
345        Resp: serde::de::DeserializeOwned + Unpin + Send + Sync + 'static,
346        M: Fn(ClientMessage<Req>) -> T + Send + Sync + 'static,
347    {
348        let (client_transport, mut server_transport) = tarpc::transport::channel::unbounded();
349        let client = client_maker(client_transport);
350
351        let handle = tokio::spawn(async move {
352            loop {
353                select! {
354                    // Incoming: RpcMessageListener coming from the tarpc server
355                    rpc_request = request_listener.next() => {
356                        match rpc_request {
357                            Some(RpcMessage { header, content }) => {
358                                debug!("📨 Bridge forwarding request to tarpc server from {}",
359                                    hex::encode(header.sender.encode()));
360
361                                if let Err(e) = server_transport.send(content).await {
362                                    error!("Failed to forward request to tarpc server: {e}");
363                                    break;
364                                }
365                            }
366                            None => {
367                                debug!("Request listener stream ended");
368                                break;
369                            }
370                        }
371                    }
372                    // Poll for messages from rpc client
373                    rpc_message = server_transport.next() => {
374                        let Some(Ok(rpc_message)) = rpc_message else {
375                            tracing::info!("RPC client closed");
376                            break;
377                        };
378                        // Send RPC message directly since MessagesServiceRequestWrap is now just ClientMessage
379
380                        if let Err(e) = send_tarpc_message(&signing_key, target_public_key, &messages_service, &mapper(rpc_message)).await {
381                            return Err(ClientError::Generic(format!("Send error: {e}")));
382                        }
383                    }
384                }
385            }
386            Ok(())
387        });
388
389        Self { client, handle }
390    }
391
392    /// Check if the bridge is still running
393    pub fn is_running(&self) -> bool {
394        !self.handle.is_finished()
395    }
396
397    pub fn abort(&self) {
398        self.handle.abort();
399    }
400}
401
402/// Internal helper function to send tarpc messages  
403/// Serializes the message using postcard before encryption
404async fn send_tarpc_message<TarpcMsg>(
405    _signing_key: &SigningKey,
406    _target_public_key: VerifyingKey,
407    _messages_service: &MessagesService,
408    _message: &TarpcMsg,
409) -> Result<()>
410where
411    TarpcMsg: serde::Serialize,
412{
413    // TODO: Temporarily disabled due to Ed25519/ML-DSA key type mismatch
414    // The message system now uses ML-DSA keys but RPC transport uses Ed25519 keys
415    // This needs to be redesigned to work with the hybrid architecture
416    Err(ClientError::Generic(
417        "RPC-over-messages temporarily disabled during ML-DSA migration".to_string(),
418    ))
419}