From 22609f7157a0b772d91060621d730695b642239c Mon Sep 17 00:00:00 2001 From: Xingyou Song Date: Tue, 25 Jun 2024 14:09:02 -0700 Subject: [PATCH] Aesthetic + small bug fixes to Vizier service PiperOrigin-RevId: 646600873 --- vizier/_src/service/vizier_service.py | 75 +++++++++++---------------- 1 file changed, 30 insertions(+), 45 deletions(-) diff --git a/vizier/_src/service/vizier_service.py b/vizier/_src/service/vizier_service.py index 0c6194512..d65fa122b 100644 --- a/vizier/_src/service/vizier_service.py +++ b/vizier/_src/service/vizier_service.py @@ -15,6 +15,7 @@ from __future__ import annotations """RPC functions implemented from vizier_service.proto.""" + import collections import datetime import threading @@ -24,7 +25,6 @@ import grpc import numpy as np import sqlalchemy as sqla - from vizier import pythia from vizier import pyvizier as vz from vizier._src.service import constants @@ -55,6 +55,10 @@ def _get_current_time() -> timestamp_pb2.Timestamp: return now +StudyResource = resources.StudyResource +TrialResource = resources.TrialResource + + # TODO: remove context = None # TODO: remove context = None class VizierServicer(vizier_service_pb2_grpc.VizierServiceServicer): @@ -196,7 +200,7 @@ def CreateStudy( study_id = study.display_name # Finally create study in database and return it. - study.name = resources.StudyResource(owner_id, study_id).name + study.name = StudyResource(owner_id, study_id).name self.datastore.create_study(study) return study @@ -214,8 +218,8 @@ def ListStudies( context: Optional[grpc.ServicerContext] = None, ) -> vizier_service_pb2.ListStudiesResponse: """Lists all the studies in a region for an associated project.""" - list_of_studies = self.datastore.list_studies(request.parent) - return vizier_service_pb2.ListStudiesResponse(studies=list_of_studies) + studies = self.datastore.list_studies(request.parent) + return vizier_service_pb2.ListStudiesResponse(studies=studies) def DeleteStudy( self, @@ -283,7 +287,7 @@ def SuggestTrials( ) grpc_util.handle_exception(e, context) - study_resource = resources.StudyResource.from_name(study_name) + study_resource = StudyResource.from_name(study_name) study_id = study_resource.study_id owner_id = study_resource.owner_id @@ -306,14 +310,12 @@ def SuggestTrials( start_time = _get_current_time() # Create a new Op if there aren't any active (not done) ops. try: - new_op_number = ( - self.datastore.max_suggestion_operation_number( - study_name, request.client_id - ) - + 1 + old_op_number = self.datastore.max_suggestion_operation_number( + study_name, request.client_id ) except custom_errors.NotFoundError: - new_op_number = 1 + old_op_number = 0 + new_op_number = old_op_number + 1 new_op_name = resources.SuggestionOperationResource( owner_id, study_id, request.client_id, new_op_number ).name @@ -441,9 +443,7 @@ def SuggestTrials( new_trial = new_trials.pop() trial_id = self.datastore.max_trial_id(request.parent) + 1 new_trial.id = str(trial_id) - new_trial.name = resources.TrialResource( - owner_id, study_id, trial_id - ).name + new_trial.name = TrialResource(owner_id, study_id, trial_id).name new_trial.state = study_pb2.Trial.State.ACTIVE new_trial.start_time.CopyFrom(start_time) new_trial.client_id = request.client_id @@ -455,14 +455,12 @@ def SuggestTrials( ).SerializeToString() # Store remaining trials as REQUESTED if Pythia over-delivered. - for remaining_trial in new_trials: + for remain_trial in new_trials: trial_id = self.datastore.max_trial_id(request.parent) + 1 - remaining_trial.id = str(trial_id) - remaining_trial.name = resources.TrialResource( - owner_id, study_id, trial_id - ).name - remaining_trial.state = study_pb2.Trial.State.REQUESTED - self.datastore.create_trial(new_trial) + remain_trial.id = str(trial_id) + remain_trial.name = TrialResource(owner_id, study_id, trial_id).name + remain_trial.state = study_pb2.Trial.State.REQUESTED + self.datastore.create_trial(remain_trial) output_op.done = True self.datastore.update_suggestion_operation(output_op) @@ -491,11 +489,8 @@ def CreateTrial( trial = request.trial with self._study_name_to_lock[request.parent]: trial.id = str(self.datastore.max_trial_id(request.parent) + 1) - trial.name = ( - resources.StudyResource.from_name(request.parent).trial_resource( - trial_id=trial.id - ) - ).name + study_resource = StudyResource.from_name(request.parent) + trial.name = (study_resource.trial_resource(trial.id)).name if trial.state != study_pb2.Trial.State.SUCCEEDED: trial.state = study_pb2.Trial.State.REQUESTED @@ -543,9 +538,7 @@ def AddTrialMeasurement( ImmutableStudyError: If study was already immutable. ImmutableTrialError: If the trial cannot be modified. """ - study_name = resources.TrialResource.from_name( - request.trial_name - ).study_resource.name + study_name = TrialResource.from_name(request.trial_name).study_resource.name if self._study_is_immutable(study_name): e = custom_errors.ImmutableStudyError( 'Study {} is immutable. Cannot add measurement.'.format(study_name) @@ -577,9 +570,7 @@ def CompleteTrial( context: Optional[grpc.ServicerContext] = None, ) -> study_pb2.Trial: """Marks a Trial as complete.""" - study_name = resources.TrialResource.from_name( - request.name - ).study_resource.name + study_name = TrialResource.from_name(request.name).study_resource.name if self._study_is_immutable(study_name): e = custom_errors.ImmutableStudyError( 'Study {} is immutable. Cannot complete trial.'.format(study_name) @@ -625,9 +616,7 @@ def DeleteTrial( context: Optional[grpc.ServicerContext] = None, ) -> empty_pb2.Empty: """Deletes a Trial.""" - study_name = resources.TrialResource.from_name( - request.name - ).study_resource.name + study_name = TrialResource.from_name(request.name).study_resource.name if self._study_is_immutable(study_name): e = custom_errors.ImmutableStudyError( 'Study {} is immutable. Cannot delete trial.'.format(study_name) @@ -679,7 +668,7 @@ def CheckTrialEarlyStoppingState( ImmutableStudyError: If study was already immutable. ImmutableTrialError: If the trial cannot be modified. """ - trial_resource = resources.TrialResource.from_name(request.trial_name) + trial_resource = TrialResource.from_name(request.trial_name) study_name = trial_resource.study_resource.name if self._study_is_immutable(study_name): e = custom_errors.ImmutableStudyError( @@ -841,9 +830,7 @@ def StopTrial( ImmutableStudyError: If study was already immutable. ImmutableTrialError: If the trial cannot be modified. """ - study_name = resources.TrialResource.from_name( - request.name - ).study_resource.name + study_name = TrialResource.from_name(request.name).study_resource.name if self._study_is_immutable(study_name): e = custom_errors.ImmutableStudyError( 'Study {} is immutable. Cannot stop trial.'.format(study_name) @@ -926,12 +913,10 @@ def ListOptimalTrials( # Find Pareto optimal trials. ys = np.array(considered_trial_objective_vectors) n = ys.shape[0] - dominated = np.asarray( - [ - [np.all(ys[i] <= ys[j]) & np.any(ys[j] > ys[i]) for i in range(n)] - for j in range(n) - ] - ) + dominated = np.asarray([ + [np.all(ys[i] <= ys[j]) & np.any(ys[j] > ys[i]) for i in range(n)] + for j in range(n) + ]) optimal_booleans = np.logical_not(np.any(dominated, axis=0)) optimal_trials = [] for i, boolean in enumerate(list(optimal_booleans)):