Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type hints errors in gymnasium/spaces #327

Merged
merged 10 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions gymnasium/spaces/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def __init__(
this value across all dimensions.

Args:
low (Union[SupportsFloat, np.ndarray]): Lower bounds of the intervals.
high (Union[SupportsFloat, np.ndarray]): Upper bounds of the intervals.
low (SupportsFloat | np.ndarray): Lower bounds of the intervals.
high (SupportsFloat | np.ndarray]): Upper bounds of the intervals.
shape (Optional[Sequence[int]]): The shape is inferred from the shape of `low` or `high` `np.ndarray`s with
`low` and `high` scalars defaulting to a shape of (1,)
dtype: The dtype of the elements of the space. If this is an integer type, the :class:`Box` is essentially a discrete space.
Expand Down Expand Up @@ -104,12 +104,13 @@ def __init__(

# Capture the boundedness information before replacing np.inf with get_inf
_low = np.full(shape, low, dtype=float) if is_float_integer(low) else low
self.bounded_below: bool = -np.inf < _low
self.bounded_below: NDArray[np.bool_] = -np.inf < _low

_high = np.full(shape, high, dtype=float) if is_float_integer(high) else high
self.bounded_above: bool = np.inf > _high
self.bounded_above: NDArray[np.bool_] = np.inf > _high

low: NDArray[Any] = _broadcast(low, dtype, shape, inf_sign="-")
high: NDArray[Any] = _broadcast(high, dtype, shape, inf_sign="+")
low = _broadcast(low, self.dtype, shape, inf_sign="-")
high = _broadcast(high, self.dtype, shape, inf_sign="+")

assert isinstance(low, np.ndarray)
assert (
Expand Down Expand Up @@ -280,7 +281,7 @@ def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
self.high_repr = _short_repr(self.high)


def get_inf(dtype: np.dtype, sign: str) -> SupportsFloat:
def get_inf(dtype: np.dtype, sign: str) -> int | float:
"""Returns an infinite that doesn't break things.

Args:
Expand Down
6 changes: 4 additions & 2 deletions gymnasium/spaces/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
), f"Dict space element is not an instance of Space: key='{key}', space={space}"

# None for shape and dtype, since it'll require special handling
super().__init__(None, None, seed)
super().__init__(None, None, seed) # type: ignore

@property
def is_np_flattenable(self):
Expand Down Expand Up @@ -226,7 +226,9 @@ def to_jsonable(self, sample_n: Sequence[dict[str, Any]]) -> dict[str, list[Any]
for key, space in self.spaces.items()
}

def from_jsonable(self, sample_n: dict[str, list[Any]]) -> list[dict[str, Any]]:
def from_jsonable(
self, sample_n: dict[str, list[Any]]
) -> list[OrderedDict[str, Any]]:
vcharraut marked this conversation as resolved.
Show resolved Hide resolved
"""Convert a JSONable data type to a batch of samples from this space."""
dict_of_list: dict[str, list[Any]] = {
key: space.from_jsonable(sample_n[key])
Expand Down
3 changes: 2 additions & 1 deletion gymnasium/spaces/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Iterable, Mapping, Sequence

import numpy as np
from numpy.typing import NDArray

from gymnasium.spaces.space import MaskNDArray, Space

Expand Down Expand Up @@ -55,7 +56,7 @@ def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True

def sample(self, mask: MaskNDArray | None = None) -> int:
def sample(self, mask: MaskNDArray | None = None) -> np.int64 | NDArray[np.int64]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think this should just be np.int64

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Generates a single random sample from this space.

A sample will be chosen uniformly at random with the mask if provided
Expand Down
4 changes: 2 additions & 2 deletions gymnasium/spaces/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,9 @@ def __eq__(self, other: Any) -> bool:

def to_jsonable(
self, sample_n: Sequence[GraphInstance]
) -> list[dict[str, list[int] | list[float]]]:
) -> list[dict[str, list[int | float]]]:
"""Convert a batch of samples from this space to a JSONable data type."""
ret_n: list[dict[str, list[int | float]]] = []
ret_n = []
for sample in sample_n:
ret = {"nodes": sample.nodes.tolist()}
if sample.edges is not None and sample.edge_links is not None:
Expand Down
16 changes: 6 additions & 10 deletions gymnasium/spaces/multi_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from typing import Any, Sequence

import numpy as np
import numpy.typing as npt
from numpy.typing import NDArray

from gymnasium.spaces.space import MaskNDArray, Space


class MultiBinary(Space[npt.NDArray[np.int8]]):
class MultiBinary(Space[NDArray[np.int8]]):
"""An n-shape binary space.

Elements of this space are binary arrays of a shape that is fixed during construction.
Expand All @@ -28,7 +28,7 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):

def __init__(
self,
n: npt.NDArray[np.integer[Any]] | Sequence[int] | int,
n: NDArray[np.integer[Any]] | Sequence[int] | int,
seed: int | np.random.Generator | None = None,
):
"""Constructor of :class:`MultiBinary` space.
Expand Down Expand Up @@ -58,7 +58,7 @@ def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True

def sample(self, mask: MaskNDArray | None = None) -> npt.NDArray[np.int8]:
def sample(self, mask: MaskNDArray | None = None) -> NDArray[np.bool_ | np.int8]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What case causes np.bool_ as sample return type as it should be the same type as class MultiBinary(Space[NDArray[np.int8]]):

Copy link
Contributor Author

@vcharraut vcharraut Feb 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this comes from np_random.integers that has a wild possiblity of type return, like int, bool and ndarray.
I don't know how to deal with it.
Since np_random.integers accepts several type in args, I guess there would be a way to force the return type of the function based on the type of the args.

I have actually the exact same proble in the graph.py and utils.py files: there are some functions that can accept multiples spaces types but when used with specific spaces, the type raised errors (like flatten_space in utils.py.

That is already the same kind of problem I encountered with this part of code #327 (comment) and I ended up splitting the function into 2 functions to handle each type cases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, even if self.n has a static type np.int64, this line is raises an error

tmp: np.int64 = self.np_random.integers(self.n)

Expression of type "ndarray[Any, dtype[int64]]" cannot be assigned to declared type "int64"
  "ndarray[Any, dtype[int64]]" is incompatible with "int64"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we cast the data in some way to fix this, otherwise, raise an issue in numpy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try to look for solutions and do another PR

"""Generates a single random sample from this space.

A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
Expand Down Expand Up @@ -104,15 +104,11 @@ def contains(self, x: Any) -> bool:
and np.all(np.logical_or(x == 0, x == 1))
)

