Skip to content

Commit

Permalink
Sodarace prompt mutation (#22)
Browse files Browse the repository at this point in the history
* Merge with the main; Redo the changes: add in prompt mutation components for sodaracer.

* Fix stuff; Small change in docstring to follow the mainstream.

* Add `error_code`. Allow custom diff model class.

* Fix error msg parsing.

* Improved some variable names; Added in `Map.empty`; If map is still empty, force `MapElites` to call `.random()` instead of `.mutate()`.

* Fix batch generation in `Sodarace`.

* Fix batch generation in `Sodarace`.

* Uniformize API for image task.

* omg brain fart...

* fixed other brain farts...

* Fix typos.

* Clean up: use original type instead of type alias; add `__add__` to `__init__`, put default seeds all under `elm.environment`; clean up api for `ELM` and different envs, ...

* Error fixes.

* reordered `image_init_args` so that people see it together with `sodarace_init_args`.

* Minor correction of a comment.

* Deleted my old file that does syntax tree parsing (if the need arises, we can easily add it back); Changed the return of `._post_process` according to Genesis's remark.

* minor fixes

Co-authored-by: Herbie Bradley <[email protected]>
  • Loading branch information
honglu2875 and herbiebradley authored Nov 25, 2022
1 parent bafaa1f commit 9df5ebf
Show file tree
Hide file tree
Showing 17 changed files with 528 additions and 424 deletions.
16 changes: 8 additions & 8 deletions elm/benchmarks_tinygp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from copy import deepcopy
from random import randint, random
from statistics import mean
from typing import Callable, Iterable, List, Optional, Tuple, Union
from typing import Callable, Iterable, Optional, Union

from graphviz import Digraph, Source
from IPython.display import Image, display
Expand Down Expand Up @@ -144,13 +144,13 @@ def draw_tree(self, fname, footer):

def compute_tree(self, b, arg_names=("b1", "b2", "b3", "b4")):
"""
Parameters:
Args:
b: a list/tuple of inputs (b1, b2, b3, b4)
arg_names: argument names.
Returns:
the evaluation at this node.
"""
if not isinstance(b, (List, Tuple)):
if not isinstance(b, (list, tuple)):
raise TypeError(f"Input b must be a list or tuple. Got {type(b)} instead.")

arg_dict = {name: value for name, value in zip(arg_names, b)}
Expand Down Expand Up @@ -237,7 +237,7 @@ def swap_node(
) -> bool:
"""
Swap the name of a variable into another one (only apply to the first encounter of a DFS).
Parameters:
Args:
tree: the GPTree node.
tree_data: the variable name or the function.
target_data: the target variable name or the function.
Expand All @@ -259,7 +259,7 @@ def swap_node(
def eval_tree(tree: GPTree, dataset: Iterable) -> list:
"""
Test the correctness of a GPTree against a dataset.
Parameters:
Args:
tree: the tree to test against.
dataset: (inputs, ground_truth)
Returns
Expand Down Expand Up @@ -287,11 +287,11 @@ def list_equal(l1, l2):


def mutate_compare(
tree: GPTree, num_mutation: int, dataset: Tuple
) -> Tuple[float, float]:
tree: GPTree, num_mutation: int, dataset: tuple
) -> tuple[float, float]:
"""
Mutate (a copy) of the tree num_mutation times, and return the percentage of successful mutations.
Parameters:
Args:
tree: the tree to mutate (will make a copy before mutation).
num_mutation: number of times to mutate.
dataset: the dataset to test against. Format: (input, ground_truth) where input and ground_truth are lists
Expand Down
28 changes: 28 additions & 0 deletions elm/config/elm_image_cfg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
model: codegen-350M-mono
checkpoints_dir: ./checkpoints/
cuda: True
gpus: 1
seed: 42
deterministic: False
fp16: False
top_p: 0.95
temp: 0.85
timeout: 5.0 # Seconds
gen_max_len: 1024
batch_size: 32
evo_init_steps: 10
evo_n_steps: 15
behavior_n_bins: 12
evo_history_length: 10
evaluation_steps: 1000
pad_token: 50256
env_name: "imageoptim"
run_name: ??? # Mandatory string argument that describes the run.

###################################################################################################
# Hydra config overrides:
hydra:
run:
dir: logs/${run_name}
sweep:
dir: logs/${run_name}
6 changes: 3 additions & 3 deletions elm/config/elm_cfg.yaml → elm/config/elm_sodarace_cfg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ top_p: 0.95
temp: 0.85
timeout: 5.0 # Seconds
gen_max_len: 1024
batch_size: 1
evo_init_steps: 1000
evo_n_steps: 10000
batch_size: 32
evo_init_steps: 10
evo_n_steps: 20
behavior_n_bins: 12
evo_history_length: 10
evaluation_steps: 1000
Expand Down
175 changes: 169 additions & 6 deletions elm/diff_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import json
import os
import re
import shutil
from typing import Dict
from abc import ABC, abstractmethod

import numpy as np
import requests
import torch
from omegaconf import DictConfig, OmegaConf

from elm.codegen.codegen_utilities import model_setup, sample, set_seed, truncate
from elm.codegen.codex_execute import (
Expand All @@ -23,7 +29,7 @@ def reset_os_funcs(rmtree, rmdir, chdir):
def unsafe_execute(code_str: str, timeout: int = 5):
if len(code_str) == 0 or "def " not in code_str:
return 6 # No code found or no function found.
code_dct: Dict = {}
code_dct: dict = {}
func_match = re.search(r"def (\w+)\s*\((.*?)\):", code_str)
if func_match:
func_name = func_match.group(1)
Expand Down Expand Up @@ -70,12 +76,169 @@ def unsafe_execute(code_str: str, timeout: int = 5):
return 6 # Code fails to run - other error.


class DiffModel:
def __init__(self, cfg) -> None:
self.cfg = cfg
class Model(ABC):
@abstractmethod
def generate_program(self, seed_str: str) -> dict:
pass


class PromptMutationModel(Model):
func_name: str # the name of the function that we want to execute
import_line: str # the import lines we add to the code
func_preamble: str # the function definition plus possibly a few initial lines to generate codes
return_line: str # the return line we add to the end of the code

def __init__(self, cfg, sandbox_server="http://localhost:5000") -> None:
if isinstance(cfg, str):
self.cfg = OmegaConf.load(cfg)
elif isinstance(cfg, (dict, DictConfig)):
self.cfg = DictConfig(cfg)
else:
raise ValueError

set_seed(self.cfg.seed)
# Use RNG to rotate random seeds during inference.
self.rng = np.random.default_rng(seed=self.cfg.seed)
self.sandbox_server = sandbox_server
self.model, self.tokenizer = model_setup(self.cfg)

def generate_prompt_str(
self,
seed: str,
tokenizer=None,
batch_size=None,
append_return=True,
without_trunc=True,
) -> list[str]:
"""
Args:
seed: the seed text.
tokenizer: (Optional) assign only if you want to use a different tokenizer (default: None)
batch_size: (Optional) override the batch size in config.
append_return: (Optional) append a return line to the code in the end.
without_trunc: (Optional) True if we don't apply the `truncate` function.
Returns:
a list of code(s) generated by the model.
"""
tokenizer = self.tokenizer if tokenizer is None else tokenizer
encoding = tokenizer(
[seed + "\n\n" + self.func_preamble],
truncation=True,
padding=True,
max_length=self.cfg.gen_max_len,
return_tensors="pt",
)

cfg = OmegaConf.merge(
self.cfg,
{"batch_size": self.cfg.batch_size if batch_size is None else batch_size},
)
with torch.no_grad():
completion = sample(cfg, self.model, self.tokenizer, encoding)
# Reset random seed
set_seed(int(self.rng.integers(0, 1e8)))

if without_trunc:
truncation = completion
else:
truncation = [
truncate(code, print_num=float("inf"), only_local_scope=True)
for code in completion
]

truncation = [
self.import_line + "\n" + self.func_preamble + "\n" + code
for code in truncation
]

if append_return:
truncation = [code + "\n" + self.return_line for code in truncation]

return truncation

def generate_program(self, code: str) -> list[dict]:
"""
Given a piece of code, do prompt mutation, call the sandbox server to execute the code and return the result.
Args:
code: the full code string.
Returns:
a numpy array (if successful) or the exception object.
"""
results = []
for code in self.generate_prompt_str(code):
resp = self._get_response(code, self.cfg.timeout)
if resp.status_code == 200:
return_dict = json.loads(resp.text)
self._post_process(return_dict)
error_code = "0"
elif resp.status_code == 500: # Bad request
try:
msg = json.loads(resp.text)
return_dict = {"program_str": code, "result_obj": msg["message"]}
error_code = msg["unsafe_execute_error_code"]
except Exception as e:
return_dict = {"program_str": code, "result_obj": str(e)}
error_code = 6
else:
return_dict = {"program_str": code, "result_obj": resp.text}
error_code = 6

results.append({**return_dict, "error_code": error_code})

return results

@abstractmethod
def _get_response(self, code: str, timeout: int) -> requests.models.Response:
pass

@abstractmethod
def _post_process(self, response_dict: dict) -> dict:
pass


class PromptMutationForSodarace(PromptMutationModel):
func_name: str = "make_walker"
import_line: str = "from .walker import walker_creator"
func_preamble: str = f"def {func_name}():\n\twc = walker_creator()\n"
return_line: str = "\treturn wc.get_walker()\n"

def _get_response(self, code: str, timeout: int) -> requests.models.Response:
return requests.post(
f"{self.sandbox_server}/gen_racer",
json={"code": code, "timeout": timeout},
timeout=timeout,
)

def _post_process(self, response_dict: dict) -> dict:
pass


class PromptMutationForImgTask(PromptMutationModel):
func_name: str = "draw"
import_line: str = "import math\nimport numpy as np"
func_preamble: str = f'def {func_name}():\n\t"""Draw a yellow circle.\n\t"""\n\tpic = np.zeros((32, 32, 3))\n'
return_line: str = ""

def reset_shape(self, shape: tuple):
self.func_preamble = f'def {self.func_name}():\n\t"""Draw a yellow circle.\n\t"""\n\tpic = np.zeros({shape})\n'

def _get_response(self, code: str, timeout: int) -> requests.models.Response:
return requests.post(
f"{self.sandbox_server}/eval_imageoptim_func",
json={"code": code, "func_name": self.func_name, "timeout": timeout},
timeout=timeout,
)

def _post_process(self, response_dict: dict) -> dict:
response_dict["result_obj"] = np.array(response_dict["result_obj"])
return response_dict


# TODO: complete diff model (when it's available)
class DiffModel(Model):
def __init__(self, cfg) -> None:
raise NotImplementedError()

def generate_prompt_str(self, seed, tokenizer):
if self.cfg.env_name == "sodarace":
encoding = tokenizer(
Expand Down Expand Up @@ -106,5 +269,5 @@ def generate_program(self, seed_str: str) -> dict:
sodaracer_dict: dict = execution_result.to_dict()
return {
"program_str": seed_str,
"result_dict": sodaracer_dict,
"result_obj": sodaracer_dict,
}
39 changes: 25 additions & 14 deletions elm/elm.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
from elm.diff_model import DiffModel
from elm.environments import IMAGE_SEED, ImageOptim, Sodarace
from elm.environments.sodaracer import SQUARE_SEED
from elm.environments import ImageOptim, Sodarace, image_init_args, sodarace_init_args
from elm.map_elites import MAPElites

ENVS_DICT = {"sodarace": Sodarace, "imageoptim": ImageOptim}
ARG_DICT = {"sodarace": sodarace_init_args, "imageoptim": image_init_args}


class ELM:
def __init__(self, cfg) -> None:
def __init__(self, cfg, diff_model_cls=None, env_args: dict = None) -> None:
"""
Args:
cfg: the config (e.g. OmegaConf who uses dot to access members).
diff_model_cls: (Optional) The class of diff model. One can apply alternative models here for comparison.
env_args: (Optional) The argument dict for Environment.
"""
self.cfg = cfg
self.diff_model = DiffModel(self.cfg)
if self.cfg.env_name == "sodarace":
self.seed = SQUARE_SEED
elif self.cfg.env_name == "imageoptim":
self.seed = IMAGE_SEED
self.environment = ENVS_DICT[self.cfg.env_name](
seed=self.seed,
diff_model=self.diff_model,
eval_steps=self.cfg.evaluation_steps,
)

# Get the defaults if `env_args` is not specified.
if env_args is None:
env_args = ARG_DICT[self.cfg.env_name]
env_args["config"] = self.cfg # Override default environment config

# Override diff model if `diff_model_cls` is specified.
if diff_model_cls is not None:
self.diff_model = diff_model_cls(self.cfg)
env_args = {**env_args, "diff_model": self.diff_model}
else:
self.diff_model = None

self.seed = env_args["seed"]

self.environment = ENVS_DICT[self.cfg.env_name](**env_args)
self.map_elites = MAPElites(
self.environment,
n_bins=self.cfg.behavior_n_bins,
Expand Down
Loading

0 comments on commit 9df5ebf

Please sign in to comment.