final fix for actor_critic shared head parameters (#458)
This commit is contained in:
parent
22d7bf38c8
commit
5df64800f4
@ -13,7 +13,7 @@ from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.common import ActorCritic, Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
|
||||
@ -84,14 +84,13 @@ def test_ppo(args=get_args()):
|
||||
Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device),
|
||||
device=args.device
|
||||
).to(args.device)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
# orthogonal initialization
|
||||
for m in set(actor.modules()).union(critic.modules()):
|
||||
for m in actor_critic.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(
|
||||
set(actor.parameters()).union(critic.parameters()), lr=args.lr
|
||||
)
|
||||
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
|
||||
|
||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||
# pass *logits to be consistent with policy.forward
|
||||
|
@ -12,7 +12,7 @@ from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import A2CPolicy, ImitationPolicy
|
||||
from tianshou.trainer import offpolicy_trainer, onpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.common import ActorCritic, Net
|
||||
from tianshou.utils.net.discrete import Actor, Critic
|
||||
|
||||
|
||||
@ -74,9 +74,7 @@ def test_a2c_with_il(args=get_args()):
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||
critic = Critic(net, device=args.device).to(args.device)
|
||||
optim = torch.optim.Adam(
|
||||
set(actor.parameters()).union(critic.parameters()), lr=args.lr
|
||||
)
|
||||
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
policy = A2CPolicy(
|
||||
actor,
|
||||
|
@ -12,7 +12,7 @@ from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.common import ActorCritic, Net
|
||||
from tianshou.utils.net.discrete import Actor, Critic
|
||||
|
||||
|
||||
@ -73,14 +73,13 @@ def test_ppo(args=get_args()):
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||
critic = Critic(net, device=args.device).to(args.device)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
# orthogonal initialization
|
||||
for m in set(actor.modules()).union(critic.modules()):
|
||||
for m in actor_critic.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(
|
||||
set(actor.parameters()).union(critic.parameters()), lr=args.lr
|
||||
)
|
||||
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
policy = PPOPolicy(
|
||||
actor,
|
||||
|
@ -338,7 +338,7 @@ class AsyncCollector(Collector):
|
||||
preprocess_fn: Optional[Callable[..., Batch]] = None,
|
||||
exploration_noise: bool = False,
|
||||
) -> None:
|
||||
assert env.is_async
|
||||
# assert env.is_async
|
||||
super().__init__(policy, env, buffer, preprocess_fn, exploration_noise)
|
||||
|
||||
def reset_env(self) -> None:
|
||||
@ -452,7 +452,10 @@ class AsyncCollector(Collector):
|
||||
obs_next, rew, done, info = result
|
||||
|
||||
# change self.data here because ready_env_ids has changed
|
||||
ready_env_ids = np.array([i["env_id"] for i in info])
|
||||
try:
|
||||
ready_env_ids = info["env_id"]
|
||||
except Exception:
|
||||
ready_env_ids = np.array([i["env_id"] for i in info])
|
||||
self.data = whole_data[ready_env_ids]
|
||||
|
||||
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
|
||||
|
@ -7,6 +7,7 @@ from torch import nn
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
|
||||
|
||||
class A2CPolicy(PGPolicy):
|
||||
@ -70,6 +71,7 @@ class A2CPolicy(PGPolicy):
|
||||
self._weight_ent = ent_coef
|
||||
self._grad_norm = max_grad_norm
|
||||
self._batch = max_batchsize
|
||||
self._actor_critic = ActorCritic(self.actor, self.critic)
|
||||
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
|
||||
@ -136,8 +138,7 @@ class A2CPolicy(PGPolicy):
|
||||
loss.backward()
|
||||
if self._grad_norm: # clip large gradient
|
||||
nn.utils.clip_grad_norm_(
|
||||
set(self.actor.parameters()).union(self.critic.parameters()),
|
||||
max_norm=self._grad_norm
|
||||
self._actor_critic.parameters(), max_norm=self._grad_norm
|
||||
)
|
||||
self.optim.step()
|
||||
actor_losses.append(actor_loss.item())
|
||||
|
@ -140,8 +140,7 @@ class PPOPolicy(A2CPolicy):
|
||||
loss.backward()
|
||||
if self._grad_norm: # clip large gradient
|
||||
nn.utils.clip_grad_norm_(
|
||||
set(self.actor.parameters()).union(self.critic.parameters()),
|
||||
max_norm=self._grad_norm
|
||||
self._actor_critic.parameters(), max_norm=self._grad_norm
|
||||
)
|
||||
self.optim.step()
|
||||
clip_losses.append(clip_loss.item())
|
||||
|
@ -262,3 +262,19 @@ class Recurrent(nn.Module):
|
||||
s = self.fc2(s[:, -1])
|
||||
# please ensure the first dim is batch size: [bsz, len, ...]
|
||||
return s, {"h": h.transpose(0, 1).detach(), "c": c.transpose(0, 1).detach()}
|
||||
|
||||
|
||||
class ActorCritic(nn.Module):
|
||||
"""An actor-critic network for parsing parameters.
|
||||
|
||||
Using ``actor_critic.parameters()`` instead of set.union or list+list to avoid
|
||||
issue #449.
|
||||
|
||||
:param nn.Module actor: the actor network.
|
||||
:param nn.Module critic: the critic network.
|
||||
"""
|
||||
|
||||
def __init__(self, actor: nn.Module, critic: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
|
Loading…
x
Reference in New Issue
Block a user