Collector: extracted interface BaseCollector, minor simplifications

Renamed is_eval kwarg
This commit is contained in:
Michael Panchenko 2024-04-26 16:46:03 +02:00
parent 12d4262f80
commit 4b619c51ba
41 changed files with 303 additions and 275 deletions

View File

@ -174,7 +174,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack, stack_num=args.frames_stack,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, is_eval=True) result = collector.collect(n_step=args.buffer_size, eval_mode=True)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
@ -184,7 +184,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
result = test_collector.collect( result = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
result.pprint_asdict() result.pprint_asdict()

View File

@ -216,7 +216,7 @@ def main(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack, stack_num=args.frames_stack,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, is_eval=True) result = collector.collect(n_step=args.buffer_size, eval_mode=True)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
@ -226,7 +226,7 @@ def main(args: argparse.Namespace = get_args()) -> None:
result = test_collector.collect( result = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
result.pprint_asdict() result.pprint_asdict()

View File

@ -187,7 +187,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack, stack_num=args.frames_stack,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, is_eval=True) result = collector.collect(n_step=args.buffer_size, eval_mode=True)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
@ -197,7 +197,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
result = test_collector.collect( result = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
result.pprint_asdict() result.pprint_asdict()

View File

@ -184,7 +184,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack, stack_num=args.frames_stack,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, is_eval=True) result = collector.collect(n_step=args.buffer_size, eval_mode=True)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
@ -194,7 +194,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
result = test_collector.collect( result = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
result.pprint_asdict() result.pprint_asdict()

View File

@ -240,7 +240,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack, stack_num=args.frames_stack,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, is_eval=True) result = collector.collect(n_step=args.buffer_size, eval_mode=True)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
@ -250,7 +250,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
result = test_collector.collect( result = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
result.pprint_asdict() result.pprint_asdict()

View File

@ -178,7 +178,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack, stack_num=args.frames_stack,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, is_eval=True) result = collector.collect(n_step=args.buffer_size, eval_mode=True)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
@ -188,7 +188,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
result = test_collector.collect( result = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
result.pprint_asdict() result.pprint_asdict()

View File

@ -214,7 +214,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
beta=args.beta, beta=args.beta,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, is_eval=True) result = collector.collect(n_step=args.buffer_size, eval_mode=True)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
@ -224,7 +224,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
result = test_collector.collect( result = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
result.pprint_asdict() result.pprint_asdict()

View File

@ -227,7 +227,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack, stack_num=args.frames_stack,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, is_eval=True) result = collector.collect(n_step=args.buffer_size, eval_mode=True)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
@ -237,7 +237,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
result = test_collector.collect( result = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
result.pprint_asdict() result.pprint_asdict()

View File

@ -150,7 +150,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -168,7 +168,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -212,7 +212,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -147,7 +147,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -158,7 +158,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -269,7 +269,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -243,7 +243,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
collector_stats.pprint_asdict() collector_stats.pprint_asdict()

View File

@ -224,7 +224,7 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -173,7 +173,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -221,7 +221,7 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -229,7 +229,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -201,7 +201,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -201,7 +201,7 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -195,7 +195,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -193,7 +193,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -226,7 +226,7 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -191,7 +191,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
test_envs.seed(args.seed) test_envs.seed(args.seed)
print("Testing agent ...") print("Testing agent ...")
test_collector.reset() test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True)
result.pprint_asdict() result.pprint_asdict()
if args.watch: if args.watch:

View File

@ -175,7 +175,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
test_envs.seed(args.seed) test_envs.seed(args.seed)
print("Testing agent ...") print("Testing agent ...")
test_collector.reset() test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True)
result.pprint_asdict() result.pprint_asdict()
if args.watch: if args.watch:

View File

@ -191,7 +191,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
test_envs.seed(args.seed) test_envs.seed(args.seed)
print("Testing agent ...") print("Testing agent ...")
test_collector.reset() test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True)
result.pprint_asdict() result.pprint_asdict()
if args.watch: if args.watch:

View File

@ -148,7 +148,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None:
test_envs.seed(args.seed) test_envs.seed(args.seed)
print("Testing agent ...") print("Testing agent ...")
test_collector.reset() test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True)
result.pprint_asdict() result.pprint_asdict()
if args.watch: if args.watch:

View File

@ -207,7 +207,7 @@ def test_bcq() -> None:
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
collector = Collector(policy, env) collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35, is_eval=True) collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
if not args.watch: if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task) replay_buffer = load_buffer_d4rl(args.expert_data_task)
@ -233,7 +233,7 @@ def test_bcq() -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -345,7 +345,7 @@ def test_cql() -> None:
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
collector = Collector(policy, env) collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35, is_eval=True) collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
if not args.watch: if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task) replay_buffer = load_buffer_d4rl(args.expert_data_task)
@ -371,7 +371,7 @@ def test_cql() -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -143,7 +143,7 @@ def test_il() -> None:
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
collector = Collector(policy, env) collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35, is_eval=True) collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
if not args.watch: if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task) replay_buffer = load_buffer_d4rl(args.expert_data_task)
@ -169,7 +169,7 @@ def test_il() -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -192,7 +192,7 @@ def test_td3_bc() -> None:
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
collector = Collector(policy, env) collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35, is_eval=True) collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
if not args.watch: if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task) replay_buffer = load_buffer_d4rl(args.expert_data_task)
@ -221,7 +221,7 @@ def test_td3_bc() -> None:
collector_stats = test_collector.collect( collector_stats = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
print(collector_stats) print(collector_stats)

View File

@ -180,7 +180,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack, stack_num=args.frames_stack,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, is_eval=True) result = collector.collect(n_step=args.buffer_size, eval_mode=True)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
@ -190,7 +190,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
result = test_collector.collect( result = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
result.pprint_asdict() result.pprint_asdict()

View File

@ -242,7 +242,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack, stack_num=args.frames_stack,
) )
collector = Collector(policy, test_envs, buffer, exploration_noise=True) collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, is_eval=True) result = collector.collect(n_step=args.buffer_size, eval_mode=True)
print(f"Save buffer into {args.save_buffer_name}") print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size # Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name) buffer.save_hdf5(args.save_buffer_name)
@ -252,7 +252,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
result = test_collector.collect( result = test_collector.collect(
n_episode=args.test_num, n_episode=args.test_num,
render=args.render, render=args.render,
is_eval=True, eval_mode=True,
) )
result.pprint_asdict() result.pprint_asdict()

View File

@ -167,7 +167,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
policy.set_eps(0.2) policy.set_eps(0.2)
collector = Collector(policy, test_envs, buf, exploration_noise=True) collector = Collector(policy, test_envs, buf, exploration_noise=True)
collector.reset() collector.reset()
collector_stats = collector.collect(n_step=args.buffer_size, is_eval=True) collector_stats = collector.collect(n_step=args.buffer_size, eval_mode=True)
if args.save_buffer_name.endswith(".hdf5"): if args.save_buffer_name.endswith(".hdf5"):
buf.save_hdf5(args.save_buffer_name) buf.save_hdf5(args.save_buffer_name)
else: else:

View File

@ -185,7 +185,7 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None:
torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")), torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")),
) )
collector = Collector(policy, env) collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35, is_eval=True) collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
# trainer # trainer
result = OfflineTrainer( result = OfflineTrainer(

View File

@ -190,5 +190,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non
policy, _, _ = get_agents(args) policy, _, _ = get_agents(args)
[agent.set_eps(args.eps_test) for agent in policy.policies.values()] [agent.set_eps(args.eps_test) for agent in policy.policies.values()]
collector = Collector(policy, env, exploration_noise=True) collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=1, render=args.render, is_eval=True) result = collector.collect(n_episode=1, render=args.render, eval_mode=True)
result.pprint_asdict() result.pprint_asdict()

View File

@ -285,5 +285,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non
) )
policy, _, _ = get_agents(args) policy, _, _ = get_agents(args)
collector = Collector(policy, env) collector = Collector(policy, env)
collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True) collector_result = collector.collect(n_episode=1, render=args.render, eval_mode=True)
collector_result.pprint_asdict() collector_result.pprint_asdict()

View File

@ -230,5 +230,5 @@ def watch(
policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent) policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent)
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
collector = Collector(policy, env, exploration_noise=True) collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=1, render=args.render, is_eval=True) result = collector.collect(n_episode=1, render=args.render, eval_mode=True)
result.pprint_asdict() result.pprint_asdict()

View File

