Skip to content

Commit

Permalink
fix CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
markub3327 committed Sep 8, 2023
1 parent 24daeea commit 2e9e789
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 61 deletions.
8 changes: 1 addition & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,9 @@ Simply import the package and create the environment with the `make` function.
Take a look at the sample code below:

```python
import time
import flappy_bird_gymnasium
import gymnasium
env = gymnasium.make("FlappyBird-v0")
env = gymnasium.make("FlappyBird-v0", render_mode="human")

obs, _ = env.reset()
while True:
Expand All @@ -76,11 +75,6 @@ while True:
# Processing:
obs, reward, terminated, _, info = env.step(action)

# Rendering the game:
# (remove this two lines during training)
env.render()
time.sleep(1 / 30) # FPS

# Checking if the player is still alive
if terminated:
break
Expand Down
6 changes: 3 additions & 3 deletions flappy_bird_gymnasium/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
import pygame

import flappy_bird_gymnasium
from tests.test_simple_env_dqn import play as dqn_agent_env
from tests.test_simple_env_human import play as human_agent_env
from tests.test_simple_env_random import play as random_agent_env
from tests.test_dqn import play as dqn_agent_env
from tests.test_human import play as human_agent_env
from tests.test_random import play as random_agent_env


def _get_args():
Expand Down
6 changes: 6 additions & 0 deletions flappy_bird_gymnasium/envs/flappy_bird_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ def step(
done = not alive
info = {"score": self._game.score}

if self.render_mode == "human":
self.render()

return obs, reward, done, False, info

def reset(self, seed=None, options=None):
Expand All @@ -198,6 +201,9 @@ def reset(self, seed=None, options=None):
if self._renderer is not None:
self._renderer.game = self._game

if self.render_mode == "human":
self.render()

info = {"score": self._game.score}
return self._get_observation(), info

Expand Down
24 changes: 3 additions & 21 deletions tests/test_simple_env_dqn.py → tests/test_dqn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import time

import gymnasium
import numpy as np
import pygame
import tensorflow as tf

import flappy_bird_gymnasium
Expand Down Expand Up @@ -48,8 +45,8 @@ def get_action(self, state):
return tf.math.argmax(q_value, axis=-1)[0]


def play(epoch=10, audio_on=True, render=True):
env = gymnasium.make("FlappyBird-v0", audio_on=audio_on)
def play(epoch=10, audio_on=True, render_mode="human"):
env = gymnasium.make("FlappyBird-v0", audio_on=audio_on, render_mode=render_mode)

# init models
q_model = DuelingDQN(env.action_space.n)
Expand All @@ -58,38 +55,23 @@ def play(epoch=10, audio_on=True, render=True):

# run
for _ in range(epoch):
clock = pygame.time.Clock()
score = 0

state, _ = env.reset(seed=123)
state = np.expand_dims(state, axis=0)
while True:
if render:
env.render()

# Getting action
action = q_model.get_action(state)
action = np.array(action, copy=False, dtype=env.env.action_space.dtype)

if render:
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()

# Processing action
next_state, reward, done, _, info = env.step(action)

state = np.expand_dims(next_state, axis=0)
score += reward
print(f"Obs: {state}\n" f"Action: {action}\n" f"Score: {score}\n")

if render:
clock.tick(30)

if done:
if render:
env.render()
time.sleep(0.6)
break

env.close()
Expand All @@ -99,7 +81,7 @@ def play(epoch=10, audio_on=True, render=True):


def test_play():
play(epoch=1, audio_on=False, render=False)
play(epoch=1, audio_on=False, render_mode=None)


if __name__ == "__main__":
Expand Down
11 changes: 1 addition & 10 deletions tests/test_simple_env_human.py → tests/test_human.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,19 @@
human player.
"""

import time

import gymnasium
import pygame

import flappy_bird_gymnasium


def play():
env = gymnasium.make("FlappyBird-v0")
env = gymnasium.make("FlappyBird-v0", audio_on=True, render_mode="human")

clock = pygame.time.Clock()
score = 0

obs = env.reset()
while True:
env.render()

# Getting action:
action = 0
for event in pygame.event.get():
Expand All @@ -60,11 +55,7 @@ def play():
score += reward
print(f"Obs: {obs}\n" f"Action: {action}\n" f"Score: {score}\n")

clock.tick(15)

if done:
env.render()
time.sleep(0.6)
break

env.close()
Expand Down
23 changes: 3 additions & 20 deletions tests/test_simple_env_random.py → tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,44 +26,27 @@
random agent.
"""

import time

import gymnasium
import numpy as np
import pygame

import flappy_bird_gymnasium


def play(audio_on=True, render=True):
env = gymnasium.make("FlappyBird-v0", audio_on=audio_on)
def play(audio_on=True, render_mode="human"):
env = gymnasium.make("FlappyBird-v0", audio_on=audio_on, render_mode=render_mode)
score = 0
obs = env.reset(seed=123)
while True:
if render:
env.render()

# Getting random action:
action = env.action_space.sample()

if render:
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()

# Processing:
obs, reward, done, _, info = env.step(action)

score += reward
print(f"Obs: {obs}\n" f"Score: {score}\n")

if render:
time.sleep(1 / 30)

if done:
if render:
env.render()
time.sleep(0.5)
break

env.close()
Expand All @@ -73,7 +56,7 @@ def play(audio_on=True, render=True):


def test_play():
play(audio_on=False, render=False)
play(audio_on=False, render_mode=None)


if __name__ == "__main__":
Expand Down

0 comments on commit 2e9e789

Please sign in to comment.