From 1ce23ab0507ef02d0faf9744f42826cb840508db Mon Sep 17 00:00:00 2001 From: Steffen Smolka Date: Wed, 10 Jun 2026 03:13:50 -0700 Subject: [PATCH 1/3] [NetKAT] Eliminate node and map copies in packet transformer combinators. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Union, Sequence, and Difference copied entire DecisionNodes — including their nested btree maps — at every recursion step, re-interned expanded nodes through the unique table just to recover handles the caller already had, and copied the per-value modification map (GetMapAtValue) inside per-value loops. Several TODOs flagged this as a likely performance problem. Restructure the combinators around cheap, non-owning views instead: * DecisionNodeView views an operand "at" a field: either the node's own maps, or the trivial "fall through to this handle" expansion when the node branches on a larger field (or is Accept). This removes both the by-value node copies and the materialize-and-re-intern expansion step. * ModifyBranchesView replaces GetMapAtValue's map copies with a view of the base map plus at most one extra entry, iterated in sorted order. * CombineModifyBranches becomes a template over the combiner, removing AnyInvocable type-erasure and double lookups, and inserts with an end hint since keys arrive sorted. * The value-collection loops now iterate in sorted order (previously nondeterministic flat_hash_set order), making combination order — and thus node numbering — deterministic. Compilation benchmarks improve by roughly 1.5-2.2x (FirstTimeCompile*), with larger wins expected on policies with wider modification maps. Behavior is unchanged: all tests pass, and the golden-file diff test output is byte-identical, including node numbering. Co-Authored-By: Claude Fable 5 --- netkat/packet_transformer.cc | 654 +++++++++++++++++------------------ netkat/packet_transformer.h | 12 - 2 files changed, 322 insertions(+), 344 deletions(-) diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index a125c8b..c2807da 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -14,8 +14,8 @@ #include "netkat/packet_transformer.h" +#include #include -#include #include #include #include @@ -28,7 +28,6 @@ #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -78,22 +77,6 @@ PacketTransformerManager::GetNodeOrDie( return nodes_[transformer.node_index_]; } -// TODO(dilo): Creating as many map copies as this method facilitates is -// probably going to cause terrible performance, and needs to be revisited. -absl::btree_map -PacketTransformerManager::GetMapAtValue(const DecisionNode& node, int value) { - if (node.modify_branch_by_field_match.contains(value)) - return node.modify_branch_by_field_match.at(value); - - absl::btree_map result = - node.default_branch_by_field_modification; - if (result.contains(value) || IsDeny(node.default_branch)) return result; - - // Otherwise, add a mapping from `value` to the default branch, then return. - result[value] = node.default_branch; - return result; -} - // Canonicalizes a decision node and returns a transformer. PacketTransformerHandle PacketTransformerManager::NodeToTransformer( DecisionNode&& node) { @@ -376,249 +359,288 @@ PacketTransformerHandle PacketTransformerManager::Modification( } namespace { -absl::btree_map CombineModifyBranches( - const absl::btree_map& left, - const absl::btree_map& right, - absl::AnyInvocable - combiner, - PacketTransformerHandle default_value) { - absl::btree_map result; - for (const auto& [value, branch] : left) { - if (right.contains(value)) { - result[value] = combiner(branch, right.at(value)); - } else { - result[value] = combiner(branch, default_value); + +// Aliases for the map types of `PacketTransformerManager::DecisionNode`. +using ModifyBranchMap = absl::btree_map; +using MatchBranchMap = absl::btree_map; + +const MatchBranchMap& EmptyMatchBranchMap() { + static const MatchBranchMap* const kEmpty = new MatchBranchMap; + return *kEmpty; +} + +const ModifyBranchMap& EmptyModifyBranchMap() { + static const ModifyBranchMap* const kEmpty = new ModifyBranchMap; + return *kEmpty; +} + +// A cheap, non-owning view of a decision node as seen "at" a field no larger +// than the node's own field: either the node's own maps (if the node branches +// on that field), or the trivial expansion "for every value of the field, +// leave it unmodified and fall through to the node" (if the node branches on +// a strictly larger field, or is Accept). This lets the binary operations +// below combine operands with distinct fields without materializing — and +// then re-interning — the trivial expansion as a `DecisionNode`. +struct DecisionNodeView { + const MatchBranchMap* modify_branch_by_field_match; + const ModifyBranchMap* default_branch_by_field_modification; + PacketTransformerHandle default_branch; +}; + +// Returns the view of `node` (the decision node of `transformer`, or null if +// `transformer` is Accept) at `field`, which must be <= `node->field`. +// +// `Node` is always `PacketTransformerManager::DecisionNode`; it is a template +// parameter only because that type is private and cannot be named here. +template +DecisionNodeView ViewAtField(PacketFieldHandle field, + PacketTransformerHandle transformer, + const Node* node) { + if (node != nullptr && node->field == field) { + return DecisionNodeView{ + .modify_branch_by_field_match = &node->modify_branch_by_field_match, + .default_branch_by_field_modification = + &node->default_branch_by_field_modification, + .default_branch = node->default_branch, + }; + } + return DecisionNodeView{ + .modify_branch_by_field_match = &EmptyMatchBranchMap(), + .default_branch_by_field_modification = &EmptyModifyBranchMap(), + .default_branch = transformer, + }; +} + +// Returns the smallest field branched on by the given decision nodes, at +// least one of which must be non-null (null encodes Accept, which branches on +// no field). See `ViewAtField` regarding the `Node` template parameter. +template +PacketFieldHandle SmallestField(const Node* left, const Node* right) { + DCHECK(left != nullptr || right != nullptr); + if (left == nullptr) return right->field; + if (right == nullptr) return left->field; + return std::min(left->field, right->field); +} + +// A cheap, non-owning view of a logical (modify value -> branch) map, +// represented as a base map plus at most one extra entry whose key must not +// be a key of the base map. Iteration is in increasing key order, like the +// underlying btree map. +class ModifyBranchesView { + public: + using Entry = std::pair; + + explicit ModifyBranchesView(const ModifyBranchMap& base) : base_(&base) {} + ModifyBranchesView(const ModifyBranchMap& base, Entry extra) + : base_(&base), extra_(extra) { + DCHECK(!base.contains(extra.first)); + } + + class const_iterator { + public: + Entry operator*() const { + return ExtraIsNext() ? *extra_ : Entry(it_->first, it_->second); + } + const_iterator& operator++() { + if (ExtraIsNext()) { + extra_ = nullptr; + } else { + ++it_; + } + return *this; + } + friend bool operator==(const const_iterator& a, const const_iterator& b) { + return a.it_ == b.it_ && a.extra_ == b.extra_; } + friend bool operator!=(const const_iterator& a, const const_iterator& b) { + return !(a == b); + } + + private: + friend class ModifyBranchesView; + const_iterator(ModifyBranchMap::const_iterator it, + ModifyBranchMap::const_iterator end, const Entry* extra) + : it_(it), end_(end), extra_(extra) {} + + // The extra entry comes next iff it has not been consumed and its key + // precedes the next base entry's key. + bool ExtraIsNext() const { + return extra_ != nullptr && (it_ == end_ || extra_->first < it_->first); + } + + ModifyBranchMap::const_iterator it_, end_; + const Entry* extra_; // Null if absent or already consumed. + }; + + const_iterator begin() const { + return const_iterator(base_->begin(), base_->end(), + extra_.has_value() ? &*extra_ : nullptr); } - for (const auto& [value, branch] : right) { - if (!result.contains(value)) - result[value] = combiner(default_value, branch); + const_iterator end() const { + return const_iterator(base_->end(), base_->end(), nullptr); + } + + bool empty() const { return base_->empty() && !extra_.has_value(); } + bool contains(int value) const { + return (extra_.has_value() && extra_->first == value) || + base_->contains(value); + } + std::optional find(int value) const { + if (extra_.has_value() && extra_->first == value) return extra_->second; + if (auto it = base_->find(value); it != base_->end()) return it->second; + return std::nullopt; + } + + private: + const ModifyBranchMap* base_; + std::optional extra_; +}; + +// Returns a view of the logical (modify value -> branch) map that `node` +// applies to packets whose field is equal to `value`: the matching entry of +// `modify_branch_by_field_match` if there is one; otherwise the default +// modifications, plus the unmodified fall-through to `default_branch` (keyed +// by `value`, since the field keeps its value) unless that branch is Deny or +// shadowed by a default modification to `value`. +ModifyBranchesView ModifyBranchesAtValue(const DecisionNodeView& node, + int value) { + if (auto it = node.modify_branch_by_field_match->find(value); + it != node.modify_branch_by_field_match->end()) { + return ModifyBranchesView(it->second); + } + const ModifyBranchMap& defaults = *node.default_branch_by_field_modification; + // `PacketTransformerHandle()` is the Deny transformer. + if (defaults.contains(value) || + node.default_branch == PacketTransformerHandle()) { + return ModifyBranchesView(defaults); + } + return ModifyBranchesView(defaults, {value, node.default_branch}); +} + +// Combines two (modify value -> branch) maps key-wise into a new map, using +// `combiner(left_branch, right_branch)` for shared keys and substituting +// `default_value` for the missing side otherwise. +template +ModifyBranchMap CombineModifyBranches(const ModifyBranchesView& left, + const ModifyBranchesView& right, + Combiner&& combiner, + PacketTransformerHandle default_value) { + ModifyBranchMap result; + for (auto [value, left_branch] : left) { + // Keys arrive in increasing order, so inserting at `end()` is O(1). + result.try_emplace( + result.end(), value, + combiner(left_branch, right.find(value).value_or(default_value))); + } + for (auto [value, right_branch] : right) { + if (left.contains(value)) continue; + result.try_emplace(value, combiner(default_value, right_branch)); } return result; } +// Returns the union of the keys of the given maps, sorted and deduplicated. +template +std::vector SortedUniqueKeys(const Maps&... maps) { + std::vector keys; + keys.reserve((maps.size() + ... + 0)); + auto append_keys = [&keys](const auto& map) { + for (const auto& entry : map) keys.push_back(entry.first); + }; + (append_keys(maps), ...); + absl::c_sort(keys); + keys.erase(std::unique(keys.begin(), keys.end()), keys.end()); + return keys; +} + } // namespace -PacketTransformerHandle PacketTransformerManager::Sequence(DecisionNode left, - DecisionNode right) { - // left.field > right.field: Expand the left node, reducing to the inductive - // case. - if (left.field > right.field) { - PacketFieldHandle first_field = right.field; - return Sequence( - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(left)), - }, - std::move(right)); - } +PacketTransformerHandle PacketTransformerManager::Sequence( + PacketTransformerHandle left, PacketTransformerHandle right) { + // Base cases. + if (IsDeny(left) || IsDeny(right)) return Deny(); + if (IsAccept(left)) return right; + if (IsAccept(right)) return left; - // left.field < right.field: Expand the right node, reducing to the - // inductive case. - if (left.field < right.field) { - PacketFieldHandle first_field = left.field; - return Sequence(std::move(left), - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(right)), - }); - } + // Both operands are decision nodes. Combine them at the smaller of their + // fields; an operand branching on a strictly larger field is viewed as a + // trivial node at the combined field. + const DecisionNode* left_node = &GetNodeOrDie(left); + const DecisionNode* right_node = &GetNodeOrDie(right); + const PacketFieldHandle field = SmallestField(left_node, right_node); + const DecisionNodeView left_view = ViewAtField(field, left, left_node); + const DecisionNodeView right_view = ViewAtField(field, right, right_node); + + auto sequence_combiner = [this](PacketTransformerHandle left, + PacketTransformerHandle right) { + return Sequence(left, right); + }; + auto union_combiner = [this](PacketTransformerHandle left, + PacketTransformerHandle right) { + return Union(left, right); + }; + const ModifyBranchesView empty_view(EmptyModifyBranchMap()); - // left.field == right.field: branch on shared field. - DCHECK(left.field == right.field); DecisionNode result_node{ - .field = left.field, - .default_branch = Sequence(left.default_branch, right.default_branch), + .field = field, + .default_branch = + Sequence(left_view.default_branch, right_view.default_branch), }; // Construct the possible results of applying the right node to packets // gotten by taken default modification branches in the left node. - absl::btree_map - right_applied_to_left_modifications; + ModifyBranchMap right_applied_to_left_modifications; for (const auto& [value, branch] : - left.default_branch_by_field_modification) { - absl::btree_map right_at_value_with_sequence = - CombineModifyBranches( - {}, GetMapAtValue(right, value), - /*combiner=*/ - [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Sequence(left, right); - }, - /*default_value=*/branch); + *left_view.default_branch_by_field_modification) { + ModifyBranchMap right_at_value_with_sequence = CombineModifyBranches( + empty_view, ModifyBranchesAtValue(right_view, value), sequence_combiner, + /*default_value=*/branch); right_applied_to_left_modifications = CombineModifyBranches( - right_applied_to_left_modifications, right_at_value_with_sequence, - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, + ModifyBranchesView(right_applied_to_left_modifications), + ModifyBranchesView(right_at_value_with_sequence), union_combiner, /*default_value=*/Deny()); } + ModifyBranchMap sequenced_right_modifications = CombineModifyBranches( + empty_view, + ModifyBranchesView(*right_view.default_branch_by_field_modification), + sequence_combiner, + /*default_value=*/left_view.default_branch); result_node.default_branch_by_field_modification = CombineModifyBranches( - right_applied_to_left_modifications, - CombineModifyBranches( - {}, right.default_branch_by_field_modification, - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Sequence(left, right); - }, - /*default_value=*/left.default_branch), - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, + ModifyBranchesView(right_applied_to_left_modifications), + ModifyBranchesView(sequenced_right_modifications), union_combiner, /*default_value=*/Deny()); - // Collect every value mapped in each node. - absl::flat_hash_set all_possible_values; - all_possible_values.reserve( - left.modify_branch_by_field_match.size() + - right.modify_branch_by_field_match.size() + - left.default_branch_by_field_modification.size() + - right.default_branch_by_field_modification.size() + - right_applied_to_left_modifications.size()); - - absl::c_transform( - left.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - left.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right_applied_to_left_modifications, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - - // For every value in mapped in each node, construct the proper new branch. - for (int value : all_possible_values) { - auto left_map_at_value = GetMapAtValue(left, value); + // For every value mapped in either node, construct the proper new branch. + for (int value : + SortedUniqueKeys(*left_view.modify_branch_by_field_match, + *right_view.modify_branch_by_field_match, + *left_view.default_branch_by_field_modification, + *right_view.default_branch_by_field_modification, + right_applied_to_left_modifications)) { + ModifyBranchesView left_branches_at_value = + ModifyBranchesAtValue(left_view, value); // An empty map is equivalent to a map with a single entry of // , but the latter is not always canonical. However, an // empty map won't work correctly for the merges below (an in fact, the // whole for-loop would be skipped), so we expand it here if necessary. - if (left_map_at_value.empty()) left_map_at_value[value] = Deny(); - - for (const auto& [left_value, left_spp] : left_map_at_value) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - result_node.modify_branch_by_field_match[value], - CombineModifyBranches( - {}, GetMapAtValue(right, left_value), - /*combiner=*/ - [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Sequence(left, right); - }, - /*default_value=*/left_spp), - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, - /*default_value=*/Deny()); + if (left_branches_at_value.empty()) { + left_branches_at_value = + ModifyBranchesView(EmptyModifyBranchMap(), {value, Deny()}); } - } - - return NodeToTransformer(std::move(result_node)); -} - -PacketTransformerHandle PacketTransformerManager::Sequence( - PacketTransformerHandle left, PacketTransformerHandle right) { - // Base cases. - if (IsDeny(left) || IsDeny(right)) return Deny(); - if (IsAccept(left)) return right; - if (IsAccept(right)) return left; - - // If neither node is accept or deny, then sequence the nodes directly. - return Sequence(GetNodeOrDie(left), GetNodeOrDie(right)); -} -PacketTransformerHandle PacketTransformerManager::Union(DecisionNode left, - DecisionNode right) { - // left.field > right.field: Expand the left node, reducing to the inductive - // case. - if (left.field > right.field) { - PacketFieldHandle first_field = right.field; - return Union( - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(left)), - }, - std::move(right)); - } - - // left.field < right.field: Expand the right node, reducing to the - // inductive case. - if (left.field < right.field) { - PacketFieldHandle first_field = left.field; - return Union(std::move(left), - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(right)), - }); - } - - // left.field == right.field: branch on shared field. - DCHECK(left.field == right.field); - DecisionNode result_node{ - .field = left.field, - .default_branch_by_field_modification = CombineModifyBranches( - left.default_branch_by_field_modification, - right.default_branch_by_field_modification, - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, - /*default_value=*/Deny()), - .default_branch = Union(left.default_branch, right.default_branch), - }; - - // Collect every value in mapped in each node. - absl::flat_hash_set all_possible_values; - all_possible_values.reserve( - left.modify_branch_by_field_match.size() + - right.modify_branch_by_field_match.size() + - left.default_branch_by_field_modification.size() + - right.default_branch_by_field_modification.size()); - - absl::c_transform( - left.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - left.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - - // For every value in mapped in each node, construct the proper new branch. - // TODO(dilo): Would like to use absl::bind_front here instead of a lambda: - // absl::bind_front( - // &PacketTransformerManager::Union, this), - for (int value : all_possible_values) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - GetMapAtValue(left, value), GetMapAtValue(right, value), - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, - /*default_value=*/Deny()); + ModifyBranchMap& result_branches = + result_node.modify_branch_by_field_match[value]; + for (auto [left_value, left_branch] : left_branches_at_value) { + ModifyBranchMap sequenced = CombineModifyBranches( + empty_view, ModifyBranchesAtValue(right_view, left_value), + sequence_combiner, + /*default_value=*/left_branch); + result_branches = + CombineModifyBranches(ModifyBranchesView(result_branches), + ModifyBranchesView(sequenced), union_combiner, + /*default_value=*/Deny()); + } } return NodeToTransformer(std::move(result_node)); @@ -631,95 +653,43 @@ PacketTransformerHandle PacketTransformerManager::Union( if (IsDeny(right)) return left; if (IsDeny(left)) return right; - // If either node is accept, then expand it before merging. - if (IsAccept(left) || IsAccept(right)) { - const DecisionNode& other_node = - GetNodeOrDie(IsAccept(left) ? right : left); - return Union( - DecisionNode{ - .field = other_node.field, - .default_branch = Accept(), - }, - other_node); - } - - // If neither node is accept or deny, then union the nodes directly. - return Union(GetNodeOrDie(left), GetNodeOrDie(right)); -} - -PacketTransformerHandle PacketTransformerManager::Difference( - DecisionNode left, DecisionNode right) { - // left.field > right.field: Expand the left node, reducing to the inductive - // case. - if (left.field > right.field) { - PacketFieldHandle first_field = right.field; - return Difference( - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(left)), - }, - std::move(right)); - } - - // left.field < right.field: Expand the right node, reducing to the - // inductive case. - if (left.field < right.field) { - PacketFieldHandle first_field = left.field; - return Difference(std::move(left), - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(right)), - }); - } + // Neither operand is Deny and at most one is Accept, so at least one is a + // decision node. Combine the operands at the smallest field branched on by + // either; an operand that is Accept, or branches on a strictly larger + // field, is viewed as a trivial node at that field. + const DecisionNode* left_node = + IsAccept(left) ? nullptr : &GetNodeOrDie(left); + const DecisionNode* right_node = + IsAccept(right) ? nullptr : &GetNodeOrDie(right); + const PacketFieldHandle field = SmallestField(left_node, right_node); + const DecisionNodeView left_view = ViewAtField(field, left, left_node); + const DecisionNodeView right_view = ViewAtField(field, right, right_node); + + auto union_combiner = [this](PacketTransformerHandle left, + PacketTransformerHandle right) { + return Union(left, right); + }; - // left.field == right.field: branch on shared field. - DCHECK(left.field == right.field); DecisionNode result_node{ - .field = left.field, + .field = field, .default_branch_by_field_modification = CombineModifyBranches( - left.default_branch_by_field_modification, - right.default_branch_by_field_modification, - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Difference(left, right); - }, + ModifyBranchesView(*left_view.default_branch_by_field_modification), + ModifyBranchesView(*right_view.default_branch_by_field_modification), + union_combiner, /*default_value=*/Deny()), - .default_branch = Difference(left.default_branch, right.default_branch), + .default_branch = + Union(left_view.default_branch, right_view.default_branch), }; - // Collect every value in mapped in each node. - absl::flat_hash_set all_possible_values; - all_possible_values.reserve( - left.modify_branch_by_field_match.size() + - right.modify_branch_by_field_match.size() + - left.default_branch_by_field_modification.size() + - right.default_branch_by_field_modification.size()); - - absl::c_transform( - left.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - left.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - - // For every value in mapped in each node, construct the proper new branch. - for (int value : all_possible_values) { + // For every value mapped in either node, construct the proper new branch. + for (int value : + SortedUniqueKeys(*left_view.modify_branch_by_field_match, + *right_view.modify_branch_by_field_match, + *left_view.default_branch_by_field_modification, + *right_view.default_branch_by_field_modification)) { result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - GetMapAtValue(left, value), GetMapAtValue(right, value), - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Difference(left, right); - }, + ModifyBranchesAtValue(left_view, value), + ModifyBranchesAtValue(right_view, value), union_combiner, /*default_value=*/Deny()); } @@ -733,27 +703,47 @@ PacketTransformerHandle PacketTransformerManager::Difference( if (IsDeny(left)) return Deny(); if (IsDeny(right)) return left; - // If either node is accept, then expand it before merging. - if (IsAccept(left)) { - const DecisionNode& right_node = GetNodeOrDie(right); - return Difference( - DecisionNode{ - .field = right_node.field, - .default_branch = Accept(), - }, - right_node); - } + // Neither operand is Deny and at most one is Accept, so at least one is a + // decision node. Combine the operands at the smallest field branched on by + // either; an operand that is Accept, or branches on a strictly larger + // field, is viewed as a trivial node at that field. + const DecisionNode* left_node = + IsAccept(left) ? nullptr : &GetNodeOrDie(left); + const DecisionNode* right_node = + IsAccept(right) ? nullptr : &GetNodeOrDie(right); + const PacketFieldHandle field = SmallestField(left_node, right_node); + const DecisionNodeView left_view = ViewAtField(field, left, left_node); + const DecisionNodeView right_view = ViewAtField(field, right, right_node); + + auto difference_combiner = [this](PacketTransformerHandle left, + PacketTransformerHandle right) { + return Difference(left, right); + }; + + DecisionNode result_node{ + .field = field, + .default_branch_by_field_modification = CombineModifyBranches( + ModifyBranchesView(*left_view.default_branch_by_field_modification), + ModifyBranchesView(*right_view.default_branch_by_field_modification), + difference_combiner, + /*default_value=*/Deny()), + .default_branch = + Difference(left_view.default_branch, right_view.default_branch), + }; - if (IsAccept(right)) { - const DecisionNode& left_node = GetNodeOrDie(left); - return Difference(left_node, DecisionNode{ - .field = left_node.field, - .default_branch = Accept(), - }); + // For every value mapped in either node, construct the proper new branch. + for (int value : + SortedUniqueKeys(*left_view.modify_branch_by_field_match, + *right_view.modify_branch_by_field_match, + *left_view.default_branch_by_field_modification, + *right_view.default_branch_by_field_modification)) { + result_node.modify_branch_by_field_match[value] = CombineModifyBranches( + ModifyBranchesAtValue(left_view, value), + ModifyBranchesAtValue(right_view, value), difference_combiner, + /*default_value=*/Deny()); } - // If neither node is accept or deny, then difference the nodes directly. - return Difference(GetNodeOrDie(left), GetNodeOrDie(right)); + return NodeToTransformer(std::move(result_node)); } PacketTransformerHandle PacketTransformerManager::Iterate( diff --git a/netkat/packet_transformer.h b/netkat/packet_transformer.h index 4c9c90d..b63bf19 100644 --- a/netkat/packet_transformer.h +++ b/netkat/packet_transformer.h @@ -404,18 +404,6 @@ class PacketTransformerManager { // enough to avoid excessive memory overhead. static constexpr size_t kPageSize = (1 << 26) / sizeof(DecisionNode); - // Helper functions to deal with DecisionNodes directly. - // TODO(dilo): Is there a convenient way to either avoid these or avoid making - // copies of the nodes? - PacketTransformerHandle Union(DecisionNode left, DecisionNode right); - PacketTransformerHandle Sequence(DecisionNode left, DecisionNode right); - PacketTransformerHandle Difference(DecisionNode left, DecisionNode right); - - // Internal helper function to get a map of possible modification values to - // branches for a given input value at `node`. - absl::btree_map GetMapAtValue( - const DecisionNode& node, int value); - // The decision nodes forming the BDD-style DAG representation of packets. // `PacketTransformerHandle::node_index_` indexes into this vector. // From bf0e8ae8dfd8196e6af9e29893cc11db5ffe79c4 Mon Sep 17 00:00:00 2001 From: Steffen Smolka Date: Wed, 10 Jun 2026 03:29:02 -0700 Subject: [PATCH 2/3] [NetKAT] Follow-up cleanups: in-place Sequence accumulation, op dedup, stable goldens. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three follow-ups to the copy-elimination refactor: * Sequence accumulated its result maps by rebuilding them from scratch for every left-branch entry, making accumulation quadratic in the number of branches. It now unions each sequenced branch into the result map in place, which also removes the remaining intermediate maps entirely. * Union and Difference were near-identical after the refactor; their shared body (pointwise combination of two operands viewed at a common field) moves into a PointwiseCombine template parameterized on the operation. * The golden diff test asserted raw node-interning indices, so any reordering inside the manager — including the change above, and planned work like operation memoization — churned the golden file without any semantic change. The test runner now renumbers nodes and fields canonically before printing, by copying each compiled transformer into a fresh manager in deterministic traversal order, so the golden output depends only on transformer structure. The golden file is regenerated accordingly (verified: the previous diff was pure renumbering, with structure, fields, and labels identical). Compilation benchmarks improve a further ~9% on top of the previous commit (~1.8x total vs the original code on FirstTimeCompile NonOverlappingPolicy). All tests pass. Co-Authored-By: Claude Fable 5 --- netkat/BUILD.bazel | 4 + netkat/packet_transformer.cc | 144 +++++++----------- netkat/packet_transformer.h | 11 ++ netkat/packet_transformer_test.expected | 186 +++++++++++------------ netkat/packet_transformer_test_runner.cc | 117 +++++++++++++- 5 files changed, 273 insertions(+), 189 deletions(-) diff --git a/netkat/BUILD.bazel b/netkat/BUILD.bazel index 9121fef..a456370 100644 --- a/netkat/BUILD.bazel +++ b/netkat/BUILD.bazel @@ -440,7 +440,11 @@ cc_test( linkstatic = True, deps = [ ":netkat_proto_constructors", + ":packet_field", ":packet_transformer", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index c2807da..efd421f 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -571,15 +571,12 @@ PacketTransformerHandle PacketTransformerManager::Sequence( const DecisionNodeView left_view = ViewAtField(field, left, left_node); const DecisionNodeView right_view = ViewAtField(field, right, right_node); - auto sequence_combiner = [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Sequence(left, right); + // Unions `branch` into `branches[modify_value]`, treating absence as Deny. + auto union_into = [this](ModifyBranchMap& branches, int modify_value, + PacketTransformerHandle branch) { + auto [it, inserted] = branches.try_emplace(modify_value, branch); + if (!inserted) it->second = Union(it->second, branch); }; - auto union_combiner = [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Union(left, right); - }; - const ModifyBranchesView empty_view(EmptyModifyBranchMap()); DecisionNode result_node{ .field = field, @@ -588,28 +585,24 @@ PacketTransformerHandle PacketTransformerManager::Sequence( }; // Construct the possible results of applying the right node to packets - // gotten by taken default modification branches in the left node. - ModifyBranchMap right_applied_to_left_modifications; + // gotten by taken default modification branches in the left node... + ModifyBranchMap& result_modifications = + result_node.default_branch_by_field_modification; for (const auto& [value, branch] : *left_view.default_branch_by_field_modification) { - ModifyBranchMap right_at_value_with_sequence = CombineModifyBranches( - empty_view, ModifyBranchesAtValue(right_view, value), sequence_combiner, - /*default_value=*/branch); - right_applied_to_left_modifications = CombineModifyBranches( - ModifyBranchesView(right_applied_to_left_modifications), - ModifyBranchesView(right_at_value_with_sequence), union_combiner, - /*default_value=*/Deny()); + for (auto [modify_value, right_branch] : + ModifyBranchesAtValue(right_view, value)) { + union_into(result_modifications, modify_value, + Sequence(branch, right_branch)); + } + } + // ... and of applying the right node's default modifications to packets + // falling through the left node unmodified. + for (const auto& [modify_value, right_branch] : + *right_view.default_branch_by_field_modification) { + union_into(result_modifications, modify_value, + Sequence(left_view.default_branch, right_branch)); } - - ModifyBranchMap sequenced_right_modifications = CombineModifyBranches( - empty_view, - ModifyBranchesView(*right_view.default_branch_by_field_modification), - sequence_combiner, - /*default_value=*/left_view.default_branch); - result_node.default_branch_by_field_modification = CombineModifyBranches( - ModifyBranchesView(right_applied_to_left_modifications), - ModifyBranchesView(sequenced_right_modifications), union_combiner, - /*default_value=*/Deny()); // For every value mapped in either node, construct the proper new branch. for (int value : @@ -617,7 +610,7 @@ PacketTransformerHandle PacketTransformerManager::Sequence( *right_view.modify_branch_by_field_match, *left_view.default_branch_by_field_modification, *right_view.default_branch_by_field_modification, - right_applied_to_left_modifications)) { + result_modifications)) { ModifyBranchesView left_branches_at_value = ModifyBranchesAtValue(left_view, value); // An empty map is equivalent to a map with a single entry of @@ -632,27 +625,21 @@ PacketTransformerHandle PacketTransformerManager::Sequence( ModifyBranchMap& result_branches = result_node.modify_branch_by_field_match[value]; for (auto [left_value, left_branch] : left_branches_at_value) { - ModifyBranchMap sequenced = CombineModifyBranches( - empty_view, ModifyBranchesAtValue(right_view, left_value), - sequence_combiner, - /*default_value=*/left_branch); - result_branches = - CombineModifyBranches(ModifyBranchesView(result_branches), - ModifyBranchesView(sequenced), union_combiner, - /*default_value=*/Deny()); + for (auto [modify_value, right_branch] : + ModifyBranchesAtValue(right_view, left_value)) { + union_into(result_branches, modify_value, + Sequence(left_branch, right_branch)); + } } } return NodeToTransformer(std::move(result_node)); } -PacketTransformerHandle PacketTransformerManager::Union( - PacketTransformerHandle left, PacketTransformerHandle right) { - // Base cases. - if (left == right) return left; - if (IsDeny(right)) return left; - if (IsDeny(left)) return right; - +template +PacketTransformerHandle PacketTransformerManager::PointwiseCombine( + PacketTransformerHandle left, PacketTransformerHandle right, + Combiner&& combiner) { // Neither operand is Deny and at most one is Accept, so at least one is a // decision node. Combine the operands at the smallest field branched on by // either; an operand that is Accept, or branches on a strictly larger @@ -665,20 +652,15 @@ PacketTransformerHandle PacketTransformerManager::Union( const DecisionNodeView left_view = ViewAtField(field, left, left_node); const DecisionNodeView right_view = ViewAtField(field, right, right_node); - auto union_combiner = [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Union(left, right); - }; - DecisionNode result_node{ .field = field, .default_branch_by_field_modification = CombineModifyBranches( ModifyBranchesView(*left_view.default_branch_by_field_modification), ModifyBranchesView(*right_view.default_branch_by_field_modification), - union_combiner, + combiner, /*default_value=*/Deny()), .default_branch = - Union(left_view.default_branch, right_view.default_branch), + combiner(left_view.default_branch, right_view.default_branch), }; // For every value mapped in either node, construct the proper new branch. @@ -689,13 +671,27 @@ PacketTransformerHandle PacketTransformerManager::Union( *right_view.default_branch_by_field_modification)) { result_node.modify_branch_by_field_match[value] = CombineModifyBranches( ModifyBranchesAtValue(left_view, value), - ModifyBranchesAtValue(right_view, value), union_combiner, + ModifyBranchesAtValue(right_view, value), combiner, /*default_value=*/Deny()); } return NodeToTransformer(std::move(result_node)); } +PacketTransformerHandle PacketTransformerManager::Union( + PacketTransformerHandle left, PacketTransformerHandle right) { + // Base cases. + if (left == right) return left; + if (IsDeny(right)) return left; + if (IsDeny(left)) return right; + + return PointwiseCombine( + left, right, + [this](PacketTransformerHandle left, PacketTransformerHandle right) { + return Union(left, right); + }); +} + PacketTransformerHandle PacketTransformerManager::Difference( PacketTransformerHandle left, PacketTransformerHandle right) { // Base cases. @@ -703,47 +699,11 @@ PacketTransformerHandle PacketTransformerManager::Difference( if (IsDeny(left)) return Deny(); if (IsDeny(right)) return left; - // Neither operand is Deny and at most one is Accept, so at least one is a - // decision node. Combine the operands at the smallest field branched on by - // either; an operand that is Accept, or branches on a strictly larger - // field, is viewed as a trivial node at that field. - const DecisionNode* left_node = - IsAccept(left) ? nullptr : &GetNodeOrDie(left); - const DecisionNode* right_node = - IsAccept(right) ? nullptr : &GetNodeOrDie(right); - const PacketFieldHandle field = SmallestField(left_node, right_node); - const DecisionNodeView left_view = ViewAtField(field, left, left_node); - const DecisionNodeView right_view = ViewAtField(field, right, right_node); - - auto difference_combiner = [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Difference(left, right); - }; - - DecisionNode result_node{ - .field = field, - .default_branch_by_field_modification = CombineModifyBranches( - ModifyBranchesView(*left_view.default_branch_by_field_modification), - ModifyBranchesView(*right_view.default_branch_by_field_modification), - difference_combiner, - /*default_value=*/Deny()), - .default_branch = - Difference(left_view.default_branch, right_view.default_branch), - }; - - // For every value mapped in either node, construct the proper new branch. - for (int value : - SortedUniqueKeys(*left_view.modify_branch_by_field_match, - *right_view.modify_branch_by_field_match, - *left_view.default_branch_by_field_modification, - *right_view.default_branch_by_field_modification)) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - ModifyBranchesAtValue(left_view, value), - ModifyBranchesAtValue(right_view, value), difference_combiner, - /*default_value=*/Deny()); - } - - return NodeToTransformer(std::move(result_node)); + return PointwiseCombine( + left, right, + [this](PacketTransformerHandle left, PacketTransformerHandle right) { + return Difference(left, right); + }); } PacketTransformerHandle PacketTransformerManager::Iterate( diff --git a/netkat/packet_transformer.h b/netkat/packet_transformer.h index b63bf19..ed0e074 100644 --- a/netkat/packet_transformer.h +++ b/netkat/packet_transformer.h @@ -404,6 +404,17 @@ class PacketTransformerManager { // enough to avoid excessive memory overhead. static constexpr size_t kPageSize = (1 << 26) / sizeof(DecisionNode); + // Shared implementation of `Union` and `Difference`, which differ only in + // their base cases and the operation applied to corresponding branches: + // combines `left` and `right` by applying `combiner` — which must be the + // handle-level operation itself, e.g. `Union` — to corresponding branches. + // Both operands must be Accept or decision nodes (i.e., the base cases must + // already have been handled). + template + PacketTransformerHandle PointwiseCombine(PacketTransformerHandle left, + PacketTransformerHandle right, + Combiner&& combiner); + // The decision nodes forming the BDD-style DAG representation of packets. // `PacketTransformerHandle::node_index_` indexes into this vector. // diff --git a/netkat/packet_transformer_test.expected b/netkat/packet_transformer_test.expected index 9f01d94..ca6b248 100644 --- a/netkat/packet_transformer_test.expected +++ b/netkat/packet_transformer_test.expected @@ -42,22 +42,22 @@ digraph { Test case: p := (a=5 + b=2);(b:=1 + c=5). Example from Katch paper Fig 5. ================================================================================ -- STRING ---------------------------------------------------------------------- -PacketTransformerHandle<8>: +PacketTransformerHandle<3>: PacketFieldHandle<0>:'a' == 5: - PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<6> + PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<1> PacketFieldHandle<0>:'a' == *: - PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<7> -PacketTransformerHandle<6>: + PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<2> +PacketTransformerHandle<1>: PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<5> -PacketTransformerHandle<7>: + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<0> +PacketTransformerHandle<2>: PacketFieldHandle<1>:'b' == 2: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle - PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<5> + PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<0> PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<5>: +PacketTransformerHandle<0>: PacketFieldHandle<2>:'c' == 5: PacketFieldHandle<2>:'c' := 5 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == *: @@ -68,50 +68,50 @@ digraph { edge [fontsize = 12] 4294967294 [label="T" shape=box] 4294967295 [label="F" shape=box] - 8 [label="a"] - 8 -> 6 [label="a==5; a:=5"] - 8 -> 7 [style=dashed] - 6 [label="b"] - 6 -> 4294967294 [label="b:=1" style=dashed] - 6 -> 5 [style=dashed] - 7 [label="b"] - 7 -> 4294967294 [label="b==2; b:=1"] - 7 -> 5 [label="b==2; b:=2"] - 7 -> 4294967295 [style=dashed] - 5 [label="c"] - 5 -> 4294967294 [label="c==5; c:=5"] - 5 -> 4294967295 [style=dashed] + 3 [label="a"] + 3 -> 1 [label="a==5; a:=5"] + 3 -> 2 [style=dashed] + 1 [label="b"] + 1 -> 4294967294 [label="b:=1" style=dashed] + 1 -> 0 [style=dashed] + 2 [label="b"] + 2 -> 4294967294 [label="b==2; b:=1"] + 2 -> 0 [label="b==2; b:=2"] + 2 -> 4294967295 [style=dashed] + 0 [label="c"] + 0 -> 4294967294 [label="c==5; c:=5"] + 0 -> 4294967295 [style=dashed] } ================================================================================ Test case: q := (b=1 + c:=4 + a:=5;b:=1). Example from Katch paper Fig 5. ================================================================================ -- STRING ---------------------------------------------------------------------- -PacketTransformerHandle<17>: +PacketTransformerHandle<5>: PacketFieldHandle<0>:'a' == 1: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<14> + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<2> PacketFieldHandle<0>:'a' == *: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<4> - PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<16> -PacketTransformerHandle<14>: + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<3> + PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<4> +PacketTransformerHandle<2>: PacketFieldHandle<1>:'b' == 1: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<10> -PacketTransformerHandle<4>: + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<1> +PacketTransformerHandle<3>: PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<16>: +PacketTransformerHandle<4>: PacketFieldHandle<1>:'b' == 1: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> PacketFieldHandle<1>:'b' == *: - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<10> -PacketTransformerHandle<13>: + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<1> +PacketTransformerHandle<0>: PacketFieldHandle<2>:'c' == *: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == * -> PacketTransformerHandle -PacketTransformerHandle<10>: +PacketTransformerHandle<1>: PacketFieldHandle<2>:'c' == *: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == * -> PacketTransformerHandle @@ -121,64 +121,64 @@ digraph { edge [fontsize = 12] 4294967294 [label="T" shape=box] 4294967295 [label="F" shape=box] - 17 [label="a"] - 17 -> 14 [label="a==1; a:=1"] - 17 -> 4 [label="a:=1" style=dashed] - 17 -> 16 [style=dashed] - 14 [label="b"] - 14 -> 13 [label="b==1; b:=1"] - 14 -> 4294967294 [label="b:=1" style=dashed] - 14 -> 10 [style=dashed] + 5 [label="a"] + 5 -> 2 [label="a==1; a:=1"] + 5 -> 3 [label="a:=1" style=dashed] + 5 -> 4 [style=dashed] + 2 [label="b"] + 2 -> 0 [label="b==1; b:=1"] + 2 -> 4294967294 [label="b:=1" style=dashed] + 2 -> 1 [style=dashed] + 3 [label="b"] + 3 -> 4294967294 [label="b:=1" style=dashed] + 3 -> 4294967295 [style=dashed] 4 [label="b"] - 4 -> 4294967294 [label="b:=1" style=dashed] - 4 -> 4294967295 [style=dashed] - 16 [label="b"] - 16 -> 13 [label="b==1; b:=1"] - 16 -> 10 [style=dashed] - 13 [label="c"] - 13 -> 4294967294 [label="c:=4" style=dashed] - 13 -> 4294967294 [style=dashed] - 10 [label="c"] - 10 -> 4294967294 [label="c:=4" style=dashed] - 10 -> 4294967295 [style=dashed] + 4 -> 0 [label="b==1; b:=1"] + 4 -> 1 [style=dashed] + 0 [label="c"] + 0 -> 4294967294 [label="c:=4" style=dashed] + 0 -> 4294967294 [style=dashed] + 1 [label="c"] + 1 -> 4294967294 [label="c:=4" style=dashed] + 1 -> 4294967295 [style=dashed] } ================================================================================ Test case: p;q. Example from Katch paper Fig 5. ================================================================================ -- STRING ---------------------------------------------------------------------- -PacketTransformerHandle<22>: +PacketTransformerHandle<6>: PacketFieldHandle<0>:'a' == 1: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<19> + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<2> PacketFieldHandle<0>:'a' == 5: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<4> - PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<21> + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<3> + PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<4> PacketFieldHandle<0>:'a' == *: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<20> - PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<19> -PacketTransformerHandle<19>: + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<5> + PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<2> +PacketTransformerHandle<2>: PacketFieldHandle<1>:'b' == 2: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> - PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<18> + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> + PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<1> PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<4>: +PacketTransformerHandle<3>: PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<21>: +PacketTransformerHandle<4>: PacketFieldHandle<1>:'b' == *: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<18> -PacketTransformerHandle<20>: + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<1> +PacketTransformerHandle<5>: PacketFieldHandle<1>:'b' == 2: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<13>: +PacketTransformerHandle<0>: PacketFieldHandle<2>:'c' == *: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == * -> PacketTransformerHandle -PacketTransformerHandle<18>: +PacketTransformerHandle<1>: PacketFieldHandle<2>:'c' == 5: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == *: @@ -189,29 +189,29 @@ digraph { edge [fontsize = 12] 4294967294 [label="T" shape=box] 4294967295 [label="F" shape=box] - 22 [label="a"] - 22 -> 19 [label="a==1; a:=1"] - 22 -> 4 [label="a==5; a:=1"] - 22 -> 21 [label="a==5; a:=5"] - 22 -> 20 [label="a:=1" style=dashed] - 22 -> 19 [style=dashed] - 19 [label="b"] - 19 -> 13 [label="b==2; b:=1"] - 19 -> 18 [label="b==2; b:=2"] - 19 -> 4294967295 [style=dashed] + 6 [label="a"] + 6 -> 2 [label="a==1; a:=1"] + 6 -> 3 [label="a==5; a:=1"] + 6 -> 4 [label="a==5; a:=5"] + 6 -> 5 [label="a:=1" style=dashed] + 6 -> 2 [style=dashed] + 2 [label="b"] + 2 -> 0 [label="b==2; b:=1"] + 2 -> 1 [label="b==2; b:=2"] + 2 -> 4294967295 [style=dashed] + 3 [label="b"] + 3 -> 4294967294 [label="b:=1" style=dashed] + 3 -> 4294967295 [style=dashed] 4 [label="b"] - 4 -> 4294967294 [label="b:=1" style=dashed] - 4 -> 4294967295 [style=dashed] - 21 [label="b"] - 21 -> 13 [label="b:=1" style=dashed] - 21 -> 18 [style=dashed] - 20 [label="b"] - 20 -> 4294967294 [label="b==2; b:=1"] - 20 -> 4294967295 [style=dashed] - 13 [label="c"] - 13 -> 4294967294 [label="c:=4" style=dashed] - 13 -> 4294967294 [style=dashed] - 18 [label="c"] - 18 -> 4294967294 [label="c==5; c:=4"] - 18 -> 4294967295 [style=dashed] + 4 -> 0 [label="b:=1" style=dashed] + 4 -> 1 [style=dashed] + 5 [label="b"] + 5 -> 4294967294 [label="b==2; b:=1"] + 5 -> 4294967295 [style=dashed] + 0 [label="c"] + 0 -> 4294967294 [label="c:=4" style=dashed] + 0 -> 4294967294 [style=dashed] + 1 [label="c"] + 1 -> 4294967294 [label="c==5; c:=4"] + 1 -> 4294967295 [style=dashed] } diff --git a/netkat/packet_transformer_test_runner.cc b/netkat/packet_transformer_test_runner.cc index 653c320..c4eacfc 100644 --- a/netkat/packet_transformer_test_runner.cc +++ b/netkat/packet_transformer_test_runner.cc @@ -16,15 +16,118 @@ // `bazel run //netkat:packet_transformer_diff_test // -- --update` +#include #include #include #include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "netkat/netkat_proto_constructors.h" +#include "netkat/packet_field.h" #include "netkat/packet_transformer.h" namespace netkat { + +// Test peer to access `PacketTransformerManager` internals, allowing this +// runner to renumber nodes canonically before printing. +class PacketTransformerManagerTestPeer { + public: + // Copies `transformer` into the fresh manager `to`, interning fields and + // nodes in a deterministic traversal order. This makes the printed handle + // numbers depend only on the structure of `transformer` — not on the + // interning order of `from`, which is an implementation detail that changes + // under refactorings of the manager (e.g. reordering its operations). + static PacketTransformerHandle CanonicalCopy( + const PacketTransformerManager& from, PacketTransformerHandle transformer, + PacketTransformerManager& to) { + // Intern the fields of all reachable nodes in their original relative + // order, which the node invariants ("fields increase strictly along each + // path") depend on. + std::vector fields; + absl::flat_hash_set visited; + CollectFields(from, transformer, visited, fields); + absl::c_sort(fields); + fields.erase(std::unique(fields.begin(), fields.end()), fields.end()); + for (PacketFieldHandle field : fields) { + to.packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( + from.packet_set_manager_.field_manager_.GetFieldName(field)); + } + + absl::flat_hash_map + copy_by_original; + return Copy(from, transformer, to, copy_by_original); + } + + private: + static void CollectFields( + const PacketTransformerManager& from, PacketTransformerHandle transformer, + absl::flat_hash_set& visited, + std::vector& fields) { + if (from.IsDeny(transformer) || from.IsAccept(transformer)) return; + if (!visited.insert(transformer).second) return; + const PacketTransformerManager::DecisionNode& node = + from.GetNodeOrDie(transformer); + fields.push_back(node.field); + for (const auto& [match_value, branch_by_modify] : + node.modify_branch_by_field_match) { + for (const auto& [modify_value, branch] : branch_by_modify) { + CollectFields(from, branch, visited, fields); + } + } + for (const auto& [modify_value, branch] : + node.default_branch_by_field_modification) { + CollectFields(from, branch, visited, fields); + } + CollectFields(from, node.default_branch, visited, fields); + } + + static PacketTransformerHandle Copy( + const PacketTransformerManager& from, PacketTransformerHandle transformer, + PacketTransformerManager& to, + absl::flat_hash_map& + copy_by_original) { + if (from.IsDeny(transformer)) return to.Deny(); + if (from.IsAccept(transformer)) return to.Accept(); + if (auto it = copy_by_original.find(transformer); + it != copy_by_original.end()) { + return it->second; + } + + const PacketTransformerManager::DecisionNode& node = + from.GetNodeOrDie(transformer); + PacketTransformerManager::DecisionNode copy{ + .field = + to.packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( + from.packet_set_manager_.field_manager_.GetFieldName( + node.field)), + }; + for (const auto& [match_value, branch_by_modify] : + node.modify_branch_by_field_match) { + // `operator[]` keeps entries with empty branch maps, which are + // meaningful: they deny packets matching `match_value`. + auto& copy_branch_by_modify = + copy.modify_branch_by_field_match[match_value]; + for (const auto& [modify_value, branch] : branch_by_modify) { + copy_branch_by_modify[modify_value] = + Copy(from, branch, to, copy_by_original); + } + } + for (const auto& [modify_value, branch] : + node.default_branch_by_field_modification) { + copy.default_branch_by_field_modification[modify_value] = + Copy(from, branch, to, copy_by_original); + } + copy.default_branch = Copy(from, node.default_branch, to, copy_by_original); + + PacketTransformerHandle result = to.NodeToTransformer(std::move(copy)); + copy_by_original.emplace(transformer, result); + return result; + } +}; + namespace { constexpr char kBanner[] = @@ -91,16 +194,22 @@ std::vector TestCases() { } void main() { - // This test needs a deterministic field interning order, and thus must start - // from a fresh manager. PacketTransformerManager manager; for (const TestCase& test_case : TestCases()) { netkat::PacketTransformerHandle packet_transformer = manager.Compile(test_case.policy); + // Renumber nodes and fields canonically, in a fresh manager per test case, + // so that the printed output depends only on the structure of the compiled + // transformer and not on the interning order of `manager`. + PacketTransformerManager canonical_manager; + netkat::PacketTransformerHandle canonical_transformer = + PacketTransformerManagerTestPeer::CanonicalCopy( + manager, packet_transformer, canonical_manager); std::cout << kBanner << "Test case: " << test_case.description << std::endl << kBanner; - std::cout << kStringHeader << manager.ToString(packet_transformer); - std::cout << kDotHeader << manager.ToDot(packet_transformer); + std::cout << kStringHeader + << canonical_manager.ToString(canonical_transformer); + std::cout << kDotHeader << canonical_manager.ToDot(canonical_transformer); } } From 37aa3a4b24aca9140eb7885a5678af5e1ac2252e Mon Sep 17 00:00:00 2001 From: Steffen Smolka Date: Wed, 10 Jun 2026 04:07:20 -0700 Subject: [PATCH 3/3] [NetKAT] Apply review cleanups to the combinator refactor. Findings from a four-angle cleanup review (reuse, simplification, efficiency, altitude): * Replace ModifyBranchesView's hand-rolled merging iterator (~45 lines, the subtlest code in the refactor) with a simple ForEach that visits base entries then the extra entry. No caller needed globally sorted iteration: map contents are order-independent and handle-level results are canonical regardless of combination order. * Deduplicate the operand-view prologue shared by Sequence and PointwiseCombine into ViewOperandsAtSmallestField. * Use absl::NoDestructor for the empty-map singletons (the repo's established idiom) and name the default-handle-is-Deny contract in one place (IsDenyHandle) instead of re-deriving it inline. * Hint the per-value btree insertions in Sequence/PointwiseCombine with end(), since SortedUniqueKeys emits values in increasing order. * Test runner: collect fields into a btree_set instead of sort+unique, and translate field handles through a precomputed map instead of a per-node string round-trip. * Drop a stale TODO referencing the deleted GetMapAtValue, and the now-unused any_invocable dependency. No behavior change: all tests pass and the golden file is unchanged. Benchmarks are neutral. Co-Authored-By: Claude Fable 5 --- netkat/BUILD.bazel | 4 +- netkat/packet_transformer.cc | 186 +++++++++++------------ netkat/packet_transformer_test_runner.cc | 36 ++--- 3 files changed, 106 insertions(+), 120 deletions(-) diff --git a/netkat/BUILD.bazel b/netkat/BUILD.bazel index a456370..5bffd85 100644 --- a/netkat/BUILD.bazel +++ b/netkat/BUILD.bazel @@ -371,11 +371,11 @@ cc_library( ":packet_set", ":paged_stable_vector", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -442,7 +442,7 @@ cc_test( ":netkat_proto_constructors", ":packet_field", ":packet_transformer", - "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", ], diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index efd421f..008fe6d 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -24,6 +24,7 @@ #include #include "absl/algorithm/container.h" +#include "absl/base/no_destructor.h" #include "absl/container/btree_map.h" #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" @@ -121,9 +122,6 @@ PacketTransformerHandle PacketTransformerManager::NodeToTransformer( absl::flat_hash_set redundant_values; for (auto& [match_value, modification_map] : node.modify_branch_by_field_match) { - // TODO(dilo): Consider if this can make use of GetMapAtValue. Perhaps by - // calling it on a copy of DecisionNode without any - // `modify_branch_by_field_match` mappings? if (skip_default_branch || node.default_branch_by_field_modification.contains(match_value)) { // Compare the modification map to the default branch modification map, @@ -365,15 +363,22 @@ using ModifyBranchMap = absl::btree_map; using MatchBranchMap = absl::btree_map; const MatchBranchMap& EmptyMatchBranchMap() { - static const MatchBranchMap* const kEmpty = new MatchBranchMap; + static const absl::NoDestructor kEmpty; return *kEmpty; } const ModifyBranchMap& EmptyModifyBranchMap() { - static const ModifyBranchMap* const kEmpty = new ModifyBranchMap; + static const absl::NoDestructor kEmpty; return *kEmpty; } +// Returns true iff `transformer` is the Deny transformer, which by documented +// contract is the default-constructed handle. (Unlike +// `PacketTransformerManager::IsDeny`, this is callable without a manager.) +bool IsDenyHandle(PacketTransformerHandle transformer) { + return transformer == PacketTransformerHandle(); +} + // A cheap, non-owning view of a decision node as seen "at" a field no larger // than the node's own field: either the node's own maps (if the node branches // on that field), or the trivial expansion "for every value of the field, @@ -422,10 +427,35 @@ PacketFieldHandle SmallestField(const Node* left, const Node* right) { return std::min(left->field, right->field); } +// The two operands of a binary combinator, viewed at the smallest field +// branched on by either: an operand that is Accept, or branches on a strictly +// larger field, is viewed as a trivial node at that field. +struct OperandViews { + PacketFieldHandle field; + DecisionNodeView left; + DecisionNodeView right; +}; + +// Views the operands `left` and `right`, whose decision nodes are `left_node` +// and `right_node` (null encoding Accept; at least one must be non-null), at +// the smallest field branched on by either. See `ViewAtField` regarding the +// `Node` template parameter. +template +OperandViews ViewOperandsAtSmallestField(PacketTransformerHandle left, + const Node* left_node, + PacketTransformerHandle right, + const Node* right_node) { + const PacketFieldHandle field = SmallestField(left_node, right_node); + return OperandViews{ + .field = field, + .left = ViewAtField(field, left, left_node), + .right = ViewAtField(field, right, right_node), + }; +} + // A cheap, non-owning view of a logical (modify value -> branch) map, // represented as a base map plus at most one extra entry whose key must not -// be a key of the base map. Iteration is in increasing key order, like the -// underlying btree map. +// be a key of the base map. class ModifyBranchesView { public: using Entry = std::pair; @@ -436,61 +466,22 @@ class ModifyBranchesView { DCHECK(!base.contains(extra.first)); } - class const_iterator { - public: - Entry operator*() const { - return ExtraIsNext() ? *extra_ : Entry(it_->first, it_->second); - } - const_iterator& operator++() { - if (ExtraIsNext()) { - extra_ = nullptr; - } else { - ++it_; - } - return *this; - } - friend bool operator==(const const_iterator& a, const const_iterator& b) { - return a.it_ == b.it_ && a.extra_ == b.extra_; - } - friend bool operator!=(const const_iterator& a, const const_iterator& b) { - return !(a == b); - } - - private: - friend class ModifyBranchesView; - const_iterator(ModifyBranchMap::const_iterator it, - ModifyBranchMap::const_iterator end, const Entry* extra) - : it_(it), end_(end), extra_(extra) {} - - // The extra entry comes next iff it has not been consumed and its key - // precedes the next base entry's key. - bool ExtraIsNext() const { - return extra_ != nullptr && (it_ == end_ || extra_->first < it_->first); - } - - ModifyBranchMap::const_iterator it_, end_; - const Entry* extra_; // Null if absent or already consumed. - }; - - const_iterator begin() const { - return const_iterator(base_->begin(), base_->end(), - extra_.has_value() ? &*extra_ : nullptr); - } - const_iterator end() const { - return const_iterator(base_->end(), base_->end(), nullptr); + // Invokes `fn(modify_value, branch)` for each entry: the base map entries + // in increasing key order, then the extra entry, if any. + template + void ForEach(Fn&& fn) const { + for (const auto& [value, branch] : *base_) fn(value, branch); + if (extra_.has_value()) fn(extra_->first, extra_->second); } - bool empty() const { return base_->empty() && !extra_.has_value(); } - bool contains(int value) const { - return (extra_.has_value() && extra_->first == value) || - base_->contains(value); - } - std::optional find(int value) const { + std::optional Find(int value) const { if (extra_.has_value() && extra_->first == value) return extra_->second; if (auto it = base_->find(value); it != base_->end()) return it->second; return std::nullopt; } + bool empty() const { return base_->empty() && !extra_.has_value(); } + private: const ModifyBranchMap* base_; std::optional extra_; @@ -509,9 +500,7 @@ ModifyBranchesView ModifyBranchesAtValue(const DecisionNodeView& node, return ModifyBranchesView(it->second); } const ModifyBranchMap& defaults = *node.default_branch_by_field_modification; - // `PacketTransformerHandle()` is the Deny transformer. - if (defaults.contains(value) || - node.default_branch == PacketTransformerHandle()) { + if (defaults.contains(value) || IsDenyHandle(node.default_branch)) { return ModifyBranchesView(defaults); } return ModifyBranchesView(defaults, {value, node.default_branch}); @@ -526,16 +515,16 @@ ModifyBranchMap CombineModifyBranches(const ModifyBranchesView& left, Combiner&& combiner, PacketTransformerHandle default_value) { ModifyBranchMap result; - for (auto [value, left_branch] : left) { - // Keys arrive in increasing order, so inserting at `end()` is O(1). + left.ForEach([&](int value, PacketTransformerHandle left_branch) { + // Keys arrive in near-sorted order, so the `end()` hint is mostly exact. result.try_emplace( result.end(), value, - combiner(left_branch, right.find(value).value_or(default_value))); - } - for (auto [value, right_branch] : right) { - if (left.contains(value)) continue; + combiner(left_branch, right.Find(value).value_or(default_value))); + }); + right.ForEach([&](int value, PacketTransformerHandle right_branch) { + if (left.Find(value).has_value()) return; result.try_emplace(value, combiner(default_value, right_branch)); - } + }); return result; } @@ -562,14 +551,9 @@ PacketTransformerHandle PacketTransformerManager::Sequence( if (IsAccept(left)) return right; if (IsAccept(right)) return left; - // Both operands are decision nodes. Combine them at the smaller of their - // fields; an operand branching on a strictly larger field is viewed as a - // trivial node at the combined field. - const DecisionNode* left_node = &GetNodeOrDie(left); - const DecisionNode* right_node = &GetNodeOrDie(right); - const PacketFieldHandle field = SmallestField(left_node, right_node); - const DecisionNodeView left_view = ViewAtField(field, left, left_node); - const DecisionNodeView right_view = ViewAtField(field, right, right_node); + // Both operands are decision nodes. + const auto [field, left_view, right_view] = ViewOperandsAtSmallestField( + left, &GetNodeOrDie(left), right, &GetNodeOrDie(right)); // Unions `branch` into `branches[modify_value]`, treating absence as Deny. auto union_into = [this](ModifyBranchMap& branches, int modify_value, @@ -590,11 +574,12 @@ PacketTransformerHandle PacketTransformerManager::Sequence( result_node.default_branch_by_field_modification; for (const auto& [value, branch] : *left_view.default_branch_by_field_modification) { - for (auto [modify_value, right_branch] : - ModifyBranchesAtValue(right_view, value)) { - union_into(result_modifications, modify_value, - Sequence(branch, right_branch)); - } + ModifyBranchesAtValue(right_view, value) + .ForEach([&, branch = branch](int modify_value, + PacketTransformerHandle right_branch) { + union_into(result_modifications, modify_value, + Sequence(branch, right_branch)); + }); } // ... and of applying the right node's default modifications to packets // falling through the left node unmodified. @@ -622,15 +607,19 @@ PacketTransformerHandle PacketTransformerManager::Sequence( ModifyBranchesView(EmptyModifyBranchMap(), {value, Deny()}); } + // `value`s arrive in increasing order, so inserting at `end()` is O(1). ModifyBranchMap& result_branches = - result_node.modify_branch_by_field_match[value]; - for (auto [left_value, left_branch] : left_branches_at_value) { - for (auto [modify_value, right_branch] : - ModifyBranchesAtValue(right_view, left_value)) { - union_into(result_branches, modify_value, - Sequence(left_branch, right_branch)); - } - } + result_node.modify_branch_by_field_match + .try_emplace(result_node.modify_branch_by_field_match.end(), value) + ->second; + left_branches_at_value.ForEach([&](int left_value, + PacketTransformerHandle left_branch) { + ModifyBranchesAtValue(right_view, left_value) + .ForEach([&](int modify_value, PacketTransformerHandle right_branch) { + union_into(result_branches, modify_value, + Sequence(left_branch, right_branch)); + }); + }); } return NodeToTransformer(std::move(result_node)); @@ -641,16 +630,10 @@ PacketTransformerHandle PacketTransformerManager::PointwiseCombine( PacketTransformerHandle left, PacketTransformerHandle right, Combiner&& combiner) { // Neither operand is Deny and at most one is Accept, so at least one is a - // decision node. Combine the operands at the smallest field branched on by - // either; an operand that is Accept, or branches on a strictly larger - // field, is viewed as a trivial node at that field. - const DecisionNode* left_node = - IsAccept(left) ? nullptr : &GetNodeOrDie(left); - const DecisionNode* right_node = - IsAccept(right) ? nullptr : &GetNodeOrDie(right); - const PacketFieldHandle field = SmallestField(left_node, right_node); - const DecisionNodeView left_view = ViewAtField(field, left, left_node); - const DecisionNodeView right_view = ViewAtField(field, right, right_node); + // decision node. + const auto [field, left_view, right_view] = ViewOperandsAtSmallestField( + left, IsAccept(left) ? nullptr : &GetNodeOrDie(left), right, + IsAccept(right) ? nullptr : &GetNodeOrDie(right)); DecisionNode result_node{ .field = field, @@ -669,10 +652,13 @@ PacketTransformerHandle PacketTransformerManager::PointwiseCombine( *right_view.modify_branch_by_field_match, *left_view.default_branch_by_field_modification, *right_view.default_branch_by_field_modification)) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - ModifyBranchesAtValue(left_view, value), - ModifyBranchesAtValue(right_view, value), combiner, - /*default_value=*/Deny()); + // `value`s arrive in increasing order, so inserting at `end()` is O(1). + result_node.modify_branch_by_field_match.try_emplace( + result_node.modify_branch_by_field_match.end(), value, + CombineModifyBranches(ModifyBranchesAtValue(left_view, value), + ModifyBranchesAtValue(right_view, value), + combiner, + /*default_value=*/Deny())); } return NodeToTransformer(std::move(result_node)); diff --git a/netkat/packet_transformer_test_runner.cc b/netkat/packet_transformer_test_runner.cc index c4eacfc..e6a4ad7 100644 --- a/netkat/packet_transformer_test_runner.cc +++ b/netkat/packet_transformer_test_runner.cc @@ -16,13 +16,12 @@ // `bazel run //netkat:packet_transformer_diff_test // -- --update` -#include #include #include #include #include -#include "absl/algorithm/container.h" +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "netkat/netkat_proto_constructors.h" @@ -45,32 +44,33 @@ class PacketTransformerManagerTestPeer { PacketTransformerManager& to) { // Intern the fields of all reachable nodes in their original relative // order, which the node invariants ("fields increase strictly along each - // path") depend on. - std::vector fields; + // path") depend on, and record the handle translation for `Copy`. + absl::btree_set fields; absl::flat_hash_set visited; CollectFields(from, transformer, visited, fields); - absl::c_sort(fields); - fields.erase(std::unique(fields.begin(), fields.end()), fields.end()); + absl::flat_hash_map field_translation; for (PacketFieldHandle field : fields) { - to.packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( - from.packet_set_manager_.field_manager_.GetFieldName(field)); + field_translation.try_emplace( + field, + to.packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( + from.packet_set_manager_.field_manager_.GetFieldName(field))); } absl::flat_hash_map copy_by_original; - return Copy(from, transformer, to, copy_by_original); + return Copy(from, transformer, to, field_translation, copy_by_original); } private: static void CollectFields( const PacketTransformerManager& from, PacketTransformerHandle transformer, absl::flat_hash_set& visited, - std::vector& fields) { + absl::btree_set& fields) { if (from.IsDeny(transformer) || from.IsAccept(transformer)) return; if (!visited.insert(transformer).second) return; const PacketTransformerManager::DecisionNode& node = from.GetNodeOrDie(transformer); - fields.push_back(node.field); + fields.insert(node.field); for (const auto& [match_value, branch_by_modify] : node.modify_branch_by_field_match) { for (const auto& [modify_value, branch] : branch_by_modify) { @@ -87,6 +87,8 @@ class PacketTransformerManagerTestPeer { static PacketTransformerHandle Copy( const PacketTransformerManager& from, PacketTransformerHandle transformer, PacketTransformerManager& to, + const absl::flat_hash_map& + field_translation, absl::flat_hash_map& copy_by_original) { if (from.IsDeny(transformer)) return to.Deny(); @@ -99,10 +101,7 @@ class PacketTransformerManagerTestPeer { const PacketTransformerManager::DecisionNode& node = from.GetNodeOrDie(transformer); PacketTransformerManager::DecisionNode copy{ - .field = - to.packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( - from.packet_set_manager_.field_manager_.GetFieldName( - node.field)), + .field = field_translation.at(node.field), }; for (const auto& [match_value, branch_by_modify] : node.modify_branch_by_field_match) { @@ -112,15 +111,16 @@ class PacketTransformerManagerTestPeer { copy.modify_branch_by_field_match[match_value]; for (const auto& [modify_value, branch] : branch_by_modify) { copy_branch_by_modify[modify_value] = - Copy(from, branch, to, copy_by_original); + Copy(from, branch, to, field_translation, copy_by_original); } } for (const auto& [modify_value, branch] : node.default_branch_by_field_modification) { copy.default_branch_by_field_modification[modify_value] = - Copy(from, branch, to, copy_by_original); + Copy(from, branch, to, field_translation, copy_by_original); } - copy.default_branch = Copy(from, node.default_branch, to, copy_by_original); + copy.default_branch = Copy(from, node.default_branch, to, field_translation, + copy_by_original); PacketTransformerHandle result = to.NodeToTransformer(std::move(copy)); copy_by_original.emplace(transformer, result);