Skip to content

Commit

Permalink
peer_connection: Make sure all packets are read through interceptor (#…
Browse files Browse the repository at this point in the history
…648)

Instead of reading first packet when probing simulcast directly from RTP
stream, rather let SRTP session accept() fn return the RTP header for
the first packet. This makes us able to configure interceptor and let
all packets travel through the regular path.

Fixes #391
  • Loading branch information
haaspors authored Jan 26, 2025
1 parent 2a5393d commit 0aa7c07
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 109 deletions.
29 changes: 16 additions & 13 deletions srtp/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const DEFAULT_SESSION_SRTCP_REPLAY_PROTECTION_WINDOW: usize = 64;
pub struct Session {
local_context: Arc<Mutex<Context>>,
streams_map: Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
new_stream_rx: Arc<Mutex<mpsc::Receiver<Arc<Stream>>>>,
new_stream_rx: Arc<Mutex<mpsc::Receiver<(Arc<Stream>, Option<rtp::header::Header>)>>>,
close_stream_tx: mpsc::Sender<u32>,
close_session_tx: mpsc::Sender<()>,
pub(crate) udp_tx: Arc<dyn Conn + Send + Sync>,
Expand Down Expand Up @@ -128,7 +128,7 @@ impl Session {
buf: &mut [u8],
streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
close_stream_tx: &mpsc::Sender<u32>,
new_stream_tx: &mut mpsc::Sender<Arc<Stream>>,
new_stream_tx: &mut mpsc::Sender<(Arc<Stream>, Option<rtp::header::Header>)>,
remote_context: &mut Context,
is_rtp: bool,
) -> Result<()> {
Expand All @@ -144,24 +144,28 @@ impl Session {
};

let mut buf = &decrypted[..];
let ssrcs = if is_rtp {
vec![rtp::header::Header::unmarshal(&mut buf)?.ssrc]
let (ssrcs, header) = if is_rtp {
let header = rtp::header::Header::unmarshal(&mut buf)?;
(vec![header.ssrc], Some(header))
} else {
let pkts = rtcp::packet::unmarshal(&mut buf)?;
destination_ssrc(&pkts)
(destination_ssrc(&pkts), None)
};

for ssrc in ssrcs {
let (stream, is_new) =
Session::get_or_create_stream(streams_map, close_stream_tx.clone(), is_rtp, ssrc)
.await;

if is_new {
log::trace!(
"srtp session got new {} stream {}",
if is_rtp { "rtp" } else { "rtcp" },
ssrc
);
new_stream_tx.send(Arc::clone(&stream)).await?;
new_stream_tx
.send((Arc::clone(&stream), header.clone()))
.await?;
}

match stream.buffer.write(&decrypted).await {
Expand Down Expand Up @@ -210,14 +214,13 @@ impl Session {
}

/// accept returns a stream to handle RTCP for a single SSRC
pub async fn accept(&self) -> Result<Arc<Stream>> {
pub async fn accept(&self) -> Result<(Arc<Stream>, Option<rtp::header::Header>)> {
let mut new_stream_rx = self.new_stream_rx.lock().await;
let result = new_stream_rx.recv().await;
if let Some(stream) = result {
Ok(stream)
} else {
Err(Error::SessionSrtpAlreadyClosed)
}

new_stream_rx
.recv()
.await
.ok_or(Error::SessionSrtpAlreadyClosed)
}

pub async fn close(&self) -> Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion srtp/src/session/session_rtcp_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async fn test_session_srtcp_accept() -> Result<()> {
let test_payload = rtcp_packet.marshal()?;
sa.write_rtcp(&rtcp_packet).await?;

let read_stream = sb.accept().await?;
let (read_stream, _) = sb.accept().await?;
let ssrc = read_stream.get_ssrc();
assert_eq!(
ssrc, TEST_SSRC,
Expand Down
5 changes: 4 additions & 1 deletion srtp/src/session/session_rtp_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,22 @@ async fn test_session_srtp_accept() -> Result<()> {

let packet = rtp::packet::Packet {
header: rtp::header::Header {
version: 2,
ssrc: TEST_SSRC,
payload_type: 96,
..Default::default()
},
payload: test_payload.clone(),
};
sa.write_rtp(&packet).await?;

let read_stream = sb.accept().await?;
let (read_stream, header) = sb.accept().await?;
let ssrc = read_stream.get_ssrc();
assert_eq!(
ssrc, TEST_SSRC,
"SSRC mismatch during accept exp({TEST_SSRC}) actual({ssrc})"
);
assert_eq!(header, Some(packet.header));

read_stream.read(&mut read_buffer).await?;

Expand Down
3 changes: 1 addition & 2 deletions webrtc/src/peer_connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ use crate::rtp_transceiver::rtp_receiver::RTCRtpReceiver;
use crate::rtp_transceiver::rtp_sender::RTCRtpSender;
use crate::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection;
use crate::rtp_transceiver::{
find_by_mid, handle_unknown_rtp_packet, satisfy_type_and_direction, RTCRtpTransceiver,
RTCRtpTransceiverInit, SSRC,
find_by_mid, satisfy_type_and_direction, RTCRtpTransceiver, RTCRtpTransceiverInit, SSRC,
};
use crate::sctp_transport::sctp_transport_capabilities::SCTPTransportCapabilities;
use crate::sctp_transport::sctp_transport_state::RTCSctpTransportState;
Expand Down
124 changes: 71 additions & 53 deletions webrtc/src/peer_connection/peer_connection_internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::VecDeque;
use std::sync::Weak;

use super::*;
use crate::rtp_transceiver::create_stream_info;
use crate::rtp_transceiver::{create_stream_info, PayloadType};
use crate::stats::stats_collector::StatsCollector;
use crate::stats::{
InboundRTPStats, OutboundRTPStats, RTCStatsType, RemoteInboundRTPStats, RemoteOutboundRTPStats,
Expand All @@ -15,7 +15,6 @@ use arc_swap::ArcSwapOption;
use portable_atomic::AtomicIsize;
use smol_str::SmolStr;
use tokio::time::Instant;
use util::Unmarshal;

pub(crate) struct PeerConnectionInternal {
/// a value containing the last known greater mid value
Expand Down Expand Up @@ -309,8 +308,12 @@ impl PeerConnectionInternal {
}
};

let stream = match srtp_session.accept().await {
Ok(stream) => stream,
let (stream, header) = match srtp_session.accept().await {
Ok((stream, Some(header))) => (stream, header),
Ok((_, None)) => {
log::error!("Accepting RTP session, without RTP header?");
return;
}
Err(err) => {
log::warn!("Failed to accept RTP {}", err);
return;
Expand Down Expand Up @@ -338,16 +341,16 @@ impl PeerConnectionInternal {
let pci = Arc::clone(&pci);
tokio::spawn(async move {
let ssrc = stream.get_ssrc();

dtls_transport
.store_simulcast_stream(ssrc, Arc::clone(&stream))
.await;

if let Err(err) = pci.handle_incoming_ssrc(stream, ssrc).await {
if let Err(err) = pci
.handle_incoming_rtp_stream(stream, header.payload_type)
.await
{
log::warn!(
"Incoming unhandled RTP ssrc({}), on_track will not be fired. {}",
ssrc,
err
"Incoming unhandled RTP ssrc({ssrc}), on_track will not be fired. {err}"
);
}

Expand All @@ -370,17 +373,18 @@ impl PeerConnectionInternal {
}
};

let stream = match srtcp_session.accept().await {
Ok(stream) => stream,
match srtcp_session.accept().await {
Ok((stream, _)) => {
let ssrc = stream.get_ssrc();
log::warn!(
"Incoming unhandled RTCP ssrc({ssrc}), on_track will not be fired"
);
}
Err(err) => {
log::warn!("Failed to accept RTCP {}", err);
log::warn!("Failed to accept RTCP {err}");
return;
}
};
log::warn!(
"Incoming unhandled RTCP ssrc({}), on_track will not be fired",
stream.get_ssrc()
);
}
});
}
Expand Down Expand Up @@ -1002,18 +1006,18 @@ impl PeerConnectionInternal {
Ok(true)
}

async fn handle_incoming_ssrc(
async fn handle_incoming_rtp_stream(
self: &Arc<Self>,
rtp_stream: Arc<Stream>,
ssrc: SSRC,
payload_type: PayloadType,
) -> Result<()> {
let ssrc = rtp_stream.get_ssrc();
let parsed = match self.remote_description().await.and_then(|rd| rd.parsed) {
Some(r) => r,
None => return Err(Error::ErrPeerConnRemoteDescriptionNil),
};
// If the remote SDP was only one media section the ssrc doesn't have to be explicitly declared
let handled = self.handle_undeclared_ssrc(ssrc, &parsed).await?;
if handled {
if self.handle_undeclared_ssrc(ssrc, &parsed).await? {
return Ok(());
}

Expand Down Expand Up @@ -1046,26 +1050,6 @@ impl PeerConnectionInternal {
})
.await;

// Packets that we read as part of simulcast probing that we need to make available
// if we do find a track later.
let mut buffered_packets: VecDeque<(rtp::packet::Packet, Attributes)> = VecDeque::default();

let mut buf = vec![0u8; self.setting_engine.get_receive_mtu()];
let n = rtp_stream.read(&mut buf).await?;
let mut b = &buf[..n];

let (mut mid, mut rid, mut rsid, payload_type) = handle_unknown_rtp_packet(
b,
mid_extension_id as u8,
sid_extension_id as u8,
rsid_extension_id as u8,
)?;

let packet = rtp::packet::Packet::unmarshal(&mut b).unwrap();

// TODO: Can we have attributes on the first packets?
buffered_packets.push_back((packet, Attributes::new()));

let params = self
.media_engine
.get_rtp_parameters_by_payload_type(payload_type)
Expand All @@ -1089,21 +1073,24 @@ impl PeerConnectionInternal {
.streams_for_ssrc(ssrc, &stream_info, &icpr)
.await?;

let a = Attributes::new();
// Packets that we read as part of simulcast probing that we need to make available
// if we do find a track later.
let mut buffered_packets: VecDeque<(rtp::packet::Packet, Attributes)> = VecDeque::default();
let mut buf = vec![0u8; self.setting_engine.get_receive_mtu()];

for _ in 0..=SIMULCAST_PROBE_COUNT {
let (pkt, a) = rtp_interceptor
.read(&mut buf, &stream_info.attributes)
.await?;
let (mid, rid, rsid) = get_stream_mid_rid(
&pkt.header,
mid_extension_id as u8,
sid_extension_id as u8,
rsid_extension_id as u8,
)?;
buffered_packets.push_back((pkt, a.clone()));

if mid.is_empty() || (rid.is_empty() && rsid.is_empty()) {
let (pkt, _) = rtp_interceptor.read(&mut buf, &a).await?;
let (m, r, rs, _) = handle_unknown_rtp_packet(
&buf[..n],
mid_extension_id as u8,
sid_extension_id as u8,
rsid_extension_id as u8,
)?;
mid = m;
rid = r;
rsid = rs;

buffered_packets.push_back((pkt, a.clone()));
continue;
}

Expand Down Expand Up @@ -1544,3 +1531,34 @@ fn capitalize(s: &str) -> String {

result
}

fn get_stream_mid_rid(
header: &rtp::header::Header,
mid_extension_id: u8,
sid_extension_id: u8,
rsid_extension_id: u8,
) -> Result<(String, String, String)> {
if !header.extension {
return Ok((String::new(), String::new(), String::new()));
}

let mid = if let Some(payload) = header.get_extension(mid_extension_id) {
String::from_utf8(payload.to_vec())?
} else {
String::new()
};

let rid = if let Some(payload) = header.get_extension(sid_extension_id) {
String::from_utf8(payload.to_vec())?
} else {
String::new()
};

let srid = if let Some(payload) = header.get_extension(rsid_extension_id) {
String::from_utf8(payload.to_vec())?
} else {
String::new()
};

Ok((mid, rid, srid))
}
39 changes: 0 additions & 39 deletions webrtc/src/rtp_transceiver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use portable_atomic::{AtomicBool, AtomicU8};
use serde::{Deserialize, Serialize};
use smol_str::SmolStr;
use tokio::sync::{Mutex, OnceCell};
use util::Unmarshal;

use crate::api::media_engine::MediaEngine;
use crate::error::{Error, Result};
Expand Down Expand Up @@ -523,41 +522,3 @@ pub(crate) async fn satisfy_type_and_direction(

None
}

/// handle_unknown_rtp_packet consumes a single RTP Packet and returns information that is helpful
/// for demuxing and handling an unknown SSRC (usually for Simulcast)
pub(crate) fn handle_unknown_rtp_packet(
buf: &[u8],
mid_extension_id: u8,
sid_extension_id: u8,
rsid_extension_id: u8,
) -> Result<(String, String, String, PayloadType)> {
let mut reader = buf;
let rp = rtp::packet::Packet::unmarshal(&mut reader)?;

if !rp.header.extension {
return Ok((String::new(), String::new(), String::new(), 0));
}

let payload_type = rp.header.payload_type;

let mid = if let Some(payload) = rp.header.get_extension(mid_extension_id) {
String::from_utf8(payload.to_vec())?
} else {
String::new()
};

let rid = if let Some(payload) = rp.header.get_extension(sid_extension_id) {
String::from_utf8(payload.to_vec())?
} else {
String::new()
};

let srid = if let Some(payload) = rp.header.get_extension(rsid_extension_id) {
String::from_utf8(payload.to_vec())?
} else {
String::new()
};

Ok((mid, rid, srid, payload_type))
}

0 comments on commit 0aa7c07

Please sign in to comment.