Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyMC/PyTensor Implementation of Pathfinder VI #387

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

aphc14
Copy link

@aphc14 aphc14 commented Oct 31, 2024

Another version to draft PR #386 which uses more of PyTensor's symbolic variables and compiling functions.

Questions for Review

  1. Which implementations should I continue for future improvements?
  2. Are there additional PyTensor optimisations we could leverage?

`fit_pathfinder`
- Edited `fit_pathfinder` to produce `pathfinder_state`, `pathfinder_info`, `pathfinder_samples` and `pathfinder_idata` for closer examination of the outputs.
- Changed the `num_samples` argument name to `num_draws` to avoid `TypeError` got multiple values for keyword argument 'num_samples'.
- Initial points are automatically set to jitter as jitter is required for pathfinder.

Extras
- New function 'get_jaxified_logp_ravel_inputs' to simplify previous code structure in fit_pathfinder.

Tests
- Added extra test for pathfinder to test pathfinder_info variables and pathfinder_idata  are consistent for a given random seed.
Add a new PyMC-based implementation of Pathfinder VI that uses PyTensor operations which provides support for both PyMC and BlackJAX backends in fit_pathfinder.
- Implemented  in  to support running multiple Pathfinder instances in parallel.
- Implemented  function in  for Pareto Smoothed Importance Resampling (PSIR).
- Moved relevant pathfinder files into the  directory.
- Updated tests to reflect changes in the Pathfinder implementation and added tests for new functionalities.
@aphc14
Copy link
Author

aphc14 commented Nov 4, 2024

Suppose the preferred approach is to stick with symbolic variables in PyTensor than the other non-symbolic approach in #386. In that case, I'd be happy to refactor the Multipath Pathfinder implementation in #386 to use symbolic variables and pytensor.function.

@aphc14 aphc14 force-pushed the pathfinder_w_pytensor_symbolic branch from 9bfc48c to ef2956f Compare November 7, 2024 18:04
@aphc14 aphc14 changed the title Pathfinder w pytensor symbolic PyMC/PyTensor Implementation of Pathfinder VI Nov 7, 2024
@aphc14
Copy link
Author

aphc14 commented Nov 7, 2024

This version runs much faster than #386, but the codes are messier due to the numerous pytensor symbolic variables created for the compiled pytensor functions (see the lines of code between def compute_logp and def single_pathfinder). Any suggestions for a cleaner setup would be appreciated

g: np.ndarray


class LBFGSHistoryManager:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaner to use a data class? Don't know.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, I agree. dataclass now added

Summaryh of changes:
- Remove multiprocessing code in favour of reusing compiled  for each path
-  takes only random_seed as argument for each path
- Compute graph significantly smaller by using pure pytensor op and symoblic variables
- Added LBFGSOp to compile with pytensor.function
- Cleaned up codes using pytensor variables
@aphc14 aphc14 marked this pull request as ready for review November 11, 2024 17:52
@aphc14 aphc14 marked this pull request as draft November 11, 2024 17:53
…and .

- Corrected the dimensions in comments for matrices Q and R in the  function.
- Uumerical stability in the  calculation by changing from  to .
@@ -31,11 +31,13 @@ def fit(method, **kwargs):
arviz.InferenceData
"""
if method == "pathfinder":
# TODO: Remove this once we have a pure PyMC implementation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR will provide that, no?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the latest commit addresses this

Fixed incorrect and inconsistent posterior approximations in the Pathfinder VI
algorithm by:

1. Adding missing parentheses in the phi calculation to ensure proper order
   of operations in matrix multiplications
2. Changing the sign in mu calculation from 'x +' to 'x -' to match Stan's
   implementation (which differs from the original paper)

The resulting changes now make the posterior approximations more reliable.
Implements both sparse and dense BFGS sampling approaches for Pathfinder VI:
- Adds bfgs_sample_dense for cases where 2*maxcor >= num_params.
- Moved existing  and  computations to bfgs_sample_sparse, making the sparse use cases more explicit.

Other changes:
- Sets default maxcor=5 instead of dynamic sizing based on parameters

Dense approximations are recommended when the target distribution has higher dependencies among the parameters.
Bigger changes:
- Made pmx.fit compatible with method='pathfinder'
- Remove JAX dependency when inference_backend='pymc' to support Windows users
- Improve runtime performance by setting trust_input=True for compiled functions

Minor changes:
- Change default num_paths from 1 to 4 for stable and reliable approximations
- Change LBFGS code using dataclasses
- Update tests to handle both PyMC and BlackJAX backends
- Add LBFGSInitFailed exception for failed LBFGS initialisation
- Skip failed paths in multipath_pathfinder and track number of failures
- Handle NaN values from Cholesky decompsition in bfgs_sample
- Add checks for numericl stabilty in matrix operations

Slight performance improvements:
- Set allow_gc=False in scan ops
- Use FAST_RUN mode consistently
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants