Skip to content

Commit

Permalink
yamux: Backport window auto tune
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandru Vasile <[email protected]>
  • Loading branch information
lexnv committed Sep 26, 2024
1 parent d50ec10 commit 066dae3
Show file tree
Hide file tree
Showing 7 changed files with 606 additions and 187 deletions.
157 changes: 61 additions & 96 deletions src/yamux/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@

mod cleanup;
mod closing;
pub(crate) mod flow_control;
pub(crate) mod rtt;
mod stream;

use crate::yamux::{
Expand Down Expand Up @@ -368,6 +370,14 @@ struct Active<T> {

pending_frames: VecDeque<Frame<()>>,
new_outbound_stream_waker: Option<Waker>,

rtt: rtt::Rtt,

/// A stream's `max_stream_receive_window` can grow beyond [`DEFAULT_CREDIT`], see
/// [`Stream::next_window_update`]. This field is the sum of the bytes by which all streams'
/// `max_stream_receive_window` have each exceeded [`DEFAULT_CREDIT`]. Used to enforce
/// [`Config::max_connection_receive_window`].
accumulated_max_stream_windows: Arc<Mutex<usize>>,
}

/// `Stream` to `Connection` commands.
Expand All @@ -385,13 +395,9 @@ enum Action {
/// Nothing to be done.
None,
/// A new stream has been opened by the remote.
New(Stream, Option<Frame<WindowUpdate>>),
/// A window update should be sent to the remote.
Update(Frame<WindowUpdate>),
New(Stream),
/// A ping should be answered.
Ping(Frame<Ping>),
/// A stream should be reset.
Reset(Frame<Data>),
/// The connection should be terminated.
Terminate(Frame<GoAway>),
}
Expand Down Expand Up @@ -424,7 +430,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
fn new(socket: T, cfg: Config, mode: Mode) -> Self {
let id = Id::random();
tracing::debug!(target: LOG_TARGET, "new connection: {} ({:?})", id, mode);
let socket = frame::Io::new(id, socket, cfg.max_buffer_size).fuse();
let socket = frame::Io::new(id, socket).fuse();
Active {
id,
mode,
Expand All @@ -439,6 +445,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
},
pending_frames: VecDeque::default(),
new_outbound_stream_waker: None,
rtt: rtt::Rtt::new(),
accumulated_max_stream_windows: Default::default(),
}
}

