From b7b42c3aaece8ddc8f50f11239fc9411463d5537 Mon Sep 17 00:00:00 2001 From: Jan Weidner Date: Sat, 10 Feb 2018 06:47:40 +0100 Subject: [PATCH] Check (#18) * add check macro * cleanup --- src/ArgCheck.jl | 194 ++++++++++++++++++++++++++++++++--------------- test/runtests.jl | 11 +-- 2 files changed, 137 insertions(+), 68 deletions(-) diff --git a/src/ArgCheck.jl b/src/ArgCheck.jl index ace6a97..b8f9b61 100644 --- a/src/ArgCheck.jl +++ b/src/ArgCheck.jl @@ -1,13 +1,51 @@ __precompile__() module ArgCheck using Base.Meta -export @argcheck +export @argcheck, @check + +abstract type AbstractCheckFlavor end +struct ArgCheckFlavor <: AbstractCheckFlavor end +struct CheckFlavor <: AbstractCheckFlavor end + +abstract type AbstractCodeFlavor end +struct CallFlavor <: AbstractCodeFlavor end +struct ComparisonFlavor <: AbstractCodeFlavor end +struct FallbackFlavor <: AbstractCodeFlavor end + +struct Checker + code + checkflavor::AbstractCheckFlavor + codeflavor::AbstractCodeFlavor + options +end + +abstract type AbstractErrorInfo end +struct CallErrorInfo <: AbstractErrorInfo + code + checkflavor::AbstractCheckFlavor + argument_expressions + argument_values + options +end +struct ComparisonErrorInfo <: AbstractErrorInfo + code + checkflavor::AbstractCheckFlavor + argument_expressions + argument_values + options +end +struct FallbackErrorInfo <: AbstractErrorInfo + code + checkflavor::AbstractCheckFlavor + options +end """ @argcheck -Macro for checking invariants on function arguments. -It can be used as follows: +Check invariants on function arguments and produce +a nice exception message if they are violated. +Usage is as follows: ```Julia function myfunction(k,n,A,B) @argcheck k > n @@ -16,15 +54,39 @@ function myfunction(k,n,A,B) # doit end ``` +See also [`@check`](@ref). """ macro argcheck(ex, options...) - if isexpr(ex, :comparison) - argcheck_comparison(ex, options...) + check(ex, ArgCheckFlavor(), options...) +end + +""" + @check + +Check that a condition holds +and produce a nice exception message, if it does not. +Usage is as follows: +```Julia +@check k > n +@check size(A) == size(B) DimensionMismatch +@check det(A) < 0 DomainError() +``` +See also [`@argcheck`](@ref). +""" +macro check(ex, options...) + check(ex, CheckFlavor(), options...) +end + +function check(ex, checkflavor, options...) + codeflavor = if isexpr(ex, :comparison) + ComparisonFlavor() elseif is_simple_call(ex) - argcheck_call(ex, options...) + CallFlavor() else - argcheck_fallback(ex, options...) + FallbackFlavor() end + checker = Checker(ex, checkflavor, codeflavor, options) + check(checker, codeflavor) end function is_simple_call(ex) @@ -37,44 +99,36 @@ function is_simple_call(ex) true end -function argcheck_fallback(ex, options...) - quote - if $(esc(ex)) - nothing - else - err = ArgCheck.build_error($(QuoteNode(ex)), $(esc.(options)...)) - throw(err) - end - end +function check(c, ::FallbackFlavor) + info = Expr(:call, :FallbackErrorInfo, + QuoteNode(c.code), + c.checkflavor, + Expr(:tuple, esc.(c.options)...)) + + condition = esc(c.code) + expr_error_block(info, condition) end -function argcheck_call(ex, options...) +function check(c, ::CallFlavor) + ex = c.code variables = [gensym() for _ in 1:length(ex.args)] assignments = map(variables, ex.args) do vi, exi Expr(:(=), vi, esc(exi)) end condition = Expr(:call, variables...) values = :([$(variables...)]) - err = Expr(:call, - :(ArgCheck.build_error_with_fancy_message), - QuoteNode(ex), - QuoteNode(ex.args), - values, - esc.(options)... - ) - quote - $(assignments...) - if $condition - nothing - else - throw($err) - end - end + info = Expr(:call, :CallErrorInfo, + QuoteNode(c.code), + c.checkflavor, + QuoteNode(c.code.args), + values, + Expr(:tuple, esc.(c.options)...)) + expr_error_block(info, condition, assignments...) end -function argcheck_comparison(ex, options...) - exprs = ex.args[1:2:end] - ops = ex.args[2:2:end] +function check(c::Checker, ::ComparisonFlavor) + exprs = c.code.args[1:2:end] + ops = c.code.args[2:2:end] variables = [gensym() for _ in 1:length(exprs)] ret = [] rhs = exprs[1] @@ -90,46 +144,60 @@ function argcheck_comparison(ex, options...) assignment = Expr(:(=), vrhs, esc(rhs)) condition = Expr(:call, esc(op), vlhs, vrhs) code = Expr(:call, op, lhs, rhs) - err = Expr(:call, :(ArgCheck.build_error_comparison), - QuoteNode(code), QuoteNode(lhs), QuoteNode(rhs), - vlhs, vrhs, esc.(options)...) - reti = quote - $assignment - if $condition - nothing - else - throw($err) - end - end + info = Expr(:call, :ComparisonErrorInfo, + QuoteNode(c.code), + c.checkflavor, + [QuoteNode(lhs), QuoteNode(rhs)], + Expr(:tuple, vlhs, vrhs), + Expr(:tuple, esc.(c.options)...)) + + reti = expr_error_block(info, condition, assignment) append!(ret, reti.args) + end Expr(:block, ret...) end -function build_error(code, T::Type{<:Exception}, options...) - ret = T(options...) - warn("`@argcheck condition $T $(join(options, ' ')...)` is deprecated. Use `@argcheck condition $ret` instead") - ret -end -function build_error(code, msg::AbstractString) - ArgumentError(msg) +function expr_error_block(info, condition, preamble...) + reti = quote + $(preamble...) + if $condition + nothing + else + info = $info + err = build_error(info) + throw(err) + end + end end -build_error(code, T::Type{<:Exception}=ArgumentError) = T("$code must hold.") -build_error(code, err::Exception) = err -build_error_comparison(code, lhs, rhs, vlhs, vrhs, options...) = build_error(code, options...) -@noinline function build_error_comparison(code, lhs, rhs, vlhs, vrhs, T::Type{<:Exception}=ArgumentError) - build_error_with_fancy_message(code, [lhs, rhs], [vlhs, vrhs], T) +default_exception_type(::ArgCheckFlavor) = ArgumentError +default_exception_type(::CheckFlavor) = ErrorException + +function build_error(info) + build_error(info, info.checkflavor, info.options...) +end +function build_error(info, checkflavor, msg::AbstractString) + E = default_exception_type(checkflavor) + E(msg) end -build_error_with_fancy_message(code, variables, values, options...) = build_error(code, options...) -@noinline function build_error_with_fancy_message(code, variables, values, - T::Type{<:Exception}=ArgumentError) - msg = fancy_error_message(code, variables, values) +function build_error(info, checkflavor, T::Type{<:Exception}=default_exception_type(checkflavor)) + msg = error_message(info) T(msg) end +function build_error(info, checkflavor, err::Exception) + err +end + +error_message(info::FallbackErrorInfo) = "$(info.code) must hold." +error_message(info::CallErrorInfo) = fancy_error_message(info) +error_message(info::ComparisonErrorInfo) = fancy_error_message(info) -function fancy_error_message(code, exprs, values) +function fancy_error_message(info) + code = info.code + exprs = info.argument_expressions + values = info.argument_values lines = String[] foreach(exprs, values) do ex, val sex = string(ex) diff --git a/test/runtests.jl b/test/runtests.jl index 58a55ed..2f95c83 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -170,9 +170,10 @@ end @argcheck true "this should not happen" end -@testset "deprecate" begin - - # deprecate - err = @catch_exception_object @argcheck false MyExoticError 1 2 - @test err === MyExoticError(1,2) +@testset "@check" begin + @check true + E = ErrorException + @test_throws E @check false + @test_throws E @check false "oh no" + @test_throws DimensionMismatch @check false DimensionMismatch end