diff --git a/.github/workflows/paper.yml b/.github/workflows/paper.yml index 32ee953..16b6737 100644 --- a/.github/workflows/paper.yml +++ b/.github/workflows/paper.yml @@ -13,10 +13,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v2 with: - python-version: "3.8" + python-version: "3.10" - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3416cbc..32f933e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.2.3 + rev: v4.5.0 hooks: - - id: trailing-whitespace - id: check-yaml - id: end-of-file-fixer + - id: trailing-whitespace - id: mixed-line-ending - - repo: https://github.com/psf/black + - repo: https://github.com/psf/black-pre-commit-mirror rev: "23.11.0" hooks: - id: black diff --git a/exmol/exmol.py b/exmol/exmol.py index eabe072..a26a5de 100644 --- a/exmol/exmol.py +++ b/exmol/exmol.py @@ -27,9 +27,10 @@ from rdkit.Chem.Draw import MolToImage as mol2img, DrawMorganBit # type: ignore from rdkit.Chem import rdchem # type: ignore from rdkit.DataStructs.cDataStructs import BulkTanimotoSimilarity, TanimotoSimilarity # type: ignore -import langchain.llms as llms -import langchain.prompts as prompts +from openai import OpenAI + +client = OpenAI() from . import stoned from .plot_utils import _mol_images, _image_scatter, _bit2atoms from .data import * @@ -392,6 +393,7 @@ def _check_alphabet_consistency( alphabet_symbols = _alphabet_to_elements(set(alphabet_symbols)) # find all elements in smiles (Upper alpha or upper alpha followed by lower alpha) smiles_symbols = set(re.findall(r"[A-Z][a-z]?", smiles)) + if check and not smiles_symbols.issubset(alphabet_symbols): # show which symbols are not in alphabet raise ValueError( @@ -1410,7 +1412,7 @@ def merge_text_explains( def text_explain_generate( text_explanations: List[Tuple[str, float]], property_name: str, - llm: Optional[llms.BaseLLM] = None, + llm_model: str = "gpt-4", single: bool = True, ) -> str: """Insert text explanations into template, and generate explanation. @@ -1430,14 +1432,22 @@ def text_explain_generate( for x in text_explanations ] ) - prompt_template = prompts.PromptTemplate( - input_variables=["property", "text"], - template=_single_prompt if single else _multi_prompt, - ) + + prompt_template = _single_prompt if single else _multi_prompt prompt = prompt_template.format(property=property_name, text=text) - if llm is None: - llm = llms.OpenAI(temperature=0.05) - return llm(prompt) + + messages = [ + { + "role": "system", + "content": "Your goal is to explain which molecular features are important to its properties based on the given text.", + }, + {"role": "user", "content": prompt}, + ] + response = client.chat.completions.create( + model=llm_model, messages=messages, temperature=0.05 + ) + + return response.choices[0].message.content def text_explain( diff --git a/exmol/stoned/stoned.py b/exmol/stoned/stoned.py index 27198ab..de4ea88 100644 --- a/exmol/stoned/stoned.py +++ b/exmol/stoned/stoned.py @@ -214,6 +214,7 @@ from rdkit import Chem # type: ignore from rdkit.Chem import MolFromSmiles as smi2mol # type: ignore from rdkit.Chem import MolToSmiles as mol2smi # type: ignore +from rdkit.Chem import MolToRandomSmilesVect # type: ignore from rdkit.Chem import AllChem # type: ignore from rdkit.DataStructs.cDataStructs import TanimotoSimilarity # type: ignore @@ -237,9 +238,10 @@ def randomize_smiles(mol): if not mol: return None - return mol2smi( - mol, canonical=False, doRandom=True, isomericSmiles=True, kekuleSmiles=True - ) + # return mol2smi( + # mol, canonical=False, doRandom=True, isomericSmiles=True, kekuleSmiles=True + # ) + return MolToRandomSmilesVect(mol, 1, isomericSmiles=True, kekuleSmiles=True, randomSeed=np.random.randint(0,100))[0] def largest_mol(smiles): diff --git a/paper2_LIME/BBB-RNN.ipynb b/paper2_LIME/BBB-RNN.ipynb new file mode 100644 index 0000000..70b5a76 --- /dev/null +++ b/paper2_LIME/BBB-RNN.ipynb @@ -0,0 +1,678 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# LIME paper: Recurrent Neural Network for Blood brain barrier permeation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import packages and set up RNN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.patches import Rectangle, FancyBboxPatch\n", + "from matplotlib.offsetbox import AnnotationBbox\n", + "import seaborn as sns\n", + "import textwrap\n", + "import skunk\n", + "import matplotlib as mpl\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "import selfies as sf\n", + "import exmol\n", + "from dataclasses import dataclass\n", + "from rdkit.Chem.Draw import rdDepictor, MolsToGridImage\n", + "from rdkit.Chem import MolFromSmiles, MACCSkeys\n", + "\n", + "rdDepictor.SetPreferCoordGen(True)\n", + "import matplotlib.font_manager as font_manager\n", + "import urllib.request\n", + "\n", + "urllib.request.urlretrieve(\n", + " \"https://github.com/google/fonts/raw/main/ofl/ibmplexmono/IBMPlexMono-Regular.ttf\",\n", + " \"IBMPlexMono-Regular.ttf\",\n", + ")\n", + "fe = font_manager.FontEntry(fname=\"IBMPlexMono-Regular.ttf\", name=\"plexmono\")\n", + "font_manager.fontManager.ttflist.append(fe)\n", + "plt.rcParams.update(\n", + " {\n", + " \"axes.facecolor\": \"#f5f4e9\",\n", + " \"grid.color\": \"#AAAAAA\",\n", + " \"axes.edgecolor\": \"#333333\",\n", + " \"figure.facecolor\": \"#FFFFFF\",\n", + " \"axes.grid\": False,\n", + " \"axes.prop_cycle\": plt.cycler(\"color\", plt.cm.Dark2.colors),\n", + " \"font.family\": fe.name,\n", + " \"figure.figsize\": (3.5, 3.5 / 1.2),\n", + " \"ytick.left\": True,\n", + " \"xtick.bottom\": True,\n", + " }\n", + ")\n", + "\n", + "color_cycle = [\"#F06060\", \"#1BBC9B\", \"#F06060\", \"#5C4B51\", \"#F3B562\", \"#6e5687\"]\n", + "mpl.rcParams[\"axes.prop_cycle\"] = mpl.cycler(color=color_cycle)\n", + "mpl.rcParams[\"font.size\"] = 10\n", + "bbb_data = pd.read_csv(\"../paper1_CFs/BBBP.csv\")\n", + "# features_start_at = list(bbb_data.columns).index(\"MolWt\")\n", + "np.random.seed(0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# scramble them\n", + "bbb_data = bbb_data.sample(frac=1, random_state=0).reset_index(drop=True)\n", + "bbb_data.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from rdkit.Chem import MolToSmiles\n", + "\n", + "\n", + "def _randomize_smiles(mol, isomericSmiles=True):\n", + " return MolToSmiles(\n", + " mol,\n", + " canonical=False,\n", + " doRandom=True,\n", + " isomericSmiles=isomericSmiles,\n", + " kekuleSmiles=random.random() < 0.5,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "\n", + "smiles = list(bbb_data[\"smiles\"])\n", + "permeabilities = list(bbb_data[\"p_np\"])\n", + "\n", + "aug_data = 0\n", + "\n", + "aug_smiles = []\n", + "aug_perm = []\n", + "for sml, sol in zip(smiles, permeabilities):\n", + " new_smls = []\n", + " new_smls.append(sml)\n", + " aug_perm.append(sol)\n", + " for _ in range(aug_data):\n", + " try:\n", + " new_sml = _randomize_smiles(MolFromSmiles(sml))\n", + " # print(new_sml)\n", + " if new_sml not in new_smls:\n", + " new_smls.append(new_sml)\n", + " aug_perm.append(sol)\n", + " except:\n", + " continue\n", + " aug_smiles.extend(new_smls)\n", + "\n", + "aug_df_bbb = pd.DataFrame(data={\"smiles\": aug_smiles, \"p_np\": aug_perm})\n", + "\n", + "print(f\"The dataset was augmented from {len(bbb_data)} to {len(aug_df_bbb)}.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "selfies_list = []\n", + "for i, s in enumerate(aug_df_bbb.smiles):\n", + " try:\n", + " selfies_list.append(sf.encoder(exmol.sanitize_smiles(s)[1]))\n", + " except (sf.EncoderError, TypeError):\n", + " selfies_list.append(None)\n", + " bbb_data.smiles[i] = exmol.sanitize_smiles(s)[1]\n", + "len(selfies_list)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "basic = set(exmol.get_basic_alphabet())\n", + "data_vocab = set(\n", + " sf.get_alphabet_from_selfies([s for s in selfies_list if s is not None])\n", + ")\n", + "vocab = ['[nop]']\n", + "vocab.extend(list(data_vocab.union(basic)))\n", + "vocab_stoi = {o: i for o, i in zip(vocab, range(len(vocab)))}\n", + "\n", + "\n", + "def selfies2ints(s):\n", + " result = []\n", + " for token in sf.split_selfies(s):\n", + " if token == '.':\n", + " continue # ?\n", + " if token in vocab_stoi:\n", + " result.append(vocab_stoi[token])\n", + " else:\n", + " print(token)\n", + " result.append(np.nan)\n", + " # print('Warning')\n", + " return result\n", + "\n", + "\n", + "def ints2selfies(v):\n", + " return \"\".join([vocab[i] for i in v])\n", + "\n", + "\n", + "# test them out\n", + "s = selfies_list[0]\n", + "print('selfies:', s)\n", + "v = selfies2ints(s)\n", + "print('selfies2ints:', v)\n", + "so = ints2selfies(v)\n", + "print('ints2selfes:', so)\n", + "assert so == s.replace(\n", + " '.', ''\n", + ") # make sure '.' is removed from Selfies string during assertion" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# creating an object\n", + "@dataclass\n", + "class Config:\n", + " vocab_size: int\n", + " example_number: int\n", + " batch_size: int\n", + " buffer_size: int\n", + " embedding_dim: int\n", + " rnn_units: int\n", + " hidden_dim: int\n", + " drop_rate: float\n", + "\n", + "\n", + "config = Config(\n", + " vocab_size=len(vocab),\n", + " example_number=len(selfies_list),\n", + " batch_size=128,\n", + " buffer_size=10000,\n", + " embedding_dim=64,\n", + " hidden_dim=32,\n", + " rnn_units=64,\n", + " drop_rate=0.20,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# now get sequences\n", + "encoded = [selfies2ints(s) for s in selfies_list if s is not None]\n", + "padded_seqs = tf.keras.preprocessing.sequence.pad_sequences(encoded, padding=\"post\")\n", + "\n", + "permeabilities = aug_df_bbb.p_np.values[[bool(s) for s in selfies_list]]\n", + "\n", + "# Should be shuffled from the beginning, so no worries\n", + "N = len(padded_seqs)\n", + "split = int(0.1 * N)\n", + "\n", + "# Now build dataset\n", + "test_data = tf.data.Dataset.from_tensor_slices(\n", + " (padded_seqs[:split], permeabilities[:split])\n", + ").batch(config.batch_size)\n", + "\n", + "nontest = tf.data.Dataset.from_tensor_slices(\n", + " (\n", + " padded_seqs[split:],\n", + " permeabilities[split:],\n", + " )\n", + ")\n", + "\n", + "val_data, train_data = nontest.take(split).batch(config.batch_size), nontest.skip(\n", + " split\n", + ").shuffle(config.buffer_size).batch(config.batch_size).prefetch(\n", + " tf.data.experimental.AUTOTUNE\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = tf.keras.Sequential()\n", + "\n", + "# make embedding and indicate that 0 should be treated as padding mask\n", + "model.add(\n", + " tf.keras.layers.Embedding(\n", + " input_dim=config.vocab_size, output_dim=config.embedding_dim, mask_zero=True\n", + " )\n", + ")\n", + "model.add(tf.keras.layers.Dropout(config.drop_rate))\n", + "# RNN layer\n", + "model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(config.rnn_units)))\n", + "model.add(tf.keras.layers.Dropout(config.drop_rate))\n", + "# a dense hidden layer\n", + "model.add(tf.keras.layers.Dense(config.hidden_dim, activation=\"relu\"))\n", + "model.add(tf.keras.layers.Dropout(config.drop_rate))\n", + "# regression, so no activation\n", + "model.add(tf.keras.layers.Dense(1))\n", + "\n", + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.compile(\n", + " tf.optimizers.Adam(1e-3),\n", + " loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n", + " metrics=[\"accuracy\"],\n", + ")\n", + "# verbose=0 silences output, to get progress bar set verbose=1\n", + "result = model.fit(train_data, validation_data=val_data, epochs=100, verbose=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.save(\"bbbp-rnn\")\n", + "# model = tf.keras.models.load_model('solubility-rnn-accurate/')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# plot test data\n", + "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))\n", + "ax1.plot(result.history[\"loss\"], label=\"training\")\n", + "ax1.plot(result.history[\"val_loss\"], label=\"validation\")\n", + "ax1.legend()\n", + "ax1.set_xlabel(\"Epoch\")\n", + "ax1.set_ylabel(\"Loss\")\n", + "\n", + "ax2.plot(result.history[\"accuracy\"], label=\"training\")\n", + "ax2.plot(result.history[\"val_accuracy\"], label=\"validation\")\n", + "ax2.legend()\n", + "ax2.set_xlabel(\"Epoch\")\n", + "ax2.set_ylabel(\"Accuracy\")\n", + "fig.tight_layout()\n", + "fig.savefig(\"bbp-rnn-loss-acc.png\", dpi=180)\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.metrics import roc_curve\n", + "from sklearn.metrics import auc\n", + "\n", + "prediction = []\n", + "test_y = []\n", + "\n", + "for x, y in test_data:\n", + " prediction.extend(model(x).numpy().flatten())\n", + " test_y.extend(y.numpy().flatten())\n", + "\n", + "prediction = np.array(prediction).flatten()\n", + "test_y = np.array(test_y)\n", + "\n", + "fpr_keras, tpr_keras, thresholds_keras = roc_curve(test_y, prediction)\n", + "auc_keras = auc(fpr_keras, tpr_keras)\n", + "\n", + "plt.figure(figsize=(5, 3.5), dpi=100)\n", + "plt.plot(fpr_keras, tpr_keras, label=\"AUC = {:.3f}\".format(auc_keras))\n", + "plt.plot([0, 1], [0, 1], linestyle=\"--\")\n", + "plt.xlabel(\"False Positive Rate\")\n", + "plt.ylabel(\"True Positive Rate\")\n", + "plt.legend()\n", + "plt.savefig(\"bbbp-rnn-roc.png\", dpi=300)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## LIME explanations\n", + "\n", + "In the following example, we find out what descriptors influence solubility of a molecules. For example, let's say we have a molecule with LogS=1.5. We create a perturbed chemical space around that molecule using `stoned` method and then use `lime` to find out which descriptors affect solubility predictions for that molecule. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Wrapper function for RNN, to use in STONED" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Predictor function is used as input to sample_space function\n", + "def predictor_function(smile_list, selfies):\n", + " if len(selfies) < 1:\n", + " selfies = [sf.encoder(s) for s in smile_list]\n", + " encoded = [selfies2ints(s) for s in selfies]\n", + " # check for nans\n", + " valid = [1.0 if sum(e) > 0 else np.nan for e in encoded]\n", + " encoded = [np.nan_to_num(e, nan=0) for e in encoded]\n", + " padded_seqs = tf.keras.preprocessing.sequence.pad_sequences(encoded, padding=\"post\")\n", + " labels = np.reshape(model.predict(padded_seqs), (-1))\n", + " return labels" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Other ploting utilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def space_fit_plot(space, beta, mode=\"regression\"):\n", + " fkw = {\"figsize\": (10, 4)}\n", + " font = {\"family\": \"normal\", \"weight\": \"normal\", \"size\": 16}\n", + "\n", + " fig = plt.figure(figsize=(10, 5))\n", + " mpl.rc(\"axes\", titlesize=12)\n", + " mpl.rc(\"font\", size=16)\n", + " ax_dict = fig.subplot_mosaic(\"AABBB\")\n", + "\n", + " # Plot space by fit\n", + " svg = exmol.plot_utils.plot_space_by_fit(\n", + " space,\n", + " [space[0]],\n", + " figure_kwargs=fkw,\n", + " mol_size=(200, 200),\n", + " offset=1,\n", + " ax=ax_dict[\"B\"],\n", + " beta=beta,\n", + " )\n", + " # Compute y_wls\n", + " w = np.array([1 / (1 + (1 / (e.similarity + 0.000001) - 1) ** 5) for e in space])\n", + " non_zero = w > 10 ** (-6)\n", + " w = w[non_zero]\n", + " N = w.shape[0]\n", + "\n", + " ys = np.array([e.yhat for e in space])[non_zero].reshape(N).astype(float)\n", + " x_mat = np.array([list(e.descriptors.descriptors) for e in space])[\n", + " non_zero\n", + " ].reshape(N, -1)\n", + " y_wls = x_mat @ beta\n", + " y_wls += np.mean(ys)\n", + "\n", + " lower = np.min(ys)\n", + " higher = np.max(ys)\n", + "\n", + " # set transparency using w\n", + " norm = plt.Normalize(min(w), max(w))\n", + " cmap = plt.cm.Oranges(w)\n", + " cmap[:, -1] = w\n", + "\n", + " def weighted_mean(x, w):\n", + " return np.sum(x * w) / np.sum(w)\n", + "\n", + " def weighted_cov(x, y, w):\n", + " return np.sum(\n", + " w * (x - weighted_mean(x, w)) * (y - weighted_mean(y, w))\n", + " ) / np.sum(w)\n", + "\n", + " def weighted_correlation(x, y, w):\n", + " return weighted_cov(x, y, w) / np.sqrt(\n", + " weighted_cov(x, x, w) * weighted_cov(y, y, w)\n", + " )\n", + "\n", + " corr = weighted_correlation(ys, y_wls, w)\n", + "\n", + " if mode == \"regression\":\n", + " ax_dict[\"A\"].plot(\n", + " np.linspace(lower, higher, 100),\n", + " np.linspace(lower, higher, 100),\n", + " \"--\",\n", + " linewidth=2,\n", + " )\n", + " sc = ax_dict[\"A\"].scatter(ys, y_wls, s=50, marker=\".\", c=cmap, cmap=cmap)\n", + " ax_dict[\"A\"].text(\n", + " max(ys) - 10,\n", + " min(ys) + 1,\n", + " f\"weighted \\ncorrelation = {corr:.3f}\",\n", + " fontsize=10,\n", + " )\n", + " ax_dict[\"A\"].set_xlabel(r\"$\\hat{y}$\")\n", + " ax_dict[\"A\"].set_ylabel(r\"$g$\")\n", + " ax_dict[\"A\"].set_title(\"Weighted Least Squares Fit\")\n", + " ax_dict[\"A\"].set_xlim(lower, higher)\n", + " ax_dict[\"A\"].set_ylim(lower, higher)\n", + " ax_dict[\"A\"].set_aspect(1.0 / ax_dict[\"A\"].get_data_ratio(), adjustable=\"box\")\n", + " sm = plt.cm.ScalarMappable(cmap=plt.cm.Oranges, norm=norm)\n", + " cbar = plt.colorbar(sm, orientation=\"horizontal\", pad=0.15, ax=ax_dict[\"A\"])\n", + " cbar.set_label(\"Chemical similarity\")\n", + " plt.tight_layout()\n", + " plt.savefig(\"weighted_fit.svg\", dpi=300, bbox_inches=\"tight\", transparent=False)\n", + " if mode == \"classification\":\n", + " fpr_keras, tpr_keras, thresholds_keras = roc_curve(ys, y_wls)\n", + " auc_keras = auc(fpr_keras, tpr_keras)\n", + " ax_dict[\"A\"].plot(fpr_keras, tpr_keras, label=\"AUC = {:.3f}\".format(auc_keras))\n", + " ax_dict[\"A\"].plot([0, 1], [0, 1], linestyle=\"--\")\n", + " ax_dict[\"A\"].set_xlabel(\"False Positive Rate\")\n", + " ax_dict[\"A\"].set_ylabel(\"True Positive Rate\")\n", + " ax_dict[\"A\"].set_aspect(1.0 / ax_dict[\"A\"].get_data_ratio(), adjustable=\"box\")\n", + " ax_dict[\"A\"].legend()\n", + " plt.tight_layout()\n", + " plt.savefig(\"space_fit.svg\", dpi=300)\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Descriptor explanations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# smi = 'Cc1onc(-c2ccccc2Cl)c1C(=O)NC1C(=O)N2C1SC(C)(C)C2C(=O)O'\n", + "# smi = soldata.SMILES[1400]\n", + "ibuprofen = \"CC(C)CC1=CC=C(C=C1)C(C)C(=O)O\" # ibuprofen is known to cross BBB\n", + "nicotine = bbb_data.smiles[bbb_data.name == \"nicotine\"].values[0] # treat anxiety\n", + "caffeine = bbb_data.smiles[bbb_data.name == \"caffeine\"].values[0]\n", + "chlorpromazine = \"CN(CCCN1c2ccccc2Sc2c1cc(Cl)cc2)C\"\n", + "gleevec = \"Cc1ccc(cc1Nc2nccc(n2)c3cccnc3)NC(=O)c4ccc(cc4)CN5CCN(CC5)C\"\n", + "# mol = MolFromSmiles(smi)\n", + "# from rdkit.Chem.Draw import MolToFile\n", + "\n", + "# MolToFile(mol, 'mol_paper.svg')\n", + "predictor_function([ibuprofen, nicotine, caffeine, chlorpromazine, gleevec], [])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "diazepam = \"CN1C(=O)CN=C(C2=C1C=CC(=C2)Cl)C3=CC=CC=C3\"\n", + "predictor_function([diazepam], [])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Make sure SMILES doesn't contain multiple fragments\n", + "# smi = soldata.SMILES[1400]\n", + "# print(smi)\n", + "alprozolam = \"CC1=NN=C2N1C3=C(C=C(C=C3)Cl)C(=NC2)C4=CC=CC=C4\"\n", + "stoned_kwargs = {\n", + " \"num_samples\": 5000,\n", + " \"alphabet\": exmol.get_basic_alphabet(),\n", + " \"max_mutations\": 1,\n", + "}\n", + "space = exmol.sample_space(alprozolam, predictor_function, stoned_kwargs=stoned_kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Filter space\n", + "from synspace.reos import REOS\n", + "\n", + "reos = REOS()\n", + "\n", + "filtered_space = []\n", + "for e in space:\n", + " if reos.process_mol(MolFromSmiles(e.smiles)) == (\"ok\", \"ok\"):\n", + " filtered_space.append(e)\n", + "\n", + "len(filtered_space)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import display, SVG\n", + "\n", + "desc_type = [\"Classic\", \"ECFP\", \"MACCS\"]\n", + "fkw = {\"figsize\": (6, 4)}\n", + "for d in desc_type:\n", + " beta = exmol.lime_explain(filtered_space, descriptor_type=d)\n", + " if d == \"Classic\":\n", + " exmol.plot_descriptors(filtered_space, output_file=f\"alprozolam_{d}.svg\")\n", + " else:\n", + " svg = exmol.plot_descriptors(\n", + " filtered_space, output_file=f\"alprozolam_{d}.svg\", return_svg=True\n", + " )\n", + " plt.close()\n", + " skunk.display(svg)\n", + " space_fit_plot(filtered_space, beta, mode=\"regression\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "exmol.lime_explain(filtered_space, \"ECFP\")\n", + "_ = exmol.plot_utils.similarity_map_using_tstats(filtered_space[0], return_svg=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "display(SVG(_))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "OPEN_AI_KEY = os.environ.get(\"OPENAI_API_KEY\")\n", + "exmol.lime_explain(filtered_space, \"ecfp\")\n", + "s1_ecfp = exmol.text_explain(filtered_space, \"ecfp\")\n", + "explanation = exmol.text_explain_generate(s1_ecfp, \"aqueous solubility\")\n", + "print(explanation)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "interpreter": { + "hash": "abc1ef2aae668f29add333aedc207234808b19831866b8480f007a054a2482dc" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/paper2_LIME/RF-lime.ipynb b/paper2_LIME/RF-lime.ipynb index e14a2fa..a8b4b6b 100644 --- a/paper2_LIME/RF-lime.ipynb +++ b/paper2_LIME/RF-lime.ipynb @@ -30,7 +30,7 @@ "import numpy as np\n", "import mordred, mordred.descriptors\n", "from mordred import HydrogenBond, Polarizability\n", - "from mordred import SLogP, AcidBase, BertzCT, Aromatic, BondCount, AtomCount\n", + "from mordred import SLogP, AcidBase, Aromatic, BondCount, AtomCount\n", "from mordred import Calculator\n", "\n", "import exmol as exmol\n", @@ -38,7 +38,6 @@ "import os\n", "from sklearn.ensemble import RandomForestRegressor\n", "from sklearn.model_selection import train_test_split\n", - "from sklearn.metrics import roc_auc_score, plot_roc_curve\n", "\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "rdDepictor.SetPreferCoordGen(True)\n", @@ -50,6 +49,9 @@ "soldata = pd.read_csv(\n", " \"https://github.com/whitead/dmol-book/raw/main/data/curated-solubility-dataset.csv\"\n", ")\n", + "#drop smile with containing 'P'\n", + "soldata = soldata[soldata[\"SMILES\"].str.contains(\"P\") == False]\n", + "\n", "features_start_at = list(soldata.columns).index(\"MolWt\")" ] }, @@ -97,7 +99,8 @@ "outputs": [], "source": [ "raw_features = np.array(raw_features)\n", - "labels = soldata[\"Solubility\"]" + "labels = soldata[\"Solubility\"]\n", + "print(len(labels)==len(molecules))" ] }, { @@ -197,7 +200,7 @@ "metadata": {}, "outputs": [], "source": [ - "smi = soldata.SMILES[1500]\n", + "smi = soldata.SMILES[150]\n", "stoned_kwargs = {\n", " \"num_samples\": 2000,\n", " \"alphabet\": exmol.get_basic_alphabet(),\n", diff --git a/paper2_LIME/Solubility-RNN.ipynb b/paper2_LIME/Solubility-RNN.ipynb index 450e235..0ae6f4a 100644 --- a/paper2_LIME/Solubility-RNN.ipynb +++ b/paper2_LIME/Solubility-RNN.ipynb @@ -22,10 +22,6 @@ "source": [ "import pandas as pd\n", "import matplotlib.pyplot as plt\n", - "from matplotlib.patches import Rectangle, FancyBboxPatch\n", - "from matplotlib.offsetbox import AnnotationBbox\n", - "import seaborn as sns\n", - "import skunk\n", "import matplotlib as mpl\n", "import numpy as np\n", "import tensorflow as tf\n", @@ -33,10 +29,8 @@ "import exmol\n", "from dataclasses import dataclass\n", "from rdkit.Chem.Draw import rdDepictor, MolsToGridImage\n", - "from rdkit.Chem import MolFromSmiles, MACCSkeys\n", + "from rdkit.Chem import MolFromSmiles\n", "import random\n", - "\n", - "\n", "rdDepictor.SetPreferCoordGen(True)\n", "import matplotlib.pyplot as plt\n", "import matplotlib.font_manager as font_manager\n", @@ -66,6 +60,7 @@ "soldata = pd.read_csv(\n", " \"https://github.com/whitead/dmol-book/raw/main/data/curated-solubility-dataset.csv\"\n", ")\n", + "\n", "features_start_at = list(soldata.columns).index(\"MolWt\")\n", "np.random.seed(0)\n", "random.seed(0)" @@ -204,7 +199,6 @@ "metadata": {}, "outputs": [], "source": [ - "# creating an object\n", "@dataclass\n", "class Config:\n", " vocab_size: int\n", @@ -509,6 +503,25 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Text explanations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "exmol.lime_explain(filtered_space, \"ecfp\")\n", + "s1_ecfp = exmol.text_explain(filtered_space, \"ecfp\")\n", + "explanation = exmol.text_explain_generate(s1_ecfp, \"aqueous solubility\")\n", + "print(explanation)" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/setup.py b/setup.py index 2333f61..8c6655b 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,6 @@ "skunk >= 0.4.0", "importlib-resources", "synspace", - "langchain==0.0.343", ], test_suite="tests", long_description=long_description, diff --git a/tests/test_exmol.py b/tests/test_exmol.py index 9c9bf23..ae64bd2 100644 --- a/tests/test_exmol.py +++ b/tests/test_exmol.py @@ -458,7 +458,7 @@ def model(s, se): s = exmol.text_explain(samples, "MACCS") assert len(s) > 0, "No explanation generated" - e = exmol.text_explain_generate(s, "soluble in water") + e = exmol.text_explain_generate([s], "soluble in water") samples1 = exmol.sample_space("c1cc(C(=O)O)c(OC(=O)C)cc1", model, batched=False) s = exmol.text_explain(samples1, "ECFP")