Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type hints errors in gymnasium/spaces #327

Merged
merged 10 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions gymnasium/spaces/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def __init__(
this value across all dimensions.

Args:
low (Union[SupportsFloat, np.ndarray]): Lower bounds of the intervals.
high (Union[SupportsFloat, np.ndarray]): Upper bounds of the intervals.
low (SupportsFloat | np.ndarray): Lower bounds of the intervals.
high (SupportsFloat | np.ndarray]): Upper bounds of the intervals.
shape (Optional[Sequence[int]]): The shape is inferred from the shape of `low` or `high` `np.ndarray`s with
`low` and `high` scalars defaulting to a shape of (1,)
dtype: The dtype of the elements of the space. If this is an integer type, the :class:`Box` is essentially a discrete space.
Expand Down Expand Up @@ -104,12 +104,13 @@ def __init__(

# Capture the boundedness information before replacing np.inf with get_inf
_low = np.full(shape, low, dtype=float) if is_float_integer(low) else low
self.bounded_below: bool = -np.inf < _low
self.bounded_below: NDArray[np.bool_] = -np.inf < _low

_high = np.full(shape, high, dtype=float) if is_float_integer(high) else high
self.bounded_above: bool = np.inf > _high
self.bounded_above: NDArray[np.bool_] = np.inf > _high

low: NDArray[Any] = _broadcast(low, dtype, shape, inf_sign="-")
high: NDArray[Any] = _broadcast(high, dtype, shape, inf_sign="+")
low = _broadcast(low, self.dtype, shape, inf_sign="-")
high = _broadcast(high, self.dtype, shape, inf_sign="+")

assert isinstance(low, np.ndarray)
assert (
Expand Down Expand Up @@ -280,7 +281,7 @@ def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
self.high_repr = _short_repr(self.high)


def get_inf(dtype: np.dtype, sign: str) -> SupportsFloat:
def get_inf(dtype: np.dtype, sign: str) -> int | float:
"""Returns an infinite that doesn't break things.

Args:
Expand Down
6 changes: 4 additions & 2 deletions gymnasium/spaces/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
), f"Dict space element is not an instance of Space: key='{key}', space={space}"

# None for shape and dtype, since it'll require special handling
super().__init__(None, None, seed)
super().__init__(None, None, seed) # type: ignore

@property
def is_np_flattenable(self):
Expand Down Expand Up @@ -226,7 +226,9 @@ def to_jsonable(self, sample_n: Sequence[dict[str, Any]]) -> dict[str, list[Any]
for key, space in self.spaces.items()
}

def from_jsonable(self, sample_n: dict[str, list[Any]]) -> list[dict[str, Any]]:
def from_jsonable(
self, sample_n: dict[str, list[Any]]
) -> list[OrderedDict[str, Any]]:
vcharraut marked this conversation as resolved.
Show resolved Hide resolved
"""Convert a JSONable data type to a batch of samples from this space."""
dict_of_list: dict[str, list[Any]] = {
key: space.from_jsonable(sample_n[key])
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/spaces/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True

def sample(self, mask: MaskNDArray | None = None) -> int:
def sample(self, mask: MaskNDArray | None = None) -> np.int64:
"""Generates a single random sample from this space.

A sample will be chosen uniformly at random with the mask if provided
Expand Down
4 changes: 2 additions & 2 deletions gymnasium/spaces/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,9 @@ def __eq__(self, other: Any) -> bool:

def to_jsonable(
self, sample_n: Sequence[GraphInstance]
) -> list[dict[str, list[int] | list[float]]]:
) -> list[dict[str, list[int | float]]]:
"""Convert a batch of samples from this space to a JSONable data type."""
ret_n: list[dict[str, list[int | float]]] = []
ret_n = []
for sample in sample_n:
ret = {"nodes": sample.nodes.tolist()}
if sample.edges is not None and sample.edge_links is not None:
Expand Down
16 changes: 6 additions & 10 deletions gymnasium/spaces/multi_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from typing import Any, Sequence

import numpy as np
import numpy.typing as npt
from numpy.typing import NDArray

from gymnasium.spaces.space import MaskNDArray, Space


class MultiBinary(Space[npt.NDArray[np.int8]]):
class MultiBinary(Space[NDArray[np.int8]]):
"""An n-shape binary space.

Elements of this space are binary arrays of a shape that is fixed during construction.
Expand All @@ -28,7 +28,7 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):

def __init__(
self,
n: npt.NDArray[np.integer[Any]] | Sequence[int] | int,
n: NDArray[np.integer[Any]] | Sequence[int] | int,
seed: int | np.random.Generator | None = None,
):
"""Constructor of :class:`MultiBinary` space.
Expand Down Expand Up @@ -58,7 +58,7 @@ def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True

def sample(self, mask: MaskNDArray | None = None) -> npt.NDArray[np.int8]:
def sample(self, mask: MaskNDArray | None = None) -> NDArray[np.int8]:
"""Generates a single random sample from this space.

A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
Expand Down Expand Up @@ -104,15 +104,11 @@ def contains(self, x: Any) -> bool:
and np.all(np.logical_or(x == 0, x == 1))
)

def to_jsonable(
self, sample_n: Sequence[npt.NDArray[np.int8]]
) -> list[Sequence[int]]:
def to_jsonable(self, sample_n: Sequence[NDArray[np.int8]]) -> list[Sequence[int]]:
"""Convert a batch of samples from this space to a JSONable data type."""
return np.array(sample_n).tolist()

def from_jsonable(
self, sample_n: list[Sequence[int]]
) -> list[npt.NDArray[np.int8]]:
def from_jsonable(self, sample_n: list[Sequence[int]]) -> list[NDArray[np.int8]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [np.asarray(sample, self.dtype) for sample in sample_n]

Expand Down
14 changes: 7 additions & 7 deletions gymnasium/spaces/multi_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from typing import Any, Sequence

import numpy as np
import numpy.typing as npt
from numpy.typing import NDArray

import gymnasium as gym
from gymnasium.spaces.discrete import Discrete
from gymnasium.spaces.space import MaskNDArray, Space


class MultiDiscrete(Space[npt.NDArray[np.integer]]):
class MultiDiscrete(Space[NDArray[np.integer]]):
"""This represents the cartesian product of arbitrary :class:`Discrete` spaces.

It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space.
Expand Down Expand Up @@ -41,7 +41,7 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):

def __init__(
self,
nvec: npt.NDArray[np.integer[Any]] | list[int],
nvec: NDArray[np.integer[Any]] | list[int],
dtype: str | type[np.integer[Any]] = np.int64,
seed: int | np.random.Generator | None = None,
):
Expand Down Expand Up @@ -72,7 +72,7 @@ def is_np_flattenable(self):

def sample(
self, mask: tuple[MaskNDArray, ...] | None = None
) -> npt.NDArray[np.integer[Any]]:
) -> NDArray[np.integer[Any]]:
"""Generates a single random sample this space.

Args:
Expand All @@ -88,7 +88,7 @@ def sample(
def _apply_mask(
sub_mask: MaskNDArray | tuple[MaskNDArray, ...],
sub_nvec: MaskNDArray | np.integer[Any],
) -> int | Sequence[int]:
) -> int | list[Any]:
if isinstance(sub_nvec, np.ndarray):
assert isinstance(
sub_mask, tuple
Expand Down Expand Up @@ -144,14 +144,14 @@ def contains(self, x: Any) -> bool:
)

def to_jsonable(
self, sample_n: Sequence[npt.NDArray[np.integer[Any]]]
self, sample_n: Sequence[NDArray[np.integer[Any]]]
) -> list[Sequence[int]]:
"""Convert a batch of samples from this space to a JSONable data type."""
return [sample.tolist() for sample in sample_n]

def from_jsonable(
self, sample_n: list[Sequence[int]]
) -> list[npt.NDArray[np.integer[Any]]]:
) -> list[NDArray[np.integer[Any]]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [np.array(sample) for sample in sample_n]

Expand Down
6 changes: 3 additions & 3 deletions gymnasium/spaces/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Union

import numpy as np
import numpy.typing as npt
from numpy.typing import NDArray

import gymnasium as gym
from gymnasium.spaces.space import Space
Expand Down Expand Up @@ -69,11 +69,11 @@ def sample(
mask: None
| (
tuple[
None | np.integer | npt.NDArray[np.integer],
None | np.integer | NDArray[np.integer],
Any,
]
) = None,
) -> tuple[Any]:
) -> tuple[Any] | Any:
"""Generates a single random sample from this space.

Args:
Expand Down
4 changes: 2 additions & 2 deletions gymnasium/spaces/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def sample(self, mask: Any | None = None) -> T_cov:

def seed(self, seed: int | None = None) -> list[int]:
"""Seed the PRNG of this space and possibly the PRNGs of subspaces."""
self._np_random, seed = seeding.np_random(seed)
return [seed]
self._np_random, np_random_seed = seeding.np_random(seed)
return [np_random_seed]

def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
Expand Down
6 changes: 3 additions & 3 deletions gymnasium/spaces/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

import numpy as np
import numpy.typing as npt
from numpy.typing import NDArray

from gymnasium.spaces.space import Space

Expand Down Expand Up @@ -35,7 +35,7 @@ def __init__(
max_length: int,
*,
min_length: int = 1,
charset: set[str] | str = alphanumeric,
charset: frozenset[str] | str = alphanumeric,
seed: int | np.random.Generator | None = None,
):
r"""Constructor of :class:`Text` space.
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(

def sample(
self,
mask: None | (tuple[int | None, npt.NDArray[np.int8] | None]) = None,
mask: None | (tuple[int | None, NDArray[np.int8] | None]) = None,
) -> str:
"""Generates a single random sample from this space with by default a random length between `min_length` and `max_length` and sampled from the `charset`.

Expand Down
6 changes: 4 additions & 2 deletions gymnasium/spaces/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,17 @@ def _flatten_multidiscrete(
def _flatten_tuple(space: Tuple, x: tuple[Any, ...]) -> tuple[Any, ...] | NDArray[Any]:
if space.is_np_flattenable:
return np.concatenate(
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
[np.array(flatten(s, x_part)) for x_part, s in zip(x, space.spaces)]
)
return tuple(flatten(s, x_part) for x_part, s in zip(x, space.spaces))


@flatten.register(Dict)
def _flatten_dict(space: Dict, x: dict[str, Any]) -> dict[str, Any] | NDArray[Any]:
if space.is_np_flattenable:
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
return np.concatenate(
[np.array(flatten(s, x[key])) for key, s in space.spaces.items()]
)
return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items())


Expand Down