Refactor PG algorithm and change behavior of compute_episodic_return (#319)

- simplify code
- apply value normalization (global) and adv norm (per-batch) in on-policy algorithms
This commit is contained in:
ChenDRAG 2021-03-23 22:05:48 +08:00 committed by GitHub
parent 2c11b6e43b
commit e27b5a26f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 109 additions and 192 deletions

View File

@ -30,9 +30,9 @@ def test_episodic_returns(size=2560):
for b in batch: for b in batch:
b.obs = b.act = 1 b.obs = b.act = 1
buf.add(b) buf.add(b)
batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1) returns, _ = fn(batch, buf, buf.sample_index(0), gamma=.1, gae_lambda=1)
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
assert np.allclose(batch.returns, ans) assert np.allclose(returns, ans)
buf.reset() buf.reset()
batch = Batch( batch = Batch(
done=np.array([0, 1, 0, 1, 0, 1, 0.]), done=np.array([0, 1, 0, 1, 0, 1, 0.]),
@ -41,9 +41,9 @@ def test_episodic_returns(size=2560):
for b in batch: for b in batch:
b.obs = b.act = 1 b.obs = b.act = 1
buf.add(b) buf.add(b)
batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1) returns, _ = fn(batch, buf, buf.sample_index(0), gamma=.1, gae_lambda=1)
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
assert np.allclose(batch.returns, ans) assert np.allclose(returns, ans)
buf.reset() buf.reset()
batch = Batch( batch = Batch(
done=np.array([0, 1, 0, 1, 0, 0, 1.]), done=np.array([0, 1, 0, 1, 0, 0, 1.]),
@ -52,9 +52,9 @@ def test_episodic_returns(size=2560):
for b in batch: for b in batch:
b.obs = b.act = 1 b.obs = b.act = 1
buf.add(b) buf.add(b)
batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1) returns, _ = fn(batch, buf, buf.sample_index(0), gamma=.1, gae_lambda=1)
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
assert np.allclose(batch.returns, ans) assert np.allclose(returns, ans)
buf.reset() buf.reset()
batch = Batch( batch = Batch(
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
@ -64,12 +64,12 @@ def test_episodic_returns(size=2560):
b.obs = b.act = 1 b.obs = b.act = 1
buf.add(b) buf.add(b)
v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95) returns, _ = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95)
returns = np.array([ ground_truth = np.array([
454.8344, 376.1143, 291.298, 200., 454.8344, 376.1143, 291.298, 200.,
464.5610, 383.1085, 295.387, 201., 464.5610, 383.1085, 295.387, 201.,
474.2876, 390.1027, 299.476, 202.]) 474.2876, 390.1027, 299.476, 202.])
assert np.allclose(ret.returns, returns) assert np.allclose(returns, ground_truth)
buf.reset() buf.reset()
batch = Batch( batch = Batch(
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
@ -82,12 +82,12 @@ def test_episodic_returns(size=2560):
b.obs = b.act = 1 b.obs = b.act = 1
buf.add(b) buf.add(b)
v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95) returns, _ = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95)
returns = np.array([ ground_truth = np.array([
454.0109, 375.2386, 290.3669, 199.01, 454.0109, 375.2386, 290.3669, 199.01,
462.9138, 381.3571, 293.5248, 199.02, 462.9138, 381.3571, 293.5248, 199.02,
474.2876, 390.1027, 299.476, 202.]) 474.2876, 390.1027, 299.476, 202.])
assert np.allclose(ret.returns, returns) assert np.allclose(returns, ground_truth)
if __name__ == '__main__': if __name__ == '__main__':
buf = ReplayBuffer(size) buf = ReplayBuffer(size)

View File

@ -91,7 +91,8 @@ def test_ppo(args=get_args()):
def dist(*logits): def dist(*logits):
return Independent(Normal(*logits), 1) return Independent(Normal(*logits), 1)
policy = PPOPolicy( policy = PPOPolicy(
actor, critic, optim, dist, args.gamma, actor, critic, optim, dist,
discount_factor=args.gamma,
max_grad_norm=args.max_grad_norm, max_grad_norm=args.max_grad_norm,
eps_clip=args.eps_clip, eps_clip=args.eps_clip,
vf_coef=args.vf_coef, vf_coef=args.vf_coef,

View File

@ -78,7 +78,8 @@ def test_a2c_with_il(args=get_args()):
actor.parameters()).union(critic.parameters()), lr=args.lr) actor.parameters()).union(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical dist = torch.distributions.Categorical
policy = A2CPolicy( policy = A2CPolicy(
actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda, actor, critic, optim, dist,
discount_factor=args.gamma, gae_lambda=args.gae_lambda,
vf_coef=args.vf_coef, ent_coef=args.ent_coef, vf_coef=args.vf_coef, ent_coef=args.ent_coef,
max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm, max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm,
action_space=env.action_space) action_space=env.action_space)

View File

@ -17,7 +17,7 @@ from tianshou.data import Collector, VectorReplayBuffer
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.95) parser.add_argument('--gamma', type=float, default=0.95)
@ -27,7 +27,7 @@ def get_args():
parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int, parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128, 128]) nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
@ -65,6 +65,11 @@ def test_pg(args=get_args()):
policy = PGPolicy(net, optim, dist, args.gamma, policy = PGPolicy(net, optim, dist, args.gamma,
reward_normalization=args.rew_norm, reward_normalization=args.rew_norm,
action_space=env.action_space) action_space=env.action_space)
for m in net.modules():
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, policy, train_envs,

View File

@ -80,7 +80,8 @@ def test_ppo(args=get_args()):
actor.parameters()).union(critic.parameters()), lr=args.lr) actor.parameters()).union(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical dist = torch.distributions.Categorical
policy = PPOPolicy( policy = PPOPolicy(
actor, critic, optim, dist, args.gamma, actor, critic, optim, dist,
discount_factor=args.gamma,
max_grad_norm=args.max_grad_norm, max_grad_norm=args.max_grad_norm,
eps_clip=args.eps_clip, eps_clip=args.eps_clip,
vf_coef=args.vf_coef, vf_coef=args.vf_coef,

View File

@ -4,7 +4,7 @@ import numpy as np
from torch import nn from torch import nn
from numba import njit from numba import njit
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Union, Optional, Callable from typing import Any, Dict, Tuple, Union, Optional, Callable
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
@ -254,14 +254,14 @@ class BasePolicy(ABC, nn.Module):
buffer: ReplayBuffer, buffer: ReplayBuffer,
indice: np.ndarray, indice: np.ndarray,
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None, v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
v_s: Optional[Union[np.ndarray, torch.Tensor]] = None,
gamma: float = 0.99, gamma: float = 0.99,
gae_lambda: float = 0.95, gae_lambda: float = 0.95,
rew_norm: bool = False, ) -> Tuple[np.ndarray, np.ndarray]:
) -> Batch:
"""Compute returns over given batch. """Compute returns over given batch.
Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438) Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
to calculate q function/reward to go of given batch. to calculate q/advantage value of given batch.
:param Batch batch: a data batch which contains several episodes of data in :param Batch batch: a data batch which contains several episodes of data in
sequential order. Mind that the end of each finished episode of batch sequential order. Mind that the end of each finished episode of batch
@ -273,10 +273,8 @@ class BasePolicy(ABC, nn.Module):
:param float gamma: the discount factor, should be in [0, 1]. Default to 0.99. :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
:param float gae_lambda: the parameter for Generalized Advantage Estimation, :param float gae_lambda: the parameter for Generalized Advantage Estimation,
should be in [0, 1]. Default to 0.95. should be in [0, 1]. Default to 0.95.
:param bool rew_norm: normalize the reward to Normal(0, 1). Default to False.
:return: a Batch. The result will be stored in batch.returns as a numpy :return: two numpy arrays (returns, advantage) with each shape (bsz, ).
array with shape (bsz, ).
""" """
rew = batch.rew rew = batch.rew
if v_s_ is None: if v_s_ is None:
@ -284,14 +282,14 @@ class BasePolicy(ABC, nn.Module):
v_s_ = np.zeros_like(rew) v_s_ = np.zeros_like(rew)
else: else:
v_s_ = to_numpy(v_s_.flatten()) * BasePolicy.value_mask(buffer, indice) v_s_ = to_numpy(v_s_.flatten()) * BasePolicy.value_mask(buffer, indice)
v_s = np.roll(v_s_, 1) if v_s is None else to_numpy(v_s.flatten())
end_flag = batch.done.copy() end_flag = batch.done.copy()
end_flag[np.isin(indice, buffer.unfinished_index())] = True end_flag[np.isin(indice, buffer.unfinished_index())] = True
returns = _episodic_return(v_s_, rew, end_flag, gamma, gae_lambda) advantage = _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda)
if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2): returns = advantage + v_s
returns = (returns - returns.mean()) / returns.std() # normalization varies from each policy, so we don't do it here
batch.returns = returns return returns, advantage
return batch
@staticmethod @staticmethod
def compute_nstep_return( def compute_nstep_return(
@ -355,8 +353,6 @@ class BasePolicy(ABC, nn.Module):
i64 = np.array([[0, 1]], dtype=np.int64) i64 = np.array([[0, 1]], dtype=np.int64)
_gae_return(f64, f64, f64, b, 0.1, 0.1) _gae_return(f64, f64, f64, b, 0.1, 0.1)
_gae_return(f32, f32, f64, b, 0.1, 0.1) _gae_return(f32, f32, f64, b, 0.1, 0.1)
_episodic_return(f64, f64, b, 0.1, 0.1)
_episodic_return(f32, f64, b, 0.1, 0.1)
_nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1) _nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1)
@ -379,19 +375,6 @@ def _gae_return(
return returns return returns
@njit
def _episodic_return(
v_s_: np.ndarray,
rew: np.ndarray,
end_flag: np.ndarray,
gamma: float,
gae_lambda: float,
) -> np.ndarray:
"""Numba speedup: 4.1s -> 0.057s."""
v_s = np.roll(v_s_, 1)
return _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda) + v_s
@njit @njit
def _nstep_return( def _nstep_return(
rew: np.ndarray, rew: np.ndarray,

View File

@ -2,7 +2,7 @@ import torch
import numpy as np import numpy as np
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import Any, Dict, List, Type, Union, Optional from typing import Any, Dict, List, Type, Optional
from tianshou.policy import PGPolicy from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
@ -53,17 +53,14 @@ class A2CPolicy(PGPolicy):
critic: torch.nn.Module, critic: torch.nn.Module,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: Type[torch.distributions.Distribution], dist_fn: Type[torch.distributions.Distribution],
discount_factor: float = 0.99,
vf_coef: float = 0.5, vf_coef: float = 0.5,
ent_coef: float = 0.01, ent_coef: float = 0.01,
max_grad_norm: Optional[float] = None, max_grad_norm: Optional[float] = None,
gae_lambda: float = 0.95, gae_lambda: float = 0.95,
reward_normalization: bool = False,
max_batchsize: int = 256, max_batchsize: int = 256,
**kwargs: Any **kwargs: Any
) -> None: ) -> None:
super().__init__(None, optim, dist_fn, discount_factor, **kwargs) super().__init__(actor, optim, dist_fn, **kwargs)
self.actor = actor
self.critic = critic self.critic = critic
assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]." assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]."
self._lambda = gae_lambda self._lambda = gae_lambda
@ -71,51 +68,27 @@ class A2CPolicy(PGPolicy):
self._weight_ent = ent_coef self._weight_ent = ent_coef
self._grad_norm = max_grad_norm self._grad_norm = max_grad_norm
self._batch = max_batchsize self._batch = max_batchsize
self._rew_norm = reward_normalization
def process_fn( def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch: ) -> Batch:
if self._lambda in [0.0, 1.0]: v_s_ = []
return self.compute_episodic_return(
batch, buffer, indice,
None, gamma=self._gamma, gae_lambda=self._lambda)
v_ = []
with torch.no_grad(): with torch.no_grad():
for b in batch.split(self._batch, shuffle=False, merge_last=True): for b in batch.split(self._batch, shuffle=False, merge_last=True):
v_.append(to_numpy(self.critic(b.obs_next))) v_s_.append(to_numpy(self.critic(b.obs_next)))
v_ = np.concatenate(v_, axis=0) v_s_ = np.concatenate(v_s_, axis=0)
return self.compute_episodic_return( if self._rew_norm: # unnormalize v_s_
batch, buffer, indice, v_, v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
gamma=self._gamma, gae_lambda=self._lambda, rew_norm=self._rew_norm) unnormalized_returns, _ = self.compute_episodic_return(
batch, buffer, indice, v_s_=v_s_,
def forward( gamma=self._gamma, gae_lambda=self._lambda)
self, if self._rew_norm:
batch: Batch, batch.returns = (unnormalized_returns - self.ret_rms.mean) / \
state: Optional[Union[dict, Batch, np.ndarray]] = None, np.sqrt(self.ret_rms.var + self._eps)
**kwargs: Any self.ret_rms.update(unnormalized_returns)
) -> Batch:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
* ``act`` the action.
* ``logits`` the network's raw output.
* ``dist`` the action distribution.
* ``state`` the hidden state.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
logits, h = self.actor(batch.obs, state=state, info=batch.info)
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else: else:
dist = self.dist_fn(logits) batch.returns = unnormalized_returns
act = dist.sample() return batch
return Batch(logits=logits, act=act, state=h, dist=dist)
def learn( # type: ignore def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any

View File

@ -4,10 +4,11 @@ from typing import Any, Dict, List, Type, Union, Optional
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as from tianshou.data import Batch, ReplayBuffer, to_torch_as
from tianshou.utils import RunningMeanStd
class PGPolicy(BasePolicy): class PGPolicy(BasePolicy):
"""Implementation of Vanilla Policy Gradient. """Implementation of REINFORCE algorithm.
:param torch.nn.Module model: a model following the rules in :param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits) :class:`~tianshou.policy.BasePolicy`. (s -> logits)
@ -33,7 +34,7 @@ class PGPolicy(BasePolicy):
def __init__( def __init__(
self, self,
model: Optional[torch.nn.Module], model: torch.nn.Module,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: Type[torch.distributions.Distribution], dist_fn: Type[torch.distributions.Distribution],
discount_factor: float = 0.99, discount_factor: float = 0.99,
@ -45,14 +46,15 @@ class PGPolicy(BasePolicy):
) -> None: ) -> None:
super().__init__(action_scaling=action_scaling, super().__init__(action_scaling=action_scaling,
action_bound_method=action_bound_method, **kwargs) action_bound_method=action_bound_method, **kwargs)
if model is not None: self.actor = model
self.model: torch.nn.Module = model
self.optim = optim self.optim = optim
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
self.dist_fn = dist_fn self.dist_fn = dist_fn
assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]" assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
self._gamma = discount_factor self._gamma = discount_factor
self._rew_norm = reward_normalization self._rew_norm = reward_normalization
self.ret_rms = RunningMeanStd()
self._eps = 1e-8
def process_fn( def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
@ -65,11 +67,16 @@ class PGPolicy(BasePolicy):
where :math:`T` is the terminal time step, :math:`\gamma` is the where :math:`T` is the terminal time step, :math:`\gamma` is the
discount factor, :math:`\gamma \in [0, 1]`. discount factor, :math:`\gamma \in [0, 1]`.
""" """
# batch.returns = self._vanilla_returns(batch) v_s_ = np.full(indice.shape, self.ret_rms.mean)
# batch.returns = self._vectorized_returns(batch) unnormalized_returns, _ = self.compute_episodic_return(
return self.compute_episodic_return( batch, buffer, indice, v_s_=v_s_, gamma=self._gamma, gae_lambda=1.0)
batch, buffer, indice, gamma=self._gamma, if self._rew_norm:
gae_lambda=1.0, rew_norm=self._rew_norm) batch.returns = (unnormalized_returns - self.ret_rms.mean) / \
np.sqrt(self.ret_rms.var + self._eps)
self.ret_rms.update(unnormalized_returns)
else:
batch.returns = unnormalized_returns
return batch
def forward( def forward(
self, self,
@ -91,7 +98,7 @@ class PGPolicy(BasePolicy):
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation. more detailed explanation.
""" """
logits, h = self.model(batch.obs, state=state, info=batch.info) logits, h = self.actor(batch.obs, state=state)
if isinstance(logits, tuple): if isinstance(logits, tuple):
dist = self.dist_fn(*logits) dist = self.dist_fn(*logits)
else: else:
@ -106,9 +113,10 @@ class PGPolicy(BasePolicy):
for _ in range(repeat): for _ in range(repeat):
for b in batch.split(batch_size, merge_last=True): for b in batch.split(batch_size, merge_last=True):
self.optim.zero_grad() self.optim.zero_grad()
dist = self(b).dist result = self(b)
a = to_torch_as(b.act, dist.logits) dist = result.dist
r = to_torch_as(b.returns, dist.logits) a = to_torch_as(b.act, result.act)
r = to_torch_as(b.returns, result.act)
log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1) log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
loss = -(log_prob * r).mean() loss = -(log_prob * r).mean()
loss.backward() loss.backward()
@ -119,27 +127,3 @@ class PGPolicy(BasePolicy):
self.lr_scheduler.step() self.lr_scheduler.step()
return {"loss": losses} return {"loss": losses}
# def _vanilla_returns(self, batch):
# returns = batch.rew[:]
# last = 0
# for i in range(len(returns) - 1, -1, -1):
# if not batch.done[i]:
# returns[i] += self._gamma * last
# last = returns[i]
# return returns
# def _vectorized_returns(self, batch):
# # according to my tests, it is slower than _vanilla_returns
# # import scipy.signal
# convolve = np.convolve
# # convolve = scipy.signal.convolve
# rew = batch.rew[::-1]
# batch_size = len(rew)
# gammas = self._gamma ** np.arange(batch_size)
# c = convolve(rew, gammas)[:batch_size]
# T = np.where(batch.done[::-1])[0]
# d = np.zeros_like(rew)
# d[T] += c[T] - rew[T]
# d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T)
# return (c - convolve(d, gammas)[:batch_size])[::-1]

View File

@ -1,13 +1,13 @@
import torch import torch
import numpy as np import numpy as np
from torch import nn from torch import nn
from typing import Any, Dict, List, Type, Union, Optional from typing import Any, Dict, List, Type, Optional
from tianshou.policy import PGPolicy from tianshou.policy import A2CPolicy
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
class PPOPolicy(PGPolicy): class PPOPolicy(A2CPolicy):
r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347. r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347.
:param torch.nn.Module actor: the actor network following the rules in :param torch.nn.Module actor: the actor network following the rules in
@ -30,8 +30,8 @@ class PPOPolicy(PGPolicy):
Default to 5.0 (set None if you do not want to use it). Default to 5.0 (set None if you do not want to use it).
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1. :param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1.
Default to True. Default to True.
:param bool reward_normalization: normalize the returns to Normal(0, 1). :param bool reward_normalization: normalize the returns and advantage to
Default to True. Normal(0, 1). Default to False.
:param int max_batchsize: the maximum size of the batch when computing GAE, :param int max_batchsize: the maximum size of the batch when computing GAE,
depends on the size of available memory and the memory cost of the depends on the size of available memory and the memory cost of the
model; should be as large as possible within the memory constraint. model; should be as large as possible within the memory constraint.
@ -58,7 +58,6 @@ class PPOPolicy(PGPolicy):
critic: torch.nn.Module, critic: torch.nn.Module,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: Type[torch.distributions.Distribution], dist_fn: Type[torch.distributions.Distribution],
discount_factor: float = 0.99,
max_grad_norm: Optional[float] = None, max_grad_norm: Optional[float] = None,
eps_clip: float = 0.2, eps_clip: float = 0.2,
vf_coef: float = 0.5, vf_coef: float = 0.5,
@ -66,81 +65,50 @@ class PPOPolicy(PGPolicy):
gae_lambda: float = 0.95, gae_lambda: float = 0.95,
dual_clip: Optional[float] = None, dual_clip: Optional[float] = None,
value_clip: bool = True, value_clip: bool = True,
reward_normalization: bool = True,
max_batchsize: int = 256, max_batchsize: int = 256,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(None, optim, dist_fn, discount_factor, **kwargs) super().__init__(
self._max_grad_norm = max_grad_norm actor, critic, optim, dist_fn, max_grad_norm=max_grad_norm,
vf_coef=vf_coef, ent_coef=ent_coef, gae_lambda=gae_lambda,
max_batchsize=max_batchsize, **kwargs)
self._eps_clip = eps_clip self._eps_clip = eps_clip
self._weight_vf = vf_coef
self._weight_ent = ent_coef
self.actor = actor
self.critic = critic
self._batch = max_batchsize
assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]."
self._lambda = gae_lambda
assert dual_clip is None or dual_clip > 1.0, \ assert dual_clip is None or dual_clip > 1.0, \
"Dual-clip PPO parameter should greater than 1.0." "Dual-clip PPO parameter should greater than 1.0."
self._dual_clip = dual_clip self._dual_clip = dual_clip
self._value_clip = value_clip self._value_clip = value_clip
self._rew_norm = reward_normalization
def process_fn( def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch: ) -> Batch:
if self._rew_norm: v_s, v_s_, old_log_prob = [], [], []
mean, std = batch.rew.mean(), batch.rew.std()
if not np.isclose(std, 0.0, 1e-2):
batch.rew = (batch.rew - mean) / std
v, v_, old_log_prob = [], [], []
with torch.no_grad(): with torch.no_grad():
for b in batch.split(self._batch, shuffle=False, merge_last=True): for b in batch.split(self._batch, shuffle=False, merge_last=True):
v_.append(self.critic(b.obs_next)) v_s.append(self.critic(b.obs))
v.append(self.critic(b.obs)) v_s_.append(self.critic(b.obs_next))
old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v[0]))) old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v_s[0])))
v_ = to_numpy(torch.cat(v_, dim=0)) batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
batch = self.compute_episodic_return( v_s = to_numpy(batch.v_s)
batch, buffer, indice, v_, gamma=self._gamma, v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten())
gae_lambda=self._lambda, rew_norm=self._rew_norm) if self._rew_norm: # unnormalize v_s & v_s_
batch.v = torch.cat(v, dim=0).flatten() # old value v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
batch.act = to_torch_as(batch.act, v[0]) v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
batch.logp_old = torch.cat(old_log_prob, dim=0) unnormalized_returns, advantages = self.compute_episodic_return(
batch.returns = to_torch_as(batch.returns, v[0]) batch, buffer, indice, v_s_, v_s,
batch.adv = batch.returns - batch.v gamma=self._gamma, gae_lambda=self._lambda)
if self._rew_norm: if self._rew_norm:
mean, std = batch.adv.mean(), batch.adv.std() batch.returns = (unnormalized_returns - self.ret_rms.mean) / \
if not np.isclose(std.item(), 0.0, 1e-2): np.sqrt(self.ret_rms.var + self._eps)
batch.adv = (batch.adv - mean) / std self.ret_rms.update(unnormalized_returns)
return batch mean, std = np.mean(advantages), np.std(advantages)
advantages = (advantages - mean) / std # per-batch norm
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
* ``act`` the action.
* ``logits`` the network's raw output.
* ``dist`` the action distribution.
* ``state`` the hidden state.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
logits, h = self.actor(batch.obs, state=state, info=batch.info)
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else: else:
dist = self.dist_fn(logits) batch.returns = unnormalized_returns
act = dist.sample() batch.act = to_torch_as(batch.act, batch.v_s)
return Batch(logits=logits, act=act, state=h, dist=dist) batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.returns = to_torch_as(batch.returns, batch.v_s)
batch.adv = to_torch_as(advantages, batch.v_s)
return batch
def learn( # type: ignore def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
@ -162,7 +130,8 @@ class PPOPolicy(PGPolicy):
clip_loss = -torch.min(surr1, surr2).mean() clip_loss = -torch.min(surr1, surr2).mean()
clip_losses.append(clip_loss.item()) clip_losses.append(clip_loss.item())
if self._value_clip: if self._value_clip:
v_clip = b.v + (value - b.v).clamp(-self._eps_clip, self._eps_clip) v_clip = b.v_s + (value - b.v_s).clamp(
-self._eps_clip, self._eps_clip)
vf1 = (b.returns - value).pow(2) vf1 = (b.returns - value).pow(2)
vf2 = (b.returns - v_clip).pow(2) vf2 = (b.returns - v_clip).pow(2)
vf_loss = 0.5 * torch.max(vf1, vf2).mean() vf_loss = 0.5 * torch.max(vf1, vf2).mean()
@ -176,10 +145,10 @@ class PPOPolicy(PGPolicy):
losses.append(loss.item()) losses.append(loss.item())
self.optim.zero_grad() self.optim.zero_grad()
loss.backward() loss.backward()
if self._max_grad_norm: if self._grad_norm is not None:
nn.utils.clip_grad_norm_( nn.utils.clip_grad_norm_(
list(self.actor.parameters()) + list(self.critic.parameters()), list(self.actor.parameters()) + list(self.critic.parameters()),
self._max_grad_norm) self._grad_norm)
self.optim.step() self.optim.step()
# update learning rate if lr_scheduler is given # update learning rate if lr_scheduler is given
if self.lr_scheduler is not None: if self.lr_scheduler is not None: