Skip to content
This repository has been archived by the owner on May 11, 2023. It is now read-only.

Commit

Permalink
Drop prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaspinder committed Dec 20, 2022
1 parent 40905ce commit 0da327b
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 22 deletions.
18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# [JaxUtils](https://github.com/JaxGaussianProcesses/JaxUtils)

[![CircleCI](https://dl.circleci.com/status-badge/img/gh/JaxGaussianProcesses/JaxUtils/tree/master.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/JaxGaussianProcesses/JaxUtils/tree/master)

`JaxUtils` provides utility functions for the [`JaxGaussianProcesses`]() ecosystem.</h2>

# Contents
* [PyTree](#pytree)
* [Dataset](#dataset)


- [PyTree](#pytree)
- [Dataset](#dataset)

# PyTree

## Overview

`jaxutils.PyTree` is a mixin class for [registering a python class as a JAX PyTree](https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees). You would define your Python class as follows.

```python
Expand All @@ -27,13 +32,15 @@ class Line(jaxutils.PyTree):
def __init__(self, gradient: Float[Array, "1"], intercept: Float[Array, "1"]) -> None
self.gradient = gradient
self.intercept = intercept

def y(self, x: Float[Array, "N"]) -> Float[Array, "N"]
return x * self.gradient + self.intercept
```

# Dataset

## Overview

`jaxutils.Dataset` is a datset abstraction. In future, we wish to extend this to a heterotopic and isotopic data abstraction.

## Example
Expand All @@ -43,7 +50,7 @@ import jaxutils
import jax.numpy as jnp

# Inputs
X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])

# Outputs
y = jnp.array([[7.0], [8.0], [9.0]])
Expand All @@ -59,6 +66,7 @@ print(f'The output data is {D.y}')
print(f'The data is supervised {D.is_supervised()}')
print(f'The data is unsupervised {D.is_unsupervised()}')
```

```
The number of datapoints is 3
The input dimension is 2
Expand Down
5 changes: 2 additions & 3 deletions jaxutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .data import Dataset, verify_dataset
from . import _version

__version__ = _version.get_versions()['version']
__version__ = _version.get_versions()["version"]
__authors__ = "Thomas Pinder, Daniel Dodd"
__license__ = "MIT"
__emails__ = "[email protected], [email protected]"
Expand All @@ -34,5 +34,4 @@
"PyTree",
"Dataset",
"verify_dataset",
]

]
2 changes: 1 addition & 1 deletion jaxutils/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_config():
cfg = VersioneerConfig()
cfg.VCS = "git"
cfg.style = "pep440"
cfg.tag_prefix = ""
cfg.tag_prefix = "v"
cfg.parentdir_prefix = "None"
cfg.versionfile_source = "jaxutils/_version.py"
cfg.verbose = False
Expand Down
13 changes: 0 additions & 13 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,6 @@
if os.environ["BUILD_JAXUTILS_NIGHTLY"] == "nightly":
NAME += "-nightly"

from versioneer import get_versions as original_get_versions

def get_versions():
from datetime import datetime, timezone

suffix = datetime.now(timezone.utc).strftime(r".dev%Y%m%d")
versions = original_get_versions()
versions["version"] = versions["version"].split("+")[0] + suffix
return versions

versioneer.get_versions = get_versions


REQUIRES = ["jax>=0.4.0", "jaxlib>=0.4.0", "jaxtyping"]

EXTRAS = {
Expand Down

0 comments on commit 0da327b

Please sign in to comment.