From 2f603152a3c0de23b8a5c0725432a41c35984881 Mon Sep 17 00:00:00 2001 From: Li Yin Date: Sun, 19 Jan 2025 18:42:47 -0800 Subject: [PATCH] add method, reasoning, and the iteration history, also for multi-parameters, we train one parameter at a time --- adalflow/adalflow/optim/parameter.py | 166 ++++++---- .../adalflow/optim/text_grad/tgd_optimizer.py | 302 ++++-------------- adalflow/adalflow/optim/trainer/trainer.py | 6 + .../hotpot_qa/adal_exp/build_multi_hop_rag.py | 5 +- .../hotpot_qa/adal_exp/train_multi_hop_rag.py | 23 +- .../adal_exp/train_multi_hop_rag_cycle.py | 29 +- .../hotpot_qa/adal_exp/train_vanilla_rag.py | 13 +- use_cases/classification/train.py | 38 +-- .../bbh/object_count/train_new.py | 19 +- use_cases/text_grad_2.0_train.py | 10 +- 10 files changed, 196 insertions(+), 415 deletions(-) diff --git a/adalflow/adalflow/optim/parameter.py b/adalflow/adalflow/optim/parameter.py index 4dd72b4f..5a510083 100644 --- a/adalflow/adalflow/optim/parameter.py +++ b/adalflow/adalflow/optim/parameter.py @@ -899,7 +899,9 @@ def draw_interactive_html_graph( from jinja2 import Template output_file = "interactive_graph.html" - final_file = filepath + "_" + output_file if filepath else output_file + filepath = filepath or "output" + os.makedirs(filepath, exist_ok=True) + final_file = os.path.join(filepath, output_file) net = Network(height="750px", width="100%", directed=True) @@ -919,9 +921,11 @@ def draw_interactive_html_graph( for node in nodes: self.generate_node_html(node, output_dir=filepath) + node_id = node.id + node_show_name = node.name.replace(f"_{node_id}", "") label = ( f"""
""" - f"Name: {node.name[0:10]}
" + f"Name: {node_show_name}
" f"Role: {node.role_desc.capitalize()}
" f"Value: {node.data}
" f"Data ID: {node.data_id}
" @@ -935,10 +939,10 @@ def draw_interactive_html_graph( net.add_node( n_id=node.id, - label=node.name[0:16], + label=node_show_name, title=label, color=node_colors.get(node.param_type, "gray"), - url=f"{filepath}/{node.name}.html", + url=f"./{node.name}.html", # Relative path ) node_ids.add(node.id) @@ -976,6 +980,23 @@ def draw_interactive_html_graph( display: block; margin-top: 10px; } + /* Simple styling for the legend */ + #legend { + margin-top: 20px; + font-family: Arial, sans-serif; + font-size: 14px; + } + .legend-item { + margin-bottom: 5px; + } + .legend-color-box { + display: inline-block; + width: 12px; + height: 12px; + margin-right: 5px; + border: 1px solid #000; + vertical-align: middle; + } @@ -983,6 +1004,35 @@ def draw_interactive_html_graph(
+ +
+ Legend: +
+ PROMPT +
+
+ DEMOS +
+
+ INPUT +
+
+ OUTPUT +
+
+ GENERATOR_OUTPUT +
+
+ RETRIEVER_OUTPUT +
+
+ LOSS_OUTPUT +
+
+ SUM_OUTPUT +
+
+