-
Notifications
You must be signed in to change notification settings - Fork 103
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #52 from ur-whitelab/local-rxn
Replace RXN4Chem: Running retrosynthesis and reaction prediction locally.
- Loading branch information
Showing
16 changed files
with
346 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
|
||
# Tools of organic chemistry | ||
|
||
A docker container was prepared for each tool, which exposes an api for requests. | ||
|
||
> docker run -d -p 8052:5000 doncamilom/rxnpred:latest | ||
Where 5000 is fixed, and 8082 is the port to be exposed. | ||
|
||
A request in curl can look like this | ||
|
||
> curl -X POST -H "Content-Type: application/json" -d '{"smiles": "O=C(OC(C)(C)C)c1ccc(C(=O)Nc2ccc(Cl)cc2)cc1"}' http://localhost:8082/api/v1/run | ||
Or in Python | ||
|
||
```python | ||
|
||
import json | ||
import requests | ||
|
||
def reaction_predict(reactants): | ||
response = requests.post( | ||
"http://localhost:8052/api/v1/run", | ||
headers={"Content-Type": "application/json"}, | ||
data=json.dumps({"smiles": reactants}) | ||
) | ||
return response.json()['product'][0] | ||
|
||
product = reaction_predict('CCOCCCCO.CC(=O)Cl') | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
FROM python:3.9 | ||
RUN pip install aizynthfinder[all] flask | ||
COPY files/ . | ||
COPY . . | ||
|
||
EXPOSE 5000 | ||
ENTRYPOINT ["python"] | ||
|
||
CMD ["app.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import json | ||
import subprocess | ||
|
||
from flask import Flask, jsonify, request | ||
|
||
app = Flask(__name__) | ||
|
||
@app.route('/api/v1/run', methods=['POST']) | ||
def rxnfp(): | ||
data = request.get_json() | ||
target = data.get("target", []) | ||
|
||
command = ["aizynthcli", "--config", "config.yml", "--smiles", f"{target}"] | ||
|
||
print(command) | ||
result = subprocess.run( | ||
command, check=True, capture_output=True, text=True | ||
) | ||
print(result) | ||
|
||
# Read output trees.json | ||
with open("trees.json", "r") as f: | ||
tree = json.load(f) | ||
|
||
return tree | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(host="0.0.0.0", port=5000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
expansion: | ||
uspto: | ||
- files/uspto_model.onnx | ||
- files/uspto_templates.csv.gz | ||
ringbreaker: | ||
- files/uspto_ringbreaker_model.onnx | ||
- files/uspto_ringbreaker_templates.csv.gz | ||
filter: | ||
uspto: files/uspto_filter_model.onnx | ||
stock: | ||
zinc: files/zinc_stock.hdf5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Download al lthe important files for runnig aizynthfinder using | ||
|
||
``` | ||
download_public_data . | ||
``` | ||
|
||
Which comes by installing aizynthfinder | ||
|
||
``` | ||
pip install aizynthfinder[all] | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
FROM python:3.10 | ||
WORKDIR /app | ||
|
||
RUN pip install rdkit-pypi==2022.3.1 | ||
RUN pip install OpenNMT-py==2.2.0 "numpy<2.0.0" | ||
|
||
COPY . . | ||
COPY input.txt . | ||
COPY models/ . | ||
CMD ["python", "app.py"] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import re | ||
import subprocess | ||
from flask import Flask, request, jsonify | ||
from rdkit import Chem | ||
|
||
app = Flask(__name__) | ||
|
||
|
||
SMI_REGEX_PATTERN = r"(\%\([0-9]{3}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\||\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])" | ||
|
||
def canonicalize_smiles(smiles, verbose=False): # will raise an Exception if invalid SMILES | ||
mol = Chem.MolFromSmiles(smiles) | ||
if mol is not None: | ||
return Chem.MolToSmiles(mol) | ||
else: | ||
if verbose: | ||
print(f'{smiles} is invalid.') | ||
return '' | ||
|
||
def smiles_tokenizer(smiles): | ||
"""Canonicalize and tokenize input smiles""" | ||
|
||
smiles = canonicalize_smiles(smiles) | ||
smiles_regex = re.compile(SMI_REGEX_PATTERN) | ||
tokens = [token for token in smiles_regex.findall(smiles)] | ||
return ' '.join(tokens) | ||
|
||
|
||
@app.route('/api/v1/run', methods=['POST']) | ||
def f(): | ||
request_data = request.get_json() | ||
input = request_data['smiles'] | ||
|
||
# Write the input to 'inp.txt' | ||
with open('input.txt', 'w') as f: | ||
# Tokenize smiles | ||
smi = smiles_tokenizer(input) | ||
f.write(smi) | ||
|
||
model_path = 'models/USPTO480k_model_step_400000.pt' | ||
|
||
src_path = 'input.txt' | ||
output_path = 'output.txt' | ||
n_best = 5 | ||
beam_size = 10 | ||
max_length = 300 | ||
batch_size = 1 | ||
|
||
try: | ||
# Construct the command to execute | ||
cmd = f"onmt_translate -model {model_path} " \ | ||
f"--src {src_path} " \ | ||
f"--output {output_path} --n_best {n_best} " \ | ||
f"--beam_size {beam_size} --max_length {max_length} " \ | ||
f"--batch_size {batch_size}" | ||
|
||
# Execute the command using subprocess.check_call() | ||
subprocess.check_call(cmd, shell=True) | ||
|
||
# Read produced output | ||
with open('output.txt', 'r') as f: | ||
prods = f.read() | ||
prods = re.sub(' ', '', prods).split('\n') | ||
|
||
|
||
# Return a success message | ||
return jsonify({'status': 'SUCCESS', 'product': prods}) | ||
|
||
except: | ||
return jsonify({'status': 'ERROR', 'product': None}) | ||
|
||
if __name__ == '__main__': | ||
# Run the Flask app | ||
app.run(debug=True, host='0.0.0.0') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Download model from https://drive.google.com/uc?id=1ywJCJHunoPTB5wr6KdZ8aLv7tMFMBHNy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
"""Self-hosted reaction tools. Retrosynthesis, reaction forward prediction.""" | ||
|
||
import abc | ||
import ast | ||
import re | ||
from time import sleep | ||
from typing import Optional | ||
|
||
import requests | ||
|
||
import json | ||
from langchain.chat_models import ChatOpenAI | ||
from langchain.schema import HumanMessage | ||
from langchain.tools import BaseTool | ||
|
||
from chemcrow.utils import is_smiles | ||
|
||
__all__ = ["RXNPredictLocal", "RXNRetrosynthesisLocal"] | ||
|
||
|
||
class RXNPredictLocal(BaseTool): | ||
"""Predict reaction.""" | ||
|
||
name = "ReactionPredict" | ||
description = ( | ||
"Predict the outcome of a chemical reaction. " | ||
"Takes as input the SMILES of the reactants separated by a dot '.', " | ||
"returns SMILES of the products." | ||
) | ||
|
||
def _run(self, reactants: str) -> str: | ||
"""Run reaction prediction.""" | ||
if not is_smiles(reactants): | ||
return "Incorrect input." | ||
|
||
product = self.predict_reaction(reactants) | ||
return product | ||
|
||
def predict_reaction(self, reactants: str) -> str: | ||
"""Make api request.""" | ||
try: | ||
response = requests.post( | ||
"http://localhost:8051/api/v1/run", | ||
headers={"Content-Type": "application/json"}, | ||
data=json.dumps({"smiles": reactants}) | ||
) | ||
return response.json()['product'][0] | ||
except: | ||
return "Error in prediction." | ||
|
||
|
||
class RXNRetrosynthesisLocal(BaseTool): | ||
"""Predict retrosynthesis.""" | ||
|
||
name = "ReactionRetrosynthesis" | ||
description = ( | ||
"Obtain the synthetic route to a chemical compound. " | ||
"Takes as input the SMILES of the product, returns recipe." | ||
) | ||
openai_api_key: str = "" | ||
|
||
def _run(self, reactants: str) -> str: | ||
"""Run reaction prediction.""" | ||
# Check that input is smiles | ||
if not is_smiles(reactants): | ||
return "Incorrect input." | ||
|
||
paths = self.retrosynthesis(reactants) | ||
procedure = self.get_action_sequence(paths[0]) | ||
return procedure | ||
|
||
def retrosynthesis(self, reactants: str) -> str: | ||
"""Make api request.""" | ||
response = requests.post( | ||
"http://localhost:8052/api/v1/run", | ||
headers={"Content-Type": "application/json"}, | ||
data=json.dumps({"smiles": reactants}) | ||
) | ||
return response.json() | ||
|
||
def get_action_sequence(self, path): | ||
"""Get sequence of actions.""" | ||
actions = path | ||
json_actions = self._preproc_actions(actions) | ||
llm_sum = self._summary_gpt(json_actions) | ||
return llm_sum | ||
|
||
def _preproc_actions(self, path): | ||
"""Preprocess actions.""" | ||
def _clean_actions(d): | ||
if 'metadata' in d: | ||
if 'mapped_reaction_smiles' in d['metadata']: | ||
r = d['metadata']['mapped_reaction_smiles'].split(">>") | ||
yield {"reactants": r[1], "products": r[0]} | ||
if 'children' in d: | ||
for c in d['children']: | ||
yield from _clean_actions(c) | ||
|
||
rxns = list(_clean_actions(path)) | ||
return rxns | ||
|
||
def _summary_gpt(self, json: dict) -> str: | ||
"""Describe synthesis.""" | ||
llm = ChatOpenAI( # type: ignore | ||
temperature=0.05, | ||
model_name="gpt-3.5-turbo-16k", | ||
request_timeout=2000, | ||
max_tokens=2000, | ||
openai_api_key=self.openai_api_key, | ||
) | ||
prompt = ( | ||
"Here is a chemical synthesis described as a json.\nYour task is " | ||
"to describe the synthesis, as if you were giving instructions for" | ||
"a recipe. Use only the substances, quantities, temperatures and " | ||
"in general any action mentioned in the json file. This is your " | ||
"only source of information, do not make up anything else. Also, " | ||
"add 15mL of DCM as a solvent in the first step. If you ever need " | ||
'to refer to the json file, refer to it as "(by) the tool". ' | ||
"However avoid references to it. \nFor this task, give as many " | ||
f"details as possible.\n {str(json)}" | ||
) | ||
return llm([HumanMessage(content=prompt)]).content |
Oops, something went wrong.