Skip to content

Commit

Permalink
Edits to opt savings and opt invest (#154)
Browse files Browse the repository at this point in the history
* misc

* misc

* misc

* misc
  • Loading branch information
jstac authored Mar 15, 2024
1 parent 91efc04 commit ce4ef38
Show file tree
Hide file tree
Showing 7 changed files with 313 additions and 107 deletions.
13 changes: 7 additions & 6 deletions lectures/_static/lecture_specific/hpi.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Implements HPI-Howard policy iteration routine

def policy_iteration(model, maxiter=250):
constants, sizes, arrays = model
def howard_policy_iteration(model, maxiter=250):
"""
Implements Howard policy iteration (see dp.quantecon.org)
"""
params, sizes, arrays = model
σ = jnp.zeros(sizes, dtype=int)
i, error = 0, 1.0
while error > 0 and i < maxiter:
v_σ = get_value(σ, constants, sizes, arrays)
σ_new = get_greedy(v_σ, constants, sizes, arrays)
v_σ = get_value(σ, params, sizes, arrays)
σ_new = get_greedy(v_σ, params, sizes, arrays)
error = jnp.max(jnp.abs(σ_new - σ))
σ = σ_new
i = i + 1
Expand Down
13 changes: 7 additions & 6 deletions lectures/_static/lecture_specific/opi.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Implements the OPI-Optimal policy Iteration routine

def optimistic_policy_iteration(model, tol=1e-5, m=10):
constants, sizes, arrays = model
"""
Implements optimistic policy iteration (see dp.quantecon.org)
"""
params, sizes, arrays = model
v = jnp.zeros(sizes)
error = tol + 1
while error > tol:
last_v = v
σ = get_greedy(v, constants, sizes, arrays)
σ = get_greedy(v, params, sizes, arrays)
for _ in range(m):
v = T_σ(v, σ, constants, sizes, arrays)
v = T_σ(v, σ, params, sizes, arrays)
error = jnp.max(jnp.abs(v - last_v))
return get_greedy(v, constants, sizes, arrays)
return get_greedy(v, params, sizes, arrays)
6 changes: 3 additions & 3 deletions lectures/_static/lecture_specific/successive_approx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
def successive_approx_jax(T, # Operator (callable)
x_0, # Initial condition
tolerance=1e-6, # Error tolerance
tol=1e-6, # Error tolerance
max_iter=10_000): # Max iteration bound
def body_fun(k_x_err):
k, x, error = k_x_err
Expand All @@ -10,9 +10,9 @@ def body_fun(k_x_err):

def cond_fun(k_x_err):
k, x, error = k_x_err
return jnp.logical_and(error > tolerance, k < max_iter)
return jnp.logical_and(error > tol, k < max_iter)

k, x, error = jax.lax.while_loop(cond_fun, body_fun, (1, x_0, tolerance + 1))
k, x, error = jax.lax.while_loop(cond_fun, body_fun, (1, x_0, tol + 1))
return x

successive_approx_jax = jax.jit(successive_approx_jax, static_argnums=(0,))
14 changes: 8 additions & 6 deletions lectures/_static/lecture_specific/vfi.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Implements VFI-Value Function iteration

def value_iteration(model, tol=1e-5):
constants, sizes, arrays = model
def value_function_iteration(model, tol=1e-5):
"""
Implements value function iteration.
"""
params, sizes, arrays = model
vz = jnp.zeros(sizes)
_T = lambda v: T(v, constants, sizes, arrays)
_T = lambda v: T(v, params, sizes, arrays)
v_star = successive_approx_jax(_T, vz, tolerance=tol)
return get_greedy(v_star, constants, sizes, arrays)
return get_greedy(v_star, params, sizes, arrays)

116 changes: 107 additions & 9 deletions lectures/opt_invest.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,30 +294,128 @@ get_value = jax.jit(get_value, static_argnums=(2,))
We use successive approximation for VFI.

```{code-cell} ipython3
:load: _static/lecture_specific/successive_approx.py
def successive_approx_jax(T, # Operator (callable)
x_0, # Initial condition
tol=1e-6, # Error tolerance
max_iter=10_000): # Max iteration bound
def body_fun(k_x_err):
k, x, error = k_x_err
x_new = T(x)
error = jnp.max(jnp.abs(x_new - x))
return k + 1, x_new, error
def cond_fun(k_x_err):
k, x, error = k_x_err
return jnp.logical_and(error > tol, k < max_iter)
k, x, error = jax.lax.while_loop(cond_fun, body_fun, (1, x_0, tol + 1))
return x
successive_approx_jax = jax.jit(successive_approx_jax, static_argnums=(0,))
```

For OPI we'll add a compiled routine that computes $T_σ^m v$.

```{code-cell} ipython3
def iterate_policy_operator(σ, v, m, params, sizes, arrays):
def update(i, v):
v = T_σ(v, σ, params, sizes, arrays)
return v
v = jax.lax.fori_loop(0, m, update, v)
return v
iterate_policy_operator = jax.jit(iterate_policy_operator,
static_argnums=(4,))
```

Finally, we introduce the solvers that implement VFI, HPI and OPI.

```{code-cell} ipython3
:load: _static/lecture_specific/vfi.py
def value_function_iteration(model, tol=1e-5):
"""
Implements value function iteration.
"""
params, sizes, arrays = model
vz = jnp.zeros(sizes)
_T = lambda v: T(v, params, sizes, arrays)
v_star = successive_approx_jax(_T, vz, tol=tol)
return get_greedy(v_star, params, sizes, arrays)
```

For OPI we will use a compiled JAX `lax.while_loop` operation to speed execution.


```{code-cell} ipython3
:load: _static/lecture_specific/hpi.py
def opi_loop(params, sizes, arrays, m, tol, max_iter):
"""
Implements optimistic policy iteration (see dp.quantecon.org) with
step size m.
"""
v_init = jnp.zeros(sizes)
def condition_function(inputs):
i, v, error = inputs
return jnp.logical_and(error > tol, i < max_iter)
def update(inputs):
i, v, error = inputs
last_v = v
σ = get_greedy(v, params, sizes, arrays)
v = iterate_policy_operator(σ, v, m, params, sizes, arrays)
error = jnp.max(jnp.abs(v - last_v))
i += 1
return i, v, error
num_iter, v, error = jax.lax.while_loop(condition_function,
update,
(0, v_init, tol + 1))
return get_greedy(v, params, sizes, arrays)
opi_loop = jax.jit(opi_loop, static_argnums=(1,))
```

Here's a friendly interface to OPI

```{code-cell} ipython3
:load: _static/lecture_specific/opi.py
def optimistic_policy_iteration(model, m=10, tol=1e-5, max_iter=10_000):
params, sizes, arrays = model
σ_star = opi_loop(params, sizes, arrays, m, tol, max_iter)
return σ_star
```

Here's HPI


```{code-cell} ipython3
def howard_policy_iteration(model, maxiter=250):
"""
Implements Howard policy iteration (see dp.quantecon.org)
"""
params, sizes, arrays = model
σ = jnp.zeros(sizes, dtype=int)
i, error = 0, 1.0
while error > 0 and i < maxiter:
v_σ = get_value(σ, params, sizes, arrays)
σ_new = get_greedy(v_σ, params, sizes, arrays)
error = jnp.max(jnp.abs(σ_new - σ))
σ = σ_new
i = i + 1
print(f"Concluded loop {i} with error {error}.")
return σ
```


```{code-cell} ipython3
:tags: [hide-output]
model = create_investment_model()
print("Starting HPI.")
qe.tic()
out = policy_iteration(model)
out = howard_policy_iteration(model)
elapsed = qe.toc()
print(out)
print(f"HPI completed in {elapsed} seconds.")
Expand All @@ -328,7 +426,7 @@ print(f"HPI completed in {elapsed} seconds.")
print("Starting VFI.")
qe.tic()
out = value_iteration(model)
out = value_function_iteration(model)
elapsed = qe.toc()
print(out)
print(f"VFI completed in {elapsed} seconds.")
Expand Down Expand Up @@ -356,7 +454,7 @@ y_grid, z_grid, Q = arrays
```

```{code-cell} ipython3
σ_star = policy_iteration(model)
σ_star = howard_policy_iteration(model)
fig, ax = plt.subplots(figsize=(9, 5))
ax.plot(y_grid, y_grid, "k--", label="45")
Expand All @@ -376,15 +474,15 @@ m_vals = range(5, 600, 40)
model = create_investment_model()
print("Running Howard policy iteration.")
qe.tic()
σ_pi = policy_iteration(model)
σ_pi = howard_policy_iteration(model)
pi_time = qe.toc()
```

```{code-cell} ipython3
print(f"PI completed in {pi_time} seconds.")
print("Running value function iteration.")
qe.tic()
σ_vfi = value_iteration(model, tol=1e-5)
σ_vfi = value_function_iteration(model, tol=1e-5)
vfi_time = qe.toc()
print(f"VFI completed in {vfi_time} seconds.")
```
Expand Down
Loading

0 comments on commit ce4ef38

Please sign in to comment.