Skip to content

Commit

Permalink
Merge branch 'master' into as/optional-inbounds
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas authored Jan 9, 2025
2 parents f9c8337 + 113aec7 commit ae5841a
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 29 deletions.
5 changes: 3 additions & 2 deletions ext/MTKHomotopyContinuationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,16 @@ function MTK.HomotopyContinuationProblem(
return prob
end

function MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap = nothing; kwargs...)
function MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap = nothing;
fraction_cancel_fn = SymbolicUtils.simplify_fractions, kwargs...)
if !iscomplete(sys)
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`")
end
transformation = MTK.PolynomialTransformation(sys)
if transformation isa MTK.NotPolynomialError
return transformation
end
result = MTK.transform_system(sys, transformation)
result = MTK.transform_system(sys, transformation; fraction_cancel_fn)
if result isa MTK.NotPolynomialError
return result
end
Expand Down
2 changes: 1 addition & 1 deletion src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ function generate_var!(dict, a, b, varclass, mod;
vd isa Vector && (vd = first(vd))
vd[a] = Dict{Symbol, Any}()
var = if indices === nothing
Symbolics.variable(a, T = SymbolicUtils.FnType{Tuple{Any}, type})(iv)
first(@variables $a(iv)::type)
else
vd[a][:size] = Tuple(lastindex.(indices))
first(@variables $a(iv)[indices...]::type)
Expand Down
63 changes: 44 additions & 19 deletions src/systems/nonlinear/homotopy_continuation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,8 @@ Transform the system `sys` with `transformation` and return a
`PolynomialTransformationResult`, or a `NotPolynomialError` if the system cannot
be transformed.
"""
function transform_system(sys::NonlinearSystem, transformation::PolynomialTransformation)
function transform_system(sys::NonlinearSystem, transformation::PolynomialTransformation;
fraction_cancel_fn = simplify_fractions)
subrules = transformation.substitution_rules
dvs = unknowns(sys)
eqs = full_equations(sys)
Expand All @@ -463,7 +464,7 @@ function transform_system(sys::NonlinearSystem, transformation::PolynomialTransf
return NotPolynomialError(
VariablesAsPolyAndNonPoly(dvs[poly_and_nonpoly]), eqs, polydata)
end
num, den = handle_rational_polynomials(t, new_dvs)
num, den = handle_rational_polynomials(t, new_dvs; fraction_cancel_fn)
# make factors different elements, otherwise the nonzero factors artificially
# inflate the error of the zero factor.
if iscall(den) && operation(den) == *
Expand Down Expand Up @@ -492,43 +493,67 @@ $(TYPEDSIGNATURES)
Given a `x`, a polynomial in variables in `wrt` which may contain rational functions,
express `x` as a single rational function with polynomial `num` and denominator `den`.
Return `(num, den)`.
Keyword arguments:
- `fraction_cancel_fn`: A function which takes a fraction (`operation(expr) == /`) and returns
a simplified symbolic quantity with common factors in the numerator and denominator are
cancelled. Defaults to `SymbolicUtils.simplify_fractions`, but can be changed to
`nothing` to improve performance on large polynomials at the cost of avoiding non-trivial
cancellation.
"""
function handle_rational_polynomials(x, wrt)
function handle_rational_polynomials(x, wrt; fraction_cancel_fn = simplify_fractions)
x = unwrap(x)
symbolic_type(x) == NotSymbolic() && return x, 1
iscall(x) || return x, 1
contains_variable(x, wrt) || return x, 1
any(isequal(x), wrt) && return x, 1

# simplify_fractions cancels out some common factors
# and expands (a / b)^c to a^c / b^c, so we only need
# to handle these cases
x = simplify_fractions(x)
op = operation(x)
args = arguments(x)

if op == /
# numerator and denominator are trivial
num, den = args
# but also search for rational functions in numerator
n, d = handle_rational_polynomials(num, wrt)
num, den = n, den * d
elseif op == +
n1, d1 = handle_rational_polynomials(num, wrt; fraction_cancel_fn)
n2, d2 = handle_rational_polynomials(den, wrt; fraction_cancel_fn)
num, den = n1 * d2, d1 * n2
elseif (op == +) || (op == -)
num = 0
den = 1

# we don't need to do common denominator
# because we don't care about cases where denominator
# is zero. The expression is zero when all the numerators
# are zero.
if op == -
args[2] = -args[2]
end
for arg in args
n, d = handle_rational_polynomials(arg, wrt; fraction_cancel_fn)
num = num * d + n * den
den *= d
end
elseif op == ^
base, pow = args
num, den = handle_rational_polynomials(base, wrt; fraction_cancel_fn)
num ^= pow
den ^= pow
elseif op == *
num = 1
den = 1
for arg in args
n, d = handle_rational_polynomials(arg, wrt)
num += n
n, d = handle_rational_polynomials(arg, wrt; fraction_cancel_fn)
num *= n
den *= d
end
else
return x, 1
error("Unhandled operation in `handle_rational_polynomials`. This should never happen. Please open an issue in ModelingToolkit.jl with an MWE.")
end

if fraction_cancel_fn !== nothing
expr = fraction_cancel_fn(num / den)
if iscall(expr) && operation(expr) == /
num, den = arguments(expr)
else
num, den = expr, 1
end
end

# if the denominator isn't a polynomial in `wrt`, better to not include it
# to reduce the size of the gcd polynomial
if !contains_variable(den, wrt)
Expand Down
11 changes: 7 additions & 4 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -527,13 +527,16 @@ end

function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map,
parammap = DiffEqBase.NullParameters();
check_length = true, use_homotopy_continuation = true, kwargs...) where {iip}
check_length = true, use_homotopy_continuation = false, kwargs...) where {iip}
if !iscomplete(sys)
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearProblem`")
end
prob = safe_HomotopyContinuationProblem(sys, u0map, parammap; check_length, kwargs...)
if prob isa HomotopyContinuationProblem
return prob
if use_homotopy_continuation
prob = safe_HomotopyContinuationProblem(
sys, u0map, parammap; check_length, kwargs...)
if prob isa HomotopyContinuationProblem
return prob
end
end
f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap;
check_length, kwargs...)
Expand Down
33 changes: 31 additions & 2 deletions test/extensions/homotopy_continuation.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelingToolkit, NonlinearSolve, SymbolicIndexingInterface
using SymbolicUtils
import ModelingToolkit as MTK
using LinearAlgebra
using Test
Expand Down Expand Up @@ -29,11 +30,13 @@ import HomotopyContinuation
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid)0.0 atol=1e-10

prob2 = NonlinearProblem(sys, u0)
prob2 = NonlinearProblem(sys, u0; use_homotopy_continuation = true)
@test prob2 isa HomotopyContinuationProblem
sol = solve(prob2; threading = false)
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid)0.0 atol=1e-10

@test NonlinearProblem(sys, u0; use_homotopy_continuation = false) isa NonlinearProblem
end

struct Wrapper
Expand Down Expand Up @@ -217,7 +220,17 @@ end
@mtkbuild sys = NonlinearSystem([x^2 + y^2 - 2x - 2 ~ 0, y ~ (x - 1) / (x - 2)])
prob = HomotopyContinuationProblem(sys, [])
@test any(prob.denominator([2.0], parameter_values(prob)) .≈ 0.0)
@test_nowarn solve(prob; threading = false)
@test SciMLBase.successful_retcode(solve(prob; threading = false))
end

@testset "Rational function forced to common denominators" begin
@variables x = 1
@mtkbuild sys = NonlinearSystem([0 ~ 1 / (1 + x) - x])
prob = HomotopyContinuationProblem(sys, [])
@test any(prob.denominator([-1.0], parameter_values(prob)) .≈ 0.0)
sol = solve(prob; threading = false)
@test SciMLBase.successful_retcode(sol)
@test 1 / (1 + sol.u[1]) - sol.u[1]0.0 atol=1e-10
end
end

Expand All @@ -229,3 +242,19 @@ end
@test sol[x] 2.0
@test sol[y] sin(2.0)
end

@testset "`fraction_cancel_fn`" begin
@variables x = 1
@named sys = NonlinearSystem([0 ~ ((x^2 - 5x + 6) / (x - 2) - 1) * (x^2 - 7x + 12) /
(x - 4)^3])
sys = complete(sys)

@testset "`simplify_fractions`" begin
prob = HomotopyContinuationProblem(sys, [])
@test prob.denominator([0.0], parameter_values(prob)) [4.0]
end
@testset "`nothing`" begin
prob = HomotopyContinuationProblem(sys, []; fraction_cancel_fn = nothing)
@test sort(prob.denominator([0.0], parameter_values(prob))) [2.0, 4.0^3]
end
end
21 changes: 20 additions & 1 deletion test/model_parsing.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ModelingToolkit, Test
using ModelingToolkit, Symbolics, Test
using ModelingToolkit: get_connector_type, get_defaults, get_gui_metadata,
get_systems, get_ps, getdefault, getname, readable_code,
scalarize, symtype, VariableDescription, RegularConnector,
Expand Down Expand Up @@ -990,3 +990,22 @@ struct CustomStruct end
@named sys = MyModel(p = CustomStruct())
@test ModelingToolkit.defaults(sys)[@nonamespace sys.p] == CustomStruct()
end

@testset "Variables are not callable symbolics" begin
@mtkmodel Example begin
@variables begin
x(t)
y(t)
end
@equations begin
x ~ y
end
end
@named ex = Example()
vars = Symbolics.get_variables(only(equations(ex)))
@test length(vars) == 2
for u in Symbolics.unwrap.(unknowns(ex))
@test !Symbolics.hasmetadata(u, Symbolics.CallWithParent)
@test any(isequal(u), vars)
end
end

0 comments on commit ae5841a

Please sign in to comment.