zoe_app_primitives/
connection.rs

1use std::collections::BTreeSet;
2use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
3
4use serde::{Deserialize, Serialize};
5use zoe_wire_protocol::VerifyingKey;
6
7#[cfg(feature = "frb-api")]
8use flutter_rust_bridge::frb;
9
10/// Network address information for connecting to a service
11///
12/// Supports multiple address types including DNS names, IPv4, and IPv6 addresses
13/// with optional port specifications for maximum flexibility.
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
15#[cfg_attr(feature = "frb-api", frb(opaque))]
16pub enum NetworkAddress {
17    /// DNS hostname with optional port
18    ///
19    /// Examples: "relay.example.com", "relay.example.com:8443"
20    /// If no port is specified, the default port should be used.
21    Dns { hostname: String, port: Option<u16> },
22
23    /// IPv4 address with optional port
24    ///
25    /// Examples: "192.168.1.100", "192.168.1.100:8443"
26    /// If no port is specified, the default port should be used.
27    Ipv4 {
28        address: Ipv4Addr,
29        port: Option<u16>,
30    },
31
32    /// IPv6 address with optional port
33    ///
34    /// Examples: "::1", "\[::1\]:8443"
35    /// If no port is specified, the default port should be used.
36    Ipv6 {
37        address: Ipv6Addr,
38        port: Option<u16>,
39    },
40}
41
42impl std::fmt::Display for NetworkAddress {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        let (first, port) = match self {
45            NetworkAddress::Dns { hostname, port } => (hostname.clone(), port),
46            NetworkAddress::Ipv4 { address, port } => (format!("{address}"), port),
47            NetworkAddress::Ipv6 { address, port } => (format!("[{address}]"), port),
48        };
49        if let Some(port) = port {
50            write!(f, "{first}:{port}")
51        } else {
52            write!(f, "{first}")
53        }
54    }
55}
56
57#[cfg_attr(feature = "frb-api", frb)]
58impl NetworkAddress {
59    /// Create a DNS network address
60    pub fn dns(hostname: impl Into<String>) -> Self {
61        Self::Dns {
62            hostname: hostname.into(),
63            port: None,
64        }
65    }
66
67    /// Create a DNS network address with port
68    pub fn dns_with_port(hostname: impl Into<String>, port: u16) -> Self {
69        Self::Dns {
70            hostname: hostname.into(),
71            port: Some(port),
72        }
73    }
74
75    /// Create an IPv4 network address
76    pub fn ipv4(address: Ipv4Addr) -> Self {
77        Self::Ipv4 {
78            address,
79            port: None,
80        }
81    }
82
83    /// Create an IPv4 network address with port
84    pub fn ipv4_with_port(address: Ipv4Addr, port: u16) -> Self {
85        Self::Ipv4 {
86            address,
87            port: Some(port),
88        }
89    }
90
91    /// Create an IPv6 network address
92    pub fn ipv6(address: Ipv6Addr) -> Self {
93        Self::Ipv6 {
94            address,
95            port: None,
96        }
97    }
98
99    /// Create an IPv6 network address with port
100    pub fn ipv6_with_port(address: Ipv6Addr, port: u16) -> Self {
101        Self::Ipv6 {
102            address,
103            port: Some(port),
104        }
105    }
106
107    /// Get the port if specified, otherwise return the default port
108    pub fn port_or_default(&self, default_port: u16) -> u16 {
109        match self {
110            NetworkAddress::Dns { port, .. } => port.unwrap_or(default_port),
111            NetworkAddress::Ipv4 { port, .. } => port.unwrap_or(default_port),
112            NetworkAddress::Ipv6 { port, .. } => port.unwrap_or(default_port),
113        }
114    }
115
116    /// Get the port if specified
117    pub fn port(&self) -> Option<u16> {
118        match self {
119            NetworkAddress::Dns { port, .. } => *port,
120            NetworkAddress::Ipv4 { port, .. } => *port,
121            NetworkAddress::Ipv6 { port, .. } => *port,
122        }
123    }
124    /// Resolve this network address to a socket address
125    ///
126    /// For IP addresses, returns immediately. For DNS addresses, performs resolution.
127    pub async fn resolve_to_socket_addr(&self, default_port: u16) -> Result<SocketAddr, String> {
128        match self {
129            NetworkAddress::Ipv4 { address, port } => Ok(SocketAddr::V4(
130                std::net::SocketAddrV4::new(*address, port.unwrap_or(default_port)),
131            )),
132            NetworkAddress::Ipv6 { address, port } => Ok(SocketAddr::V6(
133                std::net::SocketAddrV6::new(*address, port.unwrap_or(default_port), 0, 0),
134            )),
135            NetworkAddress::Dns { hostname, port } => {
136                let connection_string = match port {
137                    Some(p) => format!("{}:{}", hostname, p),
138                    None => format!("{}:{}", hostname, default_port),
139                };
140
141                // Use tokio's lookup_host for DNS resolution
142                use tokio::net::lookup_host;
143                let addrs = lookup_host(connection_string.clone())
144                    .await
145                    .map_err(|e| e.to_string())?;
146                if let Some(addr) = addrs.into_iter().next() {
147                    Ok(addr)
148                } else {
149                    Err(format!("No addresses found for {}", connection_string))
150                }
151            }
152        }
153    }
154}
155
156impl From<IpAddr> for NetworkAddress {
157    fn from(addr: IpAddr) -> Self {
158        match addr {
159            IpAddr::V4(ipv4) => NetworkAddress::ipv4(ipv4),
160            IpAddr::V6(ipv6) => NetworkAddress::ipv6(ipv6),
161        }
162    }
163}
164
165impl From<SocketAddr> for NetworkAddress {
166    fn from(addr: SocketAddr) -> Self {
167        match addr {
168            SocketAddr::V4(ipv4) => NetworkAddress::Ipv4 {
169                address: *ipv4.ip(),
170                port: Some(ipv4.port()),
171            },
172            SocketAddr::V6(ipv6) => NetworkAddress::Ipv6 {
173                address: *ipv6.ip(),
174                port: Some(ipv6.port()),
175            },
176        }
177    }
178}
179
180impl From<&str> for NetworkAddress {
181    fn from(addr_str: &str) -> Self {
182        // Try to parse as a full socket address first
183        if let Ok(socket_addr) = addr_str.parse::<SocketAddr>() {
184            return match socket_addr.ip() {
185                std::net::IpAddr::V4(ipv4) => {
186                    NetworkAddress::ipv4_with_port(ipv4, socket_addr.port())
187                }
188                std::net::IpAddr::V6(ipv6) => {
189                    NetworkAddress::ipv6_with_port(ipv6, socket_addr.port())
190                }
191            };
192        }
193
194        // Try to parse as IP:port
195        if let Some((ip_str, port_str)) = addr_str.rsplit_once(':') {
196            if let (Ok(ip), Ok(port)) =
197                (ip_str.parse::<std::net::IpAddr>(), port_str.parse::<u16>())
198            {
199                return match ip {
200                    std::net::IpAddr::V4(ipv4) => NetworkAddress::ipv4_with_port(ipv4, port),
201                    std::net::IpAddr::V6(ipv6) => NetworkAddress::ipv6_with_port(ipv6, port),
202                };
203            }
204
205            // Try as hostname:port
206            if let Ok(port) = port_str.parse::<u16>() {
207                return NetworkAddress::dns_with_port(ip_str, port);
208            }
209        }
210
211        // Try to parse as plain IP
212        if let Ok(ip) = addr_str.parse::<std::net::IpAddr>() {
213            return match ip {
214                std::net::IpAddr::V4(ipv4) => NetworkAddress::ipv4(ipv4),
215                std::net::IpAddr::V6(ipv6) => NetworkAddress::ipv6(ipv6),
216            };
217        }
218
219        // Assume it's a DNS name without port
220        NetworkAddress::dns(addr_str.to_string())
221    }
222}
223
224impl From<String> for NetworkAddress {
225    fn from(addr_str: String) -> Self {
226        Self::from(addr_str.as_str())
227    }
228}
229
230/// Relay address information for a service
231///
232/// Contains the public key and network addresses needed to connect to a service.
233/// This structure is designed to be compact and suitable for QR code encoding.
234#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
235pub struct RelayAddress {
236    /// Public key of the service
237    ///
238    /// Used to verify the service's identity during the connection handshake.
239    /// This prevents man-in-the-middle attacks and ensures the client is
240    /// connecting to the correct service. Supports Ed25519 and ML-DSA keys.
241    pub public_key: VerifyingKey,
242
243    /// Network addresses where the service can be reached
244    ///
245    /// Multiple addresses can be provided for redundancy and different network
246    /// configurations. Clients should try addresses in order until one succeeds.
247    pub addresses: BTreeSet<NetworkAddress>,
248
249    /// Optional human-readable name for the service
250    ///
251    /// Can be used for display purposes or debugging. Examples:
252    /// "Primary Relay", "EU West", "Backup Server", etc.
253    pub name: Option<String>,
254}
255
256#[cfg_attr(feature = "frb-api", frb(opaque))]
257impl RelayAddress {
258    /// Create a new connection info with minimal required fields
259    pub fn new(public_key: VerifyingKey) -> Self {
260        Self {
261            public_key,
262            addresses: BTreeSet::new(),
263            name: None,
264        }
265    }
266
267    /// Add a network address
268    pub fn with_address(mut self, address: NetworkAddress) -> Self {
269        self.addresses.insert(address);
270        self
271    }
272
273    pub fn with_address_str(mut self, address: String) -> Self {
274        self.addresses.insert(address.into());
275        self
276    }
277
278    /// Add multiple network addresses
279    pub fn with_addresses(mut self, addresses: impl IntoIterator<Item = NetworkAddress>) -> Self {
280        self.addresses.extend(addresses);
281        self
282    }
283
284    /// Set a human-readable name for this service
285    pub fn with_name(mut self, name: impl Into<String>) -> Self {
286        self.name = Some(name.into());
287        self
288    }
289
290    /// Get the service's display name (name if set, otherwise first address)
291    pub fn display_name(&self) -> String {
292        self.name.clone().unwrap_or_else(|| {
293            self.addresses
294                .iter()
295                .next()
296                .map(|addr| addr.to_string())
297                .unwrap_or_else(|| "Unknown Service".to_string())
298        })
299    }
300
301    /// Get all addresses that use the specified port (or default port)
302    pub fn addresses_with_port(&self, port: u16) -> Vec<NetworkAddress> {
303        self.addresses
304            .iter()
305            .filter(|addr| addr.port_or_default(port) == port)
306            .cloned()
307            .collect()
308    }
309
310    /// Get the first address, if any
311    pub fn primary_address(&self) -> Option<&NetworkAddress> {
312        self.addresses.iter().next()
313    }
314
315    /// Get the relay ID (public key ID)
316    pub fn id(&self) -> zoe_wire_protocol::KeyId {
317        self.public_key.id()
318    }
319
320    /// Get all addresses for connection attempts
321    pub fn all_addresses(&self) -> &BTreeSet<NetworkAddress> {
322        &self.addresses
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use std::net::{Ipv4Addr, Ipv6Addr};
330
331    #[test]
332    fn test_network_address_dns() {
333        let addr = NetworkAddress::dns("example.com");
334        assert_eq!(addr.port(), None);
335        assert_eq!(addr.port_or_default(8080), 8080);
336    }
337
338    #[test]
339    fn test_network_address_dns_with_port() {
340        let addr = NetworkAddress::dns_with_port("example.com", 9090);
341        assert_eq!(addr.port(), Some(9090));
342        assert_eq!(addr.port_or_default(8080), 9090);
343    }
344
345    #[test]
346    fn test_network_address_ipv4() {
347        let addr = NetworkAddress::ipv4(Ipv4Addr::new(192, 168, 1, 1));
348        assert_eq!(addr.port(), None);
349        assert_eq!(addr.port_or_default(8080), 8080);
350    }
351
352    #[test]
353    fn test_network_address_ipv4_with_port() {
354        let addr = NetworkAddress::ipv4_with_port(Ipv4Addr::new(192, 168, 1, 1), 9090);
355        assert_eq!(addr.port(), Some(9090));
356        assert_eq!(addr.port_or_default(8080), 9090);
357    }
358
359    #[test]
360    fn test_network_address_ipv6() {
361        let addr = NetworkAddress::ipv6(Ipv6Addr::LOCALHOST);
362        assert_eq!(addr.port(), None);
363        assert_eq!(addr.port_or_default(8080), 8080);
364    }
365
366    #[test]
367    fn test_network_address_ipv6_with_port() {
368        let addr = NetworkAddress::ipv6_with_port(Ipv6Addr::LOCALHOST, 9090);
369        assert_eq!(addr.port(), Some(9090));
370        assert_eq!(addr.port_or_default(8080), 9090);
371    }
372
373    #[test]
374    fn test_network_address_from_ip_addr() {
375        let ipv4_addr = NetworkAddress::from(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)));
376        assert!(matches!(ipv4_addr, NetworkAddress::Ipv4 { .. }));
377
378        let ipv6_addr = NetworkAddress::from(IpAddr::V6(Ipv6Addr::LOCALHOST));
379        assert!(matches!(ipv6_addr, NetworkAddress::Ipv6 { .. }));
380    }
381
382    #[test]
383    fn test_connection_info_creation() {
384        use zoe_wire_protocol::KeyPair;
385
386        let keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
387        let public_key = keypair.public_key();
388        let info = RelayAddress::new(public_key)
389            .with_address(NetworkAddress::dns("relay.example.com"))
390            .with_address(NetworkAddress::ipv4_with_port(
391                Ipv4Addr::new(192, 168, 1, 100),
392                8443,
393            ))
394            .with_name("Test Relay");
395
396        assert_eq!(info.addresses.len(), 2);
397        assert_eq!(info.name, Some("Test Relay".to_string()));
398        assert_eq!(info.display_name(), "Test Relay");
399    }
400
401    #[test]
402    fn test_connection_info_display_name_fallback() {
403        use zoe_wire_protocol::KeyPair;
404
405        let keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
406        let public_key = keypair.public_key();
407        let info =
408            RelayAddress::new(public_key).with_address(NetworkAddress::dns("relay.example.com"));
409
410        assert_eq!(info.display_name(), "relay.example.com");
411    }
412
413    #[test]
414    fn test_connection_info_addresses_with_port() {
415        use zoe_wire_protocol::KeyPair;
416
417        let keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
418        let public_key = keypair.public_key();
419        let info = RelayAddress::new(public_key)
420            .with_address(NetworkAddress::dns_with_port("relay1.example.com", 8443))
421            .with_address(NetworkAddress::dns_with_port("relay2.example.com", 9443))
422            .with_address(NetworkAddress::ipv4_with_port(
423                Ipv4Addr::new(192, 168, 1, 100),
424                8443,
425            ));
426
427        let port_8443_addrs = info.addresses_with_port(8443);
428        assert_eq!(port_8443_addrs.len(), 2);
429    }
430
431    #[test]
432    fn test_postcard_serialization_network_address() {
433        let addr = NetworkAddress::dns_with_port("example.com", 8443);
434        let serialized = postcard::to_stdvec(&addr).unwrap();
435        let deserialized: NetworkAddress = postcard::from_bytes(&serialized).unwrap();
436        assert_eq!(addr, deserialized);
437    }
438
439    #[test]
440    fn test_postcard_serialization_connection_info() {
441        use zoe_wire_protocol::KeyPair;
442
443        let keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
444        let public_key = keypair.public_key();
445        let info = RelayAddress::new(public_key)
446            .with_address(NetworkAddress::dns("relay.example.com"))
447            .with_name("Test Relay");
448
449        let serialized = postcard::to_stdvec(&info).unwrap();
450        let deserialized: RelayAddress = postcard::from_bytes(&serialized).unwrap();
451        assert_eq!(info, deserialized);
452    }
453
454    #[test]
455    fn test_from_socket_addr_ipv4_preserves_port() {
456        // Test that converting from SocketAddr preserves the port information
457        let socket_addr = SocketAddr::from(([192, 168, 1, 100], 13918));
458        let network_addr: NetworkAddress = socket_addr.into();
459
460        match network_addr {
461            NetworkAddress::Ipv4 { address, port } => {
462                assert_eq!(address, Ipv4Addr::new(192, 168, 1, 100));
463                assert_eq!(port, Some(13918));
464            }
465            _ => panic!("Expected IPv4 NetworkAddress"),
466        }
467
468        // Verify the port is correctly used in connection strings
469        assert_eq!(network_addr.port(), Some(13918));
470        assert_eq!(network_addr.port_or_default(8080), 13918);
471    }
472
473    #[test]
474    fn test_from_socket_addr_ipv6_preserves_port() {
475        // Test that converting from SocketAddr preserves the port information for IPv6
476        let socket_addr = SocketAddr::from(([0x2001, 0xdb8, 0, 0, 0, 0, 0, 1], 13918));
477        let network_addr: NetworkAddress = socket_addr.into();
478
479        match network_addr {
480            NetworkAddress::Ipv6 { address, port } => {
481                assert_eq!(address, Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1));
482                assert_eq!(port, Some(13918));
483            }
484            _ => panic!("Expected IPv6 NetworkAddress"),
485        }
486
487        // Verify the port is correctly used in connection strings
488        assert_eq!(network_addr.port(), Some(13918));
489        assert_eq!(network_addr.port_or_default(8080), 13918);
490    }
491
492    #[test]
493    fn test_from_socket_addr_localhost_preserves_port() {
494        // Test with localhost addresses that are commonly used in development
495        let ipv4_localhost = SocketAddr::from(([127, 0, 0, 1], 13918));
496        let ipv6_localhost = SocketAddr::from((Ipv6Addr::LOCALHOST, 13918));
497
498        let ipv4_network: NetworkAddress = ipv4_localhost.into();
499        let ipv6_network: NetworkAddress = ipv6_localhost.into();
500
501        assert_eq!(ipv4_network.port(), Some(13918));
502        assert_eq!(ipv6_network.port(), Some(13918));
503    }
504
505    #[test]
506    fn test_relay_address_with_socket_addr_preserves_port() {
507        // Integration test: ensure RelayAddress correctly uses the port from SocketAddr
508        use zoe_wire_protocol::KeyPair;
509
510        let keypair = KeyPair::generate_ed25519(&mut rand::thread_rng());
511        let socket_addr = SocketAddr::from(([89, 58, 47, 227], 13918));
512
513        let relay_address = RelayAddress::new(keypair.public_key())
514            .with_address(socket_addr.into())
515            .with_name("Test Server".to_string());
516
517        // Verify the address was stored with the correct port
518        let addresses: Vec<_> = relay_address.all_addresses().iter().collect();
519        assert_eq!(addresses.len(), 1);
520
521        match addresses[0] {
522            NetworkAddress::Ipv4 { address, port } => {
523                assert_eq!(*address, Ipv4Addr::new(89, 58, 47, 227));
524                assert_eq!(*port, Some(13918));
525            }
526            _ => panic!("Expected IPv4 NetworkAddress"),
527        }
528    }
529}