Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Finalizing the initial infrastructure of JaCe #18

Merged

Conversation

philip-paul-mueller
Copy link
Contributor

This PR introduces some missing pieces into mainline JaCe.

While the changes are very related to each other, the features this PR introduces were not split.
A short summary of what it introduces:

  • The Stages now support pytrees in input and output.
  • Inside the translators scalars and arrays are now distinguished.
  • Support for arrays that are not in continuous C order.
  • Type annotation in the wrapped objects (there are technical limitations, see note in src/jace/stages.py for more).
  • General changes in the organization of the code.
  • Possibility to globally control the optimization levels (currently not that useful).

However, this commit only introduces the state of the development branch regarding the basic infrastructure, i.e. (mostly) src/jace, but leaves out the translators that the development branch has and its tests, to keep the PR small.

Essentially it is a partially copy of `src/jace` from the development branch (027ae35) to this branch.
However, the translators were not copied, thus they are still WIP mode.
Furthermore, all changes to the test were made such that they pass, i.e. they are in WIP mode.
Copy link
Collaborator

@egparedes egparedes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are many small issues (e.g. typos, symbol names and noqa comments not following the coding conventions) which create a lot of noise and thus make very hard to focus on the design and programmatic aspects of the code being reviewed. Please before asking for another round of review, make sure you had read again carefully our coding guidelines and all the linked documents there (e.g. Google Python style conventions).

Also, design and style decisions should be consistent across the codebase and right now there many inconsistent cases, for example:

  • similar classes are sometimes implemented as dataclasses and sometimes not for no clear reason
  • similar attributes in similar classes follow different naming patterns
  • similar functions follow different naming patteens
  • imported modules are aliased inconsistently in different modules
  • documentation for global module symbols is sometimes expressed as a docstring after the symbol definition and sometimes as a source comment with the special sphinx notation.

Next time, please do a first round of review to the PR yourself, to make sure all the small and style issues are fixed in advance.

@codecov-commenter
Copy link

codecov-commenter commented Jun 23, 2024

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

Attention: Patch coverage is 80.98160% with 62 lines in your changes missing coverage. Please review.

Please upload report for BASE (main@1ebbcf3). Learn more about missing BASE report.

Files Patch % Lines
src/jace/translated_jaxpr_sdfg.py 74.64% 10 Missing and 8 partials ⚠️
src/jace/util/jax_helper.py 43.75% 8 Missing and 1 partial ⚠️
src/jace/translator/jaxpr_translator_builder.py 71.42% 5 Missing and 3 partials ⚠️
src/jace/translator/post_translation.py 90.32% 2 Missing and 4 partials ⚠️
src/jace/stages.py 91.93% 5 Missing ⚠️
src/jace/util/traits.py 37.50% 5 Missing ⚠️
src/jace/util/translation_cache.py 85.29% 4 Missing and 1 partial ⚠️
...translator/primitive_translators/alu_translator.py 55.55% 4 Missing ⚠️
src/jace/tracing.py 90.00% 1 Missing and 1 partial ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@           Coverage Diff           @@
##             main      #18   +/-   ##
=======================================
  Coverage        ?   76.86%           
=======================================
  Files           ?       18           
  Lines           ?      899           
  Branches        ?      177           
=======================================
  Hits            ?      691           
  Misses          ?      142           
  Partials        ?       66           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

I also now made the assignement outside, it is just better to understand.
It seems that in newer JAX versions literals are passed either as arrays (with zero dimensions) or as scalars.
@egparedes egparedes changed the title feat: Finalizing the initail infrastructure of JaCe feat: Finalizing the initial infrastructure of JaCe Jun 25, 2024
Copy link
Collaborator

@egparedes egparedes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second round of review. There are not important issues so this should be ready to be merged after addressing my comments.

Comment on lines 352 to 408
# <--------------------------- Compilation/Optimization options management

_JACELOWERED_ACTIVE_COMPILE_OPTIONS: CompilerOptions = optimization.DEFAULT_OPTIMIZATIONS.copy()
"""Global set of currently active compilation/optimization options.

The global set is initialized with `jace.optimization.DEFAULT_OPTIMIZATIONS`. It can be
managed through `update_active_compiler_options()` and accessed through
`get_active_compiler_options()`, however, it is advised that a user should use
`finalize_compilation_options()` for getting the final options that should be used
for optimization.
"""


