From 5c6103fb496cd2991243a492e6a41c3c524f1834 Mon Sep 17 00:00:00 2001 From: zdcao121 Date: Tue, 28 May 2024 15:02:26 +0800 Subject: [PATCH] add w_mask to control the sampling --- README.md | 4 ++-- scripts/awl2struct.py | 2 +- src/elements.py | 17 +++++++++++++++++ src/main.py | 39 ++++++++++++++++++++++++++++++++++++--- src/sample.py | 22 +++++++++++++++++----- src/wyckoff.py | 2 +- 6 files changed, 74 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 904dbad..0a4ad75 100644 --- a/README.md +++ b/README.md @@ -48,8 +48,8 @@ We only consider symmetry inequivalent atoms. The remaining atoms are restored b **Notebooks**: The quickest way to get started with _CrystalFormer_ is our notebooks in the Google Colab and Bohrium (Chinese version) platforms: -- CrystalFormer Quickstart [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1IMQV6OQgIGORE8FmSTmZuC5KgQwGCnDx?usp=sharing) [![Open In Bohrium](https://cdn.dp.tech/bohrium/web/static/images/open-in-bohrium.svg)](https://nb.bohrium.dp.tech/detail/68177247598): GUI notebook demonstrating the conditional generation of crystalline materials with _CrystalFormer_. - +- CrystalFormer Quickstart [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1IMQV6OQgIGORE8FmSTmZuC5KgQwGCnDx?usp=sharing) [![Open In Bohrium](https://cdn.dp.tech/bohrium/web/static/images/open-in-bohrium.svg)](https://nb.bohrium.dp.tech/detail/68177247598): GUI notebook demonstrating the conditional generation of crystalline materials with _CrystalFormer_; +- CrystalFormer Application [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QdkELaQXAHR1zEu2fcdfgabuoP61_wbU?usp=sharing): Generating stable crystals with a given structure prototype. This workflow can be applied to tasks that are dominated by element substitution. ## Installation diff --git a/scripts/awl2struct.py b/scripts/awl2struct.py index 4e382d1..ead153a 100644 --- a/scripts/awl2struct.py +++ b/scripts/awl2struct.py @@ -41,7 +41,7 @@ def symmetrize_atoms(g, w, x): #https://github.com/qzhu2017/PyXtal/blob/82e7d0eac1965c2713179eeda26a60cace06afc8/pyxtal/wyckoff_site.py#L115 def dist_to_op0x(coord): diff = np.dot(symops[g-1, w, 0], np.array([*coord, 1])) - coord - diff -= np.floor(diff) + diff -= np.rint(diff) return np.sum(diff**2) # loc = np.argmin(jax.vmap(dist_to_op0x)(coords)) loc = np.argmin([dist_to_op0x(coord) for coord in coords]) diff --git a/src/elements.py b/src/elements.py index 81ef020..8c99a44 100644 --- a/src/elements.py +++ b/src/elements.py @@ -26,6 +26,17 @@ element_dict = {value: index for index, value in enumerate(element_list)} +# radioactive elements +radioactive_elements = [ 'Tc', 'Pm', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', + 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', + 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og'] +radioactive_elements_dict = {e: element_dict[e] for e in radioactive_elements} + +# noble gas elements +noble_gas = ['He', 'Ne', 'Ar', 'Kr', 'Xe', 'Rn', 'Og'] +noble_gas_dict = {e: element_dict[e] for e in noble_gas} + + if __name__=="__main__": print (len(element_list)) print (element_dict["H"]) @@ -38,5 +49,11 @@ aw_mask = [1] + [1 if ((i-1)%(atom_types-1)+1 in idx) else 0 for i in range(1, aw_types)] # 1 for possible elements print (idx ) print (aw_mask) + print(radioactive_elements_dict) + print(noble_gas_dict) + atom_mask = [1] + [1 if i not in radioactive_elements_dict.values() and i not in noble_gas_dict.values() else 0 for i in range(1, atom_types)] + print('sampling structure formed by non-radioactive elements and non-noble gas') + print(atom_mask) + diff --git a/src/main.py b/src/main.py index c423327..a000c1a 100644 --- a/src/main.py +++ b/src/main.py @@ -7,7 +7,7 @@ import multiprocessing import math -from utils import GLXYZAW_from_file, GLXA_to_csv +from utils import GLXYZAW_from_file, GLXA_to_csv, letter_to_number from elements import element_dict, element_list from transformer import make_transformer from train import train @@ -59,10 +59,14 @@ group.add_argument('--wyck_types', type=int, default=28, help='Number of possible multiplicites including 0') group = parser.add_argument_group('sampling parameters') +group.add_argument('--seed', type=int, default=None, help='random seed to sample') group.add_argument('--spacegroup', type=int, help='The space group id to be sampled (1-230)') +group.add_argument('--wyckoff', type=str, default=None, nargs='+', help='The Wyckoff positions to be sampled, e.g. a, b') group.add_argument('--elements', type=str, default=None, nargs='+', help='name of the chemical elemenets, e.g. Bi, Ti, O') +group.add_argument('--remove_radioactive', action='store_true', help='remove radioactive elements and noble gas') group.add_argument('--top_p', type=float, default=1.0, help='1.0 means un-modified logits, smaller value of p give give less diverse samples') group.add_argument('--temperature', type=float, default=1.0, help='temperature used for sampling') +group.add_argument('--T1', type=float, default=None, help='temperature used for sampling the first atom type') group.add_argument('--num_io_process', type=int, default=40, help='number of process used in multiprocessing io') group.add_argument('--num_samples', type=int, default=1000, help='number of test samples') group.add_argument('--use_foriloop', action='store_true', help='use lax.fori_loop in sampling') @@ -94,7 +98,28 @@ print ('sampling structure formed by these elements:', args.elements) print (atom_mask) else: - atom_mask = jnp.zeros((args.atom_types), dtype=int) # we will do nothing to a_logit in sampling + if args.remove_radioactive: + from elements import radioactive_elements_dict, noble_gas_dict + # remove radioactive elements and noble gas + atom_mask = [1] + [1 if i not in radioactive_elements_dict.values() and i not in noble_gas_dict.values() else 0 for i in range(1, args.atom_types)] + atom_mask = jnp.array(atom_mask) + print('sampling structure formed by non-radioactive elements and non-noble gas') + print(atom_mask) + + else: + atom_mask = jnp.zeros((args.atom_types), dtype=int) # we will do nothing to a_logit in sampling + print(f'there is total {jnp.sum(atom_mask)-1} elements') + + if args.wyckoff is not None: + idx = [letter_to_number(w) for w in args.wyckoff] + # padding 0 until the length is args.n_max + w_mask = idx + [0]*(args.n_max -len(idx)) + # w_mask = [1 if w in idx else 0 for w in range(1, args.wyck_types+1)] + w_mask = jnp.array(w_mask, dtype=int) + print ('sampling structure formed by these Wyckoff positions:', args.wyckoff) + print (w_mask) + else: + w_mask = None ################### Model ############################# params, transformer = make_transformer(key, args.Nf, args.Kx, args.Kl, args.n_max, @@ -188,6 +213,14 @@ jax.config.update("jax_enable_x64", True) # to get off compilation warning, and to prevent sample nan lattice #FYI, the error was [Compiling module extracted] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results. + if args.seed is not None: + key = jax.random.PRNGKey(args.seed) # reset key for sampling if seed is provided + + if args.T1 is not None: + T1 = args.T1 + else: + T1 = args.temperature + num_batches = math.ceil(args.num_samples / args.batchsize) name, extension = args.output_filename.rsplit('.', 1) filename = os.path.join(output_path, @@ -197,7 +230,7 @@ end_idx = min(start_idx + args.batchsize, args.num_samples) n_sample = end_idx - start_idx key, subkey = jax.random.split(key) - XYZ, A, W, M, L = sample_crystal(subkey, transformer, params, args.n_max, n_sample, args.atom_types, args.wyck_types, args.Kx, args.Kl, args.spacegroup, atom_mask, args.top_p, args.temperature, args.use_foriloop) + XYZ, A, W, M, L = sample_crystal(subkey, transformer, params, args.n_max, n_sample, args.atom_types, args.wyck_types, args.Kx, args.Kl, args.spacegroup, w_mask, atom_mask, args.top_p, args.temperature, T1, args.use_foriloop) print ("XYZ:\n", XYZ) # fractional coordinate print ("A:\n", A) # element type print ("W:\n", W) # Wyckoff positions diff --git a/src/sample.py b/src/sample.py index e9b3c90..44f4b8a 100644 --- a/src/sample.py +++ b/src/sample.py @@ -59,8 +59,8 @@ def sample_x(key, h_x, Kx, top_p, temperature, batchsize): x = (x+ jnp.pi)/(2.0*jnp.pi) # wrap into [0, 1] return key, x -@partial(jax.jit, static_argnums=(1, 3, 4, 5, 6, 7, 8, 9, 11, 13)) -def sample_crystal(key, transformer, params, n_max, batchsize, atom_types, wyck_types, Kx, Kl, g, atom_mask, top_p, temperature, use_foriloop): +@partial(jax.jit, static_argnums=(1, 3, 4, 5, 6, 7, 8, 9, 12, 14, 15)) +def sample_crystal(key, transformer, params, n_max, batchsize, atom_types, wyck_types, Kx, Kl, g, w_mask, atom_mask, top_p, temperature, T1, use_foriloop): if use_foriloop: @@ -72,16 +72,22 @@ def body_fn(i, state): w_logit = w_logit[:, :wyck_types] key, subkey = jax.random.split(key) + if w_mask is not None: + w_logit = w_logit.at[:, w_mask[i]].set(w_logit[:, w_mask[i]] + 1e10) w = sample_top_p(subkey, w_logit, top_p, temperature) W = W.at[:, i].set(w) - + # (2) A h_al = inference(transformer, params, g, W, A, X, Y, Z)[:, 5*i+1] # (batchsize, output_size) a_logit = h_al[:, :atom_types] key, subkey = jax.random.split(key) a_logit = a_logit + jnp.where(atom_mask, 1e10, 0.0) # enhance the probability of masked atoms (do not need to normalize since we only use it for sampling, not computing logp) - a = sample_top_p(subkey, a_logit, top_p, temperature) + _temp = jax.lax.cond(i==0, + true_fun=lambda x: jnp.array(T1, dtype=float), + false_fun=lambda x: temperature, + operand=None) + a = sample_top_p(subkey, a_logit, top_p, _temp) # use T1 for the first atom type A = A.at[:, i].set(a) lattice_params = h_al[:, atom_types:atom_types+Kl+2*6*Kl] @@ -154,6 +160,8 @@ def body_fn(i, state): w_logit = w_logit[:, :wyck_types] key, subkey = jax.random.split(key) + if w_mask is not None: + w_logit = w_logit.at[:, w_mask[i]].set(w_logit[:, w_mask[i]] + 1e10) w = sample_top_p(subkey, w_logit, top_p, temperature) W = jnp.concatenate([W, w[:, None]], axis=1) @@ -170,7 +178,11 @@ def body_fn(i, state): key, subkey = jax.random.split(key) a_logit = a_logit + jnp.where(atom_mask, 1e10, 0.0) # enhance the probability of masked atoms (do not need to normalize since we only use it for sampling, not computing logp) - a = sample_top_p(subkey, a_logit, top_p, temperature) + _temp = jax.lax.cond(i==0, + true_fun=lambda x: jnp.array(T1, dtype=float), + false_fun=lambda x: temperature, + operand=None) + a = sample_top_p(subkey, a_logit, top_p, _temp) # use T1 for the first atom type A = jnp.concatenate([A, a[:, None]], axis=1) lattice_params = h_al[:, atom_types:atom_types+Kl+2*6*Kl] diff --git a/src/wyckoff.py b/src/wyckoff.py index c480918..b4b43a4 100644 --- a/src/wyckoff.py +++ b/src/wyckoff.py @@ -115,7 +115,7 @@ def symmetrize_atoms(g, w, x): #https://github.com/qzhu2017/PyXtal/blob/82e7d0eac1965c2713179eeda26a60cace06afc8/pyxtal/wyckoff_site.py#L115 def dist_to_op0x(coord): diff = jnp.dot(symops[g-1, w, 0], jnp.array([*coord, 1])) - coord - diff -= jnp.floor(diff) + diff -= jnp.rint(diff) return jnp.sum(diff**2) loc = jnp.argmin(jax.vmap(dist_to_op0x)(coords)) x = coords[loc].reshape(3,)