Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Print the egglog program before the last schedule for multi-pass schedule in --run-mode egglog #675

Merged
merged 8 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions dag_in_context/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use indexmap::IndexMap;
use interpreter::Value;
use schedule::rulesets;
use schema::TreeProgram;
use std::{cmp::min, fmt::Write, usize};
use std::{cmp::min, fmt::Write, i64};
use to_egglog::TreeToEgglog;

use crate::{
Expand Down Expand Up @@ -256,7 +256,10 @@ pub enum Schedule {
#[derive(Clone, Debug)]
pub struct EggccConfig {
pub schedule: Schedule,
pub stop_after_n_passes: usize,
/// Stop after this many passes.
/// If stop_after_n_passes is negative,
/// run [0 ... schedule.len() + stop_after_n_passes + 1] passes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems weird... I would expect -1 to run all but the last pass, but it looks like it runs all of them.
Perhaps you should make it i64::MAX by default and use [0 .. scheduler.len() + stop_after_n_passes ] in the neg case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the code to reflect this, although it is a little less elegant now to handle the pass before i64::MAX

pub stop_after_n_passes: i64,
/// For debugging, disable extraction with linearity
/// and just return the first program found.
/// This produces unsound results but is useful for seeing the intermediate extracted result.
Expand All @@ -267,7 +270,7 @@ impl Default for EggccConfig {
fn default() -> Self {
Self {
schedule: Schedule::default(),
stop_after_n_passes: usize::MAX,
stop_after_n_passes: i64::MAX,
linearity: true,
}
}
Expand All @@ -285,15 +288,14 @@ pub fn optimize(
};
let mut res = program.clone();

for (schedule, i) in schedule_list
.iter()
.zip(0..eggcc_config.stop_after_n_passes)
{
let stop_after_n_passes = if eggcc_config.stop_after_n_passes < 0 {
schedule_list.len() as i64 + eggcc_config.stop_after_n_passes
} else {
eggcc_config.stop_after_n_passes
};
for (schedule, i) in schedule_list.iter().zip(0..stop_after_n_passes) {
let mut should_maintain_linearity = true;
if i == min(
eggcc_config.stop_after_n_passes - 1,
schedule_list.len() - 1,
) {
if i == min(stop_after_n_passes - 1, schedule_list.len() as i64 - 1) {
should_maintain_linearity = eggcc_config.linearity;
}

Expand Down
2 changes: 0 additions & 2 deletions dag_in_context/src/schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,6 @@ pub fn parallel_schedule() -> Vec<CompilerPass> {
)),
CompilerPass::InlineWithSchedule(format!(
"
;; HACK: when INLINE appears in this string
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

;; we perform inlining in this pass
(run-schedule
(saturate
{helpers}
Expand Down
14 changes: 11 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use clap::Parser;
use dag_in_context::{EggccConfig, Schedule};
use eggcc::util::{visualize, InterpMode, LLVMOptLevel, Run, RunMode, TestProgram};
use std::{ffi::OsStr, path::PathBuf};
use std::{ffi::OsStr, i64, path::PathBuf};

#[derive(Debug, Parser)]
struct Args {
Expand Down Expand Up @@ -49,8 +49,16 @@ struct Args {
/// For the eggcc schedule, choose between the sequential and parallel schedules.
#[clap(long)]
eggcc_schedule: Option<Schedule>,
/// Eggcc by default performs several passes.
/// This argument specifies how many passes to run (all passes by default).
/// If stop_after_n_passes is negative,
/// run [0 ... schedule.len() + stop_after_n_passes] passes.
///
/// This flag also works with `--run-mode egglog` mode,
/// where it prints the egglog program being processed by the last pass
/// this flag specifies.
#[clap(long)]
stop_after_n_passes: Option<usize>,
stop_after_n_passes: Option<i64>,

/// Turn off enforcement that the output program uses
/// memory linearly. This can give an idea of what
Expand Down Expand Up @@ -105,7 +113,7 @@ fn main() {
add_timing: args.add_timing,
eggcc_config: EggccConfig {
schedule: args.eggcc_schedule.unwrap_or(Schedule::default()),
stop_after_n_passes: args.stop_after_n_passes.unwrap_or(usize::MAX),
stop_after_n_passes: args.stop_after_n_passes.unwrap_or(i64::MAX),
linearity: !args.no_linearity,
},
};
Expand Down
39 changes: 24 additions & 15 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{EggCCError, Optimizer};
use bril_rs::Program;
use clap::ValueEnum;
use dag_in_context::dag2svg::tree_to_svg;
use dag_in_context::schedule::parallel_schedule;
use dag_in_context::schedule::{self, parallel_schedule};
use dag_in_context::{build_program, check_roundtrip_egraph, EggccConfig, Schedule};

use dag_in_context::schema::TreeProgram;
Expand Down Expand Up @@ -802,27 +802,36 @@ impl Run {
)
}
RunMode::Egglog => {
assert_eq!(self.eggcc_config.schedule, Schedule::Parallel, "Output egglog only works in parallel mode. Sequential mode does not use a single egraph");

let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?;
let (dag, mut cache) = rvsdg.to_dag_encoding(true);

let schedule_steps = parallel_schedule();
if schedule_steps.len() != 1 {
log::warn!("Parallel schedule had multiple steps! You may need to adjust the schedule to make eggcc tractable.");
}
// to deal with i64::MAX
let stop_after_n_passes = i64::min(
self.eggcc_config.stop_after_n_passes,
parallel_schedule().len() as i64,
);
let eggcc_config = EggccConfig {
yihozhang marked this conversation as resolved.
Show resolved Hide resolved
// stop before the last pass that user specified.
stop_after_n_passes: stop_after_n_passes - 1,
..self.eggcc_config.clone()
};
let optimized = dag_in_context::optimize(&dag, &mut cache, &eggcc_config)
yihozhang marked this conversation as resolved.
Show resolved Hide resolved
.map_err(EggCCError::EggLog)?;

let schedules = parallel_schedule();
let last_schedule_step = schedules.last().unwrap();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're still using the wrong pass here- didn't you want the stop_after_n_passes pass?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, fixed!


let inline_program = match last_schedule_step {
schedule::CompilerPass::Schedule(_) => None,
schedule::CompilerPass::InlineWithSchedule(_) => Some(&optimized),
};

// TODO make the egglog run mode use intermediate egglog files instead of sticking passes together
let egglog = build_program(
&dag,
Some(&dag),
&optimized,
inline_program,
&dag.fns(),
&mut cache,
&schedule_steps
.iter()
.map(|pass| pass.egglog_schedule().to_string())
.collect::<Vec<String>>()
.join("\n"),
last_schedule_step.egglog_schedule(),
);
(
vec![Visualization {
Expand Down
Loading