actor critic share head bug for example code without sharing head - unify code style (#860)

This commit is contained in:
Gen 2023-04-29 06:43:22 +02:00 committed by GitHub
parent 1423eeb3b2
commit 7ce62a6ad4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 9 deletions

View File

@ -17,7 +17,7 @@ from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.policy import A2CPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
@ -101,8 +101,10 @@ def test_a2c(args=get_args()):
device=args.device,
)
critic = Critic(net_c, device=args.device).to(args.device)
actor_critic = ActorCritic(actor, critic)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in list(actor.modules()) + list(critic.modules()):
for m in actor_critic.modules():
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
@ -116,7 +118,7 @@ def test_a2c(args=get_args()):
m.weight.data.copy_(0.01 * m.weight.data)
optim = torch.optim.RMSprop(
list(actor.parameters()) + list(critic.parameters()),
actor_critic.parameters(),
lr=args.lr,
eps=1e-5,
alpha=0.99,

View File

@ -17,7 +17,7 @@ from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.policy import PPOPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
@ -106,8 +106,10 @@ def test_ppo(args=get_args()):
device=args.device,
)
critic = Critic(net_c, device=args.device).to(args.device)
actor_critic = ActorCritic(actor, critic)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in list(actor.modules()) + list(critic.modules()):
for m in actor_critic.modules():
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
@ -120,9 +122,7 @@ def test_ppo(args=get_args()):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)
optim = torch.optim.Adam(
list(actor.parameters()) + list(critic.parameters()), lr=args.lr
)
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
lr_scheduler = None
if args.lr_decay:

View File

@ -47,7 +47,7 @@ class EnvWorker(ABC):
def recv(
self
) -> Union[gym_new_venv_step_type, Tuple[np.ndarray, dict], ]: # noqa:E125
) -> Union[gym_new_venv_step_type, Tuple[np.ndarray, dict]]: # noqa:E125
"""Receive result from low-level worker.
If the last "send" function sends a NULL action, it only returns a