Fix SAC loss explode (#333)
* change SAC action_bound_method to "clip" (tanh is hardcoded in forward) * docstring update * modelbase -> modelbased
This commit is contained in:
parent
825da9bc53
commit
dd4a01132c
@ -1 +0,0 @@
|
||||
../../test/modelbase/test_psrl.py
|
1
examples/modelbased/psrl.py
Symbolic link
1
examples/modelbased/psrl.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../test/modelbased/test_psrl.py
|
@ -6,6 +6,6 @@ TASK=$1
|
||||
echo "Experiments started."
|
||||
for seed in $(seq 0 9)
|
||||
do
|
||||
python mujoco_sac.py --task $TASK --epoch 200 --seed $seed --logdir $LOGDIR > ${TASK}_`date '+%m-%d-%H-%M-%S'`_seed_$seed.txt 2>&1
|
||||
python mujoco_sac.py --task $TASK --epoch 200 --seed $seed --logdir $LOGDIR > ${TASK}_`date '+%m-%d-%H-%M-%S'`_seed_$seed.txt 2>&1 &
|
||||
done
|
||||
echo "Experiments ended."
|
||||
|
@ -20,12 +20,14 @@ def get_args():
|
||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
||||
parser.add_argument('--actor-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--il-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--tau', type=float, default=0.005)
|
||||
parser.add_argument('--alpha', type=float, default=0.2)
|
||||
parser.add_argument('--auto-alpha', type=int, default=1)
|
||||
parser.add_argument('--alpha-lr', type=float, default=3e-4)
|
||||
parser.add_argument('--epoch', type=int, default=5)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=24000)
|
||||
parser.add_argument('--il-step-per-epoch', type=int, default=500)
|
||||
@ -41,7 +43,7 @@ def get_args():
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--rew-norm', action="store_true", default=False)
|
||||
parser.add_argument('--n-step', type=int, default=4)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
@ -85,6 +87,13 @@ def test_sac_with_il(args=get_args()):
|
||||
concat=True, device=args.device)
|
||||
critic2 = Critic(net_c2, device=args.device).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
|
||||
if args.auto_alpha:
|
||||
target_entropy = -np.prod(env.action_space.shape)
|
||||
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
|
||||
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
|
||||
args.alpha = (target_entropy, log_alpha, alpha_optim)
|
||||
|
||||
policy = SACPolicy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
||||
@ -135,11 +144,12 @@ def test_sac_with_il(args=get_args()):
|
||||
args.action_shape, max_action=args.max_action, device=args.device
|
||||
).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
||||
il_policy = ImitationPolicy(net, optim, mode='continuous')
|
||||
il_policy = ImitationPolicy(
|
||||
net, optim, mode='continuous', action_space=env.action_space,
|
||||
action_scaling=True, action_bound_method="clip")
|
||||
il_test_collector = Collector(
|
||||
il_policy,
|
||||
DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
)
|
||||
train_collector.reset()
|
||||
result = offpolicy_trainer(
|
||||
@ -147,6 +157,7 @@ def test_sac_with_il(args=get_args()):
|
||||
args.il_step_per_epoch, args.step_per_collect, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
|
@ -12,7 +12,7 @@ from tianshou.policy.modelfree.sac import SACPolicy
|
||||
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
|
||||
from tianshou.policy.imitation.base import ImitationPolicy
|
||||
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
|
||||
from tianshou.policy.modelbase.psrl import PSRLPolicy
|
||||
from tianshou.policy.modelbased.psrl import PSRLPolicy
|
||||
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
|
||||
|
||||
|
||||
|
@ -12,39 +12,44 @@ from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
|
||||
class BasePolicy(ABC, nn.Module):
|
||||
"""The base class for any RL policy.
|
||||
|
||||
Tianshou aims to modularizing RL algorithms. It comes into several classes
|
||||
of policies in Tianshou. All of the policy classes must inherit
|
||||
Tianshou aims to modularizing RL algorithms. It comes into several classes of
|
||||
policies in Tianshou. All of the policy classes must inherit
|
||||
:class:`~tianshou.policy.BasePolicy`.
|
||||
|
||||
A policy class typically has four parts:
|
||||
A policy class typically has the following parts:
|
||||
|
||||
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, \
|
||||
including coping the target network and so on;
|
||||
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including \
|
||||
coping the target network and so on;
|
||||
* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given \
|
||||
observation;
|
||||
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from \
|
||||
the replay buffer (this function can interact with replay buffer);
|
||||
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given \
|
||||
batch of data.
|
||||
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the \
|
||||
replay buffer (this function can interact with replay buffer);
|
||||
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of \
|
||||
data.
|
||||
* :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the replay buffer \
|
||||
from the learning process (e.g., prioritized replay buffer needs to update \
|
||||
the weight);
|
||||
* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training, \
|
||||
i.e., `process_fn -> learn -> post_process_fn`.
|
||||
|
||||
Most of the policy needs a neural network to predict the action and an
|
||||
optimizer to optimize the policy. The rules of self-defined networks are:
|
||||
|
||||
1. Input: observation "obs" (may be a ``numpy.ndarray``, a \
|
||||
``torch.Tensor``, a dict or any others), hidden state "state" (for RNN \
|
||||
usage), and other information "info" provided by the environment.
|
||||
2. Output: some "logits", the next hidden state "state", and the \
|
||||
intermediate result during policy forwarding procedure "policy". The \
|
||||
"logits" could be a tuple instead of a ``torch.Tensor``. It depends on how\
|
||||
the policy process the network output. For example, in PPO, the return of \
|
||||
the network might be ``(mu, sigma), state`` for Gaussian policy. The \
|
||||
"policy" can be a Batch of torch.Tensor or other things, which will be \
|
||||
stored in the replay buffer, and can be accessed in the policy update \
|
||||
process (e.g. in "policy.learn()", the "batch.policy" is what you need).
|
||||
1. Input: observation "obs" (may be a ``numpy.ndarray``, a ``torch.Tensor``, a \
|
||||
dict or any others), hidden state "state" (for RNN usage), and other information \
|
||||
"info" provided by the environment.
|
||||
2. Output: some "logits", the next hidden state "state", and the intermediate \
|
||||
result during policy forwarding procedure "policy". The "logits" could be a tuple \
|
||||
instead of a ``torch.Tensor``. It depends on how the policy process the network \
|
||||
output. For example, in PPO, the return of the network might be \
|
||||
``(mu, sigma), state`` for Gaussian policy. The "policy" can be a Batch of \
|
||||
torch.Tensor or other things, which will be stored in the replay buffer, and can \
|
||||
be accessed in the policy update process (e.g. in "policy.learn()", the \
|
||||
"batch.policy" is what you need).
|
||||
|
||||
Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``,
|
||||
you can use :class:`~tianshou.policy.BasePolicy` almost the same as
|
||||
``torch.nn.Module``, for instance, loading and saving the model:
|
||||
Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``, you can
|
||||
use :class:`~tianshou.policy.BasePolicy` almost the same as ``torch.nn.Module``,
|
||||
for instance, loading and saving the model:
|
||||
::
|
||||
|
||||
torch.save(policy.state_dict(), "policy.pth")
|
||||
@ -117,6 +122,15 @@ class BasePolicy(ABC, nn.Module):
|
||||
return Batch(..., policy=Batch(log_prob=dist.log_prob(act)))
|
||||
# and in the sampled data batch, you can directly use
|
||||
# batch.policy.log_prob to get your data.
|
||||
|
||||
.. note::
|
||||
|
||||
In continuous action space, you should do another step "map_action" to get
|
||||
the real action:
|
||||
::
|
||||
|
||||
act = policy(batch).act # doesn't map to the target action range
|
||||
act = policy.map_action(act, batch)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -27,8 +27,8 @@ class DDPGPolicy(BasePolicy):
|
||||
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
||||
[action_spaces.low, action_spaces.high]. Default to True.
|
||||
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
||||
either "clip" (for simply clipping the action), "tanh" (for applying tanh
|
||||
squashing) for now, or empty string for no bounding. Default to "clip".
|
||||
either "clip" (for simply clipping the action) or empty string for no bounding.
|
||||
Default to "clip".
|
||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
|
||||
@ -55,6 +55,9 @@ class DDPGPolicy(BasePolicy):
|
||||
) -> None:
|
||||
super().__init__(action_scaling=action_scaling,
|
||||
action_bound_method=action_bound_method, **kwargs)
|
||||
assert action_bound_method != "tanh", "tanh mapping is not supported" \
|
||||
"in policies where action is used as input of critic , because" \
|
||||
"raw action in range (-inf, inf) will cause instability in training"
|
||||
if actor is not None and actor_optim is not None:
|
||||
self.actor: torch.nn.Module = actor
|
||||
self.actor_old = deepcopy(actor)
|
||||
|
@ -37,8 +37,8 @@ class SACPolicy(DDPGPolicy):
|
||||
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
||||
[action_spaces.low, action_spaces.high]. Default to True.
|
||||
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
||||
either "clip" (for simply clipping the action), "tanh" (for applying tanh
|
||||
squashing) for now, or empty string for no bounding. Default to "tanh".
|
||||
either "clip" (for simply clipping the action) or empty string for no bounding.
|
||||
Default to "clip".
|
||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
|
||||
@ -63,13 +63,11 @@ class SACPolicy(DDPGPolicy):
|
||||
estimation_step: int = 1,
|
||||
exploration_noise: Optional[BaseNoise] = None,
|
||||
deterministic_eval: bool = True,
|
||||
action_bound_method: str = "tanh",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
None, None, None, None, tau, gamma, exploration_noise,
|
||||
reward_normalization, estimation_step,
|
||||
action_bound_method=action_bound_method, **kwargs)
|
||||
reward_normalization, estimation_step, **kwargs)
|
||||
self.actor, self.actor_optim = actor, actor_optim
|
||||
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
||||
self.critic1_old.eval()
|
||||
@ -120,20 +118,20 @@ class SACPolicy(DDPGPolicy):
|
||||
else:
|
||||
act = dist.rsample()
|
||||
log_prob = dist.log_prob(act).unsqueeze(-1)
|
||||
if self.action_bound_method == "tanh" and self.action_space is not None:
|
||||
# apply correction for Tanh squashing when computing logprob from Gaussian
|
||||
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
|
||||
# in appendix C to get some understanding of this equation.
|
||||
if self.action_scaling:
|
||||
action_scale = to_torch_as(
|
||||
(self.action_space.high - self.action_space.low) / 2.0, act)
|
||||
else:
|
||||
action_scale = 1.0 # type: ignore
|
||||
squashed_action = torch.tanh(act)
|
||||
log_prob = log_prob - torch.log(
|
||||
action_scale * (1 - squashed_action.pow(2)) + self.__eps
|
||||
).sum(-1, keepdim=True)
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
|
||||
# apply correction for Tanh squashing when computing logprob from Gaussian
|
||||
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
|
||||
# in appendix C to get some understanding of this equation.
|
||||
if self.action_scaling and self.action_space is not None:
|
||||
action_scale = to_torch_as(
|
||||
(self.action_space.high - self.action_space.low) / 2.0, act)
|
||||
else:
|
||||
action_scale = 1.0 # type: ignore
|
||||
squashed_action = torch.tanh(act)
|
||||
log_prob = log_prob - torch.log(
|
||||
action_scale * (1 - squashed_action.pow(2)) + self.__eps
|
||||
).sum(-1, keepdim=True)
|
||||
return Batch(logits=logits, act=squashed_action,
|
||||
state=h, dist=dist, log_prob=log_prob)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs: s_{t+n}
|
||||
|
@ -35,8 +35,8 @@ class TD3Policy(DDPGPolicy):
|
||||
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
||||
[action_spaces.low, action_spaces.high]. Default to True.
|
||||
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
||||
either "clip" (for simply clipping the action), "tanh" (for applying tanh
|
||||
squashing) for now, or empty string for no bounding. Default to "clip".
|
||||
either "clip" (for simply clipping the action) or empty string for no bounding.
|
||||
Default to "clip".
|
||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user