actor critic share head bug for example code without sharing head - unify code style (#860)
This commit is contained in:
parent
1423eeb3b2
commit
7ce62a6ad4
@ -17,7 +17,7 @@ from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
|||||||
from tianshou.policy import A2CPolicy
|
from tianshou.policy import A2CPolicy
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import ActorCritic, Net
|
||||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||||
|
|
||||||
|
|
||||||
@ -101,8 +101,10 @@ def test_a2c(args=get_args()):
|
|||||||
device=args.device,
|
device=args.device,
|
||||||
)
|
)
|
||||||
critic = Critic(net_c, device=args.device).to(args.device)
|
critic = Critic(net_c, device=args.device).to(args.device)
|
||||||
|
actor_critic = ActorCritic(actor, critic)
|
||||||
|
|
||||||
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
||||||
for m in list(actor.modules()) + list(critic.modules()):
|
for m in actor_critic.modules():
|
||||||
if isinstance(m, torch.nn.Linear):
|
if isinstance(m, torch.nn.Linear):
|
||||||
# orthogonal initialization
|
# orthogonal initialization
|
||||||
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||||
@ -116,7 +118,7 @@ def test_a2c(args=get_args()):
|
|||||||
m.weight.data.copy_(0.01 * m.weight.data)
|
m.weight.data.copy_(0.01 * m.weight.data)
|
||||||
|
|
||||||
optim = torch.optim.RMSprop(
|
optim = torch.optim.RMSprop(
|
||||||
list(actor.parameters()) + list(critic.parameters()),
|
actor_critic.parameters(),
|
||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
eps=1e-5,
|
eps=1e-5,
|
||||||
alpha=0.99,
|
alpha=0.99,
|
||||||
|
@ -17,7 +17,7 @@ from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
|||||||
from tianshou.policy import PPOPolicy
|
from tianshou.policy import PPOPolicy
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import ActorCritic, Net
|
||||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||||
|
|
||||||
|
|
||||||
@ -106,8 +106,10 @@ def test_ppo(args=get_args()):
|
|||||||
device=args.device,
|
device=args.device,
|
||||||
)
|
)
|
||||||
critic = Critic(net_c, device=args.device).to(args.device)
|
critic = Critic(net_c, device=args.device).to(args.device)
|
||||||
|
actor_critic = ActorCritic(actor, critic)
|
||||||
|
|
||||||
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
||||||
for m in list(actor.modules()) + list(critic.modules()):
|
for m in actor_critic.modules():
|
||||||
if isinstance(m, torch.nn.Linear):
|
if isinstance(m, torch.nn.Linear):
|
||||||
# orthogonal initialization
|
# orthogonal initialization
|
||||||
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||||
@ -120,9 +122,7 @@ def test_ppo(args=get_args()):
|
|||||||
torch.nn.init.zeros_(m.bias)
|
torch.nn.init.zeros_(m.bias)
|
||||||
m.weight.data.copy_(0.01 * m.weight.data)
|
m.weight.data.copy_(0.01 * m.weight.data)
|
||||||
|
|
||||||
optim = torch.optim.Adam(
|
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
|
||||||
list(actor.parameters()) + list(critic.parameters()), lr=args.lr
|
|
||||||
)
|
|
||||||
|
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
if args.lr_decay:
|
if args.lr_decay:
|
||||||
|
2
tianshou/env/worker/base.py
vendored
2
tianshou/env/worker/base.py
vendored
@ -47,7 +47,7 @@ class EnvWorker(ABC):
|
|||||||
|
|
||||||
def recv(
|
def recv(
|
||||||
self
|
self
|
||||||
) -> Union[gym_new_venv_step_type, Tuple[np.ndarray, dict], ]: # noqa:E125
|
) -> Union[gym_new_venv_step_type, Tuple[np.ndarray, dict]]: # noqa:E125
|
||||||
"""Receive result from low-level worker.
|
"""Receive result from low-level worker.
|
||||||
|
|
||||||
If the last "send" function sends a NULL action, it only returns a
|
If the last "send" function sends a NULL action, it only returns a
|
||||||
|
Loading…
x
Reference in New Issue
Block a user