Change the way in which deterministic evaluation is controlled:
* Remove flag `eval_mode` from Collector.collect * Replace flag `is_eval` in BasePolicy with `is_within_training_step` (negating usages) and set it appropriately in BaseTrainer
This commit is contained in:
parent
18f236167f
commit
ca69e79b4a
@ -174,18 +174,14 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -216,18 +216,14 @@ def main(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -187,18 +187,14 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -184,18 +184,14 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -240,18 +240,14 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -178,18 +178,14 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -214,18 +214,14 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
|
||||
beta=args.beta,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -227,18 +227,14 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -147,11 +147,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -165,11 +165,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -209,11 +209,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None:
|
||||
# Let's watch its performance!
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -144,11 +144,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -155,11 +155,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
|
||||
# Let's watch its performance!
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -266,11 +266,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
|
||||
# Let's watch its performance!
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -240,11 +240,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
|
||||
# Let's watch its performance!
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats.pprint_asdict()
|
||||
|
||||
|
||||
|
@ -221,11 +221,7 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None:
|
||||
# Let's watch its performance!
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -170,11 +170,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
|
||||
# Let's watch its performance!
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -218,11 +218,7 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
|
||||
# Let's watch its performance!
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -226,11 +226,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
# Let's watch its performance!
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -198,11 +198,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
|
||||
# Let's watch its performance!
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -198,11 +198,7 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None:
|
||||
# Let's watch its performance!
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -192,11 +192,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
|
||||
# Let's watch its performance!
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -190,11 +190,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
|
||||
# Let's watch its performance!
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -223,11 +223,7 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
|
||||
# Let's watch its performance!
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -191,7 +191,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
|
||||
test_envs.seed(args.seed)
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -175,7 +175,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
|
||||
test_envs.seed(args.seed)
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -191,7 +191,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
|
||||
test_envs.seed(args.seed)
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -148,7 +148,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None:
|
||||
test_envs.seed(args.seed)
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -207,7 +207,7 @@ def test_bcq() -> None:
|
||||
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
|
||||
if not args.watch:
|
||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||
@ -230,11 +230,7 @@ def test_bcq() -> None:
|
||||
# Let's watch its performance!
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -345,7 +345,7 @@ def test_cql() -> None:
|
||||
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
|
||||
if not args.watch:
|
||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||
@ -368,11 +368,7 @@ def test_cql() -> None:
|
||||
# Let's watch its performance!
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -143,7 +143,7 @@ def test_il() -> None:
|
||||
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
|
||||
if not args.watch:
|
||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||
@ -166,11 +166,7 @@ def test_il() -> None:
|
||||
# Let's watch its performance!
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -192,7 +192,7 @@ def test_td3_bc() -> None:
|
||||
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
|
||||
if not args.watch:
|
||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||
@ -218,11 +218,7 @@ def test_td3_bc() -> None:
|
||||
# Let's watch its performance!
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -180,18 +180,14 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -242,18 +242,14 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
eval_mode=True,
|
||||
)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -250,7 +250,7 @@ def test_finite_dummy_vector_env() -> None:
|
||||
envs.tracker = MetricTracker()
|
||||
try:
|
||||
# TODO: why on earth 10**18?
|
||||
test_collector.collect(n_step=10**18)
|
||||
test_collector.collect(n_step=10 ** 18)
|
||||
except StopIteration:
|
||||
envs.tracker.validate()
|
||||
|
||||
@ -265,6 +265,6 @@ def test_finite_subproc_vector_env() -> None:
|
||||
for _ in range(3):
|
||||
envs.tracker = MetricTracker()
|
||||
try:
|
||||
test_collector.collect(n_step=10**18)
|
||||
test_collector.collect(n_step=10 ** 18)
|
||||
except StopIteration:
|
||||
envs.tracker.validate()
|
||||
|
@ -64,7 +64,7 @@ def policy(request: pytest.FixtureRequest) -> PPOPolicy:
|
||||
|
||||
class TestPolicyBasics:
|
||||
def test_get_action(self, policy: PPOPolicy) -> None:
|
||||
policy.is_eval = True
|
||||
policy.is_within_training_step = False
|
||||
sample_obs = torch.randn(obs_shape)
|
||||
policy.deterministic_eval = False
|
||||
actions = [policy.compute_action(sample_obs) for _ in range(10)]
|
||||
|
@ -167,7 +167,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
|
||||
policy.set_eps(0.2)
|
||||
collector = Collector(policy, test_envs, buf, exploration_noise=True)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
collector_stats = collector.collect(n_step=args.buffer_size)
|
||||
if args.save_buffer_name.endswith(".hdf5"):
|
||||
buf.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
|
@ -185,7 +185,7 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None:
|
||||
torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")),
|
||||
)
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
|
||||
# trainer
|
||||
result = OfflineTrainer(
|
||||
|
@ -190,5 +190,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non
|
||||
policy, _, _ = get_agents(args)
|
||||
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
|
||||
collector = Collector(policy, env, exploration_noise=True)
|
||||
result = collector.collect(n_episode=1, render=args.render, eval_mode=True)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
result.pprint_asdict()
|
||||
|
@ -285,5 +285,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non
|
||||
)
|
||||
policy, _, _ = get_agents(args)
|
||||
collector = Collector(policy, env)
|
||||
collector_result = collector.collect(n_episode=1, render=args.render, eval_mode=True)
|
||||
collector_result = collector.collect(n_episode=1, render=args.render)
|
||||
collector_result.pprint_asdict()
|
||||
|
@ -230,5 +230,5 @@ def watch(
|
||||
policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent)
|
||||
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
|
||||
collector = Collector(policy, env, exploration_noise=True)
|
||||
result = collector.collect(n_episode=1, render=args.render, eval_mode=True)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
result.pprint_asdict()
|
||||
|
@ -260,17 +260,8 @@ class BaseCollector(ABC):
|
||||
) -> CollectStats:
|
||||
pass
|
||||
|
||||
def collect(
|
||||
self,
|
||||
n_step: int | None = None,
|
||||
n_episode: int | None = None,
|
||||
random: bool = False,
|
||||
render: float | None = None,
|
||||
no_grad: bool = True,
|
||||
reset_before_collect: bool = False,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
eval_mode: bool = False,
|
||||
) -> CollectStats:
|
||||
def collect(self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None,
|
||||
no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None) -> CollectStats:
|
||||
"""Collect a specified number of steps or episodes.
|
||||
|
||||
To ensure an unbiased sampling result with the n_episode option, this function will
|
||||
@ -286,9 +277,6 @@ class BaseCollector(ABC):
|
||||
(The collector needs the initial obs and info to function properly.)
|
||||
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
|
||||
reset function. Only used if reset_before_collect is True.
|
||||
:param eval_mode: whether to collect data in evaluation mode. Will
|
||||
set the policy to training mode otherwise.
|
||||
|
||||
.. note::
|
||||
|
||||
One and only one collection number specification is permitted, either
|
||||
@ -302,8 +290,7 @@ class BaseCollector(ABC):
|
||||
if reset_before_collect:
|
||||
self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs)
|
||||
|
||||
policy_mode_context = in_eval_mode if eval_mode else in_train_mode
|
||||
with policy_mode_context(self.policy):
|
||||
with in_eval_mode(self.policy): # safety precaution only
|
||||
return self._collect(
|
||||
n_step=n_step,
|
||||
n_episode=n_episode,
|
||||
|
@ -130,10 +130,7 @@ class AgentFactory(ABC, ToStringMixin):
|
||||
log.info(
|
||||
f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})",
|
||||
)
|
||||
train_collector.collect(
|
||||
n_step=self.sampling_config.start_timesteps,
|
||||
random=self.sampling_config.start_timesteps_random,
|
||||
)
|
||||
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=self.sampling_config.start_timesteps_random)
|
||||
return train_collector, test_collector
|
||||
|
||||
def set_policy_wrapper_factory(
|
||||
|
@ -337,7 +337,7 @@ class Experiment(ToStringMixin):
|
||||
) -> None:
|
||||
collector = Collector(policy, env)
|
||||
collector.reset()
|
||||
result = collector.collect(n_episode=num_episodes, render=render, eval_mode=True)
|
||||
result = collector.collect(n_episode=num_episodes, render=render)
|
||||
assert result.returns_stat is not None # for mypy
|
||||
assert result.lens_stat is not None # for mypy
|
||||
log.info(
|
||||
|
@ -226,8 +226,18 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
||||
self.action_scaling = action_scaling
|
||||
self.action_bound_method = action_bound_method
|
||||
self.lr_scheduler = lr_scheduler
|
||||
# whether the policy is in evaluation mode
|
||||
self.is_eval = False # TODO: remove in favor of kwarg in compute_action/forward?
|
||||
self.is_within_training_step = False
|
||||
"""
|
||||
flag indicating whether we are currently within a training step, which encompasses data collection
|
||||
for training and the policy update (gradient steps).
|
||||
|
||||
It can be used, for example, to control whether a flag controlling deterministic evaluation should
|
||||
indeed be applied, because within a training step, we typically always want to apply stochastic evaluation
|
||||
(even if such a flag is enabled).
|
||||
|
||||
This flag should normally remain False and should be set to True only by the algorithm which performs
|
||||
training steps.
|
||||
"""
|
||||
self._compile()
|
||||
|
||||
@property
|
||||
|
@ -107,7 +107,7 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
|
||||
) -> Batch:
|
||||
logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
dist = Categorical(logits=logits_BA)
|
||||
act_B = dist.mode if self.deterministic_eval and self.is_eval else dist.sample()
|
||||
act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample()
|
||||
return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
|
@ -197,7 +197,7 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
|
||||
# the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked
|
||||
dist = self.dist_fn(action_dist_input_BD)
|
||||
|
||||
act_B = dist.mode if self.deterministic_eval and self.is_eval else dist.sample()
|
||||
act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample()
|
||||
# act is of dimension BA in continuous case and of dimension B in discrete
|
||||
result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist)
|
||||
return cast(DistBatchProtocol, result)
|
||||
|
@ -153,7 +153,7 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
|
||||
) -> Batch:
|
||||
(loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
dist = Independent(Normal(loc_B, scale_B), 1)
|
||||
if self.deterministic_eval and self.is_eval:
|
||||
if self.deterministic_eval and not self.is_within_training_step:
|
||||
act_B = dist.mode
|
||||
else:
|
||||
act_B = dist.rsample()
|
||||
|
@ -175,7 +175,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
|
||||
) -> DistLogProbBatchProtocol:
|
||||
(loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
dist = Independent(Normal(loc=loc_B, scale=scale_B), 1)
|
||||
if self.deterministic_eval and self.is_eval:
|
||||
if self.deterministic_eval and not self.is_within_training_step:
|
||||
act_B = dist.mode
|
||||
else:
|
||||
act_B = dist.rsample()
|
||||
|
@ -3,6 +3,7 @@ import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import asdict
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@ -407,23 +408,34 @@ class BaseTrainer(ABC):
|
||||
|
||||
return test_stat, stop_fn_flag
|
||||
|
||||
@contextmanager
|
||||
def _is_within_training_step_enabled(self, is_within_training_step: bool):
|
||||
old_value = self.policy.is_within_training_step
|
||||
try:
|
||||
self.policy.is_within_training_step = is_within_training_step
|
||||
yield
|
||||
finally:
|
||||
self.policy.is_within_training_step = old_value
|
||||
|
||||
def training_step(self) -> Tuple[CollectStatsBase, Optional[TrainingStats], bool]:
|
||||
should_stop_training = False
|
||||
with self._is_within_training_step_enabled(True):
|
||||
|
||||
if self.train_collector is not None:
|
||||
collect_stats = self._collect_training_data()
|
||||
should_stop_training = self._test_in_train(collect_stats)
|
||||
else:
|
||||
collect_stats = CollectStatsBase(
|
||||
n_collected_episodes=len(self.buffer),
|
||||
)
|
||||
should_stop_training = False
|
||||
|
||||
if not should_stop_training:
|
||||
training_stats = self.policy_update_fn(collect_stats)
|
||||
else:
|
||||
training_stats = None
|
||||
if self.train_collector is not None:
|
||||
collect_stats = self._collect_training_data()
|
||||
should_stop_training = self._test_in_train(collect_stats)
|
||||
else:
|
||||
collect_stats = CollectStatsBase(
|
||||
n_collected_episodes=len(self.buffer),
|
||||
)
|
||||
|
||||
return collect_stats, training_stats, should_stop_training
|
||||
if not should_stop_training:
|
||||
training_stats = self.policy_update_fn(collect_stats)
|
||||
else:
|
||||
training_stats = None
|
||||
|
||||
return collect_stats, training_stats, should_stop_training
|
||||
|
||||
def _collect_training_data(self) -> CollectStats:
|
||||
"""Performs training data collection
|
||||
@ -434,10 +446,7 @@ class BaseTrainer(ABC):
|
||||
assert self.train_collector is not None
|
||||
if self.train_fn:
|
||||
self.train_fn(self.epoch, self.env_step)
|
||||
collect_stats = self.train_collector.collect(
|
||||
n_step=self.step_per_collect,
|
||||
n_episode=self.episode_per_collect,
|
||||
)
|
||||
collect_stats = self.train_collector.collect(n_step=self.step_per_collect, n_episode=self.episode_per_collect)
|
||||
|
||||
self.env_step += collect_stats.n_collected_steps
|
||||
|
||||
@ -467,26 +476,28 @@ class BaseTrainer(ABC):
|
||||
"""
|
||||
should_stop_training = False
|
||||
|
||||
if (
|
||||
collect_stats.n_collected_episodes > 0
|
||||
and self.test_in_train
|
||||
and self.stop_fn
|
||||
and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore
|
||||
):
|
||||
assert self.test_collector is not None
|
||||
test_result = test_episode(
|
||||
self.test_collector,
|
||||
self.test_fn,
|
||||
self.epoch,
|
||||
self.episode_per_test,
|
||||
self.logger,
|
||||
self.env_step,
|
||||
)
|
||||
assert test_result.returns_stat is not None # for mypy
|
||||
if self.stop_fn(test_result.returns_stat.mean):
|
||||
should_stop_training = True
|
||||
self.best_reward = test_result.returns_stat.mean
|
||||
self.best_reward_std = test_result.returns_stat.std
|
||||
# Because we need to evaluate the policy, we need to temporarily leave the "is_training_step" semantics
|
||||
with self._is_within_training_step_enabled(False):
|
||||
if (
|
||||
collect_stats.n_collected_episodes > 0
|
||||
and self.test_in_train
|
||||
and self.stop_fn
|
||||
and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore
|
||||
):
|
||||
assert self.test_collector is not None
|
||||
test_result = test_episode(
|
||||
self.test_collector,
|
||||
self.test_fn,
|
||||
self.epoch,
|
||||
self.episode_per_test,
|
||||
self.logger,
|
||||
self.env_step,
|
||||
)
|
||||
assert test_result.returns_stat is not None # for mypy
|
||||
if self.stop_fn(test_result.returns_stat.mean):
|
||||
should_stop_training = True
|
||||
self.best_reward = test_result.returns_stat.mean
|
||||
self.best_reward_std = test_result.returns_stat.std
|
||||
|
||||
return should_stop_training
|
||||
|
||||
|
@ -27,7 +27,7 @@ def test_episode(
|
||||
collector.reset(reset_stats=False)
|
||||
if test_fn:
|
||||
test_fn(epoch, global_step)
|
||||
result = collector.collect(n_episode=n_episode, eval_mode=True)
|
||||
result = collector.collect(n_episode=n_episode)
|
||||
if reward_metric: # TODO: move into collector
|
||||
rew = reward_metric(result.returns)
|
||||
result.returns = rew
|
||||
|
Loading…
x
Reference in New Issue
Block a user