Skip to content

Commit

Permalink
[core] Introduce get_compatible_keys for ImageURIPlugin (#49612)
Browse files Browse the repository at this point in the history
<!-- Thank you for your contribution! Please review
https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before
opening a pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?

Refactor the code so it can be overwritten in the ImageURIPlugin

## Related issue number

<!-- For example: "Closes #1234" -->

## Checks

- [ ] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [ ] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [ ] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [ ] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(
  • Loading branch information
pcmoritz authored Jan 7, 2025
1 parent 60caa93 commit 0452613
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 5 deletions.
4 changes: 2 additions & 2 deletions python/ray/_private/runtime_env/agent/runtime_env_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ray._private.ray_logging import setup_component_logger
from ray._private.runtime_env.conda import CondaPlugin
from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.default_impl import get_image_uri_plugin
from ray._private.runtime_env.default_impl import get_image_uri_plugin_cls
from ray._private.runtime_env.java_jars import JavaJarsPlugin
from ray._private.runtime_env.image_uri import ContainerPlugin
from ray._private.runtime_env.pip import PipPlugin
Expand Down Expand Up @@ -223,7 +223,7 @@ def __init__(
# and unify with nsight and other profilers.
self._nsight_plugin = NsightPlugin(self._runtime_env_dir)
self._mpi_plugin = MPIPlugin()
self._image_uri_plugin = get_image_uri_plugin(temp_dir)
self._image_uri_plugin = get_image_uri_plugin_cls()(temp_dir)

# TODO(architkulkarni): "base plugins" and third-party plugins should all go
# through the same code path. We should never need to refer to
Expand Down
4 changes: 2 additions & 2 deletions python/ray/_private/runtime_env/default_impl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from ray._private.runtime_env.image_uri import ImageURIPlugin


def get_image_uri_plugin(ray_tmp_dir: str):
return ImageURIPlugin(ray_tmp_dir)
def get_image_uri_plugin_cls():
return ImageURIPlugin


def get_protocols_provider():
Expand Down
4 changes: 4 additions & 0 deletions python/ray/_private/runtime_env/image_uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ class ImageURIPlugin(RuntimeEnvPlugin):

name = "image_uri"

@staticmethod
def get_compatible_keys():
return {"image_uri", "config", "env_vars"}

def __init__(self, ray_tmp_dir: str):
self._ray_tmp_dir = ray_tmp_dir

Expand Down
6 changes: 5 additions & 1 deletion python/ray/runtime_env/runtime_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ray
from ray._private.ray_constants import DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS
from ray._private.runtime_env.conda import get_uri as get_conda_uri
from ray._private.runtime_env.default_impl import get_image_uri_plugin_cls
from ray._private.runtime_env.pip import get_uri as get_pip_uri
from ray._private.runtime_env.plugin_schema_manager import RuntimeEnvPluginSchemaManager
from ray._private.runtime_env.uv import get_uri as get_uv_uri
Expand Down Expand Up @@ -389,7 +390,10 @@ def __init__(
)

if self.get("image_uri"):
invalid_keys = set(runtime_env.keys()) - {"image_uri", "config", "env_vars"}
image_uri_plugin_cls = get_image_uri_plugin_cls()
invalid_keys = (
set(runtime_env.keys()) - image_uri_plugin_cls.get_compatible_keys()
)
if len(invalid_keys):
raise ValueError(
"The 'image_uri' field currently cannot be used "
Expand Down

0 comments on commit 0452613

Please sign in to comment.