From 1ce23ab0507ef02d0faf9744f42826cb840508db Mon Sep 17 00:00:00 2001 From: Steffen Smolka Date: Wed, 10 Jun 2026 03:13:50 -0700 Subject: [PATCH 1/4] [NetKAT] Eliminate node and map copies in packet transformer combinators. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Union, Sequence, and Difference copied entire DecisionNodes — including their nested btree maps — at every recursion step, re-interned expanded nodes through the unique table just to recover handles the caller already had, and copied the per-value modification map (GetMapAtValue) inside per-value loops. Several TODOs flagged this as a likely performance problem. Restructure the combinators around cheap, non-owning views instead: * DecisionNodeView views an operand "at" a field: either the node's own maps, or the trivial "fall through to this handle" expansion when the node branches on a larger field (or is Accept). This removes both the by-value node copies and the materialize-and-re-intern expansion step. * ModifyBranchesView replaces GetMapAtValue's map copies with a view of the base map plus at most one extra entry, iterated in sorted order. * CombineModifyBranches becomes a template over the combiner, removing AnyInvocable type-erasure and double lookups, and inserts with an end hint since keys arrive sorted. * The value-collection loops now iterate in sorted order (previously nondeterministic flat_hash_set order), making combination order — and thus node numbering — deterministic. Compilation benchmarks improve by roughly 1.5-2.2x (FirstTimeCompile*), with larger wins expected on policies with wider modification maps. Behavior is unchanged: all tests pass, and the golden-file diff test output is byte-identical, including node numbering. Co-Authored-By: Claude Fable 5 --- netkat/packet_transformer.cc | 654 +++++++++++++++++------------------ netkat/packet_transformer.h | 12 - 2 files changed, 322 insertions(+), 344 deletions(-) diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index a125c8b..c2807da 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -14,8 +14,8 @@ #include "netkat/packet_transformer.h" +#include #include -#include #include #include #include @@ -28,7 +28,6 @@ #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -78,22 +77,6 @@ PacketTransformerManager::GetNodeOrDie( return nodes_[transformer.node_index_]; } -// TODO(dilo): Creating as many map copies as this method facilitates is -// probably going to cause terrible performance, and needs to be revisited. -absl::btree_map -PacketTransformerManager::GetMapAtValue(const DecisionNode& node, int value) { - if (node.modify_branch_by_field_match.contains(value)) - return node.modify_branch_by_field_match.at(value); - - absl::btree_map result = - node.default_branch_by_field_modification; - if (result.contains(value) || IsDeny(node.default_branch)) return result; - - // Otherwise, add a mapping from `value` to the default branch, then return. - result[value] = node.default_branch; - return result; -} - // Canonicalizes a decision node and returns a transformer. PacketTransformerHandle PacketTransformerManager::NodeToTransformer( DecisionNode&& node) { @@ -376,249 +359,288 @@ PacketTransformerHandle PacketTransformerManager::Modification( } namespace { -absl::btree_map CombineModifyBranches( - const absl::btree_map& left, - const absl::btree_map& right, - absl::AnyInvocable - combiner, - PacketTransformerHandle default_value) { - absl::btree_map result; - for (const auto& [value, branch] : left) { - if (right.contains(value)) { - result[value] = combiner(branch, right.at(value)); - } else { - result[value] = combiner(branch, default_value); + +// Aliases for the map types of `PacketTransformerManager::DecisionNode`. +using ModifyBranchMap = absl::btree_map; +using MatchBranchMap = absl::btree_map; + +const MatchBranchMap& EmptyMatchBranchMap() { + static const MatchBranchMap* const kEmpty = new MatchBranchMap; + return *kEmpty; +} + +const ModifyBranchMap& EmptyModifyBranchMap() { + static const ModifyBranchMap* const kEmpty = new ModifyBranchMap; + return *kEmpty; +} + +// A cheap, non-owning view of a decision node as seen "at" a field no larger +// than the node's own field: either the node's own maps (if the node branches +// on that field), or the trivial expansion "for every value of the field, +// leave it unmodified and fall through to the node" (if the node branches on +// a strictly larger field, or is Accept). This lets the binary operations +// below combine operands with distinct fields without materializing — and +// then re-interning — the trivial expansion as a `DecisionNode`. +struct DecisionNodeView { + const MatchBranchMap* modify_branch_by_field_match; + const ModifyBranchMap* default_branch_by_field_modification; + PacketTransformerHandle default_branch; +}; + +// Returns the view of `node` (the decision node of `transformer`, or null if +// `transformer` is Accept) at `field`, which must be <= `node->field`. +// +// `Node` is always `PacketTransformerManager::DecisionNode`; it is a template +// parameter only because that type is private and cannot be named here. +template +DecisionNodeView ViewAtField(PacketFieldHandle field, + PacketTransformerHandle transformer, + const Node* node) { + if (node != nullptr && node->field == field) { + return DecisionNodeView{ + .modify_branch_by_field_match = &node->modify_branch_by_field_match, + .default_branch_by_field_modification = + &node->default_branch_by_field_modification, + .default_branch = node->default_branch, + }; + } + return DecisionNodeView{ + .modify_branch_by_field_match = &EmptyMatchBranchMap(), + .default_branch_by_field_modification = &EmptyModifyBranchMap(), + .default_branch = transformer, + }; +} + +// Returns the smallest field branched on by the given decision nodes, at +// least one of which must be non-null (null encodes Accept, which branches on +// no field). See `ViewAtField` regarding the `Node` template parameter. +template +PacketFieldHandle SmallestField(const Node* left, const Node* right) { + DCHECK(left != nullptr || right != nullptr); + if (left == nullptr) return right->field; + if (right == nullptr) return left->field; + return std::min(left->field, right->field); +} + +// A cheap, non-owning view of a logical (modify value -> branch) map, +// represented as a base map plus at most one extra entry whose key must not +// be a key of the base map. Iteration is in increasing key order, like the +// underlying btree map. +class ModifyBranchesView { + public: + using Entry = std::pair; + + explicit ModifyBranchesView(const ModifyBranchMap& base) : base_(&base) {} + ModifyBranchesView(const ModifyBranchMap& base, Entry extra) + : base_(&base), extra_(extra) { + DCHECK(!base.contains(extra.first)); + } + + class const_iterator { + public: + Entry operator*() const { + return ExtraIsNext() ? *extra_ : Entry(it_->first, it_->second); + } + const_iterator& operator++() { + if (ExtraIsNext()) { + extra_ = nullptr; + } else { + ++it_; + } + return *this; + } + friend bool operator==(const const_iterator& a, const const_iterator& b) { + return a.it_ == b.it_ && a.extra_ == b.extra_; } + friend bool operator!=(const const_iterator& a, const const_iterator& b) { + return !(a == b); + } + + private: + friend class ModifyBranchesView; + const_iterator(ModifyBranchMap::const_iterator it, + ModifyBranchMap::const_iterator end, const Entry* extra) + : it_(it), end_(end), extra_(extra) {} + + // The extra entry comes next iff it has not been consumed and its key + // precedes the next base entry's key. + bool ExtraIsNext() const { + return extra_ != nullptr && (it_ == end_ || extra_->first < it_->first); + } + + ModifyBranchMap::const_iterator it_, end_; + const Entry* extra_; // Null if absent or already consumed. + }; + + const_iterator begin() const { + return const_iterator(base_->begin(), base_->end(), + extra_.has_value() ? &*extra_ : nullptr); } - for (const auto& [value, branch] : right) { - if (!result.contains(value)) - result[value] = combiner(default_value, branch); + const_iterator end() const { + return const_iterator(base_->end(), base_->end(), nullptr); + } + + bool empty() const { return base_->empty() && !extra_.has_value(); } + bool contains(int value) const { + return (extra_.has_value() && extra_->first == value) || + base_->contains(value); + } + std::optional find(int value) const { + if (extra_.has_value() && extra_->first == value) return extra_->second; + if (auto it = base_->find(value); it != base_->end()) return it->second; + return std::nullopt; + } + + private: + const ModifyBranchMap* base_; + std::optional extra_; +}; + +// Returns a view of the logical (modify value -> branch) map that `node` +// applies to packets whose field is equal to `value`: the matching entry of +// `modify_branch_by_field_match` if there is one; otherwise the default +// modifications, plus the unmodified fall-through to `default_branch` (keyed +// by `value`, since the field keeps its value) unless that branch is Deny or +// shadowed by a default modification to `value`. +ModifyBranchesView ModifyBranchesAtValue(const DecisionNodeView& node, + int value) { + if (auto it = node.modify_branch_by_field_match->find(value); + it != node.modify_branch_by_field_match->end()) { + return ModifyBranchesView(it->second); + } + const ModifyBranchMap& defaults = *node.default_branch_by_field_modification; + // `PacketTransformerHandle()` is the Deny transformer. + if (defaults.contains(value) || + node.default_branch == PacketTransformerHandle()) { + return ModifyBranchesView(defaults); + } + return ModifyBranchesView(defaults, {value, node.default_branch}); +} + +// Combines two (modify value -> branch) maps key-wise into a new map, using +// `combiner(left_branch, right_branch)` for shared keys and substituting +// `default_value` for the missing side otherwise. +template +ModifyBranchMap CombineModifyBranches(const ModifyBranchesView& left, + const ModifyBranchesView& right, + Combiner&& combiner, + PacketTransformerHandle default_value) { + ModifyBranchMap result; + for (auto [value, left_branch] : left) { + // Keys arrive in increasing order, so inserting at `end()` is O(1). + result.try_emplace( + result.end(), value, + combiner(left_branch, right.find(value).value_or(default_value))); + } + for (auto [value, right_branch] : right) { + if (left.contains(value)) continue; + result.try_emplace(value, combiner(default_value, right_branch)); } return result; } +// Returns the union of the keys of the given maps, sorted and deduplicated. +template +std::vector SortedUniqueKeys(const Maps&... maps) { + std::vector keys; + keys.reserve((maps.size() + ... + 0)); + auto append_keys = [&keys](const auto& map) { + for (const auto& entry : map) keys.push_back(entry.first); + }; + (append_keys(maps), ...); + absl::c_sort(keys); + keys.erase(std::unique(keys.begin(), keys.end()), keys.end()); + return keys; +} + } // namespace -PacketTransformerHandle PacketTransformerManager::Sequence(DecisionNode left, - DecisionNode right) { - // left.field > right.field: Expand the left node, reducing to the inductive - // case. - if (left.field > right.field) { - PacketFieldHandle first_field = right.field; - return Sequence( - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(left)), - }, - std::move(right)); - } +PacketTransformerHandle PacketTransformerManager::Sequence( + PacketTransformerHandle left, PacketTransformerHandle right) { + // Base cases. + if (IsDeny(left) || IsDeny(right)) return Deny(); + if (IsAccept(left)) return right; + if (IsAccept(right)) return left; - // left.field < right.field: Expand the right node, reducing to the - // inductive case. - if (left.field < right.field) { - PacketFieldHandle first_field = left.field; - return Sequence(std::move(left), - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(right)), - }); - } + // Both operands are decision nodes. Combine them at the smaller of their + // fields; an operand branching on a strictly larger field is viewed as a + // trivial node at the combined field. + const DecisionNode* left_node = &GetNodeOrDie(left); + const DecisionNode* right_node = &GetNodeOrDie(right); + const PacketFieldHandle field = SmallestField(left_node, right_node); + const DecisionNodeView left_view = ViewAtField(field, left, left_node); + const DecisionNodeView right_view = ViewAtField(field, right, right_node); + + auto sequence_combiner = [this](PacketTransformerHandle left, + PacketTransformerHandle right) { + return Sequence(left, right); + }; + auto union_combiner = [this](PacketTransformerHandle left, + PacketTransformerHandle right) { + return Union(left, right); + }; + const ModifyBranchesView empty_view(EmptyModifyBranchMap()); - // left.field == right.field: branch on shared field. - DCHECK(left.field == right.field); DecisionNode result_node{ - .field = left.field, - .default_branch = Sequence(left.default_branch, right.default_branch), + .field = field, + .default_branch = + Sequence(left_view.default_branch, right_view.default_branch), }; // Construct the possible results of applying the right node to packets // gotten by taken default modification branches in the left node. - absl::btree_map - right_applied_to_left_modifications; + ModifyBranchMap right_applied_to_left_modifications; for (const auto& [value, branch] : - left.default_branch_by_field_modification) { - absl::btree_map right_at_value_with_sequence = - CombineModifyBranches( - {}, GetMapAtValue(right, value), - /*combiner=*/ - [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Sequence(left, right); - }, - /*default_value=*/branch); + *left_view.default_branch_by_field_modification) { + ModifyBranchMap right_at_value_with_sequence = CombineModifyBranches( + empty_view, ModifyBranchesAtValue(right_view, value), sequence_combiner, + /*default_value=*/branch); right_applied_to_left_modifications = CombineModifyBranches( - right_applied_to_left_modifications, right_at_value_with_sequence, - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, + ModifyBranchesView(right_applied_to_left_modifications), + ModifyBranchesView(right_at_value_with_sequence), union_combiner, /*default_value=*/Deny()); } + ModifyBranchMap sequenced_right_modifications = CombineModifyBranches( + empty_view, + ModifyBranchesView(*right_view.default_branch_by_field_modification), + sequence_combiner, + /*default_value=*/left_view.default_branch); result_node.default_branch_by_field_modification = CombineModifyBranches( - right_applied_to_left_modifications, - CombineModifyBranches( - {}, right.default_branch_by_field_modification, - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Sequence(left, right); - }, - /*default_value=*/left.default_branch), - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, + ModifyBranchesView(right_applied_to_left_modifications), + ModifyBranchesView(sequenced_right_modifications), union_combiner, /*default_value=*/Deny()); - // Collect every value mapped in each node. - absl::flat_hash_set all_possible_values; - all_possible_values.reserve( - left.modify_branch_by_field_match.size() + - right.modify_branch_by_field_match.size() + - left.default_branch_by_field_modification.size() + - right.default_branch_by_field_modification.size() + - right_applied_to_left_modifications.size()); - - absl::c_transform( - left.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - left.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right_applied_to_left_modifications, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - - // For every value in mapped in each node, construct the proper new branch. - for (int value : all_possible_values) { - auto left_map_at_value = GetMapAtValue(left, value); + // For every value mapped in either node, construct the proper new branch. + for (int value : + SortedUniqueKeys(*left_view.modify_branch_by_field_match, + *right_view.modify_branch_by_field_match, + *left_view.default_branch_by_field_modification, + *right_view.default_branch_by_field_modification, + right_applied_to_left_modifications)) { + ModifyBranchesView left_branches_at_value = + ModifyBranchesAtValue(left_view, value); // An empty map is equivalent to a map with a single entry of // , but the latter is not always canonical. However, an // empty map won't work correctly for the merges below (an in fact, the // whole for-loop would be skipped), so we expand it here if necessary. - if (left_map_at_value.empty()) left_map_at_value[value] = Deny(); - - for (const auto& [left_value, left_spp] : left_map_at_value) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - result_node.modify_branch_by_field_match[value], - CombineModifyBranches( - {}, GetMapAtValue(right, left_value), - /*combiner=*/ - [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Sequence(left, right); - }, - /*default_value=*/left_spp), - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, - /*default_value=*/Deny()); + if (left_branches_at_value.empty()) { + left_branches_at_value = + ModifyBranchesView(EmptyModifyBranchMap(), {value, Deny()}); } - } - - return NodeToTransformer(std::move(result_node)); -} - -PacketTransformerHandle PacketTransformerManager::Sequence( - PacketTransformerHandle left, PacketTransformerHandle right) { - // Base cases. - if (IsDeny(left) || IsDeny(right)) return Deny(); - if (IsAccept(left)) return right; - if (IsAccept(right)) return left; - - // If neither node is accept or deny, then sequence the nodes directly. - return Sequence(GetNodeOrDie(left), GetNodeOrDie(right)); -} -PacketTransformerHandle PacketTransformerManager::Union(DecisionNode left, - DecisionNode right) { - // left.field > right.field: Expand the left node, reducing to the inductive - // case. - if (left.field > right.field) { - PacketFieldHandle first_field = right.field; - return Union( - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(left)), - }, - std::move(right)); - } - - // left.field < right.field: Expand the right node, reducing to the - // inductive case. - if (left.field < right.field) { - PacketFieldHandle first_field = left.field; - return Union(std::move(left), - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(right)), - }); - } - - // left.field == right.field: branch on shared field. - DCHECK(left.field == right.field); - DecisionNode result_node{ - .field = left.field, - .default_branch_by_field_modification = CombineModifyBranches( - left.default_branch_by_field_modification, - right.default_branch_by_field_modification, - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, - /*default_value=*/Deny()), - .default_branch = Union(left.default_branch, right.default_branch), - }; - - // Collect every value in mapped in each node. - absl::flat_hash_set all_possible_values; - all_possible_values.reserve( - left.modify_branch_by_field_match.size() + - right.modify_branch_by_field_match.size() + - left.default_branch_by_field_modification.size() + - right.default_branch_by_field_modification.size()); - - absl::c_transform( - left.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - left.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - - // For every value in mapped in each node, construct the proper new branch. - // TODO(dilo): Would like to use absl::bind_front here instead of a lambda: - // absl::bind_front( - // &PacketTransformerManager::Union, this), - for (int value : all_possible_values) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - GetMapAtValue(left, value), GetMapAtValue(right, value), - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, - /*default_value=*/Deny()); + ModifyBranchMap& result_branches = + result_node.modify_branch_by_field_match[value]; + for (auto [left_value, left_branch] : left_branches_at_value) { + ModifyBranchMap sequenced = CombineModifyBranches( + empty_view, ModifyBranchesAtValue(right_view, left_value), + sequence_combiner, + /*default_value=*/left_branch); + result_branches = + CombineModifyBranches(ModifyBranchesView(result_branches), + ModifyBranchesView(sequenced), union_combiner, + /*default_value=*/Deny()); + } } return NodeToTransformer(std::move(result_node)); @@ -631,95 +653,43 @@ PacketTransformerHandle PacketTransformerManager::Union( if (IsDeny(right)) return left; if (IsDeny(left)) return right; - // If either node is accept, then expand it before merging. - if (IsAccept(left) || IsAccept(right)) { - const DecisionNode& other_node = - GetNodeOrDie(IsAccept(left) ? right : left); - return Union( - DecisionNode{ - .field = other_node.field, - .default_branch = Accept(), - }, - other_node); - } - - // If neither node is accept or deny, then union the nodes directly. - return Union(GetNodeOrDie(left), GetNodeOrDie(right)); -} - -PacketTransformerHandle PacketTransformerManager::Difference( - DecisionNode left, DecisionNode right) { - // left.field > right.field: Expand the left node, reducing to the inductive - // case. - if (left.field > right.field) { - PacketFieldHandle first_field = right.field; - return Difference( - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(left)), - }, - std::move(right)); - } - - // left.field < right.field: Expand the right node, reducing to the - // inductive case. - if (left.field < right.field) { - PacketFieldHandle first_field = left.field; - return Difference(std::move(left), - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(right)), - }); - } + // Neither operand is Deny and at most one is Accept, so at least one is a + // decision node. Combine the operands at the smallest field branched on by + // either; an operand that is Accept, or branches on a strictly larger + // field, is viewed as a trivial node at that field. + const DecisionNode* left_node = + IsAccept(left) ? nullptr : &GetNodeOrDie(left); + const DecisionNode* right_node = + IsAccept(right) ? nullptr : &GetNodeOrDie(right); + const PacketFieldHandle field = SmallestField(left_node, right_node); + const DecisionNodeView left_view = ViewAtField(field, left, left_node); + const DecisionNodeView right_view = ViewAtField(field, right, right_node); + + auto union_combiner = [this](PacketTransformerHandle left, + PacketTransformerHandle right) { + return Union(left, right); + }; - // left.field == right.field: branch on shared field. - DCHECK(left.field == right.field); DecisionNode result_node{ - .field = left.field, + .field = field, .default_branch_by_field_modification = CombineModifyBranches( - left.default_branch_by_field_modification, - right.default_branch_by_field_modification, - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Difference(left, right); - }, + ModifyBranchesView(*left_view.default_branch_by_field_modification), + ModifyBranchesView(*right_view.default_branch_by_field_modification), + union_combiner, /*default_value=*/Deny()), - .default_branch = Difference(left.default_branch, right.default_branch), + .default_branch = + Union(left_view.default_branch, right_view.default_branch), }; - // Collect every value in mapped in each node. - absl::flat_hash_set all_possible_values; - all_possible_values.reserve( - left.modify_branch_by_field_match.size() + - right.modify_branch_by_field_match.size() + - left.default_branch_by_field_modification.size() + - right.default_branch_by_field_modification.size()); - - absl::c_transform( - left.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - left.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - - // For every value in mapped in each node, construct the proper new branch. - for (int value : all_possible_values) { + // For every value mapped in either node, construct the proper new branch. + for (int value : + SortedUniqueKeys(*left_view.modify_branch_by_field_match, + *right_view.modify_branch_by_field_match, + *left_view.default_branch_by_field_modification, + *right_view.default_branch_by_field_modification)) { result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - GetMapAtValue(left, value), GetMapAtValue(right, value), - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Difference(left, right); - }, + ModifyBranchesAtValue(left_view, value), + ModifyBranchesAtValue(right_view, value), union_combiner, /*default_value=*/Deny()); } @@ -733,27 +703,47 @@ PacketTransformerHandle PacketTransformerManager::Difference( if (IsDeny(left)) return Deny(); if (IsDeny(right)) return left; - // If either node is accept, then expand it before merging. - if (IsAccept(left)) { - const DecisionNode& right_node = GetNodeOrDie(right); - return Difference( - DecisionNode{ - .field = right_node.field, - .default_branch = Accept(), - }, - right_node); - } + // Neither operand is Deny and at most one is Accept, so at least one is a + // decision node. Combine the operands at the smallest field branched on by + // either; an operand that is Accept, or branches on a strictly larger + // field, is viewed as a trivial node at that field. + const DecisionNode* left_node = + IsAccept(left) ? nullptr : &GetNodeOrDie(left); + const DecisionNode* right_node = + IsAccept(right) ? nullptr : &GetNodeOrDie(right); + const PacketFieldHandle field = SmallestField(left_node, right_node); + const DecisionNodeView left_view = ViewAtField(field, left, left_node); + const DecisionNodeView right_view = ViewAtField(field, right, right_node); + + auto difference_combiner = [this](PacketTransformerHandle left, + PacketTransformerHandle right) { + return Difference(left, right); + }; + + DecisionNode result_node{ + .field = field, + .default_branch_by_field_modification = CombineModifyBranches( + ModifyBranchesView(*left_view.default_branch_by_field_modification), + ModifyBranchesView(*right_view.default_branch_by_field_modification), + difference_combiner, + /*default_value=*/Deny()), + .default_branch = + Difference(left_view.default_branch, right_view.default_branch), + }; - if (IsAccept(right)) { - const DecisionNode& left_node = GetNodeOrDie(left); - return Difference(left_node, DecisionNode{ - .field = left_node.field, - .default_branch = Accept(), - }); + // For every value mapped in either node, construct the proper new branch. + for (int value : + SortedUniqueKeys(*left_view.modify_branch_by_field_match, + *right_view.modify_branch_by_field_match, + *left_view.default_branch_by_field_modification, + *right_view.default_branch_by_field_modification)) { + result_node.modify_branch_by_field_match[value] = CombineModifyBranches( + ModifyBranchesAtValue(left_view, value), + ModifyBranchesAtValue(right_view, value), difference_combiner, + /*default_value=*/Deny()); } - // If neither node is accept or deny, then difference the nodes directly. - return Difference(GetNodeOrDie(left), GetNodeOrDie(right)); + return NodeToTransformer(std::move(result_node)); } PacketTransformerHandle PacketTransformerManager::Iterate( diff --git a/netkat/packet_transformer.h b/netkat/packet_transformer.h index 4c9c90d..b63bf19 100644 --- a/netkat/packet_transformer.h +++ b/netkat/packet_transformer.h @@ -404,18 +404,6 @@ class PacketTransformerManager { // enough to avoid excessive memory overhead. static constexpr size_t kPageSize = (1 << 26) / sizeof(DecisionNode); - // Helper functions to deal with DecisionNodes directly. - // TODO(dilo): Is there a convenient way to either avoid these or avoid making - // copies of the nodes? - PacketTransformerHandle Union(DecisionNode left, DecisionNode right); - PacketTransformerHandle Sequence(DecisionNode left, DecisionNode right); - PacketTransformerHandle Difference(DecisionNode left, DecisionNode right); - - // Internal helper function to get a map of possible modification values to - // branches for a given input value at `node`. - absl::btree_map GetMapAtValue( - const DecisionNode& node, int value); - // The decision nodes forming the BDD-style DAG representation of packets. // `PacketTransformerHandle::node_index_` indexes into this vector. // From bf0e8ae8dfd8196e6af9e29893cc11db5ffe79c4 Mon Sep 17 00:00:00 2001 From: Steffen Smolka Date: Wed, 10 Jun 2026 03:29:02 -0700 Subject: [PATCH 2/4] [NetKAT] Follow-up cleanups: in-place Sequence accumulation, op dedup, stable goldens. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three follow-ups to the copy-elimination refactor: * Sequence accumulated its result maps by rebuilding them from scratch for every left-branch entry, making accumulation quadratic in the number of branches. It now unions each sequenced branch into the result map in place, which also removes the remaining intermediate maps entirely. * Union and Difference were near-identical after the refactor; their shared body (pointwise combination of two operands viewed at a common field) moves into a PointwiseCombine template parameterized on the operation. * The golden diff test asserted raw node-interning indices, so any reordering inside the manager — including the change above, and planned work like operation memoization — churned the golden file without any semantic change. The test runner now renumbers nodes and fields canonically before printing, by copying each compiled transformer into a fresh manager in deterministic traversal order, so the golden output depends only on transformer structure. The golden file is regenerated accordingly (verified: the previous diff was pure renumbering, with structure, fields, and labels identical). Compilation benchmarks improve a further ~9% on top of the previous commit (~1.8x total vs the original code on FirstTimeCompile NonOverlappingPolicy). All tests pass. Co-Authored-By: Claude Fable 5 --- netkat/BUILD.bazel | 4 + netkat/packet_transformer.cc | 144 +++++++----------- netkat/packet_transformer.h | 11 ++ netkat/packet_transformer_test.expected | 186 +++++++++++------------ netkat/packet_transformer_test_runner.cc | 117 +++++++++++++- 5 files changed, 273 insertions(+), 189 deletions(-) diff --git a/netkat/BUILD.bazel b/netkat/BUILD.bazel index 9121fef..a456370 100644 --- a/netkat/BUILD.bazel +++ b/netkat/BUILD.bazel @@ -440,7 +440,11 @@ cc_test( linkstatic = True, deps = [ ":netkat_proto_constructors", + ":packet_field", ":packet_transformer", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index c2807da..efd421f 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -571,15 +571,12 @@ PacketTransformerHandle PacketTransformerManager::Sequence( const DecisionNodeView left_view = ViewAtField(field, left, left_node); const DecisionNodeView right_view = ViewAtField(field, right, right_node); - auto sequence_combiner = [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Sequence(left, right); + // Unions `branch` into `branches[modify_value]`, treating absence as Deny. + auto union_into = [this](ModifyBranchMap& branches, int modify_value, + PacketTransformerHandle branch) { + auto [it, inserted] = branches.try_emplace(modify_value, branch); + if (!inserted) it->second = Union(it->second, branch); }; - auto union_combiner = [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Union(left, right); - }; - const ModifyBranchesView empty_view(EmptyModifyBranchMap()); DecisionNode result_node{ .field = field, @@ -588,28 +585,24 @@ PacketTransformerHandle PacketTransformerManager::Sequence( }; // Construct the possible results of applying the right node to packets - // gotten by taken default modification branches in the left node. - ModifyBranchMap right_applied_to_left_modifications; + // gotten by taken default modification branches in the left node... + ModifyBranchMap& result_modifications = + result_node.default_branch_by_field_modification; for (const auto& [value, branch] : *left_view.default_branch_by_field_modification) { - ModifyBranchMap right_at_value_with_sequence = CombineModifyBranches( - empty_view, ModifyBranchesAtValue(right_view, value), sequence_combiner, - /*default_value=*/branch); - right_applied_to_left_modifications = CombineModifyBranches( - ModifyBranchesView(right_applied_to_left_modifications), - ModifyBranchesView(right_at_value_with_sequence), union_combiner, - /*default_value=*/Deny()); + for (auto [modify_value, right_branch] : + ModifyBranchesAtValue(right_view, value)) { + union_into(result_modifications, modify_value, + Sequence(branch, right_branch)); + } + } + // ... and of applying the right node's default modifications to packets + // falling through the left node unmodified. + for (const auto& [modify_value, right_branch] : + *right_view.default_branch_by_field_modification) { + union_into(result_modifications, modify_value, + Sequence(left_view.default_branch, right_branch)); } - - ModifyBranchMap sequenced_right_modifications = CombineModifyBranches( - empty_view, - ModifyBranchesView(*right_view.default_branch_by_field_modification), - sequence_combiner, - /*default_value=*/left_view.default_branch); - result_node.default_branch_by_field_modification = CombineModifyBranches( - ModifyBranchesView(right_applied_to_left_modifications), - ModifyBranchesView(sequenced_right_modifications), union_combiner, - /*default_value=*/Deny()); // For every value mapped in either node, construct the proper new branch. for (int value : @@ -617,7 +610,7 @@ PacketTransformerHandle PacketTransformerManager::Sequence( *right_view.modify_branch_by_field_match, *left_view.default_branch_by_field_modification, *right_view.default_branch_by_field_modification, - right_applied_to_left_modifications)) { + result_modifications)) { ModifyBranchesView left_branches_at_value = ModifyBranchesAtValue(left_view, value); // An empty map is equivalent to a map with a single entry of @@ -632,27 +625,21 @@ PacketTransformerHandle PacketTransformerManager::Sequence( ModifyBranchMap& result_branches = result_node.modify_branch_by_field_match[value]; for (auto [left_value, left_branch] : left_branches_at_value) { - ModifyBranchMap sequenced = CombineModifyBranches( - empty_view, ModifyBranchesAtValue(right_view, left_value), - sequence_combiner, - /*default_value=*/left_branch); - result_branches = - CombineModifyBranches(ModifyBranchesView(result_branches), - ModifyBranchesView(sequenced), union_combiner, - /*default_value=*/Deny()); + for (auto [modify_value, right_branch] : + ModifyBranchesAtValue(right_view, left_value)) { + union_into(result_branches, modify_value, + Sequence(left_branch, right_branch)); + } } } return NodeToTransformer(std::move(result_node)); } -PacketTransformerHandle PacketTransformerManager::Union( - PacketTransformerHandle left, PacketTransformerHandle right) { - // Base cases. - if (left == right) return left; - if (IsDeny(right)) return left; - if (IsDeny(left)) return right; - +template +PacketTransformerHandle PacketTransformerManager::PointwiseCombine( + PacketTransformerHandle left, PacketTransformerHandle right, + Combiner&& combiner) { // Neither operand is Deny and at most one is Accept, so at least one is a // decision node. Combine the operands at the smallest field branched on by // either; an operand that is Accept, or branches on a strictly larger @@ -665,20 +652,15 @@ PacketTransformerHandle PacketTransformerManager::Union( const DecisionNodeView left_view = ViewAtField(field, left, left_node); const DecisionNodeView right_view = ViewAtField(field, right, right_node); - auto union_combiner = [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Union(left, right); - }; - DecisionNode result_node{ .field = field, .default_branch_by_field_modification = CombineModifyBranches( ModifyBranchesView(*left_view.default_branch_by_field_modification), ModifyBranchesView(*right_view.default_branch_by_field_modification), - union_combiner, + combiner, /*default_value=*/Deny()), .default_branch = - Union(left_view.default_branch, right_view.default_branch), + combiner(left_view.default_branch, right_view.default_branch), }; // For every value mapped in either node, construct the proper new branch. @@ -689,13 +671,27 @@ PacketTransformerHandle PacketTransformerManager::Union( *right_view.default_branch_by_field_modification)) { result_node.modify_branch_by_field_match[value] = CombineModifyBranches( ModifyBranchesAtValue(left_view, value), - ModifyBranchesAtValue(right_view, value), union_combiner, + ModifyBranchesAtValue(right_view, value), combiner, /*default_value=*/Deny()); } return NodeToTransformer(std::move(result_node)); } +PacketTransformerHandle PacketTransformerManager::Union( + PacketTransformerHandle left, PacketTransformerHandle right) { + // Base cases. + if (left == right) return left; + if (IsDeny(right)) return left; + if (IsDeny(left)) return right; + + return PointwiseCombine( + left, right, + [this](PacketTransformerHandle left, PacketTransformerHandle right) { + return Union(left, right); + }); +} + PacketTransformerHandle PacketTransformerManager::Difference( PacketTransformerHandle left, PacketTransformerHandle right) { // Base cases. @@ -703,47 +699,11 @@ PacketTransformerHandle PacketTransformerManager::Difference( if (IsDeny(left)) return Deny(); if (IsDeny(right)) return left; - // Neither operand is Deny and at most one is Accept, so at least one is a - // decision node. Combine the operands at the smallest field branched on by - // either; an operand that is Accept, or branches on a strictly larger - // field, is viewed as a trivial node at that field. - const DecisionNode* left_node = - IsAccept(left) ? nullptr : &GetNodeOrDie(left); - const DecisionNode* right_node = - IsAccept(right) ? nullptr : &GetNodeOrDie(right); - const PacketFieldHandle field = SmallestField(left_node, right_node); - const DecisionNodeView left_view = ViewAtField(field, left, left_node); - const DecisionNodeView right_view = ViewAtField(field, right, right_node); - - auto difference_combiner = [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Difference(left, right); - }; - - DecisionNode result_node{ - .field = field, - .default_branch_by_field_modification = CombineModifyBranches( - ModifyBranchesView(*left_view.default_branch_by_field_modification), - ModifyBranchesView(*right_view.default_branch_by_field_modification), - difference_combiner, - /*default_value=*/Deny()), - .default_branch = - Difference(left_view.default_branch, right_view.default_branch), - }; - - // For every value mapped in either node, construct the proper new branch. - for (int value : - SortedUniqueKeys(*left_view.modify_branch_by_field_match, - *right_view.modify_branch_by_field_match, - *left_view.default_branch_by_field_modification, - *right_view.default_branch_by_field_modification)) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - ModifyBranchesAtValue(left_view, value), - ModifyBranchesAtValue(right_view, value), difference_combiner, - /*default_value=*/Deny()); - } - - return NodeToTransformer(std::move(result_node)); + return PointwiseCombine( + left, right, + [this](PacketTransformerHandle left, PacketTransformerHandle right) { + return Difference(left, right); + }); } PacketTransformerHandle PacketTransformerManager::Iterate( diff --git a/netkat/packet_transformer.h b/netkat/packet_transformer.h index b63bf19..ed0e074 100644 --- a/netkat/packet_transformer.h +++ b/netkat/packet_transformer.h @@ -404,6 +404,17 @@ class PacketTransformerManager { // enough to avoid excessive memory overhead. static constexpr size_t kPageSize = (1 << 26) / sizeof(DecisionNode); + // Shared implementation of `Union` and `Difference`, which differ only in + // their base cases and the operation applied to corresponding branches: + // combines `left` and `right` by applying `combiner` — which must be the + // handle-level operation itself, e.g. `Union` — to corresponding branches. + // Both operands must be Accept or decision nodes (i.e., the base cases must + // already have been handled). + template + PacketTransformerHandle PointwiseCombine(PacketTransformerHandle left, + PacketTransformerHandle right, + Combiner&& combiner); + // The decision nodes forming the BDD-style DAG representation of packets. // `PacketTransformerHandle::node_index_` indexes into this vector. // diff --git a/netkat/packet_transformer_test.expected b/netkat/packet_transformer_test.expected index 9f01d94..ca6b248 100644 --- a/netkat/packet_transformer_test.expected +++ b/netkat/packet_transformer_test.expected @@ -42,22 +42,22 @@ digraph { Test case: p := (a=5 + b=2);(b:=1 + c=5). Example from Katch paper Fig 5. ================================================================================ -- STRING ---------------------------------------------------------------------- -PacketTransformerHandle<8>: +PacketTransformerHandle<3>: PacketFieldHandle<0>:'a' == 5: - PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<6> + PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<1> PacketFieldHandle<0>:'a' == *: - PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<7> -PacketTransformerHandle<6>: + PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<2> +PacketTransformerHandle<1>: PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<5> -PacketTransformerHandle<7>: + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<0> +PacketTransformerHandle<2>: PacketFieldHandle<1>:'b' == 2: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle - PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<5> + PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<0> PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<5>: +PacketTransformerHandle<0>: PacketFieldHandle<2>:'c' == 5: PacketFieldHandle<2>:'c' := 5 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == *: @@ -68,50 +68,50 @@ digraph { edge [fontsize = 12] 4294967294 [label="T" shape=box] 4294967295 [label="F" shape=box] - 8 [label="a"] - 8 -> 6 [label="a==5; a:=5"] - 8 -> 7 [style=dashed] - 6 [label="b"] - 6 -> 4294967294 [label="b:=1" style=dashed] - 6 -> 5 [style=dashed] - 7 [label="b"] - 7 -> 4294967294 [label="b==2; b:=1"] - 7 -> 5 [label="b==2; b:=2"] - 7 -> 4294967295 [style=dashed] - 5 [label="c"] - 5 -> 4294967294 [label="c==5; c:=5"] - 5 -> 4294967295 [style=dashed] + 3 [label="a"] + 3 -> 1 [label="a==5; a:=5"] + 3 -> 2 [style=dashed] + 1 [label="b"] + 1 -> 4294967294 [label="b:=1" style=dashed] + 1 -> 0 [style=dashed] + 2 [label="b"] + 2 -> 4294967294 [label="b==2; b:=1"] + 2 -> 0 [label="b==2; b:=2"] + 2 -> 4294967295 [style=dashed] + 0 [label="c"] + 0 -> 4294967294 [label="c==5; c:=5"] + 0 -> 4294967295 [style=dashed] } ================================================================================ Test case: q := (b=1 + c:=4 + a:=5;b:=1). Example from Katch paper Fig 5. ================================================================================ -- STRING ---------------------------------------------------------------------- -PacketTransformerHandle<17>: +PacketTransformerHandle<5>: PacketFieldHandle<0>:'a' == 1: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<14> + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<2> PacketFieldHandle<0>:'a' == *: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<4> - PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<16> -PacketTransformerHandle<14>: + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<3> + PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<4> +PacketTransformerHandle<2>: PacketFieldHandle<1>:'b' == 1: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<10> -PacketTransformerHandle<4>: + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<1> +PacketTransformerHandle<3>: PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<16>: +PacketTransformerHandle<4>: PacketFieldHandle<1>:'b' == 1: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> PacketFieldHandle<1>:'b' == *: - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<10> -PacketTransformerHandle<13>: + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<1> +PacketTransformerHandle<0>: PacketFieldHandle<2>:'c' == *: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == * -> PacketTransformerHandle -PacketTransformerHandle<10>: +PacketTransformerHandle<1>: PacketFieldHandle<2>:'c' == *: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == * -> PacketTransformerHandle @@ -121,64 +121,64 @@ digraph { edge [fontsize = 12] 4294967294 [label="T" shape=box] 4294967295 [label="F" shape=box] - 17 [label="a"] - 17 -> 14 [label="a==1; a:=1"] - 17 -> 4 [label="a:=1" style=dashed] - 17 -> 16 [style=dashed] - 14 [label="b"] - 14 -> 13 [label="b==1; b:=1"] - 14 -> 4294967294 [label="b:=1" style=dashed] - 14 -> 10 [style=dashed] + 5 [label="a"] + 5 -> 2 [label="a==1; a:=1"] + 5 -> 3 [label="a:=1" style=dashed] + 5 -> 4 [style=dashed] + 2 [label="b"] + 2 -> 0 [label="b==1; b:=1"] + 2 -> 4294967294 [label="b:=1" style=dashed] + 2 -> 1 [style=dashed] + 3 [label="b"] + 3 -> 4294967294 [label="b:=1" style=dashed] + 3 -> 4294967295 [style=dashed] 4 [label="b"] - 4 -> 4294967294 [label="b:=1" style=dashed] - 4 -> 4294967295 [style=dashed] - 16 [label="b"] - 16 -> 13 [label="b==1; b:=1"] - 16 -> 10 [style=dashed] - 13 [label="c"] - 13 -> 4294967294 [label="c:=4" style=dashed] - 13 -> 4294967294 [style=dashed] - 10 [label="c"] - 10 -> 4294967294 [label="c:=4" style=dashed] - 10 -> 4294967295 [style=dashed] + 4 -> 0 [label="b==1; b:=1"] + 4 -> 1 [style=dashed] + 0 [label="c"] + 0 -> 4294967294 [label="c:=4" style=dashed] + 0 -> 4294967294 [style=dashed] + 1 [label="c"] + 1 -> 4294967294 [label="c:=4" style=dashed] + 1 -> 4294967295 [style=dashed] } ================================================================================ Test case: p;q. Example from Katch paper Fig 5. ================================================================================ -- STRING ---------------------------------------------------------------------- -PacketTransformerHandle<22>: +PacketTransformerHandle<6>: PacketFieldHandle<0>:'a' == 1: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<19> + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<2> PacketFieldHandle<0>:'a' == 5: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<4> - PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<21> + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<3> + PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<4> PacketFieldHandle<0>:'a' == *: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<20> - PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<19> -PacketTransformerHandle<19>: + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<5> + PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<2> +PacketTransformerHandle<2>: PacketFieldHandle<1>:'b' == 2: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> - PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<18> + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> + PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<1> PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<4>: +PacketTransformerHandle<3>: PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<21>: +PacketTransformerHandle<4>: PacketFieldHandle<1>:'b' == *: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<18> -PacketTransformerHandle<20>: + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<1> +PacketTransformerHandle<5>: PacketFieldHandle<1>:'b' == 2: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<13>: +PacketTransformerHandle<0>: PacketFieldHandle<2>:'c' == *: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == * -> PacketTransformerHandle -PacketTransformerHandle<18>: +PacketTransformerHandle<1>: PacketFieldHandle<2>:'c' == 5: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == *: @@ -189,29 +189,29 @@ digraph { edge [fontsize = 12] 4294967294 [label="T" shape=box] 4294967295 [label="F" shape=box] - 22 [label="a"] - 22 -> 19 [label="a==1; a:=1"] - 22 -> 4 [label="a==5; a:=1"] - 22 -> 21 [label="a==5; a:=5"] - 22 -> 20 [label="a:=1" style=dashed] - 22 -> 19 [style=dashed] - 19 [label="b"] - 19 -> 13 [label="b==2; b:=1"] - 19 -> 18 [label="b==2; b:=2"] - 19 -> 4294967295 [style=dashed] + 6 [label="a"] + 6 -> 2 [label="a==1; a:=1"] + 6 -> 3 [label="a==5; a:=1"] + 6 -> 4 [label="a==5; a:=5"] + 6 -> 5 [label="a:=1" style=dashed] + 6 -> 2 [style=dashed] + 2 [label="b"] + 2 -> 0 [label="b==2; b:=1"] + 2 -> 1 [label="b==2; b:=2"] + 2 -> 4294967295 [style=dashed] + 3 [label="b"] + 3 -> 4294967294 [label="b:=1" style=dashed] + 3 -> 4294967295 [style=dashed] 4 [label="b"] - 4 -> 4294967294 [label="b:=1" style=dashed] - 4 -> 4294967295 [style=dashed] - 21 [label="b"] - 21 -> 13 [label="b:=1" style=dashed] - 21 -> 18 [style=dashed] - 20 [label="b"] - 20 -> 4294967294 [label="b==2; b:=1"] - 20 -> 4294967295 [style=dashed] - 13 [label="c"] - 13 -> 4294967294 [label="c:=4" style=dashed] - 13 -> 4294967294 [style=dashed] - 18 [label="c"] - 18 -> 4294967294 [label="c==5; c:=4"] - 18 -> 4294967295 [style=dashed] + 4 -> 0 [label="b:=1" style=dashed] + 4 -> 1 [style=dashed] + 5 [label="b"] + 5 -> 4294967294 [label="b==2; b:=1"] + 5 -> 4294967295 [style=dashed] + 0 [label="c"] + 0 -> 4294967294 [label="c:=4" style=dashed] + 0 -> 4294967294 [style=dashed] + 1 [label="c"] + 1 -> 4294967294 [label="c==5; c:=4"] + 1 -> 4294967295 [style=dashed] } diff --git a/netkat/packet_transformer_test_runner.cc b/netkat/packet_transformer_test_runner.cc index 653c320..c4eacfc 100644 --- a/netkat/packet_transformer_test_runner.cc +++ b/netkat/packet_transformer_test_runner.cc @@ -16,15 +16,118 @@ // `bazel run //netkat:packet_transformer_diff_test // -- --update` +#include #include #include #include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "netkat/netkat_proto_constructors.h" +#include "netkat/packet_field.h" #include "netkat/packet_transformer.h" namespace netkat { + +// Test peer to access `PacketTransformerManager` internals, allowing this +// runner to renumber nodes canonically before printing. +class PacketTransformerManagerTestPeer { + public: + // Copies `transformer` into the fresh manager `to`, interning fields and + // nodes in a deterministic traversal order. This makes the printed handle + // numbers depend only on the structure of `transformer` — not on the + // interning order of `from`, which is an implementation detail that changes + // under refactorings of the manager (e.g. reordering its operations). + static PacketTransformerHandle CanonicalCopy( + const PacketTransformerManager& from, PacketTransformerHandle transformer, + PacketTransformerManager& to) { + // Intern the fields of all reachable nodes in their original relative + // order, which the node invariants ("fields increase strictly along each + // path") depend on. + std::vector fields; + absl::flat_hash_set visited; + CollectFields(from, transformer, visited, fields); + absl::c_sort(fields); + fields.erase(std::unique(fields.begin(), fields.end()), fields.end()); + for (PacketFieldHandle field : fields) { + to.packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( + from.packet_set_manager_.field_manager_.GetFieldName(field)); + } + + absl::flat_hash_map + copy_by_original; + return Copy(from, transformer, to, copy_by_original); + } + + private: + static void CollectFields( + const PacketTransformerManager& from, PacketTransformerHandle transformer, + absl::flat_hash_set& visited, + std::vector& fields) { + if (from.IsDeny(transformer) || from.IsAccept(transformer)) return; + if (!visited.insert(transformer).second) return; + const PacketTransformerManager::DecisionNode& node = + from.GetNodeOrDie(transformer); + fields.push_back(node.field); + for (const auto& [match_value, branch_by_modify] : + node.modify_branch_by_field_match) { + for (const auto& [modify_value, branch] : branch_by_modify) { + CollectFields(from, branch, visited, fields); + } + } + for (const auto& [modify_value, branch] : + node.default_branch_by_field_modification) { + CollectFields(from, branch, visited, fields); + } + CollectFields(from, node.default_branch, visited, fields); + } + + static PacketTransformerHandle Copy( + const PacketTransformerManager& from, PacketTransformerHandle transformer, + PacketTransformerManager& to, + absl::flat_hash_map& + copy_by_original) { + if (from.IsDeny(transformer)) return to.Deny(); + if (from.IsAccept(transformer)) return to.Accept(); + if (auto it = copy_by_original.find(transformer); + it != copy_by_original.end()) { + return it->second; + } + + const PacketTransformerManager::DecisionNode& node = + from.GetNodeOrDie(transformer); + PacketTransformerManager::DecisionNode copy{ + .field = + to.packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( + from.packet_set_manager_.field_manager_.GetFieldName( + node.field)), + }; + for (const auto& [match_value, branch_by_modify] : + node.modify_branch_by_field_match) { + // `operator[]` keeps entries with empty branch maps, which are + // meaningful: they deny packets matching `match_value`. + auto& copy_branch_by_modify = + copy.modify_branch_by_field_match[match_value]; + for (const auto& [modify_value, branch] : branch_by_modify) { + copy_branch_by_modify[modify_value] = + Copy(from, branch, to, copy_by_original); + } + } + for (const auto& [modify_value, branch] : + node.default_branch_by_field_modification) { + copy.default_branch_by_field_modification[modify_value] = + Copy(from, branch, to, copy_by_original); + } + copy.default_branch = Copy(from, node.default_branch, to, copy_by_original); + + PacketTransformerHandle result = to.NodeToTransformer(std::move(copy)); + copy_by_original.emplace(transformer, result); + return result; + } +}; + namespace { constexpr char kBanner[] = @@ -91,16 +194,22 @@ std::vector TestCases() { } void main() { - // This test needs a deterministic field interning order, and thus must start - // from a fresh manager. PacketTransformerManager manager; for (const TestCase& test_case : TestCases()) { netkat::PacketTransformerHandle packet_transformer = manager.Compile(test_case.policy); + // Renumber nodes and fields canonically, in a fresh manager per test case, + // so that the printed output depends only on the structure of the compiled + // transformer and not on the interning order of `manager`. + PacketTransformerManager canonical_manager; + netkat::PacketTransformerHandle canonical_transformer = + PacketTransformerManagerTestPeer::CanonicalCopy( + manager, packet_transformer, canonical_manager); std::cout << kBanner << "Test case: " << test_case.description << std::endl << kBanner; - std::cout << kStringHeader << manager.ToString(packet_transformer); - std::cout << kDotHeader << manager.ToDot(packet_transformer); + std::cout << kStringHeader + << canonical_manager.ToString(canonical_transformer); + std::cout << kDotHeader << canonical_manager.ToDot(canonical_transformer); } } From 37aa3a4b24aca9140eb7885a5678af5e1ac2252e Mon Sep 17 00:00:00 2001 From: Steffen Smolka Date: Wed, 10 Jun 2026 04:07:20 -0700 Subject: [PATCH 3/4] [NetKAT] Apply review cleanups to the combinator refactor. Findings from a four-angle cleanup review (reuse, simplification, efficiency, altitude): * Replace ModifyBranchesView's hand-rolled merging iterator (~45 lines, the subtlest code in the refactor) with a simple ForEach that visits base entries then the extra entry. No caller needed globally sorted iteration: map contents are order-independent and handle-level results are canonical regardless of combination order. * Deduplicate the operand-view prologue shared by Sequence and PointwiseCombine into ViewOperandsAtSmallestField. * Use absl::NoDestructor for the empty-map singletons (the repo's established idiom) and name the default-handle-is-Deny contract in one place (IsDenyHandle) instead of re-deriving it inline. * Hint the per-value btree insertions in Sequence/PointwiseCombine with end(), since SortedUniqueKeys emits values in increasing order. * Test runner: collect fields into a btree_set instead of sort+unique, and translate field handles through a precomputed map instead of a per-node string round-trip. * Drop a stale TODO referencing the deleted GetMapAtValue, and the now-unused any_invocable dependency. No behavior change: all tests pass and the golden file is unchanged. Benchmarks are neutral. Co-Authored-By: Claude Fable 5 --- netkat/BUILD.bazel | 4 +- netkat/packet_transformer.cc | 186 +++++++++++------------ netkat/packet_transformer_test_runner.cc | 36 ++--- 3 files changed, 106 insertions(+), 120 deletions(-) diff --git a/netkat/BUILD.bazel b/netkat/BUILD.bazel index a456370..5bffd85 100644 --- a/netkat/BUILD.bazel +++ b/netkat/BUILD.bazel @@ -371,11 +371,11 @@ cc_library( ":packet_set", ":paged_stable_vector", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -442,7 +442,7 @@ cc_test( ":netkat_proto_constructors", ":packet_field", ":packet_transformer", - "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", ], diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index efd421f..008fe6d 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -24,6 +24,7 @@ #include #include "absl/algorithm/container.h" +#include "absl/base/no_destructor.h" #include "absl/container/btree_map.h" #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" @@ -121,9 +122,6 @@ PacketTransformerHandle PacketTransformerManager::NodeToTransformer( absl::flat_hash_set redundant_values; for (auto& [match_value, modification_map] : node.modify_branch_by_field_match) { - // TODO(dilo): Consider if this can make use of GetMapAtValue. Perhaps by - // calling it on a copy of DecisionNode without any - // `modify_branch_by_field_match` mappings? if (skip_default_branch || node.default_branch_by_field_modification.contains(match_value)) { // Compare the modification map to the default branch modification map, @@ -365,15 +363,22 @@ using ModifyBranchMap = absl::btree_map; using MatchBranchMap = absl::btree_map; const MatchBranchMap& EmptyMatchBranchMap() { - static const MatchBranchMap* const kEmpty = new MatchBranchMap; + static const absl::NoDestructor kEmpty; return *kEmpty; } const ModifyBranchMap& EmptyModifyBranchMap() { - static const ModifyBranchMap* const kEmpty = new ModifyBranchMap; + static const absl::NoDestructor kEmpty; return *kEmpty; } +// Returns true iff `transformer` is the Deny transformer, which by documented +// contract is the default-constructed handle. (Unlike +// `PacketTransformerManager::IsDeny`, this is callable without a manager.) +bool IsDenyHandle(PacketTransformerHandle transformer) { + return transformer == PacketTransformerHandle(); +} + // A cheap, non-owning view of a decision node as seen "at" a field no larger // than the node's own field: either the node's own maps (if the node branches // on that field), or the trivial expansion "for every value of the field, @@ -422,10 +427,35 @@ PacketFieldHandle SmallestField(const Node* left, const Node* right) { return std::min(left->field, right->field); } +// The two operands of a binary combinator, viewed at the smallest field +// branched on by either: an operand that is Accept, or branches on a strictly +// larger field, is viewed as a trivial node at that field. +struct OperandViews { + PacketFieldHandle field; + DecisionNodeView left; + DecisionNodeView right; +}; + +// Views the operands `left` and `right`, whose decision nodes are `left_node` +// and `right_node` (null encoding Accept; at least one must be non-null), at +// the smallest field branched on by either. See `ViewAtField` regarding the +// `Node` template parameter. +template +OperandViews ViewOperandsAtSmallestField(PacketTransformerHandle left, + const Node* left_node, + PacketTransformerHandle right, + const Node* right_node) { + const PacketFieldHandle field = SmallestField(left_node, right_node); + return OperandViews{ + .field = field, + .left = ViewAtField(field, left, left_node), + .right = ViewAtField(field, right, right_node), + }; +} + // A cheap, non-owning view of a logical (modify value -> branch) map, // represented as a base map plus at most one extra entry whose key must not -// be a key of the base map. Iteration is in increasing key order, like the -// underlying btree map. +// be a key of the base map. class ModifyBranchesView { public: using Entry = std::pair; @@ -436,61 +466,22 @@ class ModifyBranchesView { DCHECK(!base.contains(extra.first)); } - class const_iterator { - public: - Entry operator*() const { - return ExtraIsNext() ? *extra_ : Entry(it_->first, it_->second); - } - const_iterator& operator++() { - if (ExtraIsNext()) { - extra_ = nullptr; - } else { - ++it_; - } - return *this; - } - friend bool operator==(const const_iterator& a, const const_iterator& b) { - return a.it_ == b.it_ && a.extra_ == b.extra_; - } - friend bool operator!=(const const_iterator& a, const const_iterator& b) { - return !(a == b); - } - - private: - friend class ModifyBranchesView; - const_iterator(ModifyBranchMap::const_iterator it, - ModifyBranchMap::const_iterator end, const Entry* extra) - : it_(it), end_(end), extra_(extra) {} - - // The extra entry comes next iff it has not been consumed and its key - // precedes the next base entry's key. - bool ExtraIsNext() const { - return extra_ != nullptr && (it_ == end_ || extra_->first < it_->first); - } - - ModifyBranchMap::const_iterator it_, end_; - const Entry* extra_; // Null if absent or already consumed. - }; - - const_iterator begin() const { - return const_iterator(base_->begin(), base_->end(), - extra_.has_value() ? &*extra_ : nullptr); - } - const_iterator end() const { - return const_iterator(base_->end(), base_->end(), nullptr); + // Invokes `fn(modify_value, branch)` for each entry: the base map entries + // in increasing key order, then the extra entry, if any. + template + void ForEach(Fn&& fn) const { + for (const auto& [value, branch] : *base_) fn(value, branch); + if (extra_.has_value()) fn(extra_->first, extra_->second); } - bool empty() const { return base_->empty() && !extra_.has_value(); } - bool contains(int value) const { - return (extra_.has_value() && extra_->first == value) || - base_->contains(value); - } - std::optional find(int value) const { + std::optional Find(int value) const { if (extra_.has_value() && extra_->first == value) return extra_->second; if (auto it = base_->find(value); it != base_->end()) return it->second; return std::nullopt; } + bool empty() const { return base_->empty() && !extra_.has_value(); } + private: const ModifyBranchMap* base_; std::optional extra_; @@ -509,9 +500,7 @@ ModifyBranchesView ModifyBranchesAtValue(const DecisionNodeView& node, return ModifyBranchesView(it->second); } const ModifyBranchMap& defaults = *node.default_branch_by_field_modification; - // `PacketTransformerHandle()` is the Deny transformer. - if (defaults.contains(value) || - node.default_branch == PacketTransformerHandle()) { + if (defaults.contains(value) || IsDenyHandle(node.default_branch)) { return ModifyBranchesView(defaults); } return ModifyBranchesView(defaults, {value, node.default_branch}); @@ -526,16 +515,16 @@ ModifyBranchMap CombineModifyBranches(const ModifyBranchesView& left, Combiner&& combiner, PacketTransformerHandle default_value) { ModifyBranchMap result; - for (auto [value, left_branch] : left) { - // Keys arrive in increasing order, so inserting at `end()` is O(1). + left.ForEach([&](int value, PacketTransformerHandle left_branch) { + // Keys arrive in near-sorted order, so the `end()` hint is mostly exact. result.try_emplace( result.end(), value, - combiner(left_branch, right.find(value).value_or(default_value))); - } - for (auto [value, right_branch] : right) { - if (left.contains(value)) continue; + combiner(left_branch, right.Find(value).value_or(default_value))); + }); + right.ForEach([&](int value, PacketTransformerHandle right_branch) { + if (left.Find(value).has_value()) return; result.try_emplace(value, combiner(default_value, right_branch)); - } + }); return result; } @@ -562,14 +551,9 @@ PacketTransformerHandle PacketTransformerManager::Sequence( if (IsAccept(left)) return right; if (IsAccept(right)) return left; - // Both operands are decision nodes. Combine them at the smaller of their - // fields; an operand branching on a strictly larger field is viewed as a - // trivial node at the combined field. - const DecisionNode* left_node = &GetNodeOrDie(left); - const DecisionNode* right_node = &GetNodeOrDie(right); - const PacketFieldHandle field = SmallestField(left_node, right_node); - const DecisionNodeView left_view = ViewAtField(field, left, left_node); - const DecisionNodeView right_view = ViewAtField(field, right, right_node); + // Both operands are decision nodes. + const auto [field, left_view, right_view] = ViewOperandsAtSmallestField( + left, &GetNodeOrDie(left), right, &GetNodeOrDie(right)); // Unions `branch` into `branches[modify_value]`, treating absence as Deny. auto union_into = [this](ModifyBranchMap& branches, int modify_value, @@ -590,11 +574,12 @@ PacketTransformerHandle PacketTransformerManager::Sequence( result_node.default_branch_by_field_modification; for (const auto& [value, branch] : *left_view.default_branch_by_field_modification) { - for (auto [modify_value, right_branch] : - ModifyBranchesAtValue(right_view, value)) { - union_into(result_modifications, modify_value, - Sequence(branch, right_branch)); - } + ModifyBranchesAtValue(right_view, value) + .ForEach([&, branch = branch](int modify_value, + PacketTransformerHandle right_branch) { + union_into(result_modifications, modify_value, + Sequence(branch, right_branch)); + }); } // ... and of applying the right node's default modifications to packets // falling through the left node unmodified. @@ -622,15 +607,19 @@ PacketTransformerHandle PacketTransformerManager::Sequence( ModifyBranchesView(EmptyModifyBranchMap(), {value, Deny()}); } + // `value`s arrive in increasing order, so inserting at `end()` is O(1). ModifyBranchMap& result_branches = - result_node.modify_branch_by_field_match[value]; - for (auto [left_value, left_branch] : left_branches_at_value) { - for (auto [modify_value, right_branch] : - ModifyBranchesAtValue(right_view, left_value)) { - union_into(result_branches, modify_value, - Sequence(left_branch, right_branch)); - } - } + result_node.modify_branch_by_field_match + .try_emplace(result_node.modify_branch_by_field_match.end(), value) + ->second; + left_branches_at_value.ForEach([&](int left_value, + PacketTransformerHandle left_branch) { + ModifyBranchesAtValue(right_view, left_value) + .ForEach([&](int modify_value, PacketTransformerHandle right_branch) { + union_into(result_branches, modify_value, + Sequence(left_branch, right_branch)); + }); + }); } return NodeToTransformer(std::move(result_node)); @@ -641,16 +630,10 @@ PacketTransformerHandle PacketTransformerManager::PointwiseCombine( PacketTransformerHandle left, PacketTransformerHandle right, Combiner&& combiner) { // Neither operand is Deny and at most one is Accept, so at least one is a - // decision node. Combine the operands at the smallest field branched on by - // either; an operand that is Accept, or branches on a strictly larger - // field, is viewed as a trivial node at that field. - const DecisionNode* left_node = - IsAccept(left) ? nullptr : &GetNodeOrDie(left); - const DecisionNode* right_node = - IsAccept(right) ? nullptr : &GetNodeOrDie(right); - const PacketFieldHandle field = SmallestField(left_node, right_node); - const DecisionNodeView left_view = ViewAtField(field, left, left_node); - const DecisionNodeView right_view = ViewAtField(field, right, right_node); + // decision node. + const auto [field, left_view, right_view] = ViewOperandsAtSmallestField( + left, IsAccept(left) ? nullptr : &GetNodeOrDie(left), right, + IsAccept(right) ? nullptr : &GetNodeOrDie(right)); DecisionNode result_node{ .field = field, @@ -669,10 +652,13 @@ PacketTransformerHandle PacketTransformerManager::PointwiseCombine( *right_view.modify_branch_by_field_match, *left_view.default_branch_by_field_modification, *right_view.default_branch_by_field_modification)) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - ModifyBranchesAtValue(left_view, value), - ModifyBranchesAtValue(right_view, value), combiner, - /*default_value=*/Deny()); + // `value`s arrive in increasing order, so inserting at `end()` is O(1). + result_node.modify_branch_by_field_match.try_emplace( + result_node.modify_branch_by_field_match.end(), value, + CombineModifyBranches(ModifyBranchesAtValue(left_view, value), + ModifyBranchesAtValue(right_view, value), + combiner, + /*default_value=*/Deny())); } return NodeToTransformer(std::move(result_node)); diff --git a/netkat/packet_transformer_test_runner.cc b/netkat/packet_transformer_test_runner.cc index c4eacfc..e6a4ad7 100644 --- a/netkat/packet_transformer_test_runner.cc +++ b/netkat/packet_transformer_test_runner.cc @@ -16,13 +16,12 @@ // `bazel run //netkat:packet_transformer_diff_test // -- --update` -#include #include #include #include #include -#include "absl/algorithm/container.h" +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "netkat/netkat_proto_constructors.h" @@ -45,32 +44,33 @@ class PacketTransformerManagerTestPeer { PacketTransformerManager& to) { // Intern the fields of all reachable nodes in their original relative // order, which the node invariants ("fields increase strictly along each - // path") depend on. - std::vector fields; + // path") depend on, and record the handle translation for `Copy`. + absl::btree_set fields; absl::flat_hash_set visited; CollectFields(from, transformer, visited, fields); - absl::c_sort(fields); - fields.erase(std::unique(fields.begin(), fields.end()), fields.end()); + absl::flat_hash_map field_translation; for (PacketFieldHandle field : fields) { - to.packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( - from.packet_set_manager_.field_manager_.GetFieldName(field)); + field_translation.try_emplace( + field, + to.packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( + from.packet_set_manager_.field_manager_.GetFieldName(field))); } absl::flat_hash_map copy_by_original; - return Copy(from, transformer, to, copy_by_original); + return Copy(from, transformer, to, field_translation, copy_by_original); } private: static void CollectFields( const PacketTransformerManager& from, PacketTransformerHandle transformer, absl::flat_hash_set& visited, - std::vector& fields) { + absl::btree_set& fields) { if (from.IsDeny(transformer) || from.IsAccept(transformer)) return; if (!visited.insert(transformer).second) return; const PacketTransformerManager::DecisionNode& node = from.GetNodeOrDie(transformer); - fields.push_back(node.field); + fields.insert(node.field); for (const auto& [match_value, branch_by_modify] : node.modify_branch_by_field_match) { for (const auto& [modify_value, branch] : branch_by_modify) { @@ -87,6 +87,8 @@ class PacketTransformerManagerTestPeer { static PacketTransformerHandle Copy( const PacketTransformerManager& from, PacketTransformerHandle transformer, PacketTransformerManager& to, + const absl::flat_hash_map& + field_translation, absl::flat_hash_map& copy_by_original) { if (from.IsDeny(transformer)) return to.Deny(); @@ -99,10 +101,7 @@ class PacketTransformerManagerTestPeer { const PacketTransformerManager::DecisionNode& node = from.GetNodeOrDie(transformer); PacketTransformerManager::DecisionNode copy{ - .field = - to.packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( - from.packet_set_manager_.field_manager_.GetFieldName( - node.field)), + .field = field_translation.at(node.field), }; for (const auto& [match_value, branch_by_modify] : node.modify_branch_by_field_match) { @@ -112,15 +111,16 @@ class PacketTransformerManagerTestPeer { copy.modify_branch_by_field_match[match_value]; for (const auto& [modify_value, branch] : branch_by_modify) { copy_branch_by_modify[modify_value] = - Copy(from, branch, to, copy_by_original); + Copy(from, branch, to, field_translation, copy_by_original); } } for (const auto& [modify_value, branch] : node.default_branch_by_field_modification) { copy.default_branch_by_field_modification[modify_value] = - Copy(from, branch, to, copy_by_original); + Copy(from, branch, to, field_translation, copy_by_original); } - copy.default_branch = Copy(from, node.default_branch, to, copy_by_original); + copy.default_branch = Copy(from, node.default_branch, to, field_translation, + copy_by_original); PacketTransformerHandle result = to.NodeToTransformer(std::move(copy)); copy_by_original.emplace(transformer, result); From 622487145b97eb101c4f71a5c2300ea41423ce37 Mon Sep 17 00:00:00 2001 From: Steffen Smolka Date: Wed, 10 Jun 2026 08:51:57 -0700 Subject: [PATCH 4/4] [NetKAT] Store each interned decision node once: pointer-keyed unique tables. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The unique tables of both managers stored every node twice — once in the node arena (nodes_) and once as the hash map key — doubling node memory and copying each new node (nested btree maps and all, for transformer nodes) on insertion. The tables are now keyed by pointers into the arena, which PagedStableVector keeps stable across both growth and manager moves; transparent hash/equality functors allow lookup by node value before interning. Benchmarks improve ~2-9% across the board on top of #98, mostly from no longer copying freshly-created nodes into the key slot. Also: * Adds BM_PushAndPullFullSetThroughPolicy, covering the read-path operations (Push/Pull) that the analysis engine is built on; the existing benchmarks only measured compilation. Sizing note: without memoization, Push/Pull cost grows exponentially in policy size (8 unioned sub-policies ran for 20+ minutes), so the benchmark deliberately stays small — and the planned operation-memoization work is where that changes. * Includes ProtoHashKey.policy_case in the proto-cache hash; it was part of equality but omitted from the hash, causing avoidable collisions across policy kinds. * CheckInternalInvariants verifies pointer-table integrity and exercises both lookup paths in both managers. Co-Authored-By: Claude Fable 5 --- netkat/BUILD.bazel | 3 ++ netkat/packet_set.cc | 43 ++++++++++++++++------- netkat/packet_set.h | 31 +++++++++++++++-- netkat/packet_transformer.cc | 47 ++++++++++++++++++-------- netkat/packet_transformer.h | 36 ++++++++++++++++++-- netkat/packet_transformer_benchmark.cc | 26 ++++++++++++++ 6 files changed, 154 insertions(+), 32 deletions(-) diff --git a/netkat/BUILD.bazel b/netkat/BUILD.bazel index 5bffd85..a72f5c7 100644 --- a/netkat/BUILD.bazel +++ b/netkat/BUILD.bazel @@ -132,6 +132,7 @@ cc_library( "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -376,6 +377,7 @@ cc_library( "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -427,6 +429,7 @@ cc_binary( deps = [ ":netkat_cc_proto", ":netkat_proto_constructors", + ":packet_set", ":packet_transformer", "@com_google_absl//absl/strings", "@com_google_benchmark//:benchmark_main", diff --git a/netkat/packet_set.cc b/netkat/packet_set.cc index 2c7bd1c..fc111fe 100644 --- a/netkat/packet_set.cc +++ b/netkat/packet_set.cc @@ -24,6 +24,7 @@ #include "absl/algorithm/container.h" #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -87,6 +88,14 @@ const PacketSetManager::DecisionNode& PacketSetManager::GetNodeOrDie( return nodes_[packet_set.node_index_]; } +size_t PacketSetManager::NodeHash::operator()(const DecisionNode* node) const { + return absl::HashOf(*node); +} + +size_t PacketSetManager::NodeHash::operator()(const DecisionNode& node) const { + return absl::HashOf(node); +} + PacketSetHandle PacketSetManager::NodeToPacket(DecisionNode&& node) { if (node.branch_by_field_value.empty()) return node.default_branch; @@ -109,16 +118,19 @@ PacketSetHandle PacketSetManager::NodeToPacket(DecisionNode&& node) { } #endif - auto [it, inserted] = - packet_by_node_.try_emplace(node, PacketSetHandle(nodes_.size())); - if (inserted) { - nodes_.push_back(std::move(node)); - LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) - << "Internal invariant violated: Proper and sentinel node indices must " - "be disjoint. This indicates that we allocated more nodes than are " - "supported (> 2^32 - 2)."; + // Look up the node by value via the transparent `NodeHash`/`NodeEq` + // functors; only new nodes get stored (exactly once, in `nodes_`). + if (auto it = packet_by_node_.find(node); it != packet_by_node_.end()) { + return it->second; } - return it->second; + PacketSetHandle packet(nodes_.size()); + nodes_.push_back(std::move(node)); + packet_by_node_.insert({&nodes_[packet.node_index_], packet}); + LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) + << "Internal invariant violated: Proper and sentinel node indices must " + "be disjoint. This indicates that we allocated more nodes than are " + "supported (> 2^32 - 2)."; + return packet; } bool PacketSetManager::Contains(PacketSetHandle packet_set, @@ -483,14 +495,19 @@ absl::Status PacketSetManager::CheckInternalInvariants() const { // Invariant: Proper and sentinel node indices are disjoint. RET_CHECK(nodes_.size() <= SentinelNodeIndex::kMinSentinel); - // Invariant: `packet_by_node_[n] = s` iff `nodes_[s.node_index_] == n`. - for (const auto& [node, packet] : packet_by_node_) { + // Invariant: `packet_by_node_[p] = s` iff `p == &nodes_[s.node_index_]`. + for (const auto& [node_ptr, packet] : packet_by_node_) { RET_CHECK(packet.node_index_ < nodes_.size()); - RET_CHECK(nodes_[packet.node_index_] == node); + RET_CHECK(node_ptr == &nodes_[packet.node_index_]); } for (int i = 0; i < nodes_.size(); ++i) { const DecisionNode& node = nodes_[i]; - auto it = packet_by_node_.find(node); + // Look up both by pointer and by value (exercising the transparent + // functors used by `NodeToPacket`). + auto it = packet_by_node_.find(&node); + RET_CHECK(it != packet_by_node_.end()); + RET_CHECK(it->second == PacketSetHandle(i)); + it = packet_by_node_.find(node); RET_CHECK(it != packet_by_node_.end()); RET_CHECK(it->second == PacketSetHandle(i)); } diff --git a/netkat/packet_set.h b/netkat/packet_set.h index f11088f..696ae55 100644 --- a/netkat/packet_set.h +++ b/netkat/packet_set.h @@ -356,11 +356,38 @@ class PacketSetManager { // `And`, `Or`, `Not`). The class also avoids expensive relocations. PagedStableVector nodes_; + // Transparent hash and equality functors for the unique table + // (`packet_by_node_`), which is keyed by stable `DecisionNode*` pointers + // into `nodes_` (so each node is stored only once). Lookups work directly + // with a not-yet-interned `DecisionNode` value. Both functors are + // stateless: keys are pointers, and the pages holding the nodes are stable + // across moves of the manager. + struct NodeHash { + using is_transparent = void; + size_t operator()(const DecisionNode* node) const; + size_t operator()(const DecisionNode& node) const; + }; + struct NodeEq { + using is_transparent = void; + bool operator()(const DecisionNode* a, const DecisionNode* b) const { + return a == b || *a == *b; + } + bool operator()(const DecisionNode* a, const DecisionNode& b) const { + return *a == b; + } + bool operator()(const DecisionNode& a, const DecisionNode* b) const { + return a == *b; + } + }; + // A so called "unique table" to ensure each node is only added to `nodes_` // once, and thus has a unique `PacketSetHandle::node_index`. + // Keyed by pointers into `nodes_` (stable, see `PagedStableVector`), so + // nodes are not stored twice. // - // INVARIANT: `packet_by_node_[n] = s` iff `nodes_[s.node_index_] == n`. - absl::flat_hash_map packet_by_node_; + // INVARIANT: `packet_by_node_[p] = s` iff `p == &nodes_[s.node_index_]`. + absl::flat_hash_map + packet_by_node_; // A map of a given `PredicateProto` to a `PacketSetHandle`. // diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index 008fe6d..2dc1a80 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -29,6 +29,7 @@ #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -78,6 +79,16 @@ PacketTransformerManager::GetNodeOrDie( return nodes_[transformer.node_index_]; } +size_t PacketTransformerManager::NodeHash::operator()( + const DecisionNode* node) const { + return absl::HashOf(*node); +} + +size_t PacketTransformerManager::NodeHash::operator()( + const DecisionNode& node) const { + return absl::HashOf(node); +} + // Canonicalizes a decision node and returns a transformer. PacketTransformerHandle PacketTransformerManager::NodeToTransformer( DecisionNode&& node) { @@ -153,16 +164,20 @@ PacketTransformerHandle PacketTransformerManager::NodeToTransformer( node.default_branch_by_field_modification.empty()) return node.default_branch; - auto [it, inserted] = transformer_by_node_.try_emplace( - node, PacketTransformerHandle(nodes_.size())); - if (inserted) { - nodes_.push_back(std::move(node)); - LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) - << "Internal invariant violated: Proper and sentinel node indices must " - "be disjoint. This indicates that we allocated more nodes than are " - "supported (> 2^32 - 2)."; + // Look up the node by value via the transparent `NodeHash`/`NodeEq` + // functors; only new nodes get stored (exactly once, in `nodes_`). + if (auto it = transformer_by_node_.find(node); + it != transformer_by_node_.end()) { + return it->second; } - return it->second; + PacketTransformerHandle transformer(nodes_.size()); + nodes_.push_back(std::move(node)); + transformer_by_node_.insert({&nodes_[transformer.node_index_], transformer}); + LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) + << "Internal invariant violated: Proper and sentinel node indices must " + "be disjoint. This indicates that we allocated more nodes than are " + "supported (> 2^32 - 2)."; + return transformer; } bool PacketTransformerManager::IsDeny( @@ -1040,15 +1055,19 @@ absl::Status PacketTransformerManager::CheckInternalInvariants() const { // Invariant: Proper and sentinel node indices are disjoint. RET_CHECK(nodes_.size() <= SentinelNodeIndex::kMinSentinel); - // Invariant: `transformer_by_node_[n] = s` iff `nodes_[s.node_index_] == - // n`. - for (const auto& [node, transformer] : transformer_by_node_) { + // Invariant: `transformer_by_node_[p] = s` iff `p == &nodes_[s.node_index_]`. + for (const auto& [node_ptr, transformer] : transformer_by_node_) { RET_CHECK(transformer.node_index_ < nodes_.size()); - RET_CHECK(nodes_[transformer.node_index_] == node); + RET_CHECK(node_ptr == &nodes_[transformer.node_index_]); } for (int i = 0; i < nodes_.size(); ++i) { const DecisionNode& node = nodes_[i]; - auto it = transformer_by_node_.find(node); + // Look up both by pointer and by value (exercising the transparent + // functors used by `NodeToTransformer`). + auto it = transformer_by_node_.find(&node); + RET_CHECK(it != transformer_by_node_.end()); + RET_CHECK(it->second == PacketTransformerHandle(i)); + it = transformer_by_node_.find(node); RET_CHECK(it != transformer_by_node_.end()); RET_CHECK(it->second == PacketTransformerHandle(i)); } diff --git a/netkat/packet_transformer.h b/netkat/packet_transformer.h index ed0e074..d723be9 100644 --- a/netkat/packet_transformer.h +++ b/netkat/packet_transformer.h @@ -383,7 +383,33 @@ class PacketTransformerManager { template friend H AbslHashValue(H h, const ProtoHashKey& key) { - return H::combine(std::move(h), key.lhs_child, key.rhs_child); + return H::combine(std::move(h), key.policy_case, key.lhs_child, + key.rhs_child); + } + }; + + // Transparent hash and equality functors for the unique table + // (`transformer_by_node_`), which is keyed by stable `DecisionNode*` + // pointers into `nodes_` so that each node is stored only once (rather than + // twice: once in `nodes_` and once as a map key). Lookups work directly + // with a not-yet-interned `DecisionNode` value. Both functors are + // stateless: keys are pointers, and the pages holding the nodes are stable + // across moves of the manager. + struct NodeHash { + using is_transparent = void; + size_t operator()(const DecisionNode* node) const; + size_t operator()(const DecisionNode& node) const; + }; + struct NodeEq { + using is_transparent = void; + bool operator()(const DecisionNode* a, const DecisionNode* b) const { + return a == b || *a == *b; + } + bool operator()(const DecisionNode* a, const DecisionNode& b) const { + return *a == b; + } + bool operator()(const DecisionNode& a, const DecisionNode* b) const { + return a == *b; } }; @@ -425,9 +451,13 @@ class PacketTransformerManager { // A so called "unique table" to ensure each node is only added to `nodes_` // once, and thus has a unique `PacketTransformerHandle::node_index`. + // Keyed by pointers into `nodes_` (stable, see `PagedStableVector`), so + // nodes are not stored twice. The transparent `NodeHash`/`NodeEq` functors + // support lookup by node value, see their documentation. // - // INVARIANT: `transformer_by_node_[n] = s` iff `nodes_[s.node_index_] == n`. - absl::flat_hash_map + // INVARIANT: `transformer_by_node_[p] = s` iff `p == &nodes_[s.node_index_]`. + absl::flat_hash_map transformer_by_node_; // A map of a given `PolicyProto` to a `PacketTransformerHandle`. diff --git a/netkat/packet_transformer_benchmark.cc b/netkat/packet_transformer_benchmark.cc index d927ffc..30d1a01 100644 --- a/netkat/packet_transformer_benchmark.cc +++ b/netkat/packet_transformer_benchmark.cc @@ -18,6 +18,7 @@ #include "benchmark/benchmark.h" #include "netkat/netkat.pb.h" #include "netkat/netkat_proto_constructors.h" +#include "netkat/packet_set.h" #include "netkat/packet_transformer.h" namespace netkat { @@ -99,6 +100,31 @@ void BM_FirstTimeCompileOverlappingPolicy(benchmark::State& state) { } BENCHMARK(BM_FirstTimeCompileOverlappingPolicy); +// Benchmarks the read-path operations that the analysis engine is built on: +// pushing/pulling packet sets through an already-compiled transformer. +// After the first iteration all nodes exist, so steady-state iterations +// exercise DAG traversal and unique-table hits rather than first-time node +// creation. +void BM_PushAndPullFullSetThroughPolicy(benchmark::State& state) { + PacketTransformerManager manager; + PolicyProto policy = CreateFixedArbitraryPolicyProto(0); + // NOTE: Without memoization of the recursive transformer operations, + // Push/Pull cost grows exponentially with policy size; keep the number of + // unioned sub-policies small so the benchmark stays tractable. + for (int i = 1; i < 2; ++i) { + policy = UnionProto( + policy, SequenceProto(CreateFixedArbitraryPolicyProto(i), + CreateFixedArbitraryPolicyProto(i + 8))); + } + PacketTransformerHandle transformer = manager.Compile(policy); + PacketSetHandle full_set = manager.GetPacketSetManager().FullSet(); + for (auto s : state) { + benchmark::DoNotOptimize(manager.Push(full_set, transformer)); + benchmark::DoNotOptimize(manager.Pull(transformer, full_set)); + } +} +BENCHMARK(BM_PushAndPullFullSetThroughPolicy); + // Benchmarks the cost of compiling a policy, with overlapping substructures, // that has already been compiled once before. Excludes the initial cost of // compilation.