zoe_wire_protocol/connection/
client.rs

1use crate::crypto::CryptoError;
2use crate::version::ClientProtocolConfig;
3use crate::VerifyingKey;
4use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, Endpoint};
5use rustls::ClientConfig as RustlsClientConfig;
6use std::sync::Arc;
7use std::time::Duration;
8
9mod ed25519 {
10    use rustls::pki_types::CertificateDer;
11    use x509_parser::oid_registry::asn1_rs::oid;
12    use x509_parser::prelude::*;
13
14    use crate::crypto::CryptoError;
15
16    /// Extract Ed25519 public key from a certificate
17    /// This function extracts the Ed25519 public key directly from the certificate's
18    /// SubjectPublicKeyInfo field when the certificate uses the Ed25519 algorithm identifier.
19    pub fn extract_ed25519_public_key_from_cert(
20        cert_der: &CertificateDer,
21    ) -> std::result::Result<ed25519_dalek::VerifyingKey, CryptoError> {
22        // Parse the certificate
23        let (_, cert) = X509Certificate::from_der(cert_der.as_ref())
24            .map_err(|e| CryptoError::ParseError(format!("Failed to parse certificate: {e:?}")))?;
25
26        // Get the subject public key info
27        let spki = cert.public_key();
28        let algorithm_oid = &spki.algorithm.algorithm;
29
30        // Ed25519 algorithm identifier: 1.3.101.112
31        let ed25519_oid = oid!(1.3.101 .112);
32
33        tracing::debug!("🔍 Certificate algorithm OID: {}", algorithm_oid);
34
35        if algorithm_oid != &ed25519_oid {
36            return Err(CryptoError::ParseError(format!(
37                "Certificate is not using Ed25519 algorithm. Found OID: {algorithm_oid}"
38            )));
39        }
40
41        tracing::debug!("🔍 Found Ed25519 certificate");
42        let public_key_bits = &spki.subject_public_key;
43
44        // The BIT STRING contains the raw 32-byte Ed25519 public key
45        // Note: BIT STRING may have unused bits indicator as first byte
46        let key_bytes = if public_key_bits.data.len() == 33 && public_key_bits.data[0] == 0 {
47            // Skip the unused bits indicator (should be 0 for Ed25519)
48            &public_key_bits.data[1..]
49        } else if public_key_bits.data.len() == 32 {
50            // Direct 32-byte key
51            &public_key_bits.data
52        } else {
53            return Err(CryptoError::ParseError(format!(
54                "Invalid Ed25519 public key length: {} bytes",
55                public_key_bits.data.len()
56            )));
57        };
58
59        tracing::debug!("🔍 Extracted Ed25519 public key: {} bytes", key_bytes.len());
60
61        // Validate key length
62        if key_bytes.len() != 32 {
63            return Err(CryptoError::ParseError(format!(
64                "Invalid Ed25519 public key length: {} bytes (expected 32)",
65                key_bytes.len()
66            )));
67        }
68
69        // Convert to Ed25519 VerifyingKey
70        let mut key_array = [0u8; 32];
71        key_array.copy_from_slice(key_bytes);
72
73        let verifying_key = ed25519_dalek::VerifyingKey::from_bytes(&key_array)
74            .map_err(|e| CryptoError::ParseError(format!("Invalid Ed25519 public key: {e}")))?;
75
76        tracing::debug!("✅ Successfully extracted Ed25519 public key from certificate");
77        Ok(verifying_key)
78    }
79
80    #[derive(Debug)]
81    pub(super) struct AcceptSpecificEd25519ServerCertVerifier {
82        expected_server_key_ed25519: ed25519_dalek::VerifyingKey,
83    }
84
85    impl AcceptSpecificEd25519ServerCertVerifier {
86        pub(super) fn new(expected_server_key_ed25519: ed25519_dalek::VerifyingKey) -> Self {
87            Self {
88                expected_server_key_ed25519,
89            }
90        }
91    }
92
93    impl rustls::client::danger::ServerCertVerifier for AcceptSpecificEd25519ServerCertVerifier {
94        fn verify_server_cert(
95            &self,
96            end_entity: &CertificateDer,
97            _intermediates: &[CertificateDer],
98            _server_name: &rustls::pki_types::ServerName,
99            _ocsp_response: &[u8],
100            _now: rustls::pki_types::UnixTime,
101        ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error>
102        {
103            // Extract Ed25519 key from certificate
104            match extract_ed25519_public_key_from_cert(end_entity) {
105                Ok(server_ed25519_key) => {
106                    let extracted_key_hex = hex::encode(server_ed25519_key.to_bytes());
107                    let expected_key_hex = hex::encode(self.expected_server_key_ed25519.to_bytes());
108
109                    tracing::debug!("🔍 Extracted server key: {}", extracted_key_hex);
110                    tracing::debug!("🔍 Expected server key:  {}", expected_key_hex);
111
112                    // Verify it matches our expected key
113                    if server_ed25519_key.to_bytes() == self.expected_server_key_ed25519.to_bytes()
114                    {
115                        tracing::debug!("✅ Server Ed25519 identity verified via certificate");
116                        Ok(rustls::client::danger::ServerCertVerified::assertion())
117                    } else {
118                        tracing::error!("❌ Server Ed25519 key mismatch");
119                        tracing::error!("   Extracted: {}", extracted_key_hex);
120                        tracing::error!("   Expected:  {}", expected_key_hex);
121                        Err(rustls::Error::InvalidCertificate(
122                            rustls::CertificateError::ApplicationVerificationFailure,
123                        ))
124                    }
125                }
126                Err(e) => {
127                    tracing::error!("❌ Failed to extract Ed25519 key from certificate: {}", e);
128                    Err(rustls::Error::InvalidCertificate(
129                        rustls::CertificateError::ApplicationVerificationFailure,
130                    ))
131                }
132            }
133        }
134
135        fn verify_tls12_signature(
136            &self,
137            _message: &[u8],
138            _cert: &CertificateDer,
139            _dss: &rustls::DigitallySignedStruct,
140        ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
141        {
142            // We only support TLS 1.3
143            Err(rustls::Error::UnsupportedNameType)
144        }
145
146        fn verify_tls13_signature(
147            &self,
148            _message: &[u8],
149            _cert: &CertificateDer,
150            _dss: &rustls::DigitallySignedStruct,
151        ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
152        {
153            // Accept any TLS 1.3 signature - we verify identity via the embedded Ed25519 key
154            Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
155        }
156
157        fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
158            vec![rustls::SignatureScheme::ED25519]
159        }
160    }
161}
162
163pub fn create_client_endpoint(
164    server_public_key: &VerifyingKey,
165) -> std::result::Result<Endpoint, CryptoError> {
166    create_client_endpoint_with_protocols(server_public_key, &ClientProtocolConfig::default())
167}
168
169pub fn create_client_endpoint_with_protocols(
170    server_public_key: &VerifyingKey,
171    protocol_versions: &ClientProtocolConfig,
172) -> std::result::Result<Endpoint, CryptoError> {
173    let cert_verifier = match server_public_key {
174        VerifyingKey::Ed25519(verifying_key) => {
175            ed25519::AcceptSpecificEd25519ServerCertVerifier::new(**verifying_key)
176        }
177        _ => {
178            return Err(CryptoError::TlsError(
179                "Server certificate type not supported for TLS".to_string(),
180            ));
181        }
182    };
183
184    // Create client config with certificate resolver using the appropriate crypto provider
185    let mut crypto = RustlsClientConfig::builder()
186        .dangerous()
187        .with_custom_certificate_verifier(Arc::new(cert_verifier))
188        .with_no_client_auth();
189
190    // Set ALPN protocols for version negotiation
191    let alpn_protocols = protocol_versions.alpn_protocols();
192    crypto.alpn_protocols = alpn_protocols;
193
194    // we need to set the keep alive interval to 25 seconds, below the 30s timeout default so we keep the connection alive
195    let mut transport_config = quinn::TransportConfig::default();
196    transport_config.keep_alive_interval(Some(Duration::from_secs(25)));
197
198    let mut client_config = ClientConfig::new(Arc::new(
199        QuicClientConfig::try_from(crypto)
200            .map_err(|e| CryptoError::TlsError(format!("Failed to create client config: {e}")))?,
201    ));
202    client_config.transport_config(Arc::new(transport_config));
203
204    let mut endpoint = Endpoint::client((std::net::Ipv6Addr::UNSPECIFIED, 0).into())
205        .map_err(|e| CryptoError::TlsError(format!("Failed to create endpoint: {e}")))?;
206    endpoint.set_default_client_config(client_config);
207
208    Ok(endpoint)
209}