Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added dataset_size attribute to minari datasets #158

Merged
merged 10 commits into from
Nov 22, 2023
5 changes: 5 additions & 0 deletions minari/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ def _show_dataset_table(datasets, table_title):
table.add_column("Name", justify="left", style="cyan", no_wrap=True)
table.add_column("Total Episodes", justify="right", style="green")
table.add_column("Total Steps", justify="right", style="green")
table.add_column("Dataset Size", justify="left", style="green")
table.add_column("Description", justify="left", style="yellow")
table.add_column("Author", justify="left", style="magenta")
table.add_column("Email", justify="left", style="magenta")

for dst_metadata in datasets.values():
author = dst_metadata.get("author", "Unknown")
dataset_size = dst_metadata.get("dataset_size", "Unknown")
if dataset_size != "Unknown":
dataset_size = f"{str(dataset_size)} MB"
author_email = dst_metadata.get("author_email", "Unknown")

assert isinstance(dst_metadata["dataset_id"], str)
Expand All @@ -46,6 +50,7 @@ def _show_dataset_table(datasets, table_title):
dst_metadata["dataset_id"],
str(dst_metadata["total_episodes"]),
str(dst_metadata["total_steps"]),
dataset_size,
"Coming soon ...",
author,
author_email,
Expand Down
18 changes: 17 additions & 1 deletion minari/dataset/minari_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(self, data_path: PathLike):
if not os.path.exists(file_path):
raise ValueError(f"No data found in data path {data_path}")
self._file_path = file_path

self._observation_space = None
self._action_space = None

Expand Down Expand Up @@ -261,6 +260,23 @@ def update_episodes(self, episodes: Iterable[dict]):
file.attrs.modify("total_episodes", total_episodes)
file.attrs.modify("total_steps", total_steps)

def get_size(self):
"""Returns the dataset size in MB.

Returns:
datasize (float): size of the dataset in MB
"""
datasize_list = []
if os.path.exists(self.data_path):

for filename in os.listdir(self.data_path):
datasize = os.path.getsize(os.path.join(self.data_path, filename))
datasize_list.append(datasize)

datasize = np.round(np.sum(datasize_list) / 1000000, 1)

return datasize

