Skip to content

Commit

Permalink
Improve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos authored and kali committed Dec 11, 2024
1 parent 4f43dfc commit e78057f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions metal/src/ops/scaled_masked_softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use crate::MetalContext;
use derive_new::new;
use tract_core::internal::*;

/// A = SOFTMAX(INPUT * SCALE + MASK, AXIS=1)
/// Only input of rank of 3 is supported and softmax axis = 2
/// A = SOFTMAX(INPUT * SCALE + MASK, AXIS=2)
/// Only input of rank of 3 is supported
#[derive(Clone, Debug, new, Hash)]
pub struct MetalScaledMaskedSoftmax {
pub scale: Arc<Tensor>,
Expand Down
6 changes: 3 additions & 3 deletions metal/src/rewrite_rules/scaled_masked_softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use tract_core::ops::binary::BinMiniOp;
use tract_core::ops::math::{Add, Mul};
use tract_core::ops::nn::{Softmax, SoftmaxExp};

/// A = SOFTMAX(INPUT * SCALE + MASK, AXIS=1)
/// Only input of rank of 3 is supported with softmax axis = 2
/// A = SOFTMAX(INPUT * SCALE + MASK, AXIS=2)
/// Only input of rank of 3 is supported.
#[derive(Clone, Debug, Hash)]
pub struct BasicScaledMaskedSoftmax {
pub scale: Arc<Tensor>,
Expand Down Expand Up @@ -56,7 +56,7 @@ impl TypedOp for BasicScaledMaskedSoftmax {
as_op!();
}

/// Search pattern => A = SOFTMAX(A * SCALE + MASK)
/// Search pattern => A = SOFTMAX(A * SCALE + MASK, AXIS=2)
pub fn as_scaled_masked_softmax_rule(
_ctx: &(),
model: &TypedModel,
Expand Down

0 comments on commit e78057f

Please sign in to comment.