Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Depends on #606] Add state edge passthrough for terminating loops #607

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dag_in_context/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub fn prologue() -> String {
include_str!("utility/expr_size.egg"),
include_str!("utility/drop_at.egg"),
include_str!("interval_analysis.egg"),
include_str!("loop_iteration_analysis.egg"),
include_str!("optimizations/switch_rewrites.egg"),
include_str!("optimizations/select.egg"),
include_str!("optimizations/peepholes.egg"),
Expand Down
77 changes: 77 additions & 0 deletions dag_in_context/src/loop_iteration_analysis.egg
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
;; Analysis to get the number of iterations of a loop
(ruleset loop-iters-analysis)

;; inputs, outputs -> number of iterations
;; The minimum possible guess is 1 because of do-while loops
(function LoopNumItersGuess (Expr Expr) i64 :merge (max 1 (min old new)))

;; Marks loops that we know will terminate
(relation TerminatingLoop (Expr Expr))

;; by default, guess that all loops run 1000 times
(rule ((DoWhile inputs outputs))
((set (LoopNumItersGuess inputs outputs) 1000))
:ruleset loop-iters-analysis)

;; For a loop that is false, its num iters is 1
(rule
((= loop (DoWhile inputs outputs))
(= (Const (Bool false) ty ctx) (Get outputs 0)))
((set (LoopNumItersGuess inputs outputs) 1)
(TerminatingLoop inputs outputs))
:ruleset loop-iters-analysis)

;; Figure out number of iterations for a loop with constant bounds and initial value
;; and i is updated before checking pred
;; TODO: we could make it work for decrementing loops
(rule
((= lhs (DoWhile inputs outputs))
(= num-inputs (tuple-length inputs))
(= pred (Get outputs 0))
;; iteration counter starts at start_const
(= (Const (Int start_const) _ty1 _ctx1) (Get inputs counter_i))
;; updated counter at counter_i
(= next_counter (Get outputs (+ counter_i 1)))
;; increments by some constant each loop
;; TODO: how to handle the invariant case?
(= next_counter (Bop (Add) (Get (Arg _ty _ctx) counter_i)
(Const (Int increment) _ty2 _ctx2)))
(> increment 0)
;; while next_counter less than end_constant
(= pred (Bop (LessThan) next_counter
(Const (Int end_constant) _ty3 _ctx3)))
;; end constant is at least start constant
(>= end_constant start_const)
)
(
(set (LoopNumItersGuess inputs outputs) (/ (- end_constant start_const) increment))
(TerminatingLoop inputs outputs)
)
:ruleset loop-iters-analysis)

;; Figure out number of iterations for a loop with constant bounds and initial value
;; and i is updated after checking pred
(rule
((= lhs (DoWhile inputs outputs))
(= num-inputs (tuple-length inputs))
(= pred (Get outputs 0))
;; iteration counter starts at start_const
(= (Const (Int start_const) _ty1 _ctx1) (Get inputs counter_i))
;; updated counter at counter_i
(= next_counter (Get outputs (+ counter_i 1)))
;; increments by a constant each loop
(= next_counter (Bop (Add) (Get (Arg _ty _ctx) counter_i)
(Const (Int increment) _ty2 _ctx2)))
(> increment 0)
;; while this counter less than end_constant
(= pred (Bop (LessThan) (Get (Arg _ty _ctx) counter_i)
(Const (Int end_constant) _ty3 _ctx3)))
;; end constant is at least start constant
(>= end_constant start_const)
)
(
(set (LoopNumItersGuess inputs outputs) (+ (/ (- end_constant start_const) increment) 1))
(TerminatingLoop inputs outputs)
)
:ruleset loop-iters-analysis)

69 changes: 1 addition & 68 deletions dag_in_context/src/optimizations/loop_unroll.egg
Original file line number Diff line number Diff line change
@@ -1,74 +1,7 @@
;; Some simple simplifications of loops
;; Depends on loop iteration analysis
(ruleset loop-unroll)
(ruleset loop-peel)
(ruleset loop-iters-analysis)

;; inputs, outputs -> number of iterations
;; The minimum possible guess is 1 because of do-while loops
;; TODO: dead loop deletion can turn loops with a false condition to a body
(function LoopNumItersGuess (Expr Expr) i64 :merge (max 1 (min old new)))

;; by default, guess that all loops run 1000 times
(rule ((DoWhile inputs outputs))
((set (LoopNumItersGuess inputs outputs) 1000))
:ruleset loop-iters-analysis)

;; For a loop that is false, its num iters is 1
(rule
((= loop (DoWhile inputs outputs))
(= (Const (Bool false) ty ctx) (Get outputs 0)))
((set (LoopNumItersGuess inputs outputs) 1))
:ruleset loop-iters-analysis)

;; Figure out number of iterations for a loop with constant bounds and initial value
;; and i is updated before checking pred
;; TODO: we could make it work for decrementing loops
(rule
((= lhs (DoWhile inputs outputs))
(= pred (Get outputs 0))
;; iteration counter starts at start_const
(= (Const (Int start_const) _ty1 _ctx1) (Get inputs counter_i))
;; updated counter at counter_i
(= next_counter (Get outputs (+ counter_i 1)))
;; increments by some constant each loop
(= next_counter (Bop (Add) (Get (Arg _ty _ctx) counter_i)
(Const (Int increment) _ty2 _ctx2)))
(> increment 0)
;; while next_counter less than end_constant
(= pred (Bop (LessThan) next_counter
(Const (Int end_constant) _ty3 _ctx3)))
;; end constant is at least start constant
(>= end_constant start_const)
)
(
(set (LoopNumItersGuess inputs outputs) (/ (- end_constant start_const) increment))
)
:ruleset loop-iters-analysis)

;; Figure out number of iterations for a loop with constant bounds and initial value
;; and i is updated after checking pred
(rule
((= lhs (DoWhile inputs outputs))
(= pred (Get outputs 0))
;; iteration counter starts at start_const
(= (Const (Int start_const) _ty1 _ctx1) (Get inputs counter_i))
(= body-arg (Get (Arg _ty _ctx) counter_i))
;; updated counter at counter_i
(= next_counter (Get outputs (+ counter_i 1)))
;; increments by a constant each loop
(= next_counter (Bop (Add) body-arg
(Const (Int increment) _ty2 _ctx2)))
(> increment 0)
;; while this counter less than end_constant
(= pred (Bop (LessThan) body-arg
(Const (Int end_constant) _ty3 _ctx3)))
;; end constant is at least start constant
(>= end_constant start_const)
)
(
(set (LoopNumItersGuess inputs outputs) (+ (/ (- end_constant start_const) increment) 1))
)
:ruleset loop-iters-analysis)

;; loop peeling rule
;; Only peel loops that we know iterate < 3 times
Expand Down
30 changes: 29 additions & 1 deletion dag_in_context/src/optimizations/passthrough.egg
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
;; Relies on loop iteration analysis
(ruleset passthrough)


;; Pass through thetas
;; Pass through thetas: pure case
(rule ((= lhs (Get loop i))
(= loop (DoWhile inputs pred-outputs))
(= (Get pred-outputs (+ i 1)) (Get (Arg _ty _ctx) i))
Expand All @@ -13,6 +14,33 @@
((union lhs (Get inputs i)))
:ruleset passthrough)

; ;; Pass through thetas: state edge case
(rule ((= loop (DoWhile inputs pred-outputs))
(= (Get pred-outputs (+ i 1)) (Get (Arg _ty _ctx) i))
;; It is OK to pass through state edges as long as the loop terminates
(TerminatingLoop inputs pred-outputs))
(
;; To maintain the linearity invariant, we must remove the state edge
;; from the loop.
(let new-inputs (TupleRemoveAt inputs i))
(let removed-outputs (TupleRemoveAt pred-outputs (+ i 1)))
(let new-outputs (DropAt (TmpCtx) i removed-outputs))

(let projected-old-loop (TupleRemoveAt loop i))
(let new-loop (DoWhile new-inputs new-outputs))
(union new-loop projected-old-loop)

;; Resolve the temporary context
(union (TmpCtx) (InLoop new-inputs new-outputs))
(delete (TmpCtx))

;; State edge can be gotten without the loop now
(union (Get loop i) (Get inputs i))

;; Subsume the loop later
(ToSubsumeLoop inputs pred-outputs))
:ruleset passthrough)

;; Pass through switch arguments
(rule ((= lhs (Get switch i))
(= switch (Switch pred inputs branches))
Expand Down
4 changes: 4 additions & 0 deletions dag_in_context/src/utility/util.egg
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,8 @@
((subsume (If a b c d)))
:ruleset subsume-after-helpers)

(relation ToSubsumeLoop (Expr Expr))
(rule ((ToSubsumeLoop in p-out))
((subsume (DoWhile in p-out)))
:ruleset subsume-after-helpers)

15 changes: 15 additions & 0 deletions tests/passing/small/dead_loop_deletion.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
@main: int {
i: int = const 1;
forty: int = const 40;
one: int = const 1;

.loop_body:
i: int = add i one;
cond: bool = lt i forty;
br cond .loop_body .loop_end;

.loop_end:
j: int = const 2;

ret j;
}
14 changes: 14 additions & 0 deletions tests/passing/small/loop_state_pass_through.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# ARGS: 5
@main(input: int) {
one: int = const 1;
i: int = const 1;
jmp .loop;
.loop:
max: int = const 10;
cond: bool = lt i max;
i: int = add i one;
br cond .loop .exit;
.exit:
res: int = add i input;
print res;
}
Loading