diff --git a/README.md b/README.md index 409132e..3a3b0bf 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,18 @@ # [JaxUtils](https://github.com/JaxGaussianProcesses/JaxUtils) +[![CircleCI](https://dl.circleci.com/status-badge/img/gh/JaxGaussianProcesses/JaxUtils/tree/master.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/JaxGaussianProcesses/JaxUtils/tree/master) + `JaxUtils` provides utility functions for the [`JaxGaussianProcesses`]() ecosystem. # Contents -* [PyTree](#pytree) -* [Dataset](#dataset) - + +- [PyTree](#pytree) +- [Dataset](#dataset) + # PyTree + ## Overview + `jaxutils.PyTree` is a mixin class for [registering a python class as a JAX PyTree](https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees). You would define your Python class as follows. ```python @@ -27,13 +32,15 @@ class Line(jaxutils.PyTree): def __init__(self, gradient: Float[Array, "1"], intercept: Float[Array, "1"]) -> None self.gradient = gradient self.intercept = intercept - + def y(self, x: Float[Array, "N"]) -> Float[Array, "N"] return x * self.gradient + self.intercept ``` # Dataset + ## Overview + `jaxutils.Dataset` is a datset abstraction. In future, we wish to extend this to a heterotopic and isotopic data abstraction. ## Example @@ -43,7 +50,7 @@ import jaxutils import jax.numpy as jnp # Inputs -X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) +X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) # Outputs y = jnp.array([[7.0], [8.0], [9.0]]) @@ -59,6 +66,7 @@ print(f'The output data is {D.y}') print(f'The data is supervised {D.is_supervised()}') print(f'The data is unsupervised {D.is_unsupervised()}') ``` + ``` The number of datapoints is 3 The input dimension is 2 diff --git a/jaxutils/__init__.py b/jaxutils/__init__.py index 9b002bd..3fa7a54 100644 --- a/jaxutils/__init__.py +++ b/jaxutils/__init__.py @@ -18,7 +18,7 @@ from .data import Dataset, verify_dataset from . import _version -__version__ = _version.get_versions()['version'] +__version__ = _version.get_versions()["version"] __authors__ = "Thomas Pinder, Daniel Dodd" __license__ = "MIT" __emails__ = "tompinder@live.co.uk, d.dodd1@lancaster.ac.uk" @@ -34,5 +34,4 @@ "PyTree", "Dataset", "verify_dataset", - ] - +] diff --git a/jaxutils/_version.py b/jaxutils/_version.py index dc8dfe2..4c681d4 100644 --- a/jaxutils/_version.py +++ b/jaxutils/_version.py @@ -44,7 +44,7 @@ def get_config(): cfg = VersioneerConfig() cfg.VCS = "git" cfg.style = "pep440" - cfg.tag_prefix = "" + cfg.tag_prefix = "v" cfg.parentdir_prefix = "None" cfg.versionfile_source = "jaxutils/_version.py" cfg.verbose = False diff --git a/setup.py b/setup.py index ea68e93..e6a1877 100644 --- a/setup.py +++ b/setup.py @@ -26,19 +26,6 @@ if os.environ["BUILD_JAXUTILS_NIGHTLY"] == "nightly": NAME += "-nightly" - from versioneer import get_versions as original_get_versions - - def get_versions(): - from datetime import datetime, timezone - - suffix = datetime.now(timezone.utc).strftime(r".dev%Y%m%d") - versions = original_get_versions() - versions["version"] = versions["version"].split("+")[0] + suffix - return versions - - versioneer.get_versions = get_versions - - REQUIRES = ["jax>=0.4.0", "jaxlib>=0.4.0", "jaxtyping"] EXTRAS = {