Method to compute actions from observations (#991)

This PR adds a new method for getting actions from an env's observation
and info. This is useful for standard inference and stands in contrast
to batch-based methods that are currently used in training and
evaluation. Without this, users have to do some kind of gymnastics to
actually perform inference with a trained policy. I have also added a
test for the new method.

In future PRs, this method should be included in the examples (in the
the "watch" section).

To add this required improving multiple typing things and, importantly,
_simplifying the signature of `forward` in many policies!_ This is a
**breaking change**, but it will likely affect no users. The `input`
parameter of forward was a rather hacky mechanism, I believe it is good
that it's gone now. It will also help with #948 .

The main functional change is the addition of `compute_action` to
`BasePolicy`.

Other minor changes:
- improvements in typing
- updated PR and Issue templates
- Improved handling of `max_action_num`

Closes #981
This commit is contained in:
Michael Panchenko 2023-11-16 18:27:53 +01:00 committed by GitHub
parent 6d6c85e594
commit 3a1bc18add
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 338 additions and 171 deletions

View File

@ -3,6 +3,7 @@
+ [ ] RL algorithm bug
+ [ ] documentation request (i.e. "X is missing from the documentation.")
+ [ ] new feature request
+ [ ] design request (i.e. "X should be changed to Y.")
- [ ] I have visited the [source website](https://github.com/thu-ml/tianshou/)
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
- [ ] I have mentioned version numbers, operating system and environment, where applicable:

View File

@ -1,9 +1,9 @@
- [ ] I have marked all applicable categories:
+ [ ] exception-raising fix
+ [ ] algorithm implementation fix
+ [ ] documentation modification
+ [ ] new feature
- [ ] I have reformatted the code using `make format` (**required**)
- [ ] I have checked the code using `make commit-checks` (**required**)
- [ ] If applicable, I have mentioned the relevant/related issue(s)
- [ ] If applicable, I have listed every items in this Pull Request below
- [ ] I have added the correct label(s) to this Pull Request or linked the relevant issue(s)
- [ ] I have provided a description of the changes in this Pull Request
- [ ] I have added documentation for my changes
- [ ] If applicable, I have added tests to cover my changes.
- [ ] I have reformatted the code using `poe format`
- [ ] I have checked style and types with `poe lint` and `poe type-check`
- [ ] (Optional) I ran tests locally with `poe test`
(or a subset of them with `poe test-reduced`) ,and they pass
- [ ] (Optional) I have tested that documentation builds correctly with `poe doc-build`

View File

@ -172,8 +172,8 @@ if __name__ == "__main__":
print(env.spec.reward_threshold)
print(obs.shape, action_num)
for _ in range(4000):
obs, rew, done, info = env.step(0)
if done:
obs, rew, terminated, truncated, info = env.step(0)
if terminated or truncated:
env.reset()
print(obs.shape, rew, done)
print(obs.shape, rew, terminated, truncated)
cv2.imwrite("test.png", obs.transpose(1, 2, 0)[..., :3])

71
test/base/test_policy.py Normal file
View File

@ -0,0 +1,71 @@
import gymnasium as gym
import numpy as np
import pytest
import torch
from torch.distributions import Categorical, Independent, Normal
from tianshou.policy import PPOPolicy
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor
obs_shape = (5,)
def _to_hashable(x: np.ndarray | int):
return x if isinstance(x, int) else tuple(x.tolist())
@pytest.fixture(params=["continuous", "discrete"])
def policy(request):
action_type = request.param
if action_type == "continuous":
action_space = gym.spaces.Box(low=-1, high=1, shape=(3,))
actor = ActorProb(
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape),
action_shape=action_space.shape,
)
dist_fn = lambda *logits: Independent(Normal(*logits), 1)
elif action_type == "discrete":
action_space = gym.spaces.Discrete(3)
actor = Actor(
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n),
action_shape=action_space.n,
)
dist_fn = lambda logits: Categorical(logits=logits)
else:
raise ValueError(f"Unknown action type: {action_type}")
critic = Critic(
Net(obs_shape, hidden_sizes=[64, 64]),
)
actor_critic = ActorCritic(actor, critic)
optim = torch.optim.Adam(actor_critic.parameters(), lr=1e-3)
policy = PPOPolicy(
actor=actor,
critic=critic,
dist_fn=dist_fn,
optim=optim,
action_space=action_space,
action_scaling=False,
)
policy.eval()
return policy
class TestPolicyBasics:
def test_get_action(self, policy):
sample_obs = torch.randn(obs_shape)
policy.deterministic_eval = False
actions = [policy.compute_action(sample_obs) for _ in range(10)]
assert all(policy.action_space.contains(a) for a in actions)
# check that the actions are different in non-deterministic mode
assert len(set(map(_to_hashable, actions))) > 1
policy.deterministic_eval = True
actions = [policy.compute_action(sample_obs) for _ in range(10)]
# check that the actions are the same in deterministic mode
assert len(set(map(_to_hashable, actions))) == 1

