From 98da34a7dc216b5e022ff5400356c73a23dfac96 Mon Sep 17 00:00:00 2001 From: Yiannis Marangos Date: Wed, 7 Aug 2024 12:19:54 +0300 Subject: [PATCH] feat(websocket): Allow wss connections on IP addresses Pull-Request: #5525. --- transports/websocket/CHANGELOG.md | 3 + transports/websocket/src/framed.rs | 141 +++++++++++++++++++++++++---- 2 files changed, 124 insertions(+), 20 deletions(-) diff --git a/transports/websocket/CHANGELOG.md b/transports/websocket/CHANGELOG.md index 50b1c42d3e1..df51e2c807d 100644 --- a/transports/websocket/CHANGELOG.md +++ b/transports/websocket/CHANGELOG.md @@ -2,6 +2,9 @@ - Implement refactored `Transport`. See [PR 4568](https://github.com/libp2p/rust-libp2p/pull/4568) +- Allow wss connections on IP addresses. + See [PR 5525](https://github.com/libp2p/rust-libp2p/pull/5525). + ## 0.43.2 - fix: Avoid websocket panic on polling after errors. See [PR 5482]. diff --git a/transports/websocket/src/framed.rs b/transports/websocket/src/framed.rs index fc6a3f0e90e..074271e672f 100644 --- a/transports/websocket/src/framed.rs +++ b/transports/websocket/src/framed.rs @@ -21,7 +21,8 @@ use crate::{error::Error, quicksink, tls}; use either::Either; use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream}; -use futures_rustls::{client, rustls, server}; +use futures_rustls::rustls::pki_types::ServerName; +use futures_rustls::{client, server}; use libp2p_core::{ multiaddr::{Multiaddr, Protocol}, transport::{DialOpts, ListenerId, TransportError, TransportEvent}, @@ -32,6 +33,7 @@ use soketto::{ connection::{self, CloseReason}, handshake, }; +use std::net::IpAddr; use std::{collections::HashMap, ops::DerefMut, sync::Arc}; use std::{fmt, io, mem, pin::Pin, task::Context, task::Poll}; use url::Url; @@ -315,15 +317,12 @@ where let stream = if addr.use_tls { // begin TLS session - let dns_name = addr - .dns_name - .expect("for use_tls we have checked that dns_name is some"); - tracing::trace!(?dns_name, "Starting TLS handshake"); + tracing::trace!(?addr.server_name, "Starting TLS handshake"); let stream = tls_config .client - .connect(dns_name.clone(), stream) + .connect(addr.server_name.clone(), stream) .map_err(|e| { - tracing::debug!(?dns_name, "TLS handshake failed: {}", e); + tracing::debug!(?addr.server_name, "TLS handshake failed: {}", e); Error::Tls(tls::Error::from(e)) }) .await?; @@ -451,7 +450,7 @@ where struct WsAddress { host_port: String, path: String, - dns_name: Option>, + server_name: ServerName<'static>, use_tls: bool, tcp_addr: Multiaddr, } @@ -468,19 +467,21 @@ fn parse_ws_dial_addr(addr: Multiaddr) -> Result> { let mut protocols = addr.iter(); let mut ip = protocols.next(); let mut tcp = protocols.next(); - let (host_port, dns_name) = loop { + let (host_port, server_name) = loop { match (ip, tcp) { (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => { - break (format!("{ip}:{port}"), None) + let server_name = ServerName::IpAddress(IpAddr::V4(ip).into()); + break (format!("{ip}:{port}"), server_name); } (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => { - break (format!("{ip}:{port}"), None) + let server_name = ServerName::IpAddress(IpAddr::V6(ip).into()); + break (format!("[{ip}]:{port}"), server_name); } (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port))) | (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port))) | (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) | (Some(Protocol::Dnsaddr(h)), Some(Protocol::Tcp(port))) => { - break (format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?)) + break (format!("{h}:{port}"), tls::dns_name_ref(&h)?) } (Some(_), Some(p)) => { ip = Some(p); @@ -499,13 +500,7 @@ fn parse_ws_dial_addr(addr: Multiaddr) -> Result> { match protocols.pop() { p @ Some(Protocol::P2p(_)) => p2p = p, Some(Protocol::Ws(path)) => break (false, path.into_owned()), - Some(Protocol::Wss(path)) => { - if dns_name.is_none() { - tracing::debug!(address=%addr, "Missing DNS name in WSS address"); - return Err(Error::InvalidMultiaddr(addr)); - } - break (true, path.into_owned()); - } + Some(Protocol::Wss(path)) => break (true, path.into_owned()), _ => return Err(Error::InvalidMultiaddr(addr)), } }; @@ -519,7 +514,7 @@ fn parse_ws_dial_addr(addr: Multiaddr) -> Result> { Ok(WsAddress { host_port, - dns_name, + server_name, path, use_tls, tcp_addr, @@ -757,3 +752,109 @@ where .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) } } + +#[cfg(test)] +mod tests { + use super::*; + use libp2p_identity::PeerId; + use std::io; + + #[test] + fn dial_addr() { + let peer_id = PeerId::random(); + + // Check `/wss` + let addr = "/dns4/example.com/tcp/2222/wss" + .parse::() + .unwrap(); + let info = parse_ws_dial_addr::(addr).unwrap(); + assert_eq!(info.host_port, "example.com:2222"); + assert_eq!(info.path, "/"); + assert!(info.use_tls); + assert_eq!(info.server_name, "example.com".try_into().unwrap()); + assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap()); + + // Check `/wss` with `/p2p` + let addr = format!("/dns4/example.com/tcp/2222/wss/p2p/{peer_id}") + .parse() + .unwrap(); + let info = parse_ws_dial_addr::(addr).unwrap(); + assert_eq!(info.host_port, "example.com:2222"); + assert_eq!(info.path, "/"); + assert!(info.use_tls); + assert_eq!(info.server_name, "example.com".try_into().unwrap()); + assert_eq!( + info.tcp_addr, + format!("/dns4/example.com/tcp/2222/p2p/{peer_id}") + .parse() + .unwrap() + ); + + // Check `/wss` with `/ip4` + let addr = "/ip4/127.0.0.1/tcp/2222/wss".parse::().unwrap(); + let info = parse_ws_dial_addr::(addr).unwrap(); + assert_eq!(info.host_port, "127.0.0.1:2222"); + assert_eq!(info.path, "/"); + assert!(info.use_tls); + assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap()); + assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap()); + + // Check `/wss` with `/ip6` + let addr = "/ip6/::1/tcp/2222/wss".parse::().unwrap(); + let info = parse_ws_dial_addr::(addr).unwrap(); + assert_eq!(info.host_port, "[::1]:2222"); + assert_eq!(info.path, "/"); + assert!(info.use_tls); + assert_eq!(info.server_name, "::1".try_into().unwrap()); + assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap()); + + // Check `/ws` + let addr = "/dns4/example.com/tcp/2222/ws" + .parse::() + .unwrap(); + let info = parse_ws_dial_addr::(addr).unwrap(); + assert_eq!(info.host_port, "example.com:2222"); + assert_eq!(info.path, "/"); + assert!(!info.use_tls); + assert_eq!(info.server_name, "example.com".try_into().unwrap()); + assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap()); + + // Check `/ws` with `/p2p` + let addr = format!("/dns4/example.com/tcp/2222/ws/p2p/{peer_id}") + .parse() + .unwrap(); + let info = parse_ws_dial_addr::(addr).unwrap(); + assert_eq!(info.host_port, "example.com:2222"); + assert_eq!(info.path, "/"); + assert!(!info.use_tls); + assert_eq!(info.server_name, "example.com".try_into().unwrap()); + assert_eq!( + info.tcp_addr, + format!("/dns4/example.com/tcp/2222/p2p/{peer_id}") + .parse() + .unwrap() + ); + + // Check `/ws` with `/ip4` + let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::().unwrap(); + let info = parse_ws_dial_addr::(addr).unwrap(); + assert_eq!(info.host_port, "127.0.0.1:2222"); + assert_eq!(info.path, "/"); + assert!(!info.use_tls); + assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap()); + assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap()); + + // Check `/ws` with `/ip6` + let addr = "/ip6/::1/tcp/2222/ws".parse::().unwrap(); + let info = parse_ws_dial_addr::(addr).unwrap(); + assert_eq!(info.host_port, "[::1]:2222"); + assert_eq!(info.path, "/"); + assert!(!info.use_tls); + assert_eq!(info.server_name, "::1".try_into().unwrap()); + assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap()); + + // Check non-ws address + let addr = "/ip4/127.0.0.1/tcp/2222".parse::().unwrap(); + parse_ws_dial_addr::(addr).unwrap_err(); + } +}