nstep all (fix #51)
This commit is contained in:
parent
ff81a18f42
commit
dc451dfe88
@ -20,7 +20,7 @@
|
||||
|
||||
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
||||
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
|
||||
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns
|
||||
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf)
|
||||
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
|
||||
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
||||
@ -30,7 +30,7 @@
|
||||
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
|
||||
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
||||
|
||||
Tianshou supports parallel workers for all algorithms as well since all of them are reformatted as replay-buffer based algorithms. All of the algorithms support recurrent state representation in actor network (RNN-style training in POMDP). The environment state can be any type (Dict, self-defined class, ...).
|
||||
**Tianshou supports parallel workers for all algorithms as well since all of them are reformatted as replay-buffer based algorithms. All of the algorithms support recurrent state representation in actor network (RNN-style training in POMDP). The environment state can be any type (dict, self-defined class, ...). All Q-learning algorithms support n-step returns estimation.**
|
||||
|
||||
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
|
||||
|
||||
@ -102,7 +102,7 @@ We select some of famous reinforcement learning platforms: 2 GitHub repos with m
|
||||
| Algo - Task | PyTorch | TensorFlow | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
|
||||
| PG - CartPole | 6.09±4.60s | None | None | 19.26±2.29s | None | ? |
|
||||
| DQN - CartPole | 6.09±0.87s | 1046.34±291.27s | 93.47±58.05s | 28.56±4.60s | 31.58±11.30s \*\* | ? |
|
||||
| A2C - CartPole | 6.36±1.63s | \*(~1612s) | 57.56±12.87s | 57.92±9.94s | \*(Not converged) | ? |
|
||||
| A2C - CartPole | 10.59±2.04s | \*(~1612s) | 57.56±12.87s | 57.92±9.94s | \*(Not converged) | ? |
|
||||
| PPO - CartPole | 31.82±7.76s | \*(~1179s) | 34.79±17.02s | 44.60±17.04s | 23.99±9.26s \*\* | ? |
|
||||
| PPO - Pendulum | 16.18±2.49s | 745.43±160.82s | 259.73±27.37s | 123.62±44.23s | Runtime Error | ? |
|
||||
| DDPG - Pendulum | 37.26±9.55s | \*(>1h) | 277.52±92.67s | 314.70±7.92s | 59.05±10.03s \*\* | 172.18±62.48s |
|
||||
|
@ -36,7 +36,9 @@ def get_args():
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||
parser.add_argument('--rew-norm', type=int, default=1)
|
||||
parser.add_argument('--ignore-done', type=int, default=1)
|
||||
parser.add_argument('--n-step', type=int, default=1)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
@ -78,7 +80,9 @@ def test_ddpg(args=get_args()):
|
||||
actor, actor_optim, critic, critic_optim,
|
||||
args.tau, args.gamma, args.exploration_noise,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
reward_normalization=args.rew_norm, ignore_done=True)
|
||||
reward_normalization=args.rew_norm,
|
||||
ignore_done=args.ignore_done,
|
||||
estimation_step=args.n_step)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
|
@ -44,9 +44,9 @@ def get_args():
|
||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||
# parser.add_argument('--dual-clip', type=float, default=5.)
|
||||
parser.add_argument('--value-clip', type=bool, default=True)
|
||||
parser.add_argument('--rew-norm', type=int, default=1)
|
||||
parser.add_argument('--dual-clip', type=float, default=None)
|
||||
parser.add_argument('--value-clip', type=int, default=1)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
@ -37,7 +37,9 @@ def get_args():
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||
parser.add_argument('--rew-norm', type=int, default=1)
|
||||
parser.add_argument('--ignore-done', type=int, default=1)
|
||||
parser.add_argument('--n-step', type=int, default=4)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
@ -83,7 +85,9 @@ def test_sac_with_il(args=get_args()):
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
args.tau, args.gamma, args.alpha,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
reward_normalization=args.rew_norm, ignore_done=True)
|
||||
reward_normalization=args.rew_norm,
|
||||
ignore_done=args.ignore_done,
|
||||
estimation_step=args.n_step)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
@ -126,7 +130,7 @@ def test_sac_with_il(args=get_args()):
|
||||
train_collector.reset()
|
||||
result = offpolicy_trainer(
|
||||
il_policy, train_collector, il_test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.step_per_epoch // 5, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
|
@ -39,7 +39,9 @@ def get_args():
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||
parser.add_argument('--rew-norm', type=int, default=1)
|
||||
parser.add_argument('--ignore-done', type=int, default=1)
|
||||
parser.add_argument('--n-step', type=int, default=1)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
@ -86,7 +88,9 @@ def test_td3(args=get_args()):
|
||||
args.tau, args.gamma, args.exploration_noise, args.policy_noise,
|
||||
args.update_actor_freq, args.noise_clip,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
reward_normalization=args.rew_norm, ignore_done=True)
|
||||
reward_normalization=args.rew_norm,
|
||||
ignore_done=args.ignore_done,
|
||||
estimation_step=args.n_step)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
|
@ -31,7 +31,7 @@ def get_args():
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=2)
|
||||
parser.add_argument('--training-num', type=int, default=32)
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
@ -40,7 +40,7 @@ def get_args():
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
# a2c special
|
||||
parser.add_argument('--vf-coef', type=float, default=0.5)
|
||||
parser.add_argument('--ent-coef', type=float, default=0.001)
|
||||
parser.add_argument('--ent-coef', type=float, default=0.0)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=None)
|
||||
parser.add_argument('--gae-lambda', type=float, default=1.)
|
||||
parser.add_argument('--rew-norm', type=bool, default=False)
|
||||
|
@ -102,7 +102,7 @@ def get_args():
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||
parser.add_argument('--rew-norm', type=int, default=1)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
@ -43,9 +43,9 @@ def get_args():
|
||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||
parser.add_argument('--gae-lambda', type=float, default=0.8)
|
||||
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||
parser.add_argument('--rew-norm', type=int, default=1)
|
||||
parser.add_argument('--dual-clip', type=float, default=None)
|
||||
parser.add_argument('--value-clip', type=bool, default=True)
|
||||
parser.add_argument('--value-clip', type=int, default=1)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
from tianshou.data.batch import Batch
|
||||
from tianshou.data.utils import to_numpy, to_torch
|
||||
from tianshou.data.utils import to_numpy, to_torch, \
|
||||
to_torch_as
|
||||
from tianshou.data.buffer import ReplayBuffer, \
|
||||
ListReplayBuffer, PrioritizedReplayBuffer
|
||||
from tianshou.data.collector import Collector
|
||||
@ -8,6 +9,7 @@ __all__ = [
|
||||
'Batch',
|
||||
'to_numpy',
|
||||
'to_torch',
|
||||
'to_torch_as',
|
||||
'ReplayBuffer',
|
||||
'ListReplayBuffer',
|
||||
'PrioritizedReplayBuffer',
|
||||
|
@ -38,3 +38,13 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
|
||||
elif isinstance(x, Batch):
|
||||
x.to_torch(dtype, device)
|
||||
return x
|
||||
|
||||
|
||||
def to_torch_as(x: Union[torch.Tensor, dict, Batch, np.ndarray],
|
||||
y: torch.Tensor
|
||||
) -> Union[dict, Batch, torch.Tensor]:
|
||||
"""Return an object without np.ndarray. Same as
|
||||
``to_torch(x, dtype=y.dtype, device=y.device)``.
|
||||
"""
|
||||
assert isinstance(y, torch.Tensor)
|
||||
return to_torch(x, dtype=y.dtype, device=y.device)
|
||||
|
@ -4,7 +4,7 @@ from torch import nn
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Union, Optional, Callable
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||
|
||||
|
||||
class BasePolicy(ABC, nn.Module):
|
||||
@ -138,9 +138,10 @@ class BasePolicy(ABC, nn.Module):
|
||||
batch: Batch,
|
||||
buffer: ReplayBuffer,
|
||||
indice: np.ndarray,
|
||||
target_q_fn: Callable[[ReplayBuffer, np.ndarray], np.ndarray],
|
||||
target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor],
|
||||
gamma: float = 0.99,
|
||||
n_step: int = 1
|
||||
n_step: int = 1,
|
||||
rew_norm: bool = False
|
||||
) -> np.ndarray:
|
||||
r"""Compute n-step return for Q-learning targets:
|
||||
|
||||
@ -159,13 +160,25 @@ class BasePolicy(ABC, nn.Module):
|
||||
:type buffer: :class:`~tianshou.data.ReplayBuffer`
|
||||
:param indice: sampled timestep.
|
||||
:type indice: numpy.ndarray
|
||||
:param function target_q_fn: a function receives :math:`t+n-1` step's
|
||||
data and compute target Q value.
|
||||
:param float gamma: the discount factor, should be in [0, 1], defaults
|
||||
to 0.99.
|
||||
:param int n_step: the number of estimation step, should be an int
|
||||
greater than 0, defaults to 1.
|
||||
:param bool rew_norm: normalize the reward to Normal(0, 1), defaults
|
||||
to ``False``.
|
||||
|
||||
:return: a Batch. The result will be stored in batch.returns.
|
||||
:return: a Batch. The result will be stored in batch.returns as a
|
||||
torch.Tensor with shape (bsz, ).
|
||||
"""
|
||||
if rew_norm:
|
||||
bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer
|
||||
mean, std = bfr.mean(), bfr.std()
|
||||
if np.isclose(std, 0):
|
||||
mean, std = 0, 1
|
||||
else:
|
||||
mean, std = 0, 1
|
||||
returns = np.zeros_like(indice)
|
||||
gammas = np.zeros_like(indice) + n_step
|
||||
done, rew, buf_len = buffer.done, buffer.rew, len(buffer)
|
||||
@ -173,10 +186,11 @@ class BasePolicy(ABC, nn.Module):
|
||||
now = (indice + n) % buf_len
|
||||
gammas[done[now] > 0] = n
|
||||
returns[done[now] > 0] = 0
|
||||
returns = rew[now] + gamma * returns
|
||||
returns = (rew[now] - mean) / std + gamma * returns
|
||||
terminal = (indice + n_step - 1) % buf_len
|
||||
target_q = target_q_fn(buffer, terminal)
|
||||
target_q = target_q_fn(buffer, terminal).squeeze()
|
||||
target_q[gammas != n_step] = 0
|
||||
returns += (gamma ** gammas) * target_q
|
||||
batch.returns = returns
|
||||
returns = to_torch_as(returns, target_q)
|
||||
gammas = to_torch_as(gamma ** gammas, target_q)
|
||||
batch.returns = target_q * gammas + returns
|
||||
return batch
|
||||
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
||||
from typing import Dict, List, Union, Optional
|
||||
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch, to_numpy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
|
||||
|
||||
|
||||
class A2CPolicy(PGPolicy):
|
||||
@ -106,14 +106,14 @@ class A2CPolicy(PGPolicy):
|
||||
self.optim.zero_grad()
|
||||
dist = self(b).dist
|
||||
v = self.critic(b.obs)
|
||||
a = to_torch(b.act, device=v.device)
|
||||
r = to_torch(b.returns, device=v.device)
|
||||
a = to_torch_as(b.act, v)
|
||||
r = to_torch_as(b.returns, v)
|
||||
a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
|
||||
vf_loss = F.mse_loss(r[:, None], v)
|
||||
ent_loss = dist.entropy().mean()
|
||||
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
|
||||
loss.backward()
|
||||
if self._grad_norm:
|
||||
if self._grad_norm is not None:
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters()) +
|
||||
list(self.critic.parameters()),
|
||||
|
@ -6,7 +6,7 @@ from typing import Dict, Tuple, Union, Optional
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
# from tianshou.exploration import OUNoise
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||
|
||||
|
||||
class DDPGPolicy(BasePolicy):
|
||||
@ -29,6 +29,8 @@ class DDPGPolicy(BasePolicy):
|
||||
defaults to ``False``.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
defaults to ``False``.
|
||||
:param int estimation_step: greater than 1, the number of steps to look
|
||||
ahead.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -47,6 +49,7 @@ class DDPGPolicy(BasePolicy):
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
estimation_step: int = 1,
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if actor is not None:
|
||||
@ -71,6 +74,8 @@ class DDPGPolicy(BasePolicy):
|
||||
# self.noise = OUNoise()
|
||||
self._rm_done = ignore_done
|
||||
self._rew_norm = reward_normalization
|
||||
assert estimation_step > 0, 'estimation_step should greater than 0'
|
||||
self._n_step = estimation_step
|
||||
|
||||
def set_eps(self, eps: float) -> None:
|
||||
"""Set the eps for exploration."""
|
||||
@ -96,15 +101,21 @@ class DDPGPolicy(BasePolicy):
|
||||
self.critic_old.parameters(), self.critic.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||
with torch.no_grad():
|
||||
target_q = self.critic_old(batch.obs_next, self(
|
||||
batch, model='actor_old', input='obs_next', eps=0).act)
|
||||
return target_q
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
if self._rew_norm:
|
||||
bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer
|
||||
mean, std = bfr.mean(), bfr.std()
|
||||
if not np.isclose(std, 0):
|
||||
batch.rew = (batch.rew - mean) / std
|
||||
if self._rm_done:
|
||||
batch.done = batch.done * 0.
|
||||
batch = self.compute_nstep_return(
|
||||
batch, buffer, indice, self._target_q,
|
||||
self._gamma, self._n_step, self._rew_norm)
|
||||
return batch
|
||||
|
||||
def forward(self, batch: Batch,
|
||||
@ -143,16 +154,9 @@ class DDPGPolicy(BasePolicy):
|
||||
return Batch(act=logits, state=h)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
with torch.no_grad():
|
||||
target_q = self.critic_old(batch.obs_next, self(
|
||||
batch, model='actor_old', input='obs_next', eps=0).act)
|
||||
dev = target_q.device
|
||||
rew = to_torch(batch.rew,
|
||||
dtype=torch.float, device=dev)[:, None]
|
||||
done = to_torch(batch.done,
|
||||
dtype=torch.float, device=dev)[:, None]
|
||||
target_q = (rew + (1. - done) * self._gamma * target_q)
|
||||
current_q = self.critic(batch.obs, batch.act)
|
||||
target_q = to_torch_as(batch.returns, current_q)
|
||||
target_q = target_q[:, None]
|
||||
critic_loss = F.mse_loss(current_q, target_q)
|
||||
self.critic_optim.zero_grad()
|
||||
critic_loss.backward()
|
||||
|
@ -6,7 +6,7 @@ from typing import Dict, Union, Optional
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
|
||||
to_torch, to_numpy
|
||||
to_torch_as, to_numpy
|
||||
|
||||
|
||||
class DQNPolicy(BasePolicy):
|
||||
@ -69,18 +69,18 @@ class DQNPolicy(BasePolicy):
|
||||
self.model_old.load_state_dict(self.model.state_dict())
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> np.ndarray:
|
||||
data = buffer[indice]
|
||||
indice: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||
if self._target:
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
a = self(data, input='obs_next', eps=0).act
|
||||
a = self(batch, input='obs_next', eps=0).act
|
||||
with torch.no_grad():
|
||||
target_q = self(
|
||||
data, model='model_old', input='obs_next').logits
|
||||
target_q = to_numpy(target_q)
|
||||
batch, model='model_old', input='obs_next').logits
|
||||
target_q = target_q[np.arange(len(a)), a]
|
||||
else:
|
||||
target_q = self(data, input='obs_next').logits
|
||||
target_q = to_numpy(target_q).max(axis=1)
|
||||
with torch.no_grad():
|
||||
target_q = self(batch, input='obs_next').logits.max(dim=1)[0]
|
||||
return target_q
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
@ -144,12 +144,11 @@ class DQNPolicy(BasePolicy):
|
||||
self.optim.zero_grad()
|
||||
q = self(batch).logits
|
||||
q = q[np.arange(len(q)), batch.act]
|
||||
r = to_torch(batch.returns, device=q.device, dtype=q.dtype)
|
||||
r = to_torch_as(batch.returns, q)
|
||||
if hasattr(batch, 'update_weight'):
|
||||
td = r - q
|
||||
batch.update_weight(batch.indice, to_numpy(td))
|
||||
impt_weight = to_torch(batch.impt_weight,
|
||||
device=q.device, dtype=torch.float)
|
||||
impt_weight = to_torch_as(batch.impt_weight, q)
|
||||
loss = (td.pow(2) * impt_weight).mean()
|
||||
else:
|
||||
loss = F.mse_loss(q, r)
|
||||
|
@ -4,7 +4,7 @@ from torch import nn
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch
|
||||
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
|
||||
|
||||
|
||||
class PPOPolicy(PGPolicy):
|
||||
@ -129,14 +129,12 @@ class PPOPolicy(PGPolicy):
|
||||
for b in batch.split(batch_size, shuffle=False):
|
||||
v.append(self.critic(b.obs))
|
||||
old_log_prob.append(self(b).dist.log_prob(
|
||||
to_torch(b.act, device=v[0].device)))
|
||||
to_torch_as(b.act, v[0])))
|
||||
batch.v = torch.cat(v, dim=0) # old value
|
||||
dev = batch.v.device
|
||||
batch.act = to_torch(batch.act, dtype=torch.float, device=dev)
|
||||
batch.act = to_torch_as(batch.act, v[0])
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
batch.returns = to_torch(
|
||||
batch.returns, dtype=torch.float, device=dev
|
||||
).reshape(batch.v.shape)
|
||||
batch.returns = to_torch_as(
|
||||
batch.returns, v[0]).reshape(batch.v.shape)
|
||||
if self._rew_norm:
|
||||
mean, std = batch.returns.mean(), batch.returns.std()
|
||||
if not np.isclose(std.item(), 0):
|
||||
|
@ -4,9 +4,9 @@ from copy import deepcopy
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Tuple, Union, Optional
|
||||
|
||||
from tianshou.data import Batch, to_torch
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.policy.dist import DiagGaussian
|
||||
from tianshou.data import Batch, to_torch_as, ReplayBuffer
|
||||
|
||||
|
||||
class SACPolicy(DDPGPolicy):
|
||||
@ -55,10 +55,11 @@ class SACPolicy(DDPGPolicy):
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
estimation_step: int = 1,
|
||||
**kwargs) -> None:
|
||||
super().__init__(None, None, None, None, tau, gamma, 0,
|
||||
action_range, reward_normalization, ignore_done,
|
||||
**kwargs)
|
||||
estimation_step, **kwargs)
|
||||
self.actor, self.actor_optim = actor, actor_optim
|
||||
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
||||
self.critic1_old.eval()
|
||||
@ -105,23 +106,23 @@ class SACPolicy(DDPGPolicy):
|
||||
return Batch(
|
||||
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
def _target_q(self, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs: s_{t+n}
|
||||
with torch.no_grad():
|
||||
obs_next_result = self(batch, input='obs_next')
|
||||
a_ = obs_next_result.act
|
||||
dev = a_.device
|
||||
batch.act = to_torch(batch.act, dtype=torch.float, device=dev)
|
||||
batch.act = to_torch_as(batch.act, a_)
|
||||
target_q = torch.min(
|
||||
self.critic1_old(batch.obs_next, a_),
|
||||
self.critic2_old(batch.obs_next, a_),
|
||||
) - self._alpha * obs_next_result.log_prob
|
||||
rew = to_torch(batch.rew,
|
||||
dtype=torch.float, device=dev)[:, None]
|
||||
done = to_torch(batch.done,
|
||||
dtype=torch.float, device=dev)[:, None]
|
||||
target_q = (rew + (1. - done) * self._gamma * target_q)
|
||||
return target_q
|
||||
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
# critic 1
|
||||
current_q1 = self.critic1(batch.obs, batch.act)
|
||||
target_q = to_torch_as(batch.returns, current_q1)[:, None]
|
||||
critic1_loss = F.mse_loss(current_q1, target_q)
|
||||
self.critic1_optim.zero_grad()
|
||||
critic1_loss.backward()
|
||||
|
@ -1,10 +1,11 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Tuple, Optional
|
||||
|
||||
from tianshou.data import Batch, to_torch
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
|
||||
|
||||
class TD3Policy(DDPGPolicy):
|
||||
@ -62,10 +63,11 @@ class TD3Policy(DDPGPolicy):
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
estimation_step: int = 1,
|
||||
**kwargs) -> None:
|
||||
super().__init__(actor, actor_optim, None, None, tau, gamma,
|
||||
exploration_noise, action_range, reward_normalization,
|
||||
ignore_done, **kwargs)
|
||||
ignore_done, estimation_step, **kwargs)
|
||||
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
||||
self.critic1_old.eval()
|
||||
self.critic1_optim = critic1_optim
|
||||
@ -100,25 +102,26 @@ class TD3Policy(DDPGPolicy):
|
||||
self.critic2_old.parameters(), self.critic2.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
def _target_q(self, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs: s_{t+n}
|
||||
with torch.no_grad():
|
||||
a_ = self(batch, model='actor_old', input='obs_next').act
|
||||
dev = a_.device
|
||||
noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
|
||||
if self._noise_clip >= 0:
|
||||
if self._noise_clip > 0:
|
||||
noise = noise.clamp(-self._noise_clip, self._noise_clip)
|
||||
a_ += noise
|
||||
a_ = a_.clamp(self._range[0], self._range[1])
|
||||
target_q = torch.min(
|
||||
self.critic1_old(batch.obs_next, a_),
|
||||
self.critic2_old(batch.obs_next, a_))
|
||||
rew = to_torch(batch.rew,
|
||||
dtype=torch.float, device=dev)[:, None]
|
||||
done = to_torch(batch.done,
|
||||
dtype=torch.float, device=dev)[:, None]
|
||||
target_q = (rew + (1. - done) * self._gamma * target_q)
|
||||
return target_q
|
||||
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
# critic 1
|
||||
current_q1 = self.critic1(batch.obs, batch.act)
|
||||
target_q = batch.returns[:, None]
|
||||
critic1_loss = F.mse_loss(current_q1, target_q)
|
||||
self.critic1_optim.zero_grad()
|
||||
critic1_loss.backward()
|
||||
|
Loading…
x
Reference in New Issue
Block a user