Skip to content

Commit

Permalink
Merge branch 'fix/code_structure'
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanfeiLin committed Apr 17, 2024
2 parents 8e0a98f + aadf672 commit f089224
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 46 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ __pycache__
public
.coverage
build/

outputs/
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[submodule "commonroad-search"]
path = commonroad-search
url = git@gitlab.lrz.de:tum-cps/commonroad-search.git
url = https://gitlab.lrz.de/tum-cps/commonroad-search.git
2 changes: 1 addition & 1 deletion commonroad-search
Submodule commonroad-search updated from 154ade to dd3ee6
File renamed without changes.
49 changes: 49 additions & 0 deletions drplanner/diagnostics/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import math
import copy
import os

from commonroad.scenario.scenario import Scenario
from commonroad.planning.planning_problem import PlanningProblemSet

from drplanner.utils.config import DrPlannerConfiguration
from drplanner.prompter.prompter import Prompter


class DrPlannerBase:
def __init__(
self,
scenario: Scenario,
planning_problem_set: PlanningProblemSet,
config: DrPlannerConfiguration):

self.scenario = scenario
self.planning_problem_set = planning_problem_set
# otherwise the planning problem might be changed during the initialization of the planner
self.planning_problem = copy.deepcopy(
list(self.planning_problem_set.planning_problem_dict.values())[0]
)
self.config = config

self._visualize = self.config.visualize
self._save_solution = self.config.save_solution

self.THRESHOLD = config.cost_threshold
self.TOKEN_LIMIT = config.token_limit
self.ITERATION_MAX = config.iteration_max

# todo: load from solution file
self.desired_cost = self.config.desired_cost
self.initial_cost = math.inf
self.current_cost = None

self.token_count = 0
self.cost_list = []

self.dir_output = os.path.join(os.path.dirname(__file__), "../../outputs/solutions/")
os.makedirs(os.path.dirname(self.dir_output), exist_ok=True) # Ensure the directory exists

self.prompter = Prompter(
self.scenario, self.planning_problem, self.config.openai_api_key, self.config.gpt_version
)
self.prompter.LLM.temperature = self.config.temperature

Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
# make sure the SMP has been installed successfully
try:
import SMP

print("Installed SMP module is called.")
except ImportError as e:
import sys
import os

current_file_dir = os.path.dirname(os.path.abspath(__file__))
smp_path = os.path.join(current_file_dir, '../../commonroad-search/')
sys.path.append(smp_path)
Expand All @@ -44,34 +46,30 @@
)

from drplanner.utils.gpt import num_tokens_from_messages
from drplanner.utils.config import DiagnoserConfiguration
from drplanner.utils.config import DrPlannerConfiguration
from drplanner.diagnostics.base import DrPlannerBase
from drplanner.prompter.prompter import Prompter

import numpy as np


class DrPlanner:
class DrSearchPlanner(DrPlannerBase):
def __init__(
self,
scenario: Scenario,
planning_problem_set: PlanningProblemSet,
motion_primitives_id: str,
planner_id: str,
config: DiagnoserConfiguration,
self,
scenario: Scenario,
planning_problem_set: PlanningProblemSet,
config: DrPlannerConfiguration,
motion_primitives_id: str,
planner_id: str,
):
self.scenario = scenario
self.planning_problem_set = planning_problem_set

# otherwise the planning problem might be changed during the initialization of the planner
self.planning_problem = copy.deepcopy(
list(self.planning_problem_set.planning_problem_dict.values())[0]
)
super().__init__(scenario, planning_problem_set, config)

# initialize the motion primitives
self.motion_primitives_id = motion_primitives_id

# initialize the motion planner
self.planner_id = planner_id
# import the planner
planner_name = f"drplanner.planners.student_{self.planner_id}"
planner_module = importlib.import_module(planner_name)
automaton = ManeuverAutomaton.generate_automaton(motion_primitives_id)
Expand All @@ -81,36 +79,14 @@ def __init__(
self.scenario, self.planning_problem, automaton, DefaultPlotConfig
)

self._visualize = False
self._save_solution = True

# initialize the vehicle parameters and the cost function
self.cost_type = CostFunction.SM1
self.vehicle_type = VehicleType.BMW_320i
self.vehicle_model = VehicleModel.KS
self.cost_evaluator = CostFunctionEvaluator(
self.cost_type, VehicleType.BMW_320i
)

self.dir_output = "../../outputs/solutions/"

self.prompter = Prompter(
self.scenario, self.planning_problem, config.api_key, config.gpt_version
)
self.prompter.LLM.temperature = config.temperature

# todo: load from solution file
self.desired_cost = config.desired_cost
self.current_cost = None

self.token_count = 0

self.THRESHOLD = config.cost_threshold
self.TOKEN_LIMIT = config.token_limit
self.ITERATION_MAX = config.iteration_max

self.cost_list = []
self.initial_cost = math.inf

def diagnose_repair(self):
nr_iteration = 0
try:
Expand All @@ -123,9 +99,9 @@ def diagnose_repair(self):
result = None
self.initial_cost = self.current_cost
while (
abs(self.current_cost - self.desired_cost) > self.THRESHOLD
and self.token_count < self.TOKEN_LIMIT
and nr_iteration < self.ITERATION_MAX
abs(self.current_cost - self.desired_cost) > self.THRESHOLD
and self.token_count < self.TOKEN_LIMIT
and nr_iteration < self.ITERATION_MAX
):
print(
f"<{nr_iteration}>: total cost {self.current_cost} (desired: {self.desired_cost}),"
Expand Down Expand Up @@ -158,7 +134,7 @@ def diagnose_repair(self):
planned_trajectory = self.plan(nr_iteration)
# add feedback
prompt_planner += (
self.add_feedback(planned_trajectory, nr_iteration) + "\n"
self.add_feedback(planned_trajectory, nr_iteration) + "\n"
)
except Exception as e:
error_traceback = (
Expand Down Expand Up @@ -236,7 +212,7 @@ def repair(self, diagnosis_result: Union[str, None]):
self.motion_planner.frontier = PriorityQueue()

def describe(
self, planned_trajectory: Union[Trajectory, None]
self, planned_trajectory: Union[Trajectory, None]
) -> (str, PlanningProblemCostResult):
template = self.prompter.astar_template

Expand Down Expand Up @@ -305,7 +281,7 @@ def plan(self, nr_iter: int) -> Trajectory:
self.planning_problem.planning_problem_id
),
planning_problem_solution.trajectory,
output_path="..",
output_path=self.dir_output,
)
if self._save_solution:
# create solution object
Expand Down

0 comments on commit f089224

Please sign in to comment.