Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for shared websocket messages #104

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions examples/autobahn-server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ fn main() {

for stream in server.incoming() {
spawn(move || match stream {
Ok(stream) => if let Err(err) = handle_client(stream) {
match err {
Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (),
e => error!("test: {}", e),
Ok(stream) => {
if let Err(err) = handle_client(stream) {
match err {
Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8 => (),
e => error!("test: {}", e),
}
}
},
}
Err(e) => error!("Error accepting stream: {}", e),
});
}
Expand Down
4 changes: 2 additions & 2 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::string;
use http;
use httparse;

use crate::protocol::Message;
use crate::protocol::EitherMessage;

#[cfg(feature = "tls")]
pub mod tls {
Expand Down Expand Up @@ -59,7 +59,7 @@ pub enum Error {
/// Protocol violation.
Protocol(Cow<'static, str>),
/// Message send queue full.
SendQueueFull(Message),
SendQueueFull(EitherMessage),
/// UTF coding error
Utf8,
/// Invalid URL.
Expand Down
18 changes: 12 additions & 6 deletions src/handshake/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,18 @@ fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
let mut req = Vec::new();
let uri = request.uri();

let authority = uri.authority()
let authority = uri
.authority()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?
.as_str();
let host = if let Some(idx) = authority.find('@') { // handle possible name:password@
let host = if let Some(idx) = authority.find('@') {
// handle possible name:password@
authority.split_at(idx + 1).1
} else {
authority
};
if authority.is_empty() {
return Err(Error::Url("URL contains empty host name".into()))
return Err(Error::Url("URL contains empty host name".into()));
}

write!(
Expand Down Expand Up @@ -261,8 +263,8 @@ fn generate_key() -> String {
#[cfg(test)]
mod tests {
use super::super::machine::TryParse;
use crate::client::IntoClientRequest;
use super::{generate_key, generate_request, Response};
use crate::client::IntoClientRequest;

#[test]
fn random_keys() {
Expand Down Expand Up @@ -299,7 +301,9 @@ mod tests {

#[test]
fn request_formatting_with_host() {
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
let request = "wss://localhost:9001/getCaseCount"
.into_client_request()
.unwrap();
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
let correct = b"\
GET /getCaseCount HTTP/1.1\r\n\
Expand All @@ -316,7 +320,9 @@ mod tests {

#[test]
fn request_formatting_with_at() {
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
let request = "wss://user:pass@localhost:9001/getCaseCount"
.into_client_request()
.unwrap();
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
let correct = b"\
GET /getCaseCount HTTP/1.1\r\n\
Expand Down
140 changes: 113 additions & 27 deletions src/protocol/frame/frame.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt};
use bytes::{Bytes, BytesMut};
use log::*;
use std::borrow::Cow;
use std::convert::{TryFrom, TryInto};
use std::default::Default;
use std::fmt;
use std::io::{Cursor, ErrorKind, Read, Write};
use std::result::Result as StdResult;
use std::str;
use std::string::{FromUtf8Error, String};

use super::coding::{CloseCode, Control, Data, OpCode};
Expand Down Expand Up @@ -205,11 +208,83 @@ impl FrameHeader {
}
}

/// A binary payload that might or might not be shared.
#[derive(Debug, Clone)]
pub enum Payload {
Bytes(Vec<u8>),
ShBytes(Bytes),
}

impl Payload {
pub fn len(&self) -> usize {
match self {
Self::Bytes(bytes) => bytes.len(),
Self::ShBytes(bytes) => bytes.len(),
}
}

pub fn as_bytes(&self) -> &[u8] {
match self {
Self::Bytes(bytes) => bytes.as_slice(),
Self::ShBytes(bytes) => bytes.as_ref(),
}
}

pub fn unwrap_bytes(self) -> Vec<u8> {
match self {
Self::Bytes(bytes) => bytes,
_ => panic!("expected variant `Payload::Bytes`"),
}
}
}

impl TryFrom<Payload> for String {
type Error = FromUtf8Error;

fn try_from(payload: Payload) -> std::result::Result<Self, Self::Error> {
let vec = match payload {
Payload::Bytes(bytes) => bytes,
Payload::ShBytes(bytes) => bytes.as_ref().to_owned(),
};
String::from_utf8(vec)
}
}

impl From<Vec<u8>> for Payload {
fn from(bytes: Vec<u8>) -> Self {
Self::Bytes(bytes)
}
}

impl From<&[u8]> for Payload {
fn from(bytes: &[u8]) -> Self {
bytes.to_owned().into()
}
}

impl From<String> for Payload {
fn from(string: String) -> Self {
Self::Bytes(string.into())
}
}

impl From<&str> for Payload {
fn from(string: &str) -> Self {
string.to_owned().into()
}
}

impl From<Bytes> for Payload {
fn from(bytes: Bytes) -> Self {
Self::ShBytes(bytes)
}
}

/// A struct representing a WebSocket frame.
#[derive(Debug, Clone)]
pub struct Frame {
header: FrameHeader,
payload: Vec<u8>,
payload: Payload,
}

impl Frame {
Expand Down Expand Up @@ -241,14 +316,8 @@ impl Frame {

/// Get a reference to the frame's payload.
#[inline]
pub fn payload(&self) -> &Vec<u8> {
&self.payload
}

/// Get a mutable reference to the frame's payload.
#[inline]
pub fn payload_mut(&mut self) -> &mut Vec<u8> {
&mut self.payload
pub fn payload(&self) -> &[u8] {
self.payload.as_bytes()
}

/// Test whether the frame is masked.
Expand All @@ -271,20 +340,27 @@ impl Frame {
#[inline]
pub(crate) fn apply_mask(&mut self) {
if let Some(mask) = self.header.mask.take() {
apply_mask(&mut self.payload, mask)
match &mut self.payload {
Payload::Bytes(bytes) => apply_mask(bytes, mask),
Payload::ShBytes(bytes) => {
let mut bytes_mut = BytesMut::from(bytes.as_ref());
apply_mask(&mut bytes_mut, mask);
*bytes = bytes_mut.freeze();
}
}
}
}

/// Consume the frame into its payload as binary.
#[inline]
pub fn into_data(self) -> Vec<u8> {
pub fn into_payload(self) -> Payload {
self.payload
}

/// Consume the frame into its payload as string.
#[inline]
pub fn into_string(self) -> StdResult<String, FromUtf8Error> {
String::from_utf8(self.payload)
self.payload.try_into()
}

/// Consume the frame into a closing frame.
Expand All @@ -294,10 +370,16 @@ impl Frame {
0 => Ok(None),
1 => Err(Error::Protocol("Invalid close sequence".into())),
_ => {
let mut data = self.payload;
let code = NetworkEndian::read_u16(&data[0..2]).into();
data.drain(0..2);
let text = String::from_utf8(data)?;
let data = self.payload;
let code = NetworkEndian::read_u16(&data.as_bytes()[0..2]).into();
let bytes = match data {
Payload::Bytes(mut bytes) => {
bytes.drain(0..2);
bytes
}
Payload::ShBytes(bytes) => bytes.as_ref()[2..].to_owned(),
};
let text = String::from_utf8(bytes)?;
Ok(Some(CloseFrame {
code,
reason: text.into(),
Expand All @@ -308,7 +390,10 @@ impl Frame {

/// Create a new data frame.
#[inline]
pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
pub fn message<P>(payload: P, opcode: OpCode, is_final: bool) -> Frame
where
P: Into<Payload>,
{
debug_assert!(
match opcode {
OpCode::Data(_) => true,
Expand All @@ -323,7 +408,7 @@ impl Frame {
opcode,
..FrameHeader::default()
},
payload: data,
payload: payload.into(),
}
}

Expand All @@ -335,7 +420,7 @@ impl Frame {
opcode: OpCode::Control(Control::Pong),
..FrameHeader::default()
},
payload: data,
payload: data.into(),
}
}

Expand All @@ -347,7 +432,7 @@ impl Frame {
opcode: OpCode::Control(Control::Ping),
..FrameHeader::default()
},
payload: data,
payload: data.into(),
}
}

Expand All @@ -365,12 +450,12 @@ impl Frame {

Frame {
header: FrameHeader::default(),
payload,
payload: payload.into(),
}
}

/// Create a frame from given header and data.
pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self {
pub fn from_payload(header: FrameHeader, payload: Payload) -> Self {
Frame { header, payload }
}

Expand Down Expand Up @@ -405,6 +490,7 @@ payload: 0x{}
self.len(),
self.payload.len(),
self.payload
.as_bytes()
.iter()
.map(|byte| format!("{:x}", byte))
.collect::<String>()
Expand Down Expand Up @@ -476,11 +562,11 @@ mod tests {
Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap();
assert_eq!(length, 7);
let mut payload = Vec::new();
raw.read_to_end(&mut payload).unwrap();
let frame = Frame::from_payload(header, payload);
let mut bytes = Vec::new();
raw.read_to_end(&mut bytes).unwrap();
let frame = Frame::from_payload(header, bytes.into());
assert_eq!(
frame.into_data(),
frame.into_payload().unwrap_bytes(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
}
Expand All @@ -495,7 +581,7 @@ mod tests {

#[test]
fn display() {
let f = Frame::message("hi there".into(), OpCode::Data(Data::Text), true);
let f = Frame::message("hi there", OpCode::Data(Data::Text), true);
let view = format!("{}", f);
assert!(view.contains("payload:"));
}
Expand Down
8 changes: 4 additions & 4 deletions src/protocol/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ impl FrameCodec {

let (header, length) = self.header.take().expect("Bug: no frame header");
debug_assert_eq!(payload.len() as u64, length);
let frame = Frame::from_payload(header, payload);
let frame = Frame::from_payload(header, payload.into());
trace!("received frame {}", frame);
Ok(Some(frame))
}
Expand Down Expand Up @@ -228,11 +228,11 @@ mod tests {
let mut sock = FrameSocket::new(raw);

assert_eq!(
sock.read_frame(None).unwrap().unwrap().into_data(),
sock.read_frame(None).unwrap().unwrap().into_payload().unwrap_bytes(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
assert_eq!(
sock.read_frame(None).unwrap().unwrap().into_data(),
sock.read_frame(None).unwrap().unwrap().into_payload().unwrap_bytes(),
vec![0x03, 0x02, 0x01]
);
assert!(sock.read_frame(None).unwrap().is_none());
Expand All @@ -246,7 +246,7 @@ mod tests {
let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]);
assert_eq!(
sock.read_frame(None).unwrap().unwrap().into_data(),
sock.read_frame(None).unwrap().unwrap().into_payload().unwrap_bytes(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
}
Expand Down
Loading