from typing import Any from torch import nn from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import C51Policy from tianshou.utils.net.discrete import NoisyLinear # TODO: this is a hacky thing interviewing side-effects and a return. Should improve. def _sample_noise(model: nn.Module) -> bool: """Sample the random noises of NoisyLinear modules in the model. Returns True if at least one NoisyLinear submodule was found. :param model: a PyTorch module which may have NoisyLinear submodules. :returns: True if model has at least one NoisyLinear submodule; otherwise, False. """ sampled_any_noise = False for m in model.modules(): if isinstance(m, NoisyLinear): m.sample() sampled_any_noise = True return sampled_any_noise # TODO: is this class worth keeping? It barely does anything class RainbowPolicy(C51Policy): """Implementation of Rainbow DQN. arXiv:1710.02298. Same parameters as :class:`~tianshou.policy.C51Policy`. .. seealso:: Please refer to :class:`~tianshou.policy.C51Policy` for more detailed explanation. """ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: _sample_noise(self.model) if self._target and _sample_noise(self.model_old): self.model_old.train() # so that NoisyLinear takes effect return super().learn(batch, **kwargs)