* Enable to convert Batch data back to torch. * Add torch converter to collector. * Fix * Move to_numpy/to_torch convert in dedicated utils.py. * Use to_numpy/to_torch to convert arrays. * fix lint * fix * Add unit test to check Batch from/to numpy. * Fix Batch over Batch. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
122 lines
4.4 KiB
Python
122 lines
4.4 KiB
Python
import torch
|
|
import numpy as np
|
|
from typing import Dict, List, Union, Optional
|
|
|
|
from tianshou.policy import BasePolicy
|
|
from tianshou.data import Batch, ReplayBuffer, to_torch
|
|
|
|
|
|
class PGPolicy(BasePolicy):
|
|
"""Implementation of Vanilla Policy Gradient.
|
|
|
|
:param torch.nn.Module model: a model following the rules in
|
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
|
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
|
:param torch.distributions.Distribution dist_fn: for computing the action.
|
|
:param float discount_factor: in [0, 1].
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
|
explanation.
|
|
"""
|
|
|
|
def __init__(self,
|
|
model: torch.nn.Module,
|
|
optim: torch.optim.Optimizer,
|
|
dist_fn: torch.distributions.Distribution
|
|
= torch.distributions.Categorical,
|
|
discount_factor: float = 0.99,
|
|
reward_normalization: bool = False,
|
|
**kwargs) -> None:
|
|
super().__init__(**kwargs)
|
|
self.model = model
|
|
self.optim = optim
|
|
self.dist_fn = dist_fn
|
|
assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]'
|
|
self._gamma = discount_factor
|
|
self._rew_norm = reward_normalization
|
|
|
|
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
|
indice: np.ndarray) -> Batch:
|
|
r"""Compute the discounted returns for each frame:
|
|
|
|
.. math::
|
|
G_t = \sum_{i=t}^T \gamma^{i-t}r_i
|
|
|
|
, where :math:`T` is the terminal time step, :math:`\gamma` is the
|
|
discount factor, :math:`\gamma \in [0, 1]`.
|
|
"""
|
|
# batch.returns = self._vanilla_returns(batch)
|
|
# batch.returns = self._vectorized_returns(batch)
|
|
# return batch
|
|
return self.compute_episodic_return(
|
|
batch, gamma=self._gamma, gae_lambda=1.)
|
|
|
|
def forward(self, batch: Batch,
|
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
|
**kwargs) -> Batch:
|
|
"""Compute action over the given batch data.
|
|
|
|
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
|
|
|
* ``act`` the action.
|
|
* ``logits`` the network's raw output.
|
|
* ``dist`` the action distribution.
|
|
* ``state`` the hidden state.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
|
more detailed explanation.
|
|
"""
|
|
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
|
if isinstance(logits, tuple):
|
|
dist = self.dist_fn(*logits)
|
|
else:
|
|
dist = self.dist_fn(logits)
|
|
act = dist.sample()
|
|
return Batch(logits=logits, act=act, state=h, dist=dist)
|
|
|
|
def learn(self, batch: Batch, batch_size: int, repeat: int,
|
|
**kwargs) -> Dict[str, List[float]]:
|
|
losses = []
|
|
r = batch.returns
|
|
if self._rew_norm and not np.isclose(r.std(), 0):
|
|
batch.returns = (r - r.mean()) / r.std()
|
|
for _ in range(repeat):
|
|
for b in batch.split(batch_size):
|
|
self.optim.zero_grad()
|
|
dist = self(b).dist
|
|
a = to_torch(b.act, device=dist.logits.device)
|
|
r = to_torch(b.returns, device=dist.logits.device)
|
|
loss = -(dist.log_prob(a) * r).sum()
|
|
loss.backward()
|
|
self.optim.step()
|
|
losses.append(loss.item())
|
|
return {'loss': losses}
|
|
|
|
# def _vanilla_returns(self, batch):
|
|
# returns = batch.rew[:]
|
|
# last = 0
|
|
# for i in range(len(returns) - 1, -1, -1):
|
|
# if not batch.done[i]:
|
|
# returns[i] += self._gamma * last
|
|
# last = returns[i]
|
|
# return returns
|
|
|
|
# def _vectorized_returns(self, batch):
|
|
# # according to my tests, it is slower than _vanilla_returns
|
|
# # import scipy.signal
|
|
# convolve = np.convolve
|
|
# # convolve = scipy.signal.convolve
|
|
# rew = batch.rew[::-1]
|
|
# batch_size = len(rew)
|
|
# gammas = self._gamma ** np.arange(batch_size)
|
|
# c = convolve(rew, gammas)[:batch_size]
|
|
# T = np.where(batch.done[::-1])[0]
|
|
# d = np.zeros_like(rew)
|
|
# d[T] += c[T] - rew[T]
|
|
# d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T)
|
|
# return (c - convolve(d, gammas)[:batch_size])[::-1]
|