From 6aeb3425b01301cba57a659585b9ba780fa226e1 Mon Sep 17 00:00:00 2001 From: Sean DuBois Date: Thu, 21 Feb 2019 14:53:35 -0800 Subject: [PATCH] Move to new Track API See v2.0.0 Release Notes[0] for all changes Resolves #405 [0] https://github.com/pions/webrtc/wiki/v2.0.0-Release-Notes#media-api --- datachannel_test.go | 4 +- examples/gstreamer-receive/main.go | 17 +- examples/gstreamer-send-offer/main.go | 9 +- examples/gstreamer-send/main.go | 9 +- examples/internal/gstreamer-src/gst.go | 11 +- examples/janus-gateway/streaming/main.go | 12 +- examples/janus-gateway/video-room/main.go | 9 +- examples/save-to-disk/main.go | 15 +- examples/sfu/main.go | 54 ++--- lossy_stream.go | 93 +++++++++ peerconnection.go | 171 ++++++++-------- peerconnection_media_test.go | 73 +++++-- peerconnection_test.go | 59 +----- pkg/ice/ice_test.go | 51 +++++ rtpreceiver.go | 233 +++++++++------------- rtpsender.go | 200 +++++++++---------- rtptranceiver.go | 6 +- track.go | 218 +++++++++++++++----- 18 files changed, 711 insertions(+), 533 deletions(-) create mode 100644 lossy_stream.go create mode 100644 pkg/ice/ice_test.go diff --git a/datachannel_test.go b/datachannel_test.go index 3145668c6c0..f6e24e0afbd 100644 --- a/datachannel_test.go +++ b/datachannel_test.go @@ -156,8 +156,10 @@ func TestDataChannel_MessagesAreOrdered(t *testing.T) { out := make(chan int) inner := func(msg DataChannelMessage) { // randomly sleep - // NB: The big.Int/crypto.Rand is overkill but makes the linter happy + // math/rand a weak RNG, but this does not need to be secure. Ignore with #nosec + /* #nosec */ randInt, err := rand.Int(rand.Reader, big.NewInt(int64(max))) + /* #nosec */ if err != nil { t.Fatalf("Failed to get random sleep duration: %s", err) } diff --git a/examples/gstreamer-receive/main.go b/examples/gstreamer-receive/main.go index cdf28dface9..52e8f45907c 100644 --- a/examples/gstreamer-receive/main.go +++ b/examples/gstreamer-receive/main.go @@ -34,26 +34,31 @@ func gstreamerReceiveMain() { // Set a handler for when a new remote track starts, this handler creates a gstreamer pipeline // for the given codec - peerConnection.OnTrack(func(track *webrtc.Track) { + peerConnection.OnTrack(func(track *webrtc.Track, receiver *webrtc.RTPReceiver) { // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval // This is a temporary fix until we implement incoming RTCP events, then we would push a PLI only when a viewer requests it go func() { ticker := time.NewTicker(time.Second * 3) for range ticker.C { - err := peerConnection.SendRTCP(&rtcp.PictureLossIndication{MediaSSRC: track.SSRC}) + err := peerConnection.SendRTCP(&rtcp.PictureLossIndication{MediaSSRC: track.SSRC()}) if err != nil { fmt.Println(err) } } }() - codec := track.Codec - fmt.Printf("Track has started, of type %d: %s \n", track.PayloadType, codec.Name) + codec := track.Codec() + fmt.Printf("Track has started, of type %d: %s \n", track.PayloadType(), codec.Name) pipeline := gst.CreatePipeline(codec.Name) pipeline.Start() + buf := make([]byte, 1400) for { - p := <-track.Packets - pipeline.Push(p.Raw) + i, err := track.Read(buf) + if err != nil { + panic(err) + } + + pipeline.Push(buf[:i]) } }) diff --git a/examples/gstreamer-send-offer/main.go b/examples/gstreamer-send-offer/main.go index 5732c81d65b..33aca9f271e 100644 --- a/examples/gstreamer-send-offer/main.go +++ b/examples/gstreamer-send-offer/main.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "math/rand" "github.com/pions/webrtc" @@ -34,7 +35,7 @@ func main() { }) // Create a audio track - opusTrack, err := peerConnection.NewSampleTrack(webrtc.DefaultPayloadTypeOpus, "audio", "pion1") + opusTrack, err := peerConnection.NewTrack(webrtc.DefaultPayloadTypeOpus, rand.Uint32(), "audio", "pion1") if err != nil { panic(err) } @@ -44,7 +45,7 @@ func main() { } // Create a video track - vp8Track, err := peerConnection.NewSampleTrack(webrtc.DefaultPayloadTypeVP8, "video", "pion2") + vp8Track, err := peerConnection.NewTrack(webrtc.DefaultPayloadTypeVP8, rand.Uint32(), "video", "pion2") if err != nil { panic(err) } @@ -79,8 +80,8 @@ func main() { } // Start pushing buffers on these tracks - gst.CreatePipeline(webrtc.Opus, opusTrack.Samples, "audiotestsrc").Start() - gst.CreatePipeline(webrtc.VP8, vp8Track.Samples, "videotestsrc").Start() + gst.CreatePipeline(webrtc.Opus, opusTrack, "audiotestsrc").Start() + gst.CreatePipeline(webrtc.VP8, vp8Track, "videotestsrc").Start() // Block forever select {} diff --git a/examples/gstreamer-send/main.go b/examples/gstreamer-send/main.go index f8fb89b4f27..4218f242389 100644 --- a/examples/gstreamer-send/main.go +++ b/examples/gstreamer-send/main.go @@ -3,6 +3,7 @@ package main import ( "flag" "fmt" + "math/rand" "github.com/pions/webrtc" @@ -39,7 +40,7 @@ func main() { }) // Create a audio track - opusTrack, err := peerConnection.NewSampleTrack(webrtc.DefaultPayloadTypeOpus, "audio", "pion1") + opusTrack, err := peerConnection.NewTrack(webrtc.DefaultPayloadTypeOpus, rand.Uint32(), "audio", "pion1") if err != nil { panic(err) } @@ -49,7 +50,7 @@ func main() { } // Create a video track - vp8Track, err := peerConnection.NewSampleTrack(webrtc.DefaultPayloadTypeVP8, "video", "pion2") + vp8Track, err := peerConnection.NewTrack(webrtc.DefaultPayloadTypeVP8, rand.Uint32(), "video", "pion2") if err != nil { panic(err) } @@ -84,8 +85,8 @@ func main() { fmt.Println(signal.Encode(answer)) // Start pushing buffers on these tracks - gst.CreatePipeline(webrtc.Opus, opusTrack.Samples, *audioSrc).Start() - gst.CreatePipeline(webrtc.VP8, vp8Track.Samples, *videoSrc).Start() + gst.CreatePipeline(webrtc.Opus, opusTrack, *audioSrc).Start() + gst.CreatePipeline(webrtc.VP8, vp8Track, *videoSrc).Start() // Block forever select {} diff --git a/examples/internal/gstreamer-src/gst.go b/examples/internal/gstreamer-src/gst.go index c717ae20e48..a3baf968b0e 100644 --- a/examples/internal/gstreamer-src/gst.go +++ b/examples/internal/gstreamer-src/gst.go @@ -23,7 +23,7 @@ func init() { // Pipeline is a wrapper for a GStreamer Pipeline type Pipeline struct { Pipeline *C.GstElement - in chan<- media.Sample + track *webrtc.Track // stop acts as a signal that this pipeline is stopped // any pending sends to Pipeline.in should be cancelled stop chan interface{} @@ -35,7 +35,7 @@ var pipelines = make(map[int]*Pipeline) var pipelinesLock sync.Mutex // CreatePipeline creates a GStreamer Pipeline -func CreatePipeline(codecName string, in chan<- media.Sample, pipelineSrc string) *Pipeline { +func CreatePipeline(codecName string, track *webrtc.Track, pipelineSrc string) *Pipeline { pipelineStr := "appsink name=appsink" switch codecName { case webrtc.VP8: @@ -60,7 +60,7 @@ func CreatePipeline(codecName string, in chan<- media.Sample, pipelineSrc string pipeline := &Pipeline{ Pipeline: C.gstreamer_send_create_pipeline(pipelineStrUnsafe), - in: in, + track: track, id: len(pipelines), codecName: codecName, } @@ -105,9 +105,8 @@ func goHandlePipelineBuffer(buffer unsafe.Pointer, bufferLen C.int, duration C.i } // We need to be able to cancel this function even f pipeline.in isn't being serviced // When pipeline.stop is closed the sending of data will be cancelled. - select { - case pipeline.in <- media.Sample{Data: C.GoBytes(buffer, bufferLen), Samples: samples}: - case <-pipeline.stop: + if err := pipeline.track.WriteSample(media.Sample{Data: C.GoBytes(buffer, bufferLen), Samples: samples}); err != nil { + panic(err) } } else { fmt.Printf("discarding buffer, no pipeline with id %d", int(pipelineID)) diff --git a/examples/janus-gateway/streaming/main.go b/examples/janus-gateway/streaming/main.go index 44bbac50588..ccdef1dc2eb 100644 --- a/examples/janus-gateway/streaming/main.go +++ b/examples/janus-gateway/streaming/main.go @@ -52,8 +52,8 @@ func main() { fmt.Printf("Connection State has changed %s \n", connectionState.String()) }) - peerConnection.OnTrack(func(track *webrtc.Track) { - if track.Codec.Name == webrtc.Opus { + peerConnection.OnTrack(func(track *webrtc.Track, receiver *webrtc.RTPReceiver) { + if track.Codec().Name == webrtc.Opus { return } @@ -62,8 +62,14 @@ func main() { if err != nil { panic(err) } + for { - err = i.AddPacket(<-track.Packets) + packet, err := track.ReadRTP() + if err != nil { + panic(err) + } + + err = i.AddPacket(packet) if err != nil { panic(err) } diff --git a/examples/janus-gateway/video-room/main.go b/examples/janus-gateway/video-room/main.go index 343aeee615a..186df458b9b 100644 --- a/examples/janus-gateway/video-room/main.go +++ b/examples/janus-gateway/video-room/main.go @@ -3,6 +3,7 @@ package main import ( "fmt" "log" + "math/rand" janus "github.com/notedit/janus-go" "github.com/pions/webrtc" @@ -54,7 +55,7 @@ func main() { }) // Create a audio track - opusTrack, err := peerConnection.NewSampleTrack(webrtc.DefaultPayloadTypeOpus, "audio", "pion1") + opusTrack, err := peerConnection.NewTrack(webrtc.DefaultPayloadTypeOpus, rand.Uint32(), "audio", "pion1") if err != nil { panic(err) } @@ -64,7 +65,7 @@ func main() { } // Create a video track - vp8Track, err := peerConnection.NewSampleTrack(webrtc.DefaultPayloadTypeVP8, "video", "pion2") + vp8Track, err := peerConnection.NewTrack(webrtc.DefaultPayloadTypeVP8, rand.Uint32(), "video", "pion2") if err != nil { panic(err) } @@ -134,8 +135,8 @@ func main() { } // Start pushing buffers on these tracks - gst.CreatePipeline(webrtc.Opus, opusTrack.Samples, "audiotestsrc").Start() - gst.CreatePipeline(webrtc.VP8, vp8Track.Samples, "videotestsrc").Start() + gst.CreatePipeline(webrtc.Opus, opusTrack, "audiotestsrc").Start() + gst.CreatePipeline(webrtc.VP8, vp8Track, "videotestsrc").Start() } select {} diff --git a/examples/save-to-disk/main.go b/examples/save-to-disk/main.go index 09f4429a680..64cfe0494ff 100644 --- a/examples/save-to-disk/main.go +++ b/examples/save-to-disk/main.go @@ -44,27 +44,32 @@ func main() { // Set a handler for when a new remote track starts, this handler saves buffers to disk as // an ivf file, since we could have multiple video tracks we provide a counter. // In your application this is where you would handle/process video - peerConnection.OnTrack(func(track *webrtc.Track) { + peerConnection.OnTrack(func(track *webrtc.Track, receiver *webrtc.RTPReceiver) { // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval - // This is a temporary fix until we implement incoming RTCP events, then we would push a PLI only when a viewer requests it go func() { ticker := time.NewTicker(time.Second * 3) for range ticker.C { - err := peerConnection.SendRTCP(&rtcp.PictureLossIndication{MediaSSRC: track.SSRC}) + err := peerConnection.SendRTCP(&rtcp.PictureLossIndication{MediaSSRC: track.SSRC()}) if err != nil { fmt.Println(err) } } }() - if track.Codec.Name == webrtc.VP8 { + if track.Codec().Name == webrtc.VP8 { fmt.Println("Got VP8 track, saving to disk as output.ivf") i, err := ivfwriter.New("output.ivf") if err != nil { panic(err) } + for { - err = i.AddPacket(<-track.Packets) + packet, err := track.ReadRTP() + if err != nil { + panic(err) + } + + err = i.AddPacket(packet) if err != nil { panic(err) } diff --git a/examples/sfu/main.go b/examples/sfu/main.go index ed0d57920a2..c14da78c727 100644 --- a/examples/sfu/main.go +++ b/examples/sfu/main.go @@ -7,11 +7,9 @@ import ( "io/ioutil" "net/http" "strconv" - "sync" "time" "github.com/pions/rtcp" - "github.com/pions/rtp" "github.com/pions/webrtc" "github.com/pions/webrtc/examples/internal/signal" @@ -84,41 +82,38 @@ func main() { panic(err) } - inboundSSRC := make(chan uint32) - inboundPayloadType := make(chan uint8) - - outboundRTP := []chan<- *rtp.Packet{} - var outboundRTPLock sync.RWMutex + localTrackChan := make(chan *webrtc.Track) // Set a handler for when a new remote track starts, this just distributes all our packets // to connected peers - peerConnection.OnTrack(func(track *webrtc.Track) { + peerConnection.OnTrack(func(remoteTrack *webrtc.Track, receiver *webrtc.RTPReceiver) { // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval // This can be less wasteful by processing incoming RTCP events, then we would emit a NACK/PLI when a viewer requests it go func() { ticker := time.NewTicker(rtcpPLIInterval) for range ticker.C { - if err := peerConnection.SendRTCP(&rtcp.PictureLossIndication{MediaSSRC: track.SSRC}); err != nil { + if err := peerConnection.SendRTCP(&rtcp.PictureLossIndication{MediaSSRC: remoteTrack.SSRC()}); err != nil { fmt.Println(err) } } }() - inboundSSRC <- track.SSRC - inboundPayloadType <- track.PayloadType + // Create a local track, all our SFU clients will be fed via this track + localTrack, err := peerConnection.NewTrack(remoteTrack.PayloadType(), remoteTrack.SSRC(), "video", "pion") + if err != nil { + panic(err) + } + localTrackChan <- localTrack + rtpBuf := make([]byte, 1400) for { - rtpPacket := <-track.Packets - - outboundRTPLock.RLock() - for _, outChan := range outboundRTP { - outPacket := rtpPacket - outPacket.Payload = append([]byte{}, outPacket.Payload...) - select { - case outChan <- outPacket: - default: - } + i, err := remoteTrack.Read(rtpBuf) + if err != nil { + panic(err) + } + + if _, err = localTrack.Write(rtpBuf[:i]); err != nil { + panic(err) } - outboundRTPLock.RUnlock() } }) @@ -143,8 +138,7 @@ func main() { // Get the LocalDescription and take it to base64 so we can paste in browser fmt.Println(signal.Encode(answer)) - outboundSSRC := <-inboundSSRC - outboundPayloadType := <-inboundPayloadType + localTrack := <-localTrackChan for { fmt.Println("") fmt.Println("Curl an base64 SDP to start sendonly peer connection") @@ -158,21 +152,11 @@ func main() { panic(err) } - // Create a single VP8 Track to send videa - vp8Track, err := peerConnection.NewRawRTPTrack(outboundPayloadType, outboundSSRC, "video", "pion") - if err != nil { - panic(err) - } - - _, err = peerConnection.AddTrack(vp8Track) + _, err = peerConnection.AddTrack(localTrack) if err != nil { panic(err) } - outboundRTPLock.Lock() - outboundRTP = append(outboundRTP, vp8Track.RawRTP) - outboundRTPLock.Unlock() - // Set the remote SessionDescription err = peerConnection.SetRemoteDescription(recvOnlyOffer) if err != nil { diff --git a/lossy_stream.go b/lossy_stream.go new file mode 100644 index 00000000000..99cb0b7b56d --- /dev/null +++ b/lossy_stream.go @@ -0,0 +1,93 @@ +package webrtc + +import ( + "fmt" + "io" + "sync" +) + +// lossyReader wraps an io.Reader and discards data if it isn't read in time +// Allowing us to only deliver the newest data to the caller +type lossyReadCloser struct { + nextReader io.ReadCloser + mu sync.RWMutex + + incomingBuf chan []byte + amountRead chan int + + readError error + hasErrored chan interface{} + + closed chan interface{} +} + +func newLossyReadCloser(nextReader io.ReadCloser) *lossyReadCloser { + l := &lossyReadCloser{ + nextReader: nextReader, + + closed: make(chan interface{}), + + incomingBuf: make(chan []byte), + hasErrored: make(chan interface{}), + amountRead: make(chan int), + } + + go func() { + readBuf := make([]byte, receiveMTU) + for { + i, err := nextReader.Read(readBuf) + if err != nil { + l.mu.Lock() + l.readError = err + l.mu.Unlock() + + close(l.hasErrored) + break + } + + select { + case in := <-l.incomingBuf: + copy(in, readBuf[:i]) + l.amountRead <- i + default: // Discard if we have no inbound read + } + } + }() + + return l +} + +func (l *lossyReadCloser) Read(b []byte) (n int, err error) { + select { + case <-l.closed: + return 0, fmt.Errorf("lossyReadCloser is closed") + case <-l.hasErrored: + l.mu.RLock() + defer l.mu.RUnlock() + return 0, l.readError + + case l.incomingBuf <- b: + } + + select { + case <-l.closed: + return 0, fmt.Errorf("lossyReadCloser is closed") + case <-l.hasErrored: + l.mu.RLock() + defer l.mu.RUnlock() + return 0, l.readError + + case i := <-l.amountRead: + return i, nil + } +} + +func (l *lossyReadCloser) Close() error { + select { + case <-l.closed: + return fmt.Errorf("lossyReader is already closed") + default: + } + close(l.closed) + return l.nextReader.Close() +} diff --git a/peerconnection.go b/peerconnection.go index e77ce82b167..2237212feb4 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -13,7 +13,6 @@ import ( "time" "github.com/pions/rtcp" - "github.com/pions/rtp" "github.com/pions/sdp/v2" "github.com/pions/webrtc/pkg/ice" "github.com/pions/webrtc/pkg/logging" @@ -103,7 +102,7 @@ type PeerConnection struct { onSignalingStateChangeHandler func(SignalingState) onICEConnectionStateChangeHandler func(ICEConnectionState) - onTrackHandler func(*Track) + onTrackHandler func(*Track, *RTPReceiver) onDataChannelHandler func(*DataChannel) iceGatherer *ICEGatherer @@ -279,13 +278,13 @@ func (pc *PeerConnection) OnDataChannel(f func(*DataChannel)) { // OnTrack sets an event handler which is called when remote track // arrives from a remote peer. -func (pc *PeerConnection) OnTrack(f func(*Track)) { +func (pc *PeerConnection) OnTrack(f func(*Track, *RTPReceiver)) { pc.mu.Lock() defer pc.mu.Unlock() pc.onTrackHandler = f } -func (pc *PeerConnection) onTrack(t *Track) (done chan struct{}) { +func (pc *PeerConnection) onTrack(t *Track, r *RTPReceiver) (done chan struct{}) { pc.mu.RLock() hdlr := pc.onTrackHandler pc.mu.RUnlock() @@ -298,7 +297,7 @@ func (pc *PeerConnection) onTrack(t *Track) (done chan struct{}) { } go func() { - hdlr(t) + hdlr(t, r) close(done) }() @@ -856,18 +855,21 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error { return } - if pc.onTrackHandler != nil { - pc.openSRTP() - } else { - pcLog.Warnf("OnTrack unset, unable to handle incoming media streams") - } + pc.openSRTP() for _, tranceiver := range pc.rtpTransceivers { if tranceiver.Sender != nil { - tranceiver.Sender.Send(RTPSendParameters{ + err = tranceiver.Sender.Send(RTPSendParameters{ encodings: RTPEncodingParameters{ - RTPCodingParameters{SSRC: tranceiver.Sender.Track.SSRC, PayloadType: tranceiver.Sender.Track.PayloadType}, + RTPCodingParameters{ + SSRC: tranceiver.Sender.track.SSRC(), + PayloadType: tranceiver.Sender.track.PayloadType(), + }, }}) + + if err != nil { + pcLog.Warnf("Failed to start Sender: %s", err) + } } } @@ -931,15 +933,37 @@ func (pc *PeerConnection) openSRTP() { for i := range incomingSSRCes { go func(ssrc uint32, codecType RTPCodecType) { - receiver := pc.api.NewRTPReceiver(codecType, pc.dtlsTransport) - <-receiver.Receive(RTPReceiveParameters{ + receiver, err := pc.api.NewRTPReceiver(codecType, pc.dtlsTransport) + if err != nil { + pcLog.Warnf("Could not create RTPReceiver %s", err) + return + } + + if err = receiver.Receive(RTPReceiveParameters{ encodings: RTPDecodingParameters{ RTPCodingParameters{SSRC: ssrc}, - }}) + }}); err != nil { + pcLog.Warnf("RTPReceiver Receive failed %s", err) + return + } + + pc.newRTPTransceiver( + receiver, + nil, + RTPTransceiverDirectionRecvonly, + ) - sdpCodec, err := pc.CurrentLocalDescription.parsed.GetCodecForPayloadType(receiver.Track.PayloadType) + if err = receiver.Track().determinePayloadType(); err != nil { + pcLog.Warnf("Could not determine PayloadType for SSRC %d", receiver.Track().SSRC()) + return + } + + pc.mu.RLock() + defer pc.mu.RUnlock() + + sdpCodec, err := pc.CurrentLocalDescription.parsed.GetCodecForPayloadType(receiver.Track().PayloadType()) if err != nil { - pcLog.Warnf("no codec could be found in RemoteDescription for payloadType %d", receiver.Track.PayloadType) + pcLog.Warnf("no codec could be found in RemoteDescription for payloadType %d", receiver.Track().PayloadType()) return } @@ -949,18 +973,18 @@ func (pc *PeerConnection) openSRTP() { return } - receiver.Track.Kind = codec.Type - receiver.Track.Codec = codec - pc.newRTPTransceiver( - receiver, - nil, - RTPTransceiverDirectionRecvonly, - ) + receiver.Track().mu.Lock() + receiver.Track().kind = codec.Type + receiver.Track().codec = codec + receiver.Track().mu.Unlock() - pc.onTrack(receiver.Track) + if pc.onTrackHandler != nil { + pc.onTrack(receiver.Track(), receiver) + } else { + pcLog.Warnf("OnTrack unset, unable to handle incoming media streams") + } }(i, incomingSSRCes[i]) } - } // drainSRTP pulls and discards RTP/RTCP packets that don't match any SRTP @@ -984,20 +1008,14 @@ func (pc *PeerConnection) drainSRTP() { go func() { rtpBuf := make([]byte, receiveMTU) - rtpPacket := &rtp.Packet{} - for { - i, err := r.Read(rtpBuf) + _, rtpHeader, err := r.ReadRTP(rtpBuf) if err != nil { pcLog.Warnf("Failed to read, drainSRTP done for: %v %d \n", err, ssrc) return } - if err := rtpPacket.Unmarshal(rtpBuf[:i]); err != nil { - pcLog.Warnf("Failed to unmarshal RTP packet, discarding: %v \n", err) - continue - } - pcLog.Debugf("got RTP: %+v", rtpPacket) + pcLog.Debugf("got RTP: %+v", rtpHeader) } }() } @@ -1019,18 +1037,12 @@ func (pc *PeerConnection) drainSRTP() { go func() { rtcpBuf := make([]byte, receiveMTU) for { - i, err := r.Read(rtcpBuf) + _, header, err := r.ReadRTCP(rtcpBuf) if err != nil { pcLog.Warnf("Failed to read, drainSRTCP done for: %v %d \n", err, ssrc) return } - - rtcpPacket, _, err := rtcp.Unmarshal(rtcpBuf[:i]) - if err != nil { - pcLog.Warnf("Failed to unmarshal RTCP packet, discarding: %v \n", err) - continue - } - pcLog.Debugf("got RTCP: %+v", rtcpPacket) + pcLog.Debugf("got RTCP: %+v", header) } }() } @@ -1087,10 +1099,10 @@ func (pc *PeerConnection) GetSenders() []*RTPSender { pc.mu.Lock() defer pc.mu.Unlock() - result := make([]*RTPSender, len(pc.rtpTransceivers)) - for i, tranceiver := range pc.rtpTransceivers { + result := []*RTPSender{} + for _, tranceiver := range pc.rtpTransceivers { if tranceiver.Sender != nil { - result[i] = tranceiver.Sender + result = append(result, tranceiver.Sender) } } return result @@ -1101,11 +1113,10 @@ func (pc *PeerConnection) GetReceivers() []*RTPReceiver { pc.mu.Lock() defer pc.mu.Unlock() - result := make([]*RTPReceiver, len(pc.rtpTransceivers)) - for i, tranceiver := range pc.rtpTransceivers { + result := []*RTPReceiver{} + for _, tranceiver := range pc.rtpTransceivers { if tranceiver.Receiver != nil { - result[i] = tranceiver.Receiver - + result = append(result, tranceiver.Receiver) } } return result @@ -1125,10 +1136,10 @@ func (pc *PeerConnection) AddTrack(track *Track) (*RTPSender, error) { return nil, &rtcerr.InvalidStateError{Err: ErrConnectionClosed} } for _, transceiver := range pc.rtpTransceivers { - if transceiver.Sender.Track == nil { + if transceiver.Sender.track == nil { continue } - if track.ID == transceiver.Sender.Track.ID { + if track.ID() == transceiver.Sender.track.ID() { return nil, &rtcerr.InvalidAccessError{Err: ErrExistingTrack} } } @@ -1136,9 +1147,9 @@ func (pc *PeerConnection) AddTrack(track *Track) (*RTPSender, error) { for _, t := range pc.rtpTransceivers { if !t.stopped && // t.Sender == nil && // TODO: check that the sender has never sent - t.Sender.Track == nil && - t.Receiver.Track != nil && - t.Receiver.Track.Kind == track.Kind { + t.Sender.track == nil && + t.Receiver.Track() != nil && + t.Receiver.Track().Kind() == track.Kind() { transceiver = t break } @@ -1148,7 +1159,10 @@ func (pc *PeerConnection) AddTrack(track *Track) (*RTPSender, error) { return nil, err } } else { - sender := pc.api.NewRTPSender(track, pc.dtlsTransport) + sender, err := pc.api.NewRTPSender(track, pc.dtlsTransport) + if err != nil { + return nil, err + } transceiver = pc.newRTPTransceiver( nil, sender, @@ -1156,7 +1170,7 @@ func (pc *PeerConnection) AddTrack(track *Track) (*RTPSender, error) { ) } - transceiver.Mid = track.Kind.String() // TODO: Mid generation + transceiver.Mid = track.Kind().String() // TODO: Mid generation return transceiver.Sender, nil } @@ -1328,16 +1342,16 @@ func (pc *PeerConnection) Close() error { // Conn if one of the endpoints is closed down. To // continue the chain the Mux has to be closed. - if err := pc.dtlsTransport.Stop(); err != nil { - closeErrs = append(closeErrs, err) - } - for _, t := range pc.rtpTransceivers { if err := t.Stop(); err != nil { closeErrs = append(closeErrs, err) } } + if err := pc.dtlsTransport.Stop(); err != nil { + closeErrs = append(closeErrs, err) + } + if pc.sctpTransport != nil { if err := pc.sctpTransport.Stop(); err != nil { closeErrs = append(closeErrs, err) @@ -1421,13 +1435,13 @@ func (pc *PeerConnection) addRTPMediaSection(d *sdp.SessionDescription, codecTyp weSend := false for _, transceiver := range pc.rtpTransceivers { if transceiver.Sender == nil || - transceiver.Sender.Track == nil || - transceiver.Sender.Track.Kind != codecType { + transceiver.Sender.track == nil || + transceiver.Sender.track.Kind() != codecType { continue } weSend = true - track := transceiver.Sender.Track - media = media.WithMediaSource(track.SSRC, track.Label /* cname */, track.Label /* streamLabel */, track.Label) + track := transceiver.Sender.track + media = media.WithMediaSource(track.SSRC(), track.Label() /* cname */, track.Label() /* streamLabel */, track.Label()) } media = media.WithPropertyAttribute(localDirection(weSend, peerDirection).String()) @@ -1479,24 +1493,8 @@ func (pc *PeerConnection) addDataMediaSection(d *sdp.SessionDescription, midValu d.WithMedia(media) } -// NewRawRTPTrack Creates a new Track -// -// See NewSampleTrack for documentation -func (pc *PeerConnection) NewRawRTPTrack(payloadType uint8, ssrc uint32, id, label string) (*Track, error) { - codec, err := pc.api.mediaEngine.getCodec(payloadType) - if err != nil { - return nil, err - } else if codec.Payloader == nil { - return nil, errors.New("codec payloader not set") - } - - return NewRawRTPTrack(payloadType, ssrc, id, label, codec) -} - -// NewSampleTrack Creates a new Track -// -// See NewSampleTrack for documentation -func (pc *PeerConnection) NewSampleTrack(payloadType uint8, id, label string) (*Track, error) { +// NewTrack Creates a new Track +func (pc *PeerConnection) NewTrack(payloadType uint8, ssrc uint32, id, label string) (*Track, error) { codec, err := pc.api.mediaEngine.getCodec(payloadType) if err != nil { return nil, err @@ -1504,14 +1502,7 @@ func (pc *PeerConnection) NewSampleTrack(payloadType uint8, id, label string) (* return nil, errors.New("codec payloader not set") } - return NewSampleTrack(payloadType, id, label, codec) -} - -// NewTrack is used to create a new Track -// -// Deprecated: Use NewSampleTrack() instead -func (pc *PeerConnection) NewTrack(payloadType uint8, id, label string) (*Track, error) { - return pc.NewSampleTrack(payloadType, id, label) + return NewTrack(payloadType, ssrc, id, label, codec) } func (pc *PeerConnection) newRTPTransceiver( diff --git a/peerconnection_media_test.go b/peerconnection_media_test.go index dc25577cd8d..ce2d2f52499 100644 --- a/peerconnection_media_test.go +++ b/peerconnection_media_test.go @@ -2,6 +2,8 @@ package webrtc import ( "bytes" + "fmt" + "math/rand" "sync" "testing" "time" @@ -32,14 +34,14 @@ func TestPeerConnection_Media_Sample(t *testing.T) { awaitRTCPSenderRecv := make(chan bool) awaitRTCPSenderSend := make(chan error) - awaitRTCPRecieverRecv := make(chan bool) + awaitRTCPRecieverRecv := make(chan error) awaitRTCPRecieverSend := make(chan error) - pcAnswer.OnTrack(func(track *Track) { + pcAnswer.OnTrack(func(track *Track, receiver *RTPReceiver) { go func() { for { time.Sleep(time.Millisecond * 100) - if routineErr := pcAnswer.SendRTCP(&rtcp.RapidResynchronizationRequest{SenderSSRC: track.SSRC, MediaSSRC: track.SSRC}); routineErr != nil { + if routineErr := pcAnswer.SendRTCP(&rtcp.RapidResynchronizationRequest{SenderSSRC: track.SSRC(), MediaSSRC: track.SSRC()}); routineErr != nil { awaitRTCPRecieverSend <- routineErr return } @@ -54,14 +56,18 @@ func TestPeerConnection_Media_Sample(t *testing.T) { }() go func() { - <-track.RTCPPackets - close(awaitRTCPRecieverRecv) + _, routineErr := receiver.Read(make([]byte, 1400)) + if routineErr != nil { + awaitRTCPRecieverRecv <- routineErr + } else { + close(awaitRTCPRecieverRecv) + } }() haveClosedAwaitRTPRecv := false for { - p, ok := <-track.Packets - if !ok { + p, routineErr := track.ReadRTP() + if routineErr != nil { close(awaitRTPRecvClosed) return } else if bytes.Equal(p.Payload, []byte{0x10, 0x00}) && !haveClosedAwaitRTPRecv { @@ -71,18 +77,21 @@ func TestPeerConnection_Media_Sample(t *testing.T) { } }) - vp8Track, err := pcOffer.NewSampleTrack(DefaultPayloadTypeVP8, "video", "pion") + vp8Track, err := pcOffer.NewTrack(DefaultPayloadTypeVP8, rand.Uint32(), "video", "pion") if err != nil { t.Fatal(err) } - if _, err = pcOffer.AddTrack(vp8Track); err != nil { + rtpReceiver, err := pcOffer.AddTrack(vp8Track) + if err != nil { t.Fatal(err) } go func() { for { time.Sleep(time.Millisecond * 100) - vp8Track.Samples <- media.Sample{Data: []byte{0x00}, Samples: 1} + if routineErr := vp8Track.WriteSample(media.Sample{Data: []byte{0x00}, Samples: 1}); routineErr != nil { + fmt.Println(routineErr) + } select { case <-awaitRTPRecv: @@ -96,7 +105,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) { go func() { for { time.Sleep(time.Millisecond * 100) - if routineErr := pcOffer.SendRTCP(&rtcp.PictureLossIndication{SenderSSRC: vp8Track.SSRC, MediaSSRC: vp8Track.SSRC}); routineErr != nil { + if routineErr := pcOffer.SendRTCP(&rtcp.PictureLossIndication{SenderSSRC: vp8Track.SSRC(), MediaSSRC: vp8Track.SSRC()}); routineErr != nil { awaitRTCPSenderSend <- routineErr } @@ -110,8 +119,9 @@ func TestPeerConnection_Media_Sample(t *testing.T) { }() go func() { - <-vp8Track.RTCPPackets - close(awaitRTCPSenderRecv) + if _, routineErr := rtpReceiver.Read(make([]byte, 1400)); routineErr == nil { + close(awaitRTCPSenderRecv) + } }() err = signalPair(pcOffer, pcAnswer) @@ -158,38 +168,41 @@ This test adds an input track and asserts func TestPeerConnection_Media_Shutdown(t *testing.T) { iceComplete := make(chan bool) - api := NewAPI() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() - api.mediaEngine.RegisterDefaultCodecs() - pcOffer, pcAnswer, err := api.newPair() + pcOffer, err := NewPeerConnection(Configuration{}) + if err != nil { + t.Fatal(err) + } + + pcAnswer, err := NewPeerConnection(Configuration{}) if err != nil { t.Fatal(err) } - opusTrack, err := pcOffer.NewSampleTrack(DefaultPayloadTypeOpus, "audio", "pion1") + opusTrack, err := pcOffer.NewTrack(DefaultPayloadTypeOpus, rand.Uint32(), "audio", "pion1") if err != nil { t.Fatal(err) } - vp8Track, err := pcOffer.NewSampleTrack(DefaultPayloadTypeVP8, "video", "pion2") + vp8Track, err := pcOffer.NewTrack(DefaultPayloadTypeVP8, rand.Uint32(), "video", "pion2") if err != nil { t.Fatal(err) } if _, err = pcOffer.AddTrack(opusTrack); err != nil { t.Fatal(err) - } else if _, err = pcOffer.AddTrack(vp8Track); err != nil { + } else if _, err = pcAnswer.AddTrack(vp8Track); err != nil { t.Fatal(err) } var onTrackFiredLock sync.RWMutex onTrackFired := false - pcAnswer.OnTrack(func(track *Track) { + pcAnswer.OnTrack(func(track *Track, receiver *RTPReceiver) { onTrackFiredLock.Lock() defer onTrackFiredLock.Unlock() onTrackFired = true @@ -208,9 +221,26 @@ func TestPeerConnection_Media_Shutdown(t *testing.T) { if err != nil { t.Fatal(err) } - <-iceComplete + // Each PeerConnection should have one sender, one receiver and two transceivers + for _, pc := range []*PeerConnection{pcOffer, pcAnswer} { + senders := pc.GetSenders() + if len(senders) != 1 { + t.Errorf("Each PeerConnection should have one RTPSender, we have %d", len(senders)) + } + + receivers := pc.GetReceivers() + if len(receivers) != 1 { + t.Errorf("Each PeerConnection should have one RTPReceiver, we have %d", len(receivers)) + } + + transceivers := pc.GetTransceivers() + if len(transceivers) != 2 { + t.Errorf("Each PeerConnection should have two RTPTransceivers, we have %d", len(transceivers)) + } + } + err = pcOffer.Close() if err != nil { t.Fatal(err) @@ -226,5 +256,4 @@ func TestPeerConnection_Media_Shutdown(t *testing.T) { t.Fatalf("PeerConnection OnTrack fired even though we got no packets") } onTrackFiredLock.Unlock() - } diff --git a/peerconnection_test.go b/peerconnection_test.go index 932bb23121d..8a0ad4bcd71 100644 --- a/peerconnection_test.go +++ b/peerconnection_test.go @@ -10,9 +10,7 @@ import ( "testing" "time" - "github.com/pions/rtp" "github.com/pions/webrtc/pkg/ice" - "github.com/pions/webrtc/pkg/media" "github.com/pions/webrtc/pkg/rtcerr" "github.com/stretchr/testify/assert" @@ -431,55 +429,6 @@ func TestCreateOfferAnswer(t *testing.T) { } } -func TestPeerConnection_NewRawRTPTrack(t *testing.T) { - api := NewAPI() - api.mediaEngine.RegisterDefaultCodecs() - - pc, err := api.NewPeerConnection(Configuration{}) - assert.Nil(t, err) - - _, err = pc.NewRawRTPTrack(DefaultPayloadTypeH264, 0, "trackId", "trackLabel") - assert.NotNil(t, err) - - track, err := pc.NewRawRTPTrack(DefaultPayloadTypeH264, 123456, "trackId", "trackLabel") - assert.Nil(t, err) - - _, err = pc.AddTrack(track) - assert.Nil(t, err) - - // This channel should not be set up for a RawRTP track - assert.Panics(t, func() { - track.Samples <- media.Sample{} - }) - - assert.NotPanics(t, func() { - track.RawRTP <- &rtp.Packet{} - }) -} - -func TestPeerConnection_NewSampleTrack(t *testing.T) { - api := NewAPI() - api.mediaEngine.RegisterDefaultCodecs() - - pc, err := api.NewPeerConnection(Configuration{}) - assert.Nil(t, err) - - track, err := pc.NewSampleTrack(DefaultPayloadTypeH264, "trackId", "trackLabel") - assert.Nil(t, err) - - _, err = pc.AddTrack(track) - assert.Nil(t, err) - - // This channel should not be set up for a Sample track - assert.Panics(t, func() { - track.RawRTP <- &rtp.Packet{} - }) - - assert.NotPanics(t, func() { - track.Samples <- media.Sample{} - }) -} - func TestPeerConnection_EventHandlers(t *testing.T) { api := NewAPI() pc, err := api.NewPeerConnection(Configuration{}) @@ -490,10 +439,10 @@ func TestPeerConnection_EventHandlers(t *testing.T) { onDataChannelCalled := make(chan bool) // Verify that the noop case works - assert.NotPanics(t, func() { pc.onTrack(nil) }) + assert.NotPanics(t, func() { pc.onTrack(nil, nil) }) assert.NotPanics(t, func() { pc.onICEConnectionStateChange(ice.ConnectionStateNew) }) - pc.OnTrack(func(t *Track) { + pc.OnTrack(func(t *Track, r *RTPReceiver) { onTrackCalled <- true }) @@ -506,11 +455,11 @@ func TestPeerConnection_EventHandlers(t *testing.T) { }) // Verify that the handlers deal with nil inputs - assert.NotPanics(t, func() { pc.onTrack(nil) }) + assert.NotPanics(t, func() { pc.onTrack(nil, nil) }) assert.NotPanics(t, func() { go pc.onDataChannelHandler(nil) }) // Verify that the set handlers are called - assert.NotPanics(t, func() { pc.onTrack(&Track{}) }) + assert.NotPanics(t, func() { pc.onTrack(&Track{}, &RTPReceiver{}) }) assert.NotPanics(t, func() { pc.onICEConnectionStateChange(ice.ConnectionStateNew) }) assert.NotPanics(t, func() { go pc.onDataChannelHandler(&DataChannel{api: api}) }) diff --git a/pkg/ice/ice_test.go b/pkg/ice/ice_test.go new file mode 100644 index 00000000000..727ae5e1ace --- /dev/null +++ b/pkg/ice/ice_test.go @@ -0,0 +1,51 @@ +package ice + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConnectedState_String(t *testing.T) { + testCases := []struct { + connectionState ConnectionState + expectedString string + }{ + {ConnectionState(Unknown), "Invalid"}, + {ConnectionStateNew, "New"}, + {ConnectionStateChecking, "Checking"}, + {ConnectionStateConnected, "Connected"}, + {ConnectionStateCompleted, "Completed"}, + {ConnectionStateFailed, "Failed"}, + {ConnectionStateDisconnected, "Disconnected"}, + {ConnectionStateClosed, "Closed"}, + } + + for i, testCase := range testCases { + assert.Equal(t, + testCase.expectedString, + testCase.connectionState.String(), + "testCase: %d %v", i, testCase, + ) + } +} + +func TestGatheringState_String(t *testing.T) { + testCases := []struct { + gatheringState GatheringState + expectedString string + }{ + {GatheringState(Unknown), ErrUnknownType.Error()}, + {GatheringStateNew, "new"}, + {GatheringStateGathering, "gathering"}, + {GatheringStateComplete, "complete"}, + } + + for i, testCase := range testCases { + assert.Equal(t, + testCase.expectedString, + testCase.gatheringState.String(), + "testCase: %d %v", i, testCase, + ) + } +} diff --git a/rtpreceiver.go b/rtpreceiver.go index acfae4e77b3..96fa0e55456 100644 --- a/rtpreceiver.go +++ b/rtpreceiver.go @@ -5,8 +5,6 @@ import ( "sync" "github.com/pions/rtcp" - "github.com/pions/rtp" - "github.com/pions/srtp" ) // RTPReceiver allows an application to inspect the receipt of a Track @@ -14,149 +12,101 @@ type RTPReceiver struct { kind RTPCodecType transport *DTLSTransport - hasRecv chan bool + track *Track - Track *Track + closed, received chan interface{} + mu sync.RWMutex - closed bool - mu sync.Mutex - - rtpOut chan *rtp.Packet - rtpReadStream *srtp.ReadStreamSRTP - rtpOutDone chan struct{} - - rtcpOut chan rtcp.Packet - rtcpReadStream *srtp.ReadStreamSRTCP - rtcpOutDone chan struct{} + rtpReadStream, rtcpReadStream *lossyReadCloser // A reference to the associated api object api *API } // NewRTPReceiver constructs a new RTPReceiver -func (api *API) NewRTPReceiver(kind RTPCodecType, transport *DTLSTransport) *RTPReceiver { +func (api *API) NewRTPReceiver(kind RTPCodecType, transport *DTLSTransport) (*RTPReceiver, error) { + if transport == nil { + return nil, fmt.Errorf("DTLSTransport must not be nil") + } + return &RTPReceiver{ kind: kind, transport: transport, + api: api, + closed: make(chan interface{}), + received: make(chan interface{}), + }, nil +} - rtpOut: make(chan *rtp.Packet, 15), - rtpOutDone: make(chan struct{}), +// Track returns the RTCRtpTransceiver track +func (r *RTPReceiver) Track() *Track { + r.mu.RLock() + defer r.mu.RUnlock() + return r.track +} - rtcpOut: make(chan rtcp.Packet, 15), - rtcpOutDone: make(chan struct{}), +// Receive initialize the track and starts all the transports +func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) error { + r.mu.Lock() + defer r.mu.Unlock() + select { + case <-r.received: + return fmt.Errorf("Receive has already been called") + default: + } + close(r.received) - hasRecv: make(chan bool), + r.track = &Track{ + kind: r.kind, + ssrc: parameters.encodings.SSRC, + receiver: r, + } - api: api, + srtpSession, err := r.transport.getSRTPSession() + if err != nil { + return err } -} -// Receive blocks until the Track is available -func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) chan bool { - // TODO atomic only allow this to fire once - r.Track = &Track{ - Kind: r.kind, - SSRC: parameters.encodings.SSRC, - Packets: r.rtpOut, - RTCPPackets: r.rtcpOut, + srtpReadStream, err := srtpSession.OpenReadStream(parameters.encodings.SSRC) + if err != nil { + return err } - // RTP ReadLoop - go func() { - payloadSet := false - defer func() { - if !payloadSet { - close(r.hasRecv) - } - close(r.rtpOut) - close(r.rtpOutDone) - }() - - srtpSession, err := r.transport.getSRTPSession() - if err != nil { - pcLog.Warnf("Failed to open SRTPSession, Track done for: %v %d \n", err, parameters.encodings.SSRC) - return - } + srtcpSession, err := r.transport.getSRTCPSession() + if err != nil { + return err + } - readStream, err := srtpSession.OpenReadStream(parameters.encodings.SSRC) - if err != nil { - pcLog.Warnf("Failed to open RTCP ReadStream, Track done for: %v %d \n", err, parameters.encodings.SSRC) - return - } - r.mu.Lock() - r.rtpReadStream = readStream - r.mu.Unlock() - - readBuf := make([]byte, receiveMTU) - for { - rtpLen, err := readStream.Read(readBuf) - if err != nil { - pcLog.Warnf("Failed to read, Track done for: %v %d \n", err, parameters.encodings.SSRC) - return - } - - var rtpPacket rtp.Packet - if err = rtpPacket.Unmarshal(append([]byte{}, readBuf[:rtpLen]...)); err != nil { - pcLog.Warnf("Failed to unmarshal RTP packet, discarding: %v \n", err) - continue - } - - if !payloadSet { - r.Track.PayloadType = rtpPacket.PayloadType - payloadSet = true - close(r.hasRecv) - } - - select { - case r.rtpOut <- &rtpPacket: - default: - } - } - }() - - // RTCP ReadLoop - go func() { - defer func() { - close(r.rtcpOut) - close(r.rtcpOutDone) - }() - - srtcpSession, err := r.transport.getSRTCPSession() - if err != nil { - pcLog.Warnf("Failed to open SRTCPSession, Track done for: %v %d \n", err, parameters.encodings.SSRC) - return - } + srtcpReadStream, err := srtcpSession.OpenReadStream(parameters.encodings.SSRC) + if err != nil { + return err + } - readStream, err := srtcpSession.OpenReadStream(parameters.encodings.SSRC) - if err != nil { - pcLog.Warnf("Failed to open RTCP ReadStream, Track done for: %v %d \n", err, parameters.encodings.SSRC) - return - } - r.mu.Lock() - r.rtcpReadStream = readStream - r.mu.Unlock() - - readBuf := make([]byte, receiveMTU) - for { - rtcpLen, err := readStream.Read(readBuf) - if err != nil { - pcLog.Warnf("Failed to read, Track done for: %v %d \n", err, parameters.encodings.SSRC) - return - } - - rtcpPacket, _, err := rtcp.Unmarshal(append([]byte{}, readBuf[:rtcpLen]...)) - if err != nil { - pcLog.Warnf("Failed to unmarshal RTCP packet, discarding: %v \n", err) - continue - } - select { - case r.rtcpOut <- rtcpPacket: - default: - } - } - }() + r.rtpReadStream = newLossyReadCloser(srtpReadStream) + r.rtcpReadStream = newLossyReadCloser(srtcpReadStream) + return nil +} + +// Read reads incoming RTCP for this RTPReceiver +func (r *RTPReceiver) Read(b []byte) (n int, err error) { + select { + case <-r.closed: + return 0, fmt.Errorf("RTPSender has been stopped") + case <-r.received: + return r.rtcpReadStream.Read(b) + } +} + +// ReadRTCP is a convenience method that wraps Read and unmarshals for you +func (r *RTPReceiver) ReadRTCP() (rtcp.Packet, error) { + b := make([]byte, receiveMTU) + i, err := r.Read(b) + if err != nil { + return nil, err + } - return r.hasRecv + pkt, _, err := rtcp.Unmarshal(b[:i]) + return pkt, err } // Stop irreversibly stops the RTPReceiver @@ -164,26 +114,33 @@ func (r *RTPReceiver) Stop() error { r.mu.Lock() defer r.mu.Unlock() - if r.closed { - return fmt.Errorf("RTPReceiver has already been closed") - } - select { - case <-r.hasRecv: + case <-r.closed: + return nil default: - return fmt.Errorf("RTPReceiver has not been started") } - if err := r.rtcpReadStream.Close(); err != nil { - return err - } - if err := r.rtpReadStream.Close(); err != nil { - return err + select { + case <-r.received: + if err := r.rtcpReadStream.Close(); err != nil { + return err + } + if err := r.rtpReadStream.Close(); err != nil { + return err + } + default: } - <-r.rtcpOutDone - <-r.rtpOutDone - - r.closed = true + close(r.closed) return nil } + +// readRTP should only be called by a track, this only exists so we can keep state in one place +func (r *RTPReceiver) readRTP(b []byte) (n int, err error) { + select { + case <-r.closed: + return 0, fmt.Errorf("RTPSender has been stopped") + case <-r.received: + return r.rtpReadStream.Read(b) + } +} diff --git a/rtpsender.go b/rtpsender.go index 9c71f119ae8..021d61ee768 100644 --- a/rtpsender.go +++ b/rtpsender.go @@ -1,154 +1,146 @@ package webrtc import ( + "fmt" + "sync" + "github.com/pions/rtcp" - "github.com/pions/rtp" - "github.com/pions/webrtc/pkg/media" ) -const rtpOutboundMTU = 1400 - // RTPSender allows an application to control how a given Track is encoded and transmitted to a remote peer type RTPSender struct { - Track *Track + track *Track + rtcpReadStream *lossyReadCloser transport *DTLSTransport // A reference to the associated api object api *API + + mu sync.RWMutex + sendCalled, stopCalled chan interface{} } // NewRTPSender constructs a new RTPSender -func (api *API) NewRTPSender(track *Track, transport *DTLSTransport) *RTPSender { - r := &RTPSender{ - Track: track, - transport: transport, - api: api, +func (api *API) NewRTPSender(track *Track, transport *DTLSTransport) (*RTPSender, error) { + if track == nil { + return nil, fmt.Errorf("Track must not be nil") + } else if transport == nil { + return nil, fmt.Errorf("DTLSTransport must not be nil") } - r.Track.sampleInput = make(chan media.Sample, 15) // Is the buffering needed? - r.Track.rawInput = make(chan *rtp.Packet, 15) // Is the buffering needed? - r.Track.rtcpInput = make(chan rtcp.Packet, 15) // Is the buffering needed? - - r.Track.Samples = r.Track.sampleInput - r.Track.RawRTP = r.Track.rawInput - r.Track.RTCPPackets = r.Track.rtcpInput - - if r.Track.isRawRTP { - close(r.Track.Samples) - } else { - close(r.Track.RawRTP) + track.mu.RLock() + defer track.mu.RUnlock() + if track.receiver != nil { + return nil, fmt.Errorf("RTPSender can not be constructed with remote track") } - return r + return &RTPSender{ + track: track, + transport: transport, + api: api, + sendCalled: make(chan interface{}), + stopCalled: make(chan interface{}), + }, nil } // Send Attempts to set the parameters controlling the sending of media. -func (r *RTPSender) Send(parameters RTPSendParameters) { - if r.Track.isRawRTP { - go r.handleRawRTP(r.Track.rawInput) - } else { - go r.handleSampleRTP(r.Track.sampleInput) +func (r *RTPSender) Send(parameters RTPSendParameters) error { + r.mu.Lock() + defer r.mu.Unlock() + select { + case <-r.sendCalled: + return fmt.Errorf("Send has already been called") + default: + } + + srtcpSession, err := r.transport.getSRTCPSession() + if err != nil { + return err } + srtcpReadStream, err := srtcpSession.OpenReadStream(parameters.encodings.SSRC) + if err != nil { + return err + } + r.rtcpReadStream = newLossyReadCloser(srtcpReadStream) - go r.handleRTCP(r.transport, r.Track.rtcpInput) + r.track.mu.Lock() + r.track.senders = append(r.track.senders, r) + r.track.mu.Unlock() + + close(r.sendCalled) + return nil } // Stop irreversibly stops the RTPSender -func (r *RTPSender) Stop() { - if r.Track.isRawRTP { - close(r.Track.RawRTP) - } else { - close(r.Track.Samples) +func (r *RTPSender) Stop() error { + r.mu.Lock() + defer r.mu.Unlock() + + select { + case <-r.stopCalled: + return nil + default: } - // TODO properly tear down all loops (and test that) -} - -func (r *RTPSender) handleRawRTP(rtpPackets chan *rtp.Packet) { - for { - p, ok := <-rtpPackets - if !ok { - return + r.track.mu.Lock() + defer r.track.mu.Unlock() + filtered := []*RTPSender{} + for _, s := range r.track.senders { + if s != r { + filtered = append(filtered, s) } - - r.sendRTP(p) } -} + r.track.senders = filtered -func (r *RTPSender) handleSampleRTP(rtpPackets chan media.Sample) { - packetizer := rtp.NewPacketizer( - rtpOutboundMTU, - r.Track.PayloadType, - r.Track.SSRC, - r.Track.Codec.Payloader, - rtp.NewRandomSequencer(), - r.Track.Codec.ClockRate, - ) - - for { - in, ok := <-rtpPackets - if !ok { - return - } - packets := packetizer.Packetize(in.Data, in.Samples) - for _, p := range packets { - r.sendRTP(p) - } + select { + case <-r.sendCalled: + return r.rtcpReadStream.Close() + default: } + close(r.stopCalled) + return nil } -func (r *RTPSender) handleRTCP(transport *DTLSTransport, rtcpPackets chan rtcp.Packet) { - srtcpSession, err := transport.getSRTCPSession() - if err != nil { - pcLog.Warnf("Failed to open SRTCPSession, Track done for: %v %d \n", err, r.Track.SSRC) - return +// Read reads incoming RTCP for this RTPReceiver +func (r *RTPSender) Read(b []byte) (n int, err error) { + select { + case <-r.stopCalled: + return 0, fmt.Errorf("RTPSender has been stopped") + case <-r.sendCalled: + return r.rtcpReadStream.Read(b) } +} - readStream, err := srtcpSession.OpenReadStream(r.Track.SSRC) +// ReadRTCP is a convenience method that wraps Read and unmarshals for you +func (r *RTPSender) ReadRTCP() (rtcp.Packet, error) { + b := make([]byte, receiveMTU) + i, err := r.Read(b) if err != nil { - pcLog.Warnf("Failed to open RTCP ReadStream, Track done for: %v %d \n", err, r.Track.SSRC) - return + return nil, err } - var rtcpPacket rtcp.Packet - for { - rtcpBuf := make([]byte, receiveMTU) - i, err := readStream.Read(rtcpBuf) - if err != nil { - pcLog.Warnf("Failed to read, Track done for: %v %d \n", err, r.Track.SSRC) - return - } + pkt, _, err := rtcp.Unmarshal(b[:i]) + return pkt, err +} - rtcpPacket, _, err = rtcp.Unmarshal(rtcpBuf[:i]) +// sendRTP should only be called by a track, this only exists so we can keep state in one place +func (r *RTPSender) sendRTP(b []byte) (int, error) { + select { + case <-r.stopCalled: + return 0, fmt.Errorf("RTPSender has been stopped") + case <-r.sendCalled: + srtpSession, err := r.transport.getSRTPSession() if err != nil { - pcLog.Warnf("Failed to unmarshal RTCP packet, discarding: %v \n", err) - continue + return 0, err } - select { - case rtcpPackets <- rtcpPacket: - default: + writeStream, err := srtpSession.OpenWriteStream() + if err != nil { + return 0, err } - } - -} - -func (r *RTPSender) sendRTP(packet *rtp.Packet) { - srtpSession, err := r.transport.getSRTPSession() - if err != nil { - pcLog.Warnf("SendRTP failed to open SrtpSession: %v", err) - return - } - - writeStream, err := srtpSession.OpenWriteStream() - if err != nil { - pcLog.Warnf("SendRTP failed to open WriteStream: %v", err) - return - } - if _, err := writeStream.WriteRTP(&packet.Header, packet.Payload); err != nil { - pcLog.Warnf("SendRTP failed to write: %v", err) + return writeStream.Write(b) } } diff --git a/rtptranceiver.go b/rtptranceiver.go index de524ddc153..e090b77d7e7 100644 --- a/rtptranceiver.go +++ b/rtptranceiver.go @@ -17,7 +17,7 @@ type RTPTransceiver struct { } func (t *RTPTransceiver) setSendingTrack(track *Track) error { - t.Sender.Track = track + t.Sender.track = track switch t.Direction { case RTPTransceiverDirectionRecvonly: @@ -33,7 +33,9 @@ func (t *RTPTransceiver) setSendingTrack(track *Track) error { // Stop irreversibly stops the RTPTransceiver func (t *RTPTransceiver) Stop() error { if t.Sender != nil { - t.Sender.Stop() + if err := t.Sender.Stop(); err != nil { + return err + } } if t.Receiver != nil { if err := t.Receiver.Stop(); err != nil { diff --git a/track.go b/track.go index 4253392296a..d3f04f518b5 100644 --- a/track.go +++ b/track.go @@ -1,76 +1,186 @@ package webrtc import ( - "crypto/rand" - "encoding/binary" + "fmt" + "sync" - "github.com/pions/rtcp" "github.com/pions/rtp" "github.com/pions/webrtc/pkg/media" - "github.com/pkg/errors" ) -// Track represents a track that is communicated +const rtpOutboundMTU = 1400 + +// Track represents a single media track type Track struct { - isRawRTP bool - sampleInput chan media.Sample - rawInput chan *rtp.Packet - rtcpInput chan rtcp.Packet - - ID string - PayloadType uint8 - Kind RTPCodecType - Label string - SSRC uint32 - Codec *RTPCodec - - Packets <-chan *rtp.Packet - RTCPPackets <-chan rtcp.Packet - - Samples chan<- media.Sample - RawRTP chan<- *rtp.Packet + mu sync.RWMutex + + id string + payloadType uint8 + kind RTPCodecType + label string + ssrc uint32 + codec *RTPCodec + + packetizer rtp.Packetizer + receiver *RTPReceiver + senders []*RTPSender } -// NewRawRTPTrack initializes a new *Track configured to accept raw *rtp.Packet -// -// NB: If the source RTP stream is being broadcast to multiple tracks, each track -// must receive its own copies of the source packets in order to avoid packet corruption. -func NewRawRTPTrack(payloadType uint8, ssrc uint32, id, label string, codec *RTPCodec) (*Track, error) { - if ssrc == 0 { - return nil, errors.New("SSRC supplied to NewRawRTPTrack() must be non-zero") +// ID gets the ID of the track +func (t *Track) ID() string { + t.mu.RLock() + defer t.mu.RUnlock() + return t.id +} + +// PayloadType gets the PayloadType of the track +func (t *Track) PayloadType() uint8 { + t.mu.RLock() + defer t.mu.RUnlock() + return t.payloadType +} + +// Kind gets the Kind of the track +func (t *Track) Kind() RTPCodecType { + t.mu.RLock() + defer t.mu.RUnlock() + return t.kind +} + +// Label gets the Label of the track +func (t *Track) Label() string { + t.mu.RLock() + defer t.mu.RUnlock() + return t.label +} + +// SSRC gets the SSRC of the track +func (t *Track) SSRC() uint32 { + t.mu.RLock() + defer t.mu.RUnlock() + return t.ssrc +} + +// Codec gets the Codec of the track +func (t *Track) Codec() *RTPCodec { + t.mu.RLock() + defer t.mu.RUnlock() + return t.codec +} + +// Read reads data from the track. If this is a local track this will error +func (t *Track) Read(b []byte) (n int, err error) { + t.mu.RLock() + if len(t.senders) != 0 { + t.mu.RUnlock() + return 0, fmt.Errorf("this is a local track and must not be read from") } + r := t.receiver + t.mu.RUnlock() - return &Track{ - isRawRTP: true, - - ID: id, - PayloadType: payloadType, - Kind: codec.Type, - Label: label, - SSRC: ssrc, - Codec: codec, - }, nil + return r.readRTP(b) } -// NewSampleTrack initializes a new *Track configured to accept media.Sample -func NewSampleTrack(payloadType uint8, id, label string, codec *RTPCodec) (*Track, error) { - if codec == nil { - return nil, errors.New("codec supplied to NewSampleTrack() must not be nil") +// ReadRTP is a convenience method that wraps Read and unmarshals for you +func (t *Track) ReadRTP() (*rtp.Packet, error) { + b := make([]byte, receiveMTU) + i, err := t.Read(b) + if err != nil { + return nil, err } - buf := make([]byte, 4) - if _, err := rand.Read(buf); err != nil { - return nil, errors.New("failed to generate random value") + r := &rtp.Packet{} + if err := r.Unmarshal(b[:i]); err != nil { + return nil, err } + return r, nil +} + +// Write writes data to the track. If this is a remote track this will error +func (t *Track) Write(b []byte) (n int, err error) { + t.mu.RLock() + if t.receiver != nil { + t.mu.RUnlock() + return 0, fmt.Errorf("this is a remote track and must not be written to") + } + senders := t.senders + t.mu.RUnlock() + + for _, s := range senders { + if _, err := s.sendRTP(b); err != nil { + return 0, err + } + } + + return len(b), nil +} + +// WriteSample packetizes and writes to the track +func (t *Track) WriteSample(s media.Sample) error { + packets := t.packetizer.Packetize(s.Data, s.Samples) + for _, p := range packets { + buf, err := p.Marshal() + if err != nil { + return err + } + if _, err := t.Write(buf); err != nil { + return err + } + } + + return nil +} + +// WriteRTP writes RTP packets to the track +func (t *Track) WriteRTP(p *rtp.Packet) error { + buf, err := p.Marshal() + if err != nil { + return err + } + if _, err := t.Write(buf); err != nil { + return err + } + + return nil +} + +// NewTrack initializes a new *Track +func NewTrack(payloadType uint8, ssrc uint32, id, label string, codec *RTPCodec) (*Track, error) { + if ssrc == 0 { + return nil, fmt.Errorf("SSRC supplied to NewTrack() must be non-zero") + } + + packetizer := rtp.NewPacketizer( + rtpOutboundMTU, + payloadType, + ssrc, + codec.Payloader, + rtp.NewRandomSequencer(), + codec.ClockRate, + ) return &Track{ - isRawRTP: false, - - ID: id, - PayloadType: payloadType, - Kind: codec.Type, - Label: label, - SSRC: binary.LittleEndian.Uint32(buf), - Codec: codec, + id: id, + payloadType: payloadType, + kind: codec.Type, + label: label, + ssrc: ssrc, + codec: codec, + packetizer: packetizer, }, nil } + +// determinePayloadType blocks and reads a single packet to determine the PayloadType for this Track +// this is useful if we are dealing with a remote track and we can't announce it to the user until we know the payloadType +func (t *Track) determinePayloadType() error { + r, err := t.ReadRTP() + if err != nil { + return err + } + + t.mu.Lock() + t.payloadType = r.PayloadType + defer t.mu.Unlock() + + return nil +}