From 7ce62a6ad48667ee7ef00ccc9f0d116d4ef27c1c Mon Sep 17 00:00:00 2001 From: Gen Date: Sat, 29 Apr 2023 06:43:22 +0200 Subject: [PATCH] actor critic share head bug for example code without sharing head - unify code style (#860) --- examples/mujoco/mujoco_a2c.py | 8 +++++--- examples/mujoco/mujoco_ppo.py | 10 +++++----- tianshou/env/worker/base.py | 2 +- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 79cfebb..14cbc72 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -17,7 +17,7 @@ from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.policy import A2CPolicy from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger, WandbLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -101,8 +101,10 @@ def test_a2c(args=get_args()): device=args.device, ) critic = Critic(net_c, device=args.device).to(args.device) + actor_critic = ActorCritic(actor, critic) + torch.nn.init.constant_(actor.sigma_param, -0.5) - for m in list(actor.modules()) + list(critic.modules()): + for m in actor_critic.modules(): if isinstance(m, torch.nn.Linear): # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) @@ -116,7 +118,7 @@ def test_a2c(args=get_args()): m.weight.data.copy_(0.01 * m.weight.data) optim = torch.optim.RMSprop( - list(actor.parameters()) + list(critic.parameters()), + actor_critic.parameters(), lr=args.lr, eps=1e-5, alpha=0.99, diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index cb79a63..354c9ef 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -17,7 +17,7 @@ from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.policy import PPOPolicy from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger, WandbLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -106,8 +106,10 @@ def test_ppo(args=get_args()): device=args.device, ) critic = Critic(net_c, device=args.device).to(args.device) + actor_critic = ActorCritic(actor, critic) + torch.nn.init.constant_(actor.sigma_param, -0.5) - for m in list(actor.modules()) + list(critic.modules()): + for m in actor_critic.modules(): if isinstance(m, torch.nn.Linear): # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) @@ -120,9 +122,7 @@ def test_ppo(args=get_args()): torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) - optim = torch.optim.Adam( - list(actor.parameters()) + list(critic.parameters()), lr=args.lr - ) + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) lr_scheduler = None if args.lr_decay: diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index 773d56b..794f85a 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -47,7 +47,7 @@ class EnvWorker(ABC): def recv( self - ) -> Union[gym_new_venv_step_type, Tuple[np.ndarray, dict], ]: # noqa:E125 + ) -> Union[gym_new_venv_step_type, Tuple[np.ndarray, dict]]: # noqa:E125 """Receive result from low-level worker. If the last "send" function sends a NULL action, it only returns a