Skip to content

Commit

Permalink
formatting the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
BluemlJ committed Dec 16, 2024
1 parent 449622f commit b409b73
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 54 deletions.
21 changes: 13 additions & 8 deletions tests/0_general/test_easystarters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import numpy as np
from ocatari.core import OCAtari


def test_invalid_game_name():
"""
Test if the game environments are playable by running a few steps.
"""
with pytest.raises(Exception):
env = OCAtari(env_name="AEL/Pong", mode="ram", obs_mode="obj")



@pytest.mark.parametrize(
"cbs, obs_mode",
[
Expand All @@ -25,22 +27,22 @@ def test_multiple_stacks(cbs, obs_mode):
"""
Test cloning and restoring the state of the environment.
"""
env = OCAtari(env_name="ALE/Pong-v5", mode="ram", create_buffer_stacks=cbs, obs_mode=obs_mode)
env = OCAtari(env_name="ALE/Pong-v5", mode="ram",
create_buffer_stacks=cbs, obs_mode=obs_mode)
env.reset()

if "dqn" in cbs or obs_mode == "dqn":
assert env.create_dqn_stack
else:
assert env.create_dqn_stack == False
assert env.create_dqn_stack == False
if "obj" in cbs or obs_mode == "obj":
assert env.create_ns_stack
else:
assert env.create_ns_stack == False
if "ori" in cbs or obs_mode == "ori":
assert env.create_rgb_stack
assert env.create_rgb_stack
else:
assert env.create_rgb_stack == False


action = env.action_space.sample() # pick random action
obs, reward, truncated, terminated, info = env.step(action)
Expand All @@ -58,6 +60,7 @@ def test_multiple_stacks(cbs, obs_mode):
else:
assert env._state_buffer_rgb is None


def test_clone_restore_state():
"""
Test cloning and restoring the state of the environment.
Expand All @@ -69,9 +72,11 @@ def test_clone_restore_state():
env.step(0) # Take a step to change the state
env._restore_state(initial_state) # Restore the previous state
restored_state = env._clone_state()
assert np.array_equal(initial_state, restored_state), "Restored state should match the initial state."
assert np.array_equal(
initial_state, restored_state), "Restored state should match the initial state."
env.close()


def test_step_with_invalid_action():
"""
Test stepping through the environment with an invalid action.
Expand All @@ -80,4 +85,4 @@ def test_step_with_invalid_action():
env.reset()
with pytest.raises(Exception):
env.step(999) # Invalid action should raise an exception
env.close()
env.close()
21 changes: 15 additions & 6 deletions tests/0_general/test_obj_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def test_object_state_size(env_name, mode):
assert obs.shape == env._env.observation_space.shape
env.close()


def test_object_extraction_properties():
"""
Test properties of objects extracted from RAM and Vision modes.
Expand All @@ -33,11 +34,15 @@ def test_object_extraction_properties():
env.reset()
env.detect_objects() # Use both RAM and vision-based detection
for obj in env.objects:
assert hasattr(obj, 'category'), "Extracted object should have a category attribute."
assert hasattr(obj, 'xy'), "Extracted object should have an xy attribute (position)."
assert hasattr(obj, 'wh'), "Extracted object should have a wh attribute (width and height)."
assert hasattr(
obj, 'category'), "Extracted object should have a category attribute."
assert hasattr(
obj, 'xy'), "Extracted object should have an xy attribute (position)."
assert hasattr(
obj, 'wh'), "Extracted object should have a wh attribute (width and height)."
env.close()


def test_object_extraction_count():
"""
Test the count of objects extracted in RAM and Vision modes.
Expand All @@ -61,9 +66,11 @@ def test_object_extraction_consistency():
env.step(0)
env.detect_objects()
new_objects = env.objects.copy()
assert len(initial_objects) == len(new_objects), "The number of detected objects should remain consistent between steps."
assert len(initial_objects) == len(
new_objects), "The number of detected objects should remain consistent between steps."
env.close()


def test_object_extraction_category_types():
"""
Test that object categories are extracted correctly and are non-empty strings.
Expand All @@ -72,6 +79,8 @@ def test_object_extraction_category_types():
env.reset()
env.detect_objects()
for obj in env.objects:
assert isinstance(obj.category, str), "Object category should be a string."
assert len(obj.category) > 0, "Object category should not be an empty string."
assert isinstance(
obj.category, str), "Object category should be a string."
assert len(
obj.category) > 0, "Object category should not be an empty string."
env.close()
15 changes: 10 additions & 5 deletions tests/0_general/test_ram.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def test_set_get_ram():
assert modified_ram[5] == 10, "RAM state at position 5 should be updated to 10."
env.close()


def test_ram_extraction_values():
"""
Test the values extracted from RAM to ensure they are within the expected range.
Expand All @@ -25,7 +26,8 @@ def test_ram_extraction_values():
ram = env.get_ram()
assert ram is not None, "RAM state should not be None."
assert len(ram) > 0, "RAM state should contain elements."
assert all(0 <= value <= 255 for value in ram), "All RAM values should be in the range 0-255."
assert all(
0 <= value <= 255 for value in ram), "All RAM values should be in the range 0-255."
env.close()


Expand All @@ -39,7 +41,8 @@ def test_ram_extraction_specific_addresses():
assert ram is not None, "RAM state should not be None."
assert len(ram) > 10, "RAM state should contain enough elements."
specific_value = ram[10]
assert isinstance(specific_value, np.uint8), "RAM value at specific address should be an integer."
assert isinstance(
specific_value, np.uint8), "RAM value at specific address should be an integer."
assert 0 <= specific_value <= 255, "RAM value at specific address should be in the range 0-255."
env.close()

Expand All @@ -53,7 +56,8 @@ def test_ram_state_changes():
initial_ram = env.get_ram()
env.step(0) # Take a step in the environment
new_ram = env.get_ram()
assert not np.array_equal(initial_ram, new_ram), "RAM state should change after taking a step."
assert not np.array_equal(
initial_ram, new_ram), "RAM state should change after taking a step."
env.close()


Expand All @@ -67,5 +71,6 @@ def test_ram_state_reset():
env.step(0) # Take a step to change the state
env.reset() # Reset the environment
reset_ram = env.get_ram()
assert np.array_equal(initial_ram, reset_ram), "RAM state should be reset to initial values."
env.close()
assert np.array_equal(
initial_ram, reset_ram), "RAM state should be reset to initial values."
env.close()
18 changes: 11 additions & 7 deletions tests/1_games/test_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
GAMES = [f"ALE/{g}-v5" for g in os.getenv("GAMES").split()]
else:
GAMES = [f"ALE/{g}-v5" for g in AVAILABLE_GAMES]
#GAMES = ["ALE/Freeway-v5"]
# GAMES = ["ALE/Freeway-v5"]

PICKLE_PATH = f"pickle_files"

MODES = ["ram", "vision"]
OBS_MODES = ["ori", "dqn"]
FRAMESKIPS = [1, 4]


def get_states(game_name):
path = f"{PICKLE_PATH}/{game_name}"
if os.path.exists(path):
Expand All @@ -42,10 +43,10 @@ def load_pickle_state(env, game_name, state_nr):
raise FileNotFoundError(f"Pickle file {pickle_file_path} not found.")



@pytest.mark.parametrize("env_name, mode, obs_mode, frameskip, state_nr", [(game, mode, obs_mode, frameskip, state_nr) for game in GAMES for mode in MODES for obs_mode in OBS_MODES for frameskip in FRAMESKIPS for state_nr in get_states(game.split("/")[1].split("-")[0])])
def test_seeding(env_name, mode, obs_mode, frameskip, state_nr):
env = OCAtari(env_name, hud=False, mode=mode, render_mode="rgb_array", render_oc_overlay=False, obs_mode=obs_mode, frameskip=frameskip)
env = OCAtari(env_name, hud=False, mode=mode, render_mode="rgb_array",
render_oc_overlay=False, obs_mode=obs_mode, frameskip=frameskip)
load_pickle_state(env, env.game_name, state_nr)

env.action_space.seed(42)
Expand All @@ -59,7 +60,8 @@ def test_seeding(env_name, mode, obs_mode, frameskip, state_nr):

obs1 = obs
env.close()
env = OCAtari(env_name, hud=False, mode=mode, render_mode="rgb_array", render_oc_overlay=False, obs_mode=obs_mode, frameskip=frameskip)
env = OCAtari(env_name, hud=False, mode=mode, render_mode="rgb_array",
render_oc_overlay=False, obs_mode=obs_mode, frameskip=frameskip)
load_pickle_state(env, env.game_name, state_nr)

env.action_space.seed(42)
Expand All @@ -84,7 +86,8 @@ def test_seeding(env_name, mode, obs_mode, frameskip, state_nr):

@pytest.mark.parametrize("env_name, obs_mode, frameskip, state_nr", [(game, obs_mode, frameskip, state_nr) for game in GAMES for obs_mode in OBS_MODES for frameskip in FRAMESKIPS for state_nr in get_states(game.split("/")[1].split("-")[0])])
def test_outputsimilarity_between_modes(env_name, obs_mode, frameskip, state_nr):
env = OCAtari(env_name, hud=False, mode="ram", render_mode="rgb_array", render_oc_overlay=False, obs_mode=obs_mode, frameskip=frameskip)
env = OCAtari(env_name, hud=False, mode="ram", render_mode="rgb_array",
render_oc_overlay=False, obs_mode=obs_mode, frameskip=frameskip)
load_pickle_state(env, env.game_name, state_nr)

env.action_space.seed(42)
Expand All @@ -98,7 +101,8 @@ def test_outputsimilarity_between_modes(env_name, obs_mode, frameskip, state_nr)

obs1 = obs
env.close()
env = OCAtari(env_name, hud=False, mode="vision", render_mode="rgb_array", render_oc_overlay=False, obs_mode=obs_mode, frameskip=frameskip)
env = OCAtari(env_name, hud=False, mode="vision", render_mode="rgb_array",
render_oc_overlay=False, obs_mode=obs_mode, frameskip=frameskip)
load_pickle_state(env, env.game_name, state_nr)

env.action_space.seed(42)
Expand All @@ -118,4 +122,4 @@ def test_outputsimilarity_between_modes(env_name, obs_mode, frameskip, state_nr)

# Compute the difference
assert np.allclose(obs1, obs2, rtol=2)
env.close()
env.close()
23 changes: 15 additions & 8 deletions tests/1_games/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
GAMES = [f"ALE/{g}-v5" for g in os.getenv("GAMES").split()]
else:
GAMES = [f"ALE/{g}-v5" for g in AVAILABLE_GAMES]
#GAMES = ["ALE/Freeway-v5"]
# GAMES = ["ALE/Freeway-v5"]

PICKLE_PATH = f"pickle_files"

MODES = ["ram", "vision"]
OBS_MODES = ["ori", "dqn"]
FRAMESKIPS = [1, 4]


def get_states(game_name):
path = f"{PICKLE_PATH}/{game_name}"
if os.path.exists(path):
Expand Down Expand Up @@ -73,30 +74,36 @@ def test_environment_step(env_name, mode, obs_mode, frameskip, state_nr):
"""
Test stepping through the environment.
"""
env = OCAtari(env_name=env_name, mode=mode, obs_mode=obs_mode, frameskip=frameskip)
env = OCAtari(env_name=env_name, mode=mode,
obs_mode=obs_mode, frameskip=frameskip)
load_pickle_state(env, env.game_name, state_nr)
env.reset()
obs, reward, truncated, terminated, info = env.step(0) # Execute a random action (e.g., action 0)
obs, reward, truncated, terminated, info = env.step(
0) # Execute a random action (e.g., action 0)
assert obs is not None, "Observation should not be None after taking a step."
assert isinstance(reward, (int, float)), "Reward should be a number."
assert isinstance(truncated, bool), "Truncated should be a boolean value."
assert isinstance(terminated, bool), "Terminated should be a boolean value."
assert isinstance(
terminated, bool), "Terminated should be a boolean value."
assert isinstance(info, dict), "Info should be a dictionary."
env.close()


@pytest.mark.parametrize("env_name, mode, obs_mode, frameskip, state_nr", [(game, mode, obs_mode, frameskip, state_nr) for game in GAMES for mode in MODES for obs_mode in OBS_MODES for frameskip in FRAMESKIPS for state_nr in get_states(game.split("/")[1].split("-")[0])])
def test_environment_step(env_name, mode, obs_mode, frameskip, state_nr):
"""
Test if objects are in correct slots
"""
env = OCAtari(env_name=env_name, mode=mode, obs_mode=obs_mode, frameskip=frameskip)
env = OCAtari(env_name=env_name, mode=mode,
obs_mode=obs_mode, frameskip=frameskip)
load_pickle_state(env, env.game_name, state_nr)
env.reset()
obs, reward, truncated, terminated, info = env.step(0) # Execute a random action (e.g., action 0)
obs, reward, truncated, terminated, info = env.step(
0) # Execute a random action (e.g., action 0)
assert obs is not None, "Observation should not be None after taking a step."
assert isinstance(reward, (int, float)), "Reward should be a number."
assert isinstance(truncated, bool), "Truncated should be a boolean value."
assert isinstance(terminated, bool), "Terminated should be a boolean value."
assert isinstance(
terminated, bool), "Terminated should be a boolean value."
assert isinstance(info, dict), "Info should be a dictionary."
env.close()

30 changes: 20 additions & 10 deletions tests/1_games/test_obj_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,28 @@
if os.getenv("GAMES") != None:
GAMES = [f"ALE/{g}-v5" for g in os.getenv("GAMES").split()]
else:
GAMES = ["ALE/Freeway-v5"]
GAMES = ["ALE/Amidar-v5", "ALE/Asterix-v5", "ALE/Asteroids-v5", "ALE/BankHeist-v5", "ALE/Berzerk-v5", "ALE/Bowling-v5", "ALE/Breakout-v5", "ALE/DonkeyKong-v5", "ALE/FishingDerby-v5", "ALE/Freeway-v5", "ALE/Frogger-v5",
"ALE/Frostbite-v5", "ALE/Gopher-v5", "ALE/IceHockey-v5", "ALE/Kangaroo-v5", "ALE/MontezumaRevenge-v5", "ALE/MsPacman-v5", "ALE/Pong-v5", "ALE/Seaquest-v5", "ALE/Skiing-v5", "ALE/SpaceInvaders-v5", "ALE/Tennis-v5"]

MODES = ["ram", "vision"]
OBS_MODES = ["obj"]
FRAMESKIPS = [1, 4]

PICKLE_PATH = "pickle_files"


def get_states(game_name):
path = f"{PICKLE_PATH}/{game_name}"
if os.path.exists(path):
return [f for f in os.listdir(path) if f.startswith("state_") and f.endswith(".pkl")]
else:
return [""]



def load_pickle_state(env, game_name, state_nr):
"""
Load the state from the pickle file for the given game.
"""
"""
pickle_file_path = os.path.join(PICKLE_PATH, game_name, f"{state_nr}.pkl")
if os.path.exists(pickle_file_path):
with open(pickle_file_path, "rb") as f:
Expand All @@ -42,7 +45,8 @@ def load_pickle_state(env, game_name, state_nr):

@pytest.mark.parametrize("env_name, mode, obs_mode, frameskip, state_nr", [(game, mode, obs_mode, frameskip, state_nr) for game in GAMES for mode in MODES for obs_mode in OBS_MODES for frameskip in FRAMESKIPS for state_nr in get_states(game.split("/")[1].split("-")[0])])
def test_seeding(env_name, mode, obs_mode, frameskip, state_nr):
env = OCAtari(env_name, hud=False, mode=mode, render_mode="rgb_array", render_oc_overlay=False, obs_mode=obs_mode, frameskip=frameskip)
env = OCAtari(env_name, hud=False, mode=mode, render_mode="rgb_array",
render_oc_overlay=False, obs_mode=obs_mode, frameskip=frameskip)
load_pickle_state(env, env.game_name, state_nr)

env.action_space.seed(42)
Expand All @@ -58,7 +62,8 @@ def test_seeding(env_name, mode, obs_mode, frameskip, state_nr):
env.close()

# Reinitialize the environment and load the same pickle state
env = OCAtari(env_name, hud=False, mode=mode, render_mode="rgb_array", render_oc_overlay=False, obs_mode=obs_mode, frameskip=frameskip)
env = OCAtari(env_name, hud=False, mode=mode, render_mode="rgb_array",
render_oc_overlay=False, obs_mode=obs_mode, frameskip=frameskip)
load_pickle_state(env, env.game_name, state_nr)

env.action_space.seed(42)
Expand All @@ -79,7 +84,8 @@ def test_seeding(env_name, mode, obs_mode, frameskip, state_nr):

@pytest.mark.parametrize("env_name, obs_mode, frameskip, state_nr", [(game, obs_mode, frameskip, state_nr) for game in GAMES for obs_mode in OBS_MODES for frameskip in FRAMESKIPS for state_nr in get_states(game.split("/")[1].split("-")[0])])
def test_outputsimilarity_between_modes(env_name, obs_mode, frameskip, state_nr):
env = OCAtari(env_name, hud=False, mode="ram", render_mode="rgb_array", render_oc_overlay=False, obs_mode=obs_mode, frameskip=frameskip)
env = OCAtari(env_name, hud=False, mode="ram", render_mode="rgb_array",
render_oc_overlay=False, obs_mode=obs_mode, frameskip=frameskip)
load_pickle_state(env, env.game_name, state_nr)

env.action_space.seed(42)
Expand All @@ -95,7 +101,8 @@ def test_outputsimilarity_between_modes(env_name, obs_mode, frameskip, state_nr)
env.close()

# Reinitialize the environment in a different mode and load the same pickle state
env = OCAtari(env_name, hud=False, mode="vision", render_mode="rgb_array", render_oc_overlay=False, obs_mode=obs_mode, frameskip=frameskip)
env = OCAtari(env_name, hud=False, mode="vision", render_mode="rgb_array",
render_oc_overlay=False, obs_mode=obs_mode, frameskip=frameskip)
load_pickle_state(env, env.game_name, state_nr)

env.action_space.seed(42)
Expand All @@ -113,13 +120,16 @@ def test_outputsimilarity_between_modes(env_name, obs_mode, frameskip, state_nr)
assert np.allclose(obs1, obs2, rtol=2)
env.close()


@pytest.mark.parametrize("env_name, obs_mode, frameskip, state_nr", [(game, obs_mode, frameskip, state_nr) for game in GAMES for obs_mode in OBS_MODES for frameskip in FRAMESKIPS for state_nr in get_states(game.split("/")[1].split("-")[0])])
def test_objects_slots(env_name, obs_mode, frameskip, state_nr):
env = OCAtari(env_name, hud=False, mode="ram", render_mode="rgb_array", render_oc_overlay=False, obs_mode=obs_mode, frameskip=frameskip)
env = OCAtari(env_name, hud=False, mode="ram", render_mode="rgb_array",
render_oc_overlay=False, obs_mode=obs_mode, frameskip=frameskip)
load_pickle_state(env, env.game_name, state_nr)
env.reset()
object_dict = env.max_objects_per_cat
object_list = [key for key, count in object_dict.items() for _ in range(count)]
object_list = [key for key, count in object_dict.items()
for _ in range(count)]

for _ in range(100):
action = env.action_space.sample() # pick random action
Expand All @@ -131,4 +141,4 @@ def test_objects_slots(env_name, obs_mode, frameskip, state_nr):
for type, object in zip(object_list, current_objects):
assert not object or object.category == type

env.close()
env.close()
Loading

0 comments on commit b409b73

Please sign in to comment.