final fix for actor_critic shared head parameters (#458)
This commit is contained in:
parent
22d7bf38c8
commit
5df64800f4
@ -13,7 +13,7 @@ from tianshou.env import DummyVectorEnv
|
|||||||
from tianshou.policy import PPOPolicy
|
from tianshou.policy import PPOPolicy
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import ActorCritic, Net
|
||||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||||
|
|
||||||
|
|
||||||
@ -84,14 +84,13 @@ def test_ppo(args=get_args()):
|
|||||||
Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device),
|
Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device),
|
||||||
device=args.device
|
device=args.device
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
|
actor_critic = ActorCritic(actor, critic)
|
||||||
# orthogonal initialization
|
# orthogonal initialization
|
||||||
for m in set(actor.modules()).union(critic.modules()):
|
for m in actor_critic.modules():
|
||||||
if isinstance(m, torch.nn.Linear):
|
if isinstance(m, torch.nn.Linear):
|
||||||
torch.nn.init.orthogonal_(m.weight)
|
torch.nn.init.orthogonal_(m.weight)
|
||||||
torch.nn.init.zeros_(m.bias)
|
torch.nn.init.zeros_(m.bias)
|
||||||
optim = torch.optim.Adam(
|
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
|
||||||
set(actor.parameters()).union(critic.parameters()), lr=args.lr
|
|
||||||
)
|
|
||||||
|
|
||||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||||
# pass *logits to be consistent with policy.forward
|
# pass *logits to be consistent with policy.forward
|
||||||
|
@ -12,7 +12,7 @@ from tianshou.env import DummyVectorEnv
|
|||||||
from tianshou.policy import A2CPolicy, ImitationPolicy
|
from tianshou.policy import A2CPolicy, ImitationPolicy
|
||||||
from tianshou.trainer import offpolicy_trainer, onpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer, onpolicy_trainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import ActorCritic, Net
|
||||||
from tianshou.utils.net.discrete import Actor, Critic
|
from tianshou.utils.net.discrete import Actor, Critic
|
||||||
|
|
||||||
|
|
||||||
@ -74,9 +74,7 @@ def test_a2c_with_il(args=get_args()):
|
|||||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||||
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||||
critic = Critic(net, device=args.device).to(args.device)
|
critic = Critic(net, device=args.device).to(args.device)
|
||||||
optim = torch.optim.Adam(
|
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
|
||||||
set(actor.parameters()).union(critic.parameters()), lr=args.lr
|
|
||||||
)
|
|
||||||
dist = torch.distributions.Categorical
|
dist = torch.distributions.Categorical
|
||||||
policy = A2CPolicy(
|
policy = A2CPolicy(
|
||||||
actor,
|
actor,
|
||||||
|
@ -12,7 +12,7 @@ from tianshou.env import DummyVectorEnv
|
|||||||
from tianshou.policy import PPOPolicy
|
from tianshou.policy import PPOPolicy
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import ActorCritic, Net
|
||||||
from tianshou.utils.net.discrete import Actor, Critic
|
from tianshou.utils.net.discrete import Actor, Critic
|
||||||
|
|
||||||
|
|
||||||
@ -73,14 +73,13 @@ def test_ppo(args=get_args()):
|
|||||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||||
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||||
critic = Critic(net, device=args.device).to(args.device)
|
critic = Critic(net, device=args.device).to(args.device)
|
||||||
|
actor_critic = ActorCritic(actor, critic)
|
||||||
# orthogonal initialization
|
# orthogonal initialization
|
||||||
for m in set(actor.modules()).union(critic.modules()):
|
for m in actor_critic.modules():
|
||||||
if isinstance(m, torch.nn.Linear):
|
if isinstance(m, torch.nn.Linear):
|
||||||
torch.nn.init.orthogonal_(m.weight)
|
torch.nn.init.orthogonal_(m.weight)
|
||||||
torch.nn.init.zeros_(m.bias)
|
torch.nn.init.zeros_(m.bias)
|
||||||
optim = torch.optim.Adam(
|
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
|
||||||
set(actor.parameters()).union(critic.parameters()), lr=args.lr
|
|
||||||
)
|
|
||||||
dist = torch.distributions.Categorical
|
dist = torch.distributions.Categorical
|
||||||
policy = PPOPolicy(
|
policy = PPOPolicy(
|
||||||
actor,
|
actor,
|
||||||
|
@ -338,7 +338,7 @@ class AsyncCollector(Collector):
|
|||||||
preprocess_fn: Optional[Callable[..., Batch]] = None,
|
preprocess_fn: Optional[Callable[..., Batch]] = None,
|
||||||
exploration_noise: bool = False,
|
exploration_noise: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert env.is_async
|
# assert env.is_async
|
||||||
super().__init__(policy, env, buffer, preprocess_fn, exploration_noise)
|
super().__init__(policy, env, buffer, preprocess_fn, exploration_noise)
|
||||||
|
|
||||||
def reset_env(self) -> None:
|
def reset_env(self) -> None:
|
||||||
@ -452,6 +452,9 @@ class AsyncCollector(Collector):
|
|||||||
obs_next, rew, done, info = result
|
obs_next, rew, done, info = result
|
||||||
|
|
||||||
# change self.data here because ready_env_ids has changed
|
# change self.data here because ready_env_ids has changed
|
||||||
|
try:
|
||||||
|
ready_env_ids = info["env_id"]
|
||||||
|
except Exception:
|
||||||
ready_env_ids = np.array([i["env_id"] for i in info])
|
ready_env_ids = np.array([i["env_id"] for i in info])
|
||||||
self.data = whole_data[ready_env_ids]
|
self.data = whole_data[ready_env_ids]
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ from torch import nn
|
|||||||
|
|
||||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||||
from tianshou.policy import PGPolicy
|
from tianshou.policy import PGPolicy
|
||||||
|
from tianshou.utils.net.common import ActorCritic
|
||||||
|
|
||||||
|
|
||||||
class A2CPolicy(PGPolicy):
|
class A2CPolicy(PGPolicy):
|
||||||
@ -70,6 +71,7 @@ class A2CPolicy(PGPolicy):
|
|||||||
self._weight_ent = ent_coef
|
self._weight_ent = ent_coef
|
||||||
self._grad_norm = max_grad_norm
|
self._grad_norm = max_grad_norm
|
||||||
self._batch = max_batchsize
|
self._batch = max_batchsize
|
||||||
|
self._actor_critic = ActorCritic(self.actor, self.critic)
|
||||||
|
|
||||||
def process_fn(
|
def process_fn(
|
||||||
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
|
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
|
||||||
@ -136,8 +138,7 @@ class A2CPolicy(PGPolicy):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
if self._grad_norm: # clip large gradient
|
if self._grad_norm: # clip large gradient
|
||||||
nn.utils.clip_grad_norm_(
|
nn.utils.clip_grad_norm_(
|
||||||
set(self.actor.parameters()).union(self.critic.parameters()),
|
self._actor_critic.parameters(), max_norm=self._grad_norm
|
||||||
max_norm=self._grad_norm
|
|
||||||
)
|
)
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
actor_losses.append(actor_loss.item())
|
actor_losses.append(actor_loss.item())
|
||||||
|
@ -140,8 +140,7 @@ class PPOPolicy(A2CPolicy):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
if self._grad_norm: # clip large gradient
|
if self._grad_norm: # clip large gradient
|
||||||
nn.utils.clip_grad_norm_(
|
nn.utils.clip_grad_norm_(
|
||||||
set(self.actor.parameters()).union(self.critic.parameters()),
|
self._actor_critic.parameters(), max_norm=self._grad_norm
|
||||||
max_norm=self._grad_norm
|
|
||||||
)
|
)
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
clip_losses.append(clip_loss.item())
|
clip_losses.append(clip_loss.item())
|
||||||
|
@ -262,3 +262,19 @@ class Recurrent(nn.Module):
|
|||||||
s = self.fc2(s[:, -1])
|
s = self.fc2(s[:, -1])
|
||||||
# please ensure the first dim is batch size: [bsz, len, ...]
|
# please ensure the first dim is batch size: [bsz, len, ...]
|
||||||
return s, {"h": h.transpose(0, 1).detach(), "c": c.transpose(0, 1).detach()}
|
return s, {"h": h.transpose(0, 1).detach(), "c": c.transpose(0, 1).detach()}
|
||||||
|
|
||||||
|
|
||||||
|
class ActorCritic(nn.Module):
|
||||||
|
"""An actor-critic network for parsing parameters.
|
||||||
|
|
||||||
|
Using ``actor_critic.parameters()`` instead of set.union or list+list to avoid
|
||||||
|
issue #449.
|
||||||
|
|
||||||
|
:param nn.Module actor: the actor network.
|
||||||
|
:param nn.Module critic: the critic network.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, actor: nn.Module, critic: nn.Module) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.actor = actor
|
||||||
|
self.critic = critic
|
||||||
|
Loading…
x
Reference in New Issue
Block a user