Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an access-concatenate-varargs operator #86

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
14 changes: 12 additions & 2 deletions src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,11 @@ pub fn find_vars(expr: &Expr, id: Id) -> Vec<String> {
find_vars_recursive_helper(set, expr, id);
}
// Box<[Id]>
Language::RelayOperatorCall(ids) | Language::List(ids) | Language::Shape(ids) => {
Language::RelayOperatorCall(ids)
| Language::AccessConcatenateVarargs(ids)
| Language::AccessPairVarargs(ids)
| Language::List(ids)
| Language::Shape(ids) => {
for id in ids.iter() {
find_vars_recursive_helper(set, expr, *id);
}
Expand Down Expand Up @@ -411,7 +415,11 @@ pub fn generate_worklist_for_codegen(expr: &Expr, id: Id) -> Vec<Id> {
}
}
// Box<[Id]>
Language::RelayOperatorCall(ids) | Language::Shape(ids) | Language::List(ids) => {
Language::RelayOperatorCall(ids)
| Language::Shape(ids)
| Language::AccessPairVarargs(ids)
| Language::List(ids)
| Language::AccessConcatenateVarargs(ids) => {
for id in ids.iter() {
helper(worklist, expr, *id);
}
Expand Down Expand Up @@ -1730,6 +1738,8 @@ if (i{i} < {dim_len}) {{
| Language::RelayOperator(_) => None,

&Language::Literal(_)
| &Language::AccessConcatenateVarargs(_)
| &Language::AccessPairVarargs(_)
| &Language::SystolicArrayConv2dIm2colNchwOihwWithBlocking(_)
| &Language::SystolicArrayConv2dIm2colNhwcHwioWithBlocking(_)
| &Language::SystolicArrayConv2dNchwOihwWithBlocking(_)
Expand Down
5 changes: 5 additions & 0 deletions src/extraction/ilp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ type EGraph = egg::EGraph<Language, MyAnalysis>;

pub fn filter_by_enode_type(enode: &Language, _eclass_id: Id, _egraph: &EGraph) -> bool {
if match enode {
Language::AccessConcatenateVarargs(_) => todo!(),
Language::AccessPairVarargs(_) => todo!(),

// Things we should never see.
Language::CartesianProduct(_)
Expand Down Expand Up @@ -109,6 +111,9 @@ pub fn filter_obviously_less_preferable_nodes(
) -> bool {
fn is_obviously_extractable(enode: &Language) -> bool {
match enode {
Language::AccessConcatenateVarargs(_) => todo!(),
Language::AccessPairVarargs(_) => todo!(),

// Things we should never see.
Language::CartesianProduct(_)
| Language::MapDotProduct(_)
Expand Down
4 changes: 4 additions & 0 deletions src/extraction/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ impl egg::CostFunction<Language> for MonolithicCostFunction<'_> {
| Language::Slice(_)
| Language::Concatenate(_) => panic!(),

Language::AccessConcatenateVarargs(_) => todo!(),
Language::AccessPairVarargs(_) => todo!(),
Language::SystolicArrayConv2dIm2colNchwOihwWithBlocking(_) => todo!(),
Language::SystolicArrayConv2dIm2colNhwcHwioWithBlocking(_) => todo!(),
Language::SystolicArrayConv2dNchwOihwWithBlocking(_) => todo!(),
Expand Down Expand Up @@ -165,6 +167,8 @@ impl CostFunction<Language> for SimpleCostFunction {
{
use crate::language::Language::*;
let base_cost = match enode {
Language::AccessConcatenateVarargs(_) => todo!(),
Language::AccessPairVarargs(_) => todo!(),
Language::RelayOperator(_) => todo!(),
Language::RelayOperatorCall(_) => todo!(),
Language::RelayActivationLayout(_) => todo!(),
Expand Down
2 changes: 2 additions & 0 deletions src/language/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ where
usize: num_traits::cast::AsPrimitive<DataType>,
{
match &expr.as_ref()[index] {
&Language::AccessConcatenateVarargs(_) => todo!(),
&Language::AccessPairVarargs(_) => todo!(),
&Language::SystolicArrayConv2dIm2colNchwOihwWithBlocking(_) => todo!(),
&Language::SystolicArrayConv2dIm2colNhwcHwioWithBlocking(_) => todo!(),
&Language::SystolicArrayConv2dNchwOihwWithBlocking(_) => todo!(),
Expand Down
227 changes: 227 additions & 0 deletions src/language/language.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,21 @@ define_language! {
// access.item_shape.
"access-concatenate" = AccessConcatenate([Id; 3]),

// (access-concatenate-varargs <a0> ... <an> <axis (usize)>)
// Concatenate accesses <a0> ... <an> along <axis>.
// Output access pattern is accessed at the same dimension as <a0>.
// All dimensions other than <axis> must match.
"access-concatenate-varargs" = AccessConcatenateVarargs(Box<[Id]>),

// (access-pair <a0> <a1>)
// Simply pair every item of a0 with every item of a1.
"access-pair" = AccessPair([Id; 2]),

// (access-pair-varargs <a0> ... <an>)
// Pair/tuple every item of a0...an. Adds a new tuple dimension as the
// first item dimension.
"access-pair-varargs" = AccessPairVarargs(Box<[Id]>),

// (access-shift-right <a0>)
// Shifts a dimension from shape to item shape.
"access-shift-right" = AccessShiftRight(Id),
Expand Down Expand Up @@ -1127,6 +1138,102 @@ impl egg::Analysis<Language> for MyAnalysis {
fn make(egraph: &EGraph<Language, Self>, enode: &Language) -> Self::Data {
use Language::*;
match enode {
AccessPairVarargs(ids) => {
let accesses = ids
.iter()
.map(|id| match &egraph[*id].data {
MyAnalysisData::AccessPattern(a0) => a0,
_ => panic!(),
})
.collect::<Vec<_>>();

assert!(accesses.len() >= 1);

assert!(accesses.iter().all(|a| a.shape == accesses[0].shape));
assert!(accesses
.iter()
.all(|a| a.item_shape == accesses[0].item_shape));

MyAnalysisData::AccessPattern(AccessPatternData {
// TODO(@gussmith23) Implement zero regions
// It's harmless (I think) if `zero_regions` defaults to
// empty, but for it to be useful, we need to implement it
// for each operator.
zero_regions: {
if accesses.iter().any(|a| !a.zero_regions.is_empty()) {
debug!(
"Throwing away zero region analysis data on line {}",
std::line!()
);
}
HashMap::default()
},
shape: accesses[0].shape.clone(),
item_shape: IxDyn(
std::iter::once(accesses.len())
.chain(accesses[0].item_shape.as_array_view().iter().cloned())
.collect::<Vec<_>>()
.as_slice(),
),
})
}
AccessConcatenateVarargs(ids) => {
debug_assert!(ids.len() > 1);

let access_patterns = ids[0..ids.len() - 1]
.iter()
.map(|id| match &egraph[*id].data {
MyAnalysisData::AccessPattern(a) => a,
_ => panic!(),
})
.collect::<Vec<_>>();

let axis = MyAnalysis::get_usize(ids[ids.len() - 1], egraph);

// Check that all dims other than `axis` are equal.
debug_assert!(access_patterns
.iter()
.fold((None, None), |(last_result, last_pattern), this_pattern| {
if let Some(false) = last_result {
return (Some(false), Some(this_pattern));
}
if let None = last_result {
return (Some(true), Some(this_pattern));
}

let last_pattern = last_pattern.unwrap();

assert_eq!(
last_pattern.shape.ndim() + last_pattern.item_shape.ndim(),
this_pattern.shape.ndim() + this_pattern.item_shape.ndim()
);

for i in 0..(last_pattern.shape.ndim() + last_pattern.item_shape.ndim()) {
if i != axis {
assert_eq!(
last_pattern[i], this_pattern[i],
"Access patterns should have the same shape, besides `axis`"
);
}
}

(Some(true), Some(this_pattern))
})
.0
.unwrap());

let mut out_access_pattern = access_patterns[0].clone();
for access_pattern in access_patterns[1..].iter() {
out_access_pattern[axis] = out_access_pattern[axis] + access_pattern[axis];
}

if !out_access_pattern.zero_regions.is_empty() {
debug!("Zero regions being thrown away");
}
out_access_pattern.zero_regions = HashMap::default();

MyAnalysisData::AccessPattern(out_access_pattern)
}
&SystolicArrayConv2dIm2colNhwcHwioWithBlocking(
[rows_id, cols_id, weights_id, data_id, kh_id, kw_id, stride_h_id, stride_w_id],
) => {
Expand Down Expand Up @@ -5206,4 +5313,124 @@ mod tests {
_ => panic!(),
}
}
#[test]
fn access_concatenate_varargs_0() {
let mut map = HashMap::new();
map.insert("t0".to_string(), vec![1, 33, 44, 78]);
map.insert("t1".to_string(), vec![1, 33, 2, 78]);
map.insert("t2".to_string(), vec![1, 33, 52, 78]);
let program = "
(access-concatenate-varargs
(access (access-tensor t0) 1)
(access (access-tensor t1) 1)
(access (access-tensor t2) 1)
2
)"
.parse()
.unwrap();
let mut egraph =
egg::EGraph::<Language, MyAnalysis>::new(MyAnalysis { name_to_shape: map });
let id = egraph.add_expr(&program);
match &egraph[id].data {
MyAnalysisData::AccessPattern(a) => {
assert_eq!(a.shape, IxDyn(&[1]));
assert_eq!(a.item_shape, IxDyn(&[33, 44 + 2 + 52, 78]));
assert!(a.zero_regions.is_empty());
}
_ => panic!(),
}
}

#[test]
#[should_panic(expected = "Access patterns should have the same shape, besides `axis`")]
fn access_concatenate_varargs_1() {
let mut map = HashMap::new();
map.insert("t0".to_string(), vec![1, 33, 44, 78]);
map.insert("t1".to_string(), vec![1, 33, 2, 78]);
map.insert("t2".to_string(), vec![1, 32, 52, 78]);
let program = "
(access-concatenate-varargs
(access (access-tensor t0) 1)
(access (access-tensor t1) 1)
(access (access-tensor t2) 1)
2
)"
.parse()
.unwrap();
let mut egraph =
egg::EGraph::<Language, MyAnalysis>::new(MyAnalysis { name_to_shape: map });
let id = egraph.add_expr(&program);
match &egraph[id].data {
MyAnalysisData::AccessPattern(a) => {
assert_eq!(a.shape, IxDyn(&[1]));
assert_eq!(a.item_shape, IxDyn(&[33, 44 + 2 + 52, 78]));
assert!(a.zero_regions.is_empty());
}
_ => panic!(),
}
}

#[test]
fn access_concatenate_varargs_2() {
let mut map = HashMap::new();
map.insert("t0".to_string(), vec![1, 33, 44, 78]);
map.insert("t1".to_string(), vec![1, 33, 2, 78]);
map.insert("t2".to_string(), vec![1, 33, 52, 78]);
let program = "
(access-concatenate-varargs
(access (access-tensor t0) 2)
(access (access-tensor t1) 1)
(access (access-tensor t2) 1)
2
)"
.parse()
.unwrap();
let mut egraph =
egg::EGraph::<Language, MyAnalysis>::new(MyAnalysis { name_to_shape: map });
let id = egraph.add_expr(&program);
match &egraph[id].data {
MyAnalysisData::AccessPattern(a) => {
assert_eq!(a.shape, IxDyn(&[1, 33]));
assert_eq!(a.item_shape, IxDyn(&[44 + 2 + 52, 78]));
assert!(a.zero_regions.is_empty());
}
_ => panic!(),
}
}

#[test]
fn access_pair_varargs() {
let program = "
(access-pair-varargs
(access (access-tensor t-32-32) 1) (access (access-tensor t-32-32) 1)
(access (access-tensor t-32-32) 1) (access (access-tensor t-32-32) 1)
)
"
.parse()
.unwrap();
let mut egraph = egg::EGraph::<Language, MyAnalysis>::new(MyAnalysis::default());
let id = egraph.add_expr(&program);
match &egraph[id].data {
MyAnalysisData::AccessPattern(a) => {
assert_eq!(a.shape, IxDyn(&[32]));
assert_eq!(a.item_shape, IxDyn(&[4, 32]));
}
_ => panic!(),
}
}

#[should_panic = "assertion failed: accesses.iter().all(|a| a.shape == accesses[0].shape)"]
#[test]
fn access_pair_varargs_panic() {
let program = "
(access-pair-varargs
(access (access-tensor t-32-32) 1) (access (access-tensor t-32-32) 1)
(access (access-tensor t-32-32) 0) (access (access-tensor t-32-32) 1)
)
"
.parse()
.unwrap();
let mut egraph = egg::EGraph::<Language, MyAnalysis>::new(MyAnalysis::default());
egraph.add_expr(&program);
}
}
Loading