def to_jsonable(
self, sample_n: Sequence[npt.NDArray[np.int8]]
) -> list[Sequence[int]]:
def to_jsonable(self, sample_n: Sequence[NDArray[np.int8]]) -> list[Sequence[int]]:
"""Convert a batch of samples from this space to a JSONable data type."""
return np.array(sample_n).tolist()

def from_jsonable(
self, sample_n: list[Sequence[int]]
) -> list[npt.NDArray[np.int8]]:
def from_jsonable(self, sample_n: list[Sequence[int]]) -> list[NDArray[np.int8]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [np.asarray(sample, self.dtype) for sample in sample_n]

Expand Down
14 changes: 7 additions & 7 deletions gymnasium/spaces/multi_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from typing import Any, Sequence

import numpy as np
import numpy.typing as npt
from numpy.typing import NDArray

import gymnasium as gym
from gymnasium.spaces.discrete import Discrete
from gymnasium.spaces.space import MaskNDArray, Space


class MultiDiscrete(Space[npt.NDArray[np.integer]]):
class MultiDiscrete(Space[NDArray[np.integer]]):
"""This represents the cartesian product of arbitrary :class:`Discrete` spaces.

It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space.
Expand Down Expand Up @@ -41,7 +41,7 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):

def __init__(
self,
nvec: npt.NDArray[np.integer[Any]] | list[int],
nvec: NDArray[np.integer[Any]] | list[int],
dtype: str | type[np.integer[Any]] = np.int64,
seed: int | np.random.Generator | None = None,
):
Expand Down Expand Up @@ -72,7 +72,7 @@ def is_np_flattenable(self):

def sample(
self, mask: tuple[MaskNDArray, ...] | None = None
) -> npt.NDArray[np.integer[Any]]:
) -> NDArray[np.integer[Any]]:
"""Generates a single random sample this space.

Args:
Expand All @@ -88,7 +88,7 @@ def sample(
def _apply_mask(
sub_mask: MaskNDArray | tuple[MaskNDArray, ...],
sub_nvec: MaskNDArray | np.integer[Any],
) -> int | Sequence[int]:
) -> int | list[Any]:
if isinstance(sub_nvec, np.ndarray):
assert isinstance(
sub_mask, tuple
Expand Down Expand Up @@ -144,14 +144,14 @@ def contains(self, x: Any) -> bool:
)

def to_jsonable(
self, sample_n: Sequence[npt.NDArray[np.integer[Any]]]
self, sample_n: Sequence[NDArray[np.integer[Any]]]
) -> list[Sequence[int]]:
"""Convert a batch of samples from this space to a JSONable data type."""
return [sample.tolist() for sample in sample_n]

def from_jsonable(
self, sample_n: list[Sequence[int]]
) -> list[npt.NDArray[np.integer[Any]]]:
) -> list[NDArray[np.integer[Any]]]:
"""Convert a JSONable data type to a batch of samples from this space."""
return [np.array(sample) for sample in sample_n]

