zoe_wire_protocol/challenge/
server.rs

1use super::{DEFAULT_CHALLENGE_TIMEOUT_SECS, MAX_PACKAGE_SIZE};
2use crate::{
3    KeyChallenge, KeyPair, KeyProof, KeyResponse, KeyResult, VerifyingKey, ZoeChallenge,
4    ZoeChallengeRejection, ZoeChallengeResult,
5};
6use anyhow::Result;
7use quinn::{RecvStream, SendStream};
8use rand::RngCore;
9use std::collections::HashSet;
10use std::time::{SystemTime, UNIX_EPOCH};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tracing::{debug, warn};
13
14/// Performs a multi-challenge handshake with a client
15///
16/// This function implements the server side of the new flexible challenge protocol:
17/// 1. Sends multiple challenges sequentially (ML-DSA, proof-of-work, etc.)
18/// 2. Receives and verifies each challenge response
19/// 3. Returns the set of successfully verified ML-DSA public keys
20///
21/// The server can send multiple different challenge types, and the client must
22/// respond to each one. The handshake continues until all challenges are completed
23/// or a challenge fails.
24///
25/// # Arguments
26///
27/// * `send` - Stream for sending data to the client
28/// * `recv` - Stream for receiving data from the client  
29/// * `server_keypair` - Server's keypair for signing the challenge nonce
30///
31/// # Returns
32///
33/// A `BTreeSet` of successfully verified public keys (as encoded bytes)
34///
35/// # Errors
36///
37/// Returns an error if:
38/// - Network I/O fails
39/// - Serialization/deserialization fails
40/// - Any challenge fails verification
41/// - Client response is malformed or too large
42///
43/// # Example
44///
45/// ```rust
46/// use zoe_relay::challenge::perform_multi_challenge_handshake;
47///
48/// let verified_keys = perform_multi_challenge_handshake(
49///     send_stream,
50///     recv_stream,
51///     &server_keypair
52/// ).await?;
53///
54/// debug!("Verified {} keys after multi-challenge handshake", verified_keys.len());
55/// ```
56pub async fn perform_multi_challenge_handshake(
57    mut send: SendStream,
58    mut recv: RecvStream,
59    server_keypair: &KeyPair,
60) -> Result<HashSet<VerifyingKey>> {
61    debug!("🔐 Starting multi-challenge handshake");
62    send_result(&mut send, &ZoeChallengeResult::Next).await?;
63
64    // Challenge 1: Key proof
65    debug!("📝 Sending key challenge");
66    let key_challenge = generate_key_challenge(server_keypair)?;
67    debug!("🔧 Generated challenge, sending to client...");
68    send_challenge(
69        &mut send,
70        &ZoeChallenge::Key(Box::new(key_challenge.clone())),
71    )
72    .await?;
73    debug!("✅ Challenge sent, waiting for client response...");
74
75    // Receive key response
76    debug!("📥 Waiting to receive key response from client...");
77    let key_response = receive_key_response(&mut recv).await?;
78    debug!(
79        "✅ Received key response with {} proofs",
80        key_response.key_proofs.len()
81    );
82
83    // Verify key proofs
84    let (keys, _key_result) = verify_key_proofs(key_response, &key_challenge)?;
85
86    if keys.is_empty() {
87        // Key challenge failed - send rejection and close
88        let result = ZoeChallengeResult::Rejected(ZoeChallengeRejection::ChallengeIncomplete);
89        send_result(&mut send, &result).await?;
90        return Err(anyhow::anyhow!("Key challenge failed: no valid key proofs"));
91    }
92
93    //  We can add more challenge types here (proof-of-work, etc.)
94    // For now, we only have ML-DSA, so we send Accepted
95
96    debug!("✅ All challenges completed successfully");
97    let result = ZoeChallengeResult::Accepted;
98    send_result(&mut send, &result).await?;
99
100    debug!(
101        "✅ Multi-challenge handshake completed. Verified {} keys",
102        keys.len()
103    );
104
105    Ok(keys)
106}
107
108/// Generates a new key challenge with a random nonce
109///
110/// The challenge includes:
111/// - A cryptographically random 32-byte nonce
112/// - The server's signature over the nonce (for server authentication)
113/// - An expiration timestamp (current time + timeout)
114///
115/// # Arguments
116///
117/// * `server_keypair` - Server's keypair for signing the nonce
118///
119/// # Returns
120///
121/// A `KeyChallenge` containing the challenge data
122pub fn generate_key_challenge(server_keypair: &KeyPair) -> Result<KeyChallenge> {
123    let mut nonce = [0u8; 32];
124    rand::thread_rng().fill_bytes(&mut nonce);
125
126    let expires_at =
127        SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() + DEFAULT_CHALLENGE_TIMEOUT_SECS;
128
129    // Server signs the nonce to prove its identity
130    let server_signature = server_keypair.sign(&nonce);
131
132    let challenge_data = KeyChallenge {
133        nonce,
134        signature: server_signature,
135        expires_at,
136    };
137
138    debug!(
139        "Generated key challenge with nonce: {} expires at: {}",
140        hex::encode(&nonce[..8]),
141        expires_at
142    );
143
144    Ok(challenge_data)
145}
146
147/// Sends a challenge to the client over the stream
148///
149/// Serializes the challenge using postcard and sends it with a length prefix.
150///
151/// # Arguments
152///
153/// * `send` - Stream to send the challenge on
154/// * `challenge` - Challenge to send
155pub async fn send_challenge(send: &mut SendStream, challenge: &ZoeChallenge) -> Result<()> {
156    let challenge_bytes = postcard::to_stdvec(challenge)?;
157
158    debug!("Sending challenge ({} bytes)", challenge_bytes.len());
159
160    // Send length prefix (4 bytes, big endian)
161    send.write_u32(challenge_bytes.len() as u32).await?;
162
163    // Send challenge data
164    send.write_all(&challenge_bytes).await?;
165
166    Ok(())
167}
168
169/// Receives a key challenge response from the client
170///
171/// Reads the response with length prefix and deserializes it directly
172/// as a KeyResponse (no wrapper enum).
173///
174/// # Arguments
175///
176/// * `recv` - Stream to receive the response from
177///
178/// # Returns
179///
180/// The parsed `KeyResponse` from the client
181pub async fn receive_key_response(recv: &mut RecvStream) -> Result<KeyResponse> {
182    // Read length prefix
183    let response_len = recv.read_u32().await? as usize;
184
185    if response_len > MAX_PACKAGE_SIZE {
186        return Err(anyhow::anyhow!(
187            "Response too large: {} bytes (max: {})",
188            response_len,
189            MAX_PACKAGE_SIZE
190        ));
191    }
192
193    debug!("Receiving response ({} bytes)", response_len);
194
195    // Read response data
196    let mut response_buf = vec![0u8; response_len];
197    recv.read_exact(&mut response_buf).await?;
198
199    // Parse response directly as KeyResponse
200    let response: KeyResponse = postcard::from_bytes(&response_buf)?;
201
202    debug!(
203        "Received key response with {} key proofs",
204        response.key_proofs.len()
205    );
206    Ok(response)
207}
208
209/// Verifies all key proofs in a response
210///
211/// Each key proof is verified independently. The function continues even if some
212/// proofs fail, collecting all successful verifications.
213///
214/// # Arguments
215///
216/// * `response` - Client's response containing key proofs
217/// * `challenge` - Original key challenge (needed for signature verification)
218///
219/// # Returns
220///
221/// A tuple containing:
222/// - Set of successfully verified public keys (as encoded bytes)
223/// - Key specific result indicating which proofs succeeded/failed
224pub fn verify_key_proofs(
225    response: KeyResponse,
226    challenge: &KeyChallenge,
227) -> Result<(HashSet<VerifyingKey>, KeyResult)> {
228    let challenge_data = challenge;
229
230    // Check if challenge has expired
231    let current_time = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
232    if current_time > challenge_data.expires_at {
233        warn!(
234            "Challenge expired: current={}, expires={}",
235            current_time, challenge_data.expires_at
236        );
237        return Ok((HashSet::new(), KeyResult::AllFailed));
238    }
239
240    let mut verified_keys = HashSet::new();
241    let mut failed_indices = Vec::new();
242
243    // Prepare signature data: just the nonce (clients sign the nonce)
244    let signature_data = challenge_data.nonce.to_vec();
245    let total_key_proofs = response.key_proofs.len();
246    debug!("Verifying {} key proofs", total_key_proofs);
247
248    for (index, key_proof) in response.key_proofs.into_iter().enumerate() {
249        match verify_single_key_proof(&key_proof, &signature_data) {
250            Ok(()) => {
251                debug!(
252                    "✅ Verified key proof {}: {}",
253                    index,
254                    hex::encode(&key_proof.public_key.encode()[..8])
255                );
256                verified_keys.insert(key_proof.public_key);
257            }
258            Err(e) => {
259                failed_indices.push(index);
260                warn!("❌ Key proof {} failed: {}", index, e);
261            }
262        }
263    }
264
265    let result = if failed_indices.is_empty() {
266        KeyResult::AllValid
267    } else if verified_keys.is_empty() {
268        KeyResult::AllFailed
269    } else {
270        KeyResult::PartialFailure { failed_indices }
271    };
272
273    debug!(
274        "Verification complete: {}/{} keys verified",
275        verified_keys.len(),
276        total_key_proofs
277    );
278
279    Ok((verified_keys, result))
280}
281
282/// Verifies a single key proof
283///
284/// Uses the public key and signature from the key proof to verify the signature over
285/// the challenge data.
286///
287/// # Arguments
288///
289/// * `key_proof` - The key proof to verify
290/// * `signature_data` - The data that should have been signed (nonce)
291///
292/// # Returns
293///
294/// `Ok(())` if verification succeeds, `Err` with details if it fails
295fn verify_single_key_proof(key_proof: &KeyProof, signature_data: &[u8]) -> Result<()> {
296    // Use the public key and signature directly from the key proof
297    let verifying_key = &key_proof.public_key;
298    let signature = &key_proof.signature;
299
300    // Verify the signature - returns Result<bool, _>
301    verifying_key
302        .verify(signature_data, signature)
303        .map_err(|e| anyhow::anyhow!("Signature verification failed: {}", e))
304}
305
306/// Sends the challenge result back to the client
307///
308/// # Arguments
309///
310/// * `send` - Stream to send the result on
311/// * `result` - Challenge result to send (Accepted, Next, Rejected, Error)
312pub async fn send_result(send: &mut SendStream, result: &ZoeChallengeResult) -> Result<()> {
313    let result_bytes = postcard::to_stdvec(result)?;
314
315    debug!("Sending result ({} bytes)", result_bytes.len());
316
317    // Send length prefix (4 bytes, big endian)
318    send.write_u32(result_bytes.len() as u32).await?;
319
320    // Send result data
321    send.write_all(&result_bytes).await?;
322
323    Ok(())
324}
325
326/// Create key proofs for a challenge response (used in tests)
327///
328/// This function creates key proofs for the given keypairs in response to a challenge.
329/// It's primarily used in integration tests.
330pub fn create_key_proofs(challenge: &KeyChallenge, keypairs: &[&KeyPair]) -> Result<KeyResponse> {
331    let mut key_proofs = Vec::new();
332
333    // Construct the signature data (just the nonce)
334    let signature_data = challenge.nonce.to_vec();
335
336    // Create a proof for each keypair
337    for keypair in keypairs {
338        let signature = keypair.sign(&signature_data);
339        let verifying_key = keypair.public_key();
340
341        let key_proof = KeyProof {
342            public_key: verifying_key,
343            signature,
344        };
345        key_proofs.push(key_proof);
346    }
347
348    Ok(KeyResponse { key_proofs })
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use crate::KeyPair;
355
356    #[test]
357    fn test_key_challenge_generation() {
358        let server_keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
359        let challenge = generate_key_challenge(&server_keypair).unwrap();
360
361        // The signature field contains the server's signature over the nonce
362        assert!(
363            challenge.expires_at
364                > SystemTime::now()
365                    .duration_since(UNIX_EPOCH)
366                    .unwrap()
367                    .as_secs()
368        );
369    }
370
371    #[test]
372    fn test_single_key_proof_verification() {
373        // Generate test keys
374        let client_keypair = KeyPair::generate(&mut rand::thread_rng());
375
376        // Create signature data (just the nonce)
377        let nonce = [42u8; 32];
378        let signature_data = nonce.to_vec();
379
380        // Create signature
381        let signature = client_keypair.sign(&signature_data);
382        let verifying_key = client_keypair.public_key();
383
384        // Create key proof
385        let key_proof = KeyProof {
386            public_key: verifying_key,
387            signature,
388        };
389
390        // Verify proof
391        let result = verify_single_key_proof(&key_proof, &signature_data);
392        assert!(result.is_ok());
393    }
394
395    #[test]
396    fn test_invalid_signature_fails() {
397        // Generate test keys
398        let client_keypair = KeyPair::generate(&mut rand::thread_rng());
399
400        // Create signature data
401        let signature_data = b"test data";
402
403        // Create signature over different data
404        let wrong_signature = client_keypair.sign(b"wrong data");
405        let verifying_key = client_keypair.public_key();
406
407        // Create key proof with wrong signature
408        let key_proof = KeyProof {
409            public_key: verifying_key,
410            signature: wrong_signature,
411        };
412
413        // Verify proof should fail
414        let result = verify_single_key_proof(&key_proof, signature_data);
415        assert!(result.is_err());
416    }
417
418    #[test]
419    fn test_generate_key_challenge() {
420        let server_keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
421
422        let challenge = generate_key_challenge(&server_keypair).unwrap();
423
424        // Check that nonce is 32 bytes
425        assert_eq!(challenge.nonce.len(), 32);
426
427        // Check that signature verifies
428        let server_public_key = server_keypair.public_key();
429        assert!(server_public_key
430            .verify(&challenge.nonce, &challenge.signature)
431            .is_ok());
432
433        // Check that expiration is in the future
434        let now = std::time::SystemTime::now()
435            .duration_since(std::time::UNIX_EPOCH)
436            .unwrap()
437            .as_secs();
438        assert!(challenge.expires_at > now);
439    }
440
441    #[test]
442    fn test_create_key_proofs() {
443        let server_keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
444        let client_keypair1 = KeyPair::generate(&mut rand::thread_rng());
445        let client_keypair2 = KeyPair::generate(&mut rand::thread_rng());
446
447        let challenge = generate_key_challenge(&server_keypair).unwrap();
448        let client_keys = vec![&client_keypair1, &client_keypair2];
449
450        let response = create_key_proofs(&challenge, &client_keys).unwrap();
451
452        // Should have proofs for both keys
453        assert_eq!(response.key_proofs.len(), 2);
454
455        // Each proof should verify
456        for (i, proof) in response.key_proofs.iter().enumerate() {
457            let result = verify_single_key_proof(proof, &challenge.nonce);
458            assert!(result.is_ok(), "Proof {i} should verify");
459        }
460    }
461
462    #[test]
463    fn test_verify_key_proofs() {
464        let server_keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
465        let client_keypair1 = KeyPair::generate(&mut rand::thread_rng());
466        let client_keypair2 = KeyPair::generate(&mut rand::thread_rng());
467
468        let challenge = generate_key_challenge(&server_keypair).unwrap();
469        let client_keys = vec![&client_keypair1, &client_keypair2];
470
471        let response = create_key_proofs(&challenge, &client_keys).unwrap();
472
473        let (verified_keys, result) = verify_key_proofs(response, &challenge).unwrap();
474
475        // Should verify both keys
476        assert_eq!(verified_keys.len(), 2);
477        assert!(matches!(result, KeyResult::AllValid));
478
479        // Verified keys should match the client public keys
480        assert!(verified_keys.contains(&client_keypair1.public_key()));
481        assert!(verified_keys.contains(&client_keypair2.public_key()));
482    }
483
484    #[test]
485    fn test_verify_key_proofs_partial_failure() {
486        let server_keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
487        let client_keypair1 = KeyPair::generate(&mut rand::thread_rng());
488        let client_keypair2 = KeyPair::generate(&mut rand::thread_rng());
489
490        let challenge = generate_key_challenge(&server_keypair).unwrap();
491
492        // Create one valid proof and one invalid proof
493        let valid_signature = client_keypair1.sign(&challenge.nonce);
494        let invalid_signature = client_keypair2.sign(b"wrong data");
495
496        let response = KeyResponse {
497            key_proofs: vec![
498                KeyProof {
499                    public_key: client_keypair1.public_key(),
500                    signature: valid_signature,
501                },
502                KeyProof {
503                    public_key: client_keypair2.public_key(),
504                    signature: invalid_signature,
505                },
506            ],
507        };
508
509        let (verified_keys, result) = verify_key_proofs(response, &challenge).unwrap();
510
511        // Should verify only one key
512        assert_eq!(verified_keys.len(), 1);
513        assert!(
514            matches!(result, KeyResult::PartialFailure { failed_indices } if failed_indices == vec![1])
515        );
516
517        // Only the valid key should be verified
518        assert!(verified_keys.contains(&client_keypair1.public_key()));
519    }
520
521    #[test]
522    fn test_challenge_serialization_roundtrip() {
523        let server_keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
524        let challenge = generate_key_challenge(&server_keypair).unwrap();
525
526        // Test serialization and deserialization
527        let serialized = postcard::to_stdvec(&challenge).unwrap();
528        let deserialized: KeyChallenge = postcard::from_bytes(&serialized).unwrap();
529
530        // Should be identical
531        assert_eq!(challenge.nonce, deserialized.nonce);
532        assert_eq!(challenge.expires_at, deserialized.expires_at);
533        // Note: Signature comparison would need custom implementation
534    }
535
536    #[test]
537    fn test_response_serialization_roundtrip() {
538        let client_keypair = KeyPair::generate(&mut rand::thread_rng());
539        let signature = client_keypair.sign(b"test data");
540
541        let response = KeyResponse {
542            key_proofs: vec![KeyProof {
543                public_key: client_keypair.public_key(),
544                signature,
545            }],
546        };
547
548        // Test serialization and deserialization
549        let serialized = postcard::to_stdvec(&response).unwrap();
550        let deserialized: KeyResponse = postcard::from_bytes(&serialized).unwrap();
551
552        // Should be identical
553        assert_eq!(response.key_proofs.len(), deserialized.key_proofs.len());
554    }
555}