diff --git a/netkat/symbolic_packet.cc b/netkat/symbolic_packet.cc index 42003fd..23ca3ef 100644 --- a/netkat/symbolic_packet.cc +++ b/netkat/symbolic_packet.cc @@ -79,7 +79,9 @@ bool SymbolicPacketManager::IsFullSet(SymbolicPacket packet) const { const SymbolicPacketManager::DecisionNode& SymbolicPacketManager::GetNodeOrDie( SymbolicPacket packet) const { - CHECK_LT(packet.node_index_, nodes_.size()); // Crash ok + CHECK_LT(packet.node_index_, nodes_.size()) + << "Did you call this function on a leaf node (i.e. FullSet() or " + "EmptySet())? "; // Crash ok return nodes_[packet.node_index_]; } @@ -335,6 +337,48 @@ SymbolicPacket SymbolicPacketManager::Xor(SymbolicPacket left, return Or(And(Not(left), right), And(left, Not(right))); } +SymbolicPacket SymbolicPacketManager::Exists(absl::string_view field, + SymbolicPacket packet) { + if (IsFullSet(packet) || IsEmptySet(packet)) return packet; + + // Compute result the hard way. + const DecisionNode node = GetNodeOrDie(packet); + std::string node_field = field_manager_.GetFieldName(node.field); + + // Case 1: `packet` is a member of `Exists(field, *)`: remove the current node + // and return the OR-ing of all branches. + if (node_field == field) { + SymbolicPacket result_packet = node.default_branch; + for (const auto& [_, branch] : node.branch_by_field_value) { + result_packet = Or(result_packet, branch); + } + return result_packet; + } + + // Case 2: calls `packet` is a member of `Exists(field, *)` for all it's + // branches: keep current node and call `Exists` on all branches and exclude a + // branch if it is the same as the default branch. + SymbolicPacket default_branch = Exists(field, node.default_branch); + absl::FixedArray> + non_default_branches_by_field_value(node.branch_by_field_value.size()); + int num_branches = 0; + for (const auto& [value, branch] : node.branch_by_field_value) { + SymbolicPacket non_default_branch = Exists(field, branch); + if (non_default_branch == default_branch) continue; + non_default_branches_by_field_value[num_branches++] = + std::make_pair(value, non_default_branch); + } + + return NodeToPacket(DecisionNode{ + .field = node.field, + .default_branch = default_branch, + .branch_by_field_value{ + non_default_branches_by_field_value.begin(), + non_default_branches_by_field_value.begin() + num_branches, + }, + }); +} + std::string SymbolicPacketManager::ToString(SymbolicPacket packet) const { std::string result; std::queue work_list{{packet}}; diff --git a/netkat/symbolic_packet.h b/netkat/symbolic_packet.h index 6f11fbf..2062066 100644 --- a/netkat/symbolic_packet.h +++ b/netkat/symbolic_packet.h @@ -227,6 +227,10 @@ class SymbolicPacketManager { // set, but not in both. Also known as symmetric set difference. SymbolicPacket Xor(SymbolicPacket left, SymbolicPacket right); + // Return the set of packets whose given `field` gets satisfied by some + // predicate. + SymbolicPacket Exists(absl::string_view field, SymbolicPacket packet); + // Returns a human-readable string representation of the given `packet`, // intended for debugging. [[nodiscard]] std::string ToString(SymbolicPacket packet) const; diff --git a/netkat/symbolic_packet_test.cc b/netkat/symbolic_packet_test.cc index 413028f..1db2005 100644 --- a/netkat/symbolic_packet_test.cc +++ b/netkat/symbolic_packet_test.cc @@ -21,7 +21,6 @@ #include "absl/base/no_destructor.h" #include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" @@ -331,5 +330,32 @@ void XorIsAssociative(const PredicateProto& a, const PredicateProto& b, } FUZZ_TEST(SymbolicPacketManagerTest, XorIsAssociative); +TEST(SymbolicPacketManagerTest, ExistsOnPacketWithSingleFieldReturnsFullSet) { + const std::string field = "a"; + EXPECT_EQ(Manager().Exists(field, Manager().Compile(MatchProto(field, 3))), + Manager().FullSet()); +} + +TEST(SymbolicPacketManagerTest, ExistOnFieldRemovesPacketFieldProperty) { + const std::string field = "a"; + constexpr int value = 3; + // p = (a=3 && b=4) || (b!=5 && c=5) + SymbolicPacket symbolic_packet = Manager().Compile( + OrProto(AndProto(MatchProto(field, value), MatchProto("b", 4)), + AndProto(NotProto(MatchProto("b", 5)), MatchProto("c", 5)))); + SymbolicPacket symbolic_packet_without_field = + Manager().Exists(field, symbolic_packet); + EXPECT_FALSE(Manager().Contains(symbolic_packet_without_field, + Packet{{field, value}})); +} + +TEST(SymbolicPacketManagerTest, ExistsOnFieldNotInPacketIsIdentity) { + // p = (a=3 && b=4) || (b!=5 && c=5) + SymbolicPacket symbolic_packet = Manager().Compile( + OrProto(AndProto(MatchProto("a", 3), MatchProto("b", 4)), + AndProto(NotProto(MatchProto("b", 5)), MatchProto("c", 5)))); + EXPECT_EQ(symbolic_packet, Manager().Exists("d", symbolic_packet)); +} + } // namespace } // namespace netkat