Skip to content

Commit

Permalink
formatting with autopep8
Browse files Browse the repository at this point in the history
  • Loading branch information
BluemlJ committed Dec 16, 2024
1 parent efc6913 commit f873bf8
Show file tree
Hide file tree
Showing 188 changed files with 5,772 additions and 4,408 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
key: pickle-cache-${{ github.sha }}
restore-keys: |
pickle-cache-
- name: Install dependencies
run: |
Expand All @@ -49,7 +49,7 @@ jobs:
if [ ! -s changed_files.txt ]; then
echo "No Python files changed."
fi
- name: Run general tests for all changes
id: general_tests
continue-on-error: true
Expand All @@ -75,7 +75,7 @@ jobs:
done < changed_files.txt
# Remove the trailing comma
GAMES=${GAMES%,}
GAMES=${GAMES%,}
echo "GAMES=$GAMES" >> $GITHUB_ENV
- name: Download Pickle Files if Not Cached
Expand All @@ -99,3 +99,4 @@ jobs:
echo "One or more critical tests have failed. Failing workflow."
exit 1
2 changes: 1 addition & 1 deletion dataset_generation/ReadMe.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# ODA: An object-centric dataset for Atari

## Using generate_dataset.py
`generate_dataset.py`is the main script for generating object-centric datasets for Atari games. The list of supported games can be found in the OCAtari ReadMe. Games which only supports vision mode should use the `generate_dataset_vision.py`script instead. All parameters and settings are identical and can be used with both scripts.
`generate_dataset.py`is the main script for generating object-centric datasets for Atari games. The list of supported games can be found in the OCAtari ReadMe. Games which only supports vision mode should use the `generate_dataset_vision.py`script instead. All parameters and settings are identical and can be used with both scripts.


### Requirements
Expand Down
2 changes: 1 addition & 1 deletion dataset_generation/datasets_on_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ do
done

# VISION ONLY
for game in Alien Asteroids BeamRider ChopperCommand DemonAttack FishingDerby Frostbite MontezumaRevenge Qbert Riverraid RoadRunner
for game in Alien Asteroids BeamRider ChopperCommand DemonAttack FishingDerby Frostbite MontezumaRevenge Qbert Riverraid RoadRunner
# for game in Breakout Pong
do
python3 generate_dataset_vision.py -g $game
Expand Down
19 changes: 12 additions & 7 deletions dataset_generation/generate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
help="The frame interval (default 10)")
# parser.add_argument("-m", "--mode", choices=["vision", "ram"],
# default="ram", help="The frame interval")
parser.add_argument("-hud", "--hud", action="store_true", default=True, help="Detect HUD")
parser.add_argument("-dqn", "--dqn", action="store_true", default=True, help="Use DQN agent")
parser.add_argument("-hud", "--hud", action="store_true",
default=True, help="Detect HUD")
parser.add_argument("-dqn", "--dqn", action="store_true",
default=True, help="Use DQN agent")
opts = parser.parse_args()

# Init the environment
Expand All @@ -47,7 +49,7 @@
# Init an empty dataset
game_nr = 0
turn_nr = 0
dataset = {"INDEX": [], #"OBS": [],
dataset = {"INDEX": [], # "OBS": [],
"RAM": [], "VIS": [], "HUD": []}
frames = []
r_objs = []
Expand All @@ -68,9 +70,12 @@
r_objs.append(deepcopy(env.objects))
v_objs.append(deepcopy(env.objects_v))
# dataset["OBS"].append(obs.flatten().tolist())
dataset["VIS"].append([x for x in sorted(env.objects_v, key=lambda o: str(o))])
dataset["RAM"].append([x for x in sorted(env.objects, key=lambda o: str(o)) if x.hud == False])
dataset["HUD"].append([x for x in sorted(env.objects, key=lambda o: str(o)) if x.hud == True])
dataset["VIS"].append(
[x for x in sorted(env.objects_v, key=lambda o: str(o))])
dataset["RAM"].append(
[x for x in sorted(env.objects, key=lambda o: str(o)) if x.hud == False])
dataset["HUD"].append(
[x for x in sorted(env.objects, key=lambda o: str(o)) if x.hud == True])
turn_nr = turn_nr + 1

# if a game is terminated, restart with a new game and update turn and game counter
Expand Down Expand Up @@ -127,7 +132,7 @@

df = pd.DataFrame(dataset, columns=['INDEX', 'RAM', 'HUD', 'VIS'])
makedirs("data/datasets/", exist_ok=True)
prefix = f"{opts.game}_dqn" if opts.dqn else f"{opts.game}_random"
prefix = f"{opts.game}_dqn" if opts.dqn else f"{opts.game}_random"
df.to_csv(f"data/datasets/{prefix}.csv", index=False)
pickle.dump(v_objs, open(f"data/datasets/{prefix}_objects_v.pkl", "wb"))
pickle.dump(r_objs, open(f"data/datasets/{prefix}_objects_r.pkl", "wb"))
Expand Down
18 changes: 10 additions & 8 deletions dataset_generation/generate_dataset_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import random
import matplotlib.pyplot as plt
from os import path
#sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) # noqa
# sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) # noqa
from ocatari.core import OCAtari
from ocatari.vision.utils import mark_bb, make_darker
#from ocatari.vision.space_invaders import objects_colors
# from ocatari.vision.space_invaders import objects_colors
from ocatari.vision.pong import objects_colors
from ocatari.utils import load_agent, parser, make_deterministic
from copy import deepcopy
Expand All @@ -21,8 +21,10 @@
help="The frame interval (default 10)")
# parser.add_argument("-m", "--mode", choices=["vision", "ram"],
# default="ram", help="The frame interval")
parser.add_argument("-hud", "--hud", action="store_true", default=True, help="Detect HUD")
parser.add_argument("-dqn", "--dqn", action="store_true", default=True, help="Use DQN agent")
parser.add_argument("-hud", "--hud", action="store_true",
default=True, help="Detect HUD")
parser.add_argument("-dqn", "--dqn", action="store_true",
default=True, help="Use DQN agent")

opts = parser.parse_args()

Expand All @@ -45,13 +47,14 @@

obs, reward, terminated, truncated, info = env.step(action)

#if i % 1000 == 0:
# if i % 1000 == 0:
# print(f"{i} done")

dataset["INDEX"].append(f"{'%0.5d' %(game_nr)}_{'%0.5d' %(turn_nr)}")
dataset["OBS"].append(obs.flatten().tolist())
dataset["RAM"].append([])
dataset["VIS"].append([x for x in sorted(env.objects, key=lambda o: str(o))])
dataset["VIS"].append(
[x for x in sorted(env.objects, key=lambda o: str(o))])
dataset["HUD"].append([])
turn_nr = turn_nr+1

Expand Down Expand Up @@ -89,7 +92,6 @@
"""
env.close()

df = pd.DataFrame(dataset, columns = ['INDEX', 'OBS', 'RAM', 'HUD', 'VIS'])
df = pd.DataFrame(dataset, columns=['INDEX', 'OBS', 'RAM', 'HUD', 'VIS'])
df.to_csv(f"/data/datasets_v/{opts.game}.csv", index=False)
print(f"Finished {opts.game}")

Loading

0 comments on commit f873bf8

Please sign in to comment.