def update_active_compiler_options(new_active_options: CompilerOptions) -> CompilerOptions:
"""
Updates the set of active compiler options.

Merges the options passed as `new_active_options` with the currently active
compiler options. This set is used by `JaCeLowered.compile()` to determine
which options should be used.
The function will return the set of options that was active before the call.

To obtain the set of currently active options use `get_active_compiler_options()`.

Todo:
Make a proper context manager.
"""
previous_active_options = _JACELOWERED_ACTIVE_COMPILE_OPTIONS.copy()
_JACELOWERED_ACTIVE_COMPILE_OPTIONS.update(new_active_options)
return previous_active_options


def get_active_compiler_options() -> CompilerOptions:
"""Returns the set of currently active compiler options."""
return _JACELOWERED_ACTIVE_COMPILE_OPTIONS.copy()


def finalize_compilation_options(compiler_options: CompilerOptions | None) -> CompilerOptions:
"""
Returns the final compilation options.

There are two different sources of optimization options. The first one is the global
set of currently active compiler options. The second one is the options that are
passed to this function, which takes precedence. Thus, the `compiler_options`
argument describes the difference from the currently active global options.

This function is used by `JaCeLowered` if it has to determine which options to use
for optimization, either for compiling the lowered SDFG or for computing the key.

Args:
compiler_options: The local compilation options.

See Also:
`get_active_compiler_options()` to inspect the set of currently active options
and `update_active_compiler_options()` to modify them.
"""
return get_active_compiler_options() | (compiler_options or {})
Copy link
Collaborator

@egparedes egparedes Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I don't want to start a philosophical discussion here but please remember that in Python:
- global variables are just attributes of a module object
- and thus all classes and functions defined in a module are also global variables
- and your code is giving direct read access to all of them because it's the only way to use them
- and since in pure Python there is no way to forbid access, your code is also giving write access to them and any user can replace the definitions of functions and classes in JaCe modules and access underscored variables and do anything
- and thus direct access to global variables is a common pattern in Python and it's just impossible to avoid in general, so context handlers are a much better way to deal with non-local state.

Once this context is clear for both of us, let's discuss the three functions handling compiler options defined here:

  • update_active_compiler_options() does not fix the actual problem of this approach which is that any user can change the default options at any point in an uncontrolled way. The proper way to do this is to use a context handler to make sure the defaults are changed only within a limited scope. If a context manager is not implemented, this function is 99.9% as dangerous as letting the user change the global variable directly herself, or even more, since this is now public API and may mislead users into thinking this is a safe and appropriate way to go, which it's not, so I think this function should be deleted.
  • get_active_compiler_options() and make_final_compilation_options() heavily overlap since passing None as an argument to make_final_compilation_options() returns exactly the same value as get_active_compiler_options(). If you really want a function to access a global dictionary and/or merge it with another dictionary, then just create a simple function like the following, which also avoids one unneeded copy of the dict:
def get_compiler_options(custom_options: CompilerOptions | None = None) -> CompilerOptions:
    return _JACELOWERED_ACTIVE_COMPILE_OPTIONS | (custom_options or {})

Copy link
Collaborator

@egparedes egparedes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, just a couple of optional suggestions.

out_names: tuple[str, ...] | None
start_state: dace.SDFGState
terminal_state: dace.SDFGState
jaxpr: jax_core.ClosedJaxpr
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then please improve the description in the docstring explaining that it's mainly for debugging purposes, because the current description is not very informative.



@contextlib.contextmanager
def temporary_compiler_options(new_active_options: CompilerOptions) -> Generator[None, None, None]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a comment but I really don't see the need for the temporary and active adjectives in the name of these functions, since it's obvious from the usage pattern that a context manager is for temporary changes and the only global function getting a set of compiler options is by default returning the currently active options, and otherwise would have another name (e.g. get_default_compiler_options())....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed them to set_compiler_options() and get_compiler_options().

@philip-paul-mueller philip-paul-mueller merged commit 19c89b0 into GridTools:main Jul 2, 2024
4 checks passed
@philip-paul-mueller philip-paul-mueller deleted the extended_infrastructure branch September 24, 2024 13:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants