Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure simple program works #24

Merged
merged 9 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 111 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ LaSR uses the same interface as [SymbolicRegression.jl](https://github.com/Miles

For example, we can modify the `example.jl` from the SymbolicRegression.jl documentation to use LaSR as follows:

> [!NOTE]
> LaSR searches for the LLM query prompts in a a directory called `prompts/` at the location you start Julia. You can download and extract the `prompts.zip` folder from [here](https://github.com/trishullab/LibraryAugmentedSymbolicRegression.jl/raw/refs/heads/master/prompts.zip) to the desired location. If you wish to use a different location, you can pass a different `prompts_dir` argument to the `LLMOptions` object.

```julia
import LibraryAugmentedSymbolicRegression: LaSRRegressor, LLMOptions, LLMWeights
import MLJ: machine, fit!, predict, report
Expand All @@ -80,17 +83,22 @@ model = LaSRRegressor(
llm_options=LLMOptions(
active=true,
weights=LLMWeights(llm_mutate=0.1, llm_crossover=0.1, llm_gen_random=0.1),
promtp_evol=true,
prompt_evol=true,
prompt_concepts=true,
api_key="token-abc123",
prompts_dir="prompts/",
llm_recorder_dir="lasr_runs/debug_0/",
model="meta-llama/Meta-Llama-3-8B-Instruct",
api_kwargs=Dict("url" => "http://localhost:11440/v1"),
var_order=Dict("a" => "angle", "b" => "bias")
var_order=Dict("a" => "angle", "b" => "bias"),
llm_context="We believe the function to be a trigonometric function of the angle and a quadratic function of the bias.",
)
)
mach = machine(model, X, y)

# ensure ./prompts/ exists. If not, download and extract the prompts.zip file from the repository.
fit!(mach)
# open ./lasr_runs/debug_0/llm_calls.txt to see the LLM interactions.
report(mach)
predict(mach, X)
```
Expand All @@ -113,10 +121,12 @@ llm_options = LLMOptions(
model="meta-llama/Meta-Llama-3-8B-Instruct", # LLM model to use.
api_kwargs=Dict("url" => "http://localhost:11440/v1"), # Keyword arguments passed to server.
http_kwargs=Dict("retries" => 3, "readtimeout" => 3600), # Keyword arguments passed to HTTP requests.
llm_recorder_dir="lasr_runs/debug_0", # Directory to log LLM interactions.
prompts_dir="prompts/", # Directory to look for zero shot prompts to the LLM.
llm_recorder_dir="lasr_runs/debug_0/", # Directory to log LLM interactions.
llm_context="", # Natural language concept to start with. You should also be able to initialize with a list of concepts.
var_order=nothing, # Dict(variable_name => new_name).
idea_threshold=30 # Number of concepts to keep track of.
is_parametric=false, # This is a special flag to allow sampling parametric equations from LaSR. This won't be needed for most users.
)
```

Expand All @@ -135,3 +145,101 @@ LibraryAugmentedSymbolicRegression.jl development is kept independent from the m

> [!NOTE]
> The `ext/SymbolicRegressionLaSRExt` module is not yet available in the released version of SymbolicRegression.jl. It will be available in the release `vX.X.X` of SymbolicRegression.jl.



## Running with Ollama

LaSR can be paired with any LLM server that is compatible with OpenAI's API. Ollama is a free and open-source LLM server geared towards running LLMs on commodity laptops. You can download and setup Ollama from [here](https://ollama.com/download). After this, run:

```bash
$ ollama help
Large language model runner

Usage:
ollama [flags]
ollama [command]

Available Commands:
serve Start ollama
create Create a model from a Modelfile
show Show information for a model
run Run a model
stop Stop a running model
pull Pull a model from a registry
push Push a model to a registry
list List models
ps List running models
cp Copy a model
rm Remove a model
help Help about any command

Flags:
-h, --help help for ollama
-v, --version Show version information

Use "ollama [command] --help" for more information about a command.
$ ollama pull llama3.1
# This downloads a 4GB-ish file that contains the Llama3.1 8B model.
# Ollama, by default, runs on port 11434 of your local machine. Let's try a debug query to make sure we can connect to Ollama.
$ curl http://localhost:11434/v1/models
{"object":"list","data":[{"id":"llama3.1:latest","object":"model","created":1730973855,"owned_by":"library"},{"id":"mistral:latest","object":"model","created":1697556753,"owned_by":"library"},{"id":"wizard-math:latest","object":"model","created":1697556753,"owned_by":"library"},{"id":"codellama:latest","object":"model","created":1693414395,"owned_by":"library"},{"id":"nous-hermes-llama2:latest","object":"model","created":1691000950,"owned_by":"library"}]}

$ curl http://localhost:11434/v1/completions -H "Content-Type: application/json" -H "Authorization: Bearer token-abc123" -d '{
"model": "llama3.1:latest",
"prompt": "Once upon a time,",
"max_tokens": 50,
"temperature": 0.7
}'

{"id":"cmpl-626","object":"text_completion","created":1730977391,"model":"llama3.1:latest","system_fingerprint":"fp_ollama","choices":[{"text":"...in a far-off kingdom, hidden behind a veil of sparkling mist and whispering leaves, there existed a magical realm unlike any other.","index":0,"finish_reason":"stop"}],"usage":{"prompt_tokens":15,"completion_tokens":29,"total_tokens":44}}
```

Now, we can run the simple example in Julia with model_name as `llama3.1:latest` and the HTTP URL as `http://localhost:11434/v1`:

```julia
import LibraryAugmentedSymbolicRegression: LaSRRegressor, LLMOptions, LLMWeights
import MLJ: machine, fit!, predict, report

# Dataset with two named features:
X = (a = rand(500), b = rand(500))

# and one target:
y = @. 2 * cos(X.a * 23.5) - X.b ^ 2

# with some noise:
y = y .+ randn(500) .* 1e-3

model = LaSRRegressor(
niterations=50,
binary_operators=[+, -, *],
unary_operators=[cos],
llm_options=LLMOptions(
active=true,
weights=LLMWeights(llm_mutate=0.1, llm_crossover=0.1, llm_gen_random=0.1),
prompt_evol=true,
prompt_concepts=true,
api_key="token-abc123",
prompts_dir="prompts/",
llm_recorder_dir="lasr_runs/debug_0/",
model="llama3.1:latest",
api_kwargs=Dict("url" => "http://127.0.0.1:11434/v1"),
var_order=Dict("a" => "angle", "b" => "bias"),
llm_context="We believe the function to be a trigonometric function of the angle and a quadratic function of the bias."
)
)

mach = machine(model, X, y)
fit!(mach)
# julia> fit!(mach)
# [ Info: Training machine(LaSRRegressor(binary_operators = Function[+, -, *], …), …).
# ┌ Warning: You are using multithreading mode, but only one thread is available. Try starting julia with `--threads=auto`.
# └ @ LibraryAugmentedSymbolicRegression ~/Desktop/projects/004_scientific_discovery/LibraryAugmentedSymbolicRegression.jl/src/Configure.jl:55
# [ Info: Tokens: 476 in 22.4 seconds
# [ Info: Started!
# [ Info: Tokens: 542 in 49.2 seconds
# [ Info: Tokens: 556 in 51.1 seconds
# [ Info: Tokens: 573 in 53.2 seconds
report(mach)
predict(mach, X)
```
Binary file added prompts.zip
Binary file not shown.
11 changes: 6 additions & 5 deletions src/LLMFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ function gen_llm_random_tree(
http_kwargs=convertDict(options.llm_options.http_kwargs),
)
catch e
llm_recorder(options.llm_options, "None", "gen_random|failed")
llm_recorder(options.llm_options, "None " * string(e), "gen_random|failed")
return gen_random_tree_fixed_size(node_count, options, nfeatures, T)
end
llm_recorder(options.llm_options, string(msg.content), "llm_output|gen_random")
Expand Down Expand Up @@ -499,7 +499,7 @@ function prompt_evol(idea_database, options::Options)
http_kwargs=convertDict(options.llm_options.http_kwargs),
)
catch e
llm_recorder(options.llm_options, "None", "ideas|failed")
llm_recorder(options.llm_options, "None " * string(e), "ideas|failed")
return nothing
end
llm_recorder(options.llm_options, string(msg.content), "llm_output|ideas")
Expand Down Expand Up @@ -648,7 +648,7 @@ function update_idea_database(idea_database, dominating, worst_members, options:
http_kwargs=convertDict(options.llm_options.http_kwargs),
)
catch e
llm_recorder(options.llm_options, "None", "ideas|failed")
llm_recorder(options.llm_options, "None " * string(e), "ideas|failed")
return nothing
end

Expand Down Expand Up @@ -754,7 +754,8 @@ function llm_mutate_op(
http_kwargs=convertDict(options.llm_options.http_kwargs),
)
catch e
llm_recorder(options.llm_options, "None", "mutate|failed")
llm_recorder(options.llm_options, "None " * string(e), "mutate|failed")
# log error in llm_recorder
return tree
end

Expand Down Expand Up @@ -869,7 +870,7 @@ function llm_crossover_trees(
http_kwargs=convertDict(options.llm_options.http_kwargs),
)
catch e
llm_recorder(options.llm_options, "None", "crossover|failed")
llm_recorder(options.llm_options, "None " * string(e), "crossover|failed")
return tree1, tree2
end

Expand Down
2 changes: 1 addition & 1 deletion src/LLMOptions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ function validate_llm_options(options::LLMOptions)
throw(ArgumentError("api_kwargs must have a 'url' key."))
end
if !isdir(options.prompts_dir)
throw(ArgumentError("prompts_dir must be a valid directory."))
throw(ArgumentError("options.prompts_dir not found."))
end
end
end
Expand Down
Loading