orthogonal init for ppo in test script
This commit is contained in:
parent
0eef0ca198
commit
3271c92609
@ -78,6 +78,10 @@ def test_ppo(args=get_args()):
|
|||||||
critic = Critic(
|
critic = Critic(
|
||||||
args.layer_num, args.state_shape, device=args.device
|
args.layer_num, args.state_shape, device=args.device
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
|
# orthogonal initialization
|
||||||
|
for m in list(actor.modules()) + list(critic.modules()):
|
||||||
|
if isinstance(m, torch.nn.Linear):
|
||||||
|
torch.nn.init.orthogonal_(m.weight)
|
||||||
optim = torch.optim.Adam(list(
|
optim = torch.optim.Adam(list(
|
||||||
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||||
dist = torch.distributions.Normal
|
dist = torch.distributions.Normal
|
||||||
|
@ -71,6 +71,10 @@ def test_ppo(args=get_args()):
|
|||||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||||
actor = Actor(net, args.action_shape).to(args.device)
|
actor = Actor(net, args.action_shape).to(args.device)
|
||||||
critic = Critic(net).to(args.device)
|
critic = Critic(net).to(args.device)
|
||||||
|
# orthogonal initialization
|
||||||
|
for m in list(actor.modules()) + list(critic.modules()):
|
||||||
|
if isinstance(m, torch.nn.Linear):
|
||||||
|
torch.nn.init.orthogonal_(m.weight)
|
||||||
optim = torch.optim.Adam(list(
|
optim = torch.optim.Adam(list(
|
||||||
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||||
dist = torch.distributions.Categorical
|
dist = torch.distributions.Categorical
|
||||||
|
Loading…
x
Reference in New Issue
Block a user