Skip to content

Commit

Permalink
framework for partial compile result processing
Browse files Browse the repository at this point in the history
Added a AST walker for partial compile result, and a visitor implementation that can translate it to SQL conditions.
A more formal implementation of what was done in #9.
  • Loading branch information
tanmaykm committed Nov 15, 2023
1 parent 9b2b55f commit 9c4f3c7
Show file tree
Hide file tree
Showing 11 changed files with 853 additions and 263 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OpenPolicyAgent"
uuid = "8f257efb-743c-4ebc-8197-d291a1f743b4"
authors = ["JuliaHub Inc.", "Tanmay Mohapatra <[email protected]>"]
version = "0.1.1"
version = "0.2.0"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ makedocs(
"Client" => "client.md",
"Server" => "server.md",
"Command Line" => "commandline.md",
"AST Walker" => "ast_walker.md",
"Reference" => "reference.md",
],
)
Expand Down
54 changes: 54 additions & 0 deletions docs/src/ast_walker.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# AST Walker

OPA has a feature called partial evaluation which has several interesting applications. With partial evaluation, callers specify that certain inputs or pieces of data are unknown. OPA evaluates as much of the policy as possible without touching parts that depend on unknown values. The result of partial evaluation is a new policy that can be evaluated more efficiently than the original. The new policy is returned to the caller as an AST.

The returned AST thus represents a strategy, rather than a result. It may be cached and reused. It may also be converted to other forms, e.g. a SQL query condition, or elastic search query.

The `ASTWalker` module provides a framework to traverse the AST returned from a partial evaluattion. It specifies a `Visitor` interface that callers can implement to perform custom operations on the AST. The `ASTWalker` module also provides a default implementation of the `Visitor` interface that can be used to perform common operations on the AST.

Included in the `ASTWalker` module are implementations of the `Visitor` interface that can be used to:
- Create a easy to use Julia representation of the AST. This is provided by the `ASTWalker.AST.ASTVisitor` type.
- Create a SQL query condition from the Julia representation of the AST. This is provided by the `ASTWalker.SQL.SQLVisitor` type.

An example of how it can be used is shown below:
```julia
import OpenPolicyAgent: ASTWalker
import OpenPolicyAgent.ASTWalker: AST, SQL
import OpenPolicyAgent.ASTWalker.AST: ASTVisitor
import OpenPolicyAgent.ASTWalker.SQL: SQLVisitor, SQLCondition, UnconditionalInclude, UnconditionalExclude

# invoke the partial evaluation endpoint
partial_query_schema = OpenPolicyAgent.Client.PartialQuerySchema(; ...)
response, _http_resp = OpenPolicyAgent.Client.post_compile(
compile_client;
partial_query_schema = partial_query_schema,
)

# crete a Julia representation of the AST
ast = OpenPolicyAgent.ASTWalker.walk(ASTVisitor(), result)

# Provide a mapping of schema names and table names that can be used to convert policy paths to SQL table names
const SCHEMA_MAP = Dict{String, String}(
"data" => "public",
"public" => "public",
)

const TABLE_MAP = Dict{String, String}(
"reports" => "juliahub_reports",
)

# create a SQL query condition from the AST
sqlvisitor = SQLVisitor(SCHEMA_MAP, TABLE_MAP)
sqlcondition = OpenPolicyAgent.ASTWalker.walk(sqlvisitor, ast)

# sql condition should be a SQLCondition object
if isa(sqlcondition, UnconditionalExclude)
# all rows should be excluded
elseif isa(sqlcondition, UnconditionalInclude)
# all rows should be included
else
# `sqlcondition.sql` contains a string with the SQL query condition
end
```

