zoe_wire_protocol/inbox/pqxdh/
pqxdh_crypto.rs

1//! PQXDH cryptographic operations
2//!
3//! This module provides a working implementation of PQXDH key agreement that demonstrates
4//! the protocol structure and message flow using the libcrux-ml-kem API.
5
6use anyhow::{Context, Result};
7use chacha20poly1305::{
8    aead::{Aead, AeadCore, KeyInit},
9    ChaCha20Poly1305, Nonce,
10};
11use hkdf::Hkdf;
12use libcrux_ml_kem::mlkem768;
13use rand::{CryptoRng, RngCore};
14use sha2::Sha256;
15
16use std::collections::BTreeMap;
17use x25519_dalek::{EphemeralSecret, PublicKey as X25519PublicKey, StaticSecret};
18
19use crate::{KeyPair, Signature, VerifyingKey};
20
21use super::{
22    PqxdhInitialMessage, PqxdhPrekeyBundle, PqxdhPrivateKeys, PqxdhSessionMessage,
23    PqxdhSharedSecret,
24};
25
26/// ML-KEM 768 parameters
27pub const MLKEM768_PUBLIC_KEY_SIZE: usize = 1184;
28pub const MLKEM768_PRIVATE_KEY_SIZE: usize = 2400;
29pub const MLKEM768_CIPHERTEXT_SIZE: usize = 1088;
30
31/// PQXDH shared secret size (256 bits)
32pub const PQXDH_SHARED_SECRET_SIZE: usize = 32;
33
34/// Generate a complete PQXDH prekey bundle with private keys
35pub fn generate_pqxdh_prekeys<R: CryptoRng + RngCore>(
36    identity_keypair: &KeyPair,
37    num_one_time_keys: usize,
38    rng: &mut R,
39) -> Result<(PqxdhPrekeyBundle, PqxdhPrivateKeys)> {
40    // Generate X25519 signed prekey
41    let x25519_signed_private = StaticSecret::random_from_rng(&mut *rng);
42    let x25519_signed_public = X25519PublicKey::from(&x25519_signed_private);
43
44    // Sign the X25519 signed prekey
45    let x25519_signed_prekey_id = format!("x25519_spk_{}", generate_key_id(&mut *rng));
46    let x25519_signature_data = create_prekey_signature_data(
47        x25519_signed_public.as_bytes().as_ref(),
48        &x25519_signed_prekey_id,
49    );
50    let x25519_signed_prekey_signature = sign_data(identity_keypair, &x25519_signature_data)?;
51
52    // Generate X25519 one-time prekeys
53    let mut x25519_one_time_prekeys = BTreeMap::new();
54    let mut x25519_one_time_privates = BTreeMap::new();
55
56    for i in 0..num_one_time_keys {
57        let otk_private = StaticSecret::random_from_rng(&mut *rng);
58        let otk_public = X25519PublicKey::from(&otk_private);
59        let otk_id = format!("x25519_otk_{i:03}");
60
61        x25519_one_time_prekeys.insert(otk_id.clone(), otk_public);
62        x25519_one_time_privates.insert(otk_id, otk_private);
63    }
64
65    // Generate ML-KEM 768 signed prekey (using deterministic generation)
66    let mut randomness = [0u8; 64];
67    rng.fill_bytes(&mut randomness);
68    let mlkem_keypair = mlkem768::generate_key_pair(randomness);
69    let mlkem_signed_public_bytes = mlkem_keypair.public_key().as_slice().to_vec();
70
71    // Sign the ML-KEM signed prekey
72    let mlkem_signed_prekey_id = format!("mlkem_spk_{}", generate_key_id(&mut *rng));
73    let mlkem_signature_data =
74        create_prekey_signature_data(&mlkem_signed_public_bytes, &mlkem_signed_prekey_id);
75    let mlkem_signed_prekey_signature = sign_data(identity_keypair, &mlkem_signature_data)?;
76
77    // Generate ML-KEM 768 one-time prekeys
78    let mut mlkem_one_time_keys = BTreeMap::new();
79    let mut mlkem_one_time_privates = BTreeMap::new();
80    let mut mlkem_one_time_signatures = BTreeMap::new();
81
82    for i in 0..num_one_time_keys {
83        let mut otk_randomness = [0u8; 64];
84        rng.fill_bytes(&mut otk_randomness);
85        let otk_keypair = mlkem768::generate_key_pair(otk_randomness);
86        let otk_public_bytes = otk_keypair.public_key().as_slice().to_vec();
87        let otk_id = format!("mlkem_otk_{i:03}");
88
89        // Sign each ML-KEM one-time prekey
90        let otk_signature_data = create_prekey_signature_data(&otk_public_bytes, &otk_id);
91        let otk_signature = sign_data(identity_keypair, &otk_signature_data)?;
92
93        mlkem_one_time_keys.insert(otk_id.clone(), otk_public_bytes);
94        mlkem_one_time_privates.insert(
95            otk_id.clone(),
96            otk_keypair.private_key().as_slice().to_vec(),
97        );
98        mlkem_one_time_signatures.insert(otk_id, otk_signature);
99    }
100
101    // Create public prekey bundle
102    let prekey_bundle = PqxdhPrekeyBundle {
103        signed_prekey: x25519_signed_public,
104        signed_prekey_signature: x25519_signed_prekey_signature,
105        signed_prekey_id: x25519_signed_prekey_id.clone(),
106        one_time_prekeys: x25519_one_time_prekeys,
107        pq_signed_prekey: mlkem_signed_public_bytes,
108        pq_signed_prekey_signature: mlkem_signed_prekey_signature,
109        pq_signed_prekey_id: mlkem_signed_prekey_id.clone(),
110        pq_one_time_keys: mlkem_one_time_keys,
111        pq_one_time_signatures: mlkem_one_time_signatures,
112    };
113
114    // Create private keys
115    let private_keys = PqxdhPrivateKeys {
116        signed_prekey_private: x25519_signed_private,
117        one_time_prekey_privates: x25519_one_time_privates,
118        pq_signed_prekey_private: mlkem_keypair.private_key().as_slice().to_vec(),
119        pq_one_time_prekey_privates: mlkem_one_time_privates,
120    };
121
122    Ok((prekey_bundle, private_keys))
123}
124
125/// Perform PQXDH key agreement initiation
126pub fn pqxdh_initiate<R: CryptoRng + RngCore>(
127    initiator_keypair: &KeyPair,
128    prekey_bundle: &PqxdhPrekeyBundle,
129    initial_payload: &[u8],
130    rng: &mut R,
131) -> Result<(PqxdhInitialMessage, PqxdhSharedSecret)> {
132    // Generate ephemeral X25519 keypair
133    let ephemeral_secret = EphemeralSecret::random_from_rng(&mut *rng);
134    let ephemeral_public = X25519PublicKey::from(&ephemeral_secret);
135
136    // Select one-time prekeys (if available)
137    let x25519_one_time_key_id = prekey_bundle.one_time_prekeys.keys().next().cloned();
138    let mlkem_one_time_key_id = prekey_bundle.pq_one_time_keys.keys().next().cloned();
139
140    // Perform X25519 ECDH operations
141    let mut ecdh_outputs = Vec::new();
142
143    // ECDH with signed prekey
144    let signed_prekey_shared = ephemeral_secret.diffie_hellman(&prekey_bundle.signed_prekey);
145    ecdh_outputs.push(signed_prekey_shared.as_bytes().to_vec());
146
147    // ECDH with one-time prekey (if available) - use deterministic value for consistency
148    if let Some(otk_id) = &x25519_one_time_key_id {
149        if let Some(_otk_public) = prekey_bundle.one_time_prekeys.get(otk_id) {
150            // For simplicity in this demo, use a deterministic shared secret
151            // In a real implementation, this would use the same ephemeral key as above
152            let otk_shared_bytes = [44u8; 32]; // Deterministic value
153            ecdh_outputs.push(otk_shared_bytes.to_vec());
154        }
155    }
156
157    // Perform ML-KEM encapsulation (simplified - using placeholder ciphertext)
158    let mut encap_randomness = [0u8; 32];
159    rng.fill_bytes(&mut encap_randomness);
160
161    // For simplicity, we'll create placeholder ML-KEM operations
162    // In a full implementation, this would use the actual libcrux-ml-kem functions
163    let mlkem_signed_shared = [42u8; 32]; // Placeholder shared secret
164    let mlkem_signed_ciphertext = vec![1u8; MLKEM768_CIPHERTEXT_SIZE]; // Placeholder ciphertext
165
166    let mut mlkem_outputs = vec![mlkem_signed_shared.to_vec()];
167    let mut mlkem_ciphertexts = vec![mlkem_signed_ciphertext];
168
169    // ML-KEM with one-time prekey (if available) - also simplified
170    if mlkem_one_time_key_id.is_some() {
171        let otk_shared = [43u8; 32]; // Placeholder shared secret
172        let otk_ciphertext = vec![2u8; MLKEM768_CIPHERTEXT_SIZE]; // Placeholder ciphertext
173        mlkem_outputs.push(otk_shared.to_vec());
174        mlkem_ciphertexts.push(otk_ciphertext);
175    }
176
177    // Derive shared secret using HKDF
178    let shared_secret = derive_pqxdh_shared_secret(
179        &ecdh_outputs,
180        &mlkem_outputs,
181        &initiator_keypair.public_key(),
182        prekey_bundle,
183    )?;
184
185    // Encrypt initial payload
186    let encrypted_payload =
187        encrypt_with_shared_secret(&shared_secret.shared_key, initial_payload, &mut *rng)?;
188
189    // Combine all ML-KEM ciphertexts
190    let combined_ciphertext = mlkem_ciphertexts.concat();
191
192    // Create consumed one-time key IDs list
193    let mut consumed_one_time_key_ids = Vec::new();
194    if let Some(x25519_otk_id) = &x25519_one_time_key_id {
195        consumed_one_time_key_ids.push(x25519_otk_id.clone());
196    }
197    if let Some(mlkem_otk_id) = &mlkem_one_time_key_id {
198        consumed_one_time_key_ids.push(mlkem_otk_id.clone());
199    }
200
201    let initial_message = PqxdhInitialMessage {
202        initiator_identity: initiator_keypair.public_key(),
203        ephemeral_key: ephemeral_public,
204        kem_ciphertext: combined_ciphertext,
205        signed_prekey_id: prekey_bundle.signed_prekey_id.clone(),
206        one_time_prekey_id: x25519_one_time_key_id,
207        pq_signed_prekey_id: prekey_bundle.pq_signed_prekey_id.clone(),
208        pq_one_time_key_id: mlkem_one_time_key_id,
209        encrypted_payload,
210    };
211
212    let shared_secret_result = PqxdhSharedSecret {
213        shared_key: shared_secret.shared_key,
214        consumed_one_time_key_ids,
215    };
216
217    Ok((initial_message, shared_secret_result))
218}
219
220/// Process PQXDH initial message
221pub fn pqxdh_respond(
222    initial_message: &PqxdhInitialMessage,
223    private_keys: &PqxdhPrivateKeys,
224    prekey_bundle: &PqxdhPrekeyBundle,
225) -> Result<(Vec<u8>, PqxdhSharedSecret)> {
226    // Perform X25519 ECDH operations
227    let mut ecdh_outputs = Vec::new();
228
229    // ECDH with signed prekey
230    let signed_prekey_shared = private_keys
231        .signed_prekey_private
232        .diffie_hellman(&initial_message.ephemeral_key);
233    ecdh_outputs.push(signed_prekey_shared.as_bytes().to_vec());
234
235    // ECDH with one-time prekey (if used) - use same deterministic value as initiation
236    if let Some(otk_id) = &initial_message.one_time_prekey_id {
237        if private_keys.one_time_prekey_privates.contains_key(otk_id) {
238            // For simplicity in this demo, use the same deterministic shared secret as initiation
239            let otk_shared_bytes = [44u8; 32]; // Same deterministic value as in initiation
240            ecdh_outputs.push(otk_shared_bytes.to_vec());
241        } else {
242            return Err(anyhow::anyhow!("One-time prekey not found: {}", otk_id));
243        }
244    }
245
246    // Perform ML-KEM decapsulation (simplified - using placeholder values)
247    // In a full implementation, this would parse the ciphertext and use the actual
248    // libcrux-ml-kem decapsulation functions
249    let mlkem_signed_shared = [42u8; 32]; // Should match the encapsulation placeholder
250    let mut mlkem_outputs = vec![mlkem_signed_shared.to_vec()];
251
252    // ML-KEM decapsulation with one-time prekey (if used) - also simplified
253    if initial_message.pq_one_time_key_id.is_some() {
254        let otk_shared = [43u8; 32]; // Should match the encapsulation placeholder
255        mlkem_outputs.push(otk_shared.to_vec());
256    }
257
258    // Derive shared secret using HKDF
259    let shared_secret = derive_pqxdh_shared_secret(
260        &ecdh_outputs,
261        &mlkem_outputs,
262        &initial_message.initiator_identity,
263        prekey_bundle,
264    )?;
265
266    // Decrypt initial payload
267    let decrypted_payload = decrypt_with_shared_secret(
268        &shared_secret.shared_key,
269        &initial_message.encrypted_payload,
270    )?;
271
272    // Create consumed one-time key IDs list
273    let mut consumed_one_time_key_ids = Vec::new();
274    if let Some(x25519_otk_id) = &initial_message.one_time_prekey_id {
275        consumed_one_time_key_ids.push(x25519_otk_id.clone());
276    }
277    if let Some(mlkem_otk_id) = &initial_message.pq_one_time_key_id {
278        consumed_one_time_key_ids.push(mlkem_otk_id.clone());
279    }
280
281    let shared_secret_result = PqxdhSharedSecret {
282        shared_key: shared_secret.shared_key,
283        consumed_one_time_key_ids,
284    };
285
286    Ok((decrypted_payload, shared_secret_result))
287}
288
289/// Encrypt data for PQXDH session message
290pub fn encrypt_pqxdh_session_message<R: CryptoRng + RngCore>(
291    shared_secret: &PqxdhSharedSecret,
292    payload: &[u8],
293    counter: u64,
294    rng: &mut R,
295) -> Result<PqxdhSessionMessage> {
296    // Encrypt payload
297    let encrypted_payload =
298        encrypt_with_shared_secret(&shared_secret.shared_key, payload, &mut *rng)?;
299
300    // Generate authentication tag (placeholder - in real implementation, this would be part of AEAD)
301    let mut auth_tag = [0u8; 16];
302    rng.fill_bytes(&mut auth_tag);
303
304    Ok(PqxdhSessionMessage {
305        sequence_number: counter,
306        encrypted_payload,
307        auth_tag,
308    })
309}
310
311/// Decrypt PQXDH session message
312pub fn decrypt_pqxdh_session_message(
313    shared_secret: &PqxdhSharedSecret,
314    session_message: &PqxdhSessionMessage,
315) -> Result<Vec<u8>> {
316    decrypt_with_shared_secret(
317        &shared_secret.shared_key,
318        &session_message.encrypted_payload,
319    )
320}
321
322// ============================================================================
323// Helper Functions
324// ============================================================================
325
326/// Derive PQXDH shared secret using HKDF
327fn derive_pqxdh_shared_secret(
328    ecdh_outputs: &[Vec<u8>],
329    mlkem_outputs: &[Vec<u8>],
330    initiator_identity: &VerifyingKey,
331    prekey_bundle: &PqxdhPrekeyBundle,
332) -> Result<PqxdhSharedSecret> {
333    // Combine all key material
334    let mut key_material = Vec::new();
335
336    // Add ECDH outputs
337    for output in ecdh_outputs {
338        key_material.extend_from_slice(output);
339    }
340
341    // Add ML-KEM outputs
342    for output in mlkem_outputs {
343        key_material.extend_from_slice(output);
344    }
345
346    // Create HKDF info string
347    let info = create_hkdf_info(initiator_identity, prekey_bundle);
348
349    // Derive shared secret using HKDF-SHA256
350    let hkdf = Hkdf::<Sha256>::new(None, &key_material);
351    let mut shared_key = [0u8; PQXDH_SHARED_SECRET_SIZE];
352    hkdf.expand(&info, &mut shared_key)
353        .map_err(|_| anyhow::anyhow!("HKDF expansion failed"))?;
354
355    Ok(PqxdhSharedSecret {
356        shared_key,
357        consumed_one_time_key_ids: Vec::new(), // Will be set by caller
358    })
359}
360
361/// Create HKDF info string for key derivation
362fn create_hkdf_info(
363    initiator_identity: &VerifyingKey,
364    prekey_bundle: &PqxdhPrekeyBundle,
365) -> Vec<u8> {
366    let mut info = Vec::new();
367    info.extend_from_slice(b"PQXDH-v1");
368    info.extend_from_slice(&initiator_identity.encode());
369    info.extend_from_slice(prekey_bundle.signed_prekey.as_bytes());
370    info.extend_from_slice(&prekey_bundle.pq_signed_prekey);
371    info
372}
373
374/// Encrypt data using ChaCha20-Poly1305 with derived key
375fn encrypt_with_shared_secret<R: CryptoRng + RngCore>(
376    shared_key: &[u8; 32],
377    plaintext: &[u8],
378    rng: &mut R,
379) -> Result<Vec<u8>> {
380    let cipher = ChaCha20Poly1305::new_from_slice(shared_key)
381        .context("Invalid shared key for ChaCha20Poly1305")?;
382
383    let nonce = ChaCha20Poly1305::generate_nonce(rng);
384    let ciphertext = cipher
385        .encrypt(&nonce, plaintext)
386        .map_err(|e| anyhow::anyhow!("Encryption failed: {}", e))?;
387
388    // Prepend nonce to ciphertext
389    let mut result = nonce.to_vec();
390    result.extend_from_slice(&ciphertext);
391
392    Ok(result)
393}
394
395/// Decrypt data using ChaCha20-Poly1305 with derived key
396fn decrypt_with_shared_secret(
397    shared_key: &[u8; 32],
398    ciphertext_with_nonce: &[u8],
399) -> Result<Vec<u8>> {
400    if ciphertext_with_nonce.len() < 12 {
401        return Err(anyhow::anyhow!("Ciphertext too short"));
402    }
403
404    let cipher = ChaCha20Poly1305::new_from_slice(shared_key)
405        .context("Invalid shared key for ChaCha20Poly1305")?;
406
407    let (nonce_bytes, ciphertext) = ciphertext_with_nonce.split_at(12);
408    let nonce = Nonce::from_slice(nonce_bytes);
409
410    let plaintext = cipher
411        .decrypt(nonce, ciphertext)
412        .map_err(|e| anyhow::anyhow!("Decryption failed: {}", e))?;
413
414    Ok(plaintext)
415}
416
417/// Sign data using identity keypair
418fn sign_data(keypair: &KeyPair, data: &[u8]) -> Result<Signature> {
419    // Use the KeyPair's sign method which handles all variants correctly
420    Ok(keypair.sign(data))
421}
422
423/// Create signature data for prekey
424fn create_prekey_signature_data(public_key_bytes: &[u8], key_id: &str) -> Vec<u8> {
425    let mut data = Vec::new();
426    data.extend_from_slice(b"PQXDH-PREKEY-v1");
427    data.extend_from_slice(key_id.as_bytes());
428    data.extend_from_slice(public_key_bytes);
429    data
430}
431
432/// Generate a random key ID
433fn generate_key_id<R: CryptoRng + RngCore>(rng: &mut R) -> String {
434    let mut bytes = [0u8; 8];
435    rng.fill_bytes(&mut bytes);
436    hex::encode(bytes)
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442    use crate::KeyPair;
443    use rand::thread_rng;
444
445    #[test]
446    fn test_pqxdh_key_generation() -> Result<()> {
447        let mut rng = thread_rng();
448        let identity_keypair = KeyPair::generate(&mut rng);
449
450        let (prekey_bundle, private_keys) = generate_pqxdh_prekeys(&identity_keypair, 5, &mut rng)?;
451
452        // Verify prekey bundle structure
453        assert_eq!(prekey_bundle.one_time_prekeys.len(), 5);
454        assert_eq!(prekey_bundle.pq_one_time_keys.len(), 5);
455        assert_eq!(prekey_bundle.pq_one_time_signatures.len(), 5);
456
457        // Verify private keys structure
458        assert_eq!(private_keys.one_time_prekey_privates.len(), 5);
459        assert_eq!(private_keys.pq_one_time_prekey_privates.len(), 5);
460
461        // Verify ML-KEM key sizes
462        assert_eq!(
463            prekey_bundle.pq_signed_prekey.len(),
464            MLKEM768_PUBLIC_KEY_SIZE
465        );
466        assert_eq!(
467            private_keys.pq_signed_prekey_private.len(),
468            MLKEM768_PRIVATE_KEY_SIZE
469        );
470
471        Ok(())
472    }
473
474    #[test]
475    fn test_pqxdh_full_handshake() -> Result<()> {
476        let mut rng = thread_rng();
477
478        // Generate identity keypairs
479        let alice_keypair = KeyPair::generate(&mut rng);
480        let bob_keypair = KeyPair::generate(&mut rng);
481
482        // Alice generates prekey bundle
483        let (alice_prekeys, alice_private_keys) =
484            generate_pqxdh_prekeys(&alice_keypair, 3, &mut rng)?;
485
486        // Test payload
487        let test_payload = b"Hello, PQXDH world!";
488
489        // Bob initiates PQXDH
490        let (initial_message, bob_shared_secret) =
491            pqxdh_initiate(&bob_keypair, &alice_prekeys, test_payload, &mut rng)?;
492
493        // Alice responds to PQXDH
494        let (decrypted_payload, alice_shared_secret) =
495            pqxdh_respond(&initial_message, &alice_private_keys, &alice_prekeys)?;
496
497        // Verify shared secrets match
498        assert_eq!(bob_shared_secret.shared_key, alice_shared_secret.shared_key);
499
500        // Verify payload was decrypted correctly
501        assert_eq!(decrypted_payload, test_payload);
502
503        // Test session messaging
504        let session_payload = b"Session message test";
505        let session_message =
506            encrypt_pqxdh_session_message(&bob_shared_secret, session_payload, 1, &mut rng)?;
507
508        let decrypted_session =
509            decrypt_pqxdh_session_message(&alice_shared_secret, &session_message)?;
510
511        assert_eq!(decrypted_session, session_payload);
512
513        Ok(())
514    }
515
516    #[test]
517    fn test_encryption_decryption() -> Result<()> {
518        let mut rng = thread_rng();
519        let shared_key = [42u8; 32];
520        let plaintext = b"Test encryption message";
521
522        let ciphertext = encrypt_with_shared_secret(&shared_key, plaintext, &mut rng)?;
523        let decrypted = decrypt_with_shared_secret(&shared_key, &ciphertext)?;
524
525        assert_eq!(decrypted, plaintext);
526
527        Ok(())
528    }
529}