From 79062bd7719672fa3f3353615f216a82ff21baa4 Mon Sep 17 00:00:00 2001 From: Anthony Roy Date: Mon, 15 Jun 2026 15:23:38 -0700 Subject: [PATCH] [NetKAT] Implement Pull in the frontend. PiperOrigin-RevId: 932697780 --- netkat/BUILD.bazel | 54 ++++++---- netkat/analysis_engine_test.cc | 10 +- netkat/evaluator.cc | 3 + netkat/evaluator_test.cc | 118 +++++++++++++++----- netkat/frontend.cc | 20 ++++ netkat/frontend.h | 12 +++ netkat/frontend_test.cc | 12 +++ netkat/gtest_utils.cc | 46 ++++++-- netkat/gtest_utils.h | 12 +++ netkat/manager_handle_pattern.md | 29 +++-- netkat/netkat.proto | 7 ++ netkat/netkat_proto_constructors.cc | 11 ++ netkat/netkat_proto_constructors.h | 2 + netkat/netkat_proto_constructors_test.cc | 15 +++ netkat/netkat_test.cc | 5 + netkat/packet_set.cc | 54 ++++------ netkat/packet_set.h | 89 ++++------------ netkat/packet_set_benchmark.cc | 13 ++- netkat/packet_set_handle.h | 122 +++++++++++++++++++++ netkat/packet_set_test.cc | 13 ++- netkat/packet_set_test_runner.cc | 4 +- netkat/packet_transformer.cc | 78 +++++++------- netkat/packet_transformer.h | 87 ++------------- netkat/packet_transformer_handle.h | 124 +++++++++++++++++++++ netkat/packet_transformer_test.cc | 130 +++++++++++++---------- netkat/table_test.cc | 8 +- 26 files changed, 719 insertions(+), 359 deletions(-) create mode 100644 netkat/packet_set_handle.h create mode 100644 netkat/packet_transformer_handle.h diff --git a/netkat/BUILD.bazel b/netkat/BUILD.bazel index 9121fef..79ecf50 100644 --- a/netkat/BUILD.bazel +++ b/netkat/BUILD.bazel @@ -119,26 +119,9 @@ cc_test( ], ) -cc_library( +alias( name = "packet_set", - srcs = ["packet_set.cc"], - hdrs = ["packet_set.h"], - deps = [ - ":netkat_cc_proto", - ":packet", - ":packet_field", - ":paged_stable_vector", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:fixed_array", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_gutil//gutil:status", - ], + actual = ":packet_transformer", ) cc_test( @@ -147,6 +130,7 @@ cc_test( shard_count = 8, deps = [ ":evaluator", + ":gtest_utils", ":netkat_proto_constructors", ":packet", ":packet_set", @@ -229,6 +213,7 @@ cc_test( shard_count = 8, deps = [ ":evaluator", + ":gtest_utils", ":netkat_cc_proto", ":netkat_proto_constructors", ":packet", @@ -360,15 +345,38 @@ cc_test( ], ) +cc_library( + name = "packet_transformer_handle", + hdrs = ["packet_transformer_handle.h"], + deps = [ + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "packet_set_handle", + hdrs = ["packet_set_handle.h"], + deps = [ + "@com_google_absl//absl/strings:str_format", + ], +) + cc_library( name = "packet_transformer", - srcs = ["packet_transformer.cc"], - hdrs = ["packet_transformer.h"], + srcs = [ + "packet_set.cc", + "packet_transformer.cc", + ], + hdrs = [ + "packet_set.h", + "packet_transformer.h", + ], deps = [ ":netkat_cc_proto", ":packet", ":packet_field", - ":packet_set", + ":packet_set_handle", + ":packet_transformer_handle", ":paged_stable_vector", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", @@ -403,11 +411,13 @@ cc_test( shard_count = 5, deps = [ ":evaluator", + ":gtest_utils", ":netkat_cc_proto", ":netkat_proto_constructors", ":packet", ":packet_set", ":packet_transformer", + ":packet_transformer_handle", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/netkat/analysis_engine_test.cc b/netkat/analysis_engine_test.cc index cd13be2..9e7700b 100644 --- a/netkat/analysis_engine_test.cc +++ b/netkat/analysis_engine_test.cc @@ -30,6 +30,10 @@ namespace { using ::gutil::IsOk; using ::testing::Not; +// TODO: anthonyroy - Revert to ArbitraryValidPredicateProto once Pull is +// implemented. +using ::netkat::netkat_test::ArbitraryValidPredicateProtoWithoutPull; + // We include only a single `CheckEquivalent` test as a smoke test since the // function is implemented in terms of `PacketSetManager`, which is tested // thoroughly in its own unit tests. @@ -167,7 +171,7 @@ void DenyPacketsAreAlwaysDropped(PredicateProto predicate_proto) { EXPECT_TRUE(analyzer.ProgramDropsAllPackets(Policy::Accept(), predicate)); } FUZZ_TEST(AnalysisEngineTest, DenyPacketsAreAlwaysDropped) - .WithDomains(netkat_test::ArbitraryValidPredicateProto()); + .WithDomains(ArbitraryValidPredicateProtoWithoutPull()); void DenyProgramProgramDropsAllPackets(PredicateProto predicate_proto) { AnalysisEngine analyzer; @@ -178,7 +182,7 @@ void DenyProgramProgramDropsAllPackets(PredicateProto predicate_proto) { EXPECT_TRUE(analyzer.ProgramDropsAllPackets(Policy::Deny(), predicate)); } FUZZ_TEST(AnalysisEngineTest, DenyProgramProgramDropsAllPackets) - .WithDomains(netkat_test::ArbitraryValidPredicateProto()); + .WithDomains(ArbitraryValidPredicateProtoWithoutPull()); TEST(AnalysisEngineTest, PartialMatchingProgramForwardsSomePackets) { AnalysisEngine analyzer; @@ -216,7 +220,7 @@ void DenyProgramAlwaysProducesNoOutput(PredicateProto predicate_proto) { /*output_packets=*/Predicate::False())); } FUZZ_TEST(AnalysisEngineTest, DenyProgramAlwaysProducesNoOutput) - .WithDomains(netkat_test::ArbitraryValidPredicateProto()); + .WithDomains(ArbitraryValidPredicateProtoWithoutPull()); TEST(CheckInputProducesExactOutputTest, AcceptProgramReflectsInputPacket) { AnalysisEngine analyzer; diff --git a/netkat/evaluator.cc b/netkat/evaluator.cc index 3a50ea7..316da9a 100644 --- a/netkat/evaluator.cc +++ b/netkat/evaluator.cc @@ -23,6 +23,9 @@ namespace netkat { bool Evaluate(const PredicateProto& predicate, const Packet& packet) { switch (predicate.predicate_case()) { + case PredicateProto::kPullOp: + // TODO: anthonyroy - Implement Pull in the evaluator. + LOG(FATAL) << "Pull evaluation not implemented"; case PredicateProto::kBoolConstant: return predicate.bool_constant().value(); case PredicateProto::kNotOp: diff --git a/netkat/evaluator_test.cc b/netkat/evaluator_test.cc index 95ad001..39da8c7 100644 --- a/netkat/evaluator_test.cc +++ b/netkat/evaluator_test.cc @@ -19,6 +19,7 @@ #include "fuzztest/fuzztest.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "netkat/gtest_utils.h" #include "netkat/netkat.pb.h" #include "netkat/netkat_proto_constructors.h" #include "netkat/packet.h" @@ -28,6 +29,10 @@ namespace { using ::fuzztest::Arbitrary; using ::fuzztest::InRange; +// TODO: anthonyroy - Revert to ArbitraryValidPredicateProto / PolicyProto once +// Pull is implemented. +using ::netkat::netkat_test::ArbitraryValidPolicyProtoWithoutPull; +using ::netkat::netkat_test::ArbitraryValidPredicateProtoWithoutPull; using ::testing::ContainerEq; using ::testing::IsEmpty; using ::testing::IsSupersetOf; @@ -53,7 +58,9 @@ FUZZ_TEST(EvaluatePredicateProtoTest, EmptyPredicateIsFalseOnAnyPackets); void NotIsLogicalNot(Packet packet, PredicateProto negand) { EXPECT_EQ(Evaluate(NotProto(negand), packet), !Evaluate(negand, packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, NotIsLogicalNot); +FUZZ_TEST(EvaluatePredicateProtoTest, NotIsLogicalNot) + .WithDomains(Arbitrary(), + ArbitraryValidPredicateProtoWithoutPull()); void MatchOnlyMatchesPacketsWithCorrectValueAndField(Packet packet, std::string field, @@ -74,13 +81,17 @@ void AndIsLogicalAnd(Packet packet, PredicateProto left, PredicateProto right) { EXPECT_EQ(Evaluate(AndProto(left, right), packet), Evaluate(left, packet) && Evaluate(right, packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, AndIsLogicalAnd); +FUZZ_TEST(EvaluatePredicateProtoTest, AndIsLogicalAnd) + .WithDomains(Arbitrary(), ArbitraryValidPredicateProtoWithoutPull(), + ArbitraryValidPredicateProtoWithoutPull()); void OrIsLogicalOr(Packet packet, PredicateProto left, PredicateProto right) { EXPECT_EQ(Evaluate(OrProto(left, right), packet), Evaluate(left, packet) || Evaluate(right, packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, OrIsLogicalOr); +FUZZ_TEST(EvaluatePredicateProtoTest, OrIsLogicalOr) + .WithDomains(Arbitrary(), ArbitraryValidPredicateProtoWithoutPull(), + ArbitraryValidPredicateProtoWithoutPull()); /*--- Boolean algebra axioms and equivalences --------------------------------*/ @@ -88,37 +99,49 @@ void PredOrItsNegationIsTrue(const Packet& packet, const PredicateProto& predicate) { EXPECT_TRUE(Evaluate(OrProto(predicate, NotProto(predicate)), packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, PredOrItsNegationIsTrue); +FUZZ_TEST(EvaluatePredicateProtoTest, PredOrItsNegationIsTrue) + .WithDomains(Arbitrary(), + ArbitraryValidPredicateProtoWithoutPull()); void PredAndItsNegationIsFalse(const Packet& packet, const PredicateProto& predicate) { EXPECT_FALSE(Evaluate(AndProto(predicate, NotProto(predicate)), packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, PredAndItsNegationIsFalse); +FUZZ_TEST(EvaluatePredicateProtoTest, PredAndItsNegationIsFalse) + .WithDomains(Arbitrary(), + ArbitraryValidPredicateProtoWithoutPull()); void AndIsIdempotent(const Packet& packet, const PredicateProto& predicate) { EXPECT_EQ(Evaluate(AndProto(predicate, predicate), packet), Evaluate(predicate, packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, AndIsIdempotent); +FUZZ_TEST(EvaluatePredicateProtoTest, AndIsIdempotent) + .WithDomains(Arbitrary(), + ArbitraryValidPredicateProtoWithoutPull()); void AndTrueIsIdentity(const Packet& packet, const PredicateProto& predicate) { EXPECT_EQ(Evaluate(AndProto(predicate, TrueProto()), packet), Evaluate(predicate, packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, AndTrueIsIdentity); +FUZZ_TEST(EvaluatePredicateProtoTest, AndTrueIsIdentity) + .WithDomains(Arbitrary(), + ArbitraryValidPredicateProtoWithoutPull()); void AndFalseIsFalse(const Packet& packet, const PredicateProto& predicate) { EXPECT_FALSE(Evaluate(AndProto(predicate, FalseProto()), packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, AndFalseIsFalse); +FUZZ_TEST(EvaluatePredicateProtoTest, AndFalseIsFalse) + .WithDomains(Arbitrary(), + ArbitraryValidPredicateProtoWithoutPull()); void AndIsCommutative(const Packet& packet, const PredicateProto& left, const PredicateProto& right) { EXPECT_EQ(Evaluate(AndProto(left, right), packet), Evaluate(AndProto(right, left), packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, AndIsCommutative); +FUZZ_TEST(EvaluatePredicateProtoTest, AndIsCommutative) + .WithDomains(Arbitrary(), ArbitraryValidPredicateProtoWithoutPull(), + ArbitraryValidPredicateProtoWithoutPull()); void AndIsAssociative(const Packet& packet, const PredicateProto& left, const PredicateProto& middle, @@ -126,31 +149,42 @@ void AndIsAssociative(const Packet& packet, const PredicateProto& left, EXPECT_EQ(Evaluate(AndProto(AndProto(left, middle), right), packet), Evaluate(AndProto(left, AndProto(middle, right)), packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, AndIsAssociative); +FUZZ_TEST(EvaluatePredicateProtoTest, AndIsAssociative) + .WithDomains(Arbitrary(), ArbitraryValidPredicateProtoWithoutPull(), + ArbitraryValidPredicateProtoWithoutPull(), + ArbitraryValidPredicateProtoWithoutPull()); void OrIsIdempotent(const Packet& packet, const PredicateProto& predicate) { EXPECT_EQ(Evaluate(OrProto(predicate, predicate), packet), Evaluate(predicate, packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, OrIsIdempotent); +FUZZ_TEST(EvaluatePredicateProtoTest, OrIsIdempotent) + .WithDomains(Arbitrary(), + ArbitraryValidPredicateProtoWithoutPull()); void OrFalseIsIdentity(const Packet& packet, const PredicateProto& predicate) { EXPECT_EQ(Evaluate(OrProto(predicate, FalseProto()), packet), Evaluate(predicate, packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, OrFalseIsIdentity); +FUZZ_TEST(EvaluatePredicateProtoTest, OrFalseIsIdentity) + .WithDomains(Arbitrary(), + ArbitraryValidPredicateProtoWithoutPull()); void OrTrueIsTrue(const Packet& packet, const PredicateProto& predicate) { EXPECT_TRUE(Evaluate(OrProto(predicate, TrueProto()), packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, OrTrueIsTrue); +FUZZ_TEST(EvaluatePredicateProtoTest, OrTrueIsTrue) + .WithDomains(Arbitrary(), + ArbitraryValidPredicateProtoWithoutPull()); void OrIsCommutative(const Packet& packet, const PredicateProto& left, const PredicateProto& right) { EXPECT_EQ(Evaluate(OrProto(left, right), packet), Evaluate(OrProto(right, left), packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, OrIsCommutative); +FUZZ_TEST(EvaluatePredicateProtoTest, OrIsCommutative) + .WithDomains(Arbitrary(), ArbitraryValidPredicateProtoWithoutPull(), + ArbitraryValidPredicateProtoWithoutPull()); void OrIsAssociative(const Packet& packet, const PredicateProto& left, const PredicateProto& middle, @@ -158,32 +192,44 @@ void OrIsAssociative(const Packet& packet, const PredicateProto& left, EXPECT_EQ(Evaluate(OrProto(OrProto(left, middle), right), packet), Evaluate(OrProto(left, OrProto(middle, right)), packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, OrIsAssociative); +FUZZ_TEST(EvaluatePredicateProtoTest, OrIsAssociative) + .WithDomains(Arbitrary(), ArbitraryValidPredicateProtoWithoutPull(), + ArbitraryValidPredicateProtoWithoutPull(), + ArbitraryValidPredicateProtoWithoutPull()); void XorFalseIsIdentity(const Packet& packet, const PredicateProto& predicate) { EXPECT_EQ(Evaluate(XorProto(predicate, FalseProto()), packet), Evaluate(predicate, packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, XorFalseIsIdentity); +FUZZ_TEST(EvaluatePredicateProtoTest, XorFalseIsIdentity) + .WithDomains(Arbitrary(), + ArbitraryValidPredicateProtoWithoutPull()); void XorSelfIsFalse(const Packet& packet, const PredicateProto& pred) { EXPECT_FALSE(Evaluate(XorProto(pred, pred), packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, XorSelfIsFalse); +FUZZ_TEST(EvaluatePredicateProtoTest, XorSelfIsFalse) + .WithDomains(Arbitrary(), + ArbitraryValidPredicateProtoWithoutPull()); void XorIsCommutative(const Packet& packet, const PredicateProto& left, PredicateProto right) { EXPECT_EQ(Evaluate(XorProto(left, right), packet), Evaluate(XorProto(right, left), packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, XorIsCommutative); +FUZZ_TEST(EvaluatePredicateProtoTest, XorIsCommutative) + .WithDomains(Arbitrary(), ArbitraryValidPredicateProtoWithoutPull(), + ArbitraryValidPredicateProtoWithoutPull()); void XorIsAssociative(const Packet& packet, const PredicateProto& left, const PredicateProto& middle, PredicateProto right) { EXPECT_EQ(Evaluate(XorProto(XorProto(left, middle), right), packet), Evaluate(XorProto(left, XorProto(middle, right)), packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, XorIsAssociative); +FUZZ_TEST(EvaluatePredicateProtoTest, XorIsAssociative) + .WithDomains(Arbitrary(), ArbitraryValidPredicateProtoWithoutPull(), + ArbitraryValidPredicateProtoWithoutPull(), + ArbitraryValidPredicateProtoWithoutPull()); void DistributiveLawHolds(const Packet& packet, const PredicateProto& first, const PredicateProto& second, @@ -198,7 +244,10 @@ void DistributiveLawHolds(const Packet& packet, const PredicateProto& first, Evaluate(AndProto(OrProto(first, third), OrProto(second, third)), packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, DistributiveLawHolds); +FUZZ_TEST(EvaluatePredicateProtoTest, DistributiveLawHolds) + .WithDomains(Arbitrary(), ArbitraryValidPredicateProtoWithoutPull(), + ArbitraryValidPredicateProtoWithoutPull(), + ArbitraryValidPredicateProtoWithoutPull()); void DeMorganHolds(const Packet& packet, const PredicateProto& left, const PredicateProto& right) { @@ -210,7 +259,9 @@ void DeMorganHolds(const Packet& packet, const PredicateProto& left, EXPECT_EQ(Evaluate(NotProto(OrProto(left, right)), packet), Evaluate(AndProto(NotProto(left), NotProto(right)), packet)); } -FUZZ_TEST(EvaluatePredicateProtoTest, DeMorganHolds); +FUZZ_TEST(EvaluatePredicateProtoTest, DeMorganHolds) + .WithDomains(Arbitrary(), ArbitraryValidPredicateProtoWithoutPull(), + ArbitraryValidPredicateProtoWithoutPull()); /*--- Basic policy properties ------------------------------------------------*/ @@ -222,7 +273,9 @@ void LiftedEvaluationIsCorrect(absl::flat_hash_set packets, } EXPECT_THAT(Evaluate(policy, packets), ContainerEq(expected_packets)); } -FUZZ_TEST(EvaluatePolicyProtoTest, LiftedEvaluationIsCorrect); +FUZZ_TEST(EvaluatePolicyProtoTest, LiftedEvaluationIsCorrect) + .WithDomains(Arbitrary>(), + ArbitraryValidPolicyProtoWithoutPull()); void RecordIsAccept(Packet packet) { EXPECT_THAT(Evaluate(RecordProto(), packet), UnorderedElementsAre(packet)); @@ -242,7 +295,9 @@ void FilterIsCorrect(Packet packet, PredicateProto predicate) { EXPECT_THAT(Evaluate(FilterProto(predicate), packet), IsEmpty()); } } -FUZZ_TEST(EvaluatePolicyProtoTest, FilterIsCorrect); +FUZZ_TEST(EvaluatePolicyProtoTest, FilterIsCorrect) + .WithDomains(Arbitrary(), + ArbitraryValidPredicateProtoWithoutPull()); void ModifyModifies(Packet packet, std::string field, int value) { Packet expected_packet = packet; @@ -259,7 +314,9 @@ void UnionCombines(Packet packet, PolicyProto left, PolicyProto right) { EXPECT_THAT(Evaluate(UnionProto(left, right), packet), ContainerEq(expected_packets)); } -FUZZ_TEST(EvaluatePolicyProtoTest, UnionCombines); +FUZZ_TEST(EvaluatePolicyProtoTest, UnionCombines) + .WithDomains(Arbitrary(), ArbitraryValidPolicyProtoWithoutPull(), + ArbitraryValidPolicyProtoWithoutPull()); void DifferenceRemoves(Packet packet, PolicyProto left, PolicyProto right) { absl::flat_hash_set expected_packets = Evaluate(left, packet); @@ -270,7 +327,9 @@ void DifferenceRemoves(Packet packet, PolicyProto left, PolicyProto right) { EXPECT_THAT(Evaluate(DifferenceProto(left, right), packet), ContainerEq(expected_packets)); } -FUZZ_TEST(EvaluatePolicyProtoTest, DifferenceRemoves); +FUZZ_TEST(EvaluatePolicyProtoTest, DifferenceRemoves) + .WithDomains(Arbitrary(), ArbitraryValidPolicyProtoWithoutPull(), + ArbitraryValidPolicyProtoWithoutPull()); void SequenceSequences(Packet packet, PolicyProto left, PolicyProto right) { absl::flat_hash_set expected_packets = @@ -279,7 +338,9 @@ void SequenceSequences(Packet packet, PolicyProto left, PolicyProto right) { EXPECT_THAT(Evaluate(SequenceProto(left, right), packet), ContainerEq(expected_packets)); } -FUZZ_TEST(EvaluatePolicyProtoTest, SequenceSequences); +FUZZ_TEST(EvaluatePolicyProtoTest, SequenceSequences) + .WithDomains(Arbitrary(), ArbitraryValidPolicyProtoWithoutPull(), + ArbitraryValidPolicyProtoWithoutPull()); PolicyProto UnionUpToNthPower(PolicyProto iterable, int n) { PolicyProto union_policy = AcceptProto(); @@ -298,7 +359,7 @@ void IterateIsSupersetOfUnionOfNSequences(Packet packet, PolicyProto iterable, } FUZZ_TEST(EvaluatePolicyProtoTest, IterateIsSupersetOfUnionOfNSequences) .WithDomains(/*packet=*/Arbitrary(), - /*iterable=*/Arbitrary(), + /*iterable=*/ArbitraryValidPolicyProtoWithoutPull(), /*n=*/InRange(0, 100)); void IterateIsUnionOfNSequencesForSomeN(Packet packet, PolicyProto iterable) { @@ -318,7 +379,8 @@ void IterateIsUnionOfNSequencesForSomeN(Packet packet, PolicyProto iterable) { EXPECT_THAT(iterate_output_packets, ContainerEq(union_output_packets)); } -FUZZ_TEST(EvaluatePolicyProtoTest, IterateIsUnionOfNSequencesForSomeN); +FUZZ_TEST(EvaluatePolicyProtoTest, IterateIsUnionOfNSequencesForSomeN) + .WithDomains(Arbitrary(), ArbitraryValidPolicyProtoWithoutPull()); TEST(EvaluatePolicyProtoTest, SimpleIterateThroughFiltersAndModifies) { // f == 0; f:=1 + f == 1; f := 2 + f == 2; f := 3 diff --git a/netkat/frontend.cc b/netkat/frontend.cc index e51ad91..c64101a 100644 --- a/netkat/frontend.cc +++ b/netkat/frontend.cc @@ -11,9 +11,24 @@ #include "netkat/netkat_proto_constructors.h" namespace netkat { +// Forward declaration needed for mutual recursion: PredicateProto::Pull +// contains a PolicyProto, which in turn can contain Filter policies containing +// PredicateProtos. +absl::Status RecursivelyCheckIsValid(const PolicyProto& policy_proto); + // Recursively checks whether `predicate_proto` is valid. absl::Status RecursivelyCheckIsValid(const PredicateProto& predicate_proto) { switch (predicate_proto.predicate_case()) { + case PredicateProto::kPullOp: { + RETURN_IF_ERROR(RecursivelyCheckIsValid(predicate_proto.pull_op().left())) + .SetPrepend() + << "PredicateProto::Pull's policy (left) is invalid: "; + RETURN_IF_ERROR( + RecursivelyCheckIsValid(predicate_proto.pull_op().right())) + .SetPrepend() + << "PredicateProto::Pull's predicate (right) is invalid: "; + return absl::OkStatus(); + } case PredicateProto::PREDICATE_NOT_SET: return absl::InvalidArgumentError("Unset Predicate case is invalid"); case PredicateProto::kMatch: @@ -91,6 +106,11 @@ Predicate Match(absl::string_view field, int value) { return Predicate(MatchProto(field, value)); } +Predicate Pull(Policy policy, Predicate predicate) { + return Predicate( + PullProto(std::move(policy).ToProto(), std::move(predicate).ToProto())); +} + absl::Status RecursivelyCheckIsValid(const PolicyProto& policy_proto) { switch (policy_proto.policy_case()) { case PolicyProto::kFilter: diff --git a/netkat/frontend.h b/netkat/frontend.h index 238994c..d0b01a2 100644 --- a/netkat/frontend.h +++ b/netkat/frontend.h @@ -113,6 +113,7 @@ class Predicate { // Match operation for a Predicate. See below for the full definition. We // utilize friend association to ensure program construction is well-formed. friend Predicate Match(absl::string_view, int); + friend Predicate Pull(class Policy, Predicate); private: // Hide default proto construction to hinder building of ill-formed programs. @@ -221,6 +222,17 @@ class Policy { PolicyProto policy_; }; +// Pulls a `predicate` back through a `policy`. +// +// Semantically, `Pull(policy, predicate)` is a predicate that matches an input +// packet if and only if processing that packet with `policy` can yield at +// least one output packet that satisfies `predicate`. +// +// For example, `Pull(Modify("f", 1), Match("f", 1))` is equivalent to `True`, +// because modifying "f" to 1 will always produce a packet that matches "f==1". +// Conversely, `Pull(Modify("f", 2), Match("f", 1))` is equivalent to `False`. +Predicate Pull(Policy policy, Predicate predicate); + // Returns a policy that filters packets by `predicate`. Policy Filter(Predicate predicate); diff --git a/netkat/frontend_test.cc b/netkat/frontend_test.cc index dc25c10..e041e78 100644 --- a/netkat/frontend_test.cc +++ b/netkat/frontend_test.cc @@ -70,6 +70,9 @@ PredicateProto InvalidPredicateProto(PredicateProto predicate_proto) { case PredicateProto::kBoolConstant: predicate_proto.Clear(); break; + case PredicateProto::kPullOp: + predicate_proto.mutable_pull_op()->clear_right(); + break; } return predicate_proto; } @@ -126,6 +129,15 @@ FUZZ_TEST(FrontEndTest, XorToProtoIsCorrect) .WithDomains(/*lhs=*/AtomicPredicateDomain(), /*rhs=*/AtomicPredicateDomain()); +void PullToProtoIsCorrect(Policy policy, Predicate predicate) { + Predicate pull_pred = Pull(policy, predicate); + EXPECT_THAT(pull_pred.ToProto(), + EqualsProto(PullProto(policy.ToProto(), predicate.ToProto()))); +} +FUZZ_TEST(FrontEndTest, PullToProtoIsCorrect) + .WithDomains(/*policy=*/AtomicDupFreePolicyDomain(), + /*predicate=*/AtomicPredicateDomain()); + void OperationOrderIsPreserved(Predicate a, Predicate b, Predicate c) { Predicate abc = !(a || b) && c || a; EXPECT_THAT( diff --git a/netkat/gtest_utils.cc b/netkat/gtest_utils.cc index d1ff939..57d27db 100644 --- a/netkat/gtest_utils.cc +++ b/netkat/gtest_utils.cc @@ -12,15 +12,6 @@ using ::fuzztest::Just; using ::fuzztest::Map; using ::fuzztest::OneOf; -namespace { - -template -bool FieldTypeIs(const google::protobuf::FieldDescriptor* field) { - return field->message_type() == T::descriptor(); -}; - -} // namespace - fuzztest::Domain ArbitraryValidPredicateProto() { return fuzztest::Arbitrary() // The domain will recursively set all fields. This ensures @@ -31,7 +22,27 @@ fuzztest::Domain ArbitraryValidPredicateProto() { .WithProtobufFields( FieldTypeIs, fuzztest::Arbitrary().WithStringFieldAlwaysSet( - "field", fuzztest::String().WithMinSize(1))); + "field", fuzztest::String().WithMinSize(1))) + // The domain will ensure all PolicyProto::Modification::field will be + // non-empty (needed for Pull). + .WithProtobufFields(FieldTypeIs, + fuzztest::Arbitrary() + .WithStringFieldAlwaysSet( + "field", fuzztest::String().WithMinSize(1))); +} + +fuzztest::Domain ArbitraryValidPredicateProtoWithoutPull() { + return fuzztest::Arbitrary() + // The domain will recursively set all fields. This ensures + // PredicateProto will have its members PredicateProto set. + .WithFieldsAlwaysSet() + // The domain will ensure all PredicateProto::Match::field will be + // non-empty. + .WithProtobufFields( + FieldTypeIs, + fuzztest::Arbitrary().WithStringFieldAlwaysSet( + "field", fuzztest::String().WithMinSize(1))) + .WithFieldsUnset(FieldTypeIs); } fuzztest::Domain ArbitraryValidPolicyProto() { @@ -49,6 +60,21 @@ fuzztest::Domain ArbitraryValidPolicyProto() { ArbitraryValidPredicateProto()); } +fuzztest::Domain ArbitraryValidPolicyProtoWithoutPull() { + return fuzztest::Arbitrary() + // The domain will recursively set all fields. This ensures + // PolicyProto will have its members PolicyProto set. + .WithFieldsAlwaysSet() + // The domain will ensure all PolicyProto::Modification::field will be + // non-empty. + .WithProtobufFields(FieldTypeIs, + fuzztest::Arbitrary() + .WithStringFieldAlwaysSet( + "field", fuzztest::String().WithMinSize(1))) + .WithProtobufFields(FieldTypeIs, + ArbitraryValidPredicateProtoWithoutPull()); +} + fuzztest::Domain AtomicPredicateDomain() { return OneOf(Just(Predicate::True()), Just(Predicate::False()), Map([](absl::string_view field, diff --git a/netkat/gtest_utils.h b/netkat/gtest_utils.h index 6bf9768..35218bc 100644 --- a/netkat/gtest_utils.h +++ b/netkat/gtest_utils.h @@ -23,10 +23,16 @@ #define GOOGLE_NETKAT_NETKAT_GTEST_UTILS_H_ #include "fuzztest/fuzztest.h" +#include "google/protobuf/descriptor.h" #include "netkat/frontend.h" namespace netkat::netkat_test { +template +bool FieldTypeIs(const google::protobuf::FieldDescriptor* field) { + return field->message_type() == T::descriptor(); +} + // Returns a FUZZ_TEST domain for an arbitrary valid PredicateProto. // See netkat::Predicate::FromProto for the definition of a valid // PredicateProto. @@ -34,12 +40,18 @@ namespace netkat::netkat_test { // defined to mean false. fuzztest::Domain ArbitraryValidPredicateProto(); +// Same as ArbitraryValidPredicateProto but without Pull. +fuzztest::Domain ArbitraryValidPredicateProtoWithoutPull(); + // Returns a FUZZ_TEST domain for an arbitrary valid PolicyProto. // See netkat::Policy::FromProto for the definition of a valid PolicyProto. // Nonetheless, invalid protos are accepted in the backend where empty is // defined to mean DENY policy. fuzztest::Domain ArbitraryValidPolicyProto(); +// Same as ArbitraryValidPolicyProto but without Pull. +fuzztest::Domain ArbitraryValidPolicyProtoWithoutPull(); + // Returns a FUZZ_TEST domain for an arbitrary, atomic Predicate. I.e., the // predicate may be any of: an arbitrary Match, or the True/False predicates. fuzztest::Domain AtomicPredicateDomain(); diff --git a/netkat/manager_handle_pattern.md b/netkat/manager_handle_pattern.md index e2c81b4..98b8437 100644 --- a/netkat/manager_handle_pattern.md +++ b/netkat/manager_handle_pattern.md @@ -26,19 +26,28 @@ combining them, or inspecting the underlying sets, one must call methods on the manager class, which acts as an arena allocator that owns all memory associated with the handles. -For example: ``` // We need a manager to construct handles. PacketSetManager -manager; PacketSetHandle a = manager.EmptySet() PacketSetHandle b = -manager.Match("src_mac", 0xFF'FF'FF'FF); +For example: -// Handles can be compared and hashed without the help of the manager, // but -that's about it. CHECK(a != b); absl::flat_hash_map ab_set{a, -b}; +``` +// We need a manager to construct handles. +PacketSetManager manager; +PacketSetHandle a = manager.EmptySet() +PacketSetHandle b = manager.Match("src_mac", 0xFF'FF'FF'FF); + +// Handles can be compared and hashed without the help of the manager, +// but that's about it. +CHECK(a != b); +absl::flat_hash_map ab_set{a, b}; // To do interesting things with the handles, we need the manager. -PacketSetHandle c = manager.And(a, b); // The set union of `a` and `b`. -PacketSetHandle not_c = manager.Not(c); // The set complement of `c`. if -(manager.Contains(c, packet)) { CHECK(!manager.Contains(not_c, packet)); } else -{ CHECK(manager.Contains(not_c, packet)); } ``` +PacketSetHandle c = manager.And(a, b); // The set union of `a` and `b`. +PacketSetHandle not_c = manager.Not(c); // The set complement of `c`. +if (manager.Contains(c, packet)) { + CHECK(!manager.Contains(not_c, packet)); +} else { + CHECK(manager.Contains(not_c, packet)); +} +``` ## Motivation for Using the Pattern diff --git a/netkat/netkat.proto b/netkat/netkat.proto index ccc5779..1c89ee5 100644 --- a/netkat/netkat.proto +++ b/netkat/netkat.proto @@ -62,6 +62,7 @@ message PredicateProto { Or or_op = 4; Not not_op = 5; Xor xor_op = 6; + Pull pull_op = 7; } // A boolean constant, i.e. true or false. Equivalent to Accept/Deny. @@ -100,6 +101,12 @@ message PredicateProto { PredicateProto left = 1; PredicateProto right = 2; } + + // Pull policy into predicate, i.e. pull(left, right) + message Pull { + PolicyProto left = 1; + PredicateProto right = 2; + } } // The intermediate representation of a NetKAT policy. diff --git a/netkat/netkat_proto_constructors.cc b/netkat/netkat_proto_constructors.cc index 33d9664..cdd630d 100644 --- a/netkat/netkat_proto_constructors.cc +++ b/netkat/netkat_proto_constructors.cc @@ -71,6 +71,13 @@ PredicateProto XorProto(PredicateProto left, PredicateProto right) { *xor_op.mutable_right() = std::move(right); return proto; } +PredicateProto PullProto(PolicyProto left, PredicateProto right) { + PredicateProto proto; + PredicateProto::Pull& pull = *proto.mutable_pull_op(); + *pull.mutable_left() = std::move(left); + *pull.mutable_right() = std::move(right); + return proto; +} // -- Basic Policy constructors ------------------------------------------------ @@ -130,6 +137,10 @@ std::string AsShorthandString(PredicateProto predicate) { switch (predicate.predicate_case()) { case PredicateProto::kBoolConstant: return predicate.bool_constant().value() ? "true" : "false"; + case PredicateProto::kPullOp: + return absl::StrFormat("pull(%s, %s)", + AsShorthandString(predicate.pull_op().left()), + AsShorthandString(predicate.pull_op().right())); case PredicateProto::kMatch: return absl::StrFormat("@%s==%d", predicate.match().field(), predicate.match().value()); diff --git a/netkat/netkat_proto_constructors.h b/netkat/netkat_proto_constructors.h index c3c188b..1e723c5 100644 --- a/netkat/netkat_proto_constructors.h +++ b/netkat/netkat_proto_constructors.h @@ -38,6 +38,7 @@ PredicateProto AndProto(PredicateProto left, PredicateProto right); PredicateProto OrProto(PredicateProto left, PredicateProto right); PredicateProto NotProto(PredicateProto negand); PredicateProto XorProto(PredicateProto left, PredicateProto right); +PredicateProto PullProto(PolicyProto left, PredicateProto right); // -- Basic Policy constructors ------------------------------------------------ @@ -61,6 +62,7 @@ PolicyProto AcceptProto(); // Predicate Or -> '||' // Predicate Not -> '!' // Predicate Xor -> '(+)' +// Predicate Pull -> 'pull(left, right)' // Policy Sequence -> ';' // Policy Or -> '+' // Iterate -> '*' diff --git a/netkat/netkat_proto_constructors_test.cc b/netkat/netkat_proto_constructors_test.cc index 34609bf..08b6bfb 100644 --- a/netkat/netkat_proto_constructors_test.cc +++ b/netkat/netkat_proto_constructors_test.cc @@ -79,6 +79,15 @@ void XorProtoReturnsXor(PredicateProto left, PredicateProto right) { } FUZZ_TEST(XorProtoTest, XorProtoReturnsXor); +void PullProtoReturnsPull(PolicyProto left, PredicateProto right) { + PredicateProto pull_proto; + PredicateProto::Pull& pull = *pull_proto.mutable_pull_op(); + *pull.mutable_left() = left; + *pull.mutable_right() = right; + EXPECT_THAT(PullProto(left, right), EqualsProto(pull_proto)); +} +FUZZ_TEST(PullProtoTest, PullProtoReturnsPull); + // -- Basic Policy constructors ------------------------------------------------ void FilterProtoReturnsFilter(PredicateProto filter) { @@ -198,6 +207,12 @@ TEST(AsShorthandStringTest, NegationIsOkay) { "!(true || false)"); } +TEST(AsShorthandStringTest, PullIsCorrect) { + EXPECT_EQ(AsShorthandString(PullProto(ModificationProto("field", 2), + MatchProto("field", 1))), + "pull(@field:=2, @field==1)"); +} + TEST(AsShorthandStringTest, ModifyIsCorrect) { EXPECT_EQ(AsShorthandString(ModificationProto("field", 2)), "@field:=2"); } diff --git a/netkat/netkat_test.cc b/netkat/netkat_test.cc index 8477b69..613d82e 100644 --- a/netkat/netkat_test.cc +++ b/netkat/netkat_test.cc @@ -66,6 +66,11 @@ TEST(NetkatProtoTest, PredicateOneOfFieldNamesDontRequireUnderscores) { LOG(INFO) << "bool_constant: " << bool_constant; break; } + case PredicateProto::kPullOp: { + const PredicateProto::Pull& pull_op = predicate.pull_op(); + LOG(INFO) << "pull_op: " << pull_op; + break; + } case PredicateProto::PREDICATE_NOT_SET: break; } diff --git a/netkat/packet_set.cc b/netkat/packet_set.cc index 2c7bd1c..5905b1c 100644 --- a/netkat/packet_set.cc +++ b/netkat/packet_set.cc @@ -33,42 +33,20 @@ #include "absl/strings/string_view.h" #include "gutil/status.h" #include "netkat/packet.h" +#include "netkat/packet_transformer.h" +#include "netkat/packet_transformer_handle.h" namespace netkat { -// The empty and full set of packets are not decision nodes, and thus we cannot -// associate an index into the `nodes_` vector with them. Instead, we represent -// them using sentinel values, chosen maximally to avoid collisions with proper -// indices. -enum SentinelNodeIndex : uint32_t { - // Encodes the empty set of packets. - kEmptySet = std::numeric_limits::max(), - // Encodes the full set of packets. - kFullSet = std::numeric_limits::max() - 1, - // The minimum sentinel node index. - // Smaller values are reserved for proper indices into the `nodes_` vector. - kMinSentinel = kFullSet, -}; - -PacketSetHandle::PacketSetHandle() - : node_index_(SentinelNodeIndex::kEmptySet) {} - -std::string PacketSetHandle::ToString() const { - if (node_index_ == SentinelNodeIndex::kEmptySet) { - return "PacketSetHandle"; - } else if (node_index_ == SentinelNodeIndex::kFullSet) { - return "PacketSetHandle"; - } else { - return absl::StrFormat("PacketSetHandle<%d>", node_index_); - } -} +PacketSetManager::PacketSetManager(PacketTransformerManager& transformer) + : transformer_(&transformer) {} PacketSetHandle PacketSetManager::EmptySet() const { - return PacketSetHandle(SentinelNodeIndex::kEmptySet); + return PacketSetHandle(PacketSetHandle::kEmptySet); } PacketSetHandle PacketSetManager::FullSet() const { - return PacketSetHandle(SentinelNodeIndex::kFullSet); + return PacketSetHandle(PacketSetHandle::kFullSet); } bool PacketSetManager::IsEmptySet(PacketSetHandle packet_set) const { @@ -113,7 +91,7 @@ PacketSetHandle PacketSetManager::NodeToPacket(DecisionNode&& node) { packet_by_node_.try_emplace(node, PacketSetHandle(nodes_.size())); if (inserted) { nodes_.push_back(std::move(node)); - LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) + LOG_IF(DFATAL, nodes_.size() > PacketSetHandle::kMinSentinel) << "Internal invariant violated: Proper and sentinel node indices must " "be disjoint. This indicates that we allocated more nodes than are " "supported (> 2^32 - 2)."; @@ -147,21 +125,21 @@ std::string PacketSetManager::ToDot(PacketSetHandle packet_set) const { work_list.push(packet_set); if (IsFullSet(packet_set)) { absl::StrAppendFormat(&result, " %d [label=\"T\" shape=box]\n", - SentinelNodeIndex::kFullSet); + PacketSetHandle::kFullSet); absl::StrAppend(&result, "}\n"); return result; } if (IsEmptySet(packet_set)) { absl::StrAppendFormat(&result, " %d [label=\"F\" shape=box]\n", - SentinelNodeIndex::kEmptySet); + PacketSetHandle::kEmptySet); absl::StrAppend(&result, "}\n"); return result; } absl::flat_hash_set visited = {packet_set}; absl::StrAppendFormat(&result, " %d [label=\"T\" shape=box]\n", - SentinelNodeIndex::kFullSet); + PacketSetHandle::kFullSet); absl::StrAppendFormat(&result, " %d [label=\"F\" shape=box]\n", - SentinelNodeIndex::kEmptySet); + PacketSetHandle::kEmptySet); while (!work_list.empty()) { PacketSetHandle packet_set = work_list.front(); @@ -194,6 +172,14 @@ std::string PacketSetManager::ToDot(PacketSetHandle packet_set) const { PacketSetHandle PacketSetManager::Compile(const PredicateProto& pred) { ProtoHashKey key = {.predicate_case = pred.predicate_case()}; switch (pred.predicate_case()) { + case PredicateProto::kPullOp: { + key.lhs_policy_handle = transformer_->Compile(pred.pull_op().left()); + key.rhs_child = Compile(pred.pull_op().right()); + auto it = packet_set_by_hash_.find(key); + if (it != packet_set_by_hash_.end()) return it->second; + return packet_set_by_hash_[key] = + transformer_->Pull(key.lhs_policy_handle, key.rhs_child); + } case PredicateProto::kBoolConstant: { return pred.bool_constant().value() ? FullSet() : EmptySet(); } @@ -481,7 +467,7 @@ std::string PacketSetManager::ToString(const DecisionNode& node) const { absl::Status PacketSetManager::CheckInternalInvariants() const { // Invariant: Proper and sentinel node indices are disjoint. - RET_CHECK(nodes_.size() <= SentinelNodeIndex::kMinSentinel); + RET_CHECK(nodes_.size() <= PacketSetHandle::kMinSentinel); // Invariant: `packet_by_node_[n] = s` iff `nodes_[s.node_index_] == n`. for (const auto& [node, packet] : packet_by_node_) { diff --git a/netkat/packet_set.h b/netkat/packet_set.h index f11088f..108093c 100644 --- a/netkat/packet_set.h +++ b/netkat/packet_set.h @@ -16,7 +16,7 @@ // File: packet_set.h // ----------------------------------------------------------------------------- // -// Defines `PacketSetHandle` and its companion class `PacketSetManager` +// Defines `PacketSetManager`, the companion class to `PacketSetHandle` // following the manager-handle pattern described in // `manager_handle_pattern.md`. Together, they provide an often compact and // efficient representation of large and even infinite packet sets, exploiting @@ -58,69 +58,12 @@ #include "netkat/netkat.pb.h" #include "netkat/packet.h" #include "netkat/packet_field.h" +#include "netkat/packet_set_handle.h" +#include "netkat/packet_transformer_handle.h" #include "netkat/paged_stable_vector.h" namespace netkat { -// A lightweight handle (32 bits) representing a set of packets. The -// representation can efficiently encode typical large and even infinite sets -// seen in practice. -// -// The APIs of this object are almost entirely defined as methods of the -// companion class `PacketSetManager`, following the manager-handle pattern -// described in `manager_handle_pattern.md`. -// -// CAUTION: Each `PacketSetHandle` is implicitly associated with the manager -// object that created it; using it with a different manager has undefined -// behavior. -// -// This data structure enjoys the following powerful *canonicity property*: two -// handles represent the same set if and only if they have the same memory -// representation. Since the memory representation is just 32 bits, semantic set -// equality is cheap: O(1)! -class [[nodiscard]] PacketSetHandle { - public: - // Default constructor: the empty set of packets. - PacketSetHandle(); - - // Two packet set handles compare equal iff they represent the same set of - // concrete packets. Comparison is O(1), thanks to interning/hash-consing. - friend auto operator<=>(PacketSetHandle a, PacketSetHandle b) = default; - - // Hashing, see https://abseil.io/docs/cpp/guides/hash. - template - friend H AbslHashValue(H h, PacketSetHandle packet_set) { - return H::combine(std::move(h), packet_set.node_index_); - } - - // Formatting, see https://abseil.io/docs/cpp/guides/abslstringify. - // NOTE: These functions do not produce particularly useful output. Instead, - // use `PacketSetManager::ToString(packet_set)` whenever possible. - template - friend void AbslStringify(Sink& sink, PacketSetHandle packet_set) { - absl::Format(&sink, "%s", packet_set.ToString()); - } - std::string ToString() const; - - private: - // An index into the `nodes_` vector of the `PacketSetManager` object - // associated with this `PacketSetHandle`. The semantics of this packet set - // is entirely determined by the node `nodes_[node_index_]`. The index is - // otherwise arbitrary and meaningless. - // - // We use a 32-bit index as a tradeoff between minimizing memory usage and - // maximizing the number of `PacketSetHandle`s that can be created, both - // aspects that impact how well we scale to large NetKAT models. We expect - // millions, but not billions, of packet sets in practice, and 2^32 ~= 4 - // billion. - uint32_t node_index_; - explicit PacketSetHandle(uint32_t node_index) : node_index_(node_index) {} - friend class PacketSetManager; -}; - -// Protect against regressions in the memory layout, as it affects performance. -static_assert(sizeof(PacketSetHandle) <= 4); - // An "arena" in which `PacketSetHandle`s can be created and manipulated // (following the manager-handle pattern, see `manager_handle_pattern.md`). // @@ -131,18 +74,20 @@ static_assert(sizeof(PacketSetHandle) <= 4); // CAUTION: Using a `PacketSetHandle` returned by one `PacketSetManager` // object with a different manager is undefined behavior. // +// This class is not constructible or movable publicly. It must be managed +// by a `PacketTransformerManager` (which owns a `PacketSetManager` by value). +// This restriction ensures that the `PacketSetManager` always has a valid +// `PacketTransformerManager` context, which is required to compile `Pull` +// operations (as they compile down to policies). +// // TODO(b/398303840): Persistent use of an `PacketSetManager` object can // incur unbounded memory growth. Consider adding some garbage collection // mechanism. class PacketSetManager { public: - PacketSetManager() = default; - - // The class is move-only: not copyable, but movable. + // The class is move-only: not copyable. PacketSetManager(const PacketSetManager&) = delete; PacketSetManager& operator=(const PacketSetManager&) = delete; - PacketSetManager(PacketSetManager&&) = default; - PacketSetManager& operator=(PacketSetManager&&) = default; // Returns true iff this packet set represents the empty set of packets. bool IsEmptySet(PacketSetHandle packet_set) const; @@ -304,10 +249,14 @@ class PacketSetManager { // The `PredicateProto` oneof case. int predicate_case; - // The left child, if `predicate_case` is an operation. In the case + // The left child, if `predicate_case` is a predicate operation. In the case // `predicate_case` is unary, e.g. Not, this will be the child. PacketSetHandle lhs_child; + // The left child policy, if `predicate_case` is a policy-predicate + // operation (e.g., Pull). Otherwise defaulted. + PacketTransformerHandle lhs_policy_handle; + // The right child, if `predicate_case` is an operation. In the case // `predicate_case` is unary, e.g. Not, this will be defaulted. PacketSetHandle rhs_child; @@ -318,7 +267,7 @@ class PacketSetManager { template friend H AbslHashValue(H h, const ProtoHashKey& key) { return H::combine(std::move(h), key.predicate_case, key.lhs_child, - key.rhs_child); + key.lhs_policy_handle, key.rhs_child); } }; @@ -372,6 +321,12 @@ class PacketSetManager { // INVARIANT: All `DecisionNode` fields are interned by this manager. PacketFieldManager field_manager_; + explicit PacketSetManager(class PacketTransformerManager& transformer); + PacketSetManager(PacketSetManager&&) = default; + PacketSetManager& operator=(PacketSetManager&&) = default; + + class PacketTransformerManager* transformer_ = nullptr; + // Allow `PacketTransformerManager` to access private methods. friend class PacketTransformerManager; friend class PacketTransformerManagerTestPeer; diff --git a/netkat/packet_set_benchmark.cc b/netkat/packet_set_benchmark.cc index 7de2fd0..5c9a6ad 100644 --- a/netkat/packet_set_benchmark.cc +++ b/netkat/packet_set_benchmark.cc @@ -19,6 +19,7 @@ #include "netkat/netkat.pb.h" #include "netkat/netkat_proto_constructors.h" #include "netkat/packet_set.h" +#include "netkat/packet_transformer.h" namespace netkat { // Create an arbitrary fixed policy with some relative complexity. In this @@ -56,7 +57,8 @@ void BM_FirstTimeCompileNonOverlappingPredicate(benchmark::State& state) { PredicateProto policy = OrProto(sub_policy1, sub_policy2); for (auto s : state) { - PacketSetManager manager; + PacketTransformerManager transformer; + PacketSetManager& manager = transformer.GetPacketSetManager(); PacketSetHandle handle = manager.Compile(policy); benchmark::DoNotOptimize(handle); } @@ -75,7 +77,8 @@ void BM_ReCompileNonOverlappingPredicate(benchmark::State& state) { CreateFixedArbitraryPredicateProto(/*id_suffix=*/4)); PredicateProto policy = OrProto(sub_policy1, sub_policy2); - PacketSetManager manager; + PacketTransformerManager transformer; + PacketSetManager& manager = transformer.GetPacketSetManager(); PacketSetHandle handle = manager.Compile(policy); for (auto s : state) { handle = manager.Compile(policy); @@ -94,7 +97,8 @@ void BM_FirstTimeCompileOverlappingPredicate(benchmark::State& state) { PredicateProto policy = OrProto(sub_policy1, sub_policy2); for (auto s : state) { - PacketSetManager manager; + PacketTransformerManager transformer; + PacketSetManager& manager = transformer.GetPacketSetManager(); PacketSetHandle handle = manager.Compile(policy); benchmark::DoNotOptimize(handle); } @@ -111,7 +115,8 @@ void BM_ReCompileOverlappingPredicate(benchmark::State& state) { CreateFixedArbitraryPredicateProto()); PredicateProto policy = OrProto(sub_policy1, sub_policy2); - PacketSetManager manager; + PacketTransformerManager transformer; + PacketSetManager& manager = transformer.GetPacketSetManager(); PacketSetHandle handle = manager.Compile(policy); for (auto s : state) { handle = manager.Compile(policy); diff --git a/netkat/packet_set_handle.h b/netkat/packet_set_handle.h new file mode 100644 index 0000000..c5fd8f3 --- /dev/null +++ b/netkat/packet_set_handle.h @@ -0,0 +1,122 @@ +// Copyright 2025 The NetKAT authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ----------------------------------------------------------------------------- +// File: packet_set_handle.h +// ----------------------------------------------------------------------------- +// +// Defines `PacketSetHandle`, a lightweight handle representing a set of +// packets. +// +// Together with its companion class `PacketSetManager` (defined in +// `packet_set.h`), they provide an often compact and efficient representation +// of large and even infinite packet sets, exploiting structural properties that +// packet sets seen in practice typically exhibit. + +#ifndef GOOGLE_NETKAT_NETKAT_PACKET_SET_HANDLE_H_ +#define GOOGLE_NETKAT_NETKAT_PACKET_SET_HANDLE_H_ + +#include +#include +#include +#include + +#include "absl/strings/str_format.h" + +namespace netkat { + +// A lightweight handle (32 bits) representing a set of packets. The +// representation can efficiently encode typical large and even infinite sets +// seen in practice. +// +// The APIs of this object are almost entirely defined as methods of the +// companion class `PacketSetManager`, following the manager-handle pattern +// described in `manager_handle_pattern.md`. +// +// CAUTION: Each `PacketSetHandle` is implicitly associated with the manager +// object that created it; using it with a different manager has undefined +// behavior. +// +// This data structure enjoys the following powerful *canonicity property*: two +// handles represent the same set if and only if they have the same memory +// representation. Since the memory representation is just 32 bits, semantic set +// equality is cheap: O(1)! +class [[nodiscard]] PacketSetHandle { + public: + // The empty and full set of packets are not decision nodes, and thus we + // cannot associate an index into the `nodes_` vector with them. Instead, we + // represent them using sentinel values, chosen maximally to avoid collisions + // with proper indices. + enum Sentinel : uint32_t { + // Encodes the empty set of packets. + kEmptySet = std::numeric_limits::max(), + // Encodes the full set of packets. + kFullSet = std::numeric_limits::max() - 1, + // The minimum sentinel node index. + // Smaller values are reserved for proper indices into the `nodes_` vector. + kMinSentinel = kFullSet, + }; + + // Default constructor: the empty set of packets. + PacketSetHandle() : node_index_(kEmptySet) {} + + // Two packet set handles compare equal iff they represent the same set of + // concrete packets. Comparison is O(1), thanks to interning/hash-consing. + friend auto operator<=>(PacketSetHandle a, PacketSetHandle b) = default; + + // Hashing, see https://abseil.io/docs/cpp/guides/hash. + template + friend H AbslHashValue(H h, PacketSetHandle packet_set) { + return H::combine(std::move(h), packet_set.node_index_); + } + + // Formatting, see https://abseil.io/docs/cpp/guides/abslstringify. + // NOTE: These functions do not produce particularly useful output. Instead, + // use `PacketSetManager::ToString(packet_set)` whenever possible. + template + friend void AbslStringify(Sink& sink, PacketSetHandle packet_set) { + absl::Format(&sink, "%s", packet_set.ToString()); + } + std::string ToString() const { + if (node_index_ == kEmptySet) { + return "PacketSetHandle"; + } else if (node_index_ == kFullSet) { + return "PacketSetHandle"; + } else { + return absl::StrFormat("PacketSetHandle<%d>", node_index_); + } + } + + private: + // An index into the `nodes_` vector of the `PacketSetManager` object + // associated with this `PacketSetHandle`. The semantics of this packet set + // is entirely determined by the node `nodes_[node_index_]`. The index is + // otherwise arbitrary and meaningless. + // + // We use a 32-bit index as a tradeoff between minimizing memory usage and + // maximizing the number of `PacketSetHandle`s that can be created, both + // aspects that impact how well we scale to large NetKAT models. We expect + // millions, but not billions, of packet sets in practice, and 2^32 ~= 4 + // billion. + uint32_t node_index_; + explicit PacketSetHandle(uint32_t node_index) : node_index_(node_index) {} + friend class PacketSetManager; +}; + +// Protect against regressions in the memory layout, as it affects performance. +static_assert(sizeof(PacketSetHandle) <= 4); + +} // namespace netkat + +#endif // GOOGLE_NETKAT_NETKAT_PACKET_SET_HANDLE_H_ diff --git a/netkat/packet_set_test.cc b/netkat/packet_set_test.cc index 40731f6..2ef3a77 100644 --- a/netkat/packet_set_test.cc +++ b/netkat/packet_set_test.cc @@ -30,8 +30,10 @@ #include "gtest/gtest.h" #include "gutil/status_matchers.h" // IWYU pragma: keep #include "netkat/evaluator.h" +#include "netkat/gtest_utils.h" #include "netkat/netkat_proto_constructors.h" #include "netkat/packet.h" +#include "netkat/packet_transformer.h" #include "re2/re2.h" namespace netkat { @@ -40,8 +42,8 @@ namespace netkat { // test cases. This also enables better pretty printing for debugging, see // `PrintTo`. PacketSetManager& Manager() { - static absl::NoDestructor manager; - return *manager; + static absl::NoDestructor transformer; + return transformer->GetPacketSetManager(); } // The default `PacketSetHandle` pretty printer sucks! It does not have access @@ -55,6 +57,9 @@ void PrintTo(PacketSetHandle packet, std::ostream* os) { namespace { +// TODO: anthonyroy - Revert CompilationPreservesSemantics to +// ArbitraryValidPredicateProto once Pull is implemented in the evaluator. +using ::netkat::netkat_test::ArbitraryValidPredicateProtoWithoutPull; using ::testing::Ge; using ::testing::Pair; using ::testing::SizeIs; @@ -141,7 +146,9 @@ void CompilationPreservesSemantics(const PredicateProto& pred, EXPECT_EQ(Manager().Contains(Manager().Compile(pred), packet), Evaluate(pred, packet)); } -FUZZ_TEST(PacketSetManagerTest, CompilationPreservesSemantics); +FUZZ_TEST(PacketSetManagerTest, CompilationPreservesSemantics) + .WithDomains(ArbitraryValidPredicateProtoWithoutPull(), + fuzztest::Arbitrary()); void GetConcretePacketsReturnsNonEmptyListForNonEmptySet( const PredicateProto& pred) { diff --git a/netkat/packet_set_test_runner.cc b/netkat/packet_set_test_runner.cc index 30fd670..c7dfeb9 100644 --- a/netkat/packet_set_test_runner.cc +++ b/netkat/packet_set_test_runner.cc @@ -22,6 +22,7 @@ #include "netkat/netkat_proto_constructors.h" #include "netkat/packet_set.h" +#include "netkat/packet_transformer.h" namespace netkat { namespace { @@ -84,7 +85,8 @@ std::vector TestCases() { void main() { // This test needs a deterministic field interning order, and thus must start // from a fresh manager. - PacketSetManager manager; + PacketTransformerManager transformer; + PacketSetManager& manager = transformer.GetPacketSetManager(); for (const TestCase& test_case : TestCases()) { netkat::PacketSetHandle packet_set = manager.Compile(test_case.predicate); std::cout << kBanner << "Test case: " << test_case.description << std::endl diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index a125c8b..d49c565 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -41,34 +41,32 @@ #include "netkat/packet.h" #include "netkat/packet_field.h" #include "netkat/packet_set.h" +#include "netkat/packet_transformer_handle.h" namespace netkat { -// The Deny and Accept transformers are not decision nodes, and thus we cannot -// associate an index into the `nodes_` vector with them. Instead, we represent -// them using sentinel values, chosen maximally to avoid collisions with proper -// indices. -enum SentinelNodeIndex : uint32_t { - // Encodes the Deny transformer. - kDeny = std::numeric_limits::max(), - // Encodes the Accept transformer. - kAccept = std::numeric_limits::max() - 1, - // The minimum sentinel node index. - // Smaller values are reserved for proper indices into the `nodes_` vector. - kMinSentinel = kAccept, -}; - -PacketTransformerHandle::PacketTransformerHandle() - : node_index_(SentinelNodeIndex::kDeny) {} - -std::string PacketTransformerHandle::ToString() const { - if (node_index_ == SentinelNodeIndex::kDeny) { - return "PacketTransformerHandle"; - } else if (node_index_ == SentinelNodeIndex::kAccept) { - return "PacketTransformerHandle"; - } else { - return absl::StrFormat("PacketTransformerHandle<%d>", node_index_); +PacketTransformerManager::PacketTransformerManager() + : packet_set_manager_(*this) {} + +PacketTransformerManager::PacketTransformerManager( + PacketTransformerManager&& other) + : nodes_(std::move(other.nodes_)), + transformer_by_node_(std::move(other.transformer_by_node_)), + transformer_by_hash_(std::move(other.transformer_by_hash_)), + packet_set_manager_(std::move(other.packet_set_manager_)) { + packet_set_manager_.transformer_ = this; +} + +PacketTransformerManager& PacketTransformerManager::operator=( + PacketTransformerManager&& other) { + if (this != &other) { + nodes_ = std::move(other.nodes_); + transformer_by_node_ = std::move(other.transformer_by_node_); + transformer_by_hash_ = std::move(other.transformer_by_hash_); + packet_set_manager_ = std::move(other.packet_set_manager_); + packet_set_manager_.transformer_ = this; } + return *this; } const PacketTransformerManager::DecisionNode& @@ -176,7 +174,7 @@ PacketTransformerHandle PacketTransformerManager::NodeToTransformer( node, PacketTransformerHandle(nodes_.size())); if (inserted) { nodes_.push_back(std::move(node)); - LOG_IF(DFATAL, nodes_.size() > SentinelNodeIndex::kMinSentinel) + LOG_IF(DFATAL, nodes_.size() > PacketTransformerHandle::kMinSentinel) << "Internal invariant violated: Proper and sentinel node indices must " "be disjoint. This indicates that we allocated more nodes than are " "supported (> 2^32 - 2)."; @@ -322,11 +320,11 @@ PacketTransformerHandle PacketTransformerManager::Compile( } PacketTransformerHandle PacketTransformerManager::Deny() const { - return PacketTransformerHandle(SentinelNodeIndex::kDeny); + return PacketTransformerHandle(PacketTransformerHandle::kDeny); } PacketTransformerHandle PacketTransformerManager::Accept() const { - return PacketTransformerHandle(SentinelNodeIndex::kAccept); + return PacketTransformerHandle(PacketTransformerHandle::kAccept); } PacketTransformerHandle PacketTransformerManager::FromPacketSetHandle( @@ -771,8 +769,8 @@ PacketTransformerHandle PacketTransformerManager::Iterate( PacketSetHandle PacketTransformerManager::GetAllPossibleOutputPackets( PacketTransformerHandle transformer) { - if (IsAccept(transformer)) return PacketSetManager().FullSet(); - if (IsDeny(transformer)) return PacketSetManager().EmptySet(); + if (IsAccept(transformer)) return packet_set_manager_.FullSet(); + if (IsDeny(transformer)) return packet_set_manager_.EmptySet(); const DecisionNode& node = GetNodeOrDie(transformer); PacketSetHandle default_output = @@ -821,7 +819,7 @@ PacketSetHandle PacketTransformerManager::GetAllPossibleOutputPackets( // C.3 Push and Pull in KATch: A Fast Symbolic Verifier for NetKAT. for (auto& [match_value, unused] : node.modify_branch_by_field_match) { if (!branch_modify_values.contains(match_value)) { - add_to_output_by_field_value(match_value, PacketSetManager().EmptySet()); + add_to_output_by_field_value(match_value, packet_set_manager_.EmptySet()); } } @@ -862,8 +860,8 @@ PacketSetHandle PacketTransformerManager::Push( PacketSetHandle PacketTransformerManager::GetAllInputPacketsThatProduceAnyOutput( PacketTransformerHandle transformer) { - if (IsAccept(transformer)) return PacketSetManager().FullSet(); - if (IsDeny(transformer)) return PacketSetManager().EmptySet(); + if (IsAccept(transformer)) return packet_set_manager_.FullSet(); + if (IsDeny(transformer)) return packet_set_manager_.EmptySet(); const DecisionNode& node = GetNodeOrDie(transformer); @@ -1034,20 +1032,20 @@ std::string PacketTransformerManager::ToDot( if (IsAccept(transformer)) { absl::StrAppendFormat(&result, " %d [label=\"T\" shape=box]\n", - SentinelNodeIndex::kAccept); + PacketTransformerHandle::kAccept); absl::StrAppend(&result, "}\n"); return result; } if (IsDeny(transformer)) { absl::StrAppendFormat(&result, " %d [label=\"F\" shape=box]\n", - SentinelNodeIndex::kDeny); + PacketTransformerHandle::kDeny); absl::StrAppend(&result, "}\n"); return result; } absl::StrAppendFormat(&result, " %d [label=\"T\" shape=box]\n", - SentinelNodeIndex::kAccept); + PacketTransformerHandle::kAccept); absl::StrAppendFormat(&result, " %d [label=\"F\" shape=box]\n", - SentinelNodeIndex::kDeny); + PacketTransformerHandle::kDeny); std::queue work_list; work_list.push(transformer); absl::flat_hash_set visited = {transformer}; @@ -1065,9 +1063,9 @@ std::string PacketTransformerManager::ToDot( transformer.node_index_, field); for (const auto& [value, modify_map] : node.modify_branch_by_field_match) { if (modify_map.empty()) { - absl::StrAppendFormat(&result, " %d -> %d [label=\"%s==%s\"]\n", - transformer.node_index_, SentinelNodeIndex::kDeny, - field, absl::StrCat(value)); + absl::StrAppendFormat( + &result, " %d -> %d [label=\"%s==%s\"]\n", transformer.node_index_, + PacketTransformerHandle::kDeny, field, absl::StrCat(value)); } for (const auto& [new_value, branch] : modify_map) { absl::StrAppendFormat(&result, @@ -1102,7 +1100,7 @@ std::string PacketTransformerManager::ToDot( absl::Status PacketTransformerManager::CheckInternalInvariants() const { // Invariant: Proper and sentinel node indices are disjoint. - RET_CHECK(nodes_.size() <= SentinelNodeIndex::kMinSentinel); + RET_CHECK(nodes_.size() <= PacketTransformerHandle::kMinSentinel); // Invariant: `transformer_by_node_[n] = s` iff `nodes_[s.node_index_] == // n`. diff --git a/netkat/packet_transformer.h b/netkat/packet_transformer.h index 4c9c90d..e1cf487 100644 --- a/netkat/packet_transformer.h +++ b/netkat/packet_transformer.h @@ -16,8 +16,8 @@ // File: packet_transformer.h // ----------------------------------------------------------------------------- // -// Defines `PacketTransformerHandle` and its companion class -// `PacketTransformerManager` following the manager-class pattern described in +// Defines `PacketTransformerManager`, the companion class to +// `PacketTransformerHandle` following the manager-handle pattern described in // `manager_handle_pattern.md`. Together, they provide a compact and efficient // representation of record-free policies allowing for fast semantic equality // checks. Semantically, a `PacketTransformerHandle` represents a function that @@ -52,80 +52,11 @@ #include "netkat/packet.h" #include "netkat/packet_field.h" #include "netkat/packet_set.h" +#include "netkat/packet_transformer_handle.h" #include "netkat/paged_stable_vector.h" namespace netkat { -// A "packet transformer" is a lightweight handle (32 bits) that -// represents a record-free policy (or functions from packets to sets of output -// packets). Handles can only be created by a `PacketTransformerManager` -// object, which owns the graph-based representation of the set. The -// representation can efficiently encode typical large and even infinite sets -// seen in practice. -// -// The APIs of this object are almost entirely defined as methods of the -// companion class `PacketTransformerManager` following the -// manager-handle pattern, see `manager_handle_pattern.md`. -// -// CAUTION: Each `PacketTransformerHandle` is implicitly associated with the -// manager object that created it; using it with a different manager has -// undefined behavior. -// -// This data structure enjoys the following powerful *canonicity property*: two -// packet transformers represent the policy if and only if they have -// the same memory representation. Since the memory representation is just 32 -// bits, semantic policy equality is cheap: O(1)! -// -// Compared to NetKAT policies, packet transformers have a few -// advantages: -// * Cheap to store, copy, hash, and compare: O(1) -// * Cheap to check semantic equality: O(1) -class [[nodiscard]] PacketTransformerHandle { - public: - // Default constructor: the Deny policy. - PacketTransformerHandle(); - - // Two packet transformers compare equal iff they represent the same - // record-free policy (semantically). That is, two policies are equal iff they - // are semantically equivalent when Record is replaced by Accept. Comparison - // is O(1), thanks to interning/hash-consing. - friend auto operator<=>(PacketTransformerHandle a, - PacketTransformerHandle b) = default; - - // Hashing, see https://abseil.io/docs/cpp/guides/hash. - template - friend H AbslHashValue(H h, PacketTransformerHandle transformer) { - return H::combine(std::move(h), transformer.node_index_); - } - - // Formatting, see https://abseil.io/docs/cpp/guides/abslstringify. - // NOTE: These functions do not produce particularly useful output. Instead, - // use `PacketTransformerManager::ToString(transformer)` whenever - // possible. - template - friend void AbslStringify(Sink& sink, PacketTransformerHandle transformer) { - absl::Format(&sink, "%s", transformer.ToString()); - } - std::string ToString() const; - - private: - // An index into the `nodes_` vector of the `PacketTransformerManager` - // object associated with this `PacketTransformerHandle`. The semantics of - // this packet transformer is entirely determined by the node - // `nodes_[node_index_]`. The index is otherwise arbitrary and meaningless. - // - // We use a 32-bit index as a tradeoff between minimizing memory usage and - // maximizing the number of `PacketTransformerHandle`s that can be created, - // both aspects that impact how well we scale to large NetKAT models. - uint32_t node_index_; - explicit PacketTransformerHandle(uint32_t node_index) - : node_index_(node_index) {} - friend class PacketTransformerManager; -}; - -// Protect against regressions in the memory layout, as it affects performance. -static_assert(sizeof(PacketTransformerHandle) <= 4); - // An "arena" in which `PacketTransformerHandle`s can be created and // manipulated, following the manager-handle pattern (see // `manager_handle_pattern.md`). @@ -138,23 +69,25 @@ static_assert(sizeof(PacketTransformerHandle) <= 4); // `PacketTransformerManager` object with a different manager is // undefined behavior. `PacketSetHandles` and `PacketTransformerHandles` // returned by this class are not invalidated on move. + class PacketTransformerManager { public: - PacketTransformerManager() = default; - explicit PacketTransformerManager(PacketSetManager&& manager) - : packet_set_manager_(std::move(manager)) {}; + PacketTransformerManager(); // The class is move-only: not copyable, but movable. // `PacketSetHandles` and `PacketTransformerHandles` returned by this class // are not invalidated on move. PacketTransformerManager(const PacketTransformerManager&) = delete; PacketTransformerManager& operator=(const PacketTransformerManager&) = delete; - PacketTransformerManager(PacketTransformerManager&&) = default; - PacketTransformerManager& operator=(PacketTransformerManager&&) = default; + PacketTransformerManager(PacketTransformerManager&& other); + PacketTransformerManager& operator=(PacketTransformerManager&& other); // Returns the `PacketSetManager` used by this object to compile // predicates. PacketSetManager& GetPacketSetManager() { return packet_set_manager_; } + const PacketSetManager& GetPacketSetManager() const { + return packet_set_manager_; + } // Returns true iff this transformer represents the Deny policy. bool IsDeny(PacketTransformerHandle transformer) const; diff --git a/netkat/packet_transformer_handle.h b/netkat/packet_transformer_handle.h new file mode 100644 index 0000000..0b587a7 --- /dev/null +++ b/netkat/packet_transformer_handle.h @@ -0,0 +1,124 @@ +// Copyright 2025 The NetKAT authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ----------------------------------------------------------------------------- +// File: packet_transformer_handle.h +// ----------------------------------------------------------------------------- +// +// Defines `PacketTransformerHandle`, a lightweight handle representing a +// record-free policy. +// +// Together with its companion class `PacketTransformerManager` (defined in +// `packet_transformer.h`), they provide a compact and efficient representation +// of record-free policies allowing for fast semantic equality checks. + +#ifndef GOOGLE_NETKAT_NETKAT_PACKET_TRANSFORMER_HANDLE_H_ +#define GOOGLE_NETKAT_NETKAT_PACKET_TRANSFORMER_HANDLE_H_ + +#include +#include +#include +#include + +#include "absl/strings/str_format.h" + +namespace netkat { + +// A "packet transformer" is a lightweight handle (32 bits) that +// represents a record-free policy (or functions from packets to sets of output +// packets). Handles can only be created by a `PacketTransformerManager` +// object, which owns the graph-based representation of the set. The +// representation can efficiently encode typical large and even infinite sets +// seen in practice. +// +// The APIs of this object are almost entirely defined as methods of the +// companion class `PacketTransformerManager` following the +// manager-handle pattern, see `manager_handle_pattern.md`. +// +// CAUTION: Each `PacketTransformerHandle` is implicitly associated with the +// manager object that created it; using it with a different manager has +// undefined behavior. +// +// This data structure enjoys the following powerful *canonicity property*: two +// packet transformers represent the policy if and only if they have +// the same memory representation. Since the memory representation is just 32 +// bits, semantic policy equality is cheap: O(1)! +// +// Compared to NetKAT policies, packet transformers have a few +// advantages: +// * Cheap to store, copy, hash, and compare: O(1) +// * Cheap to check semantic equality: O(1) +class [[nodiscard]] PacketTransformerHandle { + public: + // The deny and accept policies are not decision nodes, and thus we cannot + // associate an index into the `nodes_` vector with them. Instead, we + // represent them using sentinel values, chosen maximally to avoid collisions + // with proper indices. + enum Sentinel : uint32_t { + // Encodes the Deny transformer. + kDeny = std::numeric_limits::max(), + // Encodes the Accept transformer. + kAccept = std::numeric_limits::max() - 1, + // The minimum sentinel node index. + // Smaller values are reserved for proper indices into the `nodes_` vector. + kMinSentinel = kAccept, + }; + + // Default constructor: the Deny policy. + PacketTransformerHandle() : node_index_(kDeny) {} + + // Two packet transformers compare equal iff they represent the same + // record-free policy (semantically). + friend auto operator<=>(PacketTransformerHandle a, + PacketTransformerHandle b) = default; + + // Hashing, see https://abseil.io/docs/cpp/guides/hash. + template + friend H AbslHashValue(H h, PacketTransformerHandle transformer) { + return H::combine(std::move(h), transformer.node_index_); + } + + // Formatting, see https://abseil.io/docs/cpp/guides/abslstringify. + template + friend void AbslStringify(Sink& sink, PacketTransformerHandle transformer) { + absl::Format(&sink, "%s", transformer.ToString()); + } + std::string ToString() const { + if (node_index_ == kDeny) { + return "PacketTransformerHandle"; + } else if (node_index_ == kAccept) { + return "PacketTransformerHandle"; + } else { + return absl::StrFormat("PacketTransformerHandle<%d>", node_index_); + } + } + + private: + // An index into the `nodes_` vector of the `PacketTransformerManager` + // object associated with this `PacketTransformerHandle`. + uint32_t node_index_; + + explicit PacketTransformerHandle(uint32_t node_index) + : node_index_(node_index) {} + + friend class PacketTransformerManager; + friend class PacketSetManager; +}; + +// Protect against regressions in the memory layout, as it affects performance. +static_assert(sizeof(PacketTransformerHandle) <= 4); + +} // namespace netkat + +#endif // GOOGLE_NETKAT_NETKAT_PACKET_TRANSFORMER_HANDLE_H_ diff --git a/netkat/packet_transformer_test.cc b/netkat/packet_transformer_test.cc index 0e7b0fc..de7cee9 100644 --- a/netkat/packet_transformer_test.cc +++ b/netkat/packet_transformer_test.cc @@ -30,10 +30,12 @@ #include "gtest/gtest.h" #include "gutil/status_matchers.h" // IWYU pragma: keep #include "netkat/evaluator.h" +#include "netkat/gtest_utils.h" #include "netkat/netkat.pb.h" #include "netkat/netkat_proto_constructors.h" #include "netkat/packet.h" #include "netkat/packet_set.h" +#include "netkat/packet_transformer_handle.h" #include "re2/re2.h" namespace netkat { @@ -60,6 +62,25 @@ namespace { using ::fuzztest::Arbitrary; using ::fuzztest::ElementOf; +using ::netkat::netkat_test::ArbitraryValidPolicyProtoWithoutPull; +using ::netkat::netkat_test::FieldTypeIs; + +fuzztest::Domain PredicateWithRestrictedFields() { + return fuzztest::Arbitrary() + .WithFieldsAlwaysSet() + .WithStringFields(ElementOf({"f", "g"})) + .WithInt32Fields(ElementOf({1, 2, 3})); +} + +fuzztest::Domain PolicyWithRestrictedFields() { + auto predicate_domain = PredicateWithRestrictedFields(); + return fuzztest::Arbitrary() + .WithFieldsAlwaysSet() + .WithStringFields(ElementOf({"f", "g"})) + .WithInt32Fields(ElementOf({1, 2, 3})) + .WithProtobufFields(FieldTypeIs, predicate_domain); +} +using ::fuzztest::ElementOf; using ::testing::ContainerEq; using ::testing::IsEmpty; using ::testing::Pair; @@ -114,6 +135,35 @@ TEST(PacketTransformerManagerTest, AbslHashValueWorks) { EXPECT_EQ(set.size(), 2); } +TEST(PacketTransformerManagerTest, MoveConstructorPreservesState) { + PacketTransformerManager manager1; + PacketTransformerHandle h1 = manager1.Compile(ModificationProto("f", 1)); + PacketTransformerHandle h2 = manager1.Compile(ModificationProto("f", 2)); + + PacketTransformerManager manager2 = std::move(manager1); + + EXPECT_THAT(manager2.ToString(h1), StartsWith("PacketTransformerHandle")); + EXPECT_THAT(manager2.ToString(h2), StartsWith("PacketTransformerHandle")); + + EXPECT_EQ(manager2.Compile(ModificationProto("f", 1)), h1); + EXPECT_EQ(manager2.Compile(ModificationProto("f", 2)), h2); +} + +TEST(PacketTransformerManagerTest, MoveAssignmentPreservesState) { + PacketTransformerManager manager1; + PacketTransformerHandle h1 = manager1.Compile(ModificationProto("f", 1)); + PacketTransformerHandle h2 = manager1.Compile(ModificationProto("f", 2)); + + PacketTransformerManager manager2; + manager2 = std::move(manager1); + + EXPECT_THAT(manager2.ToString(h1), StartsWith("PacketTransformerHandle")); + EXPECT_THAT(manager2.ToString(h2), StartsWith("PacketTransformerHandle")); + + EXPECT_EQ(manager2.Compile(ModificationProto("f", 1)), h1); + EXPECT_EQ(manager2.Compile(ModificationProto("f", 2)), h2); +} + TEST(PacketTransformerManagerTest, EmptyPolicyCompilesToDeny) { EXPECT_TRUE(Manager().IsDeny(Manager().Compile(PolicyProto()))); } @@ -128,13 +178,6 @@ void CompileIsSameAsOfCompiledPacketSetHandle(PredicateProto predicate) { PacketSetHandle set_1 = Manager().GetPacketSetManager().Compile(predicate); EXPECT_EQ(Manager().Compile(FilterProto(predicate)), Manager().FromPacketSetHandle(set_1)); - - // Using a newly constructed PacketSetManager. - PacketSetManager packet_set_manager; - PacketSetHandle set_2 = packet_set_manager.Compile(predicate); - PacketTransformerManager manager(std::move(packet_set_manager)); - EXPECT_EQ(manager.Compile(FilterProto(predicate)), - manager.FromPacketSetHandle(set_2)); } FUZZ_TEST(PacketTransformerManagerTest, CompileIsSameAsOfCompiledPacketSetHandle); @@ -442,7 +485,8 @@ void RunIsSameAsEvaluate(PolicyProto policy, Packet packet) { ContainerEq(Evaluate(policy, packet))); EXPECT_EQ(packet, original_packet); } -FUZZ_TEST(PacketTransformerManagerTest, RunIsSameAsEvaluate); +FUZZ_TEST(PacketTransformerManagerTest, RunIsSameAsEvaluate) + .WithDomains(ArbitraryValidPolicyProtoWithoutPull(), Arbitrary()); TEST(PacketTransformerManagerTest, SimpleSequenceRunTest1) { // !(once=1) ; a:=1 ; once:=1 @@ -843,17 +887,7 @@ void PacketsFromRunAreInPushPacketSet(PredicateProto predicate, } } FUZZ_TEST(PacketTransformerManagerTest, PacketsFromRunAreInPushPacketSet) - // We restrict to two field names and three field value to increases the - // likelihood for coverage for predicates/policies that match/modify the - // same field several times. - .WithDomains(Arbitrary() - .WithFieldsAlwaysSet() - .WithStringFields(ElementOf({"f", "g"})) - .WithInt32Fields(ElementOf({1, 2, 3})), - Arbitrary() - .WithFieldsAlwaysSet() - .WithStringFields(ElementOf({"f", "g"})) - .WithInt32Fields(ElementOf({1, 2, 3}))); + .WithDomains(PredicateWithRestrictedFields(), PolicyWithRestrictedFields()); void PulledPacketGetsRunThroughTransformerBelongsToInputPacketSet( PredicateProto predicate, PolicyProto policy) { @@ -882,17 +916,7 @@ void PulledPacketGetsRunThroughTransformerBelongsToInputPacketSet( } FUZZ_TEST(PacketTransformerManagerTest, PulledPacketGetsRunThroughTransformerBelongsToInputPacketSet) - // We restrict to two field names and three field value to increases the - // likelihood for coverage for policies that modify the same field several - // times. - .WithDomains(Arbitrary() - .WithFieldsAlwaysSet() - .WithStringFields(ElementOf({"f", "g"})) - .WithInt32Fields(ElementOf({1, 2, 3})), - Arbitrary() - .WithFieldsAlwaysSet() - .WithStringFields(ElementOf({"f", "g"})) - .WithInt32Fields(ElementOf({1, 2, 3}))); + .WithDomains(PredicateWithRestrictedFields(), PolicyWithRestrictedFields()); void PushOnFilterIsSameAsAnd(PredicateProto left, PredicateProto right) { PacketSetHandle left_set = Manager().GetPacketSetManager().Compile(left); @@ -901,17 +925,8 @@ void PushOnFilterIsSameAsAnd(PredicateProto left, PredicateProto right) { Manager().GetPacketSetManager().And(left_set, right_set)); } FUZZ_TEST(PacketTransformerManagerTest, PushOnFilterIsSameAsAnd) - // We restrict to two field names and three field value to increases the - // likelihood for coverage for policies that modify the same field several - // times. - .WithDomains(Arbitrary() - .WithFieldsAlwaysSet() - .WithStringFields(ElementOf({"f", "g"})) - .WithInt32Fields(ElementOf({1, 2, 3})), - Arbitrary() - .WithFieldsAlwaysSet() - .WithStringFields(ElementOf({"f", "g"})) - .WithInt32Fields(ElementOf({1, 2, 3}))); + .WithDomains(PredicateWithRestrictedFields(), + PredicateWithRestrictedFields()); void PushAndPullRoundTrippingHoldsForFullSet(PolicyProto policy) { PacketTransformerHandle transformer = Manager().Compile(policy); @@ -922,13 +937,7 @@ void PushAndPullRoundTrippingHoldsForFullSet(PolicyProto policy) { Manager().Pull(transformer, Manager().Push(full_set, transformer))); } FUZZ_TEST(PacketTransformerManagerTest, PushAndPullRoundTrippingHoldsForFullSet) - // We restrict to two field names and three field value to increases the - // likelihood for coverage for policies that modify the same field several - // times. - .WithDomains(Arbitrary() - .WithFieldsAlwaysSet() - .WithStringFields(ElementOf({"f", "g"})) - .WithInt32Fields(ElementOf({1, 2, 3}))); + .WithDomains(PolicyWithRestrictedFields()); } // namespace @@ -942,9 +951,9 @@ class PacketTransformerManagerTestPeer { PacketSetHandle GetAllPossibleOutputPacketsReferenceImplementation( PacketTransformerHandle transformer) { if (packet_transformer_manager_->IsAccept(transformer)) - return PacketSetManager().FullSet(); + return packet_transformer_manager_->GetPacketSetManager().FullSet(); if (packet_transformer_manager_->IsDeny(transformer)) - return PacketSetManager().EmptySet(); + return packet_transformer_manager_->GetPacketSetManager().EmptySet(); const PacketTransformerManager::DecisionNode& node = packet_transformer_manager_->GetNodeOrDie(transformer); const std::string field = packet_transformer_manager_->GetPacketSetManager() @@ -994,7 +1003,8 @@ class PacketTransformerManagerTestPeer { // 1. input.field != match_value for all explicit match branches, thus // output.field != match_value for all explicit match branches by (0) // 2. output.field != modify_value for all default-modify branches - PacketSetHandle fallthrough_output = PacketSetManager().FullSet(); + PacketSetHandle fallthrough_output = + packet_transformer_manager_->GetPacketSetManager().FullSet(); for (const auto& [match_value, unused] : node.modify_branch_by_field_match) { fallthrough_output = @@ -1027,12 +1037,16 @@ void GetAllPossibleOutputPacketsIsSameAsReferenceImplementation( } FUZZ_TEST(PacketTransformerManagerTest, GetAllPossibleOutputPacketsIsSameAsReferenceImplementation) - // We restrict to two field names and three field value to increases the - // likelihood for coverage for policies that modify the same field several - // times. - .WithDomains(Arbitrary() - .WithFieldsAlwaysSet() - .WithStringFields(ElementOf({"f", "g"})) - .WithInt32Fields(ElementOf({1, 2, 3}))); + .WithDomains(PolicyWithRestrictedFields()); + +void CompilePullIsCorrect(PolicyProto policy, PredicateProto predicate) { + EXPECT_EQ( + Manager().GetPacketSetManager().Compile(PullProto(policy, predicate)), + Manager().Pull(Manager().Compile(policy), + Manager().GetPacketSetManager().Compile(predicate))); +} +FUZZ_TEST(PacketTransformerManagerTest, CompilePullIsCorrect) + .WithDomains(PolicyWithRestrictedFields(), PredicateWithRestrictedFields()); + } // namespace } // namespace netkat diff --git a/netkat/table_test.cc b/netkat/table_test.cc index f6fcb64..157201a 100644 --- a/netkat/table_test.cc +++ b/netkat/table_test.cc @@ -34,6 +34,10 @@ namespace netkat { namespace { +// TODO: anthonyroy - Revert to ArbitraryValidPredicateProto once Pull is +// implemented. +using ::netkat::netkat_test::ArbitraryValidPredicateProtoWithoutPull; + using ::gutil::EqualsProto; using ::gutil::IsOk; using ::gutil::StatusIs; @@ -153,7 +157,7 @@ void RuleWithFilterIsInvalid(PredicateProto predicate) { StatusIs(absl::StatusCode::kInvalidArgument)); } FUZZ_TEST(NetkatTableTest, RuleWithFilterIsInvalid) - .WithDomains(netkat_test::ArbitraryValidPredicateProto()); + .WithDomains(ArbitraryValidPredicateProtoWithoutPull()); void RuleWithDropActionIsValid(PredicateProto match) { ASSERT_OK_AND_ASSIGN(Predicate pred, Predicate::FromProto(match)); @@ -162,7 +166,7 @@ void RuleWithDropActionIsValid(PredicateProto match) { EXPECT_THAT(table.AddRule(/*priority=*/10, pred, Policy::Deny()), IsOk()); } FUZZ_TEST(NetkatTableTest, RuleWithDropActionIsValid) - .WithDomains(netkat_test::ArbitraryValidPredicateProto()); + .WithDomains(ArbitraryValidPredicateProtoWithoutPull()); TEST(NetkatTable, NonDeterministicRuleRejected) { NetkatTable table;