orthogonal init for ppo in test script

This commit is contained in:
Trinkle23897 2020-05-16 20:27:01 +08:00
parent 0eef0ca198
commit 3271c92609
2 changed files with 8 additions and 0 deletions

View File

@ -78,6 +78,10 @@ def test_ppo(args=get_args()):
critic = Critic( critic = Critic(
args.layer_num, args.state_shape, device=args.device args.layer_num, args.state_shape, device=args.device
).to(args.device) ).to(args.device)
# orthogonal initialization
for m in list(actor.modules()) + list(critic.modules()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
optim = torch.optim.Adam(list( optim = torch.optim.Adam(list(
actor.parameters()) + list(critic.parameters()), lr=args.lr) actor.parameters()) + list(critic.parameters()), lr=args.lr)
dist = torch.distributions.Normal dist = torch.distributions.Normal

View File

@ -71,6 +71,10 @@ def test_ppo(args=get_args()):
net = Net(args.layer_num, args.state_shape, device=args.device) net = Net(args.layer_num, args.state_shape, device=args.device)
actor = Actor(net, args.action_shape).to(args.device) actor = Actor(net, args.action_shape).to(args.device)
critic = Critic(net).to(args.device) critic = Critic(net).to(args.device)
# orthogonal initialization
for m in list(actor.modules()) + list(critic.modules()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
optim = torch.optim.Adam(list( optim = torch.optim.Adam(list(
actor.parameters()) + list(critic.parameters()), lr=args.lr) actor.parameters()) + list(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical dist = torch.distributions.Categorical