Closes: #1058 ### Api Extensions - Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063 - `Collector`s can now be closed, and their reset is more granular. #1063 - Trainers can control whether collectors should be reset prior to training. #1063 - Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063 ### Internal Improvements - `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 - Introduced a first iteration of a naming convention for vars in `Collector`s. #1063 - Generally improved readability of Collector code and associated tests (still quite some way to go). #1063 - Improved typing for `exploration_noise` and within Collector. #1063 ### Breaking Changes - Removed `.data` attribute from `Collector` and its child classes. #1063 - Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset` expicitly or pass `reset_before_collect=True` . #1063 - VectorEnvs now return an array of info-dicts on reset instead of a list. #1063 - Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 --------- Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
177 lines
5.9 KiB
Python
177 lines
5.9 KiB
Python
from typing import Any, Literal, Self, TypeVar
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch
|
|
from tianshou.data.batch import BatchProtocol
|
|
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
|
|
from tianshou.policy import BasePolicy
|
|
from tianshou.policy.base import (
|
|
TLearningRateScheduler,
|
|
TrainingStats,
|
|
TrainingStatsWrapper,
|
|
TTrainingStats,
|
|
)
|
|
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
|
|
|
|
|
class ICMTrainingStats(TrainingStatsWrapper):
|
|
def __init__(
|
|
self,
|
|
wrapped_stats: TrainingStats,
|
|
*,
|
|
icm_loss: float,
|
|
icm_forward_loss: float,
|
|
icm_inverse_loss: float,
|
|
) -> None:
|
|
self.icm_loss = icm_loss
|
|
self.icm_forward_loss = icm_forward_loss
|
|
self.icm_inverse_loss = icm_inverse_loss
|
|
super().__init__(wrapped_stats)
|
|
|
|
|
|
class ICMPolicy(BasePolicy[ICMTrainingStats]):
|
|
"""Implementation of Intrinsic Curiosity Module. arXiv:1705.05363.
|
|
|
|
:param policy: a base policy to add ICM to.
|
|
:param model: the ICM model.
|
|
:param optim: a torch.optim for optimizing the model.
|
|
:param lr_scale: the scaling factor for ICM learning.
|
|
:param forward_loss_weight: the weight for forward model loss.
|
|
:param observation_space: Env's observation space.
|
|
:param action_scaling: if True, scale the action from [-1, 1] to the range
|
|
of action_space. Only used if the action_space is continuous.
|
|
:param action_bound_method: method to bound action to range [-1, 1].
|
|
Only used if the action_space is continuous.
|
|
:param lr_scheduler: if not None, will be called in `policy.update()`.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
|
explanation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
policy: BasePolicy[TTrainingStats],
|
|
model: IntrinsicCuriosityModule,
|
|
optim: torch.optim.Optimizer,
|
|
lr_scale: float,
|
|
reward_scale: float,
|
|
forward_loss_weight: float,
|
|
action_space: gym.Space,
|
|
observation_space: gym.Space | None = None,
|
|
action_scaling: bool = False,
|
|
action_bound_method: Literal["clip", "tanh"] | None = "clip",
|
|
lr_scheduler: TLearningRateScheduler | None = None,
|
|
) -> None:
|
|
super().__init__(
|
|
action_space=action_space,
|
|
observation_space=observation_space,
|
|
action_scaling=action_scaling,
|
|
action_bound_method=action_bound_method,
|
|
lr_scheduler=lr_scheduler,
|
|
)
|
|
self.policy = policy
|
|
self.model = model
|
|
self.optim = optim
|
|
self.lr_scale = lr_scale
|
|
self.reward_scale = reward_scale
|
|
self.forward_loss_weight = forward_loss_weight
|
|
|
|
def train(self, mode: bool = True) -> Self:
|
|
"""Set the module in training mode."""
|
|
self.policy.train(mode)
|
|
self.training = mode
|
|
self.model.train(mode)
|
|
return self
|
|
|
|
def forward(
|
|
self,
|
|
batch: ObsBatchProtocol,
|
|
state: dict | BatchProtocol | np.ndarray | None = None,
|
|
**kwargs: Any,
|
|
) -> ActBatchProtocol:
|
|
"""Compute action over the given batch data by inner policy.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
|
more detailed explanation.
|
|
"""
|
|
return self.policy.forward(batch, state, **kwargs)
|
|
|
|
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
|
|
|
|
def exploration_noise(
|
|
self,
|
|
act: _TArrOrActBatch,
|
|
batch: ObsBatchProtocol,
|
|
) -> _TArrOrActBatch:
|
|
return self.policy.exploration_noise(act, batch)
|
|
|
|
def set_eps(self, eps: float) -> None:
|
|
"""Set the eps for epsilon-greedy exploration."""
|
|
if hasattr(self.policy, "set_eps"):
|
|
self.policy.set_eps(eps)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def process_fn(
|
|
self,
|
|
batch: RolloutBatchProtocol,
|
|
buffer: ReplayBuffer,
|
|
indices: np.ndarray,
|
|
) -> RolloutBatchProtocol:
|
|
"""Pre-process the data from the provided replay buffer.
|
|
|
|
Used in :meth:`update`. Check out :ref:`process_fn` for more information.
|
|
"""
|
|
mse_loss, act_hat = self.model(batch.obs, batch.act, batch.obs_next)
|
|
batch.policy = Batch(orig_rew=batch.rew, act_hat=act_hat, mse_loss=mse_loss)
|
|
batch.rew += to_numpy(mse_loss * self.reward_scale)
|
|
return self.policy.process_fn(batch, buffer, indices)
|
|
|
|
def post_process_fn(
|
|
self,
|
|
batch: BatchProtocol,
|
|
buffer: ReplayBuffer,
|
|
indices: np.ndarray,
|
|
) -> None:
|
|
"""Post-process the data from the provided replay buffer.
|
|
|
|
Typical usage is to update the sampling weight in prioritized
|
|
experience replay. Used in :meth:`update`.
|
|
"""
|
|
self.policy.post_process_fn(batch, buffer, indices)
|
|
batch.rew = batch.policy.orig_rew # restore original reward
|
|
|
|
def learn(
|
|
self,
|
|
batch: RolloutBatchProtocol,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> ICMTrainingStats:
|
|
training_stat = self.policy.learn(batch, **kwargs)
|
|
self.optim.zero_grad()
|
|
act_hat = batch.policy.act_hat
|
|
act = to_torch(batch.act, dtype=torch.long, device=act_hat.device)
|
|
inverse_loss = F.cross_entropy(act_hat, act).mean()
|
|
forward_loss = batch.policy.mse_loss.mean()
|
|
loss = (
|
|
(1 - self.forward_loss_weight) * inverse_loss + self.forward_loss_weight * forward_loss
|
|
) * self.lr_scale
|
|
loss.backward()
|
|
self.optim.step()
|
|
|
|
return ICMTrainingStats(
|
|
training_stat,
|
|
icm_loss=loss.item(),
|
|
icm_forward_loss=forward_loss.item(),
|
|
icm_inverse_loss=inverse_loss.item(),
|
|
)
|