Skip to content

Commit

Permalink
Add parse_macro_arguments to unify how we handle macro inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Dec 10, 2023
1 parent c151994 commit 3b8fe8b
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 66 deletions.
58 changes: 51 additions & 7 deletions src/Containers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,48 @@ function _extract_kw_args(args)
return flat_args, kw_args, requested_container
end

function _reorder_parameters(args)
if !Meta.isexpr(args[1], :parameters)
return args
end
args = collect(args)
p = popfirst!(args)
for arg in p.args
@assert arg.head == :kw
push!(args, Expr(:(=), arg.args[1], arg.args[2]))
end
return args
end

"""
parse_macro_arguments(error_fn::Function, args)
Returns a `Tuple{Vector{Any},Dict{Symbol,Any}}` containing the ordered
positional arguments and a dictionary mapping the keyword arguments.
This specially handles the distinction of `@foo(key = value)` and
`@foo(; key = value)` in macros.
Throws an error if mulitple keyword arguments are passed with the same name.
"""
function parse_macro_arguments(error_fn::Function, args)
pos_args, kw_args = Any[], Dict{Symbol,Any}()
for arg in _reorder_parameters(args)
if Meta.isexpr(arg, :(=), 2)
if haskey(kw_args, arg.args[1])
error_fn(

Check warning on line 71 in src/Containers/macro.jl

View check run for this annotation

Codecov / codecov/patch

src/Containers/macro.jl#L71

Added line #L71 was not covered by tests
"The keyword argument $(arg.args[1]) has been given " *
"mulitple times",
)
end
kw_args[arg.args[1]] = arg.args[2]
else
push!(pos_args, arg)
end
end
return pos_args, kw_args
end

"""
_explicit_oneto(index_set)
Expand Down Expand Up @@ -381,14 +423,16 @@ SparseAxisArray{Int64, 2, Tuple{Int64, Int64}} with 6 entries:
[3, 3] = 6
```
"""
macro container(args...)
args, kw_args, requested_container = _extract_kw_args(args)
macro container(input_args...)
args, kw_args = parse_macro_arguments(error, input_args)
container = get(kw_args, :container, :Auto)
@assert length(args) == 2
@assert isempty(kw_args)
var, value = args
index_vars, indices = build_ref_sets(error, var)
code = container_code(index_vars, indices, esc(value), requested_container)
name = _get_name(var)
for key in keys(kw_args)
@assert key == :container
end
index_vars, indices = build_ref_sets(error, args[1])
code = container_code(index_vars, indices, esc(args[2]), container)
name = _get_name(args[1])
if name === nothing
return code
end
Expand Down
10 changes: 10 additions & 0 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,16 @@ function _add_kw_args(call, kw_args; exclude = Symbol[])
return
end

function _add_keyword_args(call::Expr, kwargs::Dict; exclude = Symbol[])
for (key, value) in kwargs
if key in exclude
continue
end
push!(call.args, esc(Expr(:kw, key, value)))
end
return
end

