Method to compute actions from observations (#991)
This PR adds a new method for getting actions from an env's observation and info. This is useful for standard inference and stands in contrast to batch-based methods that are currently used in training and evaluation. Without this, users have to do some kind of gymnastics to actually perform inference with a trained policy. I have also added a test for the new method. In future PRs, this method should be included in the examples (in the the "watch" section). To add this required improving multiple typing things and, importantly, _simplifying the signature of `forward` in many policies!_ This is a **breaking change**, but it will likely affect no users. The `input` parameter of forward was a rather hacky mechanism, I believe it is good that it's gone now. It will also help with #948 . The main functional change is the addition of `compute_action` to `BasePolicy`. Other minor changes: - improvements in typing - updated PR and Issue templates - Improved handling of `max_action_num` Closes #981
This commit is contained in:
parent
6d6c85e594
commit
3a1bc18add
1
.github/ISSUE_TEMPLATE.md
vendored
1
.github/ISSUE_TEMPLATE.md
vendored
@ -3,6 +3,7 @@
|
||||
+ [ ] RL algorithm bug
|
||||
+ [ ] documentation request (i.e. "X is missing from the documentation.")
|
||||
+ [ ] new feature request
|
||||
+ [ ] design request (i.e. "X should be changed to Y.")
|
||||
- [ ] I have visited the [source website](https://github.com/thu-ml/tianshou/)
|
||||
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
|
||||
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
||||
|
18
.github/PULL_REQUEST_TEMPLATE.md
vendored
18
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -1,9 +1,9 @@
|
||||
- [ ] I have marked all applicable categories:
|
||||
+ [ ] exception-raising fix
|
||||
+ [ ] algorithm implementation fix
|
||||
+ [ ] documentation modification
|
||||
+ [ ] new feature
|
||||
- [ ] I have reformatted the code using `make format` (**required**)
|
||||
- [ ] I have checked the code using `make commit-checks` (**required**)
|
||||
- [ ] If applicable, I have mentioned the relevant/related issue(s)
|
||||
- [ ] If applicable, I have listed every items in this Pull Request below
|
||||
- [ ] I have added the correct label(s) to this Pull Request or linked the relevant issue(s)
|
||||
- [ ] I have provided a description of the changes in this Pull Request
|
||||
- [ ] I have added documentation for my changes
|
||||
- [ ] If applicable, I have added tests to cover my changes.
|
||||
- [ ] I have reformatted the code using `poe format`
|
||||
- [ ] I have checked style and types with `poe lint` and `poe type-check`
|
||||
- [ ] (Optional) I ran tests locally with `poe test`
|
||||
(or a subset of them with `poe test-reduced`) ,and they pass
|
||||
- [ ] (Optional) I have tested that documentation builds correctly with `poe doc-build`
|
@ -172,8 +172,8 @@ if __name__ == "__main__":
|
||||
print(env.spec.reward_threshold)
|
||||
print(obs.shape, action_num)
|
||||
for _ in range(4000):
|
||||
obs, rew, done, info = env.step(0)
|
||||
if done:
|
||||
obs, rew, terminated, truncated, info = env.step(0)
|
||||
if terminated or truncated:
|
||||
env.reset()
|
||||
print(obs.shape, rew, done)
|
||||
print(obs.shape, rew, terminated, truncated)
|
||||
cv2.imwrite("test.png", obs.transpose(1, 2, 0)[..., :3])
|
||||
|
71
test/base/test_policy.py
Normal file
71
test/base/test_policy.py
Normal file
@ -0,0 +1,71 @@
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch.distributions import Categorical, Independent, Normal
|
||||
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.utils.net.common import ActorCritic, Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
from tianshou.utils.net.discrete import Actor
|
||||
|
||||
obs_shape = (5,)
|
||||
|
||||
|
||||
def _to_hashable(x: np.ndarray | int):
|
||||
return x if isinstance(x, int) else tuple(x.tolist())
|
||||
|
||||
|
||||
@pytest.fixture(params=["continuous", "discrete"])
|
||||
def policy(request):
|
||||
action_type = request.param
|
||||
if action_type == "continuous":
|
||||
action_space = gym.spaces.Box(low=-1, high=1, shape=(3,))
|
||||
actor = ActorProb(
|
||||
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape),
|
||||
action_shape=action_space.shape,
|
||||
)
|
||||
dist_fn = lambda *logits: Independent(Normal(*logits), 1)
|
||||
elif action_type == "discrete":
|
||||
action_space = gym.spaces.Discrete(3)
|
||||
actor = Actor(
|
||||
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n),
|
||||
action_shape=action_space.n,
|
||||
)
|
||||
dist_fn = lambda logits: Categorical(logits=logits)
|
||||
else:
|
||||
raise ValueError(f"Unknown action type: {action_type}")
|
||||
|
||||
critic = Critic(
|
||||
Net(obs_shape, hidden_sizes=[64, 64]),
|
||||
)
|
||||
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
optim = torch.optim.Adam(actor_critic.parameters(), lr=1e-3)
|
||||
|
||||
policy = PPOPolicy(
|
||||
actor=actor,
|
||||
critic=critic,
|
||||
dist_fn=dist_fn,
|
||||
optim=optim,
|
||||
action_space=action_space,
|
||||
action_scaling=False,
|
||||
)
|
||||
policy.eval()
|
||||
return policy
|
||||
|
||||
|
||||
class TestPolicyBasics:
|
||||
def test_get_action(self, policy):
|
||||
sample_obs = torch.randn(obs_shape)
|
||||
policy.deterministic_eval = False
|
||||
actions = [policy.compute_action(sample_obs) for _ in range(10)]
|
||||
assert all(policy.action_space.contains(a) for a in actions)
|
||||
|
||||
# check that the actions are different in non-deterministic mode
|
||||
assert len(set(map(_to_hashable, actions))) > 1
|
||||
|
||||
policy.deterministic_eval = True
|
||||
actions = [policy.compute_action(sample_obs) for _ in range(10)]
|
||||
# check that the actions are the same in deterministic mode
|
||||
assert len(set(map(_to_hashable, actions))) == 1
|
@ -5,16 +5,24 @@ from tianshou.data import Batch
|
||||
from tianshou.data.batch import BatchProtocol, arr_type
|
||||
|
||||
|
||||
class RolloutBatchProtocol(BatchProtocol):
|
||||
"""Typically, the outcome of sampling from a replay buffer."""
|
||||
class ObsBatchProtocol(BatchProtocol):
|
||||
"""Observations of an environment that a policy can turn into actions.
|
||||
|
||||
Typically used inside a policy's forward
|
||||
"""
|
||||
|
||||
obs: arr_type | BatchProtocol
|
||||
info: arr_type
|
||||
|
||||
|
||||
class RolloutBatchProtocol(ObsBatchProtocol):
|
||||
"""Typically, the outcome of sampling from a replay buffer."""
|
||||
|
||||
obs_next: arr_type | BatchProtocol
|
||||
act: arr_type
|
||||
rew: np.ndarray
|
||||
terminated: arr_type
|
||||
truncated: arr_type
|
||||
info: arr_type
|
||||
|
||||
|
||||
class BatchWithReturnsProtocol(RolloutBatchProtocol):
|
||||
@ -39,11 +47,17 @@ class RecurrentStateBatch(BatchProtocol):
|
||||
class ActBatchProtocol(BatchProtocol):
|
||||
"""Simplest batch, just containing the action. Useful e.g., for random policy."""
|
||||
|
||||
act: np.ndarray
|
||||
act: arr_type
|
||||
|
||||
|
||||
class ModelOutputBatchProtocol(ActBatchProtocol):
|
||||
"""Contains model output: (logits) and potentially hidden states."""
|
||||
class ActStateBatchProtocol(ActBatchProtocol):
|
||||
"""Contains action and state (which can be None), useful for policies that can support RNNs."""
|
||||
|
||||
state: dict | BatchProtocol | np.ndarray | None
|
||||
|
||||
|
||||
class ModelOutputBatchProtocol(ActStateBatchProtocol):
|
||||
"""In addition to state and action, contains model output: (logits)."""
|
||||
|
||||
logits: torch.Tensor
|
||||
state: dict | BatchProtocol | np.ndarray | None
|
||||
|
@ -10,6 +10,8 @@ import torch
|
||||
from tianshou.data.batch import Batch, _parse_value
|
||||
|
||||
|
||||
# TODO: confusing name, could actually return a batch...
|
||||
# Overrides and generic types should be added
|
||||
@no_type_check
|
||||
def to_numpy(x: Any) -> Batch | np.ndarray:
|
||||
"""Return an object without torch.Tensor."""
|
||||
|
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, TypeAlias, cast, overload
|
||||
from typing import Any, Literal, TypeAlias, cast
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -11,14 +11,18 @@ from numba import njit
|
||||
from torch import nn
|
||||
|
||||
from tianshou.data import ReplayBuffer, to_numpy, to_torch_as
|
||||
from tianshou.data.batch import BatchProtocol
|
||||
from tianshou.data.batch import Batch, BatchProtocol, arr_type
|
||||
from tianshou.data.buffer.base import TBuffer
|
||||
from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol
|
||||
from tianshou.data.types import (
|
||||
ActBatchProtocol,
|
||||
BatchWithReturnsProtocol,
|
||||
ObsBatchProtocol,
|
||||
RolloutBatchProtocol,
|
||||
)
|
||||
from tianshou.utils import MultipleLRSchedulers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
TLearningRateScheduler: TypeAlias = torch.optim.lr_scheduler.LRScheduler | MultipleLRSchedulers
|
||||
|
||||
|
||||
@ -149,13 +153,39 @@ class BasePolicy(ABC, nn.Module):
|
||||
for tgt_param, src_param in zip(tgt.parameters(), src.parameters(), strict=True):
|
||||
tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data)
|
||||
|
||||
def compute_action(
|
||||
self,
|
||||
obs: arr_type,
|
||||
info: dict[str, Any] | None = None,
|
||||
state: dict | BatchProtocol | np.ndarray | None = None,
|
||||
) -> np.ndarray | int:
|
||||
"""Get action as int (for discrete env's) or array (for continuous ones) from
|
||||
an env's observation and info.
|
||||
|
||||
:param obs: observation from the gym's env.
|
||||
:param info: information given by the gym's env.
|
||||
:param state: the hidden state of RNN policy, used for recurrent policy.
|
||||
:return: action as int (for discrete env's) or array (for continuous ones).
|
||||
"""
|
||||
# need to add empty batch dimension
|
||||
obs = obs[None, :]
|
||||
obs_batch = cast(ObsBatchProtocol, Batch(obs=obs, info=info))
|
||||
act = self.forward(obs_batch, state=state).act.squeeze()
|
||||
if isinstance(act, torch.Tensor):
|
||||
act = act.detach().cpu().numpy()
|
||||
act = self.map_action(act)
|
||||
if isinstance(self.action_space, Discrete):
|
||||
# could be an array of shape (), easier to just convert to int
|
||||
act = int(act) # type: ignore
|
||||
return act
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
batch: ObsBatchProtocol,
|
||||
state: dict | BatchProtocol | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> BatchProtocol:
|
||||
) -> ActBatchProtocol:
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which MUST have the following keys:
|
||||
@ -190,22 +220,19 @@ class BasePolicy(ABC, nn.Module):
|
||||
act = policy.map_action(act, batch)
|
||||
"""
|
||||
|
||||
@overload
|
||||
def map_action(self, act: BatchProtocol) -> BatchProtocol:
|
||||
...
|
||||
|
||||
@overload
|
||||
def map_action(self, act: np.ndarray) -> np.ndarray:
|
||||
...
|
||||
|
||||
@overload
|
||||
def map_action(self, act: torch.Tensor) -> torch.Tensor:
|
||||
...
|
||||
@staticmethod
|
||||
def _action_to_numpy(act: arr_type) -> np.ndarray:
|
||||
act = to_numpy(act) # NOTE: to_numpy could confusingly also return a Batch
|
||||
if not isinstance(act, np.ndarray):
|
||||
raise ValueError(
|
||||
f"act should have been be a numpy.ndarray, but got {type(act)}.",
|
||||
)
|
||||
return act
|
||||
|
||||
def map_action(
|
||||
self,
|
||||
act: BatchProtocol | np.ndarray | torch.Tensor,
|
||||
) -> BatchProtocol | np.ndarray | torch.Tensor:
|
||||
act: arr_type,
|
||||
) -> np.ndarray:
|
||||
"""Map raw network output to action range in gym's env.action_space.
|
||||
|
||||
This function is called in :meth:`~tianshou.data.Collector.collect` and only
|
||||
@ -223,24 +250,24 @@ class BasePolicy(ABC, nn.Module):
|
||||
:return: action in the same form of input "act" but remap to the target action
|
||||
space.
|
||||
"""
|
||||
if isinstance(self.action_space, gym.spaces.Box) and isinstance(act, np.ndarray):
|
||||
# currently this action mapping only supports np.ndarray action
|
||||
act = self._action_to_numpy(act)
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
if self.action_bound_method == "clip":
|
||||
act = np.clip(act, -1.0, 1.0)
|
||||
elif self.action_bound_method == "tanh":
|
||||
act = np.tanh(act)
|
||||
if self.action_scaling:
|
||||
assert (
|
||||
np.min(act) >= -1.0 and np.max(act) <= 1.0 # type: ignore
|
||||
np.min(act) >= -1.0 and np.max(act) <= 1.0
|
||||
), f"action scaling only accepts raw action range = [-1, 1], but got: {act}"
|
||||
low, high = self.action_space.low, self.action_space.high
|
||||
act = low + (high - low) * (act + 1.0) / 2.0 # type: ignore
|
||||
act = low + (high - low) * (act + 1.0) / 2.0
|
||||
return act
|
||||
|
||||
def map_action_inverse(
|
||||
self,
|
||||
act: BatchProtocol | list | np.ndarray,
|
||||
) -> BatchProtocol | list | np.ndarray:
|
||||
act: arr_type,
|
||||
) -> np.ndarray:
|
||||
"""Inverse operation to :meth:`~tianshou.policy.BasePolicy.map_action`.
|
||||
|
||||
This function is called in :meth:`~tianshou.data.Collector.collect` for
|
||||
@ -252,17 +279,17 @@ class BasePolicy(ABC, nn.Module):
|
||||
|
||||
:return: action remapped.
|
||||
"""
|
||||
act = self._action_to_numpy(act)
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
act = to_numpy(act)
|
||||
if isinstance(act, np.ndarray):
|
||||
if self.action_scaling:
|
||||
low, high = self.action_space.low, self.action_space.high
|
||||
scale = high - low
|
||||
eps = np.finfo(np.float32).eps.item()
|
||||
scale[scale < eps] += eps
|
||||
act = (act - low) * 2.0 / scale - 1.0
|
||||
if self.action_bound_method == "tanh":
|
||||
act = (np.log(1.0 + act) - np.log(1.0 - act)) / 2.0 # type: ignore
|
||||
if self.action_scaling:
|
||||
low, high = self.action_space.low, self.action_space.high
|
||||
scale = high - low
|
||||
eps = np.finfo(np.float32).eps.item()
|
||||
scale[scale < eps] += eps
|
||||
act = (act - low) * 2.0 / scale - 1.0
|
||||
if self.action_bound_method == "tanh":
|
||||
act = (np.log(1.0 + act) - np.log(1.0 - act)) / 2.0
|
||||
|
||||
return act
|
||||
|
||||
def process_buffer(self, buffer: TBuffer) -> TBuffer:
|
||||
|
@ -7,7 +7,11 @@ import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import Batch, to_torch
|
||||
from tianshou.data.batch import BatchProtocol
|
||||
from tianshou.data.types import ModelOutputBatchProtocol, RolloutBatchProtocol
|
||||
from tianshou.data.types import (
|
||||
ModelOutputBatchProtocol,
|
||||
ObsBatchProtocol,
|
||||
RolloutBatchProtocol,
|
||||
)
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
|
||||
@ -55,7 +59,7 @@ class ImitationPolicy(BasePolicy):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
batch: ObsBatchProtocol,
|
||||
state: dict | BatchProtocol | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ModelOutputBatchProtocol:
|
||||
|
@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import Any, Literal, Self
|
||||
from typing import Any, Literal, Self, cast
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -8,7 +8,7 @@ import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import Batch, to_torch
|
||||
from tianshou.data.batch import BatchProtocol
|
||||
from tianshou.data.types import RolloutBatchProtocol
|
||||
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.utils.net.continuous import VAE
|
||||
@ -112,10 +112,10 @@ class BCQPolicy(BasePolicy):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
batch: ObsBatchProtocol,
|
||||
state: dict | BatchProtocol | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
) -> ActBatchProtocol:
|
||||
"""Compute action over the given batch data."""
|
||||
# There is "obs" in the Batch
|
||||
# obs_group: several groups. Each group has a state.
|
||||
@ -134,7 +134,7 @@ class BCQPolicy(BasePolicy):
|
||||
max_indice = q1.argmax(0)
|
||||
act_group.append(act[max_indice].cpu().data.numpy().flatten())
|
||||
act_group = np.array(act_group)
|
||||
return Batch(act=act_group)
|
||||
return cast(ActBatchProtocol, Batch(act=act_group))
|
||||
|
||||
def sync_weight(self) -> None:
|
||||
"""Soft-update the weight for the target network."""
|
||||
|
@ -156,7 +156,7 @@ class CQLPolicy(SACPolicy):
|
||||
self.soft_update(self.critic2_old, self.critic2, self.tau)
|
||||
|
||||
def actor_pred(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
batch = Batch(obs=obs, info=None)
|
||||
batch = Batch(obs=obs, info=[None] * len(obs))
|
||||
obs_result = self(batch)
|
||||
return obs_result.act, obs_result.log_prob
|
||||
|
||||
|
@ -7,7 +7,11 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch
|
||||
from tianshou.data.types import ImitationBatchProtocol, RolloutBatchProtocol
|
||||
from tianshou.data.types import (
|
||||
ImitationBatchProtocol,
|
||||
ObsBatchProtocol,
|
||||
RolloutBatchProtocol,
|
||||
)
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
|
||||
@ -101,26 +105,25 @@ class DiscreteBCQPolicy(DQNPolicy):
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indices] # batch.obs_next: s_{t+n}
|
||||
next_obs_batch = Batch(obs=batch.obs_next, info=[None] * len(batch))
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
act = self(batch, input="obs_next").act
|
||||
act = self(next_obs_batch).act
|
||||
target_q, _ = self.model_old(batch.obs_next)
|
||||
return target_q[np.arange(len(act)), act]
|
||||
|
||||
def forward( # type: ignore
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
batch: ObsBatchProtocol,
|
||||
state: dict | Batch | np.ndarray | None = None,
|
||||
input: str = "obs",
|
||||
**kwargs: Any,
|
||||
) -> ImitationBatchProtocol:
|
||||
# TODO: Liskov substitution principle is violated here, the superclass
|
||||
# produces a batch with the field logits, but this one doesn't.
|
||||
# Should be fixed in the future!
|
||||
obs = batch[input]
|
||||
q_value, state = self.model(obs, state=state, info=batch.info)
|
||||
if not hasattr(self, "max_action_num"):
|
||||
q_value, state = self.model(batch.obs, state=state, info=batch.info)
|
||||
if self.max_action_num is None:
|
||||
self.max_action_num = q_value.shape[1]
|
||||
imitation_logits, _ = self.imitator(obs, state=state, info=batch.info)
|
||||
imitation_logits, _ = self.imitator(batch.obs, state=state, info=batch.info)
|
||||
|
||||
# mask actions for argmax
|
||||
ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values
|
||||
|
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch
|
||||
from tianshou.data.batch import BatchProtocol
|
||||
from tianshou.data.types import RolloutBatchProtocol
|
||||
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
||||
@ -72,10 +72,10 @@ class ICMPolicy(BasePolicy):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
batch: ObsBatchProtocol,
|
||||
state: dict | BatchProtocol | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> BatchProtocol:
|
||||
) -> ActBatchProtocol:
|
||||
"""Compute action over the given batch data by inner policy.
|
||||
|
||||
.. seealso::
|
||||
|
@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.data.batch import BatchProtocol
|
||||
from tianshou.data.types import ActBatchProtocol, RolloutBatchProtocol
|
||||
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
|
||||
@ -198,7 +198,7 @@ class PSRLPolicy(BasePolicy):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
batch: ObsBatchProtocol,
|
||||
state: dict | BatchProtocol | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ActBatchProtocol:
|
||||
@ -213,9 +213,9 @@ class PSRLPolicy(BasePolicy):
|
||||
more detailed explanation.
|
||||
"""
|
||||
assert isinstance(batch.obs, np.ndarray), "only support np.ndarray observation"
|
||||
# TODO: shouldn't the model output a state as well if state is passed (i.e. RNNs are involved)?
|
||||
act = self.model(batch.obs, state=state, info=batch.info)
|
||||
result = Batch(act=act)
|
||||
return cast(ActBatchProtocol, result)
|
||||
return cast(ActBatchProtocol, Batch(act=act))
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]:
|
||||
n_s, n_a = self.model.n_state, self.model.n_action
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -9,6 +9,7 @@ from tianshou.data.batch import BatchProtocol
|
||||
from tianshou.data.types import (
|
||||
BatchWithReturnsProtocol,
|
||||
ModelOutputBatchProtocol,
|
||||
ObsBatchProtocol,
|
||||
RolloutBatchProtocol,
|
||||
)
|
||||
from tianshou.policy import DQNPolicy
|
||||
@ -73,9 +74,10 @@ class BranchingDQNPolicy(DQNPolicy):
|
||||
)
|
||||
self.model = cast(BranchingNet, self.model)
|
||||
|
||||
# TODO: mypy complains b/c max_action_num is declared in base class, see todo there
|
||||
# TODO: this used to be a public property called max_action_num,
|
||||
# but it collides with an attr of the same name in base class
|
||||
@property
|
||||
def max_action_num(self) -> int: # type: ignore
|
||||
def _action_per_branch(self) -> int:
|
||||
return self.model.action_per_branch # type: ignore
|
||||
|
||||
@property
|
||||
@ -83,15 +85,18 @@ class BranchingDQNPolicy(DQNPolicy):
|
||||
return self.model.num_branches # type: ignore
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indices] # batch.obs_next: s_{t+n}
|
||||
result = self(batch, input="obs_next")
|
||||
obs_next_batch = Batch(
|
||||
obs=buffer[indices].obs_next,
|
||||
info=[None] * len(indices),
|
||||
) # obs_next: s_{t+n}
|
||||
result = self(obs_next_batch)
|
||||
if self._target:
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
target_q = self(batch, model="model_old", input="obs_next").logits
|
||||
target_q = self(obs_next_batch, model="model_old").logits
|
||||
else:
|
||||
target_q = result.logits
|
||||
if self.is_double:
|
||||
act = np.expand_dims(self(batch, input="obs_next").act, -1)
|
||||
act = np.expand_dims(self(obs_next_batch).act, -1)
|
||||
act = to_torch(act, dtype=torch.long, device=target_q.device)
|
||||
else:
|
||||
act = target_q.max(-1).indices.unsqueeze(-1)
|
||||
@ -114,7 +119,7 @@ class BranchingDQNPolicy(DQNPolicy):
|
||||
mean_target_q = np.mean(target_q, -1) if len(target_q.shape) > 1 else target_q
|
||||
_target_q = rew + gamma * mean_target_q * (1 - end_flag)
|
||||
target_q = np.repeat(_target_q[..., None], self.num_branches, axis=-1)
|
||||
target_q = np.repeat(target_q[..., None], self.max_action_num, axis=-1)
|
||||
target_q = np.repeat(target_q[..., None], self._action_per_branch, axis=-1)
|
||||
|
||||
batch.returns = to_torch_as(target_q, target_q_torch)
|
||||
if hasattr(batch, "weight"): # prio buffer update
|
||||
@ -132,14 +137,14 @@ class BranchingDQNPolicy(DQNPolicy):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
batch: ObsBatchProtocol,
|
||||
state: dict | BatchProtocol | np.ndarray | None = None,
|
||||
model: str = "model",
|
||||
input: str = "obs",
|
||||
model: Literal["model", "model_old"] = "model",
|
||||
**kwargs: Any,
|
||||
) -> ModelOutputBatchProtocol:
|
||||
model = getattr(self, model)
|
||||
obs = batch[input]
|
||||
obs = batch.obs
|
||||
# TODO: this is very contrived, see also iqn.py
|
||||
obs_next = obs.obs if hasattr(obs, "obs") else obs
|
||||
logits, hidden = model(obs_next, state=state, info=batch.info)
|
||||
act = to_numpy(logits.max(dim=-1)[1])
|
||||
@ -174,7 +179,11 @@ class BranchingDQNPolicy(DQNPolicy):
|
||||
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
|
||||
bsz = len(act)
|
||||
rand_mask = np.random.rand(bsz) < self.eps
|
||||
rand_act = np.random.randint(low=0, high=self.max_action_num, size=(bsz, act.shape[-1]))
|
||||
rand_act = np.random.randint(
|
||||
low=0,
|
||||
high=self._action_per_branch,
|
||||
size=(bsz, act.shape[-1]),
|
||||
)
|
||||
if hasattr(batch.obs, "mask"):
|
||||
rand_act += batch.obs.mask
|
||||
act[rand_mask] = rand_act[rand_mask]
|
||||
|
@ -4,7 +4,7 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from tianshou.data import ReplayBuffer
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data.types import RolloutBatchProtocol
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
@ -90,11 +90,12 @@ class C51Policy(DQNPolicy):
|
||||
return super().compute_q_value((logits * self.support).sum(2), mask)
|
||||
|
||||
def _target_dist(self, batch: RolloutBatchProtocol) -> torch.Tensor:
|
||||
obs_next_batch = Batch(obs=batch.obs_next, info=[None] * len(batch))
|
||||
if self._target:
|
||||
act = self(batch, input="obs_next").act
|
||||
next_dist = self(batch, model="model_old", input="obs_next").logits
|
||||
act = self(obs_next_batch).act
|
||||
next_dist = self(obs_next_batch, model="model_old").logits
|
||||
else:
|
||||
next_batch = self(batch, input="obs_next")
|
||||
next_batch = self(obs_next_batch)
|
||||
act = next_batch.act
|
||||
next_dist = next_batch.logits
|
||||
next_dist = next_dist[np.arange(len(act)), act, :]
|
||||
|
@ -1,6 +1,6 @@
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from typing import Any, Literal, Self
|
||||
from typing import Any, Literal, Self, cast
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -8,7 +8,12 @@ import torch
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data.batch import BatchProtocol
|
||||
from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol
|
||||
from tianshou.data.types import (
|
||||
ActStateBatchProtocol,
|
||||
BatchWithReturnsProtocol,
|
||||
ObsBatchProtocol,
|
||||
RolloutBatchProtocol,
|
||||
)
|
||||
from tianshou.exploration import BaseNoise, GaussianNoise
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
@ -118,8 +123,11 @@ class DDPGPolicy(BasePolicy):
|
||||
self.soft_update(self.critic_old, self.critic, self.tau)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indices] # batch.obs_next: s_{t+n}
|
||||
return self.critic_old(batch.obs_next, self(batch, model="actor_old", input="obs_next").act)
|
||||
obs_next_batch = Batch(
|
||||
obs=buffer[indices].obs_next,
|
||||
info=[None] * len(indices),
|
||||
) # obs_next: s_{t+n}
|
||||
return self.critic_old(obs_next_batch.obs, self(obs_next_batch, model="actor_old").act)
|
||||
|
||||
def process_fn(
|
||||
self,
|
||||
@ -138,12 +146,11 @@ class DDPGPolicy(BasePolicy):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
batch: ObsBatchProtocol,
|
||||
state: dict | BatchProtocol | np.ndarray | None = None,
|
||||
model: Literal["actor", "actor_old"] = "actor",
|
||||
input: str = "obs",
|
||||
**kwargs: Any,
|
||||
) -> BatchProtocol:
|
||||
) -> ActStateBatchProtocol:
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 2 keys:
|
||||
@ -157,9 +164,8 @@ class DDPGPolicy(BasePolicy):
|
||||
more detailed explanation.
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
obs = batch[input]
|
||||
actions, hidden = model(obs, state=state, info=batch.info)
|
||||
return Batch(act=actions, state=hidden)
|
||||
actions, hidden = model(batch.obs, state=state, info=batch.info)
|
||||
return cast(ActStateBatchProtocol, Batch(act=actions, state=hidden))
|
||||
|
||||
@staticmethod
|
||||
def _mse_optimizer(
|
||||
|
@ -8,7 +8,7 @@ from torch.distributions import Categorical
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch
|
||||
from tianshou.data.batch import BatchProtocol
|
||||
from tianshou.data.types import RolloutBatchProtocol
|
||||
from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import SACPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
|
||||
@ -92,13 +92,11 @@ class DiscreteSACPolicy(SACPolicy):
|
||||
|
||||
def forward( # type: ignore
|
||||
self,
|
||||
batch: Batch,
|
||||
batch: ObsBatchProtocol,
|
||||
state: dict | Batch | np.ndarray | None = None,
|
||||
input: str = "obs",
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
obs = batch[input]
|
||||
logits, hidden = self.actor(obs, state=state, info=batch.info)
|
||||
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
|
||||
dist = Categorical(logits=logits)
|
||||
if self.deterministic_eval and not self.training:
|
||||
act = logits.argmax(axis=-1)
|
||||
@ -107,12 +105,15 @@ class DiscreteSACPolicy(SACPolicy):
|
||||
return Batch(logits=logits, act=act, state=hidden, dist=dist)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indices] # batch.obs: s_{t+n}
|
||||
obs_next_result = self(batch, input="obs_next")
|
||||
obs_next_batch = Batch(
|
||||
obs=buffer[indices].obs_next,
|
||||
info=[None] * len(indices),
|
||||
) # obs_next: s_{t+n}
|
||||
obs_next_result = self(obs_next_batch)
|
||||
dist = obs_next_result.dist
|
||||
target_q = dist.probs * torch.min(
|
||||
self.critic_old(batch.obs_next),
|
||||
self.critic2_old(batch.obs_next),
|
||||
self.critic_old(obs_next_batch.obs),
|
||||
self.critic2_old(obs_next_batch.obs),
|
||||
)
|
||||
return target_q.sum(dim=-1) + self.alpha * dist.entropy()
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, Self, cast
|
||||
from typing import Any, Literal, Self, cast
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -10,6 +10,7 @@ from tianshou.data.batch import BatchProtocol
|
||||
from tianshou.data.types import (
|
||||
BatchWithReturnsProtocol,
|
||||
ModelOutputBatchProtocol,
|
||||
ObsBatchProtocol,
|
||||
RolloutBatchProtocol,
|
||||
)
|
||||
from tianshou.policy import BasePolicy
|
||||
@ -91,7 +92,7 @@ class DQNPolicy(BasePolicy):
|
||||
self.clip_loss_grad = clip_loss_grad
|
||||
|
||||
# TODO: set in forward, fix this!
|
||||
self.max_action_num: int
|
||||
self.max_action_num: int | None = None
|
||||
|
||||
def set_eps(self, eps: float) -> None:
|
||||
"""Set the eps for epsilon-greedy exploration."""
|
||||
@ -108,11 +109,14 @@ class DQNPolicy(BasePolicy):
|
||||
self.model_old.load_state_dict(self.model.state_dict())
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indices] # batch.obs_next: s_{t+n}
|
||||
result = self(batch, input="obs_next")
|
||||
obs_next_batch = Batch(
|
||||
obs=buffer[indices].obs_next,
|
||||
info=[None] * len(indices),
|
||||
) # obs_next: s_{t+n}
|
||||
result = self(obs_next_batch)
|
||||
if self._target:
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
target_q = self(batch, model="model_old", input="obs_next").logits
|
||||
target_q = self(obs_next_batch, model="model_old").logits
|
||||
else:
|
||||
target_q = result.logits
|
||||
if self.is_double:
|
||||
@ -151,10 +155,9 @@ class DQNPolicy(BasePolicy):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
batch: ObsBatchProtocol,
|
||||
state: dict | BatchProtocol | np.ndarray | None = None,
|
||||
model: str = "model",
|
||||
input: str = "obs",
|
||||
model: Literal["model", "model_old"] = "model",
|
||||
**kwargs: Any,
|
||||
) -> ModelOutputBatchProtocol:
|
||||
"""Compute action over the given batch data.
|
||||
@ -185,11 +188,12 @@ class DQNPolicy(BasePolicy):
|
||||
more detailed explanation.
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
obs = batch[input]
|
||||
obs = batch.obs
|
||||
# TODO: this is convoluted! See also other places where this is done.
|
||||
obs_next = obs.obs if hasattr(obs, "obs") else obs
|
||||
logits, hidden = model(obs_next, state=state, info=batch.info)
|
||||
q = self.compute_q_value(logits, getattr(obs, "mask", None))
|
||||
if not hasattr(self, "max_action_num"):
|
||||
if self.max_action_num is None:
|
||||
self.max_action_num = q.shape[1]
|
||||
act = to_numpy(q.max(dim=1)[1])
|
||||
result = Batch(logits=logits, act=act, state=hidden)
|
||||
@ -226,6 +230,9 @@ class DQNPolicy(BasePolicy):
|
||||
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
|
||||
bsz = len(act)
|
||||
rand_mask = np.random.rand(bsz) < self.eps
|
||||
assert (
|
||||
self.max_action_num is not None
|
||||
), "Can't call this method before max_action_num was set in first forward"
|
||||
q = np.random.rand(bsz, self.max_action_num) # [0, 1]
|
||||
if hasattr(batch.obs, "mask"):
|
||||
q += batch.obs.mask
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, to_numpy
|
||||
from tianshou.data.types import FQFBatchProtocol, RolloutBatchProtocol
|
||||
from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import DQNPolicy, QRDQNPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction
|
||||
@ -84,13 +84,16 @@ class FQFPolicy(QRDQNPolicy):
|
||||
self.fraction_optim = fraction_optim
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indices] # batch.obs_next: s_{t+n}
|
||||
obs_next_batch = Batch(
|
||||
obs=buffer[indices].obs_next,
|
||||
info=[None] * len(indices),
|
||||
) # obs_next: s_{t+n}
|
||||
if self._target:
|
||||
result = self(batch, input="obs_next")
|
||||
result = self(obs_next_batch)
|
||||
act, fractions = result.act, result.fractions
|
||||
next_dist = self(batch, model="model_old", input="obs_next", fractions=fractions).logits
|
||||
next_dist = self(obs_next_batch, model="model_old", fractions=fractions).logits
|
||||
else:
|
||||
next_batch = self(batch, input="obs_next")
|
||||
next_batch = self(obs_next_batch)
|
||||
act = next_batch.act
|
||||
next_dist = next_batch.logits
|
||||
return next_dist[np.arange(len(act)), act, :]
|
||||
@ -98,15 +101,15 @@ class FQFPolicy(QRDQNPolicy):
|
||||
# TODO: fix Liskov substitution principle violation
|
||||
def forward( # type: ignore
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
batch: ObsBatchProtocol,
|
||||
state: dict | Batch | np.ndarray | None = None,
|
||||
model: str = "model",
|
||||
input: str = "obs",
|
||||
model: Literal["model", "model_old"] = "model",
|
||||
fractions: Batch | None = None,
|
||||
**kwargs: Any,
|
||||
) -> FQFBatchProtocol:
|
||||
model = getattr(self, model)
|
||||
obs = batch[input]
|
||||
obs = batch.obs
|
||||
# TODO: this is convoluted! See also other places where this is done
|
||||
obs_next = obs.obs if hasattr(obs, "obs") else obs
|
||||
if fractions is None:
|
||||
(logits, fractions, quantiles_tau), hidden = model(
|
||||
@ -125,7 +128,7 @@ class FQFPolicy(QRDQNPolicy):
|
||||
)
|
||||
weighted_logits = (fractions.taus[:, 1:] - fractions.taus[:, :-1]).unsqueeze(1) * logits
|
||||
q = DQNPolicy.compute_q_value(self, weighted_logits.sum(2), getattr(obs, "mask", None))
|
||||
if not hasattr(self, "max_action_num"):
|
||||
if self.max_action_num is None: # type: ignore
|
||||
# TODO: see same thing in DQNPolicy! Also reduce code duplication.
|
||||
self.max_action_num = q.shape[1]
|
||||
act = to_numpy(q.max(dim=1)[1])
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -7,7 +7,11 @@ import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import Batch, to_numpy
|
||||
from tianshou.data.batch import BatchProtocol
|
||||
from tianshou.data.types import QuantileRegressionBatchProtocol, RolloutBatchProtocol
|
||||
from tianshou.data.types import (
|
||||
ObsBatchProtocol,
|
||||
QuantileRegressionBatchProtocol,
|
||||
RolloutBatchProtocol,
|
||||
)
|
||||
from tianshou.policy import QRDQNPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
|
||||
@ -88,10 +92,9 @@ class IQNPolicy(QRDQNPolicy):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
batch: ObsBatchProtocol,
|
||||
state: dict | BatchProtocol | np.ndarray | None = None,
|
||||
model: str = "model",
|
||||
input: str = "obs",
|
||||
model: Literal["model", "model_old"] = "model",
|
||||
**kwargs: Any,
|
||||
) -> QuantileRegressionBatchProtocol:
|
||||
if model == "model_old":
|
||||
@ -101,7 +104,8 @@ class IQNPolicy(QRDQNPolicy):
|
||||
else:
|
||||
sample_size = self.sample_size
|
||||
model = getattr(self, model)
|
||||
obs = batch[input]
|
||||
obs = batch.obs
|
||||
# TODO: this seems very contrived!
|
||||
obs_next = obs.obs if hasattr(obs, "obs") else obs
|
||||
(logits, taus), hidden = model(
|
||||
obs_next,
|
||||
@ -110,8 +114,8 @@ class IQNPolicy(QRDQNPolicy):
|
||||
info=batch.info,
|
||||
)
|
||||
q = self.compute_q_value(logits, getattr(obs, "mask", None))
|
||||
if not hasattr(self, "max_action_num"):
|
||||
# TODO: see same thing in DQNPolicy! Also reduce code duplication.
|
||||
if self.max_action_num is None: # type: ignore
|
||||
# TODO: see same thing in DQNPolicy!
|
||||
self.max_action_num = q.shape[1]
|
||||
act = to_numpy(q.max(dim=1)[1])
|
||||
result = Batch(logits=logits, act=act, state=hidden, taus=taus)
|
||||
|
@ -11,6 +11,7 @@ from tianshou.data.batch import BatchProtocol
|
||||
from tianshou.data.types import (
|
||||
BatchWithReturnsProtocol,
|
||||
DistBatchProtocol,
|
||||
ObsBatchProtocol,
|
||||
RolloutBatchProtocol,
|
||||
)
|
||||
from tianshou.policy import BasePolicy
|
||||
@ -157,7 +158,7 @@ class PGPolicy(BasePolicy):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
batch: ObsBatchProtocol,
|
||||
state: dict | BatchProtocol | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> DistBatchProtocol:
|
||||
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import ReplayBuffer
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data.types import RolloutBatchProtocol
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
@ -79,12 +79,15 @@ class QRDQNPolicy(DQNPolicy):
|
||||
warnings.filterwarnings("ignore", message="Using a target size")
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indices] # batch.obs_next: s_{t+n}
|
||||
obs_next_batch = Batch(
|
||||
obs=buffer[indices].obs_next,
|
||||
info=[None] * len(indices),
|
||||
) # obs_next: s_{t+n}
|
||||
if self._target:
|
||||
act = self(batch, input="obs_next").act
|
||||
next_dist = self(batch, model="model_old", input="obs_next").logits
|
||||
act = self(obs_next_batch).act
|
||||
next_dist = self(obs_next_batch, model="model_old").logits
|
||||
else:
|
||||
next_batch = self(batch, input="obs_next")
|
||||
next_batch = self(obs_next_batch)
|
||||
act = next_batch.act
|
||||
next_dist = next_batch.logits
|
||||
return next_dist[np.arange(len(act)), act, :]
|
||||
|
@ -6,7 +6,7 @@ import torch
|
||||
from torch.distributions import Independent, Normal
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data.types import RolloutBatchProtocol
|
||||
from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
@ -131,13 +131,11 @@ class REDQPolicy(DDPGPolicy):
|
||||
|
||||
def forward( # type: ignore
|
||||
self,
|
||||
batch: Batch,
|
||||
batch: ObsBatchProtocol,
|
||||
state: dict | Batch | np.ndarray | None = None,
|
||||
input: str = "obs",
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
obs = batch[input]
|
||||
loc_scale, h = self.actor(obs, state=state, info=batch.info)
|
||||
loc_scale, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||
loc, scale = loc_scale
|
||||
dist = Independent(Normal(loc, scale), 1)
|
||||
act = loc if self.deterministic_eval and not self.training else dist.rsample()
|
||||
@ -153,11 +151,14 @@ class REDQPolicy(DDPGPolicy):
|
||||
return Batch(logits=loc_scale, act=squashed_action, state=h, dist=dist, log_prob=log_prob)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indices] # batch.obs: s_{t+n}
|
||||
obs_next_result = self(batch, input="obs_next")
|
||||
obs_next_batch = Batch(
|
||||
obs=buffer[indices].obs_next,
|
||||
info=[None] * len(indices),
|
||||
) # obs_next: s_{t+n}
|
||||
obs_next_result = self(obs_next_batch)
|
||||
a_ = obs_next_result.act
|
||||
sample_ensemble_idx = np.random.choice(self.ensemble_size, self.subset_size, replace=False)
|
||||
qs = self.critic_old(batch.obs_next, a_)[sample_ensemble_idx, ...]
|
||||
qs = self.critic_old(obs_next_batch.obs, a_)[sample_ensemble_idx, ...]
|
||||
if self.target_mode == "min":
|
||||
target_q, _ = torch.min(qs, dim=0)
|
||||
elif self.target_mode == "mean":
|
||||
|
@ -7,7 +7,11 @@ import torch
|
||||
from torch.distributions import Independent, Normal
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data.types import DistLogProbBatchProtocol, RolloutBatchProtocol
|
||||
from tianshou.data.types import (
|
||||
DistLogProbBatchProtocol,
|
||||
ObsBatchProtocol,
|
||||
RolloutBatchProtocol,
|
||||
)
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
@ -147,13 +151,11 @@ class SACPolicy(DDPGPolicy):
|
||||
# TODO: violates Liskov substitution principle
|
||||
def forward( # type: ignore
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
batch: ObsBatchProtocol,
|
||||
state: dict | Batch | np.ndarray | None = None,
|
||||
input: str = "obs",
|
||||
**kwargs: Any,
|
||||
) -> DistLogProbBatchProtocol:
|
||||
obs = batch[input]
|
||||
logits, hidden = self.actor(obs, state=state, info=batch.info)
|
||||
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
|
||||
assert isinstance(logits, tuple)
|
||||
dist = Independent(Normal(*logits), 1)
|
||||
if self.deterministic_eval and not self.training:
|
||||
@ -179,13 +181,16 @@ class SACPolicy(DDPGPolicy):
|
||||
return cast(DistLogProbBatchProtocol, result)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indices] # batch.obs: s_{t+n}
|
||||
obs_next_result = self(batch, input="obs_next")
|
||||
obs_next_batch = Batch(
|
||||
obs=buffer[indices].obs_next,
|
||||
info=[None] * len(indices),
|
||||
) # obs_next: s_{t+n}
|
||||
obs_next_result = self(obs_next_batch)
|
||||
act_ = obs_next_result.act
|
||||
return (
|
||||
torch.min(
|
||||
self.critic_old(batch.obs_next, act_),
|
||||
self.critic2_old(batch.obs_next, act_),
|
||||
self.critic_old(obs_next_batch.obs, act_),
|
||||
self.critic2_old(obs_next_batch.obs, act_),
|
||||
)
|
||||
- self.alpha * obs_next_result.log_prob
|
||||
)
|
||||
|
@ -5,7 +5,7 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from tianshou.data import ReplayBuffer
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data.types import RolloutBatchProtocol
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.policy import DDPGPolicy
|
||||
@ -114,15 +114,18 @@ class TD3Policy(DDPGPolicy):
|
||||
self.soft_update(self.actor_old, self.actor, self.tau)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indices] # batch.obs: s_{t+n}
|
||||
act_ = self(batch, model="actor_old", input="obs_next").act
|
||||
obs_next_batch = Batch(
|
||||
obs=buffer[indices].obs_next,
|
||||
info=[None] * len(indices),
|
||||
) # obs_next: s_{t+n}
|
||||
act_ = self(obs_next_batch, model="actor_old").act
|
||||
noise = torch.randn(size=act_.shape, device=act_.device) * self.policy_noise
|
||||
if self.noise_clip > 0.0:
|
||||
noise = noise.clamp(-self.noise_clip, self.noise_clip)
|
||||
act_ += noise
|
||||
return torch.min(
|
||||
self.critic_old(batch.obs_next, act_),
|
||||
self.critic2_old(batch.obs_next, act_),
|
||||
self.critic_old(obs_next_batch.obs, act_),
|
||||
self.critic2_old(obs_next_batch.obs, act_),
|
||||
)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]:
|
||||
|
@ -4,7 +4,7 @@ import numpy as np
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.data.batch import BatchProtocol
|
||||
from tianshou.data.types import ActBatchProtocol, RolloutBatchProtocol
|
||||
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ class RandomPolicy(BasePolicy):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
batch: ObsBatchProtocol,
|
||||
state: dict | BatchProtocol | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ActBatchProtocol:
|
||||
|
@ -180,6 +180,7 @@ class ActorProb(BaseActor):
|
||||
of how preprocess_net is suggested to be defined.
|
||||
"""
|
||||
|
||||
# TODO: force kwargs, adjust downstream code
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
|
Loading…
x
Reference in New Issue
Block a user