Skip to content

Commit

Permalink
Fix Range declutter
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos committed Dec 12, 2024
1 parent b1c73c4 commit bb1e4c5
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions core/src/ops/array/range.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::ops::cast::Cast;
use tract_num_traits::AsPrimitive;
use tract_num_traits::Zero;

Expand Down Expand Up @@ -90,18 +91,25 @@ impl Range {

impl TypedOp for Range {
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let Some(succ) = model.single_succ(node.id)? else { return Ok(None) };
let Some(slice) = succ.op_as::<Slice>() else { return Ok(None) };
if slice.start.is_zero() && slice.end.is_one() {
let mut patch = TypedModelPatch::default();
let wire = patch.tap_model(model, node.inputs[0])?;
let mut wire = patch.tap_model(model, node.inputs[0])?;
if model.outlet_fact(node.inputs[0])?.datum_type.is_tdim() {
wire = patch.wire_node(
format!("{}.cast-tdim", node.name),
Cast { to: DatumType::I64 },
&[wire],
)?[0];
}
let wire = patch.wire_node(&node.name, AxisOp::Add(0), &[wire])?;
patch.shunt_outside(model, succ.id.into(), wire[0])?;
return Ok(Some(patch))
return Ok(Some(patch));
}
Ok(None)
}
Expand Down

0 comments on commit bb1e4c5

Please sign in to comment.