diff --git a/netkat/BUILD.bazel b/netkat/BUILD.bazel index 9121fef..5bffd85 100644 --- a/netkat/BUILD.bazel +++ b/netkat/BUILD.bazel @@ -371,11 +371,11 @@ cc_library( ":packet_set", ":paged_stable_vector", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -440,7 +440,11 @@ cc_test( linkstatic = True, deps = [ ":netkat_proto_constructors", + ":packet_field", ":packet_transformer", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index a125c8b..008fe6d 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -14,8 +14,8 @@ #include "netkat/packet_transformer.h" +#include #include -#include #include #include #include @@ -24,11 +24,11 @@ #include #include "absl/algorithm/container.h" +#include "absl/base/no_destructor.h" #include "absl/container/btree_map.h" #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -78,22 +78,6 @@ PacketTransformerManager::GetNodeOrDie( return nodes_[transformer.node_index_]; } -// TODO(dilo): Creating as many map copies as this method facilitates is -// probably going to cause terrible performance, and needs to be revisited. -absl::btree_map -PacketTransformerManager::GetMapAtValue(const DecisionNode& node, int value) { - if (node.modify_branch_by_field_match.contains(value)) - return node.modify_branch_by_field_match.at(value); - - absl::btree_map result = - node.default_branch_by_field_modification; - if (result.contains(value) || IsDeny(node.default_branch)) return result; - - // Otherwise, add a mapping from `value` to the default branch, then return. - result[value] = node.default_branch; - return result; -} - // Canonicalizes a decision node and returns a transformer. PacketTransformerHandle PacketTransformerManager::NodeToTransformer( DecisionNode&& node) { @@ -138,9 +122,6 @@ PacketTransformerHandle PacketTransformerManager::NodeToTransformer( absl::flat_hash_set redundant_values; for (auto& [match_value, modification_map] : node.modify_branch_by_field_match) { - // TODO(dilo): Consider if this can make use of GetMapAtValue. Perhaps by - // calling it on a copy of DecisionNode without any - // `modify_branch_by_field_match` mappings? if (skip_default_branch || node.default_branch_by_field_modification.contains(match_value)) { // Compare the modification map to the default branch modification map, @@ -376,249 +357,308 @@ PacketTransformerHandle PacketTransformerManager::Modification( } namespace { -absl::btree_map CombineModifyBranches( - const absl::btree_map& left, - const absl::btree_map& right, - absl::AnyInvocable - combiner, - PacketTransformerHandle default_value) { - absl::btree_map result; - for (const auto& [value, branch] : left) { - if (right.contains(value)) { - result[value] = combiner(branch, right.at(value)); - } else { - result[value] = combiner(branch, default_value); - } + +// Aliases for the map types of `PacketTransformerManager::DecisionNode`. +using ModifyBranchMap = absl::btree_map; +using MatchBranchMap = absl::btree_map; + +const MatchBranchMap& EmptyMatchBranchMap() { + static const absl::NoDestructor kEmpty; + return *kEmpty; +} + +const ModifyBranchMap& EmptyModifyBranchMap() { + static const absl::NoDestructor kEmpty; + return *kEmpty; +} + +// Returns true iff `transformer` is the Deny transformer, which by documented +// contract is the default-constructed handle. (Unlike +// `PacketTransformerManager::IsDeny`, this is callable without a manager.) +bool IsDenyHandle(PacketTransformerHandle transformer) { + return transformer == PacketTransformerHandle(); +} + +// A cheap, non-owning view of a decision node as seen "at" a field no larger +// than the node's own field: either the node's own maps (if the node branches +// on that field), or the trivial expansion "for every value of the field, +// leave it unmodified and fall through to the node" (if the node branches on +// a strictly larger field, or is Accept). This lets the binary operations +// below combine operands with distinct fields without materializing — and +// then re-interning — the trivial expansion as a `DecisionNode`. +struct DecisionNodeView { + const MatchBranchMap* modify_branch_by_field_match; + const ModifyBranchMap* default_branch_by_field_modification; + PacketTransformerHandle default_branch; +}; + +// Returns the view of `node` (the decision node of `transformer`, or null if +// `transformer` is Accept) at `field`, which must be <= `node->field`. +// +// `Node` is always `PacketTransformerManager::DecisionNode`; it is a template +// parameter only because that type is private and cannot be named here. +template +DecisionNodeView ViewAtField(PacketFieldHandle field, + PacketTransformerHandle transformer, + const Node* node) { + if (node != nullptr && node->field == field) { + return DecisionNodeView{ + .modify_branch_by_field_match = &node->modify_branch_by_field_match, + .default_branch_by_field_modification = + &node->default_branch_by_field_modification, + .default_branch = node->default_branch, + }; + } + return DecisionNodeView{ + .modify_branch_by_field_match = &EmptyMatchBranchMap(), + .default_branch_by_field_modification = &EmptyModifyBranchMap(), + .default_branch = transformer, + }; +} + +// Returns the smallest field branched on by the given decision nodes, at +// least one of which must be non-null (null encodes Accept, which branches on +// no field). See `ViewAtField` regarding the `Node` template parameter. +template +PacketFieldHandle SmallestField(const Node* left, const Node* right) { + DCHECK(left != nullptr || right != nullptr); + if (left == nullptr) return right->field; + if (right == nullptr) return left->field; + return std::min(left->field, right->field); +} + +// The two operands of a binary combinator, viewed at the smallest field +// branched on by either: an operand that is Accept, or branches on a strictly +// larger field, is viewed as a trivial node at that field. +struct OperandViews { + PacketFieldHandle field; + DecisionNodeView left; + DecisionNodeView right; +}; + +// Views the operands `left` and `right`, whose decision nodes are `left_node` +// and `right_node` (null encoding Accept; at least one must be non-null), at +// the smallest field branched on by either. See `ViewAtField` regarding the +// `Node` template parameter. +template +OperandViews ViewOperandsAtSmallestField(PacketTransformerHandle left, + const Node* left_node, + PacketTransformerHandle right, + const Node* right_node) { + const PacketFieldHandle field = SmallestField(left_node, right_node); + return OperandViews{ + .field = field, + .left = ViewAtField(field, left, left_node), + .right = ViewAtField(field, right, right_node), + }; +} + +// A cheap, non-owning view of a logical (modify value -> branch) map, +// represented as a base map plus at most one extra entry whose key must not +// be a key of the base map. +class ModifyBranchesView { + public: + using Entry = std::pair; + + explicit ModifyBranchesView(const ModifyBranchMap& base) : base_(&base) {} + ModifyBranchesView(const ModifyBranchMap& base, Entry extra) + : base_(&base), extra_(extra) { + DCHECK(!base.contains(extra.first)); + } + + // Invokes `fn(modify_value, branch)` for each entry: the base map entries + // in increasing key order, then the extra entry, if any. + template + void ForEach(Fn&& fn) const { + for (const auto& [value, branch] : *base_) fn(value, branch); + if (extra_.has_value()) fn(extra_->first, extra_->second); + } + + std::optional Find(int value) const { + if (extra_.has_value() && extra_->first == value) return extra_->second; + if (auto it = base_->find(value); it != base_->end()) return it->second; + return std::nullopt; + } + + bool empty() const { return base_->empty() && !extra_.has_value(); } + + private: + const ModifyBranchMap* base_; + std::optional extra_; +}; + +// Returns a view of the logical (modify value -> branch) map that `node` +// applies to packets whose field is equal to `value`: the matching entry of +// `modify_branch_by_field_match` if there is one; otherwise the default +// modifications, plus the unmodified fall-through to `default_branch` (keyed +// by `value`, since the field keeps its value) unless that branch is Deny or +// shadowed by a default modification to `value`. +ModifyBranchesView ModifyBranchesAtValue(const DecisionNodeView& node, + int value) { + if (auto it = node.modify_branch_by_field_match->find(value); + it != node.modify_branch_by_field_match->end()) { + return ModifyBranchesView(it->second); } - for (const auto& [value, branch] : right) { - if (!result.contains(value)) - result[value] = combiner(default_value, branch); + const ModifyBranchMap& defaults = *node.default_branch_by_field_modification; + if (defaults.contains(value) || IsDenyHandle(node.default_branch)) { + return ModifyBranchesView(defaults); } + return ModifyBranchesView(defaults, {value, node.default_branch}); +} + +// Combines two (modify value -> branch) maps key-wise into a new map, using +// `combiner(left_branch, right_branch)` for shared keys and substituting +// `default_value` for the missing side otherwise. +template +ModifyBranchMap CombineModifyBranches(const ModifyBranchesView& left, + const ModifyBranchesView& right, + Combiner&& combiner, + PacketTransformerHandle default_value) { + ModifyBranchMap result; + left.ForEach([&](int value, PacketTransformerHandle left_branch) { + // Keys arrive in near-sorted order, so the `end()` hint is mostly exact. + result.try_emplace( + result.end(), value, + combiner(left_branch, right.Find(value).value_or(default_value))); + }); + right.ForEach([&](int value, PacketTransformerHandle right_branch) { + if (left.Find(value).has_value()) return; + result.try_emplace(value, combiner(default_value, right_branch)); + }); return result; } +// Returns the union of the keys of the given maps, sorted and deduplicated. +template +std::vector SortedUniqueKeys(const Maps&... maps) { + std::vector keys; + keys.reserve((maps.size() + ... + 0)); + auto append_keys = [&keys](const auto& map) { + for (const auto& entry : map) keys.push_back(entry.first); + }; + (append_keys(maps), ...); + absl::c_sort(keys); + keys.erase(std::unique(keys.begin(), keys.end()), keys.end()); + return keys; +} + } // namespace -PacketTransformerHandle PacketTransformerManager::Sequence(DecisionNode left, - DecisionNode right) { - // left.field > right.field: Expand the left node, reducing to the inductive - // case. - if (left.field > right.field) { - PacketFieldHandle first_field = right.field; - return Sequence( - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(left)), - }, - std::move(right)); - } +PacketTransformerHandle PacketTransformerManager::Sequence( + PacketTransformerHandle left, PacketTransformerHandle right) { + // Base cases. + if (IsDeny(left) || IsDeny(right)) return Deny(); + if (IsAccept(left)) return right; + if (IsAccept(right)) return left; - // left.field < right.field: Expand the right node, reducing to the - // inductive case. - if (left.field < right.field) { - PacketFieldHandle first_field = left.field; - return Sequence(std::move(left), - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(right)), - }); - } + // Both operands are decision nodes. + const auto [field, left_view, right_view] = ViewOperandsAtSmallestField( + left, &GetNodeOrDie(left), right, &GetNodeOrDie(right)); + + // Unions `branch` into `branches[modify_value]`, treating absence as Deny. + auto union_into = [this](ModifyBranchMap& branches, int modify_value, + PacketTransformerHandle branch) { + auto [it, inserted] = branches.try_emplace(modify_value, branch); + if (!inserted) it->second = Union(it->second, branch); + }; - // left.field == right.field: branch on shared field. - DCHECK(left.field == right.field); DecisionNode result_node{ - .field = left.field, - .default_branch = Sequence(left.default_branch, right.default_branch), + .field = field, + .default_branch = + Sequence(left_view.default_branch, right_view.default_branch), }; // Construct the possible results of applying the right node to packets - // gotten by taken default modification branches in the left node. - absl::btree_map - right_applied_to_left_modifications; + // gotten by taken default modification branches in the left node... + ModifyBranchMap& result_modifications = + result_node.default_branch_by_field_modification; for (const auto& [value, branch] : - left.default_branch_by_field_modification) { - absl::btree_map right_at_value_with_sequence = - CombineModifyBranches( - {}, GetMapAtValue(right, value), - /*combiner=*/ - [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Sequence(left, right); - }, - /*default_value=*/branch); - right_applied_to_left_modifications = CombineModifyBranches( - right_applied_to_left_modifications, right_at_value_with_sequence, - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, - /*default_value=*/Deny()); + *left_view.default_branch_by_field_modification) { + ModifyBranchesAtValue(right_view, value) + .ForEach([&, branch = branch](int modify_value, + PacketTransformerHandle right_branch) { + union_into(result_modifications, modify_value, + Sequence(branch, right_branch)); + }); + } + // ... and of applying the right node's default modifications to packets + // falling through the left node unmodified. + for (const auto& [modify_value, right_branch] : + *right_view.default_branch_by_field_modification) { + union_into(result_modifications, modify_value, + Sequence(left_view.default_branch, right_branch)); } - result_node.default_branch_by_field_modification = CombineModifyBranches( - right_applied_to_left_modifications, - CombineModifyBranches( - {}, right.default_branch_by_field_modification, - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Sequence(left, right); - }, - /*default_value=*/left.default_branch), - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, - /*default_value=*/Deny()); - - // Collect every value mapped in each node. - absl::flat_hash_set all_possible_values; - all_possible_values.reserve( - left.modify_branch_by_field_match.size() + - right.modify_branch_by_field_match.size() + - left.default_branch_by_field_modification.size() + - right.default_branch_by_field_modification.size() + - right_applied_to_left_modifications.size()); - - absl::c_transform( - left.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - left.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right_applied_to_left_modifications, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - - // For every value in mapped in each node, construct the proper new branch. - for (int value : all_possible_values) { - auto left_map_at_value = GetMapAtValue(left, value); + // For every value mapped in either node, construct the proper new branch. + for (int value : + SortedUniqueKeys(*left_view.modify_branch_by_field_match, + *right_view.modify_branch_by_field_match, + *left_view.default_branch_by_field_modification, + *right_view.default_branch_by_field_modification, + result_modifications)) { + ModifyBranchesView left_branches_at_value = + ModifyBranchesAtValue(left_view, value); // An empty map is equivalent to a map with a single entry of // , but the latter is not always canonical. However, an // empty map won't work correctly for the merges below (an in fact, the // whole for-loop would be skipped), so we expand it here if necessary. - if (left_map_at_value.empty()) left_map_at_value[value] = Deny(); - - for (const auto& [left_value, left_spp] : left_map_at_value) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - result_node.modify_branch_by_field_match[value], - CombineModifyBranches( - {}, GetMapAtValue(right, left_value), - /*combiner=*/ - [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Sequence(left, right); - }, - /*default_value=*/left_spp), - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, - /*default_value=*/Deny()); + if (left_branches_at_value.empty()) { + left_branches_at_value = + ModifyBranchesView(EmptyModifyBranchMap(), {value, Deny()}); } + + // `value`s arrive in increasing order, so inserting at `end()` is O(1). + ModifyBranchMap& result_branches = + result_node.modify_branch_by_field_match + .try_emplace(result_node.modify_branch_by_field_match.end(), value) + ->second; + left_branches_at_value.ForEach([&](int left_value, + PacketTransformerHandle left_branch) { + ModifyBranchesAtValue(right_view, left_value) + .ForEach([&](int modify_value, PacketTransformerHandle right_branch) { + union_into(result_branches, modify_value, + Sequence(left_branch, right_branch)); + }); + }); } return NodeToTransformer(std::move(result_node)); } -PacketTransformerHandle PacketTransformerManager::Sequence( - PacketTransformerHandle left, PacketTransformerHandle right) { - // Base cases. - if (IsDeny(left) || IsDeny(right)) return Deny(); - if (IsAccept(left)) return right; - if (IsAccept(right)) return left; - - // If neither node is accept or deny, then sequence the nodes directly. - return Sequence(GetNodeOrDie(left), GetNodeOrDie(right)); -} - -PacketTransformerHandle PacketTransformerManager::Union(DecisionNode left, - DecisionNode right) { - // left.field > right.field: Expand the left node, reducing to the inductive - // case. - if (left.field > right.field) { - PacketFieldHandle first_field = right.field; - return Union( - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(left)), - }, - std::move(right)); - } - - // left.field < right.field: Expand the right node, reducing to the - // inductive case. - if (left.field < right.field) { - PacketFieldHandle first_field = left.field; - return Union(std::move(left), - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(right)), - }); - } +template +PacketTransformerHandle PacketTransformerManager::PointwiseCombine( + PacketTransformerHandle left, PacketTransformerHandle right, + Combiner&& combiner) { + // Neither operand is Deny and at most one is Accept, so at least one is a + // decision node. + const auto [field, left_view, right_view] = ViewOperandsAtSmallestField( + left, IsAccept(left) ? nullptr : &GetNodeOrDie(left), right, + IsAccept(right) ? nullptr : &GetNodeOrDie(right)); - // left.field == right.field: branch on shared field. - DCHECK(left.field == right.field); DecisionNode result_node{ - .field = left.field, + .field = field, .default_branch_by_field_modification = CombineModifyBranches( - left.default_branch_by_field_modification, - right.default_branch_by_field_modification, - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, + ModifyBranchesView(*left_view.default_branch_by_field_modification), + ModifyBranchesView(*right_view.default_branch_by_field_modification), + combiner, /*default_value=*/Deny()), - .default_branch = Union(left.default_branch, right.default_branch), + .default_branch = + combiner(left_view.default_branch, right_view.default_branch), }; - // Collect every value in mapped in each node. - absl::flat_hash_set all_possible_values; - all_possible_values.reserve( - left.modify_branch_by_field_match.size() + - right.modify_branch_by_field_match.size() + - left.default_branch_by_field_modification.size() + - right.default_branch_by_field_modification.size()); - - absl::c_transform( - left.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - left.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - - // For every value in mapped in each node, construct the proper new branch. - // TODO(dilo): Would like to use absl::bind_front here instead of a lambda: - // absl::bind_front( - // &PacketTransformerManager::Union, this), - for (int value : all_possible_values) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - GetMapAtValue(left, value), GetMapAtValue(right, value), - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, - /*default_value=*/Deny()); + // For every value mapped in either node, construct the proper new branch. + for (int value : + SortedUniqueKeys(*left_view.modify_branch_by_field_match, + *right_view.modify_branch_by_field_match, + *left_view.default_branch_by_field_modification, + *right_view.default_branch_by_field_modification)) { + // `value`s arrive in increasing order, so inserting at `end()` is O(1). + result_node.modify_branch_by_field_match.try_emplace( + result_node.modify_branch_by_field_match.end(), value, + CombineModifyBranches(ModifyBranchesAtValue(left_view, value), + ModifyBranchesAtValue(right_view, value), + combiner, + /*default_value=*/Deny())); } return NodeToTransformer(std::move(result_node)); @@ -631,99 +671,11 @@ PacketTransformerHandle PacketTransformerManager::Union( if (IsDeny(right)) return left; if (IsDeny(left)) return right; - // If either node is accept, then expand it before merging. - if (IsAccept(left) || IsAccept(right)) { - const DecisionNode& other_node = - GetNodeOrDie(IsAccept(left) ? right : left); - return Union( - DecisionNode{ - .field = other_node.field, - .default_branch = Accept(), - }, - other_node); - } - - // If neither node is accept or deny, then union the nodes directly. - return Union(GetNodeOrDie(left), GetNodeOrDie(right)); -} - -PacketTransformerHandle PacketTransformerManager::Difference( - DecisionNode left, DecisionNode right) { - // left.field > right.field: Expand the left node, reducing to the inductive - // case. - if (left.field > right.field) { - PacketFieldHandle first_field = right.field; - return Difference( - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(left)), - }, - std::move(right)); - } - - // left.field < right.field: Expand the right node, reducing to the - // inductive case. - if (left.field < right.field) { - PacketFieldHandle first_field = left.field; - return Difference(std::move(left), - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(right)), - }); - } - - // left.field == right.field: branch on shared field. - DCHECK(left.field == right.field); - DecisionNode result_node{ - .field = left.field, - .default_branch_by_field_modification = CombineModifyBranches( - left.default_branch_by_field_modification, - right.default_branch_by_field_modification, - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Difference(left, right); - }, - /*default_value=*/Deny()), - .default_branch = Difference(left.default_branch, right.default_branch), - }; - - // Collect every value in mapped in each node. - absl::flat_hash_set all_possible_values; - all_possible_values.reserve( - left.modify_branch_by_field_match.size() + - right.modify_branch_by_field_match.size() + - left.default_branch_by_field_modification.size() + - right.default_branch_by_field_modification.size()); - - absl::c_transform( - left.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - left.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - - // For every value in mapped in each node, construct the proper new branch. - for (int value : all_possible_values) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - GetMapAtValue(left, value), GetMapAtValue(right, value), - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Difference(left, right); - }, - /*default_value=*/Deny()); - } - - return NodeToTransformer(std::move(result_node)); + return PointwiseCombine( + left, right, + [this](PacketTransformerHandle left, PacketTransformerHandle right) { + return Union(left, right); + }); } PacketTransformerHandle PacketTransformerManager::Difference( @@ -733,27 +685,11 @@ PacketTransformerHandle PacketTransformerManager::Difference( if (IsDeny(left)) return Deny(); if (IsDeny(right)) return left; - // If either node is accept, then expand it before merging. - if (IsAccept(left)) { - const DecisionNode& right_node = GetNodeOrDie(right); - return Difference( - DecisionNode{ - .field = right_node.field, - .default_branch = Accept(), - }, - right_node); - } - - if (IsAccept(right)) { - const DecisionNode& left_node = GetNodeOrDie(left); - return Difference(left_node, DecisionNode{ - .field = left_node.field, - .default_branch = Accept(), - }); - } - - // If neither node is accept or deny, then difference the nodes directly. - return Difference(GetNodeOrDie(left), GetNodeOrDie(right)); + return PointwiseCombine( + left, right, + [this](PacketTransformerHandle left, PacketTransformerHandle right) { + return Difference(left, right); + }); } PacketTransformerHandle PacketTransformerManager::Iterate( diff --git a/netkat/packet_transformer.h b/netkat/packet_transformer.h index 4c9c90d..ed0e074 100644 --- a/netkat/packet_transformer.h +++ b/netkat/packet_transformer.h @@ -404,17 +404,16 @@ class PacketTransformerManager { // enough to avoid excessive memory overhead. static constexpr size_t kPageSize = (1 << 26) / sizeof(DecisionNode); - // Helper functions to deal with DecisionNodes directly. - // TODO(dilo): Is there a convenient way to either avoid these or avoid making - // copies of the nodes? - PacketTransformerHandle Union(DecisionNode left, DecisionNode right); - PacketTransformerHandle Sequence(DecisionNode left, DecisionNode right); - PacketTransformerHandle Difference(DecisionNode left, DecisionNode right); - - // Internal helper function to get a map of possible modification values to - // branches for a given input value at `node`. - absl::btree_map GetMapAtValue( - const DecisionNode& node, int value); + // Shared implementation of `Union` and `Difference`, which differ only in + // their base cases and the operation applied to corresponding branches: + // combines `left` and `right` by applying `combiner` — which must be the + // handle-level operation itself, e.g. `Union` — to corresponding branches. + // Both operands must be Accept or decision nodes (i.e., the base cases must + // already have been handled). + template + PacketTransformerHandle PointwiseCombine(PacketTransformerHandle left, + PacketTransformerHandle right, + Combiner&& combiner); // The decision nodes forming the BDD-style DAG representation of packets. // `PacketTransformerHandle::node_index_` indexes into this vector. diff --git a/netkat/packet_transformer_test.expected b/netkat/packet_transformer_test.expected index 9f01d94..ca6b248 100644 --- a/netkat/packet_transformer_test.expected +++ b/netkat/packet_transformer_test.expected @@ -42,22 +42,22 @@ digraph { Test case: p := (a=5 + b=2);(b:=1 + c=5). Example from Katch paper Fig 5. ================================================================================ -- STRING ---------------------------------------------------------------------- -PacketTransformerHandle<8>: +PacketTransformerHandle<3>: PacketFieldHandle<0>:'a' == 5: - PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<6> + PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<1> PacketFieldHandle<0>:'a' == *: - PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<7> -PacketTransformerHandle<6>: + PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<2> +PacketTransformerHandle<1>: PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<5> -PacketTransformerHandle<7>: + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<0> +PacketTransformerHandle<2>: PacketFieldHandle<1>:'b' == 2: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle - PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<5> + PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<0> PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<5>: +PacketTransformerHandle<0>: PacketFieldHandle<2>:'c' == 5: PacketFieldHandle<2>:'c' := 5 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == *: @@ -68,50 +68,50 @@ digraph { edge [fontsize = 12] 4294967294 [label="T" shape=box] 4294967295 [label="F" shape=box] - 8 [label="a"] - 8 -> 6 [label="a==5; a:=5"] - 8 -> 7 [style=dashed] - 6 [label="b"] - 6 -> 4294967294 [label="b:=1" style=dashed] - 6 -> 5 [style=dashed] - 7 [label="b"] - 7 -> 4294967294 [label="b==2; b:=1"] - 7 -> 5 [label="b==2; b:=2"] - 7 -> 4294967295 [style=dashed] - 5 [label="c"] - 5 -> 4294967294 [label="c==5; c:=5"] - 5 -> 4294967295 [style=dashed] + 3 [label="a"] + 3 -> 1 [label="a==5; a:=5"] + 3 -> 2 [style=dashed] + 1 [label="b"] + 1 -> 4294967294 [label="b:=1" style=dashed] + 1 -> 0 [style=dashed] + 2 [label="b"] + 2 -> 4294967294 [label="b==2; b:=1"] + 2 -> 0 [label="b==2; b:=2"] + 2 -> 4294967295 [style=dashed] + 0 [label="c"] + 0 -> 4294967294 [label="c==5; c:=5"] + 0 -> 4294967295 [style=dashed] } ================================================================================ Test case: q := (b=1 + c:=4 + a:=5;b:=1). Example from Katch paper Fig 5. ================================================================================ -- STRING ---------------------------------------------------------------------- -PacketTransformerHandle<17>: +PacketTransformerHandle<5>: PacketFieldHandle<0>:'a' == 1: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<14> + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<2> PacketFieldHandle<0>:'a' == *: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<4> - PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<16> -PacketTransformerHandle<14>: + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<3> + PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<4> +PacketTransformerHandle<2>: PacketFieldHandle<1>:'b' == 1: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<10> -PacketTransformerHandle<4>: + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<1> +PacketTransformerHandle<3>: PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<16>: +PacketTransformerHandle<4>: PacketFieldHandle<1>:'b' == 1: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> PacketFieldHandle<1>:'b' == *: - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<10> -PacketTransformerHandle<13>: + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<1> +PacketTransformerHandle<0>: PacketFieldHandle<2>:'c' == *: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == * -> PacketTransformerHandle -PacketTransformerHandle<10>: +PacketTransformerHandle<1>: PacketFieldHandle<2>:'c' == *: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == * -> PacketTransformerHandle @@ -121,64 +121,64 @@ digraph { edge [fontsize = 12] 4294967294 [label="T" shape=box] 4294967295 [label="F" shape=box] - 17 [label="a"] - 17 -> 14 [label="a==1; a:=1"] - 17 -> 4 [label="a:=1" style=dashed] - 17 -> 16 [style=dashed] - 14 [label="b"] - 14 -> 13 [label="b==1; b:=1"] - 14 -> 4294967294 [label="b:=1" style=dashed] - 14 -> 10 [style=dashed] + 5 [label="a"] + 5 -> 2 [label="a==1; a:=1"] + 5 -> 3 [label="a:=1" style=dashed] + 5 -> 4 [style=dashed] + 2 [label="b"] + 2 -> 0 [label="b==1; b:=1"] + 2 -> 4294967294 [label="b:=1" style=dashed] + 2 -> 1 [style=dashed] + 3 [label="b"] + 3 -> 4294967294 [label="b:=1" style=dashed] + 3 -> 4294967295 [style=dashed] 4 [label="b"] - 4 -> 4294967294 [label="b:=1" style=dashed] - 4 -> 4294967295 [style=dashed] - 16 [label="b"] - 16 -> 13 [label="b==1; b:=1"] - 16 -> 10 [style=dashed] - 13 [label="c"] - 13 -> 4294967294 [label="c:=4" style=dashed] - 13 -> 4294967294 [style=dashed] - 10 [label="c"] - 10 -> 4294967294 [label="c:=4" style=dashed] - 10 -> 4294967295 [style=dashed] + 4 -> 0 [label="b==1; b:=1"] + 4 -> 1 [style=dashed] + 0 [label="c"] + 0 -> 4294967294 [label="c:=4" style=dashed] + 0 -> 4294967294 [style=dashed] + 1 [label="c"] + 1 -> 4294967294 [label="c:=4" style=dashed] + 1 -> 4294967295 [style=dashed] } ================================================================================ Test case: p;q. Example from Katch paper Fig 5. ================================================================================ -- STRING ---------------------------------------------------------------------- -PacketTransformerHandle<22>: +PacketTransformerHandle<6>: PacketFieldHandle<0>:'a' == 1: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<19> + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<2> PacketFieldHandle<0>:'a' == 5: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<4> - PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<21> + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<3> + PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<4> PacketFieldHandle<0>:'a' == *: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<20> - PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<19> -PacketTransformerHandle<19>: + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<5> + PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<2> +PacketTransformerHandle<2>: PacketFieldHandle<1>:'b' == 2: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> - PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<18> + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> + PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<1> PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<4>: +PacketTransformerHandle<3>: PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<21>: +PacketTransformerHandle<4>: PacketFieldHandle<1>:'b' == *: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<18> -PacketTransformerHandle<20>: + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<1> +PacketTransformerHandle<5>: PacketFieldHandle<1>:'b' == 2: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<13>: +PacketTransformerHandle<0>: PacketFieldHandle<2>:'c' == *: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == * -> PacketTransformerHandle -PacketTransformerHandle<18>: +PacketTransformerHandle<1>: PacketFieldHandle<2>:'c' == 5: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == *: @@ -189,29 +189,29 @@ digraph { edge [fontsize = 12] 4294967294 [label="T" shape=box] 4294967295 [label="F" shape=box] - 22 [label="a"] - 22 -> 19 [label="a==1; a:=1"] - 22 -> 4 [label="a==5; a:=1"] - 22 -> 21 [label="a==5; a:=5"] - 22 -> 20 [label="a:=1" style=dashed] - 22 -> 19 [style=dashed] - 19 [label="b"] - 19 -> 13 [label="b==2; b:=1"] - 19 -> 18 [label="b==2; b:=2"] - 19 -> 4294967295 [style=dashed] + 6 [label="a"] + 6 -> 2 [label="a==1; a:=1"] + 6 -> 3 [label="a==5; a:=1"] + 6 -> 4 [label="a==5; a:=5"] + 6 -> 5 [label="a:=1" style=dashed] + 6 -> 2 [style=dashed] + 2 [label="b"] + 2 -> 0 [label="b==2; b:=1"] + 2 -> 1 [label="b==2; b:=2"] + 2 -> 4294967295 [style=dashed] + 3 [label="b"] + 3 -> 4294967294 [label="b:=1" style=dashed] + 3 -> 4294967295 [style=dashed] 4 [label="b"] - 4 -> 4294967294 [label="b:=1" style=dashed] - 4 -> 4294967295 [style=dashed] - 21 [label="b"] - 21 -> 13 [label="b:=1" style=dashed] - 21 -> 18 [style=dashed] - 20 [label="b"] - 20 -> 4294967294 [label="b==2; b:=1"] - 20 -> 4294967295 [style=dashed] - 13 [label="c"] - 13 -> 4294967294 [label="c:=4" style=dashed] - 13 -> 4294967294 [style=dashed] - 18 [label="c"] - 18 -> 4294967294 [label="c==5; c:=4"] - 18 -> 4294967295 [style=dashed] + 4 -> 0 [label="b:=1" style=dashed] + 4 -> 1 [style=dashed] + 5 [label="b"] + 5 -> 4294967294 [label="b==2; b:=1"] + 5 -> 4294967295 [style=dashed] + 0 [label="c"] + 0 -> 4294967294 [label="c:=4" style=dashed] + 0 -> 4294967294 [style=dashed] + 1 [label="c"] + 1 -> 4294967294 [label="c==5; c:=4"] + 1 -> 4294967295 [style=dashed] } diff --git a/netkat/packet_transformer_test_runner.cc b/netkat/packet_transformer_test_runner.cc index 653c320..e6a4ad7 100644 --- a/netkat/packet_transformer_test_runner.cc +++ b/netkat/packet_transformer_test_runner.cc @@ -21,10 +21,113 @@ #include #include +#include "absl/container/btree_set.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "netkat/netkat_proto_constructors.h" +#include "netkat/packet_field.h" #include "netkat/packet_transformer.h" namespace netkat { + +// Test peer to access `PacketTransformerManager` internals, allowing this +// runner to renumber nodes canonically before printing. +class PacketTransformerManagerTestPeer { + public: + // Copies `transformer` into the fresh manager `to`, interning fields and + // nodes in a deterministic traversal order. This makes the printed handle + // numbers depend only on the structure of `transformer` — not on the + // interning order of `from`, which is an implementation detail that changes + // under refactorings of the manager (e.g. reordering its operations). + static PacketTransformerHandle CanonicalCopy( + const PacketTransformerManager& from, PacketTransformerHandle transformer, + PacketTransformerManager& to) { + // Intern the fields of all reachable nodes in their original relative + // order, which the node invariants ("fields increase strictly along each + // path") depend on, and record the handle translation for `Copy`. + absl::btree_set fields; + absl::flat_hash_set visited; + CollectFields(from, transformer, visited, fields); + absl::flat_hash_map field_translation; + for (PacketFieldHandle field : fields) { + field_translation.try_emplace( + field, + to.packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( + from.packet_set_manager_.field_manager_.GetFieldName(field))); + } + + absl::flat_hash_map + copy_by_original; + return Copy(from, transformer, to, field_translation, copy_by_original); + } + + private: + static void CollectFields( + const PacketTransformerManager& from, PacketTransformerHandle transformer, + absl::flat_hash_set& visited, + absl::btree_set& fields) { + if (from.IsDeny(transformer) || from.IsAccept(transformer)) return; + if (!visited.insert(transformer).second) return; + const PacketTransformerManager::DecisionNode& node = + from.GetNodeOrDie(transformer); + fields.insert(node.field); + for (const auto& [match_value, branch_by_modify] : + node.modify_branch_by_field_match) { + for (const auto& [modify_value, branch] : branch_by_modify) { + CollectFields(from, branch, visited, fields); + } + } + for (const auto& [modify_value, branch] : + node.default_branch_by_field_modification) { + CollectFields(from, branch, visited, fields); + } + CollectFields(from, node.default_branch, visited, fields); + } + + static PacketTransformerHandle Copy( + const PacketTransformerManager& from, PacketTransformerHandle transformer, + PacketTransformerManager& to, + const absl::flat_hash_map& + field_translation, + absl::flat_hash_map& + copy_by_original) { + if (from.IsDeny(transformer)) return to.Deny(); + if (from.IsAccept(transformer)) return to.Accept(); + if (auto it = copy_by_original.find(transformer); + it != copy_by_original.end()) { + return it->second; + } + + const PacketTransformerManager::DecisionNode& node = + from.GetNodeOrDie(transformer); + PacketTransformerManager::DecisionNode copy{ + .field = field_translation.at(node.field), + }; + for (const auto& [match_value, branch_by_modify] : + node.modify_branch_by_field_match) { + // `operator[]` keeps entries with empty branch maps, which are + // meaningful: they deny packets matching `match_value`. + auto& copy_branch_by_modify = + copy.modify_branch_by_field_match[match_value]; + for (const auto& [modify_value, branch] : branch_by_modify) { + copy_branch_by_modify[modify_value] = + Copy(from, branch, to, field_translation, copy_by_original); + } + } + for (const auto& [modify_value, branch] : + node.default_branch_by_field_modification) { + copy.default_branch_by_field_modification[modify_value] = + Copy(from, branch, to, field_translation, copy_by_original); + } + copy.default_branch = Copy(from, node.default_branch, to, field_translation, + copy_by_original); + + PacketTransformerHandle result = to.NodeToTransformer(std::move(copy)); + copy_by_original.emplace(transformer, result); + return result; + } +}; + namespace { constexpr char kBanner[] = @@ -91,16 +194,22 @@ std::vector TestCases() { } void main() { - // This test needs a deterministic field interning order, and thus must start - // from a fresh manager. PacketTransformerManager manager; for (const TestCase& test_case : TestCases()) { netkat::PacketTransformerHandle packet_transformer = manager.Compile(test_case.policy); + // Renumber nodes and fields canonically, in a fresh manager per test case, + // so that the printed output depends only on the structure of the compiled + // transformer and not on the interning order of `manager`. + PacketTransformerManager canonical_manager; + netkat::PacketTransformerHandle canonical_transformer = + PacketTransformerManagerTestPeer::CanonicalCopy( + manager, packet_transformer, canonical_manager); std::cout << kBanner << "Test case: " << test_case.description << std::endl << kBanner; - std::cout << kStringHeader << manager.ToString(packet_transformer); - std::cout << kDotHeader << manager.ToDot(packet_transformer); + std::cout << kStringHeader + << canonical_manager.ToString(canonical_transformer); + std::cout << kDotHeader << canonical_manager.ToDot(canonical_transformer); } }