"""
_add_positional_args(call, args)::Nothing
Expand Down
2 changes: 1 addition & 1 deletion src/macros/@NL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ macro NLobjective(model, sense, x)
function error_fn(str...)
return _macro_error(:NLobjective, (model, sense, x), __source__, str...)
end
sense_expr = _moi_sense(error_fn, sense)
sense_expr = _parse_moi_sense(error_fn, sense)
esc_model = esc(model)
parsing_code, expr = _parse_nonlinear_expression(esc_model, x)
code = quote
Expand Down
44 changes: 26 additions & 18 deletions src/macros/@constraint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ user syntax: `@constraint(model, ref[...], expr, my_arg, kwargs...)`.
"""
macro constraint(input_args...)
error_fn(str...) = _macro_error(:constraint, input_args, __source__, str...)
args, kwargs, container = Containers._extract_kw_args(input_args)
args, kwargs = Containers.parse_macro_arguments(error_fn, input_args)
if length(args) < 2 && !isempty(kwargs)
error_fn(
"No constraint expression detected. If you are trying to " *
Expand All @@ -102,13 +102,12 @@ macro constraint(input_args...)
# [1:2] | Expr | :vect
# [i = 1:2, j = 1:2; i + j >= 3] | Expr | :vcat
# a constraint expression | Expr | :call or :comparison
c, x = if y isa Symbol || Meta.isexpr(y, (:vect, :vcat, :ref, :typed_vcat))
c, x = nothing, y
if y isa Symbol || Meta.isexpr(y, (:vect, :vcat, :ref, :typed_vcat))
if length(extra) == 0
error_fn("No constraint expression was given.")
end
y, popfirst!(extra)
else
nothing, y
c, x = y, popfirst!(extra)
end
if length(extra) > 1
error_fn("Cannot specify more than 1 additional positional argument.")
Expand All @@ -122,20 +121,29 @@ macro constraint(input_args...)
end
is_vectorized, parse_code, build_call = parse_constraint(error_fn, x)
_add_positional_args(build_call, extra)
_add_kw_args(build_call, kwargs; exclude = [:base_name, :set_string_name])
base_name = _get_kwarg_value(
error_fn,
kwargs,
:base_name;
default = string(something(Containers._get_name(c), "")),
)
set_name_flag = _get_kwarg_value(
error_fn,
kwargs,
:set_string_name;
default = :(set_string_names_on_creation($model)),
_add_keyword_args(
build_call,
kwargs;
exclude = [:base_name, :container, :set_string_name],
)
name_expr = :($set_name_flag ? $(_name_call(base_name, index_vars)) : "")
# ; base_name
default_base_name = string(something(Containers._get_name(c), ""))
base_name = get(kwargs, :base_name, default_base_name)
if base_name isa Expr
base_name = esc(base_name)
end
# ; container
# There is no need to escape this one.
container = get(kwargs, :container, :Auto)
# ; set_string_name
name_expr = _name_call(base_name, index_vars)
if name_expr != ""
# We use args[1] here instead of `model` because `model` is already
# escaped.
default = Expr(:call, set_string_names_on_creation, args[1])
set_string_name = esc(get(kwargs, :set_string_name, default))
name_expr = :($set_string_name ? $name_expr : "")
end
code = if is_vectorized
quote
$parse_code
Expand Down
17 changes: 11 additions & 6 deletions src/macros/@expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,17 @@ julia> expr = @expression(model, [i in 1:3], i * sum(x[j] for j in 1:3))
"""
macro expression(input_args...)
error_fn(str...) = _macro_error(:expression, input_args, __source__, str...)
args, kw_args, container = Containers._extract_kw_args(input_args)
args, kwargs = Containers.parse_macro_arguments(error_fn, input_args)
if !(2 <= length(args) <= 3)
error_fn("needs at least two arguments.")
elseif !isempty(kw_args)
error_fn("unrecognized keyword argument")
error_fn("expected 2 or 3 positional arguments, got $(length(args)).")
elseif Meta.isexpr(args[2], :block)
error_fn("Invalid syntax. Did you mean to use `@expressions`?")
elseif !isempty(kwargs)
for key in keys(kwargs)
if key != :container
error_fn("unsupported keyword argument `$key`.")
end
end
end
name_expr = length(args) == 3 ? args[2] : nothing
index_vars, indices = Containers.build_ref_sets(error_fn, name_expr)
Expand All @@ -70,14 +74,15 @@ macro expression(input_args...)
"different name for the index.",
)
end
expr_var, build_code = _rewrite_expression(args[end])
model = esc(args[1])
expr, build_code = _rewrite_expression(args[end])
code = quote
$build_code
# Don't leak a `_MA.Zero` if the expression is an empty summation, or
# other structure that returns `_MA.Zero()`.
_replace_zero($model, $expr_var)
_replace_zero($model, $expr)
end
container = get(kwargs, :container, :Auto)
return _finalize_macro(
model,
Containers.container_code(index_vars, indices, code, container),
Expand Down
44 changes: 16 additions & 28 deletions src/macros/@objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,19 @@ julia> @objective(model, sense, x^2 - 2x + 1)
x² - 2 x + 1
```
"""
macro objective(model, args...)
function error_fn(str...)
return _macro_error(:objective, (model, args...), __source__, str...)
macro objective(input_args...)
error_fn(str...) = _macro_error(:objective, input_args, __source__, str...)
args, kwargs = Containers.parse_macro_arguments(error_fn, input_args)
if length(args) != 3
error_fn("expected 3 positional arguments, got $(length(args)).")
elseif !isempty(kwargs)
for key in keys(kwargs)
error_fn("unsupported keyword argument `$key`.")
end

Check warning on line 54 in src/macros/@objective.jl

View check run for this annotation

Codecov / codecov/patch

src/macros/@objective.jl#L54

Added line #L54 was not covered by tests
end
if length(args) != 2
error_fn(
"needs three arguments: model, objective sense (Max or Min), and an expression.",
)
end
esc_model = esc(model)
sense = _moi_sense(error_fn, args[1])
expr, parse_code = _rewrite_expression(args[2])
esc_model = esc(args[1])
sense = _parse_moi_sense(error_fn, args[2])
expr, parse_code = _rewrite_expression(args[3])
code = quote
$parse_code
# Don't leak a `_MA.Zero` if the objective expression is an empty
Expand All @@ -66,33 +67,20 @@ macro objective(model, args...)
return _finalize_macro(esc_model, code, __source__)
end

"""
_moi_sense(error_fn::Function, sense)
Return an expression whose value is an `MOI.OptimizationSense` corresponding
to `sense`.
Sense is either the symbol `:Min` or `:Max`, corresponding respectively to
`MIN_SENSE` and `MAX_SENSE` or it is another expression, which should be the
name of a variable or expression whose value is an `MOI.OptimizationSense`.
In the last case, the expression throws an error using the `error_fn`
function in case the value is not an `MOI.OptimizationSense`.
"""
function _moi_sense(error_fn::Function, sense)
function _parse_moi_sense(error_fn::Function, sense)
if sense == :Min
return MIN_SENSE
elseif sense == :Max
return MAX_SENSE
end
return :(_throw_error_for_invalid_sense($error_fn, $(esc(sense))))
return :(_moi_sense($error_fn, $(esc(sense))))
end

function _throw_error_for_invalid_sense(error_fn::Function, sense)
function _moi_sense(error_fn::Function, sense)
return error_fn(
"unexpected sense `$sense`. The sense must be an " *
"`::MOI.OptimizatonSense`, or the symbol `:Min` or `:Max`.",
)
end

_throw_error_for_invalid_sense(::Function, sense::MOI.OptimizationSense) = sense
_moi_sense(::Function, sense::MOI.OptimizationSense) = sense
33 changes: 27 additions & 6 deletions test/test_macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,8 @@ function test_extension_constraint_naming(
@test name(cref) == "c1"
cref = @constraint(model, c2, x == 0, base_name = "cat")
@test name(cref) == "cat"
crefs = @constraint(model, [1:2], x == 0, base_name = "cat")
@test name.(crefs) == ["cat[1]", "cat[2]"]
crefs = @constraint(model, [i in 1:2], x == 0, base_name = "cat_$i")
@test name.(crefs) == ["cat_1[1]", "cat_2[2]"]
@test_macro_throws ErrorException @constraint(model, c3[1:2])
@test_macro_throws ErrorException @constraint(model, "c"[1:2], x == 0)
return
Expand Down Expand Up @@ -1693,8 +1693,7 @@ function test_objective_not_enough_arguments()
model = Model()
@test_macro_throws(
ErrorException(
"In `@objective(model, Min)`: needs three arguments: model, " *
"objective sense (Max or Min), and an expression.",
"In `@objective(model, Min)`: expected 3 positional arguments, got 2.",
),
@objective(model, Min),
)
Expand All @@ -1705,7 +1704,7 @@ function test_expression_not_enough_arguments()
model = Model()
@test_macro_throws(
ErrorException(
"In `@expression(model)`: needs at least two arguments.",
"In `@expression(model)`: expected 2 or 3 positional arguments, got 1.",
),
@expression(model),
)
Expand All @@ -1717,13 +1716,25 @@ function test_expression_keyword_arguments()
@variable(model, x)
@test_macro_throws(
ErrorException(
"In `@expression(model, x, foo = 1)`: unrecognized keyword argument",
"In `@expression(model, x, foo = 1)`: unsupported keyword argument `foo`.",
),
@expression(model, x, foo = 1),
)
return
end

function test_objective_keyword_arguments()
model = Model()
@variable(model, x)
@test_macro_throws(
ErrorException(
"In `@objective(model, Min, x, foo = 1)`: unsupported keyword argument `foo`.",
),
@objective(model, Min, x, foo = 1),
)
return
end

function test_build_constraint_invalid()
model = Model()
@variable(model, x)
Expand Down Expand Up @@ -2157,4 +2168,14 @@ function test_bad_objective_sense()
return
end

function test_expression_container_kwarg()
model = Model()
@variable(model, x)
@expression(model, ex1[i in 1:2], i * x, container = DenseAxisArray)
@test ex1 isa Containers.DenseAxisArray
@expression(model, ex2[i in 1:2], i * x; container = DenseAxisArray)
@test ex2 isa Containers.DenseAxisArray
return
end

end # module

0 comments on commit 3b8fe8b

Please sign in to comment.