actor critic share head bug for example code without sharing head - unify code style (#860)
This commit is contained in:
parent
1423eeb3b2
commit
7ce62a6ad4
@ -17,7 +17,7 @@ from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.policy import A2CPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.common import ActorCritic, Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
|
||||
@ -101,8 +101,10 @@ def test_a2c(args=get_args()):
|
||||
device=args.device,
|
||||
)
|
||||
critic = Critic(net_c, device=args.device).to(args.device)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
|
||||
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
||||
for m in list(actor.modules()) + list(critic.modules()):
|
||||
for m in actor_critic.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
# orthogonal initialization
|
||||
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||
@ -116,7 +118,7 @@ def test_a2c(args=get_args()):
|
||||
m.weight.data.copy_(0.01 * m.weight.data)
|
||||
|
||||
optim = torch.optim.RMSprop(
|
||||
list(actor.parameters()) + list(critic.parameters()),
|
||||
actor_critic.parameters(),
|
||||
lr=args.lr,
|
||||
eps=1e-5,
|
||||
alpha=0.99,
|
||||
|
@ -17,7 +17,7 @@ from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.common import ActorCritic, Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
|
||||
@ -106,8 +106,10 @@ def test_ppo(args=get_args()):
|
||||
device=args.device,
|
||||
)
|
||||
critic = Critic(net_c, device=args.device).to(args.device)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
|
||||
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
||||
for m in list(actor.modules()) + list(critic.modules()):
|
||||
for m in actor_critic.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
# orthogonal initialization
|
||||
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||
@ -120,9 +122,7 @@ def test_ppo(args=get_args()):
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
m.weight.data.copy_(0.01 * m.weight.data)
|
||||
|
||||
optim = torch.optim.Adam(
|
||||
list(actor.parameters()) + list(critic.parameters()), lr=args.lr
|
||||
)
|
||||
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
|
||||
|
||||
lr_scheduler = None
|
||||
if args.lr_decay:
|
||||
|
2
tianshou/env/worker/base.py
vendored
2
tianshou/env/worker/base.py
vendored
@ -47,7 +47,7 @@ class EnvWorker(ABC):
|
||||
|
||||
def recv(
|
||||
self
|
||||
) -> Union[gym_new_venv_step_type, Tuple[np.ndarray, dict], ]: # noqa:E125
|
||||
) -> Union[gym_new_venv_step_type, Tuple[np.ndarray, dict]]: # noqa:E125
|
||||
"""Receive result from low-level worker.
|
||||
|
||||
If the last "send" function sends a NULL action, it only returns a
|
||||
|
Loading…
x
Reference in New Issue
Block a user