diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index e94a917..deaf92e 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -78,6 +78,10 @@ def test_ppo(args=get_args()): critic = Critic( args.layer_num, args.state_shape, device=args.device ).to(args.device) + # orthogonal initialization + for m in list(actor.modules()) + list(critic.modules()): + if isinstance(m, torch.nn.Linear): + torch.nn.init.orthogonal_(m.weight) optim = torch.optim.Adam(list( actor.parameters()) + list(critic.parameters()), lr=args.lr) dist = torch.distributions.Normal diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index f112b69..3342707 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -71,6 +71,10 @@ def test_ppo(args=get_args()): net = Net(args.layer_num, args.state_shape, device=args.device) actor = Actor(net, args.action_shape).to(args.device) critic = Critic(net).to(args.device) + # orthogonal initialization + for m in list(actor.modules()) + list(critic.modules()): + if isinstance(m, torch.nn.Linear): + torch.nn.init.orthogonal_(m.weight) optim = torch.optim.Adam(list( actor.parameters()) + list(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical