fix atari examples (#206)

This commit is contained in:
n+e 2020-09-06 23:05:33 +08:00 committed by GitHub
parent 8bb8ecba6e
commit 380e9e911d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -84,7 +84,7 @@ def test_dqn(args=get_args()):
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True,
save_last_obs=True, stack_num=args.frames_stack)
save_only_last_obs=True, stack_num=args.frames_stack)
# collector
train_collector = Collector(policy, train_envs, buffer)
test_collector = Collector(policy, test_envs)
@ -100,6 +100,8 @@ def test_dqn(args=get_args()):
return x >= env.spec.reward_threshold
elif 'Pong' in args.task:
return x >= 20
else:
return False
def train_fn(x):
# nature DQN setting, linear decay in the first 1M steps
@ -107,10 +109,10 @@ def test_dqn(args=get_args()):
if now <= 1e6:
eps = args.eps_train - now / 1e6 * \
(args.eps_train - args.eps_train_final)
policy.set_eps(eps)
else:
policy.set_eps(args.eps_train_final)
print("set eps =", policy.eps)
eps = args.eps_train_final
policy.set_eps(eps)
writer.add_scalar('train/eps', eps, global_step=now)
def test_fn(x):
policy.set_eps(args.eps_test)