Add warnings for duplicate usage of action-bounded actor and action scaling method (#850)
- Fix the current bug discussed in #844 in `test_ppo.py`. - Add warning for `ActorProb ` if both `max_action ` and `unbounded=True` are used for model initializations. - Add warning for PGpolicy and DDPGpolicy if they find duplicate usage of action-bounded actor and action scaling method.
This commit is contained in:
parent
e7c2c3711e
commit
1423eeb3b2
@ -94,13 +94,8 @@ def test_sac_bipedal(args=get_args()):
|
|||||||
|
|
||||||
# model
|
# model
|
||||||
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||||
actor = ActorProb(
|
actor = ActorProb(net_a, args.action_shape, device=args.device,
|
||||||
net_a,
|
unbounded=True).to(args.device)
|
||||||
args.action_shape,
|
|
||||||
max_action=args.max_action,
|
|
||||||
device=args.device,
|
|
||||||
unbounded=True
|
|
||||||
).to(args.device)
|
|
||||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
|
|
||||||
net_c1 = Net(
|
net_c1 = Net(
|
||||||
|
@ -67,13 +67,8 @@ def test_sac(args=get_args()):
|
|||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||||
actor = ActorProb(
|
actor = ActorProb(net, args.action_shape, device=args.device,
|
||||||
net,
|
unbounded=True).to(args.device)
|
||||||
args.action_shape,
|
|
||||||
max_action=args.max_action,
|
|
||||||
device=args.device,
|
|
||||||
unbounded=True
|
|
||||||
).to(args.device)
|
|
||||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
net_c1 = Net(
|
net_c1 = Net(
|
||||||
args.state_shape,
|
args.state_shape,
|
||||||
|
@ -119,13 +119,8 @@ def test_gail(args=get_args()):
|
|||||||
activation=nn.Tanh,
|
activation=nn.Tanh,
|
||||||
device=args.device
|
device=args.device
|
||||||
)
|
)
|
||||||
actor = ActorProb(
|
actor = ActorProb(net_a, args.action_shape, unbounded=True,
|
||||||
net_a,
|
device=args.device).to(args.device)
|
||||||
args.action_shape,
|
|
||||||
max_action=args.max_action,
|
|
||||||
unbounded=True,
|
|
||||||
device=args.device
|
|
||||||
).to(args.device)
|
|
||||||
net_c = Net(
|
net_c = Net(
|
||||||
args.state_shape,
|
args.state_shape,
|
||||||
hidden_sizes=args.hidden_sizes,
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
@ -91,7 +91,6 @@ def test_a2c(args=get_args()):
|
|||||||
actor = ActorProb(
|
actor = ActorProb(
|
||||||
net_a,
|
net_a,
|
||||||
args.action_shape,
|
args.action_shape,
|
||||||
max_action=args.max_action,
|
|
||||||
unbounded=True,
|
unbounded=True,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
|
@ -93,7 +93,6 @@ def test_npg(args=get_args()):
|
|||||||
actor = ActorProb(
|
actor = ActorProb(
|
||||||
net_a,
|
net_a,
|
||||||
args.action_shape,
|
args.action_shape,
|
||||||
max_action=args.max_action,
|
|
||||||
unbounded=True,
|
unbounded=True,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
|
@ -96,7 +96,6 @@ def test_ppo(args=get_args()):
|
|||||||
actor = ActorProb(
|
actor = ActorProb(
|
||||||
net_a,
|
net_a,
|
||||||
args.action_shape,
|
args.action_shape,
|
||||||
max_action=args.max_action,
|
|
||||||
unbounded=True,
|
unbounded=True,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
|
@ -86,7 +86,6 @@ def test_redq(args=get_args()):
|
|||||||
actor = ActorProb(
|
actor = ActorProb(
|
||||||
net_a,
|
net_a,
|
||||||
args.action_shape,
|
args.action_shape,
|
||||||
max_action=args.max_action,
|
|
||||||
device=args.device,
|
device=args.device,
|
||||||
unbounded=True,
|
unbounded=True,
|
||||||
conditioned_sigma=True,
|
conditioned_sigma=True,
|
||||||
|
@ -88,7 +88,6 @@ def test_reinforce(args=get_args()):
|
|||||||
actor = ActorProb(
|
actor = ActorProb(
|
||||||
net_a,
|
net_a,
|
||||||
args.action_shape,
|
args.action_shape,
|
||||||
max_action=args.max_action,
|
|
||||||
unbounded=True,
|
unbounded=True,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
|
@ -81,7 +81,6 @@ def test_sac(args=get_args()):
|
|||||||
actor = ActorProb(
|
actor = ActorProb(
|
||||||
net_a,
|
net_a,
|
||||||
args.action_shape,
|
args.action_shape,
|
||||||
max_action=args.max_action,
|
|
||||||
device=args.device,
|
device=args.device,
|
||||||
unbounded=True,
|
unbounded=True,
|
||||||
conditioned_sigma=True,
|
conditioned_sigma=True,
|
||||||
|
@ -96,7 +96,6 @@ def test_trpo(args=get_args()):
|
|||||||
actor = ActorProb(
|
actor = ActorProb(
|
||||||
net_a,
|
net_a,
|
||||||
args.action_shape,
|
args.action_shape,
|
||||||
max_action=args.max_action,
|
|
||||||
unbounded=True,
|
unbounded=True,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
|
@ -108,7 +108,6 @@ def test_cql():
|
|||||||
actor = ActorProb(
|
actor = ActorProb(
|
||||||
net_a,
|
net_a,
|
||||||
action_shape=args.action_shape,
|
action_shape=args.action_shape,
|
||||||
max_action=args.max_action,
|
|
||||||
device=args.device,
|
device=args.device,
|
||||||
unbounded=True,
|
unbounded=True,
|
||||||
conditioned_sigma=True
|
conditioned_sigma=True
|
||||||
|
@ -82,13 +82,8 @@ def test_npg(args=get_args()):
|
|||||||
activation=nn.Tanh,
|
activation=nn.Tanh,
|
||||||
device=args.device
|
device=args.device
|
||||||
)
|
)
|
||||||
actor = ActorProb(
|
actor = ActorProb(net, args.action_shape, unbounded=True,
|
||||||
net,
|
device=args.device).to(args.device)
|
||||||
args.action_shape,
|
|
||||||
max_action=args.max_action,
|
|
||||||
unbounded=True,
|
|
||||||
device=args.device
|
|
||||||
).to(args.device)
|
|
||||||
critic = Critic(
|
critic = Critic(
|
||||||
Net(
|
Net(
|
||||||
args.state_shape,
|
args.state_shape,
|
||||||
|
@ -81,9 +81,8 @@ def test_ppo(args=get_args()):
|
|||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||||
actor = ActorProb(
|
actor = ActorProb(net, args.action_shape, unbounded=True,
|
||||||
net, args.action_shape, max_action=args.max_action, device=args.device
|
device=args.device).to(args.device)
|
||||||
).to(args.device)
|
|
||||||
critic = Critic(
|
critic = Critic(
|
||||||
Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device),
|
Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device),
|
||||||
device=args.device
|
device=args.device
|
||||||
|
@ -82,7 +82,6 @@ def test_redq(args=get_args()):
|
|||||||
actor = ActorProb(
|
actor = ActorProb(
|
||||||
net,
|
net,
|
||||||
args.action_shape,
|
args.action_shape,
|
||||||
max_action=args.max_action,
|
|
||||||
device=args.device,
|
device=args.device,
|
||||||
unbounded=True,
|
unbounded=True,
|
||||||
conditioned_sigma=True
|
conditioned_sigma=True
|
||||||
|
@ -79,13 +79,8 @@ def test_sac_with_il(args=get_args()):
|
|||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
# model
|
# model
|
||||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||||
actor = ActorProb(
|
actor = ActorProb(net, args.action_shape, device=args.device,
|
||||||
net,
|
unbounded=True).to(args.device)
|
||||||
args.action_shape,
|
|
||||||
max_action=args.max_action,
|
|
||||||
device=args.device,
|
|
||||||
unbounded=True
|
|
||||||
).to(args.device)
|
|
||||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
net_c1 = Net(
|
net_c1 = Net(
|
||||||
args.state_shape,
|
args.state_shape,
|
||||||
|
@ -85,13 +85,8 @@ def test_trpo(args=get_args()):
|
|||||||
activation=nn.Tanh,
|
activation=nn.Tanh,
|
||||||
device=args.device
|
device=args.device
|
||||||
)
|
)
|
||||||
actor = ActorProb(
|
actor = ActorProb(net, args.action_shape, unbounded=True,
|
||||||
net,
|
device=args.device).to(args.device)
|
||||||
args.action_shape,
|
|
||||||
max_action=args.max_action,
|
|
||||||
unbounded=True,
|
|
||||||
device=args.device
|
|
||||||
).to(args.device)
|
|
||||||
critic = Critic(
|
critic = Critic(
|
||||||
Net(
|
Net(
|
||||||
args.state_shape,
|
args.state_shape,
|
||||||
|
@ -93,7 +93,6 @@ def gather_data():
|
|||||||
actor = ActorProb(
|
actor = ActorProb(
|
||||||
net,
|
net,
|
||||||
args.action_shape,
|
args.action_shape,
|
||||||
max_action=args.max_action,
|
|
||||||
device=args.device,
|
device=args.device,
|
||||||
unbounded=True,
|
unbounded=True,
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
|
@ -108,7 +108,6 @@ def test_cql(args=get_args()):
|
|||||||
actor = ActorProb(
|
actor = ActorProb(
|
||||||
net_a,
|
net_a,
|
||||||
action_shape=args.action_shape,
|
action_shape=args.action_shape,
|
||||||
max_action=args.max_action,
|
|
||||||
device=args.device,
|
device=args.device,
|
||||||
unbounded=True,
|
unbounded=True,
|
||||||
conditioned_sigma=True,
|
conditioned_sigma=True,
|
||||||
|
@ -64,6 +64,19 @@ class DDPGPolicy(BasePolicy):
|
|||||||
assert action_bound_method != "tanh", "tanh mapping is not supported" \
|
assert action_bound_method != "tanh", "tanh mapping is not supported" \
|
||||||
"in policies where action is used as input of critic , because" \
|
"in policies where action is used as input of critic , because" \
|
||||||
"raw action in range (-inf, inf) will cause instability in training"
|
"raw action in range (-inf, inf) will cause instability in training"
|
||||||
|
try:
|
||||||
|
if actor is not None and action_scaling and \
|
||||||
|
not np.isclose(actor.max_action, 1.): # type: ignore
|
||||||
|
import warnings
|
||||||
|
warnings.warn(
|
||||||
|
"action_scaling and action_bound_method are only intended to deal"
|
||||||
|
"with unbounded model action space, but find actor model bound"
|
||||||
|
f"action space with max_action={actor.max_action}."
|
||||||
|
"Consider using unbounded=True option of the actor model,"
|
||||||
|
"or set action_scaling to False and action_bound_method to \"\"."
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
if actor is not None and actor_optim is not None:
|
if actor is not None and actor_optim is not None:
|
||||||
self.actor: torch.nn.Module = actor
|
self.actor: torch.nn.Module = actor
|
||||||
self.actor_old = deepcopy(actor)
|
self.actor_old = deepcopy(actor)
|
||||||
|
@ -53,6 +53,18 @@ class PGPolicy(BasePolicy):
|
|||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
self.actor = model
|
self.actor = model
|
||||||
|
try:
|
||||||
|
if action_scaling and not np.isclose(model.max_action, 1.): # type: ignore
|
||||||
|
import warnings
|
||||||
|
warnings.warn(
|
||||||
|
"action_scaling and action_bound_method are only intended"
|
||||||
|
"to deal with unbounded model action space, but find actor model"
|
||||||
|
f"bound action space with max_action={model.max_action}."
|
||||||
|
"Consider using unbounded=True option of the actor model,"
|
||||||
|
"or set action_scaling to False and action_bound_method to \"\"."
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
self.optim = optim
|
self.optim = optim
|
||||||
self.dist_fn = dist_fn
|
self.dist_fn = dist_fn
|
||||||
assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
|
assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import warnings
|
||||||
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
|
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -54,7 +55,7 @@ class Actor(nn.Module):
|
|||||||
hidden_sizes,
|
hidden_sizes,
|
||||||
device=self.device
|
device=self.device
|
||||||
)
|
)
|
||||||
self._max = max_action
|
self.max_action = max_action
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -64,7 +65,7 @@ class Actor(nn.Module):
|
|||||||
) -> Tuple[torch.Tensor, Any]:
|
) -> Tuple[torch.Tensor, Any]:
|
||||||
"""Mapping: obs -> logits -> action."""
|
"""Mapping: obs -> logits -> action."""
|
||||||
logits, hidden = self.preprocess(obs, state)
|
logits, hidden = self.preprocess(obs, state)
|
||||||
logits = self._max * torch.tanh(self.last(logits))
|
logits = self.max_action * torch.tanh(self.last(logits))
|
||||||
return logits, hidden
|
return logits, hidden
|
||||||
|
|
||||||
|
|
||||||
@ -178,6 +179,11 @@ class ActorProb(nn.Module):
|
|||||||
preprocess_net_output_dim: Optional[int] = None,
|
preprocess_net_output_dim: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if unbounded and not np.isclose(max_action, 1.0):
|
||||||
|
warnings.warn(
|
||||||
|
"Note that max_action input will be discarded when unbounded is True."
|
||||||
|
)
|
||||||
|
max_action = 1.0
|
||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.device = device
|
self.device = device
|
||||||
self.output_dim = int(np.prod(action_shape))
|
self.output_dim = int(np.prod(action_shape))
|
||||||
@ -198,7 +204,7 @@ class ActorProb(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
|
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
|
||||||
self._max = max_action
|
self.max_action = max_action
|
||||||
self._unbounded = unbounded
|
self._unbounded = unbounded
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -211,7 +217,7 @@ class ActorProb(nn.Module):
|
|||||||
logits, hidden = self.preprocess(obs, state)
|
logits, hidden = self.preprocess(obs, state)
|
||||||
mu = self.mu(logits)
|
mu = self.mu(logits)
|
||||||
if not self._unbounded:
|
if not self._unbounded:
|
||||||
mu = self._max * torch.tanh(mu)
|
mu = self.max_action * torch.tanh(mu)
|
||||||
if self._c_sigma:
|
if self._c_sigma:
|
||||||
sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
|
sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
|
||||||
else:
|
else:
|
||||||
@ -240,6 +246,11 @@ class RecurrentActorProb(nn.Module):
|
|||||||
conditioned_sigma: bool = False,
|
conditioned_sigma: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if unbounded and not np.isclose(max_action, 1.0):
|
||||||
|
warnings.warn(
|
||||||
|
"Note that max_action input will be discarded when unbounded is True."
|
||||||
|
)
|
||||||
|
max_action = 1.0
|
||||||
self.device = device
|
self.device = device
|
||||||
self.nn = nn.LSTM(
|
self.nn = nn.LSTM(
|
||||||
input_size=int(np.prod(state_shape)),
|
input_size=int(np.prod(state_shape)),
|
||||||
@ -254,7 +265,7 @@ class RecurrentActorProb(nn.Module):
|
|||||||
self.sigma = nn.Linear(hidden_layer_size, output_dim)
|
self.sigma = nn.Linear(hidden_layer_size, output_dim)
|
||||||
else:
|
else:
|
||||||
self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1))
|
self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1))
|
||||||
self._max = max_action
|
self.max_action = max_action
|
||||||
self._unbounded = unbounded
|
self._unbounded = unbounded
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -289,7 +300,7 @@ class RecurrentActorProb(nn.Module):
|
|||||||
logits = obs[:, -1]
|
logits = obs[:, -1]
|
||||||
mu = self.mu(logits)
|
mu = self.mu(logits)
|
||||||
if not self._unbounded:
|
if not self._unbounded:
|
||||||
mu = self._max * torch.tanh(mu)
|
mu = self.max_action * torch.tanh(mu)
|
||||||
if self._c_sigma:
|
if self._c_sigma:
|
||||||
sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
|
sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user