From 3271c926094fa9500d5bcf234622e3d8b69be874 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 16 May 2020 20:27:01 +0800 Subject: [PATCH] orthogonal init for ppo in test script --- test/continuous/test_ppo.py | 4 ++++ test/discrete/test_ppo.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index e94a917..deaf92e 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -78,6 +78,10 @@ def test_ppo(args=get_args()): critic = Critic( args.layer_num, args.state_shape, device=args.device ).to(args.device) + # orthogonal initialization + for m in list(actor.modules()) + list(critic.modules()): + if isinstance(m, torch.nn.Linear): + torch.nn.init.orthogonal_(m.weight) optim = torch.optim.Adam(list( actor.parameters()) + list(critic.parameters()), lr=args.lr) dist = torch.distributions.Normal diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index f112b69..3342707 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -71,6 +71,10 @@ def test_ppo(args=get_args()): net = Net(args.layer_num, args.state_shape, device=args.device) actor = Actor(net, args.action_shape).to(args.device) critic = Critic(net).to(args.device) + # orthogonal initialization + for m in list(actor.modules()) + list(critic.modules()): + if isinstance(m, torch.nn.Linear): + torch.nn.init.orthogonal_(m.weight) optim = torch.optim.Adam(list( actor.parameters()) + list(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical