Fix SAC loss explode (#333)

* change SAC action_bound_method to "clip" (tanh is hardcoded in forward)

* docstring update

* modelbase -> modelbased
This commit is contained in:
ChenDRAG 2021-04-04 17:33:35 +08:00 committed by GitHub
parent 825da9bc53
commit dd4a01132c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 80 additions and 54 deletions

View File

@ -1 +0,0 @@
../../test/modelbase/test_psrl.py

1
examples/modelbased/psrl.py Symbolic link
View File

@ -0,0 +1 @@
../../test/modelbased/test_psrl.py

View File

@ -6,6 +6,6 @@ TASK=$1
echo "Experiments started." echo "Experiments started."
for seed in $(seq 0 9) for seed in $(seq 0 9)
do do
python mujoco_sac.py --task $TASK --epoch 200 --seed $seed --logdir $LOGDIR > ${TASK}_`date '+%m-%d-%H-%M-%S'`_seed_$seed.txt 2>&1 python mujoco_sac.py --task $TASK --epoch 200 --seed $seed --logdir $LOGDIR > ${TASK}_`date '+%m-%d-%H-%M-%S'`_seed_$seed.txt 2>&1 &
done done
echo "Experiments ended." echo "Experiments ended."

View File

@ -20,12 +20,14 @@ def get_args():
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=3e-4) parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--il-lr', type=float, default=1e-3) parser.add_argument('--il-lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--auto-alpha', type=int, default=1)
parser.add_argument('--alpha-lr', type=float, default=3e-4)
parser.add_argument('--epoch', type=int, default=5) parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=24000) parser.add_argument('--step-per-epoch', type=int, default=24000)
parser.add_argument('--il-step-per-epoch', type=int, default=500) parser.add_argument('--il-step-per-epoch', type=int, default=500)
@ -41,7 +43,7 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.) parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', action="store_true", default=False) parser.add_argument('--rew-norm', action="store_true", default=False)
parser.add_argument('--n-step', type=int, default=4) parser.add_argument('--n-step', type=int, default=3)
parser.add_argument( parser.add_argument(
'--device', type=str, '--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu') default='cuda' if torch.cuda.is_available() else 'cpu')
@ -85,6 +87,13 @@ def test_sac_with_il(args=get_args()):
concat=True, device=args.device) concat=True, device=args.device)
critic2 = Critic(net_c2, device=args.device).to(args.device) critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
if args.auto_alpha:
target_entropy = -np.prod(env.action_space.shape)
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
args.alpha = (target_entropy, log_alpha, alpha_optim)
policy = SACPolicy( policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
tau=args.tau, gamma=args.gamma, alpha=args.alpha, tau=args.tau, gamma=args.gamma, alpha=args.alpha,
@ -135,11 +144,12 @@ def test_sac_with_il(args=get_args()):
args.action_shape, max_action=args.max_action, device=args.device args.action_shape, max_action=args.max_action, device=args.device
).to(args.device) ).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(net, optim, mode='continuous') il_policy = ImitationPolicy(
net, optim, mode='continuous', action_space=env.action_space,
action_scaling=True, action_bound_method="clip")
il_test_collector = Collector( il_test_collector = Collector(
il_policy, il_policy,
DummyVectorEnv( DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
[lambda: gym.make(args.task) for _ in range(args.test_num)])
) )
train_collector.reset() train_collector.reset()
result = offpolicy_trainer( result = offpolicy_trainer(
@ -147,6 +157,7 @@ def test_sac_with_il(args=get_args()):
args.il_step_per_epoch, args.step_per_collect, args.test_num, args.il_step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger) args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!

View File

@ -12,7 +12,7 @@ from tianshou.policy.modelfree.sac import SACPolicy
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.imitation.base import ImitationPolicy
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
from tianshou.policy.modelbase.psrl import PSRLPolicy from tianshou.policy.modelbased.psrl import PSRLPolicy
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager

View File

@ -12,39 +12,44 @@ from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
class BasePolicy(ABC, nn.Module): class BasePolicy(ABC, nn.Module):
"""The base class for any RL policy. """The base class for any RL policy.
Tianshou aims to modularizing RL algorithms. It comes into several classes Tianshou aims to modularizing RL algorithms. It comes into several classes of
of policies in Tianshou. All of the policy classes must inherit policies in Tianshou. All of the policy classes must inherit
:class:`~tianshou.policy.BasePolicy`. :class:`~tianshou.policy.BasePolicy`.
A policy class typically has four parts: A policy class typically has the following parts:
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, \ * :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including \
including coping the target network and so on; coping the target network and so on;
* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given \ * :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given \
observation; observation;
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from \ * :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the \
the replay buffer (this function can interact with replay buffer); replay buffer (this function can interact with replay buffer);
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given \ * :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of \
batch of data. data.
* :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the replay buffer \
from the learning process (e.g., prioritized replay buffer needs to update \
the weight);
* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training, \
i.e., `process_fn -> learn -> post_process_fn`.
Most of the policy needs a neural network to predict the action and an Most of the policy needs a neural network to predict the action and an
optimizer to optimize the policy. The rules of self-defined networks are: optimizer to optimize the policy. The rules of self-defined networks are:
1. Input: observation "obs" (may be a ``numpy.ndarray``, a \ 1. Input: observation "obs" (may be a ``numpy.ndarray``, a ``torch.Tensor``, a \
``torch.Tensor``, a dict or any others), hidden state "state" (for RNN \ dict or any others), hidden state "state" (for RNN usage), and other information \
usage), and other information "info" provided by the environment. "info" provided by the environment.
2. Output: some "logits", the next hidden state "state", and the \ 2. Output: some "logits", the next hidden state "state", and the intermediate \
intermediate result during policy forwarding procedure "policy". The \ result during policy forwarding procedure "policy". The "logits" could be a tuple \
"logits" could be a tuple instead of a ``torch.Tensor``. It depends on how\ instead of a ``torch.Tensor``. It depends on how the policy process the network \
the policy process the network output. For example, in PPO, the return of \ output. For example, in PPO, the return of the network might be \
the network might be ``(mu, sigma), state`` for Gaussian policy. The \ ``(mu, sigma), state`` for Gaussian policy. The "policy" can be a Batch of \
"policy" can be a Batch of torch.Tensor or other things, which will be \ torch.Tensor or other things, which will be stored in the replay buffer, and can \
stored in the replay buffer, and can be accessed in the policy update \ be accessed in the policy update process (e.g. in "policy.learn()", the \
process (e.g. in "policy.learn()", the "batch.policy" is what you need). "batch.policy" is what you need).
Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``, Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``, you can
you can use :class:`~tianshou.policy.BasePolicy` almost the same as use :class:`~tianshou.policy.BasePolicy` almost the same as ``torch.nn.Module``,
``torch.nn.Module``, for instance, loading and saving the model: for instance, loading and saving the model:
:: ::
torch.save(policy.state_dict(), "policy.pth") torch.save(policy.state_dict(), "policy.pth")
@ -117,6 +122,15 @@ class BasePolicy(ABC, nn.Module):
return Batch(..., policy=Batch(log_prob=dist.log_prob(act))) return Batch(..., policy=Batch(log_prob=dist.log_prob(act)))
# and in the sampled data batch, you can directly use # and in the sampled data batch, you can directly use
# batch.policy.log_prob to get your data. # batch.policy.log_prob to get your data.
.. note::
In continuous action space, you should do another step "map_action" to get
the real action:
::
act = policy(batch).act # doesn't map to the target action range
act = policy.map_action(act, batch)
""" """
pass pass

View File

@ -27,8 +27,8 @@ class DDPGPolicy(BasePolicy):
:param bool action_scaling: whether to map actions from range [-1, 1] to range :param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True. [action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be :param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh either "clip" (for simply clipping the action) or empty string for no bounding.
squashing) for now, or empty string for no bounding. Default to "clip". Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want :param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None. to use option "action_scaling" or "action_bound_method". Default to None.
@ -55,6 +55,9 @@ class DDPGPolicy(BasePolicy):
) -> None: ) -> None:
super().__init__(action_scaling=action_scaling, super().__init__(action_scaling=action_scaling,
action_bound_method=action_bound_method, **kwargs) action_bound_method=action_bound_method, **kwargs)
assert action_bound_method != "tanh", "tanh mapping is not supported" \
"in policies where action is used as input of critic , because" \
"raw action in range (-inf, inf) will cause instability in training"
if actor is not None and actor_optim is not None: if actor is not None and actor_optim is not None:
self.actor: torch.nn.Module = actor self.actor: torch.nn.Module = actor
self.actor_old = deepcopy(actor) self.actor_old = deepcopy(actor)

View File

@ -37,8 +37,8 @@ class SACPolicy(DDPGPolicy):
:param bool action_scaling: whether to map actions from range [-1, 1] to range :param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True. [action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be :param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh either "clip" (for simply clipping the action) or empty string for no bounding.
squashing) for now, or empty string for no bounding. Default to "tanh". Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want :param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None. to use option "action_scaling" or "action_bound_method". Default to None.
@ -63,13 +63,11 @@ class SACPolicy(DDPGPolicy):
estimation_step: int = 1, estimation_step: int = 1,
exploration_noise: Optional[BaseNoise] = None, exploration_noise: Optional[BaseNoise] = None,
deterministic_eval: bool = True, deterministic_eval: bool = True,
action_bound_method: str = "tanh",
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__( super().__init__(
None, None, None, None, tau, gamma, exploration_noise, None, None, None, None, tau, gamma, exploration_noise,
reward_normalization, estimation_step, reward_normalization, estimation_step, **kwargs)
action_bound_method=action_bound_method, **kwargs)
self.actor, self.actor_optim = actor, actor_optim self.actor, self.actor_optim = actor, actor_optim
self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1, self.critic1_old = critic1, deepcopy(critic1)
self.critic1_old.eval() self.critic1_old.eval()
@ -120,11 +118,10 @@ class SACPolicy(DDPGPolicy):
else: else:
act = dist.rsample() act = dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1) log_prob = dist.log_prob(act).unsqueeze(-1)
if self.action_bound_method == "tanh" and self.action_space is not None:
# apply correction for Tanh squashing when computing logprob from Gaussian # apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation. # in appendix C to get some understanding of this equation.
if self.action_scaling: if self.action_scaling and self.action_space is not None:
action_scale = to_torch_as( action_scale = to_torch_as(
(self.action_space.high - self.action_space.low) / 2.0, act) (self.action_space.high - self.action_space.low) / 2.0, act)
else: else:
@ -133,7 +130,8 @@ class SACPolicy(DDPGPolicy):
log_prob = log_prob - torch.log( log_prob = log_prob - torch.log(
action_scale * (1 - squashed_action.pow(2)) + self.__eps action_scale * (1 - squashed_action.pow(2)) + self.__eps
).sum(-1, keepdim=True) ).sum(-1, keepdim=True)
return Batch(logits=logits, act=act, state=h, dist=dist, log_prob=log_prob) return Batch(logits=logits, act=squashed_action,
state=h, dist=dist, log_prob=log_prob)
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor: def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs: s_{t+n} batch = buffer[indice] # batch.obs: s_{t+n}

View File

@ -35,8 +35,8 @@ class TD3Policy(DDPGPolicy):
:param bool action_scaling: whether to map actions from range [-1, 1] to range :param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True. [action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be :param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh either "clip" (for simply clipping the action) or empty string for no bounding.
squashing) for now, or empty string for no bounding. Default to "clip". Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want :param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None. to use option "action_scaling" or "action_bound_method". Default to None.