zoe_client/
relay_client.rs

1use crate::SessionManager;
2use crate::challenge::perform_client_challenge_handshake;
3use crate::error::{ClientError, Result};
4use crate::services::{BlobService, MessagePersistenceManager, MessagePersistenceManagerBuilder};
5use async_once_cell::OnceCell;
6use quinn::{Connection, Endpoint};
7use std::net::SocketAddr;
8use std::path::PathBuf;
9use std::sync::Arc;
10use tracing::info;
11use zoe_client_storage::{SqliteMessageStorage, StorageConfig as DbConfig};
12use zoe_wire_protocol::{KeyPair, VerifyingKey, connection::client::create_client_endpoint};
13
14struct RelayClientInner {
15    client_keypair_tls: KeyPair, // For TLS certificates (Ed25519 or ML-DSA-44)
16    client_keypair_inner: Arc<KeyPair>, // For inner protocol
17    connection: Connection,
18    blob_service: OnceCell<Arc<BlobService>>,
19    persistence_manager: Arc<MessagePersistenceManager>,
20    session_manager: SessionManager<SqliteMessageStorage, MessagePersistenceManager>,
21    storage: Arc<SqliteMessageStorage>,
22    endpoint: Endpoint,
23}
24
25/// Builder for creating RelayClient instances with configurable options.
26///
27/// This builder allows configuring storage, connection parameters, and message persistence
28/// before creating a RelayClient instance. All RelayClients have message persistence enabled
29/// by default and require storage configuration.
30///
31/// # Example
32///
33/// ```rust,no_run
34/// # use zoe_client::RelayClientBuilder;
35/// # use std::net::SocketAddr;
36/// # use zoe_wire_protocol::{KeyPair, VerifyingKey};
37/// # async fn example() -> zoe_client::error::Result<()> {
38/// let server_key = VerifyingKey::from([0u8; 32]); // Replace with actual key
39/// let server_addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
40/// let encryption_key = [42u8; 32]; // Use a proper encryption key
41///
42/// let client = RelayClientBuilder::new()
43///     .server_public_key(server_key)
44///     .server_address(server_addr)
45///     .db_storage_path("client_messages.db")
46///     .encryption_key(encryption_key)
47///     .autosubscribe(true)
48///     .build()
49///     .await?;
50/// # Ok(())
51/// # }
52/// ```
53pub struct RelayClientBuilder {
54    client_keypair_inner: Option<Arc<KeyPair>>,
55    server_public_key: Option<VerifyingKey>,
56    server_address: Option<SocketAddr>,
57    db_config: Option<DbConfig>,
58    encryption_key: Option<[u8; 32]>,
59    storage: Option<Arc<SqliteMessageStorage>>,
60    autosubscribe: bool,
61    buffer_size: Option<usize>,
62}
63
64impl RelayClientBuilder {
65    /// Create a new RelayClientBuilder with default settings
66    pub fn new() -> Self {
67        Self {
68            client_keypair_inner: None,
69            server_public_key: None,
70            server_address: None,
71            db_config: None,
72            encryption_key: None,
73            storage: None,
74            autosubscribe: false,
75            buffer_size: None,
76        }
77    }
78
79    /// Set the client's inner protocol keypair (for message signing/verification)
80    /// If not set, a random keypair will be generated
81    pub fn client_keypair(mut self, keypair: Arc<KeyPair>) -> Self {
82        self.client_keypair_inner = Some(keypair);
83        self
84    }
85
86    /// Set the server's public key for TLS verification
87    pub fn server_public_key(mut self, key: VerifyingKey) -> Self {
88        self.server_public_key = Some(key);
89        self
90    }
91
92    /// Set the server address to connect to
93    pub fn server_address(mut self, addr: SocketAddr) -> Self {
94        self.server_address = Some(addr);
95        self
96    }
97
98    /// Set the storage configuration
99    pub fn db_config(mut self, config: DbConfig) -> Self {
100        self.db_config = Some(config);
101        self
102    }
103
104    /// Set the storage database path (convenience method)
105    pub fn db_storage_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
106        let mut config = self.db_config.unwrap_or_default();
107        config.database_path = path.into();
108        self.db_config = Some(config);
109        self
110    }
111
112    /// Set the encryption key for storage
113    pub fn encryption_key(mut self, key: [u8; 32]) -> Self {
114        self.encryption_key = Some(key);
115        self
116    }
117
118    /// Set a pre-created storage instance
119    ///
120    /// When this is set, the builder will use this storage instead of creating one
121    /// from db_config and encryption_key.
122    pub fn storage(mut self, storage: Arc<SqliteMessageStorage>) -> Self {
123        self.storage = Some(storage);
124        self
125    }
126
127    /// Enable or disable automatic subscription to messages
128    pub fn autosubscribe(mut self, enable: bool) -> Self {
129        self.autosubscribe = enable;
130        self
131    }
132
133    /// Set the buffer size for message processing
134    pub fn buffer_size(mut self, size: usize) -> Self {
135        self.buffer_size = Some(size);
136        self
137    }
138
139    /// Build the RelayClient with the configured options
140    ///
141    /// Storage and encryption key are required for message persistence.
142    pub async fn build(self) -> Result<RelayClient> {
143        let server_public_key = self
144            .server_public_key
145            .ok_or_else(|| ClientError::Generic("Server public key is required".to_string()))?;
146
147        let server_address = self
148            .server_address
149            .ok_or_else(|| ClientError::Generic("Server address is required".to_string()))?;
150
151        // Encryption key is only required if no pre-created storage is provided
152        let encryption_key = if self.storage.is_some() {
153            self.encryption_key.unwrap_or([0u8; 32]) // Default if storage is pre-created
154        } else {
155            self.encryption_key.ok_or_else(|| {
156                ClientError::Generic(
157                    "Encryption key is required when no storage is provided".to_string(),
158                )
159            })?
160        };
161
162        let client_keypair_inner = self
163            .client_keypair_inner
164            .unwrap_or_else(|| Arc::new(KeyPair::generate(&mut rand::thread_rng())));
165
166        // Generate TLS keypair for certificates (default to Ed25519)
167        let client_keypair_tls = KeyPair::generate_ed25519(&mut rand::thread_rng());
168
169        // Establish connection
170        let (endpoint, connection) = RelayClientInner::connect_with_transport_keys(
171            &client_keypair_tls,
172            &client_keypair_inner,
173            server_address,
174            &server_public_key,
175        )
176        .await?;
177
178        // Use pre-created storage or create new one
179        let storage = if let Some(storage) = self.storage {
180            storage
181        } else {
182            let db_config = self.db_config.unwrap_or_default();
183            Arc::new(
184                SqliteMessageStorage::new(db_config, &encryption_key)
185                    .await
186                    .map_err(|e| ClientError::Generic(format!("Failed to create storage: {e}")))?,
187            )
188        };
189
190        // Create persistence manager (always required now)
191        let persistence_manager = Arc::new(
192            MessagePersistenceManagerBuilder::new()
193                .storage(storage.clone())
194                .relay_pubkey(server_public_key)
195                .autosubscribe(self.autosubscribe)
196                .buffer_size(self.buffer_size.unwrap_or(1000))
197                .build(&connection)
198                .await?,
199        );
200
201        let session_manager = SessionManager::builder(storage.clone(), persistence_manager.clone())
202            .client_keypair(client_keypair_inner.clone())
203            .build()
204            .await?;
205
206        Ok(RelayClient {
207            inner: Arc::new(RelayClientInner {
208                client_keypair_tls,
209                client_keypair_inner,
210                storage,
211                connection,
212                persistence_manager,
213                blob_service: OnceCell::new(),
214                session_manager,
215                endpoint,
216            }),
217        })
218    }
219}
220
221impl Default for RelayClientBuilder {
222    fn default() -> Self {
223        Self::new()
224    }
225}
226
227/// A Zoe Relay Client with integrated message persistence
228#[derive(Clone)]
229pub struct RelayClient {
230    inner: Arc<RelayClientInner>,
231}
232
233impl RelayClientInner {
234    async fn close(&self) {
235        self.connection.close(0u32.into(), b"Client closed");
236        self.endpoint.wait_idle().await;
237    }
238    /// Connect to relay server with transport keys and return the connection
239    async fn connect_with_transport_keys(
240        client_keypair_tls: &KeyPair, // For TLS certificates (Ed25519 or ML-DSA-44)
241        client_keypair_inner: &KeyPair, // For inner protocol
242        server_addr: SocketAddr,
243        server_public_key: &VerifyingKey,
244    ) -> Result<(Endpoint, Connection)> {
245        info!("🚀 Starting relay client with transport keys");
246        info!(
247            "🔑 Client TLS key: {} ({})",
248            client_keypair_tls.public_key(),
249            client_keypair_tls.algorithm()
250        );
251        info!(
252            "🔑 Client inner public key id: {}",
253            hex::encode(client_keypair_inner.public_key().id())
254        );
255        info!("🌐 Connecting to server: {}", server_addr);
256        info!(
257            "🔐 Server public key: {} ({})",
258            hex::encode(server_public_key.id()),
259            server_public_key.algorithm()
260        );
261
262        // Create client endpoint and establish QUIC connection
263        let client_endpoint = create_client_endpoint(server_public_key)?;
264        let connection = client_endpoint.connect(server_addr, "localhost")?.await?;
265
266        // Validate that the server supports our protocol
267        let client_protocol_config = zoe_wire_protocol::version::ClientProtocolConfig::default();
268        match zoe_wire_protocol::version::validate_server_protocol_support(
269            &connection,
270            &client_protocol_config,
271        ) {
272            Ok(negotiated_version) => {
273                info!(
274                    "✅ Connected to relay server with protocol: {}",
275                    negotiated_version
276                );
277            }
278            Err(e) => {
279                return Err(ClientError::ProtocolError(format!(
280                    "Server protocol validation failed: {e}"
281                )));
282            }
283        }
284
285        // No conversion needed - server_public_key is already a VerifyingKey
286
287        // Perform ML-DSA challenge-response handshake
288        let (send, recv) = connection.accept_bi().await?;
289        let Ok((verified_count, _)) = perform_client_challenge_handshake(
290            send,
291            recv,
292            server_public_key,
293            &[client_keypair_inner],
294        )
295        .await
296        else {
297            connection.close(0u32.into(), b"ML-DSA handshake failed");
298            client_endpoint.wait_idle().await;
299            return Err(anyhow::anyhow!("ML-DSA handshake failed").into());
300        };
301
302        info!(
303            "🔐 ML-DSA handshake completed: {} out of {} keys verified",
304            verified_count, 1
305        );
306
307        Ok((client_endpoint, connection))
308    }
309}
310
311// different services
312impl RelayClientInner {
313    pub async fn blob_service(&self) -> Result<&Arc<BlobService>> {
314        self.blob_service
315            .get_or_try_init(
316                async move { Ok(Arc::new(BlobService::connect(&self.connection).await?)) },
317            )
318            .await
319    }
320}
321
322// public methods
323impl RelayClient {
324    /// Get the client's inner protocol public key
325    pub fn public_key(&self) -> VerifyingKey {
326        self.inner.client_keypair_inner.public_key()
327    }
328
329    /// Get the client's inner protocol keypair
330    pub fn keypair(&self) -> &Arc<KeyPair> {
331        &self.inner.client_keypair_inner
332    }
333
334    /// Get the client's TLS public key (Ed25519 or ML-DSA-44)
335    pub fn tls_public_key(&self) -> VerifyingKey {
336        self.inner.client_keypair_tls.public_key()
337    }
338
339    pub fn connection(&self) -> &Connection {
340        &self.inner.connection
341    }
342
343    pub fn storage(&self) -> &Arc<SqliteMessageStorage> {
344        &self.inner.storage
345    }
346
347    pub async fn close(&self) {
348        self.inner.close().await;
349    }
350
351    /// Get access to the message persistence manager
352    pub async fn persistence_manager(&self) -> &MessagePersistenceManager {
353        &self.inner.persistence_manager
354    }
355
356    pub async fn session_manager(
357        &self,
358    ) -> &SessionManager<SqliteMessageStorage, MessagePersistenceManager> {
359        &self.inner.session_manager
360    }
361
362    pub async fn blob_service(&self) -> Result<&Arc<BlobService>> {
363        self.inner.blob_service().await
364    }
365}