From b748dc6ea17ddea2b6ab13f8f27b021d7d7ca001 Mon Sep 17 00:00:00 2001 From: Andrew Ferraiuolo Date: Tue, 9 Aug 2022 15:14:28 +0000 Subject: [PATCH 01/12] Visitor for authorization logic AST --- src/ir/auth_logic/BUILD | 27 ++ src/ir/auth_logic/ast.h | 101 ++++++ .../auth_logic_ast_traversing_visitor.h | 322 ++++++++++++++++++ .../auth_logic_ast_traversing_visitor_test.cc | 79 +++++ src/ir/auth_logic/auth_logic_ast_visitor.h | 53 +++ src/ir/ir_visitor.h | 2 +- 6 files changed, 583 insertions(+), 1 deletion(-) create mode 100644 src/ir/auth_logic/auth_logic_ast_traversing_visitor.h create mode 100644 src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc create mode 100644 src/ir/auth_logic/auth_logic_ast_visitor.h diff --git a/src/ir/auth_logic/BUILD b/src/ir/auth_logic/BUILD index cb0117dd1..403ea1ab1 100644 --- a/src/ir/auth_logic/BUILD +++ b/src/ir/auth_logic/BUILD @@ -36,6 +36,21 @@ cc_library( ], ) +cc_library( + name = "auth_logic_ast_traversing_visitor", + hdrs = [ + "auth_logic_ast_traversing_visitor.h", + "auth_logic_ast_visitor.h", + ], + deps = [ + ":ast", + "//src/common/logging", + "//src/common/utils:fold", + "//src/common/utils:overloaded", + "//src/common/utils:types", + ], +) + cc_library( name = "lowering_ast_datalog", srcs = ["lowering_ast_datalog.cc"], @@ -94,6 +109,18 @@ cc_test( ], ) +cc_test( + name = "ast_visitor_test", + srcs = ["auth_logic_ast_traversing_visitor_test.cc"], + deps = [ + ":ast", + ":auth_logic_ast_traversing_visitor", + "//src/common/testing:gtest", + "//src/ir/datalog:program", + "@absl//absl/container:flat_hash_set", + ], +) + cc_library( name = "ast_construction", srcs = ["ast_construction.cc"], diff --git a/src/ir/auth_logic/ast.h b/src/ir/auth_logic/ast.h index 78163f179..bcd798c38 100644 --- a/src/ir/auth_logic/ast.h +++ b/src/ir/auth_logic/ast.h @@ -25,6 +25,7 @@ #include #include "absl/hash/hash.h" +#include "src/ir/auth_logic/auth_logic_ast_visitor.h" #include "src/ir/datalog/program.h" namespace raksha::ir::auth_logic { @@ -34,6 +35,16 @@ class Principal { explicit Principal(std::string name) : name_(std::move(name)) {} const std::string& name() const { return name_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: std::string name_; }; @@ -47,6 +58,16 @@ class Attribute { const Principal& principal() const { return principal_; } const datalog::Predicate& predicate() const { return predicate_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: Principal principal_; datalog::Predicate predicate_; @@ -62,6 +83,16 @@ class CanActAs { const Principal& left_principal() const { return left_principal_; } const Principal& right_principal() const { return right_principal_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: Principal left_principal_; Principal right_principal_; @@ -85,6 +116,16 @@ class BaseFact { explicit BaseFact(BaseFactVariantType value) : value_(std::move(value)){}; const BaseFactVariantType& GetValue() const { return value_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: BaseFactVariantType value_; }; @@ -103,6 +144,16 @@ class Fact { const BaseFact& base_fact() const { return base_fact_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: std::forward_list delegation_chain_; BaseFact base_fact_; @@ -118,6 +169,16 @@ class ConditionalAssertion { const Fact& lhs() const { return lhs_; } const std::vector& rhs() const { return rhs_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: Fact lhs_; std::vector rhs_; @@ -135,6 +196,16 @@ class Assertion { explicit Assertion(AssertionVariantType value) : value_(std::move(value)) {} const AssertionVariantType& GetValue() const { return value_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: AssertionVariantType value_; }; @@ -147,6 +218,16 @@ class SaysAssertion { const Principal& principal() const { return principal_; } const std::vector& assertions() const { return assertions_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: Principal principal_; std::vector assertions_; @@ -164,6 +245,16 @@ class Query { const Principal& principal() const { return principal_; } const Fact& fact() const { return fact_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: std::string name_; Principal principal_; @@ -191,6 +282,16 @@ class Program { const std::vector& queries() const { return queries_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: std::vector relation_declarations_; std::vector says_assertions_; diff --git a/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h b/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h new file mode 100644 index 000000000..6077e3af8 --- /dev/null +++ b/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h @@ -0,0 +1,322 @@ +//----------------------------------------------------------------------------- +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//---------------------------------------------------------------------------- +#ifndef SRC_IR_AUTH_LOGIC_AST_TRAVERSING_VISITOR_H_ +#define SRC_IR_AUTH_LOGIC_AST_TRAVERSING_VISITOR_H_ + +#include + +#include "src/common/utils/fold.h" +#include "src/common/utils/overloaded.h" +#include "src/common/utils/types.h" +#include "src/ir/auth_logic/ast.h" +#include "src/ir/auth_logic/auth_logic_ast_visitor.h" +#include "src/ir/datalog/program.h" + +// The implementation of this visitor over the AST nodes of authorizaiton logic +// directly follows the one for the IR in /src/ir/ir_traversing_visitor.h + +namespace raksha::ir::auth_logic { + +// A visitor that also traverses the children of a node and allows performing +// different actions before (PreVisit) and after (PostVisit) the children are +// visited. Override any of the `PreVisit` and `PostVisit` methods as needed. +template +class AuthLogicAstTraversingVisitor + : public AuthLogicAstVisitor { + private: + template + struct DefaultValueGetter { + static ValueType Get() { + LOG(FATAL) << "Override required for non-default-constructible type."; + } + }; + + template + struct DefaultValueGetter< + ValueType, std::enable_if_t>> { + static ValueType Get() { return ValueType(); } + }; + + public: + virtual ~AuthLogicAstTraversingVisitor() {} + + // Gives a default value for all 'PreVisit's to start with. + // Should be over-ridden if the Result is not default constructable. + virtual Result GetDefaultValue() { return DefaultValueGetter::Get(); } + + // Used to accumulate child results from the node's children. + // Should discard or merge `child_result` into the `accumulator`. + virtual Result FoldResult(Result accumulator, Result child_result) { + return accumulator; + } + // Invoked before all the children of `principal` are visited. + virtual Result PreVisit(CopyConst& principal) { + return GetDefaultValue(); + } + // Invoked after all the children of `principal` are visited. + virtual Result PostVisit(CopyConst& principal, + Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `attribute` are visited. + virtual Result PreVisit(CopyConst& attribute) { + return GetDefaultValue(); + } + // Invoked after all the children of `attribute` are visited. + virtual Result PostVisit(CopyConst& attribute, + Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `canActAs` are visited. + virtual Result PreVisit(CopyConst& canActAs) { + return GetDefaultValue(); + } + // Invoked after all the children of `canActAs` are visited. + virtual Result PostVisit(CopyConst& canActAs, + Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `baseFact` are visited. + virtual Result PreVisit(CopyConst& baseFact) { + return GetDefaultValue(); + } + // Invoked after all the children of `baseFact` are visited. + virtual Result PostVisit(CopyConst& baseFact, + Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `fact` are visited. + virtual Result PreVisit(CopyConst& fact) { + return GetDefaultValue(); + } + // Invoked after all the children of `fact` are visited. + virtual Result PostVisit(CopyConst& fact, + Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `conditionalAssertion` are visited. + virtual Result PreVisit( + CopyConst& conditionalAssertion) { + return GetDefaultValue(); + } + // Invoked after all the children of `conditionalAssertion` are visited. + virtual Result PostVisit( + CopyConst& conditionalAssertion, + Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `assertion` are visited. + virtual Result PreVisit(CopyConst& assertion) { + return GetDefaultValue(); + } + // Invoked after all the children of `assertion` are visited. + virtual Result PostVisit(CopyConst& assertion, + Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `saysAssertion` are visited. + virtual Result PreVisit(CopyConst& saysAssertion) { + return GetDefaultValue(); + } + // Invoked after all the children of `saysAssertion` are visited. + virtual Result PostVisit(CopyConst& saysAssertion, + Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `query` are visited. + virtual Result PreVisit(CopyConst& query) { + return GetDefaultValue(); + } + // Invoked after all the children of `query` are visited. + virtual Result PostVisit(CopyConst& query, + Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `program` are visited. + virtual Result PreVisit(CopyConst& program) { + return GetDefaultValue(); + } + // Invoked after all the children of `program` are visited. + virtual Result PostVisit(CopyConst& program, + Result in_order_result) { + return in_order_result; + } + + // TODO (#644) aferr + // The Visits for the Datalog IR classes (RelationDeclaration, Predciate) + // are here temporarily until these AST classes are refactored out + // of the Datalog IR. + + virtual Result Visit( + CopyConst& relationDeclaration) { + return GetDefaultValue(); + } + + virtual Result Visit(CopyConst& predicate) { + return GetDefaultValue(); + } + + // The remaining Visits are meant to follow the convention + + Result Visit(CopyConst& principal) final override { + Result pre_visit_result = PreVisit(principal); + return PostVisit(principal, std::move(pre_visit_result)); + } + + Result Visit(CopyConst& attribute) final override { + Result pre_visit_result = PreVisit(attribute); + Result fold_result = + FoldResult(FoldResult(std::move(pre_visit_result), + attribute.principal().Accept(*this)), + // TODO(#644 aferr): fix this to use predicate().Accept once + // predicate has been refactored into ast.h + Visit(attribute.predicate())); + return PostVisit(attribute, std::move(fold_result)); + } + + Result Visit(CopyConst& canActAs) final override { + Result pre_visit_result = PreVisit(canActAs); + Result fold_result = + FoldResult(FoldResult(std::move(pre_visit_result), + canActAs.left_principal().Accept(*this)), + canActAs.right_principal().Accept(*this)); + return PostVisit(canActAs, std::move(fold_result)); + } + + Result Visit(CopyConst& baseFact) final override { + Result pre_visit_result = PreVisit(baseFact); + Result variant_visit_result = std::visit( + raksha::utils::overloaded{ + [this](const datalog::Predicate& pred) { + return VariantVisit(pred); + }, + [this](const Attribute& attrib) { return VariantVisit(attrib); }, + [this](const CanActAs& canActAs) { + return VariantVisit(canActAs); + }}, + baseFact.GetValue()); + Result fold_result = FoldResult(std::move(pre_visit_result), + std::move(variant_visit_result)); + return PostVisit(baseFact, std::move(fold_result)); + } + + Result Visit(CopyConst& fact) final override { + Result pre_visit_result = PreVisit(fact); + Result base_fact_result = + FoldResult(std::move(pre_visit_result), fact.base_fact().Accept(*this)); + Result fold_result = common::utils::fold( + fact.delegation_chain(), std::move(base_fact_result), + [this](Result acc, CopyConst principal) { + return FoldResult(std::move(acc), principal.Accept(*this)); + }); + return PostVisit(fact, std::move(fold_result)); + } + + Result Visit(CopyConst& conditionalAssertion) + final override { + Result pre_visit_result = PreVisit(conditionalAssertion); + Result lhs_result = FoldResult(std::move(pre_visit_result), + conditionalAssertion.lhs().Accept(*this)); + Result fold_result = common::utils::fold( + conditionalAssertion.rhs(), std::move(lhs_result), + [this](Result acc, CopyConst baseFact) { + return FoldResult(std::move(acc), baseFact.Accept(*this)); + }); + return PostVisit(conditionalAssertion, std::move(fold_result)); + } + + Result Visit(CopyConst& assertion) final override { + Result pre_visit_result = PreVisit(assertion); + Result variant_visit_result = + std::visit(raksha::utils::overloaded{ + [this](const Fact& fact) { return VariantVisit(fact); }, + [this](const ConditionalAssertion& condAssertion) { + return VariantVisit(condAssertion); + }}, + assertion.GetValue()); + Result fold_result = FoldResult(std::move(pre_visit_result), + std::move(variant_visit_result)); + return PostVisit(assertion, std::move(fold_result)); + } + + Result Visit( + CopyConst& saysAssertion) final override { + Result pre_visit_result = PreVisit(saysAssertion); + Result principal_result = FoldResult( + std::move(pre_visit_result), saysAssertion.principal().Accept(*this)); + Result fold_result = common::utils::fold( + saysAssertion.assertions(), std::move(principal_result), + [this](Result acc, CopyConst assertion) { + return FoldResult(std::move(acc), assertion.Accept(*this)); + }); + return PostVisit(saysAssertion, fold_result); + } + + Result Visit(CopyConst& query) final override { + Result pre_visit_result = PreVisit(query); + Result fold_result = FoldResult(std::move(pre_visit_result), + FoldResult(query.principal().Accept(*this), + query.fact().Accept(*this))); + return PostVisit(query, fold_result); + } + + Result Visit(CopyConst& program) final override { + Result pre_visit_result = PreVisit(program); + Result declarations_result = common::utils::fold( + program.relation_declarations(), std::move(pre_visit_result), + [this](Result acc, CopyConst + relationDeclaration) { + // TODO(#644 aferr) Fix this to accept once once relationDeclaration + // has been refactored into ast.h + return FoldResult(std::move(acc), Visit(relationDeclaration)); + }); + Result says_assertions_result = common::utils::fold( + program.says_assertions(), std::move(declarations_result), + [this](Result acc, CopyConst saysAssertion) { + return FoldResult(std::move(acc), saysAssertion.Accept(*this)); + }); + Result queries_result = common::utils::fold( + program.queries(), std::move(says_assertions_result), + [this](Result acc, CopyConst query) { + return FoldResult(std::move(acc), query.Accept(*this)); + }); + return PostVisit(program, queries_result); + } + + // The VariantVisit methods use overloading to help visit + // the alternatives for the underlying std::variants in the AST + + // For BaseFactVariantType + Result VariantVisit(datalog::Predicate predicate) { + // TODO(#644 aferr) once a separate predicate has been added to ast.h + // this should use predicate.Accept(*this); + return Visit(predicate); + } + Result VariantVisit(Attribute attribute) { return attribute.Accept(*this); } + Result VariantVisit(CanActAs canActAs) { return canActAs.Accept(*this); } + + // For AssertionVariantType + Result VariantVisit(Fact fact) { return fact.Accept(*this); } + Result VariantVisit(ConditionalAssertion conditionalAssertion) { + return conditionalAssertion.Accept(*this); + } + + private: +}; + +} // namespace raksha::ir::auth_logic + +#endif // SRC_IR_AUTH_LOGIC_AST_TRAVERSING_VISITOR_H_ diff --git a/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc b/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc new file mode 100644 index 000000000..9b8650853 --- /dev/null +++ b/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc @@ -0,0 +1,79 @@ +//----------------------------------------------------------------------------- +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//---------------------------------------------------------------------------- + +#include "src/ir/auth_logic/auth_logic_ast_traversing_visitor.h" + +#include "absl/container/flat_hash_set.h" +#include "src/common/testing/gtest.h" +#include "src/ir/auth_logic/ast.h" +#include "src/ir/datalog/program.h" + +namespace raksha::ir::auth_logic { +namespace { + +// A visitor that makes a set of all the names of principals in the program +class PrincipalNameCollectorVisitor + : public AuthLogicAstTraversingVisitor> { + public: + PrincipalNameCollectorVisitor() {} + + absl::flat_hash_set GetDefaultValue() override { return {}; } + + absl::flat_hash_set FoldResult( + absl::flat_hash_set acc, + absl::flat_hash_set child_result) { + acc.merge(std::move(child_result)); + return std::move(acc); + } + + absl::flat_hash_set PreVisit( + const Principal& principal) override { + return {principal.name()}; + } +}; + +Program BuildTestProgram1() { + SaysAssertion assertion1 = SaysAssertion( + Principal("PrincipalA"), + {Assertion(Fact({}, BaseFact(datalog::Predicate("foo", {"bar", "baz"}, + datalog::kPositive))))}); + SaysAssertion assertion2 = SaysAssertion( + Principal("PrincipalA"), + {Assertion( + Fact({}, BaseFact(datalog::Predicate("foo", {"barbar", "bazbaz"}, + datalog::kPositive))))}); + SaysAssertion assertion3 = SaysAssertion( + Principal("PrincipalB"), + {Assertion(Fact({}, BaseFact(CanActAs(Principal("PrincipalA"), + Principal("PrincipalC")))))}); + std::vector assertion_list = { + std::move(assertion1), std::move(assertion2), std::move(assertion3)}; + return Program({}, std::move(assertion_list), {}); +} + +TEST(AuthLogicAstTraversingVisitorTest, PrincipalNameCollectorTest) { + Program test_prog = BuildTestProgram1(); + PrincipalNameCollectorVisitor collector_visitor; + const absl::flat_hash_set result = + test_prog.Accept(collector_visitor); + const absl::flat_hash_set expected = {"PrincipalA", "PrincipalB", + "PrincipalC"}; + EXPECT_EQ(result, expected); +} + +} // namespace +} // namespace raksha::ir::auth_logic diff --git a/src/ir/auth_logic/auth_logic_ast_visitor.h b/src/ir/auth_logic/auth_logic_ast_visitor.h new file mode 100644 index 000000000..b15eebbc8 --- /dev/null +++ b/src/ir/auth_logic/auth_logic_ast_visitor.h @@ -0,0 +1,53 @@ +//----------------------------------------------------------------------------- +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//---------------------------------------------------------------------------- +#ifndef SRC_IR_AUTH_LOGIC_AST_VISITOR_H_ +#define SRC_IR_AUTH_LOGIC_AST_VISITOR_H_ + +#include "src/common/utils/types.h" + +namespace raksha::ir::auth_logic { + +class Principal; +class Attribute; +class CanActAs; +class BaseFact; +class Fact; +class ConditionalAssertion; +class Assertion; +class SaysAssertion; +class Query; +class Program; + +template +class AuthLogicAstVisitor { + public: + virtual ~AuthLogicAstVisitor() {} + virtual Result Visit(CopyConst& principal) = 0; + virtual Result Visit(CopyConst& attribute) = 0; + virtual Result Visit(CopyConst& canActAs) = 0; + virtual Result Visit(CopyConst& baseFact) = 0; + virtual Result Visit(CopyConst& fact) = 0; + virtual Result Visit( + CopyConst& conditionalAssertion) = 0; + virtual Result Visit(CopyConst& assertion) = 0; + virtual Result Visit(CopyConst& saysAssertion) = 0; + virtual Result Visit(CopyConst& query) = 0; + virtual Result Visit(CopyConst& program) = 0; +}; + +} // namespace raksha::ir::auth_logic + +#endif // SRC_IR_AUTH_LOGIC_AST_VISITOR_H_ \ No newline at end of file diff --git a/src/ir/ir_visitor.h b/src/ir/ir_visitor.h index a412a6743..4fdd88260 100644 --- a/src/ir/ir_visitor.h +++ b/src/ir/ir_visitor.h @@ -31,7 +31,7 @@ class IRVisitor { public: virtual ~IRVisitor() {} virtual Result Visit(CopyConst& module) = 0; - virtual Result Visit(CopyConst& operation) = 0; + virtual Result Visit(CopyConst& block) = 0; virtual Result Visit(CopyConst& operation) = 0; }; From c2a3746ee7a3b578bfe940bb31d31db89f6f66bb Mon Sep 17 00:00:00 2001 From: Andrew Ferraiuolo Date: Tue, 9 Aug 2022 16:33:51 +0000 Subject: [PATCH 02/12] Fix BUILD --- src/ir/auth_logic/BUILD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ir/auth_logic/BUILD b/src/ir/auth_logic/BUILD index 403ea1ab1..efa159e5d 100644 --- a/src/ir/auth_logic/BUILD +++ b/src/ir/auth_logic/BUILD @@ -24,12 +24,14 @@ package( cc_library( name = "ast", hdrs = [ + "auth_logic_ast_visitor.h", "ast.h", ], visibility = ["//visibility:private"], deps = [ "//src/common/logging", "//src/common/utils:overloaded", + "//src/common/utils:types", "//src/ir/datalog:program", "@absl//absl/hash", "@absl//absl/strings:str_format", From 122d48ee3db0a5b20b7da72bf3b86b912a582a15 Mon Sep 17 00:00:00 2001 From: Andrew Ferraiuolo Date: Tue, 9 Aug 2022 16:44:32 +0000 Subject: [PATCH 03/12] Fix build --- src/ir/auth_logic/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ir/auth_logic/BUILD b/src/ir/auth_logic/BUILD index efa159e5d..39296a8a7 100644 --- a/src/ir/auth_logic/BUILD +++ b/src/ir/auth_logic/BUILD @@ -24,8 +24,8 @@ package( cc_library( name = "ast", hdrs = [ - "auth_logic_ast_visitor.h", "ast.h", + "auth_logic_ast_visitor.h", ], visibility = ["//visibility:private"], deps = [ From c8ec5462cfb09fb54bfbfcb9938647c2f07494b6 Mon Sep 17 00:00:00 2001 From: Andrew Ferraiuolo Date: Tue, 9 Aug 2022 16:51:27 +0000 Subject: [PATCH 04/12] Make clang happy again --- src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc b/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc index 9b8650853..ee17e7b70 100644 --- a/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc +++ b/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc @@ -35,7 +35,7 @@ class PrincipalNameCollectorVisitor absl::flat_hash_set FoldResult( absl::flat_hash_set acc, - absl::flat_hash_set child_result) { + absl::flat_hash_set child_result) override { acc.merge(std::move(child_result)); return std::move(acc); } From c6f9d2d5dcd4534bc2909e33a5e80f0f35b54ef8 Mon Sep 17 00:00:00 2001 From: Andrew Ferraiuolo Date: Tue, 9 Aug 2022 16:56:51 +0000 Subject: [PATCH 05/12] Remove extra move --- src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc b/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc index ee17e7b70..e7b5114fa 100644 --- a/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc +++ b/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc @@ -37,7 +37,7 @@ class PrincipalNameCollectorVisitor absl::flat_hash_set acc, absl::flat_hash_set child_result) override { acc.merge(std::move(child_result)); - return std::move(acc); + return acc; } absl::flat_hash_set PreVisit( From a07e77e12a5061c43c0fd5af78a08f4ca800ac44 Mon Sep 17 00:00:00 2001 From: Andrew Ferraiuolo Date: Wed, 10 Aug 2022 11:51:02 +0100 Subject: [PATCH 06/12] Some PR fixups --- src/ir/auth_logic/ast.h | 40 ++++----- .../auth_logic_ast_traversing_visitor.h | 87 +++++++++---------- src/ir/auth_logic/auth_logic_ast_visitor.h | 10 ++- 3 files changed, 71 insertions(+), 66 deletions(-) diff --git a/src/ir/auth_logic/ast.h b/src/ir/auth_logic/ast.h index bcd798c38..bbbfa1368 100644 --- a/src/ir/auth_logic/ast.h +++ b/src/ir/auth_logic/ast.h @@ -36,12 +36,12 @@ class Principal { const std::string& name() const { return name_; } template - Result Accept(AuthLogicAstVisitor& visitor) { + Result Accept(AuthLogicAstVisitor& visitor) { return visitor.Visit(*this); } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept(AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } @@ -59,12 +59,12 @@ class Attribute { const datalog::Predicate& predicate() const { return predicate_; } template - Result Accept(AuthLogicAstVisitor& visitor) { + Result Accept(AuthLogicAstVisitor& visitor) { return visitor.Visit(*this); } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept(AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } @@ -84,12 +84,12 @@ class CanActAs { const Principal& right_principal() const { return right_principal_; } template - Result Accept(AuthLogicAstVisitor& visitor) { + Result Accept(AuthLogicAstVisitor& visitor) { return visitor.Visit(*this); } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept(AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } @@ -117,12 +117,12 @@ class BaseFact { const BaseFactVariantType& GetValue() const { return value_; } template - Result Accept(AuthLogicAstVisitor& visitor) { + Result Accept(AuthLogicAstVisitor& visitor) { return visitor.Visit(*this); } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept(AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } @@ -145,12 +145,12 @@ class Fact { const BaseFact& base_fact() const { return base_fact_; } template - Result Accept(AuthLogicAstVisitor& visitor) { + Result Accept(AuthLogicAstVisitor& visitor) { return visitor.Visit(*this); } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept(AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } @@ -170,12 +170,12 @@ class ConditionalAssertion { const std::vector& rhs() const { return rhs_; } template - Result Accept(AuthLogicAstVisitor& visitor) { + Result Accept(AuthLogicAstVisitor& visitor) { return visitor.Visit(*this); } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept(AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } @@ -197,12 +197,12 @@ class Assertion { const AssertionVariantType& GetValue() const { return value_; } template - Result Accept(AuthLogicAstVisitor& visitor) { + Result Accept(AuthLogicAstVisitor& visitor) { return visitor.Visit(*this); } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept(AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } @@ -219,12 +219,12 @@ class SaysAssertion { const std::vector& assertions() const { return assertions_; } template - Result Accept(AuthLogicAstVisitor& visitor) { + Result Accept(AuthLogicAstVisitor& visitor) { return visitor.Visit(*this); } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept(AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } @@ -246,12 +246,12 @@ class Query { const Fact& fact() const { return fact_; } template - Result Accept(AuthLogicAstVisitor& visitor) { + Result Accept(AuthLogicAstVisitor& visitor) { return visitor.Visit(*this); } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept(AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } @@ -283,12 +283,12 @@ class Program { const std::vector& queries() const { return queries_; } template - Result Accept(AuthLogicAstVisitor& visitor) { + Result Accept(AuthLogicAstVisitor& visitor) { return visitor.Visit(*this); } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept(AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } diff --git a/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h b/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h index 6077e3af8..47e399b82 100644 --- a/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h +++ b/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h @@ -33,7 +33,7 @@ namespace raksha::ir::auth_logic { // A visitor that also traverses the children of a node and allows performing // different actions before (PreVisit) and after (PostVisit) the children are // visited. Override any of the `PreVisit` and `PostVisit` methods as needed. -template +template class AuthLogicAstTraversingVisitor : public AuthLogicAstVisitor { private: @@ -80,21 +80,21 @@ class AuthLogicAstTraversingVisitor Result in_order_result) { return in_order_result; } - // Invoked before all the children of `canActAs` are visited. - virtual Result PreVisit(CopyConst& canActAs) { + // Invoked before all the children of `can_act_as` are visited. + virtual Result PreVisit(CopyConst& can_act_as) { return GetDefaultValue(); } // Invoked after all the children of `canActAs` are visited. - virtual Result PostVisit(CopyConst& canActAs, + virtual Result PostVisit(CopyConst& can_act_as, Result in_order_result) { return in_order_result; } // Invoked before all the children of `baseFact` are visited. - virtual Result PreVisit(CopyConst& baseFact) { + virtual Result PreVisit(CopyConst& base_fact) { return GetDefaultValue(); } // Invoked after all the children of `baseFact` are visited. - virtual Result PostVisit(CopyConst& baseFact, + virtual Result PostVisit(CopyConst& base_fact, Result in_order_result) { return in_order_result; } @@ -109,12 +109,12 @@ class AuthLogicAstTraversingVisitor } // Invoked before all the children of `conditionalAssertion` are visited. virtual Result PreVisit( - CopyConst& conditionalAssertion) { + CopyConst& conditional_assertion) { return GetDefaultValue(); } // Invoked after all the children of `conditionalAssertion` are visited. virtual Result PostVisit( - CopyConst& conditionalAssertion, + CopyConst& conditional_assertion, Result in_order_result) { return in_order_result; } @@ -128,11 +128,11 @@ class AuthLogicAstTraversingVisitor return in_order_result; } // Invoked before all the children of `saysAssertion` are visited. - virtual Result PreVisit(CopyConst& saysAssertion) { + virtual Result PreVisit(CopyConst& says_assertion) { return GetDefaultValue(); } // Invoked after all the children of `saysAssertion` are visited. - virtual Result PostVisit(CopyConst& saysAssertion, + virtual Result PostVisit(CopyConst& says_assertion, Result in_order_result) { return in_order_result; } @@ -161,7 +161,7 @@ class AuthLogicAstTraversingVisitor // of the Datalog IR. virtual Result Visit( - CopyConst& relationDeclaration) { + CopyConst& relation_declaration) { return GetDefaultValue(); } @@ -187,30 +187,30 @@ class AuthLogicAstTraversingVisitor return PostVisit(attribute, std::move(fold_result)); } - Result Visit(CopyConst& canActAs) final override { - Result pre_visit_result = PreVisit(canActAs); + Result Visit(CopyConst& can_act_as) final override { + Result pre_visit_result = PreVisit(can_act_as); Result fold_result = FoldResult(FoldResult(std::move(pre_visit_result), - canActAs.left_principal().Accept(*this)), - canActAs.right_principal().Accept(*this)); - return PostVisit(canActAs, std::move(fold_result)); + can_act_as.left_principal().Accept(*this)), + can_act_as.right_principal().Accept(*this)); + return PostVisit(can_act_as, std::move(fold_result)); } - Result Visit(CopyConst& baseFact) final override { - Result pre_visit_result = PreVisit(baseFact); + Result Visit(CopyConst& base_fact) final override { + Result pre_visit_result = PreVisit(base_fact); Result variant_visit_result = std::visit( raksha::utils::overloaded{ [this](const datalog::Predicate& pred) { return VariantVisit(pred); }, [this](const Attribute& attrib) { return VariantVisit(attrib); }, - [this](const CanActAs& canActAs) { - return VariantVisit(canActAs); + [this](const CanActAs& can_act_as) { + return VariantVisit(can_act_as); }}, - baseFact.GetValue()); + base_fact.GetValue()); Result fold_result = FoldResult(std::move(pre_visit_result), std::move(variant_visit_result)); - return PostVisit(baseFact, std::move(fold_result)); + return PostVisit(base_fact, std::move(fold_result)); } Result Visit(CopyConst& fact) final override { @@ -225,17 +225,17 @@ class AuthLogicAstTraversingVisitor return PostVisit(fact, std::move(fold_result)); } - Result Visit(CopyConst& conditionalAssertion) + Result Visit(CopyConst& conditional_assertion) final override { - Result pre_visit_result = PreVisit(conditionalAssertion); + Result pre_visit_result = PreVisit(conditional_assertion); Result lhs_result = FoldResult(std::move(pre_visit_result), - conditionalAssertion.lhs().Accept(*this)); + conditional_assertion.lhs().Accept(*this)); Result fold_result = common::utils::fold( - conditionalAssertion.rhs(), std::move(lhs_result), - [this](Result acc, CopyConst baseFact) { - return FoldResult(std::move(acc), baseFact.Accept(*this)); + conditional_assertion.rhs(), std::move(lhs_result), + [this](Result acc, CopyConst base_fact) { + return FoldResult(std::move(acc), base_fact.Accept(*this)); }); - return PostVisit(conditionalAssertion, std::move(fold_result)); + return PostVisit(conditional_assertion, std::move(fold_result)); } Result Visit(CopyConst& assertion) final override { @@ -243,8 +243,8 @@ class AuthLogicAstTraversingVisitor Result variant_visit_result = std::visit(raksha::utils::overloaded{ [this](const Fact& fact) { return VariantVisit(fact); }, - [this](const ConditionalAssertion& condAssertion) { - return VariantVisit(condAssertion); + [this](const ConditionalAssertion& cond_assertion) { + return VariantVisit(cond_assertion); }}, assertion.GetValue()); Result fold_result = FoldResult(std::move(pre_visit_result), @@ -253,16 +253,16 @@ class AuthLogicAstTraversingVisitor } Result Visit( - CopyConst& saysAssertion) final override { - Result pre_visit_result = PreVisit(saysAssertion); + CopyConst& says_assertion) final override { + Result pre_visit_result = PreVisit(says_assertion); Result principal_result = FoldResult( - std::move(pre_visit_result), saysAssertion.principal().Accept(*this)); + std::move(pre_visit_result), says_assertion.principal().Accept(*this)); Result fold_result = common::utils::fold( - saysAssertion.assertions(), std::move(principal_result), + says_assertion.assertions(), std::move(principal_result), [this](Result acc, CopyConst assertion) { return FoldResult(std::move(acc), assertion.Accept(*this)); }); - return PostVisit(saysAssertion, fold_result); + return PostVisit(says_assertion, fold_result); } Result Visit(CopyConst& query) final override { @@ -278,15 +278,15 @@ class AuthLogicAstTraversingVisitor Result declarations_result = common::utils::fold( program.relation_declarations(), std::move(pre_visit_result), [this](Result acc, CopyConst - relationDeclaration) { + relation_declaration) { // TODO(#644 aferr) Fix this to accept once once relationDeclaration // has been refactored into ast.h - return FoldResult(std::move(acc), Visit(relationDeclaration)); + return FoldResult(std::move(acc), Visit(relation_declaration)); }); Result says_assertions_result = common::utils::fold( program.says_assertions(), std::move(declarations_result), - [this](Result acc, CopyConst saysAssertion) { - return FoldResult(std::move(acc), saysAssertion.Accept(*this)); + [this](Result acc, CopyConst says_assertion) { + return FoldResult(std::move(acc), says_assertion.Accept(*this)); }); Result queries_result = common::utils::fold( program.queries(), std::move(says_assertions_result), @@ -306,15 +306,14 @@ class AuthLogicAstTraversingVisitor return Visit(predicate); } Result VariantVisit(Attribute attribute) { return attribute.Accept(*this); } - Result VariantVisit(CanActAs canActAs) { return canActAs.Accept(*this); } + Result VariantVisit(CanActAs can_act_as) { return can_act_as.Accept(*this); } // For AssertionVariantType Result VariantVisit(Fact fact) { return fact.Accept(*this); } - Result VariantVisit(ConditionalAssertion conditionalAssertion) { - return conditionalAssertion.Accept(*this); + Result VariantVisit(ConditionalAssertion conditional_assertion) { + return conditional_assertion.Accept(*this); } - private: }; } // namespace raksha::ir::auth_logic diff --git a/src/ir/auth_logic/auth_logic_ast_visitor.h b/src/ir/auth_logic/auth_logic_ast_visitor.h index b15eebbc8..856b48b67 100644 --- a/src/ir/auth_logic/auth_logic_ast_visitor.h +++ b/src/ir/auth_logic/auth_logic_ast_visitor.h @@ -31,7 +31,13 @@ class SaysAssertion; class Query; class Program; -template +enum AstNodeMutability : bool { + Mutable = false, + Immutable = true +}; + + +template class AuthLogicAstVisitor { public: virtual ~AuthLogicAstVisitor() {} @@ -50,4 +56,4 @@ class AuthLogicAstVisitor { } // namespace raksha::ir::auth_logic -#endif // SRC_IR_AUTH_LOGIC_AST_VISITOR_H_ \ No newline at end of file +#endif // SRC_IR_AUTH_LOGIC_AST_VISITOR_H_ From 55841419025f13cfb69a8298df005c95d1bf737f Mon Sep 17 00:00:00 2001 From: Andrew Ferraiuolo Date: Wed, 10 Aug 2022 12:36:16 +0100 Subject: [PATCH 07/12] PR Fixes --- .../auth_logic_ast_traversing_visitor.h | 39 +++++++++---------- .../auth_logic_ast_traversing_visitor_test.cc | 2 +- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h b/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h index 47e399b82..6f21d6eea 100644 --- a/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h +++ b/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h @@ -57,10 +57,9 @@ class AuthLogicAstTraversingVisitor // Should be over-ridden if the Result is not default constructable. virtual Result GetDefaultValue() { return DefaultValueGetter::Get(); } - // Used to accumulate child results from the node's children. - // Should discard or merge `child_result` into the `accumulator`. - virtual Result FoldResult(Result accumulator, Result child_result) { - return accumulator; + // Used to combine two `Result`s into one result while visiting a node + virtual Result CombineResult(Result left_result, Result right_result) { + return left_result; } // Invoked before all the children of `principal` are visited. virtual Result PreVisit(CopyConst& principal) { @@ -179,7 +178,7 @@ class AuthLogicAstTraversingVisitor Result Visit(CopyConst& attribute) final override { Result pre_visit_result = PreVisit(attribute); Result fold_result = - FoldResult(FoldResult(std::move(pre_visit_result), + CombineResult(CombineResult(std::move(pre_visit_result), attribute.principal().Accept(*this)), // TODO(#644 aferr): fix this to use predicate().Accept once // predicate has been refactored into ast.h @@ -190,7 +189,7 @@ class AuthLogicAstTraversingVisitor Result Visit(CopyConst& can_act_as) final override { Result pre_visit_result = PreVisit(can_act_as); Result fold_result = - FoldResult(FoldResult(std::move(pre_visit_result), + CombineResult(CombineResult(std::move(pre_visit_result), can_act_as.left_principal().Accept(*this)), can_act_as.right_principal().Accept(*this)); return PostVisit(can_act_as, std::move(fold_result)); @@ -208,7 +207,7 @@ class AuthLogicAstTraversingVisitor return VariantVisit(can_act_as); }}, base_fact.GetValue()); - Result fold_result = FoldResult(std::move(pre_visit_result), + Result fold_result = CombineResult(std::move(pre_visit_result), std::move(variant_visit_result)); return PostVisit(base_fact, std::move(fold_result)); } @@ -216,11 +215,11 @@ class AuthLogicAstTraversingVisitor Result Visit(CopyConst& fact) final override { Result pre_visit_result = PreVisit(fact); Result base_fact_result = - FoldResult(std::move(pre_visit_result), fact.base_fact().Accept(*this)); + CombineResult(std::move(pre_visit_result), fact.base_fact().Accept(*this)); Result fold_result = common::utils::fold( fact.delegation_chain(), std::move(base_fact_result), [this](Result acc, CopyConst principal) { - return FoldResult(std::move(acc), principal.Accept(*this)); + return CombineResult(std::move(acc), principal.Accept(*this)); }); return PostVisit(fact, std::move(fold_result)); } @@ -228,12 +227,12 @@ class AuthLogicAstTraversingVisitor Result Visit(CopyConst& conditional_assertion) final override { Result pre_visit_result = PreVisit(conditional_assertion); - Result lhs_result = FoldResult(std::move(pre_visit_result), + Result lhs_result = CombineResult(std::move(pre_visit_result), conditional_assertion.lhs().Accept(*this)); Result fold_result = common::utils::fold( conditional_assertion.rhs(), std::move(lhs_result), [this](Result acc, CopyConst base_fact) { - return FoldResult(std::move(acc), base_fact.Accept(*this)); + return CombineResult(std::move(acc), base_fact.Accept(*this)); }); return PostVisit(conditional_assertion, std::move(fold_result)); } @@ -247,7 +246,7 @@ class AuthLogicAstTraversingVisitor return VariantVisit(cond_assertion); }}, assertion.GetValue()); - Result fold_result = FoldResult(std::move(pre_visit_result), + Result fold_result = CombineResult(std::move(pre_visit_result), std::move(variant_visit_result)); return PostVisit(assertion, std::move(fold_result)); } @@ -255,20 +254,20 @@ class AuthLogicAstTraversingVisitor Result Visit( CopyConst& says_assertion) final override { Result pre_visit_result = PreVisit(says_assertion); - Result principal_result = FoldResult( + Result principal_result = CombineResult( std::move(pre_visit_result), says_assertion.principal().Accept(*this)); Result fold_result = common::utils::fold( says_assertion.assertions(), std::move(principal_result), [this](Result acc, CopyConst assertion) { - return FoldResult(std::move(acc), assertion.Accept(*this)); - }); + return CombineResult(std::move(acc), assertion.Accept(*this)); + }); return PostVisit(says_assertion, fold_result); } Result Visit(CopyConst& query) final override { Result pre_visit_result = PreVisit(query); - Result fold_result = FoldResult(std::move(pre_visit_result), - FoldResult(query.principal().Accept(*this), + Result fold_result = CombineResult(std::move(pre_visit_result), + CombineResult(query.principal().Accept(*this), query.fact().Accept(*this))); return PostVisit(query, fold_result); } @@ -281,17 +280,17 @@ class AuthLogicAstTraversingVisitor relation_declaration) { // TODO(#644 aferr) Fix this to accept once once relationDeclaration // has been refactored into ast.h - return FoldResult(std::move(acc), Visit(relation_declaration)); + return CombineResult(std::move(acc), Visit(relation_declaration)); }); Result says_assertions_result = common::utils::fold( program.says_assertions(), std::move(declarations_result), [this](Result acc, CopyConst says_assertion) { - return FoldResult(std::move(acc), says_assertion.Accept(*this)); + return CombineResult(std::move(acc), says_assertion.Accept(*this)); }); Result queries_result = common::utils::fold( program.queries(), std::move(says_assertions_result), [this](Result acc, CopyConst query) { - return FoldResult(std::move(acc), query.Accept(*this)); + return CombineResult(std::move(acc), query.Accept(*this)); }); return PostVisit(program, queries_result); } diff --git a/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc b/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc index e7b5114fa..6c1945c05 100644 --- a/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc +++ b/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc @@ -33,7 +33,7 @@ class PrincipalNameCollectorVisitor absl::flat_hash_set GetDefaultValue() override { return {}; } - absl::flat_hash_set FoldResult( + absl::flat_hash_set CombineResult ( absl::flat_hash_set acc, absl::flat_hash_set child_result) override { acc.merge(std::move(child_result)); From f53e4145681633eeaf280bb423299ed8368e75c5 Mon Sep 17 00:00:00 2001 From: Andrew Ferraiuolo Date: Thu, 11 Aug 2022 11:24:10 +0100 Subject: [PATCH 08/12] FoldAccept --- .../auth_logic_ast_traversing_visitor.h | 45 +++++++++---------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h b/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h index 6f21d6eea..be44b2435 100644 --- a/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h +++ b/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h @@ -216,11 +216,8 @@ class AuthLogicAstTraversingVisitor Result pre_visit_result = PreVisit(fact); Result base_fact_result = CombineResult(std::move(pre_visit_result), fact.base_fact().Accept(*this)); - Result fold_result = common::utils::fold( - fact.delegation_chain(), std::move(base_fact_result), - [this](Result acc, CopyConst principal) { - return CombineResult(std::move(acc), principal.Accept(*this)); - }); + Result fold_result = FoldAccept>( + fact.delegation_chain(), base_fact_result); return PostVisit(fact, std::move(fold_result)); } @@ -229,11 +226,8 @@ class AuthLogicAstTraversingVisitor Result pre_visit_result = PreVisit(conditional_assertion); Result lhs_result = CombineResult(std::move(pre_visit_result), conditional_assertion.lhs().Accept(*this)); - Result fold_result = common::utils::fold( - conditional_assertion.rhs(), std::move(lhs_result), - [this](Result acc, CopyConst base_fact) { - return CombineResult(std::move(acc), base_fact.Accept(*this)); - }); + Result fold_result = FoldAccept>( + conditional_assertion.rhs(), lhs_result); return PostVisit(conditional_assertion, std::move(fold_result)); } @@ -256,11 +250,8 @@ class AuthLogicAstTraversingVisitor Result pre_visit_result = PreVisit(says_assertion); Result principal_result = CombineResult( std::move(pre_visit_result), says_assertion.principal().Accept(*this)); - Result fold_result = common::utils::fold( - says_assertion.assertions(), std::move(principal_result), - [this](Result acc, CopyConst assertion) { - return CombineResult(std::move(acc), assertion.Accept(*this)); - }); + Result fold_result = FoldAccept> + (says_assertion.assertions(), principal_result); return PostVisit(says_assertion, fold_result); } @@ -282,16 +273,11 @@ class AuthLogicAstTraversingVisitor // has been refactored into ast.h return CombineResult(std::move(acc), Visit(relation_declaration)); }); - Result says_assertions_result = common::utils::fold( - program.says_assertions(), std::move(declarations_result), - [this](Result acc, CopyConst says_assertion) { - return CombineResult(std::move(acc), says_assertion.Accept(*this)); - }); - Result queries_result = common::utils::fold( - program.queries(), std::move(says_assertions_result), - [this](Result acc, CopyConst query) { - return CombineResult(std::move(acc), query.Accept(*this)); - }); + Result says_assertions_result = FoldAccept>(program.says_assertions(), + declarations_result); + Result queries_result = FoldAccept>( + program.queries(), says_assertions_result); return PostVisit(program, queries_result); } @@ -313,6 +299,15 @@ class AuthLogicAstTraversingVisitor return conditional_assertion.Accept(*this); } + private: + template + Result FoldAccept(Container container, Result initial) { + return common::utils::fold(container, std::move(initial), + [this](Result acc, CopyConst element) { + return CombineResult(std::move(acc), element.Accept(*this)); + }); + } + }; } // namespace raksha::ir::auth_logic From 89fa9cd5f44fcfd083879249d6796f78211bb288 Mon Sep 17 00:00:00 2001 From: Andrew Ferraiuolo Date: Thu, 11 Aug 2022 16:46:13 +0100 Subject: [PATCH 09/12] DebugPrint() --- src/ir/auth_logic/BUILD | 1 + src/ir/auth_logic/ast.h | 98 ++++++++++++ .../auth_logic_ast_traversing_visitor.h | 8 +- .../auth_logic_ast_traversing_visitor_test.cc | 150 +++++++++++++++++- src/ir/datalog/program.h | 29 ++++ 5 files changed, 280 insertions(+), 6 deletions(-) diff --git a/src/ir/auth_logic/BUILD b/src/ir/auth_logic/BUILD index 39296a8a7..0dcb2ba39 100644 --- a/src/ir/auth_logic/BUILD +++ b/src/ir/auth_logic/BUILD @@ -33,6 +33,7 @@ cc_library( "//src/common/utils:overloaded", "//src/common/utils:types", "//src/ir/datalog:program", + "//src/common/utils:map_iter", "@absl//absl/hash", "@absl//absl/strings:str_format", ], diff --git a/src/ir/auth_logic/ast.h b/src/ir/auth_logic/ast.h index bbbfa1368..171253c9a 100644 --- a/src/ir/auth_logic/ast.h +++ b/src/ir/auth_logic/ast.h @@ -27,6 +27,7 @@ #include "absl/hash/hash.h" #include "src/ir/auth_logic/auth_logic_ast_visitor.h" #include "src/ir/datalog/program.h" +#include "src/common/utils/map_iter.h" namespace raksha::ir::auth_logic { @@ -45,6 +46,12 @@ class Principal { return visitor.Visit(*this); } + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() { + return name_; + } + private: std::string name_; }; @@ -68,6 +75,12 @@ class Attribute { return visitor.Visit(*this); } + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() { + return absl::StrCat(principal_.name(), predicate_.DebugPrint()); + } + private: Principal principal_; datalog::Predicate predicate_; @@ -93,6 +106,13 @@ class CanActAs { return visitor.Visit(*this); } + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() { + return absl::StrCat(left_principal_.DebugPrint(), " canActAs ", + right_principal_.DebugPrint()); + } + private: Principal left_principal_; Principal right_principal_; @@ -126,6 +146,14 @@ class BaseFact { return visitor.Visit(*this); } + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() { + return absl::StrCat("BaseFact(", + std::visit([](auto & obj) {return obj.DebugPrint();}, this->value_), + ")"); + } + private: BaseFactVariantType value_; }; @@ -154,6 +182,17 @@ class Fact { return visitor.Visit(*this); } + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() { + std::vector delegations; + for (Principal& delegatee: delegation_chain_) { + delegations.push_back(delegatee.DebugPrint()); + } + return absl::StrCat("deleg: { ", absl::StrJoin(delegations, ", "), " }", + base_fact_.DebugPrint()); + } + private: std::forward_list delegation_chain_; BaseFact base_fact_; @@ -179,6 +218,17 @@ class ConditionalAssertion { return visitor.Visit(*this); } + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() { + std::vector rhs_strings; + for (BaseFact& base_fact : rhs_) { + rhs_strings.push_back(base_fact.DebugPrint()); + } + return absl::StrCat(lhs_.DebugPrint(), ":-", + absl::StrJoin(rhs_strings, ", ")); + } + private: Fact lhs_; std::vector rhs_; @@ -206,6 +256,14 @@ class Assertion { return visitor.Visit(*this); } + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() { + return absl::StrCat("Assertion(", + std::visit([](auto & obj) {return obj.DebugPrint();}, this->value_), + ")"); + } + private: AssertionVariantType value_; }; @@ -228,6 +286,17 @@ class SaysAssertion { return visitor.Visit(*this); } + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() { + std::vector assertion_strings; + for (Assertion& assertion: assertions_) { + assertion_strings.push_back(assertion.DebugPrint()); + } + return absl::StrCat(principal_.DebugPrint(), "says {\n", + absl::StrJoin(assertion_strings, "\n"), "}"); + } + private: Principal principal_; std::vector assertions_; @@ -255,6 +324,13 @@ class Query { return visitor.Visit(*this); } + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() { + return absl::StrCat("Query(", name_, principal_.DebugPrint(), + fact_.DebugPrint(), ")"); + } + private: std::string name_; Principal principal_; @@ -292,6 +368,28 @@ class Program { return visitor.Visit(*this); } + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() { + std::vector relation_decl_strings; + for (datalog::RelationDeclaration& rel_decl: relation_declarations_) { + relation_decl_strings.push_back(rel_decl.DebugPrint()); + } + std::vector says_assertion_strings; + for (SaysAssertion& says_assertion: says_assertions_) { + says_assertion_strings.push_back(says_assertion.DebugPrint()); + } + std::vector query_strings; + for (Query& query: queries_) { + query_strings.push_back(query.DebugPrint()); + } + return absl::StrCat("Program(\n", + absl::StrJoin(relation_decl_strings, "\n"), + absl::StrJoin(says_assertion_strings, "\n"), + absl::StrJoin(query_strings, "\n"), + ")"); + } + private: std::vector relation_declarations_; std::vector says_assertions_; diff --git a/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h b/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h index be44b2435..9901ce890 100644 --- a/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h +++ b/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h @@ -214,11 +214,11 @@ class AuthLogicAstTraversingVisitor Result Visit(CopyConst& fact) final override { Result pre_visit_result = PreVisit(fact); + Result deleg_result = FoldAccept>( + fact.delegation_chain(), pre_visit_result); Result base_fact_result = - CombineResult(std::move(pre_visit_result), fact.base_fact().Accept(*this)); - Result fold_result = FoldAccept>( - fact.delegation_chain(), base_fact_result); - return PostVisit(fact, std::move(fold_result)); + CombineResult(std::move(deleg_result), fact.base_fact().Accept(*this)); + return PostVisit(fact, std::move(base_fact_result)); } Result Visit(CopyConst& conditional_assertion) diff --git a/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc b/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc index 6c1945c05..81f746cab 100644 --- a/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc +++ b/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc @@ -20,6 +20,7 @@ #include "src/common/testing/gtest.h" #include "src/ir/auth_logic/ast.h" #include "src/ir/datalog/program.h" +#include namespace raksha::ir::auth_logic { namespace { @@ -29,8 +30,6 @@ class PrincipalNameCollectorVisitor : public AuthLogicAstTraversingVisitor> { public: - PrincipalNameCollectorVisitor() {} - absl::flat_hash_set GetDefaultValue() override { return {}; } absl::flat_hash_set CombineResult ( @@ -75,5 +74,152 @@ TEST(AuthLogicAstTraversingVisitorTest, PrincipalNameCollectorTest) { EXPECT_EQ(result, expected); } +enum class TraversalType { kPre = 0x1, kPost = 0x2, kBoth = 0x3 }; +class TraversalOrderVisitor + : public AuthLogicAstTraversingVisitor { + public: + TraversalOrderVisitor(TraversalType traversal_type) + : pre_visits_(traversal_type == TraversalType::kPre || + traversal_type == TraversalType::kBoth), + post_visits_(traversal_type == TraversalType::kPost || + traversal_type == TraversalType::kBoth) {} + + Unit PreVisit(const Principal& prin) override { + if (pre_visits_) { + std::cout << "pre principal " << std::addressof(prin) << std::endl; + nodes_.push_back(std::addressof(prin)); + } + return Unit(); + } + Unit PostVisit(const Principal& prin, Unit result) override { + if (post_visits_) nodes_.push_back(std::addressof(prin)); + return result; + } + + Unit PreVisit(const Attribute& attrib) override { + if (pre_visits_) nodes_.push_back(std::addressof(attrib)); + return Unit(); + } + Unit PostVisit(const Attribute& attrib, Unit result) override { + if (post_visits_) nodes_.push_back(std::addressof(attrib)); + return result; + } + + Unit PreVisit(const CanActAs& canActAs) override { + if (pre_visits_) nodes_.push_back(std::addressof(canActAs)); + return Unit(); + } + Unit PostVisit(const CanActAs& canActAs, Unit result) override { + if (post_visits_) nodes_.push_back(std::addressof(canActAs)); + return result; + } + + Unit PreVisit(const BaseFact& baseFact) override { + if (pre_visits_) { + std::cout << "pre baseFact " << std::addressof(baseFact) + << std::endl; + nodes_.push_back(std::addressof(baseFact)); + } + return Unit(); + } + Unit PostVisit(const BaseFact& baseFact, Unit result) override { + if (post_visits_) nodes_.push_back(std::addressof(baseFact)); + return result; + } + + Unit PreVisit(const Fact& fact) override { + if (pre_visits_) { + std::cout << "pre Fact " << std::addressof(fact) + << std::endl; + nodes_.push_back(std::addressof(fact)); + } + return Unit(); + } + Unit PostVisit(const Fact& fact, Unit result) override { + if (post_visits_) nodes_.push_back(std::addressof(fact)); + return result; + } + + Unit PreVisit(const ConditionalAssertion& condAssertion) override { + if (pre_visits_) nodes_.push_back(std::addressof(condAssertion)); + return Unit(); + } + Unit PostVisit(const ConditionalAssertion& condAssertion, Unit result) override { + if (post_visits_) nodes_.push_back(std::addressof(condAssertion)); + return result; + } + + Unit PreVisit(const Assertion& assertion) override { + if (pre_visits_) nodes_.push_back(std::addressof(assertion)); + return Unit(); + } + Unit PostVisit(const Assertion& assertion, Unit result) override { + if (post_visits_) nodes_.push_back(std::addressof(assertion)); + return result; + } + + Unit PreVisit(const SaysAssertion& saysAssertion) override { + if (pre_visits_) nodes_.push_back(std::addressof(saysAssertion)); + return Unit(); + } + Unit PostVisit(const SaysAssertion& saysAssertion, Unit result) override { + if (post_visits_) nodes_.push_back(std::addressof(saysAssertion)); + return result; + } + + Unit PreVisit(const Query& query) override { + if (pre_visits_) nodes_.push_back(std::addressof(query)); + return Unit(); + } + Unit PostVisit(const Query& query, Unit result) override { + if (post_visits_) nodes_.push_back(std::addressof(query)); + return result; + } + + Unit PreVisit(const Program& program) override { + if (pre_visits_) nodes_.push_back(std::addressof(program)); + return Unit(); + } + Unit PostVisit(const Program& program, Unit result) override { + if (post_visits_) nodes_.push_back(std::addressof(program)); + return result; + } + + const std::vector& nodes() const { return nodes_; } + + private: + bool pre_visits_; + bool post_visits_; + std::vector nodes_; +}; + +TEST(AuthLogicAstTraversingVisitorTest, SimpleTraversalTest) { + Principal prinA = Principal("PrincipalA"); + Principal prinB = Principal("PrincipalB"); + Principal prinC = Principal("PrincipalC"); + + datalog::Predicate pred1 = datalog::Predicate("pred1", {}, datalog::kPositive); + datalog::Predicate pred2 = datalog::Predicate("pred2", {}, datalog::kPositive); + + BaseFact baseFact1 = BaseFact(pred1); + BaseFact baseFact2 = BaseFact(pred2); + + Fact fact1 = Fact({}, baseFact1); + Fact fact2 = Fact({}, baseFact2); + + Assertion ast1 = Assertion(fact1); + + TraversalOrderVisitor preorder_visitor(TraversalType::kPre); + fact1.Accept(preorder_visitor); + EXPECT_THAT( + preorder_visitor.nodes(), + testing::ElementsAre( + std::addressof(fact1), + std::addressof(fact1.base_fact())) + ); + +} + + } // namespace } // namespace raksha::ir::auth_logic diff --git a/src/ir/datalog/program.h b/src/ir/datalog/program.h index e37d2a6e2..f55105f2e 100644 --- a/src/ir/datalog/program.h +++ b/src/ir/datalog/program.h @@ -66,6 +66,13 @@ class Predicate { return this->name() < otherPredicate.name(); } + + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() { + return absl::StrCat(sign_, name_, absl::StrJoin(args_, ", ")); + } + private: std::string name_; std::vector args_; @@ -83,6 +90,10 @@ class ArgumentType { : kind_(kind), name_(name) {} Kind kind() const { return kind_; } absl::string_view name() const { return name_; } + + std::string DebugPrint() { + return absl::StrCat(kind_, name_); + } private: Kind kind_; @@ -97,6 +108,12 @@ class Argument { absl::string_view argument_name() const { return argument_name_; } ArgumentType argument_type() const { return argument_type_; } + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() { + return absl::StrCat(argument_name_, " : ", argument_type_.DebugPrint()); + } + private: std::string argument_name_; ArgumentType argument_type_; @@ -114,6 +131,18 @@ class RelationDeclaration { bool is_attribute() const { return is_attribute_; } const std::vector& arguments() const { return arguments_; } + + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() { + std::vector arg_strings; + for(Argument& arg: arguments_) { + arg_strings.push_back(arg.DebugPrint()); + } + return absl::StrCat(".decl ", relation_name_, is_attribute_, + absl::StrJoin(arg_strings, ", ")); + } + private: std::string relation_name_; bool is_attribute_; From 2e6d9f1e1c234f157362ea664610301ef44374ea Mon Sep 17 00:00:00 2001 From: Andrew Ferraiuolo Date: Fri, 12 Aug 2022 12:17:58 +0100 Subject: [PATCH 10/12] More tests, formatting --- src/ir/auth_logic/ast.h | 98 +++++----- .../auth_logic_ast_traversing_visitor.h | 54 +++--- .../auth_logic_ast_traversing_visitor_test.cc | 168 ++++++++++++------ src/ir/datalog/program.h | 18 +- 4 files changed, 200 insertions(+), 138 deletions(-) diff --git a/src/ir/auth_logic/ast.h b/src/ir/auth_logic/ast.h index 171253c9a..56d24a999 100644 --- a/src/ir/auth_logic/ast.h +++ b/src/ir/auth_logic/ast.h @@ -25,9 +25,9 @@ #include #include "absl/hash/hash.h" +#include "src/common/utils/map_iter.h" #include "src/ir/auth_logic/auth_logic_ast_visitor.h" #include "src/ir/datalog/program.h" -#include "src/common/utils/map_iter.h" namespace raksha::ir::auth_logic { @@ -42,15 +42,14 @@ class Principal { } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept( + AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } // A potentially ugly print of the state in this class // for debugging/testing only - std::string DebugPrint() { - return name_; - } + std::string DebugPrint() const { return name_; } private: std::string name_; @@ -71,13 +70,14 @@ class Attribute { } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept( + AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } // A potentially ugly print of the state in this class // for debugging/testing only - std::string DebugPrint() { + std::string DebugPrint() const { return absl::StrCat(principal_.name(), predicate_.DebugPrint()); } @@ -102,15 +102,16 @@ class CanActAs { } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept( + AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } // A potentially ugly print of the state in this class // for debugging/testing only - std::string DebugPrint() { + std::string DebugPrint() const { return absl::StrCat(left_principal_.DebugPrint(), " canActAs ", - right_principal_.DebugPrint()); + right_principal_.DebugPrint()); } private: @@ -142,16 +143,18 @@ class BaseFact { } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept( + AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } // A potentially ugly print of the state in this class // for debugging/testing only - std::string DebugPrint() { - return absl::StrCat("BaseFact(", - std::visit([](auto & obj) {return obj.DebugPrint();}, this->value_), - ")"); + std::string DebugPrint() const { + return absl::StrCat( + "BaseFact(", + std::visit([](auto& obj) { return obj.DebugPrint(); }, this->value_), + ")"); } private: @@ -178,19 +181,20 @@ class Fact { } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept( + AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } // A potentially ugly print of the state in this class // for debugging/testing only - std::string DebugPrint() { + std::string DebugPrint() const { std::vector delegations; - for (Principal& delegatee: delegation_chain_) { + for (const Principal& delegatee : delegation_chain_) { delegations.push_back(delegatee.DebugPrint()); } return absl::StrCat("deleg: { ", absl::StrJoin(delegations, ", "), " }", - base_fact_.DebugPrint()); + base_fact_.DebugPrint()); } private: @@ -214,19 +218,20 @@ class ConditionalAssertion { } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept( + AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } // A potentially ugly print of the state in this class // for debugging/testing only - std::string DebugPrint() { + std::string DebugPrint() const { std::vector rhs_strings; - for (BaseFact& base_fact : rhs_) { + for (const BaseFact& base_fact : rhs_) { rhs_strings.push_back(base_fact.DebugPrint()); } return absl::StrCat(lhs_.DebugPrint(), ":-", - absl::StrJoin(rhs_strings, ", ")); + absl::StrJoin(rhs_strings, ", ")); } private: @@ -252,16 +257,18 @@ class Assertion { } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept( + AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } // A potentially ugly print of the state in this class // for debugging/testing only - std::string DebugPrint() { - return absl::StrCat("Assertion(", - std::visit([](auto & obj) {return obj.DebugPrint();}, this->value_), - ")"); + std::string DebugPrint() const { + return absl::StrCat( + "Assertion(", + std::visit([](auto& obj) { return obj.DebugPrint(); }, this->value_), + ")"); } private: @@ -282,19 +289,20 @@ class SaysAssertion { } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept( + AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } // A potentially ugly print of the state in this class // for debugging/testing only - std::string DebugPrint() { + std::string DebugPrint() const { std::vector assertion_strings; - for (Assertion& assertion: assertions_) { + for (const Assertion& assertion : assertions_) { assertion_strings.push_back(assertion.DebugPrint()); } return absl::StrCat(principal_.DebugPrint(), "says {\n", - absl::StrJoin(assertion_strings, "\n"), "}"); + absl::StrJoin(assertion_strings, "\n"), "}"); } private: @@ -320,15 +328,16 @@ class Query { } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept( + AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } // A potentially ugly print of the state in this class // for debugging/testing only - std::string DebugPrint() { + std::string DebugPrint() const { return absl::StrCat("Query(", name_, principal_.DebugPrint(), - fact_.DebugPrint(), ")"); + fact_.DebugPrint(), ")"); } private: @@ -364,30 +373,31 @@ class Program { } template - Result Accept(AuthLogicAstVisitor& visitor) const { + Result Accept( + AuthLogicAstVisitor& visitor) const { return visitor.Visit(*this); } // A potentially ugly print of the state in this class // for debugging/testing only - std::string DebugPrint() { + std::string DebugPrint() const { std::vector relation_decl_strings; - for (datalog::RelationDeclaration& rel_decl: relation_declarations_) { + for (const datalog::RelationDeclaration& rel_decl : + relation_declarations_) { relation_decl_strings.push_back(rel_decl.DebugPrint()); } std::vector says_assertion_strings; - for (SaysAssertion& says_assertion: says_assertions_) { + for (const SaysAssertion& says_assertion : says_assertions_) { says_assertion_strings.push_back(says_assertion.DebugPrint()); } std::vector query_strings; - for (Query& query: queries_) { + for (const Query& query : queries_) { query_strings.push_back(query.DebugPrint()); } return absl::StrCat("Program(\n", - absl::StrJoin(relation_decl_strings, "\n"), - absl::StrJoin(says_assertion_strings, "\n"), - absl::StrJoin(query_strings, "\n"), - ")"); + absl::StrJoin(relation_decl_strings, "\n"), + absl::StrJoin(says_assertion_strings, "\n"), + absl::StrJoin(query_strings, "\n"), ")"); } private: diff --git a/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h b/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h index 9901ce890..d651ea8c4 100644 --- a/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h +++ b/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h @@ -33,7 +33,8 @@ namespace raksha::ir::auth_logic { // A visitor that also traverses the children of a node and allows performing // different actions before (PreVisit) and after (PostVisit) the children are // visited. Override any of the `PreVisit` and `PostVisit` methods as needed. -template +template class AuthLogicAstTraversingVisitor : public AuthLogicAstVisitor { private: @@ -179,10 +180,10 @@ class AuthLogicAstTraversingVisitor Result pre_visit_result = PreVisit(attribute); Result fold_result = CombineResult(CombineResult(std::move(pre_visit_result), - attribute.principal().Accept(*this)), - // TODO(#644 aferr): fix this to use predicate().Accept once - // predicate has been refactored into ast.h - Visit(attribute.predicate())); + attribute.principal().Accept(*this)), + // TODO(#644 aferr): fix this to use predicate().Accept + // once predicate has been refactored into ast.h + Visit(attribute.predicate())); return PostVisit(attribute, std::move(fold_result)); } @@ -190,8 +191,8 @@ class AuthLogicAstTraversingVisitor Result pre_visit_result = PreVisit(can_act_as); Result fold_result = CombineResult(CombineResult(std::move(pre_visit_result), - can_act_as.left_principal().Accept(*this)), - can_act_as.right_principal().Accept(*this)); + can_act_as.left_principal().Accept(*this)), + can_act_as.right_principal().Accept(*this)); return PostVisit(can_act_as, std::move(fold_result)); } @@ -208,14 +209,14 @@ class AuthLogicAstTraversingVisitor }}, base_fact.GetValue()); Result fold_result = CombineResult(std::move(pre_visit_result), - std::move(variant_visit_result)); + std::move(variant_visit_result)); return PostVisit(base_fact, std::move(fold_result)); } Result Visit(CopyConst& fact) final override { Result pre_visit_result = PreVisit(fact); Result deleg_result = FoldAccept>( - fact.delegation_chain(), pre_visit_result); + fact.delegation_chain(), pre_visit_result); Result base_fact_result = CombineResult(std::move(deleg_result), fact.base_fact().Accept(*this)); return PostVisit(fact, std::move(base_fact_result)); @@ -224,8 +225,8 @@ class AuthLogicAstTraversingVisitor Result Visit(CopyConst& conditional_assertion) final override { Result pre_visit_result = PreVisit(conditional_assertion); - Result lhs_result = CombineResult(std::move(pre_visit_result), - conditional_assertion.lhs().Accept(*this)); + Result lhs_result = CombineResult( + std::move(pre_visit_result), conditional_assertion.lhs().Accept(*this)); Result fold_result = FoldAccept>( conditional_assertion.rhs(), lhs_result); return PostVisit(conditional_assertion, std::move(fold_result)); @@ -241,7 +242,7 @@ class AuthLogicAstTraversingVisitor }}, assertion.GetValue()); Result fold_result = CombineResult(std::move(pre_visit_result), - std::move(variant_visit_result)); + std::move(variant_visit_result)); return PostVisit(assertion, std::move(fold_result)); } @@ -250,16 +251,17 @@ class AuthLogicAstTraversingVisitor Result pre_visit_result = PreVisit(says_assertion); Result principal_result = CombineResult( std::move(pre_visit_result), says_assertion.principal().Accept(*this)); - Result fold_result = FoldAccept> - (says_assertion.assertions(), principal_result); + Result fold_result = FoldAccept>( + says_assertion.assertions(), principal_result); return PostVisit(says_assertion, fold_result); } Result Visit(CopyConst& query) final override { Result pre_visit_result = PreVisit(query); - Result fold_result = CombineResult(std::move(pre_visit_result), - CombineResult(query.principal().Accept(*this), - query.fact().Accept(*this))); + Result fold_result = + CombineResult(std::move(pre_visit_result), + CombineResult(query.principal().Accept(*this), + query.fact().Accept(*this))); return PostVisit(query, fold_result); } @@ -273,9 +275,9 @@ class AuthLogicAstTraversingVisitor // has been refactored into ast.h return CombineResult(std::move(acc), Visit(relation_declaration)); }); - Result says_assertions_result = FoldAccept>(program.says_assertions(), - declarations_result); + Result says_assertions_result = + FoldAccept>( + program.says_assertions(), declarations_result); Result queries_result = FoldAccept>( program.queries(), says_assertions_result); return PostVisit(program, queries_result); @@ -300,14 +302,14 @@ class AuthLogicAstTraversingVisitor } private: - template + template Result FoldAccept(Container container, Result initial) { - return common::utils::fold(container, std::move(initial), - [this](Result acc, CopyConst element) { - return CombineResult(std::move(acc), element.Accept(*this)); - }); + return common::utils::fold( + container, std::move(initial), + [this](Result acc, CopyConst element) { + return CombineResult(std::move(acc), element.Accept(*this)); + }); } - }; } // namespace raksha::ir::auth_logic diff --git a/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc b/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc index 81f746cab..3d8a1992f 100644 --- a/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc +++ b/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc @@ -16,11 +16,12 @@ #include "src/ir/auth_logic/auth_logic_ast_traversing_visitor.h" +#include + #include "absl/container/flat_hash_set.h" #include "src/common/testing/gtest.h" #include "src/ir/auth_logic/ast.h" #include "src/ir/datalog/program.h" -#include namespace raksha::ir::auth_logic { namespace { @@ -32,7 +33,7 @@ class PrincipalNameCollectorVisitor public: absl::flat_hash_set GetDefaultValue() override { return {}; } - absl::flat_hash_set CombineResult ( + absl::flat_hash_set CombineResult( absl::flat_hash_set acc, absl::flat_hash_set child_result) override { acc.merge(std::move(child_result)); @@ -75,151 +76,204 @@ TEST(AuthLogicAstTraversingVisitorTest, PrincipalNameCollectorTest) { } enum class TraversalType { kPre = 0x1, kPost = 0x2, kBoth = 0x3 }; -class TraversalOrderVisitor - : public AuthLogicAstTraversingVisitor { +class TraversalOrderVisitor + : public AuthLogicAstTraversingVisitor { public: - TraversalOrderVisitor(TraversalType traversal_type) + TraversalOrderVisitor(TraversalType traversal_type) : pre_visits_(traversal_type == TraversalType::kPre || traversal_type == TraversalType::kBoth), post_visits_(traversal_type == TraversalType::kPost || traversal_type == TraversalType::kBoth) {} Unit PreVisit(const Principal& prin) override { - if (pre_visits_) { - std::cout << "pre principal " << std::addressof(prin) << std::endl; - nodes_.push_back(std::addressof(prin)); - } + if (pre_visits_) nodes_.push_back(prin.DebugPrint()); return Unit(); } Unit PostVisit(const Principal& prin, Unit result) override { - if (post_visits_) nodes_.push_back(std::addressof(prin)); + if (post_visits_) nodes_.push_back(prin.DebugPrint()); return result; } Unit PreVisit(const Attribute& attrib) override { - if (pre_visits_) nodes_.push_back(std::addressof(attrib)); + if (pre_visits_) nodes_.push_back(attrib.DebugPrint()); return Unit(); } Unit PostVisit(const Attribute& attrib, Unit result) override { - if (post_visits_) nodes_.push_back(std::addressof(attrib)); + if (post_visits_) nodes_.push_back(attrib.DebugPrint()); return result; } Unit PreVisit(const CanActAs& canActAs) override { - if (pre_visits_) nodes_.push_back(std::addressof(canActAs)); + if (pre_visits_) nodes_.push_back(canActAs.DebugPrint()); return Unit(); } Unit PostVisit(const CanActAs& canActAs, Unit result) override { - if (post_visits_) nodes_.push_back(std::addressof(canActAs)); + if (post_visits_) nodes_.push_back(canActAs.DebugPrint()); return result; } Unit PreVisit(const BaseFact& baseFact) override { - if (pre_visits_) { - std::cout << "pre baseFact " << std::addressof(baseFact) - << std::endl; - nodes_.push_back(std::addressof(baseFact)); - } + if (pre_visits_) nodes_.push_back(baseFact.DebugPrint()); return Unit(); } Unit PostVisit(const BaseFact& baseFact, Unit result) override { - if (post_visits_) nodes_.push_back(std::addressof(baseFact)); + if (post_visits_) nodes_.push_back(baseFact.DebugPrint()); return result; } Unit PreVisit(const Fact& fact) override { - if (pre_visits_) { - std::cout << "pre Fact " << std::addressof(fact) - << std::endl; - nodes_.push_back(std::addressof(fact)); - } + if (pre_visits_) nodes_.push_back(fact.DebugPrint()); return Unit(); } Unit PostVisit(const Fact& fact, Unit result) override { - if (post_visits_) nodes_.push_back(std::addressof(fact)); + if (post_visits_) nodes_.push_back(fact.DebugPrint()); return result; } Unit PreVisit(const ConditionalAssertion& condAssertion) override { - if (pre_visits_) nodes_.push_back(std::addressof(condAssertion)); + if (pre_visits_) nodes_.push_back(condAssertion.DebugPrint()); return Unit(); } - Unit PostVisit(const ConditionalAssertion& condAssertion, Unit result) override { - if (post_visits_) nodes_.push_back(std::addressof(condAssertion)); + Unit PostVisit(const ConditionalAssertion& condAssertion, + Unit result) override { + if (post_visits_) nodes_.push_back(condAssertion.DebugPrint()); return result; } Unit PreVisit(const Assertion& assertion) override { - if (pre_visits_) nodes_.push_back(std::addressof(assertion)); + if (pre_visits_) nodes_.push_back(assertion.DebugPrint()); return Unit(); } Unit PostVisit(const Assertion& assertion, Unit result) override { - if (post_visits_) nodes_.push_back(std::addressof(assertion)); + if (post_visits_) nodes_.push_back(assertion.DebugPrint()); return result; } Unit PreVisit(const SaysAssertion& saysAssertion) override { - if (pre_visits_) nodes_.push_back(std::addressof(saysAssertion)); + if (pre_visits_) nodes_.push_back(saysAssertion.DebugPrint()); return Unit(); } Unit PostVisit(const SaysAssertion& saysAssertion, Unit result) override { - if (post_visits_) nodes_.push_back(std::addressof(saysAssertion)); + if (post_visits_) nodes_.push_back(saysAssertion.DebugPrint()); return result; } Unit PreVisit(const Query& query) override { - if (pre_visits_) nodes_.push_back(std::addressof(query)); + if (pre_visits_) nodes_.push_back(query.DebugPrint()); return Unit(); } Unit PostVisit(const Query& query, Unit result) override { - if (post_visits_) nodes_.push_back(std::addressof(query)); + if (post_visits_) nodes_.push_back(query.DebugPrint()); return result; } Unit PreVisit(const Program& program) override { - if (pre_visits_) nodes_.push_back(std::addressof(program)); + if (pre_visits_) nodes_.push_back(program.DebugPrint()); return Unit(); } Unit PostVisit(const Program& program, Unit result) override { - if (post_visits_) nodes_.push_back(std::addressof(program)); + if (post_visits_) nodes_.push_back(program.DebugPrint()); return result; } - const std::vector& nodes() const { return nodes_; } + const std::vector& nodes() const { return nodes_; } private: bool pre_visits_; bool post_visits_; - std::vector nodes_; + std::vector nodes_; }; TEST(AuthLogicAstTraversingVisitorTest, SimpleTraversalTest) { - Principal prinA = Principal("PrincipalA"); - Principal prinB = Principal("PrincipalB"); - Principal prinC = Principal("PrincipalC"); + Principal prinA("PrincipalA"); + Principal prinB("PrincipalB"); + Principal prinC("PrincipalC"); + + datalog::Predicate pred1("pred1", {}, datalog::kPositive); + datalog::Predicate pred2("pred2", {}, datalog::kPositive); - datalog::Predicate pred1 = datalog::Predicate("pred1", {}, datalog::kPositive); - datalog::Predicate pred2 = datalog::Predicate("pred2", {}, datalog::kPositive); + BaseFact baseFact1(pred1); + BaseFact baseFact2(pred2); - BaseFact baseFact1 = BaseFact(pred1); - BaseFact baseFact2 = BaseFact(pred2); - - Fact fact1 = Fact({}, baseFact1); - Fact fact2 = Fact({}, baseFact2); + Fact fact1({}, baseFact1); + Fact fact2({prinB}, baseFact2); - Assertion ast1 = Assertion(fact1); + Assertion assertion1(fact1); + Assertion assertion2(fact2); + + SaysAssertion saysAssertion1(prinA, {assertion1}); + SaysAssertion saysAssertion2(prinC, {assertion1, assertion2}); + + Query query1("query1", prinA, fact1); + Query query2("query2", prinB, fact2); + + Program program1({}, {saysAssertion1, saysAssertion2}, {query1, query2}); TraversalOrderVisitor preorder_visitor(TraversalType::kPre); - fact1.Accept(preorder_visitor); + program1.Accept(preorder_visitor); EXPECT_THAT( - preorder_visitor.nodes(), - testing::ElementsAre( - std::addressof(fact1), - std::addressof(fact1.base_fact())) - ); + preorder_visitor.nodes(), + testing::ElementsAre( + program1.DebugPrint(), saysAssertion1.DebugPrint(), + prinA.DebugPrint(), assertion1.DebugPrint(), fact1.DebugPrint(), + baseFact1.DebugPrint(), saysAssertion2.DebugPrint(), + prinC.DebugPrint(), assertion1.DebugPrint(), fact1.DebugPrint(), + baseFact1.DebugPrint(), assertion2.DebugPrint(), fact2.DebugPrint(), + prinB.DebugPrint(), baseFact2.DebugPrint(), query1.DebugPrint(), + prinA.DebugPrint(), fact1.DebugPrint(), baseFact1.DebugPrint(), + query2.DebugPrint(), prinB.DebugPrint(), fact2.DebugPrint(), + prinB.DebugPrint(), baseFact2.DebugPrint())); -} + TraversalOrderVisitor postorder_visitor(TraversalType::kPost); + program1.Accept(postorder_visitor); + EXPECT_THAT( + postorder_visitor.nodes(), + testing::ElementsAre( + prinA.DebugPrint(), baseFact1.DebugPrint(), fact1.DebugPrint(), + assertion1.DebugPrint(), saysAssertion1.DebugPrint(), + prinC.DebugPrint(), baseFact1.DebugPrint(), fact1.DebugPrint(), + assertion1.DebugPrint(), prinB.DebugPrint(), baseFact2.DebugPrint(), + fact2.DebugPrint(), assertion2.DebugPrint(), + saysAssertion2.DebugPrint(), prinA.DebugPrint(), + baseFact1.DebugPrint(), fact1.DebugPrint(), query1.DebugPrint(), + prinB.DebugPrint(), prinB.DebugPrint(), baseFact2.DebugPrint(), + fact2.DebugPrint(), query2.DebugPrint(), program1.DebugPrint())); + + // The bits of syntax not in program1 + Attribute attribute1(prinC, pred2); + CanActAs canActAs1(prinA, prinB); + BaseFact baseFact3(attribute1); + BaseFact baseFact4(canActAs1); + Fact fact3({}, baseFact3); + ConditionalAssertion conditionalAssertion1(fact3, {baseFact4}); + + TraversalOrderVisitor preorder_visitor2(TraversalType::kPre); + conditionalAssertion1.Accept(preorder_visitor2); + EXPECT_THAT( + preorder_visitor2.nodes(), + testing::ElementsAre(conditionalAssertion1.DebugPrint(), + fact3.DebugPrint(), baseFact3.DebugPrint(), + attribute1.DebugPrint(), prinC.DebugPrint(), + baseFact4.DebugPrint(), canActAs1.DebugPrint(), + prinA.DebugPrint(), prinB.DebugPrint())); + TraversalOrderVisitor postorder_visitor2(TraversalType::kPost); + conditionalAssertion1.Accept(postorder_visitor2); + EXPECT_THAT( + postorder_visitor2.nodes(), + testing::ElementsAre(prinC.DebugPrint(), attribute1.DebugPrint(), + baseFact3.DebugPrint(), fact3.DebugPrint(), + prinA.DebugPrint(), prinB.DebugPrint(), + canActAs1.DebugPrint(), baseFact4.DebugPrint(), + conditionalAssertion1.DebugPrint())); + + TraversalOrderVisitor both_order_visitor(TraversalType::kBoth); + canActAs1.Accept(both_order_visitor); + EXPECT_THAT(both_order_visitor.nodes(), + testing::ElementsAre(canActAs1.DebugPrint(), prinA.DebugPrint(), + prinA.DebugPrint(), prinB.DebugPrint(), + prinB.DebugPrint(), canActAs1.DebugPrint())); +} } // namespace } // namespace raksha::ir::auth_logic diff --git a/src/ir/datalog/program.h b/src/ir/datalog/program.h index f55105f2e..5aad4c3dc 100644 --- a/src/ir/datalog/program.h +++ b/src/ir/datalog/program.h @@ -66,10 +66,9 @@ class Predicate { return this->name() < otherPredicate.name(); } - // A potentially ugly print of the state in this class // for debugging/testing only - std::string DebugPrint() { + std::string DebugPrint() const { return absl::StrCat(sign_, name_, absl::StrJoin(args_, ", ")); } @@ -90,10 +89,8 @@ class ArgumentType { : kind_(kind), name_(name) {} Kind kind() const { return kind_; } absl::string_view name() const { return name_; } - - std::string DebugPrint() { - return absl::StrCat(kind_, name_); - } + + std::string DebugPrint() const { return absl::StrCat(kind_, name_); } private: Kind kind_; @@ -110,7 +107,7 @@ class Argument { // A potentially ugly print of the state in this class // for debugging/testing only - std::string DebugPrint() { + std::string DebugPrint() const { return absl::StrCat(argument_name_, " : ", argument_type_.DebugPrint()); } @@ -131,16 +128,15 @@ class RelationDeclaration { bool is_attribute() const { return is_attribute_; } const std::vector& arguments() const { return arguments_; } - // A potentially ugly print of the state in this class // for debugging/testing only - std::string DebugPrint() { + std::string DebugPrint() const { std::vector arg_strings; - for(Argument& arg: arguments_) { + for (const Argument& arg : arguments_) { arg_strings.push_back(arg.DebugPrint()); } return absl::StrCat(".decl ", relation_name_, is_attribute_, - absl::StrJoin(arg_strings, ", ")); + absl::StrJoin(arg_strings, ", ")); } private: From 3ebce18f90de8de8729c9a926860d79302b57f51 Mon Sep 17 00:00:00 2001 From: Andrew Ferraiuolo Date: Fri, 12 Aug 2022 12:33:08 +0100 Subject: [PATCH 11/12] Bazel lint --- src/ir/auth_logic/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ir/auth_logic/BUILD b/src/ir/auth_logic/BUILD index 0dcb2ba39..35dc7668d 100644 --- a/src/ir/auth_logic/BUILD +++ b/src/ir/auth_logic/BUILD @@ -30,10 +30,10 @@ cc_library( visibility = ["//visibility:private"], deps = [ "//src/common/logging", + "//src/common/utils:map_iter", "//src/common/utils:overloaded", "//src/common/utils:types", "//src/ir/datalog:program", - "//src/common/utils:map_iter", "@absl//absl/hash", "@absl//absl/strings:str_format", ], From d3426cbdcacb7bc1d3b8ba86af0076d7c469adfc Mon Sep 17 00:00:00 2001 From: Andrew Ferraiuolo Date: Thu, 18 Aug 2022 13:37:07 +0100 Subject: [PATCH 12/12] fix include --- src/ir/auth_logic/auth_logic_ast_traversing_visitor.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h b/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h index d651ea8c4..2a77ab89d 100644 --- a/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h +++ b/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h @@ -16,8 +16,7 @@ #ifndef SRC_IR_AUTH_LOGIC_AST_TRAVERSING_VISITOR_H_ #define SRC_IR_AUTH_LOGIC_AST_TRAVERSING_VISITOR_H_ -#include - +#include "src/common/logging/logging.h" #include "src/common/utils/fold.h" #include "src/common/utils/overloaded.h" #include "src/common/utils/types.h"