Expand All @@ -459,6 +467,14 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream>> {
loop {
if self.socket.poll_ready_unpin(cx).is_ready() {
// Note `next_ping` does not register a waker and thus if not called regularly (idle
// connection) no ping is sent. This is deliberate as an idle connection does not
// need RTT measurements to increase its stream receive window.
if let Some(frame) = self.rtt.next_ping() {
self.socket.start_send_unpin(frame.into())?;
continue;
}

if let Some(frame) = self.pending_frames.pop_front() {
self.socket.start_send_unpin(frame)?;
continue;
Expand Down Expand Up @@ -522,20 +538,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
tracing::trace!(target: LOG_TARGET, "{}: creating new outbound stream", self.id);

let id = self.next_stream_id()?;
let extra_credit = self.config.receive_window - DEFAULT_CREDIT;

if extra_credit > 0 {
let mut frame = Frame::window_update(id, extra_credit);
frame.header_mut().syn();
tracing::trace!(target: LOG_TARGET, "{}/{}: sending initial {}", self.id, id, frame.header());
self.pending_frames.push_back(frame.into());
}

let mut stream = self.make_new_outbound_stream(id, self.config.receive_window);

if extra_credit == 0 {
stream.set_flag(stream::Flag::Syn)
}
let stream = self.make_new_outbound_stream(id);

tracing::debug!(target: LOG_TARGET, "{}: new outbound {} of {}", self.id, stream, self);
self.streams.insert(id, stream.clone_shared());
Expand Down Expand Up @@ -584,23 +587,13 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
// The remote may be out of credit though and blocked on
// writing more data. We may need to reset the stream.
State::SendClosed => {
if self.config.window_update_mode == WindowUpdateMode::OnRead
&& shared.window == 0
{
// The remote may be waiting for a window update
// which we will never send, so reset the stream now.
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
} else {
// The remote has either still credit or will be given more
// (due to an enqueued window update or because the update
// mode is `OnReceive`) or we already have inbound frames in
// the socket buffer which will be processed later. In any
// case we will reply with an RST in `Connection::on_data`
// because the stream will no longer be known.
None
}
// The remote has either still credit or will be given more
// due to an enqueued window update or we already have
// inbound frames in the socket buffer which will be
// processed later. In any case we will reply with an RST in
// `Connection::on_data` because the stream will no longer
// be known.
None
}
// The stream was properly closed. We already have sent our FIN frame. The
// remote end has already done so in the past.
Expand Down Expand Up @@ -629,7 +622,9 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
fn on_frame(&mut self, frame: Frame<()>) -> Result<Option<Stream>> {
tracing::trace!(target: LOG_TARGET, "{}: received: {}", self.id, frame.header());

if frame.header().flags().contains(header::ACK) {
if frame.header().flags().contains(header::ACK)
&& matches!(frame.header().tag(), Tag::Data | Tag::WindowUpdate)
{
let id = frame.header().stream_id();
if let Some(stream) = self.streams.get(&id) {
stream.lock().update_state(self.id, id, State::Open { acknowledged: true });
Expand All @@ -647,26 +642,14 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
};
match action {
Action::None => {}
Action::New(stream, update) => {
Action::New(stream) => {
tracing::trace!(target: LOG_TARGET, "{}: new inbound {} of {}", self.id, stream, self);
if let Some(f) = update {
tracing::trace!(target: LOG_TARGET, "{}/{}: sending update", self.id, f.header().stream_id());
self.pending_frames.push_back(f.into());
}
return Ok(Some(stream));
}
Action::Update(f) => {
tracing::trace!(target: LOG_TARGET, "{}: sending update: {:?}", self.id, f.header());
self.pending_frames.push_back(f.into());
}
Action::Ping(f) => {
tracing::trace!(target: LOG_TARGET, "{}/{}: pong", self.id, f.header().stream_id());
self.pending_frames.push_back(f.into());
}
Action::Reset(f) => {
tracing::trace!(target: LOG_TARGET, "{}/{}: sending reset", self.id, f.header().stream_id());
self.pending_frames.push_back(f.into());
}
Action::Terminate(f) => {
tracing::trace!(target: LOG_TARGET, "{}: sending term", self.id);
self.pending_frames.push_back(f.into());
Expand Down Expand Up @@ -718,35 +701,22 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
tracing::error!(target: LOG_TARGET, "{}: maximum number of streams reached", self.id);
return Action::Terminate(Frame::internal_error());
}
let mut stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT);
let mut window_update = None;
let stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT);
{
let mut shared = stream.shared();
if is_finish {
shared.update_state(self.id, stream_id, State::RecvClosed);
}
shared.window = shared.window.saturating_sub(frame.body_len());
shared.consume_receive_window(frame.body_len());
shared.buffer.push(frame.into_body());

if matches!(self.config.window_update_mode, WindowUpdateMode::OnReceive) {
if let Some(credit) = shared.next_window_update() {
shared.window += credit;
let mut frame = Frame::window_update(stream_id, credit);
frame.header_mut().ack();
window_update = Some(frame)
}
}
}
if window_update.is_none() {
stream.set_flag(stream::Flag::Ack)
}
self.streams.insert(stream_id, stream.clone_shared());
return Action::New(stream, window_update);
return Action::New(stream);
}

if let Some(s) = self.streams.get_mut(&stream_id) {
let mut shared = s.lock();
if frame.body().len() > shared.window as usize {
if frame.body_len() > shared.receive_window() {
tracing::error!(target: LOG_TARGET,
"{}/{}: frame body larger than window of stream",
self.id,
Expand All @@ -757,29 +727,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
if is_finish {
shared.update_state(self.id, stream_id, State::RecvClosed);
}
let max_buffer_size = self.config.max_buffer_size;
if shared.buffer.len() >= max_buffer_size {
tracing::error!(target: LOG_TARGET,
"{}/{}: buffer of stream grows beyond limit",
self.id,
stream_id
);
let mut header = Header::data(stream_id, 0);
header.rst();
return Action::Reset(Frame::new(header));
}
shared.window = shared.window.saturating_sub(frame.body_len());
shared.consume_receive_window(frame.body_len());
shared.buffer.push(frame.into_body());
if let Some(w) = shared.reader.take() {
w.wake()
}
if matches!(self.config.window_update_mode, WindowUpdateMode::OnReceive) {
if let Some(credit) = shared.next_window_update() {
shared.window += credit;
let frame = Frame::window_update(stream_id, credit);
return Action::Update(frame);
}
}
} else {
tracing::trace!(target: LOG_TARGET,
"{}/{}: data frame for unknown stream, possibly dropped earlier: {:?}",
Expand Down Expand Up @@ -835,19 +787,18 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
}

let credit = frame.header().credit() + DEFAULT_CREDIT;
let mut stream = self.make_new_inbound_stream(stream_id, credit);
stream.set_flag(stream::Flag::Ack);
let stream = self.make_new_inbound_stream(stream_id, credit);

if is_finish {
stream.shared().update_state(self.id, stream_id, State::RecvClosed);
}
self.streams.insert(stream_id, stream.clone_shared());
return Action::New(stream, None);
return Action::New(stream);
}

if let Some(s) = self.streams.get_mut(&stream_id) {
let mut shared = s.lock();
shared.credit += frame.header().credit();
shared.increase_send_window_by(frame.header().credit());
if is_finish {
shared.update_state(self.id, stream_id, State::RecvClosed);
}
Expand Down Expand Up @@ -876,8 +827,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
fn on_ping(&mut self, frame: &Frame<Ping>) -> Action {
let stream_id = frame.header().stream_id();
if frame.header().flags().contains(header::ACK) {
// pong
return Action::None;
return self.rtt.handle_pong(frame.nonce());
}
if stream_id == CONNECTION_ID || self.streams.contains_key(&stream_id) {
let mut hdr = Header::ping(frame.header().nonce());
Expand Down Expand Up @@ -909,10 +859,18 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
waker.wake();
}

Stream::new_inbound(id, self.id, config, credit, sender)
Stream::new_inbound(
id,
self.id,
config,
credit,
sender,
self.rtt.clone(),
self.accumulated_max_stream_windows.clone(),
)
}

fn make_new_outbound_stream(&mut self, id: StreamId, window: u32) -> Stream {
fn make_new_outbound_stream(&mut self, id: StreamId) -> Stream {
let config = self.config.clone();

let (sender, receiver) = mpsc::channel(10); // 10 is an arbitrary number.
Expand All @@ -921,7 +879,14 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
waker.wake();
}

Stream::new_outbound(id, self.id, config, window, sender)
Stream::new_outbound(
id,
self.id,
config,
sender,
self.rtt.clone(),
self.accumulated_max_stream_windows.clone(),
)
}

fn next_stream_id(&mut self) -> Result<StreamId> {
Expand Down
Loading

0 comments on commit 066dae3

Please sign in to comment.