Minor: use Self type where appropriate (#942)

Small typing improvement, related to
https://github.com/thu-ml/tianshou/pull/915#discussion_r1329734222
This commit is contained in:
Michael Panchenko 2023-09-20 00:40:32 +02:00 committed by GitHub
parent 2cc34fb72b
commit c8e7d02cba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 40 deletions

View File

@ -6,6 +6,7 @@ from numbers import Number
from typing import (
Any,
Protocol,
Self,
TypeVar,
Union,
cast,
@ -232,7 +233,7 @@ class BatchProtocol(Protocol):
...
@overload
def __getitem__(self: TBatch, index: IndexType) -> TBatch:
def __getitem__(self, index: IndexType) -> Self:
...
def __getitem__(self, index: str | IndexType) -> Any:
@ -241,22 +242,22 @@ class BatchProtocol(Protocol):
def __setitem__(self, index: str | IndexType, value: Any) -> None:
...
def __iadd__(self: TBatch, other: TBatch | Number | np.number) -> TBatch:
def __iadd__(self, other: Self | Number | np.number) -> Self:
...
def __add__(self: TBatch, other: TBatch | Number | np.number) -> TBatch:
def __add__(self, other: Self | Number | np.number) -> Self:
...
def __imul__(self: TBatch, value: Number | np.number) -> TBatch:
def __imul__(self, value: Number | np.number) -> Self:
...
def __mul__(self: TBatch, value: Number | np.number) -> TBatch:
def __mul__(self, value: Number | np.number) -> Self:
...
def __itruediv__(self: TBatch, value: Number | np.number) -> TBatch:
def __itruediv__(self, value: Number | np.number) -> Self:
...
def __truediv__(self: TBatch, value: Number | np.number) -> TBatch:
def __truediv__(self, value: Number | np.number) -> Self:
...
def __repr__(self) -> str:
@ -274,7 +275,7 @@ class BatchProtocol(Protocol):
"""Change all numpy.ndarray to torch.Tensor in-place."""
...
def cat_(self, batches: TBatch | Sequence[dict | TBatch]) -> None:
def cat_(self, batches: Self | Sequence[dict | Self]) -> None:
"""Concatenate a list of (or one) Batch objects into current batch."""
...
@ -298,7 +299,7 @@ class BatchProtocol(Protocol):
"""
...
def stack_(self, batches: Sequence[dict | TBatch], axis: int = 0) -> None:
def stack_(self, batches: Sequence[dict | Self], axis: int = 0) -> None:
"""Stack a list of Batch object into current batch."""
...
@ -327,7 +328,7 @@ class BatchProtocol(Protocol):
"""
...
def empty_(self: TBatch, index: slice | IndexType | None = None) -> TBatch:
def empty_(self, index: slice | IndexType | None = None) -> Self:
"""Return an empty Batch object with 0 or None filled.
If "index" is specified, it will only reset the specific indexed-data.
@ -362,7 +363,7 @@ class BatchProtocol(Protocol):
"""
...
def update(self, batch: dict | TBatch | None = None, **kwargs: Any) -> None:
def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None:
"""Update this batch from another dict/Batch."""
...
@ -373,11 +374,11 @@ class BatchProtocol(Protocol):
...
def split(
self: TBatch,
self,
size: int,
shuffle: bool = True,
merge_last: bool = False,
) -> Iterator[TBatch]:
) -> Iterator[Self]:
"""Split whole data into multiple small batches.
:param int size: divide the data batch with the given size, but one
@ -457,7 +458,7 @@ class Batch(BatchProtocol):
...
@overload
def __getitem__(self: TBatch, index: IndexType) -> TBatch:
def __getitem__(self, index: IndexType) -> Self:
...
def __getitem__(self, index: str | IndexType) -> Any:
@ -501,7 +502,7 @@ class Batch(BatchProtocol):
else:
self.__dict__[key][index] = None
def __iadd__(self: TBatch, other: TBatch | Number | np.number) -> TBatch:
def __iadd__(self, other: Self | Number | np.number) -> Self:
"""Algebraic addition with another Batch instance in-place."""
if isinstance(other, Batch):
for (batch_key, obj), value in zip(
@ -521,11 +522,11 @@ class Batch(BatchProtocol):
return self
raise TypeError("Only addition of Batch or number is supported.")
def __add__(self: TBatch, other: TBatch | Number | np.number) -> TBatch:
def __add__(self, other: Self | Number | np.number) -> Self:
"""Algebraic addition with another Batch instance out-of-place."""
return deepcopy(self).__iadd__(other)
def __imul__(self: TBatch, value: Number | np.number) -> TBatch:
def __imul__(self, value: Number | np.number) -> Self:
"""Algebraic multiplication with a scalar value in-place."""
assert _is_number(value), "Only multiplication by a number is supported."
for batch_key, obj in self.__dict__.items():
@ -534,11 +535,11 @@ class Batch(BatchProtocol):
self.__dict__[batch_key] *= value
return self
def __mul__(self: TBatch, value: Number | np.number) -> TBatch:
def __mul__(self, value: Number | np.number) -> Self:
"""Algebraic multiplication with a scalar value out-of-place."""
return deepcopy(self).__imul__(value)
def __itruediv__(self: TBatch, value: Number | np.number) -> TBatch:
def __itruediv__(self, value: Number | np.number) -> Self:
"""Algebraic division with a scalar value in-place."""
assert _is_number(value), "Only division by a number is supported."
for batch_key, obj in self.__dict__.items():
@ -547,7 +548,7 @@ class Batch(BatchProtocol):
self.__dict__[batch_key] /= value
return self
def __truediv__(self: TBatch, value: Number | np.number) -> TBatch:
def __truediv__(self, value: Number | np.number) -> Self:
"""Algebraic division with a scalar value out-of-place."""
return deepcopy(self).__itruediv__(value)
@ -604,7 +605,7 @@ class Batch(BatchProtocol):
obj = obj.type(dtype) # noqa: PLW2901
self.__dict__[batch_key] = obj
def __cat(self: TBatch, batches: Sequence[dict | TBatch], lens: list[int]) -> None:
def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:
"""Private method for Batch.cat_.
::
@ -798,7 +799,7 @@ class Batch(BatchProtocol):
# can't cast to a generic type, so we have to ignore the type here
return batch # type: ignore
def empty_(self: TBatch, index: slice | IndexType | None = None) -> TBatch:
def empty_(self, index: slice | IndexType | None = None) -> Self:
for batch_key, obj in self.items():
if isinstance(obj, torch.Tensor): # most often case
self.__dict__[batch_key][index] = 0
@ -826,7 +827,7 @@ class Batch(BatchProtocol):
def empty(batch: TBatch, index: IndexType | None = None) -> TBatch:
return deepcopy(batch).empty_(index)
def update(self, batch: dict | TBatch | None = None, **kwargs: Any) -> None:
def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None:
if batch is None:
self.update(kwargs)
return
@ -902,11 +903,11 @@ class Batch(BatchProtocol):
)
def split(
self: TBatch,
self,
size: int,
shuffle: bool = True,
merge_last: bool = False,
) -> Iterator[TBatch]:
) -> Iterator[Self]:
length = len(self)
if size == -1:
size = length

View File

@ -1,4 +1,4 @@
from typing import Any, cast
from typing import Any, Self, cast
import h5py
import numpy as np
@ -111,7 +111,7 @@ class ReplayBuffer:
to_hdf5(self.__dict__, f, compression=compression)
@classmethod
def load_hdf5(cls, path: str, device: str | None = None) -> "ReplayBuffer":
def load_hdf5(cls, path: str, device: str | None = None) -> Self:
"""Load replay buffer from HDF5 file."""
with h5py.File(path, "r") as f:
buf = cls.__new__(cls)
@ -128,7 +128,7 @@ class ReplayBuffer:
truncated: h5py.Dataset,
done: h5py.Dataset,
obs_next: h5py.Dataset,
) -> "ReplayBuffer":
) -> Self:
size = len(obs)
assert all(
len(dset) == size for dset in [obs, act, rew, terminated, truncated, done, obs_next]

View File

@ -62,13 +62,13 @@ class Collector:
policy: BasePolicy,
env: gym.Env | BaseVectorEnv,
buffer: ReplayBuffer | None = None,
preprocess_fn: Callable[..., Batch] | None = None,
preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None,
exploration_noise: bool = False,
) -> None:
super().__init__()
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
self.env = DummyVectorEnv([lambda: env])
self.env = DummyVectorEnv([lambda: env]) # type: ignore
else:
self.env = env # type: ignore
self.env_num = len(self.env)
@ -413,7 +413,7 @@ class AsyncCollector(Collector):
policy: BasePolicy,
env: BaseVectorEnv,
buffer: ReplayBuffer | None = None,
preprocess_fn: Callable[..., Batch] | None = None,
preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None,
exploration_noise: bool = False,
) -> None:
# assert env.is_async

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Union
from typing import Any
import cloudpickle
import gymnasium
@ -6,12 +6,7 @@ import numpy as np
from tianshou.env.pettingzoo_env import PettingZooEnv
if TYPE_CHECKING:
import gym
# TODO: remove gym entirely? Currently mypy complains in several places
# if gym.Env is removed from the Union
ENV_TYPE = Union[gymnasium.Env, "gym.Env", PettingZooEnv]
ENV_TYPE = gymnasium.Env | PettingZooEnv
gym_new_venv_step_type = tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]

View File

@ -11,7 +11,7 @@ from numba import njit
from torch import nn
from tianshou.data import ReplayBuffer, to_numpy, to_torch_as
from tianshou.data.batch import BatchProtocol, TBatch
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol
from tianshou.utils import MultipleLRSchedulers
@ -185,7 +185,7 @@ class BasePolicy(ABC, nn.Module):
"""
@overload
def map_action(self, act: TBatch) -> TBatch:
def map_action(self, act: BatchProtocol) -> BatchProtocol:
...
@overload