1use crate::SessionManager;
2use crate::challenge::perform_client_challenge_handshake;
3use crate::error::{ClientError, Result};
4use crate::services::{BlobService, MessagePersistenceManager, MessagePersistenceManagerBuilder};
5use async_once_cell::OnceCell;
6use quinn::{Connection, Endpoint};
7use std::net::SocketAddr;
8use std::path::PathBuf;
9use std::sync::Arc;
10use tracing::info;
11use zoe_client_storage::{SqliteMessageStorage, StorageConfig as DbConfig};
12use zoe_wire_protocol::{KeyPair, VerifyingKey, connection::client::create_client_endpoint};
13
14struct RelayClientInner {
15 client_keypair_tls: KeyPair, client_keypair_inner: Arc<KeyPair>, connection: Connection,
18 blob_service: OnceCell<Arc<BlobService>>,
19 persistence_manager: Arc<MessagePersistenceManager>,
20 session_manager: SessionManager<SqliteMessageStorage, MessagePersistenceManager>,
21 storage: Arc<SqliteMessageStorage>,
22 endpoint: Endpoint,
23}
24
25pub struct RelayClientBuilder {
54 client_keypair_inner: Option<Arc<KeyPair>>,
55 server_public_key: Option<VerifyingKey>,
56 server_address: Option<SocketAddr>,
57 db_config: Option<DbConfig>,
58 encryption_key: Option<[u8; 32]>,
59 storage: Option<Arc<SqliteMessageStorage>>,
60 autosubscribe: bool,
61 buffer_size: Option<usize>,
62}
63
64impl RelayClientBuilder {
65 pub fn new() -> Self {
67 Self {
68 client_keypair_inner: None,
69 server_public_key: None,
70 server_address: None,
71 db_config: None,
72 encryption_key: None,
73 storage: None,
74 autosubscribe: false,
75 buffer_size: None,
76 }
77 }
78
79 pub fn client_keypair(mut self, keypair: Arc<KeyPair>) -> Self {
82 self.client_keypair_inner = Some(keypair);
83 self
84 }
85
86 pub fn server_public_key(mut self, key: VerifyingKey) -> Self {
88 self.server_public_key = Some(key);
89 self
90 }
91
92 pub fn server_address(mut self, addr: SocketAddr) -> Self {
94 self.server_address = Some(addr);
95 self
96 }
97
98 pub fn db_config(mut self, config: DbConfig) -> Self {
100 self.db_config = Some(config);
101 self
102 }
103
104 pub fn db_storage_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
106 let mut config = self.db_config.unwrap_or_default();
107 config.database_path = path.into();
108 self.db_config = Some(config);
109 self
110 }
111
112 pub fn encryption_key(mut self, key: [u8; 32]) -> Self {
114 self.encryption_key = Some(key);
115 self
116 }
117
118 pub fn storage(mut self, storage: Arc<SqliteMessageStorage>) -> Self {
123 self.storage = Some(storage);
124 self
125 }
126
127 pub fn autosubscribe(mut self, enable: bool) -> Self {
129 self.autosubscribe = enable;
130 self
131 }
132
133 pub fn buffer_size(mut self, size: usize) -> Self {
135 self.buffer_size = Some(size);
136 self
137 }
138
139 pub async fn build(self) -> Result<RelayClient> {
143 let server_public_key = self
144 .server_public_key
145 .ok_or_else(|| ClientError::Generic("Server public key is required".to_string()))?;
146
147 let server_address = self
148 .server_address
149 .ok_or_else(|| ClientError::Generic("Server address is required".to_string()))?;
150
151 let encryption_key = if self.storage.is_some() {
153 self.encryption_key.unwrap_or([0u8; 32]) } else {
155 self.encryption_key.ok_or_else(|| {
156 ClientError::Generic(
157 "Encryption key is required when no storage is provided".to_string(),
158 )
159 })?
160 };
161
162 let client_keypair_inner = self
163 .client_keypair_inner
164 .unwrap_or_else(|| Arc::new(KeyPair::generate(&mut rand::thread_rng())));
165
166 let client_keypair_tls = KeyPair::generate_ed25519(&mut rand::thread_rng());
168
169 let (endpoint, connection) = RelayClientInner::connect_with_transport_keys(
171 &client_keypair_tls,
172 &client_keypair_inner,
173 server_address,
174 &server_public_key,
175 )
176 .await?;
177
178 let storage = if let Some(storage) = self.storage {
180 storage
181 } else {
182 let db_config = self.db_config.unwrap_or_default();
183 Arc::new(
184 SqliteMessageStorage::new(db_config, &encryption_key)
185 .await
186 .map_err(|e| ClientError::Generic(format!("Failed to create storage: {e}")))?,
187 )
188 };
189
190 let persistence_manager = Arc::new(
192 MessagePersistenceManagerBuilder::new()
193 .storage(storage.clone())
194 .relay_pubkey(server_public_key)
195 .autosubscribe(self.autosubscribe)
196 .buffer_size(self.buffer_size.unwrap_or(1000))
197 .build(&connection)
198 .await?,
199 );
200
201 let session_manager = SessionManager::builder(storage.clone(), persistence_manager.clone())
202 .client_keypair(client_keypair_inner.clone())
203 .build()
204 .await?;
205
206 Ok(RelayClient {
207 inner: Arc::new(RelayClientInner {
208 client_keypair_tls,
209 client_keypair_inner,
210 storage,
211 connection,
212 persistence_manager,
213 blob_service: OnceCell::new(),
214 session_manager,
215 endpoint,
216 }),
217 })
218 }
219}
220
221impl Default for RelayClientBuilder {
222 fn default() -> Self {
223 Self::new()
224 }
225}
226
227#[derive(Clone)]
229pub struct RelayClient {
230 inner: Arc<RelayClientInner>,
231}
232
233impl RelayClientInner {
234 async fn close(&self) {
235 self.connection.close(0u32.into(), b"Client closed");
236 self.endpoint.wait_idle().await;
237 }
238 async fn connect_with_transport_keys(
240 client_keypair_tls: &KeyPair, client_keypair_inner: &KeyPair, server_addr: SocketAddr,
243 server_public_key: &VerifyingKey,
244 ) -> Result<(Endpoint, Connection)> {
245 info!("🚀 Starting relay client with transport keys");
246 info!(
247 "🔑 Client TLS key: {} ({})",
248 client_keypair_tls.public_key(),
249 client_keypair_tls.algorithm()
250 );
251 info!(
252 "🔑 Client inner public key id: {}",
253 hex::encode(client_keypair_inner.public_key().id())
254 );
255 info!("🌐 Connecting to server: {}", server_addr);
256 info!(
257 "🔐 Server public key: {} ({})",
258 hex::encode(server_public_key.id()),
259 server_public_key.algorithm()
260 );
261
262 let client_endpoint = create_client_endpoint(server_public_key)?;
264 let connection = client_endpoint.connect(server_addr, "localhost")?.await?;
265
266 let client_protocol_config = zoe_wire_protocol::version::ClientProtocolConfig::default();
268 match zoe_wire_protocol::version::validate_server_protocol_support(
269 &connection,
270 &client_protocol_config,
271 ) {
272 Ok(negotiated_version) => {
273 info!(
274 "✅ Connected to relay server with protocol: {}",
275 negotiated_version
276 );
277 }
278 Err(e) => {
279 return Err(ClientError::ProtocolError(format!(
280 "Server protocol validation failed: {e}"
281 )));
282 }
283 }
284
285 let (send, recv) = connection.accept_bi().await?;
289 let Ok((verified_count, _)) = perform_client_challenge_handshake(
290 send,
291 recv,
292 server_public_key,
293 &[client_keypair_inner],
294 )
295 .await
296 else {
297 connection.close(0u32.into(), b"ML-DSA handshake failed");
298 client_endpoint.wait_idle().await;
299 return Err(anyhow::anyhow!("ML-DSA handshake failed").into());
300 };
301
302 info!(
303 "🔐 ML-DSA handshake completed: {} out of {} keys verified",
304 verified_count, 1
305 );
306
307 Ok((client_endpoint, connection))
308 }
309}
310
311impl RelayClientInner {
313 pub async fn blob_service(&self) -> Result<&Arc<BlobService>> {
314 self.blob_service
315 .get_or_try_init(
316 async move { Ok(Arc::new(BlobService::connect(&self.connection).await?)) },
317 )
318 .await
319 }
320}
321
322impl RelayClient {
324 pub fn public_key(&self) -> VerifyingKey {
326 self.inner.client_keypair_inner.public_key()
327 }
328
329 pub fn keypair(&self) -> &Arc<KeyPair> {
331 &self.inner.client_keypair_inner
332 }
333
334 pub fn tls_public_key(&self) -> VerifyingKey {
336 self.inner.client_keypair_tls.public_key()
337 }
338
339 pub fn connection(&self) -> &Connection {
340 &self.inner.connection
341 }
342
343 pub fn storage(&self) -> &Arc<SqliteMessageStorage> {
344 &self.inner.storage
345 }
346
347 pub async fn close(&self) {
348 self.inner.close().await;
349 }
350
351 pub async fn persistence_manager(&self) -> &MessagePersistenceManager {
353 &self.inner.persistence_manager
354 }
355
356 pub async fn session_manager(
357 &self,
358 ) -> &SessionManager<SqliteMessageStorage, MessagePersistenceManager> {
359 &self.inner.session_manager
360 }
361
362 pub async fn blob_service(&self) -> Result<&Arc<BlobService>> {
363 self.inner.blob_service().await
364 }
365}