diff --git a/core/src/ops/array/range.rs b/core/src/ops/array/range.rs index 5d9cae5327..6a705ede93 100644 --- a/core/src/ops/array/range.rs +++ b/core/src/ops/array/range.rs @@ -1,3 +1,4 @@ +use crate::ops::cast::Cast; use tract_num_traits::AsPrimitive; use tract_num_traits::Zero; @@ -59,14 +60,14 @@ impl Range { values: &SymbolValues, ) -> TractResult { if start.datum_type() == TDim::datum_type() { + let start = start.to_scalar::()?.eval(values).to_i64()?; + let step = step.to_scalar::()?.eval(values).to_i64()?; let len = { - let start = start.to_scalar::()?.eval(values).to_i64()?; let end = end.to_scalar::()?.eval(values).to_i64()?; - let step = step.to_scalar::()?.eval(values).to_i64()?; #[allow(clippy::cast_abs_to_unsigned)] ((end - start).abs() as usize).divceil(step.abs() as usize) }; - Self::make_t::(start, step, len) + Self::make_t::(&tensor0(start), &tensor0(step), len) } else { let len = dispatch_numbers!(Self::len_for_numbers(start.datum_type())( self, start, end, step @@ -90,18 +91,25 @@ impl Range { impl TypedOp for Range { fn declutter( - &self, - model: &TypedModel, - node: &TypedNode, - ) -> TractResult> { + &self, + model: &TypedModel, + node: &TypedNode, + ) -> TractResult> { let Some(succ) = model.single_succ(node.id)? else { return Ok(None) }; let Some(slice) = succ.op_as::() else { return Ok(None) }; if slice.start.is_zero() && slice.end.is_one() { let mut patch = TypedModelPatch::default(); - let wire = patch.tap_model(model, node.inputs[0])?; + let mut wire = patch.tap_model(model, node.inputs[0])?; + if model.outlet_fact(node.inputs[0])?.datum_type.is_tdim() { + wire = patch.wire_node( + format!("{}.cast-tdim", node.name), + Cast { to: DatumType::I64 }, + &[wire], + )?[0]; + } let wire = patch.wire_node(&node.name, AxisOp::Add(0), &[wire])?; patch.shunt_outside(model, succ.id.into(), wire[0])?; - return Ok(Some(patch)) + return Ok(Some(patch)); } Ok(None) } @@ -116,22 +124,23 @@ impl TypedOp for Range { ensure!(end.shape.volume().is_one()); ensure!(step.shape.volume().is_one()); if let (Some(start), Some(end), Some(step)) = (&start.konst, &end.konst, &step.konst) { - let len = if start.datum_type() == TDim::datum_type() { + if start.datum_type() == TDim::datum_type() { let start = start.to_scalar::()?; let end = end.to_scalar::()?; let step = step.cast_to_scalar::()?; - if step < 0 { + let len = if step < 0 { (start.clone() - end).divceil(-step as usize) } else { (end.clone() - start).divceil(step as usize) - } + }; + Ok(tvec!(DatumType::I64.fact([len]))) } else { - dispatch_numbers!(Self::len_for_numbers(start.datum_type())( + let len = dispatch_numbers!(Self::len_for_numbers(start.datum_type())( self, start, end, step ))? - .to_dim() - }; - Ok(tvec!(start.datum_type().fact([len]))) + .to_dim(); + Ok(tvec!(start.datum_type().fact([len]))) + } } else { Ok(tvec!(start.datum_type.fact(&[self.len.clone()]))) }