Skip to content

Commit

Permalink
check backend compatability when astroNN is imported
Browse files Browse the repository at this point in the history
  • Loading branch information
henrysky committed Sep 5, 2024
1 parent afad962 commit c9e0fff
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 41 deletions.
16 changes: 16 additions & 0 deletions src/astroNN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,20 @@

from importlib.metadata import version

import keras

version = __version__ = version("astroNN")
_KERAS_BACKEND = keras.backend.backend()

# check if the backend is tensorflow or pytorch
def check_backend():
"""
Check if the backend is tensorflow or pytorch
"""
if _KERAS_BACKEND != "tensorflow" and _KERAS_BACKEND != "torch":
raise ImportError(

Check warning on line 18 in src/astroNN/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/astroNN/__init__.py#L18

Added line #L18 was not covered by tests
f"astroNN only supports PyTorch or Tensorflow backend, but your current backend is {keras.backend.backend()}"
)


check_backend()
16 changes: 4 additions & 12 deletions src/astroNN/config.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
import configparser
import os
import platform

import keras
import numpy as np
from astroNN import _KERAS_BACKEND
import importlib

backend_framework = importlib.import_module(_KERAS_BACKEND)
astroNN_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".astroNN")
_astroNN_MODEL_NAME = "model_weights.keras" # default astroNN model filename
_KERAS_BACKEND = keras.backend.backend()

supported_backend = ["tensorflow", "torch", "jax"]
if _KERAS_BACKEND not in supported_backend:
raise ImportError(
f"astroNN only support {supported_backend} backend, currently you have '{keras.backend.backend()}' as backend"
)
else:
backend_framework = importlib.import_module(_KERAS_BACKEND)


def config_path(flag=None):
Expand Down Expand Up @@ -249,3 +239,5 @@ def cpu_gpu_reader():
MULTIPROCESS_FLAG = multiprocessing_flag_reader()
ENVVAR_WARN_FLAG = envvar_warning_flag_reader()
CUSTOM_MODEL_PATH = custom_model_path_reader()

__all__ = [_KERAS_BACKEND, MAGIC_NUMBER, MULTIPROCESS_FLAG, ENVVAR_WARN_FLAG, CUSTOM_MODEL_PATH, backend_framework]
21 changes: 10 additions & 11 deletions src/astroNN/models/base_master_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from astroNN.config import _astroNN_MODEL_NAME, cpu_gpu_reader
from astroNN.shared.nn_tools import cpu_fallback
from astroNN.shared.nn_tools import folder_runnum

from astroNN.config import (
_KERAS_BACKEND,
backend_framework,
)
epsilon, plot_model = keras.backend.epsilon, keras.utils.plot_model


Expand Down Expand Up @@ -508,24 +511,20 @@ def jacobian(self, x=None, mean_output=False, mc_num=1, denormalize=False):

start_time = time.time()

if keras.backend.backend() == "tensorflow":
import tensorflow as tf

xtensor = tf.Variable(x_data)
if _KERAS_BACKEND == "tensorflow":
xtensor = backend_framework.Variable(x_data)

with tf.GradientTape(watch_accessed_variables=False) as tape:
with backend_framework.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(xtensor)
temp = _model(xtensor)
if isinstance(temp, dict):
temp = temp["output"]

jacobian = tape.batch_jacobian(temp, xtensor)
elif keras.backend.backend() == "torch":
import torch

elif _KERAS_BACKEND == "torch":
# add new axis for vmap
xtensor = torch.tensor(x_data, requires_grad=True)[:, None, ...]
jacobian = torch.vmap(torch.func.jacrev(_model), randomness="different")(xtensor)
xtensor = backend_framework.tensor(x_data, requires_grad=True)[:, None, ...]
jacobian = backend_framework.vmap(backend_framework.func.jacrev(_model), randomness="different")(xtensor)
else:
raise ValueError("Only Tensorflow and PyTorch backend is supported")

Check warning on line 529 in src/astroNN/models/base_master_nn.py

View check run for this annotation

Codecov / codecov/patch

src/astroNN/models/base_master_nn.py#L529

Added line #L529 was not covered by tests

Expand Down
26 changes: 11 additions & 15 deletions src/astroNN/models/base_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def custom_train_step(self, data):
y = keras.ops.cast(y["output"], backend_framework.float32)

# Run forward pass.
if keras.backend.backend() == "tensorflow":
if _KERAS_BACKEND == "tensorflow":
with backend_framework.GradientTape() as tape:
encoder_output = self.keras_encoder(x, training=True)
if isinstance(encoder_output, dict):
Expand All @@ -313,7 +313,7 @@ def custom_train_step(self, data):
self.keras_model.optimizer.apply_gradients(
zip(gradients, self.keras_model.trainable_weights)
)
elif keras.backend.backend() == "torch":
elif _KERAS_BACKEND == "torch":
self.keras_model.zero_grad()
encoder_output = self.keras_encoder(x, training=True)
if isinstance(encoder_output, dict):
Expand Down Expand Up @@ -936,28 +936,24 @@ def jacobian_latent(self, x=None, mean_output=False, mc_num=1, denormalize=False

start_time = time.time()

Check warning on line 937 in src/astroNN/models/base_vae.py

View check run for this annotation

Codecov / codecov/patch

src/astroNN/models/base_vae.py#L937

Added line #L937 was not covered by tests

if keras.backend.backend() == "tensorflow":
import tensorflow as tf
if _KERAS_BACKEND == "tensorflow":
xtensor = backend_framework.Variable(x_data)

Check warning on line 940 in src/astroNN/models/base_vae.py

View check run for this annotation

Codecov / codecov/patch

src/astroNN/models/base_vae.py#L939-L940

Added lines #L939 - L940 were not covered by tests

xtensor = tf.Variable(x_data)

with tf.GradientTape(watch_accessed_variables=False) as tape:
with backend_framework.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(xtensor)
temp = _model(xtensor)

Check warning on line 944 in src/astroNN/models/base_vae.py

View check run for this annotation

Codecov / codecov/patch

src/astroNN/models/base_vae.py#L942-L944

Added lines #L942 - L944 were not covered by tests

jacobian = tf.squeeze(tape.batch_jacobian(temp, xtensor))
elif keras.backend.backend() == "torch":
import torch

xtensor = torch.tensor(x_data, requires_grad=True)
jacobian = torch.autograd.functional.jacobian(_model, xtensor)
jacobian = backend_framework.squeeze(tape.batch_jacobian(temp, xtensor))
elif _KERAS_BACKEND == "torch":
xtensor = backend_framework.tensor(x_data, requires_grad=True)
jacobian = backend_framework.autograd.functional.jacobian(_model, xtensor)

Check warning on line 949 in src/astroNN/models/base_vae.py

View check run for this annotation

Codecov / codecov/patch

src/astroNN/models/base_vae.py#L946-L949

Added lines #L946 - L949 were not covered by tests
else:
raise ValueError("Only Tensorflow and PyTorch backend is supported")

Check warning on line 951 in src/astroNN/models/base_vae.py

View check run for this annotation

Codecov / codecov/patch

src/astroNN/models/base_vae.py#L951

Added line #L951 was not covered by tests

jacobian = tf.squeeze(tape.batch_jacobian(temp, xtensor))
jacobian = backend_framework.squeeze(tape.batch_jacobian(temp, xtensor))

Check warning on line 953 in src/astroNN/models/base_vae.py

View check run for this annotation

Codecov / codecov/patch

src/astroNN/models/base_vae.py#L953

Added line #L953 was not covered by tests

if mean_output is True:
jacobian_master = tf.reduce_mean(jacobian, axis=0).numpy()
jacobian_master = backend_framework.reduce_mean(jacobian, axis=0).numpy()

Check warning on line 956 in src/astroNN/models/base_vae.py

View check run for this annotation

Codecov / codecov/patch

src/astroNN/models/base_vae.py#L956

Added line #L956 was not covered by tests
else:
jacobian_master = jacobian.numpy()

Expand Down
6 changes: 3 additions & 3 deletions src/astroNN/nn/layers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math

import keras
from astroNN.config import backend_framework
from astroNN.config import _KERAS_BACKEND, backend_framework

class KLDivergenceLayer(keras.layers.Layer):
"""
Expand Down Expand Up @@ -337,12 +337,12 @@ def loop_fn(i):
return self.layer(inputs)

# vectorizing operation depends on backend
if keras.backend.backend() == "torch":
if _KERAS_BACKEND == "torch":
outputs = backend_framework.vmap(
loop_fn, randomness="different", in_dims=0
)(self.arange_n)
# TODO: tensorflow vectorized_map traced operation so there is no randomness which affects e.g., dropout
# elif keras.backend.backend() == "tensorflow":
# elif _KERAS_BACKEND == "tensorflow":
# outputs = backend_framework.vectorized_map(loop_fn, self.arange_n)
else: # fallback to simple for loop
outputs = [self.layer(inputs) for _ in range(self.n)]
Expand Down

0 comments on commit c9e0fff

Please sign in to comment.