diff --git a/netkat/BUILD.bazel b/netkat/BUILD.bazel index 9121fef..7ceec4c 100644 --- a/netkat/BUILD.bazel +++ b/netkat/BUILD.bazel @@ -132,6 +132,7 @@ cc_library( "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -375,7 +376,7 @@ cc_library( "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -427,6 +428,7 @@ cc_binary( deps = [ ":netkat_cc_proto", ":netkat_proto_constructors", + ":packet_set", ":packet_transformer", "@com_google_absl//absl/strings", "@com_google_benchmark//:benchmark_main", @@ -440,7 +442,11 @@ cc_test( linkstatic = True, deps = [ ":netkat_proto_constructors", + ":packet_field", ":packet_transformer", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) diff --git a/netkat/packet_set.cc b/netkat/packet_set.cc index 2c7bd1c..fc111fe 100644 --- a/netkat/packet_set.cc +++ b/netkat/packet_set.cc @@ -24,6 +24,7 @@ #include "absl/algorithm/container.h" #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -87,6 +88,14 @@ const PacketSetManager::DecisionNode& PacketSetManager::GetNodeOrDie( return nodes_[packet_set.node_index_]; } +size_t PacketSetManager::NodeHash::operator()(const DecisionNode* node) const { + return absl::HashOf(*node); +} + +size_t PacketSetManager::NodeHash::operator()(const DecisionNode& node) const { + return absl::HashOf(node); +} + PacketSetHandle PacketSetManager::NodeToPacket(DecisionNode&& node) { if (node.branch_by_field_value.empty()) return node.default_branch; @@ -109,16 +118,19 @@ PacketSetHandle PacketSetManager::NodeToPacket(DecisionNode&& node) { } #endif - auto [it, inserted] = - packet_by_node_.try_emplace(node, PacketSetHandle(nodes_.size())); - if (inserted) { - nodes_.push_back(std::move(node)); - LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) - << "Internal invariant violated: Proper and sentinel node indices must " - "be disjoint. This indicates that we allocated more nodes than are " - "supported (> 2^32 - 2)."; + // Look up the node by value via the transparent `NodeHash`/`NodeEq` + // functors; only new nodes get stored (exactly once, in `nodes_`). + if (auto it = packet_by_node_.find(node); it != packet_by_node_.end()) { + return it->second; } - return it->second; + PacketSetHandle packet(nodes_.size()); + nodes_.push_back(std::move(node)); + packet_by_node_.insert({&nodes_[packet.node_index_], packet}); + LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) + << "Internal invariant violated: Proper and sentinel node indices must " + "be disjoint. This indicates that we allocated more nodes than are " + "supported (> 2^32 - 2)."; + return packet; } bool PacketSetManager::Contains(PacketSetHandle packet_set, @@ -483,14 +495,19 @@ absl::Status PacketSetManager::CheckInternalInvariants() const { // Invariant: Proper and sentinel node indices are disjoint. RET_CHECK(nodes_.size() <= SentinelNodeIndex::kMinSentinel); - // Invariant: `packet_by_node_[n] = s` iff `nodes_[s.node_index_] == n`. - for (const auto& [node, packet] : packet_by_node_) { + // Invariant: `packet_by_node_[p] = s` iff `p == &nodes_[s.node_index_]`. + for (const auto& [node_ptr, packet] : packet_by_node_) { RET_CHECK(packet.node_index_ < nodes_.size()); - RET_CHECK(nodes_[packet.node_index_] == node); + RET_CHECK(node_ptr == &nodes_[packet.node_index_]); } for (int i = 0; i < nodes_.size(); ++i) { const DecisionNode& node = nodes_[i]; - auto it = packet_by_node_.find(node); + // Look up both by pointer and by value (exercising the transparent + // functors used by `NodeToPacket`). + auto it = packet_by_node_.find(&node); + RET_CHECK(it != packet_by_node_.end()); + RET_CHECK(it->second == PacketSetHandle(i)); + it = packet_by_node_.find(node); RET_CHECK(it != packet_by_node_.end()); RET_CHECK(it->second == PacketSetHandle(i)); } diff --git a/netkat/packet_set.h b/netkat/packet_set.h index f11088f..696ae55 100644 --- a/netkat/packet_set.h +++ b/netkat/packet_set.h @@ -356,11 +356,38 @@ class PacketSetManager { // `And`, `Or`, `Not`). The class also avoids expensive relocations. PagedStableVector nodes_; + // Transparent hash and equality functors for the unique table + // (`packet_by_node_`), which is keyed by stable `DecisionNode*` pointers + // into `nodes_` (so each node is stored only once). Lookups work directly + // with a not-yet-interned `DecisionNode` value. Both functors are + // stateless: keys are pointers, and the pages holding the nodes are stable + // across moves of the manager. + struct NodeHash { + using is_transparent = void; + size_t operator()(const DecisionNode* node) const; + size_t operator()(const DecisionNode& node) const; + }; + struct NodeEq { + using is_transparent = void; + bool operator()(const DecisionNode* a, const DecisionNode* b) const { + return a == b || *a == *b; + } + bool operator()(const DecisionNode* a, const DecisionNode& b) const { + return *a == b; + } + bool operator()(const DecisionNode& a, const DecisionNode* b) const { + return a == *b; + } + }; + // A so called "unique table" to ensure each node is only added to `nodes_` // once, and thus has a unique `PacketSetHandle::node_index`. + // Keyed by pointers into `nodes_` (stable, see `PagedStableVector`), so + // nodes are not stored twice. // - // INVARIANT: `packet_by_node_[n] = s` iff `nodes_[s.node_index_] == n`. - absl::flat_hash_map packet_by_node_; + // INVARIANT: `packet_by_node_[p] = s` iff `p == &nodes_[s.node_index_]`. + absl::flat_hash_map + packet_by_node_; // A map of a given `PredicateProto` to a `PacketSetHandle`. // diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index a125c8b..c093d63 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -14,8 +14,8 @@ #include "netkat/packet_transformer.h" +#include #include -#include #include #include #include @@ -28,7 +28,7 @@ #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/functional/any_invocable.h" +#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -78,25 +78,132 @@ PacketTransformerManager::GetNodeOrDie( return nodes_[transformer.node_index_]; } -// TODO(dilo): Creating as many map copies as this method facilitates is -// probably going to cause terrible performance, and needs to be revisited. -absl::btree_map -PacketTransformerManager::GetMapAtValue(const DecisionNode& node, int value) { - if (node.modify_branch_by_field_match.contains(value)) - return node.modify_branch_by_field_match.at(value); +template +bool PacketTransformerManager::ForEachFlatEntry(const DecisionNode& node, + MatchFn&& match, + ModifyFn&& modify) { + for (const auto& [match_value, end_offset] : node.matches) { + if (!match(match_value, end_offset)) return false; + } + for (const auto& [modify_value, branch] : node.modifies) { + if (!modify(modify_value, branch)) return false; + } + return true; +} - absl::btree_map result = - node.default_branch_by_field_modification; - if (result.contains(value) || IsDeny(node.default_branch)) return result; +template +bool PacketTransformerManager::ForEachFlatEntry( + const DecisionNodeBuilder& builder, MatchFn&& match, ModifyFn&& modify) { + uint32_t end_offset = 0; + for (const auto& [match_value, modify_branch_by_value] : + builder.modify_branch_by_field_match) { + end_offset += modify_branch_by_value.size(); + if (!match(match_value, end_offset)) return false; + } + for (const auto& [match_value, modify_branch_by_value] : + builder.modify_branch_by_field_match) { + for (const auto& [modify_value, branch] : modify_branch_by_value) { + if (!modify(modify_value, branch)) return false; + } + } + for (const auto& [modify_value, branch] : + builder.default_branch_by_field_modification) { + if (!modify(modify_value, branch)) return false; + } + return true; +} - // Otherwise, add a mapping from `value` to the default branch, then return. - result[value] = node.default_branch; - return result; +PacketTransformerManager::DecisionNode PacketTransformerManager::Flatten( + DecisionNodeBuilder&& builder) { + size_t num_matches = 0; + size_t num_modifies = 0; + ForEachFlatEntry( + builder, + [&](int, uint32_t) { + ++num_matches; + return true; + }, + [&](int, PacketTransformerHandle) { + ++num_modifies; + return true; + }); + DecisionNode node{ + .field = builder.field, + .default_branch = builder.default_branch, + .matches{num_matches}, + .modifies{num_modifies}, + }; + size_t i = 0; + size_t j = 0; + ForEachFlatEntry( + builder, + [&](int match_value, uint32_t end_offset) { + node.matches[i++] = {match_value, end_offset}; + return true; + }, + [&](int modify_value, PacketTransformerHandle branch) { + node.modifies[j++] = {modify_value, branch}; + return true; + }); + return node; +} + +PacketTransformerManager::DecisionNodeBuilder +PacketTransformerManager::ToBuilder(const DecisionNode& node) { + DecisionNodeBuilder builder{ + .field = node.field, + .default_branch = node.default_branch, + }; + for (const DecisionNode::Match& match : node.Matches()) { + absl::btree_map& modify_branch_by_value = + builder.modify_branch_by_field_match[match.value]; + for (const DecisionNode::ModifyEntry& entry : match.modifies) { + modify_branch_by_value.insert(modify_branch_by_value.end(), entry); + } + } + for (const DecisionNode::ModifyEntry& entry : node.DefaultModifies()) { + builder.default_branch_by_field_modification.insert( + builder.default_branch_by_field_modification.end(), entry); + } + return builder; +} + +size_t PacketTransformerManager::NodeHash::operator()( + const DecisionNode* node) const { + return absl::HashOf(FlatSequenceView{*node}); +} + +size_t PacketTransformerManager::NodeHash::operator()( + const DecisionNodeBuilder& builder) const { + return absl::HashOf(FlatSequenceView{builder}); +} + +bool PacketTransformerManager::NodeEq::operator()( + const DecisionNode* a, const DecisionNodeBuilder& b) const { + if (a->field != b.field || a->default_branch != b.default_branch) { + return false; + } + // Walk b's canonical flat sequence and compare it element-wise against a's. + size_t i = 0; + size_t j = 0; + bool prefixes_equal = ForEachFlatEntry( + b, + [&](int match_value, uint32_t end_offset) { + return i < a->matches.size() && + a->matches[i++] == + std::pair(match_value, end_offset); + }, + [&](int modify_value, PacketTransformerHandle branch) { + return j < a->modifies.size() && + a->modifies[j++] == + DecisionNode::ModifyEntry(modify_value, branch); + }); + return prefixes_equal && i == a->matches.size() && j == a->modifies.size(); } // Canonicalizes a decision node and returns a transformer. PacketTransformerHandle PacketTransformerManager::NodeToTransformer( - DecisionNode&& node) { + DecisionNodeBuilder&& node) { // Remove any default branches pointing to Deny, saving the value. absl::flat_hash_set deny_values; for (const auto& [modify_value, branch] : @@ -138,9 +245,6 @@ PacketTransformerHandle PacketTransformerManager::NodeToTransformer( absl::flat_hash_set redundant_values; for (auto& [match_value, modification_map] : node.modify_branch_by_field_match) { - // TODO(dilo): Consider if this can make use of GetMapAtValue. Perhaps by - // calling it on a copy of DecisionNode without any - // `modify_branch_by_field_match` mappings? if (skip_default_branch || node.default_branch_by_field_modification.contains(match_value)) { // Compare the modification map to the default branch modification map, @@ -172,16 +276,20 @@ PacketTransformerHandle PacketTransformerManager::NodeToTransformer( node.default_branch_by_field_modification.empty()) return node.default_branch; - auto [it, inserted] = transformer_by_node_.try_emplace( - node, PacketTransformerHandle(nodes_.size())); - if (inserted) { - nodes_.push_back(std::move(node)); - LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) - << "Internal invariant violated: Proper and sentinel node indices must " - "be disjoint. This indicates that we allocated more nodes than are " - "supported (> 2^32 - 2)."; + // Look up the builder directly (without flattening) via the transparent + // `NodeHash`/`NodeEq` functors; only new nodes pay for flattening. + if (auto it = transformer_by_node_.find(node); + it != transformer_by_node_.end()) { + return it->second; } - return it->second; + PacketTransformerHandle transformer(nodes_.size()); + nodes_.push_back(Flatten(std::move(node))); + transformer_by_node_.insert({&nodes_[transformer.node_index_], transformer}); + LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) + << "Internal invariant violated: Proper and sentinel node indices must " + "be disjoint. This indicates that we allocated more nodes than are " + "supported (> 2^32 - 2)."; + return transformer; } bool PacketTransformerManager::IsDeny( @@ -239,11 +347,11 @@ absl::flat_hash_set PacketTransformerManager::Run( if (initial_field_value.has_value()) { // If it exists, see if there is a value match for it and follow every // corresponding branch with value modified appropriately. - if (auto mod_map_it = - node.modify_branch_by_field_match.find(*initial_field_value); - mod_map_it != node.modify_branch_by_field_match.end()) { + if (std::optional match_index = + node.FindMatch(*initial_field_value); + match_index.has_value()) { matched = true; - for (const auto& [value, branch] : mod_map_it->second) { + for (const auto& [value, branch] : node.MatchModifies(*match_index)) { result.merge( RunWithNewValueThenReset(*this, branch, packet, field, value)); } @@ -255,8 +363,7 @@ absl::flat_hash_set PacketTransformerManager::Run( if (matched) return result; // Otherwise, follow the default branches. - for (const auto& [value, branch] : - node.default_branch_by_field_modification) { + for (const auto& [value, branch] : node.DefaultModifies()) { // If the original packet already had this field with the same value as // this modified branch, then we should not also attempt the default // branch. @@ -337,7 +444,7 @@ PacketTransformerHandle PacketTransformerManager::FromPacketSetHandle( const PacketSetManager::DecisionNode& packet_node = packet_set_manager_.GetNodeOrDie(packet_set); - DecisionNode transformer_node{ + DecisionNodeBuilder transformer_node{ .field = packet_node.field, // This starts out empty and will be populated below. .modify_branch_by_field_match = {}, @@ -366,7 +473,7 @@ PacketTransformerHandle PacketTransformerManager::Filter( PacketTransformerHandle PacketTransformerManager::Modification( absl::string_view field, int value) { - return NodeToTransformer(DecisionNode{ + return NodeToTransformer(DecisionNodeBuilder{ .field = packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( field), .modify_branch_by_field_match = {}, @@ -376,160 +483,221 @@ PacketTransformerHandle PacketTransformerManager::Modification( } namespace { -absl::btree_map CombineModifyBranches( - const absl::btree_map& left, - const absl::btree_map& right, - absl::AnyInvocable - combiner, - PacketTransformerHandle default_value) { - absl::btree_map result; - for (const auto& [value, branch] : left) { - if (right.contains(value)) { - result[value] = combiner(branch, right.at(value)); - } else { - result[value] = combiner(branch, default_value); - } - } - for (const auto& [value, branch] : right) { - if (!result.contains(value)) - result[value] = combiner(default_value, branch); - } - return result; + +// Alias for the modify-map type of +// `PacketTransformerManager::DecisionNodeBuilder`, used to accumulate +// combinator results before interning. +using ModifyBranchMap = absl::btree_map; + +// The public entry types of the flat `PacketTransformerManager::DecisionNode` +// arrays (respelled here since the node type itself is private). +using ModifyEntry = std::pair; +using MatchEntry = std::pair; + +// Returns true iff `transformer` is the Deny transformer, which by documented +// contract is the default-constructed handle. (Unlike +// `PacketTransformerManager::IsDeny`, this is callable without a manager.) +bool IsDenyHandle(PacketTransformerHandle transformer) { + return transformer == PacketTransformerHandle(); } -} // namespace +// Returns true iff `entries` (sorted by modify value) contains an entry with +// the given modify value. +bool ContainsModifyValue(absl::Span entries, + int modify_value) { + auto it = std::lower_bound( + entries.begin(), entries.end(), modify_value, + [](const ModifyEntry& entry, int value) { return entry.first < value; }); + return it != entries.end() && it->first == modify_value; +} -PacketTransformerHandle PacketTransformerManager::Sequence(DecisionNode left, - DecisionNode right) { - // left.field > right.field: Expand the left node, reducing to the inductive - // case. - if (left.field > right.field) { - PacketFieldHandle first_field = right.field; - return Sequence( - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(left)), - }, - std::move(right)); +// A cheap, non-owning view of a decision node as seen "at" a field no larger +// than the node's own field: either the node itself (if the node branches on +// that field), or the trivial expansion "for every value of the field, leave +// it unmodified and fall through to the node" (if the node branches on a +// strictly larger field, or is Accept). This lets the binary operations below +// combine operands with distinct fields without materializing — and then +// re-interning — the trivial expansion as a `DecisionNode`. +// +// `Node` is always `PacketTransformerManager::DecisionNode`; it is a template +// parameter only because that type is private and cannot be named here. +template +struct DecisionNodeView { + // The viewed node, or null for the trivial expansion. + const Node* node; + + // The "leave field unmodified" fall-through: the node's default branch, or + // the viewed transformer itself for the trivial expansion. + PacketTransformerHandle default_branch; + + // The (match value, end offset) headers; empty for the trivial expansion. + absl::Span Matches() const { + if (node == nullptr) return {}; + return absl::MakeConstSpan(node->matches); } - // left.field < right.field: Expand the right node, reducing to the - // inductive case. - if (left.field < right.field) { - PacketFieldHandle first_field = left.field; - return Sequence(std::move(left), - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(right)), - }); + // The default (modify value -> branch) entries; empty for the trivial + // expansion. + absl::Span DefaultModifies() const { + if (node == nullptr) return {}; + return node->DefaultModifies(); } +}; - // left.field == right.field: branch on shared field. - DCHECK(left.field == right.field); - DecisionNode result_node{ - .field = left.field, - .default_branch = Sequence(left.default_branch, right.default_branch), +// Returns the view of `node` (the decision node of `transformer`, or null if +// `transformer` is Accept) at `field`, which must be <= `node->field`. +template +DecisionNodeView ViewAtField(PacketFieldHandle field, + PacketTransformerHandle transformer, + const Node* node) { + if (node != nullptr && node->field == field) { + return DecisionNodeView{ + .node = node, + .default_branch = node->default_branch, + }; + } + return DecisionNodeView{ + .node = nullptr, + .default_branch = transformer, }; +} - // Construct the possible results of applying the right node to packets - // gotten by taken default modification branches in the left node. - absl::btree_map - right_applied_to_left_modifications; - for (const auto& [value, branch] : - left.default_branch_by_field_modification) { - absl::btree_map right_at_value_with_sequence = - CombineModifyBranches( - {}, GetMapAtValue(right, value), - /*combiner=*/ - [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Sequence(left, right); - }, - /*default_value=*/branch); - right_applied_to_left_modifications = CombineModifyBranches( - right_applied_to_left_modifications, right_at_value_with_sequence, - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, - /*default_value=*/Deny()); +// Returns the smallest field branched on by the given decision nodes, at +// least one of which must be non-null (null encodes Accept, which branches on +// no field). See `DecisionNodeView` regarding the `Node` template parameter. +template +PacketFieldHandle SmallestField(const Node* left, const Node* right) { + DCHECK(left != nullptr || right != nullptr); + if (left == nullptr) return right->field; + if (right == nullptr) return left->field; + return std::min(left->field, right->field); +} + +// The two operands of a binary combinator, viewed at the smallest field +// branched on by either: an operand that is Accept, or branches on a strictly +// larger field, is viewed as a trivial node at that field. +template +struct OperandViews { + PacketFieldHandle field; + DecisionNodeView left; + DecisionNodeView right; +}; + +// Views the operands `left` and `right`, whose decision nodes are `left_node` +// and `right_node` (null encoding Accept; at least one must be non-null), at +// the smallest field branched on by either. +template +OperandViews ViewOperandsAtSmallestField(PacketTransformerHandle left, + const Node* left_node, + PacketTransformerHandle right, + const Node* right_node) { + const PacketFieldHandle field = SmallestField(left_node, right_node); + return OperandViews{ + .field = field, + .left = ViewAtField(field, left, left_node), + .right = ViewAtField(field, right, right_node), + }; +} + +// A cheap, non-owning view of a logical (modify value -> branch) map, +// represented as a base range of entries (sorted by modify value) plus at +// most one extra entry whose key must not be a key of the base range. +class ModifyBranchesView { + public: + using Entry = ModifyEntry; + + explicit ModifyBranchesView(absl::Span base) + : base_(base) {} + ModifyBranchesView(absl::Span base, Entry extra) + : base_(base), extra_(extra) { + DCHECK(!ContainsModifyValue(base, extra.first)); } - result_node.default_branch_by_field_modification = CombineModifyBranches( - right_applied_to_left_modifications, - CombineModifyBranches( - {}, right.default_branch_by_field_modification, - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Sequence(left, right); - }, - /*default_value=*/left.default_branch), - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, - /*default_value=*/Deny()); - - // Collect every value mapped in each node. - absl::flat_hash_set all_possible_values; - all_possible_values.reserve( - left.modify_branch_by_field_match.size() + - right.modify_branch_by_field_match.size() + - left.default_branch_by_field_modification.size() + - right.default_branch_by_field_modification.size() + - right_applied_to_left_modifications.size()); - - absl::c_transform( - left.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - left.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right_applied_to_left_modifications, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - - // For every value in mapped in each node, construct the proper new branch. - for (int value : all_possible_values) { - auto left_map_at_value = GetMapAtValue(left, value); - // An empty map is equivalent to a map with a single entry of - // , but the latter is not always canonical. However, an - // empty map won't work correctly for the merges below (an in fact, the - // whole for-loop would be skipped), so we expand it here if necessary. - if (left_map_at_value.empty()) left_map_at_value[value] = Deny(); - - for (const auto& [left_value, left_spp] : left_map_at_value) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - result_node.modify_branch_by_field_match[value], - CombineModifyBranches( - {}, GetMapAtValue(right, left_value), - /*combiner=*/ - [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Sequence(left, right); - }, - /*default_value=*/left_spp), - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, - /*default_value=*/Deny()); + // Invokes `fn(modify_value, branch)` for each entry: the base entries in + // increasing key order, then the extra entry, if any. + template + void ForEach(Fn&& fn) const { + for (const auto& [value, branch] : base_) fn(value, branch); + if (extra_.has_value()) fn(extra_->first, extra_->second); + } + + std::optional Find(int value) const { + if (extra_.has_value() && extra_->first == value) return extra_->second; + auto it = std::lower_bound( + base_.begin(), base_.end(), value, + [](const ModifyEntry& entry, int v) { return entry.first < v; }); + if (it != base_.end() && it->first == value) return it->second; + return std::nullopt; + } + + bool empty() const { return base_.empty() && !extra_.has_value(); } + + private: + absl::Span base_; + std::optional extra_; +}; + +// Returns a view of the logical (modify value -> branch) map that `node` +// applies to packets whose field is equal to `value`: the matching entry of +// the node's match branches if there is one; otherwise the default +// modifications, plus the unmodified fall-through to `default_branch` (keyed +// by `value`, since the field keeps its value) unless that branch is Deny or +// shadowed by a default modification to `value`. +template +ModifyBranchesView ModifyBranchesAtValue(const DecisionNodeView& view, + int value) { + if (view.node != nullptr) { + if (std::optional match_index = view.node->FindMatch(value); + match_index.has_value()) { + return ModifyBranchesView(view.node->MatchModifies(*match_index)); } } + absl::Span defaults = view.DefaultModifies(); + if (ContainsModifyValue(defaults, value) || + IsDenyHandle(view.default_branch)) { + return ModifyBranchesView(defaults); + } + return ModifyBranchesView(defaults, {value, view.default_branch}); +} - return NodeToTransformer(std::move(result_node)); +// Combines two (modify value -> branch) maps key-wise into a new map, using +// `combiner(left_branch, right_branch)` for shared keys and substituting +// `default_value` for the missing side otherwise. +template +ModifyBranchMap CombineModifyBranches(const ModifyBranchesView& left, + const ModifyBranchesView& right, + Combiner&& combiner, + PacketTransformerHandle default_value) { + ModifyBranchMap result; + left.ForEach([&](int value, PacketTransformerHandle left_branch) { + // Keys arrive in near-sorted order, so the `end()` hint is mostly exact. + result.try_emplace( + result.end(), value, + combiner(left_branch, right.Find(value).value_or(default_value))); + }); + right.ForEach([&](int value, PacketTransformerHandle right_branch) { + if (left.Find(value).has_value()) return; + result.try_emplace(value, combiner(default_value, right_branch)); + }); + return result; } +// Returns the union of the keys of the given maps, sorted and deduplicated. +template +std::vector SortedUniqueKeys(const Maps&... maps) { + std::vector keys; + keys.reserve((maps.size() + ... + 0)); + auto append_keys = [&keys](const auto& map) { + for (const auto& entry : map) keys.push_back(entry.first); + }; + (append_keys(maps), ...); + absl::c_sort(keys); + keys.erase(std::unique(keys.begin(), keys.end()), keys.end()); + return keys; +} + +} // namespace + PacketTransformerHandle PacketTransformerManager::Sequence( PacketTransformerHandle left, PacketTransformerHandle right) { // Base cases. @@ -537,88 +705,109 @@ PacketTransformerHandle PacketTransformerManager::Sequence( if (IsAccept(left)) return right; if (IsAccept(right)) return left; - // If neither node is accept or deny, then sequence the nodes directly. - return Sequence(GetNodeOrDie(left), GetNodeOrDie(right)); -} + // Both operands are decision nodes. + const auto [field, left_view, right_view] = ViewOperandsAtSmallestField( + left, &GetNodeOrDie(left), right, &GetNodeOrDie(right)); -PacketTransformerHandle PacketTransformerManager::Union(DecisionNode left, - DecisionNode right) { - // left.field > right.field: Expand the left node, reducing to the inductive - // case. - if (left.field > right.field) { - PacketFieldHandle first_field = right.field; - return Union( - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(left)), - }, - std::move(right)); + // Unions `branch` into `branches[modify_value]`, treating absence as Deny. + auto union_into = [this](ModifyBranchMap& branches, int modify_value, + PacketTransformerHandle branch) { + auto [it, inserted] = branches.try_emplace(modify_value, branch); + if (!inserted) it->second = Union(it->second, branch); + }; + + DecisionNodeBuilder result_node{ + .field = field, + .default_branch = + Sequence(left_view.default_branch, right_view.default_branch), + }; + + // Construct the possible results of applying the right node to packets + // gotten by taken default modification branches in the left node... + ModifyBranchMap& result_modifications = + result_node.default_branch_by_field_modification; + for (const auto& [value, branch] : left_view.DefaultModifies()) { + ModifyBranchesAtValue(right_view, value) + .ForEach([&, branch = branch](int modify_value, + PacketTransformerHandle right_branch) { + union_into(result_modifications, modify_value, + Sequence(branch, right_branch)); + }); + } + // ... and of applying the right node's default modifications to packets + // falling through the left node unmodified. + for (const auto& [modify_value, right_branch] : + right_view.DefaultModifies()) { + union_into(result_modifications, modify_value, + Sequence(left_view.default_branch, right_branch)); } - // left.field < right.field: Expand the right node, reducing to the - // inductive case. - if (left.field < right.field) { - PacketFieldHandle first_field = left.field; - return Union(std::move(left), - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(right)), - }); + // For every value mapped in either node, construct the proper new branch. + for (int value : + SortedUniqueKeys(left_view.Matches(), right_view.Matches(), + left_view.DefaultModifies(), + right_view.DefaultModifies(), result_modifications)) { + ModifyBranchesView left_branches_at_value = + ModifyBranchesAtValue(left_view, value); + // An empty map is equivalent to a map with a single entry of + // , but the latter is not always canonical. However, an + // empty map won't work correctly for the merges below (an in fact, the + // whole for-loop would be skipped), so we expand it here if necessary. + if (left_branches_at_value.empty()) { + left_branches_at_value = ModifyBranchesView( + /*base=*/{}, /*extra=*/{value, Deny()}); + } + + // `value`s arrive in increasing order, so inserting at `end()` is O(1). + ModifyBranchMap& result_branches = + result_node.modify_branch_by_field_match + .try_emplace(result_node.modify_branch_by_field_match.end(), value) + ->second; + left_branches_at_value.ForEach([&](int left_value, + PacketTransformerHandle left_branch) { + ModifyBranchesAtValue(right_view, left_value) + .ForEach([&](int modify_value, PacketTransformerHandle right_branch) { + union_into(result_branches, modify_value, + Sequence(left_branch, right_branch)); + }); + }); } - // left.field == right.field: branch on shared field. - DCHECK(left.field == right.field); - DecisionNode result_node{ - .field = left.field, + return NodeToTransformer(std::move(result_node)); +} + +template +PacketTransformerHandle PacketTransformerManager::PointwiseCombine( + PacketTransformerHandle left, PacketTransformerHandle right, + Combiner&& combiner) { + // Neither operand is Deny and at most one is Accept, so at least one is a + // decision node. + const auto [field, left_view, right_view] = ViewOperandsAtSmallestField( + left, IsAccept(left) ? nullptr : &GetNodeOrDie(left), right, + IsAccept(right) ? nullptr : &GetNodeOrDie(right)); + + DecisionNodeBuilder result_node{ + .field = field, .default_branch_by_field_modification = CombineModifyBranches( - left.default_branch_by_field_modification, - right.default_branch_by_field_modification, - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, + ModifyBranchesView(left_view.DefaultModifies()), + ModifyBranchesView(right_view.DefaultModifies()), combiner, /*default_value=*/Deny()), - .default_branch = Union(left.default_branch, right.default_branch), + .default_branch = + combiner(left_view.default_branch, right_view.default_branch), }; - // Collect every value in mapped in each node. - absl::flat_hash_set all_possible_values; - all_possible_values.reserve( - left.modify_branch_by_field_match.size() + - right.modify_branch_by_field_match.size() + - left.default_branch_by_field_modification.size() + - right.default_branch_by_field_modification.size()); - - absl::c_transform( - left.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - left.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - - // For every value in mapped in each node, construct the proper new branch. - // TODO(dilo): Would like to use absl::bind_front here instead of a lambda: - // absl::bind_front( - // &PacketTransformerManager::Union, this), - for (int value : all_possible_values) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - GetMapAtValue(left, value), GetMapAtValue(right, value), - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, - /*default_value=*/Deny()); + // For every value mapped in either node, construct the proper new branch. + for (int value : + SortedUniqueKeys(left_view.Matches(), right_view.Matches(), + left_view.DefaultModifies(), + right_view.DefaultModifies())) { + // `value`s arrive in increasing order, so inserting at `end()` is O(1). + result_node.modify_branch_by_field_match.try_emplace( + result_node.modify_branch_by_field_match.end(), value, + CombineModifyBranches(ModifyBranchesAtValue(left_view, value), + ModifyBranchesAtValue(right_view, value), + combiner, + /*default_value=*/Deny())); } return NodeToTransformer(std::move(result_node)); @@ -631,99 +820,11 @@ PacketTransformerHandle PacketTransformerManager::Union( if (IsDeny(right)) return left; if (IsDeny(left)) return right; - // If either node is accept, then expand it before merging. - if (IsAccept(left) || IsAccept(right)) { - const DecisionNode& other_node = - GetNodeOrDie(IsAccept(left) ? right : left); - return Union( - DecisionNode{ - .field = other_node.field, - .default_branch = Accept(), - }, - other_node); - } - - // If neither node is accept or deny, then union the nodes directly. - return Union(GetNodeOrDie(left), GetNodeOrDie(right)); -} - -PacketTransformerHandle PacketTransformerManager::Difference( - DecisionNode left, DecisionNode right) { - // left.field > right.field: Expand the left node, reducing to the inductive - // case. - if (left.field > right.field) { - PacketFieldHandle first_field = right.field; - return Difference( - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(left)), - }, - std::move(right)); - } - - // left.field < right.field: Expand the right node, reducing to the - // inductive case. - if (left.field < right.field) { - PacketFieldHandle first_field = left.field; - return Difference(std::move(left), - DecisionNode{ - .field = first_field, - .default_branch = NodeToTransformer(std::move(right)), - }); - } - - // left.field == right.field: branch on shared field. - DCHECK(left.field == right.field); - DecisionNode result_node{ - .field = left.field, - .default_branch_by_field_modification = CombineModifyBranches( - left.default_branch_by_field_modification, - right.default_branch_by_field_modification, - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Difference(left, right); - }, - /*default_value=*/Deny()), - .default_branch = Difference(left.default_branch, right.default_branch), - }; - - // Collect every value in mapped in each node. - absl::flat_hash_set all_possible_values; - all_possible_values.reserve( - left.modify_branch_by_field_match.size() + - right.modify_branch_by_field_match.size() + - left.default_branch_by_field_modification.size() + - right.default_branch_by_field_modification.size()); - - absl::c_transform( - left.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - left.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - - // For every value in mapped in each node, construct the proper new branch. - for (int value : all_possible_values) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - GetMapAtValue(left, value), GetMapAtValue(right, value), - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Difference(left, right); - }, - /*default_value=*/Deny()); - } - - return NodeToTransformer(std::move(result_node)); + return PointwiseCombine( + left, right, + [this](PacketTransformerHandle left, PacketTransformerHandle right) { + return Union(left, right); + }); } PacketTransformerHandle PacketTransformerManager::Difference( @@ -733,27 +834,11 @@ PacketTransformerHandle PacketTransformerManager::Difference( if (IsDeny(left)) return Deny(); if (IsDeny(right)) return left; - // If either node is accept, then expand it before merging. - if (IsAccept(left)) { - const DecisionNode& right_node = GetNodeOrDie(right); - return Difference( - DecisionNode{ - .field = right_node.field, - .default_branch = Accept(), - }, - right_node); - } - - if (IsAccept(right)) { - const DecisionNode& left_node = GetNodeOrDie(left); - return Difference(left_node, DecisionNode{ - .field = left_node.field, - .default_branch = Accept(), - }); - } - - // If neither node is accept or deny, then difference the nodes directly. - return Difference(GetNodeOrDie(left), GetNodeOrDie(right)); + return PointwiseCombine( + left, right, + [this](PacketTransformerHandle left, PacketTransformerHandle right) { + return Difference(left, right); + }); } PacketTransformerHandle PacketTransformerManager::Iterate( @@ -786,8 +871,7 @@ PacketSetHandle PacketTransformerManager::GetAllPossibleOutputPackets( // Case 1: Output packets that hit the default branch and got modified. // Implements the `b_A` in the `fwd` function in section C.3 Push and Pull // in KATch: A Fast Symbolic Verifier for NetKAT. - for (const auto& [modify_value, branch] : - node.default_branch_by_field_modification) { + for (const auto& [modify_value, branch] : node.DefaultModifies()) { add_to_output_by_field_value(modify_value, GetAllPossibleOutputPackets(branch)); } @@ -796,9 +880,8 @@ PacketSetHandle PacketTransformerManager::GetAllPossibleOutputPackets( // Implements the `b_B` in the `fwd` function in section C.3 Push and Pull // in KATch: A Fast Symbolic Verifier for NetKAT. absl::flat_hash_set branch_modify_values; - for (const auto& [match_value, branch_by_modify] : - node.modify_branch_by_field_match) { - for (const auto& [modify_value, branch] : branch_by_modify) { + for (const DecisionNode::Match& match : node.Matches()) { + for (const auto& [modify_value, branch] : match.modifies) { branch_modify_values.insert(modify_value); add_to_output_by_field_value(modify_value, GetAllPossibleOutputPackets(branch)); @@ -810,8 +893,9 @@ PacketSetHandle PacketTransformerManager::GetAllPossibleOutputPackets( // Implements the `b_C` in the `fwd` function in section C.3 Push and Pull // in KATch: A Fast Symbolic Verifier for NetKAT. for (int modify_value : branch_modify_values) { - if (!node.modify_branch_by_field_match.contains(modify_value) && - !node.default_branch_by_field_modification.contains(modify_value)) { + if (!node.FindMatch(modify_value).has_value() && + !DecisionNode::ContainsModifyValue(node.DefaultModifies(), + modify_value)) { add_to_output_by_field_value(modify_value, default_output); } } @@ -819,7 +903,7 @@ PacketSetHandle PacketTransformerManager::GetAllPossibleOutputPackets( // Case 4: Output packets that got matched on an explicit branch, but did // not get modified. Implements the `b_D` in the `fwd` function in section // C.3 Push and Pull in KATch: A Fast Symbolic Verifier for NetKAT. - for (auto& [match_value, unused] : node.modify_branch_by_field_match) { + for (const auto& [match_value, unused_end_offset] : node.matches) { if (!branch_modify_values.contains(match_value)) { add_to_output_by_field_value(match_value, PacketSetManager().EmptySet()); } @@ -871,8 +955,7 @@ PacketTransformerManager::GetAllInputPacketsThatProduceAnyOutput( // Implements the `d'` in the `bwd` function in section C.3 Push and Pull in // KATch: A Fast Symbolic Verifier for NetKAT. PacketSetHandle default_branch_output_packets; - for (const auto& [modify_value, branch] : - node.default_branch_by_field_modification) { + for (const auto& [modify_value, branch] : node.DefaultModifies()) { default_branch_output_packets = packet_set_manager_.Or(default_branch_output_packets, GetAllInputPacketsThatProduceAnyOutput(branch)); @@ -882,23 +965,21 @@ PacketTransformerManager::GetAllInputPacketsThatProduceAnyOutput( // Implements the `b_A` in the `bwd` function in section C.3 Push and Pull // in KATch: A Fast Symbolic Verifier for NetKAT. absl::flat_hash_map branch_by_field_value_map; - for (const auto& [match_value, branch_by_modify] : - node.modify_branch_by_field_match) { + for (const DecisionNode::Match& match : node.Matches()) { PacketSetHandle union_of_branches; - for (const auto& [modify_value, branch] : branch_by_modify) { + for (const auto& [modify_value, branch] : match.modifies) { union_of_branches = packet_set_manager_.Or( union_of_branches, GetAllInputPacketsThatProduceAnyOutput(branch)); } - branch_by_field_value_map[match_value] = union_of_branches; + branch_by_field_value_map[match.value] = union_of_branches; } // Case 3: Input packets that do not get matched on an explicit branch, but // do get modified. // Implements the `b_B` in the `bwd` function in section C.3 Push and Pull // in KATch: A Fast Symbolic Verifier for NetKAT. - for (const auto& [modify_value, unused] : - node.default_branch_by_field_modification) { - if (!node.modify_branch_by_field_match.contains(modify_value)) { + for (const auto& [modify_value, unused_branch] : node.DefaultModifies()) { + if (!node.FindMatch(modify_value).has_value()) { branch_by_field_value_map[modify_value] = default_branch_output_packets; } } @@ -946,7 +1027,7 @@ std::string PacketTransformerManager::ToString(const DecisionNode& node) const { auto pretty_print_map = [&](absl::string_view field, - const absl::btree_map& map) { + absl::Span map) { for (const auto& [value, branch] : map) { absl::StrAppendFormat(&result, " %s := %d -> %v\n", field, value, branch); @@ -959,12 +1040,12 @@ std::string PacketTransformerManager::ToString(const DecisionNode& node) const { absl::CEscape( packet_set_manager_.field_manager_.GetFieldName(node.field))); - for (const auto& [value, modify_map] : node.modify_branch_by_field_match) { - absl::StrAppendFormat(&result, " %s == %d:\n", field, value); - pretty_print_map(field, modify_map); + for (const DecisionNode::Match& match : node.Matches()) { + absl::StrAppendFormat(&result, " %s == %d:\n", field, match.value); + pretty_print_map(field, match.modifies); } absl::StrAppendFormat(&result, " %s == *:\n", field); - pretty_print_map(field, node.default_branch_by_field_modification); + pretty_print_map(field, node.DefaultModifies()); PacketTransformerHandle fallthrough = node.default_branch; absl::StrAppendFormat(&result, " %s == * -> %v\n", field, fallthrough); if (!IsAccept(fallthrough) && !IsDeny(fallthrough)) @@ -986,7 +1067,7 @@ std::string PacketTransformerManager::ToString( auto pretty_print_map = [&](absl::string_view field, - const absl::btree_map& map) { + absl::Span map) { for (const auto& [value, branch] : map) { absl::StrAppendFormat(&result, " %s := %d -> %v\n", field, value, branch); @@ -1008,12 +1089,12 @@ std::string PacketTransformerManager::ToString( "%v:'%s'", node.field, absl::CEscape( packet_set_manager_.field_manager_.GetFieldName(node.field))); - for (const auto& [value, modify_map] : node.modify_branch_by_field_match) { - absl::StrAppendFormat(&result, " %s == %d:\n", field, value); - pretty_print_map(field, modify_map); + for (const DecisionNode::Match& match : node.Matches()) { + absl::StrAppendFormat(&result, " %s == %d:\n", field, match.value); + pretty_print_map(field, match.modifies); } absl::StrAppendFormat(&result, " %s == *:\n", field); - pretty_print_map(field, node.default_branch_by_field_modification); + pretty_print_map(field, node.DefaultModifies()); PacketTransformerHandle fallthrough = node.default_branch; absl::StrAppendFormat(&result, " %s == * -> %v\n", field, fallthrough); if (IsAccept(fallthrough) || IsDeny(fallthrough)) continue; @@ -1063,13 +1144,14 @@ std::string PacketTransformerManager::ToDot( packet_set_manager_.field_manager_.GetFieldName(node.field); absl::StrAppendFormat(&result, " %d [label=\"%s\"]\n", transformer.node_index_, field); - for (const auto& [value, modify_map] : node.modify_branch_by_field_match) { - if (modify_map.empty()) { + for (const DecisionNode::Match& match : node.Matches()) { + int value = match.value; + if (match.modifies.empty()) { absl::StrAppendFormat(&result, " %d -> %d [label=\"%s==%s\"]\n", transformer.node_index_, SentinelNodeIndex::kDeny, field, absl::StrCat(value)); } - for (const auto& [new_value, branch] : modify_map) { + for (const auto& [new_value, branch] : match.modifies) { absl::StrAppendFormat(&result, " %d -> %d [label=\"%s==%s; %s:=%d\"]\n", transformer.node_index_, branch.node_index_, @@ -1080,8 +1162,7 @@ std::string PacketTransformerManager::ToDot( } } - for (const auto& [new_value, branch] : - node.default_branch_by_field_modification) { + for (const auto& [new_value, branch] : node.DefaultModifies()) { absl::StrAppendFormat( &result, " %d -> %d [label=\"%s:=%d\" style=dashed]\n", transformer.node_index_, branch.node_index_, field, new_value); @@ -1104,15 +1185,24 @@ absl::Status PacketTransformerManager::CheckInternalInvariants() const { // Invariant: Proper and sentinel node indices are disjoint. RET_CHECK(nodes_.size() <= SentinelNodeIndex::kMinSentinel); - // Invariant: `transformer_by_node_[n] = s` iff `nodes_[s.node_index_] == - // n`. - for (const auto& [node, transformer] : transformer_by_node_) { + // Invariant: `transformer_by_node_[p] = s` iff `p == &nodes_[s.node_index_]`. + for (const auto& [node_ptr, transformer] : transformer_by_node_) { RET_CHECK(transformer.node_index_ < nodes_.size()); - RET_CHECK(nodes_[transformer.node_index_] == node); + RET_CHECK(node_ptr == &nodes_[transformer.node_index_]); } for (int i = 0; i < nodes_.size(); ++i) { - const DecisionNode& node = nodes_[i]; - auto it = transformer_by_node_.find(node); + auto it = transformer_by_node_.find(&nodes_[i]); + RET_CHECK(it != transformer_by_node_.end()); + RET_CHECK(it->second == PacketTransformerHandle(i)); + } + + // Invariant: `NodeHash` and `NodeEq` treat a builder and its flattened + // node identically, as required for transparent unique table lookups. + for (int i = 0; i < nodes_.size(); ++i) { + DecisionNodeBuilder builder = ToBuilder(nodes_[i]); + RET_CHECK(NodeHash()(builder) == NodeHash()(&nodes_[i])); + RET_CHECK(NodeEq()(&nodes_[i], builder)); + auto it = transformer_by_node_.find(builder); RET_CHECK(it != transformer_by_node_.end()); RET_CHECK(it->second == PacketTransformerHandle(i)); } @@ -1120,11 +1210,39 @@ absl::Status PacketTransformerManager::CheckInternalInvariants() const { // Node Invariants. for (int i = 0; i < nodes_.size(); ++i) { const DecisionNode& node = nodes_[i]; - // Invariant: `modify_branch_by_field_match` or - // `default_branch_by_field_modification` is non-empty. + // Invariant: `matches` or `DefaultModifies()` is non-empty. // Maintained by `NodeToTransformer`. - RET_CHECK(!node.modify_branch_by_field_match.empty() || - !node.default_branch_by_field_modification.empty()); + RET_CHECK(!node.matches.empty() || !node.modifies.empty()); + + // Invariants of the flat encoding: match values strictly increase, end + // offsets are monotone and bounded, and each ModifyEntry range is sorted + // by strictly increasing modify value. + uint32_t previous_end_offset = 0; + for (const auto& [match_value, end_offset] : node.matches) { + RET_CHECK(end_offset >= previous_end_offset) << ":\n" << ToString(node); + RET_CHECK(end_offset <= node.modifies.size()) << ":\n" << ToString(node); + previous_end_offset = end_offset; + } + for (size_t j = 1; j < node.matches.size(); ++j) { + RET_CHECK(node.matches[j - 1].first < node.matches[j].first) + << ":\n" + << ToString(node); + } + auto is_strictly_sorted_by_value = + [](absl::Span entries) { + for (size_t j = 1; j < entries.size(); ++j) { + if (entries[j - 1].first >= entries[j].first) return false; + } + return true; + }; + for (const DecisionNode::Match& match : node.Matches()) { + RET_CHECK(is_strictly_sorted_by_value(match.modifies)) + << ":\n" + << ToString(node); + } + RET_CHECK(is_strictly_sorted_by_value(node.DefaultModifies())) + << ":\n" + << ToString(node); // Invariant: node field is strictly smaller than sub-node fields. RET_CHECK(IsAccept(node.default_branch) || IsDeny(node.default_branch) || @@ -1132,12 +1250,11 @@ absl::Status PacketTransformerManager::CheckInternalInvariants() const { << ":\n" << ToString(node); - for (const auto& [match_value, branch_by_modify] : - node.modify_branch_by_field_match) { - for (const auto& [modify_value, branch] : branch_by_modify) { + for (const DecisionNode::Match& match : node.Matches()) { + for (const auto& [modify_value, branch] : match.modifies) { // Invariant: Modify branches are not Deny unless `modify_value == - // match_value`. - RET_CHECK(!IsDeny(branch) || modify_value == match_value) + // match.value`. + RET_CHECK(!IsDeny(branch) || modify_value == match.value) << ":\n" << ToString(node); @@ -1149,8 +1266,7 @@ absl::Status PacketTransformerManager::CheckInternalInvariants() const { } } - for (const auto& [match_value, branch] : - node.default_branch_by_field_modification) { + for (const auto& [modify_value, branch] : node.DefaultModifies()) { // Invariant: Default modify branches are not Deny. RET_CHECK(!IsDeny(branch)); diff --git a/netkat/packet_transformer.h b/netkat/packet_transformer.h index 4c9c90d..531a564 100644 --- a/netkat/packet_transformer.h +++ b/netkat/packet_transformer.h @@ -37,17 +37,21 @@ #ifndef GOOGLE_NETKAT_NETKAT_PACKET_TRANSFORMER_H_ #define GOOGLE_NETKAT_NETKAT_PACKET_TRANSFORMER_H_ +#include #include #include +#include #include #include #include "absl/container/btree_map.h" +#include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "netkat/netkat.pb.h" #include "netkat/packet.h" #include "netkat/packet_field.h" @@ -306,7 +310,21 @@ class PacketTransformerManager { // non-deterministically set field -> value_d_1 then branch_d_1 // non-deterministically set field -> value_d_2 then branch_d_2 // non-deterministically LEAVE field UNMODIFIED then default_branch + // + // CHOICE OF DATA STRUCTURE: + // Logically, a node is a map of maps, match value -> (modify value -> + // branch), plus a map of default modifications, modify value -> branch. We + // store all of this in two flat, sorted arrays to optimize memory layout + // (contiguous, compact, flat), exploiting that nodes are immutable once + // interned. This makes the hashing, equality comparison, and copying done + // by the unique table (`transformer_by_node_`) cheap scans over contiguous + // memory, and shrinks node storage. A mutation-friendly map-based + // representation exists separately as `DecisionNodeBuilder`, used only + // transiently while constructing nodes. struct DecisionNode { + // A single "set field to `first`, then continue with `second`" entry. + using ModifyEntry = std::pair; + // The packet field whose value this decision node branches on. // // INVARIANTS: @@ -315,54 +333,217 @@ class PacketTransformerManager { // * Interned by `field_manager_`. PacketFieldHandle field; - // The "if" branches of the decision node, "keyed" by the value they branch - // on. Each element of the map is a (match_value, Map)-pair encoding - // "if (field == match_value) then non-deterministically choose a - // (modify_value, branch) pair from `Map`, modify field to modify_value and - // follow branch". + // The "leave field unmodified" consequent of the "else" branch. + PacketTransformerHandle default_branch; + + // The "if" branches of the decision node. `matches[i]` is a + // (match_value, end_offset) pair encoding "if (field == match_value) then + // non-deterministically choose a ModifyEntry from `MatchModifies(i)`, + // modify field to its modify value and follow its branch". The end offsets + // delimit the per-match ranges of `modifies`: `MatchModifies(i)` is + // `modifies[matches[i-1].end_offset, matches[i].end_offset)`. + // A match with an empty ModifyEntry range denies all packets whose field + // equals its match value. // // INVARIANTS: - // 1. Maintained by `NodeToTransformer`: `modify_branch_by_field_match` and - // `default_branch_by_field_modification` below are not both empty. - // (If they were both empty, the decision node gets replaced by - // `default_branch`.) - // 2. For every v, v', and b such that (v,(v',b)) is in - // `modify_branch_by_field_match`, either v == v' or b is not Deny. - absl::btree_map> - modify_branch_by_field_match; - - // The "else" branch of this decision node, "keyed" by the value they modify - // the field to (or not keyed at all for the `default_branch`). + // 1. Maintained by `NodeToTransformer`: `matches` and `DefaultModifies()` + // are not both empty. (If they were both empty, the decision node gets + // replaced by `default_branch`.) + // 2. For every entry (v', b) in `MatchModifies(i)` with match value v, + // either v == v' or b is not Deny. + // 3. Sorted by strictly increasing match value; end offsets are + // non-decreasing and bounded by `modifies.size()`. + absl::FixedArray, + /*use_heap_allocation_above_size=*/0> + matches; + + // The ModifyEntry ranges of all matches, in match order, followed by the + // "else" modifications (see `DefaultModifies()`): entries encoding "if no + // match value applies, non-deterministically set field -> entry.first and + // follow entry.second". // // INVARIANTS: - // 1. For every v and b such that (v,b) is in - // `default_branch_by_field_modification`, b is not Deny. - absl::btree_map - default_branch_by_field_modification; - PacketTransformerHandle default_branch; + // 1. Each per-match range and the default range is sorted by strictly + // increasing modify value, without duplicates. + // 2. For every entry (v, b) in `DefaultModifies()`, b is not Deny. + absl::FixedArray + modifies; + + // The ModifyEntry range of `matches[i]`. + absl::Span MatchModifies(size_t i) const { + uint32_t begin = i == 0 ? 0 : matches[i - 1].second; + return absl::MakeConstSpan(modifies.data() + begin, + matches[i].second - begin); + } - // Protect against regressions in memory layout, as it affects performance. - static_assert(sizeof(modify_branch_by_field_match) == 24); - static_assert(sizeof(default_branch_by_field_modification) == 24); + // A single "if (field == value)" branch: the match value together with + // its ModifyEntry range. + struct Match { + int value; + absl::Span modifies; + }; + + // Iterates the "if" branches as `Match` views, in order of strictly + // increasing match value. Allows range-for loops over the branches + // without manual index bookkeeping. + class MatchIterator { + public: + MatchIterator(const DecisionNode* node, size_t index) + : node_(node), index_(index) {} + Match operator*() const { + return {node_->matches[index_].first, node_->MatchModifies(index_)}; + } + MatchIterator& operator++() { + ++index_; + return *this; + } + friend bool operator==(const MatchIterator& a, + const MatchIterator& b) = default; + + private: + const DecisionNode* node_; + size_t index_; + }; + struct MatchRange { + const DecisionNode* node; + MatchIterator begin() const { return {node, 0}; } + MatchIterator end() const { return {node, node->matches.size()}; } + }; + MatchRange Matches() const { return {this}; } + + // The ModifyEntry range of the "else" branch. + absl::Span DefaultModifies() const { + uint32_t begin = matches.empty() ? 0 : matches.back().second; + return absl::MakeConstSpan(modifies.data() + begin, + modifies.size() - begin); + } + + // Returns the index into `matches` with the given match value, if any. + std::optional FindMatch(int match_value) const { + auto it = std::lower_bound( + matches.begin(), matches.end(), match_value, + [](const auto& match, int value) { return match.first < value; }); + if (it == matches.end() || it->first != match_value) return std::nullopt; + return it - matches.begin(); + } + + // Returns true iff `entries` (sorted by modify value) contains an entry + // with the given modify value. + static bool ContainsModifyValue(absl::Span entries, + int modify_value) { + auto it = std::lower_bound(entries.begin(), entries.end(), modify_value, + [](const ModifyEntry& entry, int value) { + return entry.first < value; + }); + return it != entries.end() && it->first == modify_value; + } friend auto operator<=>(const DecisionNode& a, const DecisionNode& b) = default; - // Hashing, see https://abseil.io/docs/cpp/guides/hash. - template - friend H AbslHashValue(H h, const DecisionNode& node) { - return H::combine(std::move(h), node.field, node.default_branch, - node.default_branch_by_field_modification, - node.modify_branch_by_field_match); - } + // NOTE: Hashing is deliberately NOT defined on this struct. The unique + // table must hash flat nodes and `DecisionNodeBuilder`s identically, so + // there is a single hash definition for both: `NodeHash`. }; // Protect against regressions in memory layout, as it affects performance. - // TODO(dilo): Is this still important with this simpler data structure, or - // should we remove it until we optimize? - static_assert(sizeof(DecisionNode) == 64); + static_assert(sizeof(DecisionNode) == 40); static_assert(alignof(DecisionNode) == 8); + // A mutable, map-based representation of a `DecisionNode`, used only + // transiently while constructing nodes (by the combinators and the golden + // test runner's canonicalizing copy). Finished builders are canonicalized, + // flattened into `DecisionNode`s, and interned by `NodeToTransformer`. The + // members mirror `DecisionNode`; see there for semantics and invariants. + struct DecisionNodeBuilder { + PacketFieldHandle field; + + // Match value -> (modify value -> branch). See `DecisionNode::matches`. + absl::btree_map> + modify_branch_by_field_match; + + // Modify value -> branch. See `DecisionNode::DefaultModifies()`. + absl::btree_map + default_branch_by_field_modification; + PacketTransformerHandle default_branch; + }; + + // Invokes `match(match_value, end_offset)` for each match header and then + // `modify(modify_value, branch)` for each modify entry of the given node or + // builder, in the canonical flat order of `DecisionNode::matches` and + // `DecisionNode::modifies`. Stops and returns false as soon as a callback + // returns false; returns true if all elements were visited. + // + // This is the single definition of a node's flat element sequence: + // `Flatten`, `NodeHash`, and `NodeEq` are all written against it, which + // keeps the two node representations consistent by construction. + template + static bool ForEachFlatEntry(const DecisionNode& node, MatchFn&& match, + ModifyFn&& modify); + template + static bool ForEachFlatEntry(const DecisionNodeBuilder& builder, + MatchFn&& match, ModifyFn&& modify); + + // Transparent hash and equality functors for the unique table + // (`transformer_by_node_`), which is keyed by stable `DecisionNode*` + // pointers into `nodes_` (so each node is stored only once). Lookups work + // directly with a `DecisionNodeBuilder` — without flattening it — keeping + // the hot path of `NodeToTransformer`, re-deriving a node that already + // exists, free of allocations; flattening only happens for genuinely new + // nodes. Both functors are stateless: keys are pointers, and the pages + // holding the nodes are stable across moves of the manager. + // + // INVARIANT: A builder and its flattened node are treated identically: + // `NodeHash()(b) == NodeHash()(&Flatten(b))` and `NodeEq()(&Flatten(b), b)`. + // Maintained by defining both functors in terms of `ForEachFlatEntry`; + // checked by `CheckInternalInvariants`. + struct NodeHash { + using is_transparent = void; + size_t operator()(const DecisionNode* node) const; + size_t operator()(const DecisionNodeBuilder& builder) const; + + private: + // Adapter implementing both overloads: hashes the canonical flat element + // sequence (via `ForEachFlatEntry`) in a single streaming pass, so flat + // nodes and builders with the same logical content hash identically. + template + struct FlatSequenceView { + const NodeOrBuilder& node; + + // Hashing, see https://abseil.io/docs/cpp/guides/hash. + template + friend H AbslHashValue(H h, const FlatSequenceView& view) { + size_t num_matches = 0; + size_t num_modifies = 0; + h = H::combine(std::move(h), view.node.field, + view.node.default_branch); + ForEachFlatEntry( + view.node, + [&](int match_value, uint32_t end_offset) { + h = H::combine(std::move(h), match_value, end_offset); + ++num_matches; + return true; + }, + [&](int modify_value, PacketTransformerHandle branch) { + h = H::combine(std::move(h), modify_value, branch); + ++num_modifies; + return true; + }); + return H::combine(std::move(h), num_matches, num_modifies); + } + }; + }; + struct NodeEq { + using is_transparent = void; + bool operator()(const DecisionNode* a, const DecisionNode* b) const { + return a == b || *a == *b; + } + bool operator()(const DecisionNode* a, const DecisionNodeBuilder& b) const; + bool operator()(const DecisionNodeBuilder& a, const DecisionNode* b) const { + return (*this)(b, a); + } + }; + // A key for efficiently hashing a `PolicyProto` to a // `PacketTransformerHandle`. This works as a recursive hash, such that we // only internally compile unique messages exactly once. @@ -383,11 +564,12 @@ class PacketTransformerManager { template friend H AbslHashValue(H h, const ProtoHashKey& key) { - return H::combine(std::move(h), key.lhs_child, key.rhs_child); + return H::combine(std::move(h), key.policy_case, key.lhs_child, + key.rhs_child); } }; - PacketTransformerHandle NodeToTransformer(DecisionNode&& node); + PacketTransformerHandle NodeToTransformer(DecisionNodeBuilder&& node); // Returns the `DecisionNode` corresponding to the given // `PacketTransformerHandle`, or crashes if the `transformer` is @@ -404,17 +586,22 @@ class PacketTransformerManager { // enough to avoid excessive memory overhead. static constexpr size_t kPageSize = (1 << 26) / sizeof(DecisionNode); - // Helper functions to deal with DecisionNodes directly. - // TODO(dilo): Is there a convenient way to either avoid these or avoid making - // copies of the nodes? - PacketTransformerHandle Union(DecisionNode left, DecisionNode right); - PacketTransformerHandle Sequence(DecisionNode left, DecisionNode right); - PacketTransformerHandle Difference(DecisionNode left, DecisionNode right); - - // Internal helper function to get a map of possible modification values to - // branches for a given input value at `node`. - absl::btree_map GetMapAtValue( - const DecisionNode& node, int value); + // Shared implementation of `Union` and `Difference`, which differ only in + // their base cases and the operation applied to corresponding branches: + // combines `left` and `right` by applying `combiner` — which must be the + // handle-level operation itself, e.g. `Union` — to corresponding branches. + // Both operands must be Accept or decision nodes (i.e., the base cases must + // already have been handled). + template + PacketTransformerHandle PointwiseCombine(PacketTransformerHandle left, + PacketTransformerHandle right, + Combiner&& combiner); + + // Conversions between the interned (flat) and builder (map-based) + // representations of decision nodes. `ToBuilder` is only used by + // `CheckInternalInvariants`, to validate `NodeHash`/`NodeEq` consistency. + static DecisionNode Flatten(DecisionNodeBuilder&& builder); + static DecisionNodeBuilder ToBuilder(const DecisionNode& node); // The decision nodes forming the BDD-style DAG representation of packets. // `PacketTransformerHandle::node_index_` indexes into this vector. @@ -426,9 +613,14 @@ class PacketTransformerManager { // A so called "unique table" to ensure each node is only added to `nodes_` // once, and thus has a unique `PacketTransformerHandle::node_index`. + // Keyed by pointers into `nodes_` (stable, see `PagedStableVector`), so + // nodes are not stored twice. The transparent `NodeHash`/`NodeEq` functors + // support lookup by `DecisionNodeBuilder` without flattening, see their + // documentation. // - // INVARIANT: `transformer_by_node_[n] = s` iff `nodes_[s.node_index_] == n`. - absl::flat_hash_map + // INVARIANT: `transformer_by_node_[p] = s` iff `p == &nodes_[s.node_index_]`. + absl::flat_hash_map transformer_by_node_; // A map of a given `PolicyProto` to a `PacketTransformerHandle`. diff --git a/netkat/packet_transformer_benchmark.cc b/netkat/packet_transformer_benchmark.cc index d927ffc..30d1a01 100644 --- a/netkat/packet_transformer_benchmark.cc +++ b/netkat/packet_transformer_benchmark.cc @@ -18,6 +18,7 @@ #include "benchmark/benchmark.h" #include "netkat/netkat.pb.h" #include "netkat/netkat_proto_constructors.h" +#include "netkat/packet_set.h" #include "netkat/packet_transformer.h" namespace netkat { @@ -99,6 +100,31 @@ void BM_FirstTimeCompileOverlappingPolicy(benchmark::State& state) { } BENCHMARK(BM_FirstTimeCompileOverlappingPolicy); +// Benchmarks the read-path operations that the analysis engine is built on: +// pushing/pulling packet sets through an already-compiled transformer. +// After the first iteration all nodes exist, so steady-state iterations +// exercise DAG traversal and unique-table hits rather than first-time node +// creation. +void BM_PushAndPullFullSetThroughPolicy(benchmark::State& state) { + PacketTransformerManager manager; + PolicyProto policy = CreateFixedArbitraryPolicyProto(0); + // NOTE: Without memoization of the recursive transformer operations, + // Push/Pull cost grows exponentially with policy size; keep the number of + // unioned sub-policies small so the benchmark stays tractable. + for (int i = 1; i < 2; ++i) { + policy = UnionProto( + policy, SequenceProto(CreateFixedArbitraryPolicyProto(i), + CreateFixedArbitraryPolicyProto(i + 8))); + } + PacketTransformerHandle transformer = manager.Compile(policy); + PacketSetHandle full_set = manager.GetPacketSetManager().FullSet(); + for (auto s : state) { + benchmark::DoNotOptimize(manager.Push(full_set, transformer)); + benchmark::DoNotOptimize(manager.Pull(transformer, full_set)); + } +} +BENCHMARK(BM_PushAndPullFullSetThroughPolicy); + // Benchmarks the cost of compiling a policy, with overlapping substructures, // that has already been compiled once before. Excludes the initial cost of // compilation. diff --git a/netkat/packet_transformer_test.cc b/netkat/packet_transformer_test.cc index 0e7b0fc..061d213 100644 --- a/netkat/packet_transformer_test.cc +++ b/netkat/packet_transformer_test.cc @@ -968,9 +968,8 @@ class PacketTransformerManagerTestPeer { value); }; // Case 1: Output from explicit match+modify branches. - for (const auto& [match_value, branch_by_modify_value] : - node.modify_branch_by_field_match) { - for (const auto& [modify_value, branch] : branch_by_modify_value) { + for (const auto& match : node.Matches()) { + for (const auto& [modify_value, branch] : match.modifies) { add_to_output( and_fn(match_fn(field, modify_value), GetAllPossibleOutputPacketsReferenceImplementation(branch))); @@ -978,8 +977,7 @@ class PacketTransformerManagerTestPeer { } // Case 2: Output from default-modify branches. - for (const auto& [modify_value, branch] : - node.default_branch_by_field_modification) { + for (const auto& [modify_value, branch] : node.DefaultModifies()) { add_to_output( and_fn(match_fn(field, modify_value), GetAllPossibleOutputPacketsReferenceImplementation(branch))); @@ -995,13 +993,11 @@ class PacketTransformerManagerTestPeer { // output.field != match_value for all explicit match branches by (0) // 2. output.field != modify_value for all default-modify branches PacketSetHandle fallthrough_output = PacketSetManager().FullSet(); - for (const auto& [match_value, unused] : - node.modify_branch_by_field_match) { + for (const auto& [match_value, unused_end_offset] : node.matches) { fallthrough_output = and_fn(fallthrough_output, not_fn(match_fn(field, match_value))); } - for (const auto& [modify_value, unused] : - node.default_branch_by_field_modification) { + for (const auto& [modify_value, unused_branch] : node.DefaultModifies()) { fallthrough_output = and_fn(fallthrough_output, not_fn(match_fn(field, modify_value))); } diff --git a/netkat/packet_transformer_test.expected b/netkat/packet_transformer_test.expected index 9f01d94..ca6b248 100644 --- a/netkat/packet_transformer_test.expected +++ b/netkat/packet_transformer_test.expected @@ -42,22 +42,22 @@ digraph { Test case: p := (a=5 + b=2);(b:=1 + c=5). Example from Katch paper Fig 5. ================================================================================ -- STRING ---------------------------------------------------------------------- -PacketTransformerHandle<8>: +PacketTransformerHandle<3>: PacketFieldHandle<0>:'a' == 5: - PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<6> + PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<1> PacketFieldHandle<0>:'a' == *: - PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<7> -PacketTransformerHandle<6>: + PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<2> +PacketTransformerHandle<1>: PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<5> -PacketTransformerHandle<7>: + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<0> +PacketTransformerHandle<2>: PacketFieldHandle<1>:'b' == 2: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle - PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<5> + PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<0> PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<5>: +PacketTransformerHandle<0>: PacketFieldHandle<2>:'c' == 5: PacketFieldHandle<2>:'c' := 5 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == *: @@ -68,50 +68,50 @@ digraph { edge [fontsize = 12] 4294967294 [label="T" shape=box] 4294967295 [label="F" shape=box] - 8 [label="a"] - 8 -> 6 [label="a==5; a:=5"] - 8 -> 7 [style=dashed] - 6 [label="b"] - 6 -> 4294967294 [label="b:=1" style=dashed] - 6 -> 5 [style=dashed] - 7 [label="b"] - 7 -> 4294967294 [label="b==2; b:=1"] - 7 -> 5 [label="b==2; b:=2"] - 7 -> 4294967295 [style=dashed] - 5 [label="c"] - 5 -> 4294967294 [label="c==5; c:=5"] - 5 -> 4294967295 [style=dashed] + 3 [label="a"] + 3 -> 1 [label="a==5; a:=5"] + 3 -> 2 [style=dashed] + 1 [label="b"] + 1 -> 4294967294 [label="b:=1" style=dashed] + 1 -> 0 [style=dashed] + 2 [label="b"] + 2 -> 4294967294 [label="b==2; b:=1"] + 2 -> 0 [label="b==2; b:=2"] + 2 -> 4294967295 [style=dashed] + 0 [label="c"] + 0 -> 4294967294 [label="c==5; c:=5"] + 0 -> 4294967295 [style=dashed] } ================================================================================ Test case: q := (b=1 + c:=4 + a:=5;b:=1). Example from Katch paper Fig 5. ================================================================================ -- STRING ---------------------------------------------------------------------- -PacketTransformerHandle<17>: +PacketTransformerHandle<5>: PacketFieldHandle<0>:'a' == 1: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<14> + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<2> PacketFieldHandle<0>:'a' == *: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<4> - PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<16> -PacketTransformerHandle<14>: + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<3> + PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<4> +PacketTransformerHandle<2>: PacketFieldHandle<1>:'b' == 1: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<10> -PacketTransformerHandle<4>: + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<1> +PacketTransformerHandle<3>: PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<16>: +PacketTransformerHandle<4>: PacketFieldHandle<1>:'b' == 1: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> PacketFieldHandle<1>:'b' == *: - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<10> -PacketTransformerHandle<13>: + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<1> +PacketTransformerHandle<0>: PacketFieldHandle<2>:'c' == *: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == * -> PacketTransformerHandle -PacketTransformerHandle<10>: +PacketTransformerHandle<1>: PacketFieldHandle<2>:'c' == *: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == * -> PacketTransformerHandle @@ -121,64 +121,64 @@ digraph { edge [fontsize = 12] 4294967294 [label="T" shape=box] 4294967295 [label="F" shape=box] - 17 [label="a"] - 17 -> 14 [label="a==1; a:=1"] - 17 -> 4 [label="a:=1" style=dashed] - 17 -> 16 [style=dashed] - 14 [label="b"] - 14 -> 13 [label="b==1; b:=1"] - 14 -> 4294967294 [label="b:=1" style=dashed] - 14 -> 10 [style=dashed] + 5 [label="a"] + 5 -> 2 [label="a==1; a:=1"] + 5 -> 3 [label="a:=1" style=dashed] + 5 -> 4 [style=dashed] + 2 [label="b"] + 2 -> 0 [label="b==1; b:=1"] + 2 -> 4294967294 [label="b:=1" style=dashed] + 2 -> 1 [style=dashed] + 3 [label="b"] + 3 -> 4294967294 [label="b:=1" style=dashed] + 3 -> 4294967295 [style=dashed] 4 [label="b"] - 4 -> 4294967294 [label="b:=1" style=dashed] - 4 -> 4294967295 [style=dashed] - 16 [label="b"] - 16 -> 13 [label="b==1; b:=1"] - 16 -> 10 [style=dashed] - 13 [label="c"] - 13 -> 4294967294 [label="c:=4" style=dashed] - 13 -> 4294967294 [style=dashed] - 10 [label="c"] - 10 -> 4294967294 [label="c:=4" style=dashed] - 10 -> 4294967295 [style=dashed] + 4 -> 0 [label="b==1; b:=1"] + 4 -> 1 [style=dashed] + 0 [label="c"] + 0 -> 4294967294 [label="c:=4" style=dashed] + 0 -> 4294967294 [style=dashed] + 1 [label="c"] + 1 -> 4294967294 [label="c:=4" style=dashed] + 1 -> 4294967295 [style=dashed] } ================================================================================ Test case: p;q. Example from Katch paper Fig 5. ================================================================================ -- STRING ---------------------------------------------------------------------- -PacketTransformerHandle<22>: +PacketTransformerHandle<6>: PacketFieldHandle<0>:'a' == 1: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<19> + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<2> PacketFieldHandle<0>:'a' == 5: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<4> - PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<21> + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<3> + PacketFieldHandle<0>:'a' := 5 -> PacketTransformerHandle<4> PacketFieldHandle<0>:'a' == *: - PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<20> - PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<19> -PacketTransformerHandle<19>: + PacketFieldHandle<0>:'a' := 1 -> PacketTransformerHandle<5> + PacketFieldHandle<0>:'a' == * -> PacketTransformerHandle<2> +PacketTransformerHandle<2>: PacketFieldHandle<1>:'b' == 2: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> - PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<18> + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> + PacketFieldHandle<1>:'b' := 2 -> PacketTransformerHandle<1> PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<4>: +PacketTransformerHandle<3>: PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<21>: +PacketTransformerHandle<4>: PacketFieldHandle<1>:'b' == *: - PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<13> - PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<18> -PacketTransformerHandle<20>: + PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle<0> + PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle<1> +PacketTransformerHandle<5>: PacketFieldHandle<1>:'b' == 2: PacketFieldHandle<1>:'b' := 1 -> PacketTransformerHandle PacketFieldHandle<1>:'b' == *: PacketFieldHandle<1>:'b' == * -> PacketTransformerHandle -PacketTransformerHandle<13>: +PacketTransformerHandle<0>: PacketFieldHandle<2>:'c' == *: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == * -> PacketTransformerHandle -PacketTransformerHandle<18>: +PacketTransformerHandle<1>: PacketFieldHandle<2>:'c' == 5: PacketFieldHandle<2>:'c' := 4 -> PacketTransformerHandle PacketFieldHandle<2>:'c' == *: @@ -189,29 +189,29 @@ digraph { edge [fontsize = 12] 4294967294 [label="T" shape=box] 4294967295 [label="F" shape=box] - 22 [label="a"] - 22 -> 19 [label="a==1; a:=1"] - 22 -> 4 [label="a==5; a:=1"] - 22 -> 21 [label="a==5; a:=5"] - 22 -> 20 [label="a:=1" style=dashed] - 22 -> 19 [style=dashed] - 19 [label="b"] - 19 -> 13 [label="b==2; b:=1"] - 19 -> 18 [label="b==2; b:=2"] - 19 -> 4294967295 [style=dashed] + 6 [label="a"] + 6 -> 2 [label="a==1; a:=1"] + 6 -> 3 [label="a==5; a:=1"] + 6 -> 4 [label="a==5; a:=5"] + 6 -> 5 [label="a:=1" style=dashed] + 6 -> 2 [style=dashed] + 2 [label="b"] + 2 -> 0 [label="b==2; b:=1"] + 2 -> 1 [label="b==2; b:=2"] + 2 -> 4294967295 [style=dashed] + 3 [label="b"] + 3 -> 4294967294 [label="b:=1" style=dashed] + 3 -> 4294967295 [style=dashed] 4 [label="b"] - 4 -> 4294967294 [label="b:=1" style=dashed] - 4 -> 4294967295 [style=dashed] - 21 [label="b"] - 21 -> 13 [label="b:=1" style=dashed] - 21 -> 18 [style=dashed] - 20 [label="b"] - 20 -> 4294967294 [label="b==2; b:=1"] - 20 -> 4294967295 [style=dashed] - 13 [label="c"] - 13 -> 4294967294 [label="c:=4" style=dashed] - 13 -> 4294967294 [style=dashed] - 18 [label="c"] - 18 -> 4294967294 [label="c==5; c:=4"] - 18 -> 4294967295 [style=dashed] + 4 -> 0 [label="b:=1" style=dashed] + 4 -> 1 [style=dashed] + 5 [label="b"] + 5 -> 4294967294 [label="b==2; b:=1"] + 5 -> 4294967295 [style=dashed] + 0 [label="c"] + 0 -> 4294967294 [label="c:=4" style=dashed] + 0 -> 4294967294 [style=dashed] + 1 [label="c"] + 1 -> 4294967294 [label="c==5; c:=4"] + 1 -> 4294967295 [style=dashed] } diff --git a/netkat/packet_transformer_test_runner.cc b/netkat/packet_transformer_test_runner.cc index 653c320..53115f8 100644 --- a/netkat/packet_transformer_test_runner.cc +++ b/netkat/packet_transformer_test_runner.cc @@ -21,10 +21,106 @@ #include #include +#include "absl/container/btree_set.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "netkat/netkat_proto_constructors.h" +#include "netkat/packet_field.h" #include "netkat/packet_transformer.h" namespace netkat { + +// Test peer to access `PacketTransformerManager` internals, allowing this +// runner to renumber nodes canonically before printing. +class PacketTransformerManagerTestPeer { + public: + // Copies `transformer` into the fresh manager `to`, interning fields and + // nodes in a deterministic traversal order. This makes the printed handle + // numbers depend only on the structure of `transformer` — not on the + // interning order of `from`, which is an implementation detail that changes + // under refactorings of the manager (e.g. reordering its operations). + static PacketTransformerHandle CanonicalCopy( + const PacketTransformerManager& from, PacketTransformerHandle transformer, + PacketTransformerManager& to) { + // Intern the fields of all reachable nodes in their original relative + // order, which the node invariants ("fields increase strictly along each + // path") depend on, and record the handle translation for `Copy`. + absl::btree_set fields; + absl::flat_hash_set visited; + CollectFields(from, transformer, visited, fields); + absl::flat_hash_map field_translation; + for (PacketFieldHandle field : fields) { + field_translation.try_emplace( + field, + to.packet_set_manager_.field_manager_.GetOrCreatePacketFieldHandle( + from.packet_set_manager_.field_manager_.GetFieldName(field))); + } + + absl::flat_hash_map + copy_by_original; + return Copy(from, transformer, to, field_translation, copy_by_original); + } + + private: + static void CollectFields( + const PacketTransformerManager& from, PacketTransformerHandle transformer, + absl::flat_hash_set& visited, + absl::btree_set& fields) { + if (from.IsDeny(transformer) || from.IsAccept(transformer)) return; + if (!visited.insert(transformer).second) return; + const PacketTransformerManager::DecisionNode& node = + from.GetNodeOrDie(transformer); + fields.insert(node.field); + // All match and default modify entries live in the node's flat `modifies` + // array, so one pass visits every branch. + for (const auto& [modify_value, branch] : node.modifies) { + CollectFields(from, branch, visited, fields); + } + CollectFields(from, node.default_branch, visited, fields); + } + + static PacketTransformerHandle Copy( + const PacketTransformerManager& from, PacketTransformerHandle transformer, + PacketTransformerManager& to, + const absl::flat_hash_map& + field_translation, + absl::flat_hash_map& + copy_by_original) { + if (from.IsDeny(transformer)) return to.Deny(); + if (from.IsAccept(transformer)) return to.Accept(); + if (auto it = copy_by_original.find(transformer); + it != copy_by_original.end()) { + return it->second; + } + + const PacketTransformerManager::DecisionNode& node = + from.GetNodeOrDie(transformer); + PacketTransformerManager::DecisionNodeBuilder copy{ + .field = field_translation.at(node.field), + }; + for (const auto& match : node.Matches()) { + // `operator[]` keeps entries with empty branch maps, which are + // meaningful: they deny packets matching `match.value`. + auto& copy_branch_by_modify = + copy.modify_branch_by_field_match[match.value]; + for (const auto& [modify_value, branch] : match.modifies) { + copy_branch_by_modify[modify_value] = + Copy(from, branch, to, field_translation, copy_by_original); + } + } + for (const auto& [modify_value, branch] : node.DefaultModifies()) { + copy.default_branch_by_field_modification[modify_value] = + Copy(from, branch, to, field_translation, copy_by_original); + } + copy.default_branch = Copy(from, node.default_branch, to, field_translation, + copy_by_original); + + PacketTransformerHandle result = to.NodeToTransformer(std::move(copy)); + copy_by_original.emplace(transformer, result); + return result; + } +}; + namespace { constexpr char kBanner[] = @@ -91,16 +187,22 @@ std::vector TestCases() { } void main() { - // This test needs a deterministic field interning order, and thus must start - // from a fresh manager. PacketTransformerManager manager; for (const TestCase& test_case : TestCases()) { netkat::PacketTransformerHandle packet_transformer = manager.Compile(test_case.policy); + // Renumber nodes and fields canonically, in a fresh manager per test case, + // so that the printed output depends only on the structure of the compiled + // transformer and not on the interning order of `manager`. + PacketTransformerManager canonical_manager; + netkat::PacketTransformerHandle canonical_transformer = + PacketTransformerManagerTestPeer::CanonicalCopy( + manager, packet_transformer, canonical_manager); std::cout << kBanner << "Test case: " << test_case.description << std::endl << kBanner; - std::cout << kStringHeader << manager.ToString(packet_transformer); - std::cout << kDotHeader << manager.ToDot(packet_transformer); + std::cout << kStringHeader + << canonical_manager.ToString(canonical_transformer); + std::cout << kDotHeader << canonical_manager.ToDot(canonical_transformer); } }