update examples
This commit is contained in:
parent
e499bed8b0
commit
8cb17de190
@ -162,7 +162,6 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
||||
# watch agent's performance
|
||||
def watch() -> None:
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
if args.save_buffer_name:
|
||||
@ -175,14 +174,18 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -204,7 +204,6 @@ def main(args: argparse.Namespace = get_args()) -> None:
|
||||
# watch agent's performance
|
||||
def watch() -> None:
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
if args.save_buffer_name:
|
||||
@ -217,14 +216,18 @@ def main(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -175,7 +175,6 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
|
||||
# watch agent's performance
|
||||
def watch() -> None:
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
if args.save_buffer_name:
|
||||
@ -188,14 +187,18 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -172,7 +172,6 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
|
||||
# watch agent's performance
|
||||
def watch() -> None:
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
if args.save_buffer_name:
|
||||
@ -185,14 +184,18 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -229,7 +229,6 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
# watch agent's performance
|
||||
def watch() -> None:
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
if args.save_buffer_name:
|
||||
print(f"Generate buffer with size {args.buffer_size}")
|
||||
@ -241,14 +240,18 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -166,7 +166,6 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
|
||||
# watch agent's performance
|
||||
def watch() -> None:
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
if args.save_buffer_name:
|
||||
@ -179,14 +178,18 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -200,7 +200,6 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
|
||||
# watch agent's performance
|
||||
def watch() -> None:
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
if args.save_buffer_name:
|
||||
@ -215,14 +214,18 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
|
||||
beta=args.beta,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -216,7 +216,6 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
|
||||
# watch agent's performance
|
||||
def watch() -> None:
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
if args.save_buffer_name:
|
||||
print(f"Generate buffer with size {args.buffer_size}")
|
||||
@ -228,14 +227,18 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -144,11 +144,14 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -162,11 +162,14 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -207,10 +207,13 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None:
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -141,11 +141,14 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -153,10 +153,13 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -80,7 +80,6 @@ def main() -> None:
|
||||
print(f"Finished training in {result.timing.total_time} seconds")
|
||||
|
||||
# watch performance
|
||||
policy.eval()
|
||||
policy.set_eps(eps_test)
|
||||
collector = ts.data.Collector(policy, env, exploration_noise=True)
|
||||
collector.collect(n_episode=100, render=1 / 35)
|
||||
|
@ -264,10 +264,13 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -238,10 +238,13 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
collector_stats.pprint_asdict()
|
||||
|
||||
|
||||
|
@ -219,10 +219,13 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -168,10 +168,13 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -216,10 +216,13 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -224,10 +224,13 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -196,10 +196,13 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -196,10 +196,13 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -190,10 +190,13 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -188,10 +188,13 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -221,10 +221,13 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -187,12 +187,11 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
|
||||
# watch agent's performance
|
||||
def watch() -> None:
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -171,12 +171,11 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
|
||||
# watch agent's performance
|
||||
def watch() -> None:
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -188,11 +188,10 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
|
||||
# watch agent's performance
|
||||
def watch() -> None:
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -145,11 +145,10 @@ def test_il(args: argparse.Namespace = get_args()) -> None:
|
||||
# watch agent's performance
|
||||
def watch() -> None:
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -206,9 +206,8 @@ def test_bcq() -> None:
|
||||
args.resume_path = os.path.join(log_path, "policy.pth")
|
||||
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
|
||||
|
||||
if not args.watch:
|
||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||
@ -229,10 +228,13 @@ def test_bcq() -> None:
|
||||
watch()
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -344,9 +344,8 @@ def test_cql() -> None:
|
||||
args.resume_path = os.path.join(log_path, "policy.pth")
|
||||
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
|
||||
|
||||
if not args.watch:
|
||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||
@ -367,10 +366,13 @@ def test_cql() -> None:
|
||||
watch()
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -142,9 +142,8 @@ def test_il() -> None:
|
||||
args.resume_path = os.path.join(log_path, "policy.pth")
|
||||
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
|
||||
|
||||
if not args.watch:
|
||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||
@ -165,10 +164,13 @@ def test_il() -> None:
|
||||
watch()
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -191,9 +191,8 @@ def test_td3_bc() -> None:
|
||||
args.resume_path = os.path.join(log_path, "policy.pth")
|
||||
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
|
||||
|
||||
if not args.watch:
|
||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||
@ -217,10 +216,13 @@ def test_td3_bc() -> None:
|
||||
watch()
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -168,7 +168,6 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
||||
# watch agent's performance
|
||||
def watch() -> None:
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
if args.save_buffer_name:
|
||||
@ -181,14 +180,18 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -231,7 +231,6 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
# watch agent's performance
|
||||
def watch() -> None:
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
if args.save_buffer_name:
|
||||
print(f"Generate buffer with size {args.buffer_size}")
|
||||
@ -243,14 +242,18 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
Loading…
x
Reference in New Issue
Block a user