Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion netkat/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ cc_library(
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
Expand Down Expand Up @@ -375,7 +376,7 @@ cc_library(
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
Expand Down Expand Up @@ -427,6 +428,7 @@ cc_binary(
deps = [
":netkat_cc_proto",
":netkat_proto_constructors",
":packet_set",
":packet_transformer",
"@com_google_absl//absl/strings",
"@com_google_benchmark//:benchmark_main",
Expand All @@ -440,7 +442,11 @@ cc_test(
linkstatic = True,
deps = [
":netkat_proto_constructors",
":packet_field",
":packet_transformer",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
],
)

Expand Down
43 changes: 30 additions & 13 deletions netkat/packet_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "absl/algorithm/container.h"
#include "absl/container/fixed_array.h"
#include "absl/container/flat_hash_set.h"
#include "absl/hash/hash.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
Expand Down Expand Up @@ -87,6 +88,14 @@ const PacketSetManager::DecisionNode& PacketSetManager::GetNodeOrDie(
return nodes_[packet_set.node_index_];
}

size_t PacketSetManager::NodeHash::operator()(const DecisionNode* node) const {
return absl::HashOf(*node);
}

size_t PacketSetManager::NodeHash::operator()(const DecisionNode& node) const {
return absl::HashOf(node);
}

PacketSetHandle PacketSetManager::NodeToPacket(DecisionNode&& node) {
if (node.branch_by_field_value.empty()) return node.default_branch;

Expand All @@ -109,16 +118,19 @@ PacketSetHandle PacketSetManager::NodeToPacket(DecisionNode&& node) {
}
#endif

auto [it, inserted] =
packet_by_node_.try_emplace(node, PacketSetHandle(nodes_.size()));
if (inserted) {
nodes_.push_back(std::move(node));
LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel)
<< "Internal invariant violated: Proper and sentinel node indices must "
"be disjoint. This indicates that we allocated more nodes than are "
"supported (> 2^32 - 2).";
// Look up the node by value via the transparent `NodeHash`/`NodeEq`
// functors; only new nodes get stored (exactly once, in `nodes_`).
if (auto it = packet_by_node_.find(node); it != packet_by_node_.end()) {
return it->second;
}
return it->second;
PacketSetHandle packet(nodes_.size());
nodes_.push_back(std::move(node));
packet_by_node_.insert({&nodes_[packet.node_index_], packet});
LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel)
<< "Internal invariant violated: Proper and sentinel node indices must "
"be disjoint. This indicates that we allocated more nodes than are "
"supported (> 2^32 - 2).";
return packet;
}

bool PacketSetManager::Contains(PacketSetHandle packet_set,
Expand Down Expand Up @@ -483,14 +495,19 @@ absl::Status PacketSetManager::CheckInternalInvariants() const {
// Invariant: Proper and sentinel node indices are disjoint.
RET_CHECK(nodes_.size() <= SentinelNodeIndex::kMinSentinel);

// Invariant: `packet_by_node_[n] = s` iff `nodes_[s.node_index_] == n`.
for (const auto& [node, packet] : packet_by_node_) {
// Invariant: `packet_by_node_[p] = s` iff `p == &nodes_[s.node_index_]`.
for (const auto& [node_ptr, packet] : packet_by_node_) {
RET_CHECK(packet.node_index_ < nodes_.size());
RET_CHECK(nodes_[packet.node_index_] == node);
RET_CHECK(node_ptr == &nodes_[packet.node_index_]);
}
for (int i = 0; i < nodes_.size(); ++i) {
const DecisionNode& node = nodes_[i];
auto it = packet_by_node_.find(node);
// Look up both by pointer and by value (exercising the transparent
// functors used by `NodeToPacket`).
auto it = packet_by_node_.find(&node);
RET_CHECK(it != packet_by_node_.end());
RET_CHECK(it->second == PacketSetHandle(i));
it = packet_by_node_.find(node);
RET_CHECK(it != packet_by_node_.end());
RET_CHECK(it->second == PacketSetHandle(i));
}
Expand Down
31 changes: 29 additions & 2 deletions netkat/packet_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,11 +356,38 @@ class PacketSetManager {
// `And`, `Or`, `Not`). The class also avoids expensive relocations.
PagedStableVector<DecisionNode, kPageSize> nodes_;

// Transparent hash and equality functors for the unique table
// (`packet_by_node_`), which is keyed by stable `DecisionNode*` pointers
// into `nodes_` (so each node is stored only once). Lookups work directly
// with a not-yet-interned `DecisionNode` value. Both functors are
// stateless: keys are pointers, and the pages holding the nodes are stable
// across moves of the manager.
struct NodeHash {
using is_transparent = void;
size_t operator()(const DecisionNode* node) const;
size_t operator()(const DecisionNode& node) const;
};
struct NodeEq {
using is_transparent = void;
bool operator()(const DecisionNode* a, const DecisionNode* b) const {
return a == b || *a == *b;
}
bool operator()(const DecisionNode* a, const DecisionNode& b) const {
return *a == b;
}
bool operator()(const DecisionNode& a, const DecisionNode* b) const {
return a == *b;
}
};

// A so called "unique table" to ensure each node is only added to `nodes_`
// once, and thus has a unique `PacketSetHandle::node_index`.
// Keyed by pointers into `nodes_` (stable, see `PagedStableVector`), so
// nodes are not stored twice.
//
// INVARIANT: `packet_by_node_[n] = s` iff `nodes_[s.node_index_] == n`.
absl::flat_hash_map<DecisionNode, PacketSetHandle> packet_by_node_;
// INVARIANT: `packet_by_node_[p] = s` iff `p == &nodes_[s.node_index_]`.
absl::flat_hash_map<const DecisionNode*, PacketSetHandle, NodeHash, NodeEq>
packet_by_node_;

// A map of a given `PredicateProto` to a `PacketSetHandle`.
//
Expand Down
Loading