diff --git a/src/super_gradients/common/environment/package_utils.py b/src/super_gradients/common/environment/package_utils.py new file mode 100644 index 0000000000..4737b1eaed --- /dev/null +++ b/src/super_gradients/common/environment/package_utils.py @@ -0,0 +1,7 @@ +import pkg_resources +from typing import Dict + + +def get_installed_packages() -> Dict[str, str]: + """Map all the installed packages to their version.""" + return {package.key.lower(): package.version for package in pkg_resources.working_set} diff --git a/src/super_gradients/sanity_check/env_sanity_check.py b/src/super_gradients/sanity_check/env_sanity_check.py index e66ea30951..2f001893e3 100644 --- a/src/super_gradients/sanity_check/env_sanity_check.py +++ b/src/super_gradients/sanity_check/env_sanity_check.py @@ -8,6 +8,7 @@ from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.common.environment.ddp_utils import is_main_process +from super_gradients.common.environment.package_utils import get_installed_packages logger = get_logger(__name__, "DEBUG") @@ -79,7 +80,7 @@ def check_packages(): """ test_name = "installed packages" - installed_packages = {package.key.lower(): package.version for package in pkg_resources.working_set} + installed_packages = get_installed_packages() requirements = get_requirements(use_pro_requirements="deci-platform-client" in installed_packages) if requirements is None: diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 5f224838f6..97ecea6f9c 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -11,7 +11,7 @@ import torch.nn import torchmetrics from omegaconf import DictConfig, OmegaConf -from piptools.scripts.sync import _get_installed_distributions + from torch import nn from torch.cuda.amp import GradScaler, autocast from torch.utils.data import DataLoader, SequentialSampler @@ -40,6 +40,7 @@ from super_gradients.common.factories.list_factory import ListFactory from super_gradients.common.factories.losses_factory import LossesFactory from super_gradients.common.factories.metrics_factory import MetricsFactory +from super_gradients.common.environment.package_utils import get_installed_packages from super_gradients.training import utils as core_utils, models, dataloaders from super_gradients.training.datasets.samplers import RepeatAugSampler @@ -1875,8 +1876,7 @@ def _get_hyper_param_config(self): } # ADD INSTALLED PACKAGE LIST + THEIR VERSIONS if self.training_params.log_installed_packages: - pkg_list = list(map(lambda pkg: str(pkg), _get_installed_distributions())) - additional_log_items["installed_packages"] = pkg_list + additional_log_items["installed_packages"] = get_installed_packages() dataset_params = { "train_dataset_params": self.train_loader.dataset.dataset_params if hasattr(self.train_loader.dataset, "dataset_params") else None,