-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
59 lines (46 loc) · 1.32 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from evaluation.eval_policy import evaluate_policy
# from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import (
SubprocVecEnv,
VecFrameStack,
VecTransposeImage,
VecVideoRecorder,
DummyVecEnv,
)
from wrappers.viewer import Viewer
from wrappers.hotwheels import HotWheelsWrapper
import numpy as np
from utils import HotWheelsStates, make_retro
def make_env():
env = make_retro(
game="HotWheelsStuntTrackChallenge-gba",
state=HotWheelsStates.DINO_BONEYARD_MULTI,
)
env = HotWheelsWrapper(env)
return env
venv = VecTransposeImage(VecFrameStack(DummyVecEnv([make_env] * 1), n_stack=4))
model_path = "model (12).zip"
model = PPO.load(
path=model_path,
env=venv,
# Needed because sometimes sb3 cant find the
# obs and action space. Seen in colab on 8/21/23
custom_objects={
"observation_space": venv.observation_space,
"action_space": venv.action_space,
},
)
try:
eval_info = evaluate_policy(
model,
venv,
n_eval_episodes=1,
return_episode_rewards=True,
deterministic=True,
render=True,
)
for key, value in eval_info.items():
print(f"{key}: {np.mean(value)} {value}")
finally:
venv.close()