diff --git a/README.md b/README.md index 19ad082..ff9d97c 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Sierra analyzer is a security toolkit for analyzing Sierra files. #### Decompile a Sierra file ``` -cargo run --bin sierra-decompiler +cargo run -- -f ```

@@ -27,25 +27,34 @@ cargo run --bin sierra-decompiler For a colourless output : ``` -cargo run --bin sierra-decompiler --no-color +cargo run -- -f --no-color ``` It it also possible to get a verbose output with more informations : ``` -cargo run --bin sierra-decompiler --verbose +cargo run -- -f --verbose +``` + +#### Analyze a remote contract + +Contracts can be fetched directly from Starknet (Mainnet & Sepolia) by specifying the contract class to analyze : + +``` +# Fetch & decompile a contract from starknet mainnet +cargo run -- --remote 0x07c43d18d37d66d7855dab8f21ebf9d554dd213c6307aacecaf2d595a53b3bbb + +# Fetch & decompile a contract from Sepolia network +cargo run -- --network sepolia --remote 0x068377a89d64c0b16dc97c66933777bf4e9b050652c4fde2c59c8c4d755a163b ``` #### Print the contract's Control-Flow Graph ``` -cargo run ./examples/sierra/fib_array.sierra --cfg +cargo run -- -f ./examples/sierra/fib_array.sierra --cfg # Output the Control-Flow Graph to a custom folder (default is ./output_cfg) -cargo run ./tests/sierra_files/fib_array.sierra --cfg --cfg-output ./test - -# Get the CFG of a specific function -cargo run ./examples/sierra/fib_unary.sierra --cfg --function 'examples::fib_unary::fib' +cargo run -- -f ./examples/sierra/fib_array.sierra --cfg --cfg-output ./test ```

@@ -55,13 +64,13 @@ cargo run ./examples/sierra/fib_unary.sierra --cfg --function 'examples::fib_una #### Print the contract's Callgraph ``` -cargo run ./examples/sierra/fib_array.sierra --callgraph +cargo run -- -f ./examples/sierra/fib_array.sierra --callgraph # Output the Callgraph to a custom folder (default is ./output_callgraph) -cargo run ./tests/sierra_files/fib_array.sierra --callgraph --callgraph-output ./test +cargo run -- -f ./examples/sierra/fib_array.sierra --callgraph --callgraph-output ./test # Get the Callgraph of a specific function -cargo run ./examples/sierra/fib_unary.sierra --callgraph --function 'examples::fib_unary::fib' +cargo run -- -f ./examples/sierra/fib_unary.sierra --callgraph --function 'examples::fib_unary::fib' ```

@@ -71,7 +80,7 @@ cargo run ./examples/sierra/fib_unary.sierra --callgraph --function 'examples::f #### Run the detectors ``` -cargo run ./examples/sierra/fib_array.sierra -d +cargo run -- -f ./examples/sierra/fib_array.sierra -d ```

