diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index 2f106ca2..c760d06d 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -216,7 +216,7 @@ def __init__( domain_dim (`int`) : dimension of the input dims (assumed to be the same for now) domain_names (`Iterable[str]`) : list of input domains - n_steps (`int`) : number of steps to update the query vector, where 0 steps is static attention. + n_steps (`int`) : number of steps to update the query vector """ super().__init__() self.head_size = head_size