zoe_client/pqxdh/
transport.rs

1use futures::{Sink, Stream};
2
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use super::{PqxdhSessionId, Result};
7
8#[async_trait::async_trait]
9pub trait PqxdhTarpcTransportSender<Resp>
10where
11    Resp: serde::Serialize + Send + Sync,
12{
13    async fn send_response(&self, session_id: &PqxdhSessionId, resp: &Resp) -> Result<()>;
14}
15
16/// A tarpc transport that uses PQXDH for message delivery
17///
18/// This transport owns its incoming stream and manages outgoing messages
19/// through a send callback function.
20pub struct PqxdhTarpcTransport<Req, Resp> {
21    /// Incoming message stream
22    incoming_stream: Pin<Box<dyn Stream<Item = Req> + Send>>,
23    /// Queue of outgoing messages to be sent
24    outgoing_queue: tokio::sync::mpsc::UnboundedSender<Resp>,
25    /// Background handle for the background task
26    background_handle: tokio::task::JoinHandle<()>,
27}
28
29impl<Req, Resp> PqxdhTarpcTransport<Req, Resp>
30where
31    Req: for<'de> serde::Deserialize<'de> + Send,
32    Resp: serde::Serialize + Send + Sync + 'static,
33{
34    /// Creates a new PQXDH transport for the given session
35    ///
36    /// This sets up the incoming message stream by calling `listen_for_messages`
37    /// and initializes the outgoing message queue.
38    pub(crate) fn new<T>(
39        session_id: PqxdhSessionId,
40        incoming_stream: Pin<Box<dyn Stream<Item = Req> + Send>>,
41        client: T,
42    ) -> Self
43    where
44        T: PqxdhTarpcTransportSender<Resp> + Send + Sync + 'static,
45    {
46        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
47
48        // Spawn a background task to handle outgoing responses
49        let background_handle = {
50            let tx = tx.clone();
51            tokio::spawn(async move {
52                while let Some(resp) = rx.recv().await {
53                    if let Err(e) = client.send_response(&session_id, &resp).await {
54                        tracing::error!("Failed to send response: {e}. requeuing");
55                        if let Err(e) = tx.send(resp) {
56                            tracing::error!("Failed to requeue response: {e}");
57                        }
58                    }
59                }
60            })
61        };
62
63        Self {
64            incoming_stream,
65            outgoing_queue: tx,
66            background_handle,
67        }
68    }
69}
70
71impl<Req, Resp> Drop for PqxdhTarpcTransport<Req, Resp> {
72    fn drop(&mut self) {
73        self.background_handle.abort();
74    }
75}
76
77impl<Req, Resp> Stream for PqxdhTarpcTransport<Req, Resp>
78where
79    Req: for<'de> serde::Deserialize<'de> + Send,
80    Resp: serde::Serialize + Send + Sync,
81{
82    type Item = std::result::Result<Req, std::io::Error>;
83
84    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
85        // Poll the incoming stream for new requests
86        let this = unsafe { self.get_unchecked_mut() };
87        match this.incoming_stream.as_mut().poll_next(cx) {
88            Poll::Ready(Some(item)) => Poll::Ready(Some(Ok(item))),
89            Poll::Ready(None) => Poll::Ready(None),
90            Poll::Pending => Poll::Pending,
91        }
92    }
93}
94
95impl<Req, Resp> Sink<Resp> for PqxdhTarpcTransport<Req, Resp>
96where
97    Req: for<'de> serde::Deserialize<'de> + Send,
98    Resp: serde::Serialize + Send + Sync,
99{
100    type Error = std::io::Error;
101
102    fn poll_ready(
103        self: Pin<&mut Self>,
104        _cx: &mut Context<'_>,
105    ) -> Poll<std::result::Result<(), Self::Error>> {
106        // Always ready to accept responses into our queue
107        Poll::Ready(Ok(()))
108    }
109
110    fn start_send(self: Pin<&mut Self>, item: Resp) -> std::result::Result<(), Self::Error> {
111        let this = unsafe { self.get_unchecked_mut() };
112        this.outgoing_queue
113            .send(item)
114            .map_err(|e| std::io::Error::other(e.to_string()))
115    }
116
117    fn poll_flush(
118        self: Pin<&mut Self>,
119        _cx: &mut Context<'_>,
120    ) -> Poll<std::result::Result<(), Self::Error>> {
121        // For now, we'll just return Ready since we queue messages and send them immediately
122        // In a more sophisticated implementation, we might want to track pending sends
123        // and only return Ready when all are confirmed sent
124        Poll::Ready(Ok(()))
125    }
126
127    fn poll_close(
128        mut self: Pin<&mut Self>,
129        cx: &mut Context<'_>,
130    ) -> Poll<std::result::Result<(), Self::Error>> {
131        // Flush any remaining messages first
132        match self.as_mut().poll_flush(cx) {
133            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
134            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
135            Poll::Pending => Poll::Pending,
136        }
137    }
138}