Skip to content

Commit

Permalink
feat(websocket): Allow wss connections on IP addresses
Browse files Browse the repository at this point in the history
Pull-Request: #5525.
  • Loading branch information
oblique authored Aug 7, 2024
1 parent 823acd6 commit 98da34a
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 20 deletions.
3 changes: 3 additions & 0 deletions transports/websocket/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

- Implement refactored `Transport`.
See [PR 4568](https://github.com/libp2p/rust-libp2p/pull/4568)
- Allow wss connections on IP addresses.
See [PR 5525](https://github.com/libp2p/rust-libp2p/pull/5525).

## 0.43.2

- fix: Avoid websocket panic on polling after errors. See [PR 5482].
Expand Down
141 changes: 121 additions & 20 deletions transports/websocket/src/framed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
use crate::{error::Error, quicksink, tls};
use either::Either;
use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream};
use futures_rustls::{client, rustls, server};
use futures_rustls::rustls::pki_types::ServerName;
use futures_rustls::{client, server};
use libp2p_core::{
multiaddr::{Multiaddr, Protocol},
transport::{DialOpts, ListenerId, TransportError, TransportEvent},
Expand All @@ -32,6 +33,7 @@ use soketto::{
connection::{self, CloseReason},
handshake,
};
use std::net::IpAddr;
use std::{collections::HashMap, ops::DerefMut, sync::Arc};
use std::{fmt, io, mem, pin::Pin, task::Context, task::Poll};
use url::Url;
Expand Down Expand Up @@ -315,15 +317,12 @@ where

let stream = if addr.use_tls {
// begin TLS session
let dns_name = addr
.dns_name
.expect("for use_tls we have checked that dns_name is some");
tracing::trace!(?dns_name, "Starting TLS handshake");
tracing::trace!(?addr.server_name, "Starting TLS handshake");
let stream = tls_config
.client
.connect(dns_name.clone(), stream)
.connect(addr.server_name.clone(), stream)
.map_err(|e| {
tracing::debug!(?dns_name, "TLS handshake failed: {}", e);
tracing::debug!(?addr.server_name, "TLS handshake failed: {}", e);
Error::Tls(tls::Error::from(e))
})
.await?;
Expand Down Expand Up @@ -451,7 +450,7 @@ where
struct WsAddress {
host_port: String,
path: String,
dns_name: Option<rustls::pki_types::ServerName<'static>>,
server_name: ServerName<'static>,
use_tls: bool,
tcp_addr: Multiaddr,
}
Expand All @@ -468,19 +467,21 @@ fn parse_ws_dial_addr<T>(addr: Multiaddr) -> Result<WsAddress, Error<T>> {
let mut protocols = addr.iter();
let mut ip = protocols.next();
let mut tcp = protocols.next();
let (host_port, dns_name) = loop {
let (host_port, server_name) = loop {
match (ip, tcp) {
(Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => {
break (format!("{ip}:{port}"), None)
let server_name = ServerName::IpAddress(IpAddr::V4(ip).into());
break (format!("{ip}:{port}"), server_name);
}
(Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => {
break (format!("{ip}:{port}"), None)
let server_name = ServerName::IpAddress(IpAddr::V6(ip).into());
break (format!("[{ip}]:{port}"), server_name);
}
(Some(Protocol::Dns(h)), Some(Protocol::Tcp(port)))
| (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port)))
| (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port)))
| (Some(Protocol::Dnsaddr(h)), Some(Protocol::Tcp(port))) => {
break (format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?))
break (format!("{h}:{port}"), tls::dns_name_ref(&h)?)
}
(Some(_), Some(p)) => {
ip = Some(p);
Expand All @@ -499,13 +500,7 @@ fn parse_ws_dial_addr<T>(addr: Multiaddr) -> Result<WsAddress, Error<T>> {
match protocols.pop() {
p @ Some(Protocol::P2p(_)) => p2p = p,
Some(Protocol::Ws(path)) => break (false, path.into_owned()),
Some(Protocol::Wss(path)) => {
if dns_name.is_none() {
tracing::debug!(address=%addr, "Missing DNS name in WSS address");
return Err(Error::InvalidMultiaddr(addr));
}
break (true, path.into_owned());
}
Some(Protocol::Wss(path)) => break (true, path.into_owned()),
_ => return Err(Error::InvalidMultiaddr(addr)),
}
};
Expand All @@ -519,7 +514,7 @@ fn parse_ws_dial_addr<T>(addr: Multiaddr) -> Result<WsAddress, Error<T>> {

Ok(WsAddress {
host_port,
dns_name,
server_name,
path,
use_tls,
tcp_addr,
Expand Down Expand Up @@ -757,3 +752,109 @@ where
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
}

#[cfg(test)]
mod tests {
use super::*;
use libp2p_identity::PeerId;
use std::io;

#[test]
fn dial_addr() {
let peer_id = PeerId::random();

// Check `/wss`
let addr = "/dns4/example.com/tcp/2222/wss"
.parse::<Multiaddr>()
.unwrap();
let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
assert_eq!(info.host_port, "example.com:2222");
assert_eq!(info.path, "/");
assert!(info.use_tls);
assert_eq!(info.server_name, "example.com".try_into().unwrap());
assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());

// Check `/wss` with `/p2p`
let addr = format!("/dns4/example.com/tcp/2222/wss/p2p/{peer_id}")
.parse()
.unwrap();
let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
assert_eq!(info.host_port, "example.com:2222");
assert_eq!(info.path, "/");
assert!(info.use_tls);
assert_eq!(info.server_name, "example.com".try_into().unwrap());
assert_eq!(
info.tcp_addr,
format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
.parse()
.unwrap()
);

// Check `/wss` with `/ip4`
let addr = "/ip4/127.0.0.1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
assert_eq!(info.host_port, "127.0.0.1:2222");
assert_eq!(info.path, "/");
assert!(info.use_tls);
assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());

// Check `/wss` with `/ip6`
let addr = "/ip6/::1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
assert_eq!(info.host_port, "[::1]:2222");
assert_eq!(info.path, "/");
assert!(info.use_tls);
assert_eq!(info.server_name, "::1".try_into().unwrap());
assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());

// Check `/ws`
let addr = "/dns4/example.com/tcp/2222/ws"
.parse::<Multiaddr>()
.unwrap();
let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
assert_eq!(info.host_port, "example.com:2222");
assert_eq!(info.path, "/");
assert!(!info.use_tls);
assert_eq!(info.server_name, "example.com".try_into().unwrap());
assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());

// Check `/ws` with `/p2p`
let addr = format!("/dns4/example.com/tcp/2222/ws/p2p/{peer_id}")
.parse()
.unwrap();
let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
assert_eq!(info.host_port, "example.com:2222");
assert_eq!(info.path, "/");
assert!(!info.use_tls);
assert_eq!(info.server_name, "example.com".try_into().unwrap());
assert_eq!(
info.tcp_addr,
format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
.parse()
.unwrap()
);

// Check `/ws` with `/ip4`
let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
assert_eq!(info.host_port, "127.0.0.1:2222");
assert_eq!(info.path, "/");
assert!(!info.use_tls);
assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());

// Check `/ws` with `/ip6`
let addr = "/ip6/::1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
assert_eq!(info.host_port, "[::1]:2222");
assert_eq!(info.path, "/");
assert!(!info.use_tls);
assert_eq!(info.server_name, "::1".try_into().unwrap());
assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());

// Check non-ws address
let addr = "/ip4/127.0.0.1/tcp/2222".parse::<Multiaddr>().unwrap();
parse_ws_dial_addr::<io::Error>(addr).unwrap_err();
}
}

0 comments on commit 98da34a

Please sign in to comment.