1use crate::error::{ClientError, Result};
7use crate::services::messages::{MessagesService, MessagesStream};
8use ed25519_dalek::{SigningKey, VerifyingKey};
9use futures::{SinkExt, Stream, StreamExt};
10use std::collections::HashMap;
11use std::marker::PhantomData;
12use std::ops::Deref;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use tarpc::transport::channel::UnboundedChannel;
17use tarpc::{ClientMessage, Response};
18use tokio::select;
19use tokio::task::JoinHandle;
20use tracing::{debug, error};
21use zoe_wire_protocol::{Kind, Message, MessageFull, MessageV0Header, StreamMessage, Tag};
22
23#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct RpcMessage<T> {
26 pub header: MessageV0Header,
28 pub content: T,
30}
31
32pub struct RpcMessageListener<TarpcMsg> {
35 signing_key: SigningKey,
36 messages_stream: MessagesStream,
37 _phantom: PhantomData<TarpcMsg>,
38}
39
40impl<TarpcMsg> RpcMessageListener<TarpcMsg> {
41 pub fn new(signing_key: SigningKey, messages_stream: MessagesStream) -> Self {
42 Self {
43 signing_key,
44 messages_stream,
45 _phantom: PhantomData,
46 }
47 }
48
49 fn is_rpc_message_for_us(&self, message: &MessageFull) -> bool {
51 if !matches!(message.kind(), Kind::Ephemeral(_)) {
53 tracing::debug!("Message is not ephemeral: {:?}", message.kind());
54 return false;
55 }
56
57 let our_public_key = *self.signing_key.verifying_key().as_bytes();
59 tracing::debug!("Our public key: {}", hex::encode(our_public_key));
60
61 let is_targeted = message.tags().iter().any(|tag| {
62 if let Tag::User { id, .. } = tag {
63 tracing::debug!("Checking tag user ID: {}", hex::encode(id));
64 id.as_bytes() == &our_public_key
65 } else {
66 false
67 }
68 });
69
70 tracing::debug!("Message is targeted at us: {}", is_targeted);
71 is_targeted
72 }
73
74 fn try_decrypt_and_deserialize_message(
76 &self,
77 message: &MessageFull,
78 ) -> Option<RpcMessage<TarpcMsg>>
79 where
80 TarpcMsg: serde::de::DeserializeOwned,
81 {
82 let ecdh_content = message.content().as_ephemeral_ecdh()?;
84 tracing::debug!(
85 "Got ephemeral ECDH encrypted content with {} bytes ciphertext",
86 ecdh_content.ciphertext.len()
87 );
88
89 tracing::debug!(
91 "Sender ML-DSA public key: {}",
92 hex::encode(message.author().encode())
93 );
94 let decrypted_data = match ecdh_content.decrypt(&self.signing_key) {
95 Ok(data) => {
96 tracing::debug!("Successfully decrypted {} bytes", data.len());
97 data
98 }
99 Err(e) => {
100 tracing::debug!("Failed to decrypt X25519 content: {}", e);
101 return None;
102 }
103 };
104
105 let content = match postcard::from_bytes::<TarpcMsg>(&decrypted_data) {
107 Ok(content) => {
108 tracing::debug!("Successfully deserialized RPC content");
109 content
110 }
111 Err(e) => {
112 tracing::debug!("Failed to deserialize content: {}", e);
113 return None;
114 }
115 };
116
117 let header = match message.message() {
119 Message::MessageV0(msg) => msg.header.clone(),
120 };
121
122 Some(RpcMessage { header, content })
123 }
124}
125
126impl<TarpcMsg> Stream for RpcMessageListener<TarpcMsg>
127where
128 TarpcMsg: serde::de::DeserializeOwned + Unpin,
129{
130 type Item = RpcMessage<TarpcMsg>;
131
132 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
133 loop {
134 let this = self.as_mut().get_mut();
136 match this.messages_stream.poll_recv(cx) {
137 Poll::Ready(Some(stream_message)) => {
138 match stream_message {
139 StreamMessage::MessageReceived { message, .. } => {
140 tracing::debug!(
141 "RpcMessageListener received message from {}, kind: {:?}, tags: {:?}",
142 hex::encode(message.author().encode()),
143 message.kind(),
144 message.tags()
145 );
146
147 if this.is_rpc_message_for_us(&message) {
149 tracing::debug!(
150 "Message is targeted at us, attempting to decrypt..."
151 );
152 if let Some(deserialized_message) =
154 this.try_decrypt_and_deserialize_message(&message)
155 {
156 tracing::debug!(
157 "Successfully decrypted and deserialized RPC message"
158 );
159 return Poll::Ready(Some(deserialized_message));
160 } else {
161 tracing::debug!("Failed to decrypt or deserialize message");
162 }
163 } else {
164 tracing::debug!("Message is not for us or not ephemeral");
165 }
166 }
167 other => {
168 tracing::debug!(
169 "RpcMessageListener received non-message stream event: {:?}",
170 other
171 );
172 }
173 }
174 continue;
176 }
177 Poll::Ready(None) => {
178 return Poll::Ready(None);
180 }
181 Poll::Pending => {
182 return Poll::Pending;
184 }
185 }
186 }
187 }
188}
189
190type ServiceMaker<Req, Resp> =
191 fn(UnboundedChannel<ClientMessage<Req>, Response<Resp>>) -> JoinHandle<Result<()>>;
192
193pub struct TarpcOverMessagesServer {
194 handle: JoinHandle<Result<()>>,
196 rpc_spawn: JoinHandle<Result<()>>,
197}
198
199impl TarpcOverMessagesServer {
200 pub fn new<S, Req, Resp>(
201 mut request_listener: S,
202 _signing_key: SigningKey,
203 _messages_service: MessagesService,
204 service_maker: ServiceMaker<Req, Resp>,
205 ) -> Self
206 where
207 Req: serde::de::DeserializeOwned + Unpin + Send + Sync + 'static,
208 Resp: serde::Serialize + Unpin + Send + Sync + 'static,
209 S: Stream<Item = RpcMessage<ClientMessage<Req>>> + Unpin + Send + Sync + 'static,
210 {
211 let (mut client_transport, server_transport) = tarpc::transport::channel::unbounded();
213 let rpc_spawn = service_maker(server_transport);
214
215 let mut target_public_keys = HashMap::new();
216
217 let handle = tokio::spawn(async move {
219 loop {
220 tokio::select! {
221 rpc_request = request_listener.next() => {
223 match rpc_request {
224 Some(RpcMessage { header, content }) => {
225 debug!("📨 Bridge forwarding request to tarpc server from {}",
226 hex::encode(header.sender.encode()));
227 match &content {
228 ClientMessage::Request(request) => {
229 target_public_keys.insert(request.id, header.sender);
230 }
231 ClientMessage::Cancel { request_id, .. } => {
232 target_public_keys.remove(request_id);
233 }
234 _ => {
235 error!("Unexpected request type");
236 }
237 }
238
239 if let Err(e) = client_transport.send(content).await {
240 error!("Failed to forward request to tarpc server: {e}");
241 break;
242 }
243 }
244 None => {
245 debug!("Request listener stream ended");
246 break;
247 }
248 }
249 }
250
251 tarpc_response = client_transport.next() => {
253 match tarpc_response {
254 Some(Ok(response)) => {
255 debug!("📤 Bridge sending tarpc response via RPC message");
256 let Some(_target_public_key) = target_public_keys.remove(&response.request_id) else {
257 error!("Target public key not found for response ID: {}", response.request_id);
259 continue;
260 };
261 error!("RPC-over-messages temporarily disabled during ML-DSA migration");
263 break;
264 }
265 Some(Err(e)) => {
266 error!("tarpc transport error: {e}");
267 break;
268 }
269 None => {
270 debug!("tarpc server transport closed");
271 break;
272 }
273 }
274 }
275 }
276 }
277 debug!("RPC bridge server task ending");
278 Ok(())
279 });
280
281 Self { handle, rpc_spawn }
282 }
283
284 pub fn is_running(&self) -> bool {
286 !self.handle.is_finished()
287 }
288
289 pub fn abort(&self) {
290 self.handle.abort();
291 self.rpc_spawn.abort();
292 }
293}
294
295type ClientMaker<C, Req, Resp> = fn(UnboundedChannel<Response<Resp>, ClientMessage<Req>>) -> C;
296
297pub struct TarpcOverMessagesClient<C> {
298 handle: JoinHandle<Result<()>>,
300 client: C,
301}
302
303impl<C> Deref for TarpcOverMessagesClient<C> {
304 type Target = C;
305 fn deref(&self) -> &Self::Target {
306 &self.client
307 }
308}
309
310impl<C> TarpcOverMessagesClient<C> {
311 pub fn new<S, Req, Resp>(
312 request_listener: S,
313 signing_key: SigningKey,
314 messages_service: MessagesService,
315 target_public_key: VerifyingKey,
316 client_maker: ClientMaker<C, Req, Resp>,
317 ) -> Self
318 where
319 S: Stream<Item = RpcMessage<Response<Resp>>> + Unpin + Send + Sync + 'static,
320 Req: serde::Serialize + Unpin + Send + Sync + 'static,
321 Resp: serde::de::DeserializeOwned + Unpin + Send + Sync + 'static,
322 {
323 Self::new_with_mapper(
324 request_listener,
325 signing_key,
326 messages_service,
327 target_public_key,
328 client_maker,
329 |rpc_message| rpc_message,
330 )
331 }
332
333 pub fn new_with_mapper<S, Req, Resp, M, T>(
334 mut request_listener: S,
335 signing_key: SigningKey,
336 messages_service: MessagesService,
337 target_public_key: VerifyingKey,
338 client_maker: ClientMaker<C, Req, Resp>,
339 mapper: M,
340 ) -> Self
341 where
342 S: Stream<Item = RpcMessage<Response<Resp>>> + Unpin + Send + Sync + 'static,
343 T: serde::Serialize + Unpin + Send + Sync + 'static,
344 Req: serde::Serialize + Unpin + Send + Sync + 'static,
345 Resp: serde::de::DeserializeOwned + Unpin + Send + Sync + 'static,
346 M: Fn(ClientMessage<Req>) -> T + Send + Sync + 'static,
347 {
348 let (client_transport, mut server_transport) = tarpc::transport::channel::unbounded();
349 let client = client_maker(client_transport);
350
351 let handle = tokio::spawn(async move {
352 loop {
353 select! {
354 rpc_request = request_listener.next() => {
356 match rpc_request {
357 Some(RpcMessage { header, content }) => {
358 debug!("📨 Bridge forwarding request to tarpc server from {}",
359 hex::encode(header.sender.encode()));
360
361 if let Err(e) = server_transport.send(content).await {
362 error!("Failed to forward request to tarpc server: {e}");
363 break;
364 }
365 }
366 None => {
367 debug!("Request listener stream ended");
368 break;
369 }
370 }
371 }
372 rpc_message = server_transport.next() => {
374 let Some(Ok(rpc_message)) = rpc_message else {
375 tracing::info!("RPC client closed");
376 break;
377 };
378 if let Err(e) = send_tarpc_message(&signing_key, target_public_key, &messages_service, &mapper(rpc_message)).await {
381 return Err(ClientError::Generic(format!("Send error: {e}")));
382 }
383 }
384 }
385 }
386 Ok(())
387 });
388
389 Self { client, handle }
390 }
391
392 pub fn is_running(&self) -> bool {
394 !self.handle.is_finished()
395 }
396
397 pub fn abort(&self) {
398 self.handle.abort();
399 }
400}
401
402async fn send_tarpc_message<TarpcMsg>(
405 _signing_key: &SigningKey,
406 _target_public_key: VerifyingKey,
407 _messages_service: &MessagesService,
408 _message: &TarpcMsg,
409) -> Result<()>
410where
411 TarpcMsg: serde::Serialize,
412{
413 Err(ClientError::Generic(
417 "RPC-over-messages temporarily disabled during ML-DSA migration".to_string(),
418 ))
419}