fix atari examples (#206)
This commit is contained in:
parent
8bb8ecba6e
commit
380e9e911d
@ -84,7 +84,7 @@ def test_dqn(args=get_args()):
|
|||||||
# replay buffer: `save_last_obs` and `stack_num` can be removed together
|
# replay buffer: `save_last_obs` and `stack_num` can be removed together
|
||||||
# when you have enough RAM
|
# when you have enough RAM
|
||||||
buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True,
|
buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True,
|
||||||
save_last_obs=True, stack_num=args.frames_stack)
|
save_only_last_obs=True, stack_num=args.frames_stack)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(policy, train_envs, buffer)
|
train_collector = Collector(policy, train_envs, buffer)
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
@ -100,6 +100,8 @@ def test_dqn(args=get_args()):
|
|||||||
return x >= env.spec.reward_threshold
|
return x >= env.spec.reward_threshold
|
||||||
elif 'Pong' in args.task:
|
elif 'Pong' in args.task:
|
||||||
return x >= 20
|
return x >= 20
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
def train_fn(x):
|
def train_fn(x):
|
||||||
# nature DQN setting, linear decay in the first 1M steps
|
# nature DQN setting, linear decay in the first 1M steps
|
||||||
@ -107,10 +109,10 @@ def test_dqn(args=get_args()):
|
|||||||
if now <= 1e6:
|
if now <= 1e6:
|
||||||
eps = args.eps_train - now / 1e6 * \
|
eps = args.eps_train - now / 1e6 * \
|
||||||
(args.eps_train - args.eps_train_final)
|
(args.eps_train - args.eps_train_final)
|
||||||
policy.set_eps(eps)
|
|
||||||
else:
|
else:
|
||||||
policy.set_eps(args.eps_train_final)
|
eps = args.eps_train_final
|
||||||
print("set eps =", policy.eps)
|
policy.set_eps(eps)
|
||||||
|
writer.add_scalar('train/eps', eps, global_step=now)
|
||||||
|
|
||||||
def test_fn(x):
|
def test_fn(x):
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user