Skip to content

Commit

Permalink
Add backpressure on the streaming client.
Browse files Browse the repository at this point in the history
  • Loading branch information
bitbound committed Dec 24, 2024
1 parent 3d26580 commit c59ea7f
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 50 deletions.
5 changes: 2 additions & 3 deletions ControlR.Streamer/Services/DesktopCapturer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ private async Task EncodeScreenCaptures(CancellationToken stoppingToken)

if (captureResult.HadNoChanges)
{
_logger.LogDebug("DirectX output had no changes.");
await _delayer.Delay(_afterFailureDelay, stoppingToken);
continue;
}
Expand Down Expand Up @@ -385,14 +384,14 @@ private Task ProcessMetrics()
// Keep only frames in our sample window.
while (
_sentRegions.TryPeek(out var frame) &&
frame.Timestamp.AddSeconds(3) < _timeProvider.GetLocalNow())
frame.Timestamp.AddSeconds(20) < _timeProvider.GetLocalNow())
{
_sentRegions.TryDequeue(out _);
}

if (_sentRegions.Count >= 2)
{
var sampleSpan = _sentRegions.Last().Timestamp - _sentRegions.First().Timestamp;
var sampleSpan = _timeProvider.GetLocalNow() - _sentRegions.First().Timestamp;
_currentMbps = _sentRegions.Sum(x => x.Size) / 1024.0 / 1024.0 / sampleSpan.TotalSeconds * 8;
}
else if (_sentRegions.Count == 1)
Expand Down
3 changes: 2 additions & 1 deletion ControlR.Streamer/Services/StreamerStreamingClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ internal sealed class StreamerStreamingClient(
IClipboardManager clipboardManager,
IMemoryProvider memoryProvider,
IInputSimulator inputSimulator,
IDelayer delayer,
IOptions<StartupOptions> startupOptions,
ILogger<StreamerStreamingClient> logger)
: StreamingClient(messenger, memoryProvider, logger), IStreamerStreamingClient
: StreamingClient(messenger, memoryProvider, delayer, logger), IStreamerStreamingClient
{
private readonly IHostApplicationLifetime _appLifetime = appLifetime;
private readonly IClipboardManager _clipboardManager = clipboardManager;
Expand Down
13 changes: 8 additions & 5 deletions ControlR.Web.Client/Services/ViewerStreamingClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ public class ViewerStreamingClient(
IMessenger messenger,
IMemoryProvider memoryProvider,
IDelayer delayer,
ILogger<ViewerStreamingClient> logger,
ILogger<StreamingClient> baseLogger) : StreamingClient(messenger, memoryProvider, baseLogger), IViewerStreamingClient
ILogger<ViewerStreamingClient> logger)
: StreamingClient(messenger, memoryProvider, delayer, logger), IViewerStreamingClient
{
private readonly IDelayer _delayer = delayer;
private readonly ILogger<ViewerStreamingClient> _logger = logger;

public async Task RequestClipboardText(Guid sessionId, CancellationToken cancellationToken)
{
await TrySend(
Expand Down Expand Up @@ -168,13 +171,13 @@ private async Task TrySend(Func<Task> func, [CallerMemberName] string callerName
{
try
{
using var _ = logger.BeginScope(callerName);
using var _ = _logger.BeginScope(callerName);
await WaitForConnection();
await func.Invoke();
}
catch (Exception ex)
{
logger.LogError(ex, "Error while sending message via websocket stream..");
_logger.LogError(ex, "Error while sending message via websocket stream..");
}
}

Expand All @@ -188,7 +191,7 @@ private async Task WaitForConnection()
using var cts = new CancellationTokenSource();
cts.CancelAfter(TimeSpan.FromSeconds(30));

await delayer.WaitForAsync(
await _delayer.WaitForAsync(
condition: () => State == WebSocketState.Open || IsDisposed,
pollingDelay: TimeSpan.FromMilliseconds(100),
cancellationToken: cts.Token);
Expand Down
2 changes: 1 addition & 1 deletion ControlR.Web.Server/wwwroot/downloads/AgentVersion.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.11.76.0
0.11.78.0
183 changes: 143 additions & 40 deletions Libraries/ControlR.Libraries.Clients/Services/StreamingClient.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Net.WebSockets;
using ControlR.Libraries.Shared.Services;
using System.Net.WebSockets;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

Expand All @@ -16,14 +17,26 @@ public interface IStreamingClient : IAsyncDisposable, IClosable
public abstract class StreamingClient(
IMessenger messenger,
IMemoryProvider memoryProvider,
IDelayer delayer,
ILogger<StreamingClient> logger) : Closable(logger), IStreamingClient
{
protected readonly IMessenger Messenger = messenger;
private readonly int _maxSendBufferLength = ushort.MaxValue * 2;
private readonly CancellationTokenSource _clientDisposingCts = new();
private readonly IDelayer _delayer = delayer;
private readonly ILogger<StreamingClient> _logger = logger;
private readonly IMemoryProvider _memoryProvider = memoryProvider;
private readonly Guid _messageDelimiter = Guid.Parse("84da960a-54ec-47f5-a8b5-fa362221e8bf");
private readonly ConditionalWeakTable<object, Func<DtoWrapper, Task>> _messageHandlers = [];
private readonly SemaphoreSlim _sendLock = new(1);
protected readonly IMessenger Messenger = messenger;
private ClientWebSocket? _client;
private volatile int _sendBufferLength;

private enum MessageType : short
{
Dto,
Ack
}

public WebSocketState State => _client?.State ?? WebSocketState.Closed;

Expand Down Expand Up @@ -61,7 +74,7 @@ public async ValueTask DisposeAsync()
}
catch (Exception ex)
{
logger.LogError(ex, "Error while closing connection.");
_logger.LogError(ex, "Error while closing connection.");
}
finally
{
Expand All @@ -87,7 +100,39 @@ public IDisposable RegisterMessageHandler(object subscriber, Func<DtoWrapper, Ta

public async Task Send(DtoWrapper dto, CancellationToken cancellationToken)
{
await SendImpl(dto, cancellationToken);
if (!await WaitForSendBuffer(cancellationToken))
{
return;
}

using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30));
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, cts.Token);
await SendDto(dto, linkedCts.Token);
}

private async Task<bool> WaitForSendBuffer(CancellationToken cancellationToken)
{
if (_sendBufferLength < _maxSendBufferLength)
{
return true;
}

using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30));
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, cts.Token);

var waitResult = await _delayer.WaitForAsync(
() => _sendBufferLength < _maxSendBufferLength,
pollingDelay: TimeSpan.FromMilliseconds(25),
cancellationToken: linkedCts.Token);

if (waitResult)
{
return true;
}

_logger.LogError("Timed out while waiting for send buffer to drain.");
await DisposeAsync();
return false;
}

public async Task WaitForClose(CancellationToken cancellationToken)
Expand All @@ -99,15 +144,17 @@ private static MessageHeader GetHeader(byte[] buffer)
{
return new MessageHeader(
new Guid(buffer[..16]),
BitConverter.ToInt32(buffer.AsSpan()[16..20]));
(MessageType)BitConverter.ToInt16(buffer.AsSpan()[16..18]),
BitConverter.ToInt32(buffer.AsSpan()[18..22]));
}

