From 374cbffd7789589b15ea0cd6fff33b2318386b35 Mon Sep 17 00:00:00 2001 From: Steffen Smolka Date: Wed, 10 Jun 2026 03:17:01 -0700 Subject: [PATCH] [NetKAT] Intern decision nodes via per-field unique tables keyed by node index. Previously, the unique tables (packet_by_node_, transformer_by_node_) stored a full copy of every decision node as the map key, duplicating all node storage. This was especially costly for PacketTransformerManager, whose nodes own heap-allocated btree maps. Now the tables store 4-byte node indices and hash/compare by dereferencing into nodes_, so each node is stored exactly once. Splitting the tables by packet field keeps each table - and thus its probe sequences - small, and lays the groundwork for storing the nodes themselves by field/level (standard practice in BDD packages, where it improves locality and enables variable reordering). A follow-up change builds on this. Implementation notes: * Interning probes the table with the candidate node via heterogeneous lookup (no node is added to nodes_ unless it is new), keeping the common already-interned case copy- and allocation-free. * The table functors reference nodes_ through a heap-allocated location slot so that they survive manager moves; move operations repoint the slot. Benchmarks are flat overall: first-time compilation is unchanged, recompile microbenchmarks regress by ~8% (~50ns) due to the indirection in the unique tables, in exchange for halving node-related memory. Co-Authored-By: Claude Fable 5 --- netkat/BUILD.bazel | 2 + netkat/packet_field.h | 5 +++ netkat/packet_set.cc | 83 +++++++++++++++++++++++++++------- netkat/packet_set.h | 68 +++++++++++++++++++++++++--- netkat/packet_transformer.cc | 86 ++++++++++++++++++++++++++++-------- netkat/packet_transformer.h | 69 ++++++++++++++++++++++++++--- 6 files changed, 265 insertions(+), 48 deletions(-) diff --git a/netkat/BUILD.bazel b/netkat/BUILD.bazel index 9121fef..33d04a0 100644 --- a/netkat/BUILD.bazel +++ b/netkat/BUILD.bazel @@ -132,6 +132,7 @@ cc_library( "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -376,6 +377,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/netkat/packet_field.h b/netkat/packet_field.h index 35400cf..2bccbbc 100644 --- a/netkat/packet_field.h +++ b/netkat/packet_field.h @@ -78,6 +78,11 @@ class [[nodiscard]] PacketFieldHandle { // 2^16 ~= 65k fields. uint16_t index_; explicit PacketFieldHandle(uint16_t index) : index_(index) {} + + // `PacketSetManager` and `PacketTransformerManager` organize their node + // storage by field, using `index_` to address the per-field data structures. + friend class PacketSetManager; + friend class PacketTransformerManager; }; // Protect against regressions in the memory layout, as it affects performance. diff --git a/netkat/packet_set.cc b/netkat/packet_set.cc index 2c7bd1c..0b14380 100644 --- a/netkat/packet_set.cc +++ b/netkat/packet_set.cc @@ -63,6 +63,25 @@ std::string PacketSetHandle::ToString() const { } } +PacketSetManager::PacketSetManager(PacketSetManager&& other) + : nodes_(std::move(other.nodes_)), + nodes_location_(std::move(other.nodes_location_)), + unique_table_by_field_(std::move(other.unique_table_by_field_)), + packet_set_by_hash_(std::move(other.packet_set_by_hash_)), + field_manager_(std::move(other.field_manager_)) { + *nodes_location_ = &nodes_; +} + +PacketSetManager& PacketSetManager::operator=(PacketSetManager&& other) { + nodes_ = std::move(other.nodes_); + nodes_location_ = std::move(other.nodes_location_); + unique_table_by_field_ = std::move(other.unique_table_by_field_); + packet_set_by_hash_ = std::move(other.packet_set_by_hash_); + field_manager_ = std::move(other.field_manager_); + *nodes_location_ = &nodes_; + return *this; +} + PacketSetHandle PacketSetManager::EmptySet() const { return PacketSetHandle(SentinelNodeIndex::kEmptySet); } @@ -109,16 +128,33 @@ PacketSetHandle PacketSetManager::NodeToPacket(DecisionNode&& node) { } #endif - auto [it, inserted] = - packet_by_node_.try_emplace(node, PacketSetHandle(nodes_.size())); - if (inserted) { - nodes_.push_back(std::move(node)); - LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) - << "Internal invariant violated: Proper and sentinel node indices must " - "be disjoint. This indicates that we allocated more nodes than are " - "supported (> 2^32 - 2)."; + // Probe the unique table by node content first (heterogeneous lookup): the + // common case is that an equal node has already been interned, and probing + // with the candidate node avoids touching `nodes_` in that case. + UniqueNodeTable& unique_table = GetOrCreateUniqueTable(node.field); + if (auto it = unique_table.find(node); it != unique_table.end()) { + return PacketSetHandle(*it); } - return it->second; + uint32_t node_index = nodes_.size(); + nodes_.push_back(std::move(node)); + unique_table.insert(node_index); + LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) + << "Internal invariant violated: Proper and sentinel node indices must " + "be disjoint. This indicates that we allocated more nodes than are " + "supported (> 2^32 - 2)."; + return PacketSetHandle(node_index); +} + +PacketSetManager::UniqueNodeTable& PacketSetManager::GetOrCreateUniqueTable( + PacketFieldHandle field) { + if (field.index_ >= unique_table_by_field_.size()) { + unique_table_by_field_.resize( + field.index_ + 1, + UniqueNodeTable(/*bucket_count=*/0, + InternedNodeHash{nodes_location_.get()}, + InternedNodeEq{nodes_location_.get()})); + } + return unique_table_by_field_[field.index_]; } bool PacketSetManager::Contains(PacketSetHandle packet_set, @@ -483,16 +519,29 @@ absl::Status PacketSetManager::CheckInternalInvariants() const { // Invariant: Proper and sentinel node indices are disjoint. RET_CHECK(nodes_.size() <= SentinelNodeIndex::kMinSentinel); - // Invariant: `packet_by_node_[n] = s` iff `nodes_[s.node_index_] == n`. - for (const auto& [node, packet] : packet_by_node_) { - RET_CHECK(packet.node_index_ < nodes_.size()); - RET_CHECK(nodes_[packet.node_index_] == node); + // Invariant: `unique_table_by_field_[f]` contains `i` iff + // `nodes_[i].field.index_ == f`. Every valid node index is in exactly one + // table, and no two interned nodes are equal. + size_t total_table_size = 0; + for (size_t f = 0; f < unique_table_by_field_.size(); ++f) { + const UniqueNodeTable& unique_table = unique_table_by_field_[f]; + total_table_size += unique_table.size(); + for (uint32_t node_index : unique_table) { + RET_CHECK(node_index < nodes_.size()); + RET_CHECK(nodes_[node_index].field.index_ == f); + } } - for (int i = 0; i < nodes_.size(); ++i) { + RET_CHECK(total_table_size == nodes_.size()); + for (uint32_t i = 0; i < nodes_.size(); ++i) { const DecisionNode& node = nodes_[i]; - auto it = packet_by_node_.find(node); - RET_CHECK(it != packet_by_node_.end()); - RET_CHECK(it->second == PacketSetHandle(i)); + RET_CHECK(node.field.index_ < unique_table_by_field_.size()); + const UniqueNodeTable& unique_table = + unique_table_by_field_[node.field.index_]; + // Looking up `i` probes by node content; finding exactly `i` proves that + // no other interned node has the same content. + auto it = unique_table.find(i); + RET_CHECK(it != unique_table.end()); + RET_CHECK(*it == i); } // Node Invariants. diff --git a/netkat/packet_set.h b/netkat/packet_set.h index f11088f..c9237bb 100644 --- a/netkat/packet_set.h +++ b/netkat/packet_set.h @@ -46,12 +46,15 @@ #include #include +#include #include #include #include #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" @@ -139,10 +142,13 @@ class PacketSetManager { PacketSetManager() = default; // The class is move-only: not copyable, but movable. + // Moves are implemented manually (in the cc file) because the unique tables + // reference `nodes_` through `nodes_location_`, which must be repointed at + // the new `nodes_` member on move. PacketSetManager(const PacketSetManager&) = delete; PacketSetManager& operator=(const PacketSetManager&) = delete; - PacketSetManager(PacketSetManager&&) = default; - PacketSetManager& operator=(PacketSetManager&&) = default; + PacketSetManager(PacketSetManager&&); + PacketSetManager& operator=(PacketSetManager&&); // Returns true iff this packet set represents the empty set of packets. bool IsEmptySet(PacketSetHandle packet_set) const; @@ -356,11 +362,61 @@ class PacketSetManager { // `And`, `Or`, `Not`). The class also avoids expensive relocations. PagedStableVector nodes_; - // A so called "unique table" to ensure each node is only added to `nodes_` - // once, and thus has a unique `PacketSetHandle::node_index`. + // The location of `nodes_`, behind a level of indirection that remains + // stable when the manager object is moved: the unique tables below hash and + // compare node indices by dereferencing into `nodes_`, and their + // hasher/equality functors would otherwise dangle on move. Move operations + // repoint the location at the new manager's `nodes_` member. + std::unique_ptr*> + nodes_location_ = std::make_unique< + const PagedStableVector*>(&nodes_); + + // Hasher and equality for unique table entries, which are indices into + // `nodes_`. Hashing/comparing the *node* (rather than the index) is what + // makes the tables deduplicate by node content. + // The `DecisionNode` overloads enable heterogeneous lookup, so that a + // candidate node can be probed before it is added to `nodes_`. + struct InternedNodeHash { + using is_transparent = void; + const PagedStableVector* const* nodes; + size_t operator()(uint32_t node_index) const { + return absl::HashOf((**nodes)[node_index]); + } + size_t operator()(const DecisionNode& node) const { + return absl::HashOf(node); + } + }; + struct InternedNodeEq { + using is_transparent = void; + const PagedStableVector* const* nodes; + bool operator()(uint32_t a, uint32_t b) const { + return (**nodes)[a] == (**nodes)[b]; + } + bool operator()(uint32_t a, const DecisionNode& b) const { + return (**nodes)[a] == b; + } + bool operator()(const DecisionNode& a, uint32_t b) const { + return (**nodes)[b] == a; + } + }; + using UniqueNodeTable = + absl::flat_hash_set; + + // So called "unique tables" to ensure each node is only added to `nodes_` + // once, and thus has a unique `PacketSetHandle::node_index`. One table per + // packet field: a node's entry lives in the table of the field it branches + // on. Splitting by field keeps the tables, and thus probe sequences, small, + // and storing indices (instead of node copies) keeps each node stored + // exactly once, in `nodes_`. // - // INVARIANT: `packet_by_node_[n] = s` iff `nodes_[s.node_index_] == n`. - absl::flat_hash_map packet_by_node_; + // INVARIANT: `unique_table_by_field_[f]` contains `i` iff + // `nodes_[i].field.index_ == f`. Every valid node index is contained in + // exactly one table. + std::vector unique_table_by_field_; + + // Returns the unique table for nodes branching on `field`, creating it (and + // any tables for smaller fields) if it does not exist yet. + UniqueNodeTable& GetOrCreateUniqueTable(PacketFieldHandle field); // A map of a given `PredicateProto` to a `PacketSetHandle`. // diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index a125c8b..9a9915e 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -61,6 +61,27 @@ enum SentinelNodeIndex : uint32_t { PacketTransformerHandle::PacketTransformerHandle() : node_index_(SentinelNodeIndex::kDeny) {} +PacketTransformerManager::PacketTransformerManager( + PacketTransformerManager&& other) + : nodes_(std::move(other.nodes_)), + nodes_location_(std::move(other.nodes_location_)), + unique_table_by_field_(std::move(other.unique_table_by_field_)), + transformer_by_hash_(std::move(other.transformer_by_hash_)), + packet_set_manager_(std::move(other.packet_set_manager_)) { + *nodes_location_ = &nodes_; +} + +PacketTransformerManager& PacketTransformerManager::operator=( + PacketTransformerManager&& other) { + nodes_ = std::move(other.nodes_); + nodes_location_ = std::move(other.nodes_location_); + unique_table_by_field_ = std::move(other.unique_table_by_field_); + transformer_by_hash_ = std::move(other.transformer_by_hash_); + packet_set_manager_ = std::move(other.packet_set_manager_); + *nodes_location_ = &nodes_; + return *this; +} + std::string PacketTransformerHandle::ToString() const { if (node_index_ == SentinelNodeIndex::kDeny) { return "PacketTransformerHandle"; @@ -172,16 +193,33 @@ PacketTransformerHandle PacketTransformerManager::NodeToTransformer( node.default_branch_by_field_modification.empty()) return node.default_branch; - auto [it, inserted] = transformer_by_node_.try_emplace( - node, PacketTransformerHandle(nodes_.size())); - if (inserted) { - nodes_.push_back(std::move(node)); - LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) - << "Internal invariant violated: Proper and sentinel node indices must " - "be disjoint. This indicates that we allocated more nodes than are " - "supported (> 2^32 - 2)."; + // Probe the unique table by node content first (heterogeneous lookup): the + // common case is that an equal node has already been interned, and probing + // with the candidate node avoids touching `nodes_` in that case. + UniqueNodeTable& unique_table = GetOrCreateUniqueTable(node.field); + if (auto it = unique_table.find(node); it != unique_table.end()) { + return PacketTransformerHandle(*it); } - return it->second; + uint32_t node_index = nodes_.size(); + nodes_.push_back(std::move(node)); + unique_table.insert(node_index); + LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) + << "Internal invariant violated: Proper and sentinel node indices must " + "be disjoint. This indicates that we allocated more nodes than are " + "supported (> 2^32 - 2)."; + return PacketTransformerHandle(node_index); +} + +PacketTransformerManager::UniqueNodeTable& +PacketTransformerManager::GetOrCreateUniqueTable(PacketFieldHandle field) { + if (field.index_ >= unique_table_by_field_.size()) { + unique_table_by_field_.resize( + field.index_ + 1, + UniqueNodeTable(/*bucket_count=*/0, + InternedNodeHash{nodes_location_.get()}, + InternedNodeEq{nodes_location_.get()})); + } + return unique_table_by_field_[field.index_]; } bool PacketTransformerManager::IsDeny( @@ -1104,17 +1142,29 @@ absl::Status PacketTransformerManager::CheckInternalInvariants() const { // Invariant: Proper and sentinel node indices are disjoint. RET_CHECK(nodes_.size() <= SentinelNodeIndex::kMinSentinel); - // Invariant: `transformer_by_node_[n] = s` iff `nodes_[s.node_index_] == - // n`. - for (const auto& [node, transformer] : transformer_by_node_) { - RET_CHECK(transformer.node_index_ < nodes_.size()); - RET_CHECK(nodes_[transformer.node_index_] == node); + // Invariant: `unique_table_by_field_[f]` contains `i` iff + // `nodes_[i].field.index_ == f`. Every valid node index is in exactly one + // table, and no two interned nodes are equal. + size_t total_table_size = 0; + for (size_t f = 0; f < unique_table_by_field_.size(); ++f) { + const UniqueNodeTable& unique_table = unique_table_by_field_[f]; + total_table_size += unique_table.size(); + for (uint32_t node_index : unique_table) { + RET_CHECK(node_index < nodes_.size()); + RET_CHECK(nodes_[node_index].field.index_ == f); + } } - for (int i = 0; i < nodes_.size(); ++i) { + RET_CHECK(total_table_size == nodes_.size()); + for (uint32_t i = 0; i < nodes_.size(); ++i) { const DecisionNode& node = nodes_[i]; - auto it = transformer_by_node_.find(node); - RET_CHECK(it != transformer_by_node_.end()); - RET_CHECK(it->second == PacketTransformerHandle(i)); + RET_CHECK(node.field.index_ < unique_table_by_field_.size()); + const UniqueNodeTable& unique_table = + unique_table_by_field_[node.field.index_]; + // Looking up `i` probes by node content; finding exactly `i` proves that + // no other interned node has the same content. + auto it = unique_table.find(i); + RET_CHECK(it != unique_table.end()); + RET_CHECK(*it == i); } // Node Invariants. diff --git a/netkat/packet_transformer.h b/netkat/packet_transformer.h index 4c9c90d..7a3d8be 100644 --- a/netkat/packet_transformer.h +++ b/netkat/packet_transformer.h @@ -39,12 +39,15 @@ #include #include +#include #include #include +#include #include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" @@ -147,10 +150,13 @@ class PacketTransformerManager { // The class is move-only: not copyable, but movable. // `PacketSetHandles` and `PacketTransformerHandles` returned by this class // are not invalidated on move. + // Moves are implemented manually (in the cc file) because the unique tables + // reference `nodes_` through `nodes_location_`, which must be repointed at + // the new `nodes_` member on move. PacketTransformerManager(const PacketTransformerManager&) = delete; PacketTransformerManager& operator=(const PacketTransformerManager&) = delete; - PacketTransformerManager(PacketTransformerManager&&) = default; - PacketTransformerManager& operator=(PacketTransformerManager&&) = default; + PacketTransformerManager(PacketTransformerManager&&); + PacketTransformerManager& operator=(PacketTransformerManager&&); // Returns the `PacketSetManager` used by this object to compile // predicates. @@ -424,12 +430,61 @@ class PacketTransformerManager { // expensive relocations. PagedStableVector nodes_; - // A so called "unique table" to ensure each node is only added to `nodes_` - // once, and thus has a unique `PacketTransformerHandle::node_index`. + // The location of `nodes_`, behind a level of indirection that remains + // stable when the manager object is moved: the unique tables below hash and + // compare node indices by dereferencing into `nodes_`, and their + // hasher/equality functors would otherwise dangle on move. Move operations + // repoint the location at the new manager's `nodes_` member. + std::unique_ptr*> + nodes_location_ = std::make_unique< + const PagedStableVector*>(&nodes_); + + // Hasher and equality for unique table entries, which are indices into + // `nodes_`. Hashing/comparing the *node* (rather than the index) is what + // makes the tables deduplicate by node content. + // The `DecisionNode` overloads enable heterogeneous lookup, so that a + // candidate node can be probed before it is added to `nodes_`. + struct InternedNodeHash { + using is_transparent = void; + const PagedStableVector* const* nodes; + size_t operator()(uint32_t node_index) const { + return absl::HashOf((**nodes)[node_index]); + } + size_t operator()(const DecisionNode& node) const { + return absl::HashOf(node); + } + }; + struct InternedNodeEq { + using is_transparent = void; + const PagedStableVector* const* nodes; + bool operator()(uint32_t a, uint32_t b) const { + return (**nodes)[a] == (**nodes)[b]; + } + bool operator()(uint32_t a, const DecisionNode& b) const { + return (**nodes)[a] == b; + } + bool operator()(const DecisionNode& a, uint32_t b) const { + return (**nodes)[b] == a; + } + }; + using UniqueNodeTable = + absl::flat_hash_set; + + // So called "unique tables" to ensure each node is only added to `nodes_` + // once, and thus has a unique `PacketTransformerHandle::node_index`. One + // table per packet field: a node's entry lives in the table of the field it + // branches on. Splitting by field keeps the tables, and thus probe + // sequences, small, and storing indices (instead of node copies) keeps each + // node stored exactly once, in `nodes_`. // - // INVARIANT: `transformer_by_node_[n] = s` iff `nodes_[s.node_index_] == n`. - absl::flat_hash_map - transformer_by_node_; + // INVARIANT: `unique_table_by_field_[f]` contains `i` iff + // `nodes_[i].field.index_ == f`. Every valid node index is contained in + // exactly one table. + std::vector unique_table_by_field_; + + // Returns the unique table for nodes branching on `field`, creating it (and + // any tables for smaller fields) if it does not exist yet. + UniqueNodeTable& GetOrCreateUniqueTable(PacketFieldHandle field); // A map of a given `PolicyProto` to a `PacketTransformerHandle`. //