Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

peer_connection: Make sure all packets are read through interceptor #648

Merged
merged 1 commit into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions srtp/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const DEFAULT_SESSION_SRTCP_REPLAY_PROTECTION_WINDOW: usize = 64;
pub struct Session {
local_context: Arc<Mutex<Context>>,
streams_map: Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
new_stream_rx: Arc<Mutex<mpsc::Receiver<Arc<Stream>>>>,
new_stream_rx: Arc<Mutex<mpsc::Receiver<(Arc<Stream>, Option<rtp::header::Header>)>>>,
close_stream_tx: mpsc::Sender<u32>,
close_session_tx: mpsc::Sender<()>,
pub(crate) udp_tx: Arc<dyn Conn + Send + Sync>,
Expand Down Expand Up @@ -128,7 +128,7 @@ impl Session {
buf: &mut [u8],
streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
close_stream_tx: &mpsc::Sender<u32>,
new_stream_tx: &mut mpsc::Sender<Arc<Stream>>,
new_stream_tx: &mut mpsc::Sender<(Arc<Stream>, Option<rtp::header::Header>)>,
remote_context: &mut Context,
is_rtp: bool,
) -> Result<()> {
Expand All @@ -144,24 +144,28 @@ impl Session {
};

let mut buf = &decrypted[..];
let ssrcs = if is_rtp {
vec![rtp::header::Header::unmarshal(&mut buf)?.ssrc]
let (ssrcs, header) = if is_rtp {
let header = rtp::header::Header::unmarshal(&mut buf)?;
(vec![header.ssrc], Some(header))
} else {
let pkts = rtcp::packet::unmarshal(&mut buf)?;
destination_ssrc(&pkts)
(destination_ssrc(&pkts), None)
};

for ssrc in ssrcs {
let (stream, is_new) =
Session::get_or_create_stream(streams_map, close_stream_tx.clone(), is_rtp, ssrc)
.await;

if is_new {
log::trace!(
"srtp session got new {} stream {}",
if is_rtp { "rtp" } else { "rtcp" },
ssrc
);
new_stream_tx.send(Arc::clone(&stream)).await?;
new_stream_tx
.send((Arc::clone(&stream), header.clone()))
.await?;
}

match stream.buffer.write(&decrypted).await {
Expand Down Expand Up @@ -210,14 +214,13 @@ impl Session {
}

/// accept returns a stream to handle RTCP for a single SSRC
pub async fn accept(&self) -> Result<Arc<Stream>> {
pub async fn accept(&self) -> Result<(Arc<Stream>, Option<rtp::header::Header>)> {
let mut new_stream_rx = self.new_stream_rx.lock().await;
let result = new_stream_rx.recv().await;
if let Some(stream) = result {
Ok(stream)
} else {
Err(Error::SessionSrtpAlreadyClosed)
}

new_stream_rx
.recv()
.await
.ok_or(Error::SessionSrtpAlreadyClosed)
}

pub async fn close(&self) -> Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion srtp/src/session/session_rtcp_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async fn test_session_srtcp_accept() -> Result<()> {
let test_payload = rtcp_packet.marshal()?;
sa.write_rtcp(&rtcp_packet).await?;

let read_stream = sb.accept().await?;
let (read_stream, _) = sb.accept().await?;
let ssrc = read_stream.get_ssrc();
assert_eq!(
ssrc, TEST_SSRC,
Expand Down
5 changes: 4 additions & 1 deletion srtp/src/session/session_rtp_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,22 @@ async fn test_session_srtp_accept() -> Result<()> {

let packet = rtp::packet::Packet {
header: rtp::header::Header {
version: 2,
ssrc: TEST_SSRC,
payload_type: 96,
..Default::default()
},
payload: test_payload.clone(),
};
sa.write_rtp(&packet).await?;

let read_stream = sb.accept().await?;
let (read_stream, header) = sb.accept().await?;
let ssrc = read_stream.get_ssrc();
assert_eq!(
ssrc, TEST_SSRC,
"SSRC mismatch during accept exp({TEST_SSRC}) actual({ssrc})"
);
assert_eq!(header, Some(packet.header));

read_stream.read(&mut read_buffer).await?;

Expand Down
3 changes: 1 addition & 2 deletions webrtc/src/peer_connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ use crate::rtp_transceiver::rtp_receiver::RTCRtpReceiver;
use crate::rtp_transceiver::rtp_sender::RTCRtpSender;
use crate::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection;
use crate::rtp_transceiver::{
find_by_mid, handle_unknown_rtp_packet, satisfy_type_and_direction, RTCRtpTransceiver,
RTCRtpTransceiverInit, SSRC,
find_by_mid, satisfy_type_and_direction, RTCRtpTransceiver, RTCRtpTransceiverInit, SSRC,
};
use crate::sctp_transport::sctp_transport_capabilities::SCTPTransportCapabilities;
use crate::sctp_transport::sctp_transport_state::RTCSctpTransportState;
Expand Down
124 changes: 71 additions & 53 deletions webrtc/src/peer_connection/peer_connection_internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::VecDeque;
use std::sync::Weak;

use super::*;
use crate::rtp_transceiver::create_stream_info;
use crate::rtp_transceiver::{create_stream_info, PayloadType};
use crate::stats::stats_collector::StatsCollector;
use crate::stats::{
InboundRTPStats, OutboundRTPStats, RTCStatsType, RemoteInboundRTPStats, RemoteOutboundRTPStats,
Expand All @@ -15,7 +15,6 @@ use arc_swap::ArcSwapOption;
use portable_atomic::AtomicIsize;
use smol_str::SmolStr;
use tokio::time::Instant;
use util::Unmarshal;

pub(crate) struct PeerConnectionInternal {
/// a value containing the last known greater mid value
Expand Down Expand Up @@ -309,8 +308,12 @@ impl PeerConnectionInternal {
}
};

let stream = match srtp_session.accept().await {
Ok(stream) => stream,
let (stream, header) = match srtp_session.accept().await {
Ok((stream, Some(header))) => (stream, header),
Ok((_, None)) => {
log::error!("Accepting RTP session, without RTP header?");
return;
}
Err(err) => {
log::warn!("Failed to accept RTP {}", err);
return;
Expand Down Expand Up @@ -338,16 +341,16 @@ impl PeerConnectionInternal {
let pci = Arc::clone(&pci);
tokio::spawn(async move {
let ssrc = stream.get_ssrc();

dtls_transport
.store_simulcast_stream(ssrc, Arc::clone(&stream))
.await;

if let Err(err) = pci.handle_incoming_ssrc(stream, ssrc).await {
if let Err(err) = pci
.handle_incoming_rtp_stream(stream, header.payload_type)
.await
{
log::warn!(
"Incoming unhandled RTP ssrc({}), on_track will not be fired. {}",
ssrc,
err
"Incoming unhandled RTP ssrc({ssrc}), on_track will not be fired. {err}"
);
}

Expand All @@ -370,17 +373,18 @@ impl PeerConnectionInternal {
}
};

let stream = match srtcp_session.accept().await {
Ok(stream) => stream,
match srtcp_session.accept().await {
Ok((stream, _)) => {
let ssrc = stream.get_ssrc();
log::warn!(
"Incoming unhandled RTCP ssrc({ssrc}), on_track will not be fired"
);
}
Err(err) => {
log::warn!("Failed to accept RTCP {}", err);
log::warn!("Failed to accept RTCP {err}");
return;
}
};
log::warn!(
"Incoming unhandled RTCP ssrc({}), on_track will not be fired",
stream.get_ssrc()
);
}
});
}
Expand Down Expand Up @@ -1002,18 +1006,18 @@ impl PeerConnectionInternal {
Ok(true)
}

async fn handle_incoming_ssrc(
async fn handle_incoming_rtp_stream(
self: &Arc<Self>,
rtp_stream: Arc<Stream>,
ssrc: SSRC,
payload_type: PayloadType,
) -> Result<()> {
let ssrc = rtp_stream.get_ssrc();
let parsed = match self.remote_description().await.and_then(|rd| rd.parsed) {
Some(r) => r,
None => return Err(Error::ErrPeerConnRemoteDescriptionNil),
};
// If the remote SDP was only one media section the ssrc doesn't have to be explicitly declared
let handled = self.handle_undeclared_ssrc(ssrc, &parsed).await?;
if handled {
if self.handle_undeclared_ssrc(ssrc, &parsed).await? {
return Ok(());
}

Expand Down Expand Up @@ -1046,26 +1050,6 @@ impl PeerConnectionInternal {
})
.await;

// Packets that we read as part of simulcast probing that we need to make available
// if we do find a track later.
let mut buffered_packets: VecDeque<(rtp::packet::Packet, Attributes)> = VecDeque::default();

let mut buf = vec![0u8; self.setting_engine.get_receive_mtu()];
let n = rtp_stream.read(&mut buf).await?;
let mut b = &buf[..n];

let (mut mid, mut rid, mut rsid, payload_type) = handle_unknown_rtp_packet(
b,
mid_extension_id as u8,
sid_extension_id as u8,
rsid_extension_id as u8,
)?;

let packet = rtp::packet::Packet::unmarshal(&mut b).unwrap();

// TODO: Can we have attributes on the first packets?
buffered_packets.push_back((packet, Attributes::new()));

let params = self
.media_engine
.get_rtp_parameters_by_payload_type(payload_type)
Expand All @@ -1089,21 +1073,24 @@ impl PeerConnectionInternal {
.streams_for_ssrc(ssrc, &stream_info, &icpr)
.await?;

let a = Attributes::new();
// Packets that we read as part of simulcast probing that we need to make available
// if we do find a track later.
let mut buffered_packets: VecDeque<(rtp::packet::Packet, Attributes)> = VecDeque::default();
let mut buf = vec![0u8; self.setting_engine.get_receive_mtu()];

for _ in 0..=SIMULCAST_PROBE_COUNT {
let (pkt, a) = rtp_interceptor
.read(&mut buf, &stream_info.attributes)
.await?;
let (mid, rid, rsid) = get_stream_mid_rid(
&pkt.header,
mid_extension_id as u8,
sid_extension_id as u8,
rsid_extension_id as u8,
)?;
buffered_packets.push_back((pkt, a.clone()));

if mid.is_empty() || (rid.is_empty() && rsid.is_empty()) {
let (pkt, _) = rtp_interceptor.read(&mut buf, &a).await?;
let (m, r, rs, _) = handle_unknown_rtp_packet(
&buf[..n],
mid_extension_id as u8,
sid_extension_id as u8,
rsid_extension_id as u8,
)?;
mid = m;
rid = r;
rsid = rs;

buffered_packets.push_back((pkt, a.clone()));
continue;
}

Expand Down Expand Up @@ -1544,3 +1531,34 @@ fn capitalize(s: &str) -> String {

result
}

fn get_stream_mid_rid(
header: &rtp::header::Header,
mid_extension_id: u8,
sid_extension_id: u8,
rsid_extension_id: u8,
) -> Result<(String, String, String)> {
if !header.extension {
return Ok((String::new(), String::new(), String::new()));
}

let mid = if let Some(payload) = header.get_extension(mid_extension_id) {
String::from_utf8(payload.to_vec())?
} else {
String::new()
};

let rid = if let Some(payload) = header.get_extension(sid_extension_id) {
String::from_utf8(payload.to_vec())?
} else {
String::new()
};

let srid = if let Some(payload) = header.get_extension(rsid_extension_id) {
String::from_utf8(payload.to_vec())?
} else {
String::new()
};

Ok((mid, rid, srid))
}
39 changes: 0 additions & 39 deletions webrtc/src/rtp_transceiver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use portable_atomic::{AtomicBool, AtomicU8};
use serde::{Deserialize, Serialize};
use smol_str::SmolStr;
use tokio::sync::{Mutex, OnceCell};
use util::Unmarshal;

use crate::api::media_engine::MediaEngine;
use crate::error::{Error, Result};
Expand Down Expand Up @@ -523,41 +522,3 @@ pub(crate) async fn satisfy_type_and_direction(

None
}

/// handle_unknown_rtp_packet consumes a single RTP Packet and returns information that is helpful
/// for demuxing and handling an unknown SSRC (usually for Simulcast)
pub(crate) fn handle_unknown_rtp_packet(
buf: &[u8],
mid_extension_id: u8,
sid_extension_id: u8,
rsid_extension_id: u8,
) -> Result<(String, String, String, PayloadType)> {
let mut reader = buf;
let rp = rtp::packet::Packet::unmarshal(&mut reader)?;

if !rp.header.extension {
return Ok((String::new(), String::new(), String::new(), 0));
}

let payload_type = rp.header.payload_type;

let mid = if let Some(payload) = rp.header.get_extension(mid_extension_id) {
String::from_utf8(payload.to_vec())?
} else {
String::new()
};

let rid = if let Some(payload) = rp.header.get_extension(sid_extension_id) {
String::from_utf8(payload.to_vec())?
} else {
String::new()
};

let srid = if let Some(payload) = rp.header.get_extension(rsid_extension_id) {
String::from_utf8(payload.to_vec())?
} else {
String::new()
};

Ok((mid, rid, srid, payload_type))
}
Loading