zoe_wire_protocol/challenge/
server.rs1use super::{DEFAULT_CHALLENGE_TIMEOUT_SECS, MAX_PACKAGE_SIZE};
2use crate::{
3 KeyChallenge, KeyPair, KeyProof, KeyResponse, KeyResult, VerifyingKey, ZoeChallenge,
4 ZoeChallengeRejection, ZoeChallengeResult,
5};
6use anyhow::Result;
7use quinn::{RecvStream, SendStream};
8use rand::RngCore;
9use std::collections::HashSet;
10use std::time::{SystemTime, UNIX_EPOCH};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tracing::{debug, warn};
13
14pub async fn perform_multi_challenge_handshake(
57 mut send: SendStream,
58 mut recv: RecvStream,
59 server_keypair: &KeyPair,
60) -> Result<HashSet<VerifyingKey>> {
61 debug!("🔐 Starting multi-challenge handshake");
62 send_result(&mut send, &ZoeChallengeResult::Next).await?;
63
64 debug!("📝 Sending key challenge");
66 let key_challenge = generate_key_challenge(server_keypair)?;
67 debug!("🔧 Generated challenge, sending to client...");
68 send_challenge(
69 &mut send,
70 &ZoeChallenge::Key(Box::new(key_challenge.clone())),
71 )
72 .await?;
73 debug!("✅ Challenge sent, waiting for client response...");
74
75 debug!("📥 Waiting to receive key response from client...");
77 let key_response = receive_key_response(&mut recv).await?;
78 debug!(
79 "✅ Received key response with {} proofs",
80 key_response.key_proofs.len()
81 );
82
83 let (keys, _key_result) = verify_key_proofs(key_response, &key_challenge)?;
85
86 if keys.is_empty() {
87 let result = ZoeChallengeResult::Rejected(ZoeChallengeRejection::ChallengeIncomplete);
89 send_result(&mut send, &result).await?;
90 return Err(anyhow::anyhow!("Key challenge failed: no valid key proofs"));
91 }
92
93 debug!("✅ All challenges completed successfully");
97 let result = ZoeChallengeResult::Accepted;
98 send_result(&mut send, &result).await?;
99
100 debug!(
101 "✅ Multi-challenge handshake completed. Verified {} keys",
102 keys.len()
103 );
104
105 Ok(keys)
106}
107
108pub fn generate_key_challenge(server_keypair: &KeyPair) -> Result<KeyChallenge> {
123 let mut nonce = [0u8; 32];
124 rand::thread_rng().fill_bytes(&mut nonce);
125
126 let expires_at =
127 SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() + DEFAULT_CHALLENGE_TIMEOUT_SECS;
128
129 let server_signature = server_keypair.sign(&nonce);
131
132 let challenge_data = KeyChallenge {
133 nonce,
134 signature: server_signature,
135 expires_at,
136 };
137
138 debug!(
139 "Generated key challenge with nonce: {} expires at: {}",
140 hex::encode(&nonce[..8]),
141 expires_at
142 );
143
144 Ok(challenge_data)
145}
146
147pub async fn send_challenge(send: &mut SendStream, challenge: &ZoeChallenge) -> Result<()> {
156 let challenge_bytes = postcard::to_stdvec(challenge)?;
157
158 debug!("Sending challenge ({} bytes)", challenge_bytes.len());
159
160 send.write_u32(challenge_bytes.len() as u32).await?;
162
163 send.write_all(&challenge_bytes).await?;
165
166 Ok(())
167}
168
169pub async fn receive_key_response(recv: &mut RecvStream) -> Result<KeyResponse> {
182 let response_len = recv.read_u32().await? as usize;
184
185 if response_len > MAX_PACKAGE_SIZE {
186 return Err(anyhow::anyhow!(
187 "Response too large: {} bytes (max: {})",
188 response_len,
189 MAX_PACKAGE_SIZE
190 ));
191 }
192
193 debug!("Receiving response ({} bytes)", response_len);
194
195 let mut response_buf = vec![0u8; response_len];
197 recv.read_exact(&mut response_buf).await?;
198
199 let response: KeyResponse = postcard::from_bytes(&response_buf)?;
201
202 debug!(
203 "Received key response with {} key proofs",
204 response.key_proofs.len()
205 );
206 Ok(response)
207}
208
209pub fn verify_key_proofs(
225 response: KeyResponse,
226 challenge: &KeyChallenge,
227) -> Result<(HashSet<VerifyingKey>, KeyResult)> {
228 let challenge_data = challenge;
229
230 let current_time = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
232 if current_time > challenge_data.expires_at {
233 warn!(
234 "Challenge expired: current={}, expires={}",
235 current_time, challenge_data.expires_at
236 );
237 return Ok((HashSet::new(), KeyResult::AllFailed));
238 }
239
240 let mut verified_keys = HashSet::new();
241 let mut failed_indices = Vec::new();
242
243 let signature_data = challenge_data.nonce.to_vec();
245 let total_key_proofs = response.key_proofs.len();
246 debug!("Verifying {} key proofs", total_key_proofs);
247
248 for (index, key_proof) in response.key_proofs.into_iter().enumerate() {
249 match verify_single_key_proof(&key_proof, &signature_data) {
250 Ok(()) => {
251 debug!(
252 "✅ Verified key proof {}: {}",
253 index,
254 hex::encode(&key_proof.public_key.encode()[..8])
255 );
256 verified_keys.insert(key_proof.public_key);
257 }
258 Err(e) => {
259 failed_indices.push(index);
260 warn!("❌ Key proof {} failed: {}", index, e);
261 }
262 }
263 }
264
265 let result = if failed_indices.is_empty() {
266 KeyResult::AllValid
267 } else if verified_keys.is_empty() {
268 KeyResult::AllFailed
269 } else {
270 KeyResult::PartialFailure { failed_indices }
271 };
272
273 debug!(
274 "Verification complete: {}/{} keys verified",
275 verified_keys.len(),
276 total_key_proofs
277 );
278
279 Ok((verified_keys, result))
280}
281
282fn verify_single_key_proof(key_proof: &KeyProof, signature_data: &[u8]) -> Result<()> {
296 let verifying_key = &key_proof.public_key;
298 let signature = &key_proof.signature;
299
300 verifying_key
302 .verify(signature_data, signature)
303 .map_err(|e| anyhow::anyhow!("Signature verification failed: {}", e))
304}
305
306pub async fn send_result(send: &mut SendStream, result: &ZoeChallengeResult) -> Result<()> {
313 let result_bytes = postcard::to_stdvec(result)?;
314
315 debug!("Sending result ({} bytes)", result_bytes.len());
316
317 send.write_u32(result_bytes.len() as u32).await?;
319
320 send.write_all(&result_bytes).await?;
322
323 Ok(())
324}
325
326pub fn create_key_proofs(challenge: &KeyChallenge, keypairs: &[&KeyPair]) -> Result<KeyResponse> {
331 let mut key_proofs = Vec::new();
332
333 let signature_data = challenge.nonce.to_vec();
335
336 for keypair in keypairs {
338 let signature = keypair.sign(&signature_data);
339 let verifying_key = keypair.public_key();
340
341 let key_proof = KeyProof {
342 public_key: verifying_key,
343 signature,
344 };
345 key_proofs.push(key_proof);
346 }
347
348 Ok(KeyResponse { key_proofs })
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use crate::KeyPair;
355
356 #[test]
357 fn test_key_challenge_generation() {
358 let server_keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
359 let challenge = generate_key_challenge(&server_keypair).unwrap();
360
361 assert!(
363 challenge.expires_at
364 > SystemTime::now()
365 .duration_since(UNIX_EPOCH)
366 .unwrap()
367 .as_secs()
368 );
369 }
370
371 #[test]
372 fn test_single_key_proof_verification() {
373 let client_keypair = KeyPair::generate(&mut rand::thread_rng());
375
376 let nonce = [42u8; 32];
378 let signature_data = nonce.to_vec();
379
380 let signature = client_keypair.sign(&signature_data);
382 let verifying_key = client_keypair.public_key();
383
384 let key_proof = KeyProof {
386 public_key: verifying_key,
387 signature,
388 };
389
390 let result = verify_single_key_proof(&key_proof, &signature_data);
392 assert!(result.is_ok());
393 }
394
395 #[test]
396 fn test_invalid_signature_fails() {
397 let client_keypair = KeyPair::generate(&mut rand::thread_rng());
399
400 let signature_data = b"test data";
402
403 let wrong_signature = client_keypair.sign(b"wrong data");
405 let verifying_key = client_keypair.public_key();
406
407 let key_proof = KeyProof {
409 public_key: verifying_key,
410 signature: wrong_signature,
411 };
412
413 let result = verify_single_key_proof(&key_proof, signature_data);
415 assert!(result.is_err());
416 }
417
418 #[test]
419 fn test_generate_key_challenge() {
420 let server_keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
421
422 let challenge = generate_key_challenge(&server_keypair).unwrap();
423
424 assert_eq!(challenge.nonce.len(), 32);
426
427 let server_public_key = server_keypair.public_key();
429 assert!(server_public_key
430 .verify(&challenge.nonce, &challenge.signature)
431 .is_ok());
432
433 let now = std::time::SystemTime::now()
435 .duration_since(std::time::UNIX_EPOCH)
436 .unwrap()
437 .as_secs();
438 assert!(challenge.expires_at > now);
439 }
440
441 #[test]
442 fn test_create_key_proofs() {
443 let server_keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
444 let client_keypair1 = KeyPair::generate(&mut rand::thread_rng());
445 let client_keypair2 = KeyPair::generate(&mut rand::thread_rng());
446
447 let challenge = generate_key_challenge(&server_keypair).unwrap();
448 let client_keys = vec![&client_keypair1, &client_keypair2];
449
450 let response = create_key_proofs(&challenge, &client_keys).unwrap();
451
452 assert_eq!(response.key_proofs.len(), 2);
454
455 for (i, proof) in response.key_proofs.iter().enumerate() {
457 let result = verify_single_key_proof(proof, &challenge.nonce);
458 assert!(result.is_ok(), "Proof {i} should verify");
459 }
460 }
461
462 #[test]
463 fn test_verify_key_proofs() {
464 let server_keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
465 let client_keypair1 = KeyPair::generate(&mut rand::thread_rng());
466 let client_keypair2 = KeyPair::generate(&mut rand::thread_rng());
467
468 let challenge = generate_key_challenge(&server_keypair).unwrap();
469 let client_keys = vec![&client_keypair1, &client_keypair2];
470
471 let response = create_key_proofs(&challenge, &client_keys).unwrap();
472
473 let (verified_keys, result) = verify_key_proofs(response, &challenge).unwrap();
474
475 assert_eq!(verified_keys.len(), 2);
477 assert!(matches!(result, KeyResult::AllValid));
478
479 assert!(verified_keys.contains(&client_keypair1.public_key()));
481 assert!(verified_keys.contains(&client_keypair2.public_key()));
482 }
483
484 #[test]
485 fn test_verify_key_proofs_partial_failure() {
486 let server_keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
487 let client_keypair1 = KeyPair::generate(&mut rand::thread_rng());
488 let client_keypair2 = KeyPair::generate(&mut rand::thread_rng());
489
490 let challenge = generate_key_challenge(&server_keypair).unwrap();
491
492 let valid_signature = client_keypair1.sign(&challenge.nonce);
494 let invalid_signature = client_keypair2.sign(b"wrong data");
495
496 let response = KeyResponse {
497 key_proofs: vec![
498 KeyProof {
499 public_key: client_keypair1.public_key(),
500 signature: valid_signature,
501 },
502 KeyProof {
503 public_key: client_keypair2.public_key(),
504 signature: invalid_signature,
505 },
506 ],
507 };
508
509 let (verified_keys, result) = verify_key_proofs(response, &challenge).unwrap();
510
511 assert_eq!(verified_keys.len(), 1);
513 assert!(
514 matches!(result, KeyResult::PartialFailure { failed_indices } if failed_indices == vec![1])
515 );
516
517 assert!(verified_keys.contains(&client_keypair1.public_key()));
519 }
520
521 #[test]
522 fn test_challenge_serialization_roundtrip() {
523 let server_keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
524 let challenge = generate_key_challenge(&server_keypair).unwrap();
525
526 let serialized = postcard::to_stdvec(&challenge).unwrap();
528 let deserialized: KeyChallenge = postcard::from_bytes(&serialized).unwrap();
529
530 assert_eq!(challenge.nonce, deserialized.nonce);
532 assert_eq!(challenge.expires_at, deserialized.expires_at);
533 }
535
536 #[test]
537 fn test_response_serialization_roundtrip() {
538 let client_keypair = KeyPair::generate(&mut rand::thread_rng());
539 let signature = client_keypair.sign(b"test data");
540
541 let response = KeyResponse {
542 key_proofs: vec![KeyProof {
543 public_key: client_keypair.public_key(),
544 signature,
545 }],
546 };
547
548 let serialized = postcard::to_stdvec(&response).unwrap();
550 let deserialized: KeyResponse = postcard::from_bytes(&serialized).unwrap();
551
552 assert_eq!(response.key_proofs.len(), deserialized.key_proofs.len());
554 }
555}