Skip to content

Commit

Permalink
support workspace user gen code outside
Browse files Browse the repository at this point in the history
  • Loading branch information
Millione committed Oct 7, 2023
1 parent 9f3823b commit 76df1d9
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pilota-build/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pilota-build"
version = "0.8.5"
version = "0.8.6"
edition = "2021"
description = "Compile thrift and protobuf idl into rust code at compile-time."
documentation = "https://docs.rs/pilota-build"
Expand Down
18 changes: 16 additions & 2 deletions pilota-build/src/codegen/workspace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct CrateInfo {
workspace_deps: Vec<FastStr>,
items: Vec<DefId>,
re_pubs: Vec<DefId>,
user_gen: Option<String>,
}

impl<B> Workspace<B>
Expand Down Expand Up @@ -147,6 +148,7 @@ where
.sorted()
.dedup()
.collect_vec(),
user_gen: this.cx().plugin_gen.get(k).map(|v| v.value().clone()),
},
)
})?;
Expand Down Expand Up @@ -275,8 +277,10 @@ where

stream.push_str("#![feature(impl_trait_in_assoc_type)]\n");

let mut out_stream = String::default();

self.cg.write_items(
&mut stream,
&mut out_stream,
info.items
.iter()
.map(|def_id| CodegenItem::from(*def_id))
Expand All @@ -287,15 +291,25 @@ where
);

if let Some(main_mod_path) = info.main_mod_path {
stream.push_str(&format!("pub use {}::*;", main_mod_path.join("::")));
out_stream.push_str(&format!("pub use {}::*;", main_mod_path.join("::")));
}

stream.push_str("include!(\"gen.rs\");\n");
if let Some(user_gen) = info.user_gen {
stream.push_str(&user_gen);
}

let out_stream = out_stream.lines().map(|s| s.trim_end()).join("\n");
let stream = stream.lines().map(|s| s.trim_end()).join("\n");

let src_file = base_dir.as_ref().join(&*info.name).join("src/lib.rs");
let out_file = base_dir.as_ref().join(&*info.name).join("src/gen.rs");

std::fs::write(&src_file, stream)?;
std::fs::write(&out_file, out_stream)?;

fmt_file(src_file);
fmt_file(out_file);

Ok(())
}
Expand Down
23 changes: 20 additions & 3 deletions pilota-build/src/middle/context.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{ops::Deref, path::PathBuf, sync::Arc};
use std::{collections::HashMap, ops::Deref, path::PathBuf, sync::Arc};

use anyhow::Context as _;
use dashmap::DashMap;
Expand Down Expand Up @@ -31,7 +31,7 @@ pub struct CrateId {
}

#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Clone)]
pub(crate) enum DefLocation {
pub enum DefLocation {
Fixed(CrateId, ItemPath),
Dynamic,
}
Expand Down Expand Up @@ -65,6 +65,9 @@ pub struct Context {
pub(crate) path_resolver: Arc<dyn PathResolver>,
pub(crate) mode: Arc<Mode>,
pub(crate) keep_unknown_fields: FxHashSet<DefId>,
pub location_map: FxHashMap<DefId, DefLocation>,
pub entry_map: HashMap<DefLocation, Vec<(DefId, DefLocation)>>,
pub plugin_gen: DashMap<DefLocation, String>,
}

impl Clone for Context {
Expand All @@ -79,6 +82,9 @@ impl Clone for Context {
mode: self.mode.clone(),
services: self.services.clone(),
keep_unknown_fields: self.keep_unknown_fields.clone(),
location_map: self.location_map.clone(),
entry_map: self.entry_map.clone(),
plugin_gen: self.plugin_gen.clone(),
}
}
}
Expand All @@ -89,6 +95,8 @@ pub(crate) struct ContextBuilder {
input_items: Vec<DefId>,
mode: Mode,
keep_unknown_fields: FxHashSet<DefId>,
pub location_map: FxHashMap<DefId, DefLocation>,
entry_map: HashMap<DefLocation, Vec<(DefId, DefLocation)>>,
}

impl ContextBuilder {
Expand All @@ -99,6 +107,8 @@ impl ContextBuilder {
input_items,
codegen_items: Default::default(),
keep_unknown_fields: Default::default(),
location_map: Default::default(),
entry_map: Default::default(),
}
}
pub(crate) fn collect(&mut self, mode: CollectMode) {
Expand Down Expand Up @@ -159,7 +169,11 @@ impl ContextBuilder {
}
if matches!(self.mode, Mode::Workspace(_)) {
let location_map = self.workspace_collect_def_ids(&self.codegen_items);

self.location_map = location_map.clone();
self.entry_map = location_map
.clone()
.into_iter()
.into_group_map_by(|item| item.1.clone());
if let Mode::Workspace(info) = &mut self.mode {
info.location_map = location_map
}
Expand Down Expand Up @@ -389,6 +403,9 @@ impl ContextBuilder {
},
mode: Arc::new(self.mode),
keep_unknown_fields: self.keep_unknown_fields,
location_map: self.location_map,
entry_map: self.entry_map,
plugin_gen: Default::default(),
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions pilota-build/src/plugin/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::{
};

mod serde;
mod workspace;

pub use self::serde::SerdePlugin;

Expand Down Expand Up @@ -303,6 +304,10 @@ impl<T> Plugin for Box<T>
where
T: Plugin + ?Sized,
{
fn on_codegen_uint(&mut self, cx: &Context, items: &[DefId]) {
self.deref_mut().on_codegen_uint(cx, items)
}

fn on_item(&mut self, cx: &Context, def_id: DefId, item: Arc<Item>) {
self.deref_mut().on_item(cx, def_id, item)
}
Expand Down
45 changes: 45 additions & 0 deletions pilota-build/src/plugin/workspace.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use crate::{
db::RirDatabase,
middle::context::tls::CUR_ITEM,
rir::{Item, NodeKind},
Plugin,
};

#[derive(Clone, Copy)]
pub struct _WorkspacePlugin;

impl Plugin for _WorkspacePlugin {
fn on_codegen_uint(&mut self, cx: &crate::Context, _items: &[crate::DefId]) {
cx.entry_map.iter().for_each(|(k, v)| {
cx.plugin_gen.insert(k.clone(), "".to_string());
v.iter().for_each(|(def_id, _)| {
CUR_ITEM.set(def_id, || {
let node = cx.node(*def_id).unwrap();

match &node.kind {
NodeKind::Item(item) => self.on_item(cx, *def_id, item.clone()),
_ => {}
}
});
})
});
}

fn on_item(
&mut self,
cx: &crate::Context,
def_id: crate::DefId,
item: std::sync::Arc<crate::rir::Item>,
) {
match &*item {
Item::Service(s) => {
if let Some(loc) = cx.location_map.get(&def_id) {
if let Some(mut gen) = cx.plugin_gen.get_mut(loc) {
gen.push_str(&format!("pub struct {};", s.name.sym));
}
};
}
_ => {}
}
}
}

0 comments on commit 76df1d9

Please sign in to comment.