1. add policy.eval() in all test scripts' "watch performance" 2. remove dict return support for collector preprocess_fn 3. add `__contains__` and `pop` in batch: `key in batch`, `batch.pop(key, deft)` 4. exact n_episode for a list of n_episode limitation and save fake data in cache_buffer when self.buffer is None (#184) 5. fix tensorboard logging: h-axis stands for env step instead of gradient step; add test results into tensorboard 6. add test_returns (both GAE and nstep) 7. change the type-checking order in batch.py and converter.py in order to meet the most often case first 8. fix shape inconsistency for torch.Tensor in replay buffer 9. remove `**kwargs` in ReplayBuffer 10. remove default value in batch.split() and add merge_last argument (#185) 11. improve nstep efficiency 12. add max_batchsize in onpolicy algorithms 13. potential bugfix for subproc.wait 14. fix RecurrentActorProb 15. improve the code-coverage (from 90% to 95%) and remove the dead code 16. fix some incorrect type annotation The above improvement also increases the training FPS: on my computer, the previous version is only ~1800 FPS and after that, it can reach ~2050 (faster than v0.2.4.post1).
163 lines
6.5 KiB
Python
163 lines
6.5 KiB
Python
import torch
|
|
import numpy as np
|
|
from copy import deepcopy
|
|
from typing import Dict, Tuple, Union, Optional
|
|
|
|
from tianshou.policy import BasePolicy
|
|
from tianshou.exploration import BaseNoise, GaussianNoise
|
|
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
|
|
|
|
|
class DDPGPolicy(BasePolicy):
|
|
"""Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971
|
|
|
|
:param torch.nn.Module actor: the actor network following the rules in
|
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
|
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
|
:param torch.nn.Module critic: the critic network. (s, a -> Q(s, a))
|
|
:param torch.optim.Optimizer critic_optim: the optimizer for critic
|
|
network.
|
|
:param float tau: param for soft update of the target network, defaults to
|
|
0.005.
|
|
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
|
|
:param BaseNoise exploration_noise: the exploration noise,
|
|
add to the action, defaults to ``GaussianNoise(sigma=0.1)``.
|
|
:param action_range: the action range (minimum, maximum).
|
|
:type action_range: (float, float)
|
|
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
|
defaults to ``False``.
|
|
:param bool ignore_done: ignore the done flag while training the policy,
|
|
defaults to ``False``.
|
|
:param int estimation_step: greater than 1, the number of steps to look
|
|
ahead.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
|
explanation.
|
|
"""
|
|
|
|
def __init__(self,
|
|
actor: torch.nn.Module,
|
|
actor_optim: torch.optim.Optimizer,
|
|
critic: torch.nn.Module,
|
|
critic_optim: torch.optim.Optimizer,
|
|
tau: float = 0.005,
|
|
gamma: float = 0.99,
|
|
exploration_noise: Optional[BaseNoise]
|
|
= GaussianNoise(sigma=0.1),
|
|
action_range: Optional[Tuple[float, float]] = None,
|
|
reward_normalization: bool = False,
|
|
ignore_done: bool = False,
|
|
estimation_step: int = 1,
|
|
**kwargs) -> None:
|
|
super().__init__(**kwargs)
|
|
if actor is not None:
|
|
self.actor, self.actor_old = actor, deepcopy(actor)
|
|
self.actor_old.eval()
|
|
self.actor_optim = actor_optim
|
|
if critic is not None:
|
|
self.critic, self.critic_old = critic, deepcopy(critic)
|
|
self.critic_old.eval()
|
|
self.critic_optim = critic_optim
|
|
assert 0 <= tau <= 1, 'tau should in [0, 1]'
|
|
self._tau = tau
|
|
assert 0 <= gamma <= 1, 'gamma should in [0, 1]'
|
|
self._gamma = gamma
|
|
self._noise = exploration_noise
|
|
assert action_range is not None
|
|
self._range = action_range
|
|
self._action_bias = (action_range[0] + action_range[1]) / 2
|
|
self._action_scale = (action_range[1] - action_range[0]) / 2
|
|
# it is only a little difference to use rand_normal
|
|
# self.noise = OUNoise()
|
|
self._rm_done = ignore_done
|
|
self._rew_norm = reward_normalization
|
|
assert estimation_step > 0, 'estimation_step should greater than 0'
|
|
self._n_step = estimation_step
|
|
|
|
def set_exp_noise(self, noise: Optional[BaseNoise]) -> None:
|
|
"""Set the exploration noise."""
|
|
self._noise = noise
|
|
|
|
def train(self, mode=True) -> torch.nn.Module:
|
|
"""Set the module in training mode, except for the target network."""
|
|
self.training = mode
|
|
self.actor.train(mode)
|
|
self.critic.train(mode)
|
|
return self
|
|
|
|
def sync_weight(self) -> None:
|
|
"""Soft-update the weight for the target network."""
|
|
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
|
|
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
|
for o, n in zip(
|
|
self.critic_old.parameters(), self.critic.parameters()):
|
|
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
|
|
|
def _target_q(self, buffer: ReplayBuffer,
|
|
indice: np.ndarray) -> torch.Tensor:
|
|
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
|
with torch.no_grad():
|
|
target_q = self.critic_old(batch.obs_next, self(
|
|
batch, model='actor_old', input='obs_next',
|
|
explorating=False).act)
|
|
return target_q
|
|
|
|
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
|
indice: np.ndarray) -> Batch:
|
|
if self._rm_done:
|
|
batch.done = batch.done * 0.
|
|
batch = self.compute_nstep_return(
|
|
batch, buffer, indice, self._target_q,
|
|
self._gamma, self._n_step, self._rew_norm)
|
|
return batch
|
|
|
|
def forward(self, batch: Batch,
|
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
|
model: str = 'actor',
|
|
input: str = 'obs',
|
|
explorating: bool = True,
|
|
**kwargs) -> Batch:
|
|
"""Compute action over the given batch data.
|
|
|
|
:return: A :class:`~tianshou.data.Batch` which has 2 keys:
|
|
|
|
* ``act`` the action.
|
|
* ``state`` the hidden state.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
|
more detailed explanation.
|
|
"""
|
|
model = getattr(self, model)
|
|
obs = getattr(batch, input)
|
|
actions, h = model(obs, state=state, info=batch.info)
|
|
actions += self._action_bias
|
|
if self.training and explorating:
|
|
actions += to_torch_as(self._noise(actions.shape), actions)
|
|
actions = actions.clamp(self._range[0], self._range[1])
|
|
return Batch(act=actions, state=h)
|
|
|
|
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
|
weight = batch.pop('weight', 1.)
|
|
current_q = self.critic(batch.obs, batch.act).flatten()
|
|
target_q = batch.returns.flatten()
|
|
td = current_q - target_q
|
|
critic_loss = (td.pow(2) * weight).mean()
|
|
batch.weight = td # prio-buffer
|
|
self.critic_optim.zero_grad()
|
|
critic_loss.backward()
|
|
self.critic_optim.step()
|
|
action = self(batch, explorating=False).act
|
|
actor_loss = -self.critic(batch.obs, action).mean()
|
|
self.actor_optim.zero_grad()
|
|
actor_loss.backward()
|
|
self.actor_optim.step()
|
|
self.sync_weight()
|
|
return {
|
|
'loss/actor': actor_loss.item(),
|
|
'loss/critic': critic_loss.item(),
|
|
}
|