Minor: use Self type where appropriate (#942)
Small typing improvement, related to https://github.com/thu-ml/tianshou/pull/915#discussion_r1329734222
This commit is contained in:
parent
2cc34fb72b
commit
c8e7d02cba
@ -6,6 +6,7 @@ from numbers import Number
|
|||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Protocol,
|
Protocol,
|
||||||
|
Self,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
@ -232,7 +233,7 @@ class BatchProtocol(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self: TBatch, index: IndexType) -> TBatch:
|
def __getitem__(self, index: IndexType) -> Self:
|
||||||
...
|
...
|
||||||
|
|
||||||
def __getitem__(self, index: str | IndexType) -> Any:
|
def __getitem__(self, index: str | IndexType) -> Any:
|
||||||
@ -241,22 +242,22 @@ class BatchProtocol(Protocol):
|
|||||||
def __setitem__(self, index: str | IndexType, value: Any) -> None:
|
def __setitem__(self, index: str | IndexType, value: Any) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
def __iadd__(self: TBatch, other: TBatch | Number | np.number) -> TBatch:
|
def __iadd__(self, other: Self | Number | np.number) -> Self:
|
||||||
...
|
...
|
||||||
|
|
||||||
def __add__(self: TBatch, other: TBatch | Number | np.number) -> TBatch:
|
def __add__(self, other: Self | Number | np.number) -> Self:
|
||||||
...
|
...
|
||||||
|
|
||||||
def __imul__(self: TBatch, value: Number | np.number) -> TBatch:
|
def __imul__(self, value: Number | np.number) -> Self:
|
||||||
...
|
...
|
||||||
|
|
||||||
def __mul__(self: TBatch, value: Number | np.number) -> TBatch:
|
def __mul__(self, value: Number | np.number) -> Self:
|
||||||
...
|
...
|
||||||
|
|
||||||
def __itruediv__(self: TBatch, value: Number | np.number) -> TBatch:
|
def __itruediv__(self, value: Number | np.number) -> Self:
|
||||||
...
|
...
|
||||||
|
|
||||||
def __truediv__(self: TBatch, value: Number | np.number) -> TBatch:
|
def __truediv__(self, value: Number | np.number) -> Self:
|
||||||
...
|
...
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@ -274,7 +275,7 @@ class BatchProtocol(Protocol):
|
|||||||
"""Change all numpy.ndarray to torch.Tensor in-place."""
|
"""Change all numpy.ndarray to torch.Tensor in-place."""
|
||||||
...
|
...
|
||||||
|
|
||||||
def cat_(self, batches: TBatch | Sequence[dict | TBatch]) -> None:
|
def cat_(self, batches: Self | Sequence[dict | Self]) -> None:
|
||||||
"""Concatenate a list of (or one) Batch objects into current batch."""
|
"""Concatenate a list of (or one) Batch objects into current batch."""
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -298,7 +299,7 @@ class BatchProtocol(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def stack_(self, batches: Sequence[dict | TBatch], axis: int = 0) -> None:
|
def stack_(self, batches: Sequence[dict | Self], axis: int = 0) -> None:
|
||||||
"""Stack a list of Batch object into current batch."""
|
"""Stack a list of Batch object into current batch."""
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -327,7 +328,7 @@ class BatchProtocol(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def empty_(self: TBatch, index: slice | IndexType | None = None) -> TBatch:
|
def empty_(self, index: slice | IndexType | None = None) -> Self:
|
||||||
"""Return an empty Batch object with 0 or None filled.
|
"""Return an empty Batch object with 0 or None filled.
|
||||||
|
|
||||||
If "index" is specified, it will only reset the specific indexed-data.
|
If "index" is specified, it will only reset the specific indexed-data.
|
||||||
@ -362,7 +363,7 @@ class BatchProtocol(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def update(self, batch: dict | TBatch | None = None, **kwargs: Any) -> None:
|
def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None:
|
||||||
"""Update this batch from another dict/Batch."""
|
"""Update this batch from another dict/Batch."""
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -373,11 +374,11 @@ class BatchProtocol(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
def split(
|
def split(
|
||||||
self: TBatch,
|
self,
|
||||||
size: int,
|
size: int,
|
||||||
shuffle: bool = True,
|
shuffle: bool = True,
|
||||||
merge_last: bool = False,
|
merge_last: bool = False,
|
||||||
) -> Iterator[TBatch]:
|
) -> Iterator[Self]:
|
||||||
"""Split whole data into multiple small batches.
|
"""Split whole data into multiple small batches.
|
||||||
|
|
||||||
:param int size: divide the data batch with the given size, but one
|
:param int size: divide the data batch with the given size, but one
|
||||||
@ -457,7 +458,7 @@ class Batch(BatchProtocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self: TBatch, index: IndexType) -> TBatch:
|
def __getitem__(self, index: IndexType) -> Self:
|
||||||
...
|
...
|
||||||
|
|
||||||
def __getitem__(self, index: str | IndexType) -> Any:
|
def __getitem__(self, index: str | IndexType) -> Any:
|
||||||
@ -501,7 +502,7 @@ class Batch(BatchProtocol):
|
|||||||
else:
|
else:
|
||||||
self.__dict__[key][index] = None
|
self.__dict__[key][index] = None
|
||||||
|
|
||||||
def __iadd__(self: TBatch, other: TBatch | Number | np.number) -> TBatch:
|
def __iadd__(self, other: Self | Number | np.number) -> Self:
|
||||||
"""Algebraic addition with another Batch instance in-place."""
|
"""Algebraic addition with another Batch instance in-place."""
|
||||||
if isinstance(other, Batch):
|
if isinstance(other, Batch):
|
||||||
for (batch_key, obj), value in zip(
|
for (batch_key, obj), value in zip(
|
||||||
@ -521,11 +522,11 @@ class Batch(BatchProtocol):
|
|||||||
return self
|
return self
|
||||||
raise TypeError("Only addition of Batch or number is supported.")
|
raise TypeError("Only addition of Batch or number is supported.")
|
||||||
|
|
||||||
def __add__(self: TBatch, other: TBatch | Number | np.number) -> TBatch:
|
def __add__(self, other: Self | Number | np.number) -> Self:
|
||||||
"""Algebraic addition with another Batch instance out-of-place."""
|
"""Algebraic addition with another Batch instance out-of-place."""
|
||||||
return deepcopy(self).__iadd__(other)
|
return deepcopy(self).__iadd__(other)
|
||||||
|
|
||||||
def __imul__(self: TBatch, value: Number | np.number) -> TBatch:
|
def __imul__(self, value: Number | np.number) -> Self:
|
||||||
"""Algebraic multiplication with a scalar value in-place."""
|
"""Algebraic multiplication with a scalar value in-place."""
|
||||||
assert _is_number(value), "Only multiplication by a number is supported."
|
assert _is_number(value), "Only multiplication by a number is supported."
|
||||||
for batch_key, obj in self.__dict__.items():
|
for batch_key, obj in self.__dict__.items():
|
||||||
@ -534,11 +535,11 @@ class Batch(BatchProtocol):
|
|||||||
self.__dict__[batch_key] *= value
|
self.__dict__[batch_key] *= value
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __mul__(self: TBatch, value: Number | np.number) -> TBatch:
|
def __mul__(self, value: Number | np.number) -> Self:
|
||||||
"""Algebraic multiplication with a scalar value out-of-place."""
|
"""Algebraic multiplication with a scalar value out-of-place."""
|
||||||
return deepcopy(self).__imul__(value)
|
return deepcopy(self).__imul__(value)
|
||||||
|
|
||||||
def __itruediv__(self: TBatch, value: Number | np.number) -> TBatch:
|
def __itruediv__(self, value: Number | np.number) -> Self:
|
||||||
"""Algebraic division with a scalar value in-place."""
|
"""Algebraic division with a scalar value in-place."""
|
||||||
assert _is_number(value), "Only division by a number is supported."
|
assert _is_number(value), "Only division by a number is supported."
|
||||||
for batch_key, obj in self.__dict__.items():
|
for batch_key, obj in self.__dict__.items():
|
||||||
@ -547,7 +548,7 @@ class Batch(BatchProtocol):
|
|||||||
self.__dict__[batch_key] /= value
|
self.__dict__[batch_key] /= value
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __truediv__(self: TBatch, value: Number | np.number) -> TBatch:
|
def __truediv__(self, value: Number | np.number) -> Self:
|
||||||
"""Algebraic division with a scalar value out-of-place."""
|
"""Algebraic division with a scalar value out-of-place."""
|
||||||
return deepcopy(self).__itruediv__(value)
|
return deepcopy(self).__itruediv__(value)
|
||||||
|
|
||||||
@ -604,7 +605,7 @@ class Batch(BatchProtocol):
|
|||||||
obj = obj.type(dtype) # noqa: PLW2901
|
obj = obj.type(dtype) # noqa: PLW2901
|
||||||
self.__dict__[batch_key] = obj
|
self.__dict__[batch_key] = obj
|
||||||
|
|
||||||
def __cat(self: TBatch, batches: Sequence[dict | TBatch], lens: list[int]) -> None:
|
def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:
|
||||||
"""Private method for Batch.cat_.
|
"""Private method for Batch.cat_.
|
||||||
|
|
||||||
::
|
::
|
||||||
@ -798,7 +799,7 @@ class Batch(BatchProtocol):
|
|||||||
# can't cast to a generic type, so we have to ignore the type here
|
# can't cast to a generic type, so we have to ignore the type here
|
||||||
return batch # type: ignore
|
return batch # type: ignore
|
||||||
|
|
||||||
def empty_(self: TBatch, index: slice | IndexType | None = None) -> TBatch:
|
def empty_(self, index: slice | IndexType | None = None) -> Self:
|
||||||
for batch_key, obj in self.items():
|
for batch_key, obj in self.items():
|
||||||
if isinstance(obj, torch.Tensor): # most often case
|
if isinstance(obj, torch.Tensor): # most often case
|
||||||
self.__dict__[batch_key][index] = 0
|
self.__dict__[batch_key][index] = 0
|
||||||
@ -826,7 +827,7 @@ class Batch(BatchProtocol):
|
|||||||
def empty(batch: TBatch, index: IndexType | None = None) -> TBatch:
|
def empty(batch: TBatch, index: IndexType | None = None) -> TBatch:
|
||||||
return deepcopy(batch).empty_(index)
|
return deepcopy(batch).empty_(index)
|
||||||
|
|
||||||
def update(self, batch: dict | TBatch | None = None, **kwargs: Any) -> None:
|
def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None:
|
||||||
if batch is None:
|
if batch is None:
|
||||||
self.update(kwargs)
|
self.update(kwargs)
|
||||||
return
|
return
|
||||||
@ -902,11 +903,11 @@ class Batch(BatchProtocol):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def split(
|
def split(
|
||||||
self: TBatch,
|
self,
|
||||||
size: int,
|
size: int,
|
||||||
shuffle: bool = True,
|
shuffle: bool = True,
|
||||||
merge_last: bool = False,
|
merge_last: bool = False,
|
||||||
) -> Iterator[TBatch]:
|
) -> Iterator[Self]:
|
||||||
length = len(self)
|
length = len(self)
|
||||||
if size == -1:
|
if size == -1:
|
||||||
size = length
|
size = length
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, cast
|
from typing import Any, Self, cast
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -111,7 +111,7 @@ class ReplayBuffer:
|
|||||||
to_hdf5(self.__dict__, f, compression=compression)
|
to_hdf5(self.__dict__, f, compression=compression)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_hdf5(cls, path: str, device: str | None = None) -> "ReplayBuffer":
|
def load_hdf5(cls, path: str, device: str | None = None) -> Self:
|
||||||
"""Load replay buffer from HDF5 file."""
|
"""Load replay buffer from HDF5 file."""
|
||||||
with h5py.File(path, "r") as f:
|
with h5py.File(path, "r") as f:
|
||||||
buf = cls.__new__(cls)
|
buf = cls.__new__(cls)
|
||||||
@ -128,7 +128,7 @@ class ReplayBuffer:
|
|||||||
truncated: h5py.Dataset,
|
truncated: h5py.Dataset,
|
||||||
done: h5py.Dataset,
|
done: h5py.Dataset,
|
||||||
obs_next: h5py.Dataset,
|
obs_next: h5py.Dataset,
|
||||||
) -> "ReplayBuffer":
|
) -> Self:
|
||||||
size = len(obs)
|
size = len(obs)
|
||||||
assert all(
|
assert all(
|
||||||
len(dset) == size for dset in [obs, act, rew, terminated, truncated, done, obs_next]
|
len(dset) == size for dset in [obs, act, rew, terminated, truncated, done, obs_next]
|
||||||
|
@ -62,13 +62,13 @@ class Collector:
|
|||||||
policy: BasePolicy,
|
policy: BasePolicy,
|
||||||
env: gym.Env | BaseVectorEnv,
|
env: gym.Env | BaseVectorEnv,
|
||||||
buffer: ReplayBuffer | None = None,
|
buffer: ReplayBuffer | None = None,
|
||||||
preprocess_fn: Callable[..., Batch] | None = None,
|
preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None,
|
||||||
exploration_noise: bool = False,
|
exploration_noise: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
|
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
|
||||||
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
|
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
|
||||||
self.env = DummyVectorEnv([lambda: env])
|
self.env = DummyVectorEnv([lambda: env]) # type: ignore
|
||||||
else:
|
else:
|
||||||
self.env = env # type: ignore
|
self.env = env # type: ignore
|
||||||
self.env_num = len(self.env)
|
self.env_num = len(self.env)
|
||||||
@ -413,7 +413,7 @@ class AsyncCollector(Collector):
|
|||||||
policy: BasePolicy,
|
policy: BasePolicy,
|
||||||
env: BaseVectorEnv,
|
env: BaseVectorEnv,
|
||||||
buffer: ReplayBuffer | None = None,
|
buffer: ReplayBuffer | None = None,
|
||||||
preprocess_fn: Callable[..., Batch] | None = None,
|
preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None,
|
||||||
exploration_noise: bool = False,
|
exploration_noise: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
# assert env.is_async
|
# assert env.is_async
|
||||||
|
9
tianshou/env/utils.py
vendored
9
tianshou/env/utils.py
vendored
@ -1,4 +1,4 @@
|
|||||||
from typing import TYPE_CHECKING, Any, Union
|
from typing import Any
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import gymnasium
|
import gymnasium
|
||||||
@ -6,12 +6,7 @@ import numpy as np
|
|||||||
|
|
||||||
from tianshou.env.pettingzoo_env import PettingZooEnv
|
from tianshou.env.pettingzoo_env import PettingZooEnv
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
ENV_TYPE = gymnasium.Env | PettingZooEnv
|
||||||
import gym
|
|
||||||
|
|
||||||
# TODO: remove gym entirely? Currently mypy complains in several places
|
|
||||||
# if gym.Env is removed from the Union
|
|
||||||
ENV_TYPE = Union[gymnasium.Env, "gym.Env", PettingZooEnv]
|
|
||||||
|
|
||||||
gym_new_venv_step_type = tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]
|
gym_new_venv_step_type = tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ from numba import njit
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from tianshou.data import ReplayBuffer, to_numpy, to_torch_as
|
from tianshou.data import ReplayBuffer, to_numpy, to_torch_as
|
||||||
from tianshou.data.batch import BatchProtocol, TBatch
|
from tianshou.data.batch import BatchProtocol
|
||||||
from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol
|
from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol
|
||||||
from tianshou.utils import MultipleLRSchedulers
|
from tianshou.utils import MultipleLRSchedulers
|
||||||
|
|
||||||
@ -185,7 +185,7 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def map_action(self, act: TBatch) -> TBatch:
|
def map_action(self, act: BatchProtocol) -> BatchProtocol:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
Loading…
x
Reference in New Issue
Block a user