Skip to content

Commit

Permalink
simplify CI horovod
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Dec 7, 2020
1 parent b00991e commit 6cd1a83
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
4 changes: 3 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
PROJECT_ROOT = os.path.dirname(TEST_ROOT)
TEMP_PATH = os.path.join(PROJECT_ROOT, 'test_temp')

# todo: this setting `PYTHONPATH` may not be used by other evns like Conda for import packages
if PROJECT_ROOT not in os.getenv('PYTHONPATH', ""):
os.environ['PYTHONPATH'] = f'{PROJECT_ROOT}:{os.environ.get("PYTHONPATH", "")}'
splitter = ":" if os.environ.get("PYTHONPATH", "") else ""
os.environ['PYTHONPATH'] = f'{PROJECT_ROOT}{splitter}{os.environ.get("PYTHONPATH", "")}'

# generate a list of random seeds for each test
RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000))
Expand Down
10 changes: 2 additions & 8 deletions tests/models/data/horovod/train_default_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
import os
import sys

PATH_HERE = os.path.abspath(os.path.dirname(__file__))
PATH_ROOT = os.path.abspath(os.path.join(PATH_HERE, '..', '..', '..', '..'))
sys.path.insert(0, os.path.abspath(PATH_ROOT))
# this is need as e.g. Conda do not uses `PYTHONPATH` env var as pip or/and virtualenv
sys.path += os.getenv('PYTHONPATH').split(':')

from pytorch_lightning import Trainer # noqa: E402
from pytorch_lightning.callbacks import ModelCheckpoint # noqa: E402
Expand All @@ -34,11 +33,6 @@
else:
print('You requested to import Horovod which is missing or not supported for your OS.')


# Move project root to the front of the search path, as some imports may have reordered things
idx = sys.path.index(PATH_ROOT)
sys.path[0], sys.path[idx] = sys.path[idx], sys.path[0]

from tests.base import EvalModelTemplate # noqa: E402
from tests.base.develop_pipelines import run_prediction # noqa: E402
from tests.base.develop_utils import set_random_master_port, reset_seed # noqa: E402
Expand Down

0 comments on commit 6cd1a83

Please sign in to comment.