Skip to content

Commit

Permalink
Check (#18)
Browse files Browse the repository at this point in the history
* add check macro

* cleanup
  • Loading branch information
jw3126 authored Feb 10, 2018
1 parent 0a637bc commit b7b42c3
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 68 deletions.
194 changes: 131 additions & 63 deletions src/ArgCheck.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,51 @@
__precompile__()
module ArgCheck
using Base.Meta
export @argcheck
export @argcheck, @check

abstract type AbstractCheckFlavor end
struct ArgCheckFlavor <: AbstractCheckFlavor end
struct CheckFlavor <: AbstractCheckFlavor end

abstract type AbstractCodeFlavor end
struct CallFlavor <: AbstractCodeFlavor end
struct ComparisonFlavor <: AbstractCodeFlavor end
struct FallbackFlavor <: AbstractCodeFlavor end

struct Checker
code
checkflavor::AbstractCheckFlavor
codeflavor::AbstractCodeFlavor
options
end

abstract type AbstractErrorInfo end
struct CallErrorInfo <: AbstractErrorInfo
code
checkflavor::AbstractCheckFlavor
argument_expressions
argument_values
options
end
struct ComparisonErrorInfo <: AbstractErrorInfo
code
checkflavor::AbstractCheckFlavor
argument_expressions
argument_values
options
end
struct FallbackErrorInfo <: AbstractErrorInfo
code
checkflavor::AbstractCheckFlavor
options
end

"""
@argcheck
Macro for checking invariants on function arguments.
It can be used as follows:
Check invariants on function arguments and produce
a nice exception message if they are violated.
Usage is as follows:
```Julia
function myfunction(k,n,A,B)
@argcheck k > n
Expand All @@ -16,15 +54,39 @@ function myfunction(k,n,A,B)
# doit
end
```
See also [`@check`](@ref).
"""
macro argcheck(ex, options...)
if isexpr(ex, :comparison)
argcheck_comparison(ex, options...)
check(ex, ArgCheckFlavor(), options...)
end

"""
@check
Check that a condition holds
and produce a nice exception message, if it does not.
Usage is as follows:
```Julia
@check k > n
@check size(A) == size(B) DimensionMismatch
@check det(A) < 0 DomainError()
```
See also [`@argcheck`](@ref).
"""
macro check(ex, options...)
check(ex, CheckFlavor(), options...)
end

function check(ex, checkflavor, options...)
codeflavor = if isexpr(ex, :comparison)
ComparisonFlavor()
elseif is_simple_call(ex)
argcheck_call(ex, options...)
CallFlavor()
else
argcheck_fallback(ex, options...)
FallbackFlavor()
end
checker = Checker(ex, checkflavor, codeflavor, options)
check(checker, codeflavor)
end

function is_simple_call(ex)
Expand All @@ -37,44 +99,36 @@ function is_simple_call(ex)
true
end

function argcheck_fallback(ex, options...)
quote
if $(esc(ex))
nothing
else
err = ArgCheck.build_error($(QuoteNode(ex)), $(esc.(options)...))
throw(err)
end
end
function check(c, ::FallbackFlavor)
info = Expr(:call, :FallbackErrorInfo,
QuoteNode(c.code),
c.checkflavor,
Expr(:tuple, esc.(c.options)...))

condition = esc(c.code)
expr_error_block(info, condition)
end

function argcheck_call(ex, options...)
function check(c, ::CallFlavor)
ex = c.code
variables = [gensym() for _ in 1:length(ex.args)]
assignments = map(variables, ex.args) do vi, exi
Expr(:(=), vi, esc(exi))
end
condition = Expr(:call, variables...)
values = :([$(variables...)])
err = Expr(:call,
:(ArgCheck.build_error_with_fancy_message),
QuoteNode(ex),
QuoteNode(ex.args),
values,
esc.(options)...
)
quote
$(assignments...)
if $condition
nothing
else
throw($err)
end
end
info = Expr(:call, :CallErrorInfo,
QuoteNode(c.code),
c.checkflavor,
QuoteNode(c.code.args),
values,
Expr(:tuple, esc.(c.options)...))
expr_error_block(info, condition, assignments...)
end

function argcheck_comparison(ex, options...)
exprs = ex.args[1:2:end]
ops = ex.args[2:2:end]
function check(c::Checker, ::ComparisonFlavor)
exprs = c.code.args[1:2:end]
ops = c.code.args[2:2:end]
variables = [gensym() for _ in 1:length(exprs)]
ret = []
rhs = exprs[1]
Expand All @@ -90,46 +144,60 @@ function argcheck_comparison(ex, options...)
assignment = Expr(:(=), vrhs, esc(rhs))
condition = Expr(:call, esc(op), vlhs, vrhs)
code = Expr(:call, op, lhs, rhs)
err = Expr(:call, :(ArgCheck.build_error_comparison),
QuoteNode(code), QuoteNode(lhs), QuoteNode(rhs),
vlhs, vrhs, esc.(options)...)
reti = quote
$assignment
if $condition
nothing
else
throw($err)
end
end
info = Expr(:call, :ComparisonErrorInfo,
QuoteNode(c.code),
c.checkflavor,
[QuoteNode(lhs), QuoteNode(rhs)],
Expr(:tuple, vlhs, vrhs),
Expr(:tuple, esc.(c.options)...))

reti = expr_error_block(info, condition, assignment)
append!(ret, reti.args)

end
Expr(:block, ret...)
end

function build_error(code, T::Type{<:Exception}, options...)
ret = T(options...)
warn("`@argcheck condition $T $(join(options, ' ')...)` is deprecated. Use `@argcheck condition $ret` instead")
ret
end
function build_error(code, msg::AbstractString)
ArgumentError(msg)
function expr_error_block(info, condition, preamble...)
reti = quote
$(preamble...)
if $condition
nothing
else
info = $info
err = build_error(info)
throw(err)
end
end
end
build_error(code, T::Type{<:Exception}=ArgumentError) = T("$code must hold.")
build_error(code, err::Exception) = err

build_error_comparison(code, lhs, rhs, vlhs, vrhs, options...) = build_error(code, options...)
@noinline function build_error_comparison(code, lhs, rhs, vlhs, vrhs, T::Type{<:Exception}=ArgumentError)
build_error_with_fancy_message(code, [lhs, rhs], [vlhs, vrhs], T)
default_exception_type(::ArgCheckFlavor) = ArgumentError
default_exception_type(::CheckFlavor) = ErrorException

function build_error(info)
build_error(info, info.checkflavor, info.options...)
end
function build_error(info, checkflavor, msg::AbstractString)
E = default_exception_type(checkflavor)
E(msg)
end

build_error_with_fancy_message(code, variables, values, options...) = build_error(code, options...)
@noinline function build_error_with_fancy_message(code, variables, values,
T::Type{<:Exception}=ArgumentError)
msg = fancy_error_message(code, variables, values)
function build_error(info, checkflavor, T::Type{<:Exception}=default_exception_type(checkflavor))
msg = error_message(info)
T(msg)
end
function build_error(info, checkflavor, err::Exception)
err
end

error_message(info::FallbackErrorInfo) = "$(info.code) must hold."
error_message(info::CallErrorInfo) = fancy_error_message(info)
error_message(info::ComparisonErrorInfo) = fancy_error_message(info)

function fancy_error_message(code, exprs, values)
function fancy_error_message(info)
code = info.code
exprs = info.argument_expressions
values = info.argument_values
lines = String[]
foreach(exprs, values) do ex, val
sex = string(ex)
Expand Down
11 changes: 6 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,10 @@ end
@argcheck true "this should not happen"
end

@testset "deprecate" begin

# deprecate
err = @catch_exception_object @argcheck false MyExoticError 1 2
@test err === MyExoticError(1,2)
@testset "@check" begin
@check true
E = ErrorException
@test_throws E @check false
@test_throws E @check false "oh no"
@test_throws DimensionMismatch @check false DimensionMismatch
end

0 comments on commit b7b42c3

Please sign in to comment.