orthogonal init for ppo in test script
This commit is contained in:
parent
0eef0ca198
commit
3271c92609
@ -78,6 +78,10 @@ def test_ppo(args=get_args()):
|
||||
critic = Critic(
|
||||
args.layer_num, args.state_shape, device=args.device
|
||||
).to(args.device)
|
||||
# orthogonal initialization
|
||||
for m in list(actor.modules()) + list(critic.modules()):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
optim = torch.optim.Adam(list(
|
||||
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||
dist = torch.distributions.Normal
|
||||
|
@ -71,6 +71,10 @@ def test_ppo(args=get_args()):
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
actor = Actor(net, args.action_shape).to(args.device)
|
||||
critic = Critic(net).to(args.device)
|
||||
# orthogonal initialization
|
||||
for m in list(actor.modules()) + list(critic.modules()):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
optim = torch.optim.Adam(list(
|
||||
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
|
Loading…
x
Reference in New Issue
Block a user