Change the way in which deterministic evaluation is controlled:

* Remove flag `eval_mode` from Collector.collect
  * Replace flag `is_eval` in BasePolicy with `is_within_training_step` (negating usages)
    and set it appropriately in BaseTrainer
This commit is contained in:
Dominik Jain 2024-05-02 18:31:03 +02:00
parent 18f236167f
commit ca69e79b4a
51 changed files with 126 additions and 241 deletions

View File

@ -174,18 +174,14 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack, stack_num=args.frames_stack,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True) result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
else: else:
print("Testing agent ...") print("Testing agent ...")
test_collector.reset() test_collector.reset()
result = test_collector.collect( result = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result.pprint_asdict() result.pprint_asdict()
if args.watch: if args.watch:

View File

@ -216,18 +216,14 @@ def main(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack, stack_num=args.frames_stack,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True) result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
else: else:
print("Testing agent ...") print("Testing agent ...")
test_collector.reset() test_collector.reset()
result = test_collector.collect( result = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result.pprint_asdict() result.pprint_asdict()
if args.watch: if args.watch:

View File

@ -187,18 +187,14 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack, stack_num=args.frames_stack,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True) result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
else: else:
print("Testing agent ...") print("Testing agent ...")
test_collector.reset() test_collector.reset()
result = test_collector.collect( result = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result.pprint_asdict() result.pprint_asdict()
if args.watch: if args.watch:

View File

@ -184,18 +184,14 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack, stack_num=args.frames_stack,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True) result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
else: else:
print("Testing agent ...") print("Testing agent ...")
test_collector.reset() test_collector.reset()
result = test_collector.collect( result = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result.pprint_asdict() result.pprint_asdict()
if args.watch: if args.watch:

View File

@ -240,18 +240,14 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack, stack_num=args.frames_stack,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True) result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
else: else:
print("Testing agent ...") print("Testing agent ...")
test_collector.reset() test_collector.reset()
result = test_collector.collect( result = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result.pprint_asdict() result.pprint_asdict()
if args.watch: if args.watch:

View File

@ -178,18 +178,14 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack, stack_num=args.frames_stack,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True) result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
else: else:
print("Testing agent ...") print("Testing agent ...")
test_collector.reset() test_collector.reset()
result = test_collector.collect( result = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result.pprint_asdict() result.pprint_asdict()
if args.watch: if args.watch:

View File

@ -214,18 +214,14 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
beta=args.beta, beta=args.beta,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True) result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
else: else:
print("Testing agent ...") print("Testing agent ...")
test_collector.reset() test_collector.reset()
result = test_collector.collect( result = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result.pprint_asdict() result.pprint_asdict()
if args.watch: if args.watch:

View File

@ -227,18 +227,14 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack, stack_num=args.frames_stack,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True) result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
else: else:
print("Testing agent ...") print("Testing agent ...")
test_collector.reset() test_collector.reset()
result = test_collector.collect( result = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result.pprint_asdict() result.pprint_asdict()
if args.watch: if args.watch:

View File

@ -147,11 +147,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -165,11 +165,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -209,11 +209,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance! # Let's watch its performance!
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -144,11 +144,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -155,11 +155,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance! # Let's watch its performance!
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -266,11 +266,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance! # Let's watch its performance!
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -240,11 +240,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance! # Let's watch its performance!
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats.pprint_asdict() collector_stats.pprint_asdict()

View File

@ -221,11 +221,7 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance! # Let's watch its performance!
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -170,11 +170,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance! # Let's watch its performance!
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -218,11 +218,7 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance! # Let's watch its performance!
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -226,11 +226,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance! # Let's watch its performance!
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -198,11 +198,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance! # Let's watch its performance!
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -198,11 +198,7 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance! # Let's watch its performance!
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -192,11 +192,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance! # Let's watch its performance!
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -190,11 +190,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance! # Let's watch its performance!
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -223,11 +223,7 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance! # Let's watch its performance!
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -191,7 +191,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
test_envs.seed(args.seed) test_envs.seed(args.seed)
print("Testing agent ...") print("Testing agent ...")
test_collector.reset() test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict() result.pprint_asdict()
if args.watch: if args.watch:

View File

@ -175,7 +175,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
test_envs.seed(args.seed) test_envs.seed(args.seed)
print("Testing agent ...") print("Testing agent ...")
test_collector.reset() test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict() result.pprint_asdict()
if args.watch: if args.watch:

View File

@ -191,7 +191,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
test_envs.seed(args.seed) test_envs.seed(args.seed)
print("Testing agent ...") print("Testing agent ...")
test_collector.reset() test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict() result.pprint_asdict()
if args.watch: if args.watch:

View File

@ -148,7 +148,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None:
test_envs.seed(args.seed) test_envs.seed(args.seed)
print("Testing agent ...") print("Testing agent ...")
test_collector.reset() test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict() result.pprint_asdict()
if args.watch: if args.watch:

View File

@ -207,7 +207,7 @@ def test_bcq() -> None:
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
collector = Collector(policy, env) collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35, eval_mode=True) collector.collect(n_episode=1, render=1 / 35)
if not args.watch: if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task) replay_buffer = load_buffer_d4rl(args.expert_data_task)
@ -230,11 +230,7 @@ def test_bcq() -> None:
# Let's watch its performance! # Let's watch its performance!
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -345,7 +345,7 @@ def test_cql() -> None:
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
collector = Collector(policy, env) collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35, eval_mode=True) collector.collect(n_episode=1, render=1 / 35)
if not args.watch: if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task) replay_buffer = load_buffer_d4rl(args.expert_data_task)
@ -368,11 +368,7 @@ def test_cql() -> None:
# Let's watch its performance! # Let's watch its performance!
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -143,7 +143,7 @@ def test_il() -> None:
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
collector = Collector(policy, env) collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35, eval_mode=True) collector.collect(n_episode=1, render=1 / 35)
if not args.watch: if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task) replay_buffer = load_buffer_d4rl(args.expert_data_task)
@ -166,11 +166,7 @@ def test_il() -> None:
# Let's watch its performance! # Let's watch its performance!
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -192,7 +192,7 @@ def test_td3_bc() -> None:
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
collector = Collector(policy, env) collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35, eval_mode=True) collector.collect(n_episode=1, render=1 / 35)
if not args.watch: if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task) replay_buffer = load_buffer_d4rl(args.expert_data_task)
@ -218,11 +218,7 @@ def test_td3_bc() -> None:
# Let's watch its performance! # Let's watch its performance!
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect( collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats) print(collector_stats)

