-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dev guide #353
Conversation
…nd design choices of the code
…w we treat linear constraints with feasible direction method)
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #353 +/- ##
==========================================
- Coverage 94.26% 94.26% -0.01%
==========================================
Files 78 78
Lines 18176 18176
==========================================
- Hits 17134 17133 -1
- Misses 1042 1043 +1
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few minor fixes
@@ -0,0 +1,227 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are going to mention this, then we should say "For example..., or we need the operation to be performed in a static manner rather than traced dynamically at run time." That is the main reason I would use numpy instead of JAX when developing.
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, but can you use it inside jit function? I don't think so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do use numpy inside jitted and differentiable functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you google the "static jax" the relevant doc comes up
Use numpy for operations that you want to be static; use jax.numpy for operations that you want to be traced.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for example search for np.argsort in integrals/_interp_utils. It would be bad to use jax there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I think that has some other drawbacks like multiple compilations etc. For every different value of the static input, the function has to compiled again. Anyway, I am making some changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the input for the numpy array there is not a jax array so it doesn't affect compilation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left some comments, would be nice to make those quick updates as you address Rory's
and thanks |
Resolves #228
Resolves #775
JAX
tricksdocs/dev_guide
folderto do (or not to do, do we still want these?):
I am not sure how much detail I should give, but here is the first version, I can elaborate depending on your feedback.