Skip to content

Commit

Permalink
Merge pull request #96 from teddykoker/eqx-static-field
Browse files Browse the repository at this point in the history
Add static fields to e3nn_jax.equinox.Linear
  • Loading branch information
ameya98 authored Jan 23, 2025
2 parents 6f6319d + 3f093b8 commit e81640d
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions e3nn_jax/_src/linear_equinox.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,23 +102,23 @@ class Linear(eqx.Module):
(5,)
"""

irreps_out: e3nn.Irreps
irreps_in: e3nn.Irreps
channel_out: int
channel_in: int
gradient_normalization: Optional[Union[float, str]]
path_normalization: Optional[Union[float, str]]
biases: bool
num_indexed_weights: Optional[int]
weights_per_channel: bool
force_irreps_out: bool
weights_dim: Optional[int]
linear_type: str
irreps_out: e3nn.Irreps = eqx.field(static=True)
irreps_in: e3nn.Irreps = eqx.field(static=True)
channel_out: int = eqx.field(static=True)
channel_in: int = eqx.field(static=True)
gradient_normalization: Optional[Union[float, str]] = eqx.field(static=True)
path_normalization: Optional[Union[float, str]] = eqx.field(static=True)
biases: bool = eqx.field(static=True)
num_indexed_weights: Optional[int] = eqx.field(static=True)
weights_per_channel: bool = eqx.field(static=True)
force_irreps_out: bool = eqx.field(static=True)
weights_dim: Optional[int] = eqx.field(static=True)
linear_type: str = eqx.field(static=True)

# These are used internally.
_linear: FunctionalLinear
_linear: FunctionalLinear = eqx.field(static=True)
_weights: Dict[str, jax.Array]
_input_dtype: jnp.dtype
_input_dtype: jnp.dtype = eqx.field(static=True)

def __init__(
self,
Expand Down

0 comments on commit e81640d

Please sign in to comment.