diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/continuous-integration.yml index 00f246f120..539ca59ec0 100644 --- a/.github/workflows/continuous-integration.yml +++ b/.github/workflows/continuous-integration.yml @@ -20,7 +20,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.9] + python-version: [3.12] steps: - name: Checkout uses: actions/checkout@v1 @@ -38,7 +38,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.9] + python-version: [3.12] shard: [0, 1, 2, 3, 4] env: SHARD: ${{ matrix.shard }} diff --git a/setup.py b/setup.py index 6834b256d1..9768d8f238 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,13 @@ else: TFDS_PACKAGE = 'tfds-nightly' +if release: + TF_PACKAGE = 'tensorflow >= 2.15' + KERAS_PACKAGE = 'tf-keras >= 2.15' +else: + TF_PACKAGE = 'tf-nightly' + KERAS_PACKAGE = 'tf-keras-nightly' + class BinaryDistribution(Distribution): """This class is needed in order to create OS specific wheels.""" @@ -91,6 +98,7 @@ def has_ext_modules(self): 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Mathematics', 'Topic :: Scientific/Engineering :: Artificial Intelligence', @@ -101,6 +109,7 @@ def has_ext_modules(self): keywords='tensorflow probability statistics bayesian machine learning', extras_require={ # e.g. `pip install tfp-nightly[jax]` 'jax': ['jax', 'jaxlib'], + 'tf': [TF_PACKAGE, KERAS_PACKAGE], 'tfds': [TFDS_PACKAGE], } ) diff --git a/tensorflow_probability/python/__init__.py b/tensorflow_probability/python/__init__.py index f4742348ad..d81375f6f4 100644 --- a/tensorflow_probability/python/__init__.py +++ b/tensorflow_probability/python/__init__.py @@ -35,7 +35,7 @@ def _validate_tf_environment(package): inadequate. """ try: - import tensorflow.compat.v1 as tf + import tensorflow as tf except (ImportError, ModuleNotFoundError): # Print more informative error message, then reraise. print('\n\nFailed to import TensorFlow. Please note that TensorFlow is not ' @@ -51,7 +51,7 @@ def _validate_tf_environment(package): # # Update this whenever we need to depend on a newer TensorFlow release. # - required_tensorflow_version = '2.14' + required_tensorflow_version = '2.15' # required_tensorflow_version = '1.15' # Needed internally -- DisableOnExport if (distutils.version.LooseVersion(tf.__version__) < @@ -74,6 +74,18 @@ def _validate_tf_environment(package): 'For more detail, see https://github.com/tensorflow/community/pull/287.' ) + if required_tensorflow_version[0] == '2': + try: + import tf_keras # pylint: disable=unused-import + except (ImportError, ModuleNotFoundError): + # Print more informative error message, then reraise. + print('\n\nFailed to import TF-Keras. Please note that TF-Keras is not ' + 'installed by default when you install TensorFlow Probability. ' + 'This is so that JAX-only users do not have to install TensorFlow ' + 'or TF-Keras. To use TensorFlow Probability with TensorFlow, ' + 'please install the tf-keras or tf-keras-nightly package.\n\n') + raise + # Declare these explicitly to appease pytype, which otherwise misses them, # presumably due to lazy loading.