Expand Down
6 changes: 3 additions & 3 deletions gymnasium/spaces/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Union

import numpy as np
import numpy.typing as npt
from numpy.typing import NDArray

import gymnasium as gym
from gymnasium.spaces.space import Space
Expand Down Expand Up @@ -69,11 +69,11 @@ def sample(
mask: None
| (
tuple[
None | np.integer | npt.NDArray[np.integer],
None | np.integer | NDArray[np.integer],
Any,
]
) = None,
) -> tuple[Any]:
) -> tuple[Any] | Any:
"""Generates a single random sample from this space.

Args:
Expand Down
4 changes: 2 additions & 2 deletions gymnasium/spaces/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def sample(self, mask: Any | None = None) -> T_cov:

def seed(self, seed: int | None = None) -> list[int]:
"""Seed the PRNG of this space and possibly the PRNGs of subspaces."""
self._np_random, seed = seeding.np_random(seed)
return [seed]
self._np_random, np_random_seed = seeding.np_random(seed)
return [np_random_seed]

def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
Expand Down
6 changes: 3 additions & 3 deletions gymnasium/spaces/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

import numpy as np
import numpy.typing as npt
from numpy.typing import NDArray

from gymnasium.spaces.space import Space

Expand Down Expand Up @@ -35,7 +35,7 @@ def __init__(
max_length: int,
*,
min_length: int = 1,
charset: set[str] | str = alphanumeric,
charset: frozenset[str] | str = alphanumeric,
seed: int | np.random.Generator | None = None,
):
r"""Constructor of :class:`Text` space.
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(

def sample(
self,
mask: None | (tuple[int | None, npt.NDArray[np.int8] | None]) = None,
mask: None | (tuple[int | None, NDArray[np.int8] | None]) = None,
) -> str:
"""Generates a single random sample from this space with by default a random length between `min_length` and `max_length` and sampled from the `charset`.

Expand Down
22 changes: 17 additions & 5 deletions gymnasium/spaces/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,17 @@ def _flatten_multidiscrete(
def _flatten_tuple(space: Tuple, x: tuple[Any, ...]) -> tuple[Any, ...] | NDArray[Any]:
if space.is_np_flattenable:
return np.concatenate(
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
[np.array(flatten(s, x_part)) for x_part, s in zip(x, space.spaces)]
)
return tuple(flatten(s, x_part) for x_part, s in zip(x, space.spaces))


@flatten.register(Dict)
def _flatten_dict(space: Dict, x: dict[str, Any]) -> dict[str, Any] | NDArray[Any]:
if space.is_np_flattenable:
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
return np.concatenate(
[np.array(flatten(s, x[key])) for key, s in space.spaces.items()]
)
return OrderedDict((key, flatten(s, x[key])) for key, s in space.spaces.items())


Expand Down Expand Up @@ -354,7 +356,17 @@ def _unflatten_graph(space: Graph, x: GraphInstance) -> GraphInstance:
nodes and edges in the graph.
"""

def _graph_unflatten(unflatten_space, unflatten_x):
def _graph_unflatten_nodes(
unflatten_space: Box | Discrete, unflatten_x: NDArray[Any]
) -> NDArray[Any]:
if isinstance(unflatten_space, Box):
return unflatten_x.reshape(-1, *unflatten_space.shape)
elif isinstance(unflatten_space, Discrete):
return np.asarray(np.nonzero(unflatten_x))[-1, :]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you be more specific ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:
   raise TypeError("Expected the ... ")

def _graph_unflatten_edges(
unflatten_space: Box | Discrete | None, unflatten_x: NDArray[Any] | None
) -> NDArray[Any] | None:
result = None
if unflatten_space is not None and unflatten_x is not None:
if isinstance(unflatten_space, Box):
Expand All @@ -363,8 +375,8 @@ def _graph_unflatten(unflatten_space, unflatten_x):
result = np.asarray(np.nonzero(unflatten_x))[-1, :]
return result

nodes = _graph_unflatten(space.node_space, x.nodes)
edges = _graph_unflatten(space.edge_space, x.edges)
nodes = _graph_unflatten_nodes(space.node_space, x.nodes)
edges = _graph_unflatten_edges(space.edge_space, x.edges)

return GraphInstance(nodes, edges, x.edge_links)

Expand Down