diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index a9aba35050ad..1c5ace05fc0e 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -161,7 +161,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req } scMaxEachPostBytes := int(h.ln.config.GetNormalizedScMaxEachPostBytes().To) - if request.Method == "POST" && sessionId != "" { + if request.Method == "POST" && sessionId != "" { // stream-up, packet-up seq := "" if len(subpath) > 1 { seq = subpath[1] @@ -173,8 +173,12 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req writer.WriteHeader(http.StatusBadRequest) return } + uploadDone := done.New() err = currentSession.uploadQueue.Push(Packet{ - Reader: request.Body, + Reader: &httpRequestBodyReader{ + requestReader: request.Body, + uploadDone: uploadDone, + }, }) if err != nil { errors.LogInfoInner(context.Background(), err, "failed to upload (PushReader)") @@ -199,8 +203,12 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req } }() } - <-request.Context().Done() + select { + case <-request.Context().Done(): + case <-uploadDone.Wait(): + } } + uploadDone.Close() return } @@ -243,7 +251,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req } writer.WriteHeader(http.StatusOK) - } else if request.Method == "GET" || sessionId == "" { + } else if request.Method == "GET" || sessionId == "" { // stream-down, stream-one responseFlusher, ok := writer.(http.Flusher) if !ok { panic("expected http.ResponseWriter to be an http.Flusher") @@ -283,7 +291,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req reader: request.Body, remoteAddr: remoteAddr, } - if sessionId != "" { + if sessionId != "" { // if not stream-one conn.reader = currentSession.uploadQueue } @@ -302,6 +310,20 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req } } +type httpRequestBodyReader struct { + requestReader io.ReadCloser + uploadDone *done.Instance +} + +func (c *httpRequestBodyReader) Read(b []byte) (int, error) { + return c.requestReader.Read(b) +} + +func (c *httpRequestBodyReader) Close() error { + defer c.uploadDone.Close() + return c.requestReader.Close() +} + type httpResponseBodyWriter struct { sync.Mutex responseWriter http.ResponseWriter