Skip to content

Commit

Permalink
add method, reasoning, and the iteration history, also for multi-para…
Browse files Browse the repository at this point in the history
…meters, we train one parameter at a time
  • Loading branch information
liyin2015 committed Jan 20, 2025
1 parent bbda118 commit 2f60315
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 415 deletions.
166 changes: 103 additions & 63 deletions adalflow/adalflow/optim/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,9 @@ def draw_interactive_html_graph(
from jinja2 import Template

output_file = "interactive_graph.html"
final_file = filepath + "_" + output_file if filepath else output_file
filepath = filepath or "output"
os.makedirs(filepath, exist_ok=True)
final_file = os.path.join(filepath, output_file)

net = Network(height="750px", width="100%", directed=True)

Expand All @@ -919,9 +921,11 @@ def draw_interactive_html_graph(
for node in nodes:
self.generate_node_html(node, output_dir=filepath)

node_id = node.id
node_show_name = node.name.replace(f"_{node_id}", "")
label = (
f"""<div style="max-height: 150px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; background: white; position: relative; font-family: Arial, sans-serif;">"""
f"<b>Name:</b> {node.name[0:10]}<br>"
f"<b>Name:</b> {node_show_name}<br>"
f"<b>Role:</b> {node.role_desc.capitalize()}<br>"
f"<b>Value:</b> {node.data}<br>"
f"<b>Data ID:</b> {node.data_id}<br>"
Expand All @@ -935,10 +939,10 @@ def draw_interactive_html_graph(

net.add_node(
n_id=node.id,
label=node.name[0:16],
label=node_show_name,
title=label,
color=node_colors.get(node.param_type, "gray"),
url=f"{filepath}/{node.name}.html",
url=f"./{node.name}.html", # Relative path
)
node_ids.add(node.id)

Expand Down Expand Up @@ -976,13 +980,59 @@ def draw_interactive_html_graph(
display: block;
margin-top: 10px;
}
/* Simple styling for the legend */
#legend {
margin-top: 20px;
font-family: Arial, sans-serif;
font-size: 14px;
}
.legend-item {
margin-bottom: 5px;
}
.legend-color-box {
display: inline-block;
width: 12px;
height: 12px;
margin-right: 5px;
border: 1px solid #000;
vertical-align: middle;
}
</style>
</head>
<body>
<div id="tooltip">
<div id="tooltip-content"></div>
<button onclick="document.getElementById('tooltip').style.display='none'">Close</button>
</div>
<!-- Legend Section -->
<div id="legend">
<strong>Legend:</strong>
<div class="legend-item">
<span class="legend-color-box" style="background-color: lightblue;"></span>PROMPT
</div>
<div class="legend-item">
<span class="legend-color-box" style="background-color: orange;"></span>DEMOS
</div>
<div class="legend-item">
<span class="legend-color-box" style="background-color: gray;"></span>INPUT
</div>
<div class="legend-item">
<span class="legend-color-box" style="background-color: green;"></span>OUTPUT
</div>
<div class="legend-item">
<span class="legend-color-box" style="background-color: purple;"></span>GENERATOR_OUTPUT
</div>
<div class="legend-item">
<span class="legend-color-box" style="background-color: red;"></span>RETRIEVER_OUTPUT
</div>
<div class="legend-item">
<span class="legend-color-box" style="background-color: pink;"></span>LOSS_OUTPUT
</div>
<div class="legend-item">
<span class="legend-color-box" style="background-color: blue;"></span>SUM_OUTPUT
</div>
</div>
<!-- End Legend Section -->
<div id="mynetwork" style="height: {{ height }};"></div>
<script type="text/javascript">
var nodes = new vis.DataSet({{ nodes | safe }});
Expand Down Expand Up @@ -1013,6 +1063,37 @@ def draw_interactive_html_graph(

return {"graph_path": final_file}

@staticmethod
def wrap_and_escape(text, width=40):
r"""Wrap text to the specified width, considering HTML breaks, and escape special characters."""
try:
import textwrap
except ImportError as e:
raise ImportError(
"Please install textwrap using 'pip install textwrap' to use this feature"
) from e

def wrap_text(text, width):
"""Wrap text to the specified width, considering HTML breaks."""
lines = textwrap.wrap(
text, width, break_long_words=False, replace_whitespace=False
)
return "<br/>".join(lines)

if not isinstance(text, str):
text = str(text)
text = (
text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
.replace("'", "&apos;")
.replace(
"\n", "<br/>"
) # Convert newlines to HTML line breaks if using HTML labels
)
return wrap_text(text, width)

def draw_graph(
self,
add_grads: bool = True,
Expand Down Expand Up @@ -1042,19 +1123,8 @@ def draw_graph(
) from e

assert rankdir in ["LR", "TB"]
try:
import textwrap
except ImportError as e:
raise ImportError(
"Please install textwrap using 'pip install textwrap' to use this feature"
) from e

root_path = get_adalflow_default_root_path()
# # prepare the log directory
# log_dir = os.path.join(root_path, "logs")

# # Set up TensorBoard logging
# writer = SummaryWriter(log_dir)

filename = f"trace_graph_{self.name}_id_{self.id}"
filepath = (
Expand All @@ -1065,29 +1135,6 @@ def draw_graph(
# final_path = f"{filepath}.{format}"
print(f"Saving graph to {filepath}.{format}")

def wrap_text(text, width):
"""Wrap text to the specified width, considering HTML breaks."""
lines = textwrap.wrap(
text, width, break_long_words=False, replace_whitespace=False
)
return "<br/>".join(lines)

def wrap_and_escape(text, width=40):
r"""Wrap text to the specified width, considering HTML breaks, and escape special characters."""
if not isinstance(text, str):
text = str(text)
text = (
text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
.replace("'", "&apos;")
.replace(
"\n", "<br/>"
) # Convert newlines to HTML line breaks if using HTML labels
)
return wrap_text(text, width)

nodes, edges = self.trace_graph(self)
dot = Digraph(format=format, graph_attr={"rankdir": rankdir, "dpi": "300"})
node_names = set()
Expand All @@ -1096,32 +1143,32 @@ def wrap_and_escape(text, width=40):

node_label = (
f"<table border='0' cellborder='1' cellspacing='0'>"
f"<tr><td><b><font color='{label_color}'>Name: </font></b></td><td>{wrap_and_escape(n.id)}</td></tr>"
f"<tr><td><b><font color='{label_color}'>Name: </font></b></td><td>{wrap_and_escape(n.name)}</td></tr>"
f"<tr><td><b><font color='{label_color}'>Role: </font></b></td><td>{wrap_and_escape(n.role_desc.capitalize())}</td></tr>"
f"<tr><td><b><font color='{label_color}'>Value: </font></b></td><td>{wrap_and_escape(n.data)}</td></tr>"
f"<tr><td><b><font color='{label_color}'>Name: </font></b></td><td>{self.wrap_and_escape(n.id)}</td></tr>"
f"<tr><td><b><font color='{label_color}'>Name: </font></b></td><td>{self.wrap_and_escape(n.name)}</td></tr>"
f"<tr><td><b><font color='{label_color}'>Role: </font></b></td><td>{self.wrap_and_escape(n.role_desc.capitalize())}</td></tr>"
f"<tr><td><b><font color='{label_color}'>Value: </font></b></td><td>{self.wrap_and_escape(n.data)}</td></tr>"
)
if n.data_id is not None:
node_label += f"<tr><td><b><font color='{label_color}'>Data ID: </font></b></td><td>{wrap_and_escape(n.data_id)}</td></tr>"
node_label += f"<tr><td><b><font color='{label_color}'>Data ID: </font></b></td><td>{self.wrap_and_escape(n.data_id)}</td></tr>"
if n.proposing:
node_label += f"<tr><td><b><font color='{label_color}'>Proposing</font></b></td><td>{{'Yes'}}</td></tr>"
node_label += f"<tr><td><b><font color='{label_color}'>Previous Value: </font></b></td><td>{wrap_and_escape(n.previous_data)}</td></tr>"
node_label += f"<tr><td><b><font color='{label_color}'>Previous Value: </font></b></td><td>{self.wrap_and_escape(n.previous_data)}</td></tr>"
if n.requires_opt:
node_label += f"<tr><td><b><font color='{label_color}'>Requires Optimization: </font ></b></td><td>{{'Yes'}}</td></tr>"
if n.param_type:
node_label += f"<tr><td><b><font color='{label_color}'>Type: </font></b></td><td>{wrap_and_escape(n.param_type.name)}</td></tr>"
node_label += f"<tr><td><b><font color='{label_color}'>Type: </font></b></td><td>{self.wrap_and_escape(n.param_type.name)}</td></tr>"
if (
full_trace
and hasattr(n, "component_trace")
and n.component_trace.api_kwargs is not None
):
node_label += f"<tr><td><b><font color='{label_color}'> API kwargs: </font></b></td><td>{wrap_and_escape(str(n.component_trace.api_kwargs))}</td></tr>"
node_label += f"<tr><td><b><font color='{label_color}'> API kwargs: </font></b></td><td>{self.wrap_and_escape(str(n.component_trace.api_kwargs))}</td></tr>"

# show the score for intermediate nodes
if n.score is not None and len(n.predecessors) > 0:
node_label += f"<tr><td><b><font color='{label_color}'>Score: </font></b></td><td>{str(n.score)}</td></tr>"
if add_grads:
node_label += f"<tr><td><b><font color='{label_color}'>Gradients: </font></b></td><td>{wrap_and_escape(n.get_gradients_names())}</td></tr>"
node_label += f"<tr><td><b><font color='{label_color}'>Gradients: </font></b></td><td>{self.wrap_and_escape(n.get_gradients_names())}</td></tr>"
# add a list of each gradient with short value
# combine the gradients and context
# combined_gradients_contexts = zip(
Expand All @@ -1132,22 +1179,22 @@ def wrap_and_escape(text, width=40):
gradient_context = g.context
log.info(f"Gradient context display: {gradient_context}")
log.info(f"data: {g.data}")
node_label += f"<tr><td><b><font color='{label_color}'>Gradient {g.name} Feedback: </font></b></td><td>{wrap_and_escape(g.data)}</td></tr>"
node_label += f"<tr><td><b><font color='{label_color}'>Gradient {g.name} Feedback: </font></b></td><td>{self.wrap_and_escape(g.data)}</td></tr>"
# if gradient_context != "":
# node_label += f"<tr><td><b><font color='{label_color}'>Gradient {g.name} Context: </font></b></td><td>{wrap_and_escape(gradient_context)}</td></tr>"
# if g.prompt:
# node_label += f"<tr><td><b><font color='{label_color}'>Gradient {g.name} Prompt: </font></b></td><td>{wrap_and_escape(g.prompt)}</td></tr>"
if len(n._traces.values()) > 0:
node_label += f"<tr><td><b><font color='{label_color}'>Traces: keys: </font></b></td><td>{wrap_and_escape(str(n._traces.keys()))}</td></tr>"
node_label += f"<tr><td><b><font color='{label_color}'>Traces: values: </font></b></td><td>{wrap_and_escape(str(n._traces.values()))}</td></tr>"
node_label += f"<tr><td><b><font color='{label_color}'>Traces: keys: </font></b></td><td>{self.wrap_and_escape(str(n._traces.keys()))}</td></tr>"
node_label += f"<tr><td><b><font color='{label_color}'>Traces: values: </font></b></td><td>{self.wrap_and_escape(str(n._traces.values()))}</td></tr>"
if n.tgd_optimizer_trace is not None:
node_label += f"<tr><td><b><font color='{label_color}'>TGD Optimizer Trace: </font></b></td><td>{wrap_and_escape(str(n.tgd_optimizer_trace))}</td></tr>"
node_label += f"<tr><td><b><font color='{label_color}'>TGD Optimizer Trace: </font></b></td><td>{self.wrap_and_escape(str(n.tgd_optimizer_trace))}</td></tr>"

# show component trace, id and name
if hasattr(n, "component_trace") and n.component_trace.id is not None:
node_label += f"<tr><td><b><font color='{label_color}'>Component Trace ID: </font></b></td><td>{wrap_and_escape(str(n.component_trace.id))}</td></tr>"
node_label += f"<tr><td><b><font color='{label_color}'>Component Trace ID: </font></b></td><td>{self.wrap_and_escape(str(n.component_trace.id))}</td></tr>"
if hasattr(n, "component_trace") and n.component_trace.name is not None:
node_label += f"<tr><td><b><font color='{label_color}'>Component Trace Name: </font></b></td><td>{wrap_and_escape(str(n.component_trace.name))}</td></tr>"
node_label += f"<tr><td><b><font color='{label_color}'>Component Trace Name: </font></b></td><td>{self.wrap_and_escape(str(n.component_trace.name))}</td></tr>"

node_label += "</table>"
# check if the name exists in dot
Expand Down Expand Up @@ -1232,19 +1279,12 @@ def draw_output_subgraph(
node_ids = set()

for node in nodes:
escaped_name = html.escape(node.name if node.name else "")
escaped_param_type = html.escape(
node.param_type.name if node.param_type else ""
)
escaped_value = html.escape(
node.get_short_value() if node.get_short_value() else ""
)

node_label = f"""
<table border="0" cellborder="1" cellspacing="0">
<tr><td><b>Name:</b></td><td>{escaped_name}</td></tr>
<tr><td><b>Type:</b></td><td>{escaped_param_type}</td></tr>
<tr><td><b>Value:</b></td><td>{escaped_value}</td></tr>"""
<tr><td><b>Name:</b></td><td>{self.wrap_and_escape(node.name)}</td></tr>
<tr><td><b>Type:</b></td><td>{self.wrap_and_escape(node.param_type.name)}</td></tr>
<tr><td><b>Value:</b></td><td>{self.wrap_and_escape(node.get_short_value())}</td></tr>"""
# add the component trace id and name
if hasattr(node, "component_trace") and node.component_trace.id is not None:
escaped_ct_id = html.escape(str(node.component_trace.id))
Expand Down
Loading

0 comments on commit 2f60315

Please sign in to comment.