More details of AST walker, and the included visitors can be found in the reference documentation.
17 changes: 17 additions & 0 deletions docs/src/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,20 @@ OpenPolicyAgent.CLI.sign
OpenPolicyAgent.CLI.test
OpenPolicyAgent.CLI.bench
```

## AST Walker

```@docs
OpenPolicyAgent.ASTWalker.Vistor
OpenPolicyAgent.ASTWalker.walk
OpenPolicyAgent.ASTWalker.before
OpenPolicyAgent.ASTWalker.visit
OpenPolicyAgent.ASTWalker.after
```

### Included Visitors

```@docs
OpenPolicyAgent.ASTWalker.AST.ASTVisitor
OpenPolicyAgent.ASTWalker.SQL.SQLVisitor
```
1 change: 1 addition & 0 deletions src/OpenPolicyAgent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ module OpenPolicyAgent
include("cli/cli.jl")
include("client/src/Client.jl")
include("server/server.jl")
include("utils/ast_walker.jl")

end # module OpenPolicyAgent
274 changes: 274 additions & 0 deletions src/utils/ast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
"""
The AST module provides a julia based AST for OPA's partial compile result.
Provides `ASTVisitor` that implements `ASTWalker.Visitor` and can be used to walk the AST and convert it to a julia based AST.
"""
module AST

import ..ASTWalker: Visitor, walk, before, visit, after

abstract type OPATermType end
abstract type OPAComprehensionType <: OPATermType end

struct OPAScalarValue <: OPATermType
value::Union{Nothing, String, Int64, Float64, Bool}
end

struct OPAVar <: OPATermType
value::String
end

struct OPATerm <: OPATermType
value::OPATermType
end

struct OPARef <: OPATermType
value::Vector{OPATerm}
end

struct OPAArray <: OPATermType
value::Vector{OPATerm}
end

struct OPASet <: OPATermType
value::Vector{OPATerm}
end

struct OPAObject <: OPATermType
value::Vector{Pair{OPATerm, OPATerm}}
end

struct OPACall <: OPATermType
operator::OPATerm
operands::Vector{OPATerm}
end

struct OPAExpr
index::Int64
value::Union{OPATerm, Vector{OPATerm}}
end

function is_call(expr::OPAExpr)
return isa(expr.value, Vector)
end

function operator(expr::OPAExpr)
@assert is_call(expr)
return expr.value[1]
end

function operands(expr::OPAExpr)
@assert is_call(expr)
return expr.value[2:end]
end

struct OPAQuery
expressions::Vector{OPAExpr}
end

struct QuerySet
queries::Vector{OPAQuery}
end

struct OPAArrayComprehension <: OPAComprehensionType
term::OPATerm
body::OPAQuery
end

struct OPASetComprehension <: OPAComprehensionType
term::OPATerm
body::OPAQuery
end

struct OPAObjectComprehension <: OPAComprehensionType
key::OPATerm
value::OPATerm
body::OPAQuery
end

const TERM_TYPE_MAP = Dict{String, DataType}(
"null" => OPAScalarValue,
"string" => OPAScalarValue,
"number" => OPAScalarValue,
"boolean" => OPAScalarValue,
"var" => OPAVar,
"ref" => OPARef,
"array" => OPAArray,
"set" => OPASet,
"object" => OPAObject,
"call" => OPACall,
"objectcomprehension" => OPAObjectComprehension,
"arraycomprehension" => OPAArrayComprehension,
"setcomprehension" => OPASetComprehension,
)

"""
Visitor that converts a partial compile result to a julia based AST.
Must be used with `ASTWalker.walk`, providing the partial compile result as the `node` argument.
Output:
- `QuerySet`: If the partial compile result contains queries, the output is a `QuerySet` containing the queries.
- `nothing`: If the partial compile result does not contain queries, the output is `nothing`.
The output is returned from the `walk` method.
"""
struct ASTVisitor <: Visitor
state_stack::Vector{DataType}
result_stack::Vector{Any}

function ASTVisitor()
return new(DataType[], Any[])
end
end

function before(visitor::ASTVisitor, node)
if isempty(visitor.state_stack)
push!(visitor.state_stack, QuerySet)
end
return nothing
end

function visit(visitor::ASTVisitor, node)
T = visitor.state_stack[end]
_visit(visitor, T, node)
return nothing
end

function after(visitor::ASTVisitor, node)
T = pop!(visitor.state_stack)
if !isempty(visitor.state_stack)
return nothing
end
@assert T === QuerySet
return pop!(visitor.result_stack) # either a QuerySet or nothing
end

function _visit(visitor::ASTVisitor, ::Type{QuerySet}, node::Dict{String,Any})
if haskey(node, "queries") && length(node["queries"]) > 0
data = node["queries"]
N = length(data)
for idx in N:-1:1
push!(visitor.state_stack, OPAQuery)
walk(visitor, data[idx])
end
queryset = QuerySet([pop!(visitor.result_stack) for idx in 1:N])
push!(visitor.result_stack, queryset)
else
push!(visitor.result_stack, nothing)
end
end

function _visit(visitor::ASTVisitor, ::Type{OPAQuery}, node)
N = length(node)
for idx in N:-1:1
push!(visitor.state_stack, OPAExpr)
walk(visitor, node[idx])
end
query = OPAQuery([pop!(visitor.result_stack) for idx in 1:N])
push!(visitor.result_stack, query)
end

function _visit(visitor::ASTVisitor, ::Type{OPAExpr}, node)
index = node["index"]
terms = node["terms"]

if isa(terms, Vector)
N = length(terms)
for idx in N:-1:1
push!(visitor.state_stack, OPATerm)
walk(visitor, terms[idx])
end
else
N = 1
push!(visitor.state_stack, OPATerm)
walk(visitor, terms)
end
opaterms = OPATerm[pop!(visitor.result_stack) for idx in 1:N]
opaexpr = OPAExpr(index, opaterms)
push!(visitor.result_stack, opaexpr)
end

function _visit(visitor::ASTVisitor, ::Type{OPATerm}, node)
termtype = node["type"]
T = TERM_TYPE_MAP[termtype]
if isa(T, DataType)
push!(visitor.state_stack, T)
data = node["value"]
walk(visitor, data)
else
error("Unknown term type: $termtype")
end
term = OPATerm(pop!(visitor.result_stack))
push!(visitor.result_stack, term)
end

_visit(visitor::ASTVisitor, ::Type{OPAVar}, value) = push!(visitor.result_stack, OPAVar(value))
_visit(visitor::ASTVisitor, ::Type{OPAScalarValue}, value) = push!(visitor.result_stack, OPAScalarValue(value))

function _visit(visitor::ASTVisitor, ::Type{T}, data) where T <: Union{OPARef, OPAArray, OPASet}
N = length(data)
for idx in N:-1:1
push!(visitor.state_stack, OPATerm)
walk(visitor, data[idx])
end
opaterms = OPATerm[pop!(visitor.result_stack) for idx in 1:N]
result = T(opaterms)
push!(visitor.result_stack, result)
end

function _visit(visitor::ASTVisitor, ::Type{OPAObject}, data)
N = length(data)
for idx in N:-1:1
pair = data[idx]
push!(visitor.state_stack, OPATerm)
walk(visitor, pair[2])
push!(visitor.state_stack, OPATerm)
walk(visitor, pair[1])
end
termpairs = Pair{OPATerm, OPATerm}[Pair(pop!(visitor.result_stack), pop!(visitor.result_stack)) for idx in 1:N]
result = OPAObject(termpairs)
push!(visitor.result_stack, result)
end

function _visit(visitor::ASTVisitor, ::Type{OPACall}, data)
N = length(data)
for idx in N:-1:1
push!(visitor.state_stack, OPATerm)
walk(visitor, data[idx])
end
operator = pop!(visitor.result_stack)
operands = OPATerm[pop!(visitor.result_stack) for idx in 1:(N-1)]
result = OPACall(operator, operands)
push!(visitor.result_stack, result)
end

function _visit(visitor::ASTVisitor, ::Type{T}, data) where T <: Union{OPASetComprehension, OPAArrayComprehension}
push!(visitor.state_stack, OPATerm)
walk(visitor, data["term"])
term = pop!(visitor.result_stack)

push!(visitor.state_stack, OPAQuery)
walk(visitor, data["body"])
body = pop!(visitor.result_stack)

result = T(term, body)
push!(visitor.result_stack, result)
end

function _visit(visitor::ASTVisitor, ::Type{OPAObjectComprehension}, data)
push!(visitor.state_stack, OPATerm)
walk(visitor, data["key"])
key = pop!(visitor.result_stack)

push!(visitor.state_stack, OPATerm)
walk(visitor, data["value"])
value = pop!(visitor.result_stack)

push!(visitor.state_stack, OPAQuery)
walk(visitor, data["body"])
body = pop!(visitor.result_stack)

result = OPAObjectComprehension(key, value, body)
push!(visitor.result_stack, result)
end

end # module AST
Loading

0 comments on commit 9c4f3c7

Please sign in to comment.