From 1cc0756e3dc27d1e5d41d47b3d545c7843f66f19 Mon Sep 17 00:00:00 2001 From: Tin Lai Date: Sat, 2 Oct 2021 15:21:48 +1000 Subject: [PATCH] Fix #4 and add corresponding tests Signed-off-by: Tin Lai --- env.py | 5 +++ tests/test_env.py | 83 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 tests/test_env.py diff --git a/env.py b/env.py index 3c204ab..9acd460 100644 --- a/env.py +++ b/env.py @@ -125,6 +125,11 @@ def __getattr__(self, attr): """ return object.__getattribute__(self.visualiser, attr) + @property + def sampler(self): + """Pass through attribute access to sampler.""" + return self.args.sampler + @staticmethod def radian_dist(p1: np.ndarray, p2: np.ndarray): """Return the (possibly wrapped) distance between two vector of angles in diff --git a/tests/test_env.py b/tests/test_env.py new file mode 100644 index 0000000..3cff429 --- /dev/null +++ b/tests/test_env.py @@ -0,0 +1,83 @@ +from unittest import TestCase + +import numpy as np + +import env +import visualiser +from planners.basePlanner import Planner +from samplers.baseSampler import Sampler +from main import generate_args +from tests.common_vars import DummyPlannerClass +from utils import planner_registry + + +class TestGenerateArgs(TestCase): + def test_missing_argunment(self): + with self.assertRaises(TypeError): + generate_args() + with self.assertRaises(TypeError): + generate_args(planner_id="rrt") + with self.assertRaises(TypeError): + generate_args(map_fname="maps/4d.png") + + # should not raise error + generate_args(planner_id="rrt", map_fname="maps/4d.png") + + # test error if the planner id has not been registered yet + with self.assertRaises(ValueError): + generate_args(planner_id="my_planner_id", map_fname="maps/4d.png") + + # test that after the planner id is registered, it will work. + planner_registry.register_planner( + planner_id="my_planner_id", + planner_class=DummyPlannerClass, + sampler_id="random", + ) + generate_args(planner_id="my_planner_id", map_fname="maps/4d.png") + + def test_actual_planning(self): + visualiser.VisualiserSwitcher.choose_visualiser("base") + args = generate_args( + planner_id="rrt", + map_fname="maps/test.png", + start_pt=np.array([25, 123]), + goal_pt=np.array([225, 42]), + ) + args.no_display = True + + e = env.Env(args, fixed_seed=0) + ori_method = e.planner.run_once + + # prepare an exception to escape from the planning loop + class PlanningSuccess(Exception): + pass + + def planner_run_once_with_side_effect(*args, **kwargs): + # pass through to planner + ori_method(*args, **kwargs) + if e.planner.c_max < float("inf"): + raise PlanningSuccess() + + # patch the planner run_once such that it will terminates as soon as the + # planning problem is finished. + e.planner.run_once = planner_run_once_with_side_effect + with self.assertRaises(PlanningSuccess): + e.run() + + def test_get_attribute(self): + visualiser.VisualiserSwitcher.choose_visualiser("pygame") + args = generate_args( + planner_id="rrt", + map_fname="maps/test.png", + start_pt=np.array([25, 123]), + goal_pt=np.array([225, 42]), + ) + args.no_display = True + + e = env.Env(args, fixed_seed=0) + + # test get planner + assert isinstance(e.planner, Planner) + + # test get sampler + assert isinstance(e.sampler, Sampler)