27
27
from .addressing import parse_host_port , unparse_host_port
28
28
from .core import Comm , Connector , Listener , CommClosedError , FatalCommClosedError
29
29
from .utils import to_frames , from_frames , get_tcp_server_address , ensure_concrete_host
30
+ from ..protocol .utils import pack_frames_prelude , unpack_frames
30
31
31
32
32
33
logger = logging .getLogger (__name__ )
@@ -187,19 +188,16 @@ async def read(self, deserializers=None):
187
188
if stream is None :
188
189
raise CommClosedError
189
190
191
+ fmt = "Q"
192
+ fmt_size = struct .calcsize (fmt )
193
+
190
194
try :
191
- n_frames = await stream .read_bytes (8 )
192
- n_frames = struct .unpack ("Q" , n_frames )[0 ]
193
- lengths = await stream .read_bytes (8 * n_frames )
194
- lengths = struct .unpack ("Q" * n_frames , lengths )
195
-
196
- frames = []
197
- for length in lengths :
198
- frame = bytearray (length )
199
- if length :
200
- n = await stream .read_into (frame )
201
- assert n == length , (n , length )
202
- frames .append (frame )
195
+ frames_nbytes = await stream .read_bytes (fmt_size )
196
+ (frames_nbytes ,) = struct .unpack (fmt , frames_nbytes )
197
+
198
+ frames = bytearray (frames_nbytes )
199
+ n = await stream .read_into (frames )
200
+ assert n == frames_nbytes , (n , frames_nbytes )
203
201
except StreamClosedError as e :
204
202
self .stream = None
205
203
self ._closed = True
@@ -214,6 +212,8 @@ async def read(self, deserializers=None):
214
212
raise
215
213
else :
216
214
try :
215
+ frames = unpack_frames (frames )
216
+
217
217
msg = await from_frames (
218
218
frames ,
219
219
deserialize = self .deserialize ,
@@ -243,30 +243,28 @@ async def write(self, msg, serializers=None, on_error="message"):
243
243
** self .handshake_options ,
244
244
},
245
245
)
246
+ frames_nbytes = sum (map (nbytes , frames ))
246
247
247
- try :
248
- nframes = len (frames )
249
- lengths = [nbytes (frame ) for frame in frames ]
250
- length_bytes = struct .pack (f"Q{ nframes } Q" , nframes , * lengths )
248
+ header = pack_frames_prelude (frames )
249
+ header = struct .pack ("Q" , nbytes (header ) + frames_nbytes ) + header
251
250
252
- frames = [length_bytes , * frames ]
253
- lengths = [ len ( length_bytes ), * lengths ]
251
+ frames = [header , * frames ]
252
+ frames_nbytes += nbytes ( header )
254
253
255
- if sum (lengths ) < 2 ** 17 : # 128kiB
256
- # small enough, send in one go
257
- stream .write (b"" .join (frames ))
258
- else :
259
- # avoid large memcpy, send in many
260
- for frame , frame_bytes in zip (frames , lengths ):
261
- # Can't wait for the write() Future as it may be lost
262
- # ("If write is called again before that Future has resolved,
263
- # the previous future will be orphaned and will never resolve")
264
- if frame_bytes :
265
- future = stream .write (frame )
266
- bytes_since_last_yield += frame_bytes
267
- if bytes_since_last_yield > 32e6 :
268
- await future
269
- bytes_since_last_yield = 0
254
+ if frames_nbytes < 2 ** 17 : # 128kiB
255
+ # small enough, send in one go
256
+ frames = [b"" .join (frames )]
257
+
258
+ try :
259
+ # trick to enque all frames for writing beforehand
260
+ for each_frame in frames :
261
+ each_frame_nbytes = nbytes (each_frame )
262
+ if each_frame_nbytes :
263
+ stream ._write_buffer .append (each_frame )
264
+ stream ._total_write_index += each_frame_nbytes
265
+
266
+ # start writing frames
267
+ stream .write (b"" )
270
268
except StreamClosedError as e :
271
269
self .stream = None
272
270
self ._closed = True
@@ -282,7 +280,7 @@ async def write(self, msg, serializers=None, on_error="message"):
282
280
self .abort ()
283
281
raise
284
282
285
- return sum ( lengths )
283
+ return frames_nbytes
286
284
287
285
@gen .coroutine
288
286
def close (self ):
0 commit comments