Minor: use Self type where appropriate (#942)
Small typing improvement, related to https://github.com/thu-ml/tianshou/pull/915#discussion_r1329734222
This commit is contained in:
parent
2cc34fb72b
commit
c8e7d02cba
@ -6,6 +6,7 @@ from numbers import Number
|
||||
from typing import (
|
||||
Any,
|
||||
Protocol,
|
||||
Self,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
@ -232,7 +233,7 @@ class BatchProtocol(Protocol):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self: TBatch, index: IndexType) -> TBatch:
|
||||
def __getitem__(self, index: IndexType) -> Self:
|
||||
...
|
||||
|
||||
def __getitem__(self, index: str | IndexType) -> Any:
|
||||
@ -241,22 +242,22 @@ class BatchProtocol(Protocol):
|
||||
def __setitem__(self, index: str | IndexType, value: Any) -> None:
|
||||
...
|
||||
|
||||
def __iadd__(self: TBatch, other: TBatch | Number | np.number) -> TBatch:
|
||||
def __iadd__(self, other: Self | Number | np.number) -> Self:
|
||||
...
|
||||
|
||||
def __add__(self: TBatch, other: TBatch | Number | np.number) -> TBatch:
|
||||
def __add__(self, other: Self | Number | np.number) -> Self:
|
||||
...
|
||||
|
||||
def __imul__(self: TBatch, value: Number | np.number) -> TBatch:
|
||||
def __imul__(self, value: Number | np.number) -> Self:
|
||||
...
|
||||
|
||||
def __mul__(self: TBatch, value: Number | np.number) -> TBatch:
|
||||
def __mul__(self, value: Number | np.number) -> Self:
|
||||
...
|
||||
|
||||
def __itruediv__(self: TBatch, value: Number | np.number) -> TBatch:
|
||||
def __itruediv__(self, value: Number | np.number) -> Self:
|
||||
...
|
||||
|
||||
def __truediv__(self: TBatch, value: Number | np.number) -> TBatch:
|
||||
def __truediv__(self, value: Number | np.number) -> Self:
|
||||
...
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -274,7 +275,7 @@ class BatchProtocol(Protocol):
|
||||
"""Change all numpy.ndarray to torch.Tensor in-place."""
|
||||
...
|
||||
|
||||
def cat_(self, batches: TBatch | Sequence[dict | TBatch]) -> None:
|
||||
def cat_(self, batches: Self | Sequence[dict | Self]) -> None:
|
||||
"""Concatenate a list of (or one) Batch objects into current batch."""
|
||||
...
|
||||
|
||||
@ -298,7 +299,7 @@ class BatchProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def stack_(self, batches: Sequence[dict | TBatch], axis: int = 0) -> None:
|
||||
def stack_(self, batches: Sequence[dict | Self], axis: int = 0) -> None:
|
||||
"""Stack a list of Batch object into current batch."""
|
||||
...
|
||||
|
||||
@ -327,7 +328,7 @@ class BatchProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def empty_(self: TBatch, index: slice | IndexType | None = None) -> TBatch:
|
||||
def empty_(self, index: slice | IndexType | None = None) -> Self:
|
||||
"""Return an empty Batch object with 0 or None filled.
|
||||
|
||||
If "index" is specified, it will only reset the specific indexed-data.
|
||||
@ -362,7 +363,7 @@ class BatchProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def update(self, batch: dict | TBatch | None = None, **kwargs: Any) -> None:
|
||||
def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None:
|
||||
"""Update this batch from another dict/Batch."""
|
||||
...
|
||||
|
||||
@ -373,11 +374,11 @@ class BatchProtocol(Protocol):
|
||||
...
|
||||
|
||||
def split(
|
||||
self: TBatch,
|
||||
self,
|
||||
size: int,
|
||||
shuffle: bool = True,
|
||||
merge_last: bool = False,
|
||||
) -> Iterator[TBatch]:
|
||||
) -> Iterator[Self]:
|
||||
"""Split whole data into multiple small batches.
|
||||
|
||||
:param int size: divide the data batch with the given size, but one
|
||||
@ -457,7 +458,7 @@ class Batch(BatchProtocol):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self: TBatch, index: IndexType) -> TBatch:
|
||||
def __getitem__(self, index: IndexType) -> Self:
|
||||
...
|
||||
|
||||
def __getitem__(self, index: str | IndexType) -> Any:
|
||||
@ -501,7 +502,7 @@ class Batch(BatchProtocol):
|
||||
else:
|
||||
self.__dict__[key][index] = None
|
||||
|
||||
def __iadd__(self: TBatch, other: TBatch | Number | np.number) -> TBatch:
|
||||
def __iadd__(self, other: Self | Number | np.number) -> Self:
|
||||
"""Algebraic addition with another Batch instance in-place."""
|
||||
if isinstance(other, Batch):
|
||||
for (batch_key, obj), value in zip(
|
||||
@ -521,11 +522,11 @@ class Batch(BatchProtocol):
|
||||
return self
|
||||
raise TypeError("Only addition of Batch or number is supported.")
|
||||
|
||||
def __add__(self: TBatch, other: TBatch | Number | np.number) -> TBatch:
|
||||
def __add__(self, other: Self | Number | np.number) -> Self:
|
||||
"""Algebraic addition with another Batch instance out-of-place."""
|
||||
return deepcopy(self).__iadd__(other)
|
||||
|
||||
def __imul__(self: TBatch, value: Number | np.number) -> TBatch:
|
||||
def __imul__(self, value: Number | np.number) -> Self:
|
||||
"""Algebraic multiplication with a scalar value in-place."""
|
||||
assert _is_number(value), "Only multiplication by a number is supported."
|
||||
for batch_key, obj in self.__dict__.items():
|
||||
@ -534,11 +535,11 @@ class Batch(BatchProtocol):
|
||||
self.__dict__[batch_key] *= value
|
||||
return self
|
||||
|
||||
def __mul__(self: TBatch, value: Number | np.number) -> TBatch:
|
||||
def __mul__(self, value: Number | np.number) -> Self:
|
||||
"""Algebraic multiplication with a scalar value out-of-place."""
|
||||
return deepcopy(self).__imul__(value)
|
||||
|
||||
def __itruediv__(self: TBatch, value: Number | np.number) -> TBatch:
|
||||
def __itruediv__(self, value: Number | np.number) -> Self:
|
||||
"""Algebraic division with a scalar value in-place."""
|
||||
assert _is_number(value), "Only division by a number is supported."
|
||||
for batch_key, obj in self.__dict__.items():
|
||||
@ -547,7 +548,7 @@ class Batch(BatchProtocol):
|
||||
self.__dict__[batch_key] /= value
|
||||
return self
|
||||
|
||||
def __truediv__(self: TBatch, value: Number | np.number) -> TBatch:
|
||||
def __truediv__(self, value: Number | np.number) -> Self:
|
||||
"""Algebraic division with a scalar value out-of-place."""
|
||||
return deepcopy(self).__itruediv__(value)
|
||||
|
||||
@ -604,7 +605,7 @@ class Batch(BatchProtocol):
|
||||
obj = obj.type(dtype) # noqa: PLW2901
|
||||
self.__dict__[batch_key] = obj
|
||||
|
||||
def __cat(self: TBatch, batches: Sequence[dict | TBatch], lens: list[int]) -> None:
|
||||
def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:
|
||||
"""Private method for Batch.cat_.
|
||||
|
||||
::
|
||||
@ -798,7 +799,7 @@ class Batch(BatchProtocol):
|
||||
# can't cast to a generic type, so we have to ignore the type here
|
||||
return batch # type: ignore
|
||||
|
||||
def empty_(self: TBatch, index: slice | IndexType | None = None) -> TBatch:
|
||||
def empty_(self, index: slice | IndexType | None = None) -> Self:
|
||||
for batch_key, obj in self.items():
|
||||
if isinstance(obj, torch.Tensor): # most often case
|
||||
self.__dict__[batch_key][index] = 0
|
||||
@ -826,7 +827,7 @@ class Batch(BatchProtocol):
|
||||
def empty(batch: TBatch, index: IndexType | None = None) -> TBatch:
|
||||
return deepcopy(batch).empty_(index)
|
||||
|
||||
def update(self, batch: dict | TBatch | None = None, **kwargs: Any) -> None:
|
||||
def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None:
|
||||
if batch is None:
|
||||
self.update(kwargs)
|
||||
return
|
||||
@ -902,11 +903,11 @@ class Batch(BatchProtocol):
|
||||
)
|
||||
|
||||
def split(
|
||||
self: TBatch,
|
||||
self,
|
||||
size: int,
|
||||
shuffle: bool = True,
|
||||
merge_last: bool = False,
|
||||
) -> Iterator[TBatch]:
|
||||
) -> Iterator[Self]:
|
||||
length = len(self)
|
||||
if size == -1:
|
||||
size = length
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, cast
|
||||
from typing import Any, Self, cast
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
@ -111,7 +111,7 @@ class ReplayBuffer:
|
||||
to_hdf5(self.__dict__, f, compression=compression)
|
||||
|
||||
@classmethod
|
||||
def load_hdf5(cls, path: str, device: str | None = None) -> "ReplayBuffer":
|
||||
def load_hdf5(cls, path: str, device: str | None = None) -> Self:
|
||||
"""Load replay buffer from HDF5 file."""
|
||||
with h5py.File(path, "r") as f:
|
||||
buf = cls.__new__(cls)
|
||||
@ -128,7 +128,7 @@ class ReplayBuffer:
|
||||
truncated: h5py.Dataset,
|
||||
done: h5py.Dataset,
|
||||
obs_next: h5py.Dataset,
|
||||
) -> "ReplayBuffer":
|
||||
) -> Self:
|
||||
size = len(obs)
|
||||
assert all(
|
||||
len(dset) == size for dset in [obs, act, rew, terminated, truncated, done, obs_next]
|
||||
|
@ -62,13 +62,13 @@ class Collector:
|
||||
policy: BasePolicy,
|
||||
env: gym.Env | BaseVectorEnv,
|
||||
buffer: ReplayBuffer | None = None,
|
||||
preprocess_fn: Callable[..., Batch] | None = None,
|
||||
preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None,
|
||||
exploration_noise: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
|
||||
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
|
||||
self.env = DummyVectorEnv([lambda: env])
|
||||
self.env = DummyVectorEnv([lambda: env]) # type: ignore
|
||||
else:
|
||||
self.env = env # type: ignore
|
||||
self.env_num = len(self.env)
|
||||
@ -413,7 +413,7 @@ class AsyncCollector(Collector):
|
||||
policy: BasePolicy,
|
||||
env: BaseVectorEnv,
|
||||
buffer: ReplayBuffer | None = None,
|
||||
preprocess_fn: Callable[..., Batch] | None = None,
|
||||
preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None,
|
||||
exploration_noise: bool = False,
|
||||
) -> None:
|
||||
# assert env.is_async
|
||||
|
9
tianshou/env/utils.py
vendored
9
tianshou/env/utils.py
vendored
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
from typing import Any
|
||||
|
||||
import cloudpickle
|
||||
import gymnasium
|
||||
@ -6,12 +6,7 @@ import numpy as np
|
||||
|
||||
from tianshou.env.pettingzoo_env import PettingZooEnv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import gym
|
||||
|
||||
# TODO: remove gym entirely? Currently mypy complains in several places
|
||||
# if gym.Env is removed from the Union
|
||||
ENV_TYPE = Union[gymnasium.Env, "gym.Env", PettingZooEnv]
|
||||
ENV_TYPE = gymnasium.Env | PettingZooEnv
|
||||
|
||||
gym_new_venv_step_type = tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]
|
||||
|
||||
|
@ -11,7 +11,7 @@ from numba import njit
|
||||
from torch import nn
|
||||
|
||||
from tianshou.data import ReplayBuffer, to_numpy, to_torch_as
|
||||
from tianshou.data.batch import BatchProtocol, TBatch
|
||||
from tianshou.data.batch import BatchProtocol
|
||||
from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol
|
||||
from tianshou.utils import MultipleLRSchedulers
|
||||
|
||||
@ -185,7 +185,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
"""
|
||||
|
||||
@overload
|
||||
def map_action(self, act: TBatch) -> TBatch:
|
||||
def map_action(self, act: BatchProtocol) -> BatchProtocol:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
Loading…
x
Reference in New Issue
Block a user