From 955c773919e6043daf5b3e2b0ae51353f324ea4a Mon Sep 17 00:00:00 2001 From: LimpidCrypto Date: Mon, 19 Aug 2024 18:17:08 +0000 Subject: [PATCH] current state of refactoring --- Cargo.toml | 2 + examples/async_web_socket_client.rs | 50 +++++ .../websocket/async_websocket_client.rs | 86 ++++---- src/client/websocket/exceptions.rs | 30 +-- src/client/websocket/mod.rs | 107 ++++++++-- src/core/dns/mod.rs | 2 +- src/core/dns/queries/a/mod.rs | 10 +- src/core/dns/queries/aaaa/mod.rs | 10 +- .../dns/queries/{errors.rs => exceptions.rs} | 4 +- src/core/dns/queries/mod.rs | 4 +- src/core/mod.rs | 187 +++++++++++++++++- src/core/tcp/exceptions.rs | 24 +++ src/core/tcp/mod.rs | 100 +++++++++- src/core/tls/exceptions.rs | 6 +- src/core/tls/mod.rs | 76 ++++--- src/lib.rs | 2 + src/utils/mod.rs | 56 +++++- tests/common/codec.rs | 34 ---- tests/common/constants.rs | 4 +- tests/common/mod.rs | 45 +---- tests/integration/clients/async_websocket.rs | 104 +++------- tests/integration/clients/mod.rs | 2 +- 22 files changed, 659 insertions(+), 286 deletions(-) create mode 100644 examples/async_web_socket_client.rs rename src/core/dns/queries/{errors.rs => exceptions.rs} (84%) create mode 100644 src/core/tcp/exceptions.rs delete mode 100644 tests/common/codec.rs diff --git a/Cargo.toml b/Cargo.toml index 3c1bebb..74f7e44 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ embedded-websocket = { git = "https://github.com/LimpidCrypto/embedded-websocket reqwless = "0.12.1" webpki-roots = { version = "0.26.3", optional = true } rand_core = "0.6.4" +strum_macros = "0.26.4" # strum_macros = { version = "0.26.4", default-features = false } @@ -62,6 +63,7 @@ rand_core = "0.6.4" # rand_core = { version = "0.6.4", default-features = false } [dev-dependencies] +rand = "0.8.5" tokio = { version = "1.27.0", features = ["full"] } [features] diff --git a/examples/async_web_socket_client.rs b/examples/async_web_socket_client.rs new file mode 100644 index 0000000..2c2d776 --- /dev/null +++ b/examples/async_web_socket_client.rs @@ -0,0 +1,50 @@ +use em_as_net::{ + client::websocket::{ + AsyncWebSocketClient, ReadResult, WebSocketRead, WebSocketSendMessageType, WebSocketWrite, + }, + core::{tcp::TcpStream, tls::TlsStream}, +}; +use rand::thread_rng; +use url::Url; + +#[tokio::main] +async fn main() { + let uri = Url::parse("wss://ws.vi-server.org:443/mirror/").unwrap(); + let stream = TcpStream::connect(&uri).await.unwrap(); + println!("TCP Connected"); + let mut tls_stream = TlsStream::connect(stream, &uri).await.unwrap(); + println!("TLS Handshake Done"); + let mut buffer = [0u8; 4096]; + let rng = thread_rng(); + let mut websocket = + AsyncWebSocketClient::open(&mut tls_stream, &mut buffer, &uri, rng, None, None) + .await + .unwrap(); + println!("WebSocket Connected"); + websocket + .write( + &mut tls_stream, + &mut buffer, + WebSocketSendMessageType::Text, + true, + "Hello World".as_bytes(), + ) + .await + .unwrap(); + println!("Message Sent"); + loop { + let message = websocket + .try_read(&mut tls_stream, &mut buffer) + .await + .unwrap() + .unwrap(); + match message { + ReadResult::Text(text) => { + assert_eq!("Hello World".to_string(), text); + println!("Received message: {}", text); + } + _ => panic!("Expected 'Hello World' as text message."), + } + break; + } +} diff --git a/src/client/websocket/async_websocket_client.rs b/src/client/websocket/async_websocket_client.rs index b36098e..d53a7df 100644 --- a/src/client/websocket/async_websocket_client.rs +++ b/src/client/websocket/async_websocket_client.rs @@ -1,4 +1,4 @@ -use crate::{client::websocket::errors::WebsocketError, Err}; +use crate::{client::websocket::errors::WebSocketError, Err}; use alloc::string::ToString; use anyhow::Result; @@ -10,8 +10,8 @@ use core::{ task::Poll, }; use embedded_websocket::{ - framer_async::Framer as EmbeddedWebsocketFramer, Client as EmbeddedWebsocketClient, - WebSocket as EmbeddedWebsocket, + framer_async::Framer as EmbeddedWebSocketFramer, Client as EmbeddedWebSocketClient, + WebSocket as EmbeddedWebSocket, }; use futures::{Sink, Stream}; use rand_core::RngCore; @@ -22,43 +22,43 @@ use tokio::net::TcpStream; #[cfg(feature = "std")] use tokio_tungstenite::{ connect_async as tungstenite_connect_async, MaybeTlsStream as TungsteniteMaybeTlsStream, - WebSocketStream as TungsteniteWebsocketStream, + WebSocketStream as TungsteniteWebSocketStream, }; // Exports pub use embedded_websocket::{ framer_async::{ - FramerError as EmbeddedWebsocketFramerError, ReadResult as EmbeddedWebsocketReadMessageType, + FramerError as EmbeddedWebSocketFramerError, ReadResult as EmbeddedWebSocketReadMessageType, }, - Error as EmbeddedWebsocketError, WebSocketCloseStatusCode as EmbeddedWebsocketCloseStatusCode, - WebSocketOptions as EmbeddedWebsocketOptions, - WebSocketSendMessageType as EmbeddedWebsocketSendMessageType, - WebSocketState as EmbeddedWebsocketState, + Error as EmbeddedWebSocketError, WebSocketCloseStatusCode as EmbeddedWebSocketCloseStatusCode, + WebSocketOptions as EmbeddedWebSocketOptions, + WebSocketSendMessageType as EmbeddedWebSocketSendMessageType, + WebSocketState as EmbeddedWebSocketState, }; #[cfg(feature = "std")] -pub type AsyncWebsocketClientTungstenite = - AsyncWebsocketClient>, Status>; -pub type AsyncWebsocketClientEmbeddedWebsocketTokio = - AsyncWebsocketClient, Status>; +pub type AsyncWebSocketClientTungstenite = + AsyncWebSocketClient>, Status>; +pub type AsyncWebSocketClientEmbeddedWebSocketTokio = + AsyncWebSocketClient, Status>; #[cfg(feature = "std")] pub use tokio_tungstenite::tungstenite::Message as TungsteniteMessage; -pub struct WebsocketOpen; -pub struct WebsocketClosed; +pub struct WebSocketOpen; +pub struct WebSocketClosed; -pub struct AsyncWebsocketClient { +pub struct AsyncWebSocketClient { inner: T, status: PhantomData, } -impl AsyncWebsocketClient { +impl AsyncWebSocketClient { pub fn is_open(&self) -> bool { - core::any::type_name::() == core::any::type_name::() + core::any::type_name::() == core::any::type_name::() } } -impl Sink for AsyncWebsocketClient +impl Sink for AsyncWebSocketClient where T: Sink + Unpin, >::Error: Display, @@ -109,7 +109,7 @@ where } } -impl Stream for AsyncWebsocketClient +impl Stream for AsyncWebSocketClient where T: Stream + Unpin, { @@ -129,30 +129,30 @@ where #[cfg(feature = "std")] impl - AsyncWebsocketClient< - TungsteniteWebsocketStream>, - WebsocketClosed, + AsyncWebSocketClient< + TungsteniteWebSocketStream>, + WebSocketClosed, > { pub async fn open( uri: Url, ) -> Result< - AsyncWebsocketClient< - TungsteniteWebsocketStream>, - WebsocketOpen, + AsyncWebSocketClient< + TungsteniteWebSocketStream>, + WebSocketOpen, >, > { let (websocket_stream, _) = tungstenite_connect_async(uri.to_string()).await.unwrap(); - Ok(AsyncWebsocketClient { + Ok(AsyncWebSocketClient { inner: websocket_stream, - status: PhantomData::, + status: PhantomData::, }) } } impl - AsyncWebsocketClient, WebsocketClosed> + AsyncWebSocketClient, WebSocketClosed> where Rng: RngCore, { @@ -160,35 +160,35 @@ where stream: &mut (impl Stream> + for<'a> Sink<&'a [u8], Error = E> + Unpin), buffer: &mut [u8], rng: Rng, - websocket_options: &EmbeddedWebsocketOptions<'_>, + websocket_options: &EmbeddedWebSocketOptions<'_>, ) -> Result< - AsyncWebsocketClient, WebsocketOpen>, + AsyncWebSocketClient, WebSocketOpen>, > where B: AsRef<[u8]>, E: Debug, { - let websocket = EmbeddedWebsocket::::new_client(rng); - let mut framer = EmbeddedWebsocketFramer::new(websocket); + let websocket = EmbeddedWebSocket::::new_client(rng); + let mut framer = EmbeddedWebSocketFramer::new(websocket); framer .connect(stream, buffer, websocket_options) .await .unwrap(); - Ok(AsyncWebsocketClient { + Ok(AsyncWebSocketClient { inner: framer, - status: PhantomData::, + status: PhantomData::, }) } } -impl AsyncWebsocketClient, WebsocketOpen> +impl AsyncWebSocketClient, WebSocketOpen> where Rng: RngCore, { pub fn encode( &mut self, - message_type: EmbeddedWebsocketSendMessageType, + message_type: EmbeddedWebSocketSendMessageType, end_of_message: bool, from: &[u8], to: &mut [u8], @@ -208,7 +208,7 @@ where &mut self, stream: &mut (impl Sink<&'b [u8], Error = E> + Unpin), stream_buf: &'b mut [u8], - message_type: EmbeddedWebsocketSendMessageType, + message_type: EmbeddedWebSocketSendMessageType, end_of_message: bool, frame_buf: &'b [u8], ) -> Result<()> @@ -227,7 +227,7 @@ where &mut self, stream: &mut (impl Sink<&'b [u8], Error = E> + Unpin), stream_buf: &'b mut [u8], - close_status: EmbeddedWebsocketCloseStatusCode, + close_status: EmbeddedWebSocketCloseStatusCode, status_description: Option<&str>, ) -> Result<()> where @@ -245,13 +245,13 @@ where &'a mut self, stream: &mut (impl Stream> + Sink<&'a [u8], Error = E> + Unpin), buffer: &'a mut [u8], - ) -> Option>> + ) -> Option>> where E: Debug, { match self.inner.read(stream, buffer).await { Some(Ok(read_result)) => Some(Ok(read_result)), - Some(Err(error)) => Some(Err!(WebsocketError::from(error))), + Some(Err(error)) => Some(Err!(WebSocketError::from(error))), None => None, } } @@ -260,13 +260,13 @@ where &'a mut self, stream: &mut (impl Stream> + Sink<&'a [u8], Error = E> + Unpin), buffer: &'a mut [u8], - ) -> Result>> + ) -> Result>> where E: Debug, { match self.inner.read(stream, buffer).await { Some(Ok(read_result)) => Ok(Some(read_result)), - Some(Err(error)) => Err!(WebsocketError::from(error)), + Some(Err(error)) => Err!(WebSocketError::from(error)), None => Ok(None), } } diff --git a/src/client/websocket/exceptions.rs b/src/client/websocket/exceptions.rs index b394643..1da608a 100644 --- a/src/client/websocket/exceptions.rs +++ b/src/client/websocket/exceptions.rs @@ -5,9 +5,11 @@ use embedded_websocket::framer_async::FramerError; use thiserror_no_std::Error; #[derive(Debug, PartialEq, Eq, Error)] -pub enum WebsocketError { - #[error("Stream is not connected.")] - NotConnected, +pub enum WebSocketException { + #[error("Invalid domain")] + InvalidDomain, + #[error("Invalid scheme")] + InvalidScheme, // FramerError #[error("I/O error: {0:?}")] Io(E), @@ -17,7 +19,7 @@ pub enum WebsocketError { Utf8(Utf8Error), #[error("Invalid HTTP header")] HttpHeader, - #[error("Websocket error: {0:?}")] + #[error("WebSocket error: {0:?}")] WebSocket(embedded_websocket::Error), #[error("Disconnected")] Disconnected, @@ -25,25 +27,25 @@ pub enum WebsocketError { RxBufferTooSmall(usize), } -impl From> for WebsocketError { +impl From> for WebSocketException { fn from(e: FramerError) -> Self { match e { - FramerError::Io(e) => WebsocketError::Io(e), - FramerError::FrameTooLarge(size) => WebsocketError::FrameTooLarge(size), - FramerError::Utf8(e) => WebsocketError::Utf8(e), - FramerError::HttpHeader(_) => WebsocketError::HttpHeader, - FramerError::WebSocket(e) => WebsocketError::WebSocket(e), - FramerError::Disconnected => WebsocketError::Disconnected, - FramerError::RxBufferTooSmall(size) => WebsocketError::RxBufferTooSmall(size), + FramerError::Io(e) => WebSocketException::Io(e), + FramerError::FrameTooLarge(size) => WebSocketException::FrameTooLarge(size), + FramerError::Utf8(e) => WebSocketException::Utf8(e), + FramerError::HttpHeader(_) => WebSocketException::HttpHeader, + FramerError::WebSocket(e) => WebSocketException::WebSocket(e), + FramerError::Disconnected => WebSocketException::Disconnected, + FramerError::RxBufferTooSmall(size) => WebSocketException::RxBufferTooSmall(size), } } } -impl Into for WebsocketError { +impl Into for WebSocketException { fn into(self) -> anyhow::Error { anyhow!(self) } } #[cfg(feature = "std")] -impl alloc::error::Error for WebsocketError {} +impl alloc::error::Error for WebSocketException {} diff --git a/src/client/websocket/mod.rs b/src/client/websocket/mod.rs index 36866de..d86f3e1 100644 --- a/src/client/websocket/mod.rs +++ b/src/client/websocket/mod.rs @@ -6,33 +6,70 @@ use alloc::string::ToString; use anyhow::Result; use embedded_io_async::{Read, Write}; use embedded_websocket::{framer_async::Framer, Client, WebSocketClient, WebSocketOptions}; -use exceptions::WebsocketError; +use exceptions::WebSocketException; use rand_core::RngCore; use url::Url; +pub use embedded_websocket::{framer_async::ReadResult, WebSocketSendMessageType}; + use crate::Err; -pub struct WebsocketClosed; -pub struct WebsocketOpen; +pub struct WebSocketClosed; +pub struct WebSocketOpen; + +#[allow(async_fn_in_trait)] +pub trait WebSocketRead { + async fn read<'a, S: Read + Write + Unpin>( + &'a mut self, + stream: &mut S, + buf: &'a mut [u8], + ) -> Option>>; -pub struct AsyncWebsocketClient { + async fn try_read<'a, S: Read + Write + Unpin>( + &'a mut self, + stream: &mut S, + buf: &'a mut [u8], + ) -> Result>> { + match self.read(stream, buf).await { + Some(Ok(result)) => Ok(Some(result)), + Some(Err(e)) => Err(e), + None => Ok(None), + } + } +} + +#[allow(async_fn_in_trait)] +pub trait WebSocketWrite { + async fn write( + &mut self, + tx: &mut S, + tx_buf: &mut [u8], + message_type: WebSocketSendMessageType, + end_of_message: bool, + frame_buf: &[u8], + ) -> Result<()>; +} + +pub struct AsyncWebSocketClient { inner: Framer, status: PhantomData, } -impl AsyncWebsocketClient { +impl AsyncWebSocketClient { pub fn is_open(&self) -> bool { - core::any::type_name::() == core::any::type_name::() + core::any::type_name::() == core::any::type_name::() } } -impl AsyncWebsocketClient { +impl AsyncWebSocketClient { pub async fn open( - buf: &mut [u8], stream: &mut S, - uri: Url, + buf: &mut [u8], + uri: &Url, rng: T, - ) -> Result> + sub_protocols: Option<&[&str]>, + additional_headers: Option<&[&str]>, + ) -> Result> where S: Read + Write + Unpin, { @@ -54,23 +91,57 @@ impl AsyncWebsocketClient { let path = uri.path(); let host = match uri.host_str() { Some(host) => host, - None => return Err(WebsocketError::Disconnected.into()), + None => return Err(WebSocketException::::InvalidDomain.into()), }; let origin = scheme.to_string() + "://" + host + ":" + &port + path; let websocket_options = WebSocketOptions { path, host, - origin: &origin, - sub_protocols: None, - additional_headers: None, + origin: origin.as_str(), + sub_protocols, + additional_headers, }; - let websocket = Framer::new(WebSocketClient::new_client(rng)); + let mut websocket = Framer::new(WebSocketClient::new_client(rng)); match websocket.connect(stream, buf, &websocket_options).await { - Ok(_) => Ok(AsyncWebsocketClient { + Ok(_) => Ok(AsyncWebSocketClient { inner: websocket, - status: PhantomData::, + status: PhantomData::, }), - Err(e) => Err!(e), + Err(e) => Err(WebSocketException::from(e).into()), + } + } +} + +impl WebSocketRead for AsyncWebSocketClient { + async fn read<'a, S: Read + Write + Unpin>( + &'a mut self, + stream: &mut S, + buf: &'a mut [u8], + ) -> Option>> { + match self.inner.read(stream, buf).await { + Some(Ok(result)) => Some(Ok(result)), + Some(Err(e)) => Some(Err!(WebSocketException::from(e))), + None => None, + } + } +} + +impl WebSocketWrite for AsyncWebSocketClient { + async fn write( + &mut self, + tx: &mut S, + tx_buf: &mut [u8], + message_type: WebSocketSendMessageType, + end_of_message: bool, + frame_buf: &[u8], + ) -> Result<()> { + match self + .inner + .write(tx, tx_buf, message_type, end_of_message, frame_buf) + .await + { + Ok(_) => Ok(()), + Err(e) => Err!(WebSocketException::from(e)), } } } diff --git a/src/core/dns/mod.rs b/src/core/dns/mod.rs index 980cea2..5d9a61d 100644 --- a/src/core/dns/mod.rs +++ b/src/core/dns/mod.rs @@ -1,6 +1,6 @@ mod queries; -pub use queries::DnsError; +pub use queries::DnsException; use queries::{Aaaa, Lookup, A}; use anyhow::Result; diff --git a/src/core/dns/queries/a/mod.rs b/src/core/dns/queries/a/mod.rs index bbb38e9..0138312 100644 --- a/src/core/dns/queries/a/mod.rs +++ b/src/core/dns/queries/a/mod.rs @@ -10,7 +10,7 @@ mod impl_lookup { use core::net::SocketAddr; use super::*; - use crate::{core::dns::DnsError, Err}; + use crate::{core::dns::DnsException, Err}; use alloc::{string::ToString, vec::Vec}; use tokio::net::lookup_host; use url::Url; @@ -18,8 +18,8 @@ mod impl_lookup { impl Lookup for A { async fn lookup(url: &Url) -> Result { let url = url.to_string(); - let addresses = match lookup_host(&*url).await { - Err(_) => return Err!(DnsError::LookupError(url.into())), + let addresses = match lookup_host(&url).await { + Err(_) => return Err!(DnsException::LookupError(url.clone().into())), Ok(socket_addrs_iter) => socket_addrs_iter, }; return match addresses @@ -28,8 +28,8 @@ mod impl_lookup { .first() { Some(SocketAddr::V4(addrs)) => Ok(Ipv4Addr::from(addrs.ip().octets())), - None => Err!(DnsError::LookupIpv4Error(url.into())), - _ => Err!(DnsError::LookupIpv4Error(url.into())), + None => Err!(DnsException::LookupIpv4Error(url.into())), + _ => Err!(DnsException::LookupIpv4Error(url.into())), }; } } diff --git a/src/core/dns/queries/aaaa/mod.rs b/src/core/dns/queries/aaaa/mod.rs index d0c4817..ab1c7e8 100644 --- a/src/core/dns/queries/aaaa/mod.rs +++ b/src/core/dns/queries/aaaa/mod.rs @@ -10,7 +10,7 @@ mod impl_lookup { use super::*; use crate::core::dns::queries::Lookup; - use crate::core::dns::DnsError; + use crate::core::dns::DnsException; use crate::Err; use alloc::string::ToString; use alloc::vec::Vec; @@ -20,8 +20,8 @@ mod impl_lookup { impl Lookup for Aaaa { async fn lookup(url: &Url) -> Result { let url = url.to_string(); - let addresses = match lookup_host(&*url).await { - Err(_) => return Err!(DnsError::LookupError(url.into())), + let addresses = match lookup_host(&url).await { + Err(_) => return Err!(DnsException::LookupError(url.clone().into())), Ok(socket_addrs_iter) => socket_addrs_iter, }; return match addresses @@ -30,8 +30,8 @@ mod impl_lookup { .first() { Some(SocketAddr::V6(addrs)) => Ok(Ipv6Addr::from(addrs.ip().octets())), - None => Err!(DnsError::LookupIpv6Error(url.into())), - _ => Err!(DnsError::LookupIpv6Error(url.into())), + None => Err!(DnsException::LookupIpv6Error(url.into())), + _ => Err!(DnsException::LookupIpv6Error(url.into())), }; } } diff --git a/src/core/dns/queries/errors.rs b/src/core/dns/queries/exceptions.rs similarity index 84% rename from src/core/dns/queries/errors.rs rename to src/core/dns/queries/exceptions.rs index a7d7b54..cee0f66 100644 --- a/src/core/dns/queries/errors.rs +++ b/src/core/dns/queries/exceptions.rs @@ -2,7 +2,7 @@ use alloc::borrow::Cow; use thiserror_no_std::Error; #[derive(Debug, Clone, PartialEq, Eq, Error)] -pub enum DnsError<'a> { +pub enum DnsException<'a> { #[error("Invalid socket address (found: {0:?})")] LookupError(Cow<'a, str>), #[error("Unable to look up IPv4 address for hostname (found: {0:?})")] @@ -12,4 +12,4 @@ pub enum DnsError<'a> { } #[cfg(feature = "std")] -impl alloc::error::Error for DnsError<'_> {} +impl alloc::error::Error for DnsException<'_> {} diff --git a/src/core/dns/queries/mod.rs b/src/core/dns/queries/mod.rs index c9e45e2..f83b693 100644 --- a/src/core/dns/queries/mod.rs +++ b/src/core/dns/queries/mod.rs @@ -3,8 +3,8 @@ mod a; pub use a::A; mod aaaa; pub use aaaa::Aaaa; -mod errors; -pub use errors::DnsError; +mod exceptions; +pub use exceptions::DnsException; use anyhow::Result; use url::Url; diff --git a/src/core/mod.rs b/src/core/mod.rs index c96eef2..3d581f0 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -1,9 +1,190 @@ +use alloc::{ + dbg, format, + string::{String, ToString}, +}; +use anyhow::Result; +use thiserror_no_std::Error; +use url::Host; + #[cfg(feature = "dns")] pub mod dns; -// mod framed; -// mod io; #[cfg(feature = "tcp")] pub mod tcp; -// TODO: uncomment and make tls public as soon as it's working #[cfg(feature = "tls")] pub mod tls; + +#[derive(Debug, Clone)] +pub struct Url(url::Url); + +impl Url { + pub fn parse(url: &str) -> Result { + let mut url = + Url(url::Url::parse(url).map_err(|e| UrlException::InvalidUrl(e.to_string()).into())?); + url.assure_port().unwrap(); + dbg!(url.clone()); + Ok(url) + } + + pub fn for_tcp(&self) -> Result { + let mut url = self.clone(); + url.assure_port().unwrap(); + url.set_port(Some(443)).unwrap(); + let host = match self.host_str() { + Some(host) => host, + None => return Err(anyhow::anyhow!("Host not found")), + }; + let port = match self.port() { + Some(port) => port, + None => return Err(anyhow::anyhow!("Port not found")), + }; + let path = self.path(); + let url = host.to_string() + ":" + &port.to_string() + path; + + Ok(url) + } + + pub fn scheme(&self) -> &str { + self.0.scheme() + } + + pub fn host(&self) -> Option> { + self.0.host() + } + + pub fn host_str(&self) -> Option<&str> { + self.0.host_str() + } + + pub fn port(&self) -> Option { + self.0.port() + } + + pub fn path(&self) -> &str { + self.0.path() + } + + pub fn query(&self) -> Option<&str> { + self.0.query() + } + + pub fn fragment(&self) -> Option<&str> { + self.0.fragment() + } + + pub fn set_scheme(&mut self, scheme: &str) -> Result<(), ()> { + self.0.set_scheme(scheme) + } + + pub fn set_port(&mut self, port: Option) -> Result<(), ()> { + self.0.set_port(port) + } + + pub fn set_path(&mut self, path: &str) -> () { + self.0.set_path(path) + } + + pub fn set_query(&mut self, query: &str) -> () { + self.0.set_query(Some(query)) + } + + pub fn get_with_port(&self, default: Option) -> String { + let scheme = self.0.scheme(); + let host = match self.0.host_str() { + Some(host) => host, + None => return "".to_string(), + }; + let port = match self.0.port() { + Some(port) => port, + None => match scheme { + "wss" | "https" => 443, + "ws" | "http" => 80, + _ => default.unwrap_or(80), + }, + }; + let path = self.path(); + // let query = self.query().map_or("".into(), |q| format!("?{}", q)); + // let fragment = self.fragment().map_or("".into(), |f| format!("#{}", f)); + + format!( + "{}://{}:{}{}", + scheme, + host, + port, + path, + // query, + // fragment + ) + } + + fn assure_port(&mut self) -> Result<(), ()> { + if self.0.port().is_none() { + let port = match self.0.scheme() { + "wss" | "https" => 443, + "ws" | "http" => 80, + _ => 80, + }; + self.0.set_port(Some(port))?; + } + + Ok(()) + } +} + +impl TryFrom for Url { + type Error = (); + + fn try_from(url: url::Url) -> Result { + let mut url = Url(url); + url.assure_port()?; + Ok(url) + } +} + +impl From for url::Url { + fn from(url: Url) -> url::Url { + url.0 + } +} + +impl TryFrom<&str> for Url { + type Error = anyhow::Error; + + fn try_from(url: &str) -> Result { + let url = + Url(url::Url::parse(url).map_err(|e| UrlException::InvalidUrl(e.to_string()).into())?); + + Ok(url) + } +} + +impl ToString for Url { + fn to_string(&self) -> String { + self.0.to_string() + } +} + +impl From for String { + fn from(url: Url) -> Self { + url.to_string() + } +} + +#[derive(Debug, Error)] +pub enum UrlException { + #[error("Invalid URL: {0}")] + InvalidUrl(String), + #[error("Invalid scheme: {0}")] + InvalidScheme(String), +} + +impl From for UrlException { + fn from(e: url::ParseError) -> Self { + UrlException::InvalidUrl(e.to_string()) + } +} + +impl Into for UrlException { + fn into(self) -> anyhow::Error { + anyhow::anyhow!(self) + } +} diff --git a/src/core/tcp/exceptions.rs b/src/core/tcp/exceptions.rs new file mode 100644 index 0000000..55ab5a5 --- /dev/null +++ b/src/core/tcp/exceptions.rs @@ -0,0 +1,24 @@ +use embedded_io_async::ErrorKind; +use strum_macros::Display; +use thiserror_no_std::Error; + +#[derive(Debug, Error, Display)] +pub enum TcpException { + IoError(#[from] alloc::io::Error), + EmbeddedIoAsyncError(embedded_io_async::ErrorKind), +} + +impl embedded_io_async::Error for TcpException { + fn kind(&self) -> embedded_io_async::ErrorKind { + match self { + TcpException::EmbeddedIoAsyncError(e) => *e, + _ => ErrorKind::Other, + } + } +} + +impl Into for TcpException { + fn into(self) -> anyhow::Error { + anyhow::anyhow!(self) + } +} diff --git a/src/core/tcp/mod.rs b/src/core/tcp/mod.rs index 777e189..ada8a72 100644 --- a/src/core/tcp/mod.rs +++ b/src/core/tcp/mod.rs @@ -1,3 +1,5 @@ +pub mod exceptions; + #[cfg(not(feature = "std"))] pub use _embassy::*; #[cfg(feature = "std")] @@ -12,8 +14,100 @@ mod _embassy { #[cfg(feature = "std")] mod _tokio { + use super::exceptions::TcpException; + use crate::{utils::derive_tcp_url, Err}; + use anyhow::Result; use embedded_io_adapters::tokio_1::FromTokio; - use tokio::net::{TcpListener as TokioTcpListener, TcpStream as TokioTcpStream}; - pub type TcpStream = FromTokio; - pub type TcpListener = FromTokio; + use embedded_io_async::{ErrorType, Read, Write}; + use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::{TcpListener as TokioTcpListener, TcpStream as TokioTcpStream}, + }; + use url::Url; + + pub struct TcpStream(FromTokio); + + impl TcpStream { + pub async fn connect(url: &Url) -> Result { + let stream = TokioTcpStream::connect(derive_tcp_url(url, None)?) + .await + .map_err(|e| TcpException::IoError(e).into())?; + Ok(TcpStream(FromTokio::new(stream))) + } + } + + impl ErrorType for TcpStream { + type Error = TcpException; + } + + impl Read for TcpStream { + async fn read(&mut self, buf: &mut [u8]) -> Result { + self.0.read(buf).await.map_err(|e| TcpException::IoError(e)) + } + } + + impl Write for TcpStream { + async fn write(&mut self, buf: &[u8]) -> Result { + self.0 + .write(buf) + .await + .map_err(|e| TcpException::IoError(e)) + } + + async fn flush(&mut self) -> Result<(), Self::Error> { + self.0.flush().await.map_err(|e| TcpException::IoError(e)) + } + } + + impl AsyncRead for TcpStream { + fn poll_read( + self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> core::task::Poll> { + core::pin::Pin::new(self.get_mut().0.inner_mut()).poll_read(cx, buf) + } + } + + impl AsyncWrite for TcpStream { + fn poll_write( + self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + buf: &[u8], + ) -> core::task::Poll> { + core::pin::Pin::new(self.get_mut().0.inner_mut()).poll_write(cx, buf) + } + + fn poll_flush( + self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + ) -> core::task::Poll> { + core::pin::Pin::new(self.get_mut().0.inner_mut()).poll_flush(cx) + } + + fn poll_shutdown( + self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + ) -> core::task::Poll> { + core::pin::Pin::new(self.get_mut().0.inner_mut()).poll_shutdown(cx) + } + } + + pub struct TcpListener(TokioTcpListener); + + impl TcpListener { + pub async fn bind(url: &Url) -> Result { + match TokioTcpListener::bind(derive_tcp_url(url, None)?).await { + Ok(listener) => Ok(TcpListener(listener)), + Err(e) => Err!(TcpException::IoError(e)), + } + } + + pub async fn accept(&self) -> Result<(TcpStream, alloc::net::SocketAddr)> { + match self.0.accept().await { + Ok((stream, addr)) => Ok((TcpStream(FromTokio::new(stream)), addr)), + Err(e) => Err!(TcpException::IoError(e)), + } + } + } } diff --git a/src/core/tls/exceptions.rs b/src/core/tls/exceptions.rs index 93bced3..36b0235 100644 --- a/src/core/tls/exceptions.rs +++ b/src/core/tls/exceptions.rs @@ -1,5 +1,6 @@ use anyhow::anyhow; use embedded_io_async::ErrorKind; +use rustls::pki_types::InvalidDnsNameError; use thiserror_no_std::Error; #[derive(Debug, Error)] @@ -10,6 +11,8 @@ pub enum TlsException { NoDomain, #[error("Embedded IO async error")] EmbeddedIoAsyncError(ErrorKind), + #[error("Invalid server name: {0}")] + InvalidServerName(InvalidDnsNameError), } impl From for TlsException { @@ -27,9 +30,8 @@ impl Into for TlsException { impl embedded_io_async::Error for TlsException { fn kind(&self) -> embedded_io_async::ErrorKind { match self { - TlsException::IoError(_) => ErrorKind::Other, - TlsException::NoDomain => ErrorKind::Other, TlsException::EmbeddedIoAsyncError(e) => *e, + _ => ErrorKind::Other, } } } diff --git a/src/core/tls/mod.rs b/src/core/tls/mod.rs index edc8341..c3e9dd4 100644 --- a/src/core/tls/mod.rs +++ b/src/core/tls/mod.rs @@ -2,75 +2,79 @@ mod exceptions; use embedded_io_adapters::tokio_1::FromTokio; pub use exceptions::*; -use tokio_rustls::client::TlsStream; +use tokio_rustls::client::TlsStream as TokioRustlsTlsStream; use anyhow::Result; use embedded_io_async::{Read, Write}; -#[cfg(not(feature = "std"))] -use rustls::{ClientConnection, ServerConnection}; #[cfg(feature = "std")] -use tokio_rustls::TlsConnector; - -#[cfg(not(feature = "std"))] -pub struct TlsSocketClient(ClientConnection); -#[cfg(not(feature = "std"))] -pub struct TlsSocketServer(ServerConnection); - -#[cfg(feature = "std")] -pub struct TlsSocket(FromTokio>); +pub use tokio_tls_stream::*; #[cfg(feature = "std")] -mod tokio_tls_client { - use alloc::sync::Arc; +mod tokio_tls_stream { + use alloc::{borrow::Cow, string::String, sync::Arc}; use embedded_io_async::ErrorType; use rustls::{pki_types::ServerName, ClientConfig, RootCertStore}; use tokio::io::{AsyncRead, AsyncWrite}; + use tokio_rustls::{TlsAcceptor, TlsConnector}; use url::Url; - use crate::core::tcp::TcpStream; - use super::*; + use crate::{utils::ws_to_http, Err}; + + pub struct TlsStream(FromTokio>); - impl TlsSocket + impl TlsStream where S: AsyncRead + AsyncWrite + Unpin, { - pub async fn connect<'a>(url: Url) -> Result> { - let stream = TcpStream::new(inner) + pub async fn connect(stream: S, url: &Url) -> Result> { let mut root_cert_store = RootCertStore::empty(); root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); let config = ClientConfig::builder() .with_root_certificates(root_cert_store) .with_no_client_auth(); let connector = TlsConnector::from(Arc::new(config)); + let url_with_http = ws_to_http(url); + let dns_name = Self::get_dns_name(&url_with_http)?; + let server_name = match ServerName::try_from(String::from(dns_name)) { + Ok(server_name) => server_name, + Err(e) => return Err!(TlsException::InvalidServerName(e)), + }; + + let stream = match connector.connect(server_name, stream).await { + Ok(stream) => stream, + Err(e) => return Err!(TlsException::IoError(e)), + }; + + Ok(TlsStream(FromTokio::new(stream))) + } - let stream = connector - .connect(server_name, stream) - .await - .map_err(|e| TlsException::IoError(e).into())?; - - Ok(TlsSocket(FromTokio::new(stream))) + fn get_dns_name(url: &Url) -> Result> { + match url.host_str() { + Some(host) => Ok(host.into()), + None => Err!(TlsException::NoDomain), + } } } - impl TlsSocket + impl TlsStream where S: AsyncRead + AsyncWrite + Unpin, { - pub async fn accept(stream: S, url: &Url) -> Result { + pub async fn accept(_stream: S, _url: &Url) -> Result { todo!("Implement accept as TlsListener"); } } - impl ErrorType for TlsSocket + impl ErrorType for TlsStream where S: AsyncRead + AsyncWrite + Unpin, { type Error = TlsException; } - impl Read for TlsSocket + impl Read for TlsStream where S: AsyncRead + AsyncWrite + Unpin, { @@ -82,7 +86,7 @@ mod tokio_tls_client { } } - impl Write for TlsSocket + impl Write for TlsStream where S: AsyncRead + AsyncWrite + Unpin, { @@ -100,4 +104,16 @@ mod tokio_tls_client { .map_err(|e| TlsException::IoError(e).into()) } } + + pub struct TlsListener(TlsAcceptor); + + // impl TlsListener { + // pub async fn bind(_url: &Url) -> Result { + // todo!("Implement bind as TlsListener"); + // } + + // pub async fn accept(&self, _stream: S) -> Result> { + // todo!("Implement accept as TlsListener"); + // } + // } } diff --git a/src/lib.rs b/src/lib.rs index 600433b..d48be96 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,3 +15,5 @@ pub mod core; pub mod utils; mod _anyhow; + +pub use url::Url; diff --git a/src/utils/mod.rs b/src/utils/mod.rs index e1b6cba..c64eb9d 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,8 +1,50 @@ -#[macro_export] -macro_rules! singleton { - ($val:expr) => {{ - type T = impl Sized; - static STATIC_CELL: StaticCell = StaticCell::new(); - STATIC_CELL.init_with(move || $val) - }}; +use alloc::string::{String, ToString}; +use anyhow::Result; +use url::Url; + +pub fn http_to_ws(uri: &Url) -> Url { + let mut ws_uri = uri.clone(); + ws_uri + .set_scheme(match uri.scheme() { + "https" => "wss", + "http" => "ws", + _ => uri.scheme(), + }) + .unwrap(); + ws_uri +} + +pub fn ws_to_http(uri: &Url) -> Url { + let mut http_uri = uri.clone(); + http_uri + .set_scheme(match uri.scheme() { + "wss" => "https", + "ws" => "http", + _ => uri.scheme(), + }) + .unwrap(); + http_uri +} + +pub fn derive_tcp_url(url: &Url, default: Option) -> Result { + let host = match url.host_str() { + Some(host) => host, + None => return Err(anyhow::anyhow!("Host not found")), + }; + let port = match url.port() { + Some(port) => port, + None => match default { + Some(port) => port, + None => match url.scheme() { + "https" | "wss" => 443, + "http" | "ws" => 80, + _ => 80, + }, + }, + } + .to_string(); + // let path = url.path().to_string(); + let url = host.to_string() + ":" + &port; // + &path; + + Ok(url) } diff --git a/tests/common/codec.rs b/tests/common/codec.rs deleted file mode 100644 index 4410468..0000000 --- a/tests/common/codec.rs +++ /dev/null @@ -1,34 +0,0 @@ -use bytes::{BufMut, BytesMut}; -use tokio_util::codec::{Decoder, Encoder}; - -pub struct Codec {} - -impl Codec { - pub fn new() -> Self { - Codec {} - } -} - -impl Decoder for Codec { - type Item = BytesMut; - type Error = std::io::Error; - - fn decode(&mut self, buf: &mut BytesMut) -> Result, std::io::Error> { - if !buf.is_empty() { - let len = buf.len(); - Ok(Some(buf.split_to(len))) - } else { - Ok(None) - } - } -} - -impl Encoder<&[u8]> for Codec { - type Error = std::io::Error; - - fn encode(&mut self, data: &[u8], buf: &mut BytesMut) -> Result<(), std::io::Error> { - buf.reserve(data.len()); - buf.put(data); - Ok(()) - } -} diff --git a/tests/common/constants.rs b/tests/common/constants.rs index f6faa35..154eb14 100644 --- a/tests/common/constants.rs +++ b/tests/common/constants.rs @@ -1,3 +1,3 @@ -pub const ECHO_WS_SERVER: &'static str = "ws://ws.vi-server.org/mirror/"; -pub const ECHO_WSS_SERVER: &'static str = "wss://ws.vi-server.org/mirror/"; +pub const ECHO_WS_SERVER: &'static str = "ws://ws.vi-server.org:80/mirror/"; +pub const ECHO_WSS_SERVER: &'static str = "wss://ws.vi-server.org:443/mirror/"; pub const ECHO_WS_AS_IP_SERVER: &'static str = "192.236.209.31:80"; diff --git a/tests/common/mod.rs b/tests/common/mod.rs index de36f06..6925cd4 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,46 +1,21 @@ -pub mod codec; mod constants; pub use constants::*; -use em_as_net::client::websocket::{ - AsyncWebsocketClientEmbeddedWebsocketTokio, AsyncWebsocketClientTungstenite, - EmbeddedWebsocketOptions, WebsocketOpen, -}; +use em_as_net::client::websocket::{AsyncWebSocketClient, WebSocketOpen}; +use embedded_io_async::{Read, Write}; use rand::{rngs::ThreadRng, thread_rng}; -use tokio::net::TcpStream; -use tokio_util::codec::Framed; +use url::Url; -pub async fn connect_to_ws_tungstenite_echo<'a>() -> AsyncWebsocketClientTungstenite -{ - let websocket = AsyncWebsocketClientTungstenite::open(ECHO_WS_SERVER.parse().unwrap()) - .await - .unwrap(); - assert!(websocket.is_open()); - - websocket -} +pub async fn connect_ws( + uri: &Url, + stream: &mut S, + buffer: &mut [u8], +) -> AsyncWebSocketClient { + let rng = thread_rng(); -pub async fn connect_to_tungstenite_wss_echo<'a>() -> AsyncWebsocketClientTungstenite -{ - let websocket = AsyncWebsocketClientTungstenite::open(ECHO_WSS_SERVER.parse().unwrap()) + let websocket = AsyncWebSocketClient::open(stream, buffer, uri, rng, None, None) .await .unwrap(); - assert!(websocket.is_open()); - - websocket -} - -pub async fn connect_to_embedded_websocket_tokio_ws_echo<'a>( - stream: &'a mut Framed, - buffer: &'a mut [u8], - websocket_options: &'a EmbeddedWebsocketOptions<'a>, -) -> AsyncWebsocketClientEmbeddedWebsocketTokio { - let rng = thread_rng(); - - let websocket = - AsyncWebsocketClientEmbeddedWebsocketTokio::open(stream, buffer, rng, websocket_options) - .await - .unwrap(); assert!(websocket.is_open()); diff --git a/tests/integration/clients/async_websocket.rs b/tests/integration/clients/async_websocket.rs index d4f946c..fd4899c 100644 --- a/tests/integration/clients/async_websocket.rs +++ b/tests/integration/clients/async_websocket.rs @@ -1,100 +1,46 @@ -use crate::common::connect_to_embedded_websocket_tokio_ws_echo; -use crate::common::{ - codec::Codec, connect_to_tungstenite_wss_echo, connect_to_ws_tungstenite_echo, - ECHO_WS_AS_IP_SERVER, +use em_as_net::{ + client::websocket::{ReadResult, WebSocketRead, WebSocketSendMessageType, WebSocketWrite}, + core::{tcp::TcpStream, tls::TlsStream}, }; +use url::Url; -use em_as_net::client::websocket::{ - EmbeddedWebsocketOptions, EmbeddedWebsocketReadMessageType, EmbeddedWebsocketSendMessageType, - TungsteniteMessage, -}; -use futures::{SinkExt, TryStreamExt}; -use tokio::net::TcpStream; -use tokio_util::codec::Framed; - -#[tokio::test] -async fn test_websocket_non_tls() { - let mut websocket = connect_to_ws_tungstenite_echo().await; - websocket - .send(TungsteniteMessage::Text("Hello World".to_string())) - .await - .unwrap(); - - loop { - let message = websocket.try_next().await.unwrap().unwrap(); - match message { - TungsteniteMessage::Text(text) => { - assert_eq!("Hello World".to_string(), text) - } - _ => panic!("Expected 'Hello World' as text message."), - } - break; - } -} +use crate::common::{connect_ws, ECHO_WSS_SERVER}; #[tokio::test] async fn test_websocket_tls() { - let mut websocket = connect_to_tungstenite_wss_echo().await; - websocket - .send(TungsteniteMessage::Text("Hello World".to_string())) - .await - .unwrap(); - - loop { - let message = websocket.try_next().await.unwrap().unwrap(); - match message { - TungsteniteMessage::Text(text) => { - assert_eq!("Hello World".to_string(), text) - } - _ => panic!("Expected 'Hello World' as text message."), - } - break; - } -} - -#[tokio::test] -async fn test_websocket_embedded_ws_tokio() { - let stream = TcpStream::connect(ECHO_WS_AS_IP_SERVER).await.unwrap(); - let mut framed = Framed::new(stream, Codec::new()); + let uri = Url::parse(ECHO_WSS_SERVER).unwrap(); + println!("Connecting"); + let stream = TcpStream::connect(&uri).await.unwrap(); + println!("TCP Connected"); + let mut tls_stream = TlsStream::connect(stream, &uri).await.unwrap(); + println!("TLS Handshake Done"); let mut buffer = [0u8; 4096]; - let websocket_options = EmbeddedWebsocketOptions { - path: "/mirror", - host: "ws.vi-server.org", - origin: "http://ws.vi-server.org:80", - sub_protocols: None, - additional_headers: None, - }; - let mut websocket = - connect_to_embedded_websocket_tokio_ws_echo(&mut framed, &mut buffer, &websocket_options) - .await; + let mut websocket = connect_ws(&uri, &mut tls_stream, &mut buffer).await; + println!("WebSocket Connected"); websocket - .send( - &mut framed, + .write( + &mut tls_stream, &mut buffer, - EmbeddedWebsocketSendMessageType::Binary, - false, - b"Hello World", + WebSocketSendMessageType::Text, + true, + "Hello World".as_bytes(), ) .await .unwrap(); - + println!("Message Sent"); loop { let message = websocket - .next(&mut framed, &mut buffer) + .try_read(&mut tls_stream, &mut buffer) .await .unwrap() .unwrap(); match message { - EmbeddedWebsocketReadMessageType::Text(text) => { - println!("Text: {:?}", text) - } - EmbeddedWebsocketReadMessageType::Binary(msg) => { - assert_eq!(b"Hello World", msg); - break; + ReadResult::Text(text) => { + assert_eq!("Hello World".to_string(), text); + println!("Received message: {}", text); } - EmbeddedWebsocketReadMessageType::Pong(t) => println!("Pong: {:?}", t), - EmbeddedWebsocketReadMessageType::Ping(t) => println!("Ping: {:?}", t), - EmbeddedWebsocketReadMessageType::Close(_) => println!("Close:"), + _ => panic!("Expected 'Hello World' as text message."), } + break; } } diff --git a/tests/integration/clients/mod.rs b/tests/integration/clients/mod.rs index 9bb82ae..8454f9d 100644 --- a/tests/integration/clients/mod.rs +++ b/tests/integration/clients/mod.rs @@ -1 +1 @@ -mod async_websocket; \ No newline at end of file +mod async_websocket;