View File

@ -5,16 +5,24 @@ from tianshou.data import Batch
from tianshou.data.batch import BatchProtocol, arr_type
class RolloutBatchProtocol(BatchProtocol):
"""Typically, the outcome of sampling from a replay buffer."""
class ObsBatchProtocol(BatchProtocol):
"""Observations of an environment that a policy can turn into actions.
Typically used inside a policy's forward
"""
obs: arr_type | BatchProtocol
info: arr_type
class RolloutBatchProtocol(ObsBatchProtocol):
"""Typically, the outcome of sampling from a replay buffer."""
obs_next: arr_type | BatchProtocol
act: arr_type
rew: np.ndarray
terminated: arr_type
truncated: arr_type
info: arr_type
class BatchWithReturnsProtocol(RolloutBatchProtocol):
@ -39,11 +47,17 @@ class RecurrentStateBatch(BatchProtocol):
class ActBatchProtocol(BatchProtocol):
"""Simplest batch, just containing the action. Useful e.g., for random policy."""
act: np.ndarray
act: arr_type
class ModelOutputBatchProtocol(ActBatchProtocol):
"""Contains model output: (logits) and potentially hidden states."""
class ActStateBatchProtocol(ActBatchProtocol):
"""Contains action and state (which can be None), useful for policies that can support RNNs."""
state: dict | BatchProtocol | np.ndarray | None
class ModelOutputBatchProtocol(ActStateBatchProtocol):
"""In addition to state and action, contains model output: (logits)."""
logits: torch.Tensor
state: dict | BatchProtocol | np.ndarray | None

View File

@ -10,6 +10,8 @@ import torch
from tianshou.data.batch import Batch, _parse_value
# TODO: confusing name, could actually return a batch...
# Overrides and generic types should be added
@no_type_check
def to_numpy(x: Any) -> Batch | np.ndarray:
"""Return an object without torch.Tensor."""

View File

