update examples
This commit is contained in:
parent
e499bed8b0
commit
8cb17de190
@ -162,7 +162,6 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
|||||||
# watch agent's performance
|
# watch agent's performance
|
||||||
def watch() -> None:
|
def watch() -> None:
|
||||||
print("Setup test envs ...")
|
print("Setup test envs ...")
|
||||||
policy.eval()
|
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
if args.save_buffer_name:
|
if args.save_buffer_name:
|
||||||
@ -175,14 +174,18 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
|||||||
stack_num=args.frames_stack,
|
stack_num=args.frames_stack,
|
||||||
)
|
)
|
||||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||||
result = collector.collect(n_step=args.buffer_size)
|
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||||
print(f"Save buffer into {args.save_buffer_name}")
|
print(f"Save buffer into {args.save_buffer_name}")
|
||||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||||
buffer.save_hdf5(args.save_buffer_name)
|
buffer.save_hdf5(args.save_buffer_name)
|
||||||
else:
|
else:
|
||||||
print("Testing agent ...")
|
print("Testing agent ...")
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
result = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
result.pprint_asdict()
|
result.pprint_asdict()
|
||||||
|
|
||||||
if args.watch:
|
if args.watch:
|
||||||
|
@ -204,7 +204,6 @@ def main(args: argparse.Namespace = get_args()) -> None:
|
|||||||
# watch agent's performance
|
# watch agent's performance
|
||||||
def watch() -> None:
|
def watch() -> None:
|
||||||
print("Setup test envs ...")
|
print("Setup test envs ...")
|
||||||
policy.eval()
|
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
if args.save_buffer_name:
|
if args.save_buffer_name:
|
||||||
@ -217,14 +216,18 @@ def main(args: argparse.Namespace = get_args()) -> None:
|
|||||||
stack_num=args.frames_stack,
|
stack_num=args.frames_stack,
|
||||||
)
|
)
|
||||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||||
result = collector.collect(n_step=args.buffer_size)
|
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||||
print(f"Save buffer into {args.save_buffer_name}")
|
print(f"Save buffer into {args.save_buffer_name}")
|
||||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||||
buffer.save_hdf5(args.save_buffer_name)
|
buffer.save_hdf5(args.save_buffer_name)
|
||||||
else:
|
else:
|
||||||
print("Testing agent ...")
|
print("Testing agent ...")
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
result = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
result.pprint_asdict()
|
result.pprint_asdict()
|
||||||
|
|
||||||
if args.watch:
|
if args.watch:
|
||||||
|
@ -175,7 +175,6 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
|
|||||||
# watch agent's performance
|
# watch agent's performance
|
||||||
def watch() -> None:
|
def watch() -> None:
|
||||||
print("Setup test envs ...")
|
print("Setup test envs ...")
|
||||||
policy.eval()
|
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
if args.save_buffer_name:
|
if args.save_buffer_name:
|
||||||
@ -188,14 +187,18 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
|
|||||||
stack_num=args.frames_stack,
|
stack_num=args.frames_stack,
|
||||||
)
|
)
|
||||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||||
result = collector.collect(n_step=args.buffer_size)
|
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||||
print(f"Save buffer into {args.save_buffer_name}")
|
print(f"Save buffer into {args.save_buffer_name}")
|
||||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||||
buffer.save_hdf5(args.save_buffer_name)
|
buffer.save_hdf5(args.save_buffer_name)
|
||||||
else:
|
else:
|
||||||
print("Testing agent ...")
|
print("Testing agent ...")
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
result = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
result.pprint_asdict()
|
result.pprint_asdict()
|
||||||
|
|
||||||
if args.watch:
|
if args.watch:
|
||||||
|
@ -172,7 +172,6 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
|
|||||||
# watch agent's performance
|
# watch agent's performance
|
||||||
def watch() -> None:
|
def watch() -> None:
|
||||||
print("Setup test envs ...")
|
print("Setup test envs ...")
|
||||||
policy.eval()
|
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
if args.save_buffer_name:
|
if args.save_buffer_name:
|
||||||
@ -185,14 +184,18 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
|
|||||||
stack_num=args.frames_stack,
|
stack_num=args.frames_stack,
|
||||||
)
|
)
|
||||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||||
result = collector.collect(n_step=args.buffer_size)
|
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||||
print(f"Save buffer into {args.save_buffer_name}")
|
print(f"Save buffer into {args.save_buffer_name}")
|
||||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||||
buffer.save_hdf5(args.save_buffer_name)
|
buffer.save_hdf5(args.save_buffer_name)
|
||||||
else:
|
else:
|
||||||
print("Testing agent ...")
|
print("Testing agent ...")
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
result = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
result.pprint_asdict()
|
result.pprint_asdict()
|
||||||
|
|
||||||
if args.watch:
|
if args.watch:
|
||||||
|
@ -229,7 +229,6 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
|||||||
# watch agent's performance
|
# watch agent's performance
|
||||||
def watch() -> None:
|
def watch() -> None:
|
||||||
print("Setup test envs ...")
|
print("Setup test envs ...")
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
if args.save_buffer_name:
|
if args.save_buffer_name:
|
||||||
print(f"Generate buffer with size {args.buffer_size}")
|
print(f"Generate buffer with size {args.buffer_size}")
|
||||||
@ -241,14 +240,18 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
|||||||
stack_num=args.frames_stack,
|
stack_num=args.frames_stack,
|
||||||
)
|
)
|
||||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||||
result = collector.collect(n_step=args.buffer_size)
|
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||||
print(f"Save buffer into {args.save_buffer_name}")
|
print(f"Save buffer into {args.save_buffer_name}")
|
||||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||||
buffer.save_hdf5(args.save_buffer_name)
|
buffer.save_hdf5(args.save_buffer_name)
|
||||||
else:
|
else:
|
||||||
print("Testing agent ...")
|
print("Testing agent ...")
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
result = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
result.pprint_asdict()
|
result.pprint_asdict()
|
||||||
|
|
||||||
if args.watch:
|
if args.watch:
|
||||||
|
@ -166,7 +166,6 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
|
|||||||
# watch agent's performance
|
# watch agent's performance
|
||||||
def watch() -> None:
|
def watch() -> None:
|
||||||
print("Setup test envs ...")
|
print("Setup test envs ...")
|
||||||
policy.eval()
|
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
if args.save_buffer_name:
|
if args.save_buffer_name:
|
||||||
@ -179,14 +178,18 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
|
|||||||
stack_num=args.frames_stack,
|
stack_num=args.frames_stack,
|
||||||
)
|
)
|
||||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||||
result = collector.collect(n_step=args.buffer_size)
|
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||||
print(f"Save buffer into {args.save_buffer_name}")
|
print(f"Save buffer into {args.save_buffer_name}")
|
||||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||||
buffer.save_hdf5(args.save_buffer_name)
|
buffer.save_hdf5(args.save_buffer_name)
|
||||||
else:
|
else:
|
||||||
print("Testing agent ...")
|
print("Testing agent ...")
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
result = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
result.pprint_asdict()
|
result.pprint_asdict()
|
||||||
|
|
||||||
if args.watch:
|
if args.watch:
|
||||||
|
@ -200,7 +200,6 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
|
|||||||
# watch agent's performance
|
# watch agent's performance
|
||||||
def watch() -> None:
|
def watch() -> None:
|
||||||
print("Setup test envs ...")
|
print("Setup test envs ...")
|
||||||
policy.eval()
|
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
if args.save_buffer_name:
|
if args.save_buffer_name:
|
||||||
@ -215,14 +214,18 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
|
|||||||
beta=args.beta,
|
beta=args.beta,
|
||||||
)
|
)
|
||||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||||
result = collector.collect(n_step=args.buffer_size)
|
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||||
print(f"Save buffer into {args.save_buffer_name}")
|
print(f"Save buffer into {args.save_buffer_name}")
|
||||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||||
buffer.save_hdf5(args.save_buffer_name)
|
buffer.save_hdf5(args.save_buffer_name)
|
||||||
else:
|
else:
|
||||||
print("Testing agent ...")
|
print("Testing agent ...")
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
result = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
result.pprint_asdict()
|
result.pprint_asdict()
|
||||||
|
|
||||||
if args.watch:
|
if args.watch:
|
||||||
|
@ -216,7 +216,6 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
|
|||||||
# watch agent's performance
|
# watch agent's performance
|
||||||
def watch() -> None:
|
def watch() -> None:
|
||||||
print("Setup test envs ...")
|
print("Setup test envs ...")
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
if args.save_buffer_name:
|
if args.save_buffer_name:
|
||||||
print(f"Generate buffer with size {args.buffer_size}")
|
print(f"Generate buffer with size {args.buffer_size}")
|
||||||
@ -228,14 +227,18 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
|
|||||||
stack_num=args.frames_stack,
|
stack_num=args.frames_stack,
|
||||||
)
|
)
|
||||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||||
result = collector.collect(n_step=args.buffer_size)
|
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||||
print(f"Save buffer into {args.save_buffer_name}")
|
print(f"Save buffer into {args.save_buffer_name}")
|
||||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||||
buffer.save_hdf5(args.save_buffer_name)
|
buffer.save_hdf5(args.save_buffer_name)
|
||||||
else:
|
else:
|
||||||
print("Testing agent ...")
|
print("Testing agent ...")
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
result = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
result.pprint_asdict()
|
result.pprint_asdict()
|
||||||
|
|
||||||
if args.watch:
|
if args.watch:
|
||||||
|
@ -144,11 +144,14 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -162,11 +162,14 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -207,10 +207,13 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None:
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,11 +141,14 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -153,10 +153,13 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,7 +80,6 @@ def main() -> None:
|
|||||||
print(f"Finished training in {result.timing.total_time} seconds")
|
print(f"Finished training in {result.timing.total_time} seconds")
|
||||||
|
|
||||||
# watch performance
|
# watch performance
|
||||||
policy.eval()
|
|
||||||
policy.set_eps(eps_test)
|
policy.set_eps(eps_test)
|
||||||
collector = ts.data.Collector(policy, env, exploration_noise=True)
|
collector = ts.data.Collector(policy, env, exploration_noise=True)
|
||||||
collector.collect(n_episode=100, render=1 / 35)
|
collector.collect(n_episode=100, render=1 / 35)
|
||||||
|
@ -264,10 +264,13 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
|
|||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -238,10 +238,13 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
|
|||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
collector_stats.pprint_asdict()
|
collector_stats.pprint_asdict()
|
||||||
|
|
||||||
|
|
||||||
|
@ -219,10 +219,13 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None:
|
|||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -168,10 +168,13 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
|
|||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -216,10 +216,13 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
|
|||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -224,10 +224,13 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
|||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -196,10 +196,13 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
|
|||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -196,10 +196,13 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None:
|
|||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -190,10 +190,13 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
|
|||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -188,10 +188,13 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
|
|||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -221,10 +221,13 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
|
|||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -187,12 +187,11 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
|
|||||||
# watch agent's performance
|
# watch agent's performance
|
||||||
def watch() -> None:
|
def watch() -> None:
|
||||||
print("Setup test envs ...")
|
print("Setup test envs ...")
|
||||||
policy.eval()
|
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
print("Testing agent ...")
|
print("Testing agent ...")
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
|
||||||
result.pprint_asdict()
|
result.pprint_asdict()
|
||||||
|
|
||||||
if args.watch:
|
if args.watch:
|
||||||
|
@ -171,12 +171,11 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
|
|||||||
# watch agent's performance
|
# watch agent's performance
|
||||||
def watch() -> None:
|
def watch() -> None:
|
||||||
print("Setup test envs ...")
|
print("Setup test envs ...")
|
||||||
policy.eval()
|
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
print("Testing agent ...")
|
print("Testing agent ...")
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
|
||||||
result.pprint_asdict()
|
result.pprint_asdict()
|
||||||
|
|
||||||
if args.watch:
|
if args.watch:
|
||||||
|
@ -188,11 +188,10 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
|
|||||||
# watch agent's performance
|
# watch agent's performance
|
||||||
def watch() -> None:
|
def watch() -> None:
|
||||||
print("Setup test envs ...")
|
print("Setup test envs ...")
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
print("Testing agent ...")
|
print("Testing agent ...")
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
|
||||||
result.pprint_asdict()
|
result.pprint_asdict()
|
||||||
|
|
||||||
if args.watch:
|
if args.watch:
|
||||||
|
@ -145,11 +145,10 @@ def test_il(args: argparse.Namespace = get_args()) -> None:
|
|||||||
# watch agent's performance
|
# watch agent's performance
|
||||||
def watch() -> None:
|
def watch() -> None:
|
||||||
print("Setup test envs ...")
|
print("Setup test envs ...")
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
print("Testing agent ...")
|
print("Testing agent ...")
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
|
||||||
result.pprint_asdict()
|
result.pprint_asdict()
|
||||||
|
|
||||||
if args.watch:
|
if args.watch:
|
||||||
|
@ -206,9 +206,8 @@ def test_bcq() -> None:
|
|||||||
args.resume_path = os.path.join(log_path, "policy.pth")
|
args.resume_path = os.path.join(log_path, "policy.pth")
|
||||||
|
|
||||||
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
||||||
policy.eval()
|
|
||||||
collector = Collector(policy, env)
|
collector = Collector(policy, env)
|
||||||
collector.collect(n_episode=1, render=1 / 35)
|
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
|
||||||
|
|
||||||
if not args.watch:
|
if not args.watch:
|
||||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||||
@ -229,10 +228,13 @@ def test_bcq() -> None:
|
|||||||
watch()
|
watch()
|
||||||
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -344,9 +344,8 @@ def test_cql() -> None:
|
|||||||
args.resume_path = os.path.join(log_path, "policy.pth")
|
args.resume_path = os.path.join(log_path, "policy.pth")
|
||||||
|
|
||||||
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
||||||
policy.eval()
|
|
||||||
collector = Collector(policy, env)
|
collector = Collector(policy, env)
|
||||||
collector.collect(n_episode=1, render=1 / 35)
|
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
|
||||||
|
|
||||||
if not args.watch:
|
if not args.watch:
|
||||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||||
@ -367,10 +366,13 @@ def test_cql() -> None:
|
|||||||
watch()
|
watch()
|
||||||
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -142,9 +142,8 @@ def test_il() -> None:
|
|||||||
args.resume_path = os.path.join(log_path, "policy.pth")
|
args.resume_path = os.path.join(log_path, "policy.pth")
|
||||||
|
|
||||||
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
||||||
policy.eval()
|
|
||||||
collector = Collector(policy, env)
|
collector = Collector(policy, env)
|
||||||
collector.collect(n_episode=1, render=1 / 35)
|
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
|
||||||
|
|
||||||
if not args.watch:
|
if not args.watch:
|
||||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||||
@ -165,10 +164,13 @@ def test_il() -> None:
|
|||||||
watch()
|
watch()
|
||||||
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -191,9 +191,8 @@ def test_td3_bc() -> None:
|
|||||||
args.resume_path = os.path.join(log_path, "policy.pth")
|
args.resume_path = os.path.join(log_path, "policy.pth")
|
||||||
|
|
||||||
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
||||||
policy.eval()
|
|
||||||
collector = Collector(policy, env)
|
collector = Collector(policy, env)
|
||||||
collector.collect(n_episode=1, render=1 / 35)
|
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
|
||||||
|
|
||||||
if not args.watch:
|
if not args.watch:
|
||||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||||
@ -217,10 +216,13 @@ def test_td3_bc() -> None:
|
|||||||
watch()
|
watch()
|
||||||
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
collector_stats = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
print(collector_stats)
|
print(collector_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@ -168,7 +168,6 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
|||||||
# watch agent's performance
|
# watch agent's performance
|
||||||
def watch() -> None:
|
def watch() -> None:
|
||||||
print("Setup test envs ...")
|
print("Setup test envs ...")
|
||||||
policy.eval()
|
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
if args.save_buffer_name:
|
if args.save_buffer_name:
|
||||||
@ -181,14 +180,18 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
|||||||
stack_num=args.frames_stack,
|
stack_num=args.frames_stack,
|
||||||
)
|
)
|
||||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||||
result = collector.collect(n_step=args.buffer_size)
|
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||||
print(f"Save buffer into {args.save_buffer_name}")
|
print(f"Save buffer into {args.save_buffer_name}")
|
||||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||||
buffer.save_hdf5(args.save_buffer_name)
|
buffer.save_hdf5(args.save_buffer_name)
|
||||||
else:
|
else:
|
||||||
print("Testing agent ...")
|
print("Testing agent ...")
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
result = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
result.pprint_asdict()
|
result.pprint_asdict()
|
||||||
|
|
||||||
if args.watch:
|
if args.watch:
|
||||||
|
@ -231,7 +231,6 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
|||||||
# watch agent's performance
|
# watch agent's performance
|
||||||
def watch() -> None:
|
def watch() -> None:
|
||||||
print("Setup test envs ...")
|
print("Setup test envs ...")
|
||||||
policy.eval()
|
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
if args.save_buffer_name:
|
if args.save_buffer_name:
|
||||||
print(f"Generate buffer with size {args.buffer_size}")
|
print(f"Generate buffer with size {args.buffer_size}")
|
||||||
@ -243,14 +242,18 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
|||||||
stack_num=args.frames_stack,
|
stack_num=args.frames_stack,
|
||||||
)
|
)
|
||||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||||
result = collector.collect(n_step=args.buffer_size)
|
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||||
print(f"Save buffer into {args.save_buffer_name}")
|
print(f"Save buffer into {args.save_buffer_name}")
|
||||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||||
buffer.save_hdf5(args.save_buffer_name)
|
buffer.save_hdf5(args.save_buffer_name)
|
||||||
else:
|
else:
|
||||||
print("Testing agent ...")
|
print("Testing agent ...")
|
||||||
test_collector.reset()
|
test_collector.reset()
|
||||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
result = test_collector.collect(
|
||||||
|
n_episode=args.test_num,
|
||||||
|
render=args.render,
|
||||||
|
is_eval=True,
|
||||||
|
)
|
||||||
result.pprint_asdict()
|
result.pprint_asdict()
|
||||||
|
|
||||||
if args.watch:
|
if args.watch:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user