diff --git a/src/prog_server/models/prog_server.py b/src/prog_server/models/prog_server.py index 70dee13..655421a 100644 --- a/src/prog_server/models/prog_server.py +++ b/src/prog_server/models/prog_server.py @@ -19,7 +19,7 @@ class ProgServer(): def __init__(self): self.process = None - def run(self, host=DEFAULT_HOST, port=DEFAULT_PORT, debug=False, models={}, predictors={}, **kwargs) -> None: + def run(self, host=DEFAULT_HOST, port=DEFAULT_PORT, debug=False, models={}, predictors={}, state_estimators={}, **kwargs) -> None: """Run the server (blocking) Keyword Args: @@ -27,7 +27,8 @@ def run(self, host=DEFAULT_HOST, port=DEFAULT_PORT, debug=False, models={}, pred port (int, optional): Server port. Defaults to 8555. debug (bool, optional): If the server is started in debug mode models (dict[str, PrognosticsModel]): a dictionary of extra models to consider. The key is the name used to identify it. - predictors (dict[str, predictors.Predictor]): a dictionary of extra predictors to consider. The key is the name used to identify it. + predictors (dict[str, predictors.Predictor]): a dictionary of extra predictors to consider. The key is the name used to identify it. + state_estimators (dict[str, state_estimators.StateEstimator]): a dictionary of extra estimators to consider. The key is the name used to identify it. """ if not isinstance(models, dict): raise TypeError("Extra models (`model` arg in prog_server.run() or start()) must be in a dictionary in the form `name: model_name`") @@ -39,6 +40,11 @@ def run(self, host=DEFAULT_HOST, port=DEFAULT_PORT, debug=False, models={}, pred session.extra_predictors.update(predictors) + if not isinstance(state_estimators, dict): + raise TypeError("Custom Estimator (`state_estimators` arg in prog_server.run() or start()) must be in a dictionary in the form `name: state_est_name`") + + session.extra_estimators.update(state_estimators) + self.host = host self.port = port self.process = app.run(host=host, port=port, debug=debug) diff --git a/src/prog_server/models/session.py b/src/prog_server/models/session.py index 2396f6d..4ed6ad5 100644 --- a/src/prog_server/models/session.py +++ b/src/prog_server/models/session.py @@ -13,7 +13,7 @@ extra_models = {} extra_predictors = {} - +extra_estimators = {} class Session(): def __init__(self, session_id, @@ -116,23 +116,41 @@ def __init__(self, session_id, # Otherwise, will have to be initialized later # Check state estimator and predictor data try: - getattr(state_estimators, state_est_name) + if self.state_est_name not in extra_estimators: + getattr(state_estimators, state_est_name) except AttributeError: abort(400, f"Invalid state estimator name {state_est_name}") def __initialize(self, x0, predict_queue=True): app.logger.debug("Initializing...") - state_est_class = getattr(state_estimators, self.state_est_name) + #Estimator + try: + if self.state_est_name in extra_estimators: + state_est_class = extra_estimators[self.state_est_name] + else: + state_est_class = getattr(state_estimators, self.state_est_name) + except AttributeError: + abort(400, f"Invalid state estimator name {self.state_est_name}") app.logger.debug(f"Creating State Estimator of type {self.state_est_name}") + if isinstance(x0, str): - x0 = json.loads(x0) + x0 = json.loads(x0) #loads the initial state if set(self.model.states) != set(list(x0.keys())): abort(400, f"Initial state must have every state in model. states. Had {list(x0.keys())}, needed {self.model.states}") - - try: - self.state_est = state_est_class(self.model, x0, **self.state_est_cfg) - except Exception as e: - abort(400, f"Could not instantiate state estimator with input: {e}") + + if isinstance(state_est_class, type) and issubclass(state_est_class, state_estimators.StateEstimator): + try: + self.state_est = state_est_class(self.model, x0, **self.state_est_cfg) + except Exception as e: + abort(400, f"Could not instantiate state estimator with input: {e}") + elif isinstance(state_est_class, state_estimators.StateEstimator): + # state_est_class is an instance of state_estimators.StateEstimator - use the object instead + # This happens for user state estimators that are added to the server at startup. + self.state_est = deepcopy(state_est_class) + # Apply any configuration changes, overriding estimator config + self.state_est.parameters.update(self.state_est_cfg) + else: + abort(400, f"Invalid state estimator type {type(self.state_est_name)} for estimator {self.state_est_name}. For custom classes, the state estimator must be mentioned with quotes in the est argument") self.initialized = True if predict_queue: diff --git a/tests/integration.py b/tests/integration.py index fec7183..602a730 100644 --- a/tests/integration.py +++ b/tests/integration.py @@ -8,6 +8,11 @@ from progpy.predictors import MonteCarlo from progpy.uncertain_data import MultivariateNormalDist from progpy.models import ThrownObject +from progpy.models.thrown_object import LinearThrownObject +from progpy.state_estimators import KalmanFilter +from progpy.state_estimators import ParticleFilter +from progpy.uncertain_data import MultivariateNormalDist +from progpy.uncertain_data import UnweightedSamples class IntegrationTest(unittest.TestCase): @@ -307,6 +312,53 @@ def test_custom_predictors(self): prog_server.stop() prog_server.start() + def test_custom_estimators(self): + # Restart server with model + prog_server.stop() + + #define the custom model + ball = LinearThrownObject(thrower_height=1.5, throwing_speed=20) + x_guess = ball.StateContainer({'x': 1.75, 'v': 35}) + kf = KalmanFilter(ball, x_guess) + with self.assertRaises(Exception): + # state_estimators not a dictionary + prog_server.start(models ={'ball':ball}, port=9883, state_estimators=[20]) + prog_server.start(models ={'ball':ball}, port=9883, state_estimators={'kf':kf}) + ball_session = prog_client.Session('ball', port=9883, state_est='kf') + + # time step (s) + dt = 0.01 + x = ball.initialize() + # Initial input + u = ball.InputContainer({}) + + # Iterate forward 1 second and compare + x = ball.next_state(x, u, 1) + ball_session.send_data(time=1, x=x['x']) + + t, x_s = ball_session.get_state() + + # To check if the output state is multivariate normal distribution + self.assertIsInstance(x_s, MultivariateNormalDist) + + # Setup Particle Filter + pf = ParticleFilter(ball, x_guess) + prog_server.stop() + prog_server.start(models ={'ball':ball}, port=9883, state_estimators={'pf': pf, 'kf':kf}) + ball_session = prog_client.Session('ball', port=9883, state_est='pf') + + # Iterate forward 1 second and compare + x = ball.next_state(x, u, 1) + ball_session.send_data(time=1, x=x['x']) + + t, x_s = ball_session.get_state() + + # Ensure that PF output is unweighted samples + self.assertIsInstance(x_s, UnweightedSamples) + + prog_server.stop() + prog_server.start() + # This allows the module to be executed directly def run_tests():