@ -1,7 +1,7 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any, Literal, TypeAlias, cast, overload
from typing import Any, Literal, TypeAlias, cast
import gymnasium as gym
import numpy as np
@ -11,14 +11,18 @@ from numba import njit
from torch import nn
from tianshou.data import ReplayBuffer, to_numpy, to_torch_as
from tianshou.data.batch import BatchProtocol
from tianshou.data.batch import Batch, BatchProtocol, arr_type
from tianshou.data.buffer.base import TBuffer
from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol
from tianshou.data.types import (
ActBatchProtocol,
BatchWithReturnsProtocol,
ObsBatchProtocol,
RolloutBatchProtocol,
)
from tianshou.utils import MultipleLRSchedulers
logger = logging.getLogger(__name__)
TLearningRateScheduler: TypeAlias = torch.optim.lr_scheduler.LRScheduler | MultipleLRSchedulers
@ -149,13 +153,39 @@ class BasePolicy(ABC, nn.Module):
for tgt_param, src_param in zip(tgt.parameters(), src.parameters(), strict=True):
tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data)
def compute_action(
self,
obs: arr_type,
info: dict[str, Any] | None = None,
state: dict | BatchProtocol | np.ndarray | None = None,
) -> np.ndarray | int:
"""Get action as int (for discrete env's) or array (for continuous ones) from
an env's observation and info.
:param obs: observation from the gym's env.
:param info: information given by the gym's env.
:param state: the hidden state of RNN policy, used for recurrent policy.
:return: action as int (for discrete env's) or array (for continuous ones).
"""
# need to add empty batch dimension
obs = obs[None, :]
obs_batch = cast(ObsBatchProtocol, Batch(obs=obs, info=info))
act = self.forward(obs_batch, state=state).act.squeeze()
if isinstance(act, torch.Tensor):
act = act.detach().cpu().numpy()
act = self.map_action(act)
if isinstance(self.action_space, Discrete):
# could be an array of shape (), easier to just convert to int
act = int(act) # type: ignore
return act
@abstractmethod
def forward(
self,
batch: RolloutBatchProtocol,
batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> BatchProtocol:
) -> ActBatchProtocol:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which MUST have the following keys:
@ -190,22 +220,19 @@ class BasePolicy(ABC, nn.Module):
act = policy.map_action(act, batch)
"""
@overload
def map_action(self, act: BatchProtocol) -> BatchProtocol:
...
@overload
def map_action(self, act: np.ndarray) -> np.ndarray:
...
@overload
def map_action(self, act: torch.Tensor) -> torch.Tensor:
...
@staticmethod
def _action_to_numpy(act: arr_type) -> np.ndarray:
act = to_numpy(act) # NOTE: to_numpy could confusingly also return a Batch
if not isinstance(act, np.ndarray):
raise ValueError(
f"act should have been be a numpy.ndarray, but got {type(act)}.",
)
return act
def map_action(
self,
act: BatchProtocol | np.ndarray | torch.Tensor,
) -> BatchProtocol | np.ndarray | torch.Tensor:
act: arr_type,
) -> np.ndarray:
"""Map raw network output to action range in gym's env.action_space.
This function is called in :meth:`~tianshou.data.Collector.collect` and only
@ -223,24 +250,24 @@ class BasePolicy(ABC, nn.Module):
:return: action in the same form of input "act" but remap to the target action
space.
"""
if isinstance(self.action_space, gym.spaces.Box) and isinstance(act, np.ndarray):
# currently this action mapping only supports np.ndarray action
act = self._action_to_numpy(act)
if isinstance(self.action_space, gym.spaces.Box):
if self.action_bound_method == "clip":
act = np.clip(act, -1.0, 1.0)
elif self.action_bound_method == "tanh":
act = np.tanh(act)
if self.action_scaling:
assert (
np.min(act) >= -1.0 and np.max(act) <= 1.0 # type: ignore
np.min(act) >= -1.0 and np.max(act) <= 1.0
), f"action scaling only accepts raw action range = [-1, 1], but got: {act}"
low, high = self.action_space.low, self.action_space.high
act = low + (high - low) * (act + 1.0) / 2.0 # type: ignore
act = low + (high - low) * (act + 1.0) / 2.0
return act
def map_action_inverse(
self,
act: BatchProtocol | list | np.ndarray,
) -> BatchProtocol | list | np.ndarray:
act: arr_type,
) -> np.ndarray:
"""Inverse operation to :meth:`~tianshou.policy.BasePolicy.map_action`.
This function is called in :meth:`~tianshou.data.Collector.collect` for
@ -252,17 +279,17 @@ class BasePolicy(ABC, nn.Module):
:return: action remapped.
"""
act = self._action_to_numpy(act)
if isinstance(self.action_space, gym.spaces.Box):
act = to_numpy(act)
if isinstance(act, np.ndarray):
if self.action_scaling:
low, high = self.action_space.low, self.action_space.high
scale = high - low
eps = np.finfo(np.float32).eps.item()
scale[scale < eps] += eps
act = (act - low) * 2.0 / scale - 1.0
if self.action_bound_method == "tanh":
act = (np.log(1.0 + act) - np.log(1.0 - act)) / 2.0 # type: ignore
if self.action_scaling:
low, high = self.action_space.low, self.action_space.high
scale = high - low
eps = np.finfo(np.float32).eps.item()
scale[scale < eps] += eps
act = (act - low) * 2.0 / scale - 1.0
if self.action_bound_method == "tanh":
act = (np.log(1.0 + act) - np.log(1.0 - act)) / 2.0
return act
def process_buffer(self, buffer: TBuffer) -> TBuffer:

View File

@ -7,7 +7,11 @@ import torch.nn.functional as F
from tianshou.data import Batch, to_torch
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import ModelOutputBatchProtocol, RolloutBatchProtocol
from tianshou.data.types import (
ModelOutputBatchProtocol,
ObsBatchProtocol,
RolloutBatchProtocol,
)
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler
@ -55,7 +59,7 @@ class ImitationPolicy(BasePolicy):
def forward(
self,
batch: RolloutBatchProtocol,
batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> ModelOutputBatchProtocol:

View File

@ -1,5 +1,5 @@
import copy
from typing import Any, Literal, Self
from typing import Any, Literal, Self, cast
import gymnasium as gym
import numpy as np
@ -8,7 +8,7 @@ import torch.nn.functional as F
from tianshou.data import Batch, to_torch
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import RolloutBatchProtocol
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.utils.net.continuous import VAE
@ -112,10 +112,10 @@ class BCQPolicy(BasePolicy):
def forward(
self,
batch: RolloutBatchProtocol,
batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> Batch:
) -> ActBatchProtocol:
"""Compute action over the given batch data."""
# There is "obs" in the Batch
# obs_group: several groups. Each group has a state.
@ -134,7 +134,7 @@ class BCQPolicy(BasePolicy):
max_indice = q1.argmax(0)
act_group.append(act[max_indice].cpu().data.numpy().flatten())
act_group = np.array(act_group)
return Batch(act=act_group)
return cast(ActBatchProtocol, Batch(act=act_group))
def sync_weight(self) -> None:
"""Soft-update the weight for the target network."""

View File

@ -156,7 +156,7 @@ class CQLPolicy(SACPolicy):
self.soft_update(self.critic2_old, self.critic2, self.tau)
def actor_pred(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch = Batch(obs=obs, info=None)
batch = Batch(obs=obs, info=[None] * len(obs))
obs_result = self(batch)
return obs_result.act, obs_result.log_prob

View File

@ -7,7 +7,11 @@ import torch
import torch.nn.functional as F
from tianshou.data import Batch, ReplayBuffer, to_torch
from tianshou.data.types import ImitationBatchProtocol, RolloutBatchProtocol
from tianshou.data.types import (
ImitationBatchProtocol,
ObsBatchProtocol,
RolloutBatchProtocol,
)
from tianshou.policy import DQNPolicy
from tianshou.policy.base import TLearningRateScheduler
@ -101,26 +105,25 @@ class DiscreteBCQPolicy(DQNPolicy):
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs_next: s_{t+n}
next_obs_batch = Batch(obs=batch.obs_next, info=[None] * len(batch))
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
act = self(batch, input="obs_next").act
act = self(next_obs_batch).act
target_q, _ = self.model_old(batch.obs_next)
return target_q[np.arange(len(act)), act]
def forward( # type: ignore
self,
batch: RolloutBatchProtocol,
batch: ObsBatchProtocol,
state: dict | Batch | np.ndarray | None = None,
input: str = "obs",
**kwargs: Any,
) -> ImitationBatchProtocol:
# TODO: Liskov substitution principle is violated here, the superclass
# produces a batch with the field logits, but this one doesn't.
# Should be fixed in the future!
obs = batch[input]
q_value, state = self.model(obs, state=state, info=batch.info)
if not hasattr(self, "max_action_num"):
q_value, state = self.model(batch.obs, state=state, info=batch.info)
if self.max_action_num is None:
self.max_action_num = q_value.shape[1]
imitation_logits, _ = self.imitator(obs, state=state, info=batch.info)
imitation_logits, _ = self.imitator(batch.obs, state=state, info=batch.info)
# mask actions for argmax
ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import RolloutBatchProtocol
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
@ -72,10 +72,10 @@ class ICMPolicy(BasePolicy):
def forward(
self,
batch: RolloutBatchProtocol,
batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> BatchProtocol:
) -> ActBatchProtocol:
"""Compute action over the given batch data by inner policy.
.. seealso::

