diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 8105964..8121d83 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -94,13 +94,8 @@ def test_sac_bipedal(args=get_args()): # model net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb( - net_a, - args.action_shape, - max_action=args.max_action, - device=args.device, - unbounded=True - ).to(args.device) + actor = ActorProb(net_a, args.action_shape, device=args.device, + unbounded=True).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index fb62d2a..c76ae75 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -67,13 +67,8 @@ def test_sac(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb( - net, - args.action_shape, - max_action=args.max_action, - device=args.device, - unbounded=True - ).to(args.device) + actor = ActorProb(net, args.action_shape, device=args.device, + unbounded=True).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( args.state_shape, diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index c2f3dab..ec27d82 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -119,13 +119,8 @@ def test_gail(args=get_args()): activation=nn.Tanh, device=args.device ) - actor = ActorProb( - net_a, - args.action_shape, - max_action=args.max_action, - unbounded=True, - device=args.device - ).to(args.device) + actor = ActorProb(net_a, args.action_shape, unbounded=True, + device=args.device).to(args.device) net_c = Net( args.state_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 19d5961..79cfebb 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -91,7 +91,6 @@ def test_a2c(args=get_args()): actor = ActorProb( net_a, args.action_shape, - max_action=args.max_action, unbounded=True, device=args.device, ).to(args.device) diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 1bd2d19..e65c79a 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -93,7 +93,6 @@ def test_npg(args=get_args()): actor = ActorProb( net_a, args.action_shape, - max_action=args.max_action, unbounded=True, device=args.device, ).to(args.device) diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 14c5a42..cb79a63 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -96,7 +96,6 @@ def test_ppo(args=get_args()): actor = ActorProb( net_a, args.action_shape, - max_action=args.max_action, unbounded=True, device=args.device, ).to(args.device) diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 312d1d6..695c59c 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -86,7 +86,6 @@ def test_redq(args=get_args()): actor = ActorProb( net_a, args.action_shape, - max_action=args.max_action, device=args.device, unbounded=True, conditioned_sigma=True, diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 40f0a1f..974e233 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -88,7 +88,6 @@ def test_reinforce(args=get_args()): actor = ActorProb( net_a, args.action_shape, - max_action=args.max_action, unbounded=True, device=args.device, ).to(args.device) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 0450c58..a4b05fe 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -81,7 +81,6 @@ def test_sac(args=get_args()): actor = ActorProb( net_a, args.action_shape, - max_action=args.max_action, device=args.device, unbounded=True, conditioned_sigma=True, diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 7754c0c..0576b72 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -96,7 +96,6 @@ def test_trpo(args=get_args()): actor = ActorProb( net_a, args.action_shape, - max_action=args.max_action, unbounded=True, device=args.device, ).to(args.device) diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 23cd215..3b2ba0a 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -108,7 +108,6 @@ def test_cql(): actor = ActorProb( net_a, action_shape=args.action_shape, - max_action=args.max_action, device=args.device, unbounded=True, conditioned_sigma=True diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index aab35fe..ddd47c4 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -82,13 +82,8 @@ def test_npg(args=get_args()): activation=nn.Tanh, device=args.device ) - actor = ActorProb( - net, - args.action_shape, - max_action=args.max_action, - unbounded=True, - device=args.device - ).to(args.device) + actor = ActorProb(net, args.action_shape, unbounded=True, + device=args.device).to(args.device) critic = Critic( Net( args.state_shape, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 323b85d..695ac49 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -81,9 +81,8 @@ def test_ppo(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb( - net, args.action_shape, max_action=args.max_action, device=args.device - ).to(args.device) + actor = ActorProb(net, args.action_shape, unbounded=True, + device=args.device).to(args.device) critic = Critic( Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), device=args.device diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index b6bd526..1226aa7 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -82,7 +82,6 @@ def test_redq(args=get_args()): actor = ActorProb( net, args.action_shape, - max_action=args.max_action, device=args.device, unbounded=True, conditioned_sigma=True diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 599c470..71f421b 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -79,13 +79,8 @@ def test_sac_with_il(args=get_args()): torch.manual_seed(args.seed) # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb( - net, - args.action_shape, - max_action=args.max_action, - device=args.device, - unbounded=True - ).to(args.device) + actor = ActorProb(net, args.action_shape, device=args.device, + unbounded=True).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( args.state_shape, diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 9bcc014..d007f0c 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -85,13 +85,8 @@ def test_trpo(args=get_args()): activation=nn.Tanh, device=args.device ) - actor = ActorProb( - net, - args.action_shape, - max_action=args.max_action, - unbounded=True, - device=args.device - ).to(args.device) + actor = ActorProb(net, args.action_shape, unbounded=True, + device=args.device).to(args.device) critic = Critic( Net( args.state_shape, diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index b124aad..98f59dc 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -93,7 +93,6 @@ def gather_data(): actor = ActorProb( net, args.action_shape, - max_action=args.max_action, device=args.device, unbounded=True, ).to(args.device) diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index b5dfafc..3356863 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -108,7 +108,6 @@ def test_cql(args=get_args()): actor = ActorProb( net_a, action_shape=args.action_shape, - max_action=args.max_action, device=args.device, unbounded=True, conditioned_sigma=True, diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 0637d05..8bad342 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -64,6 +64,19 @@ class DDPGPolicy(BasePolicy): assert action_bound_method != "tanh", "tanh mapping is not supported" \ "in policies where action is used as input of critic , because" \ "raw action in range (-inf, inf) will cause instability in training" + try: + if actor is not None and action_scaling and \ + not np.isclose(actor.max_action, 1.): # type: ignore + import warnings + warnings.warn( + "action_scaling and action_bound_method are only intended to deal" + "with unbounded model action space, but find actor model bound" + f"action space with max_action={actor.max_action}." + "Consider using unbounded=True option of the actor model," + "or set action_scaling to False and action_bound_method to \"\"." + ) + except Exception: + pass if actor is not None and actor_optim is not None: self.actor: torch.nn.Module = actor self.actor_old = deepcopy(actor) diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 4f0b1b1..a394940 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -53,6 +53,18 @@ class PGPolicy(BasePolicy): **kwargs ) self.actor = model + try: + if action_scaling and not np.isclose(model.max_action, 1.): # type: ignore + import warnings + warnings.warn( + "action_scaling and action_bound_method are only intended" + "to deal with unbounded model action space, but find actor model" + f"bound action space with max_action={model.max_action}." + "Consider using unbounded=True option of the actor model," + "or set action_scaling to False and action_bound_method to \"\"." + ) + except Exception: + pass self.optim = optim self.dist_fn = dist_fn assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]" diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index fb75e33..8d91480 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union import numpy as np @@ -54,7 +55,7 @@ class Actor(nn.Module): hidden_sizes, device=self.device ) - self._max = max_action + self.max_action = max_action def forward( self, @@ -64,7 +65,7 @@ class Actor(nn.Module): ) -> Tuple[torch.Tensor, Any]: """Mapping: obs -> logits -> action.""" logits, hidden = self.preprocess(obs, state) - logits = self._max * torch.tanh(self.last(logits)) + logits = self.max_action * torch.tanh(self.last(logits)) return logits, hidden @@ -178,6 +179,11 @@ class ActorProb(nn.Module): preprocess_net_output_dim: Optional[int] = None, ) -> None: super().__init__() + if unbounded and not np.isclose(max_action, 1.0): + warnings.warn( + "Note that max_action input will be discarded when unbounded is True." + ) + max_action = 1.0 self.preprocess = preprocess_net self.device = device self.output_dim = int(np.prod(action_shape)) @@ -198,7 +204,7 @@ class ActorProb(nn.Module): ) else: self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1)) - self._max = max_action + self.max_action = max_action self._unbounded = unbounded def forward( @@ -211,7 +217,7 @@ class ActorProb(nn.Module): logits, hidden = self.preprocess(obs, state) mu = self.mu(logits) if not self._unbounded: - mu = self._max * torch.tanh(mu) + mu = self.max_action * torch.tanh(mu) if self._c_sigma: sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp() else: @@ -240,6 +246,11 @@ class RecurrentActorProb(nn.Module): conditioned_sigma: bool = False, ) -> None: super().__init__() + if unbounded and not np.isclose(max_action, 1.0): + warnings.warn( + "Note that max_action input will be discarded when unbounded is True." + ) + max_action = 1.0 self.device = device self.nn = nn.LSTM( input_size=int(np.prod(state_shape)), @@ -254,7 +265,7 @@ class RecurrentActorProb(nn.Module): self.sigma = nn.Linear(hidden_layer_size, output_dim) else: self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1)) - self._max = max_action + self.max_action = max_action self._unbounded = unbounded def forward( @@ -289,7 +300,7 @@ class RecurrentActorProb(nn.Module): logits = obs[:, -1] mu = self.mu(logits) if not self._unbounded: - mu = self._max * torch.tanh(mu) + mu = self.max_action * torch.tanh(mu) if self._c_sigma: sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp() else: