diff --git a/Project.toml b/Project.toml index 60dedd0..16337ec 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "OpenPolicyAgent" uuid = "8f257efb-743c-4ebc-8197-d291a1f743b4" authors = ["JuliaHub Inc.", "Tanmay Mohapatra "] -version = "0.1.1" +version = "0.2.0" [deps] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" diff --git a/docs/make.jl b/docs/make.jl index 1b72e0b..eb9d225 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -14,6 +14,7 @@ makedocs( "Client" => "client.md", "Server" => "server.md", "Command Line" => "commandline.md", + "AST Walker" => "ast_walker.md", "Reference" => "reference.md", ], ) diff --git a/docs/src/ast_walker.md b/docs/src/ast_walker.md new file mode 100644 index 0000000..6f666c9 --- /dev/null +++ b/docs/src/ast_walker.md @@ -0,0 +1,54 @@ +# AST Walker + +OPA has a feature called partial evaluation which has several interesting applications. With partial evaluation, callers specify that certain inputs or pieces of data are unknown. OPA evaluates as much of the policy as possible without touching parts that depend on unknown values. The result of partial evaluation is a new policy that can be evaluated more efficiently than the original. The new policy is returned to the caller as an AST. + +The returned AST thus represents a strategy, rather than a result. It may be cached and reused. It may also be converted to other forms, e.g. a SQL query condition, or elastic search query. + +The `ASTWalker` module provides a framework to traverse the AST returned from a partial evaluattion. It specifies a `Visitor` interface that callers can implement to perform custom operations on the AST. The `ASTWalker` module also provides a default implementation of the `Visitor` interface that can be used to perform common operations on the AST. + +Included in the `ASTWalker` module are implementations of the `Visitor` interface that can be used to: +- Create a easy to use Julia representation of the AST. This is provided by the `ASTWalker.AST.ASTVisitor` type. +- Create a SQL query condition from the Julia representation of the AST. This is provided by the `ASTWalker.SQL.SQLVisitor` type. + +An example of how it can be used is shown below: +```julia +import OpenPolicyAgent: ASTWalker +import OpenPolicyAgent.ASTWalker: AST, SQL +import OpenPolicyAgent.ASTWalker.AST: ASTVisitor +import OpenPolicyAgent.ASTWalker.SQL: SQLVisitor, SQLCondition, UnconditionalInclude, UnconditionalExclude + +# invoke the partial evaluation endpoint +partial_query_schema = OpenPolicyAgent.Client.PartialQuerySchema(; ...) +response, _http_resp = OpenPolicyAgent.Client.post_compile( + compile_client; + partial_query_schema = partial_query_schema, +) + +# crete a Julia representation of the AST +ast = OpenPolicyAgent.ASTWalker.walk(ASTVisitor(), result) + +# Provide a mapping of schema names and table names that can be used to convert policy paths to SQL table names +const SCHEMA_MAP = Dict{String, String}( + "data" => "public", + "public" => "public", +) + +const TABLE_MAP = Dict{String, String}( + "reports" => "juliahub_reports", +) + +# create a SQL query condition from the AST +sqlvisitor = SQLVisitor(SCHEMA_MAP, TABLE_MAP) +sqlcondition = OpenPolicyAgent.ASTWalker.walk(sqlvisitor, ast) + +# sql condition should be a SQLCondition object +if isa(sqlcondition, UnconditionalExclude) + # all rows should be excluded +elseif isa(sqlcondition, UnconditionalInclude) + # all rows should be included +else + # `sqlcondition.sql` contains a string with the SQL query condition +end +``` + +More details of AST walker, and the included visitors can be found in the reference documentation. \ No newline at end of file diff --git a/docs/src/reference.md b/docs/src/reference.md index 4442a15..9d60592 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -92,3 +92,20 @@ OpenPolicyAgent.CLI.sign OpenPolicyAgent.CLI.test OpenPolicyAgent.CLI.bench ``` + +## AST Walker + +```@docs +OpenPolicyAgent.ASTWalker.Visitor +OpenPolicyAgent.ASTWalker.walk +OpenPolicyAgent.ASTWalker.before +OpenPolicyAgent.ASTWalker.visit +OpenPolicyAgent.ASTWalker.after +``` + +### Included Visitors + +```@docs +OpenPolicyAgent.ASTWalker.AST.ASTVisitor +OpenPolicyAgent.ASTWalker.SQL.SQLVisitor +``` diff --git a/src/OpenPolicyAgent.jl b/src/OpenPolicyAgent.jl index cd86287..cb9224e 100644 --- a/src/OpenPolicyAgent.jl +++ b/src/OpenPolicyAgent.jl @@ -3,5 +3,6 @@ module OpenPolicyAgent include("cli/cli.jl") include("client/src/Client.jl") include("server/server.jl") +include("utils/ast_walker.jl") end # module OpenPolicyAgent \ No newline at end of file diff --git a/src/utils/ast.jl b/src/utils/ast.jl new file mode 100644 index 0000000..e8c346b --- /dev/null +++ b/src/utils/ast.jl @@ -0,0 +1,274 @@ +""" +The AST module provides a julia based AST for OPA's partial compile result. +Provides `ASTVisitor` that implements `ASTWalker.Visitor` and can be used to walk the AST and convert it to a julia based AST. +""" +module AST + +import ..ASTWalker: Visitor, walk, before, visit, after + +abstract type OPATermType end +abstract type OPAComprehensionType <: OPATermType end + +struct OPAScalarValue <: OPATermType + value::Union{Nothing, String, Int64, Float64, Bool} +end + +struct OPAVar <: OPATermType + value::String +end + +struct OPATerm <: OPATermType + value::OPATermType +end + +struct OPARef <: OPATermType + value::Vector{OPATerm} +end + +struct OPAArray <: OPATermType + value::Vector{OPATerm} +end + +struct OPASet <: OPATermType + value::Vector{OPATerm} +end + +struct OPAObject <: OPATermType + value::Vector{Pair{OPATerm, OPATerm}} +end + +struct OPACall <: OPATermType + operator::OPATerm + operands::Vector{OPATerm} +end + +struct OPAExpr + index::Int64 + value::Union{OPATerm, Vector{OPATerm}} +end + +function is_call(expr::OPAExpr) + return isa(expr.value, Vector) +end + +function operator(expr::OPAExpr) + @assert is_call(expr) + return expr.value[1] +end + +function operands(expr::OPAExpr) + @assert is_call(expr) + return expr.value[2:end] +end + +struct OPAQuery + expressions::Vector{OPAExpr} +end + +struct QuerySet + queries::Vector{OPAQuery} +end + +struct OPAArrayComprehension <: OPAComprehensionType + term::OPATerm + body::OPAQuery +end + +struct OPASetComprehension <: OPAComprehensionType + term::OPATerm + body::OPAQuery +end + +struct OPAObjectComprehension <: OPAComprehensionType + key::OPATerm + value::OPATerm + body::OPAQuery +end + +const TERM_TYPE_MAP = Dict{String, DataType}( + "null" => OPAScalarValue, + "string" => OPAScalarValue, + "number" => OPAScalarValue, + "boolean" => OPAScalarValue, + "var" => OPAVar, + "ref" => OPARef, + "array" => OPAArray, + "set" => OPASet, + "object" => OPAObject, + "call" => OPACall, + "objectcomprehension" => OPAObjectComprehension, + "arraycomprehension" => OPAArrayComprehension, + "setcomprehension" => OPASetComprehension, +) + +""" +Visitor that converts a partial compile result to a julia based AST. +Must be used with `ASTWalker.walk`, providing the partial compile result as the `node` argument. + +Output: +- `QuerySet`: If the partial compile result contains queries, the output is a `QuerySet` containing the queries. +- `nothing`: If the partial compile result does not contain queries, the output is `nothing`. + +The output is returned from the `walk` method. +""" +struct ASTVisitor <: Visitor + state_stack::Vector{DataType} + result_stack::Vector{Any} + + function ASTVisitor() + return new(DataType[], Any[]) + end +end + +function before(visitor::ASTVisitor, node) + if isempty(visitor.state_stack) + push!(visitor.state_stack, QuerySet) + end + return nothing +end + +function visit(visitor::ASTVisitor, node) + T = visitor.state_stack[end] + _visit(visitor, T, node) + return nothing +end + +function after(visitor::ASTVisitor, node) + T = pop!(visitor.state_stack) + if !isempty(visitor.state_stack) + return nothing + end + @assert T === QuerySet + return pop!(visitor.result_stack) # either a QuerySet or nothing +end + +function _visit(visitor::ASTVisitor, ::Type{QuerySet}, node::Dict{String,Any}) + if haskey(node, "queries") && length(node["queries"]) > 0 + data = node["queries"] + N = length(data) + for idx in N:-1:1 + push!(visitor.state_stack, OPAQuery) + walk(visitor, data[idx]) + end + queryset = QuerySet([pop!(visitor.result_stack) for idx in 1:N]) + push!(visitor.result_stack, queryset) + else + push!(visitor.result_stack, nothing) + end +end + +function _visit(visitor::ASTVisitor, ::Type{OPAQuery}, node) + N = length(node) + for idx in N:-1:1 + push!(visitor.state_stack, OPAExpr) + walk(visitor, node[idx]) + end + query = OPAQuery([pop!(visitor.result_stack) for idx in 1:N]) + push!(visitor.result_stack, query) +end + +function _visit(visitor::ASTVisitor, ::Type{OPAExpr}, node) + index = node["index"] + terms = node["terms"] + + if isa(terms, Vector) + N = length(terms) + for idx in N:-1:1 + push!(visitor.state_stack, OPATerm) + walk(visitor, terms[idx]) + end + else + N = 1 + push!(visitor.state_stack, OPATerm) + walk(visitor, terms) + end + opaterms = OPATerm[pop!(visitor.result_stack) for idx in 1:N] + opaexpr = OPAExpr(index, opaterms) + push!(visitor.result_stack, opaexpr) +end + +function _visit(visitor::ASTVisitor, ::Type{OPATerm}, node) + termtype = node["type"] + T = TERM_TYPE_MAP[termtype] + if isa(T, DataType) + push!(visitor.state_stack, T) + data = node["value"] + walk(visitor, data) + else + error("Unknown term type: $termtype") + end + term = OPATerm(pop!(visitor.result_stack)) + push!(visitor.result_stack, term) +end + +_visit(visitor::ASTVisitor, ::Type{OPAVar}, value) = push!(visitor.result_stack, OPAVar(value)) +_visit(visitor::ASTVisitor, ::Type{OPAScalarValue}, value) = push!(visitor.result_stack, OPAScalarValue(value)) + +function _visit(visitor::ASTVisitor, ::Type{T}, data) where T <: Union{OPARef, OPAArray, OPASet} + N = length(data) + for idx in N:-1:1 + push!(visitor.state_stack, OPATerm) + walk(visitor, data[idx]) + end + opaterms = OPATerm[pop!(visitor.result_stack) for idx in 1:N] + result = T(opaterms) + push!(visitor.result_stack, result) +end + +function _visit(visitor::ASTVisitor, ::Type{OPAObject}, data) + N = length(data) + for idx in N:-1:1 + pair = data[idx] + push!(visitor.state_stack, OPATerm) + walk(visitor, pair[2]) + push!(visitor.state_stack, OPATerm) + walk(visitor, pair[1]) + end + termpairs = Pair{OPATerm, OPATerm}[Pair(pop!(visitor.result_stack), pop!(visitor.result_stack)) for idx in 1:N] + result = OPAObject(termpairs) + push!(visitor.result_stack, result) +end + +function _visit(visitor::ASTVisitor, ::Type{OPACall}, data) + N = length(data) + for idx in N:-1:1 + push!(visitor.state_stack, OPATerm) + walk(visitor, data[idx]) + end + operator = pop!(visitor.result_stack) + operands = OPATerm[pop!(visitor.result_stack) for idx in 1:(N-1)] + result = OPACall(operator, operands) + push!(visitor.result_stack, result) +end + +function _visit(visitor::ASTVisitor, ::Type{T}, data) where T <: Union{OPASetComprehension, OPAArrayComprehension} + push!(visitor.state_stack, OPATerm) + walk(visitor, data["term"]) + term = pop!(visitor.result_stack) + + push!(visitor.state_stack, OPAQuery) + walk(visitor, data["body"]) + body = pop!(visitor.result_stack) + + result = T(term, body) + push!(visitor.result_stack, result) +end + +function _visit(visitor::ASTVisitor, ::Type{OPAObjectComprehension}, data) + push!(visitor.state_stack, OPATerm) + walk(visitor, data["key"]) + key = pop!(visitor.result_stack) + + push!(visitor.state_stack, OPATerm) + walk(visitor, data["value"]) + value = pop!(visitor.result_stack) + + push!(visitor.state_stack, OPAQuery) + walk(visitor, data["body"]) + body = pop!(visitor.result_stack) + + result = OPAObjectComprehension(key, value, body) + push!(visitor.result_stack, result) +end + +end # module AST \ No newline at end of file diff --git a/src/utils/ast_walker.jl b/src/utils/ast_walker.jl new file mode 100644 index 0000000..214acaf --- /dev/null +++ b/src/utils/ast_walker.jl @@ -0,0 +1,56 @@ +module ASTWalker + + """ + Visitor + + Abstract type for AST visitors. + Visitors must implement the `before`, `visit` and `after` methods. + Visitors can keep state, the same visitor instance will be passed to all invocations of `before`, `visit` and `after` that happen while walking the AST. + """ + abstract type Visitor end + + """ + before(visitor, node) + + Called before visiting a node. The node that will be visited is passed as the second argument. + Any preparatory work that needs to be done before visiting the node can be done here. + Return value is ignored. + """ + before(visitor::Visitor, node) = error("Not implemented: before($(typeof(visitor)), $(typeof(node)))") + + """ + visit(visitor, node) + + Called when visiting a node. The node that is being visited is passed as the second argument. + The actual action to be performed when visiting a node must be implemented here. + The visit method must also call `walk` on the visitor to visit the children of the node. + The result must be stored in the visitor state. Return value is ignored. + """ + visit(visitor::Visitor, node) = error("Not implemented: visit($(typeof(visitor)), $(typeof(node)))") + + """ + after(visitor, node) + + Called after visiting a node. The node that was visited is passed as the second argument. + Any cleanup work that needs to be done after visiting the node can be done here. + This is the last method called when visiting a node. + Must return the result of visiting the node. + """ + after(visitor::Visitor, node) = error("Not implemented: after($(typeof(visitor)), $(typeof(node)))") + + """ + walk(visitor, node) + + Walks the AST rooted at `node` using the `visitor`. + Calls `before`, `visit` and `after` methods of the `visitor` in sequence while walking the tree. + """ + function walk(visitor::Visitor, node) + before(visitor, node) + visit(visitor, node) + return after(visitor, node) + end + + include("ast.jl") + include("sql.jl") + +end # end module \ No newline at end of file diff --git a/src/utils/sql.jl b/src/utils/sql.jl new file mode 100644 index 0000000..43b6c2d --- /dev/null +++ b/src/utils/sql.jl @@ -0,0 +1,177 @@ +module SQL + +import ..ASTWalker: Visitor, walk, before, visit, after +import ..ASTWalker: AST + +""" + AbstractSQLCondition + +Abstract type for SQL conditions. This is used to represent the result of translating an OPA query to SQL. +Conditions can be either `SQLCondition` or `UnconditionalInclude` or `UnconditionalExclude`. +""" +abstract type AbstractSQLCondition end + +""" + SQLCondition + +Represents a SQL condition. +Contains the SQL string that represents the condition. It can be appended to a SQL query using a `where` clause. +""" +struct SQLCondition <: AbstractSQLCondition + sql::String +end + +""" + UnconditionalInclude + +Represents an unconditional include condition. +Equivalent of a where clause with `true` condition. +""" +struct UnconditionalInclude <: AbstractSQLCondition end + +""" + UnconditionalExclude + +Represents an unconditional exclude condition. +Equivalent of a where clause with `false` condition. +""" +struct UnconditionalExclude <: AbstractSQLCondition end + +const SQL_OP_MAP = Dict{String,String}( + "eq" => "=", + "neq" => "!=", + "gt" => ">", + "gte" => ">=", + "lt" => "<", + "lte" => "<=", + "equal" => "=", + "internal.member_2" => "in", +) + +const VALID_SQL_OPS = Set(keys(SQL_OP_MAP)) + +""" + SQLVisitor + +Visitor that converts an OPA partial compile AST to a SQL condition. + +It requires two dictionaries to be passed in the constructor: +- `schema_map`: maps OPA package names to database schema names +- `table_map`: maps OPA rule names to database table names + +Input to the visitor must be a partial compile result from OPA already converted to a julia representation using `ASTWalker.AST.ASTVisitor`. +Walking the AST using this visitor will result in a SQL condition that can be appended to a SQL query using a `where` clause. +Output, that is returned from the `walk` method, is an `AbstractSQLCondition`. It can be one of: + +- `SQLCondition`: represents a SQL condition. Contains the SQL string that represents the condition that can be used in the query with a "where" clause. +- `UnconditionalInclude`: represents an unconditional include condition. Which means that the SQL query should return all rows. +- `UnconditionalExclude`: represents an unconditional exclude condition. Which means that the SQL query should not return any rows. +""" +struct SQLVisitor <: Visitor + schema_map::Dict{String, String} + table_map::Dict{String, String} + result_stack::Vector{Any} + + function SQLVisitor(schema_map::Dict{String, String}, table_map::Dict{String, String}) + return new(schema_map, table_map, Any[]) + end +end + +function strip_quote(var) + if startswith(var, "'") && endswith(var, "'") + return var[2:end-1] + else + return var + end +end + +before(::SQLVisitor, _node) = nothing +after(visitor::SQLVisitor, _node) = pop!(visitor.result_stack) + +function visit(visitor::SQLVisitor, ::Nothing) + push!(visitor.result_stack, UnconditionalExclude()) + return nothing +end + +function visit(visitor::SQLVisitor, queryset::AST.QuerySet) + filters = [walk(visitor, q) for q in queryset.queries] + if any(isempty, filters) + # if any of the filters are fully satisfied, then the query is fully satisfied + push!(visitor.result_stack, UnconditionalInclude()) + else + push!(visitor.result_stack, SQLCondition(join(filters, " or\n"))) + end + return nothing +end + +function visit(visitor::SQLVisitor, query::AST.OPAQuery) + filters = [walk(visitor, expr) for expr in query.expressions] + push!(visitor.result_stack, join(filters, " and ")) +end + +function visit(visitor::SQLVisitor, term::AST.OPATerm) + resp = walk(visitor, term.value) + push!(visitor.result_stack, resp) +end +function visit(visitor::SQLVisitor, var::AST.OPAVar) + push!(visitor.result_stack, var.value) + return nothing +end + +function visit(visitor::SQLVisitor, scaler::AST.OPAScalarValue) + value = scaler.value + result = (value === nothing) ? "null" : + (value === true) ? "true" : + (value === false) ? "false" : + isa(value, String) ? "'$value'" : + string(value) + push!(visitor.result_stack, result) + return nothing +end + +function visit(visitor::SQLVisitor, ref::AST.OPARef) + col_spec = ref.value + + is_db_column_ref = false + if length(col_spec) == 4 + selector = walk(visitor, col_spec[3]) + is_db_column_ref = startswith(selector, '$') + end + + if is_db_column_ref + # this is a database column reference + schema = visitor.schema_map[strip_quote(walk(visitor, col_spec[1]))] + table = visitor.table_map[strip_quote(walk(visitor, col_spec[2]))] + colname = strip_quote(walk(visitor, col_spec[4])) + push!(visitor.result_stack, join([schema, table, colname], ".")) + else + # this is a reference to a variable + push!(visitor.result_stack, join(strip_quote.([walk(visitor, cs) for cs in col_spec]), '.')) + end + return nothing +end + +function visit(visitor::SQLVisitor, arr::AST.OPAArray) + push!(visitor.result_stack, string("(", join([walk(visitor, v) for v in arr.value], ", "), ")")) + return nothing +end + +function visit(visitor::SQLVisitor, expr::AST.OPAExpr) + @assert AST.is_call(expr) + + op_spec = AST.operator(expr) + op_name = walk(visitor, op_spec) + if !(op_name in VALID_SQL_OPS) + error("Invalid SQL operator: $op_name") + end + + op_operands = AST.operands(expr) + @assert length(op_operands) == 2 + op_lhs = walk(visitor, op_operands[1]) + op_rhs = walk(visitor, op_operands[2]) + op = SQL_OP_MAP[op_name] + + push!(visitor.result_stack, join([op_lhs, op, op_rhs], " ")) +end + +end # module SQL \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index c687ced..7e6435a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,217 +3,16 @@ using OpenAPI using JSON using HTTP using Test -import OpenPolicyAgent: CLI, Client +import OpenPolicyAgent: CLI, Client, ASTWalker +import OpenPolicyAgent.ASTWalker: AST, SQL +import OpenPolicyAgent.ASTWalker.AST: ASTVisitor +import OpenPolicyAgent.ASTWalker.SQL: SQLVisitor, SQLCondition, UnconditionalInclude, UnconditionalExclude +include("test_data.jl") +include("test_utils.jl") include("sql_translate.jl") import .OPASQL: translate -const opa_config_template = joinpath(@__DIR__, "conf", "config.yaml") -const bundle_root = joinpath(@__DIR__, "bundle_root") -const data_bundle_root = joinpath(bundle_root, "data_bundle") -const policies_bundle_root = joinpath(bundle_root, "policies_bundle") -const bundle_args = Dict( - :bundle => true, - :signing_alg => "HS512", - :signing_key => "secret", -) - -const EXAMPLE_POLICY = """package opa.examples - import data.servers - import data.networks - import data.ports - - public_servers[server] { - some k, m - server := servers[_] - server.ports[_] == ports[k].id - ports[k].networks[_] == networks[m].id - networks[m].public == true - } -""" - -const PARTIAL_COMPILE_CASES = [ - ( - policy = """package example - allow { - input.subject.clearance_level >= data.reports[_].clearance_level - }""", - query = "data.example.allow == true", - input = Dict{String,Any}( - "subject" => Dict{String,Any}( - "clearance_level" => 4 - ) - ), - options = Dict{String,Any}( - "disableInlining" => [] - ), - unknowns = ["data.reports"], - sql = "4 >= public.juliahub_reports.clearance_level", - ), - ( - policy = """package example - - allow { - input.subject.group == "admin" - } - allow { - data.reports[_].public == true - } - allow { - input.subject.clearance_level >= data.reports[_].clearance_level - input.subject.id == data.reports[_].owner - } - """, - query = "data.example.allow == true", - input = Dict{String,Any}( - "subject" => Dict{String,Any}( - "clearance_level" => 4, - "id" => "bob", - "group" => "eng" - ) - ), - options = Dict{String,Any}( - "disableInlining" => [] - ), - unknowns = ["data.reports"], - sql = "public.juliahub_reports.public = true or\n4 >= public.juliahub_reports.clearance_level and 'bob' = public.juliahub_reports.owner", - ), - ( - # always allowed if the policy is fully satisfied with the given input for any one condition - policy = """package example - - allow { - input.subject.group == "admin" - } - allow { - data.reports[_].public == true - } - allow { - input.subject.clearance_level >= data.reports[_].clearance_level - input.subject.id == data.reports[_].owner - } - """, - query = "data.example.allow == true", - input = Dict{String,Any}( - "subject" => Dict{String,Any}( - "clearance_level" => 4, - "id" => "sally", - "group" => "admin" - ) - ), - options = Dict{String,Any}( - "disableInlining" => [] - ), - unknowns = ["data.reports"], - sql = "", - ), - ( - # always allowed if the policy with only one condition is fully satisfied with the given input - policy = """package example - default allow = false - allow { - input.subject.group == "admin" - } - """, - query = "data.example.allow == true", - input = Dict{String,Any}( - "subject" => Dict{String,Any}( - "id" => "sally", - "group" => "admin" - ) - ), - options = Dict{String,Any}( - "disableInlining" => [] - ), - unknowns = ["data.reports"], - sql = "", - ), - ( - # not allowed if the required policy is not defined - policy = """package example - default allow = false - allow { - input.subject.group == "admin" - } - """, - query = "data.example.undefinedallow == true", - input = Dict{String,Any}( - "subject" => Dict{String,Any}( - "id" => "sally", - "group" => "admin" - ) - ), - options = Dict{String,Any}( - "disableInlining" => [] - ), - unknowns = ["data.reports"], - sql = "false", - ), - ( - policy = """package example - import future.keywords.in - allow { - input.subject.group in ["admin", "superadmin"] - } - allow { - data.reports[_].category in ["public", "pinned"] - } - allow { - input.subject.clearance_level >= data.reports[_].clearance_level - } - """, - query = "data.example.allow == true", - input = Dict{String,Any}( - "subject" => Dict{String,Any}( - "clearance_level" => 4, - "group" => "eng", - ) - ), - options = Dict{String,Any}( - "disableInlining" => [] - ), - unknowns = ["data.reports"], - sql = "public.juliahub_reports.category in ('public', 'pinned') or\n4 >= public.juliahub_reports.clearance_level", - ), -] - -const EXAMPLE_QUERY = """input.servers[i].ports[_] = "p2"; input.servers[i].name = name""" -const EXAMPLE_QUERY_INPUT = Dict{String,Any}( - "servers" => [ - Dict{String,Any}( - "id" => "s1", - "name" => "app", - "ports" => ["p1", "p2", "p3"], - "protocols" => ["https", "ssh"] - ), - Dict{String,Any}( - "id" => "s4", - "name" => "dev", - "ports" => ["p1", "p2"], - "protocols" => ["http"] - ) - ] -) - -# Prepare the bundles -function prepare_bundle(bundle_location::String) - signed_bundle_file = joinpath(bundle_location, "data.tar.gz") - run(`rm -f $signed_bundle_file`) - CLI.build(OpenPolicyAgent.CLI.CommandLine(; cmdopts=Dict(:dir => data_bundle_root)), - "."; - output=signed_bundle_file, - bundle_args... - ) - - signed_bundle_file = joinpath(bundle_location, "policies.tar.gz") - run(`rm -f $signed_bundle_file`) - CLI.build(OpenPolicyAgent.CLI.CommandLine(; cmdopts=Dict(:dir => policies_bundle_root)), - "."; - output=signed_bundle_file, - bundle_args... - ) -end - # Check version and help output function test_version_help() iob_stdout = IOBuffer() @@ -237,61 +36,6 @@ function test_version_help() end end -function file_response(path) - open(path, "r") do io - return HTTP.Response(200, readavailable(io)) - end -end - -# Start a bundle server -function start_bundle_server(root_path) - # start a HTTP.jl server serving at root_path - # and serve only two files policies.tar.gz and data.tar.gz - server = HTTP.serve!("127.0.0.1", 8080) do req::HTTP.Request - @info("request", target=req.target, method=req.method) - if req.method == "GET" && req.target == "/data.tar.gz" - return file_response(joinpath(root_path, "data.tar.gz")) - elseif req.method == "GET" && req.target == "/policies.tar.gz" - return file_response(joinpath(root_path, "policies.tar.gz")) - else - return HTTP.Response(404) - end - end - - return server -end - -function start_opa_server(root_path; change_dir::Bool=true) - if change_dir - opa_server = OpenPolicyAgent.Server.MonitoredOPAServer( - joinpath(root_path, "config.yaml"); - stdout = joinpath(root_path, "server.stdout"), - stderr = joinpath(root_path, "server.stderr"), - cmdline = OpenPolicyAgent.CLI.CommandLine(; cmdopts=Dict(:dir => root_path)), - ) - else - opa_server = OpenPolicyAgent.Server.MonitoredOPAServer( - joinpath(root_path, "config.yaml"); - stdout = joinpath(root_path, "server.stdout"), - stderr = joinpath(root_path, "server.stderr"), - ) - end - OpenPolicyAgent.Server.start!(opa_server) - return opa_server -end - -function policy_path() - policy_package = "policies/server/rest" - rule_name = "allowed" - return joinpath(policy_package, rule_name) -end - -function query_user(opa_client, username) - request_body = Dict{String,Any}("input" => Dict{String,Any}("name" => username)) - response, http_resp = OpenPolicyAgent.Client.get_document_with_path(opa_client, policy_path(), request_body; pretty=true, provenance=true, explain=true, metrics=true, instrument=true); - return response.result -end - function test_data_api(openapi_client) opa_client = OpenPolicyAgent.Client.DataApi(openapi_client) @@ -449,6 +193,18 @@ function test_compile_api(openapi_client) sql = translate(result) @test sql == partial_compile_case.sql + + ast = OpenPolicyAgent.ASTWalker.walk(ASTVisitor(), result) + sqlvisitor = SQLVisitor(SCHEMA_MAP, TABLE_MAP) + sqlcondition = OpenPolicyAgent.ASTWalker.walk(sqlvisitor, ast) + if partial_compile_case.sql == "false" + @test isa(sqlcondition, UnconditionalExclude) + elseif isempty(partial_compile_case.sql) + @test isa(sqlcondition, UnconditionalInclude) + else + @test isa(sqlcondition, SQLCondition) + @test sqlcondition.sql == partial_compile_case.sql + end finally # delete the test policy result, _http_resp = OpenPolicyAgent.Client.delete_policy_module(policy_client, "example"; pretty=true) diff --git a/test/test_data.jl b/test/test_data.jl new file mode 100644 index 0000000..f8094f6 --- /dev/null +++ b/test/test_data.jl @@ -0,0 +1,195 @@ +const opa_config_template = joinpath(@__DIR__, "conf", "config.yaml") +const bundle_root = joinpath(@__DIR__, "bundle_root") +const data_bundle_root = joinpath(bundle_root, "data_bundle") +const policies_bundle_root = joinpath(bundle_root, "policies_bundle") +const bundle_args = Dict( + :bundle => true, + :signing_alg => "HS512", + :signing_key => "secret", +) + +const EXAMPLE_POLICY = """package opa.examples + import data.servers + import data.networks + import data.ports + + public_servers[server] { + some k, m + server := servers[_] + server.ports[_] == ports[k].id + ports[k].networks[_] == networks[m].id + networks[m].public == true + } +""" + +const PARTIAL_COMPILE_CASES = [ + ( + policy = """package example + allow { + input.subject.clearance_level >= data.reports[_].clearance_level + }""", + query = "data.example.allow == true", + input = Dict{String,Any}( + "subject" => Dict{String,Any}( + "clearance_level" => 4 + ) + ), + options = Dict{String,Any}( + "disableInlining" => [] + ), + unknowns = ["data.reports"], + sql = "4 >= public.juliahub_reports.clearance_level", + ), + ( + policy = """package example + + allow { + input.subject.group == "admin" + } + allow { + data.reports[_].public == true + } + allow { + input.subject.clearance_level >= data.reports[_].clearance_level + input.subject.id == data.reports[_].owner + } + """, + query = "data.example.allow == true", + input = Dict{String,Any}( + "subject" => Dict{String,Any}( + "clearance_level" => 4, + "id" => "bob", + "group" => "eng" + ) + ), + options = Dict{String,Any}( + "disableInlining" => [] + ), + unknowns = ["data.reports"], + sql = "public.juliahub_reports.public = true or\n4 >= public.juliahub_reports.clearance_level and 'bob' = public.juliahub_reports.owner", + ), + ( + # always allowed if the policy is fully satisfied with the given input for any one condition + policy = """package example + + allow { + input.subject.group == "admin" + } + allow { + data.reports[_].public == true + } + allow { + input.subject.clearance_level >= data.reports[_].clearance_level + input.subject.id == data.reports[_].owner + } + """, + query = "data.example.allow == true", + input = Dict{String,Any}( + "subject" => Dict{String,Any}( + "clearance_level" => 4, + "id" => "sally", + "group" => "admin" + ) + ), + options = Dict{String,Any}( + "disableInlining" => [] + ), + unknowns = ["data.reports"], + sql = "", + ), + ( + # always allowed if the policy with only one condition is fully satisfied with the given input + policy = """package example + default allow = false + allow { + input.subject.group == "admin" + } + """, + query = "data.example.allow == true", + input = Dict{String,Any}( + "subject" => Dict{String,Any}( + "id" => "sally", + "group" => "admin" + ) + ), + options = Dict{String,Any}( + "disableInlining" => [] + ), + unknowns = ["data.reports"], + sql = "", + ), + ( + # not allowed if the required policy is not defined + policy = """package example + default allow = false + allow { + input.subject.group == "admin" + } + """, + query = "data.example.undefinedallow == true", + input = Dict{String,Any}( + "subject" => Dict{String,Any}( + "id" => "sally", + "group" => "admin" + ) + ), + options = Dict{String,Any}( + "disableInlining" => [] + ), + unknowns = ["data.reports"], + sql = "false", + ), + ( + policy = """package example + import future.keywords.in + allow { + input.subject.group in ["admin", "superadmin"] + } + allow { + data.reports[_].category in ["public", "pinned"] + } + allow { + input.subject.clearance_level >= data.reports[_].clearance_level + } + """, + query = "data.example.allow == true", + input = Dict{String,Any}( + "subject" => Dict{String,Any}( + "clearance_level" => 4, + "group" => "eng", + ) + ), + options = Dict{String,Any}( + "disableInlining" => [] + ), + unknowns = ["data.reports"], + sql = "public.juliahub_reports.category in ('public', 'pinned') or\n4 >= public.juliahub_reports.clearance_level", + ), +] + +const EXAMPLE_QUERY = """input.servers[i].ports[_] = "p2"; input.servers[i].name = name""" +const EXAMPLE_QUERY_INPUT = Dict{String,Any}( + "servers" => [ + Dict{String,Any}( + "id" => "s1", + "name" => "app", + "ports" => ["p1", "p2", "p3"], + "protocols" => ["https", "ssh"] + ), + Dict{String,Any}( + "id" => "s4", + "name" => "dev", + "ports" => ["p1", "p2"], + "protocols" => ["http"] + ) + ] +) + +const SCHEMA_MAP = Dict{String, String}( + "data" => "public", + "public" => "public", +) + +const TABLE_MAP = Dict{String, String}( + "reports" => "juliahub_reports", +) diff --git a/test/test_utils.jl b/test/test_utils.jl new file mode 100644 index 0000000..e0d520b --- /dev/null +++ b/test/test_utils.jl @@ -0,0 +1,74 @@ +# Prepare the bundles +function prepare_bundle(bundle_location::String) + signed_bundle_file = joinpath(bundle_location, "data.tar.gz") + run(`rm -f $signed_bundle_file`) + CLI.build(OpenPolicyAgent.CLI.CommandLine(; cmdopts=Dict(:dir => data_bundle_root)), + "."; + output=signed_bundle_file, + bundle_args... + ) + + signed_bundle_file = joinpath(bundle_location, "policies.tar.gz") + run(`rm -f $signed_bundle_file`) + CLI.build(OpenPolicyAgent.CLI.CommandLine(; cmdopts=Dict(:dir => policies_bundle_root)), + "."; + output=signed_bundle_file, + bundle_args... + ) +end + + +function file_response(path) + open(path, "r") do io + return HTTP.Response(200, readavailable(io)) + end +end + +# Start a bundle server +function start_bundle_server(root_path) + # start a HTTP.jl server serving at root_path + # and serve only two files policies.tar.gz and data.tar.gz + server = HTTP.serve!("127.0.0.1", 8080) do req::HTTP.Request + @info("request", target=req.target, method=req.method) + if req.method == "GET" && req.target == "/data.tar.gz" + return file_response(joinpath(root_path, "data.tar.gz")) + elseif req.method == "GET" && req.target == "/policies.tar.gz" + return file_response(joinpath(root_path, "policies.tar.gz")) + else + return HTTP.Response(404) + end + end + + return server +end + +function start_opa_server(root_path; change_dir::Bool=true) + if change_dir + opa_server = OpenPolicyAgent.Server.MonitoredOPAServer( + joinpath(root_path, "config.yaml"); + stdout = joinpath(root_path, "server.stdout"), + stderr = joinpath(root_path, "server.stderr"), + cmdline = OpenPolicyAgent.CLI.CommandLine(; cmdopts=Dict(:dir => root_path)), + ) + else + opa_server = OpenPolicyAgent.Server.MonitoredOPAServer( + joinpath(root_path, "config.yaml"); + stdout = joinpath(root_path, "server.stdout"), + stderr = joinpath(root_path, "server.stderr"), + ) + end + OpenPolicyAgent.Server.start!(opa_server) + return opa_server +end + +function policy_path() + policy_package = "policies/server/rest" + rule_name = "allowed" + return joinpath(policy_package, rule_name) +end + +function query_user(opa_client, username) + request_body = Dict{String,Any}("input" => Dict{String,Any}("name" => username)) + response, http_resp = OpenPolicyAgent.Client.get_document_with_path(opa_client, policy_path(), request_body; pretty=true, provenance=true, explain=true, metrics=true, instrument=true); + return response.result +end