View File

@ -6,7 +6,7 @@ import torch
from tianshou.data import Batch
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import ActBatchProtocol, RolloutBatchProtocol
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler
@ -198,7 +198,7 @@ class PSRLPolicy(BasePolicy):
def forward(
self,
batch: RolloutBatchProtocol,
batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> ActBatchProtocol:
@ -213,9 +213,9 @@ class PSRLPolicy(BasePolicy):
more detailed explanation.
"""
assert isinstance(batch.obs, np.ndarray), "only support np.ndarray observation"
# TODO: shouldn't the model output a state as well if state is passed (i.e. RNNs are involved)?
act = self.model(batch.obs, state=state, info=batch.info)
result = Batch(act=act)
return cast(ActBatchProtocol, result)
return cast(ActBatchProtocol, Batch(act=act))
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]:
n_s, n_a = self.model.n_state, self.model.n_action

View File

@ -1,4 +1,4 @@
from typing import Any, cast
from typing import Any, Literal, cast
import gymnasium as gym
import numpy as np
@ -9,6 +9,7 @@ from tianshou.data.batch import BatchProtocol
from tianshou.data.types import (
BatchWithReturnsProtocol,
ModelOutputBatchProtocol,
ObsBatchProtocol,
RolloutBatchProtocol,
)
from tianshou.policy import DQNPolicy
@ -73,9 +74,10 @@ class BranchingDQNPolicy(DQNPolicy):
)
self.model = cast(BranchingNet, self.model)
# TODO: mypy complains b/c max_action_num is declared in base class, see todo there
# TODO: this used to be a public property called max_action_num,
# but it collides with an attr of the same name in base class
@property
def max_action_num(self) -> int: # type: ignore
def _action_per_branch(self) -> int:
return self.model.action_per_branch # type: ignore
@property
@ -83,15 +85,18 @@ class BranchingDQNPolicy(DQNPolicy):
return self.model.num_branches # type: ignore
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs_next: s_{t+n}
result = self(batch, input="obs_next")
obs_next_batch = Batch(
obs=buffer[indices].obs_next,
info=[None] * len(indices),
) # obs_next: s_{t+n}
result = self(obs_next_batch)
if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
target_q = self(batch, model="model_old", input="obs_next").logits
target_q = self(obs_next_batch, model="model_old").logits
else:
target_q = result.logits
if self.is_double:
act = np.expand_dims(self(batch, input="obs_next").act, -1)
act = np.expand_dims(self(obs_next_batch).act, -1)
act = to_torch(act, dtype=torch.long, device=target_q.device)
else:
act = target_q.max(-1).indices.unsqueeze(-1)
@ -114,7 +119,7 @@ class BranchingDQNPolicy(DQNPolicy):
mean_target_q = np.mean(target_q, -1) if len(target_q.shape) > 1 else target_q
_target_q = rew + gamma * mean_target_q * (1 - end_flag)
target_q = np.repeat(_target_q[..., None], self.num_branches, axis=-1)
target_q = np.repeat(target_q[..., None], self.max_action_num, axis=-1)
target_q = np.repeat(target_q[..., None], self._action_per_branch, axis=-1)
batch.returns = to_torch_as(target_q, target_q_torch)
if hasattr(batch, "weight"): # prio buffer update
@ -132,14 +137,14 @@ class BranchingDQNPolicy(DQNPolicy):
def forward(
self,
batch: RolloutBatchProtocol,
batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
model: str = "model",
input: str = "obs",
model: Literal["model", "model_old"] = "model",
**kwargs: Any,
) -> ModelOutputBatchProtocol:
model = getattr(self, model)
obs = batch[input]
obs = batch.obs
# TODO: this is very contrived, see also iqn.py
obs_next = obs.obs if hasattr(obs, "obs") else obs
logits, hidden = model(obs_next, state=state, info=batch.info)
act = to_numpy(logits.max(dim=-1)[1])
@ -174,7 +179,11 @@ class BranchingDQNPolicy(DQNPolicy):
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
bsz = len(act)
rand_mask = np.random.rand(bsz) < self.eps
rand_act = np.random.randint(low=0, high=self.max_action_num, size=(bsz, act.shape[-1]))
rand_act = np.random.randint(
low=0,
high=self._action_per_branch,
size=(bsz, act.shape[-1]),
)
if hasattr(batch.obs, "mask"):
rand_act += batch.obs.mask
act[rand_mask] = rand_act[rand_mask]

View File

@ -4,7 +4,7 @@ import gymnasium as gym
import numpy as np
import torch
from tianshou.data import ReplayBuffer
from tianshou.data import Batch, ReplayBuffer
from tianshou.data.types import RolloutBatchProtocol
from tianshou.policy import DQNPolicy
from tianshou.policy.base import TLearningRateScheduler
@ -90,11 +90,12 @@ class C51Policy(DQNPolicy):
return super().compute_q_value((logits * self.support).sum(2), mask)
def _target_dist(self, batch: RolloutBatchProtocol) -> torch.Tensor:
obs_next_batch = Batch(obs=batch.obs_next, info=[None] * len(batch))
if self._target:
act = self(batch, input="obs_next").act
next_dist = self(batch, model="model_old", input="obs_next").logits
act = self(obs_next_batch).act
next_dist = self(obs_next_batch, model="model_old").logits
else:
next_batch = self(batch, input="obs_next")
next_batch = self(obs_next_batch)
act = next_batch.act
next_dist = next_batch.logits
next_dist = next_dist[np.arange(len(act)), act, :]

View File

@ -1,6 +1,6 @@
import warnings
from copy import deepcopy
from typing import Any, Literal, Self
from typing import Any, Literal, Self, cast
import gymnasium as gym
import numpy as np
@ -8,7 +8,12 @@ import torch
from tianshou.data import Batch, ReplayBuffer
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol
from tianshou.data.types import (
ActStateBatchProtocol,
BatchWithReturnsProtocol,
ObsBatchProtocol,
RolloutBatchProtocol,
)
from tianshou.exploration import BaseNoise, GaussianNoise
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler
@ -118,8 +123,11 @@ class DDPGPolicy(BasePolicy):
self.soft_update(self.critic_old, self.critic, self.tau)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs_next: s_{t+n}
return self.critic_old(batch.obs_next, self(batch, model="actor_old", input="obs_next").act)
obs_next_batch = Batch(
obs=buffer[indices].obs_next,
info=[None] * len(indices),
) # obs_next: s_{t+n}
return self.critic_old(obs_next_batch.obs, self(obs_next_batch, model="actor_old").act)
def process_fn(
self,
@ -138,12 +146,11 @@ class DDPGPolicy(BasePolicy):
def forward(
self,
batch: RolloutBatchProtocol,
batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
model: Literal["actor", "actor_old"] = "actor",
input: str = "obs",
**kwargs: Any,
) -> BatchProtocol:
) -> ActStateBatchProtocol:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 2 keys:
@ -157,9 +164,8 @@ class DDPGPolicy(BasePolicy):
more detailed explanation.
"""
model = getattr(self, model)
obs = batch[input]
actions, hidden = model(obs, state=state, info=batch.info)
return Batch(act=actions, state=hidden)
actions, hidden = model(batch.obs, state=state, info=batch.info)
return cast(ActStateBatchProtocol, Batch(act=actions, state=hidden))
@staticmethod
def _mse_optimizer(

View File

@ -8,7 +8,7 @@ from torch.distributions import Categorical
from tianshou.data import Batch, ReplayBuffer, to_torch
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import RolloutBatchProtocol
from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol
from tianshou.policy import SACPolicy
from tianshou.policy.base import TLearningRateScheduler
@ -92,13 +92,11 @@ class DiscreteSACPolicy(SACPolicy):
def forward( # type: ignore
self,
batch: Batch,
batch: ObsBatchProtocol,
state: dict | Batch | np.ndarray | None = None,
input: str = "obs",
**kwargs: Any,
) -> Batch:
obs = batch[input]
logits, hidden = self.actor(obs, state=state, info=batch.info)
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
dist = Categorical(logits=logits)
if self.deterministic_eval and not self.training:
act = logits.argmax(axis=-1)
@ -107,12 +105,15 @@ class DiscreteSACPolicy(SACPolicy):
return Batch(logits=logits, act=act, state=hidden, dist=dist)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs: s_{t+n}
obs_next_result = self(batch, input="obs_next")
obs_next_batch = Batch(
obs=buffer[indices].obs_next,
info=[None] * len(indices),
) # obs_next: s_{t+n}
obs_next_result = self(obs_next_batch)
dist = obs_next_result.dist
target_q = dist.probs * torch.min(
self.critic_old(batch.obs_next),
self.critic2_old(batch.obs_next),
self.critic_old(obs_next_batch.obs),
self.critic2_old(obs_next_batch.obs),
)
return target_q.sum(dim=-1) + self.alpha * dist.entropy()

View File

@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Any, Self, cast
from typing import Any, Literal, Self, cast
import gymnasium as gym
import numpy as np
@ -10,6 +10,7 @@ from tianshou.data.batch import BatchProtocol
from tianshou.data.types import (
BatchWithReturnsProtocol,
ModelOutputBatchProtocol,
ObsBatchProtocol,
RolloutBatchProtocol,
)
from tianshou.policy import BasePolicy
@ -91,7 +92,7 @@ class DQNPolicy(BasePolicy):
self.clip_loss_grad = clip_loss_grad
# TODO: set in forward, fix this!
self.max_action_num: int
self.max_action_num: int | None = None
def set_eps(self, eps: float) -> None:
"""Set the eps for epsilon-greedy exploration."""
@ -108,11 +109,14 @@ class DQNPolicy(BasePolicy):
self.model_old.load_state_dict(self.model.state_dict())
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs_next: s_{t+n}
result = self(batch, input="obs_next")
obs_next_batch = Batch(
obs=buffer[indices].obs_next,
info=[None] * len(indices),
) # obs_next: s_{t+n}
result = self(obs_next_batch)
if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
target_q = self(batch, model="model_old", input="obs_next").logits
target_q = self(obs_next_batch, model="model_old").logits
else:
target_q = result.logits
if self.is_double:
@ -151,10 +155,9 @@ class DQNPolicy(BasePolicy):
def forward(
self,
batch: RolloutBatchProtocol,
batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
model: str = "model",
input: str = "obs",
model: Literal["model", "model_old"] = "model",
**kwargs: Any,
) -> ModelOutputBatchProtocol:
"""Compute action over the given batch data.
@ -185,11 +188,12 @@ class DQNPolicy(BasePolicy):
more detailed explanation.
"""
model = getattr(self, model)
obs = batch[input]
obs = batch.obs
# TODO: this is convoluted! See also other places where this is done.
obs_next = obs.obs if hasattr(obs, "obs") else obs
logits, hidden = model(obs_next, state=state, info=batch.info)
q = self.compute_q_value(logits, getattr(obs, "mask", None))
if not hasattr(self, "max_action_num"):
if self.max_action_num is None:
self.max_action_num = q.shape[1]
act = to_numpy(q.max(dim=1)[1])
result = Batch(logits=logits, act=act, state=hidden)
@ -226,6 +230,9 @@ class DQNPolicy(BasePolicy):
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
bsz = len(act)
rand_mask = np.random.rand(bsz) < self.eps
assert (
self.max_action_num is not None
), "Can't call this method before max_action_num was set in first forward"
q = np.random.rand(bsz, self.max_action_num) # [0, 1]
if hasattr(batch.obs, "mask"):
q += batch.obs.mask

View File

@ -1,4 +1,4 @@
from typing import Any, cast
from typing import Any, Literal, cast
import gymnasium as gym
import numpy as np
@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F
from tianshou.data import Batch, ReplayBuffer, to_numpy
from tianshou.data.types import FQFBatchProtocol, RolloutBatchProtocol
from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
from tianshou.policy import DQNPolicy, QRDQNPolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction
@ -84,13 +84,16 @@ class FQFPolicy(QRDQNPolicy):
self.fraction_optim = fraction_optim
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs_next: s_{t+n}
obs_next_batch = Batch(
obs=buffer[indices].obs_next,
info=[None] * len(indices),
) # obs_next: s_{t+n}
if self._target:
result = self(batch, input="obs_next")
result = self(obs_next_batch)
act, fractions = result.act, result.fractions
next_dist = self(batch, model="model_old", input="obs_next", fractions=fractions).logits
next_dist = self(obs_next_batch, model="model_old", fractions=fractions).logits
else:
next_batch = self(batch, input="obs_next")
next_batch = self(obs_next_batch)
act = next_batch.act
next_dist = next_batch.logits
return next_dist[np.arange(len(act)), act, :]
@ -98,15 +101,15 @@ class FQFPolicy(QRDQNPolicy):
# TODO: fix Liskov substitution principle violation
def forward( # type: ignore
self,
batch: RolloutBatchProtocol,
batch: ObsBatchProtocol,
state: dict | Batch | np.ndarray | None = None,
model: str = "model",
input: str = "obs",
model: Literal["model", "model_old"] = "model",
fractions: Batch | None = None,
**kwargs: Any,
) -> FQFBatchProtocol:
model = getattr(self, model)
obs = batch[input]
obs = batch.obs
# TODO: this is convoluted! See also other places where this is done
obs_next = obs.obs if hasattr(obs, "obs") else obs
if fractions is None:
(logits, fractions, quantiles_tau), hidden = model(
@ -125,7 +128,7 @@ class FQFPolicy(QRDQNPolicy):
)
weighted_logits = (fractions.taus[:, 1:] - fractions.taus[:, :-1]).unsqueeze(1) * logits
q = DQNPolicy.compute_q_value(self, weighted_logits.sum(2), getattr(obs, "mask", None))
if not hasattr(self, "max_action_num"):
if self.max_action_num is None: # type: ignore
# TODO: see same thing in DQNPolicy! Also reduce code duplication.
self.max_action_num = q.shape[1]
act = to_numpy(q.max(dim=1)[1])

View File

@ -1,4 +1,4 @@
from typing import Any, cast
from typing import Any, Literal, cast
import gymnasium as gym
import numpy as np
@ -7,7 +7,11 @@ import torch.nn.functional as F
from tianshou.data import Batch, to_numpy
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import QuantileRegressionBatchProtocol, RolloutBatchProtocol
from tianshou.data.types import (
ObsBatchProtocol,
QuantileRegressionBatchProtocol,
RolloutBatchProtocol,
)
from tianshou.policy import QRDQNPolicy
from tianshou.policy.base import TLearningRateScheduler
@ -88,10 +92,9 @@ class IQNPolicy(QRDQNPolicy):
def forward(
self,
batch: RolloutBatchProtocol,
batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
model: str = "model",
input: str = "obs",
model: Literal["model", "model_old"] = "model",
**kwargs: Any,
) -> QuantileRegressionBatchProtocol:
if model == "model_old":
@ -101,7 +104,8 @@ class IQNPolicy(QRDQNPolicy):
else:
sample_size = self.sample_size
model = getattr(self, model)
obs = batch[input]
obs = batch.obs
# TODO: this seems very contrived!
obs_next = obs.obs if hasattr(obs, "obs") else obs
(logits, taus), hidden = model(
obs_next,
@ -110,8 +114,8 @@ class IQNPolicy(QRDQNPolicy):
info=batch.info,
)
q = self.compute_q_value(logits, getattr(obs, "mask", None))
if not hasattr(self, "max_action_num"):
# TODO: see same thing in DQNPolicy! Also reduce code duplication.
if self.max_action_num is None: # type: ignore
# TODO: see same thing in DQNPolicy!
self.max_action_num = q.shape[1]
act = to_numpy(q.max(dim=1)[1])
result = Batch(logits=logits, act=act, state=hidden, taus=taus)

View File

@ -11,6 +11,7 @@ from tianshou.data.batch import BatchProtocol
from tianshou.data.types import (
BatchWithReturnsProtocol,
DistBatchProtocol,
ObsBatchProtocol,
RolloutBatchProtocol,
)
from tianshou.policy import BasePolicy
@ -157,7 +158,7 @@ class PGPolicy(BasePolicy):
def forward(
self,
batch: RolloutBatchProtocol,
batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> DistBatchProtocol:

View File

@ -6,7 +6,7 @@ import numpy as np
import torch
import torch.nn.functional as F
from tianshou.data import ReplayBuffer
from tianshou.data import Batch, ReplayBuffer
from tianshou.data.types import RolloutBatchProtocol
from tianshou.policy import DQNPolicy
from tianshou.policy.base import TLearningRateScheduler
@ -79,12 +79,15 @@ class QRDQNPolicy(DQNPolicy):
warnings.filterwarnings("ignore", message="Using a target size")
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs_next: s_{t+n}
obs_next_batch = Batch(
obs=buffer[indices].obs_next,
info=[None] * len(indices),
) # obs_next: s_{t+n}
if self._target:
act = self(batch, input="obs_next").act
next_dist = self(batch, model="model_old", input="obs_next").logits
act = self(obs_next_batch).act
next_dist = self(obs_next_batch, model="model_old").logits
else:
next_batch = self(batch, input="obs_next")
next_batch = self(obs_next_batch)
act = next_batch.act
next_dist = next_batch.logits
return next_dist[np.arange(len(act)), act, :]

View File

@ -6,7 +6,7 @@ import torch
from torch.distributions import Independent, Normal
from tianshou.data import Batch, ReplayBuffer
from tianshou.data.types import RolloutBatchProtocol
from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol
from tianshou.exploration import BaseNoise
from tianshou.policy import DDPGPolicy
from tianshou.policy.base import TLearningRateScheduler
@ -131,13 +131,11 @@ class REDQPolicy(DDPGPolicy):
def forward( # type: ignore
self,
batch: Batch,
batch: ObsBatchProtocol,
state: dict | Batch | np.ndarray | None = None,
input: str = "obs",
**kwargs: Any,
) -> Batch:
obs = batch[input]
loc_scale, h = self.actor(obs, state=state, info=batch.info)
loc_scale, h = self.actor(batch.obs, state=state, info=batch.info)
loc, scale = loc_scale
dist = Independent(Normal(loc, scale), 1)
act = loc if self.deterministic_eval and not self.training else dist.rsample()
@ -153,11 +151,14 @@ class REDQPolicy(DDPGPolicy):
return Batch(logits=loc_scale, act=squashed_action, state=h, dist=dist, log_prob=log_prob)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs: s_{t+n}
obs_next_result = self(batch, input="obs_next")
obs_next_batch = Batch(
obs=buffer[indices].obs_next,
info=[None] * len(indices),
) # obs_next: s_{t+n}
obs_next_result = self(obs_next_batch)
a_ = obs_next_result.act
sample_ensemble_idx = np.random.choice(self.ensemble_size, self.subset_size, replace=False)
qs = self.critic_old(batch.obs_next, a_)[sample_ensemble_idx, ...]
qs = self.critic_old(obs_next_batch.obs, a_)[sample_ensemble_idx, ...]
if self.target_mode == "min":
target_q, _ = torch.min(qs, dim=0)
elif self.target_mode == "mean":

View File

@ -7,7 +7,11 @@ import torch
from torch.distributions import Independent, Normal
from tianshou.data import Batch, ReplayBuffer
from tianshou.data.types import DistLogProbBatchProtocol, RolloutBatchProtocol
from tianshou.data.types import (
DistLogProbBatchProtocol,
ObsBatchProtocol,
RolloutBatchProtocol,
)
from tianshou.exploration import BaseNoise
from tianshou.policy import DDPGPolicy
from tianshou.policy.base import TLearningRateScheduler
@ -147,13 +151,11 @@ class SACPolicy(DDPGPolicy):
# TODO: violates Liskov substitution principle
def forward( # type: ignore
self,
batch: RolloutBatchProtocol,
batch: ObsBatchProtocol,
state: dict | Batch | np.ndarray | None = None,
input: str = "obs",
**kwargs: Any,
) -> DistLogProbBatchProtocol:
obs = batch[input]
logits, hidden = self.actor(obs, state=state, info=batch.info)
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
assert isinstance(logits, tuple)
dist = Independent(Normal(*logits), 1)
if self.deterministic_eval and not self.training:
@ -179,13 +181,16 @@ class SACPolicy(DDPGPolicy):
return cast(DistLogProbBatchProtocol, result)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs: s_{t+n}
obs_next_result = self(batch, input="obs_next")
obs_next_batch = Batch(
obs=buffer[indices].obs_next,
info=[None] * len(indices),
) # obs_next: s_{t+n}
obs_next_result = self(obs_next_batch)
act_ = obs_next_result.act
return (
torch.min(
self.critic_old(batch.obs_next, act_),
self.critic2_old(batch.obs_next, act_),
self.critic_old(obs_next_batch.obs, act_),
self.critic2_old(obs_next_batch.obs, act_),
)
- self.alpha * obs_next_result.log_prob
)

View File

@ -5,7 +5,7 @@ import gymnasium as gym
import numpy as np
import torch
from tianshou.data import ReplayBuffer
from tianshou.data import Batch, ReplayBuffer
from tianshou.data.types import RolloutBatchProtocol
from tianshou.exploration import BaseNoise
from tianshou.policy import DDPGPolicy
@ -114,15 +114,18 @@ class TD3Policy(DDPGPolicy):
self.soft_update(self.actor_old, self.actor, self.tau)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs: s_{t+n}
act_ = self(batch, model="actor_old", input="obs_next").act
obs_next_batch = Batch(
obs=buffer[indices].obs_next,
info=[None] * len(indices),
) # obs_next: s_{t+n}
act_ = self(obs_next_batch, model="actor_old").act
noise = torch.randn(size=act_.shape, device=act_.device) * self.policy_noise
if self.noise_clip > 0.0:
noise = noise.clamp(-self.noise_clip, self.noise_clip)
act_ += noise
return torch.min(
self.critic_old(batch.obs_next, act_),
self.critic2_old(batch.obs_next, act_),
self.critic_old(obs_next_batch.obs, act_),
self.critic2_old(obs_next_batch.obs, act_),
)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]:

View File

@ -4,7 +4,7 @@ import numpy as np
from tianshou.data import Batch
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import ActBatchProtocol, RolloutBatchProtocol
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
from tianshou.policy import BasePolicy
@ -16,7 +16,7 @@ class RandomPolicy(BasePolicy):
def forward(
self,
batch: RolloutBatchProtocol,
batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> ActBatchProtocol:

View File

@ -180,6 +180,7 @@ class ActorProb(BaseActor):
of how preprocess_net is suggested to be defined.
"""
# TODO: force kwargs, adjust downstream code
def __init__(
self,
preprocess_net: nn.Module,