update examples

This commit is contained in:
Maximilian Huettenrauch 2024-04-24 17:06:54 +02:00
parent e499bed8b0
commit 8cb17de190
35 changed files with 168 additions and 87 deletions

View File

@ -162,7 +162,6 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
@ -175,14 +174,18 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, is_eval=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
result.pprint_asdict()
if args.watch:

View File

@ -204,7 +204,6 @@ def main(args: argparse.Namespace = get_args()) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
@ -217,14 +216,18 @@ def main(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, is_eval=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
result.pprint_asdict()
if args.watch:

View File

@ -175,7 +175,6 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
@ -188,14 +187,18 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, is_eval=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
result.pprint_asdict()
if args.watch:

View File

@ -172,7 +172,6 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
@ -185,14 +184,18 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, is_eval=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
result.pprint_asdict()
if args.watch:

View File

@ -229,7 +229,6 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
test_envs.seed(args.seed)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
@ -241,14 +240,18 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, is_eval=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
result.pprint_asdict()
if args.watch:

View File

@ -166,7 +166,6 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
@ -179,14 +178,18 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, is_eval=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
result.pprint_asdict()
if args.watch:

View File

@ -200,7 +200,6 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
@ -215,14 +214,18 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
beta=args.beta,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, is_eval=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
result.pprint_asdict()
if args.watch:

View File

@ -216,7 +216,6 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
test_envs.seed(args.seed)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
@ -228,14 +227,18 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, is_eval=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
result.pprint_asdict()
if args.watch:

View File

@ -144,11 +144,14 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -162,11 +162,14 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -207,10 +207,13 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None:
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -141,11 +141,14 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -153,10 +153,13 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -80,7 +80,6 @@ def main() -> None:
print(f"Finished training in {result.timing.total_time} seconds")
# watch performance
policy.eval()
policy.set_eps(eps_test)
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=100, render=1 / 35)

View File

@ -264,10 +264,13 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -238,10 +238,13 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
collector_stats.pprint_asdict()

View File

@ -219,10 +219,13 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -168,10 +168,13 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -216,10 +216,13 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -224,10 +224,13 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -196,10 +196,13 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -196,10 +196,13 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -190,10 +190,13 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -188,10 +188,13 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -221,10 +221,13 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -187,12 +187,11 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
result.pprint_asdict()
if args.watch:

View File

@ -171,12 +171,11 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
result.pprint_asdict()
if args.watch:

View File

@ -188,11 +188,10 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
test_envs.seed(args.seed)
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
result.pprint_asdict()
if args.watch:

View File

@ -145,11 +145,10 @@ def test_il(args: argparse.Namespace = get_args()) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
test_envs.seed(args.seed)
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
result.pprint_asdict()
if args.watch:

View File

@ -206,9 +206,8 @@ def test_bcq() -> None:
args.resume_path = os.path.join(log_path, "policy.pth")
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
policy.eval()
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task)
@ -229,10 +228,13 @@ def test_bcq() -> None:
watch()
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -344,9 +344,8 @@ def test_cql() -> None:
args.resume_path = os.path.join(log_path, "policy.pth")
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
policy.eval()
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task)
@ -367,10 +366,13 @@ def test_cql() -> None:
watch()
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -142,9 +142,8 @@ def test_il() -> None:
args.resume_path = os.path.join(log_path, "policy.pth")
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
policy.eval()
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task)
@ -165,10 +164,13 @@ def test_il() -> None:
watch()
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -191,9 +191,8 @@ def test_td3_bc() -> None:
args.resume_path = os.path.join(log_path, "policy.pth")
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
policy.eval()
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task)
@ -217,10 +216,13 @@ def test_td3_bc() -> None:
watch()
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
print(collector_stats)

View File

@ -168,7 +168,6 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
@ -181,14 +180,18 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, is_eval=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
result.pprint_asdict()
if args.watch:

View File

@ -231,7 +231,6 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.eval()
test_envs.seed(args.seed)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
@ -243,14 +242,18 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
result = collector.collect(n_step=args.buffer_size, is_eval=True)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
result.pprint_asdict()
if args.watch: