Fix SAC loss explode (#333)

* change SAC action_bound_method to "clip" (tanh is hardcoded in forward)

* docstring update

* modelbase -> modelbased
This commit is contained in:
ChenDRAG 2021-04-04 17:33:35 +08:00 committed by GitHub
parent 825da9bc53
commit dd4a01132c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 80 additions and 54 deletions

View File

@ -1 +0,0 @@
../../test/modelbase/test_psrl.py

1
examples/modelbased/psrl.py Symbolic link
View File

@ -0,0 +1 @@
../../test/modelbased/test_psrl.py

View File

@ -6,6 +6,6 @@ TASK=$1
echo "Experiments started."
for seed in $(seq 0 9)
do
python mujoco_sac.py --task $TASK --epoch 200 --seed $seed --logdir $LOGDIR > ${TASK}_`date '+%m-%d-%H-%M-%S'`_seed_$seed.txt 2>&1
python mujoco_sac.py --task $TASK --epoch 200 --seed $seed --logdir $LOGDIR > ${TASK}_`date '+%m-%d-%H-%M-%S'`_seed_$seed.txt 2>&1 &
done
echo "Experiments ended."

View File

@ -20,12 +20,14 @@ def get_args():
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=3e-4)
parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--il-lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--auto-alpha', type=int, default=1)
parser.add_argument('--alpha-lr', type=float, default=3e-4)
parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=24000)
parser.add_argument('--il-step-per-epoch', type=int, default=500)
@ -41,7 +43,7 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', action="store_true", default=False)
parser.add_argument('--n-step', type=int, default=4)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
@ -85,6 +87,13 @@ def test_sac_with_il(args=get_args()):
concat=True, device=args.device)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
if args.auto_alpha:
target_entropy = -np.prod(env.action_space.shape)
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
args.alpha = (target_entropy, log_alpha, alpha_optim)
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
@ -135,11 +144,12 @@ def test_sac_with_il(args=get_args()):
args.action_shape, max_action=args.max_action, device=args.device
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(net, optim, mode='continuous')
il_policy = ImitationPolicy(
net, optim, mode='continuous', action_space=env.action_space,
action_scaling=True, action_bound_method="clip")
il_test_collector = Collector(
il_policy,
DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
)
train_collector.reset()
result = offpolicy_trainer(
@ -147,6 +157,7 @@ def test_sac_with_il(args=get_args()):
args.il_step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!

View File

@ -12,7 +12,7 @@ from tianshou.policy.modelfree.sac import SACPolicy
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
from tianshou.policy.imitation.base import ImitationPolicy
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
from tianshou.policy.modelbase.psrl import PSRLPolicy
from tianshou.policy.modelbased.psrl import PSRLPolicy
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager

View File

@ -12,39 +12,44 @@ from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
class BasePolicy(ABC, nn.Module):
"""The base class for any RL policy.
Tianshou aims to modularizing RL algorithms. It comes into several classes
of policies in Tianshou. All of the policy classes must inherit
Tianshou aims to modularizing RL algorithms. It comes into several classes of
policies in Tianshou. All of the policy classes must inherit
:class:`~tianshou.policy.BasePolicy`.
A policy class typically has four parts:
A policy class typically has the following parts:
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, \
including coping the target network and so on;
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including \
coping the target network and so on;
* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given \
observation;
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from \
the replay buffer (this function can interact with replay buffer);
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given \
batch of data.
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the \
replay buffer (this function can interact with replay buffer);
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of \
data.
* :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the replay buffer \
from the learning process (e.g., prioritized replay buffer needs to update \
the weight);
* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training, \
i.e., `process_fn -> learn -> post_process_fn`.
Most of the policy needs a neural network to predict the action and an
optimizer to optimize the policy. The rules of self-defined networks are:
1. Input: observation "obs" (may be a ``numpy.ndarray``, a \
``torch.Tensor``, a dict or any others), hidden state "state" (for RNN \
usage), and other information "info" provided by the environment.
2. Output: some "logits", the next hidden state "state", and the \
intermediate result during policy forwarding procedure "policy". The \
"logits" could be a tuple instead of a ``torch.Tensor``. It depends on how\
the policy process the network output. For example, in PPO, the return of \
the network might be ``(mu, sigma), state`` for Gaussian policy. The \
"policy" can be a Batch of torch.Tensor or other things, which will be \
stored in the replay buffer, and can be accessed in the policy update \
process (e.g. in "policy.learn()", the "batch.policy" is what you need).
1. Input: observation "obs" (may be a ``numpy.ndarray``, a ``torch.Tensor``, a \
dict or any others), hidden state "state" (for RNN usage), and other information \
"info" provided by the environment.
2. Output: some "logits", the next hidden state "state", and the intermediate \
result during policy forwarding procedure "policy". The "logits" could be a tuple \
instead of a ``torch.Tensor``. It depends on how the policy process the network \
output. For example, in PPO, the return of the network might be \
``(mu, sigma), state`` for Gaussian policy. The "policy" can be a Batch of \
torch.Tensor or other things, which will be stored in the replay buffer, and can \
be accessed in the policy update process (e.g. in "policy.learn()", the \
"batch.policy" is what you need).
Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``,
you can use :class:`~tianshou.policy.BasePolicy` almost the same as
``torch.nn.Module``, for instance, loading and saving the model:
Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``, you can
use :class:`~tianshou.policy.BasePolicy` almost the same as ``torch.nn.Module``,
for instance, loading and saving the model:
::
torch.save(policy.state_dict(), "policy.pth")
@ -117,6 +122,15 @@ class BasePolicy(ABC, nn.Module):
return Batch(..., policy=Batch(log_prob=dist.log_prob(act)))
# and in the sampled data batch, you can directly use
# batch.policy.log_prob to get your data.
.. note::
In continuous action space, you should do another step "map_action" to get
the real action:
::
act = policy(batch).act # doesn't map to the target action range
act = policy.map_action(act, batch)
"""
pass

View File

@ -27,8 +27,8 @@ class DDPGPolicy(BasePolicy):
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "clip".
either "clip" (for simply clipping the action) or empty string for no bounding.
Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
@ -55,6 +55,9 @@ class DDPGPolicy(BasePolicy):
) -> None:
super().__init__(action_scaling=action_scaling,
action_bound_method=action_bound_method, **kwargs)
assert action_bound_method != "tanh", "tanh mapping is not supported" \
"in policies where action is used as input of critic , because" \
"raw action in range (-inf, inf) will cause instability in training"
if actor is not None and actor_optim is not None:
self.actor: torch.nn.Module = actor
self.actor_old = deepcopy(actor)

View File

@ -37,8 +37,8 @@ class SACPolicy(DDPGPolicy):
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "tanh".
either "clip" (for simply clipping the action) or empty string for no bounding.
Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
@ -63,13 +63,11 @@ class SACPolicy(DDPGPolicy):
estimation_step: int = 1,
exploration_noise: Optional[BaseNoise] = None,
deterministic_eval: bool = True,
action_bound_method: str = "tanh",
**kwargs: Any,
) -> None:
super().__init__(
None, None, None, None, tau, gamma, exploration_noise,
reward_normalization, estimation_step,
action_bound_method=action_bound_method, **kwargs)
reward_normalization, estimation_step, **kwargs)
self.actor, self.actor_optim = actor, actor_optim
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
self.critic1_old.eval()
@ -120,20 +118,20 @@ class SACPolicy(DDPGPolicy):
else:
act = dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1)
if self.action_bound_method == "tanh" and self.action_space is not None:
# apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation.
if self.action_scaling:
action_scale = to_torch_as(
(self.action_space.high - self.action_space.low) / 2.0, act)
else:
action_scale = 1.0 # type: ignore
squashed_action = torch.tanh(act)
log_prob = log_prob - torch.log(
action_scale * (1 - squashed_action.pow(2)) + self.__eps
).sum(-1, keepdim=True)
return Batch(logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
# apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation.
if self.action_scaling and self.action_space is not None:
action_scale = to_torch_as(
(self.action_space.high - self.action_space.low) / 2.0, act)
else:
action_scale = 1.0 # type: ignore
squashed_action = torch.tanh(act)
log_prob = log_prob - torch.log(
action_scale * (1 - squashed_action.pow(2)) + self.__eps
).sum(-1, keepdim=True)
return Batch(logits=logits, act=squashed_action,
state=h, dist=dist, log_prob=log_prob)
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs: s_{t+n}

View File

@ -35,8 +35,8 @@ class TD3Policy(DDPGPolicy):
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "clip".
either "clip" (for simply clipping the action) or empty string for no bounding.
Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.