diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index 16f76f9..4e3f24e 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -3,6 +3,7 @@ + [ ] RL algorithm bug + [ ] documentation request (i.e. "X is missing from the documentation.") + [ ] new feature request + + [ ] design request (i.e. "X should be changed to Y.") - [ ] I have visited the [source website](https://github.com/thu-ml/tianshou/) - [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates - [ ] I have mentioned version numbers, operating system and environment, where applicable: diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 2805835..029220c 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,9 +1,9 @@ -- [ ] I have marked all applicable categories: - + [ ] exception-raising fix - + [ ] algorithm implementation fix - + [ ] documentation modification - + [ ] new feature -- [ ] I have reformatted the code using `make format` (**required**) -- [ ] I have checked the code using `make commit-checks` (**required**) -- [ ] If applicable, I have mentioned the relevant/related issue(s) -- [ ] If applicable, I have listed every items in this Pull Request below +- [ ] I have added the correct label(s) to this Pull Request or linked the relevant issue(s) +- [ ] I have provided a description of the changes in this Pull Request +- [ ] I have added documentation for my changes +- [ ] If applicable, I have added tests to cover my changes. +- [ ] I have reformatted the code using `poe format` +- [ ] I have checked style and types with `poe lint` and `poe type-check` +- [ ] (Optional) I ran tests locally with `poe test` +(or a subset of them with `poe test-reduced`) ,and they pass +- [ ] (Optional) I have tested that documentation builds correctly with `poe doc-build` \ No newline at end of file diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index 877d103..9128eaa 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -172,8 +172,8 @@ if __name__ == "__main__": print(env.spec.reward_threshold) print(obs.shape, action_num) for _ in range(4000): - obs, rew, done, info = env.step(0) - if done: + obs, rew, terminated, truncated, info = env.step(0) + if terminated or truncated: env.reset() - print(obs.shape, rew, done) + print(obs.shape, rew, terminated, truncated) cv2.imwrite("test.png", obs.transpose(1, 2, 0)[..., :3]) diff --git a/test/base/test_policy.py b/test/base/test_policy.py new file mode 100644 index 0000000..672f194 --- /dev/null +++ b/test/base/test_policy.py @@ -0,0 +1,71 @@ +import gymnasium as gym +import numpy as np +import pytest +import torch +from torch.distributions import Categorical, Independent, Normal + +from tianshou.policy import PPOPolicy +from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.discrete import Actor + +obs_shape = (5,) + + +def _to_hashable(x: np.ndarray | int): + return x if isinstance(x, int) else tuple(x.tolist()) + + +@pytest.fixture(params=["continuous", "discrete"]) +def policy(request): + action_type = request.param + if action_type == "continuous": + action_space = gym.spaces.Box(low=-1, high=1, shape=(3,)) + actor = ActorProb( + Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape), + action_shape=action_space.shape, + ) + dist_fn = lambda *logits: Independent(Normal(*logits), 1) + elif action_type == "discrete": + action_space = gym.spaces.Discrete(3) + actor = Actor( + Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n), + action_shape=action_space.n, + ) + dist_fn = lambda logits: Categorical(logits=logits) + else: + raise ValueError(f"Unknown action type: {action_type}") + + critic = Critic( + Net(obs_shape, hidden_sizes=[64, 64]), + ) + + actor_critic = ActorCritic(actor, critic) + optim = torch.optim.Adam(actor_critic.parameters(), lr=1e-3) + + policy = PPOPolicy( + actor=actor, + critic=critic, + dist_fn=dist_fn, + optim=optim, + action_space=action_space, + action_scaling=False, + ) + policy.eval() + return policy + + +class TestPolicyBasics: + def test_get_action(self, policy): + sample_obs = torch.randn(obs_shape) + policy.deterministic_eval = False + actions = [policy.compute_action(sample_obs) for _ in range(10)] + assert all(policy.action_space.contains(a) for a in actions) + + # check that the actions are different in non-deterministic mode + assert len(set(map(_to_hashable, actions))) > 1 + + policy.deterministic_eval = True + actions = [policy.compute_action(sample_obs) for _ in range(10)] + # check that the actions are the same in deterministic mode + assert len(set(map(_to_hashable, actions))) == 1 diff --git a/tianshou/data/types.py b/tianshou/data/types.py index 3a79b5c..a63a9d1 100644 --- a/tianshou/data/types.py +++ b/tianshou/data/types.py @@ -5,16 +5,24 @@ from tianshou.data import Batch from tianshou.data.batch import BatchProtocol, arr_type -class RolloutBatchProtocol(BatchProtocol): - """Typically, the outcome of sampling from a replay buffer.""" +class ObsBatchProtocol(BatchProtocol): + """Observations of an environment that a policy can turn into actions. + + Typically used inside a policy's forward + """ obs: arr_type | BatchProtocol + info: arr_type + + +class RolloutBatchProtocol(ObsBatchProtocol): + """Typically, the outcome of sampling from a replay buffer.""" + obs_next: arr_type | BatchProtocol act: arr_type rew: np.ndarray terminated: arr_type truncated: arr_type - info: arr_type class BatchWithReturnsProtocol(RolloutBatchProtocol): @@ -39,11 +47,17 @@ class RecurrentStateBatch(BatchProtocol): class ActBatchProtocol(BatchProtocol): """Simplest batch, just containing the action. Useful e.g., for random policy.""" - act: np.ndarray + act: arr_type -class ModelOutputBatchProtocol(ActBatchProtocol): - """Contains model output: (logits) and potentially hidden states.""" +class ActStateBatchProtocol(ActBatchProtocol): + """Contains action and state (which can be None), useful for policies that can support RNNs.""" + + state: dict | BatchProtocol | np.ndarray | None + + +class ModelOutputBatchProtocol(ActStateBatchProtocol): + """In addition to state and action, contains model output: (logits).""" logits: torch.Tensor state: dict | BatchProtocol | np.ndarray | None diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index c355d6a..205c2d5 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -10,6 +10,8 @@ import torch from tianshou.data.batch import Batch, _parse_value +# TODO: confusing name, could actually return a batch... +# Overrides and generic types should be added @no_type_check def to_numpy(x: Any) -> Batch | np.ndarray: """Return an object without torch.Tensor.""" diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index b4b85e3..50db30b 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -1,7 +1,7 @@ import logging from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Any, Literal, TypeAlias, cast, overload +from typing import Any, Literal, TypeAlias, cast import gymnasium as gym import numpy as np @@ -11,14 +11,18 @@ from numba import njit from torch import nn from tianshou.data import ReplayBuffer, to_numpy, to_torch_as -from tianshou.data.batch import BatchProtocol +from tianshou.data.batch import Batch, BatchProtocol, arr_type from tianshou.data.buffer.base import TBuffer -from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol +from tianshou.data.types import ( + ActBatchProtocol, + BatchWithReturnsProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) from tianshou.utils import MultipleLRSchedulers logger = logging.getLogger(__name__) - TLearningRateScheduler: TypeAlias = torch.optim.lr_scheduler.LRScheduler | MultipleLRSchedulers @@ -149,13 +153,39 @@ class BasePolicy(ABC, nn.Module): for tgt_param, src_param in zip(tgt.parameters(), src.parameters(), strict=True): tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data) + def compute_action( + self, + obs: arr_type, + info: dict[str, Any] | None = None, + state: dict | BatchProtocol | np.ndarray | None = None, + ) -> np.ndarray | int: + """Get action as int (for discrete env's) or array (for continuous ones) from + an env's observation and info. + + :param obs: observation from the gym's env. + :param info: information given by the gym's env. + :param state: the hidden state of RNN policy, used for recurrent policy. + :return: action as int (for discrete env's) or array (for continuous ones). + """ + # need to add empty batch dimension + obs = obs[None, :] + obs_batch = cast(ObsBatchProtocol, Batch(obs=obs, info=info)) + act = self.forward(obs_batch, state=state).act.squeeze() + if isinstance(act, torch.Tensor): + act = act.detach().cpu().numpy() + act = self.map_action(act) + if isinstance(self.action_space, Discrete): + # could be an array of shape (), easier to just convert to int + act = int(act) # type: ignore + return act + @abstractmethod def forward( self, - batch: RolloutBatchProtocol, + batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, - ) -> BatchProtocol: + ) -> ActBatchProtocol: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which MUST have the following keys: @@ -190,22 +220,19 @@ class BasePolicy(ABC, nn.Module): act = policy.map_action(act, batch) """ - @overload - def map_action(self, act: BatchProtocol) -> BatchProtocol: - ... - - @overload - def map_action(self, act: np.ndarray) -> np.ndarray: - ... - - @overload - def map_action(self, act: torch.Tensor) -> torch.Tensor: - ... + @staticmethod + def _action_to_numpy(act: arr_type) -> np.ndarray: + act = to_numpy(act) # NOTE: to_numpy could confusingly also return a Batch + if not isinstance(act, np.ndarray): + raise ValueError( + f"act should have been be a numpy.ndarray, but got {type(act)}.", + ) + return act def map_action( self, - act: BatchProtocol | np.ndarray | torch.Tensor, - ) -> BatchProtocol | np.ndarray | torch.Tensor: + act: arr_type, + ) -> np.ndarray: """Map raw network output to action range in gym's env.action_space. This function is called in :meth:`~tianshou.data.Collector.collect` and only @@ -223,24 +250,24 @@ class BasePolicy(ABC, nn.Module): :return: action in the same form of input "act" but remap to the target action space. """ - if isinstance(self.action_space, gym.spaces.Box) and isinstance(act, np.ndarray): - # currently this action mapping only supports np.ndarray action + act = self._action_to_numpy(act) + if isinstance(self.action_space, gym.spaces.Box): if self.action_bound_method == "clip": act = np.clip(act, -1.0, 1.0) elif self.action_bound_method == "tanh": act = np.tanh(act) if self.action_scaling: assert ( - np.min(act) >= -1.0 and np.max(act) <= 1.0 # type: ignore + np.min(act) >= -1.0 and np.max(act) <= 1.0 ), f"action scaling only accepts raw action range = [-1, 1], but got: {act}" low, high = self.action_space.low, self.action_space.high - act = low + (high - low) * (act + 1.0) / 2.0 # type: ignore + act = low + (high - low) * (act + 1.0) / 2.0 return act def map_action_inverse( self, - act: BatchProtocol | list | np.ndarray, - ) -> BatchProtocol | list | np.ndarray: + act: arr_type, + ) -> np.ndarray: """Inverse operation to :meth:`~tianshou.policy.BasePolicy.map_action`. This function is called in :meth:`~tianshou.data.Collector.collect` for @@ -252,17 +279,17 @@ class BasePolicy(ABC, nn.Module): :return: action remapped. """ + act = self._action_to_numpy(act) if isinstance(self.action_space, gym.spaces.Box): - act = to_numpy(act) - if isinstance(act, np.ndarray): - if self.action_scaling: - low, high = self.action_space.low, self.action_space.high - scale = high - low - eps = np.finfo(np.float32).eps.item() - scale[scale < eps] += eps - act = (act - low) * 2.0 / scale - 1.0 - if self.action_bound_method == "tanh": - act = (np.log(1.0 + act) - np.log(1.0 - act)) / 2.0 # type: ignore + if self.action_scaling: + low, high = self.action_space.low, self.action_space.high + scale = high - low + eps = np.finfo(np.float32).eps.item() + scale[scale < eps] += eps + act = (act - low) * 2.0 / scale - 1.0 + if self.action_bound_method == "tanh": + act = (np.log(1.0 + act) - np.log(1.0 - act)) / 2.0 + return act def process_buffer(self, buffer: TBuffer) -> TBuffer: diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 27336ac..5acacf1 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -7,7 +7,11 @@ import torch.nn.functional as F from tianshou.data import Batch, to_torch from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ModelOutputBatchProtocol, RolloutBatchProtocol +from tianshou.data.types import ( + ModelOutputBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) from tianshou.policy import BasePolicy from tianshou.policy.base import TLearningRateScheduler @@ -55,7 +59,7 @@ class ImitationPolicy(BasePolicy): def forward( self, - batch: RolloutBatchProtocol, + batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, ) -> ModelOutputBatchProtocol: diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index b3d1f80..14f9056 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -1,5 +1,5 @@ import copy -from typing import Any, Literal, Self +from typing import Any, Literal, Self, cast import gymnasium as gym import numpy as np @@ -8,7 +8,7 @@ import torch.nn.functional as F from tianshou.data import Batch, to_torch from tianshou.data.batch import BatchProtocol -from tianshou.data.types import RolloutBatchProtocol +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import BasePolicy from tianshou.policy.base import TLearningRateScheduler from tianshou.utils.net.continuous import VAE @@ -112,10 +112,10 @@ class BCQPolicy(BasePolicy): def forward( self, - batch: RolloutBatchProtocol, + batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, - ) -> Batch: + ) -> ActBatchProtocol: """Compute action over the given batch data.""" # There is "obs" in the Batch # obs_group: several groups. Each group has a state. @@ -134,7 +134,7 @@ class BCQPolicy(BasePolicy): max_indice = q1.argmax(0) act_group.append(act[max_indice].cpu().data.numpy().flatten()) act_group = np.array(act_group) - return Batch(act=act_group) + return cast(ActBatchProtocol, Batch(act=act_group)) def sync_weight(self) -> None: """Soft-update the weight for the target network.""" diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index 876ac19..c3bbf6b 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -156,7 +156,7 @@ class CQLPolicy(SACPolicy): self.soft_update(self.critic2_old, self.critic2, self.tau) def actor_pred(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - batch = Batch(obs=obs, info=None) + batch = Batch(obs=obs, info=[None] * len(obs)) obs_result = self(batch) return obs_result.act, obs_result.log_prob diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 7ce25b1..0ab54c1 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -7,7 +7,11 @@ import torch import torch.nn.functional as F from tianshou.data import Batch, ReplayBuffer, to_torch -from tianshou.data.types import ImitationBatchProtocol, RolloutBatchProtocol +from tianshou.data.types import ( + ImitationBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) from tianshou.policy import DQNPolicy from tianshou.policy.base import TLearningRateScheduler @@ -101,26 +105,25 @@ class DiscreteBCQPolicy(DQNPolicy): def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: batch = buffer[indices] # batch.obs_next: s_{t+n} + next_obs_batch = Batch(obs=batch.obs_next, info=[None] * len(batch)) # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - act = self(batch, input="obs_next").act + act = self(next_obs_batch).act target_q, _ = self.model_old(batch.obs_next) return target_q[np.arange(len(act)), act] def forward( # type: ignore self, - batch: RolloutBatchProtocol, + batch: ObsBatchProtocol, state: dict | Batch | np.ndarray | None = None, - input: str = "obs", **kwargs: Any, ) -> ImitationBatchProtocol: # TODO: Liskov substitution principle is violated here, the superclass # produces a batch with the field logits, but this one doesn't. # Should be fixed in the future! - obs = batch[input] - q_value, state = self.model(obs, state=state, info=batch.info) - if not hasattr(self, "max_action_num"): + q_value, state = self.model(batch.obs, state=state, info=batch.info) + if self.max_action_num is None: self.max_action_num = q_value.shape[1] - imitation_logits, _ = self.imitator(obs, state=state, info=batch.info) + imitation_logits, _ = self.imitator(batch.obs, state=state, info=batch.info) # mask actions for argmax ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 6d9b98f..016399e 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch from tianshou.data.batch import BatchProtocol -from tianshou.data.types import RolloutBatchProtocol +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import BasePolicy from tianshou.policy.base import TLearningRateScheduler from tianshou.utils.net.discrete import IntrinsicCuriosityModule @@ -72,10 +72,10 @@ class ICMPolicy(BasePolicy): def forward( self, - batch: RolloutBatchProtocol, + batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, - ) -> BatchProtocol: + ) -> ActBatchProtocol: """Compute action over the given batch data by inner policy. .. seealso:: diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index 2c7572f..52dadfe 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -6,7 +6,7 @@ import torch from tianshou.data import Batch from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ActBatchProtocol, RolloutBatchProtocol +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import BasePolicy from tianshou.policy.base import TLearningRateScheduler @@ -198,7 +198,7 @@ class PSRLPolicy(BasePolicy): def forward( self, - batch: RolloutBatchProtocol, + batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, ) -> ActBatchProtocol: @@ -213,9 +213,9 @@ class PSRLPolicy(BasePolicy): more detailed explanation. """ assert isinstance(batch.obs, np.ndarray), "only support np.ndarray observation" + # TODO: shouldn't the model output a state as well if state is passed (i.e. RNNs are involved)? act = self.model(batch.obs, state=state, info=batch.info) - result = Batch(act=act) - return cast(ActBatchProtocol, result) + return cast(ActBatchProtocol, Batch(act=act)) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: n_s, n_a = self.model.n_state, self.model.n_action diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdq.py index 79bd6f5..3623289 100644 --- a/tianshou/policy/modelfree/bdq.py +++ b/tianshou/policy/modelfree/bdq.py @@ -1,4 +1,4 @@ -from typing import Any, cast +from typing import Any, Literal, cast import gymnasium as gym import numpy as np @@ -9,6 +9,7 @@ from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( BatchWithReturnsProtocol, ModelOutputBatchProtocol, + ObsBatchProtocol, RolloutBatchProtocol, ) from tianshou.policy import DQNPolicy @@ -73,9 +74,10 @@ class BranchingDQNPolicy(DQNPolicy): ) self.model = cast(BranchingNet, self.model) - # TODO: mypy complains b/c max_action_num is declared in base class, see todo there + # TODO: this used to be a public property called max_action_num, + # but it collides with an attr of the same name in base class @property - def max_action_num(self) -> int: # type: ignore + def _action_per_branch(self) -> int: return self.model.action_per_branch # type: ignore @property @@ -83,15 +85,18 @@ class BranchingDQNPolicy(DQNPolicy): return self.model.num_branches # type: ignore def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - batch = buffer[indices] # batch.obs_next: s_{t+n} - result = self(batch, input="obs_next") + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + result = self(obs_next_batch) if self._target: # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - target_q = self(batch, model="model_old", input="obs_next").logits + target_q = self(obs_next_batch, model="model_old").logits else: target_q = result.logits if self.is_double: - act = np.expand_dims(self(batch, input="obs_next").act, -1) + act = np.expand_dims(self(obs_next_batch).act, -1) act = to_torch(act, dtype=torch.long, device=target_q.device) else: act = target_q.max(-1).indices.unsqueeze(-1) @@ -114,7 +119,7 @@ class BranchingDQNPolicy(DQNPolicy): mean_target_q = np.mean(target_q, -1) if len(target_q.shape) > 1 else target_q _target_q = rew + gamma * mean_target_q * (1 - end_flag) target_q = np.repeat(_target_q[..., None], self.num_branches, axis=-1) - target_q = np.repeat(target_q[..., None], self.max_action_num, axis=-1) + target_q = np.repeat(target_q[..., None], self._action_per_branch, axis=-1) batch.returns = to_torch_as(target_q, target_q_torch) if hasattr(batch, "weight"): # prio buffer update @@ -132,14 +137,14 @@ class BranchingDQNPolicy(DQNPolicy): def forward( self, - batch: RolloutBatchProtocol, + batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, - model: str = "model", - input: str = "obs", + model: Literal["model", "model_old"] = "model", **kwargs: Any, ) -> ModelOutputBatchProtocol: model = getattr(self, model) - obs = batch[input] + obs = batch.obs + # TODO: this is very contrived, see also iqn.py obs_next = obs.obs if hasattr(obs, "obs") else obs logits, hidden = model(obs_next, state=state, info=batch.info) act = to_numpy(logits.max(dim=-1)[1]) @@ -174,7 +179,11 @@ class BranchingDQNPolicy(DQNPolicy): if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): bsz = len(act) rand_mask = np.random.rand(bsz) < self.eps - rand_act = np.random.randint(low=0, high=self.max_action_num, size=(bsz, act.shape[-1])) + rand_act = np.random.randint( + low=0, + high=self._action_per_branch, + size=(bsz, act.shape[-1]), + ) if hasattr(batch.obs, "mask"): rand_act += batch.obs.mask act[rand_mask] = rand_act[rand_mask] diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index a6f0921..600694f 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -4,7 +4,7 @@ import gymnasium as gym import numpy as np import torch -from tianshou.data import ReplayBuffer +from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import DQNPolicy from tianshou.policy.base import TLearningRateScheduler @@ -90,11 +90,12 @@ class C51Policy(DQNPolicy): return super().compute_q_value((logits * self.support).sum(2), mask) def _target_dist(self, batch: RolloutBatchProtocol) -> torch.Tensor: + obs_next_batch = Batch(obs=batch.obs_next, info=[None] * len(batch)) if self._target: - act = self(batch, input="obs_next").act - next_dist = self(batch, model="model_old", input="obs_next").logits + act = self(obs_next_batch).act + next_dist = self(obs_next_batch, model="model_old").logits else: - next_batch = self(batch, input="obs_next") + next_batch = self(obs_next_batch) act = next_batch.act next_dist = next_batch.logits next_dist = next_dist[np.arange(len(act)), act, :] diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 2283254..c00d199 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -1,6 +1,6 @@ import warnings from copy import deepcopy -from typing import Any, Literal, Self +from typing import Any, Literal, Self, cast import gymnasium as gym import numpy as np @@ -8,7 +8,12 @@ import torch from tianshou.data import Batch, ReplayBuffer from tianshou.data.batch import BatchProtocol -from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol +from tianshou.data.types import ( + ActStateBatchProtocol, + BatchWithReturnsProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.policy import BasePolicy from tianshou.policy.base import TLearningRateScheduler @@ -118,8 +123,11 @@ class DDPGPolicy(BasePolicy): self.soft_update(self.critic_old, self.critic, self.tau) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - batch = buffer[indices] # batch.obs_next: s_{t+n} - return self.critic_old(batch.obs_next, self(batch, model="actor_old", input="obs_next").act) + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + return self.critic_old(obs_next_batch.obs, self(obs_next_batch, model="actor_old").act) def process_fn( self, @@ -138,12 +146,11 @@ class DDPGPolicy(BasePolicy): def forward( self, - batch: RolloutBatchProtocol, + batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, model: Literal["actor", "actor_old"] = "actor", - input: str = "obs", **kwargs: Any, - ) -> BatchProtocol: + ) -> ActStateBatchProtocol: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 2 keys: @@ -157,9 +164,8 @@ class DDPGPolicy(BasePolicy): more detailed explanation. """ model = getattr(self, model) - obs = batch[input] - actions, hidden = model(obs, state=state, info=batch.info) - return Batch(act=actions, state=hidden) + actions, hidden = model(batch.obs, state=state, info=batch.info) + return cast(ActStateBatchProtocol, Batch(act=actions, state=hidden)) @staticmethod def _mse_optimizer( diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 5425ead..1b4a342 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -8,7 +8,7 @@ from torch.distributions import Categorical from tianshou.data import Batch, ReplayBuffer, to_torch from tianshou.data.batch import BatchProtocol -from tianshou.data.types import RolloutBatchProtocol +from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import SACPolicy from tianshou.policy.base import TLearningRateScheduler @@ -92,13 +92,11 @@ class DiscreteSACPolicy(SACPolicy): def forward( # type: ignore self, - batch: Batch, + batch: ObsBatchProtocol, state: dict | Batch | np.ndarray | None = None, - input: str = "obs", **kwargs: Any, ) -> Batch: - obs = batch[input] - logits, hidden = self.actor(obs, state=state, info=batch.info) + logits, hidden = self.actor(batch.obs, state=state, info=batch.info) dist = Categorical(logits=logits) if self.deterministic_eval and not self.training: act = logits.argmax(axis=-1) @@ -107,12 +105,15 @@ class DiscreteSACPolicy(SACPolicy): return Batch(logits=logits, act=act, state=hidden, dist=dist) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - batch = buffer[indices] # batch.obs: s_{t+n} - obs_next_result = self(batch, input="obs_next") + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + obs_next_result = self(obs_next_batch) dist = obs_next_result.dist target_q = dist.probs * torch.min( - self.critic_old(batch.obs_next), - self.critic2_old(batch.obs_next), + self.critic_old(obs_next_batch.obs), + self.critic2_old(obs_next_batch.obs), ) return target_q.sum(dim=-1) + self.alpha * dist.entropy() diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 5de9541..f8c4582 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any, Self, cast +from typing import Any, Literal, Self, cast import gymnasium as gym import numpy as np @@ -10,6 +10,7 @@ from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( BatchWithReturnsProtocol, ModelOutputBatchProtocol, + ObsBatchProtocol, RolloutBatchProtocol, ) from tianshou.policy import BasePolicy @@ -91,7 +92,7 @@ class DQNPolicy(BasePolicy): self.clip_loss_grad = clip_loss_grad # TODO: set in forward, fix this! - self.max_action_num: int + self.max_action_num: int | None = None def set_eps(self, eps: float) -> None: """Set the eps for epsilon-greedy exploration.""" @@ -108,11 +109,14 @@ class DQNPolicy(BasePolicy): self.model_old.load_state_dict(self.model.state_dict()) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - batch = buffer[indices] # batch.obs_next: s_{t+n} - result = self(batch, input="obs_next") + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + result = self(obs_next_batch) if self._target: # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - target_q = self(batch, model="model_old", input="obs_next").logits + target_q = self(obs_next_batch, model="model_old").logits else: target_q = result.logits if self.is_double: @@ -151,10 +155,9 @@ class DQNPolicy(BasePolicy): def forward( self, - batch: RolloutBatchProtocol, + batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, - model: str = "model", - input: str = "obs", + model: Literal["model", "model_old"] = "model", **kwargs: Any, ) -> ModelOutputBatchProtocol: """Compute action over the given batch data. @@ -185,11 +188,12 @@ class DQNPolicy(BasePolicy): more detailed explanation. """ model = getattr(self, model) - obs = batch[input] + obs = batch.obs + # TODO: this is convoluted! See also other places where this is done. obs_next = obs.obs if hasattr(obs, "obs") else obs logits, hidden = model(obs_next, state=state, info=batch.info) q = self.compute_q_value(logits, getattr(obs, "mask", None)) - if not hasattr(self, "max_action_num"): + if self.max_action_num is None: self.max_action_num = q.shape[1] act = to_numpy(q.max(dim=1)[1]) result = Batch(logits=logits, act=act, state=hidden) @@ -226,6 +230,9 @@ class DQNPolicy(BasePolicy): if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): bsz = len(act) rand_mask = np.random.rand(bsz) < self.eps + assert ( + self.max_action_num is not None + ), "Can't call this method before max_action_num was set in first forward" q = np.random.rand(bsz, self.max_action_num) # [0, 1] if hasattr(batch.obs, "mask"): q += batch.obs.mask diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index a615620..c680436 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -1,4 +1,4 @@ -from typing import Any, cast +from typing import Any, Literal, cast import gymnasium as gym import numpy as np @@ -6,7 +6,7 @@ import torch import torch.nn.functional as F from tianshou.data import Batch, ReplayBuffer, to_numpy -from tianshou.data.types import FQFBatchProtocol, RolloutBatchProtocol +from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import DQNPolicy, QRDQNPolicy from tianshou.policy.base import TLearningRateScheduler from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction @@ -84,13 +84,16 @@ class FQFPolicy(QRDQNPolicy): self.fraction_optim = fraction_optim def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - batch = buffer[indices] # batch.obs_next: s_{t+n} + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} if self._target: - result = self(batch, input="obs_next") + result = self(obs_next_batch) act, fractions = result.act, result.fractions - next_dist = self(batch, model="model_old", input="obs_next", fractions=fractions).logits + next_dist = self(obs_next_batch, model="model_old", fractions=fractions).logits else: - next_batch = self(batch, input="obs_next") + next_batch = self(obs_next_batch) act = next_batch.act next_dist = next_batch.logits return next_dist[np.arange(len(act)), act, :] @@ -98,15 +101,15 @@ class FQFPolicy(QRDQNPolicy): # TODO: fix Liskov substitution principle violation def forward( # type: ignore self, - batch: RolloutBatchProtocol, + batch: ObsBatchProtocol, state: dict | Batch | np.ndarray | None = None, - model: str = "model", - input: str = "obs", + model: Literal["model", "model_old"] = "model", fractions: Batch | None = None, **kwargs: Any, ) -> FQFBatchProtocol: model = getattr(self, model) - obs = batch[input] + obs = batch.obs + # TODO: this is convoluted! See also other places where this is done obs_next = obs.obs if hasattr(obs, "obs") else obs if fractions is None: (logits, fractions, quantiles_tau), hidden = model( @@ -125,7 +128,7 @@ class FQFPolicy(QRDQNPolicy): ) weighted_logits = (fractions.taus[:, 1:] - fractions.taus[:, :-1]).unsqueeze(1) * logits q = DQNPolicy.compute_q_value(self, weighted_logits.sum(2), getattr(obs, "mask", None)) - if not hasattr(self, "max_action_num"): + if self.max_action_num is None: # type: ignore # TODO: see same thing in DQNPolicy! Also reduce code duplication. self.max_action_num = q.shape[1] act = to_numpy(q.max(dim=1)[1]) diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 20ead50..c87242f 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -1,4 +1,4 @@ -from typing import Any, cast +from typing import Any, Literal, cast import gymnasium as gym import numpy as np @@ -7,7 +7,11 @@ import torch.nn.functional as F from tianshou.data import Batch, to_numpy from tianshou.data.batch import BatchProtocol -from tianshou.data.types import QuantileRegressionBatchProtocol, RolloutBatchProtocol +from tianshou.data.types import ( + ObsBatchProtocol, + QuantileRegressionBatchProtocol, + RolloutBatchProtocol, +) from tianshou.policy import QRDQNPolicy from tianshou.policy.base import TLearningRateScheduler @@ -88,10 +92,9 @@ class IQNPolicy(QRDQNPolicy): def forward( self, - batch: RolloutBatchProtocol, + batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, - model: str = "model", - input: str = "obs", + model: Literal["model", "model_old"] = "model", **kwargs: Any, ) -> QuantileRegressionBatchProtocol: if model == "model_old": @@ -101,7 +104,8 @@ class IQNPolicy(QRDQNPolicy): else: sample_size = self.sample_size model = getattr(self, model) - obs = batch[input] + obs = batch.obs + # TODO: this seems very contrived! obs_next = obs.obs if hasattr(obs, "obs") else obs (logits, taus), hidden = model( obs_next, @@ -110,8 +114,8 @@ class IQNPolicy(QRDQNPolicy): info=batch.info, ) q = self.compute_q_value(logits, getattr(obs, "mask", None)) - if not hasattr(self, "max_action_num"): - # TODO: see same thing in DQNPolicy! Also reduce code duplication. + if self.max_action_num is None: # type: ignore + # TODO: see same thing in DQNPolicy! self.max_action_num = q.shape[1] act = to_numpy(q.max(dim=1)[1]) result = Batch(logits=logits, act=act, state=hidden, taus=taus) diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index c9f5ccc..8d6e22f 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -11,6 +11,7 @@ from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( BatchWithReturnsProtocol, DistBatchProtocol, + ObsBatchProtocol, RolloutBatchProtocol, ) from tianshou.policy import BasePolicy @@ -157,7 +158,7 @@ class PGPolicy(BasePolicy): def forward( self, - batch: RolloutBatchProtocol, + batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, ) -> DistBatchProtocol: diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 888b296..f4b9615 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -6,7 +6,7 @@ import numpy as np import torch import torch.nn.functional as F -from tianshou.data import ReplayBuffer +from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import DQNPolicy from tianshou.policy.base import TLearningRateScheduler @@ -79,12 +79,15 @@ class QRDQNPolicy(DQNPolicy): warnings.filterwarnings("ignore", message="Using a target size") def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - batch = buffer[indices] # batch.obs_next: s_{t+n} + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} if self._target: - act = self(batch, input="obs_next").act - next_dist = self(batch, model="model_old", input="obs_next").logits + act = self(obs_next_batch).act + next_dist = self(obs_next_batch, model="model_old").logits else: - next_batch = self(batch, input="obs_next") + next_batch = self(obs_next_batch) act = next_batch.act next_dist = next_batch.logits return next_dist[np.arange(len(act)), act, :] diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index c452101..cdc5ad9 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -6,7 +6,7 @@ import torch from torch.distributions import Independent, Normal from tianshou.data import Batch, ReplayBuffer -from tianshou.data.types import RolloutBatchProtocol +from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol from tianshou.exploration import BaseNoise from tianshou.policy import DDPGPolicy from tianshou.policy.base import TLearningRateScheduler @@ -131,13 +131,11 @@ class REDQPolicy(DDPGPolicy): def forward( # type: ignore self, - batch: Batch, + batch: ObsBatchProtocol, state: dict | Batch | np.ndarray | None = None, - input: str = "obs", **kwargs: Any, ) -> Batch: - obs = batch[input] - loc_scale, h = self.actor(obs, state=state, info=batch.info) + loc_scale, h = self.actor(batch.obs, state=state, info=batch.info) loc, scale = loc_scale dist = Independent(Normal(loc, scale), 1) act = loc if self.deterministic_eval and not self.training else dist.rsample() @@ -153,11 +151,14 @@ class REDQPolicy(DDPGPolicy): return Batch(logits=loc_scale, act=squashed_action, state=h, dist=dist, log_prob=log_prob) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - batch = buffer[indices] # batch.obs: s_{t+n} - obs_next_result = self(batch, input="obs_next") + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + obs_next_result = self(obs_next_batch) a_ = obs_next_result.act sample_ensemble_idx = np.random.choice(self.ensemble_size, self.subset_size, replace=False) - qs = self.critic_old(batch.obs_next, a_)[sample_ensemble_idx, ...] + qs = self.critic_old(obs_next_batch.obs, a_)[sample_ensemble_idx, ...] if self.target_mode == "min": target_q, _ = torch.min(qs, dim=0) elif self.target_mode == "mean": diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index de07857..2c494d4 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -7,7 +7,11 @@ import torch from torch.distributions import Independent, Normal from tianshou.data import Batch, ReplayBuffer -from tianshou.data.types import DistLogProbBatchProtocol, RolloutBatchProtocol +from tianshou.data.types import ( + DistLogProbBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) from tianshou.exploration import BaseNoise from tianshou.policy import DDPGPolicy from tianshou.policy.base import TLearningRateScheduler @@ -147,13 +151,11 @@ class SACPolicy(DDPGPolicy): # TODO: violates Liskov substitution principle def forward( # type: ignore self, - batch: RolloutBatchProtocol, + batch: ObsBatchProtocol, state: dict | Batch | np.ndarray | None = None, - input: str = "obs", **kwargs: Any, ) -> DistLogProbBatchProtocol: - obs = batch[input] - logits, hidden = self.actor(obs, state=state, info=batch.info) + logits, hidden = self.actor(batch.obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) if self.deterministic_eval and not self.training: @@ -179,13 +181,16 @@ class SACPolicy(DDPGPolicy): return cast(DistLogProbBatchProtocol, result) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - batch = buffer[indices] # batch.obs: s_{t+n} - obs_next_result = self(batch, input="obs_next") + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + obs_next_result = self(obs_next_batch) act_ = obs_next_result.act return ( torch.min( - self.critic_old(batch.obs_next, act_), - self.critic2_old(batch.obs_next, act_), + self.critic_old(obs_next_batch.obs, act_), + self.critic2_old(obs_next_batch.obs, act_), ) - self.alpha * obs_next_result.log_prob ) diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 0f363c6..c364fc8 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -5,7 +5,7 @@ import gymnasium as gym import numpy as np import torch -from tianshou.data import ReplayBuffer +from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol from tianshou.exploration import BaseNoise from tianshou.policy import DDPGPolicy @@ -114,15 +114,18 @@ class TD3Policy(DDPGPolicy): self.soft_update(self.actor_old, self.actor, self.tau) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - batch = buffer[indices] # batch.obs: s_{t+n} - act_ = self(batch, model="actor_old", input="obs_next").act + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + act_ = self(obs_next_batch, model="actor_old").act noise = torch.randn(size=act_.shape, device=act_.device) * self.policy_noise if self.noise_clip > 0.0: noise = noise.clamp(-self.noise_clip, self.noise_clip) act_ += noise return torch.min( - self.critic_old(batch.obs_next, act_), - self.critic2_old(batch.obs_next, act_), + self.critic_old(obs_next_batch.obs, act_), + self.critic2_old(obs_next_batch.obs, act_), ) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py index ffa2a01..38ff7d0 100644 --- a/tianshou/policy/random.py +++ b/tianshou/policy/random.py @@ -4,7 +4,7 @@ import numpy as np from tianshou.data import Batch from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ActBatchProtocol, RolloutBatchProtocol +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import BasePolicy @@ -16,7 +16,7 @@ class RandomPolicy(BasePolicy): def forward( self, - batch: RolloutBatchProtocol, + batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, ) -> ActBatchProtocol: diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 5595750..31c9efb 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -180,6 +180,7 @@ class ActorProb(BaseActor): of how preprocess_net is suggested to be defined. """ + # TODO: force kwargs, adjust downstream code def __init__( self, preprocess_net: nn.Module,