diff --git a/nerfstudio/scripts/train.py b/nerfstudio/scripts/train.py index fde0a5c37a..dd2300568a 100644 --- a/nerfstudio/scripts/train.py +++ b/nerfstudio/scripts/train.py @@ -62,6 +62,7 @@ from nerfstudio.configs.method_configs import AnnotatedBaseConfigUnion from nerfstudio.engine.trainer import TrainerConfig from nerfstudio.utils import comms, profiler +from nerfstudio.utils.available_devices import get_available_devices from nerfstudio.utils.rich_utils import CONSOLE DEFAULT_TIMEOUT = timedelta(minutes=30) @@ -226,6 +227,15 @@ def launch( def main(config: TrainerConfig) -> None: """Main function.""" + # Check if the specified device type is available + available_device_types = get_available_devices() + if config.machine.device_type not in available_device_types: + raise RuntimeError( + f"Specified device type '{config.machine.device_type}' is not available. " + f"Available device types: {available_device_types}. " + "Please specify a valid device type using the CLI option: --machine.device_type [cuda|mps|cpu]" + ) + if config.data: CONSOLE.log("Using --data alias for --data.pipeline.datamanager.data") config.pipeline.datamanager.data = config.data diff --git a/nerfstudio/utils/available_devices.py b/nerfstudio/utils/available_devices.py new file mode 100644 index 0000000000..d27aad6e14 --- /dev/null +++ b/nerfstudio/utils/available_devices.py @@ -0,0 +1,32 @@ +# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Literal + +import torch + + +def get_available_devices() -> List[Literal["cpu", "cuda", "mps"]]: + """Determine the available devices on the machine + + Returns: + list: List of available device types + """ + available_devices: List[Literal["cpu", "cuda", "mps"]] = [] + if torch.cuda.is_available(): + available_devices.append("cuda") + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + available_devices.append("mps") + available_devices.append("cpu") + return available_devices