Skip to content

Commit

Permalink
Add a function to check if square A is needed
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 23, 2023
1 parent b4ea93f commit 6c1a43d
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 10 deletions.
6 changes: 2 additions & 4 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,8 @@ import PrecompileTools
for T in (Float32, Float64)
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))

# precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
# PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing)
# DON'T MERGE
precompile_algs = ()
precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing)

for alg in precompile_algs
solve(prob, alg, abstol = T(1e-2))
Expand Down
4 changes: 1 addition & 3 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
alg = get_concrete_algorithm(alg_, prob)
@unpack f, u0, p = prob

# Use QR if the user did not specify a linear solver
if alg.linsolve === nothing || alg.linsolve isa QRFactorization ||
alg.linsolve isa FastQRFactorization
if !needs_square_A(alg.linsolve) && !(u isa Number)
linsolve_with_JᵀJ = Val(false)

Check warning on line 86 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L85-L86

Added lines #L85 - L86 were not covered by tests
else
linsolve_with_JᵀJ = Val(true)

Check warning on line 88 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L88

Added line #L88 was not covered by tests
Expand Down
4 changes: 1 addition & 3 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
u = alias_u0 ? u0 : deepcopy(u0)
fu1 = evaluate_f(prob, u)

# Use QR if the user did not specify a linear solver
if (alg.linsolve === nothing || alg.linsolve isa QRFactorization ||
alg.linsolve isa FastQRFactorization) && !(u isa Number)
if !needs_square_A(alg.linsolve) && !(u isa Number)
linsolve_with_JᵀJ = Val(false)

Check warning on line 168 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L167-L168

Added lines #L167 - L168 were not covered by tests
else
linsolve_with_JᵀJ = Val(true)

Check warning on line 170 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L170

Added line #L170 was not covered by tests
Expand Down
23 changes: 23 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,26 @@ function _try_factorize_and_check_singular!(linsolve, X)
return _issingular(X), false

Check warning on line 256 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L256

Added line #L256 was not covered by tests
end
_try_factorize_and_check_singular!(::Nothing, x) = _issingular(x), false

Check warning on line 258 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L258

Added line #L258 was not covered by tests

# Needs Square Matrix
"""
needs_square_A(alg)
Returns `true` if the algorithm requires a square matrix.
"""
needs_square_A(::Nothing) = false
function needs_square_A(alg)
try
A = [1.0 2.0;

Check warning on line 269 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L266-L269

Added lines #L266 - L269 were not covered by tests
3.0 4.0;
5.0 6.0]
b = ones(Float64, 3)
solve(LinearProblem(A, b), alg)
return false

Check warning on line 274 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L272-L274

Added lines #L272 - L274 were not covered by tests
catch err
return true

Check warning on line 276 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L276

Added line #L276 was not covered by tests
end
end
for alg in (:QRFactorization, :FastQRFactorization)
@eval needs_square_A(::$(alg)) = false

Check warning on line 280 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L280

Added line #L280 was not covered by tests
end

0 comments on commit 6c1a43d

Please sign in to comment.