Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tool Router Keys now use Authors. Dependency Checks. Tools w/Assets #801

Merged
merged 7 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl JobPromptGenerator {
match tool_content.get("function") {
Some(function) => {
let tool_router_key = function["tool_router_key"].as_str().unwrap_or("");
if tool_router_key == "local:::rust_toolkit:::shinkai_sqlite_query_executor" {
if tool_router_key == "local:::__official_shinkai:::shinkai_sqlite_query_executor" {
let job_id_clone = job_id.clone();
if let Ok(current_tables) = get_current_tables(job_id_clone).await {
if !current_tables.is_empty() {
Expand Down Expand Up @@ -154,4 +154,4 @@ impl JobPromptGenerator {
}
}

// TODO: add tests
// TODO: add tests
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ impl SheetRustFunctions {
},
ToolOutputArg::empty(),
None,
"local:::rust_toolkit:::shinkai_sheet_ui_create_new_column_with_values".to_string(),
"local:::__official_shinkai:::shinkai_sheet_ui_create_new_column_with_values".to_string(),
);

// Add the tool definition for update_column_with_values
Expand All @@ -546,7 +546,7 @@ impl SheetRustFunctions {
},
ToolOutputArg::empty(),
None,
"local:::rust_toolkit:::shinkai_sheet_ui_update_column_with_values".to_string(),
"local:::__official_shinkai:::shinkai_sheet_ui_update_column_with_values".to_string(),
);

// Add the tool definition for replace_value_at_position
Expand All @@ -562,7 +562,7 @@ impl SheetRustFunctions {
},
ToolOutputArg::empty(),
None,
"local:::rust_toolkit:::shinkai_sheet_ui_replace_value_at_position".to_string(),
"local:::__official_shinkai:::shinkai_sheet_ui_replace_value_at_position".to_string(),
);

// Add the tool definition for create_new_columns_with_csv
Expand All @@ -576,7 +576,7 @@ impl SheetRustFunctions {
},
ToolOutputArg::empty(),
None,
"local:::rust_toolkit:::shinkai_sheet_ui_create_new_columns_with_csv".to_string(),
"local:::__official_shinkai:::shinkai_sheet_ui_create_new_columns_with_csv".to_string(),
);

// Add the tool definition for get_table
Expand All @@ -586,7 +586,7 @@ impl SheetRustFunctions {
Parameters::new(),
ToolOutputArg::empty(),
None,
"local:::rust_toolkit:::shinkai_sheet_ui_get_table".to_string(),
"local:::__official_shinkai:::shinkai_sheet_ui_get_table".to_string(),
);

vec![
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ mod tests {
},
ToolOutputArg::empty(),
None,
"local:::rust_toolkit:::concat_strings".to_string(),
"local:::__official_shinkai:::concat_strings".to_string(),
);
let shinkai_tool = ShinkaiTool::Rust(tool, true);

Expand Down Expand Up @@ -90,7 +90,7 @@ mod tests {
functions: Some(vec![FunctionDetails {
name: "concat_strings".to_string(),
description: "Concatenates 2 to 4 strings.".to_string(),
tool_router_key: Some("local:::rust_toolkit:::concat_strings".to_string()),
tool_router_key: Some("local:::__official_shinkai:::concat_strings".to_string()),
parameters: FunctionParameters {
type_: "object".to_string(),
properties: serde_json::json!({
Expand Down Expand Up @@ -149,7 +149,7 @@ mod tests {
},
ToolOutputArg::empty(),
None,
"local:::rust_toolkit:::concat_strings".to_string(),
"local:::__official_shinkai:::concat_strings".to_string(),
);
let shinkai_tool = ShinkaiTool::Rust(tool, true);

Expand Down Expand Up @@ -205,7 +205,7 @@ mod tests {
functions: Some(vec![FunctionDetails {
name: "concat_strings".to_string(),
description: "Concatenates 2 to 4 strings.".to_string(),
tool_router_key: Some("local:::rust_toolkit:::concat_strings".to_string()),
tool_router_key: Some("local:::__official_shinkai:::concat_strings".to_string()),
parameters: FunctionParameters {
type_: "object".to_string(),
properties: serde_json::json!({
Expand Down
122 changes: 64 additions & 58 deletions shinkai-bin/shinkai-node/src/managers/tool_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::llm_provider::execution::chains::inference_chain_trait::{FunctionCall
use crate::llm_provider::job_manager::JobManager;
use crate::network::Node;
use crate::tools::tool_definitions::definition_generation::{generate_tool_definitions, get_rust_tools};
use crate::tools::tool_execution::execution_custom::execute_custom_tool;
use crate::tools::tool_execution::execution_custom::try_to_execute_rust_tool;
use crate::tools::tool_execution::execution_header_generator::{check_tool, generate_execution_environment};
use crate::utils::environment::fetch_node_environment;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -100,31 +100,28 @@ impl ToolRouter {
}

if is_empty {
if let Err(e) = self.add_testing_network_tools().await {
eprintln!("Error adding testing network tools: {}", e);
}
if let Err(e) = self.add_rust_tools().await {
eprintln!("Error adding rust tools: {}", e);
}
if let Err(e) = self.add_static_prompts(&generator).await {
eprintln!("Error adding static prompts: {}", e);
}
} else if !has_any_js_tools {
if let Err(e) = self.add_testing_network_tools().await {
eprintln!("Error adding testing network tools: {}", e);
}
} else if !has_any_js_tools {
if let Err(e) = self.add_rust_tools().await {
eprintln!("Error adding rust tools: {}", e);
}
if let Err(e) = self.add_testing_network_tools().await {
eprintln!("Error adding testing network tools: {}", e);
}
}

Ok(())
}

pub async fn force_reinstall_all(&self, generator: &Box<dyn EmbeddingGenerator>) -> Result<(), ToolError> {
if let Err(e) = self.add_testing_network_tools().await {
eprintln!("Error adding testing network tools: {}", e);
}
if let Err(e) = self.add_rust_tools().await {
eprintln!("Error adding rust tools: {}", e);
}
Expand All @@ -134,6 +131,9 @@ impl ToolRouter {
if let Err(e) = Self::import_tools_from_directory(self.sqlite_manager.clone()).await {
eprintln!("Error importing tools from directory: {}", e);
}
if let Err(e) = self.add_testing_network_tools().await {
eprintln!("Error adding testing network tools: {}", e);
}
Ok(())
}

Expand Down Expand Up @@ -261,7 +261,7 @@ impl ToolRouter {
.map_err(|e| ToolError::DatabaseError(e.to_string()))
}

async fn add_rust_tools(&self) -> Result<(), ToolError> {
pub async fn add_rust_tools(&self) -> Result<(), ToolError> {
let rust_tools = get_rust_tools();
for tool in rust_tools {
let rust_tool = RustTool::new(
Expand Down Expand Up @@ -296,10 +296,10 @@ impl ToolRouter {
// Manually create NetworkTool
let network_tool = NetworkTool {
name: "network__echo".to_string(),
toolkit_name: "shinkai-tool-echo".to_string(),
description: "Echoes the input message".to_string(),
version: "0.1".to_string(),
provider: ShinkaiName::new("@@agent_provider.arb-sep-shinkai".to_string()).unwrap(),
author: "@@official.shinkai".to_string(),
usage_type: usage_type.clone(),
activated: true,
config: vec![],
Expand Down Expand Up @@ -329,10 +329,10 @@ impl ToolRouter {
// Manually create another NetworkTool
let youtube_tool = NetworkTool {
name: "youtube_transcript_with_timestamps".to_string(),
toolkit_name: "shinkai-tool-youtube-transcript".to_string(),
description: "Takes a YouTube link and summarizes the content by creating multiple sections with a summary and a timestamp.".to_string(),
version: "0.1".to_string(),
provider: ShinkaiName::new("@@agent_provider.arb-sep-shinkai".to_string()).unwrap(),
author: "@@official.shinkai".to_string(),
usage_type: usage_type.clone(),
activated: true,
config: vec![],
Expand Down Expand Up @@ -554,21 +554,23 @@ impl ToolRouter {
&python_tool.oauth,
)?;

let result = python_tool.run(
envs,
node_env.api_listen_address.ip().to_string(),
node_env.api_listen_address.port(),
support_files,
function_args,
function_config_vec,
node_storage_path,
app_id.clone(),
tool_id.clone(),
node_name,
false,
None,
Some(all_files),
).await?;
let result = python_tool
.run(
envs,
node_env.api_listen_address.ip().to_string(),
node_env.api_listen_address.port(),
support_files,
function_args,
function_config_vec,
node_storage_path,
app_id.clone(),
tool_id.clone(),
node_name,
false,
None,
Some(all_files),
)
.await?;
let result_str = serde_json::to_string(&result)
.map_err(|e| LLMProviderError::FunctionExecutionError(e.to_string()))?;
return Ok(ToolCallFunctionResponse {
Expand Down Expand Up @@ -612,7 +614,7 @@ impl ToolRouter {
&None,
)?;

let result = execute_custom_tool(
let result = try_to_execute_rust_tool(
&shinkai_tool.tool_router_key().to_string_without_version().clone(),
function_args,
tool_id,
Expand Down Expand Up @@ -679,21 +681,23 @@ impl ToolRouter {
&deno_tool.oauth,
)?;

let result = deno_tool.run(
envs,
node_env.api_listen_address.ip().to_string(),
node_env.api_listen_address.port(),
support_files,
function_args,
function_config_vec,
node_storage_path,
app_id,
tool_id.clone(),
node_name,
false,
Some(tool_id),
Some(all_files),
).await?;
let result = deno_tool
.run(
envs,
node_env.api_listen_address.ip().to_string(),
node_env.api_listen_address.port(),
support_files,
function_args,
function_config_vec,
node_storage_path,
app_id,
tool_id.clone(),
node_name,
false,
Some(tool_id),
Some(all_files),
)
.await?;

let result_str = serde_json::to_string(&result)
.map_err(|e| LLMProviderError::FunctionExecutionError(e.to_string()))?;
Expand Down Expand Up @@ -1034,22 +1038,24 @@ impl ToolRouter {
&oauth,
)?;

let result = js_tool.run(
env,
node_env.api_listen_address.ip().to_string(),
node_env.api_listen_address.port(),
support_files,
function_args,
function_config_vec,
node_storage_path,
app_id,
tool_id.clone(),
// TODO Is this correct?
requester_node_name,
true,
Some(tool_id),
None,
).await?;
let result = js_tool
.run(
env,
node_env.api_listen_address.ip().to_string(),
node_env.api_listen_address.port(),
support_files,
function_args,
function_config_vec,
node_storage_path,
app_id,
tool_id.clone(),
// TODO Is this correct?
requester_node_name,
true,
Some(tool_id),
None,
)
.await?;
let result_str =
serde_json::to_string(&result).map_err(|e| LLMProviderError::FunctionExecutionError(e.to_string()))?;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -578,9 +578,9 @@ impl MyAgentOfferingsManager {
// TODO: avoid the expects
let network_tool = NetworkTool::new(
tool_header.name,
tool_header.toolkit_name,
tool_header.description,
tool_header.version,
provider.node_name.clone(),
provider,
tool_header.usage_type.expect("Usage type is required"),
true, // Assuming the tool is activated by default
Expand Down
12 changes: 10 additions & 2 deletions shinkai-bin/shinkai-node/src/network/handle_commands_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,12 @@ impl Node {
.await;
});
}
NodeCommand::InternalCheckRustToolsInstallation { res } => {
let db_clone = Arc::clone(&self.db);
tokio::spawn(async move {
let _ = Node::internal_check_rust_tools_installation(db_clone, res).await;
});
}
//
// V2 API
//
Expand Down Expand Up @@ -1566,7 +1572,8 @@ impl Node {
let db_clone = Arc::clone(&self.db);
let identity_manager_clone = self.identity_manager.clone();
tokio::spawn(async move {
let _ = Node::v2_api_search_files_by_name(db_clone, identity_manager_clone, name, bearer, res).await;
let _ =
Node::v2_api_search_files_by_name(db_clone, identity_manager_clone, name, bearer, res).await;
});
}
NodeCommand::V2ApiVecFSRetrieveVectorResource { bearer, path, res } => {
Expand Down Expand Up @@ -1860,8 +1867,9 @@ impl Node {
res,
} => {
let db_clone = Arc::clone(&self.db);
let node_env = fetch_node_environment();
tokio::spawn(async move {
let _ = Node::v2_api_add_shinkai_tool(db_clone, bearer, shinkai_tool, res).await;
let _ = Node::v2_api_add_shinkai_tool(db_clone, bearer, node_env, shinkai_tool, res).await;
});
}
NodeCommand::V2ApiGetShinkaiTool { bearer, payload, res } => {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::llm_provider::job_manager::JobManager;
use crate::managers::identity_manager::IdentityManagerTrait;
use crate::managers::tool_router::ToolRouter;
use crate::managers::IdentityManager;
use crate::network::network_manager::network_handlers::{ping_pong, PingPong};
use crate::network::node::ProxyConnectionInfo;
Expand Down Expand Up @@ -722,5 +723,14 @@ impl Node {

Ok(())
}

pub async fn internal_check_rust_tools_installation(
db: Arc<SqliteManager>,
res: Sender<Result<bool, String>>,
) -> Result<(), Error> {
let result = db.has_rust_tools().map_err(|e| e.to_string());
res.send(result).await.unwrap();
Ok(())
}
}

Loading