@ -1,5 +1,7 @@
import logging
import time import time
import warnings import warnings
from abc import ABC, abstractmethod
from copy import copy from copy import copy
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Self, TypeVar, cast from typing import Any, Self, TypeVar, cast
@ -7,11 +9,11 @@ from typing import Any, Self, TypeVar, cast
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
import torch import torch
from overrides import override
from tianshou.data import ( from tianshou.data import (
Batch, Batch,
CachedReplayBuffer, CachedReplayBuffer,
PrioritizedReplayBuffer,
ReplayBuffer, ReplayBuffer,
ReplayBufferManager, ReplayBufferManager,
SequenceSummaryStats, SequenceSummaryStats,
@ -25,6 +27,9 @@ from tianshou.data.types import (
from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.env import BaseVectorEnv, DummyVectorEnv
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.print import DataclassPPrintMixin
from tianshou.utils.torch_utils import in_eval_mode, in_train_mode
log = logging.getLogger(__name__)
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -122,23 +127,12 @@ def _HACKY_create_info_batch(info_array: np.ndarray) -> Batch:
return result_batch_parent.info return result_batch_parent.info
class Collector: class BaseCollector(ABC):
"""Collector enables the policy to interact with different types of envs with exact number of steps or episodes. """Used to collect data from a vector environment into a buffer using a given policy.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param env: a ``gym.Env`` environment or an instance of the
:class:`~tianshou.env.BaseVectorEnv` class.
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer`
as the default buffer.
:param exploration_noise: determine whether the action needs to be modified
with the corresponding policy's exploration noise. If so, "policy.
exploration_noise(act, batch)" will be called automatically to add the
exploration noise into action. Default to False.
.. note:: .. note::
Please make sure the given environment has a time limitation if using n_episode Please make sure the given environment has a time limitation if using `n_episode`
collect option. collect option.
.. note:: .. note::
@ -150,72 +144,70 @@ class Collector:
def __init__( def __init__(
self, self,
policy: BasePolicy, policy: BasePolicy,
env: gym.Env | BaseVectorEnv, env: BaseVectorEnv | gym.Env,
buffer: ReplayBuffer | None = None, buffer: ReplayBuffer | None = None,
exploration_noise: bool = False, exploration_noise: bool = False,
) -> None: ) -> None:
super().__init__()
if isinstance(env, gym.Env) and not hasattr(env, "__len__"): if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
warnings.warn("Single environment detected, wrap to DummyVectorEnv.") warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
# Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy # Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy
self.env = DummyVectorEnv([lambda: env]) env = DummyVectorEnv([lambda: env]) # type: ignore
else:
self.env = env # type: ignore if buffer is None:
self.env_num = len(self.env) buffer = VectorReplayBuffer(len(env), len(env))
self.exploration_noise = exploration_noise
self.buffer = self._assign_buffer(buffer) self.buffer: ReplayBuffer = buffer
self.policy = policy self.policy = policy
self._action_space = self.env.action_space self.env = cast(BaseVectorEnv, env)
self.exploration_noise = exploration_noise
self._pre_collect_obs_RO: np.ndarray | None = None
self._pre_collect_info_R: np.ndarray | None = None
self._pre_collect_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None
self._is_closed = False
self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0
self._action_space = self.env.action_space
self._is_closed = False
self._validate_buffer()
def _validate_buffer(self) -> None:
buf = self.buffer
# TODO: a bit weird but true - all VectorReplayBuffers inherit from ReplayBufferManager.
# We should probably rename the manager
if isinstance(buf, ReplayBufferManager) and buf.buffer_num < self.env_num:
raise ValueError(
f"Buffer has only {buf.buffer_num} buffers, but at least {self.env_num=} are needed.",
)
if isinstance(buf, CachedReplayBuffer) and buf.cached_buffer_num < self.env_num:
raise ValueError(
f"Buffer has only {buf.cached_buffer_num} cached buffers, but at least {self.env_num=} are needed.",
)
# Non-VectorReplayBuffer. TODO: probably shouldn't rely on isinstance
if not isinstance(buf, ReplayBufferManager):
if buf.maxsize == 0:
raise ValueError("Buffer maxsize should be greater than 0.")
if self.env_num > 1:
raise ValueError(
f"Cannot use {type(buf).__name__} to collect from multiple envs ({self.env_num=}). "
f"Please use the corresponding VectorReplayBuffer instead.",
)
@property
def env_num(self) -> int:
return len(self.env)
@property
def action_space(self) -> gym.spaces.Space:
return self._action_space
def close(self) -> None: def close(self) -> None:
"""Close the collector and the environment.""" """Close the collector and the environment."""
self.env.close() self.env.close()
self._pre_collect_obs_RO = None
self._pre_collect_info_R = None
self._is_closed = True self._is_closed = True
@property
def is_closed(self) -> bool:
"""Return True if the collector is closed."""
return self._is_closed
def _assign_buffer(self, buffer: ReplayBuffer | None) -> ReplayBuffer:
"""Check if the buffer matches the constraint."""
if buffer is None:
buffer = VectorReplayBuffer(self.env_num, self.env_num)
elif isinstance(buffer, ReplayBufferManager):
assert buffer.buffer_num >= self.env_num
if isinstance(buffer, CachedReplayBuffer):
assert buffer.cached_buffer_num >= self.env_num
else: # ReplayBuffer or PrioritizedReplayBuffer
assert buffer.maxsize > 0
if self.env_num > 1:
if isinstance(buffer, ReplayBuffer):
buffer_type = "ReplayBuffer"
vector_type = "VectorReplayBuffer"
if isinstance(buffer, PrioritizedReplayBuffer):
buffer_type = "PrioritizedReplayBuffer"
vector_type = "PrioritizedVectorReplayBuffer"
raise TypeError(
f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect "
f"{self.env_num} envs,\n\tplease use {vector_type}(total_size="
f"{buffer.maxsize}, buffer_num={self.env_num}, ...) instead.",
)
return buffer
def reset( def reset(
self, self,
reset_buffer: bool = True, reset_buffer: bool = True,
reset_stats: bool = True, reset_stats: bool = True,
gym_reset_kwargs: dict[str, Any] | None = None, gym_reset_kwargs: dict[str, Any] | None = None,
) -> None: ) -> tuple[np.ndarray, np.ndarray]:
"""Reset the environment, statistics, and data needed to start the collection. """Reset the environment, statistics, and data needed to start the collection.
:param reset_buffer: if true, reset the replay buffer attached :param reset_buffer: if true, reset the replay buffer attached
@ -224,12 +216,13 @@ class Collector:
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's :param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Defaults to None (extra keyword arguments) reset function. Defaults to None (extra keyword arguments)
""" """
self.reset_env(gym_reset_kwargs=gym_reset_kwargs) obs_NO, info_N = self.reset_env(gym_reset_kwargs=gym_reset_kwargs)
if reset_buffer: if reset_buffer:
self.reset_buffer() self.reset_buffer()
if reset_stats: if reset_stats:
self.reset_stat() self.reset_stat()
self._is_closed = False self._is_closed = False
return obs_NO, info_N
def reset_stat(self) -> None: def reset_stat(self) -> None:
"""Reset the statistic variables.""" """Reset the statistic variables."""
@ -242,18 +235,168 @@ class Collector:
def reset_env( def reset_env(
self, self,
gym_reset_kwargs: dict[str, Any] | None = None, gym_reset_kwargs: dict[str, Any] | None = None,
) -> None: ) -> tuple[np.ndarray, np.ndarray]:
"""Reset the environments and the initial obs, info, and hidden state of the collector.""" """Reset the environments and the initial obs, info, and hidden state of the collector."""
gym_reset_kwargs = gym_reset_kwargs or {} gym_reset_kwargs = gym_reset_kwargs or {}
self._pre_collect_obs_RO, self._pre_collect_info_R = self.env.reset(**gym_reset_kwargs) obs_NO, info_N = self.env.reset(**gym_reset_kwargs)
# TODO: hack, wrap envpool envs such that they don't return a dict # TODO: hack, wrap envpool envs such that they don't return a dict
if isinstance(self._pre_collect_info_R, dict): # type: ignore[unreachable] if isinstance(info_N, dict): # type: ignore[unreachable]
# this can happen if the env is an envpool env. Then the thing returned by reset is a dict # this can happen if the env is an envpool env. Then the thing returned by reset is a dict
# with array entries instead of an array of dicts # with array entries instead of an array of dicts
# We use Batch to turn it into an array of dicts # We use Batch to turn it into an array of dicts
self._pre_collect_info_R = _dict_of_arr_to_arr_of_dicts(self._pre_collect_info_R) # type: ignore[unreachable] info_N = _dict_of_arr_to_arr_of_dicts(info_N) # type: ignore[unreachable]
return obs_NO, info_N
@abstractmethod
def _collect(
self,
n_step: int | None = None,
n_episode: int | None = None,
random: bool = False,
render: float | None = None,
no_grad: bool = True,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> CollectStats:
pass
def collect(
self,
n_step: int | None = None,
n_episode: int | None = None,
random: bool = False,
render: float | None = None,
no_grad: bool = True,
reset_before_collect: bool = False,
gym_reset_kwargs: dict[str, Any] | None = None,
eval_mode: bool = False,
) -> CollectStats:
"""Collect a specified number of steps or episodes.
To ensure an unbiased sampling result with the n_episode option, this function will
first collect ``n_episode - env_num`` episodes, then for the last ``env_num``
episodes, they will be collected evenly from each env.
:param n_step: how many steps you want to collect.
:param n_episode: how many episodes you want to collect.
:param random: whether to use random policy for collecting data.
:param render: the sleep time between rendering consecutive frames.
:param no_grad: whether to retain gradient in policy.forward().
:param reset_before_collect: whether to reset the environment before collecting data.
(The collector needs the initial obs and info to function properly.)
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Only used if reset_before_collect is True.
:param eval_mode: whether to collect data in evaluation mode. Will
set the policy to training mode otherwise.
.. note::
One and only one collection number specification is permitted, either
``n_step`` or ``n_episode``.
:return: The collected stats
"""
# check that exactly one of n_step or n_episode is set and that the other is larger than 0
self._validate_n_step_n_episode(n_episode, n_step)
if reset_before_collect:
self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs)
policy_mode_context = in_eval_mode if eval_mode else in_train_mode
with policy_mode_context(self.policy):
return self._collect(
n_step=n_step,
n_episode=n_episode,
random=random,
render=render,
no_grad=no_grad,
gym_reset_kwargs=gym_reset_kwargs,
)
def _validate_n_step_n_episode(self, n_episode: int | None, n_step: int | None) -> None:
if not n_step and not n_episode:
raise ValueError(
f"Only one of n_step and n_episode should be set to a value larger than zero "
f"but got {n_step=}, {n_episode=}.",
)
if n_step is None and n_episode is None:
raise ValueError(
"Exactly one of n_step and n_episode should be set but got None for both.",
)
if n_step and n_step % self.env_num != 0:
warnings.warn(
f"{n_step=} is not a multiple of ({self.env_num=}), "
"which may cause extra transitions being collected into the buffer.",
)
if n_episode and self.env_num > n_episode:
warnings.warn(
f"{n_episode=} should be larger than {self.env_num=} to "
f"collect at least one trajectory in each environment.",
)
class Collector(BaseCollector):
# NAMING CONVENTION (mostly suffixes):
# episode - An episode means a rollout until done (terminated or truncated). After an episode is completed,
# the corresponding env is either reset or removed from the ready envs.
# N - number of envs, always fixed and >= R.
# R - number ready env ids. Note that this might change when envs get idle.
# This can only happen in n_episode case, see explanation in the corresponding block.
# For n_step, we always use all envs to collect the data, while for n_episode,
# R will be at most n_episode at the beginning, but can decrease during the collection.
# O - dimension(s) of observations
# A - dimension(s) of actions
# H - dimension(s) of hidden state
# D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case.
# S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration.
# Only used in n_episode case. Then, R becomes R-S.
# set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy
# evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on
# policy.deterministic_eval)
def __init__(
self,
policy: BasePolicy,
env: gym.Env | BaseVectorEnv,
buffer: ReplayBuffer | None = None,
exploration_noise: bool = False,
) -> None:
""":param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param env: a ``gym.Env`` environment or an instance of the
:class:`~tianshou.env.BaseVectorEnv` class.
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer`
as the default buffer.
:param exploration_noise: determine whether the action needs to be modified
with the corresponding policy's exploration noise. If so, "policy.
exploration_noise(act, batch)" will be called automatically to add the
exploration noise into action. Default to False.
"""
super().__init__(policy, env, buffer, exploration_noise=exploration_noise)
self._pre_collect_obs_RO: np.ndarray | None = None
self._pre_collect_info_R: np.ndarray | None = None
self._pre_collect_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None
self._is_closed = False
self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0
def close(self) -> None:
super().close()
self._pre_collect_obs_RO = None
self._pre_collect_info_R = None
def reset_env(
self,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Reset the environments and the initial obs, info, and hidden state of the collector."""
obs_NO, info_N = super().reset_env(gym_reset_kwargs=gym_reset_kwargs)
# We assume that R = N when reset is called.
# TODO: there is currently no mechanism that ensures this and it's a public method!
self._pre_collect_obs_RO = obs_NO
self._pre_collect_info_R = info_N
self._pre_collect_hidden_state_RH = None self._pre_collect_hidden_state_RH = None
return obs_NO, info_N
def _compute_action_policy_hidden( def _compute_action_policy_hidden(
self, self,
@ -309,98 +452,30 @@ class Collector:
return act_RA, act_normalized_RA, policy_R, hidden_state_RH return act_RA, act_normalized_RA, policy_R, hidden_state_RH
# TODO: reduce complexity, remove the noqa # TODO: reduce complexity, remove the noqa
def collect( def _collect(
self, self,
n_step: int | None = None, n_step: int | None = None,
n_episode: int | None = None, n_episode: int | None = None,
random: bool = False, random: bool = False,
render: float | None = None, render: float | None = None,
no_grad: bool = True, no_grad: bool = True,
reset_before_collect: bool = False,
gym_reset_kwargs: dict[str, Any] | None = None, gym_reset_kwargs: dict[str, Any] | None = None,
is_eval: bool = False,
) -> CollectStats: ) -> CollectStats:
"""Collect a specified number of steps or episodes. # TODO: can't do it init since AsyncCollector is currently a subclass of Collector
if self.env.is_async:
To ensure an unbiased sampling result with the n_episode option, this function will raise ValueError(
first collect ``n_episode - env_num`` episodes, then for the last ``env_num`` f"Please use {AsyncCollector.__name__} for asynchronous environments. "
episodes, they will be collected evenly from each env. f"Env class: {self.env.__class__.__name__}.",
:param n_step: how many steps you want to collect.
:param n_episode: how many episodes you want to collect.
:param random: whether to use random policy for collecting data.
:param render: the sleep time between rendering consecutive frames.
:param no_grad: whether to retain gradient in policy.forward().
:param reset_before_collect: whether to reset the environment before collecting data.
(The collector needs the initial obs and info to function properly.)
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Only used if reset_before_collect is True.
:param is_eval: whether to collect data in evaluation mode.
.. note::
One and only one collection number specification is permitted, either
``n_step`` or ``n_episode``.
:return: The collected stats
"""
# NAMING CONVENTION (mostly suffixes):
# episode - An episode means a rollout until done (terminated or truncated). After an episode is completed,
# the corresponding env is either reset or removed from the ready envs.
# R - number ready env ids. Note that this might change when envs get idle.
# This can only happen in n_episode case, see explanation in the corresponding block.
# For n_step, we always use all envs to collect the data, while for n_episode,
# R will be at most n_episode at the beginning, but can decrease during the collection.
# O - dimension(s) of observations
# A - dimension(s) of actions
# H - dimension(s) of hidden state
# D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case.
# S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration.
# Only used in n_episode case. Then, R becomes R-S.
# set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy
# evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on
# policy.deterministic_eval)
self.policy.eval()
pre_collect_is_eval = self.policy.is_eval
self.policy.is_eval = is_eval
use_grad = not no_grad
gym_reset_kwargs = gym_reset_kwargs or {}
# Input validation
assert not self.env.is_async, "Please use AsyncCollector if using async venv."
if n_step is not None:
assert n_episode is None, (
f"Only one of n_step or n_episode is allowed in Collector."
f"collect, got {n_step=}, {n_episode=}."
) )
assert n_step > 0
if n_step % self.env_num != 0: if n_step is not None:
warnings.warn(
f"{n_step=} is not a multiple of ({self.env_num=}), "
"which may cause extra transitions being collected into the buffer.",
)
ready_env_ids_R = np.arange(self.env_num) ready_env_ids_R = np.arange(self.env_num)
elif n_episode is not None: elif n_episode is not None:
assert n_episode > 0
if self.env_num > n_episode:
warnings.warn(
f"{n_episode=} should be larger than {self.env_num=} to "
f"collect at least one trajectory in each environment.",
)
ready_env_ids_R = np.arange(min(self.env_num, n_episode)) ready_env_ids_R = np.arange(min(self.env_num, n_episode))
else:
raise TypeError( use_grad = not no_grad
"Please specify at least one (either n_step or n_episode) "
"in AsyncCollector.collect().",
)
start_time = time.time() start_time = time.time()
if reset_before_collect:
self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs)
if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None:
raise ValueError( raise ValueError(
"Initial obs and info should not be None. " "Initial obs and info should not be None. "
@ -491,7 +566,8 @@ class Collector:
step_count += len(ready_env_ids_R) step_count += len(ready_env_ids_R)
# preparing for the next iteration # preparing for the next iteration
# obs_next, info and hidden_state will be modified inplace in the code below, so we copy to not affect the data in the buffer # obs_next, info and hidden_state will be modified inplace in the code below,
# so we copy to not affect the data in the buffer
last_obs_RO = copy(obs_next_RO) last_obs_RO = copy(obs_next_RO)
last_info_R = copy(info_R) last_info_R = copy(info_R)
last_hidden_state_RH = copy(hidden_state_RH) last_hidden_state_RH = copy(hidden_state_RH)
@ -509,6 +585,7 @@ class Collector:
# now we copy obs_next to obs, but since there might be # now we copy obs_next to obs, but since there might be
# finished episodes, we have to reset finished envs first. # finished episodes, we have to reset finished envs first.
gym_reset_kwargs = gym_reset_kwargs or {}
obs_reset_DO, info_reset_D = self.env.reset( obs_reset_DO, info_reset_D = self.env.reset(
env_id=env_ind_global_D, env_id=env_ind_global_D,
**gym_reset_kwargs, **gym_reset_kwargs,
@ -577,9 +654,6 @@ class Collector:
# reset envs and the _pre_collect fields # reset envs and the _pre_collect fields
self.reset_env(gym_reset_kwargs) # todo still necessary? self.reset_env(gym_reset_kwargs) # todo still necessary?
# set the policy back to pre collect mode
self.policy.is_eval = pre_collect_is_eval
return CollectStats.with_autogenerated_stats( return CollectStats.with_autogenerated_stats(
returns=np.array(episode_returns), returns=np.array(episode_returns),
lens=np.array(episode_lens), lens=np.array(episode_lens),
@ -608,8 +682,7 @@ class Collector:
class AsyncCollector(Collector): class AsyncCollector(Collector):
"""Async Collector handles async vector environment. """Async Collector handles async vector environment.
The arguments are exactly the same as :class:`~tianshou.data.Collector`, please Please refer to :class:`~tianshou.data.Collector` for a more detailed explanation.
refer to :class:`~tianshou.data.Collector` for more detailed explanation.
""" """
def __init__( def __init__(
@ -619,6 +692,12 @@ class AsyncCollector(Collector):
buffer: ReplayBuffer | None = None, buffer: ReplayBuffer | None = None,
exploration_noise: bool = False, exploration_noise: bool = False,
) -> None: ) -> None:
if not env.is_async:
# TODO: raise an exception?
log.error(
f"Please use {Collector.__name__} if not using async venv. "
f"Env class: {env.__class__.__name__}",
)
# assert env.is_async # assert env.is_async
warnings.warn("Using async setting may collect extra transitions into buffer.") warnings.warn("Using async setting may collect extra transitions into buffer.")
super().__init__( super().__init__(
@ -644,7 +723,7 @@ class AsyncCollector(Collector):
reset_buffer: bool = True, reset_buffer: bool = True,
reset_stats: bool = True, reset_stats: bool = True,
gym_reset_kwargs: dict[str, Any] | None = None, gym_reset_kwargs: dict[str, Any] | None = None,
) -> None: ) -> tuple[np.ndarray, np.ndarray]:
"""Reset the environment, statistics, and data needed to start the collection. """Reset the environment, statistics, and data needed to start the collection.
:param reset_buffer: if true, reset the replay buffer attached :param reset_buffer: if true, reset the replay buffer attached
@ -654,7 +733,7 @@ class AsyncCollector(Collector):
reset function. Defaults to None (extra keyword arguments) reset function. Defaults to None (extra keyword arguments)
""" """
# This sets the _pre_collect attrs # This sets the _pre_collect attrs
super().reset( result = super().reset(
reset_buffer=reset_buffer, reset_buffer=reset_buffer,
reset_stats=reset_stats, reset_stats=reset_stats,
gym_reset_kwargs=gym_reset_kwargs, gym_reset_kwargs=gym_reset_kwargs,
@ -667,78 +746,29 @@ class AsyncCollector(Collector):
self._current_hidden_state_in_all_envs_EH = copy(self._pre_collect_hidden_state_RH) self._current_hidden_state_in_all_envs_EH = copy(self._pre_collect_hidden_state_RH)
self._current_action_in_all_envs_EA = np.empty(self.env_num) self._current_action_in_all_envs_EA = np.empty(self.env_num)
self._current_policy_in_all_envs_E = None self._current_policy_in_all_envs_E = None
return result
def collect( @override
def reset_env(
self,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, np.ndarray]:
# we need to step through the envs and wait until they are ready to be able to interact with them
if self.env.waiting_id:
self.env.step(None, id=self.env.waiting_id)
return super().reset_env(gym_reset_kwargs=gym_reset_kwargs)
@override
def _collect(
self, self,
n_step: int | None = None, n_step: int | None = None,
n_episode: int | None = None, n_episode: int | None = None,
random: bool = False, random: bool = False,
render: float | None = None, render: float | None = None,
no_grad: bool = True, no_grad: bool = True,
reset_before_collect: bool = False,
gym_reset_kwargs: dict[str, Any] | None = None, gym_reset_kwargs: dict[str, Any] | None = None,
is_eval: bool = False,
) -> CollectStats: ) -> CollectStats:
"""Collect a specified number of steps or episodes with async env setting.
This function does not collect an exact number of transitions specified by n_step or
n_episode. Instead, to support the asynchronous setting, it may collect more transitions
than requested by n_step or n_episode and save them into the buffer.
:param n_step: how many steps you want to collect.
:param n_episode: how many episodes you want to collect.
:param random: whether to use random policy_R for collecting data. Default
to False.
:param render: the sleep time between rendering consecutive frames.
Default to None (no rendering).
:param no_grad: whether to retain gradient in policy_R.forward(). Default to
True (no gradient retaining).
:param reset_before_collect: whether to reset the environment before
collecting data. It has only an effect if n_episode is not None, i.e.
if one wants to collect a fixed number of episodes.
(The collector needs the initial obs and info to function properly.)
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Defaults to None (extra keyword arguments)
:param is_eval: whether to collect data in evaluation mode.
.. note::
One and only one collection number specification is permitted, either
``n_step`` or ``n_episode``.
:return: A dataclass object
"""
# set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy
# evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on
# policy.deterministic_eval)
self.policy.eval()
pre_collect_is_eval = self.policy.is_eval
self.policy.is_eval = is_eval
use_grad = not no_grad use_grad = not no_grad
gym_reset_kwargs = gym_reset_kwargs or {}
# collect at least n_step or n_episode
if n_step is not None:
assert n_episode is None, (
"Only one of n_step or n_episode is allowed in Collector."
f"collect, got n_step={n_step}, n_episode={n_episode}."
)
assert n_step > 0
elif n_episode is not None:
assert n_episode > 0
else:
raise TypeError(
"Please specify at least one (either n_step or n_episode) "
"in AsyncCollector.collect().",
)
if reset_before_collect:
# first we need to step all envs to be able to interact with them
if self.env.waiting_id:
self.env.step(None, id=self.env.waiting_id)
self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs)
start_time = time.time() start_time = time.time()
step_count = 0 step_count = 0
@ -868,12 +898,12 @@ class AsyncCollector(Collector):
num_collected_episodes += num_episodes_done_this_iter num_collected_episodes += num_episodes_done_this_iter
# preparing for the next iteration # preparing for the next iteration
# todo do we need the copy stuff (tests pass also without)
# todo seem we can get rid of this last_sth stuff altogether # todo seem we can get rid of this last_sth stuff altogether
last_obs_RO = copy(obs_next_RO) last_obs_RO = copy(obs_next_RO)
last_info_R = copy(info_R) last_info_R = copy(info_R)
last_hidden_state_RH = copy(self._current_hidden_state_in_all_envs_EH[ready_env_ids_R]) # type: ignore[index] last_hidden_state_RH = copy(
self._current_hidden_state_in_all_envs_EH[ready_env_ids_R], # type: ignore[index]
)
if num_episodes_done_this_iter: if num_episodes_done_this_iter:
env_ind_local_D = np.where(done_R)[0] env_ind_local_D = np.where(done_R)[0]
env_ind_global_D = ready_env_ids_R[env_ind_local_D] env_ind_global_D = ready_env_ids_R[env_ind_local_D]
@ -883,6 +913,7 @@ class AsyncCollector(Collector):
# now we copy obs_next_RO to obs, but since there might be # now we copy obs_next_RO to obs, but since there might be
# finished episodes, we have to reset finished envs first. # finished episodes, we have to reset finished envs first.
gym_reset_kwargs = gym_reset_kwargs or {}
obs_reset_DO, info_reset_D = self.env.reset( obs_reset_DO, info_reset_D = self.env.reset(
env_id=env_ind_global_D, env_id=env_ind_global_D,
**gym_reset_kwargs, **gym_reset_kwargs,
@ -923,9 +954,6 @@ class AsyncCollector(Collector):
# persist for future collect iterations # persist for future collect iterations
self._ready_env_ids_R = ready_env_ids_R self._ready_env_ids_R = ready_env_ids_R
# set the policy back to pre collect mode
self.policy.is_eval = pre_collect_is_eval
return CollectStats.with_autogenerated_stats( return CollectStats.with_autogenerated_stats(
returns=np.array(episode_returns), returns=np.array(episode_returns),
lens=np.array(episode_lens), lens=np.array(episode_lens),

View File

@ -337,7 +337,7 @@ class Experiment(ToStringMixin):
) -> None: ) -> None:
collector = Collector(policy, env) collector = Collector(policy, env)
collector.reset() collector.reset()
result = collector.collect(n_episode=num_episodes, render=render, is_eval=True) result = collector.collect(n_episode=num_episodes, render=render, eval_mode=True)
assert result.returns_stat is not None # for mypy assert result.returns_stat is not None # for mypy
assert result.lens_stat is not None # for mypy assert result.lens_stat is not None # for mypy
log.info( log.info(