diff --git a/Cargo.toml b/Cargo.toml index 412137aa..8a4a340a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ __rustls-tls = ["rustls", "rustls-pki-types"] [dependencies] data-encoding = { version = "2", optional = true } byteorder = "1.3.2" -bytes = "1.0" +bytes = "1.9.0" http = { version = "1.0", optional = true } httparse = { version = "1.3.4", optional = true } log = "0.4.8" diff --git a/benches/read.rs b/benches/read.rs index a60c86e0..252ef3e3 100644 --- a/benches/read.rs +++ b/benches/read.rs @@ -52,7 +52,7 @@ fn benchmark(c: &mut Criterion) { writer .send(match i { _ if i % 3 == 0 => Message::binary(i.to_le_bytes().to_vec()), - _ => Message::Text(format!("{{\"id\":{i}}}")), + _ => Message::text(format!("{{\"id\":{i}}}")), }) .unwrap(); sum += i; @@ -68,7 +68,7 @@ fn benchmark(c: &mut Criterion) { sum += u64::from_le_bytes(*a); } Message::Text(msg) => { - let i: u64 = msg[6..msg.len() - 1].parse().unwrap(); + let i: u64 = msg.as_str()[6..msg.len() - 1].parse().unwrap(); sum += i; } m => panic!("Unexpected {m}"), diff --git a/benches/write.rs b/benches/write.rs index 435f9a3b..cf7ab7e6 100644 --- a/benches/write.rs +++ b/benches/write.rs @@ -1,8 +1,7 @@ //! Benchmarks for write performance. use criterion::Criterion; use std::{ - hint, - io::{self, Read, Write}, + hint, io, time::{Duration, Instant}, }; use tungstenite::{protocol::Role, Message, WebSocket}; @@ -16,12 +15,12 @@ const MOCK_WRITE_LEN: usize = 8 * 1024 * 1024; /// Each `flush` takes **~8µs** to simulate flush io. struct MockWrite(Vec); -impl Read for MockWrite { +impl io::Read for MockWrite { fn read(&mut self, _: &mut [u8]) -> io::Result { Err(io::Error::new(io::ErrorKind::WouldBlock, "reads not supported")) } } -impl Write for MockWrite { +impl io::Write for MockWrite { fn write(&mut self, buf: &[u8]) -> io::Result { if self.0.len() + buf.len() > MOCK_WRITE_LEN { self.flush()?; @@ -58,7 +57,7 @@ fn benchmark(c: &mut Criterion) { for i in 0_u64..100_000 { let msg = match i { _ if i % 3 == 0 => Message::binary(i.to_le_bytes().to_vec()), - _ => Message::Text(format!("{{\"id\":{i}}}")), + _ => Message::text(format!("{{\"id\":{i}}}")), }; ws.write(msg).unwrap(); } diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index dcc3e75f..4ba5f6b9 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -8,7 +8,7 @@ fn get_case_count() -> Result { let (mut socket, _) = connect("ws://localhost:9001/getCaseCount")?; let msg = socket.read()?; socket.close(None)?; - Ok(msg.into_text()?.parse::().unwrap()) + Ok(msg.into_text()?.as_str().parse::().unwrap()) } fn update_reports() -> Result<()> { diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 87b3e2c2..ef8dfe3c 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -1,5 +1,3 @@ -use byteorder::{NetworkEndian, ReadBytesExt}; -use log::*; use std::{ borrow::Cow, default::Default, @@ -7,14 +5,22 @@ use std::{ io::{Cursor, ErrorKind, Read, Write}, result::Result as StdResult, str::Utf8Error, - string::{FromUtf8Error, String}, + string::String, }; +use byteorder::{NetworkEndian, ReadBytesExt}; +use log::*; + use super::{ coding::{CloseCode, Control, Data, OpCode}, mask::{apply_mask, generate_mask}, + Payload, +}; +use crate::{ + error::{Error, ProtocolError, Result}, + protocol::frame::Utf8Payload, }; -use crate::error::{Error, ProtocolError, Result}; +use bytes::{Buf, BytesMut}; /// A struct representing the close command. #[derive(Debug, Clone, Eq, PartialEq)] @@ -207,7 +213,7 @@ impl FrameHeader { #[derive(Debug, Clone, Eq, PartialEq)] pub struct Frame { header: FrameHeader, - payload: Vec, + payload: Payload, } impl Frame { @@ -239,14 +245,14 @@ impl Frame { /// Get a reference to the frame's payload. #[inline] - pub fn payload(&self) -> &Vec { - &self.payload + pub fn payload(&self) -> &[u8] { + self.payload.as_slice() } /// Get a mutable reference to the frame's payload. #[inline] - pub fn payload_mut(&mut self) -> &mut Vec { - &mut self.payload + pub fn payload_mut(&mut self) -> &mut [u8] { + self.payload.as_mut_slice() } /// Test whether the frame is masked. @@ -269,26 +275,26 @@ impl Frame { #[inline] pub(crate) fn apply_mask(&mut self) { if let Some(mask) = self.header.mask.take() { - apply_mask(&mut self.payload, mask); + apply_mask(self.payload.as_mut_slice(), mask); } } - /// Consume the frame into its payload as binary. + /// Consume the frame into its payload as string. #[inline] - pub fn into_data(self) -> Vec { - self.payload + pub fn into_text(self) -> StdResult { + self.payload.into_text() } - /// Consume the frame into its payload as string. + /// Consume the frame into its payload. #[inline] - pub fn into_string(self) -> StdResult { - String::from_utf8(self.payload) + pub fn into_payload(self) -> Payload { + self.payload } /// Get frame payload as `&str`. #[inline] pub fn to_text(&self) -> Result<&str, Utf8Error> { - std::str::from_utf8(&self.payload) + std::str::from_utf8(self.payload.as_slice()) } /// Consume the frame into a closing frame. @@ -298,10 +304,10 @@ impl Frame { 0 => Ok(None), 1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)), _ => { - let mut data = self.payload; + let mut data = self.payload.as_slice(); let code = u16::from_be_bytes([data[0], data[1]]).into(); - data.drain(0..2); - let text = String::from_utf8(data)?; + data.advance(2); + let text = String::from_utf8(data.to_vec())?; Ok(Some(CloseFrame { code, reason: text.into() })) } } @@ -309,33 +315,35 @@ impl Frame { /// Create a new data frame. #[inline] - pub fn message(data: Vec, opcode: OpCode, is_final: bool) -> Frame { + pub fn message(data: impl Into, opcode: OpCode, is_final: bool) -> Frame { debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame."); - - Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data } + Frame { + header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, + payload: data.into(), + } } /// Create a new Pong control frame. #[inline] - pub fn pong(data: Vec) -> Frame { + pub fn pong(data: impl Into) -> Frame { Frame { header: FrameHeader { opcode: OpCode::Control(Control::Pong), ..FrameHeader::default() }, - payload: data, + payload: data.into(), } } /// Create a new Ping control frame. #[inline] - pub fn ping(data: Vec) -> Frame { + pub fn ping(data: impl Into) -> Frame { Frame { header: FrameHeader { opcode: OpCode::Control(Control::Ping), ..FrameHeader::default() }, - payload: data, + payload: data.into(), } } @@ -343,19 +351,19 @@ impl Frame { #[inline] pub fn close(msg: Option) -> Frame { let payload = if let Some(CloseFrame { code, reason }) = msg { - let mut p = Vec::with_capacity(reason.len() + 2); + let mut p = BytesMut::with_capacity(reason.len() + 2); p.extend(u16::from(code).to_be_bytes()); p.extend_from_slice(reason.as_bytes()); p } else { - Vec::new() + <_>::default() }; - Frame { header: FrameHeader::default(), payload } + Frame { header: FrameHeader::default(), payload: Payload::Owned(payload) } } /// Create a frame from given header and data. - pub fn from_payload(header: FrameHeader, payload: Vec) -> Self { + pub fn from_payload(header: FrameHeader, payload: Payload) -> Self { Frame { header, payload } } @@ -391,7 +399,7 @@ payload: 0x{} // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), self.len(), self.payload.len(), - self.payload.iter().fold(String::new(), |mut output, byte| { + self.payload.as_slice().iter().fold(String::new(), |mut output, byte| { _ = write!(output, "{byte:02x}"); output }) @@ -465,8 +473,8 @@ mod tests { assert_eq!(length, 7); let mut payload = Vec::new(); raw.read_to_end(&mut payload).unwrap(); - let frame = Frame::from_payload(header, payload); - assert_eq!(frame.into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); + let frame = Frame::from_payload(header, payload.into()); + assert_eq!(frame.into_payload().as_slice(), &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); } #[test] @@ -479,7 +487,7 @@ mod tests { #[test] fn display() { - let f = Frame::message("hi there".into(), OpCode::Data(Data::Text), true); + let f = Frame::message(Payload::from_static(b"hi there"), OpCode::Data(Data::Text), true); let view = format!("{f}"); assert!(view.contains("payload:")); } diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index 69bee831..8b508047 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -5,15 +5,20 @@ pub mod coding; #[allow(clippy::module_inception)] mod frame; mod mask; +mod payload; + +pub use self::{ + frame::{CloseFrame, Frame, FrameHeader}, + payload::{Payload, Utf8Payload}, +}; use crate::{ error::{CapacityError, Error, Result}, - Message, ReadBuffer, + Message, }; +use bytes::BytesMut; use log::*; -use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; - -pub use self::frame::{CloseFrame, Frame, FrameHeader}; +use std::io::{Cursor, Error as IoError, ErrorKind as IoErrorKind, Read, Write}; /// A reader and writer for WebSocket frames. #[derive(Debug)] @@ -36,8 +41,8 @@ impl FrameSocket { } /// Extract a stream from the socket. - pub fn into_inner(self) -> (Stream, Vec) { - (self.stream, self.codec.in_buffer.into_vec()) + pub fn into_inner(self) -> (Stream, BytesMut) { + (self.stream, self.codec.in_buffer) } /// Returns a shared reference to the inner stream. @@ -94,7 +99,7 @@ where #[derive(Debug)] pub(super) struct FrameCodec { /// Buffer to read data from the stream. - in_buffer: ReadBuffer, + in_buffer: BytesMut, /// Buffer to send packets to the network. out_buffer: Vec, /// Capacity limit for `out_buffer`. @@ -109,12 +114,14 @@ pub(super) struct FrameCodec { header: Option<(FrameHeader, u64)>, } +const READ_BUFFER_CAP: usize = 64 * 1024; + impl FrameCodec { /// Create a new frame codec. pub(super) fn new() -> Self { Self { - in_buffer: ReadBuffer::new(), - out_buffer: Vec::new(), + in_buffer: BytesMut::with_capacity(READ_BUFFER_CAP), + out_buffer: <_>::default(), max_out_buffer_len: usize::MAX, out_buffer_write_len: 0, header: None, @@ -123,9 +130,11 @@ impl FrameCodec { /// Create a new frame codec from partially read data. pub(super) fn from_partially_read(part: Vec) -> Self { + let mut in_buffer = BytesMut::from_iter(part); + in_buffer.reserve(READ_BUFFER_CAP.saturating_sub(in_buffer.len())); Self { - in_buffer: ReadBuffer::from_partially_read(part), - out_buffer: Vec::new(), + in_buffer, + out_buffer: <_>::default(), max_out_buffer_len: usize::MAX, out_buffer_write_len: 0, header: None, @@ -156,38 +165,39 @@ impl FrameCodec { let payload = loop { { - let cursor = self.in_buffer.as_cursor_mut(); - if self.header.is_none() { - self.header = FrameHeader::parse(cursor)?; + let mut cursor = Cursor::new(&mut self.in_buffer); + self.header = FrameHeader::parse(&mut cursor)?; + let advanced = cursor.position(); + bytes::Buf::advance(&mut self.in_buffer, advanced as _); } - if let Some((_, ref length)) = self.header { - let length = *length; + if let Some((_, len)) = &self.header { + let len = *len as usize; // Enforce frame size limit early and make sure `length` // is not too big (fits into `usize`). - if length > max_size as u64 { + if len > max_size { return Err(Error::Capacity(CapacityError::MessageTooLong { - size: length as usize, + size: len, max_size, })); } - let input_size = cursor.get_ref().len() as u64 - cursor.position(); - if length <= input_size { - // No truncation here since `length` is checked above - let mut payload = Vec::with_capacity(length as usize); - if length > 0 { - cursor.take(length).read_to_end(&mut payload)?; - } - break payload; + if len <= self.in_buffer.len() { + break self.in_buffer.split_to(len); } } } // Not enough data in buffer. - let size = self.in_buffer.read_from(stream)?; + let reserve_len = self.header.as_ref().map(|(_, l)| *l as usize).unwrap_or(6); + self.in_buffer.reserve(reserve_len); + let mut buf = self.in_buffer.split_off(self.in_buffer.len()); + buf.resize(reserve_len.max(buf.capacity()), 0); + let size = stream.read(&mut buf)?; + buf.truncate(size); + self.in_buffer.unsplit(buf); if size == 0 { trace!("no frame received"); return Ok(None); @@ -196,7 +206,7 @@ impl FrameCodec { let (header, length) = self.header.take().expect("Bug: no frame header"); debug_assert_eq!(payload.len() as u64, length); - let frame = Frame::from_payload(header, payload); + let frame = Frame::from_payload(header, Payload::Owned(payload)); trace!("received frame {frame}"); Ok(Some(frame)) } @@ -263,6 +273,8 @@ mod tests { #[test] fn read_frames() { + env_logger::init(); + let raw = Cursor::new(vec![ 0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x82, 0x03, 0x03, 0x02, 0x01, 0x99, @@ -270,10 +282,13 @@ mod tests { let mut sock = FrameSocket::new(raw); assert_eq!( - sock.read(None).unwrap().unwrap().into_data(), - vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] + sock.read(None).unwrap().unwrap().into_payload().as_slice(), + &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] + ); + assert_eq!( + sock.read(None).unwrap().unwrap().into_payload().as_slice(), + &[0x03, 0x02, 0x01] ); - assert_eq!(sock.read(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]); assert!(sock.read(None).unwrap().is_none()); let (_, rest) = sock.into_inner(); @@ -285,8 +300,8 @@ mod tests { let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]); assert_eq!( - sock.read(None).unwrap().unwrap().into_data(), - vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] + sock.read(None).unwrap().unwrap().into_payload().as_slice(), + &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] ); } diff --git a/src/protocol/frame/payload.rs b/src/protocol/frame/payload.rs new file mode 100644 index 00000000..8f3c073f --- /dev/null +++ b/src/protocol/frame/payload.rs @@ -0,0 +1,267 @@ +use bytes::{Bytes, BytesMut}; +use core::str; +use std::{fmt::Display, mem}; + +/// Utf8 payload. +#[derive(Debug, Default, Clone, Eq, PartialEq)] +pub struct Utf8Payload(Payload); + +impl Utf8Payload { + /// Creates from a static str. + #[inline] + pub const fn from_static(str: &'static str) -> Self { + Self(Payload::Shared(Bytes::from_static(str.as_bytes()))) + } + + /// Returns a slice of the payload. + #[inline] + pub fn as_slice(&self) -> &[u8] { + self.0.as_slice() + } + + /// Returns as a string slice. + #[inline] + pub fn as_str(&self) -> &str { + // safety: is valid uft8 + unsafe { str::from_utf8_unchecked(self.as_slice()) } + } + + /// Returns length in bytes. + #[inline] + pub fn len(&self) -> usize { + self.as_slice().len() + } + + /// Returns true if the length is 0. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// If owned converts into [`Bytes`] internals & then clones (cheaply). + #[inline] + pub fn share(&mut self) -> Self { + Self(self.0.share()) + } +} + +impl TryFrom for Utf8Payload { + type Error = str::Utf8Error; + + #[inline] + fn try_from(payload: Payload) -> Result { + str::from_utf8(payload.as_slice())?; + Ok(Self(payload)) + } +} + +impl TryFrom for Utf8Payload { + type Error = str::Utf8Error; + + #[inline] + fn try_from(bytes: Bytes) -> Result { + Payload::from(bytes).try_into() + } +} + +impl TryFrom for Utf8Payload { + type Error = str::Utf8Error; + + #[inline] + fn try_from(bytes: BytesMut) -> Result { + Payload::from(bytes).try_into() + } +} + +impl TryFrom> for Utf8Payload { + type Error = str::Utf8Error; + + #[inline] + fn try_from(bytes: Vec) -> Result { + Payload::from(bytes).try_into() + } +} + +impl From for Utf8Payload { + #[inline] + fn from(s: String) -> Self { + Self(s.into()) + } +} + +impl From<&str> for Utf8Payload { + #[inline] + fn from(s: &str) -> Self { + Self(Payload::Owned(s.as_bytes().into())) + } +} + +impl From<&String> for Utf8Payload { + #[inline] + fn from(s: &String) -> Self { + s.as_str().into() + } +} + +impl From for Payload { + #[inline] + fn from(Utf8Payload(payload): Utf8Payload) -> Self { + payload + } +} + +impl Display for Utf8Payload { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +/// A payload of a WebSocket frame. +#[derive(Debug, Clone)] +pub enum Payload { + /// Owned data with unique ownership. + Owned(BytesMut), + /// Shared data with shared ownership. + Shared(Bytes), + /// Owned vec data. + Vec(Vec), +} + +impl Payload { + /// Creates from static bytes. + #[inline] + pub const fn from_static(bytes: &'static [u8]) -> Self { + Self::Shared(Bytes::from_static(bytes)) + } + + /// Converts into [`Bytes`] internals & then clones (cheaply). + pub fn share(&mut self) -> Self { + match self { + Self::Owned(data) => { + *self = Self::Shared(mem::take(data).freeze()); + } + Self::Vec(data) => { + *self = Self::Shared(Bytes::from_owner(mem::take(data))); + } + Self::Shared(_) => {} + } + self.clone() + } + + /// Returns a slice of the payload. + #[inline] + pub fn as_slice(&self) -> &[u8] { + match self { + Payload::Owned(v) => v, + Payload::Shared(v) => v, + Payload::Vec(v) => v, + } + } + + /// Returns a mutable slice of the payload. + /// + /// Note that this will internally allocate if the payload is shared + /// and there are other references to the same data. No allocation + /// would happen if the payload is owned or if there is only one + /// `Bytes` instance referencing the data. + #[inline] + pub fn as_mut_slice(&mut self) -> &mut [u8] { + match self { + Payload::Owned(v) => &mut *v, + Payload::Vec(v) => &mut *v, + Payload::Shared(v) => { + // Using `Bytes::to_vec()` or `Vec::from(bytes.as_ref())` would mean making a copy. + // `Bytes::into()` would not make a copy if our `Bytes` instance is the only one. + let data = mem::take(v).into(); + *self = Payload::Owned(data); + match self { + Payload::Owned(v) => v, + _ => unreachable!(), + } + } + } + } + + /// Returns the length of the payload. + #[inline] + pub fn len(&self) -> usize { + self.as_slice().len() + } + + /// Returns true if the payload has a length of 0. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Consumes the payload and returns the underlying data as a string. + #[inline] + pub fn into_text(self) -> Result { + self.try_into() + } +} + +impl Default for Payload { + #[inline] + fn default() -> Self { + Self::Owned(<_>::default()) + } +} + +impl From> for Payload { + #[inline] + fn from(v: Vec) -> Self { + Payload::Vec(v) + } +} + +impl From for Payload { + #[inline] + fn from(v: String) -> Self { + v.into_bytes().into() + } +} + +impl From for Payload { + #[inline] + fn from(v: Bytes) -> Self { + Payload::Shared(v) + } +} + +impl From for Payload { + #[inline] + fn from(v: BytesMut) -> Self { + Payload::Owned(v) + } +} + +impl From<&[u8]> for Payload { + #[inline] + fn from(v: &[u8]) -> Self { + Self::Owned(v.into()) + } +} + +impl PartialEq for Payload { + #[inline] + fn eq(&self, other: &Payload) -> bool { + self.as_slice() == other.as_slice() + } +} + +impl Eq for Payload {} + +impl PartialEq<[u8]> for Payload { + #[inline] + fn eq(&self, other: &[u8]) -> bool { + self.as_slice() == other + } +} + +impl PartialEq<&[u8; N]> for Payload { + #[inline] + fn eq(&self, other: &&[u8; N]) -> bool { + self.as_slice() == *other + } +} diff --git a/src/protocol/message.rs b/src/protocol/message.rs index d71ac109..ee098c09 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -1,6 +1,6 @@ -use std::{fmt, result::Result as StdResult, str}; +use std::{borrow::Cow, fmt, result::Result as StdResult, str}; -use super::frame::{CloseFrame, Frame}; +use super::frame::{CloseFrame, Frame, Payload, Utf8Payload}; use crate::error::{CapacityError, Error, Result}; mod string_collect { @@ -135,10 +135,10 @@ impl IncompleteMessage { /// Convert an incomplete message into a complete one. pub fn complete(self) -> Result { match self.collector { - IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v)), + IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v.into())), IncompleteMessageCollector::Text(t) => { let text = t.into_string()?; - Ok(Message::Text(text)) + Ok(Message::text(text)) } } } @@ -154,17 +154,17 @@ pub enum IncompleteMessageType { #[derive(Debug, Eq, PartialEq, Clone)] pub enum Message { /// A text WebSocket message - Text(String), + Text(Utf8Payload), /// A binary WebSocket message - Binary(Vec), + Binary(Payload), /// A ping message with the specified payload /// /// The payload here must have a length less than 125 bytes - Ping(Vec), + Ping(Payload), /// A pong message with the specified payload /// /// The payload here must have a length less than 125 bytes - Pong(Vec), + Pong(Payload), /// A close message with the optional close frame. Close(Option>), /// Raw frame. Note, that you're not going to get this value while reading the message. @@ -175,7 +175,7 @@ impl Message { /// Create a new text WebSocket message from a stringable. pub fn text(string: S) -> Message where - S: Into, + S: Into, { Message::Text(string.into()) } @@ -183,7 +183,7 @@ impl Message { /// Create a new binary WebSocket message by converting to `Vec`. pub fn binary(bin: B) -> Message where - B: Into>, + B: Into, { Message::Binary(bin.into()) } @@ -232,26 +232,32 @@ impl Message { } /// Consume the WebSocket and return it as binary data. - pub fn into_data(self) -> Vec { + pub fn into_data(self) -> Payload { match self { - Message::Text(string) => string.into_bytes(), + Message::Text(string) => string.into(), Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data, - Message::Close(None) => Vec::new(), - Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(), - Message::Frame(frame) => frame.into_data(), + Message::Close(None) => <_>::default(), + Message::Close(Some(frame)) => match frame.reason { + Cow::Borrowed(s) => Payload::from_static(s.as_bytes()), + Cow::Owned(s) => s.into(), + }, + Message::Frame(frame) => frame.into_payload(), } } /// Attempt to consume the WebSocket message and convert it to a String. - pub fn into_text(self) -> Result { + pub fn into_text(self) -> Result { match self { - Message::Text(string) => Ok(string), + Message::Text(txt) => Ok(txt), Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => { - Ok(String::from_utf8(data)?) + Ok(data.try_into()?) } - Message::Close(None) => Ok(String::new()), - Message::Close(Some(frame)) => Ok(frame.reason.into_owned()), - Message::Frame(frame) => Ok(frame.into_string()?), + Message::Close(None) => Ok(<_>::default()), + Message::Close(Some(frame)) => Ok(match frame.reason { + Cow::Borrowed(s) => Utf8Payload::from_static(s), + Cow::Owned(s) => s.into(), + }), + Message::Frame(frame) => Ok(frame.into_text()?), } } @@ -259,9 +265,9 @@ impl Message { /// this will try to convert binary data to utf8. pub fn to_text(&self) -> Result<&str> { match *self { - Message::Text(ref string) => Ok(string), + Message::Text(ref string) => Ok(string.as_str()), Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => { - Ok(str::from_utf8(data)?) + Ok(str::from_utf8(data.as_slice())?) } Message::Close(None) => Ok(""), Message::Close(Some(ref frame)) => Ok(&frame.reason), @@ -271,40 +277,37 @@ impl Message { } impl From for Message { + #[inline] fn from(string: String) -> Self { Message::text(string) } } impl<'s> From<&'s str> for Message { + #[inline] fn from(string: &'s str) -> Self { Message::text(string) } } impl<'b> From<&'b [u8]> for Message { + #[inline] fn from(data: &'b [u8]) -> Self { Message::binary(data) } } impl From> for Message { + #[inline] fn from(data: Vec) -> Self { Message::binary(data) } } impl From for Vec { + #[inline] fn from(message: Message) -> Self { - message.into_data() - } -} - -impl TryFrom for String { - type Error = Error; - - fn try_from(value: Message) -> StdResult { - value.into_text() + message.into_data().as_slice().into() } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index fb1f7755..87b17156 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -13,7 +13,7 @@ use self::{ }, message::{IncompleteMessage, IncompleteMessageType}, }; -use crate::error::{Error, ProtocolError, Result}; +use crate::error::{CapacityError, Error, ProtocolError, Result}; use log::*; use std::{ io::{self, Read, Write}, @@ -439,7 +439,7 @@ impl WebSocketContext { } let frame = match message { - Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true), + Message::Text(data) => Frame::message(data, OpCode::Data(OpData::Text), true), Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true), Message::Ping(data) => Frame::ping(data), Message::Pong(data) => { @@ -603,14 +603,14 @@ impl WebSocketContext { Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i))) } OpCtl::Ping => { - let data = frame.into_data(); + let mut data = frame.into_payload(); // No ping processing after we sent a close frame. if self.state.is_active() { - self.set_additional(Frame::pong(data.clone())); + self.set_additional(Frame::pong(data.share())); } Ok(Some(Message::Ping(data))) } - OpCtl::Pong => Ok(Some(Message::Pong(frame.into_data()))), + OpCtl::Pong => Ok(Some(Message::Pong(frame.into_payload()))), } } @@ -619,7 +619,10 @@ impl WebSocketContext { match data { OpData::Continue => { if let Some(ref mut msg) = self.incomplete { - msg.extend(frame.into_data(), self.config.max_message_size)?; + msg.extend( + frame.into_payload().as_slice(), + self.config.max_message_size, + )?; } else { return Err(Error::Protocol( ProtocolError::UnexpectedContinueFrame, @@ -634,23 +637,27 @@ impl WebSocketContext { c if self.incomplete.is_some() => { Err(Error::Protocol(ProtocolError::ExpectedFragment(c))) } + OpData::Text if fin => { + check_max_size(frame.payload().len(), self.config.max_message_size)?; + Ok(Some(Message::Text(frame.into_text()?))) + } + OpData::Binary if fin => { + check_max_size(frame.payload().len(), self.config.max_message_size)?; + Ok(Some(Message::Binary(frame.into_payload()))) + } OpData::Text | OpData::Binary => { - let msg = { - let message_type = match data { - OpData::Text => IncompleteMessageType::Text, - OpData::Binary => IncompleteMessageType::Binary, - _ => panic!("Bug: message is not text nor binary"), - }; - let mut m = IncompleteMessage::new(message_type); - m.extend(frame.into_data(), self.config.max_message_size)?; - m + let message_type = match data { + OpData::Text => IncompleteMessageType::Text, + OpData::Binary => IncompleteMessageType::Binary, + _ => panic!("Bug: message is not text nor binary"), }; - if fin { - Ok(Some(msg.complete()?)) - } else { - self.incomplete = Some(msg); - Ok(None) - } + let mut incomplete = IncompleteMessage::new(message_type); + incomplete.extend( + frame.into_payload().as_slice(), + self.config.max_message_size, + )?; + self.incomplete = Some(incomplete); + Ok(None) } OpData::Reserved(i) => { Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i))) @@ -737,6 +744,15 @@ impl WebSocketContext { } } +fn check_max_size(size: usize, max_size: Option) -> crate::Result<()> { + if let Some(max_size) = max_size { + if size > max_size { + return Err(Error::Capacity(CapacityError::MessageTooLong { size, max_size })); + } + } + Ok(()) +} + /// The current connection state. #[derive(Debug, PartialEq, Eq, Clone, Copy)] enum WebSocketState { @@ -826,10 +842,10 @@ mod tests { 0x03, ]); let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None); - assert_eq!(socket.read().unwrap(), Message::Ping(vec![1, 2])); - assert_eq!(socket.read().unwrap(), Message::Pong(vec![3])); + assert_eq!(socket.read().unwrap(), Message::Ping(vec![1, 2].into())); + assert_eq!(socket.read().unwrap(), Message::Pong(vec![3].into())); assert_eq!(socket.read().unwrap(), Message::Text("Hello, World!".into())); - assert_eq!(socket.read().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03])); + assert_eq!(socket.read().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03].into())); } #[test]