Skip to content

Commit

Permalink
Remove config from task manager and stop killing it
Browse files Browse the repository at this point in the history
  • Loading branch information
gabegma committed Jan 27, 2023
1 parent e8849f3 commit 8e72034
Show file tree
Hide file tree
Showing 22 changed files with 174 additions and 127 deletions.
38 changes: 16 additions & 22 deletions azimuth/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from azimuth.startup import startup_tasks
from azimuth.task_manager import TaskManager
from azimuth.types import DatasetSplitName, ModuleOptions
from azimuth.utils.cluster import default_cluster
from azimuth.utils.conversion import JSONResponseIgnoreNan
from azimuth.utils.logs import set_logger_config
from azimuth.utils.project import load_dataset_split_managers_from_config, save_config
Expand Down Expand Up @@ -100,9 +99,7 @@ def start_app(config_path, debug=False) -> FastAPI:
if azimuth_config.dataset is None:
raise ValueError("No dataset has been specified in the config.")

local_cluster = default_cluster(large=azimuth_config.large_dask_cluster)

run_startup_tasks(azimuth_config, local_cluster)
run_startup_tasks(azimuth_config)
assert_not_none(_task_manager).client.run(set_logger_config, level)

app = create_app()
Expand Down Expand Up @@ -228,25 +225,23 @@ def create_app() -> FastAPI:
return app


