diff --git a/src/httpz.zig b/src/httpz.zig index f74f46a..f76105a 100644 --- a/src/httpz.zig +++ b/src/httpz.zig @@ -557,6 +557,12 @@ pub fn Server(comptime H: type) type { res.write() catch { conn.handover = .close; }; + + if (req.unread_body > 0 and conn.handover == .keepalive) { + drain(&req) catch { + conn.handover = .close; + }; + } } pub fn middleware(self: *Self, comptime M: type, config: M.Config) !Middleware(H) { @@ -714,6 +720,21 @@ const FallbackAllocator = struct { } }; +// Called when we have unread bytes on the request and want to keepalive the +// connection. Only happens when lazy_read_size is configured and the client +// didn't read the [whole] body +// There should already be a receive timeout on the socket since the only +// way for this to be +fn drain(req: *Request) !void { + var r = try req.reader(2000); + var buf: [4096]u8 = undefined; + while (true) { + if (try r.read(&buf) == 0) { + return; + } + } +} + const t = @import("t.zig"); var global_test_allocator = std.heap.GeneralPurposeAllocator(.{}){}; @@ -739,13 +760,10 @@ test "tests:beforeAll" { const ga = global_test_allocator.allocator(); { - default_server = try Server(void).init(ga, .{ - .port = 5992, - .request = .{ - .lazy_read_size = 4_096, - .max_body_size = 1_048_576, - } - }, {}); + default_server = try Server(void).init(ga, .{ .port = 5992, .request = .{ + .lazy_read_size = 4_096, + .max_body_size = 1_048_576, + } }, {}); // only need to do this because we're using listenInNewThread instead // of blocking here. So the array to hold the middleware needs to outlive @@ -772,7 +790,7 @@ test "tests:beforeAll" { router.method("PING", "/test/method", TestDummyHandler.method, .{}); router.get("/test/query", TestDummyHandler.reqQuery, .{}); router.get("/test/stream", TestDummyHandler.eventStream, .{}); - router.get("/test/req_stream", TestDummyHandler.reqStream, .{}); + router.get("/test/req_reader", TestDummyHandler.reqReader, .{}); router.get("/test/chunked", TestDummyHandler.chunked, .{}); router.get("/test/route_data", TestDummyHandler.routeData, .{ .data = &TestDummyHandler.RouteData{ .power = 12345 } }); router.all("/test/cors", TestDummyHandler.jsonRes, .{ .middlewares = cors }); @@ -1306,12 +1324,12 @@ test "httpz: custom handle" { try t.expectString("HTTP/1.1 200 \r\nContent-Length: 9\r\n\r\nhello teg", testReadAll(stream, &buf)); } -test "httpz: request body streaming" { +test "httpz: request body reader" { { // no body const stream = testStream(5992); defer stream.close(); - try stream.writeAll("GET /test/req_stream HTTP/1.1\r\nContent-Length: 0\r\n\r\n"); + try stream.writeAll("GET /test/req_reader HTTP/1.1\r\nContent-Length: 0\r\n\r\n"); var res = testReadParsed(stream); defer res.deinit(); @@ -1322,7 +1340,7 @@ test "httpz: request body streaming" { // small body const stream = testStream(5992); defer stream.close(); - try stream.writeAll("GET /test/req_stream HTTP/1.1\r\nContent-Length: 4\r\n\r\n123z"); + try stream.writeAll("GET /test/req_reader HTTP/1.1\r\nContent-Length: 4\r\n\r\n123z"); var res = testReadParsed(stream); defer res.deinit(); @@ -1336,7 +1354,7 @@ test "httpz: request body streaming" { for (0..10) |_| { const stream = testStream(5992); defer stream.close(); - var req: []const u8 = "GET /test/req_stream HTTP/1.1\r\nContent-Length: 20000\r\n\r\n" ++ ("a" ** 20_000); + var req: []const u8 = "GET /test/req_reader HTTP/1.1\r\nContent-Length: 20000\r\n\r\n" ++ ("a" ** 20_000); while (req.len > 0) { const len = random.uintAtMost(usize, req.len - 1) + 1; const n = stream.write(req[0..len]) catch |err| switch (err) { @@ -1351,7 +1369,6 @@ test "httpz: request body streaming" { defer res.deinit(); try res.expectJson(.{ .length = 20_000 }); } - } test "websocket: invalid request" { @@ -1549,16 +1566,22 @@ const TestDummyHandler = struct { try res.startEventStream(StreamContext{ .data = "hello" }, StreamContext.handle); } - fn reqStream(req: *Request, res: *Response) !void { - var stream = try req.streamBody(); - defer stream.deinit(); + fn reqReader(req: *Request, res: *Response) !void { + var reader = try req.reader(2000); var l: usize = 0; var buf: [1024]u8 = undefined; - while (try stream.read(&buf)) |data| { - l += data.len; + while (true) { + const n = try reader.read(&buf); + if (n == 0) { + break; + } + if (req.body_len > 10 and std.mem.indexOfNonePos(u8, buf[0..n], 0, "a") != null) { + return error.InvalidData; + } + l += n; } - return res.json(.{.length = l}, .{}); + return res.json(.{ .length = l }, .{}); } const StreamContext = struct { diff --git a/src/request.zig b/src/request.zig index ab685ff..80f639f 100644 --- a/src/request.zig +++ b/src/request.zig @@ -48,10 +48,10 @@ pub const Request = struct { body_buffer: ?buffer.Buffer = null, body_len: usize = 0, - // True if we haven't read the [full] body yet. This can only happen when + // The number of unread bytes from the body. This can only happen when // lazy_read_size is configured and the request is larger that this value. // There can still be _part_ of the body in body_buffer. - lazy_body: bool, + unread_body: usize, // cannot use an optional on qs, because it's pre-allocated so always exists qs_read: bool = false, @@ -95,7 +95,7 @@ pub const Request = struct { .fd = &state.fd, .mfd = &state.mfd, .method = state.method.?, - .lazy_body = state.lazy_body, + .unread_body = state.unread_body, .method_string = state.method_string orelse "", .protocol = state.protocol.?, .url = Url.parse(state.url.?), @@ -175,7 +175,9 @@ pub const Request = struct { return self.parseMultiFormData(); } - pub fn streamBody(self: *Request) !Stream { + pub const Reader = std.io.Reader(*BodyReader, BodyReader.Error, BodyReader.read); + + pub fn reader(self: *Request, timeout_ms: usize) !Reader { var buf: []const u8 = &.{}; if (self.body_buffer) |bb| { std.debug.assert(bb.type == .static); @@ -183,16 +185,24 @@ pub const Request = struct { } const conn = self.conn; - if (self.lazy_body == true) { + if (self.unread_body > 0) { try conn.blockingMode(); + const timeval = std.mem.toBytes(std.posix.timeval{ + .sec = @intCast(@divTrunc(timeout_ms, 1000)), + .usec = @intCast(@mod(timeout_ms, 1000) * 1000), + }); + try std.posix.setsockopt(conn.stream.handle, std.posix.SOL.SOCKET, std.posix.SO.RCVTIMEO, &timeval); } - return .{ + const r = try self.arena.create(BodyReader); + r.* = .{ .req = self, .buffer = buf, .remaining = self.body_len, .socket = conn.stream.handle, }; + + return .{ .context = r}; } // OK, this is a bit complicated. @@ -493,39 +503,35 @@ pub const Request = struct { }; } - pub const Stream = struct { + pub const BodyReader = struct { req: *Request, + socket: std.posix.socket_t, remaining: usize, buffer: []const u8, - socket: std.posix.socket_t, - pub fn deinit(self: *Stream) void { - self.req.conn.nonblockingMode() catch {}; - } + pub const Error = std.posix.ReadError; - pub fn read(self: *Stream, into: []u8) !?[]u8 { + pub fn read(self: *BodyReader, into: []u8) Error!usize { const b = self.buffer; const remaining = self.remaining; + if (b.len != 0) { const l = @min(b.len, into.len); - - const buf = into[0..l]; - @memcpy(buf, b[0..l]); - + @memcpy(into[0..l], b[0..l]); self.buffer = b[l..]; self.remaining = remaining - l; - return buf; + return l; } if (remaining == 0) { - return null; + return 0; } - var buf = if (into.len > remaining) into[0..remaining] else into; + const buf = if (into.len > remaining) into[0..remaining] else into; const n = try std.posix.read(self.socket, buf); self.remaining = remaining - n; - return if (n == 0) null else buf[0..n]; + return n; } }; }; @@ -595,9 +601,9 @@ pub const State = struct { // know what it is from the content-length header body_len: usize, - // True if we aren't reading the body. Happens when lazy_read_size is enabled - // and we get a large body. It'll be up to the app to read it! - lazy_body: bool, + // Happens when lazy_read_size is enabled and we get a large body. + // It'll be up to the app to read it! + unread_body: usize, middlewares: std.StringHashMap(*anyopaque), @@ -614,7 +620,7 @@ pub const State = struct { .method = null, .method_string = "", .protocol = null, - .lazy_body = false, + .unread_body = 0, .buffer_pool = buffer_pool, .lazy_read_size = config.lazy_read_size, .max_body_size = config.max_body_size orelse 1_048_576, @@ -641,7 +647,7 @@ pub const State = struct { self.len = 0; self.url = null; self.method = null; - self.lazy_body = false; + self.unread_body = 0; self.method_string = null; self.protocol = null; @@ -794,7 +800,7 @@ pub const State = struct { self.pos = space + 1; self.method = .OTHER; self.method_string = candidate; - } + }, } return true; } @@ -953,30 +959,31 @@ pub const State = struct { const buf = self.buf; // how much (if any) of the body we've already read + const read = len - pos; + + if (read > cl) { + return error.InvalidContentLength; + } + + // how much of the body are we missing + const missing = cl - read; + if (self.lazy_read_size) |lazy_read| { if (cl >= lazy_read) { self.pos = len; - self.lazy_body = true; + self.unread_body = missing; self.body = .{ .type = .static, .data = buf[pos..len] }; return true; } } - const read = len - pos; - if (read == cl) { + if (missing == 0) { // we've read the entire body into buf, point to that. self.pos = len; self.body = .{ .type = .static, .data = buf[pos..len] }; return true; } - if (read > cl) { - return error.InvalidContentLength; - } - - // how much of the body are we missing - const missing = cl - read; - // how much spare space we have in our static buffer const spare = buf.len - len; if (missing < spare) { @@ -1606,7 +1613,7 @@ test "request: fuzz" { const number_of_requests = random.uintAtMost(u8, 10) + 1; for (0..number_of_requests) |_| { - defer ctx.conn.requestDone(4096); + defer ctx.conn.requestDone(4096, true) catch unreachable; const method = randomMethod(random); const url = t.randomString(random, aa, 20); diff --git a/src/response.zig b/src/response.zig index 9ff9dc3..b344289 100644 --- a/src/response.zig +++ b/src/response.zig @@ -407,7 +407,6 @@ pub const Response = struct { }; }; - // All the upfront memory allocation that we can do. Gets re-used from request // to request. pub const State = struct { diff --git a/src/t.zig b/src/t.zig index d8e9de9..366fc44 100644 --- a/src/t.zig +++ b/src/t.zig @@ -111,6 +111,7 @@ pub const Context = struct { .ws_worker = undefined, .conn_arena = ctx_arena, .req_arena = std.heap.ArenaAllocator.init(aa), + ._io_mode = if (httpz.blockingMode()) .blocking else .nonblocking, }; return .{ diff --git a/src/worker.zig b/src/worker.zig index dc83819..6a2bd04 100644 --- a/src/worker.zig +++ b/src/worker.zig @@ -212,7 +212,8 @@ pub fn Blocking(comptime S: type, comptime WSH: type) type { var is_keepalive = false; while (true) { - defer conn.requestDone(self.retain_allocated_bytes_keepalive); + // impossible for this to fail in blocking mode + defer conn.requestDone(self.retain_allocated_bytes_keepalive, false) catch unreachable; switch (self.handleRequest(conn, is_keepalive, thread_buf) catch .close) { .keepalive => { is_keepalive = true; @@ -481,7 +482,7 @@ pub fn NonBlocking(comptime S: type, comptime WSH: type) type { .max_conn = config.workers.max_conn orelse 8_192, .timeout_request = config.timeout.request orelse MAX_TIMEOUT, .timeout_keepalive = config.timeout.keepalive orelse MAX_TIMEOUT, - .retain_allocated_bytes = config.workers.retain_allocated_bytes orelse 8192, + .retain_allocated_bytes = config.workers.retain_allocated_bytes orelse 8192, }; } @@ -572,8 +573,11 @@ pub fn NonBlocking(comptime S: type, comptime WSH: type) type { const stream = http_conn.stream; const done = http_conn.req_state.parse(http_conn.req_arena.allocator(), stream) catch |err| { + // maybe a write fail or something, doesn't matter, we're closing the connection requestError(http_conn, err) catch {}; - http_conn.requestDone(self.retain_allocated_bytes); + + // impossible to fail when false is passed + http_conn.requestDone(self.retain_allocated_bytes, false) catch unreachable; conn.close(); self.disown(conn); continue; @@ -586,7 +590,6 @@ pub fn NonBlocking(comptime S: type, comptime WSH: type) type { self.swapList(conn, .active); thread_pool.spawn(.{ self, now, conn }); - }, .websocket => thread_pool.spawn(.{ self, now, conn }), }, @@ -671,6 +674,7 @@ pub fn NonBlocking(comptime S: type, comptime WSH: type) type { const http_conn = try self.http_conn_pool.acquire(); http_conn.request_count = 1; http_conn._state = .request; + http_conn._io_mode = .nonblocking; http_conn.address = address; http_conn.socket_flags = socket_flags; http_conn.stream = .{ .handle = socket }; @@ -771,9 +775,16 @@ pub fn NonBlocking(comptime S: type, comptime WSH: type) type { metrics.request(); http_conn.request_count += 1; self.server.handleRequest(http_conn, thread_buf); - http_conn.requestDone(self.retain_allocated_bytes); - switch (http_conn.handover) { + var handover = http_conn.handover; + http_conn.requestDone(self.retain_allocated_bytes, handover == .keepalive or handover == .websocket) catch { + // This means we failed to put the connection into + // nonblocking mode. Rare, but safer to clos the connection + // at this point. + handover = .close; + }; + + switch (handover) { .keepalive => { http_conn.timeout = now + self.timeout_keepalive; self.swapList(conn, .keepalive); @@ -827,7 +838,7 @@ pub fn NonBlocking(comptime S: type, comptime WSH: type) type { } // Enforces timeouts, and returns when the next timeout should be checked. - fn prepareToWait(self: *Self, now: u32) struct {bool, ?i32} { + fn prepareToWait(self: *Self, now: u32) struct { bool, ?i32 } { const request_timed_out, const request_count, const request_timeout = collectTimedOut(&self.request_list, now); const keepalive_timed_out, const keepalive_count, const keepalive_timeout = blk: { @@ -850,17 +861,17 @@ pub fn NonBlocking(comptime S: type, comptime WSH: type) type { } if (request_timeout == null and keepalive_timeout == null) { - return .{closed, null}; + return .{ closed, null }; } const next = @min(request_timeout orelse MAX_TIMEOUT, keepalive_timeout orelse MAX_TIMEOUT); if (next < now) { // can happen if a socket was just about to timeout when prepareToWait // was called - return .{closed, 1}; + return .{ closed, 1 }; } - return .{closed, @intCast(next - now)}; + return .{ closed, @intCast(next - now) }; } // lists are ordered from soonest to timeout to last, as soon as we find @@ -1047,7 +1058,7 @@ fn KQueue(comptime WSH: type) type { try self.change(conn.getSocket(), @intFromPtr(conn), posix.system.EVFILT.READ, posix.system.EV.ADD | posix.system.EV.ENABLE, 0); } - fn rearmRead(self: *Self, conn: *Conn(WSH)) !void{ + fn rearmRead(self: *Self, conn: *Conn(WSH)) !void { // called from the worker thread, can't use change_buffer _ = try posix.kevent(self.fd, &.{.{ .ident = @intCast(conn.getSocket()), @@ -1215,7 +1226,7 @@ fn EPoll(comptime WSH: type) type { return posix.epoll_ctl(self.fd, linux.EPOLL.CTL_ADD, conn.getSocket(), &event); } - fn rearmRead(self: *Self, conn: *Conn(WSH)) !void{ + fn rearmRead(self: *Self, conn: *Conn(WSH)) !void { var event = linux.epoll_event{ .data = .{ .ptr = @intFromPtr(conn) }, .events = linux.EPOLL.IN | linux.EPOLL.RDHUP | linux.EPOLL.ONESHOT, @@ -1483,11 +1494,18 @@ pub const HTTPConn = struct { websocket: *anyopaque, }; - // can be concurrently accessed, use getState and setState + pub const IOMode = enum { + blocking, + nonblocking, + }; + + // can be concurrently accessed, use getState _state: State, _mut: Thread.Mutex, + _io_mode: IOMode, + handover: Handover, // unix timestamp (seconds) where this connection should timeout @@ -1554,6 +1572,7 @@ pub const HTTPConn = struct { .res_state = res_state, .req_arena = req_arena, .conn_arena = conn_arena, + ._io_mode = if (httpz.blockingMode()) .blocking else .nonblocking, }; } @@ -1568,20 +1587,16 @@ pub const HTTPConn = struct { allocator.destroy(self.conn_arena); } - pub fn requestDone(self: *HTTPConn, retain_allocated_bytes: usize) void { + pub fn requestDone(self: *HTTPConn, retain_allocated_bytes: usize, revert_blocking: bool) !void { self.req_state.reset(); self.res_state.reset(); _ = self.req_arena.reset(.{ .retain_with_limit = retain_allocated_bytes }); + if (revert_blocking) { + try self.nonblockingMode(); + } } - // getting put back into the pool - pub fn reset(self: *HTTPConn) void { - self.handover = .unknown; - self.stream = undefined; - self.address = undefined; - } - - pub fn writeAll(self: *const HTTPConn, data: []const u8) !void { + pub fn writeAll(self: *HTTPConn, data: []const u8) !void { const socket = self.stream.handle; var i: usize = 0; @@ -1602,25 +1617,16 @@ pub const HTTPConn = struct { std.debug.assert(n != 0); i += n; } - - // if write fails, and we're in blocking, it doesn't really matter - // we're going to be closing connction anyways - if (blocking) { - try self.nonblockingMode(); - } } - pub fn writeAllIOVec(self: *const HTTPConn, vec: []posix.iovec_const) !void { + pub fn writeAllIOVec(self: *HTTPConn, vec: []posix.iovec_const) !void { const socket = self.stream.handle; var i: usize = 0; - var blocking = false; - while (true) { var n = posix.writev(socket, vec[i..]) catch |err| switch (err) { error.WouldBlock => { try self.blockingMode(); - blocking = true; continue; }, else => return err, @@ -1630,9 +1636,6 @@ pub const HTTPConn = struct { n -= vec[i].len; i += 1; if (i >= vec.len) { - if (blocking) { - try self.nonblockingMode(); - } return; } } @@ -1641,22 +1644,30 @@ pub const HTTPConn = struct { } } - pub fn blockingMode(self: *const HTTPConn) !void { + pub fn blockingMode(self: *HTTPConn) !void { if (comptime httpz.blockingMode() == true) { // When httpz is in blocking mode, than we always keep the socket in // blocking mode return; } + if (self._io_mode == .blocking) { + return; + } _ = try posix.fcntl(self.stream.handle, posix.F.SETFL, self.socket_flags & ~@as(u32, @bitCast(posix.O{ .NONBLOCK = true }))); + self._io_mode = .blocking; } - pub fn nonblockingMode(self: *const HTTPConn) !void { + pub fn nonblockingMode(self: *HTTPConn) !void { if (comptime httpz.blockingMode() == true) { // When httpz is in blocking mode, than we always keep the socket in // blocking mode return; } + if (self._io_mode == .nonblocking) { + return; + } _ = try posix.fcntl(self.stream.handle, posix.F.SETFL, self.socket_flags); + self._io_mode = .nonblocking; } }; @@ -1792,7 +1803,7 @@ const TestNode = struct { fn alloc(id: i32) *TestNode { const tn = t.allocator.create(TestNode) catch unreachable; - tn.* = .{.id = id}; + tn.* = .{ .id = id }; return tn; } };