diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index 298f37b..678e44b 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -662,12 +662,9 @@ PacketSetHandle PacketTransformerManager::GetAllPossibleOutputPackets( // Case 2: Output packets that hit an explicit branch and got modified. // Implements the `b_B` in the `fwd` function in section C.3 Push and Pull in // KATch: A Fast Symbolic Verifier for NetKAT. - // absl::flat_hash_set branch_match_values; absl::flat_hash_set branch_modify_values; for (const auto& [match_value, branch_by_modify] : node.modify_branch_by_field_match) { - // branch_match_values.insert(match_value); - // Maps modify values --> fwd(branch). for (const auto& [modify_value, branch] : branch_by_modify) { branch_modify_values.insert(modify_value); add_to_output_by_field_value(modify_value, @@ -730,6 +727,87 @@ PacketSetHandle PacketTransformerManager::Push( Sequence(FromPacketSetHandle(input_packets), transformer)); } +PacketSetHandle +PacketTransformerManager::GetAllInputPacketsThatProduceAnyOutput( + PacketTransformerHandle transformer) { + if (IsAccept(transformer)) return PacketSetManager().FullSet(); + if (IsDeny(transformer)) return PacketSetManager().EmptySet(); + + const DecisionNode& node = GetNodeOrDie(transformer); + + // Case 1: Input packets that hit the default branch and got modified. + // Implements the `d'` in the `bwd` function in section C.3 Push and Pull in + // KATch: A Fast Symbolic Verifier for NetKAT. + PacketSetHandle default_branch_output_packets; + for (const auto& [modify_value, branch] : + node.default_branch_by_field_modification) { + default_branch_output_packets = + packet_set_manager_.Or(default_branch_output_packets, + GetAllInputPacketsThatProduceAnyOutput(branch)); + } + + // Case 2: Input packets that hit an explicit branch and got modified. + // Implements the `b_A` in the `bwd` function in section C.3 Push and Pull in + // KATch: A Fast Symbolic Verifier for NetKAT. + absl::flat_hash_map branch_by_field_value_map; + for (const auto& [match_value, branch_by_modify] : + node.modify_branch_by_field_match) { + PacketSetHandle union_of_branches; + for (const auto& [modify_value, branch] : branch_by_modify) { + union_of_branches = packet_set_manager_.Or( + union_of_branches, GetAllInputPacketsThatProduceAnyOutput(branch)); + } + branch_by_field_value_map[match_value] = union_of_branches; + } + + // Case 3: Input packets that do not get matched on an explicit branch, but + // do get modified. + // Implements the `b_B` in the `bwd` function in section C.3 Push and Pull in + // KATch: A Fast Symbolic Verifier for NetKAT. + for (const auto& [modify_value, unused] : + node.default_branch_by_field_modification) { + if (!node.modify_branch_by_field_match.contains(modify_value)) { + branch_by_field_value_map[modify_value] = default_branch_output_packets; + } + } + + PacketSetHandle default_branch = packet_set_manager_.Or( + default_branch_output_packets, + GetAllInputPacketsThatProduceAnyOutput(node.default_branch)); + int num_branches = 0; + for (const auto& [value, branch] : branch_by_field_value_map) { + if (branch != default_branch) num_branches++; + } + absl::FixedArray, 0> + branch_by_field_value_list(num_branches); + int i = 0; + for (const auto& [value, branch] : branch_by_field_value_map) { + // Skips `default_branch` because an invariant of `DecisionNode` is that no + // branch in `branch_by_field_value` can be a duplicate of the default + // branch. + if (branch == default_branch) continue; + branch_by_field_value_list[i++] = std::make_pair(value, branch); + } + + // Required to sort `branch_by_field_value_list` to ensure that it meets the + // invariant of the `DecisionNode`'s `branch_by_field_value`. + absl::c_sort(branch_by_field_value_list, [](auto& left, auto& right) { + return left.first < right.first; + }); + + return packet_set_manager_.NodeToPacket({ + .field = node.field, + .default_branch = default_branch, + .branch_by_field_value = std::move(branch_by_field_value_list), + }); +} + +PacketSetHandle PacketTransformerManager::Pull( + PacketTransformerHandle transformer, PacketSetHandle output_packets) { + return GetAllInputPacketsThatProduceAnyOutput( + Sequence(transformer, FromPacketSetHandle(output_packets))); +} + std::string PacketTransformerManager::ToString(const DecisionNode& node) const { std::string result; std::vector work_list; diff --git a/netkat/packet_transformer.h b/netkat/packet_transformer.h index 046e5dd..9983e89 100644 --- a/netkat/packet_transformer.h +++ b/netkat/packet_transformer.h @@ -222,11 +222,24 @@ class PacketTransformerManager { PacketSetHandle GetAllPossibleOutputPackets( PacketTransformerHandle transformer); + // Computes the set of possible input packets that when run through the given + // transformer produce a non-empty set of outputs. Equivalent to + // `Pull(transformer, manager::FullSet())`. + PacketSetHandle GetAllInputPacketsThatProduceAnyOutput( + PacketTransformerHandle transformer); + // Returns set of output packets obtained by applying the given `transformer` // to the given `input_packets`. PacketSetHandle Push(PacketSetHandle input_packets, PacketTransformerHandle transformer); + // Returns the set of input packets obtained by applying the given + // `transformer` in reverse on the given `output_packets`. More formally, + // returns the set of input packets that produce one or more output packets + // contained in `output_packets`. + PacketSetHandle Pull(PacketTransformerHandle transformer, + PacketSetHandle output_packets); + // TODO(b/398373935): There are many additional operations supported by this // data structure, but not currently implemented. Add them as needed. Examples // below include Intersection, Difference, and SymmetricDifference. diff --git a/netkat/packet_transformer_test.cc b/netkat/packet_transformer_test.cc index 8ff7113..c3e6d8d 100644 --- a/netkat/packet_transformer_test.cc +++ b/netkat/packet_transformer_test.cc @@ -63,6 +63,7 @@ using ::testing::ContainerEq; using ::testing::IsEmpty; using ::testing::Pair; using ::testing::StartsWith; +using ::testing::Truly; using ::testing::UnorderedElementsAre; // After executing all tests, we check once that no invariants are violated, for @@ -518,6 +519,16 @@ TEST(PacketTransformerManagerTest, PushThroughModifyIsCorrect) { packet_set_manager.And(g_24, f_42)); } +TEST(PacketTransformerManagerTest, PullThroughModifyIsCorrect) { + PacketSetManager& packet_set_manager = Manager().GetPacketSetManager(); + PacketSetHandle f_24 = packet_set_manager.Match("f", 24); + PacketSetHandle f_42 = packet_set_manager.Match("f", 42); + PacketTransformerHandle modify_f_42 = Manager().Modification("f", 42); + + EXPECT_THAT(Manager().Pull(modify_f_42, f_42), packet_set_manager.FullSet()); + EXPECT_THAT(Manager().Pull(modify_f_42, f_24), packet_set_manager.EmptySet()); +} + TEST(PacketTransformerManagerTest, PacketsPushedThroughSequenceAndUnionTransformersAreCorrect) { PacketSetManager& packet_set_manager = Manager().GetPacketSetManager(); @@ -592,6 +603,83 @@ TEST(PacketTransformerManagerTest, packet_set_manager.EmptySet()); } +TEST(PacketTransformerManagerTest, + PacketsPulledThroughSequenceAndUnionTransformersAreCorrect) { + PacketSetManager& packet_set_manager = Manager().GetPacketSetManager(); + + // a=1 ; a:=0 + PacketTransformerHandle check_a = Manager().Compile(SequenceProto( + FilterProto(MatchProto("a", 1)), ModificationProto("a", 0))); + + // Does `a:=1` exactly once. + // !(once=1) ; a:=1 ; once:=1 + PacketTransformerHandle a_once = Manager().Compile(SequenceProto( + FilterProto(NotProto(MatchProto("once", 1))), + SequenceProto(ModificationProto("a", 1), ModificationProto("once", 1)))); + + PacketTransformerHandle check_a_and_a_once_transformer = + Manager().Union(check_a, a_once); + + PacketSetHandle packet_with_once_0 = + packet_set_manager.Compile(MatchProto("once", 0)); + PacketSetHandle packet_with_once_1 = + packet_set_manager.Compile(MatchProto("once", 1)); + PacketSetHandle packet_with_a_0 = + packet_set_manager.Compile(MatchProto("a", 0)); + PacketSetHandle packet_with_a_1 = + packet_set_manager.Compile(MatchProto("a", 1)); + PacketSetHandle packet_with_a_0_or_1 = + packet_set_manager.Or(packet_with_a_0, packet_with_a_1); + + PacketSetHandle packet_with_once_1_and_a_1 = + packet_set_manager.And(packet_with_once_1, packet_with_a_1); + PacketSetHandle packet_with_once_0_and_a_0 = + packet_set_manager.And(packet_with_once_0, packet_with_a_0); + + // Test `check_a_and_a_once_transformer`. + EXPECT_THAT(Manager().Pull(check_a_and_a_once_transformer, packet_with_a_0), + packet_with_a_1); + + PacketSetHandle expected_packet_set1 = + packet_set_manager.Not(packet_with_once_1); + EXPECT_THAT(Manager().Pull(check_a_and_a_once_transformer, packet_with_a_1), + expected_packet_set1); + EXPECT_THAT(Manager().Pull(check_a_and_a_once_transformer, + packet_with_once_1_and_a_1), + expected_packet_set1); + + PacketSetHandle expected_packet_set2 = + packet_set_manager.And(packet_with_once_0, packet_with_a_1); + EXPECT_THAT( + Manager().Pull(check_a_and_a_once_transformer, packet_with_once_0), + expected_packet_set2); + EXPECT_THAT(Manager().Pull(check_a_and_a_once_transformer, + packet_with_once_0_and_a_0), + expected_packet_set2); + + PacketSetHandle expected_packet_set3 = + packet_set_manager.Or(expected_packet_set1, packet_with_a_1); + EXPECT_THAT( + Manager().Pull(check_a_and_a_once_transformer, packet_with_once_1), + expected_packet_set3); + EXPECT_THAT( + Manager().Pull(check_a_and_a_once_transformer, packet_with_a_0_or_1), + expected_packet_set3); + + // Pull the results through again! + PacketSetHandle expected_packet_set4 = + packet_set_manager.And(expected_packet_set1, packet_with_a_1); + EXPECT_THAT( + Manager().Pull(check_a_and_a_once_transformer, expected_packet_set1), + expected_packet_set4); + EXPECT_THAT( + Manager().Pull(check_a_and_a_once_transformer, expected_packet_set2), + Manager().GetPacketSetManager().EmptySet()); + EXPECT_THAT( + Manager().Pull(check_a_and_a_once_transformer, expected_packet_set3), + expected_packet_set1); +} + TEST(PacketTransformerManagerTest, AllTransformedPacketBelongsToPushedPacketSet) { // predicate := (a=5 && b=2) || (b!=5 && c=5) @@ -625,6 +713,47 @@ TEST(PacketTransformerManagerTest, } } +TEST(PacketTransformerManagerTest, + ConcretePacketFromPullGetsRunThroughTransformerBelongsToInputPacketSet) { + // predicate := (a=3 && b=4) || (b!=5 && c=5) + PredicateProto predicate = + OrProto(AndProto(MatchProto("a", 3), MatchProto("b", 4)), + AndProto(NotProto(MatchProto("b", 5)), MatchProto("c", 5))); + PacketSetHandle packet_set = + Manager().GetPacketSetManager().Compile(predicate); + + // policy := (a=5 + b=2);(b:=1 + c=5) + PolicyProto policy = SequenceProto( + UnionProto(FilterProto(MatchProto("a", 5)), + FilterProto(MatchProto("b", 2))), + UnionProto(ModificationProto("b", 1), FilterProto(MatchProto("c", 5)))); + PacketTransformerHandle transformer = Manager().Compile(policy); + + // Get all concrete packets from Pull on a transformer and packet set. + std::vector pulled_concrete_packets = + Manager().GetPacketSetManager().GetConcretePackets( + Manager().Pull(transformer, packet_set)); + + if (pulled_concrete_packets.empty()) { + LOG(INFO) << "SKIPPED: no concrete pulled packets were obtained"; + return; + } + + for (Packet& concrete_packet : pulled_concrete_packets) { + // Run the pulled concrete packets through the transformer. There exists at + // least one transformed concrete packet from the transformed packets that + // belongs to the packet set. + bool packet_exist_in_pulled_packet_set = false; + for (const Packet& transformed_packet : + Manager().Run(transformer, concrete_packet)) { + packet_exist_in_pulled_packet_set |= + Manager().GetPacketSetManager().Contains(packet_set, + transformed_packet); + } + EXPECT_TRUE(packet_exist_in_pulled_packet_set); + } +} + void PacketsFromRunAreInPushPacketSet(PredicateProto predicate, PolicyProto policy) { PacketSetHandle packet_set = @@ -656,6 +785,45 @@ FUZZ_TEST(PacketTransformerManagerTest, PacketsFromRunAreInPushPacketSet) .WithStringFields(ElementOf({"f", "g"})) .WithInt32Fields(ElementOf({1, 2, 3}))); +void PulledPacketGetsRunThroughTransformerBelongsToInputPacketSet( + PredicateProto predicate, PolicyProto policy) { + PacketSetHandle packet_set = + Manager().GetPacketSetManager().Compile(predicate); + PacketTransformerHandle transformer = Manager().Compile(policy); + std::vector pulled_concrete_packets = + Manager().GetPacketSetManager().GetConcretePackets( + Manager().Pull(transformer, packet_set)); + + if (pulled_concrete_packets.empty()) { + LOG(INFO) << "SKIPPED: no concrete pulled packets were obtained"; + return; + } + + for (Packet& concrete_packet : pulled_concrete_packets) { + // Run the pulled concrete packets through the transformer. There exists at + // least one transformed concrete packet from the transformed packets that + // belongs to the packet set. + EXPECT_THAT(Manager().Run(transformer, concrete_packet), + Contains(Truly([&](const Packet& output_packet) { + return Manager().GetPacketSetManager().Contains( + packet_set, output_packet); + }))); + } +} +FUZZ_TEST(PacketTransformerManagerTest, + PulledPacketGetsRunThroughTransformerBelongsToInputPacketSet) + // We restrict to two field names and three field value to increases the + // likelihood for coverage for policies that modify the same field several + // times. + .WithDomains(Arbitrary() + .WithFieldsAlwaysSet() + .WithStringFields(ElementOf({"f", "g"})) + .WithInt32Fields(ElementOf({1, 2, 3})), + Arbitrary() + .WithFieldsAlwaysSet() + .WithStringFields(ElementOf({"f", "g"})) + .WithInt32Fields(ElementOf({1, 2, 3}))); + void PushOnFilterIsSameAsAnd(PredicateProto left, PredicateProto right) { PacketSetHandle left_set = Manager().GetPacketSetManager().Compile(left); PacketSetHandle right_set = Manager().GetPacketSetManager().Compile(right); @@ -675,6 +843,23 @@ FUZZ_TEST(PacketTransformerManagerTest, PushOnFilterIsSameAsAnd) .WithStringFields(ElementOf({"f", "g"})) .WithInt32Fields(ElementOf({1, 2, 3}))); +void PushAndPullRoundTrippingHoldsForFullSet(PolicyProto policy) { + PacketTransformerHandle transformer = Manager().Compile(policy); + PacketSetHandle full_set = Manager().GetPacketSetManager().FullSet(); + EXPECT_EQ(Manager().Push(full_set, transformer), + Manager().Push(Manager().Pull(transformer, full_set), transformer)); + EXPECT_EQ(Manager().Pull(transformer, full_set), + Manager().Pull(transformer, Manager().Push(full_set, transformer))); +} +FUZZ_TEST(PacketTransformerManagerTest, PushAndPullRoundTrippingHoldsForFullSet) + // We restrict to two field names and three field value to increases the + // likelihood for coverage for policies that modify the same field several + // times. + .WithDomains(Arbitrary() + .WithFieldsAlwaysSet() + .WithStringFields(ElementOf({"f", "g"})) + .WithInt32Fields(ElementOf({1, 2, 3}))); + } // namespace // Test peer class to access private methods.