def initialize_managers(azimuth_config: AzimuthConfig, cluster: SpecCluster):
"""Initialize DatasetSplitManagers and TaskManagers.
def initialize_managers_and_config(
azimuth_config: AzimuthConfig, cluster: Optional[SpecCluster] = None
):
"""Initialize DatasetSplitManagers and Config.
Args:
azimuth_config: Configuration
cluster: Dask cluster to use.
azimuth_config: Config
cluster: Dask cluster to use, if different than default.
"""
global _task_manager, _dataset_split_managers, _azimuth_config
_azimuth_config = azimuth_config
if _task_manager is not None:
task_history = _task_manager.current_tasks
if _task_manager:
_task_manager.clear_worker_cache()
_task_manager.restart()
else:
task_history = {}

_task_manager = TaskManager(azimuth_config, cluster=cluster)

_task_manager.current_tasks = task_history
_task_manager = TaskManager(cluster, azimuth_config.large_dask_cluster)

_azimuth_config = azimuth_config
_dataset_split_managers = load_dataset_split_managers_from_config(azimuth_config)


Expand Down Expand Up @@ -283,15 +278,14 @@ def run_validation_module(pipeline_index=None):
task_manager.restart()


def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: SpecCluster):
def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: Optional[SpecCluster] = None):
"""Initialize managers, run validation and startup tasks.
Args:
azimuth_config: Config
cluster: Cluster
cluster: Dask cluster to use, if different than default.
"""
initialize_managers(azimuth_config, cluster)
initialize_managers_and_config(azimuth_config, cluster)

task_manager = assert_not_none(get_task_manager())
# Validate that everything is in order **before** the startup tasks.
Expand All @@ -303,5 +297,5 @@ def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: SpecCluster):
save_config(azimuth_config) # Save only after the validation modules ran successfully

global _startup_tasks, _ready_flag
_startup_tasks = startup_tasks(_dataset_split_managers, task_manager)
_startup_tasks = startup_tasks(_dataset_split_managers, task_manager, azimuth_config)
_ready_flag = Event()
6 changes: 3 additions & 3 deletions azimuth/routers/v1/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from azimuth.app import (
get_config,
get_task_manager,
initialize_managers,
initialize_managers_and_config,
require_editable_config,
run_startup_tasks,
)
Expand Down Expand Up @@ -77,11 +77,11 @@ def patch_config(
) -> AzimuthConfig:
try:
new_config = update_config(old_config=config, partial_config=partial_config)
run_startup_tasks(new_config, task_manager.cluster)
run_startup_tasks(new_config)
except Exception as e:
log.error("Rollback config update due to error", exc_info=e)
new_config = config
initialize_managers(new_config, task_manager.cluster)
initialize_managers_and_config(new_config)
if isinstance(e, AzimuthValidationError):
raise HTTPException(HTTP_400_BAD_REQUEST, detail=str(e))
elif isinstance(e, ValidationError):
Expand Down
18 changes: 7 additions & 11 deletions azimuth/routers/v1/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,12 @@ def get_dataset_info(
get_dataset_split_manager_mapping
),
startup_tasks: Dict[str, Module] = Depends(get_startup_tasks),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
):
eval_dm = dataset_split_managers.get(DatasetSplitName.eval)
training_dm = dataset_split_managers.get(DatasetSplitName.train)
dm = assert_not_none(eval_dm or training_dm)

model_contract = task_manager.config.model_contract

return DatasetInfoResponse(
project_name=config.name,
class_names=dm.get_class_names(),
Expand All @@ -109,19 +106,16 @@ def get_dataset_info(
if training_dm is not None
else [],
startup_tasks={k: v.status() for k, v in startup_tasks.items()},
model_contract=model_contract,
prediction_available=predictions_available(task_manager.config),
perturbation_testing_available=perturbation_testing_available(task_manager.config),
model_contract=config.model_contract,
prediction_available=predictions_available(config),
perturbation_testing_available=perturbation_testing_available(config),
available_dataset_splits=AvailableDatasetSplits(
eval=eval_dm is not None, train=training_dm is not None
),
similarity_available=similarity_available(task_manager.config),
similarity_available=similarity_available(config),
postprocessing_editable=None
if config.pipelines is None
else [
postprocessing_editable(task_manager.config, idx)
for idx in range(len(config.pipelines))
],
else [postprocessing_editable(config, idx) for idx in range(len(config.pipelines))],
)


Expand Down Expand Up @@ -177,6 +171,7 @@ def get_perturbation_testing_summary(
SupportedModule.PerturbationTestingMerged,
dataset_split_name=DatasetSplitName.all,
task_manager=task_manager,
config=config,
last_update=last_update,
mod_options=ModuleOptions(pipeline_index=pipeline_index),
)[0]
Expand All @@ -192,6 +187,7 @@ def get_perturbation_testing_summary(
SupportedModule.PerturbationTestingSummary,
dataset_split_name=DatasetSplitName.all,
task_manager=task_manager,
config=config,
mod_options=ModuleOptions(pipeline_index=pipeline_index),
)[0]
return PerturbationTestingSummary(
Expand Down
4 changes: 4 additions & 0 deletions azimuth/routers/v1/class_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def get_class_overlap_plot(
SupportedModule.ClassOverlap,
dataset_split_name=dataset_split_name,
task_manager=task_manager,
config=config,
last_update=-1,
)[0]
class_overlap_plot_response: ClassOverlapPlotResponse = make_sankey_plot(
Expand All @@ -97,6 +98,7 @@ def get_class_overlap_plot(
def get_class_overlap(
dataset_split_name: DatasetSplitName,
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
dataset_split_managers: Dict[DatasetSplitName, DatasetSplitManager] = Depends(
get_dataset_split_manager_mapping
Expand All @@ -109,6 +111,7 @@ def get_class_overlap(
SupportedModule.ClassOverlap,
dataset_split_name=dataset_split_name,
task_manager=task_manager,
config=config,
last_update=-1,
)[0]
dataset_class_count = class_overlap_result.s_matrix.shape[0]
Expand All @@ -124,6 +127,7 @@ def get_class_overlap(
SupportedModule.ConfusionMatrix,
DatasetSplitName.eval,
task_manager=task_manager,
config=config,
mod_options=ModuleOptions(
pipeline_index=pipeline_index, cf_normalize=False, cf_reorder_classes=False
),
Expand Down
4 changes: 3 additions & 1 deletion azimuth/routers/v1/custom_utterances.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,13 @@ def get_saliency(
utterances: List[str] = Query([], title="Utterances"),
pipeline_index: int = Depends(require_pipeline_index),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
) -> List[SaliencyResponse]:
task_result: List[SaliencyResponse] = get_custom_task_result(
SupportedMethod.Saliency,
task_manager=task_manager,
custom_query={task_manager.config.columns.text_input: utterances},
config=config,
custom_query={config.columns.text_input: utterances},
mod_options=ModuleOptions(pipeline_index=pipeline_index),
)

Expand Down
5 changes: 4 additions & 1 deletion azimuth/routers/v1/dataset_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from fastapi import APIRouter, Depends

from azimuth.app import get_all_dataset_split_managers, get_task_manager
from azimuth.app import get_all_dataset_split_managers, get_config, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.task_manager import TaskManager
from azimuth.types import DatasetSplitName, SupportedModule
Expand All @@ -28,6 +29,7 @@
)
def get_dataset_warnings(
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_managers: Dict[DatasetSplitName, DatasetSplitManager] = Depends(
get_all_dataset_split_managers
),
Expand All @@ -42,6 +44,7 @@ def get_dataset_warnings(
SupportedModule.DatasetWarnings,
dataset_split_name=DatasetSplitName.all,
task_manager=task_manager,
config=config,
last_update=last_update,
)[0]

Expand Down
15 changes: 6 additions & 9 deletions azimuth/routers/v1/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,17 @@ def get_export_perturbation_testing_summary(
SupportedModule.PerturbationTestingSummary,
DatasetSplitName.all,
task_manager=task_manager,
config=config,
last_update=last_update,
mod_options=ModuleOptions(pipeline_index=pipeline_index),
)[0].all_tests_summary

cfg = task_manager.config
df = pd.DataFrame.from_records([t.dict() for t in task_result])
df["example"] = df["example"].apply(lambda i: i["perturbedUtterance"])
file_label = time.strftime("%Y%m%d_%H%M%S", time.localtime())

filename = f"azimuth_export_behavioral_testing_summary_{cfg.name}_{file_label}.csv"

pt = pjoin(cfg.get_artifact_path(), filename)

filename = f"azimuth_export_behavioral_testing_summary_{config.name}_{file_label}.csv"
pt = pjoin(config.get_artifact_path(), filename)
df.to_csv(pt, index=False)

return FileResponse(path=pt, filename=filename)
Expand All @@ -135,15 +133,14 @@ def get_export_perturbed_set(
) -> FileResponse:
pipeline_index_not_null = assert_not_none(pipeline_index)
file_label = time.strftime("%Y%m%d_%H%M%S", time.localtime())
cfg = task_manager.config

filename = f"azimuth_export_modified_set_{cfg.name}_{dataset_split_name}_{file_label}.json"
pt = pjoin(cfg.get_artifact_path(), filename)
filename = f"azimuth_export_modified_set_{config.name}_{dataset_split_name}_{file_label}.json"
pt = pjoin(config.get_artifact_path(), filename)

task_result: List[List[PerturbedUtteranceResult]] = get_standard_task_result(
SupportedModule.PerturbationTesting,
dataset_split_name,
task_manager,
config=config,
mod_options=ModuleOptions(pipeline_index=pipeline_index_not_null),
)

Expand Down
5 changes: 4 additions & 1 deletion azimuth/routers/v1/model_performance/confidence_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from fastapi import APIRouter, Depends, Query

from azimuth.app import get_dataset_split_manager, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.task_manager import TaskManager
from azimuth.types import (
Expand Down Expand Up @@ -36,6 +37,7 @@ def get_confidence_histogram(
dataset_split_name: DatasetSplitName,
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
Expand All @@ -50,6 +52,7 @@ def get_confidence_histogram(
task_name=SupportedModule.ConfidenceHistogram,
dataset_split_name=dataset_split_name,
task_manager=task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand Down
5 changes: 4 additions & 1 deletion azimuth/routers/v1/model_performance/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from fastapi import APIRouter, Depends, Query

from azimuth.app import get_dataset_split_manager, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.task_manager import TaskManager
from azimuth.types import (
Expand Down Expand Up @@ -36,6 +37,7 @@ def get_confusion_matrix(
dataset_split_name: DatasetSplitName,
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
Expand All @@ -54,6 +56,7 @@ def get_confusion_matrix(
SupportedModule.ConfusionMatrix,
dataset_split_name,
task_manager=task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand Down
8 changes: 7 additions & 1 deletion azimuth/routers/v1/model_performance/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from fastapi import APIRouter, Depends, Query

from azimuth.app import get_dataset_split_manager, get_task_manager
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
from azimuth.config import AzimuthConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.modules.model_performance.metrics import MetricsModule
from azimuth.task_manager import TaskManager
Expand Down Expand Up @@ -44,6 +45,7 @@ def get_metrics(
dataset_split_name: DatasetSplitName,
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
Expand All @@ -58,6 +60,7 @@ def get_metrics(
SupportedModule.Metrics,
dataset_split_name,
task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)
Expand All @@ -77,6 +80,7 @@ def get_metrics(
def get_metrics_per_filter(
dataset_split_name: DatasetSplitName,
task_manager: TaskManager = Depends(get_task_manager),
config: AzimuthConfig = Depends(get_config),
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
pipeline_index: int = Depends(require_pipeline_index),
) -> MetricsPerFilterAPIResponse:
Expand All @@ -85,6 +89,7 @@ def get_metrics_per_filter(
SupportedModule.MetricsPerFilter,
dataset_split_name,
task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand All @@ -93,6 +98,7 @@ def get_metrics_per_filter(
SupportedModule.Metrics,
dataset_split_name,
task_manager,
config=config,
mod_options=mod_options,
last_update=dataset_split_manager.last_update,
)[0]
Expand Down
Loading

0 comments on commit 8e72034

Please sign in to comment.