From 1ce23ab0507ef02d0faf9744f42826cb840508db Mon Sep 17 00:00:00 2001 From: Steffen Smolka Date: Wed, 10 Jun 2026 03:13:50 -0700 Subject: [PATCH 1/5] [NetKAT] Eliminate node and map copies in packet transformer combinators. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Union, Sequence, and Difference copied entire DecisionNodes — including their nested btree maps — at every recursion step, re-interned expanded nodes through the unique table just to recover handles the caller already had, and copied the per-value modification map (GetMapAtValue) inside per-value loops. Several TODOs flagged this as a likely performance problem. Restructure the combinators around cheap, non-owning views instead: * DecisionNodeView views an operand "at" a field: either the node's own maps, or the trivial "fall through to this handle" expansion when the node branches on a larger field (or is Accept). This removes both the by-value node copies and the materialize-and-re-intern expansion step. * ModifyBranchesView replaces GetMapAtValue's map copies with a view of the base map plus at most one extra entry, iterated in sorted order. * CombineModifyBranches becomes a template over the combiner, removing AnyInvocable type-erasure and double lookups, and inserts with an end hint since keys arrive sorted. * The value-collection loops now iterate in sorted order (previously nondeterministic flat_hash_set order), making combination order — and thus node numbering — deterministic. Compilation benchmarks improve by roughly 1.5-2.2x (FirstTimeCompile*), with larger wins expected on policies with wider modification maps. Behavior is unchanged: all tests pass, and the golden-file diff test output is byte-identical, including node numbering. Co-Authored-By: Claude Fable 5 --- netkat/packet_transformer.cc | 654 +++++++++++++++++------------------ netkat/packet_transformer.h | 12 - 2 files changed, 322 insertions(+), 344 deletions(-) diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index a125c8b..c2807da 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -14,8 +14,8 @@ #include "netkat/packet_transformer.h" +#include #include -#include #include #include #include @@ -28,7 +28,6 @@ #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -78,22 +77,6 @@ PacketTransformerManager::GetNodeOrDie( return nodes_[transformer.node_index_]; } -// TODO(dilo): Creating as many map copies as this method facilitates is -// probably going to cause terrible performance, and needs to be revisited. -absl::btree_map -PacketTransformerManager::GetMapAtValue(const DecisionNode& node, int value) { - if (node.modify_branch_by_field_match.contains(value)) - return node.modify_branch_by_field_match.at(value); - - absl::btree_map result = - node.default_branch_by_field_modification; - if (result.contains(value) || IsDeny(node.default_branch)) return result; - - // Otherwise, add a mapping from `value` to the default branch, then return. - result[value] = node.default_branch; - return result; -} - // Canonicalizes a decision node and returns a transformer. PacketTransformerHandle PacketTransformerManager::NodeToTransformer( DecisionNode&& node) { @@ -376,249 +359,288 @@ PacketTransformerHandle PacketTransformerManager::Modification( } namespace { -absl::btree_map CombineModifyBranches( - const absl::btree_map& left, - const absl::btree_map& right, - absl::AnyInvocable - combiner, - PacketTransformerHandle default_value) { - absl::btree_map result; - for (const auto& [value, branch] : left) { - if (right.contains(value)) { - result[value] = combiner(branch, right.at(value)); - } else { - result[value] = combiner(branch, default_value); + +// Aliases for the map types of `PacketTransformerManager::DecisionNode`. +using ModifyBranchMap = absl::btree_map; +using MatchBranchMap = absl::btree_map; + +const MatchBranchMap& EmptyMatchBranchMap() { + static const MatchBranchMap* const kEmpty = new MatchBranchMap; + return *kEmpty; +} + +const ModifyBranchMap& EmptyModifyBranchMap() { + static const ModifyBranchMap* const kEmpty = new ModifyBranchMap; + return *kEmpty; +} + +// A cheap, non-owning view of a decision node as seen "at" a field no larger +// than the node's own field: either the node's own maps (if the node branches +// on that field), or the trivial expansion "for every value of the field, +// leave it unmodified and fall through to the node" (if the node branches on +// a strictly larger field, or is Accept). This lets the binary operations +// below combine operands with distinct fields without materializing — and +// then re-interning — the trivial expansion as a `DecisionNode`. +struct DecisionNodeView { + const MatchBranchMap* modify_branch_by_field_match; + const ModifyBranchMap* default_branch_by_field_modification; + PacketTransformerHandle default_branch; +}; + +// Returns the view of `node` (the decision node of `transformer`, or null if +// `transformer` is Accept) at `field`, which must be <= `node->field`. +// +// `Node` is always `PacketTransformerManager::DecisionNode`; it is a template +// parameter only because that type is private and cannot be named here. +template +DecisionNodeView ViewAtField(PacketFieldHandle field, + PacketTransformerHandle transformer, + const Node* node) { + if (node != nullptr && node->field == field) { + return DecisionNodeView{ + .modify_branch_by_field_match = &node->modify_branch_by_field_match, + .default_branch_by_field_modification = + &node->default_branch_by_field_modification, + .default_branch = node->default_branch, + }; + } + return DecisionNodeView{ + .modify_branch_by_field_match = &EmptyMatchBranchMap(), + .default_branch_by_field_modification = &EmptyModifyBranchMap(), + .default_branch = transformer, + }; +} + +// Returns the smallest field branched on by the given decision nodes, at +// least one of which must be non-null (null encodes Accept, which branches on +// no field). See `ViewAtField` regarding the `Node` template parameter. +template +PacketFieldHandle SmallestField(const Node* left, const Node* right) { + DCHECK(left != nullptr || right != nullptr); + if (left == nullptr) return right->field; + if (right == nullptr) return left->field; + return std::min(left->field, right->field); +} + +// A cheap, non-owning view of a logical (modify value -> branch) map, +// represented as a base map plus at most one extra entry whose key must not +// be a key of the base map. Iteration is in increasing key order, like the +// underlying btree map. +class ModifyBranchesView { + public: + using Entry = std::pair; + + explicit ModifyBranchesView(const ModifyBranchMap& base) : base_(&base) {} + ModifyBranchesView(const ModifyBranchMap& base, Entry extra) + : base_(&base), extra_(extra) { + DCHECK(!base.contains(extra.first)); + } + + class const_iterator { + public: + Entry operator*() const { + return ExtraIsNext() ? *extra_ : Entry(it_->first, it_->second); + } + const_iterator& operator++() { + if (ExtraIsNext()) { + extra_ = nullptr; + } else { + ++it_; + } + return *this; + } + friend bool operator==(const const_iterator& a, const const_iterator& b) { + return a.it_ == b.it_ && a.extra_ == b.extra_; } + friend bool operator!=(const const_iterator& a, const const_iterator& b) { + return !(a == b); + } + + private: + friend class ModifyBranchesView; + const_iterator(ModifyBranchMap::const_iterator it, + ModifyBranchMap::const_iterator end, const Entry* extra) + : it_(it), end_(end), extra_(extra) {} + + // The extra entry comes next iff it has not been consumed and its key + // precedes the next base entry's key. + bool ExtraIsNext() const { + return extra_ != nullptr && (it_ == end_ || extra_->first < it_->first); + } + + ModifyBranchMap::const_iterator it_, end_; + const Entry* extra_; // Null if absent or already consumed. + }; + + const_iterator begin() const { + return const_iterator(base_->begin(), base_->end(), + extra_.has_value() ? &*extra_ : nullptr); } - for (const auto& [value, branch] : right) { - if (!result.contains(value)) - result[value] = combiner(default_value, branch); + const_iterator end() const { + return const_iterator(base_->end(), base_->end(), nullptr); + } + + bool empty() const { return base_->empty() && !extra_.has_value(); } + bool contains(int value) const { + return (extra_.has_value() && extra_->first == value) || + base_->contains(value); + } + std::optional find(int value) const { + if (extra_.has_value() && extra_->first == value) return extra_->second; + if (auto it = base_->find(value); it != base_->end()) return it->second; + return std::nullopt; + } + + private: + const ModifyBranchMap* base_; + std::optional extra_; +}; + +// Returns a view of the logical (modify value -> branch) map that `node` +// applies to packets whose field is equal to `value`: the matching entry of +// `modify_branch_by_field_match` if there is one; otherwise the default +// modifications, plus the unmodified fall-through to `default_branch` (keyed +// by `value`, since the field keeps its value) unless that branch is Deny or +// shadowed by a default modification to `value`. +ModifyBranchesView ModifyBranchesAtValue(const DecisionNodeView& node, + int value) { + if (auto it = node.modify_branch_by_field_match->find(value); + it != node.modify_branch_by_field_match->end()) { + return ModifyBranchesView(it->second); + } + const ModifyBranchMap& defaults = *node.default_branch_by_field_modification; + // `PacketTransformerHandle()` is the Deny transformer. + if (defaults.contains(value) || + node.default_branch == PacketTransformerHandle()) { + return ModifyBranchesView(defaults); + } + return ModifyBranchesView(defaults, {value, node.default_branch}); +} + +// Combines two (modify value -> branch) maps key-wise into a new map, using +// `combiner(left_branch, right_branch)` for shared keys and substituting +// `default_value` for the missing side otherwise. +template +ModifyBranchMap CombineModifyBranches(const ModifyBranchesView& left, + const ModifyBranchesView& right, + Combiner&& combiner, + PacketTransformerHandle default_value) { + ModifyBranchMap result; + for (auto [value, left_branch] : left) { + // Keys arrive in increasing order, so inserting at `end()` is O(1). + result.try_emplace( + result.end(), value, + combiner(left_branch, right.find(value).value_or(default_value))); + } + for (auto [value, right_branch] : right) { + if (left.contains(value)) continue; + result.try_emplace(value, combiner(default_value, right_branch)); } return result; } +// Returns the union of the keys of the given maps, sorted and deduplicated. +template +std::vector SortedUniqueKeys(const Maps&... maps) { + std::vector keys; + keys.reserve((maps.size() + ... + 0)); + auto append_keys = [&keys](const auto& map) { + for (const auto& entry : map) keys.push_back(entry.first); + }; + (append_keys(maps), ...); + absl::c_sort(keys); + keys.erase(std::unique(keys.begin(), keys.end()), keys.end()); + return keys; +} + } // namespace -PacketTransformerHandle PacketTransformerManager::Sequence(DecisionNode left, - DecisionNode right) { - // left.field > right.field: Expand the left node, reducing to the inductive - // case. - if (left.field > right.field) { - PacketFieldHandle first_field = right.field; - return Sequence( - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(left)), - }, - std::move(right)); - } +PacketTransformerHandle PacketTransformerManager::Sequence( + PacketTransformerHandle left, PacketTransformerHandle right) { + // Base cases. + if (IsDeny(left) || IsDeny(right)) return Deny(); + if (IsAccept(left)) return right; + if (IsAccept(right)) return left; - // left.field < right.field: Expand the right node, reducing to the - // inductive case. - if (left.field < right.field) { - PacketFieldHandle first_field = left.field; - return Sequence(std::move(left), - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(right)), - }); - } + // Both operands are decision nodes. Combine them at the smaller of their + // fields; an operand branching on a strictly larger field is viewed as a + // trivial node at the combined field. + const DecisionNode* left_node = &GetNodeOrDie(left); + const DecisionNode* right_node = &GetNodeOrDie(right); + const PacketFieldHandle field = SmallestField(left_node, right_node); + const DecisionNodeView left_view = ViewAtField(field, left, left_node); + const DecisionNodeView right_view = ViewAtField(field, right, right_node); + + auto sequence_combiner = [this](PacketTransformerHandle left, + PacketTransformerHandle right) { + return Sequence(left, right); + }; + auto union_combiner = [this](PacketTransformerHandle left, + PacketTransformerHandle right) { + return Union(left, right); + }; + const ModifyBranchesView empty_view(EmptyModifyBranchMap()); - // left.field == right.field: branch on shared field. - DCHECK(left.field == right.field); DecisionNode result_node{ - .field = left.field, - .default_branch = Sequence(left.default_branch, right.default_branch), + .field = field, + .default_branch = + Sequence(left_view.default_branch, right_view.default_branch), }; // Construct the possible results of applying the right node to packets // gotten by taken default modification branches in the left node. - absl::btree_map - right_applied_to_left_modifications; + ModifyBranchMap right_applied_to_left_modifications; for (const auto& [value, branch] : - left.default_branch_by_field_modification) { - absl::btree_map right_at_value_with_sequence = - CombineModifyBranches( - {}, GetMapAtValue(right, value), - /*combiner=*/ - [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Sequence(left, right); - }, - /*default_value=*/branch); + *left_view.default_branch_by_field_modification) { + ModifyBranchMap right_at_value_with_sequence = CombineModifyBranches( + empty_view, ModifyBranchesAtValue(right_view, value), sequence_combiner, + /*default_value=*/branch); right_applied_to_left_modifications = CombineModifyBranches( - right_applied_to_left_modifications, right_at_value_with_sequence, - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, + ModifyBranchesView(right_applied_to_left_modifications), + ModifyBranchesView(right_at_value_with_sequence), union_combiner, /*default_value=*/Deny()); } + ModifyBranchMap sequenced_right_modifications = CombineModifyBranches( + empty_view, + ModifyBranchesView(*right_view.default_branch_by_field_modification), + sequence_combiner, + /*default_value=*/left_view.default_branch); result_node.default_branch_by_field_modification = CombineModifyBranches( - right_applied_to_left_modifications, - CombineModifyBranches( - {}, right.default_branch_by_field_modification, - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Sequence(left, right); - }, - /*default_value=*/left.default_branch), - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, + ModifyBranchesView(right_applied_to_left_modifications), + ModifyBranchesView(sequenced_right_modifications), union_combiner, /*default_value=*/Deny()); - // Collect every value mapped in each node. - absl::flat_hash_set all_possible_values; - all_possible_values.reserve( - left.modify_branch_by_field_match.size() + - right.modify_branch_by_field_match.size() + - left.default_branch_by_field_modification.size() + - right.default_branch_by_field_modification.size() + - right_applied_to_left_modifications.size()); - - absl::c_transform( - left.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - left.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right_applied_to_left_modifications, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - - // For every value in mapped in each node, construct the proper new branch. - for (int value : all_possible_values) { - auto left_map_at_value = GetMapAtValue(left, value); + // For every value mapped in either node, construct the proper new branch. + for (int value : + SortedUniqueKeys(*left_view.modify_branch_by_field_match, + *right_view.modify_branch_by_field_match, + *left_view.default_branch_by_field_modification, + *right_view.default_branch_by_field_modification, + right_applied_to_left_modifications)) { + ModifyBranchesView left_branches_at_value = + ModifyBranchesAtValue(left_view, value); // An empty map is equivalent to a map with a single entry of // , but the latter is not always canonical. However, an // empty map won't work correctly for the merges below (an in fact, the // whole for-loop would be skipped), so we expand it here if necessary. - if (left_map_at_value.empty()) left_map_at_value[value] = Deny(); - - for (const auto& [left_value, left_spp] : left_map_at_value) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - result_node.modify_branch_by_field_match[value], - CombineModifyBranches( - {}, GetMapAtValue(right, left_value), - /*combiner=*/ - [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Sequence(left, right); - }, - /*default_value=*/left_spp), - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, - /*default_value=*/Deny()); + if (left_branches_at_value.empty()) { + left_branches_at_value = + ModifyBranchesView(EmptyModifyBranchMap(), {value, Deny()}); } - } - - return NodeToTransformer(std::move(result_node)); -} - -PacketTransformerHandle PacketTransformerManager::Sequence( - PacketTransformerHandle left, PacketTransformerHandle right) { - // Base cases. - if (IsDeny(left) || IsDeny(right)) return Deny(); - if (IsAccept(left)) return right; - if (IsAccept(right)) return left; - - // If neither node is accept or deny, then sequence the nodes directly. - return Sequence(GetNodeOrDie(left), GetNodeOrDie(right)); -} -PacketTransformerHandle PacketTransformerManager::Union(DecisionNode left, - DecisionNode right) { - // left.field > right.field: Expand the left node, reducing to the inductive - // case. - if (left.field > right.field) { - PacketFieldHandle first_field = right.field; - return Union( - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(left)), - }, - std::move(right)); - } - - // left.field < right.field: Expand the right node, reducing to the - // inductive case. - if (left.field < right.field) { - PacketFieldHandle first_field = left.field; - return Union(std::move(left), - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(right)), - }); - } - - // left.field == right.field: branch on shared field. - DCHECK(left.field == right.field); - DecisionNode result_node{ - .field = left.field, - .default_branch_by_field_modification = CombineModifyBranches( - left.default_branch_by_field_modification, - right.default_branch_by_field_modification, - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, - /*default_value=*/Deny()), - .default_branch = Union(left.default_branch, right.default_branch), - }; - - // Collect every value in mapped in each node. - absl::flat_hash_set all_possible_values; - all_possible_values.reserve( - left.modify_branch_by_field_match.size() + - right.modify_branch_by_field_match.size() + - left.default_branch_by_field_modification.size() + - right.default_branch_by_field_modification.size()); - - absl::c_transform( - left.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - left.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - - // For every value in mapped in each node, construct the proper new branch. - // TODO(dilo): Would like to use absl::bind_front here instead of a lambda: - // absl::bind_front( - // &PacketTransformerManager::Union, this), - for (int value : all_possible_values) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - GetMapAtValue(left, value), GetMapAtValue(right, value), - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, - /*default_value=*/Deny()); + ModifyBranchMap& result_branches = + result_node.modify_branch_by_field_match[value]; + for (auto [left_value, left_branch] : left_branches_at_value) { + ModifyBranchMap sequenced = CombineModifyBranches( + empty_view, ModifyBranchesAtValue(right_view, left_value), + sequence_combiner, + /*default_value=*/left_branch); + result_branches = + CombineModifyBranches(ModifyBranchesView(result_branches), + ModifyBranchesView(sequenced), union_combiner, + /*default_value=*/Deny()); + } } return NodeToTransformer(std::move(result_node)); @@ -631,95 +653,43 @@ PacketTransformerHandle PacketTransformerManager::Union( if (IsDeny(right)) return left; if (IsDeny(left)) return right; - // If either node is accept, then expand it before merging. - if (IsAccept(left) || IsAccept(right)) { - const DecisionNode& other_node = - GetNodeOrDie(IsAccept(left) ? right : left); - return Union( - DecisionNode{ - .field = other_node.field, - .default_branch = Accept(), - }, - other_node); - } - - // If neither node is accept or deny, then union the nodes directly. - return Union(GetNodeOrDie(left), GetNodeOrDie(right)); -} - -PacketTransformerHandle PacketTransformerManager::Difference( - DecisionNode left, DecisionNode right) { - // left.field > right.field: Expand the left node, reducing to the inductive - // case. - if (left.field > right.field) { - PacketFieldHandle first_field = right.field; - return Difference( - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(left)), - }, - std::move(right)); - } - - // left.field < right.field: Expand the right node, reducing to the - // inductive case. - if (left.field < right.field) { - PacketFieldHandle first_field = left.field; - return Difference(std::move(left), - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(right)), - }); - } + // Neither operand is Deny and at most one is Accept, so at least one is a + // decision node. Combine the operands at the smallest field branched on by + // either; an operand that is Accept, or branches on a strictly larger + // field, is viewed as a trivial node at that field. + const DecisionNode* left_node = + IsAccept(left) ? nullptr : &GetNodeOrDie(left); + const DecisionNode* right_node = + IsAccept(right) ? nullptr : &GetNodeOrDie(right); + const PacketFieldHandle field = SmallestField(left_node, right_node); + const DecisionNodeView left_view = ViewAtField(field, left, left_node); + const DecisionNodeView right_view = ViewAtField(field, right, right_node); + + auto union_combiner = [this](PacketTransformerHandle left, + PacketTransformerHandle right) { + return Union(left, right); + }; - // left.field == right.field: branch on shared field. - DCHECK(left.field == right.field); DecisionNode result_node{ - .field = left.field, + .field = field, .default_branch_by_field_modification = CombineModifyBranches( - left.default_branch_by_field_modification, - right.default_branch_by_field_modification, - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Difference(left, right); - }, + ModifyBranchesView(*left_view.default_branch_by_field_modification), + ModifyBranchesView(*right_view.default_branch_by_field_modification), + union_combiner, /*default_value=*/Deny()), - .default_branch = Difference(left.default_branch, right.default_branch), + .default_branch = + Union(left_view.default_branch, right_view.default_branch), }; - // Collect every value in mapped in each node. - absl::flat_hash_set all_possible_values; - all_possible_values.reserve( - left.modify_branch_by_field_match.size() + - right.modify_branch_by_field_match.size() + - left.default_branch_by_field_modification.size() + - right.default_branch_by_field_modification.size()); - - absl::c_transform( - left.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - left.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - - // For every value in mapped in each node, construct the proper new branch. - for (int value : all_possible_values) { + // For every value mapped in either node, construct the proper new branch. + for (int value : + SortedUniqueKeys(*left_view.modify_branch_by_field_match, + *right_view.modify_branch_by_field_match, + *left_view.default_branch_by_field_modification, + *right_view.default_branch_by_field_modification)) { result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - GetMapAtValue(left, value), GetMapAtValue(right, value), - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Difference(left, right); - }, + ModifyBranchesAtValue(left_view, value), + ModifyBranchesAtValue(right_view, value), union_combiner, /*default_value=*/Deny()); } @@ -733,27 +703,47 @@ PacketTransformerHandle PacketTransformerManager::Difference( if (IsDeny(left)) return Deny(); if (IsDeny(right)) return left; - // If either node is accept, then expand it before merging. - if (IsAccept(left)) { - const DecisionNode& right_node = GetNodeOrDie(right); - return Difference( - DecisionNode{ - .field = right_node.field, - .default_branch = Accept(), - }, - right_node); - } + // Neither operand is Deny and at most one is Accept, so at least one is a + // decision node. Combine the operands at the smallest field branched on by + // either; an operand that is Accept, or branches on a strictly larger + // field, is viewed as a trivial node at that field. + const DecisionNode* left_node = + IsAccept(left) ? nullptr : &GetNodeOrDie(left); + const DecisionNode* right_node = + IsAccept(right) ? nullptr : &GetNodeOrDie(right); + const PacketFieldHandle field = SmallestField(left_node, right_node); + const DecisionNodeView left_view = ViewAtField(field, left, left_node); + const DecisionNodeView right_view = ViewAtField(field, right, right_node); + + auto difference_combiner = [this](PacketTransformerHandle left, + PacketTransformerHandle right) { + return Difference(left, right); + }; + + DecisionNode result_node{ + .field = field, + .default_branch_by_field_modification = CombineModifyBranches( + ModifyBranchesView(*left_view.default_branch_by_field_modification), + ModifyBranchesView(*right_view.default_branch_by_field_modification), + difference_combiner, + /*default_value=*/Deny()), + .default_branch = + Difference(left_view.default_branch, right_view.default_branch), + }; - if (IsAccept(right)) { - const DecisionNode& left_node = GetNodeOrDie(left); - return Difference(left_node, DecisionNode{ - .field = left_node.field, - .default_branch = Accept(), - }); + // For every value mapped in either node, construct the proper new branch. + for (int value : + SortedUniqueKeys(*left_view.modify_branch_by_field_match, + *right_view.modify_branch_by_field_match, + *left_view.default_branch_by_field_modification, + *right_view.default_branch_by_field_modification)) { + result_node.modify_branch_by_field_match[value] = CombineModifyBranches( + ModifyBranchesAtValue(left_view, value), + ModifyBranchesAtValue(right_view, value), difference_combiner, + /*default_value=*/Deny()); } - // If neither node is accept or deny, then difference the nodes directly. - return Difference(GetNodeOrDie(left), GetNodeOrDie(right)); + return NodeToTransformer(std::move(result_node)); } PacketTransformerHandle PacketTransformerManager::Iterate( diff --git a/netkat/packet_transformer.h b/netkat/packet_transformer.h index 4c9c90d..b63bf19 100644 --- a/netkat/packet_transformer.h +++ b/netkat/packet_transformer.h @@ -404,18 +404,6 @@ class PacketTransformerManager { // enough to avoid excessive memory overhead. static constexpr size_t kPageSize = (1 << 26) / sizeof(DecisionNode); - // Helper functions to deal with DecisionNodes directly. - // TODO(dilo): Is there a convenient way to either avoid these or avoid making - // copies of the nodes? - PacketTransformerHandle Union(DecisionNode left, DecisionNode right); - PacketTransformerHandle Sequence(DecisionNode left, DecisionNode right); - PacketTransformerHandle Difference(DecisionNode left, DecisionNode right); - - // Internal helper function to get a map of possible modification values to - // branches for a given input value at `node`. - absl::btree_map GetMapAtValue( - const DecisionNode& node, int value); - // The decision nodes forming the BDD-style DAG representation of packets. // `PacketTransformerHandle::node_index_` indexes into this vector. // From bf0e8ae8dfd8196e6af9e29893cc11db5ffe79c4 Mon Sep 17 00:00:00 2001 From: Steffen Smolka Date: Wed, 10 Jun 2026 03:29:02 -0700 Subject: [PATCH 2/5] [NetKAT] Follow-up cleanups: in-place Sequence accumulation, op dedup, stable goldens. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three follow-ups to the copy-elimination refactor: * Sequence accumulated its result maps by rebuilding them from scratch for every left-branch entry, making accumulation quadratic in the number of branches. It now unions each sequenced branch into the result map in place, which also removes the remaining intermediate maps entirely. * Union and Difference were near-identical after the refactor; their shared body (pointwise combination of two operands viewed at a common field) moves into a PointwiseCombine template parameterized on the operation. * The golden diff test asserted raw node-interning indices, so any reordering inside the manager — including the change above, and planned work like operation memoization — churned the golden file without any semantic change. The test runner now renumbers nodes and fields canonically before printing, by copying each compiled transformer into a fresh manager in deterministic traversal order, so the golden output depends only on transformer structure. The golden file is regenerated accordingly (verified: the previous diff was pure renumbering, with structure, fields, and labels identical). Compilation benchmarks improve a further ~9% on top of the previous commit (~1.8x total vs the original code on FirstTimeCompile NonOverlappingPolicy). All tests pass. Co-Authored-By: Claude Fable 5 --- netkat/BUILD.bazel | 4 + netkat/packet_transformer.cc | 144 +++++++----------- netkat/packet_transformer.h | 11 ++ netkat/packet_transformer_test.expected | 186 +++++++++++------------ netkat/packet_transformer_test_runner.cc | 117 +++++++++++++- 5 files changed, 273 insertions(+), 189 deletions(-) diff --git a/netkat/BUILD.bazel b/netkat/BUILD.bazel index 9121fef..a456370 100644 --- a/netkat/BUILD.bazel +++ b/netkat/BUILD.bazel @@ -440,7 +440,11 @@ cc_test( linkstatic = True, deps = [ ":netkat_proto_constructors", + ":packet_field", ":packet_transformer", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index c2807da..efd421f 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -571,15 +571,12 @@ PacketTransformerHandle PacketTransformerManager::Sequence( const DecisionNodeView left_view = ViewAtField(field, left, left_node); const DecisionNodeView right_view = ViewAtField(field, right, right_node); - auto sequence_combiner = [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Sequence(left, right); + // Unions `branch` into `branches[modify_value]`, treating absence as Deny. + auto union_into = [this](ModifyBranchMap& branches, int modify_value, + PacketTransformerHandle branch) { + auto [it, inserted] = branches.try_emplace(modify_value, branch); + if (!inserted) it->second = Union(it->second, branch); }; - auto union_combiner = [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Union(left, right); - }; - const ModifyBranchesView empty_view(EmptyModifyBranchMap()); DecisionNode result_node{ .field = field, @@ -588,28 +585,24 @@ PacketTransformerHandle PacketTransformerManager::Sequence( }; // Construct the possible results of applying the right node to packets - // gotten by taken default modification branches in the left node. - ModifyBranchMap right_applied_to_left_modifications; + // gotten by taken default modification branches in the left node... + ModifyBranchMap& result_modifications = + result_node.default_branch_by_field_modification; for (const auto& [value, branch] : *left_view.default_branch_by_field_modification) { - ModifyBranchMap right_at_value_with_sequence = CombineModifyBranches( - empty_view, ModifyBranchesAtValue(right_view, value), sequence_combiner, - /*default_value=*/branch); - right_applied_to_left_modifications = CombineModifyBranches( - ModifyBranchesView(right_applied_to_left_modifications), - ModifyBranchesView(right_at_value_with_sequence), union_combiner, - /*default_value=*/Deny()); + for (auto [modify_value, right_branch] : + ModifyBranchesAtValue(right_view, value)) { + union_into(result_modifications, modify_value, + Sequence(branch, right_branch)); + } + } + // ... and of applying the right node's default modifications to packets + // falling through the left node unmodified. + for (const auto& [modify_value, right_branch] : + *right_view.default_branch_by_field_modification) { + union_into(result_modifications, modify_value, + Sequence(left_view.default_branch, right_branch)); } - - ModifyBranchMap sequenced_right_modifications = CombineModifyBranches( - empty_view, - ModifyBranchesView(*right_view.default_branch_by_field_modification), - sequence_combiner, - /*default_value=*/left_view.default_branch); - result_node.default_branch_by_field_modification = CombineModifyBranches( - ModifyBranchesView(right_applied_to_left_modifications), - ModifyBranchesView(sequenced_right_modifications), union_combiner, - /*default_value=*/Deny()); // For every value mapped in either node, construct the proper new branch. for (int value : @@ -617,7 +610,7 @@ PacketTransformerHandle PacketTransformerManager::Sequence( *right_view.modify_branch_by_field_match, *left_view.default_branch_by_field_modification, *right_view.default_branch_by_field_modification, - right_applied_to_left_modifications)) { + result_modifications)) { ModifyBranchesView left_branches_at_value = ModifyBranchesAtValue(left_view, value); // An empty map is equivalent to a map with a single entry of @@ -632,27 +625,21 @@ PacketTransformerHandle PacketTransformerManager::Sequence( ModifyBranchMap& result_branches = result_node.modify_branch_by_field_match[value]; for (auto [left_value, left_branch] : left_branches_at_value) { - ModifyBranchMap sequenced = CombineModifyBranches( - empty_view, ModifyBranchesAtValue(right_view, left_value), - sequence_combiner, - /*default_value=*/left_branch); - result_branches = - CombineModifyBranches(ModifyBranchesView(result_branches), - ModifyBranchesView(sequenced), union_combiner, - /*default_value=*/Deny()); + for (auto [modify_value, right_branch] : + ModifyBranchesAtValue(right_view, left_value)) { + union_into(result_branches, modify_value, + Sequence(left_branch, right_branch)); + } } } return NodeToTransformer(std::move(result_node)); } -PacketTransformerHandle PacketTransformerManager::Union( - PacketTransformerHandle left, PacketTransformerHandle right) { - // Base cases. - if (left == right) return left; - if (IsDeny(right)) return left; - if (IsDeny(left)) return right; - +template +PacketTransformerHandle PacketTransformerManager::PointwiseCombine( + PacketTransformerHandle left, PacketTransformerHandle right, + Combiner&& combiner) { // Neither operand is Deny and at most one is Accept, so at least one is a // decision node. Combine the operands at the smallest field branched on by // either; an operand that is Accept, or branches on a strictly larger @@ -665,20 +652,15 @@ PacketTransformerHandle PacketTransformerManager::Union( const DecisionNodeView left_view = ViewAtField(field, left, left_node); const DecisionNodeView right_view = ViewAtField(field, right, right_node); - auto union_combiner = [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Union(left, right); - }; - DecisionNode result_node{ .field = field, .default_branch_by_field_modification = CombineModifyBranches( ModifyBranchesView(*left_view.default_branch_by_field_modification), ModifyBranchesView(*right_view.default_branch_by_field_modification), - union_combiner, + combiner, /*default_value=*/Deny()), .default_branch = - Union(left_view.default_branch, right_view.default_branch), + combiner(left_view.default_branch, right_view.default_branch), }; // For every value mapped in either node, construct the proper new branch. @@ -689,13 +671,27 @@ PacketTransformerHandle PacketTransformerManager::Union( *right_view.default_branch_by_field_modification)) { result_node.modify_branch_by_field_match[value] = CombineModifyBranches( ModifyBranchesAtValue(left_view, value), - ModifyBranchesAtValue(right_view, value), union_combiner, + ModifyBranchesAtValue(right_view, value), combiner, /*default_value=*/Deny()); } return NodeToTransformer(std::move(result_node)); } +PacketTransformerHandle PacketTransformerManager::Union( + PacketTransformerHandle left, PacketTransformerHandle right) { + // Base cases. + if (left == right) return left; + if (IsDeny(right)) return left; + if (IsDeny(left)) return right; + + return PointwiseCombine( + left, right, + [this](PacketTransformerHandle left, PacketTransformerHandle right) { + return Union(left, right); + }); +} + PacketTransformerHandle PacketTransformerManager::Difference( PacketTransformerHandle left, PacketTransformerHandle right) { // Base cases. @@ -703,47 +699,11 @@ PacketTransformerHandle PacketTransformerManager::Difference( if (IsDeny(left)) return Deny(); if (IsDeny(right)) return left; - // Neither operand is Deny and at most one is Accept, so at least one is a - // decision node. Combine the operands at the smallest field branched on by - // either; an operand that is Accept, or branches on a strictly larger - // field, is viewed as a trivial node at that field. - const DecisionNode* left_node = - IsAccept(left) ? nullptr : &GetNodeOrDie(left); - const DecisionNode* right_node = - IsAccept(right) ? nullptr : &GetNodeOrDie(right); - const PacketFieldHandle field = SmallestField(left_node, right_node); - const DecisionNodeView left_view = ViewAtField(field, left, left_node); - const DecisionNodeView right_view = ViewAtField(field, right, right_node); - - auto difference_combiner = [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Difference(left, right); - }; - - DecisionNode result_node{ - .field = field, - .default_branch_by_field_modification = CombineModifyBranches( - ModifyBranchesView(*left_view.default_branch_by_field_modification), - ModifyBranchesView(*right_view.default_branch_by_field_modification), - difference_combiner, - /*default_value=*/Deny()), - .default_branch = - Difference(left_view.default_branch, right_view.default_branch), - }; - - // For every value mapped in either node, construct the proper new branch. - for (int value : - SortedUniqueKeys(*left_view.modify_branch_by_field_match, - *right_view.modify_branch_by_field_match, - *left_view.default_branch_by_field_modification, - *right_view.default_branch_by_field_modification)) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - ModifyBranchesAtValue(left_view, value), - ModifyBranchesAtValue(right_view, value), difference_combiner, - /*default_value=*/Deny()); - } - - return NodeToTransformer(std::move(result_node)); + return PointwiseCombine( + left, right, + [this](PacketTransformerHandle left, PacketTransformerHandle right) { + return Difference(left, right); + }); } PacketTransformerHandle PacketTransformerManager::Iterate( diff --git a/netkat/packet_transformer.h b/netkat/packet_transformer.h index b63bf19..ed0e074 100644 --- a/netkat/packet_transformer.h +++ b/netkat/packet_transformer.h @@ -404,6 +404,17 @@ class PacketTransformerManager { // enough to avoid excessive memory overhead. static constexpr size_t kPageSize = (1 << 26) / sizeof(DecisionNode); + // Shared implementation of `Union` and `Difference`, which differ only in + // their base cases and the operation applied to corresponding branches: + // combines `left` and `right` by applying `combiner` — which must be the + // handle-level operation itself, e.g. `Union` — to corresponding branches. + // Both operands must be Accept or decision nodes (i.e., the base cases must + // already have been handled). + template + PacketTransformerHandle PointwiseCombine(PacketTransformerHandle left, + PacketTransformerHandle right, + Combiner&& combiner); + // The decision nodes forming the BDD-style DAG representation of packets. // `PacketTransformerHandle::node_index_` indexes into this vector. // diff --git a/netkat/packet_transformer_test.expected b/netkat/packet_transformer_test.expected index 9f01d94..ca6b248 100644 --- a/netkat/packet_transformer_test.expected +++ b/netkat/packet_transformer_test.expected @@ -42,22 +42,22 @@ digraph { Test case: p := (a=5 + b=2);(b:=1 + c=5). Example from Katch paper Fig 5. ================================================================================ -- STRING ---------------------------------------------------------------------- -PacketTransformerHandle<8>: +PacketTransformerHandle<3>: PacketFieldHandle<0>:'a' == 5: - PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<6> + PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<1> PacketFieldHandle<0>:'a' == *: - PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<7> -PacketTransformerHandle<6>: + PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<2> +PacketTransformerHandle<1>: PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<5> -PacketTransformerHandle<7>: + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<0> +PacketTransformerHandle<2>: PacketFieldHandle<1>:'b' == 2: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle - PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<5> + PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<0> PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<5>: +PacketTransformerHandle<0>: PacketFieldHandle<2>:'c' == 5: PacketFieldHandle<2>:'c' := 5 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == *: @@ -68,50 +68,50 @@ digraph { edge [fontsize = 12] 4294967294 [label="T" shape=box] 4294967295 [label="F" shape=box] - 8 [label="a"] - 8 -> 6 [label="a==5; a:=5"] - 8 -> 7 [style=dashed] - 6 [label="b"] - 6 -> 4294967294 [label="b:=1" style=dashed] - 6 -> 5 [style=dashed] - 7 [label="b"] - 7 -> 4294967294 [label="b==2; b:=1"] - 7 -> 5 [label="b==2; b:=2"] - 7 -> 4294967295 [style=dashed] - 5 [label="c"] - 5 -> 4294967294 [label="c==5; c:=5"] - 5 -> 4294967295 [style=dashed] + 3 [label="a"] + 3 -> 1 [label="a==5; a:=5"] + 3 -> 2 [style=dashed] + 1 [label="b"] + 1 -> 4294967294 [label="b:=1" style=dashed] + 1 -> 0 [style=dashed] + 2 [label="b"] + 2 -> 4294967294 [label="b==2; b:=1"] + 2 -> 0 [label="b==2; b:=2"] + 2 -> 4294967295 [style=dashed] + 0 [label="c"] + 0 -> 4294967294 [label="c==5; c:=5"] + 0 -> 4294967295 [style=dashed] } ================================================================================ Test case: q := (b=1 + c:=4 + a:=5;b:=1). Example from Katch paper Fig 5. ================================================================================ -- STRING ---------------------------------------------------------------------- -PacketTransformerHandle<17>: +PacketTransformerHandle<5>: PacketFieldHandle<0>:'a' == 1: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<14> + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<2> PacketFieldHandle<0>:'a' == *: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<4> - PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<16> -PacketTransformerHandle<14>: + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<3> + PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<4> +PacketTransformerHandle<2>: PacketFieldHandle<1>:'b' == 1: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<10> -PacketTransformerHandle<4>: + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<1> +PacketTransformerHandle<3>: PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<16>: +PacketTransformerHandle<4>: PacketFieldHandle<1>:'b' == 1: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> PacketFieldHandle<1>:'b' == *: - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<10> -PacketTransformerHandle<13>: + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<1> +PacketTransformerHandle<0>: PacketFieldHandle<2>:'c' == *: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == * -> PacketTransformerHandle -PacketTransformerHandle<10>: +PacketTransformerHandle<1>: PacketFieldHandle<2>:'c' == *: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == * -> PacketTransformerHandle @@ -121,64 +121,64 @@ digraph { edge [fontsize = 12] 4294967294 [label="T" shape=box] 4294967295 [label="F" shape=box] - 17 [label="a"] - 17 -> 14 [label="a==1; a:=1"] - 17 -> 4 [label="a:=1" style=dashed] - 17 -> 16 [style=dashed] - 14 [label="b"] - 14 -> 13 [label="b==1; b:=1"] - 14 -> 4294967294 [label="b:=1" style=dashed] - 14 -> 10 [style=dashed] + 5 [label="a"] + 5 -> 2 [label="a==1; a:=1"] + 5 -> 3 [label="a:=1" style=dashed] + 5 -> 4 [style=dashed] + 2 [label="b"] + 2 -> 0 [label="b==1; b:=1"] + 2 -> 4294967294 [label="b:=1" style=dashed] + 2 -> 1 [style=dashed] + 3 [label="b"] + 3 -> 4294967294 [label="b:=1" style=dashed] + 3 -> 4294967295 [style=dashed] 4 [label="b"] - 4 -> 4294967294 [label="b:=1" style=dashed] - 4 -> 4294967295 [style=dashed] - 16 [label="b"] - 16 -> 13 [label="b==1; b:=1"] - 16 -> 10 [style=dashed] - 13 [label="c"] - 13 -> 4294967294 [label="c:=4" style=dashed] - 13 -> 4294967294 [style=dashed] - 10 [label="c"] - 10 -> 4294967294 [label="c:=4" style=dashed] - 10 -> 4294967295 [style=dashed] + 4 -> 0 [label="b==1; b:=1"] + 4 -> 1 [style=dashed] + 0 [label="c"] + 0 -> 4294967294 [label="c:=4" style=dashed] + 0 -> 4294967294 [style=dashed] + 1 [label="c"] + 1 -> 4294967294 [label="c:=4" style=dashed] + 1 -> 4294967295 [style=dashed] } ================================================================================ Test case: p;q. Example from Katch paper Fig 5. ================================================================================ -- STRING ---------------------------------------------------------------------- -PacketTransformerHandle<22>: +PacketTransformerHandle<6>: PacketFieldHandle<0>:'a' == 1: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<19> + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<2> PacketFieldHandle<0>:'a' == 5: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<4> - PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<21> + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<3> + PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<4> PacketFieldHandle<0>:'a' == *: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<20> - PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<19> -PacketTransformerHandle<19>: + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<5> + PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<2> +PacketTransformerHandle<2>: PacketFieldHandle<1>:'b' == 2: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> - PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<18> + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> + PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<1> PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<4>: +PacketTransformerHandle<3>: PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<21>: +PacketTransformerHandle<4>: PacketFieldHandle<1>:'b' == *: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<18> -PacketTransformerHandle<20>: + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<1> +PacketTransformerHandle<5>: PacketFieldHandle<1>:'b' == 2: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<13>: +PacketTransformerHandle<0>: PacketFieldHandle<2>:'c' == *: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == * -> PacketTransformerHandle -PacketTransformerHandle<18>: +PacketTransformerHandle<1>: PacketFieldHandle<2>:'c' == 5: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == *: @@ -189,29 +189,29 @@ digraph { edge [fontsize = 12] 4294967294 [label="T" shape=box] 4294967295 [label="F" shape=box] - 22 [label="a"] - 22 -> 19 [label="a==1; a:=1"] - 22 -> 4 [label="a==5; a:=1"] - 22 -> 21 [label="a==5; a:=5"] - 22 -> 20 [label="a:=1" style=dashed] - 22 -> 19 [style=dashed] - 19 [label="b"] - 19 -> 13 [label="b==2; b:=1"] - 19 -> 18 [label="b==2; b:=2"] - 19 -> 4294967295 [style=dashed] + 6 [label="a"] + 6 -> 2 [label="a==1; a:=1"] + 6 -> 3 [label="a==5; a:=1"] + 6 -> 4 [label="a==5; a:=5"] + 6 -> 5 [label="a:=1" style=dashed] + 6 -> 2 [style=dashed] + 2 [label="b"] + 2 -> 0 [label="b==2; b:=1"] + 2 -> 1 [label="b==2; b:=2"] + 2 -> 4294967295 [style=dashed] + 3 [label="b"] + 3 -> 4294967294 [label="b:=1" style=dashed] + 3 -> 4294967295 [style=dashed] 4 [label="b"] - 4 -> 4294967294 [label="b:=1" style=dashed] - 4 -> 4294967295 [style=dashed] - 21 [label="b"] - 21 -> 13 [label="b:=1" style=dashed] - 21 -> 18 [style=dashed] - 20 [label="b"] - 20 -> 4294967294 [label="b==2; b:=1"] - 20 -> 4294967295 [style=dashed] - 13 [label="c"] - 13 -> 4294967294 [label="c:=4" style=dashed] - 13 -> 4294967294 [style=dashed] - 18 [label="c"] - 18 -> 4294967294 [label="c==5; c:=4"] - 18 -> 4294967295 [style=dashed] + 4 -> 0 [label="b:=1" style=dashed] + 4 -> 1 [style=dashed] + 5 [label="b"] + 5 -> 4294967294 [label="b==2; b:=1"] + 5 -> 4294967295 [style=dashed] + 0 [label="c"] + 0 -> 4294967294 [label="c:=4" style=dashed] + 0 -> 4294967294 [style=dashed] + 1 [label="c"] + 1 -> 4294967294 [label="c==5; c:=4"] + 1 -> 4294967295 [style=dashed] } diff --git a/netkat/packet_transformer_test_runner.cc b/netkat/packet_transformer_test_runner.cc index 653c320..c4eacfc 100644 --- a/netkat/packet_transformer_test_runner.cc +++ b/netkat/packet_transformer_test_runner.cc @@ -16,15 +16,118 @@ // `bazel run //netkat:packet_transformer_diff_test // -- --update` +#include #include #include #include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "netkat/netkat_proto_constructors.h" +#include "netkat/packet_field.h" #include "netkat/packet_transformer.h" namespace netkat { + +// Test peer to access `PacketTransformerManager` internals, allowing this +// runner to renumber nodes canonically before printing. +class PacketTransformerManagerTestPeer { + public: + // Copies `transformer` into the fresh manager `to`, interning fields and + // nodes in a deterministic traversal order. This makes the printed handle + // numbers depend only on the structure of `transformer` — not on the + // interning order of `from`, which is an implementation detail that changes + // under refactorings of the manager (e.g. reordering its operations). + static PacketTransformerHandle CanonicalCopy( + const PacketTransformerManager& from, PacketTransformerHandle transformer, + PacketTransformerManager& to) { + // Intern the fields of all reachable nodes in their original relative + // order, which the node invariants ("fields increase strictly along each + // path") depend on. + std::vector fields; + absl::flat_hash_set visited; + CollectFields(from, transformer, visited, fields); + absl::c_sort(fields); + fields.erase(std::unique(fields.begin(), fields.end()), fields.end()); + for (PacketFieldHandle field : fields) { + to.packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( + from.packet_set_manager_.field_manager_.GetFieldName(field)); + } + + absl::flat_hash_map + copy_by_original; + return Copy(from, transformer, to, copy_by_original); + } + + private: + static void CollectFields( + const PacketTransformerManager& from, PacketTransformerHandle transformer, + absl::flat_hash_set& visited, + std::vector& fields) { + if (from.IsDeny(transformer) || from.IsAccept(transformer)) return; + if (!visited.insert(transformer).second) return; + const PacketTransformerManager::DecisionNode& node = + from.GetNodeOrDie(transformer); + fields.push_back(node.field); + for (const auto& [match_value, branch_by_modify] : + node.modify_branch_by_field_match) { + for (const auto& [modify_value, branch] : branch_by_modify) { + CollectFields(from, branch, visited, fields); + } + } + for (const auto& [modify_value, branch] : + node.default_branch_by_field_modification) { + CollectFields(from, branch, visited, fields); + } + CollectFields(from, node.default_branch, visited, fields); + } + + static PacketTransformerHandle Copy( + const PacketTransformerManager& from, PacketTransformerHandle transformer, + PacketTransformerManager& to, + absl::flat_hash_map& + copy_by_original) { + if (from.IsDeny(transformer)) return to.Deny(); + if (from.IsAccept(transformer)) return to.Accept(); + if (auto it = copy_by_original.find(transformer); + it != copy_by_original.end()) { + return it->second; + } + + const PacketTransformerManager::DecisionNode& node = + from.GetNodeOrDie(transformer); + PacketTransformerManager::DecisionNode copy{ + .field = + to.packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( + from.packet_set_manager_.field_manager_.GetFieldName( + node.field)), + }; + for (const auto& [match_value, branch_by_modify] : + node.modify_branch_by_field_match) { + // `operator[]` keeps entries with empty branch maps, which are + // meaningful: they deny packets matching `match_value`. + auto& copy_branch_by_modify = + copy.modify_branch_by_field_match[match_value]; + for (const auto& [modify_value, branch] : branch_by_modify) { + copy_branch_by_modify[modify_value] = + Copy(from, branch, to, copy_by_original); + } + } + for (const auto& [modify_value, branch] : + node.default_branch_by_field_modification) { + copy.default_branch_by_field_modification[modify_value] = + Copy(from, branch, to, copy_by_original); + } + copy.default_branch = Copy(from, node.default_branch, to, copy_by_original); + + PacketTransformerHandle result = to.NodeToTransformer(std::move(copy)); + copy_by_original.emplace(transformer, result); + return result; + } +}; + namespace { constexpr char kBanner[] = @@ -91,16 +194,22 @@ std::vector TestCases() { } void main() { - // This test needs a deterministic field interning order, and thus must start - // from a fresh manager. PacketTransformerManager manager; for (const TestCase& test_case : TestCases()) { netkat::PacketTransformerHandle packet_transformer = manager.Compile(test_case.policy); + // Renumber nodes and fields canonically, in a fresh manager per test case, + // so that the printed output depends only on the structure of the compiled + // transformer and not on the interning order of `manager`. + PacketTransformerManager canonical_manager; + netkat::PacketTransformerHandle canonical_transformer = + PacketTransformerManagerTestPeer::CanonicalCopy( + manager, packet_transformer, canonical_manager); std::cout << kBanner << "Test case: " << test_case.description << std::endl << kBanner; - std::cout << kStringHeader << manager.ToString(packet_transformer); - std::cout << kDotHeader << manager.ToDot(packet_transformer); + std::cout << kStringHeader + << canonical_manager.ToString(canonical_transformer); + std::cout << kDotHeader << canonical_manager.ToDot(canonical_transformer); } } From 37aa3a4b24aca9140eb7885a5678af5e1ac2252e Mon Sep 17 00:00:00 2001 From: Steffen Smolka Date: Wed, 10 Jun 2026 04:07:20 -0700 Subject: [PATCH 3/5] [NetKAT] Apply review cleanups to the combinator refactor. Findings from a four-angle cleanup review (reuse, simplification, efficiency, altitude): * Replace ModifyBranchesView's hand-rolled merging iterator (~45 lines, the subtlest code in the refactor) with a simple ForEach that visits base entries then the extra entry. No caller needed globally sorted iteration: map contents are order-independent and handle-level results are canonical regardless of combination order. * Deduplicate the operand-view prologue shared by Sequence and PointwiseCombine into ViewOperandsAtSmallestField. * Use absl::NoDestructor for the empty-map singletons (the repo's established idiom) and name the default-handle-is-Deny contract in one place (IsDenyHandle) instead of re-deriving it inline. * Hint the per-value btree insertions in Sequence/PointwiseCombine with end(), since SortedUniqueKeys emits values in increasing order. * Test runner: collect fields into a btree_set instead of sort+unique, and translate field handles through a precomputed map instead of a per-node string round-trip. * Drop a stale TODO referencing the deleted GetMapAtValue, and the now-unused any_invocable dependency. No behavior change: all tests pass and the golden file is unchanged. Benchmarks are neutral. Co-Authored-By: Claude Fable 5 --- netkat/BUILD.bazel | 4 +- netkat/packet_transformer.cc | 186 +++++++++++------------ netkat/packet_transformer_test_runner.cc | 36 ++--- 3 files changed, 106 insertions(+), 120 deletions(-) diff --git a/netkat/BUILD.bazel b/netkat/BUILD.bazel index a456370..5bffd85 100644 --- a/netkat/BUILD.bazel +++ b/netkat/BUILD.bazel @@ -371,11 +371,11 @@ cc_library( ":packet_set", ":paged_stable_vector", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -442,7 +442,7 @@ cc_test( ":netkat_proto_constructors", ":packet_field", ":packet_transformer", - "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", ], diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index efd421f..008fe6d 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -24,6 +24,7 @@ #include #include "absl/algorithm/container.h" +#include "absl/base/no_destructor.h" #include "absl/container/btree_map.h" #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" @@ -121,9 +122,6 @@ PacketTransformerHandle PacketTransformerManager::NodeToTransformer( absl::flat_hash_set redundant_values; for (auto& [match_value, modification_map] : node.modify_branch_by_field_match) { - // TODO(dilo): Consider if this can make use of GetMapAtValue. Perhaps by - // calling it on a copy of DecisionNode without any - // `modify_branch_by_field_match` mappings? if (skip_default_branch || node.default_branch_by_field_modification.contains(match_value)) { // Compare the modification map to the default branch modification map, @@ -365,15 +363,22 @@ using ModifyBranchMap = absl::btree_map; using MatchBranchMap = absl::btree_map; const MatchBranchMap& EmptyMatchBranchMap() { - static const MatchBranchMap* const kEmpty = new MatchBranchMap; + static const absl::NoDestructor kEmpty; return *kEmpty; } const ModifyBranchMap& EmptyModifyBranchMap() { - static const ModifyBranchMap* const kEmpty = new ModifyBranchMap; + static const absl::NoDestructor kEmpty; return *kEmpty; } +// Returns true iff `transformer` is the Deny transformer, which by documented +// contract is the default-constructed handle. (Unlike +// `PacketTransformerManager::IsDeny`, this is callable without a manager.) +bool IsDenyHandle(PacketTransformerHandle transformer) { + return transformer == PacketTransformerHandle(); +} + // A cheap, non-owning view of a decision node as seen "at" a field no larger // than the node's own field: either the node's own maps (if the node branches // on that field), or the trivial expansion "for every value of the field, @@ -422,10 +427,35 @@ PacketFieldHandle SmallestField(const Node* left, const Node* right) { return std::min(left->field, right->field); } +// The two operands of a binary combinator, viewed at the smallest field +// branched on by either: an operand that is Accept, or branches on a strictly +// larger field, is viewed as a trivial node at that field. +struct OperandViews { + PacketFieldHandle field; + DecisionNodeView left; + DecisionNodeView right; +}; + +// Views the operands `left` and `right`, whose decision nodes are `left_node` +// and `right_node` (null encoding Accept; at least one must be non-null), at +// the smallest field branched on by either. See `ViewAtField` regarding the +// `Node` template parameter. +template +OperandViews ViewOperandsAtSmallestField(PacketTransformerHandle left, + const Node* left_node, + PacketTransformerHandle right, + const Node* right_node) { + const PacketFieldHandle field = SmallestField(left_node, right_node); + return OperandViews{ + .field = field, + .left = ViewAtField(field, left, left_node), + .right = ViewAtField(field, right, right_node), + }; +} + // A cheap, non-owning view of a logical (modify value -> branch) map, // represented as a base map plus at most one extra entry whose key must not -// be a key of the base map. Iteration is in increasing key order, like the -// underlying btree map. +// be a key of the base map. class ModifyBranchesView { public: using Entry = std::pair; @@ -436,61 +466,22 @@ class ModifyBranchesView { DCHECK(!base.contains(extra.first)); } - class const_iterator { - public: - Entry operator*() const { - return ExtraIsNext() ? *extra_ : Entry(it_->first, it_->second); - } - const_iterator& operator++() { - if (ExtraIsNext()) { - extra_ = nullptr; - } else { - ++it_; - } - return *this; - } - friend bool operator==(const const_iterator& a, const const_iterator& b) { - return a.it_ == b.it_ && a.extra_ == b.extra_; - } - friend bool operator!=(const const_iterator& a, const const_iterator& b) { - return !(a == b); - } - - private: - friend class ModifyBranchesView; - const_iterator(ModifyBranchMap::const_iterator it, - ModifyBranchMap::const_iterator end, const Entry* extra) - : it_(it), end_(end), extra_(extra) {} - - // The extra entry comes next iff it has not been consumed and its key - // precedes the next base entry's key. - bool ExtraIsNext() const { - return extra_ != nullptr && (it_ == end_ || extra_->first < it_->first); - } - - ModifyBranchMap::const_iterator it_, end_; - const Entry* extra_; // Null if absent or already consumed. - }; - - const_iterator begin() const { - return const_iterator(base_->begin(), base_->end(), - extra_.has_value() ? &*extra_ : nullptr); - } - const_iterator end() const { - return const_iterator(base_->end(), base_->end(), nullptr); + // Invokes `fn(modify_value, branch)` for each entry: the base map entries + // in increasing key order, then the extra entry, if any. + template + void ForEach(Fn&& fn) const { + for (const auto& [value, branch] : *base_) fn(value, branch); + if (extra_.has_value()) fn(extra_->first, extra_->second); } - bool empty() const { return base_->empty() && !extra_.has_value(); } - bool contains(int value) const { - return (extra_.has_value() && extra_->first == value) || - base_->contains(value); - } - std::optional find(int value) const { + std::optional Find(int value) const { if (extra_.has_value() && extra_->first == value) return extra_->second; if (auto it = base_->find(value); it != base_->end()) return it->second; return std::nullopt; } + bool empty() const { return base_->empty() && !extra_.has_value(); } + private: const ModifyBranchMap* base_; std::optional extra_; @@ -509,9 +500,7 @@ ModifyBranchesView ModifyBranchesAtValue(const DecisionNodeView& node, return ModifyBranchesView(it->second); } const ModifyBranchMap& defaults = *node.default_branch_by_field_modification; - // `PacketTransformerHandle()` is the Deny transformer. - if (defaults.contains(value) || - node.default_branch == PacketTransformerHandle()) { + if (defaults.contains(value) || IsDenyHandle(node.default_branch)) { return ModifyBranchesView(defaults); } return ModifyBranchesView(defaults, {value, node.default_branch}); @@ -526,16 +515,16 @@ ModifyBranchMap CombineModifyBranches(const ModifyBranchesView& left, Combiner&& combiner, PacketTransformerHandle default_value) { ModifyBranchMap result; - for (auto [value, left_branch] : left) { - // Keys arrive in increasing order, so inserting at `end()` is O(1). + left.ForEach([&](int value, PacketTransformerHandle left_branch) { + // Keys arrive in near-sorted order, so the `end()` hint is mostly exact. result.try_emplace( result.end(), value, - combiner(left_branch, right.find(value).value_or(default_value))); - } - for (auto [value, right_branch] : right) { - if (left.contains(value)) continue; + combiner(left_branch, right.Find(value).value_or(default_value))); + }); + right.ForEach([&](int value, PacketTransformerHandle right_branch) { + if (left.Find(value).has_value()) return; result.try_emplace(value, combiner(default_value, right_branch)); - } + }); return result; } @@ -562,14 +551,9 @@ PacketTransformerHandle PacketTransformerManager::Sequence( if (IsAccept(left)) return right; if (IsAccept(right)) return left; - // Both operands are decision nodes. Combine them at the smaller of their - // fields; an operand branching on a strictly larger field is viewed as a - // trivial node at the combined field. - const DecisionNode* left_node = &GetNodeOrDie(left); - const DecisionNode* right_node = &GetNodeOrDie(right); - const PacketFieldHandle field = SmallestField(left_node, right_node); - const DecisionNodeView left_view = ViewAtField(field, left, left_node); - const DecisionNodeView right_view = ViewAtField(field, right, right_node); + // Both operands are decision nodes. + const auto [field, left_view, right_view] = ViewOperandsAtSmallestField( + left, &GetNodeOrDie(left), right, &GetNodeOrDie(right)); // Unions `branch` into `branches[modify_value]`, treating absence as Deny. auto union_into = [this](ModifyBranchMap& branches, int modify_value, @@ -590,11 +574,12 @@ PacketTransformerHandle PacketTransformerManager::Sequence( result_node.default_branch_by_field_modification; for (const auto& [value, branch] : *left_view.default_branch_by_field_modification) { - for (auto [modify_value, right_branch] : - ModifyBranchesAtValue(right_view, value)) { - union_into(result_modifications, modify_value, - Sequence(branch, right_branch)); - } + ModifyBranchesAtValue(right_view, value) + .ForEach([&, branch = branch](int modify_value, + PacketTransformerHandle right_branch) { + union_into(result_modifications, modify_value, + Sequence(branch, right_branch)); + }); } // ... and of applying the right node's default modifications to packets // falling through the left node unmodified. @@ -622,15 +607,19 @@ PacketTransformerHandle PacketTransformerManager::Sequence( ModifyBranchesView(EmptyModifyBranchMap(), {value, Deny()}); } + // `value`s arrive in increasing order, so inserting at `end()` is O(1). ModifyBranchMap& result_branches = - result_node.modify_branch_by_field_match[value]; - for (auto [left_value, left_branch] : left_branches_at_value) { - for (auto [modify_value, right_branch] : - ModifyBranchesAtValue(right_view, left_value)) { - union_into(result_branches, modify_value, - Sequence(left_branch, right_branch)); - } - } + result_node.modify_branch_by_field_match + .try_emplace(result_node.modify_branch_by_field_match.end(), value) + ->second; + left_branches_at_value.ForEach([&](int left_value, + PacketTransformerHandle left_branch) { + ModifyBranchesAtValue(right_view, left_value) + .ForEach([&](int modify_value, PacketTransformerHandle right_branch) { + union_into(result_branches, modify_value, + Sequence(left_branch, right_branch)); + }); + }); } return NodeToTransformer(std::move(result_node)); @@ -641,16 +630,10 @@ PacketTransformerHandle PacketTransformerManager::PointwiseCombine( PacketTransformerHandle left, PacketTransformerHandle right, Combiner&& combiner) { // Neither operand is Deny and at most one is Accept, so at least one is a - // decision node. Combine the operands at the smallest field branched on by - // either; an operand that is Accept, or branches on a strictly larger - // field, is viewed as a trivial node at that field. - const DecisionNode* left_node = - IsAccept(left) ? nullptr : &GetNodeOrDie(left); - const DecisionNode* right_node = - IsAccept(right) ? nullptr : &GetNodeOrDie(right); - const PacketFieldHandle field = SmallestField(left_node, right_node); - const DecisionNodeView left_view = ViewAtField(field, left, left_node); - const DecisionNodeView right_view = ViewAtField(field, right, right_node); + // decision node. + const auto [field, left_view, right_view] = ViewOperandsAtSmallestField( + left, IsAccept(left) ? nullptr : &GetNodeOrDie(left), right, + IsAccept(right) ? nullptr : &GetNodeOrDie(right)); DecisionNode result_node{ .field = field, @@ -669,10 +652,13 @@ PacketTransformerHandle PacketTransformerManager::PointwiseCombine( *right_view.modify_branch_by_field_match, *left_view.default_branch_by_field_modification, *right_view.default_branch_by_field_modification)) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - ModifyBranchesAtValue(left_view, value), - ModifyBranchesAtValue(right_view, value), combiner, - /*default_value=*/Deny()); + // `value`s arrive in increasing order, so inserting at `end()` is O(1). + result_node.modify_branch_by_field_match.try_emplace( + result_node.modify_branch_by_field_match.end(), value, + CombineModifyBranches(ModifyBranchesAtValue(left_view, value), + ModifyBranchesAtValue(right_view, value), + combiner, + /*default_value=*/Deny())); } return NodeToTransformer(std::move(result_node)); diff --git a/netkat/packet_transformer_test_runner.cc b/netkat/packet_transformer_test_runner.cc index c4eacfc..e6a4ad7 100644 --- a/netkat/packet_transformer_test_runner.cc +++ b/netkat/packet_transformer_test_runner.cc @@ -16,13 +16,12 @@ // `bazel run //netkat:packet_transformer_diff_test // -- --update` -#include #include #include #include #include -#include "absl/algorithm/container.h" +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "netkat/netkat_proto_constructors.h" @@ -45,32 +44,33 @@ class PacketTransformerManagerTestPeer { PacketTransformerManager& to) { // Intern the fields of all reachable nodes in their original relative // order, which the node invariants ("fields increase strictly along each - // path") depend on. - std::vector fields; + // path") depend on, and record the handle translation for `Copy`. + absl::btree_set fields; absl::flat_hash_set visited; CollectFields(from, transformer, visited, fields); - absl::c_sort(fields); - fields.erase(std::unique(fields.begin(), fields.end()), fields.end()); + absl::flat_hash_map field_translation; for (PacketFieldHandle field : fields) { - to.packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( - from.packet_set_manager_.field_manager_.GetFieldName(field)); + field_translation.try_emplace( + field, + to.packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( + from.packet_set_manager_.field_manager_.GetFieldName(field))); } absl::flat_hash_map copy_by_original; - return Copy(from, transformer, to, copy_by_original); + return Copy(from, transformer, to, field_translation, copy_by_original); } private: static void CollectFields( const PacketTransformerManager& from, PacketTransformerHandle transformer, absl::flat_hash_set& visited, - std::vector& fields) { + absl::btree_set& fields) { if (from.IsDeny(transformer) || from.IsAccept(transformer)) return; if (!visited.insert(transformer).second) return; const PacketTransformerManager::DecisionNode& node = from.GetNodeOrDie(transformer); - fields.push_back(node.field); + fields.insert(node.field); for (const auto& [match_value, branch_by_modify] : node.modify_branch_by_field_match) { for (const auto& [modify_value, branch] : branch_by_modify) { @@ -87,6 +87,8 @@ class PacketTransformerManagerTestPeer { static PacketTransformerHandle Copy( const PacketTransformerManager& from, PacketTransformerHandle transformer, PacketTransformerManager& to, + const absl::flat_hash_map& + field_translation, absl::flat_hash_map& copy_by_original) { if (from.IsDeny(transformer)) return to.Deny(); @@ -99,10 +101,7 @@ class PacketTransformerManagerTestPeer { const PacketTransformerManager::DecisionNode& node = from.GetNodeOrDie(transformer); PacketTransformerManager::DecisionNode copy{ - .field = - to.packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( - from.packet_set_manager_.field_manager_.GetFieldName( - node.field)), + .field = field_translation.at(node.field), }; for (const auto& [match_value, branch_by_modify] : node.modify_branch_by_field_match) { @@ -112,15 +111,16 @@ class PacketTransformerManagerTestPeer { copy.modify_branch_by_field_match[match_value]; for (const auto& [modify_value, branch] : branch_by_modify) { copy_branch_by_modify[modify_value] = - Copy(from, branch, to, copy_by_original); + Copy(from, branch, to, field_translation, copy_by_original); } } for (const auto& [modify_value, branch] : node.default_branch_by_field_modification) { copy.default_branch_by_field_modification[modify_value] = - Copy(from, branch, to, copy_by_original); + Copy(from, branch, to, field_translation, copy_by_original); } - copy.default_branch = Copy(from, node.default_branch, to, copy_by_original); + copy.default_branch = Copy(from, node.default_branch, to, field_translation, + copy_by_original); PacketTransformerHandle result = to.NodeToTransformer(std::move(copy)); copy_by_original.emplace(transformer, result); From 622487145b97eb101c4f71a5c2300ea41423ce37 Mon Sep 17 00:00:00 2001 From: Steffen Smolka Date: Wed, 10 Jun 2026 08:51:57 -0700 Subject: [PATCH 4/5] [NetKAT] Store each interned decision node once: pointer-keyed unique tables. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The unique tables of both managers stored every node twice — once in the node arena (nodes_) and once as the hash map key — doubling node memory and copying each new node (nested btree maps and all, for transformer nodes) on insertion. The tables are now keyed by pointers into the arena, which PagedStableVector keeps stable across both growth and manager moves; transparent hash/equality functors allow lookup by node value before interning. Benchmarks improve ~2-9% across the board on top of #98, mostly from no longer copying freshly-created nodes into the key slot. Also: * Adds BM_PushAndPullFullSetThroughPolicy, covering the read-path operations (Push/Pull) that the analysis engine is built on; the existing benchmarks only measured compilation. Sizing note: without memoization, Push/Pull cost grows exponentially in policy size (8 unioned sub-policies ran for 20+ minutes), so the benchmark deliberately stays small — and the planned operation-memoization work is where that changes. * Includes ProtoHashKey.policy_case in the proto-cache hash; it was part of equality but omitted from the hash, causing avoidable collisions across policy kinds. * CheckInternalInvariants verifies pointer-table integrity and exercises both lookup paths in both managers. Co-Authored-By: Claude Fable 5 --- netkat/BUILD.bazel | 3 ++ netkat/packet_set.cc | 43 ++++++++++++++++------- netkat/packet_set.h | 31 +++++++++++++++-- netkat/packet_transformer.cc | 47 ++++++++++++++++++-------- netkat/packet_transformer.h | 36 ++++++++++++++++++-- netkat/packet_transformer_benchmark.cc | 26 ++++++++++++++ 6 files changed, 154 insertions(+), 32 deletions(-) diff --git a/netkat/BUILD.bazel b/netkat/BUILD.bazel index 5bffd85..a72f5c7 100644 --- a/netkat/BUILD.bazel +++ b/netkat/BUILD.bazel @@ -132,6 +132,7 @@ cc_library( "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -376,6 +377,7 @@ cc_library( "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -427,6 +429,7 @@ cc_binary( deps = [ ":netkat_cc_proto", ":netkat_proto_constructors", + ":packet_set", ":packet_transformer", "@com_google_absl//absl/strings", "@com_google_benchmark//:benchmark_main", diff --git a/netkat/packet_set.cc b/netkat/packet_set.cc index 2c7bd1c..fc111fe 100644 --- a/netkat/packet_set.cc +++ b/netkat/packet_set.cc @@ -24,6 +24,7 @@ #include "absl/algorithm/container.h" #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -87,6 +88,14 @@ const PacketSetManager::DecisionNode& PacketSetManager::GetNodeOrDie( return nodes_[packet_set.node_index_]; } +size_t PacketSetManager::NodeHash::operator()(const DecisionNode* node) const { + return absl::HashOf(*node); +} + +size_t PacketSetManager::NodeHash::operator()(const DecisionNode& node) const { + return absl::HashOf(node); +} + PacketSetHandle PacketSetManager::NodeToPacket(DecisionNode&& node) { if (node.branch_by_field_value.empty()) return node.default_branch; @@ -109,16 +118,19 @@ PacketSetHandle PacketSetManager::NodeToPacket(DecisionNode&& node) { } #endif - auto [it, inserted] = - packet_by_node_.try_emplace(node, PacketSetHandle(nodes_.size())); - if (inserted) { - nodes_.push_back(std::move(node)); - LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) - << "Internal invariant violated: Proper and sentinel node indices must " - "be disjoint. This indicates that we allocated more nodes than are " - "supported (> 2^32 - 2)."; + // Look up the node by value via the transparent `NodeHash`/`NodeEq` + // functors; only new nodes get stored (exactly once, in `nodes_`). + if (auto it = packet_by_node_.find(node); it != packet_by_node_.end()) { + return it->second; } - return it->second; + PacketSetHandle packet(nodes_.size()); + nodes_.push_back(std::move(node)); + packet_by_node_.insert({&nodes_[packet.node_index_], packet}); + LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) + << "Internal invariant violated: Proper and sentinel node indices must " + "be disjoint. This indicates that we allocated more nodes than are " + "supported (> 2^32 - 2)."; + return packet; } bool PacketSetManager::Contains(PacketSetHandle packet_set, @@ -483,14 +495,19 @@ absl::Status PacketSetManager::CheckInternalInvariants() const { // Invariant: Proper and sentinel node indices are disjoint. RET_CHECK(nodes_.size() <= SentinelNodeIndex::kMinSentinel); - // Invariant: `packet_by_node_[n] = s` iff `nodes_[s.node_index_] == n`. - for (const auto& [node, packet] : packet_by_node_) { + // Invariant: `packet_by_node_[p] = s` iff `p == &nodes_[s.node_index_]`. + for (const auto& [node_ptr, packet] : packet_by_node_) { RET_CHECK(packet.node_index_ < nodes_.size()); - RET_CHECK(nodes_[packet.node_index_] == node); + RET_CHECK(node_ptr == &nodes_[packet.node_index_]); } for (int i = 0; i < nodes_.size(); ++i) { const DecisionNode& node = nodes_[i]; - auto it = packet_by_node_.find(node); + // Look up both by pointer and by value (exercising the transparent + // functors used by `NodeToPacket`). + auto it = packet_by_node_.find(&node); + RET_CHECK(it != packet_by_node_.end()); + RET_CHECK(it->second == PacketSetHandle(i)); + it = packet_by_node_.find(node); RET_CHECK(it != packet_by_node_.end()); RET_CHECK(it->second == PacketSetHandle(i)); } diff --git a/netkat/packet_set.h b/netkat/packet_set.h index f11088f..696ae55 100644 --- a/netkat/packet_set.h +++ b/netkat/packet_set.h @@ -356,11 +356,38 @@ class PacketSetManager { // `And`, `Or`, `Not`). The class also avoids expensive relocations. PagedStableVector nodes_; + // Transparent hash and equality functors for the unique table + // (`packet_by_node_`), which is keyed by stable `DecisionNode*` pointers + // into `nodes_` (so each node is stored only once). Lookups work directly + // with a not-yet-interned `DecisionNode` value. Both functors are + // stateless: keys are pointers, and the pages holding the nodes are stable + // across moves of the manager. + struct NodeHash { + using is_transparent = void; + size_t operator()(const DecisionNode* node) const; + size_t operator()(const DecisionNode& node) const; + }; + struct NodeEq { + using is_transparent = void; + bool operator()(const DecisionNode* a, const DecisionNode* b) const { + return a == b || *a == *b; + } + bool operator()(const DecisionNode* a, const DecisionNode& b) const { + return *a == b; + } + bool operator()(const DecisionNode& a, const DecisionNode* b) const { + return a == *b; + } + }; + // A so called "unique table" to ensure each node is only added to `nodes_` // once, and thus has a unique `PacketSetHandle::node_index`. + // Keyed by pointers into `nodes_` (stable, see `PagedStableVector`), so + // nodes are not stored twice. // - // INVARIANT: `packet_by_node_[n] = s` iff `nodes_[s.node_index_] == n`. - absl::flat_hash_map packet_by_node_; + // INVARIANT: `packet_by_node_[p] = s` iff `p == &nodes_[s.node_index_]`. + absl::flat_hash_map + packet_by_node_; // A map of a given `PredicateProto` to a `PacketSetHandle`. // diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index 008fe6d..2dc1a80 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -29,6 +29,7 @@ #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -78,6 +79,16 @@ PacketTransformerManager::GetNodeOrDie( return nodes_[transformer.node_index_]; } +size_t PacketTransformerManager::NodeHash::operator()( + const DecisionNode* node) const { + return absl::HashOf(*node); +} + +size_t PacketTransformerManager::NodeHash::operator()( + const DecisionNode& node) const { + return absl::HashOf(node); +} + // Canonicalizes a decision node and returns a transformer. PacketTransformerHandle PacketTransformerManager::NodeToTransformer( DecisionNode&& node) { @@ -153,16 +164,20 @@ PacketTransformerHandle PacketTransformerManager::NodeToTransformer( node.default_branch_by_field_modification.empty()) return node.default_branch; - auto [it, inserted] = transformer_by_node_.try_emplace( - node, PacketTransformerHandle(nodes_.size())); - if (inserted) { - nodes_.push_back(std::move(node)); - LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) - << "Internal invariant violated: Proper and sentinel node indices must " - "be disjoint. This indicates that we allocated more nodes than are " - "supported (> 2^32 - 2)."; + // Look up the node by value via the transparent `NodeHash`/`NodeEq` + // functors; only new nodes get stored (exactly once, in `nodes_`). + if (auto it = transformer_by_node_.find(node); + it != transformer_by_node_.end()) { + return it->second; } - return it->second; + PacketTransformerHandle transformer(nodes_.size()); + nodes_.push_back(std::move(node)); + transformer_by_node_.insert({&nodes_[transformer.node_index_], transformer}); + LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) + << "Internal invariant violated: Proper and sentinel node indices must " + "be disjoint. This indicates that we allocated more nodes than are " + "supported (> 2^32 - 2)."; + return transformer; } bool PacketTransformerManager::IsDeny( @@ -1040,15 +1055,19 @@ absl::Status PacketTransformerManager::CheckInternalInvariants() const { // Invariant: Proper and sentinel node indices are disjoint. RET_CHECK(nodes_.size() <= SentinelNodeIndex::kMinSentinel); - // Invariant: `transformer_by_node_[n] = s` iff `nodes_[s.node_index_] == - // n`. - for (const auto& [node, transformer] : transformer_by_node_) { + // Invariant: `transformer_by_node_[p] = s` iff `p == &nodes_[s.node_index_]`. + for (const auto& [node_ptr, transformer] : transformer_by_node_) { RET_CHECK(transformer.node_index_ < nodes_.size()); - RET_CHECK(nodes_[transformer.node_index_] == node); + RET_CHECK(node_ptr == &nodes_[transformer.node_index_]); } for (int i = 0; i < nodes_.size(); ++i) { const DecisionNode& node = nodes_[i]; - auto it = transformer_by_node_.find(node); + // Look up both by pointer and by value (exercising the transparent + // functors used by `NodeToTransformer`). + auto it = transformer_by_node_.find(&node); + RET_CHECK(it != transformer_by_node_.end()); + RET_CHECK(it->second == PacketTransformerHandle(i)); + it = transformer_by_node_.find(node); RET_CHECK(it != transformer_by_node_.end()); RET_CHECK(it->second == PacketTransformerHandle(i)); } diff --git a/netkat/packet_transformer.h b/netkat/packet_transformer.h index ed0e074..d723be9 100644 --- a/netkat/packet_transformer.h +++ b/netkat/packet_transformer.h @@ -383,7 +383,33 @@ class PacketTransformerManager { template friend H AbslHashValue(H h, const ProtoHashKey& key) { - return H::combine(std::move(h), key.lhs_child, key.rhs_child); + return H::combine(std::move(h), key.policy_case, key.lhs_child, + key.rhs_child); + } + }; + + // Transparent hash and equality functors for the unique table + // (`transformer_by_node_`), which is keyed by stable `DecisionNode*` + // pointers into `nodes_` so that each node is stored only once (rather than + // twice: once in `nodes_` and once as a map key). Lookups work directly + // with a not-yet-interned `DecisionNode` value. Both functors are + // stateless: keys are pointers, and the pages holding the nodes are stable + // across moves of the manager. + struct NodeHash { + using is_transparent = void; + size_t operator()(const DecisionNode* node) const; + size_t operator()(const DecisionNode& node) const; + }; + struct NodeEq { + using is_transparent = void; + bool operator()(const DecisionNode* a, const DecisionNode* b) const { + return a == b || *a == *b; + } + bool operator()(const DecisionNode* a, const DecisionNode& b) const { + return *a == b; + } + bool operator()(const DecisionNode& a, const DecisionNode* b) const { + return a == *b; } }; @@ -425,9 +451,13 @@ class PacketTransformerManager { // A so called "unique table" to ensure each node is only added to `nodes_` // once, and thus has a unique `PacketTransformerHandle::node_index`. + // Keyed by pointers into `nodes_` (stable, see `PagedStableVector`), so + // nodes are not stored twice. The transparent `NodeHash`/`NodeEq` functors + // support lookup by node value, see their documentation. // - // INVARIANT: `transformer_by_node_[n] = s` iff `nodes_[s.node_index_] == n`. - absl::flat_hash_map + // INVARIANT: `transformer_by_node_[p] = s` iff `p == &nodes_[s.node_index_]`. + absl::flat_hash_map transformer_by_node_; // A map of a given `PolicyProto` to a `PacketTransformerHandle`. diff --git a/netkat/packet_transformer_benchmark.cc b/netkat/packet_transformer_benchmark.cc index d927ffc..30d1a01 100644 --- a/netkat/packet_transformer_benchmark.cc +++ b/netkat/packet_transformer_benchmark.cc @@ -18,6 +18,7 @@ #include "benchmark/benchmark.h" #include "netkat/netkat.pb.h" #include "netkat/netkat_proto_constructors.h" +#include "netkat/packet_set.h" #include "netkat/packet_transformer.h" namespace netkat { @@ -99,6 +100,31 @@ void BM_FirstTimeCompileOverlappingPolicy(benchmark::State& state) { } BENCHMARK(BM_FirstTimeCompileOverlappingPolicy); +// Benchmarks the read-path operations that the analysis engine is built on: +// pushing/pulling packet sets through an already-compiled transformer. +// After the first iteration all nodes exist, so steady-state iterations +// exercise DAG traversal and unique-table hits rather than first-time node +// creation. +void BM_PushAndPullFullSetThroughPolicy(benchmark::State& state) { + PacketTransformerManager manager; + PolicyProto policy = CreateFixedArbitraryPolicyProto(0); + // NOTE: Without memoization of the recursive transformer operations, + // Push/Pull cost grows exponentially with policy size; keep the number of + // unioned sub-policies small so the benchmark stays tractable. + for (int i = 1; i < 2; ++i) { + policy = UnionProto( + policy, SequenceProto(CreateFixedArbitraryPolicyProto(i), + CreateFixedArbitraryPolicyProto(i + 8))); + } + PacketTransformerHandle transformer = manager.Compile(policy); + PacketSetHandle full_set = manager.GetPacketSetManager().FullSet(); + for (auto s : state) { + benchmark::DoNotOptimize(manager.Push(full_set, transformer)); + benchmark::DoNotOptimize(manager.Pull(transformer, full_set)); + } +} +BENCHMARK(BM_PushAndPullFullSetThroughPolicy); + // Benchmarks the cost of compiling a policy, with overlapping substructures, // that has already been compiled once before. Excludes the initial cost of // compilation. From b962f40f47de836bcb0f64beaf2716bd4eb9cdf9 Mon Sep 17 00:00:00 2001 From: Steffen Smolka Date: Wed, 10 Jun 2026 04:20:31 -0700 Subject: [PATCH 5/5] [NetKAT] Rebase flat interned DecisionNodes onto the copy-elimination combinators. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reworks the flat-node representation change to stack on top of the view-based combinator refactor (#98), composing the two designs into the intended end state: * Interned nodes are flat (two contiguous sorted arrays, 40 bytes, two allocations) and stored exactly once, in a unique table keyed by stable pointers into the node arena with transparent lookup-by-builder (and the same table design ported to PacketSetManager, where nodes are even more numerous). * The combinators keep #98's structure but their operand views now read flat spans: DecisionNodeView wraps a flat node (or the trivial fall-through), ModifyBranchesView is a span plus at most one extra entry, and per-value lookups are binary searches. Only result accumulation still uses btree maps (DecisionNodeBuilder), which NodeToTransformer canonicalizes and flattens on interning. * The builder-expansion machinery from the previous iteration of this branch (ToBuilder on every combinator call, handle-level field alignment) is gone — operands are never expanded, completing the deletion that #98's views started. Also carries over from the previous iteration: the canonical flat-sequence visitor (ForEachFlatEntry) shared by Flatten/NodeHash/ NodeEq, streaming hashing, the Matches() iteration API, strengthened internal invariants (flat-encoding structure, hash/eq consistency, pointer-table integrity), the Push/Pull read-path benchmark, and the ProtoHashKey.policy_case hash fix. Verified structure-preserving: the golden file of the canonicalizing diff test is byte-identical to #98's. All 17 test targets pass. Co-Authored-By: Claude Fable 5 --- netkat/BUILD.bazel | 1 - netkat/packet_transformer.cc | 453 +++++++++++++++-------- netkat/packet_transformer.h | 285 +++++++++++--- netkat/packet_transformer_test.cc | 14 +- netkat/packet_transformer_test_runner.cc | 25 +- 5 files changed, 545 insertions(+), 233 deletions(-) diff --git a/netkat/BUILD.bazel b/netkat/BUILD.bazel index a72f5c7..7ceec4c 100644 --- a/netkat/BUILD.bazel +++ b/netkat/BUILD.bazel @@ -372,7 +372,6 @@ cc_library( ":packet_set", ":paged_stable_vector", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index 2dc1a80..c093d63 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -24,7 +24,6 @@ #include #include "absl/algorithm/container.h" -#include "absl/base/no_destructor.h" #include "absl/container/btree_map.h" #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" @@ -79,19 +78,132 @@ PacketTransformerManager::GetNodeOrDie( return nodes_[transformer.node_index_]; } +template +bool PacketTransformerManager::ForEachFlatEntry(const DecisionNode& node, + MatchFn&& match, + ModifyFn&& modify) { + for (const auto& [match_value, end_offset] : node.matches) { + if (!match(match_value, end_offset)) return false; + } + for (const auto& [modify_value, branch] : node.modifies) { + if (!modify(modify_value, branch)) return false; + } + return true; +} + +template +bool PacketTransformerManager::ForEachFlatEntry( + const DecisionNodeBuilder& builder, MatchFn&& match, ModifyFn&& modify) { + uint32_t end_offset = 0; + for (const auto& [match_value, modify_branch_by_value] : + builder.modify_branch_by_field_match) { + end_offset += modify_branch_by_value.size(); + if (!match(match_value, end_offset)) return false; + } + for (const auto& [match_value, modify_branch_by_value] : + builder.modify_branch_by_field_match) { + for (const auto& [modify_value, branch] : modify_branch_by_value) { + if (!modify(modify_value, branch)) return false; + } + } + for (const auto& [modify_value, branch] : + builder.default_branch_by_field_modification) { + if (!modify(modify_value, branch)) return false; + } + return true; +} + +PacketTransformerManager::DecisionNode PacketTransformerManager::Flatten( + DecisionNodeBuilder&& builder) { + size_t num_matches = 0; + size_t num_modifies = 0; + ForEachFlatEntry( + builder, + [&](int, uint32_t) { + ++num_matches; + return true; + }, + [&](int, PacketTransformerHandle) { + ++num_modifies; + return true; + }); + DecisionNode node{ + .field = builder.field, + .default_branch = builder.default_branch, + .matches{num_matches}, + .modifies{num_modifies}, + }; + size_t i = 0; + size_t j = 0; + ForEachFlatEntry( + builder, + [&](int match_value, uint32_t end_offset) { + node.matches[i++] = {match_value, end_offset}; + return true; + }, + [&](int modify_value, PacketTransformerHandle branch) { + node.modifies[j++] = {modify_value, branch}; + return true; + }); + return node; +} + +PacketTransformerManager::DecisionNodeBuilder +PacketTransformerManager::ToBuilder(const DecisionNode& node) { + DecisionNodeBuilder builder{ + .field = node.field, + .default_branch = node.default_branch, + }; + for (const DecisionNode::Match& match : node.Matches()) { + absl::btree_map& modify_branch_by_value = + builder.modify_branch_by_field_match[match.value]; + for (const DecisionNode::ModifyEntry& entry : match.modifies) { + modify_branch_by_value.insert(modify_branch_by_value.end(), entry); + } + } + for (const DecisionNode::ModifyEntry& entry : node.DefaultModifies()) { + builder.default_branch_by_field_modification.insert( + builder.default_branch_by_field_modification.end(), entry); + } + return builder; +} + size_t PacketTransformerManager::NodeHash::operator()( const DecisionNode* node) const { - return absl::HashOf(*node); + return absl::HashOf(FlatSequenceView{*node}); } size_t PacketTransformerManager::NodeHash::operator()( - const DecisionNode& node) const { - return absl::HashOf(node); + const DecisionNodeBuilder& builder) const { + return absl::HashOf(FlatSequenceView{builder}); +} + +bool PacketTransformerManager::NodeEq::operator()( + const DecisionNode* a, const DecisionNodeBuilder& b) const { + if (a->field != b.field || a->default_branch != b.default_branch) { + return false; + } + // Walk b's canonical flat sequence and compare it element-wise against a's. + size_t i = 0; + size_t j = 0; + bool prefixes_equal = ForEachFlatEntry( + b, + [&](int match_value, uint32_t end_offset) { + return i < a->matches.size() && + a->matches[i++] == + std::pair(match_value, end_offset); + }, + [&](int modify_value, PacketTransformerHandle branch) { + return j < a->modifies.size() && + a->modifies[j++] == + DecisionNode::ModifyEntry(modify_value, branch); + }); + return prefixes_equal && i == a->matches.size() && j == a->modifies.size(); } // Canonicalizes a decision node and returns a transformer. PacketTransformerHandle PacketTransformerManager::NodeToTransformer( - DecisionNode&& node) { + DecisionNodeBuilder&& node) { // Remove any default branches pointing to Deny, saving the value. absl::flat_hash_set deny_values; for (const auto& [modify_value, branch] : @@ -164,14 +276,14 @@ PacketTransformerHandle PacketTransformerManager::NodeToTransformer( node.default_branch_by_field_modification.empty()) return node.default_branch; - // Look up the node by value via the transparent `NodeHash`/`NodeEq` - // functors; only new nodes get stored (exactly once, in `nodes_`). + // Look up the builder directly (without flattening) via the transparent + // `NodeHash`/`NodeEq` functors; only new nodes pay for flattening. if (auto it = transformer_by_node_.find(node); it != transformer_by_node_.end()) { return it->second; } PacketTransformerHandle transformer(nodes_.size()); - nodes_.push_back(std::move(node)); + nodes_.push_back(Flatten(std::move(node))); transformer_by_node_.insert({&nodes_[transformer.node_index_], transformer}); LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) << "Internal invariant violated: Proper and sentinel node indices must " @@ -235,11 +347,11 @@ absl::flat_hash_set PacketTransformerManager::Run( if (initial_field_value.has_value()) { // If it exists, see if there is a value match for it and follow every // corresponding branch with value modified appropriately. - if (auto mod_map_it = - node.modify_branch_by_field_match.find(*initial_field_value); - mod_map_it != node.modify_branch_by_field_match.end()) { + if (std::optional match_index = + node.FindMatch(*initial_field_value); + match_index.has_value()) { matched = true; - for (const auto& [value, branch] : mod_map_it->second) { + for (const auto& [value, branch] : node.MatchModifies(*match_index)) { result.merge( RunWithNewValueThenReset(*this, branch, packet, field, value)); } @@ -251,8 +363,7 @@ absl::flat_hash_set PacketTransformerManager::Run( if (matched) return result; // Otherwise, follow the default branches. - for (const auto& [value, branch] : - node.default_branch_by_field_modification) { + for (const auto& [value, branch] : node.DefaultModifies()) { // If the original packet already had this field with the same value as // this modified branch, then we should not also attempt the default // branch. @@ -333,7 +444,7 @@ PacketTransformerHandle PacketTransformerManager::FromPacketSetHandle( const PacketSetManager::DecisionNode& packet_node = packet_set_manager_.GetNodeOrDie(packet_set); - DecisionNode transformer_node{ + DecisionNodeBuilder transformer_node{ .field = packet_node.field, // This starts out empty and will be populated below. .modify_branch_by_field_match = {}, @@ -362,7 +473,7 @@ PacketTransformerHandle PacketTransformerManager::Filter( PacketTransformerHandle PacketTransformerManager::Modification( absl::string_view field, int value) { - return NodeToTransformer(DecisionNode{ + return NodeToTransformer(DecisionNodeBuilder{ .field = packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( field), .modify_branch_by_field_match = {}, @@ -373,19 +484,15 @@ PacketTransformerHandle PacketTransformerManager::Modification( namespace { -// Aliases for the map types of `PacketTransformerManager::DecisionNode`. +// Alias for the modify-map type of +// `PacketTransformerManager::DecisionNodeBuilder`, used to accumulate +// combinator results before interning. using ModifyBranchMap = absl::btree_map; -using MatchBranchMap = absl::btree_map; -const MatchBranchMap& EmptyMatchBranchMap() { - static const absl::NoDestructor kEmpty; - return *kEmpty; -} - -const ModifyBranchMap& EmptyModifyBranchMap() { - static const absl::NoDestructor kEmpty; - return *kEmpty; -} +// The public entry types of the flat `PacketTransformerManager::DecisionNode` +// arrays (respelled here since the node type itself is private). +using ModifyEntry = std::pair; +using MatchEntry = std::pair; // Returns true iff `transformer` is the Deny transformer, which by documented // contract is the default-constructed handle. (Unlike @@ -394,46 +501,70 @@ bool IsDenyHandle(PacketTransformerHandle transformer) { return transformer == PacketTransformerHandle(); } +// Returns true iff `entries` (sorted by modify value) contains an entry with +// the given modify value. +bool ContainsModifyValue(absl::Span entries, + int modify_value) { + auto it = std::lower_bound( + entries.begin(), entries.end(), modify_value, + [](const ModifyEntry& entry, int value) { return entry.first < value; }); + return it != entries.end() && it->first == modify_value; +} + // A cheap, non-owning view of a decision node as seen "at" a field no larger -// than the node's own field: either the node's own maps (if the node branches -// on that field), or the trivial expansion "for every value of the field, -// leave it unmodified and fall through to the node" (if the node branches on -// a strictly larger field, or is Accept). This lets the binary operations -// below combine operands with distinct fields without materializing — and -// then re-interning — the trivial expansion as a `DecisionNode`. +// than the node's own field: either the node itself (if the node branches on +// that field), or the trivial expansion "for every value of the field, leave +// it unmodified and fall through to the node" (if the node branches on a +// strictly larger field, or is Accept). This lets the binary operations below +// combine operands with distinct fields without materializing — and then +// re-interning — the trivial expansion as a `DecisionNode`. +// +// `Node` is always `PacketTransformerManager::DecisionNode`; it is a template +// parameter only because that type is private and cannot be named here. +template struct DecisionNodeView { - const MatchBranchMap* modify_branch_by_field_match; - const ModifyBranchMap* default_branch_by_field_modification; + // The viewed node, or null for the trivial expansion. + const Node* node; + + // The "leave field unmodified" fall-through: the node's default branch, or + // the viewed transformer itself for the trivial expansion. PacketTransformerHandle default_branch; + + // The (match value, end offset) headers; empty for the trivial expansion. + absl::Span Matches() const { + if (node == nullptr) return {}; + return absl::MakeConstSpan(node->matches); + } + + // The default (modify value -> branch) entries; empty for the trivial + // expansion. + absl::Span DefaultModifies() const { + if (node == nullptr) return {}; + return node->DefaultModifies(); + } }; // Returns the view of `node` (the decision node of `transformer`, or null if // `transformer` is Accept) at `field`, which must be <= `node->field`. -// -// `Node` is always `PacketTransformerManager::DecisionNode`; it is a template -// parameter only because that type is private and cannot be named here. template -DecisionNodeView ViewAtField(PacketFieldHandle field, - PacketTransformerHandle transformer, - const Node* node) { +DecisionNodeView ViewAtField(PacketFieldHandle field, + PacketTransformerHandle transformer, + const Node* node) { if (node != nullptr && node->field == field) { - return DecisionNodeView{ - .modify_branch_by_field_match = &node->modify_branch_by_field_match, - .default_branch_by_field_modification = - &node->default_branch_by_field_modification, + return DecisionNodeView{ + .node = node, .default_branch = node->default_branch, }; } - return DecisionNodeView{ - .modify_branch_by_field_match = &EmptyMatchBranchMap(), - .default_branch_by_field_modification = &EmptyModifyBranchMap(), + return DecisionNodeView{ + .node = nullptr, .default_branch = transformer, }; } // Returns the smallest field branched on by the given decision nodes, at // least one of which must be non-null (null encodes Accept, which branches on -// no field). See `ViewAtField` regarding the `Node` template parameter. +// no field). See `DecisionNodeView` regarding the `Node` template parameter. template PacketFieldHandle SmallestField(const Node* left, const Node* right) { DCHECK(left != nullptr || right != nullptr); @@ -445,23 +576,23 @@ PacketFieldHandle SmallestField(const Node* left, const Node* right) { // The two operands of a binary combinator, viewed at the smallest field // branched on by either: an operand that is Accept, or branches on a strictly // larger field, is viewed as a trivial node at that field. +template struct OperandViews { PacketFieldHandle field; - DecisionNodeView left; - DecisionNodeView right; + DecisionNodeView left; + DecisionNodeView right; }; // Views the operands `left` and `right`, whose decision nodes are `left_node` // and `right_node` (null encoding Accept; at least one must be non-null), at -// the smallest field branched on by either. See `ViewAtField` regarding the -// `Node` template parameter. +// the smallest field branched on by either. template -OperandViews ViewOperandsAtSmallestField(PacketTransformerHandle left, - const Node* left_node, - PacketTransformerHandle right, - const Node* right_node) { +OperandViews ViewOperandsAtSmallestField(PacketTransformerHandle left, + const Node* left_node, + PacketTransformerHandle right, + const Node* right_node) { const PacketFieldHandle field = SmallestField(left_node, right_node); - return OperandViews{ + return OperandViews{ .field = field, .left = ViewAtField(field, left, left_node), .right = ViewAtField(field, right, right_node), @@ -469,56 +600,64 @@ OperandViews ViewOperandsAtSmallestField(PacketTransformerHandle left, } // A cheap, non-owning view of a logical (modify value -> branch) map, -// represented as a base map plus at most one extra entry whose key must not -// be a key of the base map. +// represented as a base range of entries (sorted by modify value) plus at +// most one extra entry whose key must not be a key of the base range. class ModifyBranchesView { public: - using Entry = std::pair; + using Entry = ModifyEntry; - explicit ModifyBranchesView(const ModifyBranchMap& base) : base_(&base) {} - ModifyBranchesView(const ModifyBranchMap& base, Entry extra) - : base_(&base), extra_(extra) { - DCHECK(!base.contains(extra.first)); + explicit ModifyBranchesView(absl::Span base) + : base_(base) {} + ModifyBranchesView(absl::Span base, Entry extra) + : base_(base), extra_(extra) { + DCHECK(!ContainsModifyValue(base, extra.first)); } - // Invokes `fn(modify_value, branch)` for each entry: the base map entries - // in increasing key order, then the extra entry, if any. + // Invokes `fn(modify_value, branch)` for each entry: the base entries in + // increasing key order, then the extra entry, if any. template void ForEach(Fn&& fn) const { - for (const auto& [value, branch] : *base_) fn(value, branch); + for (const auto& [value, branch] : base_) fn(value, branch); if (extra_.has_value()) fn(extra_->first, extra_->second); } std::optional Find(int value) const { if (extra_.has_value() && extra_->first == value) return extra_->second; - if (auto it = base_->find(value); it != base_->end()) return it->second; + auto it = std::lower_bound( + base_.begin(), base_.end(), value, + [](const ModifyEntry& entry, int v) { return entry.first < v; }); + if (it != base_.end() && it->first == value) return it->second; return std::nullopt; } - bool empty() const { return base_->empty() && !extra_.has_value(); } + bool empty() const { return base_.empty() && !extra_.has_value(); } private: - const ModifyBranchMap* base_; + absl::Span base_; std::optional extra_; }; // Returns a view of the logical (modify value -> branch) map that `node` // applies to packets whose field is equal to `value`: the matching entry of -// `modify_branch_by_field_match` if there is one; otherwise the default +// the node's match branches if there is one; otherwise the default // modifications, plus the unmodified fall-through to `default_branch` (keyed // by `value`, since the field keeps its value) unless that branch is Deny or // shadowed by a default modification to `value`. -ModifyBranchesView ModifyBranchesAtValue(const DecisionNodeView& node, +template +ModifyBranchesView ModifyBranchesAtValue(const DecisionNodeView& view, int value) { - if (auto it = node.modify_branch_by_field_match->find(value); - it != node.modify_branch_by_field_match->end()) { - return ModifyBranchesView(it->second); + if (view.node != nullptr) { + if (std::optional match_index = view.node->FindMatch(value); + match_index.has_value()) { + return ModifyBranchesView(view.node->MatchModifies(*match_index)); + } } - const ModifyBranchMap& defaults = *node.default_branch_by_field_modification; - if (defaults.contains(value) || IsDenyHandle(node.default_branch)) { + absl::Span defaults = view.DefaultModifies(); + if (ContainsModifyValue(defaults, value) || + IsDenyHandle(view.default_branch)) { return ModifyBranchesView(defaults); } - return ModifyBranchesView(defaults, {value, node.default_branch}); + return ModifyBranchesView(defaults, {value, view.default_branch}); } // Combines two (modify value -> branch) maps key-wise into a new map, using @@ -577,7 +716,7 @@ PacketTransformerHandle PacketTransformerManager::Sequence( if (!inserted) it->second = Union(it->second, branch); }; - DecisionNode result_node{ + DecisionNodeBuilder result_node{ .field = field, .default_branch = Sequence(left_view.default_branch, right_view.default_branch), @@ -587,8 +726,7 @@ PacketTransformerHandle PacketTransformerManager::Sequence( // gotten by taken default modification branches in the left node... ModifyBranchMap& result_modifications = result_node.default_branch_by_field_modification; - for (const auto& [value, branch] : - *left_view.default_branch_by_field_modification) { + for (const auto& [value, branch] : left_view.DefaultModifies()) { ModifyBranchesAtValue(right_view, value) .ForEach([&, branch = branch](int modify_value, PacketTransformerHandle right_branch) { @@ -599,18 +737,16 @@ PacketTransformerHandle PacketTransformerManager::Sequence( // ... and of applying the right node's default modifications to packets // falling through the left node unmodified. for (const auto& [modify_value, right_branch] : - *right_view.default_branch_by_field_modification) { + right_view.DefaultModifies()) { union_into(result_modifications, modify_value, Sequence(left_view.default_branch, right_branch)); } // For every value mapped in either node, construct the proper new branch. for (int value : - SortedUniqueKeys(*left_view.modify_branch_by_field_match, - *right_view.modify_branch_by_field_match, - *left_view.default_branch_by_field_modification, - *right_view.default_branch_by_field_modification, - result_modifications)) { + SortedUniqueKeys(left_view.Matches(), right_view.Matches(), + left_view.DefaultModifies(), + right_view.DefaultModifies(), result_modifications)) { ModifyBranchesView left_branches_at_value = ModifyBranchesAtValue(left_view, value); // An empty map is equivalent to a map with a single entry of @@ -618,8 +754,8 @@ PacketTransformerHandle PacketTransformerManager::Sequence( // empty map won't work correctly for the merges below (an in fact, the // whole for-loop would be skipped), so we expand it here if necessary. if (left_branches_at_value.empty()) { - left_branches_at_value = - ModifyBranchesView(EmptyModifyBranchMap(), {value, Deny()}); + left_branches_at_value = ModifyBranchesView( + /*base=*/{}, /*extra=*/{value, Deny()}); } // `value`s arrive in increasing order, so inserting at `end()` is O(1). @@ -650,12 +786,11 @@ PacketTransformerHandle PacketTransformerManager::PointwiseCombine( left, IsAccept(left) ? nullptr : &GetNodeOrDie(left), right, IsAccept(right) ? nullptr : &GetNodeOrDie(right)); - DecisionNode result_node{ + DecisionNodeBuilder result_node{ .field = field, .default_branch_by_field_modification = CombineModifyBranches( - ModifyBranchesView(*left_view.default_branch_by_field_modification), - ModifyBranchesView(*right_view.default_branch_by_field_modification), - combiner, + ModifyBranchesView(left_view.DefaultModifies()), + ModifyBranchesView(right_view.DefaultModifies()), combiner, /*default_value=*/Deny()), .default_branch = combiner(left_view.default_branch, right_view.default_branch), @@ -663,10 +798,9 @@ PacketTransformerHandle PacketTransformerManager::PointwiseCombine( // For every value mapped in either node, construct the proper new branch. for (int value : - SortedUniqueKeys(*left_view.modify_branch_by_field_match, - *right_view.modify_branch_by_field_match, - *left_view.default_branch_by_field_modification, - *right_view.default_branch_by_field_modification)) { + SortedUniqueKeys(left_view.Matches(), right_view.Matches(), + left_view.DefaultModifies(), + right_view.DefaultModifies())) { // `value`s arrive in increasing order, so inserting at `end()` is O(1). result_node.modify_branch_by_field_match.try_emplace( result_node.modify_branch_by_field_match.end(), value, @@ -737,8 +871,7 @@ PacketSetHandle PacketTransformerManager::GetAllPossibleOutputPackets( // Case 1: Output packets that hit the default branch and got modified. // Implements the `b_A` in the `fwd` function in section C.3 Push and Pull // in KATch: A Fast Symbolic Verifier for NetKAT. - for (const auto& [modify_value, branch] : - node.default_branch_by_field_modification) { + for (const auto& [modify_value, branch] : node.DefaultModifies()) { add_to_output_by_field_value(modify_value, GetAllPossibleOutputPackets(branch)); } @@ -747,9 +880,8 @@ PacketSetHandle PacketTransformerManager::GetAllPossibleOutputPackets( // Implements the `b_B` in the `fwd` function in section C.3 Push and Pull // in KATch: A Fast Symbolic Verifier for NetKAT. absl::flat_hash_set branch_modify_values; - for (const auto& [match_value, branch_by_modify] : - node.modify_branch_by_field_match) { - for (const auto& [modify_value, branch] : branch_by_modify) { + for (const DecisionNode::Match& match : node.Matches()) { + for (const auto& [modify_value, branch] : match.modifies) { branch_modify_values.insert(modify_value); add_to_output_by_field_value(modify_value, GetAllPossibleOutputPackets(branch)); @@ -761,8 +893,9 @@ PacketSetHandle PacketTransformerManager::GetAllPossibleOutputPackets( // Implements the `b_C` in the `fwd` function in section C.3 Push and Pull // in KATch: A Fast Symbolic Verifier for NetKAT. for (int modify_value : branch_modify_values) { - if (!node.modify_branch_by_field_match.contains(modify_value) && - !node.default_branch_by_field_modification.contains(modify_value)) { + if (!node.FindMatch(modify_value).has_value() && + !DecisionNode::ContainsModifyValue(node.DefaultModifies(), + modify_value)) { add_to_output_by_field_value(modify_value, default_output); } } @@ -770,7 +903,7 @@ PacketSetHandle PacketTransformerManager::GetAllPossibleOutputPackets( // Case 4: Output packets that got matched on an explicit branch, but did // not get modified. Implements the `b_D` in the `fwd` function in section // C.3 Push and Pull in KATch: A Fast Symbolic Verifier for NetKAT. - for (auto& [match_value, unused] : node.modify_branch_by_field_match) { + for (const auto& [match_value, unused_end_offset] : node.matches) { if (!branch_modify_values.contains(match_value)) { add_to_output_by_field_value(match_value, PacketSetManager().EmptySet()); } @@ -822,8 +955,7 @@ PacketTransformerManager::GetAllInputPacketsThatProduceAnyOutput( // Implements the `d'` in the `bwd` function in section C.3 Push and Pull in // KATch: A Fast Symbolic Verifier for NetKAT. PacketSetHandle default_branch_output_packets; - for (const auto& [modify_value, branch] : - node.default_branch_by_field_modification) { + for (const auto& [modify_value, branch] : node.DefaultModifies()) { default_branch_output_packets = packet_set_manager_.Or(default_branch_output_packets, GetAllInputPacketsThatProduceAnyOutput(branch)); @@ -833,23 +965,21 @@ PacketTransformerManager::GetAllInputPacketsThatProduceAnyOutput( // Implements the `b_A` in the `bwd` function in section C.3 Push and Pull // in KATch: A Fast Symbolic Verifier for NetKAT. absl::flat_hash_map branch_by_field_value_map; - for (const auto& [match_value, branch_by_modify] : - node.modify_branch_by_field_match) { + for (const DecisionNode::Match& match : node.Matches()) { PacketSetHandle union_of_branches; - for (const auto& [modify_value, branch] : branch_by_modify) { + for (const auto& [modify_value, branch] : match.modifies) { union_of_branches = packet_set_manager_.Or( union_of_branches, GetAllInputPacketsThatProduceAnyOutput(branch)); } - branch_by_field_value_map[match_value] = union_of_branches; + branch_by_field_value_map[match.value] = union_of_branches; } // Case 3: Input packets that do not get matched on an explicit branch, but // do get modified. // Implements the `b_B` in the `bwd` function in section C.3 Push and Pull // in KATch: A Fast Symbolic Verifier for NetKAT. - for (const auto& [modify_value, unused] : - node.default_branch_by_field_modification) { - if (!node.modify_branch_by_field_match.contains(modify_value)) { + for (const auto& [modify_value, unused_branch] : node.DefaultModifies()) { + if (!node.FindMatch(modify_value).has_value()) { branch_by_field_value_map[modify_value] = default_branch_output_packets; } } @@ -897,7 +1027,7 @@ std::string PacketTransformerManager::ToString(const DecisionNode& node) const { auto pretty_print_map = [&](absl::string_view field, - const absl::btree_map& map) { + absl::Span map) { for (const auto& [value, branch] : map) { absl::StrAppendFormat(&result, " %s := %d -> %v\n", field, value, branch); @@ -910,12 +1040,12 @@ std::string PacketTransformerManager::ToString(const DecisionNode& node) const { absl::CEscape( packet_set_manager_.field_manager_.GetFieldName(node.field))); - for (const auto& [value, modify_map] : node.modify_branch_by_field_match) { - absl::StrAppendFormat(&result, " %s == %d:\n", field, value); - pretty_print_map(field, modify_map); + for (const DecisionNode::Match& match : node.Matches()) { + absl::StrAppendFormat(&result, " %s == %d:\n", field, match.value); + pretty_print_map(field, match.modifies); } absl::StrAppendFormat(&result, " %s == *:\n", field); - pretty_print_map(field, node.default_branch_by_field_modification); + pretty_print_map(field, node.DefaultModifies()); PacketTransformerHandle fallthrough = node.default_branch; absl::StrAppendFormat(&result, " %s == * -> %v\n", field, fallthrough); if (!IsAccept(fallthrough) && !IsDeny(fallthrough)) @@ -937,7 +1067,7 @@ std::string PacketTransformerManager::ToString( auto pretty_print_map = [&](absl::string_view field, - const absl::btree_map& map) { + absl::Span map) { for (const auto& [value, branch] : map) { absl::StrAppendFormat(&result, " %s := %d -> %v\n", field, value, branch); @@ -959,12 +1089,12 @@ std::string PacketTransformerManager::ToString( "%v:'%s'", node.field, absl::CEscape( packet_set_manager_.field_manager_.GetFieldName(node.field))); - for (const auto& [value, modify_map] : node.modify_branch_by_field_match) { - absl::StrAppendFormat(&result, " %s == %d:\n", field, value); - pretty_print_map(field, modify_map); + for (const DecisionNode::Match& match : node.Matches()) { + absl::StrAppendFormat(&result, " %s == %d:\n", field, match.value); + pretty_print_map(field, match.modifies); } absl::StrAppendFormat(&result, " %s == *:\n", field); - pretty_print_map(field, node.default_branch_by_field_modification); + pretty_print_map(field, node.DefaultModifies()); PacketTransformerHandle fallthrough = node.default_branch; absl::StrAppendFormat(&result, " %s == * -> %v\n", field, fallthrough); if (IsAccept(fallthrough) || IsDeny(fallthrough)) continue; @@ -1014,13 +1144,14 @@ std::string PacketTransformerManager::ToDot( packet_set_manager_.field_manager_.GetFieldName(node.field); absl::StrAppendFormat(&result, " %d [label=\"%s\"]\n", transformer.node_index_, field); - for (const auto& [value, modify_map] : node.modify_branch_by_field_match) { - if (modify_map.empty()) { + for (const DecisionNode::Match& match : node.Matches()) { + int value = match.value; + if (match.modifies.empty()) { absl::StrAppendFormat(&result, " %d -> %d [label=\"%s==%s\"]\n", transformer.node_index_, SentinelNodeIndex::kDeny, field, absl::StrCat(value)); } - for (const auto& [new_value, branch] : modify_map) { + for (const auto& [new_value, branch] : match.modifies) { absl::StrAppendFormat(&result, " %d -> %d [label=\"%s==%s; %s:=%d\"]\n", transformer.node_index_, branch.node_index_, @@ -1031,8 +1162,7 @@ std::string PacketTransformerManager::ToDot( } } - for (const auto& [new_value, branch] : - node.default_branch_by_field_modification) { + for (const auto& [new_value, branch] : node.DefaultModifies()) { absl::StrAppendFormat( &result, " %d -> %d [label=\"%s:=%d\" style=dashed]\n", transformer.node_index_, branch.node_index_, field, new_value); @@ -1061,13 +1191,18 @@ absl::Status PacketTransformerManager::CheckInternalInvariants() const { RET_CHECK(node_ptr == &nodes_[transformer.node_index_]); } for (int i = 0; i < nodes_.size(); ++i) { - const DecisionNode& node = nodes_[i]; - // Look up both by pointer and by value (exercising the transparent - // functors used by `NodeToTransformer`). - auto it = transformer_by_node_.find(&node); + auto it = transformer_by_node_.find(&nodes_[i]); RET_CHECK(it != transformer_by_node_.end()); RET_CHECK(it->second == PacketTransformerHandle(i)); - it = transformer_by_node_.find(node); + } + + // Invariant: `NodeHash` and `NodeEq` treat a builder and its flattened + // node identically, as required for transparent unique table lookups. + for (int i = 0; i < nodes_.size(); ++i) { + DecisionNodeBuilder builder = ToBuilder(nodes_[i]); + RET_CHECK(NodeHash()(builder) == NodeHash()(&nodes_[i])); + RET_CHECK(NodeEq()(&nodes_[i], builder)); + auto it = transformer_by_node_.find(builder); RET_CHECK(it != transformer_by_node_.end()); RET_CHECK(it->second == PacketTransformerHandle(i)); } @@ -1075,11 +1210,39 @@ absl::Status PacketTransformerManager::CheckInternalInvariants() const { // Node Invariants. for (int i = 0; i < nodes_.size(); ++i) { const DecisionNode& node = nodes_[i]; - // Invariant: `modify_branch_by_field_match` or - // `default_branch_by_field_modification` is non-empty. + // Invariant: `matches` or `DefaultModifies()` is non-empty. // Maintained by `NodeToTransformer`. - RET_CHECK(!node.modify_branch_by_field_match.empty() || - !node.default_branch_by_field_modification.empty()); + RET_CHECK(!node.matches.empty() || !node.modifies.empty()); + + // Invariants of the flat encoding: match values strictly increase, end + // offsets are monotone and bounded, and each ModifyEntry range is sorted + // by strictly increasing modify value. + uint32_t previous_end_offset = 0; + for (const auto& [match_value, end_offset] : node.matches) { + RET_CHECK(end_offset >= previous_end_offset) << ":\n" << ToString(node); + RET_CHECK(end_offset <= node.modifies.size()) << ":\n" << ToString(node); + previous_end_offset = end_offset; + } + for (size_t j = 1; j < node.matches.size(); ++j) { + RET_CHECK(node.matches[j - 1].first < node.matches[j].first) + << ":\n" + << ToString(node); + } + auto is_strictly_sorted_by_value = + [](absl::Span entries) { + for (size_t j = 1; j < entries.size(); ++j) { + if (entries[j - 1].first >= entries[j].first) return false; + } + return true; + }; + for (const DecisionNode::Match& match : node.Matches()) { + RET_CHECK(is_strictly_sorted_by_value(match.modifies)) + << ":\n" + << ToString(node); + } + RET_CHECK(is_strictly_sorted_by_value(node.DefaultModifies())) + << ":\n" + << ToString(node); // Invariant: node field is strictly smaller than sub-node fields. RET_CHECK(IsAccept(node.default_branch) || IsDeny(node.default_branch) || @@ -1087,12 +1250,11 @@ absl::Status PacketTransformerManager::CheckInternalInvariants() const { << ":\n" << ToString(node); - for (const auto& [match_value, branch_by_modify] : - node.modify_branch_by_field_match) { - for (const auto& [modify_value, branch] : branch_by_modify) { + for (const DecisionNode::Match& match : node.Matches()) { + for (const auto& [modify_value, branch] : match.modifies) { // Invariant: Modify branches are not Deny unless `modify_value == - // match_value`. - RET_CHECK(!IsDeny(branch) || modify_value == match_value) + // match.value`. + RET_CHECK(!IsDeny(branch) || modify_value == match.value) << ":\n" << ToString(node); @@ -1104,8 +1266,7 @@ absl::Status PacketTransformerManager::CheckInternalInvariants() const { } } - for (const auto& [match_value, branch] : - node.default_branch_by_field_modification) { + for (const auto& [modify_value, branch] : node.DefaultModifies()) { // Invariant: Default modify branches are not Deny. RET_CHECK(!IsDeny(branch)); diff --git a/netkat/packet_transformer.h b/netkat/packet_transformer.h index d723be9..531a564 100644 --- a/netkat/packet_transformer.h +++ b/netkat/packet_transformer.h @@ -37,17 +37,21 @@ #ifndef GOOGLE_NETKAT_NETKAT_PACKET_TRANSFORMER_H_ #define GOOGLE_NETKAT_NETKAT_PACKET_TRANSFORMER_H_ +#include #include #include +#include #include #include #include "absl/container/btree_map.h" +#include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "netkat/netkat.pb.h" #include "netkat/packet.h" #include "netkat/packet_field.h" @@ -306,7 +310,21 @@ class PacketTransformerManager { // non-deterministically set field -> value_d_1 then branch_d_1 // non-deterministically set field -> value_d_2 then branch_d_2 // non-deterministically LEAVE field UNMODIFIED then default_branch + // + // CHOICE OF DATA STRUCTURE: + // Logically, a node is a map of maps, match value -> (modify value -> + // branch), plus a map of default modifications, modify value -> branch. We + // store all of this in two flat, sorted arrays to optimize memory layout + // (contiguous, compact, flat), exploiting that nodes are immutable once + // interned. This makes the hashing, equality comparison, and copying done + // by the unique table (`transformer_by_node_`) cheap scans over contiguous + // memory, and shrinks node storage. A mutation-friendly map-based + // representation exists separately as `DecisionNodeBuilder`, used only + // transiently while constructing nodes. struct DecisionNode { + // A single "set field to `first`, then continue with `second`" entry. + using ModifyEntry = std::pair; + // The packet field whose value this decision node branches on. // // INVARIANTS: @@ -315,54 +333,217 @@ class PacketTransformerManager { // * Interned by `field_manager_`. PacketFieldHandle field; - // The "if" branches of the decision node, "keyed" by the value they branch - // on. Each element of the map is a (match_value, Map)-pair encoding - // "if (field == match_value) then non-deterministically choose a - // (modify_value, branch) pair from `Map`, modify field to modify_value and - // follow branch". + // The "leave field unmodified" consequent of the "else" branch. + PacketTransformerHandle default_branch; + + // The "if" branches of the decision node. `matches[i]` is a + // (match_value, end_offset) pair encoding "if (field == match_value) then + // non-deterministically choose a ModifyEntry from `MatchModifies(i)`, + // modify field to its modify value and follow its branch". The end offsets + // delimit the per-match ranges of `modifies`: `MatchModifies(i)` is + // `modifies[matches[i-1].end_offset, matches[i].end_offset)`. + // A match with an empty ModifyEntry range denies all packets whose field + // equals its match value. // // INVARIANTS: - // 1. Maintained by `NodeToTransformer`: `modify_branch_by_field_match` and - // `default_branch_by_field_modification` below are not both empty. - // (If they were both empty, the decision node gets replaced by - // `default_branch`.) - // 2. For every v, v', and b such that (v,(v',b)) is in - // `modify_branch_by_field_match`, either v == v' or b is not Deny. - absl::btree_map> - modify_branch_by_field_match; - - // The "else" branch of this decision node, "keyed" by the value they modify - // the field to (or not keyed at all for the `default_branch`). + // 1. Maintained by `NodeToTransformer`: `matches` and `DefaultModifies()` + // are not both empty. (If they were both empty, the decision node gets + // replaced by `default_branch`.) + // 2. For every entry (v', b) in `MatchModifies(i)` with match value v, + // either v == v' or b is not Deny. + // 3. Sorted by strictly increasing match value; end offsets are + // non-decreasing and bounded by `modifies.size()`. + absl::FixedArray, + /*use_heap_allocation_above_size=*/0> + matches; + + // The ModifyEntry ranges of all matches, in match order, followed by the + // "else" modifications (see `DefaultModifies()`): entries encoding "if no + // match value applies, non-deterministically set field -> entry.first and + // follow entry.second". // // INVARIANTS: - // 1. For every v and b such that (v,b) is in - // `default_branch_by_field_modification`, b is not Deny. - absl::btree_map - default_branch_by_field_modification; - PacketTransformerHandle default_branch; + // 1. Each per-match range and the default range is sorted by strictly + // increasing modify value, without duplicates. + // 2. For every entry (v, b) in `DefaultModifies()`, b is not Deny. + absl::FixedArray + modifies; + + // The ModifyEntry range of `matches[i]`. + absl::Span MatchModifies(size_t i) const { + uint32_t begin = i == 0 ? 0 : matches[i - 1].second; + return absl::MakeConstSpan(modifies.data() + begin, + matches[i].second - begin); + } + + // A single "if (field == value)" branch: the match value together with + // its ModifyEntry range. + struct Match { + int value; + absl::Span modifies; + }; + + // Iterates the "if" branches as `Match` views, in order of strictly + // increasing match value. Allows range-for loops over the branches + // without manual index bookkeeping. + class MatchIterator { + public: + MatchIterator(const DecisionNode* node, size_t index) + : node_(node), index_(index) {} + Match operator*() const { + return {node_->matches[index_].first, node_->MatchModifies(index_)}; + } + MatchIterator& operator++() { + ++index_; + return *this; + } + friend bool operator==(const MatchIterator& a, + const MatchIterator& b) = default; + + private: + const DecisionNode* node_; + size_t index_; + }; + struct MatchRange { + const DecisionNode* node; + MatchIterator begin() const { return {node, 0}; } + MatchIterator end() const { return {node, node->matches.size()}; } + }; + MatchRange Matches() const { return {this}; } + + // The ModifyEntry range of the "else" branch. + absl::Span DefaultModifies() const { + uint32_t begin = matches.empty() ? 0 : matches.back().second; + return absl::MakeConstSpan(modifies.data() + begin, + modifies.size() - begin); + } + + // Returns the index into `matches` with the given match value, if any. + std::optional FindMatch(int match_value) const { + auto it = std::lower_bound( + matches.begin(), matches.end(), match_value, + [](const auto& match, int value) { return match.first < value; }); + if (it == matches.end() || it->first != match_value) return std::nullopt; + return it - matches.begin(); + } - // Protect against regressions in memory layout, as it affects performance. - static_assert(sizeof(modify_branch_by_field_match) == 24); - static_assert(sizeof(default_branch_by_field_modification) == 24); + // Returns true iff `entries` (sorted by modify value) contains an entry + // with the given modify value. + static bool ContainsModifyValue(absl::Span entries, + int modify_value) { + auto it = std::lower_bound(entries.begin(), entries.end(), modify_value, + [](const ModifyEntry& entry, int value) { + return entry.first < value; + }); + return it != entries.end() && it->first == modify_value; + } friend auto operator<=>(const DecisionNode& a, const DecisionNode& b) = default; - // Hashing, see https://abseil.io/docs/cpp/guides/hash. - template - friend H AbslHashValue(H h, const DecisionNode& node) { - return H::combine(std::move(h), node.field, node.default_branch, - node.default_branch_by_field_modification, - node.modify_branch_by_field_match); - } + // NOTE: Hashing is deliberately NOT defined on this struct. The unique + // table must hash flat nodes and `DecisionNodeBuilder`s identically, so + // there is a single hash definition for both: `NodeHash`. }; // Protect against regressions in memory layout, as it affects performance. - // TODO(dilo): Is this still important with this simpler data structure, or - // should we remove it until we optimize? - static_assert(sizeof(DecisionNode) == 64); + static_assert(sizeof(DecisionNode) == 40); static_assert(alignof(DecisionNode) == 8); + // A mutable, map-based representation of a `DecisionNode`, used only + // transiently while constructing nodes (by the combinators and the golden + // test runner's canonicalizing copy). Finished builders are canonicalized, + // flattened into `DecisionNode`s, and interned by `NodeToTransformer`. The + // members mirror `DecisionNode`; see there for semantics and invariants. + struct DecisionNodeBuilder { + PacketFieldHandle field; + + // Match value -> (modify value -> branch). See `DecisionNode::matches`. + absl::btree_map> + modify_branch_by_field_match; + + // Modify value -> branch. See `DecisionNode::DefaultModifies()`. + absl::btree_map + default_branch_by_field_modification; + PacketTransformerHandle default_branch; + }; + + // Invokes `match(match_value, end_offset)` for each match header and then + // `modify(modify_value, branch)` for each modify entry of the given node or + // builder, in the canonical flat order of `DecisionNode::matches` and + // `DecisionNode::modifies`. Stops and returns false as soon as a callback + // returns false; returns true if all elements were visited. + // + // This is the single definition of a node's flat element sequence: + // `Flatten`, `NodeHash`, and `NodeEq` are all written against it, which + // keeps the two node representations consistent by construction. + template + static bool ForEachFlatEntry(const DecisionNode& node, MatchFn&& match, + ModifyFn&& modify); + template + static bool ForEachFlatEntry(const DecisionNodeBuilder& builder, + MatchFn&& match, ModifyFn&& modify); + + // Transparent hash and equality functors for the unique table + // (`transformer_by_node_`), which is keyed by stable `DecisionNode*` + // pointers into `nodes_` (so each node is stored only once). Lookups work + // directly with a `DecisionNodeBuilder` — without flattening it — keeping + // the hot path of `NodeToTransformer`, re-deriving a node that already + // exists, free of allocations; flattening only happens for genuinely new + // nodes. Both functors are stateless: keys are pointers, and the pages + // holding the nodes are stable across moves of the manager. + // + // INVARIANT: A builder and its flattened node are treated identically: + // `NodeHash()(b) == NodeHash()(&Flatten(b))` and `NodeEq()(&Flatten(b), b)`. + // Maintained by defining both functors in terms of `ForEachFlatEntry`; + // checked by `CheckInternalInvariants`. + struct NodeHash { + using is_transparent = void; + size_t operator()(const DecisionNode* node) const; + size_t operator()(const DecisionNodeBuilder& builder) const; + + private: + // Adapter implementing both overloads: hashes the canonical flat element + // sequence (via `ForEachFlatEntry`) in a single streaming pass, so flat + // nodes and builders with the same logical content hash identically. + template + struct FlatSequenceView { + const NodeOrBuilder& node; + + // Hashing, see https://abseil.io/docs/cpp/guides/hash. + template + friend H AbslHashValue(H h, const FlatSequenceView& view) { + size_t num_matches = 0; + size_t num_modifies = 0; + h = H::combine(std::move(h), view.node.field, + view.node.default_branch); + ForEachFlatEntry( + view.node, + [&](int match_value, uint32_t end_offset) { + h = H::combine(std::move(h), match_value, end_offset); + ++num_matches; + return true; + }, + [&](int modify_value, PacketTransformerHandle branch) { + h = H::combine(std::move(h), modify_value, branch); + ++num_modifies; + return true; + }); + return H::combine(std::move(h), num_matches, num_modifies); + } + }; + }; + struct NodeEq { + using is_transparent = void; + bool operator()(const DecisionNode* a, const DecisionNode* b) const { + return a == b || *a == *b; + } + bool operator()(const DecisionNode* a, const DecisionNodeBuilder& b) const; + bool operator()(const DecisionNodeBuilder& a, const DecisionNode* b) const { + return (*this)(b, a); + } + }; + // A key for efficiently hashing a `PolicyProto` to a // `PacketTransformerHandle`. This works as a recursive hash, such that we // only internally compile unique messages exactly once. @@ -388,32 +569,7 @@ class PacketTransformerManager { } }; - // Transparent hash and equality functors for the unique table - // (`transformer_by_node_`), which is keyed by stable `DecisionNode*` - // pointers into `nodes_` so that each node is stored only once (rather than - // twice: once in `nodes_` and once as a map key). Lookups work directly - // with a not-yet-interned `DecisionNode` value. Both functors are - // stateless: keys are pointers, and the pages holding the nodes are stable - // across moves of the manager. - struct NodeHash { - using is_transparent = void; - size_t operator()(const DecisionNode* node) const; - size_t operator()(const DecisionNode& node) const; - }; - struct NodeEq { - using is_transparent = void; - bool operator()(const DecisionNode* a, const DecisionNode* b) const { - return a == b || *a == *b; - } - bool operator()(const DecisionNode* a, const DecisionNode& b) const { - return *a == b; - } - bool operator()(const DecisionNode& a, const DecisionNode* b) const { - return a == *b; - } - }; - - PacketTransformerHandle NodeToTransformer(DecisionNode&& node); + PacketTransformerHandle NodeToTransformer(DecisionNodeBuilder&& node); // Returns the `DecisionNode` corresponding to the given // `PacketTransformerHandle`, or crashes if the `transformer` is @@ -441,6 +597,12 @@ class PacketTransformerManager { PacketTransformerHandle right, Combiner&& combiner); + // Conversions between the interned (flat) and builder (map-based) + // representations of decision nodes. `ToBuilder` is only used by + // `CheckInternalInvariants`, to validate `NodeHash`/`NodeEq` consistency. + static DecisionNode Flatten(DecisionNodeBuilder&& builder); + static DecisionNodeBuilder ToBuilder(const DecisionNode& node); + // The decision nodes forming the BDD-style DAG representation of packets. // `PacketTransformerHandle::node_index_` indexes into this vector. // @@ -453,7 +615,8 @@ class PacketTransformerManager { // once, and thus has a unique `PacketTransformerHandle::node_index`. // Keyed by pointers into `nodes_` (stable, see `PagedStableVector`), so // nodes are not stored twice. The transparent `NodeHash`/`NodeEq` functors - // support lookup by node value, see their documentation. + // support lookup by `DecisionNodeBuilder` without flattening, see their + // documentation. // // INVARIANT: `transformer_by_node_[p] = s` iff `p == &nodes_[s.node_index_]`. absl::flat_hash_map