Skip to content

Commit

Permalink
arcarray
Browse files Browse the repository at this point in the history
  • Loading branch information
hypercubestart committed Oct 23, 2020
1 parent 0c3ab75 commit 01937fb
Showing 1 changed file with 27 additions and 31 deletions.
58 changes: 27 additions & 31 deletions src/language/interpreter.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use super::language::{ComputeType, Language, PadType};
use egg::{Id, RecExpr};
use itertools::Itertools;
use ndarray::{s, Array, ArrayD, Dimension, IxDyn, Zip};
use ndarray::{s, Array, ArrayD, Dimension, IxDyn, Zip, ArcArray};
use num_traits::cast::AsPrimitive;
use num_traits::Pow;
use std::collections::hash_map::HashMap;
use std::iter::FromIterator;
use std::ops::Div;

pub enum Value<DataType> {
Tensor(ArrayD<DataType>),
Tensor(ArcArray<DataType, IxDyn>),
Access(Access<DataType>),
Usize(usize),
Shape(IxDyn),
Expand All @@ -20,7 +20,7 @@ pub enum Value<DataType> {
}

pub struct Access<DataType> {
pub tensor: ArrayD<DataType>,
pub tensor: ArcArray<DataType, IxDyn>,
pub access_axis: usize,
}

Expand Down Expand Up @@ -168,9 +168,7 @@ where
tensor: access
.tensor
.clone()
.into_owned()
.slice(slice_info.as_ref())
.into_owned(),
.slice_move(slice_info.as_ref()),
access_axis: access.access_axis,
})
}
Expand All @@ -191,23 +189,23 @@ where
assert_eq!(a.access_axis, b.access_axis);

Value::Access(Access {
tensor: ndarray::stack![ndarray::Axis(axis), a.tensor, b.tensor].into_dyn(),
tensor: ndarray::stack![ndarray::Axis(axis), a.tensor, b.tensor].into_dyn().into_shared(),
access_axis: a.access_axis,
})
}
&Language::AccessLiteral(id) => match interpret!(memo_map, &id) {
Value::Tensor(t) => Value::Access(Access {
tensor: t.clone().into_owned(),
tensor: t.clone(),
access_axis: 0,
}),
_ => panic!(),
},
&Language::Literal(id) => match interpret!(memo_map, &id) {
Value::Tensor(t) => Value::Tensor(t.clone().into_owned()),
Value::Tensor(t) => Value::Tensor(t.clone()),
_ => panic!(),
},
&Language::NotNanFloat64(v) => Value::Tensor(
ndarray::arr0(DataType::from_not_nan_float_64_literal(v.into())).into_dyn(),
ndarray::arr0(DataType::from_not_nan_float_64_literal(v.into())).into_dyn().into_shared(),
),
&Language::AccessFlatten(access_id) => {
let access = match interpret!(memo_map, &access_id) {
Expand Down Expand Up @@ -244,7 +242,7 @@ where
};

Value::Access(Access {
tensor: access.tensor.to_owned().permuted_axes(list.to_owned()),
tensor: access.tensor.clone().permuted_axes(list.to_owned()),
access_axis: access.access_axis,
})
}
Expand Down Expand Up @@ -276,10 +274,10 @@ where
Value::Access(Access {
tensor: access
.tensor
.to_owned()
.clone()
.broadcast(shape.to_owned())
.unwrap()
.to_owned(),
.to_owned().into_shared(),
access_axis: access.access_axis,
})
}
Expand Down Expand Up @@ -337,7 +335,7 @@ where
.unwrap();

Value::Access(Access {
tensor,
tensor: tensor.into_shared(),
access_axis,
})
}
Expand Down Expand Up @@ -424,7 +422,7 @@ where
.view(),
],
)
.unwrap(),
.unwrap().into_shared(),
access_axis: access.access_axis,
})
}
Expand Down Expand Up @@ -464,7 +462,7 @@ where
.iter()
.product::<usize>()
.as_(),
),
).into_shared(),
access_axis: access.access_axis,
}),
ComputeType::Softmax => {
Expand Down Expand Up @@ -493,7 +491,7 @@ where

Value::Access(Access {
access_axis: access.access_axis,
tensor: exps,
tensor: exps.into_shared(),
})
}
ComputeType::ElementwiseDiv => Value::Access(Access {
Expand All @@ -513,7 +511,7 @@ where
.axis_iter(ndarray::Axis(access.access_axis))
.next()
.expect("Cannot divide 0 arguments")
.into_owned(),
.to_owned().into_shared(),
|acc, t| acc / t,
),
}),
Expand Down Expand Up @@ -617,15 +615,15 @@ where

Value::Access(Access {
access_axis: reshaped.ndim(),
tensor: reshaped,
tensor: reshaped.into_shared(),
})
}
ComputeType::Negative => Value::Access(Access {
tensor: access.tensor.mapv(|v| v.neg()),
tensor: access.tensor.mapv(|v| v.neg()).into_shared(),
access_axis: access.access_axis,
}),
ComputeType::Sqrt => Value::Access(Access {
tensor: access.tensor.mapv(|v| v.sqrt()),
tensor: access.tensor.mapv(|v| v.sqrt()).into_shared(),
access_axis: access.access_axis,
}),
ComputeType::ReLU => Value::Access(Access {
Expand All @@ -635,14 +633,14 @@ where
} else {
DataType::zero()
}
}),
}).into_shared(),
access_axis: access.access_axis,
}),
ComputeType::ReduceSum => Value::Access(Access {
tensor: access
.tensor
.clone()
.into_shape(
.reshape(
access.tensor.shape()[..access.access_axis]
.iter()
.cloned()
Expand All @@ -655,15 +653,14 @@ where
.collect::<Vec<_>>()
.as_slice(),
)
.unwrap()
.sum_axis(ndarray::Axis(access.access_axis)),
.sum_axis(ndarray::Axis(access.access_axis)).into_shared(),
access_axis: access.access_axis,
}),
ComputeType::ReduceMax => Value::Access(Access {
tensor: access
.tensor
.clone()
.into_shape(
.reshape(
access.tensor.shape()[..access.access_axis]
.iter()
.cloned()
Expand All @@ -676,13 +673,12 @@ where
.collect::<Vec<_>>()
.as_slice(),
)
.unwrap()
.map_axis(ndarray::Axis(access.access_axis), |t| {
t.iter().fold(
DataType::min_value(),
|acc, v| if *v > acc { *v } else { acc },
)
}),
}).into_shared(),
access_axis: access.access_axis,
}),
}
Expand Down Expand Up @@ -767,7 +763,7 @@ where
.unwrap();

Value::Access(Access {
tensor: reshaped.into_dyn(),
tensor: reshaped.into_dyn().into_shared(),
access_axis: a0.access_axis + a1.access_axis,
})
}
Expand Down Expand Up @@ -875,7 +871,7 @@ where
}

Value::Access(Access {
tensor: result,
tensor: result.into_shared(),
// TODO(@gussmith23) Hardcoded
// This already bit me. I forgot to update it when I changed the
// access-windows semantics, and it took me a bit to find the
Expand Down Expand Up @@ -958,7 +954,7 @@ where
Language::Symbol(s) => Value::Tensor(
env.get(s.as_str())
.unwrap_or_else(|| panic!("Symbol {} not in environment", s))
.clone(),
.clone().into_shared(),
),
&Language::Usize(u) => Value::Usize(u),

Expand Down

0 comments on commit 01937fb

Please sign in to comment.