View File

@ -180,18 +180,14 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack, stack_num=args.frames_stack,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True) result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
else: else:
print("Testing agent ...") print("Testing agent ...")
test_collector.reset() test_collector.reset()
result = test_collector.collect( result = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result.pprint_asdict() result.pprint_asdict()
if args.watch: if args.watch:

View File

@ -242,18 +242,14 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack, stack_num=args.frames_stack,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True) result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
else: else:
print("Testing agent ...") print("Testing agent ...")
test_collector.reset() test_collector.reset()
result = test_collector.collect( result = test_collector.collect(n_episode=args.test_num, render=args.render)
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result.pprint_asdict() result.pprint_asdict()
if args.watch: if args.watch:

View File

@ -64,7 +64,7 @@ def policy(request: pytest.FixtureRequest) -> PPOPolicy:
class TestPolicyBasics: class TestPolicyBasics:
def test_get_action(self, policy: PPOPolicy) -> None: def test_get_action(self, policy: PPOPolicy) -> None:
policy.is_eval = True policy.is_within_training_step = False
sample_obs = torch.randn(obs_shape) sample_obs = torch.randn(obs_shape)
policy.deterministic_eval = False policy.deterministic_eval = False
actions = [policy.compute_action(sample_obs) for _ in range(10)] actions = [policy.compute_action(sample_obs) for _ in range(10)]

View File

@ -167,7 +167,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
policy.set_eps(0.2) policy.set_eps(0.2)
collector = Collector(policy, test_envs, buf, exploration_noise=True) collector = Collector(policy, test_envs, buf, exploration_noise=True)
collector.reset() collector.reset()
collector_stats = collector.collect(n_step=args.buffer_size, eval_mode=True) collector_stats = collector.collect(n_step=args.buffer_size)
if args.save_buffer_name.endswith(".hdf5"): if args.save_buffer_name.endswith(".hdf5"):
buf.save_hdf5(args.save_buffer_name) buf.save_hdf5(args.save_buffer_name)
else: else:

View File

@ -185,7 +185,7 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None:
torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")), torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")),
) )
collector = Collector(policy, env) collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35, eval_mode=True) collector.collect(n_episode=1, render=1 / 35)
# trainer # trainer
result = OfflineTrainer( result = OfflineTrainer(

View File

@ -190,5 +190,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non
policy, _, _ = get_agents(args) policy, _, _ = get_agents(args)
[agent.set_eps(args.eps_test) for agent in policy.policies.values()] [agent.set_eps(args.eps_test) for agent in policy.policies.values()]
collector = Collector(policy, env, exploration_noise=True) collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=1, render=args.render, eval_mode=True) result = collector.collect(n_episode=1, render=args.render)
result.pprint_asdict() result.pprint_asdict()

View File

@ -285,5 +285,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non
) )
policy, _, _ = get_agents(args) policy, _, _ = get_agents(args)
collector = Collector(policy, env) collector = Collector(policy, env)
collector_result = collector.collect(n_episode=1, render=args.render, eval_mode=True) collector_result = collector.collect(n_episode=1, render=args.render)
collector_result.pprint_asdict() collector_result.pprint_asdict()

