Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nnx.Module.perturb #4515

Merged
merged 1 commit into from
Jan 31, 2025
Merged

Add nnx.Module.perturb #4515

merged 1 commit into from
Jan 31, 2025

Conversation

IvyZX
Copy link
Collaborator

@IvyZX IvyZX commented Jan 30, 2025

Add perturb to the nnx.Module, same as Linen. Also added a nnx.Perturbations variable class.
Covered with tests.

Note:

  • Diff from Linen behavior: Since we no longer have mutable collection check, this will create new perturbation variables even if none is passed into the Module.
  • Perturbation variables requires a sample input to be created, as its shape depends on the input shape.

@@ -747,6 +747,38 @@ class Intermediate(Variable[A]):
pass


class Perturbations(Variable[A]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make it singular like the other types? (e.g. Perturbation)

@@ -747,6 +747,38 @@ class Intermediate(Variable[A]):
pass


class Perturbations(Variable[A]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we should inherit from Intermediate for all temp Variables.

Suggested change
class Perturbations(Variable[A]):
class Perturbations(Intermediate[A]):

@IvyZX IvyZX force-pushed the perturb branch 3 times, most recently from 913bdf6 to b61e796 Compare January 30, 2025 21:34
@copybara-service copybara-service bot merged commit f71ce6c into google:main Jan 31, 2025
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants