zoe_client/pqxdh/
transport.rs1use futures::{Sink, Stream};
2
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use super::{PqxdhSessionId, Result};
7
8#[async_trait::async_trait]
9pub trait PqxdhTarpcTransportSender<Resp>
10where
11 Resp: serde::Serialize + Send + Sync,
12{
13 async fn send_response(&self, session_id: &PqxdhSessionId, resp: &Resp) -> Result<()>;
14}
15
16pub struct PqxdhTarpcTransport<Req, Resp> {
21 incoming_stream: Pin<Box<dyn Stream<Item = Req> + Send>>,
23 outgoing_queue: tokio::sync::mpsc::UnboundedSender<Resp>,
25 background_handle: tokio::task::JoinHandle<()>,
27}
28
29impl<Req, Resp> PqxdhTarpcTransport<Req, Resp>
30where
31 Req: for<'de> serde::Deserialize<'de> + Send,
32 Resp: serde::Serialize + Send + Sync + 'static,
33{
34 pub(crate) fn new<T>(
39 session_id: PqxdhSessionId,
40 incoming_stream: Pin<Box<dyn Stream<Item = Req> + Send>>,
41 client: T,
42 ) -> Self
43 where
44 T: PqxdhTarpcTransportSender<Resp> + Send + Sync + 'static,
45 {
46 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
47
48 let background_handle = {
50 let tx = tx.clone();
51 tokio::spawn(async move {
52 while let Some(resp) = rx.recv().await {
53 if let Err(e) = client.send_response(&session_id, &resp).await {
54 tracing::error!("Failed to send response: {e}. requeuing");
55 if let Err(e) = tx.send(resp) {
56 tracing::error!("Failed to requeue response: {e}");
57 }
58 }
59 }
60 })
61 };
62
63 Self {
64 incoming_stream,
65 outgoing_queue: tx,
66 background_handle,
67 }
68 }
69}
70
71impl<Req, Resp> Drop for PqxdhTarpcTransport<Req, Resp> {
72 fn drop(&mut self) {
73 self.background_handle.abort();
74 }
75}
76
77impl<Req, Resp> Stream for PqxdhTarpcTransport<Req, Resp>
78where
79 Req: for<'de> serde::Deserialize<'de> + Send,
80 Resp: serde::Serialize + Send + Sync,
81{
82 type Item = std::result::Result<Req, std::io::Error>;
83
84 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
85 let this = unsafe { self.get_unchecked_mut() };
87 match this.incoming_stream.as_mut().poll_next(cx) {
88 Poll::Ready(Some(item)) => Poll::Ready(Some(Ok(item))),
89 Poll::Ready(None) => Poll::Ready(None),
90 Poll::Pending => Poll::Pending,
91 }
92 }
93}
94
95impl<Req, Resp> Sink<Resp> for PqxdhTarpcTransport<Req, Resp>
96where
97 Req: for<'de> serde::Deserialize<'de> + Send,
98 Resp: serde::Serialize + Send + Sync,
99{
100 type Error = std::io::Error;
101
102 fn poll_ready(
103 self: Pin<&mut Self>,
104 _cx: &mut Context<'_>,
105 ) -> Poll<std::result::Result<(), Self::Error>> {
106 Poll::Ready(Ok(()))
108 }
109
110 fn start_send(self: Pin<&mut Self>, item: Resp) -> std::result::Result<(), Self::Error> {
111 let this = unsafe { self.get_unchecked_mut() };
112 this.outgoing_queue
113 .send(item)
114 .map_err(|e| std::io::Error::other(e.to_string()))
115 }
116
117 fn poll_flush(
118 self: Pin<&mut Self>,
119 _cx: &mut Context<'_>,
120 ) -> Poll<std::result::Result<(), Self::Error>> {
121 Poll::Ready(Ok(()))
125 }
126
127 fn poll_close(
128 mut self: Pin<&mut Self>,
129 cx: &mut Context<'_>,
130 ) -> Poll<std::result::Result<(), Self::Error>> {
131 match self.as_mut().poll_flush(cx) {
133 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
134 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
135 Poll::Pending => Poll::Pending,
136 }
137 }
138}