Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update documentation #8

Merged
merged 17 commits into from
Sep 17, 2024
Merged
3 changes: 1 addition & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ jobs:
- "part2"
- "part3"
julia-version:
- "1.6"
- "1.8"
- "1.10"
- "1"
os:
- ubuntu-latest
Expand Down
14 changes: 14 additions & 0 deletions CITATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,17 @@ To cite symbolic distillation of neural networks, the following BibTeX entry can
primaryClass={cs.LG}
}
```

To cite LaSR, please use the following BibTeX entry:

```bibtex
@misc{grayeli2024symbolicregressionlearnedconcept,
title={Symbolic Regression with a Learned Concept Library},
author={Arya Grayeli and Atharva Sehgal and Omar Costilla-Reyes and Miles Cranmer and Swarat Chaudhuri},
year={2024},
eprint={2409.09359},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2409.09359},
}
```
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Distributed = "<0.0.1, 1"
DynamicExpressions = "0.19.3"
DynamicQuantities = "0.10, 0.11, 0.12, 0.13, 0.14, 1"
Enzyme = "0.12"
JSON = "0.21.4"
JSON = "0.21"
JSON3 = "1"
LineSearches = "7"
LossFunctions = "0.10, 0.11"
Expand All @@ -64,14 +64,14 @@ Pkg = "<0.0.1, 1"
PrecompileTools = "1"
Printf = "<0.0.1, 1"
ProgressBars = "~1.4, ~1.5"
PromptingTools = "0.54.0"
PromptingTools = "0.53, 0.54"
Random = "<0.0.1, 1"
Reexport = "1"
SpecialFunctions = "0.10.1, 1, 2"
StatsBase = "0.33, 0.34"
SymbolicUtils = "0.19, ^1.0.5, 2, 3"
TOML = "<0.0.1, 1"
julia = "1.6"
julia = "1.10"

[extras]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand Down
198 changes: 40 additions & 158 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,30 +1,46 @@
<!-- prettier-ignore-start -->
<div align="center">

LaSR.jl accelerates the search for symbolic expressions using library learning.
LibraryAugmentedSymbolicRegression.jl (LaSR.jl) accelerates the search for symbolic expressions using library learning.

| Latest release | Website | Forums | Paper |
| :---: | :---: | :---: | :---: |
| [![version](https://juliahub.com/docs/LaSR/version.svg)](https://juliahub.com/ui/Packages/LaSR/X2eIS) | [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://trishullab.github.io/lasr-web/) | [![Discussions](https://img.shields.io/badge/discussions-github-informational)](https://github.com/trishullab/LaSR.jl/discussions) | [![Paper](https://img.shields.io/badge/arXiv-????.?????-b31b1b)](https://atharvas.net/static/lasr.pdf) |
| [![version](https://juliahub.com/docs/LibraryAugmentedSymbolicRegression/version.svg)](https://juliahub.com/ui/Packages/LibraryAugmentedSymbolicRegression/X2eIS) | [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://trishullab.github.io/lasr-web/) | [![Discussions](https://img.shields.io/badge/discussions-github-informational)](https://github.com/trishullab/LibraryAugmentedSymbolicRegression.jl/discussions) | [![Paper](https://img.shields.io/badge/arXiv-2409.09359-b31b1b)](https://arxiv.org/abs/2409.09359) |

| Build status | Coverage |
| :---: | :---: |
| [![CI](https://github.com/trishullab/LaSR.jl/workflows/CI/badge.svg)](.github/workflows/CI.yml) | [![Coverage Status](https://coveralls.io/repos/github/trishullab/LaSR.jl/badge.svg?branch=master)](https://coveralls.io/github/trishullab/LaSR.jl?branch=master) |
| [![CI](https://github.com/trishullab/LibraryAugmentedSymbolicRegression.jl/workflows/CI/badge.svg)](.github/workflows/CI.yml) | [![Coverage Status](https://coveralls.io/repos/github/trishullab/LibraryAugmentedSymbolicRegression.jl/badge.svg?branch=master)](https://coveralls.io/github/trishullab/LibraryAugmentedSymbolicRegression.jl?branch=master) |

LaSR is integrated with [SymbolicRegression.jl](https://github.com/MilesCranmer/SymbolicRegression.jl). Check out [PySR](https://github.com/MilesCranmer/PySR) for
a Python frontend.

[Cite this software](https://arxiv.org/abs/????.?????)
[Cite this software](https://arxiv.org/abs/2409.09359)

</div>
<!-- prettier-ignore-end -->

**Contents**:

- [Benchmarking](#benchmarking)
- [Quickstart](#quickstart)
- [Organization](#organization)
- [LLM Utilities](#llm-utilities)

## Benchmarking

If you'd like to compare with LaSR, we've archived the code used in the paper in the `lasr-experiments` branch. Clone this repository and run:
```bash
$ git switch lasr-experiments
```
to switch to the branch and follow the instructions in the README to reproduce our results. This directory contains the code for evaluating LaSR on the

- [x] Feynman Equations dataset
- [x] Synthetic equations dataset
- [x] and generation code
- [x] Bigbench experiments
- [x] and evaluation code


## Quickstart

Install in Julia with:
Expand All @@ -34,7 +50,7 @@ using Pkg
Pkg.add("LibraryAugmentedSymbolicRegression")
```

