Skip to content

Commit

Permalink
Support splitting a single item into multiple files
Browse files Browse the repository at this point in the history
  • Loading branch information
missingdays committed Sep 24, 2024
1 parent 8402807 commit 4c26560
Show file tree
Hide file tree
Showing 14 changed files with 1,239 additions and 3 deletions.
35 changes: 32 additions & 3 deletions pilota-build/src/codegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
use traits::CodegenBackend;

use self::workspace::Workspace;
use crate::rir::NodeKind;
use crate::{
db::RirDatabase,
dedup::def_id_equal,
Expand Down Expand Up @@ -447,8 +448,12 @@ where
ws.write_crates()
}

pub fn write_items(&self, stream: &mut String, items: impl Iterator<Item = CodegenItem>)
where
pub fn write_items(
&self,
stream: &mut String,
items: impl Iterator<Item = CodegenItem>,
base_dir: &Path,
) where
B: Send,
{
let mods = items.into_group_map_by(|CodegenItem { def_id, .. }| {
Expand All @@ -474,7 +479,29 @@ where
let _enter = span.enter();
let mut dup = AHashMap::default();
for def_id in def_ids.iter() {
this.write_item(&mut stream, *def_id, &mut dup)
if this.split {
let mut item_stream = String::new();
let node = this.db.node(def_id.def_id).unwrap();
let file_name = format!("{}.rs", node.name());
this.write_item(&mut item_stream, *def_id, &mut dup);

let full_path = base_dir.join(file_name.clone());
std::fs::create_dir_all(base_dir).unwrap();
let mut file =
std::io::BufWriter::new(std::fs::File::create(full_path.clone()).unwrap());
file.write_all(item_stream.as_bytes()).unwrap();
file.flush().unwrap();
fmt_file(full_path);

let base_dir_local_path = base_dir.iter().last().unwrap().to_str().unwrap();

stream.push_str(
format!("\ninclude!(\"{}/{}\");\n", base_dir_local_path, file_name)
.as_str(),
);
} else {
this.write_item(&mut stream, *def_id, &mut dup)
}
}
});

Expand Down Expand Up @@ -515,10 +542,12 @@ where
}

pub fn write_file(self, ns_name: Symbol, file_name: impl AsRef<Path>) {
let base_dir = file_name.as_ref().parent().unwrap();
let mut stream = String::default();
self.write_items(
&mut stream,
self.codegen_items.iter().map(|def_id| (*def_id).into()),
base_dir.join(ns_name.to_string()).as_path(),
);

stream = format! {r#"pub mod {ns_name} {{
Expand Down
1 change: 1 addition & 0 deletions pilota-build/src/codegen/workspace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ where
def_id,
kind: super::CodegenKind::RePub,
})),
base_dir.as_ref(),
);
if let Some(main_mod_path) = info.main_mod_path {
gen_rs_stream.push_str(&format!(
Expand Down
13 changes: 13 additions & 0 deletions pilota-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ pub struct Builder<MkB, P> {
parser: P,
plugins: Vec<Box<dyn Plugin>>,
ignore_unused: bool,
split: bool,
touches: Vec<(std::path::PathBuf, Vec<String>)>,
change_case: bool,
keep_unknown_fields: Vec<std::path::PathBuf>,
Expand All @@ -103,6 +104,7 @@ impl Builder<MkThriftBackend, ThriftParser> {
dedups: Vec::default(),
special_namings: Vec::default(),
common_crate_name: "common".into(),
split: false,
}
}
}
Expand All @@ -124,6 +126,7 @@ impl Builder<MkProtobufBackend, ProtobufParser> {
dedups: Vec::default(),
special_namings: Vec::default(),
common_crate_name: "common".into(),
split: false,
}
}
}
Expand Down Expand Up @@ -152,6 +155,7 @@ impl<MkB, P> Builder<MkB, P> {
dedups: self.dedups,
special_namings: self.special_namings,
common_crate_name: self.common_crate_name,
split: self.split,
}
}