View File

@ -230,5 +230,5 @@ def watch(
policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent) policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent)
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
collector = Collector(policy, env, exploration_noise=True) collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=1, render=args.render, eval_mode=True) result = collector.collect(n_episode=1, render=args.render)
result.pprint_asdict() result.pprint_asdict()

View File

@ -260,17 +260,8 @@ class BaseCollector(ABC):
) -> CollectStats: ) -> CollectStats:
pass pass
def collect( def collect(self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None,
self, no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None) -> CollectStats:
n_step: int | None = None,
n_episode: int | None = None,
random: bool = False,
render: float | None = None,
no_grad: bool = True,
reset_before_collect: bool = False,
gym_reset_kwargs: dict[str, Any] | None = None,
eval_mode: bool = False,
) -> CollectStats:
"""Collect a specified number of steps or episodes. """Collect a specified number of steps or episodes.
To ensure an unbiased sampling result with the n_episode option, this function will To ensure an unbiased sampling result with the n_episode option, this function will
@ -286,9 +277,6 @@ class BaseCollector(ABC):
(The collector needs the initial obs and info to function properly.) (The collector needs the initial obs and info to function properly.)
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's :param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Only used if reset_before_collect is True. reset function. Only used if reset_before_collect is True.
:param eval_mode: whether to collect data in evaluation mode. Will
set the policy to training mode otherwise.
.. note:: .. note::
One and only one collection number specification is permitted, either One and only one collection number specification is permitted, either
@ -302,8 +290,7 @@ class BaseCollector(ABC):
if reset_before_collect: if reset_before_collect:
self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs)
policy_mode_context = in_eval_mode if eval_mode else in_train_mode with in_eval_mode(self.policy): # safety precaution only
with policy_mode_context(self.policy):
return self._collect( return self._collect(
n_step=n_step, n_step=n_step,
n_episode=n_episode, n_episode=n_episode,

View File

@ -130,10 +130,7 @@ class AgentFactory(ABC, ToStringMixin):
log.info( log.info(
f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})", f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})",
) )
train_collector.collect( train_collector.collect(n_step=self.sampling_config.start_timesteps, random=self.sampling_config.start_timesteps_random)
n_step=self.sampling_config.start_timesteps,
random=self.sampling_config.start_timesteps_random,
)
return train_collector, test_collector return train_collector, test_collector
def set_policy_wrapper_factory( def set_policy_wrapper_factory(

View File

@ -337,7 +337,7 @@ class Experiment(ToStringMixin):
) -> None: ) -> None:
collector = Collector(policy, env) collector = Collector(policy, env)
collector.reset() collector.reset()
result = collector.collect(n_episode=num_episodes, render=render, eval_mode=True) result = collector.collect(n_episode=num_episodes, render=render)
assert result.returns_stat is not None # for mypy assert result.returns_stat is not None # for mypy
assert result.lens_stat is not None # for mypy assert result.lens_stat is not None # for mypy
log.info( log.info(

View File

@ -226,8 +226,18 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
self.action_scaling = action_scaling self.action_scaling = action_scaling
self.action_bound_method = action_bound_method self.action_bound_method = action_bound_method
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
# whether the policy is in evaluation mode self.is_within_training_step = False
self.is_eval = False # TODO: remove in favor of kwarg in compute_action/forward? """
flag indicating whether we are currently within a training step, which encompasses data collection
for training and the policy update (gradient steps).
It can be used, for example, to control whether a flag controlling deterministic evaluation should
indeed be applied, because within a training step, we typically always want to apply stochastic evaluation
(even if such a flag is enabled).
This flag should normally remain False and should be set to True only by the algorithm which performs
training steps.
"""
self._compile() self._compile()
@property @property

View File

@ -107,7 +107,7 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
) -> Batch: ) -> Batch:
logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Categorical(logits=logits_BA) dist = Categorical(logits=logits_BA)
act_B = dist.mode if self.deterministic_eval and self.is_eval else dist.sample() act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample()
return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:

View File

@ -197,7 +197,7 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
# the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked
dist = self.dist_fn(action_dist_input_BD) dist = self.dist_fn(action_dist_input_BD)
act_B = dist.mode if self.deterministic_eval and self.is_eval else dist.sample() act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample()
# act is of dimension BA in continuous case and of dimension B in discrete # act is of dimension BA in continuous case and of dimension B in discrete
result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist) result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist)
return cast(DistBatchProtocol, result) return cast(DistBatchProtocol, result)

View File

@ -153,7 +153,7 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
) -> Batch: ) -> Batch:
(loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info) (loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Independent(Normal(loc_B, scale_B), 1) dist = Independent(Normal(loc_B, scale_B), 1)
if self.deterministic_eval and self.is_eval: if self.deterministic_eval and not self.is_within_training_step:
act_B = dist.mode act_B = dist.mode
else: else:
act_B = dist.rsample() act_B = dist.rsample()

View File

@ -175,7 +175,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
) -> DistLogProbBatchProtocol: ) -> DistLogProbBatchProtocol:
(loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info) (loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Independent(Normal(loc=loc_B, scale=scale_B), 1) dist = Independent(Normal(loc=loc_B, scale=scale_B), 1)
if self.deterministic_eval and self.is_eval: if self.deterministic_eval and not self.is_within_training_step:
act_B = dist.mode act_B = dist.mode
else: else:
act_B = dist.rsample() act_B = dist.rsample()

View File

@ -3,6 +3,7 @@ import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict, deque from collections import defaultdict, deque
from collections.abc import Callable from collections.abc import Callable
from contextlib import contextmanager
from dataclasses import asdict from dataclasses import asdict
from typing import Optional, Tuple from typing import Optional, Tuple
@ -407,7 +408,18 @@ class BaseTrainer(ABC):
return test_stat, stop_fn_flag return test_stat, stop_fn_flag
@contextmanager
def _is_within_training_step_enabled(self, is_within_training_step: bool):
old_value = self.policy.is_within_training_step
try:
self.policy.is_within_training_step = is_within_training_step
yield
finally:
self.policy.is_within_training_step = old_value
def training_step(self) -> Tuple[CollectStatsBase, Optional[TrainingStats], bool]: def training_step(self) -> Tuple[CollectStatsBase, Optional[TrainingStats], bool]:
with self._is_within_training_step_enabled(True):
should_stop_training = False should_stop_training = False
if self.train_collector is not None: if self.train_collector is not None:
@ -434,10 +446,7 @@ class BaseTrainer(ABC):
assert self.train_collector is not None assert self.train_collector is not None
if self.train_fn: if self.train_fn:
self.train_fn(self.epoch, self.env_step) self.train_fn(self.epoch, self.env_step)
collect_stats = self.train_collector.collect( collect_stats = self.train_collector.collect(n_step=self.step_per_collect, n_episode=self.episode_per_collect)
n_step=self.step_per_collect,
n_episode=self.episode_per_collect,
)
self.env_step += collect_stats.n_collected_steps self.env_step += collect_stats.n_collected_steps
@ -467,6 +476,8 @@ class BaseTrainer(ABC):
""" """
should_stop_training = False should_stop_training = False
# Because we need to evaluate the policy, we need to temporarily leave the "is_training_step" semantics
with self._is_within_training_step_enabled(False):
if ( if (
collect_stats.n_collected_episodes > 0 collect_stats.n_collected_episodes > 0
and self.test_in_train and self.test_in_train

View File

@ -27,7 +27,7 @@ def test_episode(
collector.reset(reset_stats=False) collector.reset(reset_stats=False)
if test_fn: if test_fn:
test_fn(epoch, global_step) test_fn(epoch, global_step)
result = collector.collect(n_episode=n_episode, eval_mode=True) result = collector.collect(n_episode=n_episode)
if reward_metric: # TODO: move into collector if reward_metric: # TODO: move into collector
rew = reward_metric(result.returns) rew = reward_metric(result.returns)
result.returns = rew result.returns = rew