@@ -90,4 +99,5 @@ Examples can be found [here](/lib/examples/). - [x] Control-Flow Graph - [x] Call Graph - [X] Informational & Security detectors +- [x] Fetching contracts from Starknet - [ ] Symbolic execution diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 95d653f..6c4b028 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -13,6 +13,10 @@ hex = "0.4.3" lazy_static = "1.4.0" num-bigint = "0.4.4" regex = "1.10.4" +reqwest = { version = "0.12.4", features = ["json"] } +serde = "1.0.198" +serde_json = "1.0.116" +tokio = { version="^1.37.0", features = ["full"] } [dev-dependencies] serde_json = "1.0.116" diff --git a/lib/examples/parse_starknet_contract.rs b/lib/examples/parse_starknet_contract.rs new file mode 100644 index 0000000..19850db --- /dev/null +++ b/lib/examples/parse_starknet_contract.rs @@ -0,0 +1,40 @@ +use sierra_analyzer_lib::provider::NetworkConfig; +use sierra_analyzer_lib::provider::RpcClient; +use sierra_analyzer_lib::sierra_program::SierraProgram; + +use cairo_lang_starknet_classes::contract_class::ContractClass; + +use tokio; + +#[tokio::main] +async fn main() { + let client = RpcClient::new(NetworkConfig::MAINNET_API_URL); + let contract_class = "0x01c0bb51e2ce73dc007601a1e7725453627254016c28f118251a71bbb0507fcb"; + match client.get_class(contract_class).await { + Ok(response) => { + // Convert RpcClient response to JSON content + let content = response.to_json(); + + // Deserialize JSON into a ContractClass + let program_string = serde_json::from_str::(&content) + .ok() + .and_then(|prog| prog.extract_sierra_program().ok()) + .map_or_else(|| content.clone(), |prog_sierra| prog_sierra.to_string()); + let program = SierraProgram::new(program_string); + + // Don't use the verbose output + let verbose_output = false; + + // Decompile the Sierra program + let mut decompiler = program.decompiler(verbose_output); + + // Print the decompiled program with use_color=true parameter + // You can disable colored output by passing use_color=false + let use_color = true; + println!("{}", decompiler.decompile(use_color)); + } + Err(e) => { + eprintln!("Error calling RPC: {}", e); + } + } +} diff --git a/lib/src/decompiler/cfg.rs b/lib/src/decompiler/cfg.rs index 8c1f053..0f6e771 100644 --- a/lib/src/decompiler/cfg.rs +++ b/lib/src/decompiler/cfg.rs @@ -76,6 +76,11 @@ impl<'a> ControlFlowGraph { /// Generates the CFG basic blocks pub fn generate_basic_blocks(&mut self) { + // Check if there are no statements and return early + if self.statements.is_empty() { + return; + } + // Retrieve basic blocks delimitations let (basic_blocks_starts, basic_blocks_ends) = self.get_basic_blocks_delimitations(); @@ -107,7 +112,7 @@ impl<'a> ControlFlowGraph { } // Handle conditional branches - if let Some(conditional_branch) = statement.as_conditional_branch() { + if let Some(conditional_branch) = statement.as_conditional_branch(vec![]) { if let Some(edge_2_offset) = conditional_branch.edge_2_offset { // Conditional branch with 2 edges (JNZ) current_basic_block.edges.push(Edge { diff --git a/lib/src/decompiler/decompiler.rs b/lib/src/decompiler/decompiler.rs index ffe196a..3a69409 100644 --- a/lib/src/decompiler/decompiler.rs +++ b/lib/src/decompiler/decompiler.rs @@ -14,6 +14,7 @@ use crate::decompiler::function::Function; use crate::decompiler::function::SierraStatement; use crate::decompiler::libfuncs_patterns::{IS_ZERO_REGEX, USER_DEFINED_FUNCTION_REGEX}; use crate::parse_element_name; +use crate::parse_element_name_with_fallback; use crate::sierra_program::SierraProgram; /// A struct that represents a decompiler for a Sierra program @@ -28,6 +29,10 @@ pub struct Decompiler<'a> { printed_blocks: Vec, /// The function we are currently working on current_function: Option>, + /// Names of all declared types (in order) + declared_types_names: Vec, + /// Names of all declared libfuncs (in order) + declared_libfuncs_names: Vec, /// Enable / disable the verbose output /// Some statements are not included in the regular output to improve the readability verbose: bool, @@ -41,6 +46,8 @@ impl<'a> Decompiler<'a> { indentation: 1, printed_blocks: Vec::new(), current_function: None, + declared_types_names: Vec::new(), + declared_libfuncs_names: Vec::new(), verbose: verbose, } } @@ -76,7 +83,7 @@ impl<'a> Decompiler<'a> { } /// Decompiles the type declarations - fn decompile_types(&self) -> String { + fn decompile_types(&mut self) -> String { self.sierra_program .program() .type_declarations @@ -87,7 +94,7 @@ impl<'a> Decompiler<'a> { } /// Decompiles the libfunc declarations - fn decompile_libfuncs(&self) -> String { + fn decompile_libfuncs(&mut self) -> String { self.sierra_program .program() .libfunc_declarations @@ -126,7 +133,7 @@ impl<'a> Decompiler<'a> { } /// Decompiles a single type declaration - fn decompile_type(&self, type_declaration: &TypeDeclaration) -> String { + fn decompile_type(&mut self, type_declaration: &TypeDeclaration) -> String { // Get the debug name of the type's ID let id = format!( "{}", @@ -135,10 +142,9 @@ impl<'a> Decompiler<'a> { .debug_name .as_ref() .unwrap_or(&"".into()) - ) - .yellow(); + ); - // Get the long ID of the type, which consists of the generic ID and any generic arguments + // Get the long ID of the type let long_id = &type_declaration.long_id; let generic_id = long_id.generic_id.to_string(); @@ -152,8 +158,14 @@ impl<'a> Decompiler<'a> { generic_id.clone() }; - // Retrieve the declared type information for the type, if it exists - // We don't use it in the decompiler output because it might not be readable enough + // Conditionally format id and long_id_repr + let (id_colored, long_id_repr_colored) = if id.is_empty() { + (id.yellow(), long_id_repr.yellow().to_string()) + } else { + (id.white(), long_id_repr.clone()) + }; + + // Retrieve declared type information let _declared_type_info_str = type_declaration.declared_type_info.as_ref().map_or_else( String::new, |declared_type_info| { @@ -168,18 +180,33 @@ impl<'a> Decompiler<'a> { }, ); - // Conditionally append long_id_repr in parentheses if it is different from id - let type_definition = if *long_id_repr != *id { - format!("type {} ({})", id, long_id_repr) - } else { - format!("type {}", id) + // Construct the type definition string + // If the id is not empty, format the type definition with the id and optionally the long ID representation + let type_definition = if !id.is_empty() { + let id_string = id.clone().to_string(); + self.declared_types_names.push(id_string.clone()); + format!( + "type {}{}", + id.yellow(), + if long_id_repr_colored != id_colored.to_string() { + format!(" ({})", long_id_repr_colored) + } else { + "".to_string() + } + ) + } + // If the id is empty, format the type definition with only the long ID representation + else { + let long_id_repr_string = long_id_repr.clone().to_string(); + self.declared_types_names.push(long_id_repr_string.clone()); + format!("type {}{}", long_id_repr_colored, "") }; type_definition } - /// Decompile an single libfunc - fn decompile_libfunc(&self, libfunc_declaration: &LibfuncDeclaration) -> String { + /// Decompiles an individual libfunc declaration + fn decompile_libfunc(&mut self, libfunc_declaration: &LibfuncDeclaration) -> String { // Get the debug name of the libfunc's ID let id = format!( "{}", @@ -188,10 +215,25 @@ impl<'a> Decompiler<'a> { .debug_name .as_ref() .unwrap_or(&"".into()) - ) - .blue(); + ); + + // Get the long ID of the libfunc + let long_id = &libfunc_declaration.long_id; + + // Parse kgeneric arguments + let _arguments = self.parse_arguments(&libfunc_declaration.long_id.generic_args); + + // Construct the libfunc definition string + let libfunc_definition = if id.is_empty() { + long_id.to_string() // Use long_id if id is empty + } else { + id.to_string() + }; + + self.declared_libfuncs_names + .push(libfunc_definition.clone()); // Push non-colored version to declared_libfuncs_names - format!("libfunc {}", id) + format!("libfunc {}", libfunc_definition.blue()) } /// Decompiles the functions prototypes @@ -217,8 +259,8 @@ impl<'a> Decompiler<'a> { &self, function_declaration: &GenFunction, ) -> String { - // Get the debug name of the function's ID and format it in bold - let id = format!("{}", function_declaration.id.debug_name.as_ref().unwrap()).bold(); + // Parse the function name + let id = format!("{}", parse_element_name!(function_declaration.id)).bold(); // Get the function signature, which consists of the parameter types and return types let signature = &function_declaration.signature; @@ -226,15 +268,13 @@ impl<'a> Decompiler<'a> { .param_types .iter() .map(|param_type| { - param_type - .debug_name - .as_ref() - .unwrap_or(&format!("[{}]", param_type.id).into()) - .to_string() + // We use `parse_element_name_with_fallback` and not `parse_element_name` because + // we try to match the type id with it's corresponding name if it's a remote contract + parse_element_name_with_fallback!(param_type, self.declared_types_names) }) .collect(); - // Create a list of strings representing the function parameters, + // Create a list of strings representing the function parameters // with each string formatted as ": " let param_strings: Vec = param_types .iter() @@ -262,7 +302,8 @@ impl<'a> Decompiler<'a> { let ret_type_string = if let Some(debug_name) = &ret_type.debug_name { debug_name.to_string() } else { - format!("[{}]", ret_type.id) + // Replace id with the corresponding type name + format!("[{}]", self.declared_types_names[ret_type.id as usize]) }; let ret_type_colored = ret_type_string.purple(); // Color ret_type_string in purple ret_type_colored.to_string() @@ -482,7 +523,11 @@ impl<'a> Decompiler<'a> { // Append each statement to the string block for statement in &block.statements { // If condition - if let Some(conditional_branch) = statement.as_conditional_branch() { + if let Some(conditional_branch) = + // We pass it the declared libfunc names to allow the method to reconstruct function calls + // For remote contracts + statement.as_conditional_branch(self.declared_libfuncs_names.clone()) + { if block.edges.len() == 2 { let function_name = &conditional_branch.function; let function_arguments = conditional_branch.parameters.join(", "); @@ -494,7 +539,11 @@ impl<'a> Decompiler<'a> { } } // Unconditional jump - else if let Some(_unconditional_branch) = statement.as_conditional_branch() { + else if let Some(_unconditional_branch) = + // We pass it the declared libfunc names to allow the method to reconstruct function calls + // For remote contracts + statement.as_conditional_branch(self.declared_libfuncs_names.clone()) + { // Handle unconditional branch logic todo!() } @@ -502,7 +551,12 @@ impl<'a> Decompiler<'a> { else { // Add the formatted statements to the block // Some statements are only included in the verbose output - if let Some(formatted_statement) = statement.formatted_statement(self.verbose) { + // + // We pass it the declared libfunc names to allow the method to reconstruct function calls + // For remote contracts + if let Some(formatted_statement) = statement + .formatted_statement(self.verbose, self.declared_libfuncs_names.clone()) + { decompiled_basic_block += &format!("{}{}\n", indentation, formatted_statement); } } diff --git a/lib/src/decompiler/function.rs b/lib/src/decompiler/function.rs index b278e46..a7a2547 100644 --- a/lib/src/decompiler/function.rs +++ b/lib/src/decompiler/function.rs @@ -16,6 +16,7 @@ use crate::decompiler::libfuncs_patterns::{ use crate::decompiler::utils::decode_hex_bigint; use crate::extract_parameters; use crate::parse_element_name; +use crate::parse_element_name_with_fallback; /// A struct representing a statement #[derive(Debug, Clone)] @@ -50,7 +51,11 @@ impl SierraStatement { /// Formats the statement as a string /// We try to format them in a way that is as similar as possible to the Cairo syntax - pub fn formatted_statement(&self, verbose: bool) -> Option { + pub fn formatted_statement( + &self, + verbose: bool, + declared_libfuncs_names: Vec, + ) -> Option { match &self.statement { // Return statements GenStatement::Return(vars) => { @@ -67,7 +72,14 @@ impl SierraStatement { } // Invocation statements GenStatement::Invocation(invocation) => { - let libfunc_id = parse_element_name!(invocation.libfunc_id); + // Try to get the debug name of the libfunc_id + // We use `parse_element_name_with_fallback` and not `parse_element_name` because + // we try to match the libfunc id with it's corresponding name if it's a remote contract + let libfunc_id = parse_element_name_with_fallback!( + invocation.libfunc_id, + declared_libfuncs_names + ); + if !Self::is_function_allowed(&libfunc_id, verbose) { return None; // Skip formatting if function is not allowed } @@ -283,14 +295,31 @@ impl SierraStatement { } /// Returns a reference to this statement as a conditional branch if it is one - pub fn as_conditional_branch(&self) -> Option { + pub fn as_conditional_branch( + &self, + declared_libfuncs_names: Vec, + ) -> Option { if self.is_conditional_branch { if let GenStatement::Invocation(invocation) = &self.statement { // Statement let statement = self.statement.clone(); // Function name - let libfunc_id_str = parse_element_name!(invocation.libfunc_id); + let libfunc_id_str = invocation + .libfunc_id + .debug_name + .as_ref() + .map(|name| name.to_string()) + // If the debug name is not present, try to get the name from declared_libfuncs_names + .or_else(|| { + declared_libfuncs_names + .get(invocation.libfunc_id.id as usize) + .map(|name| name.to_string()) + // If neither the debug name nor the name from declared_libfuncs_names is present, + // format the id as a string + .or_else(|| Some(format!("[{}]", invocation.libfunc_id.id))) + }) + .unwrap(); // Parameters let parameters = extract_parameters!(invocation.args); diff --git a/lib/src/decompiler/macros.rs b/lib/src/decompiler/macros.rs index 6a9053c..4230d75 100644 --- a/lib/src/decompiler/macros.rs +++ b/lib/src/decompiler/macros.rs @@ -2,15 +2,35 @@ /// using the debug_name if present or falling back to the ID field #[macro_export] macro_rules! parse_element_name { - ($libfunc_id:expr) => { - if let Some(debug_name) = &$libfunc_id.debug_name { + ($element_id:expr) => { + if let Some(debug_name) = &$element_id.debug_name { debug_name.to_string() } else { - $libfunc_id.id.to_string() + $element_id.id.to_string() } }; } +/// Macro to parse the debug name or get the name from a provided map, +/// or fallback to the ID. This is used to match the element ID with +/// its corresponding libfunc or type name +#[macro_export] +macro_rules! parse_element_name_with_fallback { + ($element_id:expr, $fallback_map:expr) => { + $element_id + .debug_name + .as_ref() + .map(|name| name.to_string()) + .or_else(|| { + $fallback_map + .get($element_id.id as usize) + .map(|name| name.to_string()) + .or_else(|| Some(format!("[{}]", $element_id.id))) + }) + .unwrap() + }; +} + /// Macro to extract parameters from the args field of a GenInvocation object. /// It converts each parameter into a String, using the debug_name if available, /// otherwise using the ID field diff --git a/lib/src/lib.rs b/lib/src/lib.rs index 9e6894a..e95f382 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -2,4 +2,5 @@ mod config; mod decompiler; pub mod detectors; pub mod graph; +pub mod provider; pub mod sierra_program; diff --git a/lib/src/provider.rs b/lib/src/provider.rs new file mode 100644 index 0000000..45b615d --- /dev/null +++ b/lib/src/provider.rs @@ -0,0 +1,109 @@ +use reqwest::{Client, Error}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +/// Configuration for endpoints +/// TODO : Add other endpoints in case of failure +pub struct NetworkConfig; + +impl NetworkConfig { + /// URL for the mainnet API + pub const MAINNET_API_URL: &'static str = + "https://starknet-mainnet.public.blastapi.io/rpc/v0_6"; + + /// URL for the Sepolia API + pub const SEPOLIA_API_URL: &'static str = + "https://starknet-sepolia.public.blastapi.io/rpc/v0_6"; +} + +/// Struct representing an RPC client +pub struct RpcClient<'a> { + /// The client used for sending requests + client: Client, + /// The address of the RPC node + node_address: &'a str, +} + +impl<'a> RpcClient<'a> { + pub fn new(node_address: &'a str) -> Self { + RpcClient { + client: Client::new(), + node_address, + } + } + + /// Sends a starknet_getClass request to the RPC node. + pub async fn get_class(&self, contract_class: &str) -> Result { + let url = format!("{}", self.node_address); + let request_body = serde_json::json!({ + "id": 1, + "jsonrpc": "2.0", + "method": "starknet_getClass", + "params": ["pending", contract_class], + }); + + let response = self + .client + .post(&url) + .json(&request_body) + .send() + .await? + .json::() + .await?; + + Ok(response) + } +} + +/// The expected response structure from the starknet_getClass RPC call +#[derive(Deserialize, Serialize, Debug)] +pub struct RpcResponse { + pub result: Value, +} + +impl RpcResponse { + /// Returns the response JSON + pub fn to_json(&self) -> String { + // Serialize the RpcResponse into a JSON string + let json_string = serde_json::to_string_pretty(&self.result) + .unwrap_or_else(|e| format!("Error serializing JSON: {}", e)); + + // Parse the JSON string into a serde_json::Value + let value: Value = serde_json::from_str(&json_string) + .unwrap_or_else(|e| panic!("Error parsing JSON: {}", e)); + + // Clean the ABI field + let clean_abi = value["abi"] + .as_str() + .unwrap_or_else(|| panic!("Missing ABI field")) + .replace(r#"\""#, "") // Remove escaped quotes + .trim_matches('"') // Trim surrounding quotes + .to_string(); + + // Extract other fields + let sierra_program = &value["sierra_program"].to_string(); + let sierra_program_debug_info = &value["sierra_program_debug_info"].to_string(); + let contract_class_version = &value["contract_class_version"].to_string(); + let entry_points_by_type = &value["entry_points_by_type"].to_string(); + + // Construct the formatted JSON string + let sierra_json = format!( + r#"{{ + "abi": {}, + "sierra_program": {}, + "sierra_program_debug_info": {}, + "contract_class_version": {}, + "entry_points_by_type": {} +}}"#, + clean_abi, + sierra_program, + sierra_program_debug_info, + contract_class_version, + entry_points_by_type + ) + .trim() // Trim leading and trailing whitespaces + .to_string(); + + sierra_json + } +} diff --git a/sierra-decompiler/Cargo.toml b/sierra-decompiler/Cargo.toml index f03e0a0..12543d3 100644 --- a/sierra-decompiler/Cargo.toml +++ b/sierra-decompiler/Cargo.toml @@ -11,3 +11,4 @@ cairo-lang-starknet-classes = "~2.6.3" clap = { version = "4.0.0-rc.1", features = [ "derive" ] } serde_json = "1.0.116" sierra-analyzer-lib = { path = "../lib" } +tokio = "1.37.0" diff --git a/sierra-decompiler/src/main.rs b/sierra-decompiler/src/main.rs index 9a4b9dc..c6dc99a 100644 --- a/sierra-decompiler/src/main.rs +++ b/sierra-decompiler/src/main.rs @@ -1,22 +1,27 @@ use std::fs; use std::fs::File; use std::io::Read; +use std::path::PathBuf; use clap::Parser; use serde_json; -use std::path::PathBuf; +use tokio; use cairo_lang_starknet_classes::contract_class::ContractClass; use sierra_analyzer_lib::detectors::get_detectors; use sierra_analyzer_lib::graph::graph::save_svg_graph_to_file; +use sierra_analyzer_lib::provider::NetworkConfig; +use sierra_analyzer_lib::provider::RpcClient; use sierra_analyzer_lib::sierra_program; +use sierra_analyzer_lib::sierra_program::SierraProgram; /// Decompile a Sierra program #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Args { /// Sierra program file - sierra_file: PathBuf, + #[clap(short = 'f', long)] + sierra_file: Option, /// Function name to only get one function for CFG & Callgraph #[clap(long)] @@ -49,29 +54,82 @@ struct Args { /// Run the detectors #[clap(short, long)] detectors: bool, + + /// Remote contract class address + #[clap(long, default_value = "")] + remote: String, + + /// Network type (Mainnet & Sepolia are supported) + #[clap(long, default_value = "mainnet")] + network: String, } -fn main() { +#[tokio::main] +async fn main() { let args = Args::parse(); - // Read input file - let mut file = File::open(&args.sierra_file).expect("Failed to open file"); - let mut content = String::new(); - file.read_to_string(&mut content) - .expect("Failed to read file"); + if args.remote.is_empty() && args.sierra_file.is_none() { + eprintln!("Error: Either remote or Sierra file must be provided"); + return; + } - // Deserialize JSON into a ContractClass, or use the content directly if that fails - let program_string = serde_json::from_str::(&content) - .ok() - .and_then(|prog| prog.extract_sierra_program().ok()) - .map_or_else(|| content.clone(), |prog_sierra| prog_sierra.to_string()); - let program = sierra_program::SierraProgram::new(program_string); + // Define program and sierra_file before the if statement + let program: SierraProgram; + let mut sierra_file: Option = None; + + // Analyze a contract deployed on Starknet + if !args.remote.is_empty() { + // Define the client based on the network parameter + let client = match args.network.as_str() { + "mainnet" => RpcClient::new(NetworkConfig::MAINNET_API_URL), + "sepolia" => RpcClient::new(NetworkConfig::SEPOLIA_API_URL), + _ => { + eprintln!("Error: Unsupported network type '{}'", args.network); + return; + } + }; + + // Fetch contract class from the RPC Node + match client.get_class(&args.remote).await { + Ok(response) => { + // Convert RpcClient response to JSON content + let content = response.to_json(); + + // Deserialize JSON into a ContractClass + let program_string = serde_json::from_str::(&content) + .ok() + .and_then(|prog| prog.extract_sierra_program().ok()) + .map_or_else(|| content.clone(), |prog_sierra| prog_sierra.to_string()); + program = SierraProgram::new(program_string); + } + Err(e) => { + eprintln!("Error calling RPC: {}", e); + // Stop the program if there is an error in the RPC response + return; + } + } + } + // Analyze a local file + else { + sierra_file = args.sierra_file; + let mut file = File::open(sierra_file.as_ref().unwrap()).expect("Failed to open file"); + let mut content = String::new(); + file.read_to_string(&mut content) + .expect("Failed to read file"); + + // Deserialize JSON into a ContractClass, or use the content directly if that fails + let program_string = serde_json::from_str::(&content) + .ok() + .and_then(|prog| prog.extract_sierra_program().ok()) + .map_or_else(|| content.clone(), |prog_sierra| prog_sierra.to_string()); + program = sierra_program::SierraProgram::new(program_string); + } // Color output by default and if CFG or Callgraph is not enabled to avoid bugs in the SVG output let colored_output = !args.no_color ^ (args.cfg | args.callgraph); + // Now you can use program and sierra_file outside the if and else blocks let mut decompiler = program.decompiler(args.verbose); - let decompiled_code = decompiler.decompile(colored_output); // Filter functions if a specific function name is given @@ -79,14 +137,20 @@ fn main() { decompiler.filter_functions(function_name); } - if args.cfg { - // Determine the full path for the output file - let file_stem = args - .sierra_file + // Determine the file stem based on the remote address or the sierra_file + let file_stem = if !args.remote.is_empty() { + args.remote.clone() + } else { + sierra_file + .as_ref() + .unwrap() .file_stem() .unwrap_or_default() .to_string_lossy() - .to_string(); + .to_string() + }; + + if args.cfg { let svg_filename = format!("{}_cfg.svg", file_stem); let full_path = args.cfg_output.join(svg_filename); @@ -105,13 +169,6 @@ fn main() { save_svg_graph_to_file(full_path.to_str().unwrap(), cfg_graph) .expect("Failed to save CFG to SVG"); } else if args.callgraph { - // Determine the full path for the output file - let file_stem = args - .sierra_file - .file_stem() - .unwrap_or_default() - .to_string_lossy() - .to_string(); let svg_filename = format!("{}_callgraph.svg", file_stem); let full_path = args.callgraph_output.join(svg_filename);