Skip to content

Commit

Permalink
Fix bug in printing of FSTRIPS instances. Closes #69
Browse files Browse the repository at this point in the history
  • Loading branch information
gfrances committed Aug 2, 2019
1 parent b85371a commit e145bb0
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 18 deletions.
27 changes: 20 additions & 7 deletions src/tarski/io/fstrips.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
from collections import defaultdict
from typing import Optional, List

from .common import load_tpl
from ..model import ExtensionalFunctionDefinition
from ..syntax import Tautology, Contradiction, Atom, CompoundTerm, CompoundFormula, QuantifiedFormula, \
Term, Variable, Constant, Formula
Term, Variable, Constant, Formula, symref
from ..syntax.sorts import parent, Interval, ancestors

from ._fstrips.common import tarski_to_pddl_type, get_requirements_string
Expand Down Expand Up @@ -150,12 +151,18 @@ def __init__(self, problem):
self.problem = problem
self.lang = problem.language

def write(self, domain_filename, instance_filename, domain_constants=None):
def write(self, domain_filename, instance_filename, domain_constants: Optional[List[Constant]] = None):
domain_constants = domain_constants or []
self.write_domain(domain_filename, domain_constants)
self.write_instance(instance_filename, domain_constants)

def print_domain(self, constant_objects=None):
def print_domain(self, constant_objects: Optional[List[Constant]] = None):
""" Generate the PDDL string representation that would correspond to the domain.pddl file of the current
planning problem.
The parameter `constant_objects` is used to determine which of the PDDL objects are printed as "PDDL domain
constants", and which as "PDDL instance objects", which is something that cannot be determined from the problem
information alone. If `constant_objects` is None, all objects are considered instance objects.
"""
tpl = load_tpl("fstrips_domain.tpl")
content = tpl.format(
header_info="",
Expand All @@ -166,20 +173,26 @@ def print_domain(self, constant_objects=None):
predicates=self.get_predicates(),
actions=self.get_actions(),
derived=self.get_derived_predicates(),
constants=print_objects(constant_objects if constant_objects else set()),
constants=print_objects(constant_objects if constant_objects else []),
)
return content

def write_domain(self, filename, constant_objects):
with open(filename, 'w') as file:
file.write(self.print_domain(constant_objects))

def print_instance(self, constant_objects=None):
def print_instance(self, constant_objects: Optional[List[Constant]] = None):
""" Generate the PDDL string representation that would correspond to the instance.pddl file of the current
planning problem.
The parameter `constant_objects` is used to determine which of the PDDL objects are printed as "PDDL domain
constants", and which as "PDDL instance objects", which is something that cannot be determined from the problem
information alone. If `constant_objects` is None, all objects are considered instance objects.
"""
tpl = load_tpl("fstrips_instance.tpl")

# Only objects which are not declared in the domain file need to be printed in the instance file
constant_objs_set = set(constant_objects) if constant_objects else set()
instance_objects = [c for c in self.problem.language.constants() if c not in constant_objs_set]
constants = {symref(c) for c in constant_objects} if constant_objects else set()
instance_objects = [c for c in self.problem.language.constants() if symref(c) not in constants]

content = tpl.format(
header_info="",
Expand Down
36 changes: 25 additions & 11 deletions tests/io/test_fstrips_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,7 @@ def write_problem(problem):

def get_bw_elements():
problem = generate_small_fstrips_bw_problem()
lang = problem.language
# writer = FstripsWriter(problem)

loc = lang.get_function("loc")
clear = lang.get_predicate("clear")
b1 = lang.get_constant("b1")
table = lang.get_constant("table")
loc, clear, b1, table = problem.language.get("loc", "clear", "b1", "table")
return problem, loc, clear, b1, table


Expand Down Expand Up @@ -60,10 +54,6 @@ def test_effect_writing():
assert s6 == "(forall (?b - block) (when (clear ?b) (assign (loc ?b) table)))"


# def test_atom_writing():
# pass


def test_objects_writing():
problem, _, _, _, _ = get_bw_elements()
assert print_objects(problem.language.constants()) == "b1 b2 b3 b4 - block\n table - place"
Expand All @@ -82,6 +72,30 @@ def test_gridworld_writing():
write_problem(problem)


def test_blocksworld_writing():
problem, _, _, _, _ = get_bw_elements()
write_problem(problem)


def test_blocksworld_writing_with_different_constants():
problem, loc, clear, b1, table = get_bw_elements()
writer = FstripsWriter(problem)
instance_model_string = writer.print_instance()
domain_model_string = writer.print_domain()
assert "b1" not in domain_model_string
assert "b1 b2 b3 b4 - block\n table - place" in instance_model_string

constant_objects = [b1, table]
instance_model_string = writer.print_instance(constant_objects=constant_objects)
domain_model_string = writer.print_domain(constant_objects=constant_objects)

assert "b1 - block" in domain_model_string
assert "table - place" in domain_model_string
assert """(:objects
b2 b3 b4 - block
)""" in instance_model_string


def test_action_costs_numeric_fluents_requirements():
problem = parcprinter.create_small_task()
writer = FstripsWriter(problem)
Expand Down

0 comments on commit e145bb0

Please sign in to comment.