-
Notifications
You must be signed in to change notification settings - Fork 0
⚡ Bolt: Optimize packet classification branch prediction #64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| ## 2024-05-24 - Branch Misprediction Optimization for Packet Classification | ||
| **Learning:** In Zig, standard `switch` statements on integers compile to jump tables. Extracting the dominant case (like data-plane packets `wg_transport` which is 99.9% of traffic) into an explicit `if` branch before a `switch` improves branch prediction and avoids jump table overhead. | ||
| **Action:** When classifying packets on the hot path, explicitly use an `if` statement for the most common case, falling back to a `switch` for the less common ones. Also, use the `inline` keyword on small, frequently called classification functions to eliminate function call overhead. | ||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2486,7 +2486,29 @@ fn processIncomingPacket( | |||||||||||
| ) void { | ||||||||||||
| const Device = lib.wireguard.Device; | ||||||||||||
|
|
||||||||||||
| switch (Device.PacketType.classify(pkt)) { | ||||||||||||
| const pkt_type = Device.PacketType.classify(pkt); | ||||||||||||
|
|
||||||||||||
| // Optimization: Fast path for data plane transport to avoid jump table overhead | ||||||||||||
| if (pkt_type == .wg_transport) { | ||||||||||||
| if (n_decrypted.* < 64) { | ||||||||||||
| if (wg_dev.decryptTransport(pkt, &decrypt_storage[n_decrypted.*])) |result| { | ||||||||||||
| // Check service filter before buffering | ||||||||||||
| const PolicyMod = lib.services.Policy; | ||||||||||||
| if (PolicyMod.parseTransportHeader(decrypt_storage[n_decrypted.*][0..result.len])) |ti| { | ||||||||||||
| if (wg_dev.peers[result.slot]) |peer| { | ||||||||||||
| const org_pk = if (swim.membership.peers.getPtr(peer.identity_key)) |mp| mp.org_pubkey else null; | ||||||||||||
| if (!service_filter.check(peer.identity_key, org_pk, ti.proto, ti.dst_port)) return; | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| decrypt_lens[n_decrypted.*] = result.len; | ||||||||||||
| decrypt_slots[n_decrypted.*] = result.slot; | ||||||||||||
| n_decrypted.* += 1; | ||||||||||||
| } else |_| {} | ||||||||||||
| } | ||||||||||||
| return; | ||||||||||||
| } | ||||||||||||
|
Comment on lines
+2491
to
+2509
|
||||||||||||
|
|
||||||||||||
| switch (pkt_type) { | ||||||||||||
| .wg_handshake_init => { | ||||||||||||
| if (pkt.len >= @sizeOf(lib.wireguard.Noise.HandshakeInitiation)) { | ||||||||||||
| const msg: *const lib.wireguard.Noise.HandshakeInitiation = @ptrCast(@alignCast(pkt.ptr)); | ||||||||||||
|
|
@@ -2509,27 +2531,10 @@ fn processIncomingPacket( | |||||||||||
| } else |_| {} | ||||||||||||
| } | ||||||||||||
| }, | ||||||||||||
| .wg_transport => { | ||||||||||||
| if (n_decrypted.* < 64) { | ||||||||||||
| if (wg_dev.decryptTransport(pkt, &decrypt_storage[n_decrypted.*])) |result| { | ||||||||||||
| // Check service filter before buffering | ||||||||||||
| const PolicyMod = lib.services.Policy; | ||||||||||||
| if (PolicyMod.parseTransportHeader(decrypt_storage[n_decrypted.*][0..result.len])) |ti| { | ||||||||||||
| if (wg_dev.peers[result.slot]) |peer| { | ||||||||||||
| const org_pk = if (swim.membership.peers.getPtr(peer.identity_key)) |mp| mp.org_pubkey else null; | ||||||||||||
| if (!service_filter.check(peer.identity_key, org_pk, ti.proto, ti.dst_port)) return; | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| decrypt_lens[n_decrypted.*] = result.len; | ||||||||||||
| decrypt_slots[n_decrypted.*] = result.slot; | ||||||||||||
| n_decrypted.* += 1; | ||||||||||||
| } else |_| {} | ||||||||||||
| } | ||||||||||||
| }, | ||||||||||||
| .wg_cookie => {}, | ||||||||||||
| .stun => swim.feedPacket(pkt, sender_addr, sender_port), | ||||||||||||
| .swim => swim.feedPacket(pkt, sender_addr, sender_port), | ||||||||||||
| .unknown => {}, | ||||||||||||
| .wg_transport, .unknown => {}, | ||||||||||||
|
||||||||||||
| .wg_transport, .unknown => {}, | |
| .wg_transport => { | |
| // handled above in fast path | |
| }, | |
| .unknown => {}, |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -24,16 +24,20 @@ pub const PacketType = enum { | |||||||||||||||||||||||||||||||||||||||||||||||||||
| stun, // STUN binding response | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| unknown, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| pub fn classify(data: []const u8) PacketType { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// Optimization: Inline and extract dominant case to avoid jump table overhead | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| pub inline fn classify(data: []const u8) PacketType { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (data.len < 4) return .unknown; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| // WireGuard messages: first byte is type, next 3 are zeros | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| const msg_type = std.mem.readInt(u32, data[0..4], .little); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Fast path for data-plane transport packets (99.9% of traffic) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (msg_type == 4) return .wg_transport; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| return switch (msg_type) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| 1 => .wg_handshake_init, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| 2 => .wg_handshake_resp, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| 3 => .wg_cookie, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+27
to
40
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// Optimization: Inline and extract dominant case to avoid jump table overhead | |
| pub inline fn classify(data: []const u8) PacketType { | |
| if (data.len < 4) return .unknown; | |
| // WireGuard messages: first byte is type, next 3 are zeros | |
| const msg_type = std.mem.readInt(u32, data[0..4], .little); | |
| // Fast path for data-plane transport packets (99.9% of traffic) | |
| if (msg_type == 4) return .wg_transport; | |
| return switch (msg_type) { | |
| 1 => .wg_handshake_init, | |
| 2 => .wg_handshake_resp, | |
| 3 => .wg_cookie, | |
| pub fn classify(data: []const u8) PacketType { | |
| if (data.len < 4) return .unknown; | |
| // WireGuard messages: first byte is type, next 3 are zeros | |
| const msg_type = std.mem.readInt(u32, data[0..4], .little); | |
| return switch (msg_type) { | |
| 1 => .wg_handshake_init, | |
| 2 => .wg_handshake_resp, | |
| 3 => .wg_cookie, | |
| 4 => .wg_transport, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The claim that "In Zig, standard
switchstatements on integers compile to jump tables" is an oversimplification. Zig uses LLVM as its backend in release modes, and LLVM'sSimplifyCFGpass decides whether to use a jump table, branch chain, or lookup table based on case density and count. For a small, dense switch on values 1–4, LLVM will typically not generate a jump table. This documentation could mislead future developers into unnecessary manual optimizations. Consider qualifying the statement (e.g., "for larger or sparser switches") or removing the unverified claim.