Skip to content

Commit

Permalink
add back pzflow demo
Browse files Browse the repository at this point in the history
  • Loading branch information
ztq1996 committed Aug 1, 2024
1 parent e55c667 commit f42e72a
Showing 1 changed file with 191 additions and 0 deletions.
191 changes: 191 additions & 0 deletions examples/estimation_examples/pzflow_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "327d391f-58bc-4b6a-9bbe-3987b969c8f4",
"metadata": {},
"source": [
"PZFlow Informer and Estimator Demo\n",
"\n",
"Author: Tianqing Zhang\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "916a05ad",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import os\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import rail\n",
"from rail.core.data import TableHandle\n",
"from rail.core.stage import RailStage\n",
"import qp\n",
"import tables_io\n",
"\n",
"from rail.estimation.algos.pzflow_nf import PZFlowInformer, PZFlowEstimator\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8ef87d3",
"metadata": {},
"outputs": [],
"source": [
"DS = RailStage.data_store\n",
"DS.__class__.allow_overwrite = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f79c3a7b",
"metadata": {},
"outputs": [],
"source": [
"from rail.utils.path_utils import find_rail_file\n",
"trainFile = find_rail_file('examples_data/testdata/test_dc2_training_9816.hdf5')\n",
"testFile = find_rail_file('examples_data/testdata/test_dc2_validation_9816.hdf5')\n",
"training_data = DS.read_file(\"training_data\", TableHandle, trainFile)\n",
"test_data = DS.read_file(\"test_data\", TableHandle, testFile)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "756d78a3",
"metadata": {},
"outputs": [],
"source": [
"pzflow_dict = dict(hdf5_groupname='photometry',output_mode = 'not_fiducial' )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0857e6bb-18eb-4f89-bc4b-29bed1ffa122",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "1042a9f3",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# epoch = 200 gives a reasonable converged loss\n",
"pzflow_train = PZFlowInformer.make_stage(name='inform_pzflow',model='demo_pzflow.pkl',num_training_epochs = 30, **pzflow_dict)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ae0180b-b67f-4ac1-bf90-545c7c2fbb62",
"metadata": {},
"outputs": [],
"source": [
"pzflow_train.finalize()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c407f45b",
"metadata": {},
"outputs": [],
"source": [
"# training of the pzflow\n",
"pzflow_train.inform(training_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "156b6e3d",
"metadata": {},
"outputs": [],
"source": [
"pzflow_dict = dict(hdf5_groupname='photometry')\n",
"\n",
"pzflow_estimator = PZFlowEstimator.make_stage(name='estimate_pzflow',model='demo_pzflow.pkl',**pzflow_dict, chunk_size = 20000)\n",
"\n",
"pzflow_estimator._aliases = dict()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00911d60",
"metadata": {},
"outputs": [],
"source": [
"# estimate using the test data\n",
"estimate_results = pzflow_estimator.estimate(test_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4cbdece3",
"metadata": {},
"outputs": [],
"source": [
"mode = estimate_results.data.ancil['zmode']\n",
"truth = np.array(training_data.data['photometry']['redshift'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba076bab-c5ab-4292-8de9-415e7b30af5c",
"metadata": {},
"outputs": [],
"source": [
"# visualize the prediction. \n",
"plt.figure(figsize = (8,8))\n",
"plt.scatter(truth, mode, s = 0.5)\n",
"plt.xlabel('True Redshift')\n",
"plt.ylabel('Mode of Estimated Redshift')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed5bf266-3b5c-4d9b-8428-77a2833cafef",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "rail_env",
"language": "python",
"name": "rail_env"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit f42e72a

Please sign in to comment.