From 616de3dfeef41149425266c4152cb59097232766 Mon Sep 17 00:00:00 2001 From: Christina Floristean Date: Mon, 4 Dec 2023 10:11:39 -0500 Subject: [PATCH 1/2] Fix for MSA block deletion --- openfold/config.py | 8 ++++++++ openfold/data/input_pipeline.py | 3 +++ 2 files changed, 11 insertions(+) diff --git a/openfold/config.py b/openfold/config.py index dea58afa..eaca89b7 100644 --- a/openfold/config.py +++ b/openfold/config.py @@ -311,6 +311,11 @@ def model_config( "true_msa": [NUM_MSA_SEQ, NUM_RES], "use_clamped_fape": [], }, + "block_delete_msa": { + "msa_fraction_per_block": 0.3, + "randomize_num_blocks": False, + "num_blocks": 5, + }, "masked_msa": { "profile_prob": 0.1, "same_prob": 0.1, @@ -355,6 +360,7 @@ def model_config( "predict": { "fixed_size": True, "subsample_templates": False, # We want top templates. + "block_delete_msa": False, "masked_msa_replace_fraction": 0.15, "max_msa_clusters": 512, "max_extra_msa": 1024, @@ -368,6 +374,7 @@ def model_config( "eval": { "fixed_size": True, "subsample_templates": False, # We want top templates. + "block_delete_msa": False, "masked_msa_replace_fraction": 0.15, "max_msa_clusters": 128, "max_extra_msa": 1024, @@ -381,6 +388,7 @@ def model_config( "train": { "fixed_size": True, "subsample_templates": True, + "block_delete_msa": True, "masked_msa_replace_fraction": 0.15, "max_msa_clusters": 128, "max_extra_msa": 1024, diff --git a/openfold/data/input_pipeline.py b/openfold/data/input_pipeline.py index 651e549e..dee7aa0b 100644 --- a/openfold/data/input_pipeline.py +++ b/openfold/data/input_pipeline.py @@ -71,6 +71,9 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): """Input pipeline data transformers that can be ensembled and averaged.""" transforms = [] + if mode_cfg.block_delete_msa: + transforms.append(data_transforms.block_delete_msa(common_cfg.block_delete_msa)) + if "max_distillation_msa_clusters" in mode_cfg: transforms.append( data_transforms.sample_msa_distillation( From b935639b98686998d6d78d1d81f1be431a32f71d Mon Sep 17 00:00:00 2001 From: Christina Floristean Date: Mon, 4 Dec 2023 13:02:13 -0500 Subject: [PATCH 2/2] Fix bugs in block deletion, disable for soloseq --- openfold/config.py | 2 ++ openfold/data/data_transforms.py | 29 +++++++++++++++++------------ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/openfold/config.py b/openfold/config.py index eaca89b7..4b1ea448 100644 --- a/openfold/config.py +++ b/openfold/config.py @@ -156,10 +156,12 @@ def model_config( elif name == "seqemb_initial_training": c.data.train.max_msa_clusters = 1 c.data.eval.max_msa_clusters = 1 + c.data.train.block_delete_msa = False c.data.train.max_distillation_msa_clusters = 1 elif name == "seqemb_finetuning": c.data.train.max_msa_clusters = 1 c.data.eval.max_msa_clusters = 1 + c.data.train.block_delete_msa = False c.data.train.max_distillation_msa_clusters = 1 c.data.train.crop_size = 384 c.loss.violation.weight = 1. diff --git a/openfold/data/data_transforms.py b/openfold/data/data_transforms.py index a3f27373..c2e20be1 100755 --- a/openfold/data/data_transforms.py +++ b/openfold/data/data_transforms.py @@ -253,28 +253,33 @@ def block_delete_msa(protein, config): * config.msa_fraction_per_block ).to(torch.int32) + if int(block_num_seq) == 0: + return protein + if config.randomize_num_blocks: - nb = torch.distributions.uniform.Uniform( - 0, config.num_blocks + 1 - ).sample() + nb = int(torch.randint( + low=0, + high=config.num_blocks + 1, + size=(1,), + device=protein["msa"].device, + )[0]) else: nb = config.num_blocks - del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb) - del_blocks = del_block_starts[:, None] + torch.range(block_num_seq) - del_blocks = torch.clip(del_blocks, 0, num_seq - 1) - del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0] + del_block_starts = torch.randint(low=1, high=num_seq, size=(nb,), device=protein["msa"].device) + del_blocks = del_block_starts[:, None] + torch.arange(start=0, end=block_num_seq) + del_blocks = torch.clip(del_blocks, 1, num_seq - 1) + del_indices = torch.unique(torch.reshape(del_blocks, [-1])) # Make sure we keep the original sequence - combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None])) + combined = torch.cat((torch.arange(start=0, end=num_seq), del_indices)).long() uniques, counts = combined.unique(return_counts=True) - difference = uniques[counts == 1] - intersection = uniques[counts > 1] - keep_indices = torch.squeeze(difference, 0) + keep_indices = uniques[counts == 1] + assert int(keep_indices[0]) == 0 for k in MSA_FEATURE_NAMES: if k in protein: - protein[k] = torch.gather(protein[k], keep_indices) + protein[k] = torch.index_select(protein[k], 0, keep_indices) return protein