Skip to content

Commit

Permalink
Merge pull request #192 from yonatank93/update_uq_test
Browse files Browse the repository at this point in the history
Refactor UQ tests
  • Loading branch information
mjwen authored Oct 8, 2024
2 parents 8fcf4f2 + 21f632c commit f2a4c90
Show file tree
Hide file tree
Showing 7 changed files with 549 additions and 198 deletions.
1 change: 1 addition & 0 deletions docs/source/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ Tutorials
tutorials/nn_SiC
tutorials/parameter_transform
tutorials/uq_mcmc
tutorials/uq_bootstrap
tutorials/lennard_jones
tutorials/linear_regression
269 changes: 269 additions & 0 deletions docs/source/tutorials/uq_bootstrap.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "656f7afc",
"metadata": {},
"source": [
"# Bootstrapping\n",
"\n",
"In this example, we demonstrate how to perform uncertainty quantification (UQ) using\n",
"bootstrap method. We use a Stillinger-Weber (SW) potential for silicon that is archived\n",
"in OpenKIM_.\n",
"\n",
"For simplicity, we only set the energy-scaling parameters, i.e., ``A`` and ``lambda`` as\n",
"the tunable parameters. These parameters will be calibrated to energies and forces of a\n",
"small dataset, consisting of 4 compressed and stretched configurations of diamond silicon\n",
"structure."
]
},
{
"cell_type": "markdown",
"id": "98b590d7",
"metadata": {},
"source": [
"To start, let's first install the SW model::\n",
"\n",
"$ kim-api-collections-management install user SW_StillingerWeber_1985_Si__MO_405512056662_006\n",
"\n",
".. seealso::\n",
" This installs the model and its driver into the ``User Collection``. See\n",
" :ref:`install_model` for more information about installing KIM models."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e617de91",
"metadata": {
"ExecuteTime": {
"end_time": "2024-10-07T12:37:31.393276Z",
"start_time": "2024-10-07T12:37:29.472146Z"
}
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"from kliff.calculators import Calculator\n",
"from kliff.dataset import Dataset\n",
"from kliff.loss import Loss\n",
"from kliff.models import KIMModel\n",
"from kliff.uq.bootstrap import BootstrapEmpiricalModel\n",
"from kliff.utils import download_dataset\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"id": "57f71678",
"metadata": {},
"source": [
"Before running bootstrap, we need to define a loss function and train the model. More\n",
"detail information about this step can be found in :ref:`tut_kim_sw` and :ref:`tut_params_transform`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a3aa1d13",
"metadata": {
"ExecuteTime": {
"end_time": "2024-10-07T12:37:59.004347Z",
"start_time": "2024-10-07T12:37:58.979490Z"
}
},
"outputs": [],
"source": [
"# Create the model\n",
"model = KIMModel(model_name=\"SW_StillingerWeber_1985_Si__MO_405512056662_006\")\n",
"\n",
"# Set the tunable parameters and the initial guess\n",
"opt_params = {\"A\": [[\"default\"]], \"lambda\": [[\"default\"]]}\n",
"\n",
"model.set_opt_params(**opt_params)\n",
"model.echo_opt_params()\n",
"\n",
"# Get the dataset\n",
"dataset_path = download_dataset(dataset_name=\"Si_training_set_4_configs\")\n",
"# Read the dataset\n",
"tset = Dataset(dataset_path)\n",
"configs = tset.get_configs()\n",
"\n",
"# Create calculator\n",
"calc = Calculator(model)\n",
"# Only use the forces data\n",
"ca = calc.create(configs, use_energy=False, use_forces=True)\n",
"\n",
"# Instantiate the loss function\n",
"residual_data = {\"normalize_by_natoms\": False}\n",
"loss = Loss(calc, residual_data=residual_data)"
]
},
{
"cell_type": "markdown",
"id": "39a95904",
"metadata": {},
"source": [
"To perform UQ by bootstrapping, the general workflow starts by instantiating :class:`~kliff.uq.bootstrap.BootstrapEmpiricalModel`, or :class:`~kliff.uq.bootstrap.BootstrapNeuralNetworkModel` if using a neural network\n",
"potential."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "966614ab",
"metadata": {
"ExecuteTime": {
"end_time": "2024-10-07T12:38:38.479190Z",
"start_time": "2024-10-07T12:38:38.475357Z"
}
},
"outputs": [],
"source": [
"# Instantiate bootstrap class object\n",
"BS = BootstrapEmpiricalModel(loss, seed=1717)"
]
},
{
"cell_type": "markdown",
"id": "b8be6029",
"metadata": {},
"source": [
"Then, we generate some bootstrap compute arguments. This is equivalent to generating\n",
"bootstrap data. Typically, we just need to specify how many bootstrap data samples to\n",
"generate. Additionally, if we call ``generate_bootstrap_compute_arguments`` multiple\n",
"times, the new generated data samples will be appended to the previously generated data\n",
"samples. This is also the behavior if we read the data samples from the previously\n",
"exported file."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e660eb87",
"metadata": {
"ExecuteTime": {
"end_time": "2024-10-07T12:39:14.455217Z",
"start_time": "2024-10-07T12:39:14.442511Z"
}
},
"outputs": [],
"source": [
"# Generate bootstrap compute arguments\n",
"BS.generate_bootstrap_compute_arguments(100)"
]
},
{
"cell_type": "markdown",
"id": "898350eb",
"metadata": {},
"source": [
"Finally, we will iterate over these bootstrap data samples and train the potential\n",
"using each data sample. The resulting optimal parameters from each data sample give a\n",
"single sample of parameters. By iterating over all data samples, then we will get an\n",
"ensemble of parameters.\n",
"\n",
"Note that the mapping from the bootstrap dataset to the parameters involve optimization.\n",
"We suggest to use the same mapping, i.e., the same optimizer setting, in each iteration.\n",
"This includes using the same set of initial parameter guess. In the case when the loss\n",
"function has multiple local minima, we don't want the parameter ensemble to be biased\n",
"on the results of the other optimizations. For neural network model, we need to reset\n",
"the initial parameter value, which is done internally.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d347a576",
"metadata": {
"ExecuteTime": {
"end_time": "2024-10-07T12:39:53.510993Z",
"start_time": "2024-10-07T12:39:48.359289Z"
}
},
"outputs": [],
"source": [
"# Run bootstrap\n",
"min_kwargs = dict(method=\"lm\") # Optimizer setting\n",
"initial_guess = calc.get_opt_params() # Initial guess in the optimization\n",
"BS.run(min_kwargs=min_kwargs, initial_guess=initial_guess)"
]
},
{
"cell_type": "markdown",
"id": "e2526a32",
"metadata": {},
"source": [
"The resulting parameter ensemble can be accessed in `BS.samples` as a `np.ndarray`.\n",
"Then, we can plot the distribution of the parameters, as an example, or propagate the\n",
"error to the target quantities we want to study."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e33a7732",
"metadata": {
"ExecuteTime": {
"end_time": "2024-10-07T12:40:23.927758Z",
"start_time": "2024-10-07T12:40:23.759710Z"
}
},
"outputs": [],
"source": [
"# Plot the distribution of the parameters\n",
"plt.figure()\n",
"plt.plot(*(BS.samples.T), \".\", alpha=0.5)\n",
"param_names = list(opt_params.keys())\n",
"plt.xlabel(param_names[0])\n",
"plt.ylabel(param_names[1])\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "fe68cf9b",
"metadata": {},
"source": [
".. _OpenKIM: https://openkim.org"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": false,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}
40 changes: 40 additions & 0 deletions tests/uq/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from pathlib import Path

import pytest

from kliff.dataset import Dataset
from kliff.models import KIMModel


@pytest.fixture(scope="session")
def uq_test_dir():
# Directory of uq test files
return Path(__file__).resolve().parent


@pytest.fixture(scope="session")
def uq_test_data_dir():
# Directory of uq test data
return Path(__file__).resolve().parents[1] / "test_data/configs/Si_4"


@pytest.fixture(scope="session")
def uq_test_configs(uq_test_data_dir):
# Load test configs
data = Dataset(uq_test_data_dir)
return data.get_configs()


@pytest.fixture(scope="session")
def uq_kim_model():
# Load a KIM model
modelname = "SW_StillingerWeber_1985_Si__MO_405512056662_006"
model = KIMModel(modelname)
model.set_opt_params(A=[["default"]])
return model


@pytest.fixture(scope="session")
def uq_nn_orig_state_filename(uq_test_dir):
"""Return the original state filename for the NN model."""
return uq_test_dir / "orig_model.pkl"
Loading

0 comments on commit f2a4c90

Please sign in to comment.