diff --git a/mdagent/tools/base_tools/analysis_tools/vis_tools.py b/mdagent/tools/base_tools/analysis_tools/vis_tools.py index 3223c1ac..7de53ddc 100644 --- a/mdagent/tools/base_tools/analysis_tools/vis_tools.py +++ b/mdagent/tools/base_tools/analysis_tools/vis_tools.py @@ -5,6 +5,7 @@ import nbformat as nbf from langchain.tools import BaseTool +from pydantic import BaseModel, Field from mdagent.utils import PathRegistry @@ -57,34 +58,44 @@ def run_molrender(self, cif_path: str) -> str: f"saved as: mol_render_{self.cif_file_name}" ) - def create_notebook(self, cif_file: str) -> str: + def create_notebook(self, top_file: str, traj=None) -> str: """This is for plan B tool, it will create a notebook with the code to install nglview and display the cif/pdb file.""" - self.cif_file_name = os.path.basename(cif_file) + self.cif_file_name = os.path.basename(top_file) # Create a new notebook nb = nbf.v4.new_notebook() # Code to install NGLview install_code = "!pip install -q nglview" - + disclaimer = ( + "#Note: Is possible the agent misses the \n" + "#correct topology file, and/or trajectory file.\n" + "#Check if the files are correct beforehand.\n" + ) + if traj: + # Code to import NGLview and display a file + install_code += "\n!pip install -q mdtraj" + + import_code = ( + "import nglview as nv\nimport mdtraj as md\n" + f"traj = md.load('{traj}',top='{top_file}')\n" + "view=nv.show_mdtraj(traj)\nview" + ) # Code to import NGLview and display a file - import_code = f""" -import nglview as nv -view = nv.show_file("{cif_file}") -view -""" + import_code = f"import nglview as nv\nview=nv.show_file('{top_file}')\nview" # Create new code cells install_cell = nbf.v4.new_code_cell(source=install_code) + disclaimer_cell = nbf.v4.new_markdown_cell(disclaimer) import_cell = nbf.v4.new_code_cell(source=import_code) # Add the cells - nb.cells.extend([install_cell, import_cell]) + nb.cells.extend([install_cell, disclaimer_cell, import_cell]) # Write the notebook to a file notebook_name = ( @@ -102,6 +113,23 @@ def create_notebook(self, cif_file: str) -> str: return "Visualization Complete" +class visProteinSchema(BaseModel): + topology_fileid: str = Field( + decription="The fileid of the protein file to visualize" + ) + trajectory_fileid: Optional[str] = Field( + description="The fileid of the trajectory" + " file to visualize if type is 'movie'" + ) + type: Optional[str] = Field( + "static", + description=( + "The type of visualization to create." + "Options are 'static' (default) or 'movie'" + ), + ) + + class VisualizeProtein(BaseTool): """To get a png, you must install molrender https://github.com/molstar/molrender/tree/master @@ -109,38 +137,57 @@ class VisualizeProtein(BaseTool): can visualize the protein.""" name = "PDBVisualization" - description = """This tool will create - a visualization of a cif - file as a png file OR - it will create - a .ipynb file with the - visualization of the - file, depending on the - packages available. - If a notebook is created, - the user can open the - notebook and visualize the - system.""" + args_schema = visProteinSchema + description = ( + "This tool will create" + " a visualization of a protein" + " file as a png file OR" + " it will create" + " a .ipynb file with the" + " visualization of the" + " file, depending on the" + " packages available." + " If a notebook is created," + " the user can open the" + " notebook and visualize the" + " system." + ) path_registry: Optional[PathRegistry] def __init__(self, path_registry: Optional[PathRegistry]): super().__init__() self.path_registry = path_registry - def _run(self, cif_file_name: str) -> str: + def _run(self, **input): """use the tool.""" + input = self.validate_input(input) + topology_id = input["topology_fileid"] if not self.path_registry: return "Failed. Error: Path registry is not set" - cif_path = self.path_registry.get_mapped_path(cif_file_name) - if not cif_path: - return f"Failed. File not found: {cif_file_name}" + topology_path = self.path_registry.get_mapped_path(topology_id) + type = input["type"] + if not self.path_registry: + return "Failed. Error: Path registry is not set" + top_path = self.path_registry.get_mapped_path(topology_path) + if not top_path: + return f"Failed. File not found: {topology_id}" + if type == "movie": + trajectory_id = input["trajectory_fileid"] + if not trajectory_id: + print("no trajectory fileid, using static visualization") + type = "static" + else: + traj_path = self.path_registry.get_mapped_path(trajectory_id) vis = VisFunctions(self.path_registry) try: - return "Succeeded. " + vis.run_molrender(cif_path) + if type == "static": + return "Succeeded" + vis.run_molrender(top_path) + if type == "movie": + return "Succeeded" + vis.create_notebook(top_path, traj=traj_path) except (RuntimeError, FileNotFoundError) as e: print(f"Error running molrender: {str(e)}. Using NGLView instead.") try: - vis.create_notebook(cif_path) + vis.create_notebook(top_path) return "Succeeded. Visualization created as notebook" except Exception as e: return f"Failed. {type(e).__name__}: {e}" @@ -148,3 +195,34 @@ def _run(self, cif_file_name: str) -> str: async def _arun(self, query: str) -> str: """Use the tool asynchronously.""" raise NotImplementedError("custom_search does not support async") + + def validate_input(self, input): + input = input.get("input", input) + input = input.get("action_input", input) + error = "" + top_id = input.get("topology_fileid") + if not top_id: + error += "topology_fileid field is required. " + + # check if trajectory id is valid + fileids = self.path_registry.list_path_names() + + if top_id not in fileids: + error += "trajectory_fileid not in path registry" + + trajectory_id = input.get("trajectory_fileid", None) + if trajectory_id: + if trajectory_id not in fileids: + error += "trajectory_fileid not in path registry" + + type = input.get("type", "static") + if type not in ["static", "movie"]: + type = "static" + + if error == "": + error = None + return { + "protein_fileid": top_id, + "trajectory_fileid": trajectory_id, + "type": type, + }