From 95b8dc246b7b3490ab19352870d39d32a5afc6ab Mon Sep 17 00:00:00 2001 From: "WhO.opie" Date: Mon, 14 Nov 2022 22:47:00 +0100 Subject: [PATCH 01/10] Added Dockerfile of maurorigo --- docker/maurorigo/Dockerfile | 18 ++++++++++++++++++ docker/maurorigo/test.py | 14 ++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 docker/maurorigo/Dockerfile create mode 100644 docker/maurorigo/test.py diff --git a/docker/maurorigo/Dockerfile b/docker/maurorigo/Dockerfile new file mode 100644 index 0000000..cbba484 --- /dev/null +++ b/docker/maurorigo/Dockerfile @@ -0,0 +1,18 @@ +FROM ubuntu:focal + +RUN apt-get update \ + && apt-get install -y vim \ + && apt-get install -y python3 \ + && apt-get install -y python3-pip \ + && pip install numpy \ + && pip install --upgrade "jax[cpu]" +RUN echo 'echo "This image runs Ubuntu and has vim, python, pip, numpy and jax installed."' > /usr/local/bin/start.sh +RUN echo 'echo "For a simple test, run \"python3 test.py\"."' >> /usr/local/bin/start.sh +RUN chmod +x /usr/local/bin/start.sh + +WORKDIR /home + +# No need to do this, could also do as the starting script, but just for testing this feature +COPY test.py test.py + +ENTRYPOINT /usr/local/bin/start.sh && /bin/bash \ No newline at end of file diff --git a/docker/maurorigo/test.py b/docker/maurorigo/test.py new file mode 100644 index 0000000..80aea38 --- /dev/null +++ b/docker/maurorigo/test.py @@ -0,0 +1,14 @@ +import numpy as np +import jax +jax.config.update('jax_platform_name', 'cpu') + +print("Test code using numpy and jax with cpu.") + +arr = np.arange(10) + +arrnp = np.random.permutation(arr) +print(f"Permuted with numpy: {arrnp}.") + +key = jax.random.PRNGKey(13) +arrjnp = jax.random.permutation(key, arr) +print(f"Permuted with jax: {arrjnp}.") From 4f545b4504310e01f9a368a258c57c6ad20266c5 Mon Sep 17 00:00:00 2001 From: "WhO.opie" Date: Wed, 16 Nov 2022 14:01:15 +0100 Subject: [PATCH 02/10] Updated folder layout --- docker/maurorigo/Dockerfile | 5 +---- docker/maurorigo/test.py | 14 -------------- 2 files changed, 1 insertion(+), 18 deletions(-) delete mode 100644 docker/maurorigo/test.py diff --git a/docker/maurorigo/Dockerfile b/docker/maurorigo/Dockerfile index cbba484..5164078 100644 --- a/docker/maurorigo/Dockerfile +++ b/docker/maurorigo/Dockerfile @@ -12,7 +12,4 @@ RUN chmod +x /usr/local/bin/start.sh WORKDIR /home -# No need to do this, could also do as the starting script, but just for testing this feature -COPY test.py test.py - -ENTRYPOINT /usr/local/bin/start.sh && /bin/bash \ No newline at end of file +ENTRYPOINT /usr/local/bin/start.sh && /bin/bash diff --git a/docker/maurorigo/test.py b/docker/maurorigo/test.py deleted file mode 100644 index 80aea38..0000000 --- a/docker/maurorigo/test.py +++ /dev/null @@ -1,14 +0,0 @@ -import numpy as np -import jax -jax.config.update('jax_platform_name', 'cpu') - -print("Test code using numpy and jax with cpu.") - -arr = np.arange(10) - -arrnp = np.random.permutation(arr) -print(f"Permuted with numpy: {arrnp}.") - -key = jax.random.PRNGKey(13) -arrjnp = jax.random.permutation(key, arr) -print(f"Permuted with jax: {arrjnp}.") From 742e117ebe18d1afeb26ed3dcda0c6696dbae93b Mon Sep 17 00:00:00 2001 From: "WhO.opie" Date: Wed, 16 Nov 2022 14:02:12 +0100 Subject: [PATCH 03/10] Updated folder layout 3 --- tests/rigo/test.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 tests/rigo/test.py diff --git a/tests/rigo/test.py b/tests/rigo/test.py new file mode 100644 index 0000000..80aea38 --- /dev/null +++ b/tests/rigo/test.py @@ -0,0 +1,14 @@ +import numpy as np +import jax +jax.config.update('jax_platform_name', 'cpu') + +print("Test code using numpy and jax with cpu.") + +arr = np.arange(10) + +arrnp = np.random.permutation(arr) +print(f"Permuted with numpy: {arrnp}.") + +key = jax.random.PRNGKey(13) +arrjnp = jax.random.permutation(key, arr) +print(f"Permuted with jax: {arrjnp}.") From e94ef167ea1429b66edac66be2b00ebb694bea5b Mon Sep 17 00:00:00 2001 From: "WhO.opie" Date: Wed, 16 Nov 2022 14:08:14 +0100 Subject: [PATCH 04/10] Added actions for Rigo. --- .github/workflows/rigo.yml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 .github/workflows/rigo.yml diff --git a/.github/workflows/rigo.yml b/.github/workflows/rigo.yml new file mode 100644 index 0000000..d47c503 --- /dev/null +++ b/.github/workflows/rigo.yml @@ -0,0 +1,16 @@ +name: Actions on Rigo +on: [push] +jobs: + test-docker-image: + runs-on: ubuntu-latest + container: ghcr.io/maurorigo/jaxubuntu:latest + + steps: + - name: Checkout the repo + uses: actions/checkout@v3 + - name: Check that we have vim installed + run: test -f /bin/vim && echo "Vim is installed" + - name: Run tests in tests/heltai directory + run: | + cd tests/rigo + python3 test.py From 8aa2d4793b29c414faed3533d7b2a61f95d2d725 Mon Sep 17 00:00:00 2001 From: "WhO.opie" Date: Wed, 16 Nov 2022 14:28:22 +0100 Subject: [PATCH 05/10] Test file updated. --- tests/rigo/ciao.m | 1 + 1 file changed, 1 insertion(+) create mode 100644 tests/rigo/ciao.m diff --git a/tests/rigo/ciao.m b/tests/rigo/ciao.m new file mode 100644 index 0000000..80493a5 --- /dev/null +++ b/tests/rigo/ciao.m @@ -0,0 +1 @@ +Ciao From f94c5c5eeffd5843e84f64fb6773717927880694 Mon Sep 17 00:00:00 2001 From: "WhO.opie" Date: Wed, 16 Nov 2022 14:43:25 +0100 Subject: [PATCH 06/10] Test file 2 updated. --- ciao.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 ciao.md diff --git a/ciao.md b/ciao.md new file mode 100644 index 0000000..80493a5 --- /dev/null +++ b/ciao.md @@ -0,0 +1 @@ +Ciao From 35c50403e1350ed5a7f2806b1ca666f6a53cf2df Mon Sep 17 00:00:00 2001 From: "WhO.opie" Date: Thu, 17 Nov 2022 10:39:31 +0100 Subject: [PATCH 07/10] Reverted to original branch and updated test file. --- .github/workflows/rigo.yml | 16 ---------------- ciao.md | 1 - docker/maurorigo/Dockerfile | 2 ++ {tests/rigo => docker/maurorigo}/test.py | 1 + tests/heltai/interpolation.py | 9 --------- tests/heltai/test_interpolation.py | 10 ---------- tests/rigo/ciao.m | 1 - 7 files changed, 3 insertions(+), 37 deletions(-) delete mode 100644 .github/workflows/rigo.yml delete mode 100644 ciao.md rename {tests/rigo => docker/maurorigo}/test.py (81%) delete mode 100644 tests/heltai/interpolation.py delete mode 100644 tests/heltai/test_interpolation.py delete mode 100644 tests/rigo/ciao.m diff --git a/.github/workflows/rigo.yml b/.github/workflows/rigo.yml deleted file mode 100644 index d47c503..0000000 --- a/.github/workflows/rigo.yml +++ /dev/null @@ -1,16 +0,0 @@ -name: Actions on Rigo -on: [push] -jobs: - test-docker-image: - runs-on: ubuntu-latest - container: ghcr.io/maurorigo/jaxubuntu:latest - - steps: - - name: Checkout the repo - uses: actions/checkout@v3 - - name: Check that we have vim installed - run: test -f /bin/vim && echo "Vim is installed" - - name: Run tests in tests/heltai directory - run: | - cd tests/rigo - python3 test.py diff --git a/ciao.md b/ciao.md deleted file mode 100644 index 80493a5..0000000 --- a/ciao.md +++ /dev/null @@ -1 +0,0 @@ -Ciao diff --git a/docker/maurorigo/Dockerfile b/docker/maurorigo/Dockerfile index 5164078..0e9e1e7 100644 --- a/docker/maurorigo/Dockerfile +++ b/docker/maurorigo/Dockerfile @@ -12,4 +12,6 @@ RUN chmod +x /usr/local/bin/start.sh WORKDIR /home +COPY test.py test.py + ENTRYPOINT /usr/local/bin/start.sh && /bin/bash diff --git a/tests/rigo/test.py b/docker/maurorigo/test.py similarity index 81% rename from tests/rigo/test.py rename to docker/maurorigo/test.py index 80aea38..88603d6 100644 --- a/tests/rigo/test.py +++ b/docker/maurorigo/test.py @@ -11,4 +11,5 @@ key = jax.random.PRNGKey(13) arrjnp = jax.random.permutation(key, arr) +assert len(arrjnp) == 10, "Something is wrong with the jax installation." print(f"Permuted with jax: {arrjnp}.") diff --git a/tests/heltai/interpolation.py b/tests/heltai/interpolation.py deleted file mode 100644 index 1ffe613..0000000 --- a/tests/heltai/interpolation.py +++ /dev/null @@ -1,9 +0,0 @@ - -from numpy import * - -def lagrange(x, i, X): - """ - Returns the ith Lagrange basis function, evaluated at x, - generated by the interpolation points X - """ - return prod([(x-X[j])/(X[i]-X[j]) for j in range(len(X)) if i != j], axis=0) \ No newline at end of file diff --git a/tests/heltai/test_interpolation.py b/tests/heltai/test_interpolation.py deleted file mode 100644 index ab74a8b..0000000 --- a/tests/heltai/test_interpolation.py +++ /dev/null @@ -1,10 +0,0 @@ -from interpolation import * - -def test_cronecker(): - X = linspace(0,1,5) - for i in range(len(X)): - assert lagrange(X[i], i, X) == 1 - for j in range(len(X)): - if i != j: - assert lagrange(X[j], i, X) == 0 - diff --git a/tests/rigo/ciao.m b/tests/rigo/ciao.m deleted file mode 100644 index 80493a5..0000000 --- a/tests/rigo/ciao.m +++ /dev/null @@ -1 +0,0 @@ -Ciao From 1362339a8a28681c3aee7338c8a81f981843657b Mon Sep 17 00:00:00 2001 From: "WhO.opie" Date: Thu, 17 Nov 2022 12:52:54 +0100 Subject: [PATCH 08/10] Fixed deletion of tests folder. --- tests/heltai/interpolation.py | 9 +++++++++ tests/heltai/test_interpolation.py | 10 ++++++++++ tests/rigo/ciao.m | 1 + tests/rigo/test.py | 14 ++++++++++++++ 4 files changed, 34 insertions(+) create mode 100644 tests/heltai/interpolation.py create mode 100644 tests/heltai/test_interpolation.py create mode 100644 tests/rigo/ciao.m create mode 100644 tests/rigo/test.py diff --git a/tests/heltai/interpolation.py b/tests/heltai/interpolation.py new file mode 100644 index 0000000..1ffe613 --- /dev/null +++ b/tests/heltai/interpolation.py @@ -0,0 +1,9 @@ + +from numpy import * + +def lagrange(x, i, X): + """ + Returns the ith Lagrange basis function, evaluated at x, + generated by the interpolation points X + """ + return prod([(x-X[j])/(X[i]-X[j]) for j in range(len(X)) if i != j], axis=0) \ No newline at end of file diff --git a/tests/heltai/test_interpolation.py b/tests/heltai/test_interpolation.py new file mode 100644 index 0000000..ab74a8b --- /dev/null +++ b/tests/heltai/test_interpolation.py @@ -0,0 +1,10 @@ +from interpolation import * + +def test_cronecker(): + X = linspace(0,1,5) + for i in range(len(X)): + assert lagrange(X[i], i, X) == 1 + for j in range(len(X)): + if i != j: + assert lagrange(X[j], i, X) == 0 + diff --git a/tests/rigo/ciao.m b/tests/rigo/ciao.m new file mode 100644 index 0000000..80493a5 --- /dev/null +++ b/tests/rigo/ciao.m @@ -0,0 +1 @@ +Ciao diff --git a/tests/rigo/test.py b/tests/rigo/test.py new file mode 100644 index 0000000..80aea38 --- /dev/null +++ b/tests/rigo/test.py @@ -0,0 +1,14 @@ +import numpy as np +import jax +jax.config.update('jax_platform_name', 'cpu') + +print("Test code using numpy and jax with cpu.") + +arr = np.arange(10) + +arrnp = np.random.permutation(arr) +print(f"Permuted with numpy: {arrnp}.") + +key = jax.random.PRNGKey(13) +arrjnp = jax.random.permutation(key, arr) +print(f"Permuted with jax: {arrjnp}.") From 30c866da1029a25a1f6086c16ea6bfc59ad846d2 Mon Sep 17 00:00:00 2001 From: "WhO.opie" Date: Thu, 17 Nov 2022 12:56:15 +0100 Subject: [PATCH 09/10] Removed unnecessary files. --- tests/rigo/ciao.m | 1 - tests/rigo/test.py | 14 -------------- 2 files changed, 15 deletions(-) delete mode 100644 tests/rigo/ciao.m delete mode 100644 tests/rigo/test.py diff --git a/tests/rigo/ciao.m b/tests/rigo/ciao.m deleted file mode 100644 index 80493a5..0000000 --- a/tests/rigo/ciao.m +++ /dev/null @@ -1 +0,0 @@ -Ciao diff --git a/tests/rigo/test.py b/tests/rigo/test.py deleted file mode 100644 index 80aea38..0000000 --- a/tests/rigo/test.py +++ /dev/null @@ -1,14 +0,0 @@ -import numpy as np -import jax -jax.config.update('jax_platform_name', 'cpu') - -print("Test code using numpy and jax with cpu.") - -arr = np.arange(10) - -arrnp = np.random.permutation(arr) -print(f"Permuted with numpy: {arrnp}.") - -key = jax.random.PRNGKey(13) -arrjnp = jax.random.permutation(key, arr) -print(f"Permuted with jax: {arrjnp}.") From 8fea2819ab3ec906b877576dbc4a1c47b840fc14 Mon Sep 17 00:00:00 2001 From: "WhO.opie" Date: Tue, 29 Nov 2022 17:25:12 +0100 Subject: [PATCH 10/10] Added actions and test files. --- .github/workflows/rigo.yml | 26 +++ tests/rigo/MCtest.py | 362 +++++++++++++++++++++++++++++++++++++ tests/rigo/test.py | 16 ++ 3 files changed, 404 insertions(+) create mode 100644 .github/workflows/rigo.yml create mode 100644 tests/rigo/MCtest.py create mode 100644 tests/rigo/test.py diff --git a/.github/workflows/rigo.yml b/.github/workflows/rigo.yml new file mode 100644 index 0000000..e73087f --- /dev/null +++ b/.github/workflows/rigo.yml @@ -0,0 +1,26 @@ +name: Actions on Rigo +on: [push] +jobs: + test-docker-image: + runs-on: ubuntu-latest + container: ghcr.io/maurorigo/jaxubuntu:latest + + steps: + - name: Checkout the repo + uses: actions/checkout@v3 + - name: Test for vim + run: test -f /bin/vim && echo "Vim is installed." + - name: Test for python + run: python3 --version + - name: Test for numpy + run: python3 -c "import numpy" && echo "numpy is installed." + - name: Test for jax + run: python3 -c "import jax" && echo "jax is installed." + - name: Run short test in tests/rigo directory + run: | + cd tests/rigo + python3 test.py + - name: Run long test in tests/rigo directory + run: | + cd tests/rigo + python3 MCtest.py \ No newline at end of file diff --git a/tests/rigo/MCtest.py b/tests/rigo/MCtest.py new file mode 100644 index 0000000..cb08519 --- /dev/null +++ b/tests/rigo/MCtest.py @@ -0,0 +1,362 @@ +import numpy as np +import jax +import jax.numpy as jnp +from jax import random, grad, jit, vmap +from functools import partial +from jax import lax + +from jax.tree_util import tree_flatten, tree_map +from jax.flatten_util import ravel_pytree +from jax.example_libraries.stax import (serial, Dense, Tanh) + +import time + +jax.config.update('jax_platform_name', 'cpu') + +# DEFINE THE NEURAL NET CLASS +class Wavefunction(object): + def __init__(self, key, nstates, ndense): + self.key = key + self.nstates = nstates + self.activation = Tanh + self.ndense = ndense + self.alpha = 8 + + def build(self): + self.psi_a_init, self.psi_a_apply = serial( + Dense(self.ndense), self.activation, + Dense(self.ndense), self.activation, + Dense(self.ndense), self.activation, + Dense(1), + ) + self.key, key_input = jax.random.split(self.key) + in_shape = (-1, self.nstates) + psi_a_shape, psi_a_params = self.psi_a_init(key_input, in_shape) + self.num_psi_a_params = len(psi_a_params) + + self.psi_p_init, self.psi_p_apply = serial( + Dense(self.ndense), self.activation, + Dense(self.ndense), self.activation, + Dense(self.ndense), self.activation, + Dense(1), + ) + self.key, key_input = jax.random.split(self.key) + psi_p_shape, psi_p_params = self.psi_p_init(key_input, in_shape) + self.num_psi_p_params = len(psi_p_params) + + net_params = psi_a_params + psi_p_params + + net_params = tree_map(self.update_cast, net_params) + flat_net_params = self.flatten_params(net_params) + num_flat_params = flat_net_params.shape[0] + + return net_params, num_flat_params + + # Calculates wf + @partial(jit, static_argnums=(0,)) + def psi(self, params, ni): + num_offset_params = 0 + psi_a_params = params[num_offset_params : num_offset_params + self.num_psi_a_params] + num_offset_params = num_offset_params + self.num_psi_a_params + psi_p_params = params[num_offset_params : num_offset_params + self.num_psi_p_params] + + psiout = jnp.exp(self.alpha * jnp.tanh(self.psi_a_apply(psi_a_params, ni)/self.alpha)) * jnp.tanh(self.psi_p_apply(psi_p_params, ni)) + #psiout = jnp.exp(self.psi_a_apply(psi_a_params, ni)) * jnp.tanh(self.psi_p_apply(psi_p_params, ni)) + + return jnp.reshape(psiout, ()) + + # Batched version + @partial(jit, static_argnums=(0,)) + def vmap_psi(self, params, ni_batched): + return vmap(self.psi, in_axes=(None, 0))(params, ni_batched) + + @partial(jit, static_argnums=(0,)) + def flatten_params(self, parameters): + flatten_parameters, self.unravel = ravel_pytree(parameters) + return flatten_parameters + + @partial(jit, static_argnums=(0,)) + def unflatten_params(self, flatten_parameters): + unflatten_parameters = self.unravel(flatten_parameters) + return unflatten_parameters + + @partial(jit, static_argnums=(0,)) + def update_cast(self, params): + return params.astype(jnp.float64) + + +# DEFINE THE HAMILTONIAN CLASS +class Hamiltonian(object): + def __init__(self, npart, nstates, dvec, gmat, wavefunction): + self.npart = npart + self.nstates = nstates + self.wavefunction = wavefunction + + # Initialize 1-body and 2-body potentials + self.dvec = dvec + self.gmat = gmat + + @partial(jit, static_argnums=(0,)) + def pot1body(self, ni): + return 2*jnp.dot(ni, self.dvec) + + @partial(jit, static_argnums=(0,)) + def vmap_1body(self, ni_batched): + return vmap(self.pot1body, in_axes=0)(ni_batched) + + # 2 body potential + @partial(jit, static_argnums=(0,)) + def pot2body(self, params, ni): + + cN = self.wavefunction.psi(params, ni) + + def qbody(j, carry): + aivec, p, ni1, cumul = carry + q = aivec[j] + ni2 = ni1.at[q].add(1) + + cNp = self.wavefunction.psi(params, ni2) + cumul += cNp*self.gmat[p, q] + return aivec, p, ni1, cumul + + def pbody(i, carry): + v, ivec, aivec, ni = carry + p = ivec[i] + ni1 = ni.at[p].add(-1) + aivec, p, _, pterm2 = lax.fori_loop(0, self.nstates-self.npart, qbody, (aivec, p, ni1, 0)) # Sum over all unoccupied spins + v += pterm2 + return v, ivec, aivec, ni + + ivec = jnp.nonzero(ni, size=self.npart)[0] # Array of occupied sites + nip = 1 - ni + aivec = jnp.nonzero(nip, size=self.nstates-self.npart)[0] # Array of unoccupied sites + + v, ivec, aivec, ni = lax.fori_loop(0, self.npart, pbody, (0, ivec, aivec, ni)) + + return v/cN + jnp.dot(ni, jnp.diag(self.gmat)) + + @partial(jit, static_argnums=(0,)) + def vmap_2body(self, params, ni_batched): + return vmap(self.pot2body, in_axes=(None, 0))(params, ni_batched) + + @partial(jit, static_argnums=(0,)) + def energy(self, params, ni_batched): + e1 = self.vmap_1body(ni_batched) + e2 = self.vmap_2body(params, ni_batched) + en = e1 + e2 + #return ke, pe, en + return en + +# DEFINE THE MONTE CARLO CLASS +class Metropolis(object): + def __init__(self, npart, nstates, nwalk, neq, nav, nac, nvoid, wavefunction): + + self.npart = npart + self.nstates = nstates + self.nwalk = nwalk + self.neq = neq + self.nav = nav + self.nac = nac + self.nvoid = nvoid + self.wavefunction = wavefunction + + # Number of ordered pairs that can be formed starting from ntot elements + self.npair = int((self.nstates) * ((self.nstates) - 1) / 2) + k = 0 + ipnp = np.empty(self.npair, dtype=int) + jpnp = np.empty(self.npair, dtype=int) + for i in range(0, self.nstates-1): + for j in range(i+1, self.nstates): + ipnp[k] = i + jpnp[k] = j + k+=1 + + self.ip = jnp.array(ipnp) + self.jp = jnp.array(jpnp) + + # Initializes Fock state (two spins) + @partial(jit, static_argnums=(0,)) + def nocc_init(self, key): + key, key_input = random.split(key) + ni = jnp.zeros((self.nwalk, self.nstates)) + ni = ni.at[:, 0:self.npart].set(1) + ni = random.permutation(key_input, ni, 1, independent=True) + return key, ni + + # Exchanges occupation number of two states + @partial(jit, static_argnums=(0,)) + def nocc_exch(self, ni_o, k): + # Spin up exchange + ip = self.ip[k] + jp = self.jp[k] + ni_n = ni_o.at[ip].set(ni_o[jp]) + ni_n = ni_n.at[jp].set(ni_o[ip]) + return ni_n + + @partial(jit, static_argnums=(0,)) + def nocc_prop(self, ni_o_batched, k_batched): + return vmap(self.nocc_exch, in_axes=(0, 0))(ni_o_batched, k_batched) + + # Void steps + @partial(jit, static_argnums=(0,)) + def step_void(self, key, ni_o, acc, params): + # Generate all random numbers + key, key_input = random.split(key) + k = jax.random.randint(key_input, shape = [self.nvoid, self.nwalk], minval = 0, maxval = self.npair) + key, key_input = random.split(key) + unifp = jax.random.uniform(key_input, shape = [self.nvoid, self.nwalk]) + def step(i, carry): + ni_o, wf_o, acc = carry + ni_n = self.nocc_prop(ni_o, k[i, :]) + wf_n = self.wavefunction.vmap_psi(params, ni_n) + prob = ( wf_n / wf_o )**2 + accept = jnp.greater_equal(prob, unifp[i, :]) + ni_o = jnp.where(accept.reshape([self.nwalk,1]), ni_n, ni_o) + wf_o = jnp.where(accept, wf_n, wf_o) + acc += jnp.mean(accept.astype('float32')) + return ni_o, wf_o, acc + + # Initialize log of wavefunction in order to calculate it only once in each loop + wf_o = self.wavefunction.vmap_psi(params, ni_o) + ni_o, foo, acc = lax.fori_loop(0, self.nvoid, step, (ni_o, wf_o, acc)) + return key, ni_o, acc + + @partial(jit, static_argnums=(0,)) + def initialize(self, key, nin, params): + key, ni_o = self.nocc_init(key) + # Initialization steps + def initialization(i, carry): + key, ni_o, acc = carry + key, ni_o, acc = self.step_void(key, ni_o, acc, params) + return key, ni_o, acc + acc = 0 + key, ni_o, acc = lax.fori_loop(0, nin, initialization, (key, ni_o, acc)) + return key, ni_o + + @partial(jit, static_argnums=(0,)) + def walk(self, key, params, ni_o): + # Equilibrium steps + def equilibration(i, carry): + key, ni_o, acc = carry + key, ni_o, acc = self.step_void(key, ni_o, acc, params) + return key, ni_o, acc + acc = 0 + key, ni_o, acc = lax.fori_loop(0, self.neq, equilibration, (key, ni_o, acc)) + + ni_stored = jnp.empty((self.nav + self.nac, self.nwalk, self.nstates)) + # Average steps + def average(i, carry): + key, ni_o, acc, ni_stored = carry + key, ni_o, acc = self.step_void(key, ni_o, acc, params) + ni_stored = ni_stored.at[i, :, :].set(ni_o) + return key, ni_o, acc, ni_stored + acc = 0 + key, ni_o, acc, ni_stored = lax.fori_loop(0, self.nav+self.nac, average, (key, ni_o, acc, ni_stored)) + acc /= (self.nav+self.nac) * self.nvoid + + return key, acc, ni_stored + + +# DEFINE THE ESTIMATOR CLASS +class Estimator(object): + def reset(self): + self.energy_blk = 0 + self.energy2_blk = 0 + self.weight_blk = 0 + + def addval(self, energy): + self.energy_blk += energy + self.energy2_blk += energy**2 + self.weight_blk += 1 + + def average(self): + self.energy = self.energy_blk / self.weight_blk + self.energy2 = self.energy2_blk / self.weight_blk + + error = jnp.sqrt((self.energy2 - self.energy**2) / (self.weight_blk-1)) + return error + + +# RUN CODE +print("Monte Carlo for the nuclear pairing model.") +from jax.config import config +config.update("jax_enable_x64", True) + +npart = 5 # Number of pairs +nstates = 10 # Number of energy levels +nin = 30 # Equilibration steps for first initialization of ni +neq = 10 # Equilibration steps +nav = 40 # Averaging steps +nac = 4 # Check steps +nvoid = 200 # Void steps between energy calculations +nwalk = 800 # Quantum Monte Carlo configurations + +ndense = 10 + +seed_net = 73 +seed_walk = 103 + +# Initialize the network with one batch dimension, ndim, and npart +key = random.PRNGKey(seed_net) +wavefunction = Wavefunction(key, nstates, ndense) +params, nparams = wavefunction.build() +print("Number of parameters of the neural net = ", nparams) + +# Initialize Metropolis sampler +metropolis = Metropolis(npart, nstates, nwalk, neq, nav, nac, nvoid, wavefunction) + +# Initialize Hamiltonian +gconst = -.6 +pvec = jnp.arange(nstates) +gmat = jnp.zeros((nstates, nstates)) +for i in range(nstates): + for j in range(nstates): + gmat = gmat.at[i, j].set(gconst) +hamiltonian = Hamiltonian(npart, nstates, pvec, gmat, wavefunction) + +# Initialize Estimator +estimator = Estimator() + +print("Classes initialized. Initializing MC states...") + +# Store the last walker +ni_o = jnp.zeros(shape=[nwalk, nstates]) +# Initialize ni for the first time +key = random.PRNGKey(seed_walk) +key, ni_o = metropolis.initialize(key, nin, params) + +print("MC states initialized. Performing MC walk...") + +# Metropolis energy calculation +twlk_i = time.time() + +key, acc, ni_stored = metropolis.walk(key, params, ni_o) +ni_stored.block_until_ready() +if ni_stored.shape[0] == nav+nac and ni_stored.shape[1] == nwalk and ni_stored.shape[2] == nstates: + print("Stored states matrix has the right dimension.") + print(f"{ni_stored.shape[0]}, {ni_stored.shape[1]}, {ni_stored.shape[2]} vs {nav+nac}, {nwalk}, {nstates}.") +else: + print(f"{ni_stored.shape[0]}, {ni_stored.shape[1]}, {ni_stored.shape[2]} vs {nav+nac}, {nwalk}, {nstates}.") + raise Exception("Stored states matrix doesn't have the right dimension.") + + +twlk_f = time.time() +print(f"Walk stored, elapsed time: {(twlk_f - twlk_i):.2f}s. Computing MC energy...") +estimator.reset() +energy_stored = jnp.zeros(shape=[nav, nwalk]) + +twlk_i = time.time() +for i in range(nav): + energy = hamiltonian.energy(params, ni_stored[i, :, :]) + energy_stored = energy_stored.at[i,:].set(energy) + estimator.addval(jnp.mean(energy)) +energy.block_until_ready() +twlk_f = time.time() +print(f"Energy computed, elapsed time: {(twlk_f - twlk_i):.2f}s.") + +error = estimator.average() +energy = estimator.energy + +estimator.reset() + +print(f"Energy = {energy:.6f}, err = {error:.6f}.") diff --git a/tests/rigo/test.py b/tests/rigo/test.py new file mode 100644 index 0000000..a10bf3c --- /dev/null +++ b/tests/rigo/test.py @@ -0,0 +1,16 @@ +import numpy as np +import jax +jax.config.update('jax_platform_name', 'cpu') + +print("Test code using numpy and jax with cpu.") + +arr = np.arange(10) + +arrnp = np.random.permutation(arr) +print(f"Permuted with numpy: {arrnp}.") + +key = jax.random.PRNGKey(13) +arrjnp = jax.random.permutation(key, arr) +assert len(arrjnp) == 10, "Something is wrong with the jax installation." +print("Jax installation does not have issues.") +print(f"Permuted with jax: {arrjnp}.")