LaSR uses the same interface as [SymbolicRegression.jl](https://github.com/MilesCranmer/SymbolicRegression.jl). The easiest way to use LaSR.jl
LaSR uses the same interface as [SymbolicRegression.jl](https://github.com/MilesCranmer/SymbolicRegression.jl). The easiest way to use LibraryAugmentedSymbolicRegression.jl
is with [MLJ](https://github.com/alan-turing-institute/MLJ.jl).
Let's see an example:

Expand Down Expand Up @@ -105,173 +121,39 @@ where here we choose to evaluate the second equation.

For fitting multiple outputs, one can use `MultitargetLaSRRegressor`
(and pass an array of indices to `idx` in `predict` for selecting specific equations).
For a full list of options available to each regressor, see the [API page](https://astroautomata.com/LaSR.jl/dev/api/).
For a full list of options available to each regressor, see the [API page](https://astroautomata.com/LibraryAugmentedSymbolicRegression.jl/dev/api/).

### LLM Options

LaSR uses PromptingTools.jl for zero shot prompting. If you wish to make changes to the prompting options, you can pass an `LLMOptions` object to the `LaSRRegressor` constructor. The options available are:
```julia
llm_options = LLMOptions(
...
active=true, # Whether to use LLM inference or not
weights=LLMWeights(llm_mutate=0.5, llm_crossover=0.3, llm_gen_random=0.2), # Probabilities of using LLM for mutation, crossover, and random generation
num_pareto_context=5, # Number of equations to sample from the Pareto frontier for summarization.
prompt_evol=true, # Whether to evolve natural language concepts through LLM calls.
prompt_concepts=true, # Whether to use natural language concepts in the search.
api_key="token-abc123", # API key to OpenAI API compatible server.
model="meta-llama/Meta-Llama-3-8B-Instruct", # LLM model to use.
api_kwargs=Dict("url" => "http://localhost:11440/v1"), # Keyword arguments passed to server.
http_kwargs=Dict("retries" => 3, "readtimeout" => 3600), # Keyword arguments passed to HTTP requests.
llm_recorder_dir="lasr_runs/debug_0", # Directory to log LLM interactions.
llm_context="", # Natural language concept to start with. You should also be able to initialize with a list of concepts.
var_order=nothing, # Dict(variable_name => new_name).
idea_threshold=30 # Number of concepts to keep track of.
)
```


## Organization

LaSR.jl development is kept independent from the main codebase. However, to ensure LaSR can be used easily, it is integrated into SymbolicRegression.jl via the `ext/SymbolicRegressionLaSRExt` extension module. This, in turn, is loaded into PySR. This cartoon summarizes the interaction between the different packages:

![LaSR.jl organization](https://raw.githubusercontent.com/trishullab/lasr-web/main/static/lasr-code-interactions.svg)

## Code structure

LaSR.jl is organized roughly as follows.
Rounded rectangles indicate objects, and rectangles indicate functions.

> (if you can't see this diagram being rendered, try pasting it into [mermaid-js.github.io/mermaid-live-editor](https://mermaid-js.github.io/mermaid-live-editor))

```mermaid
flowchart TB
op([Options])
d([Dataset])
op --> ES
d --> ES
subgraph ES[equation_search]
direction TB
IP[sr_spawner]
IP --> p1
IP --> p2
subgraph p1[Thread 1]
direction LR
pop1([Population])
pop1 --> src[s_r_cycle]
src --> opt[optimize_and_simplify_population]
opt --> pop1
end
subgraph p2[Thread 2]
direction LR
pop2([Population])
pop2 --> src2[s_r_cycle]
src2 --> opt2[optimize_and_simplify_population]
opt2 --> pop2
end
pop1 --> hof
pop2 --> hof
hof([HallOfFame])
hof --> migration
pop1 <-.-> migration
pop2 <-.-> migration
migration[migrate!]
end
ES --> output([HallOfFame])
```
LibraryAugmentedSymbolicRegression.jl development is kept independent from the main codebase. However, to ensure LaSR can be used easily, it is integrated into SymbolicRegression.jl via the [`ext/SymbolicRegressionLaSRExt`](https://www.example.com) extension module. This, in turn, is loaded into PySR. This cartoon summarizes the interaction between the different packages:

The `HallOfFame` objects store the expressions with the lowest loss seen at each complexity.

The dependency structure of the code itself is as follows:

```mermaid
stateDiagram-v2
AdaptiveParsimony --> Mutate
AdaptiveParsimony --> Population
AdaptiveParsimony --> RegularizedEvolution
AdaptiveParsimony --> SingleIteration
AdaptiveParsimony --> LaSR
CheckConstraints --> Mutate
CheckConstraints --> LaSR
Complexity --> CheckConstraints
Complexity --> HallOfFame
Complexity --> LossFunctions
Complexity --> Mutate
Complexity --> Population
Complexity --> SearchUtils
Complexity --> SingleIteration
Complexity --> LaSR
ConstantOptimization --> Mutate
ConstantOptimization --> SingleIteration
Core --> AdaptiveParsimony
Core --> CheckConstraints
Core --> Complexity
Core --> ConstantOptimization
Core --> HallOfFame
Core --> InterfaceDynamicExpressions
Core --> LossFunctions
Core --> Migration
Core --> Mutate
Core --> MutationFunctions
Core --> PopMember
Core --> Population
Core --> Recorder
Core --> RegularizedEvolution
Core --> SearchUtils
Core --> SingleIteration
Core --> LaSR
Dataset --> Core
HallOfFame --> SearchUtils
HallOfFame --> SingleIteration
HallOfFame --> LaSR
InterfaceDynamicExpressions --> LossFunctions
InterfaceDynamicExpressions --> LaSR
LossFunctions --> ConstantOptimization
LossFunctions --> HallOfFame
LossFunctions --> Mutate
LossFunctions --> PopMember
LossFunctions --> Population
LossFunctions --> LaSR
Migration --> LaSR
Mutate --> RegularizedEvolution
MutationFunctions --> Mutate
MutationFunctions --> Population
MutationFunctions --> LaSR
Operators --> Core
Operators --> Options
Options --> Core
OptionsStruct --> Core
OptionsStruct --> Options
PopMember --> ConstantOptimization
PopMember --> HallOfFame
PopMember --> Migration
PopMember --> Mutate
PopMember --> Population
PopMember --> RegularizedEvolution
PopMember --> SingleIteration
PopMember --> LaSR
Population --> Migration
Population --> RegularizedEvolution
Population --> SearchUtils
Population --> SingleIteration
Population --> LaSR
ProgramConstants --> Core
ProgramConstants --> Dataset
ProgressBars --> SearchUtils
ProgressBars --> LaSR
Recorder --> Mutate
Recorder --> RegularizedEvolution
Recorder --> SingleIteration
Recorder --> LaSR
RegularizedEvolution --> SingleIteration
SearchUtils --> LaSR
SingleIteration --> LaSR
Utils --> CheckConstraints
Utils --> ConstantOptimization
Utils --> Options
Utils --> PopMember
Utils --> SingleIteration
Utils --> LaSR
```
![LibraryAugmentedSymbolicRegression.jl organization](https://raw.githubusercontent.com/trishullab/lasr-web/main/static/lasr-code-interactions.svg)

Bash command to generate dependency structure from `src` directory (requires `vim-stream`):
> [!NOTE]
> The `ext/SymbolicRegressionLaSRExt` module is not yet available in the released version of SymbolicRegression.jl. It will be available in the release `vX.X.X` of SymbolicRegression.jl.

```bash
echo 'stateDiagram-v2'
IFS=$'\n'
for f in *.jl; do
for line in $(cat $f | grep -e 'import \.\.' -e 'import \.'); do
echo $(echo $line | vims -s 'dwf:d$' -t '%s/^\.*//g' '%s/Module//g') $(basename "$f" .jl);
done;
done | vims -l 'f a--> ' | sort
```

## Search options

See https://astroautomata.com/LaSR.jl/stable/api/#Options
Other than `LLMOptions`, We have the same search options as SymbolicRegression.jl. See https://astroautomata.com/SymbolicRegression.jl/stable/api/#Options
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
module LaSRSymbolicUtilsExt

using SymbolicUtils: Symbolic
using LibraryAugmentedSymbolicRegression: AbstractExpressionNode, AbstractExpression, Node, Options
using LibraryAugmentedSymbolicRegression.MLJInterfaceModule: AbstractSRRegressor, get_options
using LibraryAugmentedSymbolicRegression:
AbstractExpressionNode, AbstractExpression, Node, Options
using LibraryAugmentedSymbolicRegression.MLJInterfaceModule:
AbstractSRRegressor, get_options
using DynamicExpressions: get_tree, get_operators

import LibraryAugmentedSymbolicRegression: node_to_symbolic, symbolic_to_node
Expand Down
4 changes: 3 additions & 1 deletion src/Configure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,9 @@ function test_module_on_workers(procs, options::Options, verbosity)
for proc in procs
push!(
futures,
@spawnat proc LibraryAugmentedSymbolicRegression.gen_random_tree(3, options, 5, TEST_TYPE)
@spawnat proc LibraryAugmentedSymbolicRegression.gen_random_tree(
3, options, 5, TEST_TYPE
)
)
end
for future in futures
Expand Down
18 changes: 12 additions & 6 deletions src/LLMFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using DynamicExpressions:
AbstractOperatorEnum
using Compat: Returns, @inline
using ..CoreModule: Options, DATA_TYPE, binopmap, unaopmap, LLMOptions
using ..MutationFunctionsModule: gen_random_tree_fixed_size
using ..MutationFunctionsModule: gen_random_tree_fixed_size, random_node_and_parent

using PromptingTools:
SystemMessage,
Expand Down Expand Up @@ -58,12 +58,13 @@ function convertDict(d)::NamedTuple
end

function get_vars(options::Options)::String
variable_names = ["x", "y", "z", "k", "j", "l", "m", "n", "p", "a", "b"]
if !isnothing(options.llm_options.var_order)
if !isnothing(options.llm_options) && !isnothing(options.llm_options.var_order)
variable_names = [
options.llm_options.var_order[key] for
key in sort(collect(keys(options.llm_options.var_order)))
]
else
variable_names = ["x", "y", "z", "k", "j", "l", "m", "n", "p", "a", "b"]
end
return join(variable_names, ", ")
end
Expand Down Expand Up @@ -105,6 +106,7 @@ function construct_prompt(
# if n_occurrences is less than |element_list|, add the missing elements after the last occurrence
if n_occurrences < length(element_list)
last_occurrence = findlast(x -> occursin(pattern, x), lines)
@assert last_occurrence !== nothing "No occurrences of the element_id_tag found in the user prompt."
for i in reverse((n_occurrences + 1):length(element_list))
new_line = replace(lines[last_occurrence], string(n_occurrences) => string(i))
insert!(lines, last_occurrence + 1, new_line)
Expand Down Expand Up @@ -544,6 +546,9 @@ function parse_msg_content(msg_content)

try
out = parse(content) # json parse
if out === nothing
return []
end
if out isa Dict
return [out[key] for key in keys(out)]
end
Expand Down Expand Up @@ -893,7 +898,7 @@ function llm_crossover_trees(
String(strip(cross_tree_options[1], [' ', '\n', '"', ',', '.', '[', ']'])),
options,
)

llm_recorder(options.llm_options, tree_to_expr(t, options), "crossover")

return t, tree2
Expand Down Expand Up @@ -934,10 +939,11 @@ function llm_crossover_trees(
)
end

recording_str = tree_to_expr(cross_tree1, options) * " && " * tree_to_expr(cross_tree2, options)
recording_str =
tree_to_expr(cross_tree1, options) * " && " * tree_to_expr(cross_tree2, options)
llm_recorder(options.llm_options, recording_str, "crossover")

return cross_tree1, cross_tree2
end

end
end
Loading
Loading