fix atari examples (#206)
This commit is contained in:
parent
8bb8ecba6e
commit
380e9e911d
@ -84,7 +84,7 @@ def test_dqn(args=get_args()):
|
||||
# replay buffer: `save_last_obs` and `stack_num` can be removed together
|
||||
# when you have enough RAM
|
||||
buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True,
|
||||
save_last_obs=True, stack_num=args.frames_stack)
|
||||
save_only_last_obs=True, stack_num=args.frames_stack)
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buffer)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
@ -100,6 +100,8 @@ def test_dqn(args=get_args()):
|
||||
return x >= env.spec.reward_threshold
|
||||
elif 'Pong' in args.task:
|
||||
return x >= 20
|
||||
else:
|
||||
return False
|
||||
|
||||
def train_fn(x):
|
||||
# nature DQN setting, linear decay in the first 1M steps
|
||||
@ -107,10 +109,10 @@ def test_dqn(args=get_args()):
|
||||
if now <= 1e6:
|
||||
eps = args.eps_train - now / 1e6 * \
|
||||
(args.eps_train - args.eps_train_final)
|
||||
policy.set_eps(eps)
|
||||
else:
|
||||
policy.set_eps(args.eps_train_final)
|
||||
print("set eps =", policy.eps)
|
||||
eps = args.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
writer.add_scalar('train/eps', eps, global_step=now)
|
||||
|
||||
def test_fn(x):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
Loading…
x
Reference in New Issue
Block a user