diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
index cb1c9ea86..5e1e61f3c 100644
--- a/.github/workflows/CI.yml
+++ b/.github/workflows/CI.yml
@@ -30,8 +30,7 @@ jobs:
- "part2"
- "part3"
julia-version:
- - "1.6"
- - "1.8"
+ - "1.10"
- "1"
os:
- ubuntu-latest
diff --git a/CITATION.md b/CITATION.md
index 517ad4804..7272f07a3 100644
--- a/CITATION.md
+++ b/CITATION.md
@@ -30,3 +30,17 @@ To cite symbolic distillation of neural networks, the following BibTeX entry can
primaryClass={cs.LG}
}
```
+
+To cite LaSR, please use the following BibTeX entry:
+
+```bibtex
+@misc{grayeli2024symbolicregressionlearnedconcept,
+ title={Symbolic Regression with a Learned Concept Library},
+ author={Arya Grayeli and Atharva Sehgal and Omar Costilla-Reyes and Miles Cranmer and Swarat Chaudhuri},
+ year={2024},
+ eprint={2409.09359},
+ archivePrefix={arXiv},
+ primaryClass={cs.LG},
+ url={https://arxiv.org/abs/2409.09359},
+}
+```
\ No newline at end of file
diff --git a/Project.toml b/Project.toml
index d9765a717..bc4879131 100644
--- a/Project.toml
+++ b/Project.toml
@@ -52,7 +52,7 @@ Distributed = "<0.0.1, 1"
DynamicExpressions = "0.19.3"
DynamicQuantities = "0.10, 0.11, 0.12, 0.13, 0.14, 1"
Enzyme = "0.12"
-JSON = "0.21.4"
+JSON = "0.21"
JSON3 = "1"
LineSearches = "7"
LossFunctions = "0.10, 0.11"
@@ -64,14 +64,14 @@ Pkg = "<0.0.1, 1"
PrecompileTools = "1"
Printf = "<0.0.1, 1"
ProgressBars = "~1.4, ~1.5"
-PromptingTools = "0.54.0"
+PromptingTools = "0.53, 0.54"
Random = "<0.0.1, 1"
Reexport = "1"
SpecialFunctions = "0.10.1, 1, 2"
StatsBase = "0.33, 0.34"
SymbolicUtils = "0.19, ^1.0.5, 2, 3"
TOML = "<0.0.1, 1"
-julia = "1.6"
+julia = "1.10"
[extras]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
diff --git a/README.md b/README.md
index 144674c5e..6cb0bc4b3 100644
--- a/README.md
+++ b/README.md
@@ -1,30 +1,46 @@
-LaSR.jl accelerates the search for symbolic expressions using library learning.
+LibraryAugmentedSymbolicRegression.jl (LaSR.jl) accelerates the search for symbolic expressions using library learning.
| Latest release | Website | Forums | Paper |
| :---: | :---: | :---: | :---: |
-| [![version](https://juliahub.com/docs/LaSR/version.svg)](https://juliahub.com/ui/Packages/LaSR/X2eIS) | [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://trishullab.github.io/lasr-web/) | [![Discussions](https://img.shields.io/badge/discussions-github-informational)](https://github.com/trishullab/LaSR.jl/discussions) | [![Paper](https://img.shields.io/badge/arXiv-????.?????-b31b1b)](https://atharvas.net/static/lasr.pdf) |
+| [![version](https://juliahub.com/docs/LibraryAugmentedSymbolicRegression/version.svg)](https://juliahub.com/ui/Packages/LibraryAugmentedSymbolicRegression/X2eIS) | [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://trishullab.github.io/lasr-web/) | [![Discussions](https://img.shields.io/badge/discussions-github-informational)](https://github.com/trishullab/LibraryAugmentedSymbolicRegression.jl/discussions) | [![Paper](https://img.shields.io/badge/arXiv-2409.09359-b31b1b)](https://arxiv.org/abs/2409.09359) |
| Build status | Coverage |
| :---: | :---: |
-| [![CI](https://github.com/trishullab/LaSR.jl/workflows/CI/badge.svg)](.github/workflows/CI.yml) | [![Coverage Status](https://coveralls.io/repos/github/trishullab/LaSR.jl/badge.svg?branch=master)](https://coveralls.io/github/trishullab/LaSR.jl?branch=master) |
+| [![CI](https://github.com/trishullab/LibraryAugmentedSymbolicRegression.jl/workflows/CI/badge.svg)](.github/workflows/CI.yml) | [![Coverage Status](https://coveralls.io/repos/github/trishullab/LibraryAugmentedSymbolicRegression.jl/badge.svg?branch=master)](https://coveralls.io/github/trishullab/LibraryAugmentedSymbolicRegression.jl?branch=master) |
LaSR is integrated with [SymbolicRegression.jl](https://github.com/MilesCranmer/SymbolicRegression.jl). Check out [PySR](https://github.com/MilesCranmer/PySR) for
a Python frontend.
-[Cite this software](https://arxiv.org/abs/????.?????)
+[Cite this software](https://arxiv.org/abs/2409.09359)
**Contents**:
+- [Benchmarking](#benchmarking)
- [Quickstart](#quickstart)
- [Organization](#organization)
- [LLM Utilities](#llm-utilities)
+## Benchmarking
+
+If you'd like to compare with LaSR, we've archived the code used in the paper in the `lasr-experiments` branch. Clone this repository and run:
+```bash
+$ git switch lasr-experiments
+```
+to switch to the branch and follow the instructions in the README to reproduce our results. This directory contains the code for evaluating LaSR on the
+
+- [x] Feynman Equations dataset
+- [x] Synthetic equations dataset
+ - [x] and generation code
+- [x] Bigbench experiments
+ - [x] and evaluation code
+
+
## Quickstart
Install in Julia with:
@@ -34,7 +50,7 @@ using Pkg
Pkg.add("LibraryAugmentedSymbolicRegression")
```
-LaSR uses the same interface as [SymbolicRegression.jl](https://github.com/MilesCranmer/SymbolicRegression.jl). The easiest way to use LaSR.jl
+LaSR uses the same interface as [SymbolicRegression.jl](https://github.com/MilesCranmer/SymbolicRegression.jl). The easiest way to use LibraryAugmentedSymbolicRegression.jl
is with [MLJ](https://github.com/alan-turing-institute/MLJ.jl).
Let's see an example:
@@ -105,173 +121,39 @@ where here we choose to evaluate the second equation.
For fitting multiple outputs, one can use `MultitargetLaSRRegressor`
(and pass an array of indices to `idx` in `predict` for selecting specific equations).
-For a full list of options available to each regressor, see the [API page](https://astroautomata.com/LaSR.jl/dev/api/).
+For a full list of options available to each regressor, see the [API page](https://astroautomata.com/LibraryAugmentedSymbolicRegression.jl/dev/api/).
### LLM Options
LaSR uses PromptingTools.jl for zero shot prompting. If you wish to make changes to the prompting options, you can pass an `LLMOptions` object to the `LaSRRegressor` constructor. The options available are:
```julia
llm_options = LLMOptions(
- ...
+ active=true, # Whether to use LLM inference or not
+ weights=LLMWeights(llm_mutate=0.5, llm_crossover=0.3, llm_gen_random=0.2), # Probabilities of using LLM for mutation, crossover, and random generation
+ num_pareto_context=5, # Number of equations to sample from the Pareto frontier for summarization.
+ prompt_evol=true, # Whether to evolve natural language concepts through LLM calls.
+ prompt_concepts=true, # Whether to use natural language concepts in the search.
+ api_key="token-abc123", # API key to OpenAI API compatible server.
+ model="meta-llama/Meta-Llama-3-8B-Instruct", # LLM model to use.
+ api_kwargs=Dict("url" => "http://localhost:11440/v1"), # Keyword arguments passed to server.
+ http_kwargs=Dict("retries" => 3, "readtimeout" => 3600), # Keyword arguments passed to HTTP requests.
+ llm_recorder_dir="lasr_runs/debug_0", # Directory to log LLM interactions.
+ llm_context="", # Natural language concept to start with. You should also be able to initialize with a list of concepts.
+ var_order=nothing, # Dict(variable_name => new_name).
+ idea_threshold=30 # Number of concepts to keep track of.
)
```
-
## Organization
-LaSR.jl development is kept independent from the main codebase. However, to ensure LaSR can be used easily, it is integrated into SymbolicRegression.jl via the `ext/SymbolicRegressionLaSRExt` extension module. This, in turn, is loaded into PySR. This cartoon summarizes the interaction between the different packages:
-
-![LaSR.jl organization](https://raw.githubusercontent.com/trishullab/lasr-web/main/static/lasr-code-interactions.svg)
-
-## Code structure
-
-LaSR.jl is organized roughly as follows.
-Rounded rectangles indicate objects, and rectangles indicate functions.
-
-> (if you can't see this diagram being rendered, try pasting it into [mermaid-js.github.io/mermaid-live-editor](https://mermaid-js.github.io/mermaid-live-editor))
-
-```mermaid
-flowchart TB
- op([Options])
- d([Dataset])
- op --> ES
- d --> ES
- subgraph ES[equation_search]
- direction TB
- IP[sr_spawner]
- IP --> p1
- IP --> p2
- subgraph p1[Thread 1]
- direction LR
- pop1([Population])
- pop1 --> src[s_r_cycle]
- src --> opt[optimize_and_simplify_population]
- opt --> pop1
- end
- subgraph p2[Thread 2]
- direction LR
- pop2([Population])
- pop2 --> src2[s_r_cycle]
- src2 --> opt2[optimize_and_simplify_population]
- opt2 --> pop2
- end
- pop1 --> hof
- pop2 --> hof
- hof([HallOfFame])
- hof --> migration
- pop1 <-.-> migration
- pop2 <-.-> migration
- migration[migrate!]
- end
- ES --> output([HallOfFame])
-```
+LibraryAugmentedSymbolicRegression.jl development is kept independent from the main codebase. However, to ensure LaSR can be used easily, it is integrated into SymbolicRegression.jl via the [`ext/SymbolicRegressionLaSRExt`](https://www.example.com) extension module. This, in turn, is loaded into PySR. This cartoon summarizes the interaction between the different packages:
-The `HallOfFame` objects store the expressions with the lowest loss seen at each complexity.
-
-The dependency structure of the code itself is as follows:
-
-```mermaid
-stateDiagram-v2
- AdaptiveParsimony --> Mutate
- AdaptiveParsimony --> Population
- AdaptiveParsimony --> RegularizedEvolution
- AdaptiveParsimony --> SingleIteration
- AdaptiveParsimony --> LaSR
- CheckConstraints --> Mutate
- CheckConstraints --> LaSR
- Complexity --> CheckConstraints
- Complexity --> HallOfFame
- Complexity --> LossFunctions
- Complexity --> Mutate
- Complexity --> Population
- Complexity --> SearchUtils
- Complexity --> SingleIteration
- Complexity --> LaSR
- ConstantOptimization --> Mutate
- ConstantOptimization --> SingleIteration
- Core --> AdaptiveParsimony
- Core --> CheckConstraints
- Core --> Complexity
- Core --> ConstantOptimization
- Core --> HallOfFame
- Core --> InterfaceDynamicExpressions
- Core --> LossFunctions
- Core --> Migration
- Core --> Mutate
- Core --> MutationFunctions
- Core --> PopMember
- Core --> Population
- Core --> Recorder
- Core --> RegularizedEvolution
- Core --> SearchUtils
- Core --> SingleIteration
- Core --> LaSR
- Dataset --> Core
- HallOfFame --> SearchUtils
- HallOfFame --> SingleIteration
- HallOfFame --> LaSR
- InterfaceDynamicExpressions --> LossFunctions
- InterfaceDynamicExpressions --> LaSR
- LossFunctions --> ConstantOptimization
- LossFunctions --> HallOfFame
- LossFunctions --> Mutate
- LossFunctions --> PopMember
- LossFunctions --> Population
- LossFunctions --> LaSR
- Migration --> LaSR
- Mutate --> RegularizedEvolution
- MutationFunctions --> Mutate
- MutationFunctions --> Population
- MutationFunctions --> LaSR
- Operators --> Core
- Operators --> Options
- Options --> Core
- OptionsStruct --> Core
- OptionsStruct --> Options
- PopMember --> ConstantOptimization
- PopMember --> HallOfFame
- PopMember --> Migration
- PopMember --> Mutate
- PopMember --> Population
- PopMember --> RegularizedEvolution
- PopMember --> SingleIteration
- PopMember --> LaSR
- Population --> Migration
- Population --> RegularizedEvolution
- Population --> SearchUtils
- Population --> SingleIteration
- Population --> LaSR
- ProgramConstants --> Core
- ProgramConstants --> Dataset
- ProgressBars --> SearchUtils
- ProgressBars --> LaSR
- Recorder --> Mutate
- Recorder --> RegularizedEvolution
- Recorder --> SingleIteration
- Recorder --> LaSR
- RegularizedEvolution --> SingleIteration
- SearchUtils --> LaSR
- SingleIteration --> LaSR
- Utils --> CheckConstraints
- Utils --> ConstantOptimization
- Utils --> Options
- Utils --> PopMember
- Utils --> SingleIteration
- Utils --> LaSR
-```
+![LibraryAugmentedSymbolicRegression.jl organization](https://raw.githubusercontent.com/trishullab/lasr-web/main/static/lasr-code-interactions.svg)
-Bash command to generate dependency structure from `src` directory (requires `vim-stream`):
+> [!NOTE]
+> The `ext/SymbolicRegressionLaSRExt` module is not yet available in the released version of SymbolicRegression.jl. It will be available in the release `vX.X.X` of SymbolicRegression.jl.
-```bash
-echo 'stateDiagram-v2'
-IFS=$'\n'
-for f in *.jl; do
- for line in $(cat $f | grep -e 'import \.\.' -e 'import \.'); do
- echo $(echo $line | vims -s 'dwf:d$' -t '%s/^\.*//g' '%s/Module//g') $(basename "$f" .jl);
- done;
-done | vims -l 'f a--> ' | sort
-```
## Search options
-See https://astroautomata.com/LaSR.jl/stable/api/#Options
+Other than `LLMOptions`, We have the same search options as SymbolicRegression.jl. See https://astroautomata.com/SymbolicRegression.jl/stable/api/#Options
diff --git a/ext/SymbolicRegressionEnzymeExt.jl b/ext/LaSREnzymeExt.jl
similarity index 100%
rename from ext/SymbolicRegressionEnzymeExt.jl
rename to ext/LaSREnzymeExt.jl
diff --git a/ext/SymbolicRegressionJSON3Ext.jl b/ext/LaSRJSON3Ext.jl
similarity index 100%
rename from ext/SymbolicRegressionJSON3Ext.jl
rename to ext/LaSRJSON3Ext.jl
diff --git a/ext/SymbolicRegressionSymbolicUtilsExt.jl b/ext/LaSRSymbolicUtilsExt.jl
similarity index 90%
rename from ext/SymbolicRegressionSymbolicUtilsExt.jl
rename to ext/LaSRSymbolicUtilsExt.jl
index 5fbefeb9a..2a7f32cc7 100644
--- a/ext/SymbolicRegressionSymbolicUtilsExt.jl
+++ b/ext/LaSRSymbolicUtilsExt.jl
@@ -1,8 +1,10 @@
module LaSRSymbolicUtilsExt
using SymbolicUtils: Symbolic
-using LibraryAugmentedSymbolicRegression: AbstractExpressionNode, AbstractExpression, Node, Options
-using LibraryAugmentedSymbolicRegression.MLJInterfaceModule: AbstractSRRegressor, get_options
+using LibraryAugmentedSymbolicRegression:
+ AbstractExpressionNode, AbstractExpression, Node, Options
+using LibraryAugmentedSymbolicRegression.MLJInterfaceModule:
+ AbstractSRRegressor, get_options
using DynamicExpressions: get_tree, get_operators
import LibraryAugmentedSymbolicRegression: node_to_symbolic, symbolic_to_node
diff --git a/src/Configure.jl b/src/Configure.jl
index a256a1ee0..1440ba7ad 100644
--- a/src/Configure.jl
+++ b/src/Configure.jl
@@ -257,7 +257,9 @@ function test_module_on_workers(procs, options::Options, verbosity)
for proc in procs
push!(
futures,
- @spawnat proc LibraryAugmentedSymbolicRegression.gen_random_tree(3, options, 5, TEST_TYPE)
+ @spawnat proc LibraryAugmentedSymbolicRegression.gen_random_tree(
+ 3, options, 5, TEST_TYPE
+ )
)
end
for future in futures
diff --git a/src/LLMFunctions.jl b/src/LLMFunctions.jl
index 67ead9a8d..2557af19e 100644
--- a/src/LLMFunctions.jl
+++ b/src/LLMFunctions.jl
@@ -21,7 +21,7 @@ using DynamicExpressions:
AbstractOperatorEnum
using Compat: Returns, @inline
using ..CoreModule: Options, DATA_TYPE, binopmap, unaopmap, LLMOptions
-using ..MutationFunctionsModule: gen_random_tree_fixed_size
+using ..MutationFunctionsModule: gen_random_tree_fixed_size, random_node_and_parent
using PromptingTools:
SystemMessage,
@@ -58,12 +58,13 @@ function convertDict(d)::NamedTuple
end
function get_vars(options::Options)::String
- variable_names = ["x", "y", "z", "k", "j", "l", "m", "n", "p", "a", "b"]
- if !isnothing(options.llm_options.var_order)
+ if !isnothing(options.llm_options) && !isnothing(options.llm_options.var_order)
variable_names = [
options.llm_options.var_order[key] for
key in sort(collect(keys(options.llm_options.var_order)))
]
+ else
+ variable_names = ["x", "y", "z", "k", "j", "l", "m", "n", "p", "a", "b"]
end
return join(variable_names, ", ")
end
@@ -105,6 +106,7 @@ function construct_prompt(
# if n_occurrences is less than |element_list|, add the missing elements after the last occurrence
if n_occurrences < length(element_list)
last_occurrence = findlast(x -> occursin(pattern, x), lines)
+ @assert last_occurrence !== nothing "No occurrences of the element_id_tag found in the user prompt."
for i in reverse((n_occurrences + 1):length(element_list))
new_line = replace(lines[last_occurrence], string(n_occurrences) => string(i))
insert!(lines, last_occurrence + 1, new_line)
@@ -544,6 +546,9 @@ function parse_msg_content(msg_content)
try
out = parse(content) # json parse
+ if out === nothing
+ return []
+ end
if out isa Dict
return [out[key] for key in keys(out)]
end
@@ -893,7 +898,7 @@ function llm_crossover_trees(
String(strip(cross_tree_options[1], [' ', '\n', '"', ',', '.', '[', ']'])),
options,
)
-
+
llm_recorder(options.llm_options, tree_to_expr(t, options), "crossover")
return t, tree2
@@ -934,10 +939,11 @@ function llm_crossover_trees(
)
end
- recording_str = tree_to_expr(cross_tree1, options) * " && " * tree_to_expr(cross_tree2, options)
+ recording_str =
+ tree_to_expr(cross_tree1, options) * " && " * tree_to_expr(cross_tree2, options)
llm_recorder(options.llm_options, recording_str, "crossover")
return cross_tree1, cross_tree2
end
-end
\ No newline at end of file
+end
diff --git a/src/LLMOptions.jl b/src/LLMOptions.jl
index 5f7a01432..1dc8aa3a9 100644
--- a/src/LLMOptions.jl
+++ b/src/LLMOptions.jl
@@ -46,6 +46,7 @@ this module serves as the entry point to define new options for the LLM inferenc
- `llm_recorder_dir::String`: File to save LLM logs to. Useful for debugging.
- `llm_context::AbstractString`: Context string for LLM.
- `var_order::Union{Dict,Nothing}`: Variable order for LLM. (default: nothing)
+- `idea_threshold::UInt32`: Number of concepts to keep track of. (default: 30)
"""
Base.@kwdef mutable struct LLMOptions
active::Bool = false
@@ -55,9 +56,7 @@ Base.@kwdef mutable struct LLMOptions
prompt_evol::Bool = false
api_key::String = ""
model::String = ""
- api_kwargs::Dict = Dict(
- "max_tokens" => 1000
- )
+ api_kwargs::Dict = Dict("max_tokens" => 1000)
http_kwargs::Dict = Dict("retries" => 3, "readtimeout" => 3600)
llm_recorder_dir::String = "lasr_runs/"
prompts_dir::String = "prompts/"
@@ -94,8 +93,6 @@ function validate_llm_options(options::LLMOptions)
end
end
-
-
# """Sample LLM mutation, given the weightings."""
# function sample_llm_mutation(w::LLMWeights)
# weights = convert(Vector, w)
@@ -104,8 +101,6 @@ end
end # module
-
-
# sample invocation following:
# python -m experiments.main --use_llm --use_prompt_evol --model "meta-llama/Meta-Llama-3-8B-Instruct" --api_key "vllm_api.key" --model_url "http://localhost:11440/v1" --exp_idx 0 --dataset_path FeynmanEquations.csv --start_idx 0
# options = LLMOptions(
@@ -122,4 +117,4 @@ end # module
# llm_context="",
# var_order=nothing,
# idea_threshold=30
-# )
\ No newline at end of file
+# )
diff --git a/src/LibraryAugmentedSymbolicRegression.jl b/src/LibraryAugmentedSymbolicRegression.jl
index c91d6717a..27e99b5bb 100644
--- a/src/LibraryAugmentedSymbolicRegression.jl
+++ b/src/LibraryAugmentedSymbolicRegression.jl
@@ -762,7 +762,7 @@ function _initialize_search!(
ropt::RuntimeOptions,
options::Options,
saved_state,
- idea_database_all,
+ idea_database_all::Vector{Vector{String}},
) where {T,L,N}
nout = length(datasets)
@@ -922,7 +922,9 @@ function _main_search_loop!(
window_size=options.populations * 2 * nout,
)
n_iterations = 0
- llm_recorder(options.llm_options, string(div(n_iterations, options.populations)), "n_iterations")
+ llm_recorder(
+ options.llm_options, string(div(n_iterations, options.populations)), "n_iterations"
+ )
worst_members = Vector{PopMember}()
while sum(state.cycles_remaining) > 0
kappa += 1
@@ -1135,7 +1137,9 @@ function _main_search_loop!(
end
################################################################
end
- llm_recorder(options.llm_options, string(div(n_iterations, options.populations)), "n_iterations")
+ llm_recorder(
+ options.llm_options, string(div(n_iterations, options.populations)), "n_iterations"
+ )
return nothing
end
function _tear_down!(state::SearchState, ropt::RuntimeOptions, options::Options)
diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl
index 04204d0d9..c86c3da40 100644
--- a/src/MLJInterface.jl
+++ b/src/MLJInterface.jl
@@ -455,7 +455,9 @@ end
function get_equation_strings_for(::LaSRRegressor, trees, options, variable_names)
return (t -> string_tree(t, options; variable_names=variable_names)).(trees)
end
-function get_equation_strings_for(::MultitargetLaSRRegressor, trees, options, variable_names)
+function get_equation_strings_for(
+ ::MultitargetLaSRRegressor, trees, options, variable_names
+)
return [
(t -> string_tree(t, options; variable_names=variable_names)).(ts) for ts in trees
]
diff --git a/src/Mutate.jl b/src/Mutate.jl
index f377174f2..496b8a337 100644
--- a/src/Mutate.jl
+++ b/src/Mutate.jl
@@ -1,5 +1,6 @@
module MutateModule
+using DispatchDoctor: @unstable
using DynamicExpressions:
AbstractExpressionNode,
AbstractExpression,
@@ -94,7 +95,7 @@ end
# Go through one simulated options.annealing mutation cycle
# exp(-delta/T) defines probability of accepting a change
-function next_generation(
+@unstable function next_generation(
dataset::D,
member::P,
temperature,
@@ -430,7 +431,7 @@ function next_generation(
end
"""Generate a generation via crossover of two members."""
-function crossover_generation(
+@unstable function crossover_generation(
member1::P,
member2::P,
dataset::D,
@@ -485,11 +486,17 @@ function crossover_generation(
check_constraints(child_tree2, options, curmaxsize, afterSize2)
if successful_crossover
- recorder_str = tree_to_expr(child_tree1, options) * " && " * tree_to_expr(child_tree2, options)
+ recorder_str =
+ tree_to_expr(child_tree1, options) *
+ " && " *
+ tree_to_expr(child_tree2, options)
llm_recorder(options.llm_options, recorder_str, "crossover")
llm_skip = true
else
- recorder_str = tree_to_expr(child_tree1, options) * " && " * tree_to_expr(child_tree2, options)
+ recorder_str =
+ tree_to_expr(child_tree1, options) *
+ " && " *
+ tree_to_expr(child_tree2, options)
llm_recorder(options.llm_options, recorder_str, "crossover|failed")
child_tree1, child_tree2 = crossover_trees(tree1, tree2)
end
diff --git a/src/Population.jl b/src/Population.jl
index 547c8b81e..75d0b75c2 100644
--- a/src/Population.jl
+++ b/src/Population.jl
@@ -26,7 +26,13 @@ function Population(pop::Vector{<:PopMember})
return Population(pop, size(pop, 1))
end
-function gen_random_tree_pop(nlength, options, nfeatures, T, idea_database)
+@unstable function gen_random_tree_pop(
+ nlength::Int,
+ options::Options,
+ nfeatures::Int,
+ ::Type{T},
+ idea_database::Union{Vector{String},Nothing},
+) where {T<:DATA_TYPE}
if options.llm_options.active && (rand() < options.llm_options.weights.llm_gen_random)
gen_llm_random_tree(nlength, options, nfeatures, T, idea_database)
else
@@ -37,18 +43,18 @@ end
"""
Population(dataset::Dataset{T,L};
population_size, nlength::Int=3, options::Options,
- nfeatures::Int)
+ nfeatures::Int, idea_database::Vector{String})
-Create random population and score them on the dataset.
+Create random population with LLM and RNG and score them on the dataset.
"""
-function Population(
+@unstable function Population(
dataset::Dataset{T,L};
options::Options,
population_size=nothing,
nlength::Int=3,
nfeatures::Int,
npop=nothing,
- idea_database=nothing,
+ idea_database::Union{Vector{String},Nothing}=nothing,
) where {T,L}
@assert (population_size !== nothing) ⊻ (npop !== nothing)
population_size = if npop === nothing
diff --git a/src/RegularizedEvolution.jl b/src/RegularizedEvolution.jl
index 913df0d89..b2a08c990 100644
--- a/src/RegularizedEvolution.jl
+++ b/src/RegularizedEvolution.jl
@@ -93,7 +93,11 @@ function reg_evol_cycle(
allstar2 = best_of_sample(pop, running_search_statistics, options)
baby1, baby2, crossover_accepted, tmp_num_evals = crossover_generation(
- allstar1, allstar2, dataset, curmaxsize, options;
+ allstar1,
+ allstar2,
+ dataset,
+ curmaxsize,
+ options;
dominating=dominating,
idea_database=idea_database,
)
diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl
index d8f5bd382..ff57e0eeb 100644
--- a/src/SearchUtils.jl
+++ b/src/SearchUtils.jl
@@ -125,7 +125,7 @@ macro sr_spawner(expr, kws...)
end |> esc
end
-function init_dummy_pops(
+@unstable function init_dummy_pops(
npops::Int, datasets::Vector{D}, options::Options
) where {T,L,D<:Dataset{T,L}}
prototype = Population(
diff --git a/src/SingleIteration.jl b/src/SingleIteration.jl
index ce420e5b0..ae1b3dad2 100644
--- a/src/SingleIteration.jl
+++ b/src/SingleIteration.jl
@@ -53,7 +53,7 @@ function s_r_cycle(
curmaxsize,
running_search_statistics,
options,
- record,
+ record;
dominating=dominating,
idea_database=idea_database,
)
diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl
deleted file mode 100644
index d08d49780..000000000
--- a/src/SymbolicRegression.jl
+++ /dev/null
@@ -1,1249 +0,0 @@
-module LibraryAugmentedSymbolicRegression
-
-# Types
-export Population,
- PopMember,
- HallOfFame,
- Options,
- Dataset,
- MutationWeights,
- LLMWeights,
- LLMOptions,
- Node,
- GraphNode,
- ParametricNode,
- Expression,
- ParametricExpression,
- StructuredExpression,
- NodeSampler,
- AbstractExpression,
- AbstractExpressionNode,
- LaSRRegressor,
- MultitargetLaSRRegressor,
- LOSS_TYPE,
- DATA_TYPE,
-
- #Functions:
- equation_search,
- s_r_cycle,
- calculate_pareto_frontier,
- count_nodes,
- compute_complexity,
- @parse_expression,
- parse_expression,
- print_tree,
- string_tree,
- eval_tree_array,
- eval_diff_tree_array,
- eval_grad_tree_array,
- differentiable_eval_tree_array,
- set_node!,
- copy_node,
- node_to_symbolic,
- node_type,
- symbolic_to_node,
- simplify_tree!,
- tree_mapreduce,
- combine_operators,
- gen_random_tree,
- gen_random_tree_fixed_size,
- @extend_operators,
- get_tree,
- get_contents,
- get_metadata,
-
- #Operators
- plus,
- sub,
- mult,
- square,
- cube,
- pow,
- safe_pow,
- safe_log,
- safe_log2,
- safe_log10,
- safe_log1p,
- safe_acosh,
- safe_sqrt,
- neg,
- greater,
- cond,
- relu,
- logical_or,
- logical_and,
-
- # special operators
- gamma,
- erf,
- erfc,
- atanh_clip
-
-using Distributed
-using Printf: @printf, @sprintf
-using PackageExtensionCompat: @require_extensions
-using Pkg: Pkg
-using TOML: parsefile
-using Random: seed!, shuffle!
-using Reexport
-using DynamicExpressions:
- Node,
- GraphNode,
- ParametricNode,
- Expression,
- ParametricExpression,
- StructuredExpression,
- NodeSampler,
- AbstractExpression,
- AbstractExpressionNode,
- @parse_expression,
- parse_expression,
- copy_node,
- set_node!,
- string_tree,
- print_tree,
- count_nodes,
- get_constants,
- get_scalar_constants,
- set_constants!,
- set_scalar_constants!,
- index_constants,
- NodeIndex,
- eval_tree_array,
- differentiable_eval_tree_array,
- eval_diff_tree_array,
- eval_grad_tree_array,
- node_to_symbolic,
- symbolic_to_node,
- combine_operators,
- simplify_tree!,
- tree_mapreduce,
- set_default_variable_names!,
- node_type,
- get_tree,
- get_contents,
- get_metadata
-using DynamicExpressions: with_type_parameters
-@reexport using LossFunctions:
- MarginLoss,
- DistanceLoss,
- SupervisedLoss,
- ZeroOneLoss,
- LogitMarginLoss,
- PerceptronLoss,
- HingeLoss,
- L1HingeLoss,
- L2HingeLoss,
- SmoothedL1HingeLoss,
- ModifiedHuberLoss,
- L2MarginLoss,
- ExpLoss,
- SigmoidLoss,
- DWDMarginLoss,
- LPDistLoss,
- L1DistLoss,
- L2DistLoss,
- PeriodicLoss,
- HuberLoss,
- EpsilonInsLoss,
- L1EpsilonInsLoss,
- L2EpsilonInsLoss,
- LogitDistLoss,
- QuantileLoss,
- LogCoshLoss
-
-# https://discourse.julialang.org/t/how-to-find-out-the-version-of-a-package-from-its-module/37755/15
-const PACKAGE_VERSION = try
- root = pkgdir(@__MODULE__)
- if root == String
- let project = parsefile(joinpath(root, "Project.toml"))
- VersionNumber(project["version"])
- end
- else
- VersionNumber(0, 0, 0)
- end
-catch
- VersionNumber(0, 0, 0)
-end
-
-function deprecate_varmap(variable_names, varMap, func_name)
- if varMap !== nothing
- Base.depwarn("`varMap` is deprecated; use `variable_names` instead", func_name)
- @assert variable_names === nothing "Cannot pass both `varMap` and `variable_names`"
- variable_names = varMap
- end
- return variable_names
-end
-
-using DispatchDoctor: @stable
-
-@stable default_mode = "disable" begin
- include("Utils.jl")
- include("InterfaceDynamicQuantities.jl")
- include("Core.jl")
- include("InterfaceDynamicExpressions.jl")
- include("Recorder.jl")
- include("Complexity.jl")
- include("DimensionalAnalysis.jl")
- include("CheckConstraints.jl")
- include("AdaptiveParsimony.jl")
- include("MutationFunctions.jl")
- include("LLMFunctions.jl")
- include("LossFunctions.jl")
- include("PopMember.jl")
- include("ConstantOptimization.jl")
- include("Population.jl")
- include("HallOfFame.jl")
- include("Mutate.jl")
- include("RegularizedEvolution.jl")
- include("SingleIteration.jl")
- include("ProgressBars.jl")
- include("Migration.jl")
- include("SearchUtils.jl")
- include("ExpressionBuilder.jl")
-end
-
-using .CoreModule:
- MAX_DEGREE,
- BATCH_DIM,
- FEATURE_DIM,
- DATA_TYPE,
- LOSS_TYPE,
- RecordType,
- Dataset,
- Options,
- MutationWeights,
- LLMOptions,
- LLMWeights,
- plus,
- sub,
- mult,
- square,
- cube,
- pow,
- safe_pow,
- safe_log,
- safe_log2,
- safe_log10,
- safe_log1p,
- safe_sqrt,
- safe_acosh,
- neg,
- greater,
- cond,
- relu,
- logical_or,
- logical_and,
- gamma,
- erf,
- erfc,
- atanh_clip,
- create_expression
-using .UtilsModule: is_anonymous_function, recursive_merge, json3_write, @ignore
-using .ComplexityModule: compute_complexity
-using .CheckConstraintsModule: check_constraints
-using .AdaptiveParsimonyModule:
- RunningSearchStatistics, update_frequencies!, move_window!, normalize_frequencies!
-using .MutationFunctionsModule:
- gen_random_tree,
- gen_random_tree_fixed_size,
- random_node,
- random_node_and_parent,
- crossover_trees
-using .LLMFunctionsModule: update_idea_database
-
-using .InterfaceDynamicExpressionsModule: @extend_operators
-using .LossFunctionsModule: eval_loss, score_func, update_baseline_loss!
-using .PopMemberModule: PopMember, reset_birth!
-using .PopulationModule: Population, best_sub_pop, record_population, best_of_sample
-using .HallOfFameModule:
- HallOfFame, calculate_pareto_frontier, string_dominating_pareto_curve
-using .SingleIterationModule: s_r_cycle, optimize_and_simplify_population
-using .ProgressBarsModule: WrappedProgressBar
-using .RecorderModule: @recorder, find_iteration_from_record
-using .MigrationModule: migrate!
-using .SearchUtilsModule:
- SearchState,
- RuntimeOptions,
- WorkerAssignments,
- DefaultWorkerOutputType,
- assign_next_worker!,
- get_worker_output_type,
- extract_from_worker,
- @sr_spawner,
- StdinReader,
- watch_stream,
- close_reader!,
- check_for_user_quit,
- check_for_loss_threshold,
- check_for_timeout,
- check_max_evals,
- ResourceMonitor,
- record_channel_state!,
- estimate_work_fraction,
- update_progress_bar!,
- print_search_state,
- init_dummy_pops,
- load_saved_hall_of_fame,
- load_saved_population,
- construct_datasets,
- save_to_file,
- get_cur_maxsize,
- update_hall_of_fame!
-using .ExpressionBuilderModule: embed_metadata, strip_metadata
-
-@stable default_mode = "disable" begin
- include("deprecates.jl")
- include("Configure.jl")
-end
-
-"""
- equation_search(X, y[; kws...])
-
-Perform a distributed equation search for functions `f_i` which
-describe the mapping `f_i(X[:, j]) ≈ y[i, j]`. Options are
-configured using LibraryAugmentedSymbolicRegression.Options(...),
-which should be passed as a keyword argument to options.
-One can turn off parallelism with `numprocs=0`,
-which is useful for debugging and profiling.
-
-# Arguments
-- `X::AbstractMatrix{T}`: The input dataset to predict `y` from.
- The first dimension is features, the second dimension is rows.
-- `y::Union{AbstractMatrix{T}, AbstractVector{T}}`: The values to predict. The first dimension
- is the output feature to predict with each equation, and the
- second dimension is rows.
-- `niterations::Int=10`: The number of iterations to perform the search.
- More iterations will improve the results.
-- `weights::Union{AbstractMatrix{T}, AbstractVector{T}, Nothing}=nothing`: Optionally
- weight the loss for each `y` by this value (same shape as `y`).
-- `options::Options=Options()`: The options for the search, such as
- which operators to use, evolution hyperparameters, etc.
-- `variable_names::Union{Vector{String}, Nothing}=nothing`: The names
- of each feature in `X`, which will be used during printing of equations.
-- `display_variable_names::Union{Vector{String}, Nothing}=variable_names`: Names
- to use when printing expressions during the search, but not when saving
- to an equation file.
-- `y_variable_names::Union{String,AbstractVector{String},Nothing}=nothing`: The
- names of each output feature in `y`, which will be used during printing
- of equations.
-- `parallelism=:multithreading`: What parallelism mode to use.
- The options are `:multithreading`, `:multiprocessing`, and `:serial`.
- By default, multithreading will be used. Multithreading uses less memory,
- but multiprocessing can handle multi-node compute. If using `:multithreading`
- mode, the number of threads available to julia are used. If using
- `:multiprocessing`, `numprocs` processes will be created dynamically if
- `procs` is unset. If you have already allocated processes, pass them
- to the `procs` argument and they will be used.
- You may also pass a string instead of a symbol, like `"multithreading"`.
-- `numprocs::Union{Int, Nothing}=nothing`: The number of processes to use,
- if you want `equation_search` to set this up automatically. By default
- this will be `4`, but can be any number (you should pick a number <=
- the number of cores available).
-- `procs::Union{Vector{Int}, Nothing}=nothing`: If you have set up
- a distributed run manually with `procs = addprocs()` and `@everywhere`,
- pass the `procs` to this keyword argument.
-- `addprocs_function::Union{Function, Nothing}=nothing`: If using multiprocessing
- (`parallelism=:multithreading`), and are not passing `procs` manually,
- then they will be allocated dynamically using `addprocs`. However,
- you may also pass a custom function to use instead of `addprocs`.
- This function should take a single positional argument,
- which is the number of processes to use, as well as the `lazy` keyword argument.
- For example, if set up on a slurm cluster, you could pass
- `addprocs_function = addprocs_slurm`, which will set up slurm processes.
-- `heap_size_hint_in_bytes::Union{Int,Nothing}=nothing`: On Julia 1.9+, you may set the `--heap-size-hint`
- flag on Julia processes, recommending garbage collection once a process
- is close to the recommended size. This is important for long-running distributed
- jobs where each process has an independent memory, and can help avoid
- out-of-memory errors. By default, this is set to `Sys.free_memory() / numprocs`.
-- `runtests::Bool=true`: Whether to run (quick) tests before starting the
- search, to see if there will be any problems during the equation search
- related to the host environment.
-- `saved_state=nothing`: If you have already
- run `equation_search` and want to resume it, pass the state here.
- To get this to work, you need to have set return_state=true,
- which will cause `equation_search` to return the state. The second
- element of the state is the regular return value with the hall of fame.
- Note that you cannot change the operators or dataset, but most other options
- should be changeable.
-- `return_state::Union{Bool, Nothing}=nothing`: Whether to return the
- state of the search for warm starts. By default this is false.
-- `loss_type::Type=Nothing`: If you would like to use a different type
- for the loss than for the data you passed, specify the type here.
- Note that if you pass complex data `::Complex{L}`, then the loss
- type will automatically be set to `L`.
-- `verbosity`: Whether to print debugging statements or not.
-- `progress`: Whether to use a progress bar output. Only available for
- single target output.
-- `X_units::Union{AbstractVector,Nothing}=nothing`: The units of the dataset,
- to be used for dimensional constraints. For example, if `X_units=["kg", "m"]`,
- then the first feature will have units of kilograms, and the second will
- have units of meters.
-- `y_units=nothing`: The units of the output, to be used for dimensional constraints.
- If `y` is a matrix, then this can be a vector of units, in which case
- each element corresponds to each output feature.
-
-# Returns
-- `hallOfFame::HallOfFame`: The best equations seen during the search.
- hallOfFame.members gives an array of `PopMember` objects, which
- have their tree (equation) stored in `.tree`. Their score (loss)
- is given in `.score`. The array of `PopMember` objects
- is enumerated by size from `1` to `options.maxsize`.
-"""
-function equation_search(
- X::AbstractMatrix{T},
- y::AbstractMatrix{T};
- niterations::Int=10,
- weights::Union{AbstractMatrix{T},AbstractVector{T},Nothing}=nothing,
- options::Options=Options(),
- variable_names::Union{AbstractVector{String},Nothing}=nothing,
- display_variable_names::Union{AbstractVector{String},Nothing}=variable_names,
- y_variable_names::Union{String,AbstractVector{String},Nothing}=nothing,
- parallelism=:multithreading,
- numprocs::Union{Int,Nothing}=nothing,
- procs::Union{Vector{Int},Nothing}=nothing,
- addprocs_function::Union{Function,Nothing}=nothing,
- heap_size_hint_in_bytes::Union{Integer,Nothing}=nothing,
- runtests::Bool=true,
- saved_state=nothing,
- return_state::Union{Bool,Nothing,Val}=nothing,
- loss_type::Type{L}=Nothing,
- verbosity::Union{Integer,Nothing}=nothing,
- progress::Union{Bool,Nothing}=nothing,
- X_units::Union{AbstractVector,Nothing}=nothing,
- y_units=nothing,
- extra::NamedTuple=NamedTuple(),
- v_dim_out::Val{DIM_OUT}=Val(nothing),
- # Deprecated:
- multithreaded=nothing,
- varMap=nothing,
-) where {T<:DATA_TYPE,L,DIM_OUT}
- if multithreaded !== nothing
- error(
- "`multithreaded` is deprecated. Use the `parallelism` argument instead. " *
- "Choose one of :multithreaded, :multiprocessing, or :serial.",
- )
- end
- variable_names = deprecate_varmap(variable_names, varMap, :equation_search)
-
- if weights !== nothing
- @assert length(weights) == length(y)
- weights = reshape(weights, size(y))
- end
-
- datasets = construct_datasets(
- X,
- y,
- weights,
- variable_names,
- display_variable_names,
- y_variable_names,
- X_units,
- y_units,
- extra,
- L,
- )
-
- return equation_search(
- datasets;
- niterations=niterations,
- options=options,
- parallelism=parallelism,
- numprocs=numprocs,
- procs=procs,
- addprocs_function=addprocs_function,
- heap_size_hint_in_bytes=heap_size_hint_in_bytes,
- runtests=runtests,
- saved_state=saved_state,
- return_state=return_state,
- verbosity=verbosity,
- progress=progress,
- v_dim_out=Val(DIM_OUT),
- )
-end
-
-function equation_search(
- X::AbstractMatrix{T1}, y::AbstractMatrix{T2}; kw...
-) where {T1<:DATA_TYPE,T2<:DATA_TYPE}
- U = promote_type(T1, T2)
- return equation_search(
- convert(AbstractMatrix{U}, X), convert(AbstractMatrix{U}, y); kw...
- )
-end
-
-function equation_search(
- X::AbstractMatrix{T1}, y::AbstractVector{T2}; kw...
-) where {T1<:DATA_TYPE,T2<:DATA_TYPE}
- return equation_search(X, reshape(y, (1, size(y, 1))); kw..., v_dim_out=Val(1))
-end
-
-function equation_search(dataset::Dataset; kws...)
- return equation_search([dataset]; kws..., v_dim_out=Val(1))
-end
-
-function equation_search(
- datasets::Vector{D};
- niterations::Int=10,
- options::Options=Options(),
- parallelism=:multithreading,
- numprocs::Union{Int,Nothing}=nothing,
- procs::Union{Vector{Int},Nothing}=nothing,
- addprocs_function::Union{Function,Nothing}=nothing,
- heap_size_hint_in_bytes::Union{Integer,Nothing}=nothing,
- runtests::Bool=true,
- saved_state=nothing,
- return_state::Union{Bool,Nothing,Val}=nothing,
- verbosity::Union{Int,Nothing}=nothing,
- progress::Union{Bool,Nothing}=nothing,
- v_dim_out::Val{DIM_OUT}=Val(nothing),
-) where {DIM_OUT,T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L}}
- concurrency = if parallelism in (:multithreading, "multithreading")
- :multithreading
- elseif parallelism in (:multiprocessing, "multiprocessing")
- :multiprocessing
- elseif parallelism in (:serial, "serial")
- :serial
- else
- error(
- "Invalid parallelism mode: $parallelism. " *
- "You must choose one of :multithreading, :multiprocessing, or :serial.",
- )
- :serial
- end
- not_distributed = concurrency in (:multithreading, :serial)
- not_distributed &&
- procs !== nothing &&
- error(
- "`procs` should not be set when using `parallelism=$(parallelism)`. Please use `:multiprocessing`.",
- )
- not_distributed &&
- numprocs !== nothing &&
- error(
- "`numprocs` should not be set when using `parallelism=$(parallelism)`. Please use `:multiprocessing`.",
- )
-
- _return_state = if return_state isa Val
- first(typeof(return_state).parameters)
- else
- if options.return_state === Val(nothing)
- return_state === nothing ? false : return_state
- else
- @assert(
- return_state === nothing,
- "You cannot set `return_state` in both the `Options` and in the passed arguments."
- )
- first(typeof(options.return_state).parameters)
- end
- end
-
- dim_out = if DIM_OUT === nothing
- length(datasets) > 1 ? 2 : 1
- else
- DIM_OUT
- end
- _numprocs::Int = if numprocs === nothing
- if procs === nothing
- 4
- else
- length(procs)
- end
- else
- if procs === nothing
- numprocs
- else
- @assert length(procs) == numprocs
- numprocs
- end
- end
-
- _verbosity = if verbosity === nothing && options.verbosity === nothing
- 1
- elseif verbosity === nothing && options.verbosity !== nothing
- options.verbosity
- elseif verbosity !== nothing && options.verbosity === nothing
- verbosity
- else
- error(
- "You cannot set `verbosity` in both the search parameters `Options` and the call to `equation_search`.",
- )
- 1
- end
- _progress::Bool = if progress === nothing && options.progress === nothing
- (_verbosity > 0) && length(datasets) == 1
- elseif progress === nothing && options.progress !== nothing
- options.progress
- elseif progress !== nothing && options.progress === nothing
- progress
- else
- error(
- "You cannot set `progress` in both the search parameters `Options` and the call to `equation_search`.",
- )
- false
- end
-
- _addprocs_function = addprocs_function === nothing ? addprocs : addprocs_function
-
- exeflags = if VERSION >= v"1.9" && concurrency == :multiprocessing
- heap_size_hint_in_megabytes = floor(
- Int, (
- if heap_size_hint_in_bytes === nothing
- (Sys.free_memory() / _numprocs)
- else
- heap_size_hint_in_bytes
- end
- ) / 1024^2
- )
- _verbosity > 0 &&
- heap_size_hint_in_bytes === nothing &&
- @info "Automatically setting `--heap-size-hint=$(heap_size_hint_in_megabytes)M` on each Julia process. You can configure this with the `heap_size_hint_in_bytes` parameter."
-
- `--heap-size=$(heap_size_hint_in_megabytes)M`
- else
- ``
- end
-
- # Underscores here mean that we have mutated the variable
- return _equation_search(
- datasets,
- RuntimeOptions(;
- niterations=niterations,
- total_cycles=options.populations * niterations,
- numprocs=_numprocs,
- init_procs=procs,
- addprocs_function=_addprocs_function,
- exeflags=exeflags,
- runtests=runtests,
- verbosity=_verbosity,
- progress=_progress,
- parallelism=Val(concurrency),
- dim_out=Val(dim_out),
- return_state=Val(_return_state),
- ),
- options,
- saved_state,
- )
-end
-
-@noinline function _equation_search(
- datasets::Vector{D}, ropt::RuntimeOptions, options::Options, saved_state
-) where {D<:Dataset}
- # PROMPT EVOLUTION
- idea_database_all = [Vector{String}() for j in 1:length(datasets)]
-
- _validate_options(datasets, ropt, options)
- state = _create_workers(datasets, ropt, options)
- _initialize_search!(state, datasets, ropt, options, saved_state, idea_database_all)
- _warmup_search!(state, datasets, ropt, options, idea_database_all)
- _main_search_loop!(state, datasets, ropt, options, idea_database_all)
- _tear_down!(state, ropt, options)
- return _format_output(state, datasets, ropt, options)
-end
-
-function _validate_options(
- datasets::Vector{D}, ropt::RuntimeOptions, options::Options
-) where {T,L,D<:Dataset{T,L}}
- example_dataset = first(datasets)
- nout = length(datasets)
- @assert nout >= 1
- @assert (nout == 1 || ropt.dim_out == 2)
- @assert options.populations >= 1
- if ropt.progress
- @assert(nout == 1, "You cannot display a progress bar for multi-output searches.")
- @assert(ropt.verbosity > 0, "You cannot display a progress bar with `verbosity=0`.")
- end
- if options.node_type <: GraphNode && ropt.verbosity > 0
- @warn "The `GraphNode` interface and mutation operators are experimental and will change in future versions."
- end
- if ropt.runtests
- test_option_configuration(ropt.parallelism, datasets, options, ropt.verbosity)
- test_dataset_configuration(example_dataset, options, ropt.verbosity)
- end
- for dataset in datasets
- update_baseline_loss!(dataset, options)
- end
- if options.define_helper_functions
- set_default_variable_names!(first(datasets).variable_names)
- end
- if options.seed !== nothing
- seed!(options.seed)
- end
- return nothing
-end
-@stable default_mode = "disable" function _create_workers(
- datasets::Vector{D}, ropt::RuntimeOptions, options::Options
-) where {T,L,D<:Dataset{T,L}}
- stdin_reader = watch_stream(stdin)
-
- record = RecordType()
- @recorder record["options"] = "$(options)"
-
- nout = length(datasets)
- example_dataset = first(datasets)
- example_ex = create_expression(zero(T), options, example_dataset)
- NT = typeof(example_ex)
- PopType = Population{T,L,NT}
- HallOfFameType = HallOfFame{T,L,NT}
- WorkerOutputType = get_worker_output_type(
- Val(ropt.parallelism), PopType, HallOfFameType
- )
- ChannelType = ropt.parallelism == :multiprocessing ? RemoteChannel : Channel
-
- # Pointers to populations on each worker:
- worker_output = Vector{WorkerOutputType}[WorkerOutputType[] for j in 1:nout]
- # Initialize storage for workers
- tasks = [Task[] for j in 1:nout]
- # Set up a channel to send finished populations back to head node
- channels = [[ChannelType(1) for i in 1:(options.populations)] for j in 1:nout]
- (procs, we_created_procs) = if ropt.parallelism == :multiprocessing
- configure_workers(;
- procs=ropt.init_procs,
- ropt.numprocs,
- ropt.addprocs_function,
- options,
- project_path=splitdir(Pkg.project().path)[1],
- file=@__FILE__,
- ropt.exeflags,
- ropt.verbosity,
- example_dataset,
- ropt.runtests,
- )
- else
- Int[], false
- end
- # Get the next worker process to give a job:
- worker_assignment = WorkerAssignments()
- # Randomly order which order to check populations:
- # This is done so that we do work on all nout equally.
- task_order = [(j, i) for j in 1:nout for i in 1:(options.populations)]
- shuffle!(task_order)
-
- # Persistent storage of last-saved population for final return:
- last_pops = init_dummy_pops(options.populations, datasets, options)
- # Best 10 members from each population for migration:
- best_sub_pops = init_dummy_pops(options.populations, datasets, options)
- # TODO: Should really be one per population too.
- all_running_search_statistics = [
- RunningSearchStatistics(; options=options) for j in 1:nout
- ]
- # Records the number of evaluations:
- # Real numbers indicate use of batching.
- num_evals = [[0.0 for i in 1:(options.populations)] for j in 1:nout]
-
- halls_of_fame = Vector{HallOfFameType}(undef, nout)
-
- cycles_remaining = [ropt.total_cycles for j in 1:nout]
- cur_maxsizes = [
- get_cur_maxsize(; options, ropt.total_cycles, cycles_remaining=cycles_remaining[j])
- for j in 1:nout
- ]
-
- return SearchState{T,L,typeof(example_ex),WorkerOutputType,ChannelType}(;
- procs=procs,
- we_created_procs=we_created_procs,
- worker_output=worker_output,
- tasks=tasks,
- channels=channels,
- worker_assignment=worker_assignment,
- task_order=task_order,
- halls_of_fame=halls_of_fame,
- last_pops=last_pops,
- best_sub_pops=best_sub_pops,
- all_running_search_statistics=all_running_search_statistics,
- num_evals=num_evals,
- cycles_remaining=cycles_remaining,
- cur_maxsizes=cur_maxsizes,
- stdin_reader=stdin_reader,
- record=Ref(record),
- )
-end
-function _initialize_search!(
- state::SearchState{T,L,N},
- datasets,
- ropt::RuntimeOptions,
- options::Options,
- saved_state,
- idea_database_all,
-) where {T,L,N}
- nout = length(datasets)
-
- init_hall_of_fame = load_saved_hall_of_fame(saved_state)
- if init_hall_of_fame === nothing
- for j in 1:nout
- state.halls_of_fame[j] = HallOfFame(options, datasets[j])
- end
- else
- # Recompute losses for the hall of fame, in
- # case the dataset changed:
- for j in eachindex(init_hall_of_fame, datasets, state.halls_of_fame)
- hof = strip_metadata(init_hall_of_fame[j], options, datasets[j])
- for member in hof.members[hof.exists]
- score, result_loss = score_func(datasets[j], member, options)
- member.score = score
- member.loss = result_loss
- end
- state.halls_of_fame[j] = hof
- end
- end
-
- for j in 1:nout, i in 1:(options.populations)
- worker_idx = assign_next_worker!(
- state.worker_assignment; out=j, pop=i, parallelism=ropt.parallelism, state.procs
- )
- saved_pop = load_saved_population(saved_state; out=j, pop=i)
- new_pop =
- if saved_pop !== nothing && length(saved_pop.members) == options.population_size
- _saved_pop = strip_metadata(saved_pop, options, datasets[j])
- ## Update losses:
- for member in _saved_pop.members
- score, result_loss = score_func(datasets[j], member, options)
- member.score = score
- member.loss = result_loss
- end
- copy_pop = copy(_saved_pop)
- @sr_spawner(
- begin
- (copy_pop, HallOfFame(options, datasets[j]), RecordType(), 0.0)
- end,
- parallelism = ropt.parallelism,
- worker_idx = worker_idx
- )
- else
- if saved_pop !== nothing && ropt.verbosity > 0
- @warn "Recreating population (output=$(j), population=$(i)), as the saved one doesn't have the correct number of members."
- end
- @sr_spawner(
- begin
- (
- Population(
- datasets[j];
- population_size=options.population_size,
- nlength=3,
- options=options,
- nfeatures=datasets[j].nfeatures,
- idea_database=idea_database_all[j],
- ),
- HallOfFame(options, datasets[j]),
- RecordType(),
- Float64(options.population_size),
- )
- end,
- parallelism = ropt.parallelism,
- worker_idx = worker_idx
- )
- # This involves population_size evaluations, on the full dataset:
- end
- push!(state.worker_output[j], new_pop)
- end
- return nothing
-end
-function _warmup_search!(
- state::SearchState{T,L,N},
- datasets,
- ropt::RuntimeOptions,
- options::Options,
- idea_database_all,
-) where {T,L,N}
- nout = length(datasets)
- for j in 1:nout, i in 1:(options.populations)
- dataset = datasets[j]
- running_search_statistics = state.all_running_search_statistics[j]
- cur_maxsize = state.cur_maxsizes[j]
- @recorder state.record[]["out$(j)_pop$(i)"] = RecordType()
- worker_idx = assign_next_worker!(
- state.worker_assignment; out=j, pop=i, parallelism=ropt.parallelism, state.procs
- )
-
- # TODO - why is this needed??
- # Multi-threaded doesn't like to fetch within a new task:
- c_rss = deepcopy(running_search_statistics)
- last_pop = state.worker_output[j][i]
- updated_pop = @sr_spawner(
- begin
- in_pop = first(
- extract_from_worker(last_pop, Population{T,L,N}, HallOfFame{T,L,N})
- )
- _dispatch_s_r_cycle(
- in_pop,
- dataset,
- options;
- pop=i,
- out=j,
- iteration=0,
- ropt.verbosity,
- cur_maxsize,
- running_search_statistics=c_rss,
- idea_database=idea_database_all[j],
- )::DefaultWorkerOutputType{Population{T,L,N},HallOfFame{T,L,N}}
- end,
- parallelism = ropt.parallelism,
- worker_idx = worker_idx
- )
- state.worker_output[j][i] = updated_pop
- end
- return nothing
-end
-function _main_search_loop!(
- state::SearchState{T,L,N},
- datasets,
- ropt::RuntimeOptions,
- options::Options,
- idea_database_all,
-) where {T,L,N}
- ropt.verbosity > 0 && @info "Started!"
- nout = length(datasets)
- start_time = time()
- if ropt.progress
- #TODO: need to iterate this on the max cycles remaining!
- sum_cycle_remaining = sum(state.cycles_remaining)
- progress_bar = WrappedProgressBar(
- 1:sum_cycle_remaining; width=options.terminal_width
- )
- end
- last_print_time = time()
- last_speed_recording_time = time()
- num_evals_last = sum(sum, state.num_evals)
- num_evals_since_last = sum(sum, state.num_evals) - num_evals_last # i.e., start at 0
- print_every_n_seconds = 5
- equation_speed = Float32[]
-
- if ropt.parallelism in (:multiprocessing, :multithreading)
- for j in 1:nout, i in 1:(options.populations)
- # Start listening for each population to finish:
- t = @async put!(state.channels[j][i], fetch(state.worker_output[j][i]))
- push!(state.tasks[j], t)
- end
- end
- kappa = 0
- resource_monitor = ResourceMonitor(;
- # Storing n times as many monitoring intervals as populations seems like it will
- # help get accurate resource estimates:
- max_recordings=options.populations * 100 * nout,
- start_reporting_at=options.populations * 3 * nout,
- window_size=options.populations * 2 * nout,
- )
- n_iterations = 0
- if options.llm_options.active
- open(options.llm_options.llm_recorder_dir * "n_iterations.txt", "a") do file
- write(file, "- " * string(div(n_iterations, options.populations)) * "\n")
- end
- end
- worst_members = Vector{PopMember}()
- while sum(state.cycles_remaining) > 0
- kappa += 1
- if kappa > options.populations * nout
- kappa = 1
- end
- # nout, populations:
- j, i = state.task_order[kappa]
- idea_database = idea_database_all[j]
-
- # Check if error on population:
- if ropt.parallelism in (:multiprocessing, :multithreading)
- if istaskfailed(state.tasks[j][i])
- fetch(state.tasks[j][i])
- error("Task failed for population")
- end
- end
- # Non-blocking check if a population is ready:
- population_ready = if ropt.parallelism in (:multiprocessing, :multithreading)
- # TODO: Implement type assertions based on parallelism.
- isready(state.channels[j][i])
- else
- true
- end
- record_channel_state!(resource_monitor, population_ready)
-
- # Don't start more if this output has finished its cycles:
- # TODO - this might skip extra cycles?
- population_ready &= (state.cycles_remaining[j] > 0)
- if population_ready
- if n_iterations % options.populations == 0
- worst_members = Vector{PopMember}()
- end
- n_iterations += 1
- # Take the fetch operation from the channel since its ready
- (cur_pop, best_seen, cur_record, cur_num_evals) = if ropt.parallelism in
- (
- :multiprocessing, :multithreading
- )
- take!(
- state.channels[j][i]
- )
- else
- state.worker_output[j][i]
- end::DefaultWorkerOutputType{Population{T,L,N},HallOfFame{T,L,N}}
- state.last_pops[j][i] = copy(cur_pop)
- state.best_sub_pops[j][i] = best_sub_pop(cur_pop; topn=options.topn)
- @recorder state.record[] = recursive_merge(state.record[], cur_record)
- state.num_evals[j][i] += cur_num_evals
- dataset = datasets[j]
- cur_maxsize = state.cur_maxsizes[j]
-
- worst_member = nothing
- for member in cur_pop.members
- if worst_member == nothing || worst_member.loss < member.loss
- worst_member = member
- end
- size = compute_complexity(member, options)
- update_frequencies!(state.all_running_search_statistics[j]; size)
- end
- if worst_member != nothing && worst_member.loss > 100 # if the worst of population is good then thats still good to keep
- push!(worst_members, worst_member)
- end
- #! format: off
- update_hall_of_fame!(state.halls_of_fame[j], cur_pop.members, options)
- update_hall_of_fame!(state.halls_of_fame[j], best_seen.members[best_seen.exists], options)
- #! format: on
-
- # Dominating pareto curve - must be better than all simpler equations
- dominating = calculate_pareto_frontier(state.halls_of_fame[j])
- if options.llm_options.active &&
- options.llm_options.prompt_evol &&
- (n_iterations % options.populations == 0)
- update_idea_database(idea_database, dominating, worst_members, options)
- end
-
- if options.save_to_file
- save_to_file(dominating, nout, j, dataset, options)
- end
- ###################################################################
- # Migration #######################################################
- if options.migration
- best_of_each = Population([
- member for pop in state.best_sub_pops[j] for member in pop.members
- ])
- migrate!(
- best_of_each.members => cur_pop, options; frac=options.fraction_replaced
- )
- end
- if options.hof_migration && length(dominating) > 0
- migrate!(dominating => cur_pop, options; frac=options.fraction_replaced_hof)
- end
- ###################################################################
-
- state.cycles_remaining[j] -= 1
- if state.cycles_remaining[j] == 0
- break
- end
- worker_idx = assign_next_worker!(
- state.worker_assignment;
- out=j,
- pop=i,
- parallelism=ropt.parallelism,
- state.procs,
- )
- iteration = if options.use_recorder
- key = "out$(j)_pop$(i)"
- find_iteration_from_record(key, state.record[]) + 1
- else
- 0
- end
-
- c_rss = deepcopy(state.all_running_search_statistics[j])
- in_pop = copy(cur_pop::Population{T,L,N})
- state.worker_output[j][i] = @sr_spawner(
- begin
- _dispatch_s_r_cycle(
- in_pop,
- dataset,
- options;
- pop=i,
- out=j,
- iteration,
- ropt.verbosity,
- cur_maxsize,
- running_search_statistics=c_rss,
- dominating=dominating,
- idea_database=idea_database,
- )
- end,
- parallelism = ropt.parallelism,
- worker_idx = worker_idx
- )
- if ropt.parallelism in (:multiprocessing, :multithreading)
- state.tasks[j][i] = @async put!(
- state.channels[j][i], fetch(state.worker_output[j][i])
- )
- end
-
- state.cur_maxsizes[j] = get_cur_maxsize(;
- options, ropt.total_cycles, cycles_remaining=state.cycles_remaining[j]
- )
- move_window!(state.all_running_search_statistics[j])
- if ropt.progress
- head_node_occupation = estimate_work_fraction(resource_monitor)
- update_progress_bar!(
- progress_bar,
- only(state.halls_of_fame),
- only(datasets),
- options,
- equation_speed,
- head_node_occupation,
- ropt.parallelism,
- )
- end
- end
- yield()
-
- ################################################################
- ## Search statistics
- elapsed_since_speed_recording = time() - last_speed_recording_time
- if elapsed_since_speed_recording > 1.0
- num_evals_since_last, num_evals_last = let s = sum(sum, state.num_evals)
- s - num_evals_last, s
- end
- current_speed = num_evals_since_last / elapsed_since_speed_recording
- push!(equation_speed, current_speed)
- average_over_m_measurements = 20 # 20 second running average
- if length(equation_speed) > average_over_m_measurements
- deleteat!(equation_speed, 1)
- end
- last_speed_recording_time = time()
- end
- ################################################################
-
- ################################################################
- ## Printing code
- elapsed = time() - last_print_time
- # Update if time has passed
- if elapsed > print_every_n_seconds
- if ropt.verbosity > 0 && !ropt.progress && length(equation_speed) > 0
-
- # Dominating pareto curve - must be better than all simpler equations
- head_node_occupation = estimate_work_fraction(resource_monitor)
- print_search_state(
- state.halls_of_fame,
- datasets;
- options,
- equation_speed,
- ropt.total_cycles,
- state.cycles_remaining,
- head_node_occupation,
- parallelism=ropt.parallelism,
- width=options.terminal_width,
- )
- end
- last_print_time = time()
- end
- ################################################################
-
- ################################################################
- ## Early stopping code
- if any((
- check_for_loss_threshold(state.halls_of_fame, options),
- check_for_user_quit(state.stdin_reader),
- check_for_timeout(start_time, options),
- check_max_evals(state.num_evals, options),
- ))
- break
- end
- ################################################################
- end
- if options.llm_options.active
- open(options.llm_options.llm_recorder_dir * "n_iterations.txt", "a") do file
- write(file, "- " * string(div(n_iterations, options.populations)) * "\n")
- end
- end
- return nothing
-end
-function _tear_down!(state::SearchState, ropt::RuntimeOptions, options::Options)
- close_reader!(state.stdin_reader)
- # Safely close all processes or threads
- if ropt.parallelism == :multiprocessing
- state.we_created_procs && rmprocs(state.procs)
- elseif ropt.parallelism == :multithreading
- nout = length(state.worker_output)
- for j in 1:nout, i in eachindex(state.worker_output[j])
- wait(state.worker_output[j][i])
- end
- end
- @recorder json3_write(state.record[], options.recorder_file)
- return nothing
-end
-function _format_output(
- state::SearchState, datasets, ropt::RuntimeOptions, options::Options
-)
- nout = length(datasets)
- out_hof = if ropt.dim_out == 1
- embed_metadata(only(state.halls_of_fame), options, only(datasets))
- else
- map(j -> embed_metadata(state.halls_of_fame[j], options, datasets[j]), 1:nout)
- end
- if ropt.return_state
- return (
- map(j -> embed_metadata(state.last_pops[j], options, datasets[j]), 1:nout),
- out_hof,
- )
- else
- return out_hof
- end
-end
-
-@stable default_mode = "disable" function _dispatch_s_r_cycle(
- in_pop::Population{T,L,N},
- dataset::Dataset,
- options::Options;
- pop::Int,
- out::Int,
- iteration::Int,
- verbosity,
- cur_maxsize::Int,
- running_search_statistics,
- dominating=nothing,
- idea_database=nothing,
-) where {T,L,N}
- record = RecordType()
- @recorder record["out$(out)_pop$(pop)"] = RecordType(
- "iteration$(iteration)" => record_population(in_pop, options)
- )
- num_evals = 0.0
- normalize_frequencies!(running_search_statistics)
- out_pop, best_seen, evals_from_cycle = s_r_cycle(
- dataset,
- in_pop,
- options.ncycles_per_iteration,
- cur_maxsize,
- running_search_statistics;
- verbosity=verbosity,
- options=options,
- record=record,
- dominating=dominating,
- idea_database=idea_database,
- )
- num_evals += evals_from_cycle
- out_pop, evals_from_optimize = optimize_and_simplify_population(
- dataset, out_pop, options, cur_maxsize, record
- )
- num_evals += evals_from_optimize
- if options.batching
- for i_member in 1:(options.maxsize + MAX_DEGREE)
- score, result_loss = score_func(dataset, best_seen.members[i_member], options)
- best_seen.members[i_member].score = score
- best_seen.members[i_member].loss = result_loss
- num_evals += 1
- end
- end
- return (out_pop, best_seen, record, num_evals)
-end
-
-include("MLJInterface.jl")
-using .MLJInterfaceModule: LaSRRegressor, MultitargetLaSRRegressor
-
-function __init__()
- @require_extensions
-end
-
-# Hack to get static analysis to work from within tests:
-@ignore include("../test/runtests.jl")
-
-# TODO: Hack to force ConstructionBase version
-using ConstructionBase: ConstructionBase as _
-
-include("precompile.jl")
-redirect_stdout(devnull) do
- redirect_stderr(devnull) do
- do_precompilation(Val(:precompile))
- end
-end
-
-end #module SR
diff --git a/test/runtests.jl b/test/runtests.jl
index 52a941def..0a484650d 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -178,4 +178,4 @@ end
@testitem "LLM Integration tests" tags = [:part3, :llm] begin
include("test_lasr_integration.jl")
-end
\ No newline at end of file
+end
diff --git a/test/test_expression_derivatives.jl b/test/test_expression_derivatives.jl
index 78cfddbef..e9d2a0a89 100644
--- a/test/test_expression_derivatives.jl
+++ b/test/test_expression_derivatives.jl
@@ -37,7 +37,8 @@ end
@testitem "Test derivatives during optimization" tags = [:part1] begin
using LibraryAugmentedSymbolicRegression
- using LibraryAugmentedSymbolicRegression.ConstantOptimizationModule: Evaluator, GradEvaluator
+ using LibraryAugmentedSymbolicRegression.ConstantOptimizationModule:
+ Evaluator, GradEvaluator
using DynamicExpressions
using Zygote: Zygote
using Random: MersenneTwister
diff --git a/test/test_graph_nodes.jl b/test/test_graph_nodes.jl
index 82a612d60..4f3ebb2fa 100644
--- a/test/test_graph_nodes.jl
+++ b/test/test_graph_nodes.jl
@@ -59,7 +59,8 @@ end
@testitem "GraphNode break connection mutation" tags = [:part1] begin
using LibraryAugmentedSymbolicRegression
- using LibraryAugmentedSymbolicRegression.MutationFunctionsModule: break_random_connection!
+ using LibraryAugmentedSymbolicRegression.MutationFunctionsModule:
+ break_random_connection!
using Random: MersenneTwister
options = Options(;
@@ -92,7 +93,8 @@ end
@testitem "GraphNode form connection mutation" tags = [:part1] begin
using LibraryAugmentedSymbolicRegression
- using LibraryAugmentedSymbolicRegression.MutationFunctionsModule: form_random_connection!
+ using LibraryAugmentedSymbolicRegression.MutationFunctionsModule:
+ form_random_connection!
using Random: MersenneTwister
options = Options(;
diff --git a/test/test_lasr_integration.jl b/test/test_lasr_integration.jl
index 0274bc103..f40cfbc52 100644
--- a/test/test_lasr_integration.jl
+++ b/test/test_lasr_integration.jl
@@ -1,13 +1,13 @@
using LibraryAugmentedSymbolicRegression: LLMOptions, Options
# test that we can partially specify LLMOptions
-op1 = LLMOptions(active=false)
+op1 = LLMOptions(; active=false)
@test op1.active == false
# test that we can fully specify LLMOptions
-op2 = LLMOptions(
+op2 = LLMOptions(;
active=true,
- weights=LLMWeights(llm_mutate=0.5, llm_crossover=0.3, llm_gen_random=0.2),
+ weights=LLMWeights(; llm_mutate=0.5, llm_crossover=0.3, llm_gen_random=0.2),
num_pareto_context=5,
prompt_evol=true,
prompt_concepts=true,
@@ -18,12 +18,14 @@ op2 = LLMOptions(
llm_recorder_dir="test/",
llm_context="test",
var_order=nothing,
- idea_threshold=30
+ idea_threshold=30,
)
@test op2.active == true
# test that we can pass LLMOptions to Options
-llm_opt = LLMOptions(active=false)
-op = Options(; optimizer_options=(iterations=16, f_calls_limit=100, x_tol=1e-16), llm_options=llm_opt)
+llm_opt = LLMOptions(; active=false)
+op = Options(;
+ optimizer_options=(iterations=16, f_calls_limit=100, x_tol=1e-16), llm_options=llm_opt
+)
@test isa(op.llm_options, LLMOptions)
println("Passed.")
diff --git a/test/test_nested_constraints.jl b/test/test_nested_constraints.jl
index 59e89863a..c4527b3ed 100644
--- a/test/test_nested_constraints.jl
+++ b/test/test_nested_constraints.jl
@@ -34,21 +34,33 @@ tree = cos(exp(Node("x1")) + exp(exp(Node("x1") + exp(exp(exp(Node("x1")))))))
x1 = Node("x1")
options = create_options(nothing)
tree = cos(cos(x1)) + cos(x1) + exp(cos(x1))
-@test !LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests(tree, options)
+@test !LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests(
+ tree, options
+)
options = create_options([cos => [cos => 0]])
-@test LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests(tree, options)
+@test LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests(
+ tree, options
+)
options = create_options([cos => [cos => 1]])
-@test !LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests(tree, options)
+@test !LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests(
+ tree, options
+)
options = create_options([cos => [exp => 0]])
-@test !LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests(tree, options)
+@test !LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests(
+ tree, options
+)
options = create_options([exp => [cos => 0]])
-@test LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests(tree, options)
+@test LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests(
+ tree, options
+)
options = create_options([(+) => [(+) => 0]])
-@test LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests(tree, options)
+@test LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests(
+ tree, options
+)
println("Passed.")
diff --git a/test/test_operators.jl b/test/test_operators.jl
index 1d3bf614e..384d4c668 100644
--- a/test/test_operators.jl
+++ b/test/test_operators.jl
@@ -79,7 +79,9 @@ end
],
)
for T in types_to_test
- @test_nowarn LibraryAugmentedSymbolicRegression.assert_operators_well_defined(T, options)
+ @test_nowarn LibraryAugmentedSymbolicRegression.assert_operators_well_defined(
+ T, options
+ )
end
end
@@ -90,7 +92,9 @@ end
unary_operators=[square, cube, log, log2, log10, log1p, sqrt, acosh, neg],
)
for T in types_to_test
- @test_nowarn LibraryAugmentedSymbolicRegression.assert_operators_well_defined(T, options)
+ @test_nowarn LibraryAugmentedSymbolicRegression.assert_operators_well_defined(
+ T, options
+ )
end
end
@@ -115,7 +119,9 @@ end
@test_throws "returned an output of type" LibraryAugmentedSymbolicRegression.assert_operators_well_defined(
Float64, options
)
- @test_nowarn LibraryAugmentedSymbolicRegression.assert_operators_well_defined(Float32, options)
+ @test_nowarn LibraryAugmentedSymbolicRegression.assert_operators_well_defined(
+ Float32, options
+ )
end
@testset "Turbo mode should be the same" begin
diff --git a/test/test_prob_pick_first.jl b/test/test_prob_pick_first.jl
index dbd5cc1cc..9b21b7a06 100644
--- a/test/test_prob_pick_first.jl
+++ b/test/test_prob_pick_first.jl
@@ -38,8 +38,9 @@ for reverse in [false, true]
options=options
)
best_pop_member = [
- LibraryAugmentedSymbolicRegression.best_of_sample(pop, dummy_running_stats, options).score for
- j in 1:100
+ LibraryAugmentedSymbolicRegression.best_of_sample(
+ pop, dummy_running_stats, options
+ ).score for j in 1:100
]
mean_value = sum(best_pop_member) / length(best_pop_member)
diff --git a/test/test_units.jl b/test/test_units.jl
index c7f85130e..6046f8d21 100644
--- a/test/test_units.jl
+++ b/test/test_units.jl
@@ -1,7 +1,8 @@
@testitem "Dimensional analysis" tags = [:part3] begin
using LibraryAugmentedSymbolicRegression
using LibraryAugmentedSymbolicRegression.InterfaceDynamicQuantitiesModule: get_units
- using LibraryAugmentedSymbolicRegression.DimensionalAnalysisModule: violates_dimensional_constraints
+ using LibraryAugmentedSymbolicRegression.DimensionalAnalysisModule:
+ violates_dimensional_constraints
using DynamicQuantities
using DynamicQuantities: DEFAULT_DIM_BASE_TYPE
@@ -102,7 +103,8 @@ end
@testitem "Search with dimensional constraints" tags = [:part3] begin
using LibraryAugmentedSymbolicRegression
- using LibraryAugmentedSymbolicRegression.DimensionalAnalysisModule: violates_dimensional_constraints
+ using LibraryAugmentedSymbolicRegression.DimensionalAnalysisModule:
+ violates_dimensional_constraints
using Random: MersenneTwister
rng = MersenneTwister(0)
@@ -392,7 +394,8 @@ end
@testitem "Dimensionless constants" tags = [:part3] begin
using LibraryAugmentedSymbolicRegression
- using LibraryAugmentedSymbolicRegression.DimensionalAnalysisModule: violates_dimensional_constraints
+ using LibraryAugmentedSymbolicRegression.DimensionalAnalysisModule:
+ violates_dimensional_constraints
using DynamicQuantities
include("utils.jl")
@@ -435,9 +438,11 @@ end
@testitem "Miscellaneous tests of unit interface" tags = [:part3] begin
using LibraryAugmentedSymbolicRegression
using DynamicQuantities
- using LibraryAugmentedSymbolicRegression.DimensionalAnalysisModule: @maybe_return_call, WildcardQuantity
+ using LibraryAugmentedSymbolicRegression.DimensionalAnalysisModule:
+ @maybe_return_call, WildcardQuantity
using LibraryAugmentedSymbolicRegression.MLJInterfaceModule: unwrap_units_single
- using LibraryAugmentedSymbolicRegression.InterfaceDynamicQuantitiesModule: get_dimensions_type
+ using LibraryAugmentedSymbolicRegression.InterfaceDynamicQuantitiesModule:
+ get_dimensions_type
using MLJModelInterface: MLJModelInterface as MMI
function test_return_call(op::Function, w...)