Expand All @@ -161,6 +165,11 @@ impl<MkB, P> Builder<MkB, P> {
self
}

pub fn with_split(mut self) -> Self {
self.split = true;
self
}

pub fn change_case(mut self, change_case: bool) -> Self {
self.change_case = change_case;
self
Expand Down Expand Up @@ -266,6 +275,7 @@ where
dedups: Vec<FastStr>,
special_namings: Vec<FastStr>,
common_crate_name: FastStr,
split: bool,
) -> Context {
let mut db = RootDatabase::default();
parser.inputs(services.iter().map(|s| &s.path));
Expand Down Expand Up @@ -341,6 +351,7 @@ where
dedups,
special_namings,
common_crate_name,
split,
)
}

Expand All @@ -359,6 +370,7 @@ where
self.dedups,
self.special_namings,
self.common_crate_name,
self.split,
);

cx.exec_plugin(BoxedPlugin);
Expand Down Expand Up @@ -441,6 +453,7 @@ where
self.dedups,
self.special_namings,
self.common_crate_name,
self.split,
);

std::thread::scope(|_scope| {
Expand Down
4 changes: 4 additions & 0 deletions pilota-build/src/middle/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ pub struct Context {
pub(crate) codegen_items: Arc<[DefId]>,
pub(crate) path_resolver: Arc<dyn PathResolver>,
pub(crate) mode: Arc<Mode>,
pub(crate) split: bool,
pub(crate) keep_unknown_fields: Arc<FxHashSet<DefId>>,
pub location_map: Arc<FxHashMap<DefId, DefLocation>>,
pub entry_map: Arc<HashMap<DefLocation, Vec<(DefId, DefLocation)>>>,
Expand All @@ -86,6 +87,7 @@ impl Clone for Context {
codegen_items: self.codegen_items.clone(),
path_resolver: self.path_resolver.clone(),
mode: self.mode.clone(),
split: self.split,
services: self.services.clone(),
keep_unknown_fields: self.keep_unknown_fields.clone(),
location_map: self.location_map.clone(),
Expand Down Expand Up @@ -327,6 +329,7 @@ impl ContextBuilder {
dedups: Vec<FastStr>,
special_namings: Vec<FastStr>,
common_crate_name: FastStr,
split: bool,
) -> Context {
SPECIAL_NAMINGS.get_or_init(|| special_namings);
let mut cx = Context {
Expand All @@ -341,6 +344,7 @@ impl ContextBuilder {
Mode::SingleFile { .. } => Arc::new(DefaultPathResolver),
},
mode: Arc::new(self.mode),
split,
keep_unknown_fields: Arc::new(self.keep_unknown_fields),
location_map: Arc::new(self.location_map),
entry_map: Arc::new(self.entry_map),
Expand Down
102 changes: 102 additions & 0 deletions pilota-build/src/test/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![cfg(test)]

use std::fs;
use std::path::Path;

use tempfile::tempdir;
Expand All @@ -19,6 +20,36 @@ fn diff_file(old: impl AsRef<Path>, new: impl AsRef<Path>) {
}
}

fn diff_dir(old: impl AsRef<Path>, new: impl AsRef<Path>) {
let old_files: Vec<_> = fs::read_dir(old.as_ref())
.unwrap()
.map(|res| res.unwrap().path())
.collect();
let new_files: Vec<_> = fs::read_dir(new.as_ref())
.unwrap()
.map(|res| res.unwrap().path())
.collect();

if old_files.len() != new_files.len() {
panic!(
"Number of files are different between {} and {}: {} vs {}",
old.as_ref().to_str().unwrap(),
new.as_ref().to_str().unwrap(),
old_files.len(),
new_files.len()
);
}

for old_file in old_files {
let file_name = old_file.file_name().unwrap();
let corresponding_new_file = new.as_ref().join(file_name);
if !corresponding_new_file.exists() {
panic!("File {:?} does not exist in the new directory", file_name);
}
diff_file(old_file, corresponding_new_file);
}
}

fn test_protobuf(source: impl AsRef<Path>, target: impl AsRef<Path>) {
test_with_builder(source, target, |source, target| {
crate::Builder::protobuf()
Expand Down Expand Up @@ -55,6 +86,35 @@ fn test_with_builder<F: FnOnce(&Path, &Path)>(
}
}

fn test_with_split_builder<F: FnOnce(&Path, &Path)>(
source: impl AsRef<Path>,
target: impl AsRef<Path>,
gen_dir: impl AsRef<Path>,
f: F,
) {
if std::env::var("UPDATE_TEST_DATA").as_deref() == Ok("1") {
f(source.as_ref(), target.as_ref());
} else {
let dir = tempdir().unwrap();
let path = dir.path().join(
target
.as_ref()
.file_name()
.and_then(|s| s.to_str())
.unwrap(),
);
let mut base_dir_tmp = path.clone();
base_dir_tmp.pop();
base_dir_tmp.push(path.file_stem().unwrap());
println!("{path:?}");

f(source.as_ref(), &path);
diff_file(target, path);

diff_dir(gen_dir, base_dir_tmp);
}
}

fn test_thrift(source: impl AsRef<Path>, target: impl AsRef<Path>) {
test_with_builder(source, target, |source, target| {
crate::Builder::thrift()
Expand All @@ -66,6 +126,22 @@ fn test_thrift(source: impl AsRef<Path>, target: impl AsRef<Path>) {
});
}

fn test_thrift_with_split(
source: impl AsRef<Path>,
target: impl AsRef<Path>,
gen_dir: impl AsRef<Path>,
) {
test_with_split_builder(source, target, gen_dir, |source, target| {
crate::Builder::thrift()
.ignore_unused(false)
.with_split()
.compile_with_config(
vec![IdlService::from_path(source.to_owned())],
crate::Output::File(target.into()),
)
});
}

fn test_plugin_thrift(source: impl AsRef<Path>, target: impl AsRef<Path>) {
test_with_builder(source, target, |source, target| {
crate::Builder::thrift()
Expand Down Expand Up @@ -111,6 +187,32 @@ fn test_thrift_gen() {
});
}

#[test]
fn test_thrift_gen_with_split() {
let test_data_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("test_data")
.join("thrift_with_split");

test_data_dir.read_dir().unwrap().for_each(|f| {
let f = f.unwrap();

let path = f.path();

if let Some(ext) = path.extension() {
if ext == "thrift" {
let mut rs_path = path.clone();
rs_path.set_extension("rs");

let mut gen_dir = path.clone();
gen_dir.pop();
gen_dir.push(rs_path.file_stem().unwrap());

test_thrift_with_split(path, rs_path, gen_dir.as_path());
}
}
});
}

#[test]
fn test_protobuf_gen() {
let test_data_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
Expand Down
20 changes: 20 additions & 0 deletions pilota-build/test_data/thrift_with_split/wrapper_arc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
pub mod wrapper_arc {
#![allow(warnings, clippy::all)]

pub mod wrapper_arc {

include!("wrapper_arc/A.rs");

include!("wrapper_arc/TestService.rs");

include!("wrapper_arc/TestServiceTestResultRecv.rs");

include!("wrapper_arc/TestServiceTestArgsRecv.rs");

include!("wrapper_arc/TestServiceTestResultSend.rs");

include!("wrapper_arc/TEST.rs");

include!("wrapper_arc/TestServiceTestArgsSend.rs");
}
}
13 changes: 13 additions & 0 deletions pilota-build/test_data/thrift_with_split/wrapper_arc.thrift
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
struct A {

}

struct TEST {
1: required string ID,
2: required list<list<A>> Name2(pilota.rust_wrapper_arc="true"),
3: required map<i32, list<A>> Name3(pilota.rust_wrapper_arc="true"),
}

service TestService {
TEST(pilota.rust_wrapper_arc="true") test(1: TEST req(pilota.rust_wrapper_arc="true"));
}
Loading

0 comments on commit 4c26560

Please sign in to comment.