def update_from_storage(self, storage: MinariStorage):
"""Update the dataset using another MinariStorage.

Expand Down
13 changes: 12 additions & 1 deletion minari/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,13 @@ def create_dataset_from_buffers(
env_spec=env_spec,
)

# adding `update_metadata` before hand too, as for small envs, the absence of metadata is causing a difference of some 10ths of MBs leading to errors in unit tests.
storage.update_metadata(metadata)
storage.update_episodes(buffer)

metadata['dataset_size'] = storage.get_size()
storage.update_metadata(metadata)

return MinariDataset(storage)


Expand Down Expand Up @@ -618,7 +623,13 @@ def create_dataset_from_collector_env(
)

collector_env.save_to_disk(dataset_path, metadata)
return MinariDataset(dataset_path)

# will be able to calculate dataset size only after saving the disk, so updating the dataset metadata post `save_to_disk` method

dataset = MinariDataset(dataset_path)
metadata['dataset_size'] = dataset.storage.get_size()
dataset.storage.update_metadata(metadata)
return dataset


def get_normalized_score(dataset: MinariDataset, returns: np.ndarray) -> np.ndarray:
Expand Down
157 changes: 157 additions & 0 deletions tests/dataset/test_minari_storage.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
import copy
import os

import gymnasium as gym
import numpy as np
import pytest
from gymnasium import spaces

import minari
from minari import DataCollectorV0
from minari.dataset.minari_storage import MinariStorage
from tests.common import (
check_data_integrity,
check_load_and_delete_dataset,
register_dummy_envs,
)


register_dummy_envs()

file_path = os.path.join(os.path.expanduser("~"), ".minari", "datasets")


def _generate_episode_dict(
Expand Down Expand Up @@ -170,3 +186,144 @@ def test_episode_metadata(tmp_dataset_dir):

ep_indices = [1, 4, 5]
storage.update_episode_metadata(ep_metadatas, episode_indices=ep_indices)


@pytest.mark.parametrize(
"dataset_id,env_id",
[
("cartpole-test-v0", "CartPole-v1"),
("dummy-dict-test-v0", "DummyDictEnv-v0"),
("dummy-box-test-v0", "DummyBoxEnv-v0"),
("dummy-tuple-test-v0", "DummyTupleEnv-v0"),
("dummy-combo-test-v0", "DummyComboEnv-v0"),
("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"),
],
)
def test_minari_get_dataset_size_from_collector_env(dataset_id, env_id):
"""Test get_dataset_size method for dataset made using create_dataset_from_collector_env method."""
# delete the test dataset if it already exists
local_datasets = minari.list_local_datasets()
if dataset_id in local_datasets:
minari.delete_dataset(dataset_id)

env = gym.make(env_id)

env = DataCollectorV0(env)
num_episodes = 100

# Step the environment, DataCollectorV0 wrapper will do the data collection job
env.reset(seed=42)

for episode in range(num_episodes):
done = False
while not done:
action = env.action_space.sample() # User-defined policy function
_, _, terminated, truncated, _ = env.step(action)
done = terminated or truncated

env.reset()

# Create Minari dataset and store locally
dataset = minari.create_dataset_from_collector_env(
dataset_id=dataset_id,
collector_env=env,
algorithm_name="random_policy",
code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py",
author="WillDudley",
author_email="[email protected]",
)

assert dataset.storage.metadata['dataset_size'] == dataset.storage.get_size()

check_data_integrity(dataset.storage, dataset.episode_indices)

env.close()

check_load_and_delete_dataset(dataset_id)


@pytest.mark.parametrize(
"dataset_id,env_id",
[
("cartpole-test-v0", "CartPole-v1"),
("dummy-dict-test-v0", "DummyDictEnv-v0"),
("dummy-box-test-v0", "DummyBoxEnv-v0"),
("dummy-tuple-test-v0", "DummyTupleEnv-v0"),
("dummy-text-test-v0", "DummyTextEnv-v0"),
("dummy-combo-test-v0", "DummyComboEnv-v0"),
("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"),
],
)
def test_minari_get_dataset_size_from_buffer(dataset_id, env_id):
"""Test get_dataset_size method for dataset made using create_dataset_from_buffers method."""
buffer = []

# delete the test dataset if it already exists
local_datasets = minari.list_local_datasets()
if dataset_id in local_datasets:
minari.delete_dataset(dataset_id)

env = gym.make(env_id)

observations = []
actions = []
rewards = []
terminations = []
truncations = []

num_episodes = 10

observation, info = env.reset(seed=42)

# Step the environment, DataCollectorV0 wrapper will do the data collection job
observation, _ = env.reset()
observations.append(observation)
for episode in range(num_episodes):
terminated = False
truncated = False

while not terminated and not truncated:
action = env.action_space.sample() # User-defined policy function
observation, reward, terminated, truncated, _ = env.step(action)
observations.append(observation)
actions.append(action)
rewards.append(reward)
terminations.append(terminated)
truncations.append(truncated)

episode_buffer = {
"observations": copy.deepcopy(observations),
"actions": copy.deepcopy(actions),
"rewards": np.asarray(rewards),
"terminations": np.asarray(terminations),
"truncations": np.asarray(truncations),
}
buffer.append(episode_buffer)

observations.clear()
actions.clear()
rewards.clear()
terminations.clear()
truncations.clear()

observation, _ = env.reset()
observations.append(observation)

# Create Minari dataset and store locally
dataset = minari.create_dataset_from_buffers(
dataset_id=dataset_id,
env=env,
buffer=buffer,
algorithm_name="random_policy",
code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py",
author="WillDudley",
author_email="[email protected]",
)

assert dataset.storage.metadata['dataset_size'] == dataset.storage.get_size()

check_data_integrity(dataset.storage, dataset.episode_indices)

env.close()

check_load_and_delete_dataset(dataset_id)