private static byte[] GetHeaderBytes(MessageHeader header)
{
return
[
.. header.Delimiter.ToByteArray(),
.. BitConverter.GetBytes(header.DtoSize)
.. BitConverter.GetBytes((short)header.MessageType),
.. BitConverter.GetBytes(header.MessageSize)
];
}

Expand All @@ -119,6 +166,38 @@ private List<Func<DtoWrapper, Task>> GetMessageHandlers()
}
}

private void HandleAck(MessageHeader header)
{
_ = Interlocked.Add(ref _sendBufferLength, -header.MessageSize);
}

private async Task HandleDtoMessage(MessageHeader header, byte[] dtoBuffer)
{
using var dtoStream = _memoryProvider.GetRecyclableStream();

while (dtoStream.Position < header.MessageSize)
{
var result = await Client.ReceiveAsync(dtoBuffer, _clientDisposingCts.Token);

if (result.MessageType == WebSocketMessageType.Close ||
result.Count == 0)
{
_logger.LogWarning("Stream ended before DTO was complete.");
break;
}

await dtoStream.WriteAsync(dtoBuffer.AsMemory(0, result.Count));
await SendAck(result.Count);
}

dtoStream.Seek(0, SeekOrigin.Begin);

var dto = await MessagePackSerializer.DeserializeAsync<DtoWrapper>(dtoStream,
cancellationToken: _clientDisposingCts.Token);
await InvokeMessageHandlers(dto, _clientDisposingCts.Token);

}

private async Task InvokeMessageHandlers(DtoWrapper dto, CancellationToken cancellationToken)
{
var handlers = GetMessageHandlers();
Expand All @@ -136,7 +215,7 @@ private async Task InvokeMessageHandlers(DtoWrapper dto, CancellationToken cance
}
catch (Exception ex)
{
logger.LogError(ex, "Error while invoking message handler.");
_logger.LogError(ex, "Error while invoking message handler.");
}
}
}
Expand All @@ -154,74 +233,83 @@ private async Task ReadFromStream()

