diff --git a/.gitignore b/.gitignore
index fcc5213..cd52057 100644
--- a/.gitignore
+++ b/.gitignore
@@ -137,4 +137,9 @@ dmypy.json
.pyre/
-/wandb/
\ No newline at end of file
+/wandb/
+scripts/rgn2_models/
+scripts/sidechainnet_data
+scripts/wandb
+
+.idea/
diff --git a/rgn2_play.ipynb b/rgn2_play.ipynb
new file mode 100644
index 0000000..d29fe67
--- /dev/null
+++ b/rgn2_play.ipynb
@@ -0,0 +1,602 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "name": "rgn2_play.ipynb",
+ "provenance": [],
+ "collapsed_sections": [],
+ "machine_shape": "hm",
+ "authorship_tag": "ABX9TyMCCozHZ4MhRpJY40idz1IE",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "9Y-jmJk5S5sG"
+ },
+ "source": [
+ "* How would coevolution implicitly affect the language model training?\n",
+ "* Would kNN-style unsupervised learning useful for RGN2?"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "id": "bisOhBV5FBxp",
+ "outputId": "9a869ace-a3ff-4b40-9d8d-46b80077c364"
+ },
+ "source": [
+ "import IPython\n",
+ "from google.colab import output\n",
+ "\n",
+ "display(IPython.display.Javascript('''\n",
+ " function ClickConnect(){\n",
+ " btn = document.querySelector(\"colab-connect-button\")\n",
+ " if (btn != null){\n",
+ " console.log(\"Click colab-connect-button\"); \n",
+ " btn.click() \n",
+ " }\n",
+ " \n",
+ " btn = document.getElementById('ok')\n",
+ " if (btn != null){\n",
+ " console.log(\"Click reconnect\"); \n",
+ " btn.click() \n",
+ " }\n",
+ " }\n",
+ " \n",
+ "setInterval(ClickConnect,60000)\n",
+ "'''))\n",
+ "\n",
+ "print(\"Done.\")"
+ ],
+ "execution_count": 35,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/javascript": [
+ "\n",
+ " function ClickConnect(){\n",
+ " btn = document.querySelector(\"colab-connect-button\")\n",
+ " if (btn != null){\n",
+ " console.log(\"Click colab-connect-button\"); \n",
+ " btn.click() \n",
+ " }\n",
+ " \n",
+ " btn = document.getElementById('ok')\n",
+ " if (btn != null){\n",
+ " console.log(\"Click reconnect\"); \n",
+ " btn.click() \n",
+ " }\n",
+ " }\n",
+ " \n",
+ "setInterval(ClickConnect,60000)\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Done.\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "SO2-ZpV7SuFi"
+ },
+ "source": [
+ "!git clone https://github.com/hushuangwei/rgn2-replica.git"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "qS7DR2x_9kV_"
+ },
+ "source": [
+ "# https://blog.csdn.net/NEUdeep/article/details/115724826\n",
+ "!export PYTHONWARNINGS='ignore:semaphore_tracker:UserWarning'"
+ ],
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "r4w9hFGdUZmc"
+ },
+ "source": [
+ "!pip install wandb sidechainnet einops proDy tqdm datasets transformers x-transformers pytorch-lightning fair-esm En-transformer pytorch3d invariant_point_attention"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "CYVrT9vEBthE",
+ "outputId": "80e40f7b-ad1c-44a5-eab9-95abb1049203"
+ },
+ "source": [
+ "%cd rgn2-replica/scripts/"
+ ],
+ "execution_count": 4,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "/content/rgn2-replica/scripts\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "h6kBnPjeMxMw"
+ },
+ "source": [
+ "One can skip this google drive setting and download sidechainnet data directly but more slowly"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "rUVr-8p-at6T",
+ "outputId": "d11f57da-bae4-44a9-a6e8-6c0050e2a9f6"
+ },
+ "source": [
+ "from google.colab import drive\n",
+ "drive.mount('/content/drive')"
+ ],
+ "execution_count": 5,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Mounted at /content/drive\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "iKeXyuo6Z4-o"
+ },
+ "source": [
+ "!mkdir -p /content/rgn2-replica/scripts/sidechainnet_data/\n",
+ "!cp /content/drive/MyDrive/protein/sidechainnet_casp12_90.pkl /content/rgn2-replica/scripts/sidechainnet_data/"
+ ],
+ "execution_count": 6,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "t0JBGC4vHYe-",
+ "outputId": "eb9e3478-925d-4402-fa82-4a136d5da534"
+ },
+ "source": [
+ "!wandb login"
+ ],
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://wandb.ai/authorize\n",
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Paste an API key from your profile and hit enter: \n",
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Y8z-qNqGTlVh"
+ },
+ "source": [
+ "!nohup python train_rgn2.py --device cuda:0 --wb_entity hushuangwei \\\n",
+ " --wb_proj rgn2_replica --run_name RGN2_ipa_1e-4 \\\n",
+ " --min_len_valid 50 --casp_version 12 --scn_thinning 90 \\\n",
+ " --min_len 0 --max_len 384 --input_dropout 0.1 --num_layers 6 \\\n",
+ " --bidirectional 1 --layer_type LSTM --act aconc --num_recycles_train 8 \\\n",
+ " --refiner_args \"{\\\"refiner_type\\\": \\\"IPA\\\"}\" \\\n",
+ " > RGN2X_vanillaLSTM_full_run_logs.txt 2>&1 &"
+ ],
+ "execution_count": 33,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Vh2vtuZ-Uc72",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "b9cb8a21-b6ca-47fb-c657-dd7b676d5242"
+ },
+ "source": [
+ "!tail -f RGN2X_vanillaLSTM_full_run_logs.txt"
+ ],
+ "execution_count": 32,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "wandb: loss 0.08731\n",
+ "wandb: rmsd 9.9361\n",
+ "wandb: torsion_loss 1.76916\n",
+ "wandb: viol_loss 0.00342\n",
+ "wandb: \n",
+ "wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)\n",
+ "wandb: Synced RGN2_ipa_1e-4: https://wandb.ai/hushuangwei/rgn2_replica/runs/372ib897\n",
+ "wandb: Find logs at: ./wandb/run-20211010_110523-372ib897/logs/debug.log\n",
+ "wandb: \n",
+ "\n",
+ "^C\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "_tNr33PMAsiy"
+ },
+ "source": [
+ "!pkill -f train_rgn2.py"
+ ],
+ "execution_count": 27,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "KGrjhdZB0crD",
+ "outputId": "5088c897-2ef1-48d8-cb01-87926a6a9d8d"
+ },
+ "source": [
+ "!ps aux | grep train_rgn2.py"
+ ],
+ "execution_count": 61,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "root 2285 0.0 0.0 39200 6252 ? S 12:47 0:00 /bin/bash -c ps aux | grep train_rgn2.py\n",
+ "root 2287 0.0 0.0 38572 5180 ? S 12:47 0:00 grep train_rgn2.py\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Qa2fwnskPqRD",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "71e4390e-08bd-4348-d82c-cb45cba91408"
+ },
+ "source": [
+ "!nvidia-smi"
+ ],
+ "execution_count": 53,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Tue Oct 5 08:13:04 2021 \n",
+ "+-----------------------------------------------------------------------------+\n",
+ "| NVIDIA-SMI 470.74 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
+ "|-------------------------------+----------------------+----------------------+\n",
+ "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
+ "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
+ "| | | MIG M. |\n",
+ "|===============================+======================+======================|\n",
+ "| 0 Tesla V100-SXM2... Off | 00000000:00:04.0 Off | 0 |\n",
+ "| N/A 53C P0 103W / 300W | 11293MiB / 16160MiB | 59% Default |\n",
+ "| | | N/A |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ " \n",
+ "+-----------------------------------------------------------------------------+\n",
+ "| Processes: |\n",
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
+ "| ID ID Usage |\n",
+ "|=============================================================================|\n",
+ "| No running processes found |\n",
+ "+-----------------------------------------------------------------------------+\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jlgXayRhMa0v"
+ },
+ "source": [
+ "# below is unfinished"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "UWV02qQUNvFO"
+ },
+ "source": [
+ "import sidechainnet as scn\n",
+ "import random \n",
+ "\n",
+ "import sys\n",
+ "sys.path.append(\"..\")\n",
+ "\n",
+ "import torch\n",
+ "import py3Dmol\n",
+ "import esm\n",
+ "\n",
+ "from rgn2_replica import *\n",
+ "from rgn2_replica.rgn2 import *\n",
+ "from rgn2_replica.rgn2_utils import *\n",
+ "from rgn2_replica.rgn2_trainers import *"
+ ],
+ "execution_count": 56,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "o04uJC6CVZFi"
+ },
+ "source": [
+ "from sidechainnet.utils.sequence import ProteinVocabulary as VOCAB\n",
+ "VOCAB = VOCAB()"
+ ],
+ "execution_count": 78,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "q-y5nsoxwZSM",
+ "outputId": "67a7536a-a23d-4f37-9492-43d1c6f7d7ae"
+ },
+ "source": [
+ "set_seed(42)\n",
+ "dataloaders = scn.load(casp_version=12, thinning=90, with_pytorch=\"dataloaders\", \n",
+ " batch_size=1, dynamic_batching=False)"
+ ],
+ "execution_count": 81,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "SidechainNet was loaded from ./sidechainnet_data/sidechainnet_casp12_90.pkl.\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "OfTlIEmgK1C4"
+ },
+ "source": [
+ "# if torch.cuda.is_available():\n",
+ "# device = torch.device(\"cuda\")\n",
+ "# else:\n",
+ "# device = torch.device(\"cpu\")\n",
+ "device = \"cpu\"\n",
+ "model = RGN2_IPA(embedding_dim=1284).to(device)"
+ ],
+ "execution_count": 82,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "YkgieANeMKe_",
+ "outputId": "42073236-ac77-4284-9fcc-6d00d21fc002"
+ },
+ "source": [
+ "save_path = \"/content/rgn2-replica/scripts/rgn2_models/RGN2_ipa_1e-4@_32K.pt\"\n",
+ "model.load_state_dict(torch.load(save_path))\n",
+ "sum([p.numel() for p in model.parameters()])"
+ ],
+ "execution_count": 83,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "34216661"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 83
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "HgJDm6ohv0FO"
+ },
+ "source": [
+ "dataloaders[\"train\"].dataset\n",
+ "MIN_LEN_TEST = 70\n",
+ "MIN_LEN = 0\n",
+ "MAX_LEN = 512"
+ ],
+ "execution_count": 84,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "SjWn3Q3vOrtt"
+ },
+ "source": [
+ "embedder, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()\n",
+ "batch_converter = alphabet.get_batch_converter()"
+ ],
+ "execution_count": 85,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "wD28yPldv-xi"
+ },
+ "source": [
+ "embedder = embedder.to(device)"
+ ],
+ "execution_count": 86,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 398
+ },
+ "id": "0fICN-VLUwEq",
+ "outputId": "d91ecc41-92e2-4b99-d3b6-04aca1bb2b40"
+ },
+ "source": [
+ "### TEST\n",
+ "tic = time.time()\n",
+ "get_prot_test_ = mp_nerf.utils.get_prot( \n",
+ " dataloader_=dataloaders, \n",
+ " vocab_=VOCAB, # mp_nerf.utils.\n",
+ " min_len=MIN_LEN, max_len=MAX_LEN, \n",
+ " verbose=False, subset=\"test\"\n",
+ ")\n",
+ "# get num of unique, full-masked proteins\n",
+ "seqs = []\n",
+ "for i, prot_args in enumerate(dataloaders[\"test\"].dataset):\n",
+ " # (id, int_seq, mask, ... , str_seq)\n",
+ " length = len(prot_args[-1]) \n",
+ " if 0 < length < MAX_LEN and sum( prot_args[2] ) == length:\n",
+ " seqs.append( prot_args[-1] )\n",
+ "\n",
+ "metrics_stuff_test = predict(\n",
+ " get_prot_= get_prot_test_, \n",
+ " steps = len(set(seqs)), # 24 for MIN_LEN=70\n",
+ " model = model,\n",
+ " embedder = embedder, \n",
+ " return_preds = True,\n",
+ " log_every = 4,\n",
+ " accumulate_every = len(set(seqs)),\n",
+ " seed = 42, # 42\n",
+ " mode = \"fast_test\", # \"test\" # \"test\" is for AR, \"fast_test\" is for iterative\n",
+ " recycle_func = lambda x: 1, # 5 # 3 # 2 \n",
+ " wandbai = False,\n",
+ ")\n",
+ "preds_list_test, metrics_list_test, metrics_stats_test = metrics_stuff_test\n",
+ "print(\"\\n\", \"Test Results:\", sep=\"\")\n",
+ "for k,v in metrics_stats_test.items():\n",
+ " offset = \" \" * ( max(len(ki) for ki in metrics_stats_test.keys()) - len(k) )\n",
+ " print(k + offset, \":\", v)\n",
+ "print(\"\\n\")\n",
+ "print(\"Time taken: \", time.time()-tic, \"\\n\")"
+ ],
+ "execution_count": 90,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "torch.Size([3, 246, 2])\n"
+ ]
+ },
+ {
+ "output_type": "error",
+ "ename": "TypeError",
+ "evalue": "ignored",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"fast_test\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;31m# \"test\" # \"test\" is for AR, \"fast_test\" is for iterative\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mrecycle_func\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;31m# 5 # 3 # 2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 29\u001b[0;31m \u001b[0mwandbai\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 30\u001b[0m )\n\u001b[1;32m 31\u001b[0m \u001b[0mpreds_list_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics_list_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics_stats_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmetrics_stuff_test\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/content/rgn2-replica/rgn2_replica/rgn2_trainers.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(get_prot_, steps, model, embedder, return_preds, accumulate_every, log_every, seed, wandbai, recycle_func, mode)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mprots\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 334\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0membedder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0membedder\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 335\u001b[0;31m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecycle_func\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrecycle_func\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 336\u001b[0m )\n\u001b[1;32m 337\u001b[0m \u001b[0;31m# calculate metrics || calc loss terms || baselines for next-term: torsion=2, fape=0.95\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/content/rgn2-replica/rgn2_replica/rgn2_trainers.py\u001b[0m in \u001b[0;36mbatched_inference\u001b[0;34m(model, embedder, mode, device, recycle_func, config, *args)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;31m# don't pass angles info - just 0 at start (sin=0, cos=1)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mangles_input\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mangles_input\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m ], dim=-1)\n\u001b[0m\u001b[1;32m 104\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mTypeError\u001b[0m: expected Tensor as element 0 in argument 0, but got dict"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qovzXcFDPnvC"
+ },
+ "source": [
+ ""
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/rgn2_replica/embedders.py b/rgn2_replica/embedders.py
index c4443f6..295d0a6 100644
--- a/rgn2_replica/embedders.py
+++ b/rgn2_replica/embedders.py
@@ -31,8 +31,8 @@ def forward(self, aa_seq):
torch.Tensor (B, L) according to MP-NeRF encoding
"""
# format
- if isinstance(aa_seqs, torch.Tensor):
- aa_seq = ids_to_embed_input(to_cpu(aa_seqs).tolist())
+ if isinstance(aa_seq, torch.Tensor):
+ aa_seq = ids_to_embed_input(to_cpu(aa_seq).tolist())
with torch.no_grad():
tokenized_seq = self.tokenizer(aa_seq, context_length=len(aa_seq), return_mask=False)
diff --git a/rgn2_replica/mp_nerf/LICENSE b/rgn2_replica/mp_nerf/LICENSE
new file mode 100644
index 0000000..1132c9a
--- /dev/null
+++ b/rgn2_replica/mp_nerf/LICENSE
@@ -0,0 +1,421 @@
+
+Copyright (c) 2021, Eric Alcaide
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ 1. Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following license.
+ 2. Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the following
+ license in the documentation and or other materials provided
+ with the distribution.
+ 3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived
+ from this software without specific prior written permission.
+
+
+Attribution-NonCommercial-NoDerivatives 4.0 International
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial-NoDerivatives 4.0
+International Public License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial-NoDerivatives 4.0 International Public
+License ("Public License"). To the extent this Public License may be
+interpreted as a contract, You are granted the Licensed Rights in
+consideration of Your acceptance of these terms and conditions, and the
+Licensor grants You such rights in consideration of benefits the
+Licensor receives from making the Licensed Material available under
+these terms and conditions.
+
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+
+ c. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ d. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ e. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ f. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ g. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ h. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ i. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ j. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ k. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce and reproduce, but not Share, Adapted Material
+ for NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material, You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ For the avoidance of doubt, You do not have permission under
+ this Public License to Share Adapted Material.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only and provided You do not Share Adapted Material;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material; and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.” The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
diff --git a/rgn2_replica/mp_nerf/__init__.py b/rgn2_replica/mp_nerf/__init__.py
new file mode 100644
index 0000000..2fd09c0
--- /dev/null
+++ b/rgn2_replica/mp_nerf/__init__.py
@@ -0,0 +1,5 @@
+# from rgn2_replica.mp_nerf import *
+# from rgn2_replica.mp_nerf import *
+import rgn2_replica.mp_nerf.utils
+import rgn2_replica.mp_nerf.proteins
+import rgn2_replica.mp_nerf.ml_utils
\ No newline at end of file
diff --git a/rgn2_replica/mp_nerf/kb_proteins.py b/rgn2_replica/mp_nerf/kb_proteins.py
new file mode 100644
index 0000000..6d5e17f
--- /dev/null
+++ b/rgn2_replica/mp_nerf/kb_proteins.py
@@ -0,0 +1,846 @@
+# Author: Eric Alcaide
+
+# A substantial part has been borrowed from
+# https://github.com/jonathanking/sidechainnet
+#
+# Here's the License for it:
+#
+# Copyright 2020 Jonathan King
+# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
+# following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
+# disclaimer in the documentation and/or other materials provided with the distribution.
+#
+# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote
+# products derived from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
+# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
+# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
+# THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import numpy as np
+
+#########################
+### FROM SIDECHAINNET ###
+#########################
+
+# modified by considering rigid bodies in sidechains (remove extra torsions)
+
+SC_BUILD_INFO = {
+ 'A': {
+ 'angles-names': ['N-CA-CB'],
+ 'angles-types': ['N -CX-CT'],
+ 'angles-vals': [1.9146261894377796],
+ 'atom-names': ['CB'],
+ 'bonds-names': ['CA-CB'],
+ 'bonds-types': ['CX-CT'],
+ 'bonds-vals': [1.526],
+ 'torsion-names': ['C-N-CA-CB'],
+ 'torsion-types': ['C -N -CX-CT'],
+ 'torsion-vals': ['p'],
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4]],
+ },
+
+ 'R': {
+ 'angles-names': [
+ 'N-CA-CB', 'CA-CB-CG', 'CB-CG-CD', 'CG-CD-NE', 'CD-NE-CZ', 'NE-CZ-NH1',
+ 'NE-CZ-NH2'
+ ],
+ 'angles-types': [
+ 'N -CX-C8', 'CX-C8-C8', 'C8-C8-C8', 'C8-C8-N2', 'C8-N2-CA', 'N2-CA-N2',
+ 'N2-CA-N2'
+ ],
+ 'angles-vals': [
+ 1.9146261894377796, 1.911135530933791, 1.911135530933791, 1.9408061282176945,
+ 2.150245638457014, 2.0943951023931953, 2.0943951023931953
+ ],
+ 'atom-names': ['CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2'],
+ 'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD', 'CD-NE', 'NE-CZ', 'CZ-NH1', 'CZ-NH2'],
+ 'bonds-types': ['CX-C8', 'C8-C8', 'C8-C8', 'C8-N2', 'N2-CA', 'CA-N2', 'CA-N2'],
+ 'bonds-vals': [1.526, 1.526, 1.526, 1.463, 1.34, 1.34, 1.34],
+ 'torsion-names': [
+ 'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD', 'CB-CG-CD-NE', 'CG-CD-NE-CZ',
+ 'CD-NE-CZ-NH1', 'CD-NE-CZ-NH2'
+ ],
+ 'torsion-types': [
+ 'C -N -CX-C8', 'N -CX-C8-C8', 'CX-C8-C8-C8', 'C8-C8-C8-N2', 'C8-C8-N2-CA',
+ 'C8-N2-CA-N2', 'C8-N2-CA-N2'
+ ],
+ 'torsion-vals': ['p', 'p', 'p', 'p', 'p', 0., 3.141592],
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6], [5,6,7], [6,7,8]],
+ },
+
+ 'N': {
+ 'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-OD1', 'CB-CG-ND2'],
+ 'angles-types': ['N -CX-2C', 'CX-2C-C ', '2C-C -O ', '2C-C -N '],
+ 'angles-vals': [
+ 1.9146261894377796, 1.9390607989657, 2.101376419401173, 2.035053907825388
+ ],
+ 'atom-names': ['CB', 'CG', 'OD1', 'ND2'],
+ 'bonds-names': ['CA-CB', 'CB-CG', 'CG-OD1', 'CG-ND2'],
+ 'bonds-types': ['CX-2C', '2C-C ', 'C -O ', 'C -N '],
+ 'bonds-vals': [1.526, 1.522, 1.229, 1.335],
+ 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-OD1', 'CA-CB-CG-ND2'],
+ 'torsion-types': ['C -N -CX-2C', 'N -CX-2C-C ', 'CX-2C-C -O ', 'CX-2C-C -N '],
+ 'torsion-vals': ['p', 'p', 'p', 'i'],
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6]],
+ },
+
+ 'D': {
+ 'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-OD1', 'CB-CG-OD2'],
+ 'angles-types': ['N -CX-2C', 'CX-2C-CO', '2C-CO-O2', '2C-CO-O2'],
+ 'angles-vals': [
+ 1.9146261894377796, 1.9390607989657, 2.0420352248333655, 2.0420352248333655
+ ],
+ 'atom-names': ['CB', 'CG', 'OD1', 'OD2'],
+ 'bonds-names': ['CA-CB', 'CB-CG', 'CG-OD1', 'CG-OD2'],
+ 'bonds-types': ['CX-2C', '2C-CO', 'CO-O2', 'CO-O2'],
+ 'bonds-vals': [1.526, 1.522, 1.25, 1.25],
+ 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-OD1', 'CA-CB-CG-OD2'],
+ 'torsion-types': ['C -N -CX-2C', 'N -CX-2C-CO', 'CX-2C-CO-O2', 'CX-2C-CO-O2'],
+ 'torsion-vals': ['p', 'p', 'p', 'i'],
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6]],
+ },
+
+ 'C': {
+ 'angles-names': ['N-CA-CB', 'CA-CB-SG'],
+ 'angles-types': ['N -CX-2C', 'CX-2C-SH'],
+ 'angles-vals': [1.9146261894377796, 1.8954275676658419],
+ 'atom-names': ['CB', 'SG'],
+ 'bonds-names': ['CA-CB', 'CB-SG'],
+ 'bonds-types': ['CX-2C', '2C-SH'],
+ 'bonds-vals': [1.526, 1.81],
+ 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-SG'],
+ 'torsion-types': ['C -N -CX-2C', 'N -CX-2C-SH'],
+ 'torsion-vals': ['p', 'p'],
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5]],
+ },
+
+ 'Q': {
+ 'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD', 'CG-CD-OE1', 'CG-CD-NE2'],
+ 'angles-types': ['N -CX-2C', 'CX-2C-2C', '2C-2C-C ', '2C-C -O ', '2C-C -N '],
+ 'angles-vals': [
+ 1.9146261894377796, 1.911135530933791, 1.9390607989657, 2.101376419401173,
+ 2.035053907825388
+ ],
+ 'atom-names': ['CB', 'CG', 'CD', 'OE1', 'NE2'],
+ 'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD', 'CD-OE1', 'CD-NE2'],
+ 'bonds-types': ['CX-2C', '2C-2C', '2C-C ', 'C -O ', 'C -N '],
+ 'bonds-vals': [1.526, 1.526, 1.522, 1.229, 1.335],
+ 'torsion-names': [
+ 'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD', 'CB-CG-CD-OE1', 'CB-CG-CD-NE2'
+ ],
+ 'torsion-types': [
+ 'C -N -CX-2C', 'N -CX-2C-2C', 'CX-2C-2C-C ', '2C-2C-C -O ', '2C-2C-C -N '
+ ],
+ 'torsion-vals': ['p', 'p', 'p', 'p', 'i'],
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6], [5,6,7]],
+ },
+
+ 'E': {
+ 'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD', 'CG-CD-OE1', 'CG-CD-OE2'],
+ 'angles-types': ['N -CX-2C', 'CX-2C-2C', '2C-2C-CO', '2C-CO-O2', '2C-CO-O2'],
+ 'angles-vals': [
+ 1.9146261894377796, 1.911135530933791, 1.9390607989657, 2.0420352248333655,
+ 2.0420352248333655
+ ],
+ 'atom-names': ['CB', 'CG', 'CD', 'OE1', 'OE2'],
+ 'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD', 'CD-OE1', 'CD-OE2'],
+ 'bonds-types': ['CX-2C', '2C-2C', '2C-CO', 'CO-O2', 'CO-O2'],
+ 'bonds-vals': [1.526, 1.526, 1.522, 1.25, 1.25],
+ 'torsion-names': [
+ 'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD', 'CB-CG-CD-OE1', 'CB-CG-CD-OE2'
+ ],
+ 'torsion-types': [
+ 'C -N -CX-2C', 'N -CX-2C-2C', 'CX-2C-2C-CO', '2C-2C-CO-O2', '2C-2C-CO-O2'
+ ],
+ 'torsion-vals': ['p', 'p', 'p', 'p', 'i'],
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6], [5,6,7]],
+ },
+
+ 'G': {
+ 'angles-names': [],
+ 'angles-types': [],
+ 'angles-vals': [],
+ 'atom-names': [],
+ 'bonds-names': [],
+ 'bonds-types': [],
+ 'bonds-vals': [],
+ 'torsion-names': [],
+ 'torsion-types': [],
+ 'torsion-vals': [],
+ 'rigid-frames-idxs': [[0,1,2]],
+ },
+
+ 'H': {
+ 'angles-names': [
+ 'N-CA-CB', 'CA-CB-CG', 'CB-CG-ND1', 'CG-ND1-CE1', 'ND1-CE1-NE2', 'CE1-NE2-CD2'
+ ],
+ 'angles-types': [
+ 'N -CX-CT', 'CX-CT-CC', 'CT-CC-NA', 'CC-NA-CR', 'NA-CR-NB', 'CR-NB-CV'
+ ],
+ 'angles-vals': [
+ 1.9146261894377796, 1.9739673840055867, 2.0943951023931953,
+ 1.8849555921538759, 1.8849555921538759, 1.8849555921538759
+ ],
+ 'atom-names': ['CB', 'CG', 'ND1', 'CE1', 'NE2', 'CD2'],
+ 'bonds-names': ['CA-CB', 'CB-CG', 'CG-ND1', 'ND1-CE1', 'CE1-NE2', 'NE2-CD2'],
+ 'bonds-types': ['CX-CT', 'CT-CC', 'CC-NA', 'NA-CR', 'CR-NB', 'NB-CV'],
+ 'bonds-vals': [1.526, 1.504, 1.385, 1.343, 1.335, 1.394],
+ 'torsion-names': [
+ 'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-ND1', 'CB-CG-ND1-CE1', 'CG-ND1-CE1-NE2',
+ 'ND1-CE1-NE2-CD2'
+ ],
+ 'torsion-types': [
+ 'C -N -CX-CT', 'N -CX-CT-CC', 'CX-CT-CC-NA', 'CT-CC-NA-CR', 'CC-NA-CR-NB',
+ 'NA-CR-NB-CV'
+ ],
+ 'torsion-vals': ['p', 'p', 'p', 3.141592653589793, 0.0, 0.0],
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6]],
+ },
+
+ 'I': {
+ 'angles-names': ['N-CA-CB', 'CA-CB-CG1', 'CB-CG1-CD1', 'CA-CB-CG2'],
+ 'angles-types': ['N -CX-3C', 'CX-3C-2C', '3C-2C-CT', 'CX-3C-CT'],
+ 'angles-vals': [
+ 1.9146261894377796, 1.911135530933791, 1.911135530933791, 1.911135530933791
+ ],
+ 'atom-names': ['CB', 'CG1', 'CD1', 'CG2'],
+ 'bonds-names': ['CA-CB', 'CB-CG1', 'CG1-CD1', 'CB-CG2'],
+ 'bonds-types': ['CX-3C', '3C-2C', '2C-CT', '3C-CT'],
+ 'bonds-vals': [1.526, 1.526, 1.526, 1.526],
+ 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG1', 'CA-CB-CG1-CD1', 'N-CA-CB-CG2'],
+ 'torsion-types': ['C -N -CX-3C', 'N -CX-3C-2C', 'CX-3C-2C-CT', 'N -CX-3C-CT'],
+ 'torsion-vals': ['p', 'p', 'p', -2.1315], # last one was 'p' in the original - but cg1-cg2 = "2.133"
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,7]],
+ },
+
+ 'L': {
+ 'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD1', 'CB-CG-CD2'],
+ 'angles-types': ['N -CX-2C', 'CX-2C-3C', '2C-3C-CT', '2C-3C-CT'],
+ 'angles-vals': [
+ 1.9146261894377796, 1.911135530933791, 1.911135530933791, 1.911135530933791
+ ],
+ 'atom-names': ['CB', 'CG', 'CD1', 'CD2'],
+ 'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD1', 'CG-CD2'],
+ 'bonds-types': ['CX-2C', '2C-3C', '3C-CT', '3C-CT'],
+ 'bonds-vals': [1.526, 1.526, 1.526, 1.526],
+ 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD1', 'CA-CB-CG-CD2'],
+ 'torsion-types': ['C -N -CX-2C', 'N -CX-2C-3C', 'CX-2C-3C-CT', 'CX-2C-3C-CT'],
+ # extra torsion is in negative bc in mask construction, previous angle is summed.
+ 'torsion-vals': ['p', 'p', 'p', 2.1315], # last one was 'p' in the original - but cd1-cd2 = "-2.130"
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6]],
+ },
+
+ 'K': {
+ 'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD', 'CG-CD-CE', 'CD-CE-NZ'],
+ 'angles-types': ['N -CX-C8', 'CX-C8-C8', 'C8-C8-C8', 'C8-C8-C8', 'C8-C8-N3'],
+ 'angles-vals': [
+ 1.9146261894377796, 1.911135530933791, 1.911135530933791, 1.911135530933791,
+ 1.9408061282176945
+ ],
+ 'atom-names': ['CB', 'CG', 'CD', 'CE', 'NZ'],
+ 'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD', 'CD-CE', 'CE-NZ'],
+ 'bonds-types': ['CX-C8', 'C8-C8', 'C8-C8', 'C8-C8', 'C8-N3'],
+ 'bonds-vals': [1.526, 1.526, 1.526, 1.526, 1.471],
+ 'torsion-names': [
+ 'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD', 'CB-CG-CD-CE', 'CG-CD-CE-NZ'
+ ],
+ 'torsion-types': [
+ 'C -N -CX-C8', 'N -CX-C8-C8', 'CX-C8-C8-C8', 'C8-C8-C8-C8', 'C8-C8-C8-N3'
+ ],
+ 'torsion-vals': ['p', 'p', 'p', 'p', 'p'],
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6], [5,6,7], [6,7,8]],
+ },
+
+ 'M': {
+ 'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-SD', 'CG-SD-CE'],
+ 'angles-types': ['N -CX-2C', 'CX-2C-2C', '2C-2C-S ', '2C-S -CT'],
+ 'angles-vals': [
+ 1.9146261894377796, 1.911135530933791, 2.0018926520374962, 1.726130630222392
+ ],
+ 'atom-names': ['CB', 'CG', 'SD', 'CE'],
+ 'bonds-names': ['CA-CB', 'CB-CG', 'CG-SD', 'SD-CE'],
+ 'bonds-types': ['CX-2C', '2C-2C', '2C-S ', 'S -CT'],
+ 'bonds-vals': [1.526, 1.526, 1.81, 1.81],
+ 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-SD', 'CB-CG-SD-CE'],
+ 'torsion-types': ['C -N -CX-2C', 'N -CX-2C-2C', 'CX-2C-2C-S ', '2C-2C-S -CT'],
+ 'torsion-vals': ['p', 'p', 'p', 'p'],
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6], [5,6,7]],
+ },
+
+ 'F': {
+ 'angles-names': [
+ 'N-CA-CB', 'CA-CB-CG', 'CB-CG-CD1', 'CG-CD1-CE1', 'CD1-CE1-CZ', 'CE1-CZ-CE2',
+ 'CZ-CE2-CD2'
+ ],
+ 'angles-types': [
+ 'N -CX-CT', 'CX-CT-CA', 'CT-CA-CA', 'CA-CA-CA', 'CA-CA-CA', 'CA-CA-CA',
+ 'CA-CA-CA'
+ ],
+ 'angles-vals': [
+ 1.9146261894377796, 1.9896753472735358, 2.0943951023931953,
+ 2.0943951023931953, 2.0943951023931953, 2.0943951023931953, 2.0943951023931953
+ ],
+ 'atom-names': ['CB', 'CG', 'CD1', 'CE1', 'CZ', 'CE2', 'CD2'],
+ 'bonds-names': [
+ 'CA-CB', 'CB-CG', 'CG-CD1', 'CD1-CE1', 'CE1-CZ', 'CZ-CE2', 'CE2-CD2'
+ ],
+ 'bonds-types': ['CX-CT', 'CT-CA', 'CA-CA', 'CA-CA', 'CA-CA', 'CA-CA', 'CA-CA'],
+ 'bonds-vals': [1.526, 1.51, 1.4, 1.4, 1.4, 1.4, 1.4],
+ 'torsion-names': [
+ 'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD1', 'CB-CG-CD1-CE1', 'CG-CD1-CE1-CZ',
+ 'CD1-CE1-CZ-CE2', 'CE1-CZ-CE2-CD2'
+ ],
+ 'torsion-types': [
+ 'C -N -CX-CT', 'N -CX-CT-CA', 'CX-CT-CA-CA', 'CT-CA-CA-CA', 'CA-CA-CA-CA',
+ 'CA-CA-CA-CA', 'CA-CA-CA-CA'
+ ],
+ 'torsion-vals': ['p', 'p', 'p', 3.141592653589793, 0.0, 0.0, 0.0],
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6]],
+ },
+
+ 'P': {
+ 'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD'],
+ 'angles-types': ['N -CX-CT', 'CX-CT-CT', 'CT-CT-CT'],
+ 'angles-vals': [1.9146261894377796, 1.911135530933791, 1.911135530933791],
+ 'atom-names': ['CB', 'CG', 'CD'],
+ 'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD'],
+ 'bonds-types': ['CX-CT', 'CT-CT', 'CT-CT'],
+ 'bonds-vals': [1.526, 1.526, 1.526],
+ 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD'],
+ 'torsion-types': ['C -N -CX-CT', 'N -CX-CT-CT', 'CX-CT-CT-CT'],
+ 'torsion-vals': ['p', 'p', 'p'],
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6]],
+ },
+
+ 'S': {
+ 'angles-names': ['N-CA-CB', 'CA-CB-OG'],
+ 'angles-types': ['N -CX-2C', 'CX-2C-OH'],
+ 'angles-vals': [1.9146261894377796, 1.911135530933791],
+ 'atom-names': ['CB', 'OG'],
+ 'bonds-names': ['CA-CB', 'CB-OG'],
+ 'bonds-types': ['CX-2C', '2C-OH'],
+ 'bonds-vals': [1.526, 1.41],
+ 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-OG'],
+ 'torsion-types': ['C -N -CX-2C', 'N -CX-2C-OH'],
+ 'torsion-vals': ['p', 'p'],
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5]],
+ },
+
+ 'T': {
+ 'angles-names': ['N-CA-CB', 'CA-CB-OG1', 'CA-CB-CG2'],
+ 'angles-types': ['N -CX-3C', 'CX-3C-OH', 'CX-3C-CT'],
+ 'angles-vals': [1.9146261894377796, 1.911135530933791, 1.911135530933791],
+ 'atom-names': ['CB', 'OG1', 'CG2'],
+ 'bonds-names': ['CA-CB', 'CB-OG1', 'CB-CG2'],
+ 'bonds-types': ['CX-3C', '3C-OH', '3C-CT'],
+ 'bonds-vals': [1.526, 1.41, 1.526],
+ 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-OG1', 'N-CA-CB-CG2'],
+ 'torsion-types': ['C -N -CX-3C', 'N -CX-3C-OH', 'N -CX-3C-CT'],
+ 'torsion-vals': ['p', 'p', 'p'],
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5]],
+ },
+
+ 'W': {
+ 'angles-names': [
+ 'N-CA-CB', 'CA-CB-CG', 'CB-CG-CD1', 'CG-CD1-NE1', 'CD1-NE1-CE2',
+ 'NE1-CE2-CZ2', 'CE2-CZ2-CH2', 'CZ2-CH2-CZ3', 'CH2-CZ3-CE3', 'CZ3-CE3-CD2'
+ ],
+ 'angles-types': [
+ 'N -CX-CT', 'CX-CT-C*', 'CT-C*-CW', 'C*-CW-NA', 'CW-NA-CN', 'NA-CN-CA',
+ 'CN-CA-CA', 'CA-CA-CA', 'CA-CA-CA', 'CA-CA-CB'
+ ],
+ 'angles-vals': [
+ 1.9146261894377796, 2.0176006153054447, 2.181661564992912, 1.8971728969178363,
+ 1.9477874452256716, 2.3177972466484698, 2.0943951023931953,
+ 2.0943951023931953, 2.0943951023931953, 2.0943951023931953
+ ],
+ 'atom-names': [
+ 'CB', 'CG', 'CD1', 'NE1', 'CE2', 'CZ2', 'CH2', 'CZ3', 'CE3', 'CD2'
+ ],
+ 'bonds-names': [
+ 'CA-CB', 'CB-CG', 'CG-CD1', 'CD1-NE1', 'NE1-CE2', 'CE2-CZ2', 'CZ2-CH2',
+ 'CH2-CZ3', 'CZ3-CE3', 'CE3-CD2'
+ ],
+ 'bonds-types': [
+ 'CX-CT', 'CT-C*', 'C*-CW', 'CW-NA', 'NA-CN', 'CN-CA', 'CA-CA', 'CA-CA',
+ 'CA-CA', 'CA-CB'
+ ],
+ 'bonds-vals': [1.526, 1.495, 1.352, 1.381, 1.38, 1.4, 1.4, 1.4, 1.4, 1.404],
+ 'torsion-names': [
+ 'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD1', 'CB-CG-CD1-NE1', 'CG-CD1-NE1-CE2',
+ 'CD1-NE1-CE2-CZ2', 'NE1-CE2-CZ2-CH2', 'CE2-CZ2-CH2-CZ3', 'CZ2-CH2-CZ3-CE3',
+ 'CH2-CZ3-CE3-CD2'
+ ],
+ 'torsion-types': [
+ 'C -N -CX-CT', 'N -CX-CT-C*', 'CX-CT-C*-CW', 'CT-C*-CW-NA', 'C*-CW-NA-CN',
+ 'CW-NA-CN-CA', 'NA-CN-CA-CA', 'CN-CA-CA-CA', 'CA-CA-CA-CA', 'CA-CA-CA-CB'
+ ],
+ 'torsion-vals': [
+ 'p', 'p', 'p', 3.141592653589793, 0.0, 3.141592653589793, 3.141592653589793,
+ 0.0, 0.0, 0.0
+ ],
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6]]
+ },
+
+ 'Y': {
+ 'angles-names': [
+ 'N-CA-CB', 'CA-CB-CG', 'CB-CG-CD1', 'CG-CD1-CE1', 'CD1-CE1-CZ', 'CE1-CZ-OH',
+ 'CE1-CZ-CE2', 'CZ-CE2-CD2'
+ ],
+ 'angles-types': [
+ 'N -CX-CT', 'CX-CT-CA', 'CT-CA-CA', 'CA-CA-CA', 'CA-CA-C ', 'CA-C -OH',
+ 'CA-C -CA', 'C -CA-CA'
+ ],
+ 'angles-vals': [
+ 1.9146261894377796, 1.9896753472735358, 2.0943951023931953,
+ 2.0943951023931953, 2.0943951023931953, 2.0943951023931953,
+ 2.0943951023931953, 2.0943951023931953
+ ],
+ 'atom-names': ['CB', 'CG', 'CD1', 'CE1', 'CZ', 'OH', 'CE2', 'CD2'],
+ 'bonds-names': [
+ 'CA-CB', 'CB-CG', 'CG-CD1', 'CD1-CE1', 'CE1-CZ', 'CZ-OH', 'CZ-CE2', 'CE2-CD2'
+ ],
+ 'bonds-types': [
+ 'CX-CT', 'CT-CA', 'CA-CA', 'CA-CA', 'CA-C ', 'C -OH', 'C -CA', 'CA-CA'
+ ],
+ 'bonds-vals': [1.526, 1.51, 1.4, 1.4, 1.409, 1.364, 1.409, 1.4],
+ 'torsion-names': [
+ 'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD1', 'CB-CG-CD1-CE1', 'CG-CD1-CE1-CZ',
+ 'CD1-CE1-CZ-OH', 'CD1-CE1-CZ-CE2', 'CE1-CZ-CE2-CD2'
+ ],
+ 'torsion-types': [
+ 'C -N -CX-CT', 'N -CX-CT-CA', 'CX-CT-CA-CA', 'CT-CA-CA-CA', 'CA-CA-CA-C ',
+ 'CA-CA-C -OH', 'CA-CA-C -CA', 'CA-C -CA-CA'
+ ],
+ 'torsion-vals': [
+ 'p', 'p', 'p', 3.141592653589793, 0.0, 3.141592653589793, 0.0, 0.0
+ ],
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6]],
+ },
+
+ 'V': {
+ 'angles-names': ['N-CA-CB', 'CA-CB-CG1', 'CA-CB-CG2'],
+ 'angles-types': ['N -CX-3C', 'CX-3C-CT', 'CX-3C-CT'],
+ 'angles-vals': [1.9146261894377796, 1.911135530933791, 1.911135530933791],
+ 'atom-names': ['CB', 'CG1', 'CG2'],
+ 'bonds-names': ['CA-CB', 'CB-CG1', 'CB-CG2'],
+ 'bonds-types': ['CX-3C', '3C-CT', '3C-CT'],
+ 'bonds-vals': [1.526, 1.526, 1.526],
+ 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG1', 'N-CA-CB-CG2'],
+ 'torsion-types': ['C -N -CX-3C', 'N -CX-3C-CT', 'N -CX-3C-CT'],
+ 'torsion-vals': ['p', 'p', 'p'],
+ 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5]]
+ },
+
+ '_': {
+ 'angles-names': [],
+ 'angles-types': [],
+ 'angles-vals': [],
+ 'atom-names': [],
+ 'bonds-names': [],
+ 'bonds-types': [],
+ 'bonds-vals': [],
+ 'torsion-names': [],
+ 'torsion-types': [],
+ 'torsion-vals': [],
+ 'rigid-frames-idxs': [[]],
+ }
+}
+
+BB_BUILD_INFO = {
+ "BONDLENS": {
+ # the updated is according to crystal data from 1DPE_1_A and validated with other structures
+ # the commented is the sidechainnet one
+ 'n-ca': 1.4664931, # 1.442,
+ 'ca-c': 1.524119, # 1.498,
+ 'c-n': 1.3289373, # 1.379,
+ 'c-o': 1.229, # From parm10.dat || huge variability according to structures
+ # we get 1.3389416 from 1DPE_1_A but also 1.2289 from 2F2H_d2f2hf1
+ 'c-oh': 1.364
+ },
+ # From parm10.dat, for OXT
+ # For placing oxygens
+ "BONDANGS": {
+ 'ca-c-o': 2.0944, # Approximated to be 2pi / 3; parm10.dat says 2.0350539
+ 'ca-c-oh': 2.0944,
+ 'ca-c-n': 2.03,
+ 'n-ca-c': 1.94,
+ 'c-n-ca': 2.08,
+ },
+ # Equal to 'ca-c-o', for OXT
+ "BONDTORSIONS": {
+ 'n-ca-c-n': -0.785398163, # psi (-44 deg, bimodal distro, pick one)
+ 'c-n-ca-c': -1.3962634015954636, # phi (-80 deg, bimodal distro, pick one)
+ 'ca-n-c-ca': 3.141592, # omega (180 deg - https://doi.org/10.1016/j.jmb.2005.01.065)
+ 'n-ca-c-o': -2.406 # oxygen
+ } # A simple approximation, not meant to be exact.
+}
+
+
+# numbers follow the same order as sidechainnet atoms
+SCN_CONNECT = {
+ 'A': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4]]
+ },
+ 'R': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
+ [6,7], [7,8], [8,9], [8,10]]
+ },
+ 'N': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
+ [5,7]]
+ },
+ 'D': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
+ [5,7]]
+ },
+ 'C': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5]]
+ },
+ 'Q': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
+ [6,7], [6,8]]
+ },
+ 'E': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
+ [6,7], [6,8]]
+ },
+ 'G': {
+ 'bonds': [[0,1], [1,2], [2,3]]
+ },
+ 'H': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
+ [6,7], [7,8], [8,9], [5,9]]
+ },
+ 'I': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
+ [4,7]]
+ },
+ 'L': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
+ [5,7]]
+ },
+ 'K': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
+ [6,7], [7,8]]
+ },
+ 'M': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
+ [6,7]]
+ },
+ 'F': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
+ [6,7], [7,8], [8,9], [9,10], [5,10]]
+ },
+ 'P': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
+ [0,6]]
+ },
+ 'S': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5]]
+ },
+ 'T': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [4,6]]
+ },
+ 'W': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
+ [6,7], [7,8], [8,9], [9,10], [10,11], [11,12],
+ [12, 13], [5,13], [8,13]]
+ },
+ 'Y': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6],
+ [6,7], [7,8], [8,9], [8,10], [10,11], [5,11]]
+ },
+ 'V': {
+ 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [4,6]]
+ },
+ '_': {
+ 'bonds': []
+ }
+ }
+
+# from: https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf
+AMBIGUOUS = {
+ "D": {"names": [["OD1", "OD2"]],
+ "indexs": [[6, 7]],
+ },
+ "E": {"names": [["OE1", "OE2"]],
+ "indexs": [[7, 8]],
+ },
+ "F": {"names": [["CD1", "CD2"], ["CE1", "CE2"]],
+ "indexs": [[6, 10], [7, 9]],
+ },
+ "Y": {"names": [["CD1", "CD2"], ["CE1", "CE2"]],
+ "indexs": [[6,10], [7,9]],
+ },
+}
+
+
+# AA subst mat
+BLOSUM = {
+ "A" : [4.0, -1.0, -2.0, -2.0, 0.0, -1.0, -1.0, 0.0, -2.0, -1.0, -1.0, -1.0, -1.0, -2.0, -1.0, 1.0, 0.0, -3.0, -2.0, 0.0, 0.0],
+ "C" : [-1.0, 5.0, 0.0, -2.0, -3.0, 1.0, 0.0, -2.0, 0.0, -3.0, -2.0, 2.0, -1.0, -3.0, -2.0, -1.0, -1.0, -3.0, -2.0, -3.0, 0.0],
+ "D" : [-2.0, 0.0, 6.0, 1.0, -3.0, 0.0, 0.0, 0.0, 1.0, -3.0, -3.0, 0.0, -2.0, -3.0, -2.0, 1.0, 0.0, -4.0, -2.0, -3.0, 0.0],
+ "E" : [-2.0, -2.0, 1.0, 6.0, -3.0, 0.0, 2.0, -1.0, -1.0, -3.0, -4.0, -1.0, -3.0, -3.0, -1.0, 0.0, -1.0, -4.0, -3.0, -3.0, 0.0],
+ "F" : [0.0, -3.0, -3.0, -3.0, 9.0, -3.0, -4.0, -3.0, -3.0, -1.0, -1.0, -3.0, -1.0, -2.0, -3.0, -1.0, -1.0, -2.0, -2.0, -1.0, 0.0],
+ "G" : [-1.0, 1.0, 0.0, 0.0, -3.0, 5.0, 2.0, -2.0, 0.0, -3.0, -2.0, 1.0, 0.0, -3.0, -1.0, 0.0, -1.0, -2.0, -1.0, -2.0, 0.0],
+ "H" : [-1.0, 0.0, 0.0, 2.0, -4.0, 2.0, 5.0, -2.0, 0.0, -3.0, -3.0, 1.0, -2.0, -3.0, -1.0, 0.0, -1.0, -3.0, -2.0, -2.0, 0.0],
+ "I" : [0.0, -2.0, 0.0, -1.0, -3.0, -2.0, -2.0, 6.0, -2.0, -4.0, -4.0, -2.0, -3.0, -3.0, -2.0, 0.0, -2.0, -2.0, -3.0, -3.0, 0.0],
+ "K" : [-2.0, 0.0, 1.0, -1.0, -3.0, 0.0, 0.0, -2.0, 8.0, -3.0, -3.0, -1.0, -2.0, -1.0, -2.0, -1.0, -2.0, -2.0, 2.0, -3.0, 0.0],
+ "L" : [-1.0, -3.0, -3.0, -3.0, -1.0, -3.0, -3.0, -4.0, -3.0, 4.0, 2.0, -3.0, 1.0, 0.0, -3.0, -2.0, -1.0, -3.0, -1.0, 3.0, 0.0],
+ "M" : [-1.0, -2.0, -3.0, -4.0, -1.0, -2.0, -3.0, -4.0, -3.0, 2.0, 4.0, -2.0, 2.0, 0.0, -3.0, -2.0, -1.0, -2.0, -1.0, 1.0, 0.0],
+ "N" : [-1.0, 2.0, 0.0, -1.0, -3.0, 1.0, 1.0, -2.0, -1.0, -3.0, -2.0, 5.0, -1.0, -3.0, -1.0, 0.0, -1.0, -3.0, -2.0, -2.0, 0.0],
+ "P" : [-1.0, -1.0, -2.0, -3.0, -1.0, 0.0, -2.0, -3.0, -2.0, 1.0, 2.0, -1.0, 5.0, 0.0, -2.0, -1.0, -1.0, -1.0, -1.0, 1.0, 0.0],
+ "Q" : [-2.0, -3.0, -3.0, -3.0, -2.0, -3.0, -3.0, -3.0, -1.0, 0.0, 0.0, -3.0, 0.0, 6.0, -4.0, -2.0, -2.0, 1.0, 3.0, -1.0, 0.0],
+ "R" : [-1.0, -2.0, -2.0, -1.0, -3.0, -1.0, -1.0, -2.0, -2.0, -3.0, -3.0, -1.0, -2.0, -4.0, 7.0, -1.0, -1.0, -4.0, -3.0, -2.0, 0.0],
+ "S" : [1.0, -1.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, -1.0, -2.0, -2.0, 0.0, -1.0, -2.0, -1.0, 4.0, 1.0, -3.0, -2.0, -2.0, 0.0],
+ "T" : [0.0, -1.0, 0.0, -1.0, -1.0, -1.0, -1.0, -2.0, -2.0, -1.0, -1.0, -1.0, -1.0, -2.0, -1.0, 1.0, 5.0, -2.0, -2.0, 0.0, 0.0],
+ "V" : [-3.0, -3.0, -4.0, -4.0, -2.0, -2.0, -3.0, -2.0, -2.0, -3.0, -2.0, -3.0, -1.0, 1.0, -4.0, -3.0, -2.0, 11.0, 2.0, -3.0, 0.0],
+ "W" : [-2.0, -2.0, -2.0, -3.0, -2.0, -1.0, -2.0, -3.0, 2.0, -1.0, -1.0, -2.0, -1.0, 3.0, -3.0, -2.0, -2.0, 2.0, 7.0, -1.0, 0.0],
+ "Y" : [0.0, -3.0, -3.0, -3.0, -1.0, -2.0, -2.0, -3.0, -3.0, 3.0, 1.0, -2.0, 1.0, -1.0, -2.0, -2.0, 0.0, -3.0, -1.0, 4.0, 0.0],
+ "_" : [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
+}
+
+
+# modified manually to match the mode
+MP3SC_INFO = {
+ 'A': {'CB': {'bond_lens': 1.5260003, 'bond_angs': 1.9146265, 'bond_dihedral': 2.848366}
+ },
+ 'R': {'CB': {'bond_lens': 1.5260003, 'bond_angs': 1.9146265, 'bond_dihedral': 2.6976738},
+ 'CG': {'bond_lens': 1.5260003, 'bond_angs': 1.9111352, 'bond_dihedral': -1.2},
+ 'CD': {'bond_lens': 1.5260003, 'bond_angs': 1.9111352, 'bond_dihedral': -3.141592},
+ 'NE': {'bond_lens': 1.463, 'bond_angs': 1.9408059, 'bond_dihedral': -3.141592},
+ 'CZ': {'bond_lens': 1.34, 'bond_angs': 2.1502457, 'bond_dihedral': -3.141592},
+ 'NH1': {'bond_lens': 1.34, 'bond_angs': 2.094395, 'bond_dihedral': 0.},
+ 'NH2': {'bond_lens': 1.34, 'bond_angs': 2.094395, 'bond_dihedral': -3.141592}
+ },
+ 'N': {'CB': {'bond_lens': 1.5260003, 'bond_angs': 1.9146265, 'bond_dihedral': 2.8416245},
+ 'CG': {'bond_lens': 1.5219998, 'bond_angs': 1.9390607, 'bond_dihedral': -1.15},
+ 'OD1': {'bond_lens': 1.229, 'bond_angs': 2.101376, 'bond_dihedral': -1.}, # spread out w/ mean at -1
+ 'ND2': {'bond_lens': 1.3349999, 'bond_angs': 2.0350537, 'bond_dihedral': 2.14} # spread out with mean at -4
+ },
+ 'D': {'CB': {'bond_lens': 1.526, 'bond_angs': 1.9146265, 'bond_dihedral': 2.7741134},
+ 'CG': {'bond_lens': 1.522, 'bond_angs': 1.9390608, 'bond_dihedral': -1.07},
+ 'OD1': {'bond_lens': 1.25, 'bond_angs': 2.0420356, 'bond_dihedral': -0.2678593},
+ 'OD2': {'bond_lens': 1.25, 'bond_angs': 2.0420356, 'bond_dihedral': 2.95}
+ },
+ 'C': {'CB': {'bond_lens': 1.5259998, 'bond_angs': 1.9146262, 'bond_dihedral': 2.553627},
+ 'SG': {'bond_lens': 1.8099997, 'bond_angs': 1.8954275, 'bond_dihedral': -1.07}
+ },
+ 'Q': {'CB': {'bond_lens': 1.5260003, 'bond_angs': 1.9146266, 'bond_dihedral': 2.7262106},
+ 'CG': {'bond_lens': 1.5260003, 'bond_angs': 1.9111353, 'bond_dihedral': -1.075},
+ 'CD': {'bond_lens': 1.5219998, 'bond_angs': 1.9390606, 'bond_dihedral': -3.141592},
+ 'OE1': {'bond_lens': 1.229, 'bond_angs': 2.101376, 'bond_dihedral': -1}, # bimodal at -1, +1
+ 'NE2': {'bond_lens': 1.3349998, 'bond_angs': 2.0350537, 'bond_dihedral': 2.14} # bimodal at -2, -4
+ },
+ 'E': {'CB': {'bond_lens': 1.5260003, 'bond_angs': 1.9146267, 'bond_dihedral': 2.7813723},
+ 'CG': {'bond_lens': 1.5260003, 'bond_angs': 1.9111352, 'bond_dihedral': -1.07}, # bimodal at -1.07, 3.14
+ 'CD': {'bond_lens': 1.5219998, 'bond_angs': 1.9390606, 'bond_dihedral': -3.0907722155200403},
+ 'OE1': {'bond_lens': 1.25, 'bond_angs': 2.0420356, 'bond_dihedral': 0.003740118}, # spread out btween -1,1
+ 'OE2': {'bond_lens': 1.25, 'bond_angs': 2.0420356, 'bond_dihedral': -3.1378527} # spread out btween -4.3, -2.14
+ },
+ 'G': {},
+ 'H': {'CB': {'bond_lens': 1.5259998, 'bond_angs': 1.9146264, 'bond_dihedral': 2.614421},
+ 'CG': {'bond_lens': 1.5039998, 'bond_angs': 1.9739674, 'bond_dihedral': -1.05},
+ 'ND1': {'bond_lens': 1.3850001, 'bond_angs': 2.094395, 'bond_dihedral': -1.41}, # bimodal at -1.4, 1.4
+ 'CE1': {'bond_lens': 1.3430002, 'bond_angs': 1.8849558, 'bond_dihedral': 3.14},
+ 'NE2': {'bond_lens': 1.335, 'bond_angs': 1.8849558, 'bond_dihedral': 0.0},
+ 'CD2': {'bond_lens': 1.3940002, 'bond_angs': 1.8849558, 'bond_dihedral': 0.0}
+ },
+ 'I': {'CB': {'bond_lens': 1.526, 'bond_angs': 1.9146265, 'bond_dihedral': 2.5604365},
+ 'CG1': {'bond_lens': 1.526, 'bond_angs': 1.9111353, 'bond_dihedral': -1.025},
+ 'CD1': {'bond_lens': 1.526, 'bond_angs': 1.9111353, 'bond_dihedral': -3.0667439142810267},
+ 'CG2': {'bond_lens': 1.526, 'bond_angs': 1.9111353, 'bond_dihedral': -3.1225884596454065}
+ },
+ 'L': {'CB': {'bond_lens': 1.5260003, 'bond_angs': 1.9146265, 'bond_dihedral': 2.711971},
+ 'CG': {'bond_lens': 1.5260003, 'bond_angs': 1.9111352, 'bond_dihedral': -1.15},
+ 'CD1': {'bond_lens': 1.5260003, 'bond_angs': 1.9111352, 'bond_dihedral': 3.14},
+ 'CD2': {'bond_lens': 1.5260003, 'bond_angs': 1.9111352, 'bond_dihedral': -1.05}
+ },
+ 'K': {'CB': {'bond_lens': 1.526, 'bond_angs': 1.9146266, 'bond_dihedral': 2.7441595},
+ 'CG': {'bond_lens': 1.526, 'bond_angs': 1.9111353, 'bond_dihedral': -1.15},
+ 'CD': {'bond_lens': 1.526, 'bond_angs': 1.9111353, 'bond_dihedral': -3.09},
+ 'CE': {'bond_lens': 1.526, 'bond_angs': 1.9111353, 'bond_dihedral': 3.092959},
+ 'NZ': {'bond_lens': 1.4710001, 'bond_angs': 1.940806, 'bond_dihedral': 3.0515378}
+ },
+ 'M': {'CB': {'bond_lens': 1.526, 'bond_angs': 1.9146264, 'bond_dihedral': 2.7051392},
+ 'CG': {'bond_lens': 1.526, 'bond_angs': 1.9111354, 'bond_dihedral': -1.1},
+ 'SD': {'bond_lens': 1.8099998, 'bond_angs': 2.001892, 'bond_dihedral': 3.1411812}, # bimodal at 0, 3.14
+ 'CE': {'bond_lens': 1.8099998, 'bond_angs': 1.7261307, 'bond_dihedral': -0.048235133} # trimodal at -1.41, 0, 1.41
+ },
+ 'F': {'CB': {'bond_lens': 1.5260003, 'bond_angs': 1.9146266, 'bond_dihedral': 2.545154},
+ 'CG': {'bond_lens': 1.5100001, 'bond_angs': 1.9896755, 'bond_dihedral': -1.2}, # bimodal at -1, 3.14
+ 'CD1': {'bond_lens': 1.3999997, 'bond_angs': 2.094395, 'bond_dihedral': 1.41}, # bimodal -1.41, 1.41
+ 'CE1': {'bond_lens': 1.3999997, 'bond_angs': 2.094395, 'bond_dihedral': 3.141592},
+ 'CZ': {'bond_lens': 1.3999997, 'bond_angs': 2.094395, 'bond_dihedral': 0.0},
+ 'CE2': {'bond_lens': 1.3999997, 'bond_angs': 2.094395, 'bond_dihedral': 0.0},
+ 'CD2': {'bond_lens': 1.3999997, 'bond_angs': 2.094395, 'bond_dihedral': 0.0}
+ },
+ 'P': {'CB': {'bond_lens': 1.5260001, 'bond_angs': 1.9146266, 'bond_dihedral': 3.141592},
+ 'CG': {'bond_lens': 1.5260001, 'bond_angs': 1.9111352, 'bond_dihedral': -0.707}, # bimodal at -0.7, 0.7
+ 'CD': {'bond_lens': 1.5260001, 'bond_angs': 1.9111352, 'bond_dihedral': 0.85} # bimodal at -0.85, 0.85
+ },
+ 'S': {'CB': {'bond_lens': 1.5260001, 'bond_angs': 1.9146266, 'bond_dihedral': 2.6017702},
+ 'OG': {'bond_lens': 1.41, 'bond_angs': 1.9111352, 'bond_dihedral': 1.1}
+ },
+ 'T': {'CB': {'bond_lens': 1.5260001, 'bond_angs': 1.9146265, 'bond_dihedral': 2.55},
+ 'OG1': {'bond_lens': 1.4099998, 'bond_angs': 1.9111353, 'bond_dihedral': -1.07}, # bimodal at -1 and +1
+ 'CG2': {'bond_lens': 1.5260001, 'bond_angs': 1.9111353, 'bond_dihedral': -3.05} # bimodal at -1 and -3
+ },
+ 'W': {'CB': {'bond_lens': 1.526, 'bond_angs': 1.9146266, 'bond_dihedral': 3.141592},
+ 'CG': {'bond_lens': 1.4950002, 'bond_angs': 2.0176008, 'bond_dihedral': -1.2},
+ 'CD1': {'bond_lens': 1.3520001, 'bond_angs': 2.1816616, 'bond_dihedral': 1.53},
+ 'NE1': {'bond_lens': 1.3810003, 'bond_angs': 1.8971729, 'bond_dihedral': 3.141592},
+ 'CE2': {'bond_lens': 1.3799998, 'bond_angs': 1.9477878, 'bond_dihedral': 0.0},
+ 'CZ2': {'bond_lens': 1.3999999, 'bond_angs': 2.317797, 'bond_dihedral': 3.141592},
+ 'CH2': {'bond_lens': 1.3999999, 'bond_angs': 2.094395, 'bond_dihedral': 3.141592},
+ 'CZ3': {'bond_lens': 1.3999999, 'bond_angs': 2.094395, 'bond_dihedral': 0.0},
+ 'CE3': {'bond_lens': 1.3999999, 'bond_angs': 2.094395, 'bond_dihedral': 0.0},
+ 'CD2': {'bond_lens': 1.404, 'bond_angs': 2.094395, 'bond_dihedral': 0.0}
+ },
+ 'Y': {'CB': {'bond_lens': 1.5260003, 'bond_angs': 1.9146266, 'bond_dihedral': 3.1},
+ 'CG': {'bond_lens': 1.5100001, 'bond_angs': 1.9896754, 'bond_dihedral': -1.1},
+ 'CD1': {'bond_lens': 1.3999997, 'bond_angs': 2.094395, 'bond_dihedral': 1.36},
+ 'CE1': {'bond_lens': 1.3999997, 'bond_angs': 2.094395, 'bond_dihedral': 3.141592},
+ 'CZ': {'bond_lens': 1.4090003, 'bond_angs': 2.094395, 'bond_dihedral': 0.0},
+ 'OH': {'bond_lens': 1.3640002, 'bond_angs': 2.094395, 'bond_dihedral': 3.141592},
+ 'CE2': {'bond_lens': 1.4090003, 'bond_angs': 2.094395, 'bond_dihedral': 0.0},
+ 'CD2': {'bond_lens': 1.3999997, 'bond_angs': 2.094395, 'bond_dihedral': 0.0}
+ },
+ 'V': {'CB': {'bond_lens': 1.5260003, 'bond_angs': 1.9146266, 'bond_dihedral': 2.55},
+ 'CG1': {'bond_lens': 1.5260003, 'bond_angs': 1.9111352, 'bond_dihedral': 3.141592},
+ 'CG2': {'bond_lens': 1.5260003, 'bond_angs': 1.9111352, 'bond_dihedral': -1.1}
+ },
+
+ '_': {}
+}
+
+# experimentally checked distances
+FF = {"MIN_DISTS": {1: 1.180, # shortest =N or =O bond
+ 2: 2.138, # N-N in histidine group
+ 3: 2.380}, # N-N in backbone (N-CA-C-N)
+ "MAX_DISTS": {i: 1.840*i for i in range(1, 5+1)} # 1.84 is longest -S bond found,
+ }
+
+ATOM_TOKEN_IDS = set(["", "N", "CA", "C", "O"])
+ATOM_TOKEN_IDS = {k: i for i,k in enumerate(sorted(
+ ATOM_TOKEN_IDS.union( set(
+ [name for k,v in SC_BUILD_INFO.items() for name in v["atom-names"]]
+ ) )
+ ))}
+
+#################
+##### DOERS #####
+#################
+
+def make_cloud_mask(aa):
+ """ relevent points will be 1. paddings will be 0. """
+ mask = np.zeros(14)
+ if aa != "_":
+ n_atoms = 4+len( SC_BUILD_INFO[aa]["atom-names"] )
+ mask[:n_atoms] = True
+ return mask
+
+def make_bond_mask(aa):
+ """ Gives the length of the bond originating each atom. """
+ mask = np.zeros(14)
+ # backbone
+ if aa != "_":
+ mask[0] = BB_BUILD_INFO["BONDLENS"]['c-n']
+ mask[1] = BB_BUILD_INFO["BONDLENS"]['n-ca']
+ mask[2] = BB_BUILD_INFO["BONDLENS"]['ca-c']
+ mask[3] = BB_BUILD_INFO["BONDLENS"]['c-o']
+ # sidechain - except padding token
+ if aa in SC_BUILD_INFO.keys():
+ for i,bond in enumerate(SC_BUILD_INFO[aa]['bonds-vals']):
+ mask[4+i] = bond
+ return mask
+
+def make_theta_mask(aa):
+ """ Gives the theta of the bond originating each atom. """
+ mask = np.zeros(14)
+ # backbone
+ if aa != "_":
+ mask[0] = BB_BUILD_INFO["BONDANGS"]['ca-c-n'] # nitrogen
+ mask[1] = BB_BUILD_INFO["BONDANGS"]['c-n-ca'] # c_alpha
+ mask[2] = BB_BUILD_INFO["BONDANGS"]['n-ca-c'] # carbon
+ mask[3] = BB_BUILD_INFO["BONDANGS"]['ca-c-o'] # oxygen
+ # sidechain
+ for i,theta in enumerate(SC_BUILD_INFO[aa]['angles-vals']):
+ mask[4+i] = theta
+ return mask
+
+def make_torsion_mask(aa, fill=False):
+ """ Gives the dihedral of the bond originating each atom. """
+ mask = np.zeros(14)
+ if aa != "_":
+ # backbone
+ mask[0] = BB_BUILD_INFO["BONDTORSIONS"]['n-ca-c-n'] # psi
+ mask[1] = BB_BUILD_INFO["BONDTORSIONS"]['ca-n-c-ca'] # omega
+ mask[2] = BB_BUILD_INFO["BONDTORSIONS"]['c-n-ca-c'] # psi
+ mask[3] = BB_BUILD_INFO["BONDTORSIONS"]['n-ca-c-o'] # oxygen
+ # sidechain
+ for i, torsion in enumerate(SC_BUILD_INFO[aa]['torsion-vals']):
+ if fill:
+ mask[4+i] = MP3SC_INFO[aa][ SC_BUILD_INFO[aa]["atom-names"][i] ]["bond_dihedral"]
+ else:
+ # https://github.com/jonathanking/sidechainnet/blob/master/sidechainnet/structure/StructureBuilder.py#L372
+ # 999 is an anotation -- change later || same for 555
+ mask[4+i] = np.nan if torsion == 'p' else 999 if torsion == "i" else torsion
+ return mask
+
+def make_idx_mask(aa):
+ """ Gives the idxs of the 3 previous points. """
+ mask = np.zeros((11, 3))
+ if aa != "_":
+ # backbone
+ mask[0, :] = np.arange(3)
+ # sidechain
+ mapper = {"N": 0, "CA": 1, "C":2, "CB": 4}
+ for i, torsion in enumerate(SC_BUILD_INFO[aa]['torsion-names']):
+ # get all the atoms forming the dihedral
+ torsions = [x.rstrip(" ") for x in torsion.split("-")]
+ # for each atom
+ for n, torsion in enumerate(torsions[:-1]):
+ # get the index of the atom in the coords array
+ loc = mapper[torsion] if torsion in mapper.keys() else 4 + SC_BUILD_INFO[aa]['atom-names'].index(torsion)
+ # set position to index
+ mask[i+1][n] = loc
+ return mask
+
+def make_atom_token_mask(aa):
+ """ Return the tokens for each atom in the aa. """
+ mask = np.zeros(14)
+ # get atom id
+ if aa != "_":
+ atom_list = ["N", "CA", "C", "O"] + SC_BUILD_INFO[ aa ]["atom-names"]
+ for i,atom in enumerate(atom_list):
+ mask[i] = ATOM_TOKEN_IDS[atom]
+ return mask
+
+
+###################
+##### GETTERS #####
+###################
+INDEX2AAS = "ACDEFGHIKLMNPQRSTVWY_"
+AAS2INDEX = {aa:i for i,aa in enumerate(INDEX2AAS)}
+SUPREME_INFO = {k: {"cloud_mask": make_cloud_mask(k),
+ "bond_mask": make_bond_mask(k),
+ "theta_mask": make_theta_mask(k),
+ "torsion_mask": make_torsion_mask(k),
+ "torsion_mask_filled": make_torsion_mask(k, fill=True),
+ "idx_mask": make_idx_mask(k),
+ "atom_token_mask": make_atom_token_mask(k),
+ "rigid_idx_mask": SC_BUILD_INFO[k]['rigid-frames-idxs'],
+ }
+ for k in INDEX2AAS}
+
diff --git a/rgn2_replica/mp_nerf/massive_pnerf.py b/rgn2_replica/mp_nerf/massive_pnerf.py
new file mode 100644
index 0000000..cf0d43d
--- /dev/null
+++ b/rgn2_replica/mp_nerf/massive_pnerf.py
@@ -0,0 +1,67 @@
+import time
+import numpy as np
+# diff ml
+import torch
+from einops import repeat
+
+
+def get_axis_matrix(a, b, c, norm=True):
+ """ Gets an orthonomal basis as a matrix of [e1, e2, e3].
+ Useful for constructing rotation matrices between planes
+ according to the first answer here:
+ https://math.stackexchange.com/questions/1876615/rotation-matrix-from-plane-a-to-b
+ Inputs:
+ * a: (batch, 3) or (3, ). point(s) of the plane
+ * b: (batch, 3) or (3, ). point(s) of the plane
+ * c: (batch, 3) or (3, ). point(s) of the plane
+ Outputs: orthonomal basis as a matrix of [e1, e2, e3]. calculated as:
+ * e1_ = (c-b)
+ * e2_proto = (b-a)
+ * e3_ = e1_ ^ e2_proto
+ * e2_ = e3_ ^ e1_
+ * basis = normalize_by_vectors( [e1_, e2_, e3_] )
+ Note: Could be done more by Grahm-Schmidt and extend to N-dimensions
+ but this is faster and more intuitive for 3D.
+ """
+ v1_ = c - b
+ v2_ = b - a
+ v3_ = torch.cross(v1_, v2_, dim=-1)
+ v2_ready = torch.cross(v3_, v1_, dim=-1)
+ basis = torch.stack([v1_, v2_ready, v3_], dim=-2)
+ # normalize if needed
+ if norm:
+ return basis / torch.norm(basis, dim=-1, keepdim=True)
+ return basis
+
+
+
+def mp_nerf_torch(a, b, c, l, theta, chi):
+ """ Custom Natural extension of Reference Frame.
+ Inputs:
+ * a: (batch, 3) or (3,). point(s) of the plane, not connected to d
+ * b: (batch, 3) or (3,). point(s) of the plane, not connected to d
+ * c: (batch, 3) or (3,). point(s) of the plane, connected to d
+ * theta: (batch,) or (float). angle(s) between b-c-d
+ * chi: (batch,) or float. dihedral angle(s) between the a-b-c and b-c-d planes
+ Outputs: d (batch, 3) or (float). the next point in the sequence, linked to c
+ """
+ # safety check
+ if not ( (-np.pi <= theta) * (theta <= np.pi) ).all().item():
+ raise ValueError(f"theta(s) must be in radians and in [-pi, pi]. theta(s) = {theta}")
+ # calc vecs
+ ba = b-a
+ cb = c-b
+ # calc rotation matrix. based on plane normals and normalized
+ n_plane = torch.cross(ba, cb, dim=-1)
+ n_plane_ = torch.cross(n_plane, cb, dim=-1)
+ rotate = torch.stack([cb, n_plane_, n_plane], dim=-1)
+ rotate /= torch.norm(rotate, dim=-2, keepdim=True)
+ # calc proto point, rotate. add (-1 for sidechainnet convention)
+ # https://github.com/jonathanking/sidechainnet/issues/14
+ d = torch.stack([-torch.cos(theta),
+ torch.sin(theta) * torch.cos(chi),
+ torch.sin(theta) * torch.sin(chi)], dim=-1).unsqueeze(-1)
+ # extend base point, set length
+ return c + l.unsqueeze(-1) * torch.matmul(rotate, d).squeeze()
+
+
diff --git a/rgn2_replica/mp_nerf/ml_utils.py b/rgn2_replica/mp_nerf/ml_utils.py
new file mode 100644
index 0000000..fb5cf10
--- /dev/null
+++ b/rgn2_replica/mp_nerf/ml_utils.py
@@ -0,0 +1,435 @@
+# Author: Eric Alcaide
+
+# module
+import torch
+# from rgn2_replica.mp_nerf.utils import *
+from rgn2_replica.mp_nerf.massive_pnerf import *
+from rgn2_replica.mp_nerf.kb_proteins import *
+from rgn2_replica.mp_nerf.proteins import *
+from einops import rearrange, repeat
+
+def scn_atom_embedd(seq_list):
+ """ Returns the token for each atom in the aa seq.
+ Inputs:
+ * seq_list: list of FASTA sequences. same length
+ """
+ batch_tokens = []
+ # do loop in cpu
+ for i,seq in enumerate(seq_list):
+ batch_tokens.append( torch.tensor([SUPREME_INFO[aa]["atom_token_mask"] \
+ for aa in seq]) )
+ batch_tokens = torch.stack(batch_tokens, dim=0).long()
+ return batch_tokens
+
+
+def chain2atoms(x, mask=None, c=3):
+ """ Expand from (L, other) to (L, C, other). """
+ wrap = repeat( x, 'l ... -> l c ...', c=c )
+ if mask is not None:
+ return wrap[mask]
+ return wrap
+
+
+######################
+# from: https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf
+
+def rename_symmetric_atoms(pred_coors, true_coors, seq_list, cloud_mask, pred_feats=None):
+ """ Corrects ambiguous atoms (due to 180 torsions - ambiguous sidechains).
+ Inputs:
+ * pred_coors: (batch, L, 14, 3) float. sidechainnet format (see mp_nerf.kb_proteins)
+ * true_coors: (batch, L, 14, 3) float. sidechainnet format (see mp_nerf.kb_proteins)
+ * seq_list: list of FASTA sequences
+ * cloud_mask: (batch, L, 14) bool. mask for present atoms
+ * pred_feats: (batch, L, 14, D) optional. atom-wise predicted features
+
+ Warning! A coordinate might be missing. TODO:
+ Outputs: pred_coors, pred_feats
+ """
+ aux_cloud_mask = cloud_mask.clone() # will be manipulated
+
+ for i,seq in enumerate(seq_list):
+ for aa, pairs in AMBIGUOUS.items():
+ # indexes of aas in chain - check coords are given for aa
+ amb_idxs = np.array(pairs["indexs"]).flatten().tolist()
+ idxs = torch.tensor([
+ k for k,s in enumerate(seq) if s==aa and \
+ k in set( torch.nonzero(aux_cloud_mask[i, :, amb_idxs].sum(dim=-1)).tolist()[0] )
+ ]).long()
+ # check if any AAs matching
+ if idxs.shape[0] == 0:
+ continue
+ # get indexes of non-ambiguous
+ aux_cloud_mask[i, idxs, amb_idxs] = False
+ non_amb_idx = torch.nonzero(aux_cloud_mask[i, idxs[0]]).tolist()
+ for a, pair in enumerate(pairs["indexs"]):
+ # calc distances
+ d_ij_pred = torch.cdist(pred_coors[ i, idxs, pair ], pred_coors[i, idxs, non_amb_idx], p=2) # 2, N
+ d_ij_true = torch.cdist(true_coors[ i, idxs, pair+pair[::-1] ], true_coors[i, idxs, non_amb_idx], p=2) # 2, 2N
+ # see if alternative is better (less distance)
+ idxs_to_change = ( (d_ij_pred - d_ij_true[2:]).sum(dim=-1) < (d_ij_pred - d_ij_true[:2]).sum(dim=-1) ).nonzero()
+ # change those
+ pred_coors[i, idxs[idxs_to_change], pair] = pred_coors[i, idxs[idxs_to_change], pair[::-1]]
+ if pred_feats is not None:
+ pred_feats[i, idxs[idxs_to_change], pair] = pred_feats[i, idxs[idxs_to_change], pair[::-1]]
+
+ return pred_coors, pred_feats
+
+
+def angle_to_point_in_circum(angles):
+ """ Converts an angle to a point in the unit circumference.
+ Inputs:
+ * angles: tensor of (any) shape.
+ Outputs: (any, 2)
+ """
+ # ensure no last dummy dim
+ if len(angles.shape) == 0:
+ angles = angles.unsqueeze(0)
+ elif angles.shape[-1] == 1 and len(angles.shape) > 1 :
+ angles = angles[..., 0]
+
+ return torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1)
+
+def point_in_circum_to_angle(points):
+ """ Converts a point in the circumference to an angle
+ Inputs:
+ * poits: (any, 2)
+ Outputs: (any)
+ """
+ # ensure first dim
+ if len(points.shape) == 1:
+ points = points.unsqueeze(0)
+
+ return torch.atan2(points[..., points.shape[-1] // 2:],
+ points[..., :points.shape[-1] // 2] )
+
+
+def torsion_angle_loss(pred_torsions=None, true_torsions=None,
+ pred_points=None, true_points=None,
+ alt_true_points=None, alt_true_torsions=None,
+ coeff=2., norm_coeff=1e-2, angle_mask=None):
+ """ Computes a loss on the angles as the cosine of the difference.
+ Equivalent to an L2 on the unit circle.
+ Due to angle periodicity, for angle inputs, calculate the
+ disparity on both sides.
+ Alternative truths should only be passed if not previous renaming.
+ Inputs:
+ * pred_torsions: ( (B), L, X ) float. Predicted torsion angles.(-pi, pi)
+ Same format as sidechainnet.
+ * true_torsions: ( (B), L, X ) true torsion angles. (-pi, pi)
+ * pred_points: ( (B), L, X, 2) float. Predicted points in circum.
+ * true_points: ( (B), L, X, 2) float. true points in circum.
+ * alt_true_torsions: ( (B), L, X ) alt true torsion angles. (-pi, pi)
+ * alt_true_points: ( (B), L, X, 2) float. alt true points in circum.
+ * coeff: float. weight coefficient
+ * norm_coeff: float. coefficient for norm term. avoids big outputs.
+ * angle_mask: ((B), L, (X)) bool. Masks the non-existing angles.
+ Outputs: ( (B), L*X_masked ) 2*cosine difference + 0.02*norm
+ """
+ # convert to sin·cos rep if not available
+ if pred_torsions is not None and pred_points is None:
+ pred_points = angle_to_point_in_circum(pred_torsions)
+ if true_torsions is not None and true_points is None:
+ true_points = angle_to_point_in_circum(true_torsions)
+ if alt_true_torsions is not None and alt_true_points is None:
+ alt_true_points = angle_to_point_in_circum(alt_true_torsions)
+
+ # calc norm of angles
+ norm = torch.norm(pred_points, dim=-1)
+ angle_norm_loss = norm_coeff * (1-norm).abs()
+
+ # do L2 on unit circle
+ pred_points = pred_points / norm.unsqueeze(-1)
+ torsion_loss = torch.pow(pred_points - true_points, 2).sum(dim=-1)
+
+ if alt_true_points is not None:
+ torsion_loss = torch.minimum(
+ torsion_loss,
+ torch.pow(pred_points - alt_true_points, 2).sum(dim=-1)
+ )
+ if coeff != 2.:
+ torsion_loss *= coeff/2
+
+ if angle_mask is None:
+ angle_mask = torch.ones(*pred_points.shape[:-1], dtype=torch.bool)
+
+ return (torsion_loss + angle_norm_loss)[angle_mask]
+
+
+def fape_torch(pred_coords, true_coords, max_val=10., d_clamp=10., l_func=None,
+ partial=None, seq_list=None, rot_mats_g=None, max_points=10000):
+ """ Computes the Frame-Aligned Point Error. Scaled 0 <= FAPE <= 1
+ Even if computed only on C-alphas, all backbone atoms (N-CA-C)
+ must be passed to build the frames.
+ Inputs:
+ * pred_coords: (B, L, C, 3) or (B, (l c), 3) predicted coordinates.
+ * true_coords: (B, L, C, 3) or (B, (l c), 3) ground truth coordinates.
+ * max_val: float. number to divide by - the final loss
+ * d_clamp: float. the radius due to L1 usage
+ * l_func: function. allow for options other than l1 (consider dRMSD maybe)
+ * partial: str or None. one of ["c_alpha"].
+ * seq_list: list of strs (FASTA sequences). to calculate rigid bodies' indexs.
+ Defaults to C-alpha if not passed.
+ * rot_mats_g: optional. List of n_seqs x (N_frames, 3, 3) rotation matrices.
+ * max_points: int. maximum points to rotate at once.
+ the higher, the more batching allowed.
+ Outputs: (B, N_atoms)
+ """
+ fape_store = []
+ if l_func is None:
+ l_func = lambda x,y,eps=1e-7,sup=d_clamp: (((x-y)**2).sum(dim=-1) + \
+ eps).sqrt().clamp(0, sup)
+ # for chain
+ for s in range(pred_coords.shape[0]):
+ fape_store.append(0)
+ cloud_mask = (torch.abs(true_coords[s]).sum(dim=-1) != 0)
+ # center both structures
+ pred_center = pred_coords[s] - pred_coords[s, cloud_mask].mean(dim=0, keepdim=True)
+ true_center = true_coords[s] - true_coords[s, cloud_mask].mean(dim=0, keepdim=True)
+ # convert to (B, L*C, 3)
+ pred_center = rearrange(pred_center, 'l c d -> (l c) d')
+ true_center = rearrange(true_center, 'l c d -> (l c) d')
+ mask_center = rearrange(cloud_mask, 'l c -> (l c)')
+ # get frames and conversions - same scheme as in mp_nerf proteins' concat of monomers
+ if rot_mats_g is None:
+ rigid_idxs = scn_rigid_index_mask(seq_list[s], c_alpha=partial=="c_alpha")
+ true_frames = get_axis_matrix(*true_center[rigid_idxs], norm=True)
+ pred_frames = get_axis_matrix(*pred_center[rigid_idxs], norm=True)
+ rot_mats = torch.matmul(torch.transpose(pred_frames, -1, -2), true_frames).detach()
+ else:
+ rot_mats = rot_mats_g[s]
+
+ # calculate loss only on c_alphas
+ if partial is not None:
+ mask_center = torch.zeros_like(mask_center, dtype=torch.bool)
+ if partial == "c_alpha": # only keep c-alphas
+ mask_center[np.arange(0, pred_coords.shape[1]) * 14 + 1] = \
+ mask_center[np.arange(0, pred_coords.shape[1]) * 14 + 1] + True
+ else: # only keep backbone(+cb) frames' atoms
+ mask_center[rigid_idxs] = mask_center[rigid_idxs] + True
+
+ pred_center = pred_center[mask_center]
+ true_center = true_center[mask_center]
+
+ # return pred_center, true_center, mask_center, rot_mats
+ # measure errors - for residue
+ num = 0
+ batch_size = max(1, int( max_points // pred_center.shape[0] ) )
+
+ while num <= rot_mats.shape[0]:
+ fape_store[s] = fape_store[s] + l_func(
+ pred_center @ rot_mats[num:num+batch_size], # (L_, D)
+ true_center # (L_, D)
+ ).sum(dim=0)
+
+ num += batch_size
+
+ fape_store[s] /= rot_mats.shape[0] # take mean
+
+ # stack and average
+ return (1/max_val) * torch.stack(fape_store, dim=0)
+
+
+# custom
+
+def atom_selector(scn_seq, x, option=None, discard_absent=True):
+ """ Returns a selection of the atoms in a protein.
+ Inputs:
+ * scn_seq: (batch, len) sidechainnet format or list of strings
+ * x: (batch, (len * n_aa), dims) sidechainnet format
+ * option: one of [torch.tensor, 'backbone-only', 'backbone-with-cbeta',
+ 'all', 'backbone-with-oxygen', 'backbone-with-cbeta-and-oxygen']
+ * discard_absent: bool. Whether to discard the points for which
+ there are no labels (bad recordings)
+ """
+
+
+ # get mask
+ present = []
+ for i,seq in enumerate(scn_seq):
+ pass_x = x[i] if discard_absent else None
+ if pass_x is None and isinstance(seq, torch.Tensor):
+ seq = "".join([INDEX2AAS[x] for x in seq.cpu().detach().tolist()])
+
+ present.append( scn_cloud_mask(seq, coords=pass_x) )
+
+ present = torch.stack(present, dim=0).bool()
+
+
+ # atom mask
+ if isinstance(option, str):
+ atom_mask = torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
+ if "backbone" in option:
+ atom_mask[[0, 2]] = 1
+
+ if option == "backbone":
+ pass
+ elif option == 'backbone-with-oxygen':
+ atom_mask[3] = 1
+ elif option == 'backbone-with-cbeta':
+ atom_mask[5] = 1
+ elif option == 'backbone-with-cbeta-and-oxygen':
+ atom_mask[3] = 1
+ atom_mask[5] = 1
+ elif option == 'all':
+ atom_mask[:] = 1
+ else:
+ print("Your string doesn't match any option.")
+
+ elif isinstance(option, torch.Tensor):
+ atom_mask = option
+ else:
+ raise ValueError('option needs to be a valid string or a mask tensor of shape (14,) ')
+
+ mask = rearrange(present * atom_mask.unsqueeze(0).unsqueeze(0).bool(), 'b l c -> b (l c)')
+ return x[mask], mask
+
+
+def noise_internals(seq, angles=None, coords=None, noise_scale=0.5, theta_scale=0.5, verbose=0):
+ """ Noises the internal coordinates -> dihedral and bond angles.
+ Inputs:
+ * seq: string. Sequence in FASTA format
+ * angles: (l, 11) sidechainnet angles tensor
+ * coords: (l, 14, 13)
+ * noise_scale: float. std of noise gaussian.
+ * theta_scale: float. multiplier for bond angles
+ Outputs:
+ * chain (l, c, d)
+ * cloud_mask (l, c)
+ """
+ assert angles is not None or coords is not None, \
+ "You must pass either angles or coordinates"
+ # get scaffolds
+ if angles is None:
+ angles = torch.randn(coords.shape[0], 12).to(coords.device)
+
+ scaffolds = build_scaffolds_from_scn_angles(seq, angles.clone())
+
+ if coords is not None:
+ scaffolds = modify_scaffolds_with_coords(scaffolds, coords)
+
+ # noise bond angles and dihedrals (dihedrals of everyone, angles only of BB)
+ if noise_scale > 0.:
+ if verbose:
+ print("noising", noise_scale)
+ # thetas (half of noise of dihedrals. only for BB)
+ noised_bb = scaffolds["angles_mask"][0, :, :3].clone()
+ noised_bb += theta_scale*noise_scale * torch.randn_like(noised_bb)
+ # get noised values between [-pi, pi]
+ off_bounds = (noised_bb > 2*np.pi) + (noised_bb < -2*np.pi)
+ if off_bounds.sum().item() > 0:
+ noised_bb[off_bounds] = noised_bb[off_bounds] % (2*np.pi)
+
+ upper, lower = noised_bb > np.pi, noised_bb < -np.pi
+ if upper.sum().item() > 0:
+ noised_bb[upper] = - ( 2*np.pi - noised_bb[upper] ).clone()
+ if lower.sum().item() > 0:
+ noised_bb[lower] = 2*np.pi + noised_bb[lower].clone()
+ scaffolds["angles_mask"][0, :, :3] = noised_bb
+
+ # dihedrals
+ noised_dihedrals = scaffolds["angles_mask"][1].clone()
+ noised_dihedrals += noise_scale * torch.randn_like(noised_dihedrals)
+ # get noised values between [-pi, pi]
+ off_bounds = (noised_dihedrals > 2*np.pi) + (noised_dihedrals < -2*np.pi)
+ if off_bounds.sum().item() > 0:
+ noised_dihedrals[off_bounds] = noised_dihedrals[off_bounds] % (2*np.pi)
+
+ upper, lower = noised_dihedrals > np.pi, noised_dihedrals < -np.pi
+ if upper.sum().item() > 0:
+ noised_dihedrals[upper] = - ( 2*np.pi - noised_dihedrals[upper] ).clone()
+ if lower.sum().item() > 0:
+ noised_dihedrals[lower] = 2*np.pi + noised_dihedrals[lower].clone()
+ scaffolds["angles_mask"][1] = noised_dihedrals
+
+ # reconstruct
+ return protein_fold(**scaffolds)
+
+
+def combine_noise(true_coords, seq=None, int_seq=None, angles=None,
+ NOISE_INTERNALS=1e-2, INTERNALS_SCN_SCALE=5.,
+ SIDECHAIN_RECONSTRUCT=True):
+ """ Combines noises. For internal noise, no points can be missing.
+ Inputs:
+ * true_coords: ((B), N, D)
+ * int_seq: (N,) torch long tensor of sidechainnet AA tokens
+ * seq: str of length N. FASTA AAs.
+ * angles: (N_aa, D_). optional. used for internal noising
+ * NOISE_INTERNALS: float. amount of noise for internal coordinates.
+ * SIDECHAIN_RECONSTRUCT: bool. whether to discard the sidechain and
+ rebuild by sampling from plausible distro.
+ Outputs: (B, N, D) coords and (B, N) boolean mask
+ """
+ # get seqs right
+ assert int_seq is not None or seq is not None, "Either int_seq or seq must be passed"
+ if int_seq is not None and seq is None:
+ seq = "".join([INDEX2AAS[x] for x in int_seq.cpu().detach().tolist()])
+ elif int_seq is None and seq is not None:
+ int_seq = torch.tensor([AAS2INDEX[x] for x in seq.upper()], device=true_coords.device)
+
+ cloud_mask_flat = (true_coords == 0.).sum(dim=-1) != true_coords.shape[-1]
+ naive_cloud_mask = scn_cloud_mask(seq).bool()
+
+ if NOISE_INTERNALS:
+ assert cloud_mask_flat.sum().item() == naive_cloud_mask.sum().item(), \
+ "atoms missing: {0}".format( naive_cloud_mask.sum().item() - \
+ cloud_mask_flat.sum().item() )
+ # expand to batch dim if needed
+ if len(true_coords.shape) < 3:
+ true_coords = true_coords.unsqueeze(0)
+ noised_coords = true_coords.clone()
+ coords_scn = rearrange(true_coords, 'b (l c) d -> b l c d', c=14)
+
+ ###### SETP 1: internals #########
+ if NOISE_INTERNALS:
+ # create noised and masked noised coords
+ noised_coords, cloud_mask = noise_internals(seq, angles = angles,
+ coords = coords_scn.squeeze(),
+ noise_scale = NOISE_INTERNALS,
+ theta_scale = INTERNALS_SCN_SCALE,
+ verbose = False)
+ masked_noised = noised_coords[naive_cloud_mask]
+ noised_coords = rearrange(noised_coords, 'l c d -> () (l c) d')
+
+ ###### SETP 2: build from backbone #########
+ if SIDECHAIN_RECONSTRUCT:
+ bb, mask = atom_selector(int_seq.unsqueeze(0), noised_coords, option="backbone", discard_absent=False)
+ scaffolds = build_scaffolds_from_scn_angles(seq, angles=None, device="cpu")
+ noised_coords[~mask] = 0.
+ noised_coords = rearrange(noised_coords, '() (l c) d -> l c d', c=14)
+ noised_coords, _ = sidechain_fold(wrapper = noised_coords.cpu(), **scaffolds, c_beta = False)
+ noised_coords = rearrange(noised_coords, 'l c d -> () (l c) d').to(true_coords.device)
+
+
+ return noised_coords, cloud_mask_flat
+
+
+
+if __name__ == "__main__":
+ import joblib
+ # imports of data (from mp_nerf.utils.get_prot)
+ prots = joblib.load("some_route_to_local_serialized_file_with_prots")
+
+ # set params
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ # unpack and test
+ seq, int_seq, true_coords, angles, padding_seq, mask, pid = prots[-1]
+
+ true_coords = true_coords.unsqueeze(0)
+
+ # check noised internals
+ coords_scn = rearrange(true_coords, 'b (l c) d -> b l c d', c=14)
+ cloud, cloud_mask = noise_internals(seq, angles=angles, coords=coords_scn[0], noise_scale=1.)
+ print("cloud.shape", cloud.shape)
+
+ # check integral
+ integral, mask = combine_noise(true_coords, seq=seq, int_seq = None, angles=None,
+ NOISE_INTERNALS=1e-2, SIDECHAIN_RECONSTRUCT=True)
+ print("integral.shape", integral.shape)
+
+ integral, mask = combine_noise(true_coords, seq=None, int_seq = int_seq, angles=None,
+ NOISE_INTERNALS=1e-2, SIDECHAIN_RECONSTRUCT=True)
+ print("integral.shape2", integral.shape)
+
+
+
diff --git a/rgn2_replica/mp_nerf/proteins.py b/rgn2_replica/mp_nerf/proteins.py
new file mode 100644
index 0000000..f5433de
--- /dev/null
+++ b/rgn2_replica/mp_nerf/proteins.py
@@ -0,0 +1,536 @@
+# science
+# diff / ml
+# module
+import torch.nn.functional as F
+from rgn2_replica.mp_nerf.utils import *
+from rgn2_replica.mp_nerf.ml_utils import *
+from rgn2_replica.mp_nerf.massive_pnerf import *
+from rgn2_replica.mp_nerf.kb_proteins import *
+from einops import rearrange, repeat
+
+
+def scn_cloud_mask(seq, coords=None, strict=False):
+ """ Gets the boolean mask atom positions (not all aas have same atoms).
+ Inputs:
+ * seqs: (length) iterable of 1-letter aa codes of a protein
+ * coords: optional .(batch, lc, 3). sidechainnet coords.
+ returns the true mask (solves potential atoms that might not be provided)
+ * strict: bool. whther to discard the next points after a missing one
+ Outputs: (length, 14) boolean mask
+ """
+ if coords is not None:
+ start = (( rearrange(coords, 'b (l c) d -> b l c d', c=14) != 0 ).sum(dim=-1) != 0).float()
+ # if a point is 0, the following are 0s as well
+ if strict:
+ for b in range(start.shape[0]):
+ for pos in range(start.shape[1]):
+ for chain in range(start.shape[2]):
+ if start[b, pos, chain].item() == 0:
+ start[b, pos, chain:] *= 0
+ return start
+ return torch.tensor([SUPREME_INFO[aa]['cloud_mask'] for aa in seq])
+
+
+def scn_bond_mask(seq):
+ """ Inputs:
+ * seqs: (length). iterable of 1-letter aa codes of a protein
+ Outputs: (L, 14) maps point to bond length
+ """
+ return torch.tensor([SUPREME_INFO[aa]['bond_mask'] for aa in seq])
+
+
+def scn_angle_mask(seq, angles=None, device=None):
+ """ Inputs:
+ * seq: (length). iterable of 1-letter aa codes of a protein
+ * angles: (length, 12). [phi, psi, omega, b_angle(n_ca_c), b_angle(ca_c_n), b_angle(c_n_ca), 6_scn_torsions]
+ Outputs: (L, 14) maps point to theta and dihedral.
+ first angle is theta, second is dihedral
+ """
+ device = angles.device if angles is not None else torch.device("cpu")
+ precise = angles.dtype if angles is not None else torch.get_default_dtype()
+ torsion_mask_use = "torsion_mask" if angles is not None else "torsion_mask_filled"
+ # get masks
+ theta_mask = torch.tensor([SUPREME_INFO[aa]['theta_mask'] for aa in seq], dtype=precise).to(device)
+ torsion_mask = torch.tensor([SUPREME_INFO[aa][torsion_mask_use] for aa in seq], dtype=precise).to(device)
+
+ # adapt general to specific angles if passed
+ if angles is not None:
+ # fill masks with angle values
+ theta_mask[:, 0] = angles[:, 4] # ca_c_n
+ theta_mask[1:, 1] = angles[:-1, 5] # c_n_ca
+ theta_mask[:, 2] = angles[:, 3] # n_ca_c
+ # backbone_torsions
+ torsion_mask[:, 0] = angles[:, 1] # n determined by psi of previous
+ torsion_mask[1:, 1] = angles[:-1, 2] # ca determined by omega of previous
+ torsion_mask[:, 2] = angles[:, 0] # c determined by phi
+ # https://github.com/jonathanking/sidechainnet/blob/master/sidechainnet/structure/StructureBuilder.py#L313
+ torsion_mask[:, 3] = angles[:, 1] - np.pi
+
+ # add torsions to sidechains - no need to modify indexes due to torsion modification
+ # since extra rigid modies are in terminal positions in sidechain
+ to_fill = torsion_mask != torsion_mask # "p" fill with passed values
+ to_pick = torsion_mask == 999 # "i" infer from previous one
+ for i,aa in enumerate(seq):
+ # check if any is nan -> fill the holes
+ number = to_fill[i].long().sum()
+ torsion_mask[i, to_fill[i]] = angles[i, 6:6+number]
+
+ # pick previous value for inferred torsions
+ for j, val in enumerate(to_pick[i]):
+ if val:
+ torsion_mask[i, j] = torsion_mask[i, j-1] - np.pi # pick values from last one.
+
+ # special rigid bodies anomalies:
+ if aa == "I": # scn_torsion(CG1) - scn_torsion(CG2) = 2.13 (see KB)
+ torsion_mask[i, 7] += torsion_mask[i, 5]
+ elif aa == "L":
+ torsion_mask[i, 7] += torsion_mask[i, 6]
+
+
+ torsion_mask[-1, 3] += np.pi
+ return torch.stack([theta_mask, torsion_mask], dim=0)
+
+
+def scn_index_mask(seq):
+ """ Inputs:
+ * seq: (length). iterable of 1-letter aa codes of a protein
+ Outputs: (L, 11, 3) maps point to theta and dihedral.
+ first angle is theta, second is dihedral
+ """
+ idxs = torch.tensor([SUPREME_INFO[aa]['idx_mask'] for aa in seq])
+ return rearrange(idxs, 'l s d -> d l s')
+
+
+def scn_rigid_index_mask(seq, c_alpha=None):
+ """ Inputs:
+ * seq: (length). iterable of 1-letter aa codes of a protein
+ * c_alpha: part of the chain to compute frames on.
+ Outputs: (3, Length * Groups). indexes for 1st, 2nd and 3rd point
+ to construct frames for each group.
+ """
+ maxi = 1 if c_alpha else None
+
+ return torch.cat([torch.tensor(SUPREME_INFO[aa]['rigid_idx_mask'])[:maxi] if i==0 else \
+ torch.tensor(SUPREME_INFO[aa]['rigid_idx_mask'])[:maxi] + 14*i \
+ for i,aa in enumerate(seq)], dim=0).t()
+
+
+def build_scaffolds_from_scn_angles(seq, angles=None, coords=None, device="auto"):
+ """ Builds scaffolds for fast access to data
+ Inputs:
+ * seq: string of aas (1 letter code)
+ * angles: (L, 12) tensor containing the internal angles.
+ Distributed as follows (following sidechainnet convention):
+ * (L, 3) for torsion angles
+ * (L, 3) bond angles
+ * (L, 6) sidechain angles
+ * coords: (L, 3) sidechainnet coords. builds the mask with those instead
+ (better accuracy if modified residues present).
+ Outputs:
+ * cloud_mask: (L, 14 ) mask of points that should be converted to coords
+ * point_ref_mask: (3, L, 11) maps point (except n-ca-c) to idxs of
+ previous 3 points in the coords array
+ * angles_mask: (2, L, 14) maps point to theta and dihedral
+ * bond_mask: (L, 14) gives the length of the bond originating that atom
+ """
+ # auto infer device and precision
+ precise = angles.dtype if angles is not None else torch.get_default_dtype()
+ if device == "auto":
+ device = angles.device if angles is not None else device
+
+ if coords is not None:
+ cloud_mask = scn_cloud_mask(seq, coords=coords)
+ else:
+ cloud_mask = scn_cloud_mask(seq)
+
+ cloud_mask = cloud_mask.bool().to(device)
+
+ point_ref_mask = scn_index_mask(seq).long().to(device)
+
+ angles_mask = scn_angle_mask(seq, angles).to(device, precise)
+
+ bond_mask = scn_bond_mask(seq).to(device, precise)
+ # return all in a dict
+ return {"cloud_mask": cloud_mask,
+ "point_ref_mask": point_ref_mask,
+ "angles_mask": angles_mask,
+ "bond_mask": bond_mask }
+
+
+#############################
+####### ENCODERS ############
+#############################
+
+
+def modify_angles_mask_with_torsions(seq, angles_mask, torsions):
+ """ Modifies a torsion mask to include variable torsions.
+ Inputs:
+ * seq: (L,) str. FASTA sequence
+ * angles_mask: (2, L, 14) float tensor of (angles, torsions)
+ * torsions: (L, 4) float tensor (or (L, 5) if it includes torsion for cb)
+ Outputs: (2, L, 14) a new angles mask
+ """
+ c_beta = torsions.shape[-1] == 5 # whether c_beta torsion is passed as well
+ start = 4 if c_beta else 5
+ # get mask of to-fill values
+ torsion_mask = torch.tensor([SUPREME_INFO[aa]["torsion_mask"] for aa in seq]).to(torsions.device) # (L, 14)
+ torsion_mask = torsion_mask != torsion_mask # values that are nan need replace
+ # undesired outside of margins
+ torsion_mask[:, :start] = torsion_mask[:, start+torsions.shape[-1]:] = False
+
+ angles_mask[1, torsion_mask] = torsions[ torsion_mask[:, start:start+torsions.shape[-1]] ]
+ return angles_mask
+
+
+def modify_scaffolds_with_coords(scaffolds, coords):
+ """ Gets scaffolds and fills in the right data.
+ Inputs:
+ * scaffolds: dict. as returned by `build_scaffolds_from_scn_angles`
+ * coords: (L, 14, 3). sidechainnet tensor. same device as scaffolds
+ Outputs: corrected scaffolds
+ """
+
+
+ # calculate distances and update:
+ # N, CA, C
+ scaffolds["bond_mask"][1:, 0] = torch.norm(coords[1:, 0] - coords[:-1, 2], dim=-1) # N
+ scaffolds["bond_mask"][ :, 1] = torch.norm(coords[ :, 1] - coords[: , 0], dim=-1) # CA
+ scaffolds["bond_mask"][ :, 2] = torch.norm(coords[ :, 2] - coords[: , 1], dim=-1) # C
+ # O, CB, side chain
+ selector = np.arange(len(coords))
+ for i in range(3, 14):
+ # get indexes
+ idx_a, idx_b, idx_c = scaffolds["point_ref_mask"][:, :, i-3] # (3, L, 11) -> 3 * (L, 11)
+ # correct distances
+ scaffolds["bond_mask"][:, i] = torch.norm(coords[:, i] - coords[selector, idx_c], dim=-1)
+ # get angles
+ scaffolds["angles_mask"][0, :, i] = get_angle(coords[selector, idx_b],
+ coords[selector, idx_c],
+ coords[:, i])
+ # handle C-beta, where the C requested is from the previous aa
+ if i == 4:
+ # for 1st residue, use position of the second residue's N
+ first_next_n = coords[1, :1] # 1, 3
+ # the c requested is from the previous residue
+ main_c_prev_idxs = coords[selector[:-1], idx_a[1:]]# (L-1), 3
+ # concat
+ coords_a = torch.cat([first_next_n, main_c_prev_idxs])
+ else:
+ coords_a = coords[selector, idx_a]
+ # get dihedrals
+ scaffolds["angles_mask"][1, :, i] = get_dihedral(coords_a,
+ coords[selector, idx_b],
+ coords[selector, idx_c],
+ coords[:, i])
+ # correct angles and dihedrals for backbone
+ scaffolds["angles_mask"][0, :-1, 0] = get_angle(coords[:-1, 1], coords[:-1, 2], coords[1: , 0]) # ca_c_n
+ scaffolds["angles_mask"][0, 1:, 1] = get_angle(coords[:-1, 2], coords[1:, 0], coords[1: , 1]) # c_n_ca
+ scaffolds["angles_mask"][0, :, 2] = get_angle(coords[:, 0], coords[ :, 1], coords[ : , 2]) # n_ca_c
+
+ # N determined by previous psi = f(n, ca, c, n+1)
+ scaffolds["angles_mask"][1, :-1, 0] = get_dihedral(coords[:-1, 0], coords[:-1, 1], coords[:-1, 2], coords[1:, 0])
+ # CA determined by omega = f(ca, c, n+1, ca+1)
+ scaffolds["angles_mask"][1, 1:, 1] = get_dihedral(coords[:-1, 1], coords[:-1, 2], coords[1:, 0], coords[1:, 1])
+ # C determined by phi = f(c-1, n, ca, c)
+ scaffolds["angles_mask"][1, 1:, 2] = get_dihedral(coords[:-1, 2], coords[1:, 0], coords[1:, 1], coords[1:, 2])
+
+ return scaffolds
+
+
+##################################
+####### MAIN FUNCTION ############
+##################################
+
+
+def protein_fold(cloud_mask, point_ref_mask, angles_mask, bond_mask,
+ device=torch.device("cpu"), hybrid=False):
+ """ Calcs coords of a protein given it's
+ sequence and internal angles.
+ Inputs:
+ * cloud_mask: (L, 14) mask of points that should be converted to coords
+ * point_ref_mask: (3, L, 11) maps point (except n-ca-c) to idxs of
+ previous 3 points in the coords array
+ * angles_mask: (2, 14, L) maps point to theta and dihedral
+ * bond_mask: (L, 14) gives the length of the bond originating that atom
+
+ Output: (L, 14, 3) and (L, 14) coordinates and cloud_mask
+ """
+ # automatic type (float, mixed, double) and size detection
+ precise = bond_mask.dtype
+ length = cloud_mask.shape[0]
+ # create coord wrapper
+ coords = torch.zeros(length, 14, 3, device=device, dtype=precise)
+
+ # do first AA
+ coords[0, 1] = coords[0, 0] + torch.tensor([1, 0, 0], device=device, dtype=precise) * BB_BUILD_INFO["BONDLENS"]["n-ca"]
+ coords[0, 2] = coords[0, 1] + torch.tensor([torch.cos(np.pi - angles_mask[0, 0, 2]),
+ torch.sin(np.pi - angles_mask[0, 0, 2]),
+ 0.], device=device, dtype=precise) * BB_BUILD_INFO["BONDLENS"]["ca-c"]
+
+ # starting positions (in the x,y plane) and normal vector [0,0,1]
+ init_a = repeat(torch.tensor([1., 0., 0.], device=device, dtype=precise), 'd -> l d', l=length)
+ init_b = repeat(torch.tensor([1., 1., 0.], device=device, dtype=precise), 'd -> l d', l=length)
+ # do N -> CA. don't do 1st since its done already
+ thetas, dihedrals = angles_mask[:, :, 1]
+ coords[1:, 1] = mp_nerf_torch(init_a,
+ init_b,
+ coords[:, 0],
+ bond_mask[:, 1],
+ thetas, dihedrals)[1:]
+ # do CA -> C. don't do 1st since its done already
+ thetas, dihedrals = angles_mask[:, :, 2]
+ coords[1:, 2] = mp_nerf_torch(init_b,
+ coords[:, 0],
+ coords[:, 1],
+ bond_mask[:, 2],
+ thetas, dihedrals)[1:]
+ # do C -> N
+ thetas, dihedrals = angles_mask[:, :, 0]
+ coords[:, 3] = mp_nerf_torch(coords[:, 0],
+ coords[:, 1],
+ coords[:, 2],
+ bond_mask[:, 0],
+ thetas, dihedrals)
+
+ #########
+ # sequential pass to join fragments
+ #########
+ # part of rotation mat corresponding to origin - 3 orthogonals
+ mat_origin = get_axis_matrix(init_a[0], init_b[0], coords[0, 0], norm=False)
+ # part of rotation mat corresponding to destins || a, b, c = CA, C, N+1
+ # (L-1) since the first is in the origin already
+ mat_destins = get_axis_matrix(coords[:-1, 1], coords[:-1, 2], coords[:-1, 3])
+
+ # get rotation matrices from origins
+ # https://math.stackexchange.com/questions/1876615/rotation-matrix-from-plane-a-to-b
+ rotations = torch.matmul(mat_origin.t(), mat_destins)
+ rotations /= torch.norm(rotations, dim=-1, keepdim=True)
+
+ # do rotation concatenation - do for loop in cpu always - faster
+ rotations = rotations.cpu() if coords.is_cuda and hybrid else rotations
+ for i in range(1, length-1):
+ rotations[i] = torch.matmul(rotations[i], rotations[i-1])
+ rotations = rotations.to(device) if coords.is_cuda and hybrid else rotations
+ # rotate all
+ coords[1:, :4] = torch.matmul(coords[1:, :4], rotations)
+ # offset each position by cumulative sum at that position
+ coords[1:, :4] += torch.cumsum(coords[:-1, 3], dim=0).unsqueeze(-2)
+
+
+ #########
+ # parallel sidechain - do the oxygen, c-beta and side chain
+ #########
+ for i in range(3,14):
+ level_mask = cloud_mask[:, i]
+ thetas, dihedrals = angles_mask[:, level_mask, i]
+ idx_a, idx_b, idx_c = point_ref_mask[:, level_mask, i-3]
+
+ # to place C-beta, we need the carbons from prev res - not available for the 1st res
+ if i == 4:
+ # the c requested is from the previous residue - offset boolean mask by one
+ # can't be done with slicing bc glycines are inside chain (dont have cb)
+ coords_a = coords[(level_mask.nonzero().view(-1) - 1), idx_a] # (L-1), 3
+ # if first residue is not glycine,
+ # for 1st residue, use position of the second residue's N (1,3)
+ if level_mask[0].item():
+ coords_a[0] = coords[1, 1]
+ else:
+ coords_a = coords[level_mask, idx_a]
+
+ coords[level_mask, i] = mp_nerf_torch(coords_a,
+ coords[level_mask, idx_b],
+ coords[level_mask, idx_c],
+ bond_mask[level_mask, i],
+ thetas, dihedrals)
+
+ return coords, cloud_mask
+
+
+def sidechain_fold(wrapper, cloud_mask, point_ref_mask, angles_mask, bond_mask,
+ device=torch.device("cpu"), c_beta=False):
+ """ Calcs coords of a protein given it's sequence and internal angles.
+ Inputs:
+ * wrapper: (L, 14, 3). coords container with backbone ([:, :3]) and optionally
+ c_beta ([:, 4])
+ * cloud_mask: (L, 14) mask of points that should be converted to coords
+ * point_ref_mask: (3, L, 11) maps point (except n-ca-c) to idxs of
+ previous 3 points in the coords array
+ * angles_mask: (2, 14, L) maps point to theta and dihedral
+ * bond_mask: (L, 14) gives the length of the bond originating that atom
+ * c_beta: whether to place cbeta
+
+ Output: (L, 14, 3) and (L, 14) coordinates and cloud_mask
+ """
+ precise = wrapper.dtype
+
+ # parallel sidechain - do the oxygen, c-beta and side chain
+ for i in range(3,14):
+ # skip cbeta if arg is set
+ if i == 4 and not isinstance(c_beta, str):
+ continue
+ # prepare inputs
+ level_mask = cloud_mask[:, i]
+ thetas, dihedrals = angles_mask[:, level_mask, i]
+ idx_a, idx_b, idx_c = point_ref_mask[:, level_mask, i-3]
+
+ # to place C-beta, we need the carbons from prev res - not available for the 1st res
+ if i == 4:
+ # the c requested is from the previous residue - offset boolean mask by one
+ # can't be done with slicing bc glycines are inside chain (dont have cb)
+ coords_a = wrapper[(level_mask.nonzero().view(-1) - 1), idx_a] # (L-1), 3
+ # if first residue is not glycine,
+ # for 1st residue, use position of the second residue's N (1,3)
+ if level_mask[0].item():
+ coords_a[0] = wrapper[1, 1]
+ else:
+ coords_a = wrapper[level_mask, idx_a]
+
+ wrapper[level_mask, i] = mp_nerf_torch(coords_a,
+ wrapper[level_mask, idx_b],
+ wrapper[level_mask, idx_c],
+ bond_mask[level_mask, i],
+ thetas, dihedrals)
+
+ return wrapper, cloud_mask
+
+
+##############################
+####### XTENSION ############
+##############################
+
+
+# inspired by: https://www.biorxiv.org/content/10.1101/2021.08.02.454840v1
+def ca_from_angles(angles, bond_len=3.80):
+ """ Builds a C-alpha trace from a set of 2 angles (theta, chi).
+ Inputs:
+ * angles: (B, L, 4): float tensor. (cos, sin) · (theta, chi)
+ angles in point-in-unit-circumference format.
+ Outputs: (B, L, 3) coords for c-alpha trace
+ """
+ device = angles.device
+ length = angles.shape[-2]
+ frames = [ torch.repeat_interleave(
+ torch.eye(3, device=device, dtype=torch.float).unsqueeze(0),
+ angles.shape[0],
+ dim=0
+ )]
+
+ rot_mats = torch.stack([
+ torch.stack([ angles[...,0] * angles[...,2], angles[...,0] * angles[...,3], -angles[...,1] ], dim=-1),
+ torch.stack([ -angles[...,3] , angles[...,2] , angles[...,0]*0. ], dim=-1),
+ torch.stack([ angles[...,1] * angles[...,2], angles[...,1] * angles[...,3], angles[...,0] ], dim=-1),
+ ], dim=-2) # (B, L, 3, 3)
+
+ # iterative update of frames - skip last frame.
+ for i in range(length-1):
+ frames.append( rot_mats[:, i] @ frames[i] ) # could do frames[-1] as well
+ frames = torch.stack(frames, dim=1) # (B, L, 3, 3)
+
+ ca_trace = bond_len * frames[..., -1, :].cumsum(dim=-2) # (B, L, 3)
+
+ return ca_trace, frames
+
+
+# inspired by: https://github.com/psipred/DMPfold2/blob/master/dmpfold/network.py#L139
+def ca_bb_fold(ca_trace):
+ """ Calcs a backbone given the coordinate trace of the CAs.
+ Inputs:
+ * ca_trace: (B, L, 3) float tensor with CA coordinates.
+ Outputs: (B, L, 14, 3) (-N-CA(-CB-...)-C(=O)-)
+ """
+ wrapper = torch.zeros(ca_trace.shape[0], ca_trace.shape[1]+2, 14, 3, device=ca_trace.device)
+ wrapper[:, 1:-1, 1] = ca_trace
+ # Place dummy extra Cα atoms on extremenes to get the required vectors
+ vecs = ca_trace[ :, [0, 2, -1, -3] ] - ca_trace[ :, [1, 1, -2, -2] ] # (B, 4, 3)
+ wrapper[:, 0, 1] = ca_trace[:, 0] + 3.80 * F.normalize(torch.cross(vecs[:, 0], vecs[:, 1]), dim=-1)
+ wrapper[:, -1, 1] = ca_trace[:, -1] + 3.80 * F.normalize(torch.cross(vecs[:, 2], vecs[:, 3]), dim=-1)
+
+ # place N and C term
+ vec_ca_can = wrapper[:, :-2, 1] - wrapper[:, 1:-1, 1]
+ vec_ca_cac = wrapper[:, 2: , 1] - wrapper[:, 1:-1, 1]
+ mid_ca_can = (wrapper[:, 1:, 1] + wrapper[:, :-1, 1]) / 2
+ cross_vcan_vcac = F.normalize(torch.cross(vec_ca_can, vec_ca_cac, dim=-1), dim=-1)
+ wrapper[:, 1:-1, 0] = mid_ca_can[:, :-1] - vec_ca_can / 7.5 + cross_vcan_vcac / 3.33
+ # placve all C but last, which is special
+ wrapper[:, 1:-2, 2] = (mid_ca_can[:, :-1] + vec_ca_can / 8 - cross_vcan_vcac / 2.5)[:, 1:]
+ wrapper[:, -2, 2] = mid_ca_can[:, -1, :] - vec_ca_cac[:, -1, :] / 8 + cross_vcan_vcac[:, -1, :] / 2.5
+
+ return wrapper[:, 1:-1]
+
+
+
+############################
+####### METRICS ############
+############################
+
+
+def get_protein_metrics(
+ true_coords,
+ pred_coords,
+ cloud_mask = None,
+ return_aligned = True,
+ detach = None
+ ):
+ """ Calculates many metrics for protein structure quality.
+ Aligns coordinates.
+ Inputs:
+ * true_coords: (B, L, 14, 3) unaligned coords (B = 1)
+ * pred_coords: (B, L, 14, 3) unaligned coords (B = 1)
+ * cloud_mask: (B, L, 14) bool. gotten from pred_coords if not passed
+ * return_aligned: bool. whether to return aligned structs.
+ * detach: bool. whether to detach inputs before compute. saves mem
+ Outputs: dict (k,v)
+ """
+ metric_dict = {
+ "rmsd": rmsd_torch,
+ "drmsd": drmsd_torch,
+ # not implemented yet
+ # "gdt_ts": partial(GDT, mode="TS"),
+ # "gdt_ha": partial(GDT, mode="HA"),
+ # "tmscore": tmscore_torch,
+ # "lddt": lddt_torch,
+ }
+
+ if detach:
+ true_coords = true_coords.detach()
+ pred_coords = pred_coords.detach()
+
+ # clone so originals are not modified
+ true_coords = true_coords.clone()
+ pred_coords = pred_coords.clone()
+ cloud_mask = pred_coords.abs().sum(dim=-1).bool() * \
+ true_coords.abs().sum(dim=-1).bool() # 1, L, 14
+ chain_mask = cloud_mask.sum(dim=-1).bool() # 1, L
+
+ true_aligned, pred_aligned = kabsch_torch(
+ pred_coords[cloud_mask].t(), true_coords[cloud_mask].t()
+ )
+ # no need to rebuild true coords since unaffected by kabsch
+ true_coords[cloud_mask] = true_aligned.t()
+ pred_coords[cloud_mask] = pred_aligned.t()
+
+ # compute metrics
+ outputs = {}
+ for k,f in metric_dict.items():
+ # special. works only on ca trace
+ if k == "tmscore":
+ ca_trace = true_coords[:, :, 1].transpose(-1, -2)
+ ca_pred_trace = pred_coords[:, :, 1].transpose(-1, -2)
+ outputs[k] = f(ca_trace, ca_pred_trace)
+ # special. works on full prot
+ elif k == "lddt":
+ outputs[k] = f(true_coords[:, chain_mask[0]], pred_coords[:, chain_mask[0]], cloud_mask=cloud_mask)
+ # special. needs batch dim
+ elif "gdt" in k:
+ outputs[k] = f(true_aligned[None, ...], pred_aligned[None, ...])
+ else:
+ outputs[k] = f(true_aligned, pred_aligned)
+
+ if return_aligned:
+ outputs.update({
+ "pred_align_wrap": pred_coords,
+ "true_align_wrap": true_coords,
+ })
+
+ return outputs
+
diff --git a/rgn2_replica/mp_nerf/utils.py b/rgn2_replica/mp_nerf/utils.py
new file mode 100644
index 0000000..7e26f02
--- /dev/null
+++ b/rgn2_replica/mp_nerf/utils.py
@@ -0,0 +1,224 @@
+# Author: Eric Alcaide
+
+import torch
+import numpy as np
+
+
+# random hacks
+
+# to_pi_minus_pi(4) = -2.28 # to_pi_minus_pi(-4) = 2.28 # rads to pi-(-pi)
+to_zero_two_pi = lambda x: ( x + (2*np.pi) * ( 1 + torch.floor_divide(x.abs(), 2*np.pi) ) ) % (2*np.pi)
+def to_pi_minus_pi(x):
+ zero_two_pi = to_zero_two_pi(x)
+ return torch.where(
+ zero_two_pi < np.pi, zero_two_pi, -(2*np.pi - zero_two_pi)
+ )
+
+@torch.jit.script
+def cdist(x,y):
+ """ robust cdist - drop-in for pytorch's.
+ Inputs:
+ * x, y: (B, N, D)
+ """
+ return torch.pow(
+ x.unsqueeze(-3) - y.unsqueeze(-2), 2
+ ).sum(dim=-1).clamp(min=1e-7).sqrt()
+
+# data utils
+def get_prot(dataloader_=None, vocab_=None, min_len=80, max_len=150,
+ verbose=True, subset="train", xray_filter=False, full_mask=True):
+ """ Gets a protein from sidechainnet and returns
+ the right attrs for training.
+ Inputs:
+ * dataloader_: sidechainnet iterator over dataset
+ * vocab_: sidechainnet VOCAB class
+ * min_len: int. minimum sequence length
+ * max_len: int. maximum sequence length
+ * verbose: bool. verbosity level
+ * subset: str. which subset to load proteins from.
+ * xray_filter: bool. whether to return only xray structures.
+ * mask_tol: bool or int. bool: whether to return seqs with unknown coords.
+ int: number of minimum label positions
+ Outputs: (cleaned, without padding)
+ (seq_str, int_seq, coords, angles, padding_seq, mask, pid)
+ """
+ if xray_filter:
+ raise NotImplementedError
+
+ while True:
+ for b,batch in enumerate(dataloader_[subset]):
+ for i in range(batch.int_seqs.shape[0]):
+ # skip too short
+ if batch.int_seqs[i].shape[0] < min_len:
+ continue
+
+ # strip padding - matching angles to string means
+ # only accepting prots with no missing residues (mask is 0)
+ padding_seq = (batch.int_seqs[i] == 20).sum().item()
+ padding_mask = -(batch.msks[i] - 1).sum().item() # find 0s
+
+ if (full_mask and padding_seq == padding_mask) or \
+ (full_mask is not True and batch.int_seqs[i].shape[0] - full_mask > 0):
+ # check for appropiate length
+ real_len = batch.int_seqs[i].shape[0] - padding_seq
+ if max_len >= real_len >= min_len:
+ # strip padding tokens
+ seq = batch.str_seqs[i] # seq is already unpadded - see README at scn repo
+ int_seq = batch.int_seqs[i][:-padding_seq or None]
+ angles = batch.angs[i][:-padding_seq or None]
+ mask = batch.msks[i][:-padding_seq or None]
+ coords = batch.crds[i][:-padding_seq*14 or None]
+
+ if verbose:
+ print("stopping at sequence of length", real_len)
+
+ yield seq, int_seq, coords, angles, padding_seq, mask, batch.pids[i]
+ else:
+ if verbose:
+ print("found a seq of length:", batch.int_seqs[i].shape,
+ "but oustide the threshold:", min_len, max_len)
+ else:
+ if verbose:
+ print("paddings not matching", padding_seq, padding_mask)
+ pass
+ return None
+
+
+######################
+## structural utils ##
+######################
+
+def get_dihedral(c1, c2, c3, c4):
+ """ Returns the dihedral angle in radians.
+ Will use atan2 formula from:
+ https://en.wikipedia.org/wiki/Dihedral_angle#In_polymer_physics
+ Inputs:
+ * c1: (batch, 3) or (3,)
+ * c2: (batch, 3) or (3,)
+ * c3: (batch, 3) or (3,)
+ * c4: (batch, 3) or (3,)
+ """
+ u1 = c2 - c1
+ u2 = c3 - c2
+ u3 = c4 - c3
+
+ return torch.atan2( ( (torch.norm(u2, dim=-1, keepdim=True) * u1) * torch.cross(u2,u3, dim=-1) ).sum(dim=-1) ,
+ ( torch.cross(u1,u2, dim=-1) * torch.cross(u2, u3, dim=-1) ).sum(dim=-1) )
+
+
+def get_cosine_angle(c1, c2, c3, eps=1e-7):
+ """ Returns the angle in radians. Uses cosine formula
+ Not all angles are possible all the time.
+ Inputs:
+ * c1: (batch, 3) or (3,)
+ * c2: (batch, 3) or (3,)
+ * c3: (batch, 3) or (3,)
+ """
+ u1 = c2 - c1
+ u2 = c3 - c2
+
+ return torch.acos( (u1*u2).sum(dim=-1) / (u1.norm(dim=-1)*u2.norm(dim=-1) + eps))
+
+
+def get_angle(c1, c2, c3):
+ """ Returns the angle in radians.
+ Inputs:
+ * c1: (batch, 3) or (3,)
+ * c2: (batch, 3) or (3,)
+ * c3: (batch, 3) or (3,)
+ """
+ u1 = c2 - c1
+ u2 = c3 - c2
+
+ # dont use acos since norms involved.
+ # better use atan2 formula: atan2(cross, dot) from here:
+ # https://johnblackburne.blogspot.com/2012/05/angle-between-two-3d-vectors.html
+
+ # add a minus since we want the angle in reversed order - sidechainnet issues
+ return torch.atan2( torch.norm(torch.cross(u1,u2, dim=-1), dim=-1),
+ -(u1*u2).sum(dim=-1) )
+
+
+def kabsch_torch(X, Y):
+ """ Kabsch alignment of X into Y.
+ Assumes X,Y are both (D, N) - usually (3, N)
+ """
+ # center X and Y to the origin
+ X_ = X - X.mean(dim=-1, keepdim=True)
+ Y_ = Y - Y.mean(dim=-1, keepdim=True)
+ # calculate convariance matrix (for each prot in the batch)
+ C = torch.matmul(X_, Y_.t())
+ # Optimal rotation matrix via SVD - warning! W must be transposed
+ if int(torch.__version__.split(".")[1]) < 8:
+ V, S, W = torch.svd(C.detach())
+ W = W.t()
+ else:
+ V, S, W = torch.linalg.svd(C.detach())
+ # determinant sign for direction correction
+ d = (torch.det(V) * torch.det(W)) < 0.0
+ if d:
+ S[-1] = S[-1] * (-1)
+ V[:, -1] = V[:, -1] * (-1)
+ # Create Rotation matrix U
+ U = torch.matmul(V, W)
+ # calculate rotations
+ X_ = torch.matmul(X_.t(), U).t()
+ # return centered and aligned
+ return X_, Y_
+
+
+def rmsd_torch(X, Y):
+ """ Assumes x,y are both (batch, d, n) - usually (batch, 3, N). """
+ return torch.sqrt( torch.mean((X - Y)**2, axis=(-1, -2)) )
+
+
+def drmsd_torch(X, Y):
+ """ Assumes x,y are both (B x D x N). See below for wrapper. """
+ X_ = X.transpose(-1, -2)
+ Y_ = Y.transpose(-1, -2)
+ x_dist = cdist(X_, X_) # (B, N, N)
+ y_dist = cdist(Y_, Y_) # (B, N, N)
+
+ return torch.sqrt( torch.pow(x_dist-y_dist, 2).mean(dim=(-1, -2)).clamp(min=1e-7) )
+
+
+def ensure_chirality(coords_wrapper, use_backbone=True):
+ """ Ensures protein agrees with natural distribution
+ of chiral bonds (ramachandran plots).
+ Reflects ( (-1)*Z ) the ones that do not.
+ Inputs:
+ * coords_wrapper: (B, L, C, 3) float tensor. First 3 atoms
+ in C should be N-CA-C
+ * use_backbone: bool. whether to use the backbone (better, more robust)
+ if provided, or just use c-alphas.
+ Ouputs: (B, L, C, 3)
+ """
+
+ # detach gradients for angle calculation - mirror selection
+ coords_wrapper_ = coords_wrapper.detach()
+ mask = coords_wrapper_.abs().sum(dim=(-1, -2)) != 0.
+
+ # if BB present: use bb dihedrals
+ if coords_wrapper[:, :, 0].abs().sum() != 0. and use_backbone:
+ # compute phis for every protein in the batch
+ phis = get_dihedral(
+ coords_wrapper_[:, :-1, 2], # C_{i-1}
+ coords_wrapper_[:, 1: , 0], # N_{i}
+ coords_wrapper_[:, 1: , 1], # CA_{i}
+ coords_wrapper_[:, 1: , 2], # C_{i}
+ )
+
+ # get proportion of negatives
+ props = [(phis[i, mask[i, :-1]] > 0).float().mean() for i in range(mask.shape[0])]
+
+ # fix mirrors by (-1)*Z if more (+) than (-) phi angles
+ corrector = torch.tensor([ [1, 1, -1 if p > 0.5 else 1] # (B, 3)
+ for p in props ], dtype=coords_wrapper.dtype)
+
+ return coords_wrapper * corrector.to(coords_wrapper.device)[:, None, None, :]
+ else:
+ return coords_wrapper
+
+
+
+
diff --git a/rgn2_replica/rgn2.py b/rgn2_replica/rgn2.py
index e313092..19aa472 100644
--- a/rgn2_replica/rgn2.py
+++ b/rgn2_replica/rgn2.py
@@ -1,4 +1,4 @@
-# Author: Eric Alcaide ( @hypnopump )
+# Author: Eric Alcaide ( @hypnopump )
import os
import sys
from typing import Optional, Tuple, List
@@ -11,10 +11,11 @@
from x_transformers import XTransformer, Encoder
from einops import rearrange, repeat
# custom
-import mp_nerf
+from rgn2_replica import mp_nerf
from rgn2_replica.utils import *
# refiners
import en_transformer
+from pytorch3d.transforms import quaternion_multiply, quaternion_to_matrix
import invariant_point_attention
@@ -22,6 +23,15 @@
#### USEFUL PIECES ####
#######################
+def exists(val):
+ return val is not None
+
+
+def init_zero_(layer):
+ torch.nn.init.constant_(layer.weight, 0.)
+ if exists(layer.bias):
+ torch.nn.init.constant_(layer.bias, 0.)
+
@torch.jit.script
def prediction_wrapper(x: torch.Tensor, pred: torch.Tensor):
""" Facilitates recycling. Inputs the original input + prediction
@@ -122,6 +132,112 @@ def pred_post_process(points_preds: torch.Tensor,
return points_preds, ca_trace_pred, frames_preds, wrapper_pred
+def rotations2angles(rotations: torch.Tensor):
+ # ref to eq.20 in paper: https://arxiv.org/pdf/1102.5658.pdf
+ # input: (B, L, 3, 3) = (B, L, [t, n, b])
+ length = rotations.shape[1]
+ points_preds = torch.zeros(*rotations.shape[:-2], 2, 2, device=rotations.device)
+ points_preds[:, 0, [0, 1], 0] = 1.
+ points_preds[:, 0, [0, 1], 1] = 0.
+ points_preds[:, 1, 1, 0] = 1.
+ points_preds[:, 1, 1, 1] = 0.
+
+ for i in range(length - 1):
+ # cos(theta) = t_{i+1} * t_i
+ points_preds[:, i+1, 0, 0] = \
+ torch.einsum('b i, b i -> b', rotations[:, i, :, 0], rotations[:, i+1, :, 0])
+ # sin(theta) = -n_{i+1} * t_i
+ points_preds[:, i+1, 0, 1] = \
+ -torch.einsum('b i, b i -> b', rotations[:, i, :, 0], rotations[:, i+1, :, 1])
+ if i > 0:
+ # cos(chi) = b_{i+1} * b_i
+ points_preds[:, i+1, 1, 0] = \
+ torch.einsum('b i, b i -> b', rotations[:, i, :, 2], rotations[:, i+1, :, 2])
+ # sin(chi) = -b_{i+1} * n_i
+ points_preds[:, i+1, 1, 1] = \
+ -torch.einsum('b i, b i -> b', rotations[:, i, :, 1], rotations[:, i+1, :, 2])
+
+ return points_preds
+
+
+def pred_post_process_ipa(coords_preds: torch.Tensor,
+ rotations: torch.Tensor,
+ seq_list: Optional[List] = None,
+ mask: Optional[torch.Tensor] = None,
+ model = None,
+ refine_args = {}):
+ """ Converts an angle-based output to structures.
+ Inputs:
+ * coords_preds: (B, L, 3)
+ * seq_list: (B,) list of str. FASTA sequences. Optional. build scns
+ * mask: (B, L) bool tensor.
+ * model: subclass of torch.nn.Module. prediction model w/ potential refiner
+ * model_args: dict. arguments to pass to model for refinement
+ Outputs:
+ * points_preds: (B, L, 2, 2)
+ * ca_trace_pred: (B, L, 14, 3)
+ * frames_preds: (B, L, 3, 3)
+ * wrapper_pred: (B, L, 14, 3)
+ """
+ device = coords_preds.device
+ if mask is None:
+ mask = torch.ones(coords_preds.shape[:-2], dtype=torch.bool)
+ lengths = mask.sum(dim=-1).cpu().detach().tolist()
+
+ frames_preds = rotations # simply forward it
+
+ # restate first values to known ones (1st angle, 1s + 2nd dihedral)
+ points_preds = rotations2angles(rotations)
+
+ # rebuild ca trace with angles - norm vectors to ensure mod=1. - (B, L, 14, 3)
+ ca_trace_pred = torch.zeros(*coords_preds.shape[:2], 14, 3, device=device)
+ ca_trace_pred[:, :, 1] = coords_preds
+ # delete extra part and chirally reflect
+ ca_trace_pred_aux = torch.zeros_like(ca_trace_pred)
+ for i in range(coords_preds.shape[0]):
+ ca_trace_pred_aux[i, :lengths[i]] = ca_trace_pred_aux[i, :lengths[i]] + \
+ mp_nerf.utils.ensure_chirality(ca_trace_pred[i:i+1, :lengths[i]])
+ ca_trace_pred = ca_trace_pred_aux
+
+ # use model's refiner if available
+ if model is not None:
+ if model.refiner is not None:
+ for i in range(mask.shape[0]):
+ adj_mat = torch.from_numpy(
+ np.eye(lengths[i], k=1) + np.eye(lengths[i], k=1).T
+ ).bool().to(device).unsqueeze(0)
+
+ coors = ca_trace_pred[i:i+1, :mask[i].shape[-1], 1].clone()
+ coors = coors.detach() if model.refiner.refiner_detach else coors
+ feats, coors, r_iters = model.refiner(
+ feats=refine_args[model.refiner.feats_inputs][i:i+1, :lengths[i]], # embeddings
+ coors=coors,
+ adj_mat=adj_mat,
+ recycle=refine_args["recycle"],
+ inter_recycle=refine_args["inter_recycle"],
+ )
+ ca_trace_pred[i:i+1, :lengths[i], 1] = coors
+
+ # calc BB - can't do batched bc relies on extremes.
+ wrapper_pred = torch.zeros_like(ca_trace_pred)
+ for i in range(coords_preds.shape[0]):
+ wrapper_pred[i, :lengths[i]] = mp_nerf.proteins.ca_bb_fold(
+ ca_trace_pred[i:i+1, :lengths[i], 1]
+ )
+ if seq_list is not None:
+ # solve backbone steric clashes
+ wrapper_pred[i, :lengths[i]] = mp_nerf.ml_utils.backbone_forcefield(
+ coords=wrapper_pred[i, :lengths[i]], coeffs=[3, 5, 3, 1], lr=1e-2
+ )
+ # build sidechains
+ scaffolds = mp_nerf.proteins.build_scaffolds_from_scn_angles(seq=seq_list[i], device=device)
+ wrapper_pred[i, :lengths[i]], _ = mp_nerf.proteins.sidechain_fold(
+ wrapper_pred[i, :lengths[i]], **scaffolds, c_beta="backbone"
+ )
+
+ return points_preds, ca_trace_pred, frames_preds, wrapper_pred
+
+
class SqReLU(torch.jit.ScriptModule):
r""" Squared ReLU activation from https://arxiv.org/abs/2109.08668v1. """
@@ -441,6 +557,8 @@ def __init__(self, embedding_dim=1280, hidden=[512], mlp_hidden=[128, 4],
torch.nn.Linear(self.mlp_hidden[0], self.mlp_hidden[-1])
)
+ self.refiner = None # to be implemented
+
def forward(self, x, mask : Optional[torch.Tensor] = None,
recycle:int = 1, inter_recycle:bool = False):
@@ -858,3 +976,149 @@ def forward(self, **data_dict):
+class RGN2_IPA(torch.nn.Module):
+ def __init__(self, embedding_dim=1280, hidden=[512], mlp_hidden=[128, 4],
+ act="silu", structure_module_depth=8, predict_points=False, x_transformer_config={
+ "depth": 8,
+ "heads": 4,
+ "attn_dim_head": 64,
+ # "attn_num_mem_kv": 16, # 16 memory key / values
+ "use_scalenorm": True, # set to true to use for all layers
+ "ff_glu": True, # set to true to use for all feedforwards
+ "attn_collab_heads": True,
+ "attn_collab_compression": .3,
+ "cross_attend": False,
+ "gate_values": True, # gate aggregated values with the input"
+ # "sandwich_coef": 6, # interleave attention and feedforwards with sandwich coefficient of 6
+ "rotary_pos_emb": True # turns on rotary positional embeddings"
+ }
+ ):
+ """ Transformer drop-in for RGN2-LSTM.
+ Inputs:
+ * layers: int. number of rnn layers
+ * mlp_hidden: list of ints.
+ """
+ super(RGN2_IPA, self).__init__()
+ act_types = {
+ "relu": torch.nn.ReLU,
+ "silu": torch.nn.SiLU,
+ }
+ # store params
+ self.embedding_dim = embedding_dim
+ self.hidden = hidden
+ self.mlp_hidden = mlp_hidden
+ self.structure_module_depth = structure_module_depth
+ self.predict_points = predict_points
+
+ # declare layers
+ """ Declares an XTransformer model.
+ * No decoder, just predict embeddings
+ * project with a lst_mlp
+
+ """
+ self.to_latent = torch.nn.Linear(self.embedding_dim, self.hidden[0])
+ self.transformer = Encoder(
+ dim= self.hidden[-1],
+
+ **x_transformer_config
+ )
+ self.last_mlp = torch.nn.Sequential(
+ torch.nn.Linear(self.hidden[-1], self.mlp_hidden[0]),
+ act_types[act](),
+ torch.nn.Linear(self.mlp_hidden[0], self.mlp_hidden[-1])
+ )
+
+ """
+ IPA stuff
+ """
+ with torch_default_dtype(torch.float32):
+ self.ipa_block = invariant_point_attention.IPABlock(
+ dim=self.embedding_dim,
+ heads=4, #structure_module_heads,
+ require_pairwise_repr=False
+ )
+
+ self.to_quaternion_update = torch.nn.Linear(self.embedding_dim, 6)
+
+ init_zero_(self.ipa_block.attn.to_out)
+
+ self.to_points = torch.nn.Linear(self.embedding_dim, 3)
+
+
+ self.refiner = None # to be implemented
+
+
+ def forward(self, x, mask : Optional[torch.Tensor] = None,
+ recycle:int = 1, inter_recycle:bool = False):
+ """ Inputs:
+ * x (B, L, Emb_dim)
+ Outputs: (B, L, 4).
+
+ """
+ # same input for both rgn2-lstm and transformer, so mask angles
+ r_iters = [] # todo: implement this
+ x_buffer = x.clone() if recycle > 1 else x # buffer for recycling
+ x[..., -4:] = 0.
+
+ b, n, device = *x.shape[:2], x.device
+
+ with torch_default_dtype(torch.float32):
+ quaternions = torch.tensor([1., 0., 0., 0.], device=device)
+ quaternions = repeat(quaternions, 'd -> b n d', b=b, n=n)
+ translations = torch.zeros((b, n, 3), device=device)
+
+ # go through the layers and apply invariant point attention and feedforward
+
+ for i in range(self.structure_module_depth):
+ is_last = i == (self.structure_module_depth - 1)
+
+ # the detach comes from
+ # https://github.com/deepmind/alphafold/blob/0bab1bf84d9d887aba5cfb6d09af1e8c3ecbc408/alphafold/model/folding.py#L383
+ rotations = quaternion_to_matrix(quaternions)
+
+ if not is_last:
+ rotations = rotations.detach()
+
+ x = self.ipa_block(
+ x,
+ mask=mask,
+ # pairwise_repr=pairwise_repr,
+ rotations=rotations,
+ translations=translations
+ )
+
+ # update quaternion and translation
+
+ quaternion_update, translation_update = self.to_quaternion_update(x).chunk(2, dim=-1)
+ quaternion_update = F.pad(quaternion_update, (1, 0), value=1.)
+
+ quaternions = quaternion_multiply(quaternions, quaternion_update)
+ translations = translations + torch.einsum('b n c, b n c r -> b n r', translation_update, rotations)
+
+ points_local = self.to_points(x)
+ rotations = quaternion_to_matrix(quaternions)
+ x_pred = torch.einsum('b n c, b n c d -> b n d', points_local, rotations) + translations
+
+ x_pred = x_pred.type(x.dtype).to(x.device)
+ # todo: support the inter_recycle option
+ r_iters = \
+ torch.empty(x.shape[0], recycle - 1, device=x.device) # (B, recycle-1, L, 4)
+
+ if not self.predict_points:
+ # todo:
+ return x_pred, r_iters, rotations, translations
+
+ return x_pred, r_iters
+
+
+ def predict_fold(self, x, mask : Optional[torch.Tensor] = None,
+ recycle:int = 1, inter_recycle:bool = False):
+ """ Predicts all angles at once so no need for AR prediction.
+ Same inputs / outputs than
+ """
+ with torch.no_grad():
+ return self.forward(
+ x=x, mask=mask,
+ recycle=recycle, inter_recycle=inter_recycle
+ )
+
diff --git a/rgn2_replica/rgn2_trainers.py b/rgn2_replica/rgn2_trainers.py
index 2838ab3..eea8177 100644
--- a/rgn2_replica/rgn2_trainers.py
+++ b/rgn2_replica/rgn2_trainers.py
@@ -2,14 +2,7 @@
import time
import gc
-import random
-import numpy as np
-import torch
-from einops import rearrange, repeat
-from functools import partial
-
-import mp_nerf
from rgn2_replica.rgn2 import *
from rgn2_replica.utils import *
from rgn2_replica.rgn2_utils import *
@@ -47,7 +40,7 @@ def batched_inference(*args, model, embedder,
# create scaffolds
int_seq = torch.ones(batch_dim, max_seq_len, dtype=torch.long) * 20 # padding tok
# mask is true mask. long mask is for lstm
- mask, long_mask = torch.zeros(2, *int_seq.shape, dtype=torch.bool)
+ mask, long_mask = torch.zeros(2, *int_seq.shape, dtype=torch.bool, device=device)
true_coords = torch.zeros(int_seq.shape[0], int_seq.shape[1]*14, 3, device=device)
# fill scaffolds
for i,arg in enumerate(args):
@@ -59,14 +52,14 @@ def batched_inference(*args, model, embedder,
mask = mask.bool().to(device)
coords = rearrange(true_coords, 'b (l c) d -> b l c d', c=14)
ca_trace = coords[..., 1, :]
- coords_rebuilt = mp_nerf.proteins.ca_bb_fold( ca_trace ) # beware extremes
+ # coords_rebuilt = mp_nerf.proteins.ca_bb_fold( ca_trace ) # beware extremes
# calc angle labels
angles_label_ = torch.zeros(*ca_trace.shape[:-1], 2, dtype=torch.float, device=device)
angles_mask_ = torch.zeros_like(angles_label_).bool() # propagate mask to angles w/ missing points
for i, arg in enumerate(args):
length = arg[1].shape[-1]
- angles_label_[i, 1:length-1, 0] = mp_nerf.utils.get_cosine_angle(
+ angles_label_[i, 1:length-1, 0] = mp_nerf.utils.get_cosine_angle(
ca_trace[i, :length-2 , :],
ca_trace[i, 1:length-1, :],
ca_trace[i, 2:length , :],
@@ -87,6 +80,7 @@ def batched_inference(*args, model, embedder,
# later don't count them
# angles_label_[~angles_mask_] = 0.
angles_label_[angles_label_ != angles_label_] = 0.
+ print(angles_label_.shape)
points_label = mp_nerf.ml_utils.angle_to_point_in_circum(angles_label_) # (B, L, 2, 2)
# include angles of previous AA as input
@@ -116,21 +110,39 @@ def batched_inference(*args, model, embedder,
# PREDICT
if mode in ["train", "test", "fast_test"]:
# get angles
- preds, r_iters = model.forward(embedds, mask=long_mask,
- recycle=recycle_func(None)) # (B, L, 4)
- points_preds = rearrange(preds, '... (a d) -> ... a d', a=2) # (B, L, 2, 2)
-
- # POST-PROCESS
- points_preds, ca_trace_pred, frames_preds, wrapper_pred = pred_post_process(
- points_preds, mask=long_mask, # long_mask == True for all seq_len
- # seq_list = None, # don't fold sidechain
- model=model, refine_args={
- "embedds": embedds,
- "int_seq": int_seq.to(device),
- "recycle": recycle_func(None),
- "inter_recycle": False,
- }
- )
+ refiner_type = config.refiner_args["refiner_type"]
+ if refiner_type == "En":
+ preds, r_iters = model.forward(embedds, mask=long_mask,
+ recycle=recycle_func(None)) # (B, L, 4)
+
+ points_preds = rearrange(preds, '... (a d) -> ... a d', a=2) # (B, L, 2, 2)
+
+ # POST-PROCESS
+ points_preds, ca_trace_pred, frames_preds, wrapper_pred = pred_post_process(
+ points_preds, mask=long_mask, # long_mask == True for all seq_len
+ # seq_list = None, # don't fold sidechain
+ model=model, refine_args={
+ "embedds": embedds,
+ "int_seq": int_seq.to(device),
+ "recycle": recycle_func(None),
+ "inter_recycle": False,
+ }
+ )
+ elif refiner_type == "IPA": # IPA returns coords
+ preds, r_iters, rotations, translations = model.forward(embedds, mask=long_mask,
+ recycle=recycle_func(None)) # (B, L, 4)
+ points_preds, ca_trace_pred, frames_preds, wrapper_pred = pred_post_process_ipa(
+ preds, rotations, mask=long_mask, # long_mask == True for all seq_len
+ # seq_list = None, # don't fold sidechain
+ model=model, refine_args={
+ "embedds": embedds,
+ "int_seq": int_seq.to(device),
+ "recycle": recycle_func(None),
+ "inter_recycle": False,
+ }
+ )
+ else:
+ raise NotImplementedError("refiner types besides En/IPA are not supported.")
# get frames (for labels) for for later fape
bb_ca_trace_rebuilt, frames_labels = mp_nerf.proteins.ca_from_angles(
@@ -188,7 +200,7 @@ def inference(*args, model, embedder,
long_mask = torch.ones_like(mask)
coords = rearrange(true_coords, '(l c) d -> () l c d', c=14).to(device)
ca_trace = coords[..., 1, :]
- coords_rebuilt = mp_nerf.proteins.ca_bb_fold( ca_trace )
+ coords_rebuilt = mp_nerf.proteins.ca_bb_fold(ca_trace)
# mask for thetas and chis
angles_label_ = torch.zeros(*ca_trace.shape[:-1], 2, dtype=torch.float, device=device)
angles_mask_ = torch.zeros_like(angles_label_).bool()
@@ -203,12 +215,12 @@ def inference(*args, model, embedder,
ca_trace[..., 2:-1, :],
ca_trace[..., 3: , :],
)
- angles_mask_[..., 1:-1, 0] = (
- mask[i, :-2] * mask[i, 1:-1] * mask[i, 2:]
- )
- angles_mask_[i, 2:-1, 0] = (
- mask[i, :-3] * mask[i, 1:-2] * mask[i, 2:-1], mask[i, 3:]
- )
+ # angles_mask_[..., 1:-1, 0] = (
+ # mask[i, :-2] * mask[i, 1:-1] * mask[i, 2:]
+ # )
+ # angles_mask_[i, 2:-1, 0] = (
+ # mask[i, :-3] * mask[i, 1:-2] * mask[i, 2:-1], mask[i, 3:]
+ # )
# replace nan and (angles whose coords are not fully known) by 0.
# angles_label_[~angles_mask_] = 0.
angles_label_[angles_label_ != angles_label_] = 0.
@@ -249,7 +261,7 @@ def inference(*args, model, embedder,
)
# get frames for for later fape
- bb_ca_trace_rebuilt, frames_labels = mp_nerf.proteins.ca_from_angles(
+ bb_ca_trace_rebuilt, frames_labels = mp_nerf.proteins.ca_from_angles(
points_label.reshape(1, -1, 4) # (B, L, 2, 2) -> (B, L, 4)
)
@@ -333,7 +345,7 @@ def predict(get_prot_, steps, model, embedder, return_preds=True,
# violation loss btween calphas - L1
dist_mat = mp_nerf.utils.cdist(infer["wrapper_pred"][:, :, 1],
- infer["wrapper_pred"][:, :, 1],) # B, L, L
+ infer["wrapper_pred"][:, :, 1], ) # B, L, L
dist_mat[:, np.arange(dist_mat.shape[-1]), np.arange(dist_mat.shape[-1])] = \
dist_mat[:, np.arange(dist_mat.shape[-1]), np.arange(dist_mat.shape[-1])] + 5.
viol_loss = -(dist_mat - 3.78).clamp(min=-np.inf, max=0.)
@@ -443,7 +455,7 @@ def train(get_prot_, steps, model, embedder, optim, loss_f=None,
# violation loss btween calphas - L1
dist_mat = mp_nerf.utils.cdist(infer["wrapper_pred"][:, :, 1],
- infer["wrapper_pred"][:, :, 1],) # B, L, L
+ infer["wrapper_pred"][:, :, 1], ) # B, L, L
dist_mat = dist_mat + torch.eye(dist_mat.shape[-1]).unsqueeze(0).to(dist_mat)*5.
viol_loss = -(dist_mat - 3.78).clamp(min=-np.inf, max=0.).contiguous()
diff --git a/rgn2_replica/utils.py b/rgn2_replica/utils.py
index f96978b..fbb657f 100644
--- a/rgn2_replica/utils.py
+++ b/rgn2_replica/utils.py
@@ -3,6 +3,7 @@
import math
import torch
import numpy as np
+import contextlib
# random hacks - device utils for pyTorch - saves transfers
@@ -46,7 +47,12 @@ def set_seed(seed, verbose=False):
print("Seet seed to {0}".format(seed))
-
+@contextlib.contextmanager
+def torch_default_dtype(dtype):
+ prev_dtype = torch.get_default_dtype()
+ torch.set_default_dtype(dtype)
+ yield
+ torch.set_default_dtype(prev_dtype)
diff --git a/scripts/rgn2_predict_fold.py b/scripts/rgn2_predict_fold.py
index 9d5567f..e4e83d7 100644
--- a/scripts/rgn2_predict_fold.py
+++ b/scripts/rgn2_predict_fold.py
@@ -1,17 +1,16 @@
# Author: Eirc Alcaide (@hypnopump)
+<<<<<<< HEAD
+=======
import os
import re
import json
import numpy as np
import torch
+>>>>>>> 9c60d1ddc49a5b9dd73937ad6f0c9e21f8bf8867
# process
import argparse
-import joblib
-from tqdm import tqdm
# custom
-import esm
import sidechainnet
-import mp_nerf
from rgn2_replica import *
from rgn2_replica.embedders import *
from rgn2_replica.rgn2_refine import *
@@ -118,7 +117,6 @@
# refine structs
if args.rosetta_refine:
from typing import Optional
- import pyrosetta
for i, seq in enumerate(seq_list):
# only refine
diff --git a/scripts/train_rgn2.py b/scripts/train_rgn2.py
index fbdce59..1ca6db3 100644
--- a/scripts/train_rgn2.py
+++ b/scripts/train_rgn2.py
@@ -1,24 +1,19 @@
import os
import json
import argparse
-import random
-import numpy as np
-import wandb
-import torch
-import esm
import sidechainnet
from sidechainnet.utils.sequence import ProteinVocabulary as VOCAB
+import sys
+sys.path.append("..")
+
# IMPORTED ALSO IN LATER MODULES
VOCAB = VOCAB()
-import mp_nerf
from rgn2_replica.rgn2_trainers import *
from rgn2_replica.embedders import *
-from rgn2_replica import set_seed, RGN2_Naive
-
-
+from rgn2_replica import set_seed, RGN2_Naive, mp_nerf
def parse_arguments():
@@ -34,9 +29,9 @@ def parse_arguments():
# data params
parser.add_argument("--min_len", help="Min seq len, for train", type=int, default=0)
parser.add_argument("--min_len_valid", help="Min seq len, for valid", type=int, default=0)
- parser.add_argument("--max_len", help="Max seq len", type=int, default=512)
- parser.add_argument("--casp_version", help="SCN dataset version", type=int, default=12)
- parser.add_argument("--scn_thinning", help="SCN dataset thinning", type=int, default=90)
+ parser.add_argument("--max_len", help="Max seq len", type=int, default=128)#512)
+ parser.add_argument("--casp_version", help="SCN dataset version", type=int, default=7)
+ parser.add_argument("--scn_thinning", help="SCN dataset thinning", type=int, default=30)
parser.add_argument("--xray", help="only use xray structures", type=bool, default=0)
parser.add_argument("--frac_true_torsions", help="Provide right torsions for some prots", type=bool, default=0)
parser.add_argument("--full_mask", help="require full mask in proteins", type=bool, default=1)
@@ -53,7 +48,8 @@ def parse_arguments():
parser.add_argument("--num_recycles_train", type=int, default=3,
help="number of recycling iters. set to 1 to speed training.",)
# refiner params
- parser.add_argument("--refiner_args", help="args for refiner module", type=json.loads, default={})
+ parser.add_argument("--refiner_args", help="args for refiner module", type=json.loads,
+ default={"refiner_type": "En"})
parser.add_argument("--seed", help="Random seed", default=101)
return parser.parse_args()
@@ -145,17 +141,20 @@ def run_train_schedule(dataloaders, embedder, config, args):
embedder = embedder.to(device)
set_seed(config.seed)
- model = RGN2_Naive(layers=config.num_layers,
- emb_dim=config.emb_dim+4,
- hidden=config.hidden,
- bidirectional=config.bidirectional,
- mlp_hidden=config.mlp_hidden,
- act=config.act,
- layer_type=config.layer_type,
- input_dropout=config.input_dropout,
- angularize=config.angularize,
- refiner_args=config.refiner_args,
- ).to(device)
+ # model = RGN2_Naive(layers=config.num_layers,
+ # emb_dim=config.emb_dim+4,
+ # hidden=config.hidden,
+ # bidirectional=config.bidirectional,
+ # mlp_hidden=config.mlp_hidden,
+ # act=config.act,
+ # layer_type=config.layer_type,
+ # input_dropout=config.input_dropout,
+ # angularize=config.angularize,
+ # refiner_args=config.refiner_args,
+ # ).to(device)
+ model = RGN2_IPA(
+ embedding_dim=config.emb_dim+4,
+ ).to(device)
if args.resume_name is not None:
model.load_my_state_dict(torch.load(args.resume_name, map_location=device))
@@ -326,9 +325,10 @@ def get_training_schedule(args):
loss_f = " metrics['drmsd'].mean() / len(infer['seq']) "
# steps, ckpt, lr , bs , max_len, clip, loss_f
- return [[32000, 135 , 1e-3, 16 , args.max_len, None, loss_f, 42 , ],
- [64000, 135 , 1e-3, 32 , args.max_len, None, loss_f, 42 , ],
+ return [[32000, 135 , 1e-4, 16 , args.max_len, None, loss_f, 42 , ],
+ [64000, 135 , 1e-4, 32 , args.max_len, None, loss_f, 42 , ],
[32000, 135 , 1e-4, 32 , args.max_len, None, loss_f, 42 , ],]
+ # return [[32, 2, 1e-3, 16, args.max_len, None, loss_f, 42, ]]
if __name__ == '__main__':
diff --git a/setup.py b/setup.py
index f39dab7..9e5a113 100644
--- a/setup.py
+++ b/setup.py
@@ -22,14 +22,16 @@
'sidechainnet',
'proDy',
'tqdm',
- 'mp-nerf',
+ # 'mp-nerf',
'en-transformer>=0.5.0',
'datasets>=1.10',
'transformers>=4.2',
'x-transformers>=0.16.1',
'pytorch-lightning>=1.4',
'wandb',
- 'fair-esm>=0.4.0'
+ 'fair-esm>=0.4.0',
+ 'pytorch3d',
+ 'invariant_point_attention'
],
setup_requires=[
'pytest-runner',