zoe_wire_protocol/challenge/
client.rs

1use super::MAX_PACKAGE_SIZE;
2use crate::{
3    keys::*, KeyChallenge, KeyProof, KeyResponse, ZoeChallenge, ZoeChallengeResult,
4    ZoeChallengeWarning,
5};
6use anyhow::Result;
7use quinn::{RecvStream, SendStream};
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tracing::{debug, warn};
10
11/// Performs the client side of the challenge-response handshake
12///
13/// This function implements the client side of the challenge protocol:
14/// 1. Receives a challenge from the server
15/// 2. Creates proofs for all provided keys
16/// 3. Sends the response to the server
17/// 4. Receives and processes the verification result
18///
19/// # Arguments
20///
21/// * `send` - Stream for sending data to the server
22/// * `recv` - Stream for receiving data from the server
23/// * `key_pairs` - Slice of signing keys to prove possession of
24///
25/// # Returns
26///
27/// The number of keys that were successfully verified by the server
28///
29/// # Errors
30///
31/// Returns an error if:
32/// - Network I/O fails
33/// - Serialization/deserialization fails
34/// - All key proofs fail verification
35/// - Server response is malformed or too large
36///
37/// # Example
38///
39/// ```rust
40/// use zoe_client::challenge::perform_client_challenge_handshake;
41///
42/// let verified_count = perform_client_challenge_handshake(
43///     send_stream,
44///     recv_stream,
45///     &[&personal_key, &work_key]
46/// ).await?;
47///
48/// debug!("Successfully verified {} out of {} keys", verified_count, 2);
49/// ```
50pub async fn perform_client_challenge_handshake(
51    mut send: SendStream,
52    mut recv: RecvStream,
53    server_public_key: &VerifyingKey,
54    key_pairs: &[&KeyPair],
55) -> Result<(usize, Vec<ZoeChallengeWarning>)> {
56    debug!("🔐 Starting client-side multi-challenge handshake");
57
58    if key_pairs.is_empty() {
59        return Err(anyhow::anyhow!("No keys provided for handshake"));
60    }
61
62    debug!("Proving possession of {} keys", key_pairs.len());
63
64    let verified_count = key_pairs.len();
65    let mut warnings = Vec::new();
66
67    loop {
68        // Step 3: Receive result from server
69        let result = receive_result(&mut recv).await?;
70
71        match result {
72            ZoeChallengeResult::Accepted => {
73                debug!("✅ All challenges completed successfully");
74                break;
75            }
76            ZoeChallengeResult::Warning(warning) => {
77                warn!("🔔 Warning received: {warning:?}");
78                warnings.push(warning);
79                continue; // we need to read for the next result
80            }
81            ZoeChallengeResult::Next => {
82                debug!("➡️ Challenge accepted, waiting for next challenge");
83                // Continue to next iteration to receive next challenge
84            }
85            ZoeChallengeResult::Rejected(rejection) => {
86                return Err(anyhow::anyhow!("Challenge rejected: {rejection:?}"));
87            }
88            ZoeChallengeResult::Error(error) => {
89                return Err(anyhow::anyhow!("Server error: {error}"));
90            }
91            ZoeChallengeResult::Unknown { discriminant, .. } => {
92                return Err(anyhow::anyhow!("Unsupported result type: {discriminant}"));
93            }
94        }
95        // Step 1: Receive challenge from server
96        debug!("📥 Waiting to receive challenge from server...");
97        let challenge = receive_challenge(&mut recv).await?;
98        debug!("✅ Received challenge from server");
99
100        // Step 2: Handle different challenge types
101        match challenge {
102            ZoeChallenge::Key(key_challenge) => {
103                debug!("📝 Received key challenge");
104
105                // check the signature
106                let nonce = key_challenge.nonce;
107                let signature = &key_challenge.signature;
108                debug!("🔍 Verifying server signature on challenge nonce...");
109                if server_public_key.verify(&nonce, signature).is_err() {
110                    return Err(anyhow::anyhow!(
111                        "Invalid signature in challenge. Person-in-the-middle attack?"
112                    ));
113                }
114                debug!("✅ Server signature verified");
115
116                // Create proofs for all keys
117                debug!("🔧 Creating key proofs for {} keys...", key_pairs.len());
118                let response = create_key_proofs(&key_challenge, key_pairs)?;
119                debug!("✅ Created {} key proofs", response.key_proofs.len());
120
121                // Send response directly (no wrapper enum)
122                debug!("📤 Sending key response to server...");
123                send_key_response(&mut send, &response).await?;
124                debug!("✅ Key response sent");
125            }
126            ZoeChallenge::Unknown { discriminant, .. } => {
127                return Err(anyhow::anyhow!(
128                    "Unsupported challenge type: {discriminant}"
129                ));
130            }
131        }
132    }
133
134    debug!(
135        "✅ Client-side multi-challenge handshake completed. {} keys verified",
136        verified_count
137    );
138
139    Ok((verified_count, warnings))
140}
141
142/// Receives a challenge from the server
143///
144/// Reads the challenge with length prefix and deserializes it.
145///
146/// # Arguments
147///
148/// * `recv` - Stream to receive the challenge from
149///
150/// # Returns
151///
152/// The parsed challenge from the server
153async fn receive_challenge(recv: &mut RecvStream) -> Result<ZoeChallenge> {
154    // Read length prefix
155    let challenge_len = recv.read_u32().await? as usize;
156
157    if challenge_len > MAX_PACKAGE_SIZE {
158        return Err(anyhow::anyhow!(
159            "Challenge too large: {} bytes (max: {})",
160            challenge_len,
161            MAX_PACKAGE_SIZE
162        ));
163    }
164
165    debug!("Receiving challenge ({} bytes)", challenge_len);
166
167    // Read challenge data
168    let mut challenge_buf = vec![0u8; challenge_len];
169    recv.read_exact(&mut challenge_buf).await?;
170
171    // Parse challenge
172    let challenge: ZoeChallenge = postcard::from_bytes(&challenge_buf)?;
173
174    debug!("Received challenge from server");
175    Ok(challenge)
176}
177
178/// Creates key proofs for all provided keys
179///
180/// For each key, creates a signature over (nonce || server_public_key) and
181/// packages it with the corresponding public key.
182///
183/// # Arguments
184///
185/// * `challenge` - Challenge received from server
186/// * `key_pairs` - Keys to create proofs for
187///
188/// # Returns
189///
190/// A response containing all key proofs
191pub fn create_key_proofs(challenge: &KeyChallenge, key_pairs: &[&KeyPair]) -> Result<KeyResponse> {
192    let challenge_data = challenge;
193
194    // Prepare signature data: just the nonce (as updated in the protocol)
195    let signature_data = challenge_data.nonce.to_vec();
196
197    debug!("Creating proofs for {} keys", key_pairs.len());
198
199    let mut key_proofs = Vec::new();
200
201    for (index, keypair) in key_pairs.iter().enumerate() {
202        // Create signature over challenge data
203        let signature = keypair.sign(&signature_data);
204        let verifying_key = keypair.public_key();
205
206        // Create key proof
207        let key_proof = KeyProof {
208            public_key: verifying_key,
209            signature,
210        };
211
212        key_proofs.push(key_proof);
213        debug!("Created proof for key {}", index);
214    }
215
216    let response = KeyResponse { key_proofs };
217    Ok(response)
218}
219
220/// Sends the key challenge response to the server
221///
222/// Serializes the response using postcard and sends it with a length prefix.
223///
224/// # Arguments
225///
226/// * `send` - Stream to send the response on
227/// * `response` - Key response to send
228async fn send_key_response(send: &mut SendStream, response: &KeyResponse) -> Result<()> {
229    let response_bytes = postcard::to_stdvec(response)?;
230
231    debug!("Sending response ({} bytes)", response_bytes.len());
232
233    // Send length prefix (4 bytes, big endian)
234    send.write_u32(response_bytes.len() as u32).await?;
235
236    // Send response data
237    send.write_all(&response_bytes).await?;
238
239    Ok(())
240}
241
242/// Receives the verification result from the server
243///
244/// Reads the result with length prefix and deserializes it.
245///
246/// # Arguments
247///
248/// * `recv` - Stream to receive the result from
249///
250/// # Returns
251///
252/// The parsed verification result from the server
253async fn receive_result(recv: &mut RecvStream) -> Result<ZoeChallengeResult> {
254    // Read length prefix
255    let result_len = recv.read_u32().await? as usize;
256
257    if result_len > MAX_PACKAGE_SIZE {
258        return Err(anyhow::anyhow!(
259            "Result too large: {} bytes (max: {})",
260            result_len,
261            MAX_PACKAGE_SIZE
262        ));
263    }
264
265    debug!("Receiving result ({} bytes)", result_len);
266
267    // Read result data
268    let mut result_buf = vec![0u8; result_len];
269    recv.read_exact(&mut result_buf).await?;
270
271    // Parse result
272    let result: ZoeChallengeResult = postcard::from_bytes(&result_buf)?;
273
274    debug!("Received verification result from server");
275    Ok(result)
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use crate::{KeyPair, ZoeChallengeRejection};
282    use anyhow::Result;
283
284    /// Process a challenge result and return the verified count or an error
285    #[allow(dead_code)]
286    fn process_result(result: &ZoeChallengeResult, expected_count: usize) -> Result<usize> {
287        match result {
288            ZoeChallengeResult::Accepted => Ok(expected_count),
289            ZoeChallengeResult::Next => Ok(expected_count),
290            ZoeChallengeResult::Warning(warning) => {
291                Err(anyhow::anyhow!(format!("Warning received: {warning:?}")))
292            }
293            ZoeChallengeResult::Rejected(rejection) => Err(anyhow::anyhow!(format!(
294                "Challenge rejected: {rejection:?}"
295            ))),
296            ZoeChallengeResult::Error(error) => {
297                Err(anyhow::anyhow!(format!("Challenge error: {error}")))
298            }
299            ZoeChallengeResult::Unknown { discriminant, .. } => Err(anyhow::anyhow!(format!(
300                "Unknown challenge result: {discriminant}"
301            ))),
302        }
303    }
304
305    #[test]
306    fn test_create_key_proofs() {
307        // Generate test keys
308        let keypair1 = KeyPair::generate(&mut rand::thread_rng());
309        let keypair2 = KeyPair::generate(&mut rand::thread_rng());
310
311        // Create test challenge
312        let server_keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
313        let nonce = [42u8; 32];
314        let server_signature = server_keypair.sign(&nonce);
315
316        let challenge_data = KeyChallenge {
317            nonce,
318            signature: server_signature,
319            expires_at: 1234567890,
320        };
321
322        // Create proofs
323        let keys = vec![&keypair1, &keypair2];
324        let response = create_key_proofs(&challenge_data, &keys).unwrap();
325
326        // Verify response structure
327        assert_eq!(response.key_proofs.len(), 2);
328
329        // Verify each proof has the expected structure
330        for (i, _proof) in response.key_proofs.iter().enumerate() {
331            // The public_key and signature are now proper types, not Vec<u8>
332            // We can verify they exist by checking the proof structure
333            debug!("Key {} proof created successfully", i);
334
335            // Verify we can create a signature with the key
336            let test_data = b"test";
337            let test_sig = keys[i].sign(test_data);
338            let pub_key = keys[i].public_key();
339
340            // Verify the signature works
341            pub_key.verify(test_data, &test_sig).unwrap();
342        }
343    }
344
345    #[test]
346    fn test_process_result_accepted() {
347        let result = ZoeChallengeResult::Accepted;
348        let verified_count = process_result(&result, 3).unwrap();
349        assert_eq!(verified_count, 3);
350    }
351
352    #[test]
353    fn test_process_result_next() {
354        let result = ZoeChallengeResult::Next;
355        let verified_count = process_result(&result, 3).unwrap();
356        assert_eq!(verified_count, 3);
357    }
358
359    #[test]
360    fn test_process_result_rejected() {
361        let result = ZoeChallengeResult::Rejected(ZoeChallengeRejection::ChallengeFailed);
362        let result = process_result(&result, 3);
363        assert!(result.is_err());
364    }
365
366    #[test]
367    fn test_create_key_proofs_client() {
368        let server_keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
369        let client_keypair1 = KeyPair::generate(&mut rand::thread_rng());
370        let client_keypair2 = KeyPair::generate(&mut rand::thread_rng());
371
372        // Create a challenge
373        let nonce = [42u8; 32];
374        let signature = server_keypair.sign(&nonce);
375        let challenge = KeyChallenge {
376            nonce,
377            signature,
378            expires_at: std::time::SystemTime::now()
379                .duration_since(std::time::UNIX_EPOCH)
380                .unwrap()
381                .as_secs()
382                + 300,
383        };
384
385        let client_keys = vec![&client_keypair1, &client_keypair2];
386        let response = create_key_proofs(&challenge, &client_keys).unwrap();
387
388        // Should have proofs for both keys
389        assert_eq!(response.key_proofs.len(), 2);
390
391        // Each proof should be valid (we can verify this by checking the signature)
392        for (i, proof) in response.key_proofs.iter().enumerate() {
393            let expected_key = &client_keys[i];
394            assert_eq!(
395                proof.public_key.encode(),
396                expected_key.public_key().encode()
397            );
398
399            // Verify the signature
400            assert!(proof
401                .public_key
402                .verify(&challenge.nonce, &proof.signature)
403                .is_ok());
404        }
405    }
406
407    #[test]
408    fn test_send_key_response_serialization() {
409        let client_keypair = KeyPair::generate(&mut rand::thread_rng());
410        let signature = client_keypair.sign(b"test data");
411
412        let response = KeyResponse {
413            key_proofs: vec![KeyProof {
414                public_key: client_keypair.public_key(),
415                signature,
416            }],
417        };
418
419        // Test that we can serialize the response (this is what send_key_response does internally)
420        let serialized = postcard::to_stdvec(&response).unwrap();
421        assert!(!serialized.is_empty());
422
423        // Test that we can deserialize it back
424        let deserialized: KeyResponse = postcard::from_bytes(&serialized).unwrap();
425        assert_eq!(response.key_proofs.len(), deserialized.key_proofs.len());
426    }
427
428    #[test]
429    fn test_challenge_signature_verification() {
430        let server_keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
431        let server_public_key = server_keypair.public_key();
432
433        // Create a valid challenge
434        let nonce = [42u8; 32];
435        let signature = server_keypair.sign(&nonce);
436        let challenge = KeyChallenge {
437            nonce,
438            signature,
439            expires_at: std::time::SystemTime::now()
440                .duration_since(std::time::UNIX_EPOCH)
441                .unwrap()
442                .as_secs()
443                + 300,
444        };
445
446        // Signature should verify
447        assert!(server_public_key
448            .verify(&challenge.nonce, &challenge.signature)
449            .is_ok());
450
451        // Wrong signature should fail
452        let wrong_signature = server_keypair.sign(b"wrong data");
453        let bad_challenge = KeyChallenge {
454            nonce,
455            signature: wrong_signature,
456            expires_at: challenge.expires_at,
457        };
458
459        assert!(server_public_key
460            .verify(&bad_challenge.nonce, &bad_challenge.signature)
461            .is_err());
462    }
463
464    #[test]
465    fn test_empty_key_list() {
466        let server_keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
467
468        let nonce = [42u8; 32];
469        let signature = server_keypair.sign(&nonce);
470        let challenge = KeyChallenge {
471            nonce,
472            signature,
473            expires_at: std::time::SystemTime::now()
474                .duration_since(std::time::UNIX_EPOCH)
475                .unwrap()
476                .as_secs()
477                + 300,
478        };
479
480        let empty_keys: Vec<&KeyPair> = vec![];
481        let response = create_key_proofs(&challenge, &empty_keys).unwrap();
482
483        // Should have no proofs
484        assert_eq!(response.key_proofs.len(), 0);
485    }
486}