diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index deaf92e..8fd5042 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -82,6 +82,7 @@ def test_ppo(args=get_args()): for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) + torch.nn.init.zeros_(m.bias) optim = torch.optim.Adam(list( actor.parameters()) + list(critic.parameters()), lr=args.lr) dist = torch.distributions.Normal diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 3342707..2d2a484 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -42,7 +42,7 @@ def get_args(): parser.add_argument('--ent-coef', type=float, default=0.0) parser.add_argument('--eps-clip', type=float, default=0.2) parser.add_argument('--max-grad-norm', type=float, default=0.5) - parser.add_argument('--gae-lambda', type=float, default=1) + parser.add_argument('--gae-lambda', type=float, default=0.8) parser.add_argument('--rew-norm', type=bool, default=True) parser.add_argument('--dual-clip', type=float, default=None) parser.add_argument('--value-clip', type=bool, default=True) @@ -75,6 +75,7 @@ def test_ppo(args=get_args()): for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) + torch.nn.init.zeros_(m.bias) optim = torch.optim.Adam(list( actor.parameters()) + list(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical