30
30
31
31
class VideoStreamTrack (MediaStreamTrack ):
32
32
kind = "video"
33
+ def __init__ (self , track : MediaStreamTrack , pipeline ):
34
+ super ().__init__ ()
35
+ self .track = track
36
+ self .pipeline = pipeline
37
+ asyncio .create_task (self .collect_frames ())
38
+
39
+ async def collect_frames (self ):
40
+ while True :
41
+ frame = await self .track .recv ()
42
+ await self .pipeline .put_video_frame (frame )
33
43
44
+ async def recv (self ):
45
+ return await self .pipeline .get_processed_video_frame ()
46
+
47
+
48
+ class AudioStreamTrack (MediaStreamTrack ):
49
+ kind = "audio"
34
50
def __init__ (self , track : MediaStreamTrack , pipeline ):
35
51
super ().__init__ ()
36
52
self .track = track
37
53
self .pipeline = pipeline
54
+ asyncio .create_task (self .collect_frames ())
55
+
56
+ async def collect_frames (self ):
57
+ while True :
58
+ frame = await self .track .recv ()
59
+ await self .pipeline .put_audio_frame (frame )
38
60
39
61
async def recv (self ):
40
- frame = await self .track .recv ()
41
- return await self .pipeline (frame )
62
+ return await self .pipeline .get_processed_audio_frame ()
42
63
43
64
44
65
def force_codec (pc , sender , forced_codec ):
@@ -87,8 +108,7 @@ async def offer(request):
87
108
88
109
params = await request .json ()
89
110
90
- pipeline .set_prompt (params ["prompt" ])
91
- await pipeline .warm ()
111
+ await pipeline .set_prompts (params ["prompts" ])
92
112
93
113
offer_params = params ["offer" ]
94
114
offer = RTCSessionDescription (sdp = offer_params ["sdp" ], type = offer_params ["type" ])
@@ -103,17 +123,19 @@ async def offer(request):
103
123
104
124
pcs .add (pc )
105
125
106
- tracks = {"video" : None }
126
+ tracks = {"video" : None , "audio" : None }
107
127
108
- # Prefer h264
109
- transceiver = pc .addTransceiver ("video" )
110
- caps = RTCRtpSender .getCapabilities ("video" )
111
- prefs = list (filter (lambda x : x .name == "H264" , caps .codecs ))
112
- transceiver .setCodecPreferences (prefs )
128
+ # Only add video transceiver if video is present in the offer
129
+ if "m=video" in offer .sdp :
130
+ # Prefer h264
131
+ transceiver = pc .addTransceiver ("video" )
132
+ caps = RTCRtpSender .getCapabilities ("video" )
133
+ prefs = list (filter (lambda x : x .name == "H264" , caps .codecs ))
134
+ transceiver .setCodecPreferences (prefs )
113
135
114
- # Monkey patch max and min bitrate to ensure constant bitrate
115
- h264 .MAX_BITRATE = MAX_BITRATE
116
- h264 .MIN_BITRATE = MIN_BITRATE
136
+ # Monkey patch max and min bitrate to ensure constant bitrate
137
+ h264 .MAX_BITRATE = MAX_BITRATE
138
+ h264 .MIN_BITRATE = MIN_BITRATE
117
139
118
140
# Handle control channel from client
119
141
@pc .on ("datachannel" )
@@ -131,13 +153,13 @@ async def on_message(message):
131
153
"nodes" : nodes_info
132
154
}
133
155
channel .send (json .dumps (response ))
134
- elif params .get ("type" ) == "update_prompt " :
135
- if "prompt " not in params :
156
+ elif params .get ("type" ) == "update_prompts " :
157
+ if "prompts " not in params :
136
158
logger .warning ("[Control] Missing prompt in update_prompt message" )
137
159
return
138
- pipeline .set_prompt (params ["prompt " ])
160
+ await pipeline .update_prompts (params ["prompts " ])
139
161
response = {
140
- "type" : "prompt_updated " ,
162
+ "type" : "prompts_updated " ,
141
163
"success" : True
142
164
}
143
165
channel .send (json .dumps (response ))
@@ -158,6 +180,10 @@ def on_track(track):
158
180
159
181
codec = "video/H264"
160
182
force_codec (pc , sender , codec )
183
+ elif track .kind == "audio" :
184
+ audioTrack = AudioStreamTrack (track , pipeline )
185
+ tracks ["audio" ] = audioTrack
186
+ pc .addTrack (audioTrack )
161
187
162
188
@track .on ("ended" )
163
189
async def on_ended ():
@@ -175,6 +201,11 @@ async def on_connectionstatechange():
175
201
176
202
await pc .setRemoteDescription (offer )
177
203
204
+ if "m=audio" in pc .remoteDescription .sdp :
205
+ await pipeline .warm_audio ()
206
+ if "m=video" in pc .remoteDescription .sdp :
207
+ await pipeline .warm_video ()
208
+
178
209
answer = await pc .createAnswer ()
179
210
await pc .setLocalDescription (answer )
180
211
@@ -190,7 +221,7 @@ async def set_prompt(request):
190
221
pipeline = request .app ["pipeline" ]
191
222
192
223
prompt = await request .json ()
193
- pipeline .set_prompt (prompt )
224
+ await pipeline .set_prompts (prompt )
194
225
195
226
return web .Response (content_type = "application/json" , text = "OK" )
196
227
0 commit comments