if (result.MessageType == WebSocketMessageType.Close)
{
logger.LogInformation("Websocket close message received.");
_logger.LogInformation("Websocket close message received.");
break;
}

var bytesRead = result.Count;

if (bytesRead < MessageHeader.Size)
{
logger.LogError("Failed to get DTO header.");
_logger.LogError("Failed to get DTO header.");
break;
}

var header = GetHeader(headerBuffer);

if (header.Delimiter != _messageDelimiter)
{
logger.LogCritical("Message header delimiter was incorrect. Value: {Delimiter}", header.Delimiter);
_logger.LogCritical("Message header delimiter was incorrect. Value: {Delimiter}", header.Delimiter);
break;
}

using var dtoStream = memoryProvider.GetRecyclableStream();

while (dtoStream.Position < header.DtoSize)
switch (header.MessageType)
{
result = await Client.ReceiveAsync(dtoBuffer, _clientDisposingCts.Token);

if (result.MessageType == WebSocketMessageType.Close ||
result.Count == 0)
{
logger.LogWarning("Stream ended before DTO was complete.");
case MessageType.Dto:
await SendAck(bytesRead);
await HandleDtoMessage(header, dtoBuffer);
break;
}

await dtoStream.WriteAsync(dtoBuffer.AsMemory(0, result.Count));
case MessageType.Ack:
HandleAck(header);
break;
default:
throw new InvalidOperationException($"Unknown message type: {header.MessageType}");
}

dtoStream.Seek(0, SeekOrigin.Begin);

var dto = await MessagePackSerializer.DeserializeAsync<DtoWrapper>(dtoStream,
cancellationToken: _clientDisposingCts.Token);
await InvokeMessageHandlers(dto, _clientDisposingCts.Token);
}
catch (OperationCanceledException)
{
logger.LogInformation("Streaming cancelled.");
_logger.LogInformation("Streaming cancelled.");
break;
}
catch (Exception ex)
{
logger.LogError(ex, "Error while reading from stream.");
_logger.LogError(ex, "Error while reading from stream.");
break;
}
}

await InvokeOnClosed();
}
private async Task SendAck(int receivedBytes)
{
await _sendLock.WaitAsync(_clientDisposingCts.Token);
try
{
var header = new MessageHeader(_messageDelimiter, MessageType.Ack, receivedBytes);
var headerBytes = GetHeaderBytes(header);
await Client.SendAsync(
headerBytes,
WebSocketMessageType.Binary,
true,
_clientDisposingCts.Token);
}
finally
{
_sendLock.Release();
}
}

private async Task SendImpl<T>(T dto, CancellationToken cancellationToken)
private async Task SendDto<T>(T dto, CancellationToken cancellationToken)
{
await _sendLock.WaitAsync(cancellationToken);
try
{
var payload = MessagePackSerializer.Serialize(dto, cancellationToken: cancellationToken);
var header = new MessageHeader(_messageDelimiter, payload.Length);

var header = new MessageHeader(_messageDelimiter, MessageType.Dto, payload.Length);
var headerBytes = GetHeaderBytes(header);

await Client.SendAsync(
GetHeaderBytes(header),
headerBytes,
WebSocketMessageType.Binary,
false,
cancellationToken);
Expand All @@ -231,6 +319,8 @@ await Client.SendAsync(
WebSocketMessageType.Binary,
true,
cancellationToken);

_ = Interlocked.Add(ref _sendBufferLength, headerBytes.Length + payload.Length);
}
finally
{
Expand All @@ -239,12 +329,25 @@ await Client.SendAsync(
}

[StructLayout(LayoutKind.Explicit)]
private struct MessageHeader(Guid delimiter, int messageSize)
private struct MessageHeader(Guid delimiter, MessageType messageType, int messageSize)
{
public const int Size = 20;

[FieldOffset(0)] public readonly Guid Delimiter = delimiter;

[FieldOffset(16)] public readonly int DtoSize = messageSize;
public const int Size = 22;

[FieldOffset(0)]
public readonly Guid Delimiter = delimiter;

[FieldOffset(16)]
public MessageType MessageType = messageType;

/// <summary>
/// <para>
/// For Dto message type, this will be the message size following the header.
/// </para>
/// <para>
/// For Ack message type, this will be the number of bytes received by the other client.
/// </para>
/// </summary>
[FieldOffset(18)]
public readonly int MessageSize = messageSize;
}
}

0 comments on commit c59ea7f

Please sign in to comment.