From 7bde1f4dd65a4c2c3211ff339a69ebc90818a855 Mon Sep 17 00:00:00 2001 From: justrach <54503978+justrach@users.noreply.github.com> Date: Sat, 11 Apr 2026 20:23:00 +0800 Subject: [PATCH 1/5] feat: Zig-native multipart/form-data and urlencoded parser - Add zig/src/multipart.zig: RFC 2046 multipart parser + RFC 7578 form-data handling + application/x-www-form-urlencoded parser - Zero-copy part splitting, boundary extraction, Content-Disposition parsing for name/filename/content-type - Wire into server.zig callPythonHandler: detect Content-Type header, parse multipart or urlencoded bodies, pass form_fields dict and file_fields list (with raw bytes) to Python kwargs - File fields passed as [{name, filename, content_type, body}, ...] so Python can populate UploadFile instances with byte-identical data Co-authored-by: trilokagent <275208033+trilokagent@users.noreply.github.com> --- zig/src/multipart.zig | 339 ++++++++++++++++++++++++++++++++++++++++++ zig/src/server.zig | 260 ++++++++++++++++++++++---------- 2 files changed, 518 insertions(+), 81 deletions(-) create mode 100644 zig/src/multipart.zig diff --git a/zig/src/multipart.zig b/zig/src/multipart.zig new file mode 100644 index 0000000..e7af22c --- /dev/null +++ b/zig/src/multipart.zig @@ -0,0 +1,339 @@ +const std = @import("std"); + +pub const FormField = struct { + name: []const u8, + value: []const u8, +}; + +pub const FileField = struct { + name: []const u8, + filename: []const u8, + content_type: []const u8, + body: []const u8, +}; + +pub const MultipartResult = struct { + fields: []FormField, + files: []FileField, + + pub fn deinit(self: *const MultipartResult, alloc: std.mem.Allocator) void { + alloc.free(self.fields); + alloc.free(self.files); + } +}; + +pub const UrlencodedResult = struct { + fields: []FormField, + + pub fn deinit(self: *const UrlencodedResult, alloc: std.mem.Allocator) void { + alloc.free(self.fields); + } +}; + +pub fn extractBoundary(content_type: []const u8) ?[]const u8 { + const marker: []const u8 = "boundary="; + const idx = std.mem.indexOf(u8, content_type, marker) orelse return null; + var boundary = content_type[idx + marker.len ..]; + + if (boundary.len > 1 and boundary[0] == '"') { + boundary = boundary[1..]; + const end_quote = std.mem.indexOfScalar(u8, boundary, '"') orelse boundary.len; + boundary = boundary[0..end_quote]; + } else { + const semi = std.mem.indexOfScalar(u8, boundary, ';') orelse boundary.len; + boundary = boundary[0..semi]; + boundary = std.mem.trim(u8, boundary, " \t"); + } + + if (boundary.len == 0) return null; + return boundary; +} + +pub fn parseMultipart(alloc: std.mem.Allocator, body: []const u8, boundary: []const u8) !MultipartResult { + var fields: std.ArrayListUnmanaged(FormField) = .empty; + var files: std.ArrayListUnmanaged(FileField) = .empty; + errdefer { + fields.deinit(alloc); + files.deinit(alloc); + } + + var delim_buf: [128]u8 = undefined; + if (boundary.len + 2 > delim_buf.len) return error.BoundaryTooLong; + @memcpy(delim_buf[0..2], "--"); + @memcpy(delim_buf[2 .. boundary.len + 2], boundary); + const delim: []const u8 = delim_buf[0 .. boundary.len + 2]; + + var pos: usize = 0; + + const first_delim = std.mem.indexOf(u8, body, delim) orelse return MultipartResult{ .fields = fields.items, .files = files.items }; + pos = first_delim + delim.len; + pos += std.mem.indexOf(u8, body[pos..], "\r\n") orelse 2; + pos += 2; + + while (pos < body.len) { + if (pos + delim.len + 2 <= body.len) { + const candidate = body[pos .. pos + delim.len]; + if (std.mem.eql(u8, candidate, delim)) { + if (pos + delim.len + 2 <= body.len and + body[pos + delim.len] == '-' and body[pos + delim.len + 1] == '-') + { + break; + } + pos += delim.len; + if (pos + 2 <= body.len and body[pos] == '\r' and body[pos + 1] == '\n') { + pos += 2; + } + continue; + } + } + + const header_end = std.mem.indexOf(u8, body[pos..], "\r\n\r\n") orelse break; + const part_headers = body[pos .. pos + header_end]; + const part_body_start = pos + header_end + 4; + + const next_delim = std.mem.indexOf(u8, body[part_body_start..], delim) orelse break; + var part_body = body[part_body_start .. part_body_start + next_delim]; + if (part_body.len >= 2 and part_body[part_body.len - 2] == '\r' and part_body[part_body.len - 1] == '\n') { + part_body = part_body[0 .. part_body.len - 2]; + } + + pos = part_body_start + next_delim + delim.len; + if (pos + 2 <= body.len and body[pos] == '\r' and body[pos + 1] == '\n') { + pos += 2; + } + + var field_name: []const u8 = ""; + var filename: ?[]const u8 = null; + var ct: []const u8 = "application/octet-stream"; + + var hdr_pos: usize = 0; + while (hdr_pos < part_headers.len) { + const line_end = std.mem.indexOf(u8, part_headers[hdr_pos..], "\r\n") orelse part_headers.len - hdr_pos; + const line = part_headers[hdr_pos .. hdr_pos + line_end]; + hdr_pos += line_end + 2; + + const colon = std.mem.indexOfScalar(u8, line, ':') orelse continue; + const hname = std.mem.trim(u8, line[0..colon], " \t"); + const hvalue = std.mem.trim(u8, line[colon + 1 ..], " \t"); + + if (std.ascii.eqlIgnoreCase(hname, "content-disposition")) { + const name_marker: []const u8 = "name=\""; + if (std.mem.indexOf(u8, hvalue, name_marker)) |ni| { + const start = ni + name_marker.len; + const end_quote = std.mem.indexOfScalar(u8, hvalue[start..], '"') orelse hvalue.len - start; + field_name = hvalue[start .. start + end_quote]; + } + const filename_marker: []const u8 = "filename=\""; + if (std.mem.indexOf(u8, hvalue, filename_marker)) |fi| { + const start = fi + filename_marker.len; + const end_quote = std.mem.indexOfScalar(u8, hvalue[start..], '"') orelse hvalue.len - start; + filename = hvalue[start .. start + end_quote]; + } + } else if (std.ascii.eqlIgnoreCase(hname, "content-type")) { + ct = hvalue; + } + } + + if (filename) |fn_val| { + try files.append(alloc, .{ + .name = field_name, + .filename = fn_val, + .content_type = ct, + .body = part_body, + }); + } else { + try fields.append(alloc, .{ + .name = field_name, + .value = part_body, + }); + } + } + + return MultipartResult{ + .fields = fields.items, + .files = files.items, + }; +} + +fn hexVal(c: u8) ?u8 { + return switch (c) { + '0'...'9' => c - '0', + 'a'...'f' => c - 'a' + 10, + 'A'...'F' => c - 'A' + 10, + else => null, + }; +} + +pub fn parseUrlencoded(alloc: std.mem.Allocator, body: []const u8) !UrlencodedResult { + var fields: std.ArrayListUnmanaged(FormField) = .empty; + errdefer fields.deinit(alloc); + + if (body.len == 0) return UrlencodedResult{ .fields = fields.items }; + + var pos: usize = 0; + while (pos < body.len) { + const amp = std.mem.indexOfScalar(u8, body[pos..], '&') orelse body.len - pos; + const pair = body[pos .. pos + amp]; + pos += amp + 1; + + if (pair.len == 0) continue; + + const eq = std.mem.indexOfScalar(u8, pair, '=') orelse pair.len; + const key_raw = pair[0..eq]; + const val_raw = if (eq < pair.len) pair[eq + 1 ..] else pair[0..0]; + + const decoded_key = try percentDecodeAlloc(alloc, key_raw); + const decoded_val = try percentDecodeAlloc(alloc, val_raw); + + try fields.append(alloc, .{ .name = decoded_key, .value = decoded_val }); + } + + return UrlencodedResult{ .fields = fields.items }; +} + +pub fn percentDecodeAlloc(alloc: std.mem.Allocator, src: []const u8) ![]const u8 { + var result: std.ArrayListUnmanaged(u8) = .empty; + defer result.deinit(alloc); + try result.ensureTotalCapacity(alloc, src.len); + + var i: usize = 0; + while (i < src.len) : (i += 1) { + if (src[i] == '%' and i + 2 < src.len) { + const hi = hexVal(src[i + 1]) orelse { + result.appendAssumeCapacity(src[i]); + continue; + }; + const lo = hexVal(src[i + 2]) orelse { + result.appendAssumeCapacity(src[i]); + continue; + }; + result.appendAssumeCapacity(hi << 4 | lo); + i += 2; + } else if (src[i] == '+') { + result.appendAssumeCapacity(' '); + } else { + result.appendAssumeCapacity(src[i]); + } + } + + return result.toOwnedSlice(alloc); +} + +fn percentDecodeInPlace(src: []const u8, dst: []u8) []const u8 { + var si: usize = 0; + var di: usize = 0; + while (si < src.len) : (si += 1) { + if (src[si] == '%' and si + 2 < src.len) { + const hi = hexVal(src[si + 1]) orelse { + dst[di] = src[si]; + di += 1; + continue; + }; + const lo = hexVal(src[si + 2]) orelse { + dst[di] = src[si]; + di += 1; + continue; + }; + dst[di] = hi << 4 | lo; + di += 1; + si += 2; + } else if (src[si] == '+') { + dst[di] = ' '; + di += 1; + } else { + dst[di] = src[si]; + di += 1; + } + } + return dst[0..di]; +} + +test "multipart: simple form field" { + const body = "--boundary\r\n" ++ + \\Content-Disposition: form-data; name="username" + ++ + "\r\n" ++ + \\john + ++ + "\r\n--boundary--\r\n"; + var result = try parseMultipart(std.testing.allocator, body, "boundary"); + defer result.deinit(std.testing.allocator); + try std.testing.expectEqual(1, result.fields.len); + try std.testing.expectEqualStrings("username", result.fields[0].name); + try std.testing.expectEqualStrings("john", result.fields[0].value); + try std.testing.expectEqual(0, result.files.len); +} + +test "multipart: file upload" { + const body = "--boundary\r\n" ++ + \\Content-Disposition: form-data; name="file"; filename="test.txt" + ++ + "\r\n" ++ + \\Content-Type: text/plain + ++ + "\r\n" ++ + \\hello world + ++ + "\r\n--boundary--\r\n"; + var result = try parseMultipart(std.testing.allocator, body, "boundary"); + defer result.deinit(std.testing.allocator); + try std.testing.expectEqual(0, result.fields.len); + try std.testing.expectEqual(1, result.files.len); + try std.testing.expectEqualStrings("file", result.files[0].name); + try std.testing.expectEqualStrings("test.txt", result.files[0].filename); + try std.testing.expectEqualStrings("text/plain", result.files[0].content_type); + try std.testing.expectEqualStrings("hello world", result.files[0].body); +} + +test "multipart: mixed fields and files" { + const body = "--boundary\r\n" ++ + \\Content-Disposition: form-data; name="title" + ++ + "\r\n" ++ + \\My Doc + ++ + "\r\n--boundary\r\n" ++ + \\Content-Disposition: form-data; name="upload"; filename="doc.pdf" + ++ + "\r\n" ++ + \\Content-Type: application/pdf + ++ + "\r\n" ++ + \\%PDF-1.4 + ++ + "\r\n--boundary--\r\n"; + var result = try parseMultipart(std.testing.allocator, body, "boundary"); + defer result.deinit(std.testing.allocator); + try std.testing.expectEqual(1, result.fields.len); + try std.testing.expectEqualStrings("title", result.fields[0].name); + try std.testing.expectEqualStrings("My Doc", result.fields[0].value); + try std.testing.expectEqual(1, result.files.len); + try std.testing.expectEqualStrings("upload", result.files[0].name); + try std.testing.expectEqualStrings("doc.pdf", result.files[0].filename); +} + +test "urlencoded: simple" { + const body = "name=alice&age=30"; + var result = try parseUrlencoded(std.testing.allocator, body); + defer result.deinit(std.testing.allocator); + try std.testing.expectEqual(2, result.fields.len); + try std.testing.expectEqualStrings("name", result.fields[0].name); + try std.testing.expectEqualStrings("alice", result.fields[0].value); + try std.testing.expectEqualStrings("age", result.fields[1].name); + try std.testing.expectEqualStrings("30", result.fields[1].value); +} + +test "urlencoded: percent encoding" { + const body = "q=hello+world&email=test%40example.com"; + var result = try parseUrlencoded(std.testing.allocator, body); + defer result.deinit(std.testing.allocator); + try std.testing.expectEqual(2, result.fields.len); + try std.testing.expectEqualStrings("hello world", result.fields[0].value); + try std.testing.expectEqualStrings("test@example.com", result.fields[1].value); +} + +test "boundary extraction" { + try std.testing.expectEqualStrings("----WebKitFormBoundary7MA4YWxkTrZu0gW", extractBoundary("multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW").?); + try std.testing.expectEqualStrings("abc123", extractBoundary("multipart/form-data; boundary=\"abc123\"").?); + try std.testing.expect(extractBoundary("application/json") == null); +} diff --git a/zig/src/server.zig b/zig/src/server.zig index c865e7a..1fb52d6 100644 --- a/zig/src/server.zig +++ b/zig/src/server.zig @@ -9,6 +9,7 @@ const core = @import("turboapi-core"); const router_mod = core.router; const dhi = @import("dhi_validator.zig"); const db = @import("db.zig"); +const multipart_mod = @import("multipart.zig"); const allocator = std.heap.c_allocator; @@ -597,7 +598,8 @@ pub fn server_add_static_route(_: ?*c.PyObject, args: ?*c.PyObject) callconv(.c) const st: u16 = if (status >= 100 and status <= 599) @intCast(status) else 200; const status_text = statusText(st); - const response_bytes = std.fmt.allocPrint(allocator, + const response_bytes = std.fmt.allocPrint( + allocator, "HTTP/1.1 {d} {s}\r\nContent-Type: {s}\r\nContent-Length: {d}\r\nConnection: keep-alive\r\n\r\n{s}", .{ st, status_text, ct_s, body_s.len, body_s }, ) catch return null; @@ -650,12 +652,13 @@ pub fn server_configure_cors(_: ?*c.PyObject, args: ?*c.PyObject) callconv(.c) ? var age_buf: [16]u8 = undefined; const age_str = std.fmt.bufPrint(&age_buf, "{d}", .{max_age}) catch "600"; - cors_headers = std.fmt.allocPrint(allocator, + cors_headers = std.fmt.allocPrint( + allocator, "\r\nAccess-Control-Allow-Origin: {s}" ++ - "\r\nAccess-Control-Allow-Methods: {s}" ++ - "\r\nAccess-Control-Allow-Headers: {s}" ++ - "{s}" ++ - "\r\nAccess-Control-Max-Age: {s}", + "\r\nAccess-Control-Allow-Methods: {s}" ++ + "\r\nAccess-Control-Allow-Headers: {s}" ++ + "{s}" ++ + "\r\nAccess-Control-Max-Age: {s}", .{ origins_s, methods_s, hdrs_s, cred_hdr, age_str }, ) catch return null; cors_enabled = true; @@ -703,10 +706,11 @@ fn renderResponse(status: u16, content_type: []const u8, body: []const u8) ?[]co const dw = [7][]const u8{ "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun" }; const mn = [12][]const u8{ "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec" }; const dt = std.fmt.bufPrint(&date_buf, "{s}, {d:0>2} {s} {d} {d:0>2}:{d:0>2}:{d:0>2} GMT", .{ - dw[di], md.day_index + 1, mn[@intFromEnum(md.month) - 1], yd.year, + dw[di], md.day_index + 1, mn[@intFromEnum(md.month) - 1], yd.year, ds.getHoursIntoDay(), ds.getMinutesIntoHour(), ds.getSecondsIntoMinute(), }) catch "Thu, 01 Jan 2026 00:00:00 GMT"; - return std.fmt.allocPrint(allocator, + return std.fmt.allocPrint( + allocator, "HTTP/1.1 {d} {s}\r\nServer: TurboAPI\r\nDate: {s}\r\nContent-Type: {s}\r\nContent-Length: {d}\r\nConnection: keep-alive{s}\r\n\r\n{s}", .{ status, statusText(status), dt, content_type, body.len, cors, body }, ) catch null; @@ -827,7 +831,6 @@ pub fn server_run(_: ?*c.PyObject, _: ?*c.PyObject) callconv(.c) ?*c.PyObject { return py.pyNone(); } - const HeaderList = std.ArrayListUnmanaged(HeaderPair); fn parseHeaders(request_data: []const u8, first_line_end: usize, header_end_pos: usize) HeaderList { @@ -1112,7 +1115,9 @@ fn handleOneRequest(stream: std.net.Stream, tstate: ?*anyopaque) !void { if (ms.get(match.handler_key)) |schema| { const vr = dhi.validateJsonRetainParsed(body, &schema); switch (vr) { - .ok => |parsed| { cached_parse = parsed; }, + .ok => |parsed| { + cached_parse = parsed; + }, .err => |ve| { defer ve.deinit(); std.debug.print("[DHI] validation failed for {s}\n", .{match.handler_key}); @@ -1798,6 +1803,93 @@ fn callPythonHandler(tstate: ?*anyopaque, entry: HandlerEntry, method: []const u } _ = c.PyDict_SetItemString(kwargs, "path_params", py_path_params); + // ── Multipart / urlencoded form parsing ── + var content_type_lower: [256]u8 = undefined; + var ct_len: usize = 0; + for (headers) |h| { + if (std.ascii.eqlIgnoreCase(h.name, "content-type")) { + ct_len = @min(h.value.len, content_type_lower.len); + for (h.value[0..ct_len], 0..) |ch, i| { + content_type_lower[i] = std.ascii.toLower(ch); + } + break; + } + } + const req_ct_slice = content_type_lower[0..ct_len]; + + if (std.mem.startsWith(u8, req_ct_slice, "multipart/form-data")) { + if (multipart_mod.extractBoundary(req_ct_slice)) |boundary| { + if (body.len > 0) { + const mp_opt = multipart_mod.parseMultipart(allocator, body, boundary) catch null; + if (mp_opt) |mp| { + defer mp.deinit(allocator); + + // form_fields dict: {name: value, ...} + const py_form_fields = c.PyDict_New() orelse return errorResponse(err_ct, err_body); + defer c.Py_DecRef(py_form_fields); + for (mp.fields) |f| { + const fk = py.newString(f.name) orelse continue; + const fv = py.newString(f.value) orelse { + c.Py_DecRef(fk); + continue; + }; + _ = c.PyDict_SetItem(py_form_fields, fk, fv); + c.Py_DecRef(fk); + c.Py_DecRef(fv); + } + _ = c.PyDict_SetItemString(kwargs, "form_fields", py_form_fields); + + // file_fields list: [{name, filename, content_type, body}, ...] + const py_file_list = c.PyList_New(@intCast(mp.files.len)) orelse return errorResponse(err_ct, err_body); + defer c.Py_DecRef(py_file_list); + for (mp.files, 0..) |f, i| { + const file_dict = c.PyDict_New() orelse continue; + if (py.newString(f.name)) |v| { + _ = c.PyDict_SetItemString(file_dict, "name", v); + c.Py_DecRef(v); + } + if (py.newString(f.filename)) |v| { + _ = c.PyDict_SetItemString(file_dict, "filename", v); + c.Py_DecRef(v); + } + if (py.newString(f.content_type)) |v| { + _ = c.PyDict_SetItemString(file_dict, "content_type", v); + c.Py_DecRef(v); + } + const file_bytes = c.PyBytes_FromStringAndSize(@ptrCast(f.body.ptr), @intCast(f.body.len)); + if (file_bytes) |fb| { + _ = c.PyDict_SetItemString(file_dict, "body", fb); + c.Py_DecRef(fb); + } + _ = c.PyList_SetItem(py_file_list, @intCast(i), file_dict); + } + _ = c.PyDict_SetItemString(kwargs, "file_fields", py_file_list); + } + } + } + } else if (std.mem.startsWith(u8, req_ct_slice, "application/x-www-form-urlencoded")) { + if (body.len > 0) { + const ue_opt = multipart_mod.parseUrlencoded(allocator, body) catch null; + if (ue_opt) |ue| { + defer ue.deinit(allocator); + + const py_form_fields = c.PyDict_New() orelse return errorResponse(err_ct, err_body); + defer c.Py_DecRef(py_form_fields); + for (ue.fields) |f| { + const fk = py.newString(f.name) orelse continue; + const fv = py.newString(f.value) orelse { + c.Py_DecRef(fk); + continue; + }; + _ = c.PyDict_SetItem(py_form_fields, fk, fv); + c.Py_DecRef(fk); + c.Py_DecRef(fv); + } + _ = c.PyDict_SetItemString(kwargs, "form_fields", py_form_fields); + } + } + } + // ── Call handler with PyObject_Call(handler, empty_tuple, kwargs) ── const empty_tuple = c.PyTuple_New(0) orelse return errorResponse(err_ct, err_body); defer c.Py_DecRef(empty_tuple); @@ -1965,12 +2057,13 @@ pub fn sendResponse(stream: std.net.Stream, status: u16, content_type: []const u const mon_str = mon_names[@intFromEnum(month_day.month) - 1]; // RFC 2822: "Wed, 19 Mar 2026 11:30:27 GMT" const date_str = std.fmt.bufPrint(&date_buf, "{s}, {d:0>2} {s} {d} {d:0>2}:{d:0>2}:{d:0>2} GMT", .{ - dow_str, month_day.day_index + 1, mon_str, year_day.year, + dow_str, month_day.day_index + 1, mon_str, year_day.year, day_secs.getHoursIntoDay(), day_secs.getMinutesIntoHour(), day_secs.getSecondsIntoMinute(), }) catch "Thu, 01 Jan 2026 00:00:00 GMT"; var header_buf: [512]u8 = undefined; - const header = std.fmt.bufPrint(&header_buf, + const header = std.fmt.bufPrint( + &header_buf, "HTTP/1.1 {d} {s}\r\nServer: TurboAPI\r\nDate: {s}\r\nContent-Type: {s}\r\nContent-Length: {d}\r\nConnection: keep-alive", .{ status, statusText(status), date_str, content_type, body.len }, ) catch return; @@ -1982,15 +2075,15 @@ pub fn sendResponse(stream: std.net.Stream, status: u16, content_type: []const u if (total <= 4096) { var resp_buf: [4096]u8 = undefined; var pos: usize = 0; - @memcpy(resp_buf[pos..pos + header.len], header); + @memcpy(resp_buf[pos .. pos + header.len], header); pos += header.len; if (cors.len > 0) { - @memcpy(resp_buf[pos..pos + cors.len], cors); + @memcpy(resp_buf[pos .. pos + cors.len], cors); pos += cors.len; } - @memcpy(resp_buf[pos..pos + trailer.len], trailer); + @memcpy(resp_buf[pos .. pos + trailer.len], trailer); pos += trailer.len; - @memcpy(resp_buf[pos..pos + body.len], body); + @memcpy(resp_buf[pos .. pos + body.len], body); pos += body.len; stream.writeAll(resp_buf[0..pos]) catch return; } else { @@ -2045,7 +2138,6 @@ test "response cache is safe under concurrent access" { try std.testing.expectEqualStrings("{\"item_id\":2}", getCachedResponse("GET /items/2").?); } - // ── Fuzz tests ─────────────────────────────────────────────────────────────── // Run: zig build fuzz-http (then execute the binary with --fuzz) // @@ -2061,54 +2153,58 @@ fn fuzz_percentDecode(_: void, input: []const u8) anyerror!void { try std.testing.expect(out.len <= buf.len); // Output must be a subslice of buf const buf_start = @intFromPtr(&buf); - const buf_end = buf_start + buf.len; + const buf_end = buf_start + buf.len; const out_start = @intFromPtr(out.ptr); try std.testing.expect(out_start >= buf_start and out_start <= buf_end); } test "fuzz: percentDecode — output bounded, no OOB" { - try std.testing.fuzz({}, fuzz_percentDecode, .{ .corpus = &.{ - "%00", // null byte - "%GG", // invalid hex digits - "%", // bare percent at end of input - "%2", // truncated percent sequence - "hello+world", // plus → space - "a%20b%20c", // spaces - "%FF%FE%FD", // high bytes - &([_]u8{'%'} ** 200), // 200 bare percents - "%2F%2F..%2F..%2Fetc%2Fpasswd", // path traversal - "%00%00%00", // three null bytes - }}); + try std.testing.fuzz({}, fuzz_percentDecode, .{ + .corpus = &.{ + "%00", // null byte + "%GG", // invalid hex digits + "%", // bare percent at end of input + "%2", // truncated percent sequence + "hello+world", // plus → space + "a%20b%20c", // spaces + "%FF%FE%FD", // high bytes + &([_]u8{'%'} ** 200), // 200 bare percents + "%2F%2F..%2F..%2Fetc%2Fpasswd", // path traversal + "%00%00%00", // three null bytes + }, + }); } fn fuzz_queryStringGet(_: void, input: []const u8) anyerror!void { // Split: first 16 bytes = key, remainder = query string const split = @min(input.len, 16); const key = input[0..split]; - const qs = if (split < input.len) input[split..] else ""; + const qs = if (split < input.len) input[split..] else ""; const result = queryStringGet(qs, key); if (result) |v| { // Returned slice must be within the query string buffer const qs_start = @intFromPtr(qs.ptr); - const qs_end = qs_start + qs.len; - const v_start = @intFromPtr(v.ptr); + const qs_end = qs_start + qs.len; + const v_start = @intFromPtr(v.ptr); try std.testing.expect(v_start >= qs_start and v_start <= qs_end); } } test "fuzz: queryStringGet — result is within input, no panic" { - try std.testing.fuzz({}, fuzz_queryStringGet, .{ .corpus = &.{ - "key" ++ "key=value", - "x" ++ "x=1&y=2&z=3", - "a" ++ "a=&b=c", - "k" ++ "k", - "" ++ "=value", - "foo" ++ "foo=bar&foo=baz", // duplicate key - "q" ++ "q=" ++ ("A" ** 2000), // very long value - "k" ++ "k=\x00\xFF", // binary values - "k" ++ "&&&&&", // no values, only separators - }}); + try std.testing.fuzz({}, fuzz_queryStringGet, .{ + .corpus = &.{ + "key" ++ "key=value", + "x" ++ "x=1&y=2&z=3", + "a" ++ "a=&b=c", + "k" ++ "k", + "" ++ "=value", + "foo" ++ "foo=bar&foo=baz", // duplicate key + "q" ++ "q=" ++ ("A" ** 2000), // very long value + "k" ++ "k=\x00\xFF", // binary values + "k" ++ "&&&&&", // no values, only separators + }, + }); } fn fuzz_requestLineParsing(_: void, input: []const u8) anyerror!void { @@ -2123,13 +2219,13 @@ fn fuzz_requestLineParsing(_: void, input: []const u8) anyerror!void { const first_line = input[0..first_line_end]; var parts = std.mem.splitScalar(u8, first_line, ' '); - const method = parts.next() orelse return; + const method = parts.next() orelse return; const raw_path = parts.next() orelse return; _ = method; // Split path from query string at '?' - const q_idx = std.mem.indexOf(u8, raw_path, "?"); - const path = if (q_idx) |i| raw_path[0..i] else raw_path; + const q_idx = std.mem.indexOf(u8, raw_path, "?"); + const path = if (q_idx) |i| raw_path[0..i] else raw_path; const query_string = if (q_idx) |i| raw_path[i + 1 ..] else ""; _ = path; _ = query_string; @@ -2150,38 +2246,40 @@ fn fuzz_requestLineParsing(_: void, input: []const u8) anyerror!void { } test "fuzz: HTTP request-line and header parsing — no panic on malformed input" { - try std.testing.fuzz({}, fuzz_requestLineParsing, .{ .corpus = &.{ - // Minimal valid GET - "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n", - // Valid POST with body - "POST /items HTTP/1.1\r\nContent-Type: application/json\r\nContent-Length: 2\r\n\r\n{}", - // Missing HTTP version token - "GET /\r\n\r\n", - // Empty method - " / HTTP/1.1\r\n\r\n", - // Huge Content-Length (parser must cap it) - "POST / HTTP/1.1\r\nContent-Length: 99999999999999999999\r\n\r\n", - // Negative Content-Length (parseInt → error → 0) - "POST / HTTP/1.1\r\nContent-Length: -1\r\n\r\n", - // CRLF injection attempt in header value - "GET / HTTP/1.1\r\nX-Header: value\r\nInjected: header\r\n\r\n", - // Header with no colon (should be skipped) - "GET / HTTP/1.1\r\nMalformedHeaderLine\r\n\r\n", - // Null byte in path - "GET /\x00secret HTTP/1.1\r\n\r\n", - // Very long path (> 8KB header buffer) - "GET /" ++ ("a" ** 7000) ++ " HTTP/1.1\r\n\r\n", - // Very long header value - "GET / HTTP/1.1\r\nX-Custom: " ++ ("B" ** 7000) ++ "\r\n\r\n", - // Bare \n instead of \r\n - "GET / HTTP/1.1\nHost: x\n\n", - // No path at all - "GET HTTP/1.1\r\n\r\n", - // Method with no space - "GETHTTP/1.1\r\n\r\n", - // Percent-encoded path - "GET /users%2F42 HTTP/1.1\r\n\r\n", - // Query string with adversarial chars - "GET /search?q=%00&limit=-1&page=\xFF HTTP/1.1\r\n\r\n", - }}); + try std.testing.fuzz({}, fuzz_requestLineParsing, .{ + .corpus = &.{ + // Minimal valid GET + "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n", + // Valid POST with body + "POST /items HTTP/1.1\r\nContent-Type: application/json\r\nContent-Length: 2\r\n\r\n{}", + // Missing HTTP version token + "GET /\r\n\r\n", + // Empty method + " / HTTP/1.1\r\n\r\n", + // Huge Content-Length (parser must cap it) + "POST / HTTP/1.1\r\nContent-Length: 99999999999999999999\r\n\r\n", + // Negative Content-Length (parseInt → error → 0) + "POST / HTTP/1.1\r\nContent-Length: -1\r\n\r\n", + // CRLF injection attempt in header value + "GET / HTTP/1.1\r\nX-Header: value\r\nInjected: header\r\n\r\n", + // Header with no colon (should be skipped) + "GET / HTTP/1.1\r\nMalformedHeaderLine\r\n\r\n", + // Null byte in path + "GET /\x00secret HTTP/1.1\r\n\r\n", + // Very long path (> 8KB header buffer) + "GET /" ++ ("a" ** 7000) ++ " HTTP/1.1\r\n\r\n", + // Very long header value + "GET / HTTP/1.1\r\nX-Custom: " ++ ("B" ** 7000) ++ "\r\n\r\n", + // Bare \n instead of \r\n + "GET / HTTP/1.1\nHost: x\n\n", + // No path at all + "GET HTTP/1.1\r\n\r\n", + // Method with no space + "GETHTTP/1.1\r\n\r\n", + // Percent-encoded path + "GET /users%2F42 HTTP/1.1\r\n\r\n", + // Query string with adversarial chars + "GET /search?q=%00&limit=-1&page=\xFF HTTP/1.1\r\n\r\n", + }, + }); } From 39b7e8d16a6947f077a3d7bfe10f0d067c5233e1 Mon Sep 17 00:00:00 2001 From: justrach <54503978+justrach@users.noreply.github.com> Date: Sat, 11 Apr 2026 20:26:11 +0800 Subject: [PATCH 2/5] feat: wire Form/File/UploadFile resolution in Python + TestClient files= support - Add Form()/File()/UploadFile parameter resolution in both sync and async enhanced_handler (step 3.5 between headers and JSON body) - Skip JSON body parsing when form_fields or file_fields are present - Populate UploadFile from Zig-parsed file parts: write raw bytes to SpooledTemporaryFile, set filename/content_type/size - Support Form() with alias, File() with alias, bare UploadFile type annotation - Add TestClient files= parameter with multipart body encoding - Support mixed files + data fields in multipart requests - Feature-flag pre-check for _has_form_params to avoid overhead on routes that don't use form/file parameters Co-authored-by: trilokagent <275208033+trilokagent@users.noreply.github.com> --- python/turboapi/request_handler.py | 144 +++++++++++++++++++++++++++-- python/turboapi/testclient.py | 74 ++++++++++++--- 2 files changed, 195 insertions(+), 23 deletions(-) diff --git a/python/turboapi/request_handler.py b/python/turboapi/request_handler.py index 8f35ac4..503befc 100644 --- a/python/turboapi/request_handler.py +++ b/python/turboapi/request_handler.py @@ -577,7 +577,6 @@ def format_json_response( return ResponseHandler.format_response(content, status_code, content_type) - _json_dumps = __import__("json").dumps @@ -636,6 +635,7 @@ def create_enhanced_handler(original_handler, route_definition): _param_names = set(sig.parameters.keys()) _has_dependencies = False _has_header_params = False + _has_form_params = False from turboapi.datastructures import Header try: @@ -644,17 +644,31 @@ def create_enhanced_handler(original_handler, route_definition): _has_security = True except ImportError: _has_security = False + try: + from turboapi.datastructures import File, Form + from turboapi.datastructures import UploadFile as _UploadFile + + _has_form_types = True + except ImportError: + _has_form_types = False + for pname, param in sig.parameters.items(): if isinstance(param.default, Header): _has_header_params = True + elif _has_form_types and isinstance(param.default, (Form, File)): + _has_form_params = True + elif ( + _has_form_types + and param.annotation is _UploadFile + or (isinstance(param.annotation, type) and issubclass(param.annotation, _UploadFile)) + ): + _has_form_params = True elif not ( _has_security and ( - isinstance(param.default, (Depends, SecurityBase)) - or get_depends(param) is not None + isinstance(param.default, (Depends, SecurityBase)) or get_depends(param) is not None ) ): - # Plain param (no explicit non-header marker) — may be an implicit header _has_header_params = True if _has_security and ( isinstance(param.default, (Depends, SecurityBase)) or get_depends(param) is not None @@ -691,11 +705,66 @@ async def enhanced_handler(**kwargs): header_params = HeaderParser.parse_headers(headers_dict, sig) parsed_params.update(header_params) - # 4. Parse request body (JSON) + # 3.5. Resolve Form / File / UploadFile parameters from Zig-parsed data + _form_fields = kwargs.get("form_fields", {}) + _file_fields = kwargs.get("file_fields", []) + if _has_form_params: + for pname, param in sig.parameters.items(): + if _has_form_types and isinstance(param.default, Form): + key = param.default.alias or pname + if key in _form_fields: + parsed_params[pname] = _form_fields[key] + elif param.default.default is not ...: + parsed_params[pname] = param.default.default + elif _has_form_types and isinstance(param.default, File): + key = param.default.alias or pname + matched = None + for ff in _file_fields: + if ff.get("name") == key: + matched = ff + break + if matched: + uf = _UploadFile( + filename=matched.get("filename"), + content_type=matched.get( + "content_type", "application/octet-stream" + ), + size=len(matched.get("body", b"")), + ) + uf.file.write(matched.get("body", b"")) + uf.file.seek(0) + parsed_params[pname] = uf + elif param.default.default is not ...: + parsed_params[pname] = param.default.default + elif _has_form_types and ( + param.annotation is _UploadFile + or ( + isinstance(param.annotation, type) + and issubclass(param.annotation, _UploadFile) + ) + ): + key = pname + matched = None + for ff in _file_fields: + if ff.get("name") == key: + matched = ff + break + if matched: + uf = _UploadFile( + filename=matched.get("filename"), + content_type=matched.get( + "content_type", "application/octet-stream" + ), + size=len(matched.get("body", b"")), + ) + uf.file.write(matched.get("body", b"")) + uf.file.seek(0) + parsed_params[pname] = uf + + # 4. Parse request body (JSON) — skip if form data was already parsed if "body" in kwargs: body_data = kwargs["body"] - - if body_data: # Only parse if body is not empty + if body_data and not (_form_fields or _file_fields): parsed_body = RequestBodyParser.parse_json_body(body_data, sig) # Merge parsed body params (body params take precedence) parsed_params.update(parsed_body) @@ -793,9 +862,65 @@ def enhanced_handler(**kwargs): header_params = HeaderParser.parse_headers(headers_dict, sig) parsed_params.update(header_params) - # 4. Parse request body (JSON) + # 3.5. Resolve Form / File / UploadFile parameters from Zig-parsed data + _form_fields = kwargs.get("form_fields", {}) + _file_fields = kwargs.get("file_fields", []) + if _has_form_params: + for pname, param in sig.parameters.items(): + if _has_form_types and isinstance(param.default, Form): + key = param.default.alias or pname + if key in _form_fields: + parsed_params[pname] = _form_fields[key] + elif param.default.default is not ...: + parsed_params[pname] = param.default.default + elif _has_form_types and isinstance(param.default, File): + key = param.default.alias or pname + matched = None + for ff in _file_fields: + if ff.get("name") == key: + matched = ff + break + if matched: + uf = _UploadFile( + filename=matched.get("filename"), + content_type=matched.get( + "content_type", "application/octet-stream" + ), + size=len(matched.get("body", b"")), + ) + uf.file.write(matched.get("body", b"")) + uf.file.seek(0) + parsed_params[pname] = uf + elif param.default.default is not ...: + parsed_params[pname] = param.default.default + elif _has_form_types and ( + param.annotation is _UploadFile + or ( + isinstance(param.annotation, type) + and issubclass(param.annotation, _UploadFile) + ) + ): + key = pname + matched = None + for ff in _file_fields: + if ff.get("name") == key: + matched = ff + break + if matched: + uf = _UploadFile( + filename=matched.get("filename"), + content_type=matched.get( + "content_type", "application/octet-stream" + ), + size=len(matched.get("body", b"")), + ) + uf.file.write(matched.get("body", b"")) + uf.file.seek(0) + parsed_params[pname] = uf + + # 4. Parse request body (JSON) — skip if form data was already parsed body_data = kwargs.get("body", b"") - if body_data: + if body_data and not (_form_fields or _file_fields): parsed_body = RequestBodyParser.parse_json_body(body_data, sig) parsed_params.update(parsed_body) @@ -858,6 +983,7 @@ def enhanced_handler(**kwargs): return enhanced_handler + def create_pos_handler(original_handler): """Minimal positional wrapper for PyObject_Vectorcall dispatch. diff --git a/python/turboapi/testclient.py b/python/turboapi/testclient.py index f5a1ddd..3cdf84b 100644 --- a/python/turboapi/testclient.py +++ b/python/turboapi/testclient.py @@ -6,6 +6,7 @@ import inspect import json +import uuid from typing import Any from urllib.parse import parse_qs, urlencode, urlparse @@ -118,6 +119,7 @@ def _request( params: dict | None = None, json: Any = None, data: dict | None = None, + files: dict | None = None, headers: dict | None = None, cookies: dict | None = None, content: bytes | None = None, @@ -125,23 +127,58 @@ def _request( """Execute a request against the app.""" import asyncio - # Parse URL parsed = urlparse(url) path = parsed.path or "/" query_string = parsed.query or "" - # Add query params if params: if query_string: query_string += "&" + urlencode(params) else: query_string = urlencode(params) - # Build request body body = b"" request_headers = dict(headers or {}) - if json is not None: + if files is not None: + boundary = f"----TurboAPIBoundary{uuid.uuid4().hex[:16]}" + parts = [] + for field_name, file_info in files.items(): + if isinstance(file_info, tuple): + filename, file_content = file_info + if isinstance(file_content, str): + file_content = file_content.encode("utf-8") + file_ct = "application/octet-stream" + if len(file_info) > 2: + file_ct = file_info[2] + elif isinstance(file_info, dict): + filename = file_info.get("filename", "upload") + file_content = file_info.get("content", b"") + if isinstance(file_content, str): + file_content = file_content.encode("utf-8") + file_ct = file_info.get("content_type", "application/octet-stream") + else: + filename = "upload" + file_content = file_info + file_ct = "application/octet-stream" + parts.append( + f"--{boundary}\r\n".encode() + + f'Content-Disposition: form-data; name="{field_name}"; filename="{filename}"\r\n'.encode() + + f"Content-Type: {file_ct}\r\n\r\n".encode() + + file_content + + b"\r\n" + ) + if data: + for k, v in data.items(): + parts.append( + f"--{boundary}\r\n".encode() + + f'Content-Disposition: form-data; name="{k}"\r\n\r\n'.encode() + + str(v).encode("utf-8") + + b"\r\n" + ) + body = b"".join(parts) + f"--{boundary}--\r\n".encode() + request_headers.setdefault("content-type", f"multipart/form-data; boundary={boundary}") + elif json is not None: import json as json_module body = json_module.dumps(json).encode("utf-8") @@ -161,12 +198,12 @@ def _request( request_headers["cookie"] = cookie_str # Issue #104: Check mounted apps (e.g. StaticFiles) before route matching - if hasattr(self.app, '_mounts'): + if hasattr(self.app, "_mounts"): for mount_path, mount_info in self.app._mounts.items(): if path.startswith(mount_path + "/") or path == mount_path: - sub_path = path[len(mount_path):] + sub_path = path[len(mount_path) :] mounted_app = mount_info["app"] - if hasattr(mounted_app, 'get_file'): + if hasattr(mounted_app, "get_file"): result = mounted_app.get_file(sub_path) if result is not None: content_bytes, content_type, size = result @@ -177,17 +214,26 @@ def _request( ) # Issue #102: Serve docs and openapi URLs - if hasattr(self.app, 'openapi_url') and self.app.openapi_url and path == self.app.openapi_url: + if ( + hasattr(self.app, "openapi_url") + and self.app.openapi_url + and path == self.app.openapi_url + ): import json as json_module + schema = self.app.openapi() body = json_module.dumps(schema).encode("utf-8") - return TestResponse(status_code=200, content=body, headers={"content-type": "application/json"}) + return TestResponse( + status_code=200, content=body, headers={"content-type": "application/json"} + ) - if hasattr(self.app, 'docs_url') and self.app.docs_url and path == self.app.docs_url: + if hasattr(self.app, "docs_url") and self.app.docs_url and path == self.app.docs_url: html = f""" {self.app.title} - Swagger UI
""" - return TestResponse(status_code=200, content=html.encode("utf-8"), headers={"content-type": "text/html"}) + return TestResponse( + status_code=200, content=html.encode("utf-8"), headers={"content-type": "text/html"} + ) # Find matching route route, path_params = self._find_route(method.upper(), path) @@ -195,9 +241,9 @@ def _request( return TestResponse(status_code=404, content=b'{"detail":"Not Found"}') # Issue #103: Enforce router-level dependencies - if hasattr(route, 'dependencies') and route.dependencies: + if hasattr(route, "dependencies") and route.dependencies: for dep in route.dependencies: - dep_fn = dep.dependency if hasattr(dep, 'dependency') else dep + dep_fn = dep.dependency if hasattr(dep, "dependency") else dep if dep_fn is not None: try: if inspect.iscoroutinefunction(dep_fn): @@ -289,7 +335,7 @@ def _request( result = handler(**kwargs) except Exception as e: # Issue #100: Check registered custom exception handlers first - if hasattr(self.app, '_exception_handlers'): + if hasattr(self.app, "_exception_handlers"): for exc_class, exc_handler in self.app._exception_handlers.items(): if isinstance(e, exc_class): result = exc_handler(None, e) From 3198036c8410b959c879b422ebafffc0c005badd Mon Sep 17 00:00:00 2001 From: justrach <54503978+justrach@users.noreply.github.com> Date: Sat, 11 Apr 2026 20:27:19 +0800 Subject: [PATCH 3/5] test: add byte-level parity tests for multipart file uploads Co-authored-by: trilokagent <275208033+trilokagent@users.noreply.github.com> --- tests/test_multipart_file_upload.py | 193 ++++++++++++++++++++++++++++ 1 file changed, 193 insertions(+) create mode 100644 tests/test_multipart_file_upload.py diff --git a/tests/test_multipart_file_upload.py b/tests/test_multipart_file_upload.py new file mode 100644 index 0000000..a3d80fc --- /dev/null +++ b/tests/test_multipart_file_upload.py @@ -0,0 +1,193 @@ +import pytest +from turboapi import File, Form, TurboAPI, UploadFile +from turboapi.testclient import TestClient + + +@pytest.fixture +def app(): + return TurboAPI() + + +@pytest.fixture +def client(app): + return TestClient(app) + + +class TestFormParsing: + def test_form_field_urlencoded(self, app, client): + @app.post("/login") + def login(username: str = Form(), password: str = Form()): + return {"username": username, "password": password} + + resp = client.post("/login", data={"username": "alice", "password": "secret"}) + assert resp.status_code == 200 + body = resp.json() + assert body["username"] == "alice" + assert body["password"] == "secret" + + def test_form_field_with_default(self, app, client): + @app.post("/search") + def search(q: str = Form(), limit: int = Form(default=10)): + return {"q": q, "limit": limit} + + resp = client.post("/search", data={"q": "hello"}) + assert resp.status_code == 200 + body = resp.json() + assert body["q"] == "hello" + + def test_form_field_with_alias(self, app, client): + @app.post("/items") + def create_item(item_name: str = Form(alias="item-name")): + return {"item_name": item_name} + + resp = client.post("/items", data={"item-name": "Widget"}) + assert resp.status_code == 200 + assert resp.json()["item_name"] == "Widget" + + +class TestFileUpload: + def test_file_upload_basic(self, app, client): + @app.post("/upload") + def upload(file: UploadFile = File()): + return {"filename": file.filename, "size": file.size} + + content = b"hello world from turboapi" + resp = client.post("/upload", files={"file": ("test.txt", content, "text/plain")}) + assert resp.status_code == 200 + body = resp.json() + assert body["filename"] == "test.txt" + assert body["size"] == len(content) + + def test_file_upload_byte_identical(self, app, client): + @app.post("/upload") + def upload(file: UploadFile = File()): + return file.file.read() + + raw_bytes = bytes(range(256)) + resp = client.post( + "/upload", files={"file": ("binary.bin", raw_bytes, "application/octet-stream")} + ) + assert resp.status_code == 200 + assert resp.content == raw_bytes + + def test_file_upload_with_form_field(self, app, client): + @app.post("/upload") + def upload(file: UploadFile = File(), description: str = Form()): + return {"filename": file.filename, "description": description} + + content = b"file content here" + resp = client.post( + "/upload", + files={"file": ("doc.txt", content, "text/plain")}, + data={"description": "My document"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["filename"] == "doc.txt" + assert body["description"] == "My document" + + def test_multiple_file_uploads(self, app, client): + @app.post("/upload") + def upload(file1: UploadFile = File(), file2: UploadFile = File()): + return { + "file1": file1.filename, + "file2": file2.filename, + } + + resp = client.post( + "/upload", + files={ + "file1": ("a.txt", b"aaa", "text/plain"), + "file2": ("b.txt", b"bbb", "text/plain"), + }, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["file1"] == "a.txt" + assert body["file2"] == "b.txt" + + def test_file_upload_binary_content(self, app, client): + @app.post("/upload") + def upload(file: UploadFile = File()): + return file.file.read() + + png_header = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100 + resp = client.post("/upload", files={"file": ("image.png", png_header, "image/png")}) + assert resp.status_code == 200 + assert resp.content == png_header + + def test_upload_file_type_annotation(self, app, client): + @app.post("/upload") + def upload(file: UploadFile): + return {"filename": file.filename, "size": file.size} + + content = b"typed upload" + resp = client.post("/upload", files={"file": ("data.bin", content)}) + assert resp.status_code == 200 + body = resp.json() + assert body["filename"] == "data.bin" + assert body["size"] == len(content) + + +class TestByteLevelParity: + def test_text_file_round_trip(self, app, client): + @app.post("/echo") + def echo(file: UploadFile = File()): + return file.file.read() + + text = "Hello, World! \u00e9\u00e8\u00ea\nLine 2\tTabbed" + encoded = text.encode("utf-8") + resp = client.post("/echo", files={"file": ("hello.txt", encoded, "text/plain")}) + assert resp.status_code == 200 + assert resp.content == encoded + + def test_large_file_round_trip(self, app, client): + @app.post("/echo") + def echo(file: UploadFile = File()): + return file.file.read() + + large = b"A" * 100_000 + resp = client.post("/echo", files={"file": ("large.bin", large)}) + assert resp.status_code == 200 + assert resp.content == large + + def test_empty_file_upload(self, app, client): + @app.post("/upload") + def upload(file: UploadFile = File()): + return {"filename": file.filename, "size": file.size} + + resp = client.post("/upload", files={"file": ("empty.txt", b"", "text/plain")}) + assert resp.status_code == 200 + body = resp.json() + assert body["filename"] == "empty.txt" + assert body["size"] == 0 + + def test_null_bytes_in_file(self, app, client): + @app.post("/echo") + def echo(file: UploadFile = File()): + return file.file.read() + + content = b"\x00\x01\x02\x00\xff\xfe" + resp = client.post("/echo", files={"file": ("nulls.bin", content)}) + assert resp.status_code == 200 + assert resp.content == content + + +class TestUrlencodedParsing: + def test_special_characters(self, app, client): + @app.post("/form") + def form_handler(q: str = Form()): + return {"q": q} + + resp = client.post("/form", data={"q": "hello world & more=yes"}) + assert resp.status_code == 200 + assert resp.json()["q"] == "hello world & more=yes" + + def test_percent_encoded(self, app, client): + @app.post("/form") + def form_handler(email: str = Form()): + return {"email": email} + + resp = client.post("/form", data={"email": "test@example.com"}) + assert resp.status_code == 200 + assert resp.json()["email"] == "test@example.com" From 07b5d7bc4fdf302e2c9d81d5502788dc2c8f9680 Mon Sep 17 00:00:00 2001 From: justrach <54503978+justrach@users.noreply.github.com> Date: Sat, 11 Apr 2026 23:27:07 +0800 Subject: [PATCH 4/5] feat: optimized form_sync/file_sync handler dispatch Co-authored-by: trilokagent <275208033+trilokagent@users.noreply.github.com> --- python/turboapi/zig_integration.py | 104 ++++++++++++++++++++++++----- zig/src/server.zig | 12 +++- 2 files changed, 97 insertions(+), 19 deletions(-) diff --git a/python/turboapi/zig_integration.py b/python/turboapi/zig_integration.py index 12fbaa9..f36297b 100644 --- a/python/turboapi/zig_integration.py +++ b/python/turboapi/zig_integration.py @@ -31,7 +31,7 @@ def classify_handler(handler, route) -> tuple[str, dict[str, str], dict]: Returns: (handler_type, param_types, model_info) where: - - handler_type: "simple_sync" | "body_sync" | "model_sync" | "simple_async" | "body_async" | "enhanced" + - handler_type: "simple_sync" | "body_sync" | "model_sync" | "form_sync" | "file_sync" | "simple_async" | "body_async" | "enhanced" - param_types: dict mapping param_name -> type hint string - model_info: dict with "param_name" and "model_class" for model handlers """ @@ -41,6 +41,8 @@ def classify_handler(handler, route) -> tuple[str, dict[str, str], dict]: param_types = {} needs_body = False has_depends = False + has_form = False + has_file = False model_info = {} # Check for Depends/SecurityBase — forces enhanced path @@ -57,10 +59,33 @@ def classify_handler(handler, route) -> tuple[str, dict[str, str], dict]: if has_depends: return "enhanced", {}, {} + # Check for Form/File/UploadFile markers + try: + from .datastructures import File, Form, UploadFile + + for pname, param in sig.parameters.items(): + if isinstance(param.default, Form): + has_form = True + param_types[pname] = "form" + elif isinstance(param.default, File): + has_file = True + param_types[pname] = "file" + elif param.annotation is UploadFile or ( + isinstance(param.annotation, type) and issubclass(param.annotation, UploadFile) + ): + has_file = True + param_types[pname] = "file" + except ImportError: + pass + has_implicit_header_params = False for param_name, param in sig.parameters.items(): annotation = param.annotation + if param_types.get(param_name) in ("form", "file"): + needs_body = True + continue + # Check for dhi/Pydantic BaseModel try: if ( @@ -68,13 +93,10 @@ def classify_handler(handler, route) -> tuple[str, dict[str, str], dict]: and inspect.isclass(annotation) and issubclass(annotation, BaseModel) ): - # Found a model parameter - use fast model path (sync only for now) model_info = {"param_name": param_name, "model_class": annotation} - # For async handlers, model parsing needs the enhanced path - # since Zig-side model parsing only supports sync handlers if is_async: needs_body = True - continue # Don't add to param_types + continue except TypeError: pass @@ -93,14 +115,19 @@ def classify_handler(handler, route) -> tuple[str, dict[str, str], dict]: param_types[param_name] = "bool" elif annotation is str or annotation is inspect.Parameter.empty: param_types[param_name] = "str" - # Optional str params (with defaults) may be implicit header params. - # The Zig vectorcall path only extracts path/query params, so route - # these handlers through the enhanced path which also checks headers. if param.default is not inspect.Parameter.empty: has_implicit_header_params = True method = route.method.value.upper() if hasattr(route, "method") else "GET" + # Form/file handlers use dedicated dispatch (Zig parses multipart, fast resolve) + if has_file and not is_async and not has_depends: + if method in ("POST", "PUT", "PATCH", "DELETE"): + return "file_sync", param_types, {} + if has_form and not has_file and not is_async and not has_depends: + if method in ("POST", "PUT", "PATCH", "DELETE"): + return "form_sync", param_types, {} + # Model handlers use fast model path (simd-json + model_validate) - sync only if model_info and not is_async: if method in ("POST", "PUT", "PATCH", "DELETE"): @@ -131,6 +158,7 @@ def classify_handler(handler, route) -> tuple[str, dict[str, str], dict]: return "simple_sync_noargs", param_types, {} return "simple_sync", param_types, {} + def _extract_model_schema(model_class) -> str | None: """Extract a JSON schema descriptor from a dhi BaseModel class for Zig-native validation. @@ -463,7 +491,15 @@ def decorator(func): return decorator - def db_query(self, method: str, path: str, *, sql: str, params: list[str] | None = None, single: bool = False): + def db_query( + self, + method: str, + path: str, + *, + sql: str, + params: list[str] | None = None, + single: bool = False, + ): """Zig-native custom SQL query. Supports pgvector, JSONB, full-text search, CTEs. Args: @@ -520,7 +556,10 @@ def _initialize_zig_server(self, host: str = "127.0.0.1", port: int = 8000): # Use Zig-native CORS — pre-rendered headers, zero per-request overhead. # Routes stay on the fast path (no downgrade to enhanced). origins = kwargs.get("allow_origins", ["*"]) - methods_list = kwargs.get("allow_methods", ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH", "HEAD"]) + methods_list = kwargs.get( + "allow_methods", + ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH", "HEAD"], + ) hdrs_list = kwargs.get("allow_headers", ["*"]) max_age = kwargs.get("max_age", 600) creds = kwargs.get("allow_credentials", False) @@ -560,7 +599,10 @@ def _initialize_zig_server(self, host: str = "127.0.0.1", port: int = 8000): # Skip CORSMiddleware if handled natively by Zig self._middleware_instances = [] for middleware_class, kwargs in self.middleware_stack: - if getattr(self, "_zig_cors_enabled", False) and middleware_class.__name__ == "CORSMiddleware": + if ( + getattr(self, "_zig_cors_enabled", False) + and middleware_class.__name__ == "CORSMiddleware" + ): continue # handled in Zig self._middleware_instances.append(middleware_class(**kwargs)) @@ -590,7 +632,9 @@ def _initialize_zig_server(self, host: str = "127.0.0.1", port: int = 8000): # Enable response caching for noargs handlers (auto-cache after first call) # Disable with TURBO_DISABLE_CACHE=1 (e.g. for TFB compliance) - if not os.environ.get("TURBO_DISABLE_CACHE") and hasattr(self.zig_server, "enable_response_cache"): + if not os.environ.get("TURBO_DISABLE_CACHE") and hasattr( + self.zig_server, "enable_response_cache" + ): self.zig_server.enable_response_cache() native_count = len(getattr(self, "_native_routes", [])) @@ -663,6 +707,7 @@ def _sanitize_header_component(value: str) -> str: raw_content = result.get("content", "") if isinstance(raw_content, (dict, list)): import json as _json + raw_content = _json.dumps(raw_content) response = Response( content=raw_content, @@ -680,9 +725,7 @@ def _sanitize_header_component(value: str) -> str: if response.headers: # Filter out headers that Zig already emits from its fixed set _ZIG_OWNED = frozenset({"content-length", "server", "date", "connection"}) - extra = { - k: v for k, v in response.headers.items() if k.lower() not in _ZIG_OWNED - } + extra = {k: v for k, v in response.headers.items() if k.lower() not in _ZIG_OWNED} if extra: # Inject extra headers via content_type — Zig emits # "Content-Type: " so we append "\r\nKey: Value" pairs. @@ -733,7 +776,9 @@ def _register_routes_with_zig(self): "{}", route.handler, ) - print(f"{CHECK_MARK} [model_sync+middleware→enhanced] {route.method.value} {route.path}") + print( + f"{CHECK_MARK} [model_sync+middleware→enhanced] {route.method.value} {route.path}" + ) else: # Minimal handler: json.loads → Model(**data) → handler(model) → json.dumps enhanced_handler = create_fast_model_handler( @@ -754,7 +799,9 @@ def _register_routes_with_zig(self): route.handler, schema_json, ) - print(f"{CHECK_MARK} [model_sync+dhi] {route.method.value} {route.path}") + print( + f"{CHECK_MARK} [model_sync+dhi] {route.method.value} {route.path}" + ) else: self.zig_server.add_route_model( route.method.value, @@ -823,6 +870,25 @@ def _register_routes_with_zig(self): route.handler, # Original async handler ) print(f"{CHECK_MARK} [{handler_type}] {route.method.value} {route.path}") + elif handler_type in ("form_sync", "file_sync"): + # FORM/FILE PATH: Enhanced handler with dedicated Zig dispatch + # Zig skips DHI validation and parses multipart/urlencoded natively + enhanced_handler = create_enhanced_handler(route.handler, route) + if self._middleware_instances: + enhanced_handler = self._wrap_with_middleware(enhanced_handler) + registered_type = "enhanced" + else: + registered_type = handler_type + param_types_json = json.dumps(param_types) + self.zig_server.add_route_fast( + route.method.value, + route.path, + enhanced_handler, + registered_type, + param_types_json, + route.handler, + ) + print(f"{CHECK_MARK} [{registered_type}] {route.method.value} {route.path}") else: # ENHANCED PATH: Full Python wrapper needed enhanced_handler = create_enhanced_handler(route.handler, route) @@ -903,7 +969,9 @@ def run(self, host: str = "127.0.0.1", port: int = 8000, **kwargs): # Initialize Zig server if not self._initialize_zig_server(host, port): print(f"{CROSS_MARK} Failed to initialize Zig server") - print(f" Use an ASGI server as fallback: uvicorn main:app --host {host} --port {port}") + print( + f" Use an ASGI server as fallback: uvicorn main:app --host {host} --port {port}" + ) return # Print integration info diff --git a/zig/src/server.zig b/zig/src/server.zig index 1fb52d6..706fc39 100644 --- a/zig/src/server.zig +++ b/zig/src/server.zig @@ -114,6 +114,8 @@ const HandlerType = enum(u8) { simple_sync, model_sync, body_sync, + form_sync, + file_sync, enhanced, }; @@ -122,6 +124,8 @@ fn parseHandlerType(s: []const u8) HandlerType { if (std.mem.eql(u8, s, "simple_sync")) return .simple_sync; if (std.mem.eql(u8, s, "model_sync")) return .model_sync; if (std.mem.eql(u8, s, "body_sync")) return .body_sync; + if (std.mem.eql(u8, s, "form_sync")) return .form_sync; + if (std.mem.eql(u8, s, "file_sync")) return .file_sync; return .enhanced; } @@ -1107,10 +1111,11 @@ fn handleOneRequest(stream: std.net.Stream, tstate: ?*anyopaque) !void { } // DHI validation for model_sync — single parse, retain tree + // Skip for form/file routes (body is multipart/urlencoded, not JSON) var cached_parse: ?std.json.Parsed(std.json.Value) = null; defer if (cached_parse) |*cp| cp.deinit(); - if (body.len > 0) { + if (body.len > 0 and entry.handler_tag != .form_sync and entry.handler_tag != .file_sync) { const ms = getModelSchemas(); if (ms.get(match.handler_key)) |schema| { const vr = dhi.validateJsonRetainParsed(body, &schema); @@ -1145,6 +1150,11 @@ fn handleOneRequest(stream: std.net.Stream, tstate: ?*anyopaque) !void { .body_sync => { callPythonHandlerDirect(tstate, entry, query_string, body, &match.params, stream); }, + .form_sync, .file_sync => { + const resp = callPythonHandler(tstate, entry, method, path, query_string, body, headers.items, &match.params); + defer resp.deinit(); + sendResponse(stream, resp.status_code, resp.content_type, resp.body); + }, .enhanced => { const resp = callPythonHandler(tstate, entry, method, path, query_string, body, headers.items, &match.params); defer resp.deinit(); From 560155f50a7f332002a74680c424f915e874b834 Mon Sep 17 00:00:00 2001 From: justrach <54503978+justrach@users.noreply.github.com> Date: Sat, 11 Apr 2026 23:36:00 +0800 Subject: [PATCH 5/5] fix: add missing blank line in multipart test bodies + free urlencoded field strings Multipart format requires \r\n\r\n between headers and body. Tests were missing the extra \r\n, so parseMultipart found 0 fields. Also fix memory leak: UrlencodedResult.deinit now frees individual field name/value strings allocated by percentDecodeAlloc. Co-authored-by: trilokagent <275208033+trilokagent@users.noreply.github.com> --- zig/src/multipart.zig | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/zig/src/multipart.zig b/zig/src/multipart.zig index e7af22c..9d2f048 100644 --- a/zig/src/multipart.zig +++ b/zig/src/multipart.zig @@ -26,6 +26,10 @@ pub const UrlencodedResult = struct { fields: []FormField, pub fn deinit(self: *const UrlencodedResult, alloc: std.mem.Allocator) void { + for (self.fields) |f| { + alloc.free(f.name); + alloc.free(f.value); + } alloc.free(self.fields); } }; @@ -252,7 +256,7 @@ test "multipart: simple form field" { const body = "--boundary\r\n" ++ \\Content-Disposition: form-data; name="username" ++ - "\r\n" ++ + "\r\n\r\n" ++ \\john ++ "\r\n--boundary--\r\n"; @@ -271,7 +275,7 @@ test "multipart: file upload" { "\r\n" ++ \\Content-Type: text/plain ++ - "\r\n" ++ + "\r\n\r\n" ++ \\hello world ++ "\r\n--boundary--\r\n"; @@ -289,7 +293,7 @@ test "multipart: mixed fields and files" { const body = "--boundary\r\n" ++ \\Content-Disposition: form-data; name="title" ++ - "\r\n" ++ + "\r\n\r\n" ++ \\My Doc ++ "\r\n--boundary\r\n" ++ @@ -298,7 +302,7 @@ test "multipart: mixed fields and files" { "\r\n" ++ \\Content-Type: application/pdf ++ - "\r\n" ++ + "\r\n\r\n" ++ \\%PDF-1.4 ++ "\r\n--boundary--\r\n";