zoe_wire_protocol/connection/
client.rs1use crate::crypto::CryptoError;
2use crate::version::ClientProtocolConfig;
3use crate::VerifyingKey;
4use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, Endpoint};
5use rustls::ClientConfig as RustlsClientConfig;
6use std::sync::Arc;
7use std::time::Duration;
8
9mod ed25519 {
10 use rustls::pki_types::CertificateDer;
11 use x509_parser::oid_registry::asn1_rs::oid;
12 use x509_parser::prelude::*;
13
14 use crate::crypto::CryptoError;
15
16 pub fn extract_ed25519_public_key_from_cert(
20 cert_der: &CertificateDer,
21 ) -> std::result::Result<ed25519_dalek::VerifyingKey, CryptoError> {
22 let (_, cert) = X509Certificate::from_der(cert_der.as_ref())
24 .map_err(|e| CryptoError::ParseError(format!("Failed to parse certificate: {e:?}")))?;
25
26 let spki = cert.public_key();
28 let algorithm_oid = &spki.algorithm.algorithm;
29
30 let ed25519_oid = oid!(1.3.101 .112);
32
33 tracing::debug!("🔍 Certificate algorithm OID: {}", algorithm_oid);
34
35 if algorithm_oid != &ed25519_oid {
36 return Err(CryptoError::ParseError(format!(
37 "Certificate is not using Ed25519 algorithm. Found OID: {algorithm_oid}"
38 )));
39 }
40
41 tracing::debug!("🔍 Found Ed25519 certificate");
42 let public_key_bits = &spki.subject_public_key;
43
44 let key_bytes = if public_key_bits.data.len() == 33 && public_key_bits.data[0] == 0 {
47 &public_key_bits.data[1..]
49 } else if public_key_bits.data.len() == 32 {
50 &public_key_bits.data
52 } else {
53 return Err(CryptoError::ParseError(format!(
54 "Invalid Ed25519 public key length: {} bytes",
55 public_key_bits.data.len()
56 )));
57 };
58
59 tracing::debug!("🔍 Extracted Ed25519 public key: {} bytes", key_bytes.len());
60
61 if key_bytes.len() != 32 {
63 return Err(CryptoError::ParseError(format!(
64 "Invalid Ed25519 public key length: {} bytes (expected 32)",
65 key_bytes.len()
66 )));
67 }
68
69 let mut key_array = [0u8; 32];
71 key_array.copy_from_slice(key_bytes);
72
73 let verifying_key = ed25519_dalek::VerifyingKey::from_bytes(&key_array)
74 .map_err(|e| CryptoError::ParseError(format!("Invalid Ed25519 public key: {e}")))?;
75
76 tracing::debug!("✅ Successfully extracted Ed25519 public key from certificate");
77 Ok(verifying_key)
78 }
79
80 #[derive(Debug)]
81 pub(super) struct AcceptSpecificEd25519ServerCertVerifier {
82 expected_server_key_ed25519: ed25519_dalek::VerifyingKey,
83 }
84
85 impl AcceptSpecificEd25519ServerCertVerifier {
86 pub(super) fn new(expected_server_key_ed25519: ed25519_dalek::VerifyingKey) -> Self {
87 Self {
88 expected_server_key_ed25519,
89 }
90 }
91 }
92
93 impl rustls::client::danger::ServerCertVerifier for AcceptSpecificEd25519ServerCertVerifier {
94 fn verify_server_cert(
95 &self,
96 end_entity: &CertificateDer,
97 _intermediates: &[CertificateDer],
98 _server_name: &rustls::pki_types::ServerName,
99 _ocsp_response: &[u8],
100 _now: rustls::pki_types::UnixTime,
101 ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error>
102 {
103 match extract_ed25519_public_key_from_cert(end_entity) {
105 Ok(server_ed25519_key) => {
106 let extracted_key_hex = hex::encode(server_ed25519_key.to_bytes());
107 let expected_key_hex = hex::encode(self.expected_server_key_ed25519.to_bytes());
108
109 tracing::debug!("🔍 Extracted server key: {}", extracted_key_hex);
110 tracing::debug!("🔍 Expected server key: {}", expected_key_hex);
111
112 if server_ed25519_key.to_bytes() == self.expected_server_key_ed25519.to_bytes()
114 {
115 tracing::debug!("✅ Server Ed25519 identity verified via certificate");
116 Ok(rustls::client::danger::ServerCertVerified::assertion())
117 } else {
118 tracing::error!("❌ Server Ed25519 key mismatch");
119 tracing::error!(" Extracted: {}", extracted_key_hex);
120 tracing::error!(" Expected: {}", expected_key_hex);
121 Err(rustls::Error::InvalidCertificate(
122 rustls::CertificateError::ApplicationVerificationFailure,
123 ))
124 }
125 }
126 Err(e) => {
127 tracing::error!("❌ Failed to extract Ed25519 key from certificate: {}", e);
128 Err(rustls::Error::InvalidCertificate(
129 rustls::CertificateError::ApplicationVerificationFailure,
130 ))
131 }
132 }
133 }
134
135 fn verify_tls12_signature(
136 &self,
137 _message: &[u8],
138 _cert: &CertificateDer,
139 _dss: &rustls::DigitallySignedStruct,
140 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
141 {
142 Err(rustls::Error::UnsupportedNameType)
144 }
145
146 fn verify_tls13_signature(
147 &self,
148 _message: &[u8],
149 _cert: &CertificateDer,
150 _dss: &rustls::DigitallySignedStruct,
151 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
152 {
153 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
155 }
156
157 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
158 vec![rustls::SignatureScheme::ED25519]
159 }
160 }
161}
162
163pub fn create_client_endpoint(
164 server_public_key: &VerifyingKey,
165) -> std::result::Result<Endpoint, CryptoError> {
166 create_client_endpoint_with_protocols(server_public_key, &ClientProtocolConfig::default())
167}
168
169pub fn create_client_endpoint_with_protocols(
170 server_public_key: &VerifyingKey,
171 protocol_versions: &ClientProtocolConfig,
172) -> std::result::Result<Endpoint, CryptoError> {
173 let cert_verifier = match server_public_key {
174 VerifyingKey::Ed25519(verifying_key) => {
175 ed25519::AcceptSpecificEd25519ServerCertVerifier::new(**verifying_key)
176 }
177 _ => {
178 return Err(CryptoError::TlsError(
179 "Server certificate type not supported for TLS".to_string(),
180 ));
181 }
182 };
183
184 let mut crypto = RustlsClientConfig::builder()
186 .dangerous()
187 .with_custom_certificate_verifier(Arc::new(cert_verifier))
188 .with_no_client_auth();
189
190 let alpn_protocols = protocol_versions.alpn_protocols();
192 crypto.alpn_protocols = alpn_protocols;
193
194 let mut transport_config = quinn::TransportConfig::default();
196 transport_config.keep_alive_interval(Some(Duration::from_secs(25)));
197
198 let mut client_config = ClientConfig::new(Arc::new(
199 QuicClientConfig::try_from(crypto)
200 .map_err(|e| CryptoError::TlsError(format!("Failed to create client config: {e}")))?,
201 ));
202 client_config.transport_config(Arc::new(transport_config));
203
204 let mut endpoint = Endpoint::client((std::net::Ipv6Addr::UNSPECIFIED, 0).into())
205 .map_err(|e| CryptoError::TlsError(format!("Failed to create endpoint: {e}")))?;
206 endpoint.set_default_client_config(client_config);
207
208 Ok(endpoint)
209}