diff --git a/src/Containers/macro.jl b/src/Containers/macro.jl index b69d710ccce..95c7d04fe24 100644 --- a/src/Containers/macro.jl +++ b/src/Containers/macro.jl @@ -15,28 +15,46 @@ function _get_name(c::Expr) return error("Expression $c cannot be used as a name.") end +function _reorder_parameters(args) + if !Meta.isexpr(args[1], :parameters) + return args + end + args = collect(args) + p = popfirst!(args) + for arg in p.args + @assert arg.head == :kw + push!(args, Expr(:(=), arg.args[1], arg.args[2])) + end + return args +end + """ - _extract_kw_args(args) + parse_macro_arguments(error_fn::Function, args) -Process the arguments to a macro, separating out the keyword arguments. +Returns a `Tuple{Vector{Any},Dict{Symbol,Any}}` containing the ordered +positional arguments and a dictionary mapping the keyword arguments. -Return a tuple of (flat_arguments, keyword arguments, and requested_container), -where `requested_container` is a symbol to be passed to `container_code`. +This specially handles the distinction of `@foo(key = value)` and +`@foo(; key = value)` in macros. + +Throws an error if mulitple keyword arguments are passed with the same name. """ -function _extract_kw_args(args) - flat_args, kw_args, requested_container = Any[], Any[], :Auto - for arg in args - if Meta.isexpr(arg, :(=)) - if arg.args[1] == :container - requested_container = arg.args[2] - else - push!(kw_args, arg) +function parse_macro_arguments(error_fn::Function, args) + pos_args, kw_args = Any[], Dict{Symbol,Any}() + for arg in _reorder_parameters(args) + if Meta.isexpr(arg, :(=), 2) + if haskey(kw_args, arg.args[1]) + error_fn( + "the keyword argument `$(arg.args[1])` was given " * + "multiple times.", + ) end + kw_args[arg.args[1]] = arg.args[2] else - push!(flat_args, arg) + push!(pos_args, arg) end end - return flat_args, kw_args, requested_container + return pos_args, kw_args end """ @@ -381,14 +399,16 @@ SparseAxisArray{Int64, 2, Tuple{Int64, Int64}} with 6 entries: [3, 3] = 6 ``` """ -macro container(args...) - args, kw_args, requested_container = _extract_kw_args(args) +macro container(input_args...) + args, kw_args = parse_macro_arguments(error, input_args) + container = get(kw_args, :container, :Auto) @assert length(args) == 2 - @assert isempty(kw_args) - var, value = args - index_vars, indices = build_ref_sets(error, var) - code = container_code(index_vars, indices, esc(value), requested_container) - name = _get_name(var) + for key in keys(kw_args) + @assert key == :container + end + index_vars, indices = build_ref_sets(error, args[1]) + code = container_code(index_vars, indices, esc(args[2]), container) + name = _get_name(args[1]) if name === nothing return code end diff --git a/src/macros.jl b/src/macros.jl index 4cf29256c2b..e0e68c4503e 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -323,45 +323,24 @@ function model_convert( return model_convert.(model, x) end -""" - _add_kw_args(call, kw_args) - -Add the keyword arguments `kw_args` to the function call expression `call`, -escaping the expressions. The elements of `kw_args` should be expressions of the -form `:(key = value)`. The `kw_args` vector can be extracted from the arguments -of a macro with [`Containers._extract_kw_args`](@ref). - -## Example - -```jldoctest -julia> call = :(f(1, a=2)) -:(f(1, a = 2)) - -julia> JuMP._add_kw_args(call, [:(b=3), :(c=4)]) - -julia> call -:(f(1, a = 2, \$(Expr(:escape, :(\$(Expr(:kw, :b, 3))))), \$(Expr(:escape, :(\$(Expr(:kw, :c, 4))))))) -``` -""" -function _add_kw_args(call, kw_args; exclude = Symbol[]) - for kw in kw_args - @assert Meta.isexpr(kw, :(=)) - if kw.args[1] in exclude +function _add_keyword_args(call::Expr, kwargs::Dict; exclude = Symbol[]) + for (key, value) in kwargs + if key in exclude continue end - push!(call.args, esc(Expr(:kw, kw.args...))) + push!(call.args, esc(Expr(:kw, key, value))) end return end """ - _add_positional_args(call, args)::Nothing + _add_positional_args(call::Expr, args::Vector{Any})::Nothing Add the positional arguments `args` to the function call expression `call`, -escaping each argument expression. The elements of `args` should be ones that -were extracted via [`Containers._extract_kw_args`](@ref) and had appropriate -arguments filtered out (e.g., the model argument). This is able to incorporate -additional positional arguments to `call`s that already have keyword arguments. +escaping each argument expression. + +This function is able to incorporate additional positional arguments to `call`s +that already have keyword arguments. ## Example @@ -375,7 +354,7 @@ julia> call :(f(1, $(Expr(:escape, :x)), a = 2)) ``` """ -function _add_positional_args(call, args) +function _add_positional_args(call::Expr, args::Vector) call_args = call.args if Meta.isexpr(call, :.) # call is broadcasted @@ -392,19 +371,6 @@ function _add_positional_args(call, args) return end -function _reorder_parameters(args) - if !Meta.isexpr(args[1], :parameters) - return args - end - args = collect(args) - p = popfirst!(args) - for arg in p.args - @assert arg.head == :kw - push!(args, Expr(:(=), arg.args[1], arg.args[2])) - end - return args -end - _valid_model(::AbstractModel, ::Any) = nothing function _valid_model(m::M, name) where {M} @@ -654,33 +620,6 @@ function _wrap_let(model, code) return code end -function _get_kwarg_value( - error_fn, - kwargs, - key::Symbol; - default = nothing, - escape::Bool = true, -) - index, count = 0, 0 - for (i, kwarg) in enumerate(kwargs) - if kwarg.args[1] == key - count += 1 - index = i - end - end - if count == 0 - return default - elseif count == 1 - if escape - return esc(kwargs[index].args[2]) - else - return kwargs[index].args[2] - end - else - error_fn("`$key` keyword argument was given $count times.") - end -end - include("macros/@objective.jl") include("macros/@expression.jl") include("macros/@constraint.jl") diff --git a/src/macros/@NL.jl b/src/macros/@NL.jl index 64048d24845..060b6560830 100644 --- a/src/macros/@NL.jl +++ b/src/macros/@NL.jl @@ -223,6 +223,30 @@ function _parse_generator_expression(code, x, operators) return y end +""" + _extract_kw_args(args) + +Process the arguments to a macro, separating out the keyword arguments. + +Return a tuple of (flat_arguments, keyword arguments, and requested_container), +where `requested_container` is a symbol to be passed to `container_code`. +""" +function _extract_kw_args(args) + flat_args, kw_args, requested_container = Any[], Any[], :Auto + for arg in args + if Meta.isexpr(arg, :(=)) + if arg.args[1] == :container + requested_container = arg.args[2] + else + push!(kw_args, arg) + end + else + push!(flat_args, arg) + end + end + return flat_args, kw_args, requested_container +end + ### ### @NLobjective(s) ### @@ -252,7 +276,7 @@ macro NLobjective(model, sense, x) function error_fn(str...) return _macro_error(:NLobjective, (model, sense, x), __source__, str...) end - sense_expr = _moi_sense(error_fn, sense) + sense_expr = _parse_moi_sense(error_fn, sense) esc_model = esc(model) parsing_code, expr = _parse_nonlinear_expression(esc_model, x) code = quote @@ -299,7 +323,7 @@ macro NLconstraint(m, x, args...) # Two formats: # - @NLconstraint(m, a*x <= 5) # - @NLconstraint(m, myref[a=1:5], sin(x^a) <= 5) - extra, kw_args, requested_container = Containers._extract_kw_args(args) + extra, kw_args, requested_container = _extract_kw_args(args) if length(extra) > 1 || length(kw_args) > 0 error_fn("too many arguments.") end @@ -413,7 +437,7 @@ subexpression[5]: log(1.0 + (exp(subexpression[2]) + exp(subexpression[3]))) """ macro NLexpression(args...) error_fn(str...) = _macro_error(:NLexpression, args, __source__, str...) - args, kw_args, requested_container = Containers._extract_kw_args(args) + args, kw_args, requested_container = _extract_kw_args(args) if length(args) <= 1 error_fn( "To few arguments ($(length(args))); must pass the model and nonlinear expression as arguments.", @@ -577,7 +601,7 @@ macro NLparameter(model, args...) function error_fn(str...) return _macro_error(:NLparameter, (model, args...), __source__, str...) end - pos_args, kw_args, requested_container = Containers._extract_kw_args(args) + pos_args, kw_args, requested_container = _extract_kw_args(args) value = missing for arg in kw_args if arg.args[1] == :value diff --git a/src/macros/@constraint.jl b/src/macros/@constraint.jl index 0aa6cbfd6b9..07559097ae5 100644 --- a/src/macros/@constraint.jl +++ b/src/macros/@constraint.jl @@ -60,7 +60,7 @@ Other keyword arguments may be supported by JuMP extensions. """ macro constraint(input_args...) error_fn(str...) = _macro_error(:constraint, input_args, __source__, str...) - args, kwargs, container = Containers._extract_kw_args(input_args) + args, kwargs = Containers.parse_macro_arguments(error_fn, input_args) if length(args) < 2 && !isempty(kwargs) error_fn( "No constraint expression detected. If you are trying to " * @@ -82,13 +82,12 @@ macro constraint(input_args...) # [1:2] | Expr | :vect # [i = 1:2, j = 1:2; i + j >= 3] | Expr | :vcat # a constraint expression | Expr | :call or :comparison - c, x = if y isa Symbol || Meta.isexpr(y, (:vect, :vcat, :ref, :typed_vcat)) + c, x = nothing, y + if y isa Symbol || Meta.isexpr(y, (:vect, :vcat, :ref, :typed_vcat)) if length(extra) == 0 error_fn("No constraint expression was given.") end - y, popfirst!(extra) - else - nothing, y + c, x = y, popfirst!(extra) end if length(extra) > 1 error_fn("Cannot specify more than 1 additional positional argument.") @@ -102,20 +101,30 @@ macro constraint(input_args...) end is_vectorized, parse_code, build_call = parse_constraint(error_fn, x) _add_positional_args(build_call, extra) - _add_kw_args(build_call, kwargs; exclude = [:base_name, :set_string_name]) - base_name = _get_kwarg_value( - error_fn, - kwargs, - :base_name; - default = string(something(Containers._get_name(c), "")), - ) - set_name_flag = _get_kwarg_value( - error_fn, - kwargs, - :set_string_name; - default = :(set_string_names_on_creation($model)), + _add_keyword_args( + build_call, + kwargs; + exclude = [:base_name, :container, :set_string_name], ) - name_expr = :($set_name_flag ? $(_name_call(base_name, index_vars)) : "") + # ; base_name + default_base_name = string(something(Containers._get_name(c), "")) + base_name = get(kwargs, :base_name, default_base_name) + if base_name isa Expr + base_name = esc(base_name) + end + # ; container + # There is no need to escape this one. + container = get(kwargs, :container, :Auto) + # ; set_string_name + name_expr = _name_call(base_name, index_vars) + if name_expr != "" + set_string_name = if haskey(kwargs, :set_string_name) + esc(kwargs[:set_string_name]) + else + :(set_string_names_on_creation($model)) + end + name_expr = :($set_string_name ? $name_expr : "") + end code = if is_vectorized quote $parse_code diff --git a/src/macros/@expression.jl b/src/macros/@expression.jl index 3e1bd484075..1cdc3ad1e9d 100644 --- a/src/macros/@expression.jl +++ b/src/macros/@expression.jl @@ -65,13 +65,17 @@ julia> expr = @expression(model, [i in 1:3], i * sum(x[j] for j in 1:3)) """ macro expression(input_args...) error_fn(str...) = _macro_error(:expression, input_args, __source__, str...) - args, kw_args, container = Containers._extract_kw_args(input_args) + args, kwargs = Containers.parse_macro_arguments(error_fn, input_args) if !(2 <= length(args) <= 3) - error_fn("needs at least two arguments.") - elseif !isempty(kw_args) - error_fn("unrecognized keyword argument") + error_fn("expected 2 or 3 positional arguments, got $(length(args)).") elseif Meta.isexpr(args[2], :block) error_fn("Invalid syntax. Did you mean to use `@expressions`?") + elseif !isempty(kwargs) + for key in keys(kwargs) + if key != :container + error_fn("unsupported keyword argument `$key`.") + end + end end name_expr = length(args) == 3 ? args[2] : nothing index_vars, indices = Containers.build_ref_sets(error_fn, name_expr) @@ -81,14 +85,15 @@ macro expression(input_args...) "different name for the index.", ) end - expr_var, build_code = _rewrite_expression(args[end]) model = esc(args[1]) + expr, build_code = _rewrite_expression(args[end]) code = quote $build_code # Don't leak a `_MA.Zero` if the expression is an empty summation, or # other structure that returns `_MA.Zero()`. - _replace_zero($model, $expr_var) + _replace_zero($model, $expr) end + container = get(kwargs, :container, :Auto) return _finalize_macro( model, Containers.container_code(index_vars, indices, code, container), diff --git a/src/macros/@objective.jl b/src/macros/@objective.jl index 074c532bffc..4f35db77f55 100644 --- a/src/macros/@objective.jl +++ b/src/macros/@objective.jl @@ -51,18 +51,19 @@ julia> @objective(model, sense, x^2 - 2x + 1) x² - 2 x + 1 ``` """ -macro objective(model, args...) - function error_fn(str...) - return _macro_error(:objective, (model, args...), __source__, str...) +macro objective(input_args...) + error_fn(str...) = _macro_error(:objective, input_args, __source__, str...) + args, kwargs = Containers.parse_macro_arguments(error_fn, input_args) + if length(args) != 3 + error_fn("expected 3 positional arguments, got $(length(args)).") + elseif !isempty(kwargs) + for key in keys(kwargs) + error_fn("unsupported keyword argument `$key`.") + end end - if length(args) != 2 - error_fn( - "needs three arguments: model, objective sense (Max or Min), and an expression.", - ) - end - esc_model = esc(model) - sense = _moi_sense(error_fn, args[1]) - expr, parse_code = _rewrite_expression(args[2]) + esc_model = esc(args[1]) + sense = _parse_moi_sense(error_fn, args[2]) + expr, parse_code = _rewrite_expression(args[3]) code = quote $parse_code # Don't leak a `_MA.Zero` if the objective expression is an empty @@ -74,33 +75,20 @@ macro objective(model, args...) return _finalize_macro(esc_model, code, __source__) end -""" - _moi_sense(error_fn::Function, sense) - -Return an expression whose value is an `MOI.OptimizationSense` corresponding -to `sense`. - -Sense is either the symbol `:Min` or `:Max`, corresponding respectively to -`MIN_SENSE` and `MAX_SENSE` or it is another expression, which should be the -name of a variable or expression whose value is an `MOI.OptimizationSense`. - -In the last case, the expression throws an error using the `error_fn` -function in case the value is not an `MOI.OptimizationSense`. -""" -function _moi_sense(error_fn::Function, sense) +function _parse_moi_sense(error_fn::Function, sense) if sense == :Min return MIN_SENSE elseif sense == :Max return MAX_SENSE end - return :(_throw_error_for_invalid_sense($error_fn, $(esc(sense)))) + return :(_moi_sense($error_fn, $(esc(sense)))) end -function _throw_error_for_invalid_sense(error_fn::Function, sense) +function _moi_sense(error_fn::Function, sense) return error_fn( "unexpected sense `$sense`. The sense must be an " * "`::MOI.OptimizatonSense`, or the symbol `:Min` or `:Max`.", ) end -_throw_error_for_invalid_sense(::Function, sense::MOI.OptimizationSense) = sense +_moi_sense(::Function, sense::MOI.OptimizationSense) = sense diff --git a/src/macros/@variable.jl b/src/macros/@variable.jl index e0b452bda58..1e6dbc88197 100644 --- a/src/macros/@variable.jl +++ b/src/macros/@variable.jl @@ -141,18 +141,13 @@ julia> @variable(model, z[i=1:3], set_string_name = false) """ macro variable(input_args...) error_fn(str...) = _macro_error(:variable, input_args, __source__, str...) - # We need to re-order the parameters here to account for cases like - # `@variable(model; integer = true)`, since Julia handles kwargs by placing - # them first(!) in the list of arguments. - args = _reorder_parameters(input_args) - model = esc(args[1]) + args, kwargs = Containers.parse_macro_arguments(error_fn, input_args) if length(args) >= 2 && Meta.isexpr(args[2], :block) error_fn("Invalid syntax. Did you mean to use `@variables`?") end - pos_args, kw_args, container = Containers._extract_kw_args(args[2:end]) - # if there is only a single non-keyword argument, this is an anonymous - # variable spec and the one non-kwarg is the model - x = isempty(pos_args) ? nothing : popfirst!(pos_args) + model_sym = popfirst!(args) + model = esc(model_sym) + x = isempty(args) ? nothing : popfirst!(args) if x == :Int error_fn( "Ambiguous variable name $x detected. To specify an anonymous " * @@ -170,10 +165,8 @@ macro variable(input_args...) "matrix use `@variable(model, [1:n, 1:n], PSD)` instead.", ) end - info_kwargs = [ - (kw.args[1], _esc_non_constant(kw.args[2])) for - kw in kw_args if kw.args[1] in _INFO_KWARGS - ] + info_kwargs = + [(k, _esc_non_constant(v)) for (k, v) in kwargs if k in _INFO_KWARGS] info_expr = _VariableInfoExpr(; info_kwargs...) # There are four cases to consider: # x | type of x | x.head @@ -199,15 +192,15 @@ macro variable(input_args...) error_fn("Expected $var to be a variable name") end index_vars, indices = Containers.build_ref_sets(error_fn, var) - if args[1] in index_vars + if model_sym in index_vars error_fn( - "Index $(args[1]) is the same symbol as the model. Use a " * + "Index $model_sym is the same symbol as the model. Use a " * "different name for the index.", ) end # Handle special keyword arguments # ; set - set_kw = _get_kwarg_value(error_fn, kw_args, :set) + set_kw = get(kwargs, :set, nothing) if set_kw !== nothing if set !== nothing error_fn( @@ -218,27 +211,31 @@ macro variable(input_args...) set = set_kw end # ; base_name - base_name = _get_kwarg_value( - error_fn, - kw_args, - :base_name; - default = string(something(Containers._get_name(var), "")), - ) + default_base_name = string(something(Containers._get_name(var), "")) + base_name = get(kwargs, :base_name, default_base_name) + if base_name isa Expr + base_name = esc(base_name) + end + # ; container + # There is no need to escape this one. + container = get(kwargs, :container, :Auto) # ; set_string_name - set_string_name_kw = _get_kwarg_value( - error_fn, - kw_args, - :set_string_name; - default = :(set_string_names_on_creation($model)), - ) + name_expr = _name_call(base_name, index_vars) + if name_expr != "" + set_string_name = if haskey(kwargs, :set_string_name) + esc(kwargs[:set_string_name]) + else + :(set_string_names_on_creation($model)) + end + name_expr = :($set_string_name ? $name_expr : "") + end # ; variable_type - variable_type_kw = - _get_kwarg_value(error_fn, kw_args, :variable_type; escape = false) + variable_type_kw = get(kwargs, :variable_type, nothing) if variable_type_kw !== nothing - push!(pos_args, variable_type_kw) + push!(args, variable_type_kw) end # Handle positional arguments - for ex in pos_args + for ex in args if ex == :Int _set_integer_or_error(error_fn, info_expr) elseif ex == :Bin @@ -269,12 +266,17 @@ macro variable(input_args...) set = HermitianMatrixSpace() end end - filter!(ex -> !(ex in (:Int, :Bin, :PSD, :Symmetric, :Hermitian)), pos_args) + filter!(ex -> !(ex in (:Int, :Bin, :PSD, :Symmetric, :Hermitian)), args) build_code = :(build_variable($error_fn, $(_constructor_expr(info_expr)))) - _add_positional_args(build_code, pos_args) - explicit_kwargs = [:base_name, :variable_type, :set, :set_string_name] - _add_kw_args(build_code, kw_args; exclude = [_INFO_KWARGS; explicit_kwargs]) - name_code = _name_call(base_name, index_vars) + _add_positional_args(build_code, args) + _add_keyword_args( + build_code, + kwargs; + exclude = vcat( + _INFO_KWARGS, + [:base_name, :container, :variable_type, :set, :set_string_name], + ), + ) code = if set === nothing # This is for calls like: # @variable(model, x) @@ -283,8 +285,8 @@ macro variable(input_args...) index_vars, indices, quote - name = $set_string_name_kw ? $name_code : "" - add_variable($model, model_convert($model, $build_code), name) + variable = model_convert($model, $build_code) + add_variable($model, variable, $name_expr) end, container, ) @@ -296,8 +298,7 @@ macro variable(input_args...) indices, quote build = build_variable($error_fn, $build_code, $set) - name = $set_string_name_kw ? $name_code : "" - add_variable($model, model_convert($model, build), name) + add_variable($model, model_convert($model, build), $name_expr) end, container, ) @@ -313,16 +314,15 @@ macro variable(input_args...) build_code, container, ) - name_code = Containers.container_code( + name_expr = Containers.container_code( index_vars, indices, - name_code, + name_expr, container, ) quote build = build_variable($error_fn, $build_code, $set) - name = $set_string_name_kw ? $name_code : "" - add_variable($model, model_convert($model, build), name) + add_variable($model, model_convert($model, build), $name_expr) end end return _finalize_macro( diff --git a/test/test_macros.jl b/test/test_macros.jl index 1c05b9e5f3d..d40ba967165 100644 --- a/test/test_macros.jl +++ b/test/test_macros.jl @@ -471,8 +471,8 @@ function test_extension_constraint_naming( @test name(cref) == "c1" cref = @constraint(model, c2, x == 0, base_name = "cat") @test name(cref) == "cat" - crefs = @constraint(model, [1:2], x == 0, base_name = "cat") - @test name.(crefs) == ["cat[1]", "cat[2]"] + crefs = @constraint(model, [i in 1:2], x == 0, base_name = "cat_$i") + @test name.(crefs) == ["cat_1[1]", "cat_2[2]"] @test_macro_throws ErrorException @constraint(model, c3[1:2]) @test_macro_throws ErrorException @constraint(model, "c"[1:2], x == 0) return @@ -1693,8 +1693,7 @@ function test_objective_not_enough_arguments() model = Model() @test_macro_throws( ErrorException( - "In `@objective(model, Min)`: needs three arguments: model, " * - "objective sense (Max or Min), and an expression.", + "In `@objective(model, Min)`: expected 3 positional arguments, got 2.", ), @objective(model, Min), ) @@ -1705,7 +1704,7 @@ function test_expression_not_enough_arguments() model = Model() @test_macro_throws( ErrorException( - "In `@expression(model)`: needs at least two arguments.", + "In `@expression(model)`: expected 2 or 3 positional arguments, got 1.", ), @expression(model), ) @@ -1717,13 +1716,25 @@ function test_expression_keyword_arguments() @variable(model, x) @test_macro_throws( ErrorException( - "In `@expression(model, x, foo = 1)`: unrecognized keyword argument", + "In `@expression(model, x, foo = 1)`: unsupported keyword argument `foo`.", ), @expression(model, x, foo = 1), ) return end +function test_objective_keyword_arguments() + model = Model() + @variable(model, x) + @test_macro_throws( + ErrorException( + "In `@objective(model, Min, x, foo = 1)`: unsupported keyword argument `foo`.", + ), + @objective(model, Min, x, foo = 1), + ) + return +end + function test_build_constraint_invalid() model = Model() @variable(model, x) @@ -2157,4 +2168,14 @@ function test_bad_objective_sense() return end +function test_expression_container_kwarg() + model = Model() + @variable(model, x) + @expression(model, ex1[i in 1:2], i * x, container = DenseAxisArray) + @test ex1 isa Containers.DenseAxisArray + @expression(model, ex2[i in 1:2], i * x; container = DenseAxisArray) + @test ex2 isa Containers.DenseAxisArray + return +end + end # module diff --git a/test/test_variable.jl b/test/test_variable.jl index 22c8d951144..431494bcc87 100644 --- a/test/test_variable.jl +++ b/test/test_variable.jl @@ -624,7 +624,7 @@ function test_extension_variables_constrained_on_creation_errors( @test_macro_throws( ErrorException( "In `@variable(model, x[1:2], set = SecondOrderCone(), set = PSDCone())`: " * - "`set` keyword argument was given 2 times.", + "the keyword argument `set` was given multiple times.", ), @variable(model, x[1:2], set = SecondOrderCone(), set = PSDCone()), )