diff --git a/Cargo.toml b/Cargo.toml index 74f7e44..fb5134b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -78,13 +78,23 @@ default = [ tcp = [] dns = ["embassy-net/dns"] tls = [] -websocket = [] +websocket = ["embedded-websocket"] json-rpc = [] std = [ - "dep:tokio", + "tokio", "embedded-io-adapters/tokio-1", "embedded-websocket/std", "tokio-rustls", "webpki-roots", ] webpki-roots = ["dep:webpki-roots"] + +[[example]] +name = "std_async_web_socket_client" +path = "examples/std_async_web_socket_client.rs" +required-features = ["std", "tcp", "websocket"] + +[[example]] +name = "std_tls_async_web_socket_client" +path = "examples/std_tls_async_web_socket_client.rs" +required-features = ["std", "tcp", "tls", "websocket"] diff --git a/examples/std_async_web_socket_client.rs b/examples/std_async_web_socket_client.rs new file mode 100644 index 0000000..b1e8539 --- /dev/null +++ b/examples/std_async_web_socket_client.rs @@ -0,0 +1,47 @@ +use em_as_net::{ + client::websocket::{ + AsyncWebSocketClient, ReadResult, WebSocketRead, WebSocketSendMessageType, WebSocketWrite, + }, + core::tcp::TcpStream, +}; +use rand::thread_rng; +use url::Url; + +#[tokio::main] +async fn main() { + let uri = Url::parse("ws://ws.vi-server.org:80/mirror/").unwrap(); + let mut stream = TcpStream::connect(&uri).await.unwrap(); + println!("TCP Connected"); + let mut buffer = [0u8; 4096]; + let rng = thread_rng(); + let mut websocket = AsyncWebSocketClient::open(&mut stream, &mut buffer, &uri, rng, None, None) + .await + .unwrap(); + println!("WebSocket Connected"); + websocket + .write( + &mut stream, + &mut buffer, + WebSocketSendMessageType::Text, + true, + "Hello World".as_bytes(), + ) + .await + .unwrap(); + println!("Message Sent"); + loop { + let message = websocket + .try_read(&mut stream, &mut buffer) + .await + .unwrap() + .unwrap(); + match message { + ReadResult::Text(text) => { + assert_eq!("Hello World".to_string(), text); + println!("Received message: {}", text); + } + _ => panic!("Expected 'Hello World' as text message."), + } + break; + } +} diff --git a/examples/async_web_socket_client.rs b/examples/std_tls_async_web_socket_client.rs similarity index 100% rename from examples/async_web_socket_client.rs rename to examples/std_tls_async_web_socket_client.rs diff --git a/src/core/framed/codec/decoder.rs b/src/core/framed/codec/decoder.rs deleted file mode 100644 index 670c74e..0000000 --- a/src/core/framed/codec/decoder.rs +++ /dev/null @@ -1,31 +0,0 @@ -use crate::core::framed::{Framed, IoError}; -use crate::core::io::{AsyncRead, AsyncWrite}; -use bytes::BytesMut; -use core::fmt::Display; - -pub trait Decoder { - type Item; - type Error: From + Display; - - fn decode(&mut self, buf: &mut BytesMut) -> Result, Self::Error>; - - fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { - match self.decode(buf)? { - Some(frame) => Ok(Some(frame)), - None => { - if buf.is_empty() { - Ok(None) - } else { - Err(IoError::DecodeError.into()) - } - } - } - } - - fn framed(self, io: T) -> Framed - where - Self: Sized, - { - Framed::new(io, self) - } -} diff --git a/src/core/framed/codec/encoder.rs b/src/core/framed/codec/encoder.rs deleted file mode 100644 index 2ed6159..0000000 --- a/src/core/framed/codec/encoder.rs +++ /dev/null @@ -1,8 +0,0 @@ -use bytes::BytesMut; -use core::fmt::Display; - -pub trait Encoder { - type Error: Display; - - fn encode(&mut self, data: Item, dst: &mut BytesMut) -> Result<(), Self::Error>; -} diff --git a/src/core/framed/codec/mod.rs b/src/core/framed/codec/mod.rs deleted file mode 100644 index a829346..0000000 --- a/src/core/framed/codec/mod.rs +++ /dev/null @@ -1,42 +0,0 @@ -use bytes::{BufMut, BytesMut}; - -mod decoder; -mod encoder; - -pub use decoder::Decoder; -pub use encoder::Encoder; - -use crate::core::framed::IoError; - -#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)] -pub struct Codec(()); - -impl Codec { - pub fn new() -> Self { - Self(()) - } -} - -impl Encoder<&[u8]> for Codec { - type Error = IoError; - - fn encode(&mut self, data: &[u8], buf: &mut BytesMut) -> Result<(), Self::Error> { - buf.reserve(data.len()); - buf.put(data); - Ok(()) - } -} - -impl Decoder for Codec { - type Item = BytesMut; - type Error = IoError; - - fn decode(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { - if !buf.is_empty() { - let len = buf.len(); - Ok(Some(buf.split_to(len))) - } else { - Ok(None) - } - } -} diff --git a/src/core/framed/errors.rs b/src/core/framed/errors.rs deleted file mode 100644 index 312ac19..0000000 --- a/src/core/framed/errors.rs +++ /dev/null @@ -1,74 +0,0 @@ -use thiserror_no_std::Error; - -#[derive(Debug, Clone, Error)] -pub enum IoError { - #[error("Failed to encode message while starting send.")] - EncodeWhileSendError, - #[error("Failed to flush message. Bytes remain on stream.")] - FailedToFlush, - #[error("Unable to decode bytes from stream.")] - DecodeError, - #[error("Unable to read from stream. Bytes remain on stream.")] - DecodeWhileReadError, - #[error("Tried to write but the stream is not connected.")] - WriteNotConnected, - #[error("Tried to flush but the stream is not connected.")] - FlushNotConnected, - #[error("Tried to shutdown but the stream is not connected.")] - ShutdownNotConnected, - #[error("Tried to read but the stream is not connected.")] - ReadNotConnected, - - // TlsConnection errors - #[error("TLS: Tried to write but the stream is not connected.")] - TlsWriteNotConnected, - #[error("TLS: Tried to flush but the stream is not connected.")] - TlsFlushNotConnected, - #[error("TLS: Tried to shutdown but the stream is not connected.")] - TlsShutdownNotConnected, - #[error("TLS: Tried to read but the stream is not connected.")] - TlsReadNotConnected, - - // FromTokio errors - #[error("FromTokio: Tried to write but the stream is not connected.")] - AdapterTokioWriteNotConnected, - #[error("FromTokio: Tried to flush but the stream is not connected.")] - AdapterTokioFlushNotConnected, - #[error("FromTokio: Tried to shutdown but the stream is not connected.")] - AdapterTokioShutdownNotConnected, - #[error("FromTokio: Tried to read but the stream is not connected.")] - AdapterTokioReadNotConnected, - - // AsyncRead errors - #[error("Error occured while reading from stream")] - UnableToRead, - - // AsyncWrite errors - #[error("Error occured while writing to stream")] - UnableToWrite, - #[error("Error occured while flushing stream")] - UnableToFlush, - #[error("Error occured while closing stream")] - UnableToClose, - - // embedded_io errors - #[error("{0:?}")] - Io(embedded_io_async::ErrorKind), - - // Tls errors during IO - #[cfg(feature = "tls")] - #[error("TLS: {0:?}")] - TlsRead(embedded_tls::TlsError), -} - -impl embedded_io_async::Error for IoError { - fn kind(&self) -> embedded_io_async::ErrorKind { - match self { - Self::Io(k) => *k, - _ => embedded_io_async::ErrorKind::Other, - } - } -} - -#[cfg(feature = "std")] -impl alloc::error::Error for IoError {} diff --git a/src/core/framed/framed_impl.rs b/src/core/framed/framed_impl.rs deleted file mode 100644 index a419d0d..0000000 --- a/src/core/framed/framed_impl.rs +++ /dev/null @@ -1,333 +0,0 @@ -//! A no_std implementation of https://github.com/tokio-rs/tokio/blob/master/tokio-util/src/codec/framed_impl.rs - -use super::super::io::{io_slice::IoSlice, AsyncRead, AsyncWrite}; -use super::codec::{Decoder, Encoder}; -use super::errors::IoError; -use anyhow::Result; -use bytes::{Buf, BufMut, BytesMut}; -use core::borrow::{Borrow, BorrowMut}; -use core::mem::MaybeUninit; -use core::pin::Pin; -use core::task::{Context, Poll}; -use futures::{ready, Sink, Stream}; -use pin_project_lite::pin_project; - -#[cfg(not(feature = "std"))] -use crate::core::io::ReadBuf; -use crate::Err; -#[cfg(feature = "std")] -use tokio::io::ReadBuf; - -const INITIAL_CAPACITY: usize = 8 * 1024; - -pin_project! { - #[derive(Debug)] - pub(crate) struct FramedImpl - { - #[pin] - pub inner: T, - pub(crate) state: State, - pub codec: C, - } -} - -#[derive(Debug)] -pub(crate) struct ReadFrame { - pub(crate) eof: bool, - pub(crate) is_readable: bool, - pub(crate) buffer: BytesMut, - pub(crate) has_errored: bool, -} - -pub(crate) struct WriteFrame { - pub(crate) buffer: BytesMut, - pub(crate) backpressure_boundary: usize, -} - -#[derive(Default)] -pub(crate) struct RWFrames { - pub(crate) read: ReadFrame, - pub(crate) write: WriteFrame, -} - -impl Default for ReadFrame { - fn default() -> Self { - Self { - eof: false, - is_readable: false, - buffer: BytesMut::with_capacity(INITIAL_CAPACITY), - has_errored: false, - } - } -} - -impl Default for WriteFrame { - fn default() -> Self { - Self { - buffer: BytesMut::with_capacity(INITIAL_CAPACITY), - backpressure_boundary: INITIAL_CAPACITY, - } - } -} - -impl From for ReadFrame { - fn from(mut buffer: BytesMut) -> Self { - let size = buffer.capacity(); - if size < INITIAL_CAPACITY { - buffer.reserve(INITIAL_CAPACITY - size); - } - - Self { - buffer, - is_readable: size > 0, - eof: false, - has_errored: false, - } - } -} - -impl From for WriteFrame { - fn from(mut buffer: BytesMut) -> Self { - let size = buffer.capacity(); - if size < INITIAL_CAPACITY { - buffer.reserve(INITIAL_CAPACITY - size); - } - - Self { - buffer, - backpressure_boundary: INITIAL_CAPACITY, - } - } -} - -impl Borrow for RWFrames { - fn borrow(&self) -> &ReadFrame { - &self.read - } -} -impl BorrowMut for RWFrames { - fn borrow_mut(&mut self) -> &mut ReadFrame { - &mut self.read - } -} -impl Borrow for RWFrames { - fn borrow(&self) -> &WriteFrame { - &self.write - } -} -impl BorrowMut for RWFrames { - fn borrow_mut(&mut self) -> &mut WriteFrame { - &mut self.write - } -} - -impl Stream for FramedImpl -where - T: AsyncRead, - U: Decoder, - R: BorrowMut, -{ - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut pinned = self.project(); - let state: &mut ReadFrame = pinned.state.borrow_mut(); - - loop { - if state.has_errored { - state.is_readable = false; - state.has_errored = false; - return Poll::Ready(None); - } - - if state.is_readable { - if state.eof { - return match pinned.codec.decode_eof(&mut state.buffer) { - Err(err) => { - state.has_errored = true; - Poll::Ready(Some(Err!(err))) - } - Ok(frame) => { - if frame.is_none() { - state.is_readable = false; - } - Poll::Ready(frame.map(Ok)) - } - }; - } - - if let Some(frame) = match pinned.codec.decode(&mut state.buffer) { - Err(err) => { - state.has_errored = true; - return Poll::Ready(Some(Err!(err))); - } - Ok(frame) => frame, - } { - return Poll::Ready(Some(Ok(frame))); - } - - state.is_readable = false; - } - - state.buffer.reserve(1); - match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer) { - Poll::Pending => { - return Poll::Pending; - } - Poll::Ready(bytect_res) => match bytect_res { - Err(err) => { - return Poll::Ready(Some(Err!(err))); - } - Ok(bytect) => { - if bytect == 0 { - if state.eof { - return Poll::Ready(None); - } - state.eof = true; - } else { - state.eof = false; - } - - state.is_readable = true; - } - }, - }; - } - } -} - -impl Sink for FramedImpl -where - T: AsyncWrite, - U: Encoder, - W: BorrowMut, -{ - type Error = anyhow::Error; - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.state.borrow().buffer.len() >= self.state.borrow().backpressure_boundary { - self.as_mut().poll_flush(cx) - } else { - Poll::Ready(Ok(())) - } - } - - fn start_send(self: Pin<&mut Self>, item: I) -> Result<()> { - let pinned = self.project(); - match pinned - .codec - .encode(item, &mut pinned.state.borrow_mut().buffer) - { - Ok(_) => Ok(()), - Err(_) => { - Err!(IoError::EncodeWhileSendError) - } - } - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut pinned = self.project(); - - while !pinned.state.borrow_mut().buffer.is_empty() { - let WriteFrame { buffer, .. } = pinned.state.borrow_mut(); - - match ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer)) { - Err(e) => { - return Poll::Ready(Err!(e)); - } - Ok(n) => { - if n == 0 { - return Poll::Ready(Err!(IoError::FailedToFlush)); - } - } - } - } - - match ready!(pinned.inner.poll_flush(cx)) { - Err(e) => Poll::Ready(Err!(e)), - Ok(_) => Poll::Ready(Ok(())), - } - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if let Err(err) = ready!(self.as_mut().poll_flush(cx)) { - return Poll::Ready(Err!(err)); - } - - if let Err(err) = ready!(self.project().inner.poll_shutdown(cx)) { - return Poll::Ready(Err!(err)); - } - - Poll::Ready(Ok(())) - } -} - -pub fn poll_read_buf( - io: Pin<&mut T>, - cx: &mut Context<'_>, - buf: &mut B, -) -> Poll> { - if !buf.has_remaining_mut() { - return Poll::Ready(Ok(0)); - } - - let n = { - let dst = buf.chunk_mut(); - - // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a - // transparent wrapper around `[MaybeUninit]`. - let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit]) }; - let mut buf = ReadBuf::uninit(dst); - let ptr = buf.filled().as_ptr(); - ready!(io.poll_read(cx, &mut buf)?); - - // Ensure the pointer does not change from under us - assert_eq!(ptr, buf.filled().as_ptr()); - buf.filled().len() - }; - - // Safety: This is guaranteed to be the number of initialized (and read) - // bytes due to the invariants provided by `ReadBuf::filled`. - unsafe { - buf.advance_mut(n); - } - - Poll::Ready(Ok(n)) -} - -pub fn poll_write_buf( - io: Pin<&mut T>, - cx: &mut Context<'_>, - buf: &mut B, -) -> Poll> { - const MAX_BUFS: usize = 64; - - if !buf.has_remaining() { - return Poll::Ready(Ok(0)); - } - - let n = if io.is_write_vectored() { - let mut slices = [IoSlice::new(&[]); MAX_BUFS]; - let cnt = chunks_vectored(&buf, &mut slices); - ready!(io.poll_write_vectored(cx, &slices[..cnt]))? - } else { - ready!(io.poll_write(cx, buf.chunk()))? - }; - - buf.advance(n); - - Poll::Ready(Ok(n)) -} - -fn chunks_vectored<'a, B: Buf>(buf: &'a B, dst: &mut [IoSlice<'a>]) -> usize { - if dst.is_empty() { - return 0; - } - - if buf.has_remaining() { - dst[0] = IoSlice::new(buf.chunk()); - 1 - } else { - 0 - } -} diff --git a/src/core/framed/mod.rs b/src/core/framed/mod.rs deleted file mode 100644 index f776221..0000000 --- a/src/core/framed/mod.rs +++ /dev/null @@ -1,175 +0,0 @@ -//! A no_std version of `tokio::Framed` - -use anyhow::Result; -pub mod codec; -pub use codec::Codec; - -mod framed_impl; -use framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame}; - -pub mod errors; - -use bytes::BytesMut; -use core::fmt; -use core::pin::Pin; -use core::task::{Context, Poll}; -use futures::{Sink, Stream}; -use pin_project_lite::pin_project; - -use super::io::{AsyncRead, AsyncWrite}; - -pub use errors::IoError; - -use codec::{Decoder, Encoder}; - -pin_project! { - pub struct Framed { - #[pin] - pub(crate) inner: FramedImpl - } -} - -impl Framed { - pub fn new(inner: T, codec: U) -> Framed { - Framed { - inner: FramedImpl { - inner, - codec, - state: Default::default(), - }, - } - } - - pub fn with_capacity(inner: T, codec: U, capacity: usize) -> Framed { - Framed { - inner: FramedImpl { - inner, - codec, - state: RWFrames { - read: ReadFrame { - eof: false, - is_readable: false, - buffer: BytesMut::with_capacity(capacity), - has_errored: false, - }, - write: WriteFrame::default(), - }, - }, - } - } -} - -impl Framed { - pub fn get_ref(&self) -> &T { - &self.inner.inner - } - - pub fn get_mut(&mut self) -> &mut T { - &mut self.inner.inner - } - - pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { - self.project().inner.project().inner - } - - pub fn codec(&self) -> &U { - &self.inner.codec - } - - pub fn codec_mut(&mut self) -> &mut U { - &mut self.inner.codec - } - - pub fn codec_pin_mut(self: Pin<&mut Self>) -> &mut U { - self.project().inner.project().codec - } - - pub fn read_buffer(&self) -> &BytesMut { - &self.inner.state.read.buffer - } - - pub fn read_buffer_mut(&mut self) -> &mut BytesMut { - &mut self.inner.state.read.buffer - } - - pub fn write_buffer(&self) -> &BytesMut { - &self.inner.state.write.buffer - } - - pub fn write_buffer_mut(&mut self) -> &mut BytesMut { - &mut self.inner.state.write.buffer - } - - pub fn backpressure_boundary(&self) -> usize { - self.inner.state.write.backpressure_boundary - } - - pub fn set_backpressure_boundary(&mut self, boundary: usize) { - self.inner.state.write.backpressure_boundary = boundary; - } - - pub fn into_inner(self) -> T { - self.inner.inner - } -} - -// This impl just defers to the underlying FramedImpl -impl Stream for Framed -where - T: AsyncRead, - U: Decoder, -{ - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_next(cx) - } -} - -// This impl just defers to the underlying FramedImpl -impl Sink for Framed -where - T: AsyncWrite, - U: Encoder, -{ - type Error = anyhow::Error; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_ready(cx) - } - - fn start_send(self: Pin<&mut Self>, item: I) -> Result<()> { - self.project().inner.start_send(item) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_flush(cx) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_close(cx) - } -} - -impl fmt::Debug for Framed -where - T: fmt::Debug, - U: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Framed") - .field("io", self.get_ref()) - .field("codec", self.codec()) - .finish() - } -} - -impl From for Framed -where - T: AsyncRead + AsyncWrite, -{ - /// Uses [`Codec`] as default - fn from(value: T) -> Self { - Self::new(value, Codec::new()) - } -} diff --git a/src/core/io/async_read.rs b/src/core/io/async_read.rs deleted file mode 100644 index 60aa60a..0000000 --- a/src/core/io/async_read.rs +++ /dev/null @@ -1,76 +0,0 @@ -// use crate::Err; -// use alloc::boxed::Box; -use anyhow::Result; -use core::fmt::{Debug, Display}; -// use core::ops::DerefMut; -use core::pin::Pin; -use core::task::{Context, Poll}; - -#[cfg(not(feature = "std"))] -use crate::core::io::ReadBuf; -#[cfg(feature = "std")] -use tokio::io::ReadBuf; - -// use crate::core::framed::IoError; - -pub trait AsyncRead { - type Error: Debug + Display; - - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll>; -} - -// macro_rules! deref_async_read { -// () => { -// fn poll_read( -// mut self: Pin<&mut Self>, -// cx: &mut Context<'_>, -// buf: &mut ReadBuf<'_>, -// ) -> Poll> { -// match Pin::new(&mut **self).poll_read(cx, buf) { -// Poll::Ready(result) => match result { -// Ok(_) => Poll::Ready(Ok(())), -// Err(_) => Poll::Ready(Err!(IoError::DecodeWhileReadError)), -// }, -// Poll::Pending => Poll::Pending, -// } -// } -// }; -// } - -// impl AsyncRead for Box { -// type Error = anyhow::Error; - -// deref_async_read!(); -// } - -// impl AsyncRead for &mut T { -// type Error = anyhow::Error; - -// deref_async_read!(); -// } - -// impl

AsyncRead for Pin

-// where -// P: DerefMut + Unpin, -// P::Target: AsyncRead, -// { -// type Error = anyhow::Error; - -// fn poll_read( -// self: Pin<&mut Self>, -// cx: &mut Context<'_>, -// buf: &mut ReadBuf<'_>, -// ) -> Poll> { -// match self.get_mut().as_mut().poll_read(cx, buf) { -// Poll::Ready(result) => match result { -// Ok(()) => Poll::Ready(Ok(())), -// Err(err) => Poll::Ready(Err!(err)), -// }, -// Poll::Pending => Poll::Pending, -// } -// } -// } diff --git a/src/core/io/async_write.rs b/src/core/io/async_write.rs deleted file mode 100644 index 6ff20b7..0000000 --- a/src/core/io/async_write.rs +++ /dev/null @@ -1,133 +0,0 @@ -// use alloc::boxed::Box; -use anyhow::Result; -// use core::ops::DerefMut; -use core::pin::Pin; -use core::task::{Context, Poll}; - -use super::io_slice::IoSlice; - -pub trait AsyncWrite { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll>; - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - let buf = bufs - .iter() - .find(|b| !b.is_empty()) - .map_or(&[][..], |b| &**b); - self.poll_write(cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; - - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; - - fn is_write_vectored(&self) -> bool { - false - } -} - -// macro_rules! deref_async_write { -// () => { -// fn poll_write( -// mut self: Pin<&mut Self>, -// cx: &mut Context<'_>, -// buf: &[u8], -// ) -> Poll> { -// Pin::new(&mut **self).poll_write(cx, buf) -// } - -// fn poll_write_vectored( -// mut self: Pin<&mut Self>, -// cx: &mut Context<'_>, -// bufs: &[IoSlice<'_>], -// ) -> Poll> { -// Pin::new(&mut **self).poll_write_vectored(cx, bufs) -// } - -// fn is_write_vectored(&self) -> bool { -// (**self).is_write_vectored() -// } - -// fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { -// Pin::new(&mut **self).poll_flush(cx) -// } - -// fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { -// Pin::new(&mut **self).poll_shutdown(cx) -// } -// }; -// } - -// impl AsyncWrite for Box { -// deref_async_write!(); -// } - -// impl AsyncWrite for &mut T { -// deref_async_write!(); -// } - -// impl

AsyncWrite for Pin

-// where -// P: DerefMut + Unpin, -// P::Target: AsyncWrite, -// { -// fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { -// self.get_mut().as_mut().poll_write(cx, buf) -// } - -// fn poll_write_vectored( -// self: Pin<&mut Self>, -// cx: &mut Context<'_>, -// bufs: &[IoSlice<'_>], -// ) -> Poll> { -// self.get_mut().as_mut().poll_write_vectored(cx, bufs) -// } - -// fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { -// self.get_mut().as_mut().poll_flush(cx) -// } - -// fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { -// self.get_mut().as_mut().poll_shutdown(cx) -// } - -// fn is_write_vectored(&self) -> bool { -// (**self).is_write_vectored() -// } -// } - -// TODO: implement if needed, otherwise delete -// impl AsyncWrite for Vec { -// fn poll_write( -// self: Pin<&mut Self>, -// _cx: &mut Context<'_>, -// buf: &[u8], -// ) -> Poll> { -// self.get_mut().extend_from_slice(buf); -// Poll::Ready(Ok(buf.len())) -// } -// -// fn poll_write_vectored( -// mut self: Pin<&mut Self>, -// _: &mut Context<'_>, -// bufs: &[IoSlice<'_>], -// ) -> Poll> { -// Poll::Ready(::write_vectored(&mut *self, bufs)) -// } -// -// fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { -// Poll::Ready(Ok(())) -// } -// -// fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { -// Poll::Ready(Ok(())) -// } -// -// fn is_write_vectored(&self) -> bool { -// true -// } -// } diff --git a/src/core/io/io_slice.rs b/src/core/io/io_slice.rs deleted file mode 100644 index fd6f841..0000000 --- a/src/core/io/io_slice.rs +++ /dev/null @@ -1,181 +0,0 @@ -use core::fmt::{Debug, Formatter}; -use core::marker::PhantomData; -use core::mem::take; -use core::ops::{Deref, DerefMut}; -use core::slice; -use libc::{c_void, iovec}; - -#[derive(Copy, Clone)] -#[repr(transparent)] -pub struct IoSlice<'a> { - vec: iovec, - _p: PhantomData<&'a [u8]>, -} - -impl<'a> IoSlice<'a> { - #[inline] - pub fn new(buf: &'a [u8]) -> IoSlice<'a> { - IoSlice { - vec: iovec { - iov_base: buf.as_ptr() as *mut u8 as *mut c_void, - iov_len: buf.len(), - }, - _p: PhantomData, - } - } - - #[inline] - pub fn advance(&mut self, n: usize) { - if self.vec.iov_len < n { - panic!("advancing IoSlice beyond its length"); - } - - unsafe { - self.vec.iov_len -= n; - self.vec.iov_base = self.vec.iov_base.add(n); - } - } - - #[inline] - pub fn as_slice(&self) -> &[u8] { - unsafe { slice::from_raw_parts(self.vec.iov_base as *mut u8, self.vec.iov_len) } - } - - #[inline] - pub fn advance_slices(bufs: &mut &mut [IoSlice<'a>], n: usize) { - // Number of buffers to remove. - let mut remove = 0; - // Total length of all the to be removed buffers. - let mut accumulated_len = 0; - for buf in bufs.iter() { - if accumulated_len + buf.len() > n { - break; - } else { - accumulated_len += buf.len(); - remove += 1; - } - } - - *bufs = &mut take(bufs)[remove..]; - if bufs.is_empty() { - assert_eq!( - n, accumulated_len, - "advancing io slices beyond their length" - ); - } else { - bufs[0].advance(n - accumulated_len) - } - } -} - -unsafe impl<'a> Send for IoSlice<'a> {} - -unsafe impl<'a> Sync for IoSlice<'a> {} - -impl<'a> Debug for IoSlice<'a> { - fn fmt(&self, fmt: &mut Formatter<'_>) -> core::fmt::Result { - Debug::fmt(self.as_slice(), fmt) - } -} - -impl<'a> Deref for IoSlice<'a> { - type Target = [u8]; - - #[inline] - fn deref(&self) -> &[u8] { - self.as_slice() - } -} - -#[repr(transparent)] -pub struct IoSliceMut<'a> { - vec: iovec, - _p: PhantomData<&'a mut [u8]>, -} - -impl<'a> IoSliceMut<'a> { - #[inline] - pub fn new(buf: &'a mut [u8]) -> IoSliceMut<'a> { - IoSliceMut { - vec: iovec { - iov_base: buf.as_mut_ptr() as *mut c_void, - iov_len: buf.len(), - }, - _p: PhantomData, - } - } - - #[inline] - pub fn advance(&mut self, n: usize) { - if self.vec.iov_len < n { - panic!("advancing IoSliceMut beyond its length"); - } - - unsafe { - self.vec.iov_len -= n; - self.vec.iov_base = self.vec.iov_base.add(n); - } - } - - #[inline] - pub fn as_slice(&self) -> &[u8] { - unsafe { slice::from_raw_parts(self.vec.iov_base as *mut u8, self.vec.iov_len) } - } - - #[inline] - pub fn as_mut_slice(&mut self) -> &mut [u8] { - unsafe { slice::from_raw_parts_mut(self.vec.iov_base as *mut u8, self.vec.iov_len) } - } - - #[inline] - pub fn advance_slices(bufs: &mut &mut [IoSliceMut<'a>], n: usize) { - // Number of buffers to remove. - let mut remove = 0; - // Total length of all the to be removed buffers. - let mut accumulated_len = 0; - for buf in bufs.iter() { - if accumulated_len + buf.len() > n { - break; - } else { - accumulated_len += buf.len(); - remove += 1; - } - } - - *bufs = &mut take(bufs)[remove..]; - if bufs.is_empty() { - assert_eq!( - n, accumulated_len, - "advancing io slices beyond their length" - ); - } else { - bufs[0].advance(n - accumulated_len) - } - } -} - -unsafe impl<'a> Send for IoSliceMut<'a> {} - -unsafe impl<'a> Sync for IoSliceMut<'a> {} - -impl<'a> Debug for IoSliceMut<'a> { - fn fmt(&self, fmt: &mut Formatter<'_>) -> core::fmt::Result { - Debug::fmt(self.as_slice(), fmt) - } -} - -impl<'a> Deref for IoSliceMut<'a> { - type Target = [u8]; - - #[inline] - fn deref(&self) -> &[u8] { - self.as_slice() - } -} - -impl<'a> DerefMut for IoSliceMut<'a> { - #[inline] - fn deref_mut(&mut self) -> &mut [u8] { - self.as_mut_slice() - } -} diff --git a/src/core/io/mod.rs b/src/core/io/mod.rs deleted file mode 100644 index f7b89bf..0000000 --- a/src/core/io/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -mod read_buf; -#[cfg(not(feature = "std"))] -pub(crate) use read_buf::ReadBuf; - -pub mod async_read; -pub use async_read::AsyncRead; - -pub mod async_write; -pub use async_write::AsyncWrite; - -pub mod io_slice; diff --git a/src/core/io/read_buf.rs b/src/core/io/read_buf.rs deleted file mode 100644 index d2ed59f..0000000 --- a/src/core/io/read_buf.rs +++ /dev/null @@ -1,185 +0,0 @@ -//! A no_std version of `tokio::io::ReadBuf` -//! -//! `` - -use core::fmt; -use core::mem::MaybeUninit; - -pub struct ReadBuf<'a> { - buf: &'a mut [MaybeUninit], - filled: usize, - initialized: usize, -} - -impl<'a> ReadBuf<'a> { - #[inline] - pub fn new(buf: &'a mut [u8]) -> ReadBuf<'a> { - let initialized = buf.len(); - let buf = unsafe { slice_to_uninit_mut(buf) }; - ReadBuf { - buf, - filled: 0, - initialized, - } - } - - #[inline] - pub fn uninit(buf: &'a mut [MaybeUninit]) -> ReadBuf<'a> { - ReadBuf { - buf, - filled: 0, - initialized: 0, - } - } - - #[inline] - pub fn capacity(&self) -> usize { - self.buf.len() - } - - #[inline] - pub fn filled(&self) -> &[u8] { - let slice = &self.buf[..self.filled]; - unsafe { slice_assume_init(slice) } - } - - #[inline] - pub fn filled_mut(&mut self) -> &mut [u8] { - let slice = &mut self.buf[..self.filled]; - unsafe { slice_assume_init_mut(slice) } - } - - #[inline] - pub fn take(&mut self, n: usize) -> ReadBuf<'_> { - let max = core::cmp::min(self.remaining(), n); - unsafe { ReadBuf::uninit(&mut self.unfilled_mut()[..max]) } - } - - #[inline] - pub fn initialized(&self) -> &[u8] { - let slice = &self.buf[..self.initialized]; - unsafe { slice_assume_init(slice) } - } - - #[inline] - pub fn initialized_mut(&mut self) -> &mut [u8] { - let slice = &mut self.buf[..self.initialized]; - unsafe { slice_assume_init_mut(slice) } - } - - #[inline] - pub unsafe fn inner_mut(&mut self) -> &mut [MaybeUninit] { - self.buf - } - - #[inline] - pub unsafe fn unfilled_mut(&mut self) -> &mut [MaybeUninit] { - &mut self.buf[self.filled..] - } - - #[inline] - pub fn initialize_unfilled(&mut self) -> &mut [u8] { - self.initialize_unfilled_to(self.remaining()) - } - - #[inline] - #[track_caller] - pub fn initialize_unfilled_to(&mut self, n: usize) -> &mut [u8] { - assert!(self.remaining() >= n, "n overflows remaining"); - - let end = self.filled + n; - - if self.initialized < end { - unsafe { - self.buf[self.initialized..end] - .as_mut_ptr() - .write_bytes(0, end - self.initialized); - } - self.initialized = end; - } - - let slice = &mut self.buf[self.filled..end]; - unsafe { slice_assume_init_mut(slice) } - } - - #[inline] - pub fn remaining(&self) -> usize { - self.capacity() - self.filled - } - - #[inline] - pub fn clear(&mut self) { - self.filled = 0; - } - - #[inline] - #[track_caller] - pub fn advance(&mut self, n: usize) { - let new = self.filled.checked_add(n).expect("filled overflow"); - self.set_filled(new); - } - - #[inline] - #[track_caller] - pub fn set_filled(&mut self, n: usize) { - assert!( - n <= self.initialized, - "filled must not become larger than initialized" - ); - self.filled = n; - } - - #[inline] - pub unsafe fn assume_init(&mut self, n: usize) { - let new = self.filled + n; - if new > self.initialized { - self.initialized = new; - } - } - - #[inline] - #[track_caller] - pub fn put_slice(&mut self, buf: &[u8]) { - assert!( - self.remaining() >= buf.len(), - "buf.len() must fit in remaining()" - ); - - let amt = buf.len(); - let end = self.filled + amt; - - unsafe { - self.buf[self.filled..end] - .as_mut_ptr() - .cast::() - .copy_from_nonoverlapping(buf.as_ptr(), amt); - } - - if self.initialized < end { - self.initialized = end; - } - self.filled = end; - } -} - -impl fmt::Debug for ReadBuf<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("ReadBuf") - .field("filled", &self.filled) - .field("initialized", &self.initialized) - .field("capacity", &self.capacity()) - .finish() - } -} - -unsafe fn slice_to_uninit_mut(slice: &mut [u8]) -> &mut [MaybeUninit] { - &mut *(slice as *mut [u8] as *mut [MaybeUninit]) -} - -unsafe fn slice_assume_init(slice: &[MaybeUninit]) -> &[u8] { - &*(slice as *const [MaybeUninit] as *const [u8]) -} - -unsafe fn slice_assume_init_mut(slice: &mut [MaybeUninit]) -> &mut [u8] { - &mut *(slice as *mut [MaybeUninit] as *mut [u8]) -} diff --git a/src/core/mod.rs b/src/core/mod.rs index 3d581f0..cddbf3e 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -1,190 +1,6 @@ -use alloc::{ - dbg, format, - string::{String, ToString}, -}; -use anyhow::Result; -use thiserror_no_std::Error; -use url::Host; - #[cfg(feature = "dns")] pub mod dns; #[cfg(feature = "tcp")] pub mod tcp; #[cfg(feature = "tls")] pub mod tls; - -#[derive(Debug, Clone)] -pub struct Url(url::Url); - -impl Url { - pub fn parse(url: &str) -> Result { - let mut url = - Url(url::Url::parse(url).map_err(|e| UrlException::InvalidUrl(e.to_string()).into())?); - url.assure_port().unwrap(); - dbg!(url.clone()); - Ok(url) - } - - pub fn for_tcp(&self) -> Result { - let mut url = self.clone(); - url.assure_port().unwrap(); - url.set_port(Some(443)).unwrap(); - let host = match self.host_str() { - Some(host) => host, - None => return Err(anyhow::anyhow!("Host not found")), - }; - let port = match self.port() { - Some(port) => port, - None => return Err(anyhow::anyhow!("Port not found")), - }; - let path = self.path(); - let url = host.to_string() + ":" + &port.to_string() + path; - - Ok(url) - } - - pub fn scheme(&self) -> &str { - self.0.scheme() - } - - pub fn host(&self) -> Option> { - self.0.host() - } - - pub fn host_str(&self) -> Option<&str> { - self.0.host_str() - } - - pub fn port(&self) -> Option { - self.0.port() - } - - pub fn path(&self) -> &str { - self.0.path() - } - - pub fn query(&self) -> Option<&str> { - self.0.query() - } - - pub fn fragment(&self) -> Option<&str> { - self.0.fragment() - } - - pub fn set_scheme(&mut self, scheme: &str) -> Result<(), ()> { - self.0.set_scheme(scheme) - } - - pub fn set_port(&mut self, port: Option) -> Result<(), ()> { - self.0.set_port(port) - } - - pub fn set_path(&mut self, path: &str) -> () { - self.0.set_path(path) - } - - pub fn set_query(&mut self, query: &str) -> () { - self.0.set_query(Some(query)) - } - - pub fn get_with_port(&self, default: Option) -> String { - let scheme = self.0.scheme(); - let host = match self.0.host_str() { - Some(host) => host, - None => return "".to_string(), - }; - let port = match self.0.port() { - Some(port) => port, - None => match scheme { - "wss" | "https" => 443, - "ws" | "http" => 80, - _ => default.unwrap_or(80), - }, - }; - let path = self.path(); - // let query = self.query().map_or("".into(), |q| format!("?{}", q)); - // let fragment = self.fragment().map_or("".into(), |f| format!("#{}", f)); - - format!( - "{}://{}:{}{}", - scheme, - host, - port, - path, - // query, - // fragment - ) - } - - fn assure_port(&mut self) -> Result<(), ()> { - if self.0.port().is_none() { - let port = match self.0.scheme() { - "wss" | "https" => 443, - "ws" | "http" => 80, - _ => 80, - }; - self.0.set_port(Some(port))?; - } - - Ok(()) - } -} - -impl TryFrom for Url { - type Error = (); - - fn try_from(url: url::Url) -> Result { - let mut url = Url(url); - url.assure_port()?; - Ok(url) - } -} - -impl From for url::Url { - fn from(url: Url) -> url::Url { - url.0 - } -} - -impl TryFrom<&str> for Url { - type Error = anyhow::Error; - - fn try_from(url: &str) -> Result { - let url = - Url(url::Url::parse(url).map_err(|e| UrlException::InvalidUrl(e.to_string()).into())?); - - Ok(url) - } -} - -impl ToString for Url { - fn to_string(&self) -> String { - self.0.to_string() - } -} - -impl From for String { - fn from(url: Url) -> Self { - url.to_string() - } -} - -#[derive(Debug, Error)] -pub enum UrlException { - #[error("Invalid URL: {0}")] - InvalidUrl(String), - #[error("Invalid scheme: {0}")] - InvalidScheme(String), -} - -impl From for UrlException { - fn from(e: url::ParseError) -> Self { - UrlException::InvalidUrl(e.to_string()) - } -} - -impl Into for UrlException { - fn into(self) -> anyhow::Error { - anyhow::anyhow!(self) - } -} diff --git a/src/core/tcp/exceptions.rs b/src/core/tcp/exceptions.rs index 55ab5a5..26e41ff 100644 --- a/src/core/tcp/exceptions.rs +++ b/src/core/tcp/exceptions.rs @@ -1,10 +1,17 @@ +use core::fmt::Debug; + use embedded_io_async::ErrorKind; -use strum_macros::Display; use thiserror_no_std::Error; -#[derive(Debug, Error, Display)] +#[derive(Debug, Error)] pub enum TcpException { - IoError(#[from] alloc::io::Error), + #[cfg(not(feature = "std"))] + #[error("I/O error")] + IoError, + #[cfg(feature = "std")] + #[error("I/O error: {0}")] + IoError(alloc::io::Error), + #[error("Embedded IO async error")] EmbeddedIoAsyncError(embedded_io_async::ErrorKind), } @@ -16,9 +23,3 @@ impl embedded_io_async::Error for TcpException { } } } - -impl Into for TcpException { - fn into(self) -> anyhow::Error { - anyhow::anyhow!(self) - } -} diff --git a/src/core/tcp/mod.rs b/src/core/tcp/mod.rs index ada8a72..9660423 100644 --- a/src/core/tcp/mod.rs +++ b/src/core/tcp/mod.rs @@ -29,9 +29,10 @@ mod _tokio { impl TcpStream { pub async fn connect(url: &Url) -> Result { - let stream = TokioTcpStream::connect(derive_tcp_url(url, None)?) - .await - .map_err(|e| TcpException::IoError(e).into())?; + let stream = match TokioTcpStream::connect(derive_tcp_url(url, None)?).await { + Ok(stream) => stream, + Err(e) => return Err!(TcpException::IoError(e)), + }; Ok(TcpStream(FromTokio::new(stream))) } } @@ -42,20 +43,26 @@ mod _tokio { impl Read for TcpStream { async fn read(&mut self, buf: &mut [u8]) -> Result { - self.0.read(buf).await.map_err(|e| TcpException::IoError(e)) + match self.0.read(buf).await { + Ok(n) => Ok(n), + Err(e) => Err(TcpException::IoError(e)), + } } } impl Write for TcpStream { async fn write(&mut self, buf: &[u8]) -> Result { - self.0 - .write(buf) - .await - .map_err(|e| TcpException::IoError(e)) + match self.0.write(buf).await { + Ok(n) => Ok(n), + Err(e) => Err(TcpException::IoError(e)), + } } async fn flush(&mut self) -> Result<(), Self::Error> { - self.0.flush().await.map_err(|e| TcpException::IoError(e)) + match self.0.flush().await { + Ok(()) => Ok(()), + Err(e) => Err(TcpException::IoError(e)), + } } } diff --git a/src/core/tls/exceptions.rs b/src/core/tls/exceptions.rs index 36b0235..50bd741 100644 --- a/src/core/tls/exceptions.rs +++ b/src/core/tls/exceptions.rs @@ -1,10 +1,15 @@ -use anyhow::anyhow; +use core::fmt::Debug; + use embedded_io_async::ErrorKind; use rustls::pki_types::InvalidDnsNameError; use thiserror_no_std::Error; #[derive(Debug, Error)] pub enum TlsException { + #[cfg(not(feature = "std"))] + #[error("I/O error")] + IoError, + #[cfg(feature = "std")] #[error("I/O error: {0}")] IoError(alloc::io::Error), #[error("No domain")] @@ -15,18 +20,6 @@ pub enum TlsException { InvalidServerName(InvalidDnsNameError), } -impl From for TlsException { - fn from(e: alloc::io::Error) -> Self { - TlsException::IoError(e) - } -} - -impl Into for TlsException { - fn into(self) -> anyhow::Error { - anyhow!(self) - } -} - impl embedded_io_async::Error for TlsException { fn kind(&self) -> embedded_io_async::ErrorKind { match self { diff --git a/src/core/tls/mod.rs b/src/core/tls/mod.rs index c3e9dd4..e4a3465 100644 --- a/src/core/tls/mod.rs +++ b/src/core/tls/mod.rs @@ -1,21 +1,42 @@ mod exceptions; -use embedded_io_adapters::tokio_1::FromTokio; pub use exceptions::*; -use tokio_rustls::client::TlsStream as TokioRustlsTlsStream; use anyhow::Result; use embedded_io_async::{Read, Write}; +#[cfg(not(feature = "std"))] +pub use rustls_stream::*; + +#[cfg(not(feature = "std"))] +mod rustls_stream { + use embedded_io_async::{Read, Write}; + use embedded_websocket::Result; + + pub struct TlsStream(S); + + impl TlsStream { + pub async fn connect(stream: S, _url: &url::Url) -> Result { + Ok(TlsStream(stream)) + } + + pub async fn accept(_stream: S, _url: &url::Url) -> Result { + todo!("Implement accept as TlsListener"); + } + } +} + #[cfg(feature = "std")] pub use tokio_tls_stream::*; #[cfg(feature = "std")] mod tokio_tls_stream { use alloc::{borrow::Cow, string::String, sync::Arc}; + use embedded_io_adapters::tokio_1::FromTokio; use embedded_io_async::ErrorType; use rustls::{pki_types::ServerName, ClientConfig, RootCertStore}; use tokio::io::{AsyncRead, AsyncWrite}; + use tokio_rustls::client::TlsStream as TokioRustlsTlsStream; use tokio_rustls::{TlsAcceptor, TlsConnector}; use url::Url; @@ -79,10 +100,10 @@ mod tokio_tls_stream { S: AsyncRead + AsyncWrite + Unpin, { async fn read(&mut self, buf: &mut [u8]) -> core::result::Result { - self.0 - .read(buf) - .await - .map_err(|e| TlsException::IoError(e).into()) + match self.0.read(buf).await { + Ok(n) => Ok(n), + Err(e) => Err(TlsException::IoError(e)), + } } } @@ -91,17 +112,17 @@ mod tokio_tls_stream { S: AsyncRead + AsyncWrite + Unpin, { async fn write(&mut self, buf: &[u8]) -> core::result::Result { - self.0 - .write(buf) - .await - .map_err(|e| TlsException::IoError(e).into()) + match self.0.write(buf).await { + Ok(n) => Ok(n), + Err(e) => Err(TlsException::IoError(e)), + } } async fn flush(&mut self) -> core::result::Result<(), Self::Error> { - self.0 - .flush() - .await - .map_err(|e| TlsException::IoError(e).into()) + match self.0.flush().await { + Ok(()) => Ok(()), + Err(e) => Err(TlsException::IoError(e)), + } } } diff --git a/tests/integration/clients/async_websocket.rs b/tests/integration/clients/async_websocket.rs index fd4899c..3694f3b 100644 --- a/tests/integration/clients/async_websocket.rs +++ b/tests/integration/clients/async_websocket.rs @@ -6,6 +6,7 @@ use url::Url; use crate::common::{connect_ws, ECHO_WSS_SERVER}; +#[cfg(feature = "std")] #[tokio::test] async fn test_websocket_tls() { let uri = Url::parse(ECHO_WSS_SERVER).unwrap();