diff --git a/.github/workflows/cargo-test.yml b/.github/workflows/cargo-test.yml index 44a3764..42021b1 100644 --- a/.github/workflows/cargo-test.yml +++ b/.github/workflows/cargo-test.yml @@ -18,6 +18,7 @@ jobs: - uses: actions/checkout@v3 with: fetch-depth: '0' + ref: ${{ github.event.pull_request.head.ref }} - name: Cache Cargo dependencies uses: actions/cache@v2 with: @@ -52,7 +53,7 @@ jobs: cp ./target/debug/gossiphs ./gossiphs cd tree-sitter - time ../gossiphs relation + time ../gossiphs relation --symbol-csv test-symbol.csv cd .. cd typescript-eslint diff --git a/src/graph.rs b/src/graph.rs index 9e97807..4e1b0f8 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -1,5 +1,5 @@ use crate::extractor::Extractor; -use crate::symbol::{Symbol, SymbolGraph, SymbolKind}; +use crate::symbol::{DefRefPair, Symbol, SymbolGraph, SymbolKind}; use cupido::collector::config::Collect; use cupido::collector::config::{get_collector, Config}; use cupido::relation::graph::RelationGraph as CupidoRelationGraph; @@ -12,7 +12,7 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use std::fs; use std::path::Path; use std::time::Instant; -use tracing::{debug, info, warn}; +use tracing::{debug, info}; pub struct FileContext { pub path: String, @@ -491,36 +491,6 @@ impl Graph { contexts } - pub fn symbols_between_files(&self, src: &String, dst: &String) -> Vec { - if !self.files().contains(src) || !self.files().contains(dst) { - return Vec::new(); - } - - let mut related_symbols: Vec = vec![]; - - // other files -> this file - let definitions_in_file = self.symbol_graph.list_definitions(src); - let definition_count = definitions_in_file.len(); - - definitions_in_file.iter().for_each(|def| { - self.symbol_graph - .list_references_by_definition(&def.id()) - .iter() - .filter(|(each, _)| { - return each.file.eq(dst); - }) - .for_each(|(each_ref, weight)| { - let real_weight = std::cmp::max(weight / definition_count, 1); - related_symbols.push(RelatedSymbol { - symbol: each_ref.clone(), - weight: real_weight, - }) - }); - }); - - related_symbols - } - pub fn related_symbols(&self, symbol: &Symbol) -> HashMap { match symbol.kind { SymbolKind::DEF => self @@ -545,6 +515,13 @@ impl Graph { .collect(); FileMetadata { symbols } } + + pub fn pairs_between_files(&self, src_file: &String, dst_file: &String) -> Vec { + if !self.files().contains(src_file) || !self.files().contains(dst_file) { + return Vec::new(); + } + self.symbol_graph.pairs_between_files(src_file, dst_file) + } } #[derive(Serialize, Deserialize, Clone)] @@ -607,6 +584,7 @@ impl GraphConfig { #[cfg(test)] mod tests { use crate::graph::{Graph, GraphConfig}; + use crate::symbol::DefRefPair; use petgraph::visit::EdgeRef; use tracing::{debug, info}; @@ -688,16 +666,25 @@ mod tests { } #[test] - fn between_files() { + fn paths() { + tracing_subscriber::fmt::init(); let mut config = GraphConfig::default(); config.project_path = String::from("."); let g = Graph::from(config); - let symbols = g.symbols_between_files( - &String::from("src/rule.rs"), + let symbols: Vec = g.pairs_between_files( &String::from("src/extractor.rs"), + &String::from("src/graph.rs"), ); - symbols.iter().for_each(|item| { - info!("{:?}: {}", item.symbol, item.weight); + symbols.iter().for_each(|pair| { + info!( + "{} {} {} -> {} {} {}", + pair.src_symbol.file, + pair.src_symbol.name, + pair.src_symbol.range.start_point.row, + pair.dst_symbol.file, + pair.dst_symbol.name, + pair.dst_symbol.range.start_point.row + ); }); } } diff --git a/src/main.rs b/src/main.rs index 0cf530a..390cf7f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,7 @@ use gossiphs::graph::{Graph, GraphConfig, RelatedFileContext}; use gossiphs::server::{server_main, ServerConfig}; use inquire::Text; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fs; use std::fs::File; use std::io::Write; @@ -104,6 +104,10 @@ struct RelationCommand { #[clap(long)] #[clap(default_value = "output.csv")] csv: String, + + #[clap(long)] + #[clap(default_value = "")] + symbol_csv: String, } #[derive(Parser, Debug)] @@ -245,7 +249,6 @@ fn handle_relation(relation_cmd: RelationCommand) { Ok(writer) => writer, Err(e) => panic!("Failed to create CSV writer: {}", e), }; - // Write the header row let mut header = vec!["".to_string()]; header.extend(files.clone()); @@ -253,23 +256,57 @@ fn handle_relation(relation_cmd: RelationCommand) { panic!("Failed to write CSV header: {}", e); } + let mut symbol_wtr_opts = None; + if !relation_cmd.symbol_csv.is_empty() { + let symbol_wtr_result = Writer::from_path(relation_cmd.symbol_csv); + symbol_wtr_opts = match symbol_wtr_result { + Ok(writer) => Some(writer), + Err(e) => panic!("Failed to create CSV writer: {}", e), + }; + let mut header = vec!["".to_string()]; + header.extend(files.clone()); + if let Some(symbol_wtr) = symbol_wtr_opts.as_mut() { + symbol_wtr + .write_record(&header) + .expect("Failed to write header to symbol_wtr"); + } + } + // Write each row for file in &files { let mut row = vec![file.clone()]; - let related_files = g.related_files(file); - let related_files_map: HashMap<_, _> = related_files + let mut pair_row = vec![file.clone()]; + let related_files_map: HashMap<_, _> = g + .related_files(file) .into_iter() .map(|rf| (rf.name, rf.score)) .collect(); for related_file in &files { - let score = related_files_map - .get(related_file) - .unwrap_or(&0) - .to_string(); - row.push(score); + let score = related_files_map.get(related_file).unwrap_or(&0); + row.push(score.to_string()); + + if symbol_wtr_opts.is_some() { + if score > &0 { + let pairs = g + .pairs_between_files(&file, &related_file) + .iter() + .map(|each| each.src_symbol.name.clone()) + .collect::>() + .into_iter() + .collect::>(); + pair_row.push(pairs.join("|")); + } else { + pair_row.push(String::new()); + } + } } wtr.write_record(&row).expect("Failed to write record"); + if let Some(symbol_wtr) = symbol_wtr_opts.as_mut() { + symbol_wtr + .write_record(&pair_row) + .expect("Failed to write pair_row to symbol_wtr"); + } } // Flush the writer to ensure all data is written @@ -665,8 +702,11 @@ fn diff_test() { #[test] fn relation_test() { + let mut config = CommonOptions::default(); + config.project_path = ".".parse().unwrap(); handle_relation(RelationCommand { - common_options: CommonOptions::default(), + common_options: config, csv: "ok.csv".to_string(), + symbol_csv: "ok1.csv".to_string(), }) } diff --git a/src/rule.rs b/src/rule.rs index e6c0a9f..620836e 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -16,9 +16,24 @@ pub fn get_rule(extractor_type: &Extractor) -> Rule { Extractor::Rust => Rule { import_grammar: r#" (identifier) @variable_name +(call_expression + function: (identifier) @function) +(call_expression + function: (field_expression + field: (field_identifier) @function.method)) +(call_expression + function: (scoped_identifier + "::" + name: (identifier) @function)) "#, export_grammar: r#" (function_item name: (identifier) @exported_symbol) +(function_signature_item name: (identifier) @exported_symbol) +(generic_function + function: (identifier) @exported_symbol) +(generic_function + function: (scoped_identifier + name: (identifier) @exported_symbol)) "#, }, diff --git a/src/symbol.rs b/src/symbol.rs index 05aa02c..c0011b0 100644 --- a/src/symbol.rs +++ b/src/symbol.rs @@ -1,3 +1,4 @@ +use petgraph::algo::all_simple_paths; use petgraph::graph::{NodeIndex, UnGraph}; use petgraph::prelude::EdgeRef; use serde::{Deserialize, Serialize}; @@ -248,4 +249,32 @@ impl SymbolGraph { let ref_index = self.symbol_mapping.get(symbol_id).unwrap(); self.neighbor_symbols(*ref_index) } + + pub fn pairs_between_files(&self, src_file: &String, dst_file: &String) -> Vec { + if let (Some(src_index), Some(dst_index)) = ( + self.file_mapping.get(src_file), + self.file_mapping.get(dst_file), + ) { + // file -> symbol -> symbol -> file + // so at most 2 + let pairs: Vec<_> = + all_simple_paths::, _>(&self.g, *src_index, *dst_index, 1, Some(2)) + .filter(|each| each.len() == 4) + .map(|each| DefRefPair { + src_symbol: self.g[each[1]].get_symbol().unwrap().clone(), + dst_symbol: self.g[each[2]].get_symbol().unwrap().clone(), + }) + .filter(|each| each.src_symbol.kind == SymbolKind::DEF) + .collect(); + return pairs; + } + + // fallback + vec![] + } +} + +pub struct DefRefPair { + pub src_symbol: Symbol, + pub dst_symbol: Symbol, }