zoe_wire_protocol/connection/
server.rs

1use crate::crypto::CryptoError;
2use crate::version::ServerProtocolConfig;
3use crate::KeyPair;
4use quinn::Endpoint;
5use std::net::SocketAddr;
6use tracing::debug;
7
8use quinn::ServerConfig;
9use std::sync::Arc;
10
11mod ed25519 {
12
13    use der::{asn1::*, Encode};
14    use ed25519_dalek::pkcs8::EncodePrivateKey;
15    use rustls::pki_types::CertificateDer;
16    use std::sync::Arc;
17    use x509_cert::{
18        attr::{AttributeTypeAndValue, AttributeValue},
19        certificate::{Certificate, TbsCertificate, Version},
20        ext::{Extension, Extensions},
21        name::{Name, RelativeDistinguishedName},
22        serial_number::SerialNumber,
23        spki::{AlgorithmIdentifier, SubjectPublicKeyInfo},
24        time::{Time, Validity},
25    };
26
27    use crate::{
28        crypto::CryptoError, version::ProtocolVersion, ClientProtocolConfig, ServerProtocolConfig,
29    };
30
31    // Create a simple certificate resolver that always returns our certificate
32    #[derive(Debug)]
33    struct Ed25519CertResolver {
34        server_signing_key: ed25519_dalek::SigningKey,
35        server_protocol_config: ServerProtocolConfig,
36    }
37
38    impl rustls::server::ResolvesServerCert for Ed25519CertResolver {
39        fn resolve(
40            &self,
41            client_hello: rustls::server::ClientHello,
42        ) -> Option<Arc<rustls::sign::CertifiedKey>> {
43            tracing::debug!("🔍 Resolving server certificate for client hello");
44
45            let Some(alpn) = client_hello.alpn() else {
46                tracing::debug!("❌ No ALPN protocols provided by client");
47                return None;
48            };
49
50            let alpn_protocols: Vec<&[u8]> = alpn.collect();
51            let client_protocol_config =
52                match ClientProtocolConfig::from_alpn_data(alpn_protocols.iter().copied()) {
53                    Ok(config) => config,
54                    Err(e) => {
55                        tracing::error!("❌ Failed to parse client ALPN protocol config: {e}");
56                        return None;
57                    }
58                };
59
60            let protocol_version = self
61                .server_protocol_config
62                .negotiate_version(&client_protocol_config.0);
63
64            // Generate TLS certificate with negotiated protocol version
65            let certs = match generate_ed25519_cert_for_tls(
66                &self.server_signing_key,
67                "localhost",
68                protocol_version,
69            ) {
70                Ok(certs) => certs,
71                Err(e) => {
72                    tracing::error!("Failed to generate certificate: {e}");
73                    return None;
74                }
75            };
76
77            // Create certificate resolver with Ed25519 signing key
78            let pkcs8_der = self
79                .server_signing_key
80                .to_pkcs8_der()
81                .inspect_err(|&e| {
82                    tracing::error!("Failed to encode Ed25519 key: {}", e);
83                })
84                .ok()?;
85
86            let private_key = rustls::pki_types::PrivateKeyDer::from(
87                rustls::pki_types::PrivatePkcs8KeyDer::from(pkcs8_der.as_bytes().to_vec()),
88            );
89
90            let signing_key =
91                rustls::crypto::aws_lc_rs::sign::any_supported_type(&private_key).ok()?;
92
93            Some(Arc::new(rustls::sign::CertifiedKey::new(
94                certs.to_vec(),
95                signing_key,
96            )))
97        }
98    }
99
100    /// Generate a deterministic TLS certificate using Ed25519
101    ///
102    /// This creates a proper Ed25519 certificate where the Ed25519 public key
103    /// is stored directly in the SubjectPublicKeyInfo field.
104    pub(crate) fn generate_ed25519_cert_for_tls(
105        ed25519_signing_key: &ed25519_dalek::SigningKey,
106        subject_name: &str,
107        selected_protocol_version: Option<ProtocolVersion>,
108    ) -> std::result::Result<Vec<CertificateDer<'static>>, CryptoError> {
109        tracing::debug!(
110            "🔧 Creating proper Ed25519 certificate for subject: {}",
111            subject_name
112        );
113
114        let verifying_key = ed25519_signing_key.verifying_key();
115        let public_key_bytes = verifying_key.to_bytes();
116
117        tracing::debug!(
118            "🔧 Ed25519 public key length: {} bytes",
119            public_key_bytes.len()
120        );
121
122        // Ed25519 algorithm identifier OID (RFC 8410)
123        let ed25519_oid = ObjectIdentifier::new("1.3.101.112")
124            .map_err(|e| CryptoError::ParseError(format!("Invalid Ed25519 OID: {e}")))?;
125
126        // Create SubjectPublicKeyInfo with Ed25519 public key
127        let algorithm = AlgorithmIdentifier {
128            oid: ed25519_oid,
129            parameters: None,
130        };
131
132        let subject_public_key_info = SubjectPublicKeyInfo {
133            algorithm,
134            subject_public_key: BitString::from_bytes(&public_key_bytes)
135                .map_err(|e| CryptoError::ParseError(format!("Failed to create BitString: {e}")))?,
136        };
137
138        // Create subject name (CN=subject_name)
139        let cn_oid = const_oid::db::rfc4519::CN;
140        let cn_value = AttributeValue::new(der::Tag::Utf8String, subject_name.as_bytes())
141            .map_err(|e| CryptoError::ParseError(format!("Failed to create CN value: {e}")))?;
142
143        let cn_attribute = AttributeTypeAndValue {
144            oid: cn_oid,
145            value: cn_value,
146        };
147
148        let rdn = RelativeDistinguishedName::from(
149            SetOfVec::try_from(vec![cn_attribute])
150                .map_err(|e| CryptoError::ParseError(format!("Failed to create RDN: {e}")))?,
151        );
152
153        let subject = Name::from(vec![rdn]);
154
155        // Set validity period (1 year from now)
156        let now = std::time::SystemTime::now();
157        let not_before = Time::GeneralTime(
158            GeneralizedTime::from_system_time(now)
159                .map_err(|e| CryptoError::ParseError(format!("Time conversion error: {e}")))?,
160        );
161        let not_after = Time::GeneralTime(
162            GeneralizedTime::from_system_time(
163                now + std::time::Duration::from_secs(365 * 24 * 3600),
164            )
165            .map_err(|e| CryptoError::ParseError(format!("Time conversion error: {e}")))?,
166        );
167
168        let validity = Validity {
169            not_before,
170            not_after,
171        };
172
173        // Create ALPN extension containing the negotiated protocol version
174        let protocol_version_bytes =
175            if let Some(selected_protocol_version) = selected_protocol_version {
176                postcard::to_stdvec(&selected_protocol_version).map_err(|e| {
177                    CryptoError::ParseError(format!("Failed to serialize protocol version: {e}"))
178                })?
179            } else {
180                vec![]
181            };
182
183        // Use a custom OID for the ALPN protocol version extension
184        // Using private enterprise arc: 1.3.6.1.4.1.99999.1 (placeholder OID)
185        let alpn_extension_oid = ObjectIdentifier::new("1.3.6.1.4.1.99999.1")
186            .map_err(|e| CryptoError::ParseError(format!("Invalid ALPN extension OID: {e}")))?;
187
188        let alpn_extension = Extension {
189            extn_id: alpn_extension_oid,
190            critical: false, // Non-critical extension
191            extn_value: OctetString::new(protocol_version_bytes).map_err(|e| {
192                CryptoError::ParseError(format!("Failed to create extension value: {e}"))
193            })?,
194        };
195
196        let extensions = Extensions::from(vec![alpn_extension]);
197
198        // Create TBS certificate
199        let tbs_certificate = TbsCertificate {
200            version: Version::V3,
201            serial_number: SerialNumber::from(1u32),
202            signature: AlgorithmIdentifier {
203                oid: ed25519_oid,
204                parameters: None,
205            },
206            issuer: subject.clone(), // Self-signed
207            validity,
208            subject,
209            subject_public_key_info,
210            issuer_unique_id: None,
211            subject_unique_id: None,
212            extensions: Some(extensions),
213        };
214
215        // Encode TBS certificate for signing
216        let tbs_der = tbs_certificate.to_der().map_err(|e| {
217            CryptoError::ParseError(format!("Failed to encode TBS certificate: {e}"))
218        })?;
219
220        // Sign the TBS certificate with Ed25519
221        use signature::Signer;
222        let signature = ed25519_signing_key.sign(&tbs_der);
223
224        // Create final certificate
225        let certificate = Certificate {
226            tbs_certificate,
227            signature_algorithm: AlgorithmIdentifier {
228                oid: ed25519_oid,
229                parameters: None,
230            },
231            signature: BitString::from_bytes(&signature.to_bytes()).map_err(|e| {
232                CryptoError::ParseError(format!("Failed to create signature BitString: {e}"))
233            })?,
234        };
235
236        // Encode certificate to DER
237        let cert_der = certificate
238            .to_der()
239            .map_err(|e| CryptoError::ParseError(format!("Failed to encode certificate: {e}")))?;
240
241        tracing::debug!("✅ Generated proper Ed25519 certificate successfully");
242
243        Ok(vec![CertificateDer::from(cert_der)])
244    }
245
246    /// Create a complete rustls ServerConfig for Ed25519 certificates
247    ///
248    /// This function creates a fully configured rustls ServerConfig that:
249    /// - Uses the default crypto provider with Ed25519 support
250    /// - Requires TLS 1.3
251    /// - Uses anonymous client authentication
252    /// - Configures the Ed25519 certificate and signing key
253    ///
254    /// # Arguments
255    /// * `server_signing_key` - The Ed25519 signing key for the server
256    /// * `hostname` - The hostname for the certificate (e.g., "localhost")
257    /// * `alpn_protocols` - The ALPN protocols to advertise to the client
258    ///
259    /// # Returns
260    /// A configured rustls ServerConfig ready for use with QUIC
261    pub(crate) fn create_ed25519_server_config_with_alpn(
262        server_signing_key: &ed25519_dalek::SigningKey,
263        _hostname: &str,
264        server_protocol_config: ServerProtocolConfig,
265    ) -> std::result::Result<rustls::ServerConfig, CryptoError> {
266        let cert_resolver = Ed25519CertResolver {
267            server_signing_key: server_signing_key.clone(),
268            server_protocol_config,
269        };
270
271        // Create rustls server config
272        let mut rustls_config = rustls::ServerConfig::builder()
273            .with_no_client_auth() // we accept any client
274            .with_cert_resolver(Arc::new(cert_resolver));
275
276        // Advertise the same protocols that the default client would send
277        // This ensures ALPN negotiation succeeds and we can do actual negotiation in the cert resolver
278        let default_client_config = crate::version::ClientProtocolConfig::default();
279        rustls_config.alpn_protocols = default_client_config.alpn_protocols();
280
281        Ok(rustls_config)
282    }
283}
284
285/// Create a QUIC server endpoint with TLS certificate (Ed25519 or ML-DSA-44)
286pub fn create_server_endpoint(
287    addr: SocketAddr,
288    server_keypair: &KeyPair,
289) -> std::result::Result<Endpoint, CryptoError> {
290    create_server_endpoint_with_protocols(addr, server_keypair, &ServerProtocolConfig::default())
291}
292
293/// Create a QUIC server endpoint with protocol version negotiation support
294pub fn create_server_endpoint_with_protocols(
295    addr: SocketAddr,
296    server_keypair: &KeyPair,
297    protocol_negotiation: &ServerProtocolConfig,
298) -> std::result::Result<Endpoint, CryptoError> {
299    debug!("🚀 Creating relay server endpoint on {}", addr);
300
301    let rustls_config = match server_keypair {
302        KeyPair::Ed25519(signing_key) => {
303            debug!(
304                "🔑 Server Ed25519 public key: {}",
305                hex::encode(signing_key.verifying_key().to_bytes())
306            );
307
308            // Create Ed25519 server configuration with ALPN
309            ed25519::create_ed25519_server_config_with_alpn(
310                signing_key.as_ref(),
311                "localhost",
312                protocol_negotiation.clone(),
313            )
314            .map_err(|e| {
315                CryptoError::TlsError(format!("Failed to create Ed25519 server config: {e}"))
316            })?
317        }
318
319        KeyPair::MlDsa44(_, _) => {
320            return Err(CryptoError::TlsError(
321                "ML-DSA-44 is not supported for TLS transport security yet. Use Ed25519."
322                    .to_string(),
323            ));
324        }
325
326        // ML-DSA-65 and ML-DSA-87 are not supported for TLS yet
327        KeyPair::MlDsa65(_, _) => {
328            return Err(CryptoError::TlsError(
329                "ML-DSA-65 is not supported for TLS transport security yet. Use Ed25519."
330                    .to_string(),
331            ));
332        }
333        KeyPair::MlDsa87(_, _) => {
334            return Err(CryptoError::TlsError(
335                "ML-DSA-87 is not supported for TLS transport security yet. Use Ed25519."
336                    .to_string(),
337            ));
338        }
339    };
340
341    let server_config = ServerConfig::with_crypto(Arc::new(
342        quinn::crypto::rustls::QuicServerConfig::try_from(rustls_config)
343            .map_err(|e| CryptoError::TlsError(format!("Failed to create server config: {e}")))?,
344    ));
345
346    let endpoint = Endpoint::server(server_config, addr)
347        .map_err(|e| CryptoError::TlsError(format!("Failed to create server endpoint: {e}")))?;
348    debug!("✅ Server endpoint ready on {}", addr);
349
350    Ok(endpoint)
351}