Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .jules/bolt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

## 2024-05-19 - Optimization: Dominant packet path extraction
**Learning:** Zig switch statements over integer values compile down to jump tables or sequential branches depending on the optimizer. In hot loops like `PacketType.classify(pkt)` checking, an enum `switch` can cause a pipeline stall on jump evaluation.
**Action:** Extracting the dominant data-plane case (e.g. `if (pkt_type == .wg_transport)`) explicitly before the `switch` statement forces the compiler to emit a direct branch instruction, which the CPU branch predictor handles much more efficiently, avoiding jump table overhead on the hot path. Remember to include the unreachable case within the switch so the code compiles.
Comment on lines +2 to +4
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This note mixes “switch statements over integer values” with “an enum switch”, which is a bit inconsistent/confusing relative to the actual implementation (the hot switch in PacketType.classify is over an integer msg_type, and the event loops switch over the PacketType enum). Consider rewording to consistently describe which switch is being optimized (integer msg_type vs PacketType enum) so future readers don’t misapply the guidance.

Copilot uses AI. Check for mistakes.
102 changes: 54 additions & 48 deletions src/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2533,7 +2533,25 @@ fn processIncomingPacket(
) void {
const Device = lib.wireguard.Device;

switch (Device.PacketType.classify(pkt)) {
const pkt_type = Device.PacketType.classify(pkt);
// Optimization: Extract dominant data-plane case to explicit if branch
if (pkt_type == .wg_transport) {
if (n_decrypted.* < 64) {
if (wg_dev.decryptTransport(pkt, &decrypt_storage[n_decrypted.*])) |result| {
// Check service filter before buffering
const PolicyMod = lib.services.Policy;
if (PolicyMod.parseTransportHeader(decrypt_storage[n_decrypted.*][0..result.len])) |ti| {
if (wg_dev.peers[result.slot]) |peer| {
const org_pk = if (swim.membership.peers.getPtr(peer.identity_key)) |mp| mp.org_pubkey else null;
if (!service_filter.check(peer.identity_key, org_pk, ti.proto, ti.dst_port)) return;
}
}
decrypt_lens[n_decrypted.*] = result.len;
decrypt_slots[n_decrypted.*] = result.slot;
n_decrypted.* += 1;
} else |_| {}
}
} else switch (pkt_type) {
.wg_handshake_init => {
if (pkt.len >= @sizeOf(lib.wireguard.Noise.HandshakeInitiation)) {
const msg: *const lib.wireguard.Noise.HandshakeInitiation = @ptrCast(@alignCast(pkt.ptr));
Expand All @@ -2556,23 +2574,7 @@ fn processIncomingPacket(
} else |_| {}
}
},
.wg_transport => {
if (n_decrypted.* < 64) {
if (wg_dev.decryptTransport(pkt, &decrypt_storage[n_decrypted.*])) |result| {
// Check service filter before buffering
const PolicyMod = lib.services.Policy;
if (PolicyMod.parseTransportHeader(decrypt_storage[n_decrypted.*][0..result.len])) |ti| {
if (wg_dev.peers[result.slot]) |peer| {
const org_pk = if (swim.membership.peers.getPtr(peer.identity_key)) |mp| mp.org_pubkey else null;
if (!service_filter.check(peer.identity_key, org_pk, ti.proto, ti.dst_port)) return;
}
}
decrypt_lens[n_decrypted.*] = result.len;
decrypt_slots[n_decrypted.*] = result.slot;
n_decrypted.* += 1;
} else |_| {}
}
},
.wg_transport => unreachable,
.wg_cookie => {},
.stun => swim.feedPacket(pkt, sender_addr, sender_port),
.swim => swim.feedPacket(pkt, sender_addr, sender_port),
Expand Down Expand Up @@ -2615,7 +2617,22 @@ fn windowsEventLoop(
const recv = (udp_sock.recvFrom(&udp_recv_buf) catch break) orelse break;
const pkt = recv.data;

switch (Device.PacketType.classify(pkt)) {
const pkt_type = Device.PacketType.classify(pkt);
// Optimization: Extract dominant data-plane case to explicit if branch
if (pkt_type == .wg_transport) {
// Decrypt WG transport → write plaintext to Wintun
if (wg_dev.decryptTransport(pkt, &decrypt_buf)) |result| {
// Apply service filter before writing to TUN
const PolicyMod = lib.services.Policy;
if (PolicyMod.parseTransportHeader(decrypt_buf[0..result.len])) |ti| {
if (wg_dev.peers[result.slot]) |peer| {
const org_pk = if (swim.membership.peers.getPtr(peer.identity_key)) |mp| mp.org_pubkey else null;
if (!service_filter.check(peer.identity_key, org_pk, ti.proto, ti.dst_port)) continue;
}
}
tun_dev.write(decrypt_buf[0..result.len]) catch {};
} else |_| {}
} else switch (pkt_type) {
.wg_handshake_init => {
if (pkt.len >= @sizeOf(Noise.HandshakeInitiation)) {
const msg: *const Noise.HandshakeInitiation = @ptrCast(@alignCast(pkt.ptr));
Expand All @@ -2638,20 +2655,7 @@ fn windowsEventLoop(
} else |_| {}
}
},
.wg_transport => {
// Decrypt WG transport → write plaintext to Wintun
if (wg_dev.decryptTransport(pkt, &decrypt_buf)) |result| {
// Apply service filter before writing to TUN
const PolicyMod = lib.services.Policy;
if (PolicyMod.parseTransportHeader(decrypt_buf[0..result.len])) |ti| {
if (wg_dev.peers[result.slot]) |peer| {
const org_pk = if (swim.membership.peers.getPtr(peer.identity_key)) |mp| mp.org_pubkey else null;
if (!service_filter.check(peer.identity_key, org_pk, ti.proto, ti.dst_port)) continue;
}
}
tun_dev.write(decrypt_buf[0..result.len]) catch {};
} else |_| {}
},
.wg_transport => unreachable,
.wg_cookie => {},
// SWIM and STUN packets: feed to SWIM via feedPacket (non-blocking)
.stun => swim.feedPacket(pkt, recv.sender_addr, recv.sender_port),
Expand Down Expand Up @@ -2750,7 +2754,22 @@ fn macosEventLoop(
const recv = (udp_sock.recvFrom(&udp_recv_buf) catch break) orelse break;
const pkt = recv.data;

switch (Device.PacketType.classify(pkt)) {
const pkt_type = Device.PacketType.classify(pkt);
// Optimization: Extract dominant data-plane case to explicit if branch
if (pkt_type == .wg_transport) {
// Decrypt WG transport → write plaintext to utun
if (wg_dev.decryptTransport(pkt, &decrypt_buf)) |result| {
// Apply service filter before writing to TUN
const PolicyMod = lib.services.Policy;
if (PolicyMod.parseTransportHeader(decrypt_buf[0..result.len])) |ti| {
if (wg_dev.peers[result.slot]) |peer| {
const org_pk = if (swim.membership.peers.getPtr(peer.identity_key)) |mp| mp.org_pubkey else null;
if (!service_filter.check(peer.identity_key, org_pk, ti.proto, ti.dst_port)) continue;
}
}
tun_dev.write(decrypt_buf[0..result.len]) catch {};
} else |_| {}
} else switch (pkt_type) {
.wg_handshake_init => {
if (pkt.len >= @sizeOf(Noise.HandshakeInitiation)) {
const msg: *const Noise.HandshakeInitiation = @ptrCast(@alignCast(pkt.ptr));
Expand All @@ -2773,20 +2792,7 @@ fn macosEventLoop(
} else |_| {}
}
},
.wg_transport => {
// Decrypt WG transport → write plaintext to utun
if (wg_dev.decryptTransport(pkt, &decrypt_buf)) |result| {
// Apply service filter before writing to TUN
const PolicyMod = lib.services.Policy;
if (PolicyMod.parseTransportHeader(decrypt_buf[0..result.len])) |ti| {
if (wg_dev.peers[result.slot]) |peer| {
const org_pk = if (swim.membership.peers.getPtr(peer.identity_key)) |mp| mp.org_pubkey else null;
if (!service_filter.check(peer.identity_key, org_pk, ti.proto, ti.dst_port)) continue;
}
}
tun_dev.write(decrypt_buf[0..result.len]) catch {};
} else |_| {}
},
.wg_transport => unreachable,
.wg_cookie => {},
.stun => swim.feedPacket(pkt, recv.sender_addr, recv.sender_port),
.swim => swim.feedPacket(pkt, recv.sender_addr, recv.sender_port),
Expand Down
8 changes: 6 additions & 2 deletions src/wireguard/device.zig
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,20 @@ pub const PacketType = enum {
stun, // STUN binding response
unknown,

pub fn classify(data: []const u8) PacketType {
pub inline fn classify(data: []const u8) PacketType {
if (data.len < 4) return .unknown;

// WireGuard messages: first byte is type, next 3 are zeros
const msg_type = std.mem.readInt(u32, data[0..4], .little);

// Optimization: Extract dominant data-plane case to explicit if branch to avoid jump table
if (msg_type == 4) return .wg_transport;
Comment on lines +27 to +34
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

classify is marked inline, which in Zig is a forced inline and (per repo-wide search) this is the only inline fn usage under src/. Forced inlining can increase code size/compile time and makes this public API harder to use in contexts where a function pointer might be needed. Unless there’s a measured win specifically from forced inlining, consider reverting to pub fn classify and relying on the optimizer (or keep it non-inline and benchmark both).

Copilot uses AI. Check for mistakes.

return switch (msg_type) {
1 => .wg_handshake_init,
2 => .wg_handshake_resp,
3 => .wg_cookie,
4 => .wg_transport,
4 => unreachable,
else => blk: {
// STUN: check for magic cookie at bytes 4-7
if (data.len >= 8) {
Expand Down