Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visitor for authorization logic AST #643

Merged
merged 12 commits into from
Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/ir/auth_logic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,35 @@ cc_library(
name = "ast",
hdrs = [
"ast.h",
Cypher1 marked this conversation as resolved.
Show resolved Hide resolved
"auth_logic_ast_visitor.h",
],
visibility = ["//visibility:private"],
deps = [
"//src/common/logging",
"//src/common/utils:map_iter",
"//src/common/utils:overloaded",
"//src/common/utils:types",
"//src/ir/datalog:program",
"@absl//absl/hash",
"@absl//absl/strings:str_format",
],
)

cc_library(
name = "auth_logic_ast_traversing_visitor",
hdrs = [
"auth_logic_ast_traversing_visitor.h",
"auth_logic_ast_visitor.h",
],
deps = [
":ast",
"//src/common/logging",
"//src/common/utils:fold",
"//src/common/utils:overloaded",
"//src/common/utils:types",
],
)

cc_library(
name = "lowering_ast_datalog",
srcs = ["lowering_ast_datalog.cc"],
Expand Down Expand Up @@ -94,6 +112,18 @@ cc_test(
],
)

cc_test(
name = "ast_visitor_test",
srcs = ["auth_logic_ast_traversing_visitor_test.cc"],
deps = [
":ast",
":auth_logic_ast_traversing_visitor",
"//src/common/testing:gtest",
"//src/ir/datalog:program",
"@absl//absl/container:flat_hash_set",
],
)

cc_library(
name = "ast_construction",
srcs = ["ast_construction.cc"],
Expand Down
209 changes: 209 additions & 0 deletions src/ir/auth_logic/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include <vector>

#include "absl/hash/hash.h"
#include "src/common/utils/map_iter.h"
#include "src/ir/auth_logic/auth_logic_ast_visitor.h"
#include "src/ir/datalog/program.h"

namespace raksha::ir::auth_logic {
Expand All @@ -34,6 +36,21 @@ class Principal {
explicit Principal(std::string name) : name_(std::move(name)) {}
const std::string& name() const { return name_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const { return name_; }

private:
std::string name_;
};
Expand All @@ -47,6 +64,23 @@ class Attribute {
const Principal& principal() const { return principal_; }
const datalog::Predicate& predicate() const { return predicate_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const {
return absl::StrCat(principal_.name(), predicate_.DebugPrint());
}

private:
Principal principal_;
datalog::Predicate predicate_;
Expand All @@ -62,6 +96,24 @@ class CanActAs {
const Principal& left_principal() const { return left_principal_; }
const Principal& right_principal() const { return right_principal_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const {
return absl::StrCat(left_principal_.DebugPrint(), " canActAs ",
right_principal_.DebugPrint());
}

private:
Principal left_principal_;
Principal right_principal_;
Expand All @@ -85,6 +137,26 @@ class BaseFact {
explicit BaseFact(BaseFactVariantType value) : value_(std::move(value)){};
const BaseFactVariantType& GetValue() const { return value_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const {
return absl::StrCat(
"BaseFact(",
std::visit([](auto& obj) { return obj.DebugPrint(); }, this->value_),
")");
}

private:
BaseFactVariantType value_;
};
Expand All @@ -103,6 +175,28 @@ class Fact {

const BaseFact& base_fact() const { return base_fact_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const {
aferr marked this conversation as resolved.
Show resolved Hide resolved
std::vector<std::string> delegations;
for (const Principal& delegatee : delegation_chain_) {
delegations.push_back(delegatee.DebugPrint());
}
return absl::StrCat("deleg: { ", absl::StrJoin(delegations, ", "), " }",
base_fact_.DebugPrint());
}

private:
std::forward_list<Principal> delegation_chain_;
BaseFact base_fact_;
Expand All @@ -118,6 +212,28 @@ class ConditionalAssertion {
const Fact& lhs() const { return lhs_; }
const std::vector<BaseFact>& rhs() const { return rhs_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const {
std::vector<std::string> rhs_strings;
for (const BaseFact& base_fact : rhs_) {
rhs_strings.push_back(base_fact.DebugPrint());
}
return absl::StrCat(lhs_.DebugPrint(), ":-",
absl::StrJoin(rhs_strings, ", "));
}

private:
Fact lhs_;
std::vector<BaseFact> rhs_;
Expand All @@ -135,6 +251,26 @@ class Assertion {
explicit Assertion(AssertionVariantType value) : value_(std::move(value)) {}
const AssertionVariantType& GetValue() const { return value_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const {
return absl::StrCat(
"Assertion(",
std::visit([](auto& obj) { return obj.DebugPrint(); }, this->value_),
")");
}

private:
AssertionVariantType value_;
};
Expand All @@ -147,6 +283,28 @@ class SaysAssertion {
const Principal& principal() const { return principal_; }
const std::vector<Assertion>& assertions() const { return assertions_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const {
std::vector<std::string> assertion_strings;
for (const Assertion& assertion : assertions_) {
assertion_strings.push_back(assertion.DebugPrint());
}
return absl::StrCat(principal_.DebugPrint(), "says {\n",
absl::StrJoin(assertion_strings, "\n"), "}");
}

private:
Principal principal_;
std::vector<Assertion> assertions_;
Expand All @@ -164,6 +322,24 @@ class Query {
const Principal& principal() const { return principal_; }
const Fact& fact() const { return fact_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const {
return absl::StrCat("Query(", name_, principal_.DebugPrint(),
fact_.DebugPrint(), ")");
}

private:
std::string name_;
Principal principal_;
Expand Down Expand Up @@ -191,6 +367,39 @@ class Program {

const std::vector<Query>& queries() const { return queries_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const {
std::vector<std::string> relation_decl_strings;
for (const datalog::RelationDeclaration& rel_decl :
relation_declarations_) {
relation_decl_strings.push_back(rel_decl.DebugPrint());
}
std::vector<std::string> says_assertion_strings;
for (const SaysAssertion& says_assertion : says_assertions_) {
says_assertion_strings.push_back(says_assertion.DebugPrint());
}
std::vector<std::string> query_strings;
for (const Query& query : queries_) {
query_strings.push_back(query.DebugPrint());
}
return absl::StrCat("Program(\n",
absl::StrJoin(relation_decl_strings, "\n"),
absl::StrJoin(says_assertion_strings, "\n"),
absl::StrJoin(query_strings, "\n"), ")");
}

private:
std::vector<datalog::RelationDeclaration> relation_declarations_;
std::vector<SaysAssertion> says_assertions_;
Expand Down
Loading