diff --git a/CHANGELOG.md b/CHANGELOG.md index 82be87f95f..c1e8a720e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ ## Breaking changes +- Remove deprecated function `pybamm_install_jax` ([#4362](https://github.com/pybamm-team/PyBaMM/pull/4362)) - Removed legacy python-IDAKLU solver. ([#4326](https://github.com/pybamm-team/PyBaMM/pull/4326)) # [v24.5](https://github.com/pybamm-team/PyBaMM/tree/v24.5) - 2024-07-26 diff --git a/docs/source/api/util.rst b/docs/source/api/util.rst index f187cfbabb..7496b59554 100644 --- a/docs/source/api/util.rst +++ b/docs/source/api/util.rst @@ -16,8 +16,6 @@ Utility functions .. autofunction:: pybamm.load -.. autofunction:: pybamm.install_jax - .. autofunction:: pybamm.have_jax .. autofunction:: pybamm.is_jax_compatible diff --git a/docs/source/user_guide/installation/gnu-linux-mac.rst b/docs/source/user_guide/installation/gnu-linux-mac.rst index 7e69afa839..97171b53b7 100644 --- a/docs/source/user_guide/installation/gnu-linux-mac.rst +++ b/docs/source/user_guide/installation/gnu-linux-mac.rst @@ -99,7 +99,7 @@ Users can install ``jax`` and ``jaxlib`` to use the Jax solver. pip install "pybamm[jax]" -The ``pip install "pybamm[jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``jaxlib`` on your system. (``pybamm_install_jax`` is deprecated.) +The ``pip install "pybamm[jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``jaxlib`` on your system. .. _optional-iree-mlir-support: diff --git a/docs/source/user_guide/installation/windows.rst b/docs/source/user_guide/installation/windows.rst index 02d9f8dd29..44dc79a7d3 100644 --- a/docs/source/user_guide/installation/windows.rst +++ b/docs/source/user_guide/installation/windows.rst @@ -75,7 +75,7 @@ Users can install ``jax`` and ``jaxlib`` to use the Jax solver. pip install "pybamm[jax]" -The ``pip install "pybamm[jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``jaxlib`` on your system. (``pybamm_install_jax`` is deprecated.) +The ``pip install "pybamm[jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``jaxlib`` on your system. Uninstall PyBaMM ---------------- diff --git a/pyproject.toml b/pyproject.toml index 7ab3d5f573..e4cce3eccd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,9 +134,6 @@ all = [ "pybamm[examples,plot,cite,bpx,tqdm]", ] -[project.scripts] -pybamm_install_jax = "pybamm.util:install_jax" - [project.entry-points."pybamm_parameter_sets"] Sulzer2019 = "pybamm.input.parameters.lead_acid.Sulzer2019:get_parameter_values" Ai2020 = "pybamm.input.parameters.lithium_ion.Ai2020:get_parameter_values" diff --git a/src/pybamm/__init__.py b/src/pybamm/__init__.py index a371fdbc03..75f5f4f160 100644 --- a/src/pybamm/__init__.py +++ b/src/pybamm/__init__.py @@ -16,7 +16,6 @@ from .util import ( get_parameters_filepath, have_jax, - install_jax, import_optional_dependency, is_jax_compatible, get_git_commit_info, diff --git a/src/pybamm/util.py b/src/pybamm/util.py index ee1431ecb1..fd94eb88f4 100644 --- a/src/pybamm/util.py +++ b/src/pybamm/util.py @@ -4,7 +4,6 @@ # The code in this file is adapted from Pints # (see https://github.com/pints-team/pints) # -import argparse import importlib.util import importlib.metadata import numbers @@ -12,9 +11,7 @@ import pathlib import pickle import subprocess -import sys import timeit -from platform import system import difflib from warnings import warn @@ -314,60 +311,6 @@ def is_constant_and_can_evaluate(symbol): return False -def install_jax(arguments=None): # pragma: no cover - """ - Install compatible versions of jax, jaxlib. - - Command Line Interface:: - - $ pybamm_install_jax - - | optional arguments: - | -h, --help show help message - | -f, --force force install compatible versions of jax and jaxlib - """ - parser = argparse.ArgumentParser(description="Install jax and jaxlib") - parser.add_argument( - "-f", - "--force", - action="store_true", - help="force install compatible versions of" - f" jax ({JAX_VERSION}) and jaxlib ({JAXLIB_VERSION})", - ) - - args = parser.parse_args(arguments) - - if system() == "Windows": - raise NotImplementedError("Jax is not available on Windows") - - # Raise an error if jax and jaxlib are already installed, but incompatible - # and --force is not set - elif importlib.util.find_spec("jax") is not None: - if not args.force and not is_jax_compatible(): - raise ValueError( - "Jax is already installed but the installed version of jax or jaxlib is" - " not supported by PyBaMM. \nYou can force install compatible versions" - f" of jax ({JAX_VERSION}) and jaxlib ({JAXLIB_VERSION}) using the" - " following command: \npybamm_install_jax --force" - ) - - msg = ( - "pybamm_install_jax is deprecated," - " use 'pip install pybamm[jax]' to install jax & jaxlib" - ) - warn(msg, DeprecationWarning, stacklevel=2) - subprocess.check_call( - [ - sys.executable, - "-m", - "pip", - "install", - f"jax>={JAX_VERSION}", - f"jaxlib>={JAXLIB_VERSION}", - ] - ) - - # https://docs.pybamm.org/en/latest/source/user_guide/contributing.html#managing-optional-dependencies-and-their-imports def import_optional_dependency(module_name, attribute=None): err_msg = f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details."