From f75fc133bed4e77b6ad7dcce74f8cf2fde86861a Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sun, 6 Oct 2024 13:39:44 -0700 Subject: [PATCH] Added experimental support for MathOpt, and unified expression building across linear, constraint, routing, and MathOpt --- CHANGELOG.md | 4 + ext/or-tools/constraint.cpp | 35 ++++----- ext/or-tools/math_opt.cpp | 66 ++++++++++++++-- ext/or-tools/routing.cpp | 2 +- lib/or-tools.rb | 27 +++---- lib/or_tools/bool_var.rb | 9 --- lib/or_tools/comparison.rb | 17 ++--- lib/or_tools/comparison_operators.rb | 9 --- lib/or_tools/constant.rb | 29 +++---- lib/or_tools/cp_model.rb | 16 ++-- lib/or_tools/expression.rb | 92 +++++++++++++++++++++++ lib/or_tools/int_var.rb | 5 -- lib/or_tools/linear_constraint.rb | 43 ----------- lib/or_tools/linear_expr.rb | 108 --------------------------- lib/or_tools/math_opt/model.rb | 43 +++++++++++ lib/or_tools/math_opt/variable.rb | 15 ++++ lib/or_tools/mp_variable.rb | 13 ---- lib/or_tools/product.rb | 38 ++++++++++ lib/or_tools/product_cst.rb | 35 --------- lib/or_tools/sat_int_var.rb | 29 ------- lib/or_tools/sat_linear_expr.rb | 59 --------------- lib/or_tools/solver.rb | 34 ++++++--- lib/or_tools/utils.rb | 107 ++++++++++++++++++++++++++ lib/or_tools/variable.rb | 29 +++++++ test/expression_test.rb | 6 +- test/linear_test.rb | 13 ++-- test/math_opt_test.rb | 21 +++--- 27 files changed, 488 insertions(+), 416 deletions(-) delete mode 100644 lib/or_tools/bool_var.rb delete mode 100644 lib/or_tools/comparison_operators.rb create mode 100644 lib/or_tools/expression.rb delete mode 100644 lib/or_tools/int_var.rb delete mode 100644 lib/or_tools/linear_constraint.rb delete mode 100644 lib/or_tools/linear_expr.rb create mode 100644 lib/or_tools/math_opt/model.rb create mode 100644 lib/or_tools/math_opt/variable.rb delete mode 100644 lib/or_tools/mp_variable.rb create mode 100644 lib/or_tools/product.rb delete mode 100644 lib/or_tools/product_cst.rb delete mode 100644 lib/or_tools/sat_int_var.rb delete mode 100644 lib/or_tools/sat_linear_expr.rb create mode 100644 lib/or_tools/utils.rb create mode 100644 lib/or_tools/variable.rb diff --git a/CHANGELOG.md b/CHANGELOG.md index fd894ee..bae4a37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.14.0 (unreleased) + +- Added experimental support for `MathOpt` + ## 0.13.1 (2024-10-05) - Added binary installation for Debian 12 diff --git a/ext/or-tools/constraint.cpp b/ext/or-tools/constraint.cpp index d015efc..9d89498 100644 --- a/ext/or-tools/constraint.cpp +++ b/ext/or-tools/constraint.cpp @@ -44,32 +44,23 @@ namespace Rice::detail public: LinearExpr convert(VALUE v) { - Object x(v); LinearExpr expr; - if (x.respond_to("to_i")) { - expr = From_Ruby().convert(x.call("to_i").value()); - } else if (x.respond_to("vars")) { - Array vars = x.call("vars"); - for (const auto& v : vars) { - // TODO clean up - auto cvar = (Array) v; - Object var = cvar[0]; - auto coeff = From_Ruby().convert(cvar[1].value()); + Rice::Object utils = Rice::define_module("ORTools").const_get("Utils"); - if (var.is_a(rb_cBoolVar)) { - expr += From_Ruby().convert(var.value()) * coeff; - } else if (var.is_a(rb_cInteger)) { - expr += From_Ruby().convert(var.value()) * coeff; - } else { - expr += From_Ruby().convert(var.value()) * coeff; - } - } - } else { - if (x.is_a(rb_cBoolVar)) { - expr = From_Ruby().convert(x.value()); + Object x(v); + Rice::Hash coeffs = utils.call("index_expression", x); + + for (const auto& entry : coeffs) { + Object var = entry.key; + auto coeff = From_Ruby().convert(entry.value.value()); + + if (var.is_nil()) { + expr += coeff; + } else if (var.is_a(rb_cBoolVar)) { + expr += From_Ruby().convert(var.value()) * coeff; } else { - expr = From_Ruby().convert(x.value()); + expr += From_Ruby().convert(var.value()) * coeff; } } diff --git a/ext/or-tools/math_opt.cpp b/ext/or-tools/math_opt.cpp index ace71e1..a686743 100644 --- a/ext/or-tools/math_opt.cpp +++ b/ext/or-tools/math_opt.cpp @@ -5,21 +5,77 @@ #include "ext.h" -using operations_research::math_opt::BoundedLinearExpression; +using operations_research::math_opt::LinearConstraint; using operations_research::math_opt::Model; +using operations_research::math_opt::Solve; +using operations_research::math_opt::SolveArguments; +using operations_research::math_opt::SolveResult; +using operations_research::math_opt::SolverType; using operations_research::math_opt::Variable; void init_math_opt(Rice::Module& m) { auto mathopt = Rice::define_module_under(m, "MathOpt"); - Rice::define_class_under(mathopt, "Variable"); + Rice::define_class_under(mathopt, "Variable") + .define_method("id", &Variable::id) + .define_method( + "_eql?", + [](Variable& self, Variable &other) { + return (bool) (self == other); + }); + + Rice::define_class_under(mathopt, "LinearConstraint"); + + Rice::define_class_under(mathopt, "SolveResult") + .define_method( + "objective_value", + [](SolveResult& self) { + return self.objective_value(); + }) + .define_method( + "variable_values", + [](SolveResult& self) { + Rice::Hash map; + for (auto& [k, v] : self.variable_values()) { + map[k] = v; + } + return map; + }); Rice::define_class_under(mathopt, "Model") .define_constructor(Rice::Constructor()) .define_method("add_variable", &Model::AddContinuousVariable) .define_method( - "add_linear_constraint", - [](Model& self, const BoundedLinearExpression& bounded_expr, const std::string& name) { - self.AddLinearConstraint(bounded_expr, name); + "_add_linear_constraint", + [](Model& self) { + return self.AddLinearConstraint(); + }) + .define_method( + "_set_upper_bound", + [](Model& self, LinearConstraint constraint, double upper_bound) { + self.set_upper_bound(constraint, upper_bound); + }) + .define_method("_set_coefficient", &Model::set_coefficient) + .define_method( + "_set_objective_coefficient", + [](Model& self, Variable variable, double value) { + self.set_objective_coefficient(variable, value); + }) + .define_method("_clear_objective", &Model::clear_objective) + .define_method( + "_set_objective_offset", + [](Model& self, double value) { + self.set_objective_offset(value); + }) + .define_method( + "_set_maximize", + [](Model& self) { + self.set_maximize(); + }) + .define_method( + "_solve", + [](Model& self) { + SolveArguments args; + return *Solve(self, SolverType::kGlop, args); }); } diff --git a/ext/or-tools/routing.cpp b/ext/or-tools/routing.cpp index e1f5d4c..e30cbab 100644 --- a/ext/or-tools/routing.cpp +++ b/ext/or-tools/routing.cpp @@ -247,7 +247,7 @@ void init_routing(Rice::Module& m) { if (o.respond_to("left")) { operations_research::IntExpr* left(Rice::detail::From_Ruby().convert(o.call("left"))); operations_research::IntExpr* right(Rice::detail::From_Ruby().convert(o.call("right"))); - auto op = o.call("operator").to_s().str(); + auto op = o.call("op").to_s().str(); if (op == "==") { constraint = self.MakeEquality(left, right); } else if (op == "<=") { diff --git a/lib/or-tools.rb b/lib/or-tools.rb index e1415a2..2f1585a 100644 --- a/lib/or-tools.rb +++ b/lib/or-tools.rb @@ -2,38 +2,34 @@ require "or_tools/ext" # expressions +require_relative "or_tools/expression" require_relative "or_tools/comparison" -require_relative "or_tools/comparison_operators" +require_relative "or_tools/constant" +require_relative "or_tools/product" +require_relative "or_tools/variable" # bin packing require_relative "or_tools/knapsack_solver" # constraint -require_relative "or_tools/bool_var" require_relative "or_tools/cp_model" require_relative "or_tools/cp_solver" require_relative "or_tools/cp_solver_solution_callback" -require_relative "or_tools/sat_int_var" -require_relative "or_tools/sat_linear_expr" +require_relative "or_tools/objective_solution_printer" +require_relative "or_tools/var_array_solution_printer" +require_relative "or_tools/var_array_and_objective_solution_printer" # linear -require_relative "or_tools/linear_expr" -require_relative "or_tools/constant" -require_relative "or_tools/mp_variable" -require_relative "or_tools/linear_constraint" -require_relative "or_tools/product_cst" require_relative "or_tools/solver" +# math opt +require_relative "or_tools/math_opt/model" +require_relative "or_tools/math_opt/variable" + # routing -require_relative "or_tools/int_var" require_relative "or_tools/routing_index_manager" require_relative "or_tools/routing_model" -# solution printers -require_relative "or_tools/objective_solution_printer" -require_relative "or_tools/var_array_solution_printer" -require_relative "or_tools/var_array_and_objective_solution_printer" - # higher level interfaces require_relative "or_tools/basic_scheduler" require_relative "or_tools/seating" @@ -41,6 +37,7 @@ require_relative "or_tools/tsp" # modules +require_relative "or_tools/utils" require_relative "or_tools/version" module ORTools diff --git a/lib/or_tools/bool_var.rb b/lib/or_tools/bool_var.rb deleted file mode 100644 index b148239..0000000 --- a/lib/or_tools/bool_var.rb +++ /dev/null @@ -1,9 +0,0 @@ -module ORTools - class BoolVar - include ComparisonOperators - - def *(other) - SatLinearExpr.new([[self, other]]) - end - end -end diff --git a/lib/or_tools/comparison.rb b/lib/or_tools/comparison.rb index 45d30bc..62178a2 100644 --- a/lib/or_tools/comparison.rb +++ b/lib/or_tools/comparison.rb @@ -1,19 +1,16 @@ module ORTools class Comparison - attr_reader :operator, :left, :right + attr_reader :left, :op, :right - def initialize(operator, left, right) - @operator = operator - @left = left - @right = right - end - - def to_s - "#{left} #{operator} #{right}" + def initialize(left, op, right) + @left = Expression.to_expression(left) + @op = op + @right = Expression.to_expression(right) end def inspect - "#<#{self.class.name} #{to_s}>" + "#{@left.inspect} #{@op} #{@right.inspect}" end + alias_method :to_s, :inspect end end diff --git a/lib/or_tools/comparison_operators.rb b/lib/or_tools/comparison_operators.rb deleted file mode 100644 index f47128a..0000000 --- a/lib/or_tools/comparison_operators.rb +++ /dev/null @@ -1,9 +0,0 @@ -module ORTools - module ComparisonOperators - ["==", "!=", ">", ">=", "<", "<="].each do |operator| - define_method(operator) do |other| - Comparison.new(operator, self, other) - end - end - end -end diff --git a/lib/or_tools/constant.rb b/lib/or_tools/constant.rb index fbeebe0..1feed34 100644 --- a/lib/or_tools/constant.rb +++ b/lib/or_tools/constant.rb @@ -1,23 +1,26 @@ module ORTools - class Constant < LinearExpr - def initialize(val) - @val = val + class Constant < Expression + attr_reader :value + + def initialize(value) + @value = value end - def to_s - @val.to_s + # simplify Ruby sum + def +(other) + @value == 0 ? other : super end - def add_self_to_coeff_map_or_stack(coeffs, multiplier, stack) - coeffs[OFFSET_KEY] += @val * multiplier + def inspect + @value.to_s end - end - class FakeMPVariableRepresentingTheConstantOffset - def solution_value - 1 + def -@ + Constant.new(-value) end - end - OFFSET_KEY = FakeMPVariableRepresentingTheConstantOffset.new + def vars + @vars ||= [] + end + end end diff --git a/lib/or_tools/cp_model.rb b/lib/or_tools/cp_model.rb index 10c52c5..ca642a8 100644 --- a/lib/or_tools/cp_model.rb +++ b/lib/or_tools/cp_model.rb @@ -4,18 +4,18 @@ def add(comparison) case comparison when Comparison method_name = - case comparison.operator - when "==" + case comparison.op + when :== :add_equality - when "!=" + when :!= :add_not_equal - when ">" + when :> :add_greater_than - when ">=" + when :>= :add_greater_or_equal - when "<" + when :< :add_less_than - when "<=" + when :<= :add_less_or_equal else raise ArgumentError, "Unknown operator: #{comparison.operator}" @@ -32,7 +32,7 @@ def add(comparison) end def sum(arr) - arr.sum(SatLinearExpr.new) + Expression.new(arr) end def inspect diff --git a/lib/or_tools/expression.rb b/lib/or_tools/expression.rb new file mode 100644 index 0000000..0c97046 --- /dev/null +++ b/lib/or_tools/expression.rb @@ -0,0 +1,92 @@ +module ORTools + module ExpressionMethods + attr_reader :parts + + def +(other) + Expression.new((parts || [self]) + [Expression.to_expression(other)]) + end + + def -(other) + Expression.new((parts || [self]) + [-Expression.to_expression(other)]) + end + + def -@ + -1 * self + end + + def *(other) + Expression.new([Product.new(self, Expression.to_expression(other))]) + end + + def >(other) + Comparison.new(self, :>, other) + end + + def <(other) + Comparison.new(self, :<, other) + end + + def >=(other) + Comparison.new(self, :>=, other) + end + + def <=(other) + Comparison.new(self, :<=, other) + end + + def ==(other) + Comparison.new(self, :==, other) + end + + def !=(other) + Comparison.new(self, :!=, other) + end + + def inspect + @parts.reject { |v| v.is_a?(Constant) && v.value == 0 }.map(&:inspect).join(" + ").gsub(" + -", " - ") + end + + def to_s + inspect + end + + # keep order + def coerce(other) + if other.is_a?(Numeric) + [Constant.new(other), self] + else + raise TypeError, "#{self.class} can't be coerced into #{other.class}" + end + end + + def value + values = parts.map(&:value) + return nil if values.any?(&:nil?) + + values.sum + end + + def vars + @vars ||= @parts.flat_map(&:vars) + end + end + + class Expression + include ExpressionMethods + + def initialize(parts = []) + @parts = parts + end + + # private + def self.to_expression(other) + if other.is_a?(Numeric) + Constant.new(other) + elsif other.is_a?(Variable) || other.is_a?(Expression) + other + else + raise TypeError, "can't cast #{other.class.name} to Expression" + end + end + end +end diff --git a/lib/or_tools/int_var.rb b/lib/or_tools/int_var.rb deleted file mode 100644 index 850e188..0000000 --- a/lib/or_tools/int_var.rb +++ /dev/null @@ -1,5 +0,0 @@ -module ORTools - class IntVar - include ComparisonOperators - end -end diff --git a/lib/or_tools/linear_constraint.rb b/lib/or_tools/linear_constraint.rb deleted file mode 100644 index fadf5d9..0000000 --- a/lib/or_tools/linear_constraint.rb +++ /dev/null @@ -1,43 +0,0 @@ -module ORTools - class LinearConstraint - attr_reader :expr, :lb, :ub - - def initialize(expr, lb, ub) - @expr = expr - @lb = lb - @ub = ub - end - - def to_s - if @lb > -Float::INFINITY && @ub < Float::INFINITY - if @lb == @ub - "#{@expr} == #{@lb}" - else - "#{@lb} <= #{@expr} <= #{@ub}" - end - elsif @lb > -Float::INFINITY - "#{@expr} >= #{@lb}" - elsif @ub < Float::INFINITY - "#{@expr} <= #{@ub}" - else - "Trivial inequality (always true)" - end - end - - def inspect - "#<#{self.class.name} #{to_s}>" - end - - def extract - coeffs = @expr.coeffs - constant = coeffs.delete(OFFSET_KEY) || 0.0 - if @lb > -Float::INFINITY - lb = @lb - constant - end - if @ub < Float::INFINITY - ub = @ub - constant - end - [coeffs, lb, ub] - end - end -end diff --git a/lib/or_tools/linear_expr.rb b/lib/or_tools/linear_expr.rb deleted file mode 100644 index c7282b2..0000000 --- a/lib/or_tools/linear_expr.rb +++ /dev/null @@ -1,108 +0,0 @@ -module ORTools - module LinearExprMethods - def solution_value - coeffs.sum { |var, coeff| var.solution_value * coeff } - end - - def coeffs - coeffs = Hash.new(0.0) - stack = [[1.0, self]] - while stack.any? - current_multiplier, current_expression = stack.pop - - current_expression.add_self_to_coeff_map_or_stack(coeffs, current_multiplier, stack) - end - coeffs - end - - def +(expr) - LinearExpr.new([self, expr]) - end - - def -(expr) - LinearExpr.new([self, -expr]) - end - - def *(other) - if is_a?(Constant) - ProductCst.new(other, @val) - else - ProductCst.new(self, other) - end - end - - def /(cst) - ProductCst.new(self, 1.0 / other) - end - - def -@ - ProductCst.new(self, -1) - end - - def ==(arg) - if arg.is_a?(Numeric) - LinearConstraint.new(self, arg, arg) - else - LinearConstraint.new(self - arg, 0.0, 0.0) - end - end - - def >=(arg) - if arg.is_a?(Numeric) - LinearConstraint.new(self, arg, Float::INFINITY) - else - LinearConstraint.new(self - arg, 0.0, Float::INFINITY) - end - end - - def <=(arg) - if arg.is_a?(Numeric) - LinearConstraint.new(self, -Float::INFINITY, arg) - else - LinearConstraint.new(self - arg, -Float::INFINITY, 0.0) - end - end - - def inspect - "#<#{self.class.name} #{to_s}>" - end - - def coerce(other) - if other.is_a?(Numeric) - [Constant.new(other), self] - else - raise TypeError, "#{self.class} can't be coerced into #{other.class}" - end - end - end - - class LinearExpr - include LinearExprMethods - - attr_reader :array - - def initialize(array = []) - @array = array.map { |v| cast_to_lin_exp(v) } - end - - def add_self_to_coeff_map_or_stack(coeffs, multiplier, stack) - @array.reverse_each do |arg| - stack << [multiplier, arg] - end - end - - def cast_to_lin_exp(v) - v.is_a?(Numeric) ? Constant.new(v) : v - end - - def to_s - if @array.empty? - "(empty)" - else - "#{@array.map(&:to_s).reject { |v| v == "0" }.join(" + ")}".gsub(" + -", " - ") - end - end - end - - SumArray = LinearExpr -end diff --git a/lib/or_tools/math_opt/model.rb b/lib/or_tools/math_opt/model.rb new file mode 100644 index 0000000..6d93903 --- /dev/null +++ b/lib/or_tools/math_opt/model.rb @@ -0,0 +1,43 @@ +module ORTools + module MathOpt + class Model + def add_linear_constraint(expr) + left, op, const = Utils.index_constraint(expr) + + constraint = _add_linear_constraint + left.each do |var, c| + _set_coefficient(constraint, var, c) + end + case op + when :<= + _set_upper_bound(constraint, const) + else + raise "todo: #{op}" + end + nil + end + + def maximize(objective) + set_objective(objective) + _set_maximize + end + + def solve + _solve + end + + private + + def set_objective(objective) + objective = Expression.to_expression(objective) + coeffs = Utils.index_expression(objective) + offset = coeffs.delete(nil) + + objective.set_offset(offset) if offset + coeffs.each do |var, c| + _set_objective_coefficient(var, c) + end + end + end + end +end diff --git a/lib/or_tools/math_opt/variable.rb b/lib/or_tools/math_opt/variable.rb new file mode 100644 index 0000000..20c6bf5 --- /dev/null +++ b/lib/or_tools/math_opt/variable.rb @@ -0,0 +1,15 @@ +module ORTools + module MathOpt + class Variable + include ORTools::Variable + + def eql?(other) + other.is_a?(self.class) && _eql?(other) + end + + def hash + id.hash + end + end + end +end diff --git a/lib/or_tools/mp_variable.rb b/lib/or_tools/mp_variable.rb deleted file mode 100644 index e8572f7..0000000 --- a/lib/or_tools/mp_variable.rb +++ /dev/null @@ -1,13 +0,0 @@ -module ORTools - class MPVariable - include LinearExprMethods - - def add_self_to_coeff_map_or_stack(coeffs, multiplier, stack) - coeffs[self] += multiplier - end - - def to_s - name - end - end -end diff --git a/lib/or_tools/product.rb b/lib/or_tools/product.rb new file mode 100644 index 0000000..a2393ea --- /dev/null +++ b/lib/or_tools/product.rb @@ -0,0 +1,38 @@ +module ORTools + class Product < Expression + attr_reader :left, :right + + def initialize(left, right) + @left = left + @right = right + end + + def inspect + if @left.is_a?(Constant) && @right.is_a?(Variable) && left.value == -1 + "-#{inspect_part(@right)}" + else + "#{inspect_part(@left)} * #{inspect_part(@right)}" + end + end + + def value + return nil if left.value.nil? || right.value.nil? + + left.value * right.value + end + + def vars + @vars ||= (@left.vars + @right.vars).uniq + end + + private + + def inspect_part(var) + if var.instance_of?(Expression) + "(#{var.inspect})" + else + var.inspect + end + end + end +end diff --git a/lib/or_tools/product_cst.rb b/lib/or_tools/product_cst.rb deleted file mode 100644 index 48e1d94..0000000 --- a/lib/or_tools/product_cst.rb +++ /dev/null @@ -1,35 +0,0 @@ -module ORTools - class ProductCst < LinearExpr - attr_reader :expr, :coef - - def initialize(expr, coef) - @expr = cast_to_lin_exp(expr) - # TODO improve message - raise TypeError, "expected numeric" unless coef.is_a?(Numeric) - @coef = coef - end - - def to_s - if @coef == -1 - "-#{@expr}" - else - expr = @expr.to_s - if expr.include?("+") || expr.include?("-") - expr = "(#{expr})" - end - "#{@coef} * #{expr}" - end - end - - def add_self_to_coeff_map_or_stack(coeffs, multiplier, stack) - current_multiplier = multiplier * @coef - if current_multiplier - stack << [current_multiplier, @expr] - end - end - - def cast_to_lin_exp(v) - v.is_a?(Numeric) ? Constant.new(v) : v - end - end -end diff --git a/lib/or_tools/sat_int_var.rb b/lib/or_tools/sat_int_var.rb deleted file mode 100644 index 3e6f5e3..0000000 --- a/lib/or_tools/sat_int_var.rb +++ /dev/null @@ -1,29 +0,0 @@ -module ORTools - class SatIntVar - include ComparisonOperators - - def *(other) - SatLinearExpr.new([[self, other]]) - end - - def +(other) - SatLinearExpr.new([[self, 1], [other, 1]]) - end - - def -(other) - SatLinearExpr.new([[self, 1], [-other, 1]]) - end - - def -@ - SatLinearExpr.new([[self, -1]]) - end - - def to_s - name - end - - def inspect - "#<#{self.class.name} #{to_s}>" - end - end -end diff --git a/lib/or_tools/sat_linear_expr.rb b/lib/or_tools/sat_linear_expr.rb deleted file mode 100644 index 154bc15..0000000 --- a/lib/or_tools/sat_linear_expr.rb +++ /dev/null @@ -1,59 +0,0 @@ -module ORTools - class SatLinearExpr - include ComparisonOperators - - attr_reader :vars - - def initialize(vars = []) - @vars = vars - end - - def +(other) - add(other, 1) - end - - def -(other) - add(other, -1) - end - - def *(other) - if vars.size == 1 - self.class.new([[vars[0][0], vars[0][1] * other]]) - else - raise ArgumentError, "Multiplication not allowed here" - end - end - - def to_s - vars.map do |v| - k = v[0] - k = k.respond_to?(:name) ? k.name : k.to_s - if v[1] == 1 - k - else - "#{k} * #{v[1]}" - end - end.join(" + ").sub(" + -", " - ") - end - - def inspect - "#<#{self.class.name} #{to_s}>" - end - - private - - def add(other, sign) - other_vars = - case other - when SatLinearExpr - other.vars - when BoolVar, SatIntVar, Integer - [[other, 1]] - else - raise ArgumentError, "Unsupported type: #{other.class.name}" - end - - self.class.new(vars + other_vars.map { |a, b| [a, sign * b] }) - end - end -end diff --git a/lib/or_tools/solver.rb b/lib/or_tools/solver.rb index 4ad38b2..296da91 100644 --- a/lib/or_tools/solver.rb +++ b/lib/or_tools/solver.rb @@ -1,17 +1,31 @@ module ORTools class Solver def sum(arr) - LinearExpr.new(arr) + Expression.new(arr) end def add(expr) - coeffs, lb, ub = expr.extract + left, op, const = Utils.index_constraint(expr) + + case op + when :<= + lb = -infinity + ub = const + when :>= + lb = const + ub = infinity + when :== + lb = const + ub = const + else + raise "todo: #{op}" + end - constraint = self.constraint(lb || -infinity, ub || infinity) - coeffs.each do |v, c| - constraint.set_coefficient(v, c.to_f) + constraint = constraint(lb, ub) + left.each do |var, c| + constraint.set_coefficient(var, c) end - constraint + nil end def maximize(expr) @@ -27,13 +41,13 @@ def minimize(expr) private def set_objective(expr) - coeffs = expr.coeffs - offset = coeffs.delete(OFFSET_KEY) + coeffs = Utils.index_expression(expr, check_linear: true) + offset = coeffs.delete(nil) objective.clear objective.set_offset(offset) if offset - coeffs.each do |v, c| - objective.set_coefficient(v, c) + coeffs.each do |var, c| + objective.set_coefficient(var, c) end end diff --git a/lib/or_tools/utils.rb b/lib/or_tools/utils.rb new file mode 100644 index 0000000..2e847d4 --- /dev/null +++ b/lib/or_tools/utils.rb @@ -0,0 +1,107 @@ +module ORTools + module Utils + def self.index_constraint(constraint) + raise ArgumentError, "Expected Comparison" unless constraint.is_a?(Comparison) + + left = index_expression(constraint.left, check_linear: true) + right = index_expression(constraint.right, check_linear: true) + + const = right.delete(nil).to_f - left.delete(nil).to_f + right.each do |k, v| + left[k] -= v + end + + [left, constraint.op, const] + end + + def self.index_expression(expression, check_linear: false) + vars = Hash.new(0) + case expression + when Numeric + vars[nil] += expression + when Constant + vars[nil] += expression.value + when Variable + vars[expression] += 1 + when Product + if check_linear && expression.left.vars.any? && expression.right.vars.any? + raise ArgumentError, "Nonlinear" + end + vars = index_product(expression.left, expression.right) + when Expression + expression.parts.each do |part| + index_expression(part, check_linear: check_linear).each do |k, v| + vars[k] += v + end + end + else + raise TypeError, "Unsupported type" + end + vars + end + + def self.index_product(left, right) + # normalize + types = [Constant, Variable, Product, Expression] + if types.index { |t| left.is_a?(t) } > types.index { |t| right.is_a?(t) } + left, right = right, left + end + + vars = Hash.new(0) + case left + when Constant + vars = index_expression(right) + vars.transform_values! { |v| v * left.value } + when Variable + case right + when Variable + vars[quad_key(left, right)] = 1 + when Product + index_expression(right).each do |k, v| + case k + when Array + raise Error, "Non-quadratic" + when Variable + vars[quad_key(left, k)] = v + else # nil + raise "Bug?" + end + end + else + right.parts.each do |part| + index_product(left, part).each do |k, v| + vars[k] += v + end + end + end + when Product + index_expression(left).each do |lk, lv| + index_expression(right).each do |rk, rv| + if lk.is_a?(Variable) && rk.is_a?(Variable) + vars[quad_key(lk, rk)] = lv * rv + else + raise "todo" + end + end + end + else # Expression + left.parts.each do |lp| + right.parts.each do |rp| + index_product(lp, rp).each do |k, v| + vars[k] += v + end + end + end + end + vars + end + + def self.quad_key(left, right) + if left.object_id <= right.object_id + [left, right] + else + [right, left] + end + end + end +end diff --git a/lib/or_tools/variable.rb b/lib/or_tools/variable.rb new file mode 100644 index 0000000..9c3a934 --- /dev/null +++ b/lib/or_tools/variable.rb @@ -0,0 +1,29 @@ +module ORTools + module Variable + include ExpressionMethods + + def inspect + name + end + + def vars + @vars ||= [self] + end + end + + class MPVariable + include Variable + end + + class SatIntVar + include Variable + end + + class BoolVar + include Variable + end + + class IntVar + include Variable + end +end diff --git a/test/expression_test.rb b/test/expression_test.rb index b6d1c8c..3ace1b1 100644 --- a/test/expression_test.rb +++ b/test/expression_test.rb @@ -126,8 +126,8 @@ def test_inspect y = model.new_int_var(0, 1, "y") z = model.new_int_var(0, 1, "z") - assert_equal "#", x.inspect - assert_equal "#", (x + y).inspect - assert_equal "#", (x + y == z).inspect + assert_equal "x", x.inspect + assert_equal "x + y", (x + y).inspect + assert_equal "x + y == z", (x + y == z).inspect end end diff --git a/test/linear_test.rb b/test/linear_test.rb index a777f10..350ef0b 100644 --- a/test/linear_test.rb +++ b/test/linear_test.rb @@ -32,10 +32,10 @@ def test_type_error solver = ORTools::Solver.new("LinearProgrammingExample", :glop) x = solver.num_var(0, solver.infinity, "x") - error = assert_raises(TypeError) do - x * x + error = assert_raises(ArgumentError) do + solver.maximize(x * x) end - assert_equal "expected numeric", error.message + assert_equal "Nonlinear", error.message end def test_to_s @@ -63,10 +63,9 @@ def test_inspect solver = ORTools::Solver.new("GLOP") x = solver.num_var(0, solver.infinity, "x") - assert_equal "#", x.inspect - assert_equal "#", (x + 1).inspect - assert_equal "#", ORTools::LinearExpr.new.inspect - assert_equal "#", (x + 1 == 1).inspect + assert_equal "x", x.inspect + assert_equal "x + 1", (x + 1).inspect + assert_equal "x + 1 == 1", (x + 1 == 1).inspect end def test_offset diff --git a/test/math_opt_test.rb b/test/math_opt_test.rb index 8db4dcf..ddec27b 100644 --- a/test/math_opt_test.rb +++ b/test/math_opt_test.rb @@ -6,22 +6,19 @@ def test_basic model = ORTools::MathOpt::Model.new("getting_started_lp") x = model.add_variable(-1.0, 1.5, "x") y = model.add_variable(0.0, 1.0, "y") - - skip "todo" - model.add_linear_constraint(x + y <= 1.5) model.maximize(x + 2 * y) - params = ORTools::MathOpt::SolveParameters.new(enable_output: true) + result = model.solve - result = mathopt.solve(model, :glop, params) - if result.termination.reason != :optimal - raise RuntimeError, "model failed to solve: #{result.termination}" - end + puts "Objective value: #{result.objective_value}" + puts "x: #{result.variable_values[x]}" + puts "y: #{result.variable_values[y]}" - puts "MathOpt solve succeeded" - puts "Objective value:", result.objective_value - puts "x:", result.variable_values[x] - puts "y:", result.variable_values[y] + assert_output <<~EOS + Objective value: 2.5 + x: 0.5 + y: 1.0 + EOS end end