Skip to content

Commit

Permalink
chore(docs): document how to load trained model (#196)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeertmans authored Jan 7, 2025
1 parent 41a270f commit 604f014
Showing 1 changed file with 49 additions and 3 deletions.
52 changes: 49 additions & 3 deletions docs/source/notebooks/sampling_paths.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"metadata": {},
"outputs": [],
"source": [
"import urllib.request\n",
"from collections.abc import Iterator\n",
"from functools import partial\n",
"from typing import Any\n",
Expand All @@ -86,7 +87,7 @@
"import optax\n",
"from beartype import beartype as typechecker\n",
"from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray, PyTree, jaxtyped\n",
"from tqdm.notebook import trange\n",
"from tqdm.notebook import tqdm, trange\n",
"\n",
"from differt.geometry import TriangleMesh\n",
"from differt.plotting import reuse, set_defaults\n",
Expand Down Expand Up @@ -4581,7 +4582,45 @@
" scan_fun,\n",
" init=None,\n",
" xs=jax.random.split(key, num_path_candidates),\n",
" )[1]"
" )[1]\n",
"\n",
" @classmethod\n",
" def load(cls, url: str, **kwargs: Any) -> \"Self\":\n",
" \"\"\"Load a model from a remote URL.\n",
"\n",
" This is a convenient method to load a model from pre-trained\n",
" weights uploaded on the internet.\n",
"\n",
" Args:\n",
" url: The url to download the weights from.\n",
"\n",
" E.g., provide \"https://github.com/jeertmans/DiffeRT/releases/download/icmlcn2025/model_1.eqx\"\n",
" to download the model weights for first order reflection.\n",
" kwargs: Keyword arguments passed to initialize the model.\n",
"\n",
" Shape arguments must be the same as those used when training,\n",
" other arguments (such as the random key) are not important, and\n",
" dummy values can be provided.\n",
" \"\"\"\n",
" with tqdm(\n",
" unit=\"B\",\n",
" unit_scale=True,\n",
" unit_divisor=1024,\n",
" desc=f\"Retrieving model from: {url}\",\n",
" miniters=1,\n",
" ) as bar:\n",
"\n",
" def reporthook(\n",
" block_count: int, block_size: int, total_size: int\n",
" ) -> None:\n",
" bar.total = total_size\n",
" bar.update(block_count * block_size - bar.n)\n",
"\n",
" model = eqx.filter_eval_shape(cls, **kwargs)\n",
" filename, _ = urllib.request.urlretrieve(url, reporthook=reporthook)\n",
"\n",
" with open(filename, \"rb\") as f:\n",
" return eqx.tree_deserialise_leaves(f, model)"
]
},
{
Expand Down Expand Up @@ -4957,7 +4996,14 @@
" optim,\n",
" order=order,\n",
" key=key_train_model,\n",
")"
")\n",
"\n",
"# TIP: model can be loaded from pre-trained weights with\n",
"# trained_model = Model.load(\n",
"# \"https://github.com/jeertmans/DiffeRT/releases/download/icmlcn2025/model_1.eqx\",\n",
"# key=jax.random.key(0),\n",
"# sample_quads=sample_quads\n",
"# )"
]
},
{
Expand Down

0 comments on commit 604f014

Please sign in to comment.