From d15be1cdfe8c232c0ce695c59236d225348b382c Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Thu, 15 Oct 2020 21:23:25 -0700 Subject: [PATCH] memo --- src/language/interpreter.rs | 334 ++++++++++++++++++++++-------------- 1 file changed, 205 insertions(+), 129 deletions(-) diff --git a/src/language/interpreter.rs b/src/language/interpreter.rs index fd5832c806..d43500582f 100644 --- a/src/language/interpreter.rs +++ b/src/language/interpreter.rs @@ -63,6 +63,46 @@ where + std::ops::Mul + std::ops::Div + std::ops::Neg + + std::ops::Sub + + std::iter::Sum + + num_traits::identities::One + + num_traits::identities::Zero + + std::cmp::PartialOrd + + num_traits::Bounded + + Exp + + Sqrt + + FromNotNanFloat64Literal + + ndarray::ScalarOperand, + usize: num_traits::cast::AsPrimitive, +{ + let mut memo_map: std::collections::HashMap> = HashMap::default(); + + for i in 0..index { + let val = interpret_impl(expr, i, env, &mut memo_map); + memo_map.insert(egg::Id::from(i), val); + } + + interpret_impl(expr, index, env, &mut memo_map) +} + +macro_rules! interpret { + ($memo_map: expr, $id: expr) => { + $memo_map.get($id).expect("Child expression should have already been interpreted! Do you have a loop in your program?") + } +} + +fn interpret_impl( + expr: &RecExpr, + index: usize, + env: &Environment, + memo_map: &mut std::collections::HashMap>, +) -> Value +where + DataType: Copy + + std::ops::Mul + + std::ops::Div + + std::ops::Neg + + std::ops::Sub + std::iter::Sum + num_traits::identities::One + num_traits::identities::Zero @@ -76,14 +116,15 @@ where { match &expr.as_ref()[index] { &Language::AccessShape([shape_id, item_shape_id]) => { - let shape = match interpret(expr, shape_id.into(), env) { + let shape = match interpret!(memo_map, &shape_id) { Value::Shape(s) => s, _ => panic!(), }; - let item_shape = match interpret(expr, item_shape_id.into(), env) { + let item_shape = match interpret!(memo_map, &item_shape_id) { Value::Shape(s) => s, _ => panic!(), }; + Value::AccessShape( IxDyn( shape @@ -98,20 +139,20 @@ where ) } &Language::AccessSlice([access_id, axis_id, low_id, high_id]) => { - let mut access = match interpret(expr, access_id.into(), env) { + let access = match interpret!(memo_map, &access_id) { Value::Access(a) => a, _ => panic!(), }; - let axis = match interpret(expr, axis_id.into(), env) { - Value::Usize(u) => u, + let axis = match interpret!(memo_map, &axis_id) { + Value::Usize(u) => *u, _ => panic!(), }; - let low = match interpret(expr, low_id.into(), env) { - Value::Usize(u) => u, + let low = match interpret!(memo_map, &low_id) { + Value::Usize(u) => *u, _ => panic!(), }; - let high = match interpret(expr, high_id.into(), env) { - Value::Usize(u) => u, + let high = match interpret!(memo_map, &high_id) { + Value::Usize(u) => *u, _ => panic!(), }; @@ -121,25 +162,28 @@ where .collect(); slice_info[axis] = ndarray::SliceOrIndex::from(low..high); let slice_info = ndarray::SliceInfo::new(slice_info).unwrap(); - access.tensor = access - .tensor - .into_owned() - .slice(slice_info.as_ref()) - .into_owned(); - Value::Access(access) + Value::Access(Access { + tensor: access + .tensor + .clone() + .into_owned() + .slice(slice_info.as_ref()) + .into_owned(), + access_axis: access.access_axis, + }) } &Language::AccessConcatenate([a_id, b_id, axis_id]) => { - let a = match interpret(expr, a_id.into(), env) { + let a = match interpret!(memo_map, &a_id) { Value::Access(a) => a, _ => panic!(), }; - let b = match interpret(expr, b_id.into(), env) { + let b = match interpret!(memo_map, &b_id) { Value::Access(a) => a, _ => panic!(), }; - let axis = match interpret(expr, axis_id.into(), env) { - Value::Usize(u) => u, + let axis = match interpret!(memo_map, &axis_id) { + Value::Usize(u) => *u, _ => panic!(), }; @@ -150,22 +194,22 @@ where access_axis: a.access_axis, }) } - &Language::AccessLiteral(id) => match interpret(expr, id.into(), env) { + &Language::AccessLiteral(id) => match interpret!(memo_map, &id) { Value::Tensor(t) => Value::Access(Access { - tensor: t, + tensor: t.clone().into_owned(), access_axis: 0, }), _ => panic!(), }, - &Language::Literal(id) => match interpret(expr, id.into(), env) { - t @ Value::Tensor(_) => t, + &Language::Literal(id) => match interpret!(memo_map, &id) { + Value::Tensor(t) => Value::Tensor(t.clone().into_owned()), _ => panic!(), }, &Language::NotNanFloat64(v) => Value::Tensor( ndarray::arr0(DataType::from_not_nan_float_64_literal(v.into())).into_dyn(), ), &Language::AccessFlatten(access_id) => { - let mut access = match interpret(expr, access_id.into(), env) { + let access = match interpret!(memo_map, &access_id) { Value::Access(a) => a, _ => panic!(), }; @@ -183,37 +227,40 @@ where ] }; - access.tensor = access.tensor.into_shape(shape).unwrap().into_dyn(); - - Value::Access(access) + Value::Access(Access { + tensor: access.tensor.clone().into_shape(shape).unwrap().into_dyn(), + access_axis: access.access_axis, + }) } &Language::AccessTranspose([access_id, list_id]) => { - let mut access = match interpret(expr, access_id.into(), env) { + let access = match interpret!(memo_map, &access_id) { Value::Access(a) => a, _ => panic!(), }; - let list = match interpret(expr, list_id.into(), env) { + let list = match interpret!(memo_map, &list_id) { Value::List(l) => l, _ => panic!(), }; - access.tensor = access.tensor.permuted_axes(list); - Value::Access(access) + Value::Access(Access { + tensor: access.tensor.to_owned().permuted_axes(list.to_owned()), + access_axis: access.access_axis, + }) } Language::List(list) => Value::List( list.iter() - .map(|id: &Id| match interpret(expr, (*id).into(), env) { - Value::Usize(u) => u, + .map(|id: &Id| match interpret!(memo_map, id) { + Value::Usize(u) => *u, _ => panic!(), }) .collect::>(), ), &Language::AccessBroadcast([access_id, shape_id]) => { - let mut access = match interpret(expr, access_id.into(), env) { + let access = match interpret!(memo_map, &access_id) { Value::Access(a) => a, _ => panic!(), }; - let shape = match interpret(expr, shape_id.into(), env) { + let shape = match interpret!(memo_map, &shape_id) { Value::AccessShape(s, _) => s, _ => panic!("Expected access shape as second argument to access-broadcast"), }; @@ -225,34 +272,40 @@ where assert!(*broadcast_from_dim == 1 || broadcast_from_dim == broadcast_to_dim); } - access.tensor = access.tensor.broadcast(shape).unwrap().to_owned(); - - Value::Access(access) + Value::Access(Access { + tensor: access + .tensor + .to_owned() + .broadcast(shape.to_owned()) + .unwrap() + .to_owned(), + access_axis: access.access_axis, + }) } &Language::AccessInsertAxis([access_id, axis_id]) => { - let mut access = match interpret(expr, access_id.into(), env) { + let access = match interpret!(memo_map, &access_id) { Value::Access(a) => a, _ => panic!(), }; - let axis = match interpret(expr, axis_id.into(), env) { - Value::Usize(u) => u, + let axis = match interpret!(memo_map, &axis_id) { + Value::Usize(u) => *u, _ => panic!(), }; assert!(axis <= access.tensor.ndim()); - access.tensor = access.tensor.insert_axis(ndarray::Axis(axis)); - if axis <= access.access_axis { - access.access_axis += 1; + let mut access_axis = access.access_axis; + if axis <= access_axis { + access_axis += 1; } - Value::Access(access) + Value::Access(Access { + tensor: access.tensor.clone().insert_axis(ndarray::Axis(axis)), + access_axis: access_axis, + }) } &Language::AccessPair([a0_id, a1_id]) => { - let (a0, a1) = match ( - interpret(expr, a0_id.into(), env), - interpret(expr, a1_id.into(), env), - ) { + let (a0, a1) = match (interpret!(memo_map, &a0_id), interpret!(memo_map, &a1_id)) { (Value::Access(a0), Value::Access(a1)) => (a0, a1), _ => panic!("Expected both arguments to access-pair to be accesses"), }; @@ -270,8 +323,14 @@ where let tensor = ndarray::stack( ndarray::Axis(access_axis), &[ - a0.tensor.insert_axis(ndarray::Axis(access_axis)).view(), - a1.tensor.insert_axis(ndarray::Axis(access_axis)).view(), + a0.tensor + .clone() + .insert_axis(ndarray::Axis(access_axis)) + .view(), + a1.tensor + .clone() + .insert_axis(ndarray::Axis(access_axis)) + .view(), ], ) .unwrap(); @@ -282,12 +341,12 @@ where }) } &Language::AccessSqueeze([access_id, axis_id]) => { - let mut access = match interpret(expr, access_id.into(), env) { + let access = match interpret!(memo_map, &access_id) { Value::Access(a) => a, _ => panic!(), }; - let axis = match interpret(expr, axis_id.into(), env) { - Value::Usize(u) => u, + let axis = match interpret!(memo_map, &axis_id) { + Value::Usize(u) => *u, _ => panic!(), }; @@ -297,33 +356,39 @@ where "Cannot squeeze an axis which is not equal to 1" ); - access.tensor = access.tensor.index_axis_move(ndarray::Axis(axis), 0); - if axis < access.access_axis { - access.access_axis -= 1; + let mut access_axis = access.access_axis; + if axis < access_axis { + access_axis -= 1; } - Value::Access(access) + Value::Access(Access { + tensor: access + .tensor + .clone() + .index_axis_move(ndarray::Axis(axis), 0), + access_axis: access_axis, + }) } Language::PadType(t) => Value::PadType(*t), &Language::AccessPad([access_id, pad_type_id, axis_id, pad_before_id, pad_after_id]) => { - let access = match interpret(expr, access_id.into(), env) { + let access = match interpret!(memo_map, &access_id) { Value::Access(a) => a, _ => panic!(), }; - let pad_type = match interpret(expr, pad_type_id.into(), env) { + let pad_type = match interpret!(memo_map, &pad_type_id) { Value::PadType(t) => t, _ => panic!(), }; - let axis = match interpret(expr, axis_id.into(), env) { - Value::Usize(u) => u, + let axis = match interpret!(memo_map, &axis_id) { + Value::Usize(u) => *u, _ => panic!(), }; - let pad_before = match interpret(expr, pad_before_id.into(), env) { - Value::Usize(u) => u, + let pad_before = match interpret!(memo_map, &pad_before_id) { + Value::Usize(u) => *u, _ => panic!(), }; - let pad_after = match interpret(expr, pad_after_id.into(), env) { - Value::Usize(u) => u, + let pad_after = match interpret!(memo_map, &pad_after_id) { + Value::Usize(u) => *u, _ => panic!(), }; @@ -364,11 +429,11 @@ where } Language::ComputeType(t) => Value::ComputeType(t.clone()), &Language::Compute([compute_type_id, access_id]) => { - let compute_type = match interpret(expr, compute_type_id.into(), env) { + let compute_type = match interpret!(memo_map, &compute_type_id) { Value::ComputeType(t) => t, _ => panic!(), }; - let access = match interpret(expr, access_id.into(), env) { + let access = match interpret!(memo_map, &access_id) { Value::Access(a) => a, _ => panic!(), }; @@ -409,7 +474,15 @@ where ); let shape = access.tensor.shape(); - let mut exps = ndarray::Zip::from(&access.tensor).apply_collect(|v| v.exp()); + // shift by max value to prevent integer overflow + let max = access + .tensor + .iter() + .max_by(|x, y| x.partial_cmp(y).unwrap()) + .unwrap() + .clone(); + let mut exps = + ndarray::Zip::from(&access.tensor).apply_collect(|v| (*v - max).exp()); let denominators = exps .sum_axis(ndarray::Axis(access.tensor.ndim() - 1)) .insert_axis(ndarray::Axis(access.tensor.ndim() - 1)); @@ -614,10 +687,7 @@ where } } &Language::AccessCartesianProduct([a0_id, a1_id]) => { - let (a0, a1) = match ( - interpret(expr, a0_id.into(), env), - interpret(expr, a1_id.into(), env), - ) { + let (a0, a1) = match (interpret!(memo_map, &a0_id), interpret!(memo_map, &a1_id)) { (Value::Access(a0), Value::Access(a1)) => (a0, a1), _ => panic!(), }; @@ -701,33 +771,33 @@ where }) } &Language::Access([access_id, dim_id]) => { - let access = match interpret(expr, access_id.into(), env) { + let access = match interpret!(memo_map, &access_id) { Value::Access(a) => a, _ => panic!(), }; - let dim = match interpret(expr, dim_id.into(), env) { - Value::Usize(u) => u, + let dim = match interpret!(memo_map, &dim_id) { + Value::Usize(u) => *u, _ => panic!(), }; assert!(dim <= access.tensor.ndim()); Value::Access(Access { - tensor: access.tensor, + tensor: access.tensor.clone(), // TODO(@gussmith) Settle on vocab: "axis" or "dimension"? access_axis: dim, }) } &Language::AccessWindows([access_id, filters_shape_id, stride_shape_id]) => { - let access = match interpret(expr, access_id.into(), env) { + let access = match interpret!(memo_map, &access_id) { Value::Access(a) => a, _ => panic!(), }; - let filters_shape = match interpret(expr, filters_shape_id.into(), env) { + let filters_shape = match interpret!(memo_map, &filters_shape_id) { Value::Shape(s) => s, _ => panic!(), }; - let stride_shape = match interpret(expr, stride_shape_id.into(), env) { + let stride_shape = match interpret!(memo_map, &stride_shape_id) { Value::Shape(s) => s, _ => panic!(), }; @@ -815,64 +885,70 @@ where } Language::Shape(list) => Value::Shape(IxDyn( list.iter() - .map(|id: &Id| match interpret(expr, (*id).into(), env) { - Value::Usize(u) => u, + .map(|id: &Id| match interpret!(memo_map, id) { + Value::Usize(u) => *u, _ => panic!(), }) .collect::>() .as_slice(), )), - &Language::SliceShape([shape_id, slice_axis_id]) => match ( - interpret(expr, shape_id.into(), env), - interpret(expr, slice_axis_id.into(), env), - ) { - (Value::Shape(s), Value::Usize(u)) => { - Value::Shape(IxDyn(s.as_array_view().slice(s![u..]).to_slice().unwrap())) - } - _ => panic!(), - }, - &Language::ShapeInsertAxis([shape_id, axis_id]) => match ( - interpret(expr, shape_id.into(), env), - interpret(expr, axis_id.into(), env), - ) { - (Value::Shape(s), Value::Usize(u)) => { - assert!(u <= s.ndim()); - Value::Shape(IxDyn( - s.slice()[..u] - .iter() - .chain(std::iter::once(&1)) - .chain(s.slice()[u..].iter()) - .cloned() - .collect::>() - .as_slice(), - )) - } - _ => panic!(), - }, - &Language::ShapeRemoveAxis([shape_id, axis_id]) => match ( - interpret(expr, shape_id.into(), env), - interpret(expr, axis_id.into(), env), - ) { - (Value::Shape(s), Value::Usize(u)) => { - assert!(u < s.ndim(), "Invalid axis in shape-remove-axis"); - Value::Shape(IxDyn( - s.slice()[..u] - .iter() - .chain(s.slice()[u + 1..].iter()) - .cloned() - .collect::>() - .as_slice(), - )) - } - _ => panic!(), - }, - &Language::ShapeOf([tensor_id]) => match interpret(expr, tensor_id.into(), env) { + &Language::SliceShape([shape_id, slice_axis_id]) => { + let s = match interpret!(memo_map, &shape_id) { + Value::Shape(s) => s, + _ => panic!(), + }; + let u = match interpret!(memo_map, &slice_axis_id) { + Value::Usize(u) => *u, + _ => panic!(), + }; + Value::Shape(IxDyn(s.as_array_view().slice(s![u..]).to_slice().unwrap())) + } + &Language::ShapeInsertAxis([shape_id, axis_id]) => { + let s = match interpret!(memo_map, &shape_id) { + Value::Shape(s) => s, + _ => panic!(), + }; + let u = match interpret!(memo_map, &axis_id) { + Value::Usize(u) => *u, + _ => panic!(), + }; + assert!(u <= s.ndim()); + Value::Shape(IxDyn( + s.slice()[..u] + .iter() + .chain(std::iter::once(&1)) + .chain(s.slice()[u..].iter()) + .cloned() + .collect::>() + .as_slice(), + )) + } + &Language::ShapeRemoveAxis([shape_id, axis_id]) => { + let s = match interpret!(memo_map, &shape_id) { + Value::Shape(s) => s, + _ => panic!(), + }; + let u = match interpret!(memo_map, &axis_id) { + Value::Usize(u) => *u, + _ => panic!(), + }; + assert!(u < s.ndim(), "Invalid axis in shape-remove-axis"); + Value::Shape(IxDyn( + s.slice()[..u] + .iter() + .chain(s.slice()[u + 1..].iter()) + .cloned() + .collect::>() + .as_slice(), + )) + } + &Language::ShapeOf([tensor_id]) => match interpret!(memo_map, &tensor_id) { Value::Tensor(t) => Value::Shape(IxDyn(t.shape())), _ => panic!(), }, - &Language::AccessTensor(tensor_id) => match interpret(expr, tensor_id.into(), env) { + &Language::AccessTensor(tensor_id) => match interpret!(memo_map, &tensor_id) { Value::Tensor(t) => Value::Access(Access { - tensor: t, + tensor: t.clone(), // TODO(@gussmith) Arbitrarily picked default access axis access_axis: 0, }),