final fix for actor_critic shared head parameters (#458)

This commit is contained in:
Jiayi Weng 2021-10-04 11:19:07 -04:00 committed by GitHub
parent 22d7bf38c8
commit 5df64800f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 35 additions and 20 deletions

View File

@ -13,7 +13,7 @@ from tianshou.env import DummyVectorEnv
from tianshou.policy import PPOPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
@ -84,14 +84,13 @@ def test_ppo(args=get_args()):
Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device),
device=args.device
).to(args.device)
actor_critic = ActorCritic(actor, critic)
# orthogonal initialization
for m in set(actor.modules()).union(critic.modules()):
for m in actor_critic.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(
set(actor.parameters()).union(critic.parameters()), lr=args.lr
)
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward

View File

@ -12,7 +12,7 @@ from tianshou.env import DummyVectorEnv
from tianshou.policy import A2CPolicy, ImitationPolicy
from tianshou.trainer import offpolicy_trainer, onpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.discrete import Actor, Critic
@ -74,9 +74,7 @@ def test_a2c_with_il(args=get_args()):
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
critic = Critic(net, device=args.device).to(args.device)
optim = torch.optim.Adam(
set(actor.parameters()).union(critic.parameters()), lr=args.lr
)
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
dist = torch.distributions.Categorical
policy = A2CPolicy(
actor,

View File

@ -12,7 +12,7 @@ from tianshou.env import DummyVectorEnv
from tianshou.policy import PPOPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.discrete import Actor, Critic
@ -73,14 +73,13 @@ def test_ppo(args=get_args()):
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
critic = Critic(net, device=args.device).to(args.device)
actor_critic = ActorCritic(actor, critic)
# orthogonal initialization
for m in set(actor.modules()).union(critic.modules()):
for m in actor_critic.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(
set(actor.parameters()).union(critic.parameters()), lr=args.lr
)
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
dist = torch.distributions.Categorical
policy = PPOPolicy(
actor,

View File

@ -338,7 +338,7 @@ class AsyncCollector(Collector):
preprocess_fn: Optional[Callable[..., Batch]] = None,
exploration_noise: bool = False,
) -> None:
assert env.is_async
# assert env.is_async
super().__init__(policy, env, buffer, preprocess_fn, exploration_noise)
def reset_env(self) -> None:
@ -452,7 +452,10 @@ class AsyncCollector(Collector):
obs_next, rew, done, info = result
# change self.data here because ready_env_ids has changed
ready_env_ids = np.array([i["env_id"] for i in info])
try:
ready_env_ids = info["env_id"]
except Exception:
ready_env_ids = np.array([i["env_id"] for i in info])
self.data = whole_data[ready_env_ids]
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)

View File

@ -7,6 +7,7 @@ from torch import nn
from tianshou.data import Batch, ReplayBuffer, to_torch_as
from tianshou.policy import PGPolicy
from tianshou.utils.net.common import ActorCritic
class A2CPolicy(PGPolicy):
@ -70,6 +71,7 @@ class A2CPolicy(PGPolicy):
self._weight_ent = ent_coef
self._grad_norm = max_grad_norm
self._batch = max_batchsize
self._actor_critic = ActorCritic(self.actor, self.critic)
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
@ -136,8 +138,7 @@ class A2CPolicy(PGPolicy):
loss.backward()
if self._grad_norm: # clip large gradient
nn.utils.clip_grad_norm_(
set(self.actor.parameters()).union(self.critic.parameters()),
max_norm=self._grad_norm
self._actor_critic.parameters(), max_norm=self._grad_norm
)
self.optim.step()
actor_losses.append(actor_loss.item())

View File

@ -140,8 +140,7 @@ class PPOPolicy(A2CPolicy):
loss.backward()
if self._grad_norm: # clip large gradient
nn.utils.clip_grad_norm_(
set(self.actor.parameters()).union(self.critic.parameters()),
max_norm=self._grad_norm
self._actor_critic.parameters(), max_norm=self._grad_norm
)
self.optim.step()
clip_losses.append(clip_loss.item())

View File

@ -262,3 +262,19 @@ class Recurrent(nn.Module):
s = self.fc2(s[:, -1])
# please ensure the first dim is batch size: [bsz, len, ...]
return s, {"h": h.transpose(0, 1).detach(), "c": c.transpose(0, 1).detach()}
class ActorCritic(nn.Module):
"""An actor-critic network for parsing parameters.
Using ``actor_critic.parameters()`` instead of set.union or list+list to avoid
issue #449.
:param nn.Module actor: the actor network.
:param nn.Module critic: the critic network.
"""
def __init__(self, actor: nn.Module, critic: nn.Module) -> None:
super().__init__()
self.actor = actor
self.critic = critic