Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vis improvement #134

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 105 additions & 27 deletions mdagent/tools/base_tools/analysis_tools/vis_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import nbformat as nbf
from langchain.tools import BaseTool
from pydantic import BaseModel, Field

from mdagent.utils import PathRegistry

Expand Down Expand Up @@ -57,34 +58,44 @@ def run_molrender(self, cif_path: str) -> str:
f"saved as: mol_render_{self.cif_file_name}"
)

def create_notebook(self, cif_file: str) -> str:
def create_notebook(self, top_file: str, traj=None) -> str:
"""This is for plan B
tool, it will create
a notebook
with the code to
install nglview and
display the cif/pdb file."""
self.cif_file_name = os.path.basename(cif_file)
self.cif_file_name = os.path.basename(top_file)

# Create a new notebook
nb = nbf.v4.new_notebook()

# Code to install NGLview
install_code = "!pip install -q nglview"

disclaimer = (
"#Note: Is possible the agent misses the \n"
"#correct topology file, and/or trajectory file.\n"
"#Check if the files are correct beforehand.\n"
)
if traj:
# Code to import NGLview and display a file
install_code += "\n!pip install -q mdtraj"

import_code = (
"import nglview as nv\nimport mdtraj as md\n"
f"traj = md.load('{traj}',top='{top_file}')\n"
"view=nv.show_mdtraj(traj)\nview"
)
# Code to import NGLview and display a file
import_code = f"""
import nglview as nv
view = nv.show_file("{cif_file}")
view
"""
import_code = f"import nglview as nv\nview=nv.show_file('{top_file}')\nview"

# Create new code cells
install_cell = nbf.v4.new_code_cell(source=install_code)
disclaimer_cell = nbf.v4.new_markdown_cell(disclaimer)
import_cell = nbf.v4.new_code_cell(source=import_code)

# Add the cells
nb.cells.extend([install_cell, import_cell])
nb.cells.extend([install_cell, disclaimer_cell, import_cell])

# Write the notebook to a file
notebook_name = (
Expand All @@ -102,49 +113,116 @@ def create_notebook(self, cif_file: str) -> str:
return "Visualization Complete"


class visProteinSchema(BaseModel):
topology_fileid: str = Field(
decription="The fileid of the protein file to visualize"
)
trajectory_fileid: Optional[str] = Field(
description="The fileid of the trajectory"
" file to visualize if type is 'movie'"
)
type: Optional[str] = Field(
"static",
description=(
"The type of visualization to create."
"Options are 'static' (default) or 'movie'"
),
)


class VisualizeProtein(BaseTool):
"""To get a png, you must install molrender
https://github.com/molstar/molrender/tree/master
Otherwise, you will get a notebook where you
can visualize the protein."""

name = "PDBVisualization"
description = """This tool will create
a visualization of a cif
file as a png file OR
it will create
a .ipynb file with the
visualization of the
file, depending on the
packages available.
If a notebook is created,
the user can open the
notebook and visualize the
system."""
args_schema = visProteinSchema
description = (
"This tool will create"
" a visualization of a protein"
" file as a png file OR"
" it will create"
" a .ipynb file with the"
" visualization of the"
" file, depending on the"
" packages available."
" If a notebook is created,"
" the user can open the"
" notebook and visualize the"
" system."
)
path_registry: Optional[PathRegistry]

def __init__(self, path_registry: Optional[PathRegistry]):
super().__init__()
self.path_registry = path_registry

def _run(self, cif_file_name: str) -> str:
def _run(self, **input):
"""use the tool."""
input = self.validate_input(input)
topology_id = input["topology_fileid"]
if not self.path_registry:
return "Failed. Error: Path registry is not set"
cif_path = self.path_registry.get_mapped_path(cif_file_name)
if not cif_path:
return f"Failed. File not found: {cif_file_name}"
topology_path = self.path_registry.get_mapped_path(topology_id)
type = input["type"]
if not self.path_registry:
return "Failed. Error: Path registry is not set"
top_path = self.path_registry.get_mapped_path(topology_path)
if not top_path:
return f"Failed. File not found: {topology_id}"
if type == "movie":
trajectory_id = input["trajectory_fileid"]
if not trajectory_id:
print("no trajectory fileid, using static visualization")
type = "static"
else:
traj_path = self.path_registry.get_mapped_path(trajectory_id)
vis = VisFunctions(self.path_registry)
try:
return "Succeeded. " + vis.run_molrender(cif_path)
if type == "static":
return "Succeeded" + vis.run_molrender(top_path)
if type == "movie":
return "Succeeded" + vis.create_notebook(top_path, traj=traj_path)
except (RuntimeError, FileNotFoundError) as e:
print(f"Error running molrender: {str(e)}. Using NGLView instead.")
try:
vis.create_notebook(cif_path)
vis.create_notebook(top_path)
return "Succeeded. Visualization created as notebook"
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")

def validate_input(self, input):
input = input.get("input", input)
input = input.get("action_input", input)
error = ""
top_id = input.get("topology_fileid")
if not top_id:
error += "topology_fileid field is required. "

# check if trajectory id is valid
fileids = self.path_registry.list_path_names()

if top_id not in fileids:
error += "trajectory_fileid not in path registry"

trajectory_id = input.get("trajectory_fileid", None)
if trajectory_id:
if trajectory_id not in fileids:
error += "trajectory_fileid not in path registry"

type = input.get("type", "static")
if type not in ["static", "movie"]:
type = "static"

if error == "":
error = None
return {
"protein_fileid": top_id,
"trajectory_fileid": trajectory_id,
"type": type,
}
Loading