Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support bytes::Bytes & bytes::BytesMut payloads for binary & text messaging #465

Merged
merged 13 commits into from
Dec 14, 2024
Merged
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ __rustls-tls = ["rustls", "rustls-pki-types"]
[dependencies]
data-encoding = { version = "2", optional = true }
byteorder = "1.3.2"
bytes = "1.0"
bytes = "1.9.0"
http = { version = "1.0", optional = true }
httparse = { version = "1.3.4", optional = true }
log = "0.4.8"
Expand Down
4 changes: 2 additions & 2 deletions benches/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn benchmark(c: &mut Criterion) {
writer
.send(match i {
_ if i % 3 == 0 => Message::binary(i.to_le_bytes().to_vec()),
_ => Message::Text(format!("{{\"id\":{i}}}")),
_ => Message::text(format!("{{\"id\":{i}}}")),
})
.unwrap();
sum += i;
Expand All @@ -68,7 +68,7 @@ fn benchmark(c: &mut Criterion) {
sum += u64::from_le_bytes(*a);
}
Message::Text(msg) => {
let i: u64 = msg[6..msg.len() - 1].parse().unwrap();
let i: u64 = msg.as_str()[6..msg.len() - 1].parse().unwrap();
sum += i;
}
m => panic!("Unexpected {m}"),
Expand Down
9 changes: 4 additions & 5 deletions benches/write.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
//! Benchmarks for write performance.
use criterion::Criterion;
use std::{
hint,
io::{self, Read, Write},
hint, io,
time::{Duration, Instant},
};
use tungstenite::{protocol::Role, Message, WebSocket};
Expand All @@ -16,12 +15,12 @@ const MOCK_WRITE_LEN: usize = 8 * 1024 * 1024;
/// Each `flush` takes **~8µs** to simulate flush io.
struct MockWrite(Vec<u8>);

impl Read for MockWrite {
impl io::Read for MockWrite {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "reads not supported"))
}
}
impl Write for MockWrite {
impl io::Write for MockWrite {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.0.len() + buf.len() > MOCK_WRITE_LEN {
self.flush()?;
Expand Down Expand Up @@ -58,7 +57,7 @@ fn benchmark(c: &mut Criterion) {
for i in 0_u64..100_000 {
let msg = match i {
_ if i % 3 == 0 => Message::binary(i.to_le_bytes().to_vec()),
_ => Message::Text(format!("{{\"id\":{i}}}")),
_ => Message::text(format!("{{\"id\":{i}}}")),
};
ws.write(msg).unwrap();
}
Expand Down
2 changes: 1 addition & 1 deletion examples/autobahn-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect("ws://localhost:9001/getCaseCount")?;
let msg = socket.read()?;
socket.close(None)?;
Ok(msg.into_text()?.parse::<u32>().unwrap())
Ok(msg.into_text()?.as_str().parse::<u32>().unwrap())
}

fn update_reports() -> Result<()> {
Expand Down
78 changes: 43 additions & 35 deletions src/protocol/frame/frame.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
use byteorder::{NetworkEndian, ReadBytesExt};
use log::*;
use std::{
borrow::Cow,
default::Default,
fmt,
io::{Cursor, ErrorKind, Read, Write},
result::Result as StdResult,
str::Utf8Error,
string::{FromUtf8Error, String},
string::String,
};

use byteorder::{NetworkEndian, ReadBytesExt};
use log::*;

use super::{
coding::{CloseCode, Control, Data, OpCode},
mask::{apply_mask, generate_mask},
Payload,
};
use crate::{
error::{Error, ProtocolError, Result},
protocol::frame::Utf8Payload,
};
use crate::error::{Error, ProtocolError, Result};
use bytes::{Buf, BytesMut};

/// A struct representing the close command.
#[derive(Debug, Clone, Eq, PartialEq)]
Expand Down Expand Up @@ -207,7 +213,7 @@ impl FrameHeader {
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Frame {
header: FrameHeader,
payload: Vec<u8>,
payload: Payload,
}

impl Frame {
Expand Down Expand Up @@ -239,14 +245,14 @@ impl Frame {

/// Get a reference to the frame's payload.
#[inline]
pub fn payload(&self) -> &Vec<u8> {
&self.payload
pub fn payload(&self) -> &[u8] {
self.payload.as_slice()
}

/// Get a mutable reference to the frame's payload.
#[inline]
pub fn payload_mut(&mut self) -> &mut Vec<u8> {
&mut self.payload
pub fn payload_mut(&mut self) -> &mut [u8] {
self.payload.as_mut_slice()
}

/// Test whether the frame is masked.
Expand All @@ -269,26 +275,26 @@ impl Frame {
#[inline]
pub(crate) fn apply_mask(&mut self) {
if let Some(mask) = self.header.mask.take() {
apply_mask(&mut self.payload, mask);
apply_mask(self.payload.as_mut_slice(), mask);
}
}

/// Consume the frame into its payload as binary.
/// Consume the frame into its payload as string.
#[inline]
pub fn into_data(self) -> Vec<u8> {
self.payload
pub fn into_text(self) -> StdResult<Utf8Payload, Utf8Error> {
self.payload.into_text()
}

/// Consume the frame into its payload as string.
/// Consume the frame into its payload.
#[inline]
pub fn into_string(self) -> StdResult<String, FromUtf8Error> {
String::from_utf8(self.payload)
pub fn into_payload(self) -> Payload {
self.payload
}

/// Get frame payload as `&str`.
#[inline]
pub fn to_text(&self) -> Result<&str, Utf8Error> {
std::str::from_utf8(&self.payload)
std::str::from_utf8(self.payload.as_slice())
}

/// Consume the frame into a closing frame.
Expand All @@ -298,64 +304,66 @@ impl Frame {
0 => Ok(None),
1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)),
_ => {
let mut data = self.payload;
let mut data = self.payload.as_slice();
let code = u16::from_be_bytes([data[0], data[1]]).into();
data.drain(0..2);
let text = String::from_utf8(data)?;
data.advance(2);
let text = String::from_utf8(data.to_vec())?;
Ok(Some(CloseFrame { code, reason: text.into() }))
}
}
}

/// Create a new data frame.
#[inline]
pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
pub fn message(data: impl Into<Payload>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");

Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
Frame {
header: FrameHeader { is_final, opcode, ..FrameHeader::default() },
payload: data.into(),
}
}

/// Create a new Pong control frame.
#[inline]
pub fn pong(data: Vec<u8>) -> Frame {
pub fn pong(data: impl Into<Payload>) -> Frame {
Frame {
header: FrameHeader {
opcode: OpCode::Control(Control::Pong),
..FrameHeader::default()
},
payload: data,
payload: data.into(),
}
}

/// Create a new Ping control frame.
#[inline]
pub fn ping(data: Vec<u8>) -> Frame {
pub fn ping(data: impl Into<Payload>) -> Frame {
Frame {
header: FrameHeader {
opcode: OpCode::Control(Control::Ping),
..FrameHeader::default()
},
payload: data,
payload: data.into(),
}
}

/// Create a new Close control frame.
#[inline]
pub fn close(msg: Option<CloseFrame>) -> Frame {
let payload = if let Some(CloseFrame { code, reason }) = msg {
let mut p = Vec::with_capacity(reason.len() + 2);
let mut p = BytesMut::with_capacity(reason.len() + 2);
p.extend(u16::from(code).to_be_bytes());
p.extend_from_slice(reason.as_bytes());
p
} else {
Vec::new()
<_>::default()
};

Frame { header: FrameHeader::default(), payload }
Frame { header: FrameHeader::default(), payload: Payload::Owned(payload) }
}

/// Create a frame from given header and data.
pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self {
pub fn from_payload(header: FrameHeader, payload: Payload) -> Self {
Frame { header, payload }
}

Expand Down Expand Up @@ -391,7 +399,7 @@ payload: 0x{}
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
self.len(),
self.payload.len(),
self.payload.iter().fold(String::new(), |mut output, byte| {
self.payload.as_slice().iter().fold(String::new(), |mut output, byte| {
_ = write!(output, "{byte:02x}");
output
})
Expand Down Expand Up @@ -465,8 +473,8 @@ mod tests {
assert_eq!(length, 7);
let mut payload = Vec::new();
raw.read_to_end(&mut payload).unwrap();
let frame = Frame::from_payload(header, payload);
assert_eq!(frame.into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let frame = Frame::from_payload(header, payload.into());
assert_eq!(frame.into_payload().as_slice(), &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
}

#[test]
Expand All @@ -479,7 +487,7 @@ mod tests {

#[test]
fn display() {
let f = Frame::message("hi there".into(), OpCode::Data(Data::Text), true);
let f = Frame::message(Payload::from_static(b"hi there"), OpCode::Data(Data::Text), true);
let view = format!("{f}");
assert!(view.contains("payload:"));
}
Expand Down
Loading
Loading