Skip to content

Source code for Generative Adversarial Bayesian Optimization (GABO) for Surrogate Objectives

License

Notifications You must be signed in to change notification settings

michael-s-yao/gabo

Repository files navigation

Generative Adversarial Bayesian Optimization (GABO) for Surrogate Objectives

LICENSE CONTACT CONTACT

Offline model-based policy optimization seeks to optimize a learned surrogate objective function without querying the true oracle objective during optimization. However, inaccurate surrogate model predictions are frequently encountered in this setting. To address this limitation, we propose adaptive source critic regularization that utilizes a Lipschitz-constrained source critic agent to constrain the optimization trajectory to regions where the surrogate performs well. We show that under certain assumptions for the continuous input space prior, we can dynamically adjust the strength of the source critic regularization, which consistently outperforms existing baselines on a number of different optimization tasks across a variety of domains. Our work provides a practical framework for offline policy optimization via source critic regularization.

Installation

To install and run our code, first clone the gabo repository.

cd ~
git clone https://github.com/michael-s-yao/gabo
cd gabo

Next, create a conda environment from the environment.yml file to setup the environment and install the relevant dependencies.

conda env create -f environment.yml
conda activate gabo

If you are running our codebase on a GPU, please also run the following commands:

python -m pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113

There is also a minor versioning conflict between the installed dependences that requires one line in a package to be modified. More specifically, please navigate to the location of the installed transformers package:

cd /home/usr/miniconda3/envs/gabo/lib/python3.8/site-packages/transformers
vim trainer_pt_utils.py

In the trainer_pt_utils.py source code, file, please change the line if version.parse(torch.__version__) <= version.parse("1.4.1"): to if version.parse(torch.__version__) <= version.parse("1.12.1"):. Next, please copy the smiles_vocab.txt file to the design_bench_data package directory:

cd ~/gabo
cp -p data/molecules/smiles_vocab.txt /home/usr/miniconda3/envs/gabo/lib/python3.8/site-packages/design_bench_data/

Similarly, please also follow the directions here to also download the design-bench-associated datasets as well if applicable. Finally, please initialize the submodules associated with the repository. After successful setup, you can run our code as

python mbo/run_gabo.py --help

To replicate our experiments, please refer to the scripts directory for relevant shell scripts.

Running BDI Baseline Experiments

Bidirectional learning for offline infinite-width model-based optimization (BDI) is a baseline method for MBO tasks that we compare against in our experiments. The codebase provided by Chen et al. (2022) depends on the jax and neural-tangents libraries which are version-incompatible with our own dependencies specified in environment.yml. Therefore, if you are interested in replicating our BDI experimental results, please create a new conda environment separate from the one described above using

conda env create -f bdi_environment.yml
conda activate bdi

Of note, you may come across versioning issues with jaxlib, a required dependency for jax. The solution proposed in this GitHub issue worked in our hands.

Contact

Questions and comments are welcome. Suggestions can be submitted through Github issues. Contact information is linked below.

Michael Yao

Osbert Bastani

Citation

If you found our work helpful for your research, please consider citing our paper:

@misc{yaom2024gabo,
  title={Generative Adversarial {Bayesian} Optimization for Surrogate Objectives},
  author={Yao, Michael S. and Zeng, Yimeng and Bastani, Hamsa and Gardner, Jacob and Gee, James C. and Bastani, Osbert},
  journal={arXiv Preprint},
  year={2024},
  eprint={2402.06532},
  archivePrefix={arXiv},
  primaryClass={cs.LG}
}

License

This repository is MIT licensed (see LICENSE).

About

Source code for Generative Adversarial Bayesian Optimization (GABO) for Surrogate Objectives

Topics

Resources

License

Stars

Watchers

Forks