diff --git a/netkat/evaluator.cc b/netkat/evaluator.cc index 3a50ea7..bc94d2c 100644 --- a/netkat/evaluator.cc +++ b/netkat/evaluator.cc @@ -107,6 +107,19 @@ absl::flat_hash_set Evaluate(const PolicyProto& policy, [&](const Packet& p) { return subtrahend.contains(p); }); return result; } + case PolicyProto::kSymmetricDifferenceOp: { + absl::flat_hash_set result = + Evaluate(policy.symmetric_difference_op().left(), packet); + for (const Packet& packet : + Evaluate(policy.symmetric_difference_op().right(), packet)) { + if (result.contains(packet)) { + result.erase(packet); + } else { + result.insert(packet); + } + } + return result; + } case PolicyProto::POLICY_NOT_SET: // Unset policy is treated as Deny. return {}; diff --git a/netkat/evaluator_test.cc b/netkat/evaluator_test.cc index 5e2c332..80ae91b 100644 --- a/netkat/evaluator_test.cc +++ b/netkat/evaluator_test.cc @@ -271,6 +271,31 @@ void DifferenceRemoves(Packet packet, PolicyProto left, PolicyProto right) { } FUZZ_TEST(EvaluatePolicyProtoTest, DifferenceRemoves); +void SymmetricDifferenceIsXor(Packet packet, PolicyProto left, + PolicyProto right) { + absl::flat_hash_set expected_packets = Evaluate(left, packet); + for (const Packet& packet : Evaluate(right, packet)) { + if (expected_packets.contains(packet)) { + expected_packets.erase(packet); + } else { + expected_packets.insert(packet); + } + } + + EXPECT_THAT(Evaluate(SymmetricDifferenceProto(left, right), packet), + ContainerEq(expected_packets)); +} +FUZZ_TEST(EvaluatePolicyProtoTest, SymmetricDifferenceIsXor); + +void SymmetricDifferenceByDefinition(Packet packet, PolicyProto left, + PolicyProto right) { + EXPECT_EQ(Evaluate(SymmetricDifferenceProto(left, right), packet), + Evaluate(UnionProto(DifferenceProto(left, right), + DifferenceProto(right, left)), + packet)); +} +FUZZ_TEST(EvaluatePolicyProtoTest, SymmetricDifferenceByDefinition); + void SequenceSequences(Packet packet, PolicyProto left, PolicyProto right) { absl::flat_hash_set expected_packets = Evaluate(right, Evaluate(left, packet)); diff --git a/netkat/frontend.cc b/netkat/frontend.cc index e51ad91..f173e80 100644 --- a/netkat/frontend.cc +++ b/netkat/frontend.cc @@ -136,6 +136,16 @@ absl::Status RecursivelyCheckIsValid(const PolicyProto& policy_proto) { .SetPrepend() << "PolicyProto::DifferenceOp::right is invalid: "; return absl::OkStatus(); + case PolicyProto::kSymmetricDifferenceOp: + RETURN_IF_ERROR(RecursivelyCheckIsValid( + policy_proto.symmetric_difference_op().left())) + .SetPrepend() + << "PolicyProto::SymmetricDifferenceOp::left is invalid: "; + RETURN_IF_ERROR(RecursivelyCheckIsValid( + policy_proto.symmetric_difference_op().right())) + .SetPrepend() + << "PolicyProto::SymmetricDifferenceOp::right is invalid: "; + return absl::OkStatus(); case PolicyProto::POLICY_NOT_SET: return absl::InvalidArgumentError("Unset Policy case is invalid"); } @@ -194,6 +204,11 @@ Policy Difference(Policy left, Policy right) { DifferenceProto(std::move(left).ToProto(), std::move(right).ToProto())); } +Policy SymmetricDifference(Policy left, Policy right) { + return Policy(SymmetricDifferenceProto(std::move(left).ToProto(), + std::move(right).ToProto())); +} + Policy Policy::Accept() { return Filter(Predicate::True()); } Policy Policy::Deny() { return Filter(Predicate::False()); } diff --git a/netkat/frontend.h b/netkat/frontend.h index 238994c..3270c57 100644 --- a/netkat/frontend.h +++ b/netkat/frontend.h @@ -202,6 +202,7 @@ class Policy { friend Policy Union(std::vector); friend Policy Iterate(Policy); friend Policy Difference(Policy, Policy); + friend Policy SymmetricDifference(Policy, Policy); friend Policy Record(); // Policies that conceptually represent a program that should accept or @@ -323,6 +324,13 @@ Policy Iterate(Policy policy); // be semantically different. Policy Difference(Policy, Policy); +// Performs a symmetric difference operation on the given policies. +// +// For example, SymmetricDifference(p0, p1) we compute the symmetric set +// difference between the outputs of p0 and p1. Symmetric difference is both +// associative and commutative. +Policy SymmetricDifference(Policy, Policy); + // Records the packet into the packet history. Referred to as 'dup' in the // literature. // diff --git a/netkat/frontend_test.cc b/netkat/frontend_test.cc index dc25c10..aa33736 100644 --- a/netkat/frontend_test.cc +++ b/netkat/frontend_test.cc @@ -173,6 +173,9 @@ void ExpectFromProtoToFailWithInvalidPolicyProto(PolicyProto policy_proto) { case PolicyProto::kDifferenceOp: policy_proto.mutable_difference_op()->clear_left(); break; + case PolicyProto::kSymmetricDifferenceOp: + policy_proto.mutable_symmetric_difference_op()->clear_left(); + break; // Unset policy is invalid. case PolicyProto::POLICY_NOT_SET: break; @@ -220,6 +223,15 @@ FUZZ_TEST(FrontEndTest, DifferenceToProtoIsCorrect) .WithDomains(/*policy=*/AtomicDupFreePolicyDomain(), /*policy=*/AtomicDupFreePolicyDomain()); +void SymmetricDifferenceToProtoIsCorrect(Policy left, Policy right) { + EXPECT_THAT( + SymmetricDifference(left, right).ToProto(), + EqualsProto(SymmetricDifferenceProto(left.ToProto(), right.ToProto()))); +} +FUZZ_TEST(FrontEndTest, SymmetricDifferenceToProtoIsCorrect) + .WithDomains(/*policy=*/AtomicDupFreePolicyDomain(), + /*policy=*/AtomicDupFreePolicyDomain()); + TEST(FrontEndTest, SequenceWithNoElementsIsAccept) { EXPECT_THAT(Sequence().ToProto(), EqualsProto(AcceptProto())); } diff --git a/netkat/netkat.proto b/netkat/netkat.proto index ccc5779..51730ae 100644 --- a/netkat/netkat.proto +++ b/netkat/netkat.proto @@ -117,6 +117,7 @@ message PolicyProto { Union union_op = 5; Iterate iterate_op = 6; Difference difference_op = 7; + SymmetricDifference symmetric_difference_op = 8; } // Sets the field to the given value. @@ -155,6 +156,12 @@ message PolicyProto { PolicyProto right = 2; } + // Represents the symmetric difference of two policies, i.e. a Δ b. + message SymmetricDifference { + PolicyProto left = 1; + PolicyProto right = 2; + } + // Records the packet, at the given point, into the history. Referred to as // "dup" in the literature. message Record {} diff --git a/netkat/netkat_proto_constructors.cc b/netkat/netkat_proto_constructors.cc index 33d9664..d44bf84 100644 --- a/netkat/netkat_proto_constructors.cc +++ b/netkat/netkat_proto_constructors.cc @@ -120,6 +120,13 @@ PolicyProto DifferenceProto(PolicyProto left, PolicyProto right) { return policy; } +PolicyProto SymmetricDifferenceProto(PolicyProto left, PolicyProto right) { + PolicyProto policy; + *policy.mutable_symmetric_difference_op()->mutable_left() = std::move(left); + *policy.mutable_symmetric_difference_op()->mutable_right() = std::move(right); + return policy; +} + // -- Derived Policy constructors ---------------------------------------------- PolicyProto DenyProto() { return FilterProto(FalseProto()); } @@ -176,6 +183,11 @@ std::string AsShorthandString(PolicyProto policy) { return absl::StrFormat("(%s - %s)", AsShorthandString(policy.difference_op().left()), AsShorthandString(policy.difference_op().right())); + case PolicyProto::kSymmetricDifferenceOp: + return absl::StrFormat( + "(%s (+) %s)", + AsShorthandString(policy.symmetric_difference_op().left()), + AsShorthandString(policy.symmetric_difference_op().right())); case PolicyProto::POLICY_NOT_SET: return "deny"; } diff --git a/netkat/netkat_proto_constructors.h b/netkat/netkat_proto_constructors.h index c3c188b..b08f89c 100644 --- a/netkat/netkat_proto_constructors.h +++ b/netkat/netkat_proto_constructors.h @@ -48,6 +48,7 @@ PolicyProto SequenceProto(PolicyProto left, PolicyProto right); PolicyProto UnionProto(PolicyProto left, PolicyProto right); PolicyProto IterateProto(PolicyProto iterable); PolicyProto DifferenceProto(PolicyProto left, PolicyProto right); +PolicyProto SymmetricDifferenceProto(PolicyProto left, PolicyProto right); // -- Derived Policy constructors ---------------------------------------------- @@ -65,6 +66,7 @@ PolicyProto AcceptProto(); // Policy Or -> '+' // Iterate -> '*' // Difference -> '-' +// SymmetricDifference -> '(+)' // Record -> 'record' // Match -> '@field==value' // Modify -> '@field:=value' diff --git a/netkat/netkat_proto_constructors_test.cc b/netkat/netkat_proto_constructors_test.cc index 34609bf..3cc6330 100644 --- a/netkat/netkat_proto_constructors_test.cc +++ b/netkat/netkat_proto_constructors_test.cc @@ -136,6 +136,17 @@ void DifferenceProtoReturnsDifference(PolicyProto left, PolicyProto right) { } FUZZ_TEST(PolicyProtoTest, DifferenceProtoReturnsDifference); +void SymmetricDifferenceProtoReturnsSymmetricDifference(PolicyProto left, + PolicyProto right) { + PolicyProto expected_policy; + *expected_policy.mutable_symmetric_difference_op()->mutable_left() = left; + *expected_policy.mutable_symmetric_difference_op()->mutable_right() = right; + + EXPECT_THAT(SymmetricDifferenceProto(left, right), + EqualsProto(expected_policy)); +} +FUZZ_TEST(PolicyProtoTest, SymmetricDifferenceProtoReturnsSymmetricDifference); + // -- Derived Policy tests ----------------------------------------------------- TEST(PolicyProtoTest, DenyProtoFiltersOnFalse) { diff --git a/netkat/netkat_test.cc b/netkat/netkat_test.cc index 8477b69..2a54ab8 100644 --- a/netkat/netkat_test.cc +++ b/netkat/netkat_test.cc @@ -109,6 +109,12 @@ TEST(NetkatProtoTest, PolicyOneOfFieldNamesDontRequireUnderscores) { LOG(INFO) << "difference: " << difference_op; break; } + case PolicyProto::kSymmetricDifferenceOp: { + const PolicyProto::SymmetricDifference& symmetric_difference_op = + policy.symmetric_difference_op(); + LOG(INFO) << "symmetric_difference: " << symmetric_difference_op; + break; + } case PolicyProto::POLICY_NOT_SET: break; } diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index 070d012..a6291c6 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -311,6 +311,14 @@ PacketTransformerHandle PacketTransformerManager::Compile( return transformer_by_hash_[key] = Difference(key.lhs_child, key.rhs_child); } + case PolicyProto::kSymmetricDifferenceOp: { + key.lhs_child = Compile(policy.symmetric_difference_op().left()); + key.rhs_child = Compile(policy.symmetric_difference_op().right()); + auto it = transformer_by_hash_.find(key); + if (it != transformer_by_hash_.end()) return it->second; + return transformer_by_hash_[key] = + SymmetricDifference(key.lhs_child, key.rhs_child); + } // By convention, uninitialized policies must be treated like the Deny // policy. case PolicyProto::POLICY_NOT_SET: { @@ -756,6 +764,11 @@ PacketTransformerHandle PacketTransformerManager::Difference( return Difference(GetNodeOrDie(left), GetNodeOrDie(right)); } +PacketTransformerHandle PacketTransformerManager::SymmetricDifference( + PacketTransformerHandle left, PacketTransformerHandle right) { + return Union(Difference(left, right), Difference(right, left)); +} + PacketTransformerHandle PacketTransformerManager::Iterate( PacketTransformerHandle iterable) { PacketTransformerHandle previous_approximation = Accept(); diff --git a/netkat/packet_transformer.h b/netkat/packet_transformer.h index eb2e8a1..b20accc 100644 --- a/netkat/packet_transformer.h +++ b/netkat/packet_transformer.h @@ -257,8 +257,8 @@ class PacketTransformerManager { // Returns the transformer that describes the packets produced by the `left` // transformer or the `right` transformer, but not both. - PacketTransformerHandle SymmetricDifference( - PacketTransformerHandle left, PacketTransformerHandle right) = delete; + PacketTransformerHandle SymmetricDifference(PacketTransformerHandle left, + PacketTransformerHandle right); // Dynamically checks all class invariants. Exposed for testing only. absl::Status CheckInternalInvariants() const; diff --git a/netkat/packet_transformer_test.cc b/netkat/packet_transformer_test.cc index 73d2ab9..55a3ab5 100644 --- a/netkat/packet_transformer_test.cc +++ b/netkat/packet_transformer_test.cc @@ -186,6 +186,15 @@ void DifferenceCompilesToDifference(PolicyProto left, PolicyProto right) { } FUZZ_TEST(PacketTransformerManagerTest, DifferenceCompilesToDifference); +void SymmetricDifferenceCompilesToSymmetricDifference(PolicyProto left, + PolicyProto right) { + EXPECT_EQ(Manager().Compile(SymmetricDifferenceProto(left, right)), + Manager().SymmetricDifference(Manager().Compile(left), + Manager().Compile(right))); +} +FUZZ_TEST(PacketTransformerManagerTest, + SymmetricDifferenceCompilesToSymmetricDifference); + /*--- Kleene algebra axioms and equivalences ---------------------------------*/ void UnionIsAssociative(PolicyProto a, PolicyProto b, PolicyProto c) { @@ -313,6 +322,63 @@ void DifferenceOfPolicyIsSubsetOfSelf(PolicyProto a, PolicyProto b) { } FUZZ_TEST(PacketTransformerManagerTest, DifferenceOfPolicyIsSubsetOfSelf); +void SymmetricDifferenceIsAssociative(PolicyProto a, PolicyProto b, + PolicyProto c) { + EXPECT_EQ(Manager().Compile( + SymmetricDifferenceProto(a, SymmetricDifferenceProto(b, c))), + Manager().Compile( + SymmetricDifferenceProto(SymmetricDifferenceProto(a, b), c))); +} +FUZZ_TEST(PacketTransformerManagerTest, SymmetricDifferenceIsAssociative); + +void SymmetricDifferenceIsCommutative(PolicyProto a, PolicyProto b) { + EXPECT_EQ(Manager().Compile(SymmetricDifferenceProto(a, b)), + Manager().Compile(SymmetricDifferenceProto(b, a))); +} +FUZZ_TEST(PacketTransformerManagerTest, SymmetricDifferenceIsCommutative); + +void SymmetricDifferenceOfPolicyAndDenyIsIdentity(PolicyProto policy) { + EXPECT_EQ(Manager().Compile(SymmetricDifferenceProto(policy, DenyProto())), + Manager().Compile(policy)); +} +FUZZ_TEST(PacketTransformerManagerTest, + SymmetricDifferenceOfPolicyAndDenyIsIdentity); + +void SymmetricDifferenceOfPolicyAndSelfIsDeny(PolicyProto policy) { + EXPECT_EQ(Manager().Compile(SymmetricDifferenceProto(policy, policy)), + Manager().Compile(DenyProto())); +} +FUZZ_TEST(PacketTransformerManagerTest, + SymmetricDifferenceOfPolicyAndSelfIsDeny); + +// TODO(matthewtlam): Uncomment this once Intersection is implemented. +// void IntersectionDistributesOverSymmetricDifference(PolicyProto a, +// PolicyProto b, +// PolicyProto c) { +// EXPECT_EQ( +// Manager().Compile(IntersectionProto(a, SymmetricDifferenceProto(b, +// c))), Manager().Compile(SymmetricDifferenceProto(IntersectionProto(a, +// b), +// IntersectionProto(a, c)))); +// } +// FUZZ_TEST(PacketTransformerManagerTest, +// IntersectionDistributesOverSymmetricDifference); + +void SymmetricDifferenceDefinition(PolicyProto a, PolicyProto b) { + EXPECT_EQ(Manager().Compile(SymmetricDifferenceProto(a, b)), + Manager().Compile( + UnionProto(DifferenceProto(a, b), DifferenceProto(b, a)))); +} +FUZZ_TEST(PacketTransformerManagerTest, SymmetricDifferenceDefinition); + +// TODO(matthewtlam): Uncomment this once Intersection is implemented. +// void SymmetricDifference2(PolicyProto a, PolicyProto b) { +// EXPECT_EQ(Manager().Compile(SymmetricDifferenceProto(a, b)), +// Manager().Compile( +// DifferenceProto(UnionProto(a, b), IntersectionProto(a, b)))); +// } +// FUZZ_TEST(PacketTransformerManagerTest, SymmetricDifference2); + /*--- Tests with concrete protos ---------------------------------------------*/ TEST(PacketTransformerManagerTest, KatchPaperFig5) { diff --git a/netkat/table.cc b/netkat/table.cc index 4b19b4b..4666b3e 100644 --- a/netkat/table.cc +++ b/netkat/table.cc @@ -56,6 +56,10 @@ absl::Status VerifyActionHasNoPredicate(const Policy& action) { stack.push_back(&policy->difference_op().left()); stack.push_back(&policy->difference_op().right()); break; + case PolicyProto::PolicyCase::kSymmetricDifferenceOp: + stack.push_back(&policy->symmetric_difference_op().left()); + stack.push_back(&policy->symmetric_difference_op().right()); + break; case PolicyProto::PolicyCase::kSequenceOp: stack.push_back(&policy->sequence_op().left()); stack.push_back(&policy->sequence_op().right());