Skip to content

Commit

Permalink
Migrate GS usages to new module location (#3311)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/aepsych#623

Pull Request resolved: #3311

Removing the following old files from `ax/modelbridge/` in favor of the new `ax/generation_strategy` directory.

```
best_model_selector.py
dispatch_utils.py
external_generation_node.py
generation_node_input_constructors.py
generation_node.py
generation_strategy.py
model_spec.py
transition_criterion.py
```

Reviewed By: saitcakmak

Differential Revision: D68645075
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed Feb 6, 2025
1 parent 0f8660a commit f58e30c
Show file tree
Hide file tree
Showing 66 changed files with 4,066 additions and 4,062 deletions.
2 changes: 1 addition & 1 deletion ax/analysis/healthcheck/constraints_feasibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.optimization_config import OptimizationConfig
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.base import Adapter
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.transforms.derelativize import Derelativize
from pyre_extensions import assert_is_instance, none_throws

Expand Down
6 changes: 3 additions & 3 deletions ax/analysis/healthcheck/tests/test_constraints_feasibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from ax.core.objective import Objective
from ax.core.optimization_config import OptimizationConfig
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_node import GenerationNode
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.generation_strategy.model_spec import GeneratorSpec
from ax.modelbridge.factory import get_sobol
from ax.modelbridge.generation_node import GenerationNode
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.model_spec import GeneratorSpec
from ax.modelbridge.registry import Generators
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
Expand Down
2 changes: 1 addition & 1 deletion ax/analysis/plotly/arm_effects/insample_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from ax.core.generator_run import GeneratorRun
from ax.core.outcome_constraint import OutcomeConstraint
from ax.exceptions.core import DataRequiredError, UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.base import Adapter
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.registry import Generators
from ax.modelbridge.transforms.derelativize import Derelativize
from ax.utils.common.logger import get_logger
Expand Down
2 changes: 1 addition & 1 deletion ax/analysis/plotly/arm_effects/predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.base import Adapter
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.transforms.derelativize import Derelativize
from pyre_extensions import assert_is_instance, none_throws

Expand Down
2 changes: 1 addition & 1 deletion ax/analysis/plotly/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.cross_validation import cross_validate
from ax.modelbridge.generation_strategy import GenerationStrategy
from plotly import express as px, graph_objects as go
from pyre_extensions import assert_is_instance, none_throws

Expand Down
2 changes: 1 addition & 1 deletion ax/analysis/plotly/surface/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.observation import ObservationFeatures
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.base import Adapter
from ax.modelbridge.generation_strategy import GenerationStrategy
from plotly import graph_objects as go
from pyre_extensions import none_throws

Expand Down
2 changes: 1 addition & 1 deletion ax/analysis/plotly/surface/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.observation import ObservationFeatures
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.base import Adapter
from ax.modelbridge.generation_strategy import GenerationStrategy
from plotly import express as px, graph_objects as go
from pyre_extensions import none_throws

Expand Down
2 changes: 1 addition & 1 deletion ax/analysis/plotly/tests/test_predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ax.core.observation import ObservationFeatures
from ax.core.trial import Trial
from ax.exceptions.core import UserInputError
from ax.modelbridge.dispatch_utils import choose_generation_strategy
from ax.generation_strategy.dispatch_utils import choose_generation_strategy
from ax.modelbridge.prediction_utils import predict_at_point
from ax.modelbridge.registry import Generators
from ax.utils.common.testutils import TestCase
Expand Down
2 changes: 1 addition & 1 deletion ax/benchmark/benchmark_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ax.core.types import TParameterization
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy

from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.service.utils.best_point_mixin import BestPointMixin
from ax.utils.common.base import Base
from pyre_extensions import none_throws
Expand Down
4 changes: 2 additions & 2 deletions ax/benchmark/methods/modular_botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from typing import Any

from ax.benchmark.benchmark_method import BenchmarkMethod
from ax.modelbridge.generation_node import GenerationStep
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.generation_strategy.generation_node import GenerationStep
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.registry import Generators
from ax.models.torch.botorch_modular.surrogate import SurrogateSpec
from botorch.acquisition.acquisition import AcquisitionFunction
Expand Down
5 changes: 4 additions & 1 deletion ax/benchmark/methods/sobol.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@


from ax.benchmark.benchmark_method import BenchmarkMethod
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.generation_strategy.generation_strategy import (
GenerationStep,
GenerationStrategy,
)
from ax.modelbridge.registry import Generators


Expand Down
9 changes: 6 additions & 3 deletions ax/benchmark/tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.early_stopping.strategies.threshold import ThresholdEarlyStoppingStrategy
from ax.modelbridge.external_generation_node import ExternalGenerationNode
from ax.modelbridge.generation_strategy import GenerationNode, GenerationStrategy
from ax.modelbridge.model_spec import GeneratorSpec
from ax.generation_strategy.external_generation_node import ExternalGenerationNode
from ax.generation_strategy.generation_strategy import (
GenerationNode,
GenerationStrategy,
)
from ax.generation_strategy.model_spec import GeneratorSpec
from ax.modelbridge.registry import Generators
from ax.service.utils.scheduler_options import TrialType
from ax.storage.json_store.load import load_experiment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@
import pandas as pd
from ax.core.base_trial import TrialStatus
from ax.core.data import Data
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.generation_strategy.generation_strategy import (
GenerationStep,
GenerationStrategy,
)
from ax.generation_strategy.transition_criterion import (
MinimumPreferenceOccurances,
MinTrials,
)
from ax.modelbridge.registry import Generators
from ax.modelbridge.transition_criterion import MinimumPreferenceOccurances, MinTrials
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_experiment

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from unittest.mock import Mock, patch

from ax.exceptions.core import UserInputError
from ax.modelbridge.best_model_selector import (
from ax.generation_strategy.best_model_selector import (
ReductionCriterion,
SingleDiagnosticBestModelSelector,
)
from ax.modelbridge.model_spec import GeneratorSpec
from ax.generation_strategy.model_spec import GeneratorSpec
from ax.modelbridge.registry import Generators
from ax.utils.common.testutils import TestCase

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
from ax.core.objective import MultiObjective
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
from ax.modelbridge.dispatch_utils import (
from ax.generation_strategy.dispatch_utils import (
_make_botorch_step,
calculate_num_initialization_trials,
choose_generation_strategy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from ax.core.observation import ObservationFeatures
from ax.core.types import TParameterization
from ax.exceptions.core import UnsupportedError
from ax.modelbridge.external_generation_node import ExternalGenerationNode
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.generation_strategy.external_generation_node import ExternalGenerationNode
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.random import RandomAdapter
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,26 @@
from ax.core.base_trial import TrialStatus
from ax.core.observation import ObservationFeatures
from ax.exceptions.core import UserInputError
from ax.modelbridge.best_model_selector import (
from ax.generation_strategy.best_model_selector import (
ReductionCriterion,
SingleDiagnosticBestModelSelector,
)
from ax.modelbridge.factory import get_sobol
from ax.modelbridge.generation_node import (
from ax.generation_strategy.generation_node import (
GenerationNode,
GenerationStep,
MISSING_MODEL_SELECTOR_MESSAGE,
)
from ax.modelbridge.generation_node_input_constructors import (
from ax.generation_strategy.generation_node_input_constructors import (
InputConstructorPurpose,
NodeInputConstructors,
)
from ax.modelbridge.model_spec import FactoryFunctionGeneratorSpec, GeneratorSpec
from ax.generation_strategy.model_spec import (
FactoryFunctionGeneratorSpec,
GeneratorSpec,
)
from ax.generation_strategy.transition_criterion import MinTrials
from ax.modelbridge.factory import get_sobol
from ax.modelbridge.registry import Generators
from ax.modelbridge.transition_criterion import MinTrials
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.testutils import TestCase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.exceptions.generation_strategy import AxGenerationException
from ax.modelbridge.generation_node import GenerationNode
from ax.modelbridge.generation_node_input_constructors import (
from ax.generation_strategy.generation_node import GenerationNode
from ax.generation_strategy.generation_node_input_constructors import (
InputConstructorPurpose,
NodeInputConstructors,
)
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.model_spec import GeneratorSpec
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.generation_strategy.model_spec import GeneratorSpec
from ax.modelbridge.registry import Generators
from ax.utils.common.constants import Keys
from ax.utils.common.testutils import TestCase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,24 @@
GenerationStrategyRepeatedPoints,
MaxParallelismReachedException,
)
from ax.modelbridge.best_model_selector import SingleDiagnosticBestModelSelector
from ax.modelbridge.discrete import DiscreteAdapter
from ax.modelbridge.factory import get_sobol
from ax.modelbridge.generation_node import GenerationNode
from ax.modelbridge.generation_node_input_constructors import (
from ax.generation_strategy.best_model_selector import SingleDiagnosticBestModelSelector
from ax.generation_strategy.generation_node import GenerationNode
from ax.generation_strategy.generation_node_input_constructors import (
InputConstructorPurpose,
NodeInputConstructors,
)
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.model_spec import GeneratorSpec
from ax.generation_strategy.generation_strategy import (
GenerationStep,
GenerationStrategy,
)
from ax.generation_strategy.model_spec import GeneratorSpec
from ax.generation_strategy.transition_criterion import (
AutoTransitionAfterGen,
MaxGenerationParallelism,
MinTrials,
)
from ax.modelbridge.discrete import DiscreteAdapter
from ax.modelbridge.factory import get_sobol
from ax.modelbridge.random import RandomAdapter
from ax.modelbridge.registry import (
_extract_model_state_after_gen,
Expand All @@ -48,11 +56,6 @@
MODEL_KEY_TO_MODEL_SETUP,
)
from ax.modelbridge.torch import TorchAdapter
from ax.modelbridge.transition_criterion import (
AutoTransitionAfterGen,
MaxGenerationParallelism,
MinTrials,
)
from ax.models.random.sobol import SobolGenerator
from ax.utils.common.constants import Keys
from ax.utils.common.equality import same_elements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@

from ax.core.observation import ObservationFeatures
from ax.exceptions.core import UserInputError
from ax.generation_strategy.model_spec import (
FactoryFunctionGeneratorSpec,
GeneratorSpec,
)
from ax.modelbridge.factory import get_sobol
from ax.modelbridge.model_spec import FactoryFunctionGeneratorSpec, GeneratorSpec
from ax.modelbridge.modelbridge_utils import extract_search_space_digest
from ax.modelbridge.registry import Generators
from ax.utils.common.testutils import TestCase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
from ax.core.base_trial import TrialStatus
from ax.core.data import Data
from ax.exceptions.core import UserInputError
from ax.modelbridge.generation_strategy import (
from ax.generation_strategy.generation_strategy import (
GenerationNode,
GenerationStep,
GenerationStrategy,
)
from ax.modelbridge.model_spec import GeneratorSpec
from ax.modelbridge.registry import Generators
from ax.modelbridge.transition_criterion import (
from ax.generation_strategy.model_spec import GeneratorSpec
from ax.generation_strategy.transition_criterion import (
AutoTransitionAfterGen,
AuxiliaryExperimentCheck,
IsSingleObjective,
Expand All @@ -32,6 +31,7 @@
MinimumTrialsInStatus,
MinTrials,
)
from ax.modelbridge.registry import Generators
from ax.utils.common.logger import get_logger
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
Expand Down
5 changes: 4 additions & 1 deletion ax/modelbridge/tests/test_model_fit_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
from ax.core.experiment import Experiment
from ax.core.objective import Objective
from ax.core.optimization_config import OptimizationConfig
from ax.generation_strategy.generation_strategy import (
GenerationStep,
GenerationStrategy,
)
from ax.metrics.branin import BraninMetric
from ax.modelbridge.cross_validation import (
_predict_on_cross_validation_data,
_predict_on_training_data,
compute_model_fit_metrics_from_modelbridge,
get_fit_and_std_quality_and_generalization_dict,
)
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Generators
from ax.runners.synthetic import SyntheticRunner
from ax.service.scheduler import get_fitted_model_bridge, Scheduler, SchedulerOptions
Expand Down
2 changes: 1 addition & 1 deletion ax/preview/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
PercentileEarlyStoppingStrategy,
)
from ax.exceptions.core import ObjectNotFoundError, UnsupportedError
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.preview.api.configs import (
ExperimentConfig,
GenerationStrategyConfig,
Expand Down
9 changes: 6 additions & 3 deletions ax/preview/modelbridge/dispatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
import torch
from ax.core.base_trial import TrialStatus
from ax.exceptions.core import UnsupportedError
from ax.modelbridge.generation_strategy import GenerationNode, GenerationStrategy
from ax.modelbridge.model_spec import GeneratorSpec
from ax.generation_strategy.generation_strategy import (
GenerationNode,
GenerationStrategy,
)
from ax.generation_strategy.model_spec import GeneratorSpec
from ax.generation_strategy.transition_criterion import MinTrials
from ax.modelbridge.registry import Generators
from ax.modelbridge.transition_criterion import MinTrials
from ax.models.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec
from ax.preview.api.configs import GenerationMethod, GenerationStrategyConfig
from botorch.models.transforms.input import Normalize, Warp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import torch
from ax.core.base_trial import TrialStatus
from ax.core.trial import Trial
from ax.generation_strategy.transition_criterion import MinTrials
from ax.modelbridge.registry import Generators
from ax.modelbridge.transition_criterion import MinTrials
from ax.models.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec
from ax.preview.api.configs import GenerationMethod, GenerationStrategyConfig
from ax.preview.modelbridge.dispatch_utils import choose_generation_strategy
Expand Down
3 changes: 2 additions & 1 deletion ax/runners/tests/test_torchx.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
RangeParameter,
SearchSpace,
)

from ax.generation_strategy.dispatch_utils import choose_generation_strategy
from ax.metrics.torchx import TorchXMetric
from ax.modelbridge.dispatch_utils import choose_generation_strategy
from ax.runners.torchx import TorchXRunner
from ax.service.scheduler import FailureRateExceededError, Scheduler, SchedulerOptions
from ax.utils.common.constants import Keys
Expand Down
4 changes: 2 additions & 2 deletions ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@
UserInputError,
)
from ax.exceptions.generation_strategy import MaxParallelismReachedException
from ax.generation_strategy.dispatch_utils import choose_generation_strategy
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy
from ax.global_stopping.strategies.improvement import constraint_satisfaction
from ax.modelbridge.dispatch_utils import choose_generation_strategy
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.prediction_utils import predict_by_features
from ax.plot.base import AxPlotConfig
from ax.plot.contour import plot_contour
Expand Down
Loading

0 comments on commit f58e30c

Please sign in to comment.