Collector: extracted interface BaseCollector, minor simplifications
Renamed is_eval kwarg
This commit is contained in:
parent
12d4262f80
commit
4b619c51ba
@ -174,7 +174,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
@ -184,7 +184,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
|
@ -216,7 +216,7 @@ def main(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
@ -226,7 +226,7 @@ def main(args: argparse.Namespace = get_args()) -> None:
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
|
@ -187,7 +187,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
@ -197,7 +197,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
|
@ -184,7 +184,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
@ -194,7 +194,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
|
@ -240,7 +240,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
@ -250,7 +250,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
|
@ -178,7 +178,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
@ -188,7 +188,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
|
@ -214,7 +214,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
|
||||
beta=args.beta,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
@ -224,7 +224,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
|
@ -227,7 +227,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
@ -237,7 +237,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
|
@ -150,7 +150,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -168,7 +168,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -212,7 +212,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -147,7 +147,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -158,7 +158,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -269,7 +269,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -243,7 +243,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
collector_stats.pprint_asdict()
|
||||
|
||||
|
@ -224,7 +224,7 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -173,7 +173,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -221,7 +221,7 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -229,7 +229,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -201,7 +201,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -201,7 +201,7 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -195,7 +195,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -193,7 +193,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -226,7 +226,7 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -191,7 +191,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
|
||||
test_envs.seed(args.seed)
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -175,7 +175,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
|
||||
test_envs.seed(args.seed)
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -191,7 +191,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
|
||||
test_envs.seed(args.seed)
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -148,7 +148,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None:
|
||||
test_envs.seed(args.seed)
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True)
|
||||
result.pprint_asdict()
|
||||
|
||||
if args.watch:
|
||||
|
@ -207,7 +207,7 @@ def test_bcq() -> None:
|
||||
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
|
||||
collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
|
||||
|
||||
if not args.watch:
|
||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||
@ -233,7 +233,7 @@ def test_bcq() -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -345,7 +345,7 @@ def test_cql() -> None:
|
||||
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
|
||||
collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
|
||||
|
||||
if not args.watch:
|
||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||
@ -371,7 +371,7 @@ def test_cql() -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -143,7 +143,7 @@ def test_il() -> None:
|
||||
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
|
||||
collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
|
||||
|
||||
if not args.watch:
|
||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||
@ -169,7 +169,7 @@ def test_il() -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -192,7 +192,7 @@ def test_td3_bc() -> None:
|
||||
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
|
||||
collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
|
||||
|
||||
if not args.watch:
|
||||
replay_buffer = load_buffer_d4rl(args.expert_data_task)
|
||||
@ -221,7 +221,7 @@ def test_td3_bc() -> None:
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -180,7 +180,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
@ -190,7 +190,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
|
@ -242,7 +242,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
@ -252,7 +252,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
eval_mode=True,
|
||||
)
|
||||
result.pprint_asdict()
|
||||
|
||||
|
@ -167,7 +167,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
|
||||
policy.set_eps(0.2)
|
||||
collector = Collector(policy, test_envs, buf, exploration_noise=True)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
collector_stats = collector.collect(n_step=args.buffer_size, eval_mode=True)
|
||||
if args.save_buffer_name.endswith(".hdf5"):
|
||||
buf.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
|
@ -185,7 +185,7 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None:
|
||||
torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")),
|
||||
)
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
|
||||
collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
|
||||
|
||||
# trainer
|
||||
result = OfflineTrainer(
|
||||
|
@ -190,5 +190,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non
|
||||
policy, _, _ = get_agents(args)
|
||||
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
|
||||
collector = Collector(policy, env, exploration_noise=True)
|
||||
result = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
result = collector.collect(n_episode=1, render=args.render, eval_mode=True)
|
||||
result.pprint_asdict()
|
||||
|
@ -285,5 +285,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non
|
||||
)
|
||||
policy, _, _ = get_agents(args)
|
||||
collector = Collector(policy, env)
|
||||
collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
collector_result = collector.collect(n_episode=1, render=args.render, eval_mode=True)
|
||||
collector_result.pprint_asdict()
|
||||
|
@ -230,5 +230,5 @@ def watch(
|
||||
policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent)
|
||||
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
|
||||
collector = Collector(policy, env, exploration_noise=True)
|
||||
result = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
result = collector.collect(n_episode=1, render=args.render, eval_mode=True)
|
||||
result.pprint_asdict()
|
||||
|
@ -1,5 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Self, TypeVar, cast
|
||||
@ -7,11 +9,11 @@ from typing import Any, Self, TypeVar, cast
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from overrides import override
|
||||
|
||||
from tianshou.data import (
|
||||
Batch,
|
||||
CachedReplayBuffer,
|
||||
PrioritizedReplayBuffer,
|
||||
ReplayBuffer,
|
||||
ReplayBufferManager,
|
||||
SequenceSummaryStats,
|
||||
@ -25,6 +27,9 @@ from tianshou.data.types import (
|
||||
from tianshou.env import BaseVectorEnv, DummyVectorEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.utils.print import DataclassPPrintMixin
|
||||
from tianshou.utils.torch_utils import in_eval_mode, in_train_mode
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -122,23 +127,12 @@ def _HACKY_create_info_batch(info_array: np.ndarray) -> Batch:
|
||||
return result_batch_parent.info
|
||||
|
||||
|
||||
class Collector:
|
||||
"""Collector enables the policy to interact with different types of envs with exact number of steps or episodes.
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||
:param env: a ``gym.Env`` environment or an instance of the
|
||||
:class:`~tianshou.env.BaseVectorEnv` class.
|
||||
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
|
||||
If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer`
|
||||
as the default buffer.
|
||||
:param exploration_noise: determine whether the action needs to be modified
|
||||
with the corresponding policy's exploration noise. If so, "policy.
|
||||
exploration_noise(act, batch)" will be called automatically to add the
|
||||
exploration noise into action. Default to False.
|
||||
class BaseCollector(ABC):
|
||||
"""Used to collect data from a vector environment into a buffer using a given policy.
|
||||
|
||||
.. note::
|
||||
|
||||
Please make sure the given environment has a time limitation if using n_episode
|
||||
Please make sure the given environment has a time limitation if using `n_episode`
|
||||
collect option.
|
||||
|
||||
.. note::
|
||||
@ -150,72 +144,70 @@ class Collector:
|
||||
def __init__(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
env: gym.Env | BaseVectorEnv,
|
||||
env: BaseVectorEnv | gym.Env,
|
||||
buffer: ReplayBuffer | None = None,
|
||||
exploration_noise: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
|
||||
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
|
||||
# Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy
|
||||
self.env = DummyVectorEnv([lambda: env])
|
||||
else:
|
||||
self.env = env # type: ignore
|
||||
self.env_num = len(self.env)
|
||||
self.exploration_noise = exploration_noise
|
||||
self.buffer = self._assign_buffer(buffer)
|
||||
env = DummyVectorEnv([lambda: env]) # type: ignore
|
||||
|
||||
if buffer is None:
|
||||
buffer = VectorReplayBuffer(len(env), len(env))
|
||||
|
||||
self.buffer: ReplayBuffer = buffer
|
||||
self.policy = policy
|
||||
self._action_space = self.env.action_space
|
||||
|
||||
self._pre_collect_obs_RO: np.ndarray | None = None
|
||||
self._pre_collect_info_R: np.ndarray | None = None
|
||||
self._pre_collect_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None
|
||||
|
||||
self._is_closed = False
|
||||
self.env = cast(BaseVectorEnv, env)
|
||||
self.exploration_noise = exploration_noise
|
||||
self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0
|
||||
|
||||
self._action_space = self.env.action_space
|
||||
self._is_closed = False
|
||||
|
||||
self._validate_buffer()
|
||||
|
||||
def _validate_buffer(self) -> None:
|
||||
buf = self.buffer
|
||||
# TODO: a bit weird but true - all VectorReplayBuffers inherit from ReplayBufferManager.
|
||||
# We should probably rename the manager
|
||||
if isinstance(buf, ReplayBufferManager) and buf.buffer_num < self.env_num:
|
||||
raise ValueError(
|
||||
f"Buffer has only {buf.buffer_num} buffers, but at least {self.env_num=} are needed.",
|
||||
)
|
||||
if isinstance(buf, CachedReplayBuffer) and buf.cached_buffer_num < self.env_num:
|
||||
raise ValueError(
|
||||
f"Buffer has only {buf.cached_buffer_num} cached buffers, but at least {self.env_num=} are needed.",
|
||||
)
|
||||
# Non-VectorReplayBuffer. TODO: probably shouldn't rely on isinstance
|
||||
if not isinstance(buf, ReplayBufferManager):
|
||||
if buf.maxsize == 0:
|
||||
raise ValueError("Buffer maxsize should be greater than 0.")
|
||||
if self.env_num > 1:
|
||||
raise ValueError(
|
||||
f"Cannot use {type(buf).__name__} to collect from multiple envs ({self.env_num=}). "
|
||||
f"Please use the corresponding VectorReplayBuffer instead.",
|
||||
)
|
||||
|
||||
@property
|
||||
def env_num(self) -> int:
|
||||
return len(self.env)
|
||||
|
||||
@property
|
||||
def action_space(self) -> gym.spaces.Space:
|
||||
return self._action_space
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the collector and the environment."""
|
||||
self.env.close()
|
||||
self._pre_collect_obs_RO = None
|
||||
self._pre_collect_info_R = None
|
||||
self._is_closed = True
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Return True if the collector is closed."""
|
||||
return self._is_closed
|
||||
|
||||
def _assign_buffer(self, buffer: ReplayBuffer | None) -> ReplayBuffer:
|
||||
"""Check if the buffer matches the constraint."""
|
||||
if buffer is None:
|
||||
buffer = VectorReplayBuffer(self.env_num, self.env_num)
|
||||
elif isinstance(buffer, ReplayBufferManager):
|
||||
assert buffer.buffer_num >= self.env_num
|
||||
if isinstance(buffer, CachedReplayBuffer):
|
||||
assert buffer.cached_buffer_num >= self.env_num
|
||||
else: # ReplayBuffer or PrioritizedReplayBuffer
|
||||
assert buffer.maxsize > 0
|
||||
if self.env_num > 1:
|
||||
if isinstance(buffer, ReplayBuffer):
|
||||
buffer_type = "ReplayBuffer"
|
||||
vector_type = "VectorReplayBuffer"
|
||||
if isinstance(buffer, PrioritizedReplayBuffer):
|
||||
buffer_type = "PrioritizedReplayBuffer"
|
||||
vector_type = "PrioritizedVectorReplayBuffer"
|
||||
raise TypeError(
|
||||
f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect "
|
||||
f"{self.env_num} envs,\n\tplease use {vector_type}(total_size="
|
||||
f"{buffer.maxsize}, buffer_num={self.env_num}, ...) instead.",
|
||||
)
|
||||
return buffer
|
||||
|
||||
def reset(
|
||||
self,
|
||||
reset_buffer: bool = True,
|
||||
reset_stats: bool = True,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Reset the environment, statistics, and data needed to start the collection.
|
||||
|
||||
:param reset_buffer: if true, reset the replay buffer attached
|
||||
@ -224,12 +216,13 @@ class Collector:
|
||||
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
|
||||
reset function. Defaults to None (extra keyword arguments)
|
||||
"""
|
||||
self.reset_env(gym_reset_kwargs=gym_reset_kwargs)
|
||||
obs_NO, info_N = self.reset_env(gym_reset_kwargs=gym_reset_kwargs)
|
||||
if reset_buffer:
|
||||
self.reset_buffer()
|
||||
if reset_stats:
|
||||
self.reset_stat()
|
||||
self._is_closed = False
|
||||
return obs_NO, info_N
|
||||
|
||||
def reset_stat(self) -> None:
|
||||
"""Reset the statistic variables."""
|
||||
@ -242,18 +235,168 @@ class Collector:
|
||||
def reset_env(
|
||||
self,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Reset the environments and the initial obs, info, and hidden state of the collector."""
|
||||
gym_reset_kwargs = gym_reset_kwargs or {}
|
||||
self._pre_collect_obs_RO, self._pre_collect_info_R = self.env.reset(**gym_reset_kwargs)
|
||||
obs_NO, info_N = self.env.reset(**gym_reset_kwargs)
|
||||
# TODO: hack, wrap envpool envs such that they don't return a dict
|
||||
if isinstance(self._pre_collect_info_R, dict): # type: ignore[unreachable]
|
||||
if isinstance(info_N, dict): # type: ignore[unreachable]
|
||||
# this can happen if the env is an envpool env. Then the thing returned by reset is a dict
|
||||
# with array entries instead of an array of dicts
|
||||
# We use Batch to turn it into an array of dicts
|
||||
self._pre_collect_info_R = _dict_of_arr_to_arr_of_dicts(self._pre_collect_info_R) # type: ignore[unreachable]
|
||||
info_N = _dict_of_arr_to_arr_of_dicts(info_N) # type: ignore[unreachable]
|
||||
return obs_NO, info_N
|
||||
|
||||
@abstractmethod
|
||||
def _collect(
|
||||
self,
|
||||
n_step: int | None = None,
|
||||
n_episode: int | None = None,
|
||||
random: bool = False,
|
||||
render: float | None = None,
|
||||
no_grad: bool = True,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
) -> CollectStats:
|
||||
pass
|
||||
|
||||
def collect(
|
||||
self,
|
||||
n_step: int | None = None,
|
||||
n_episode: int | None = None,
|
||||
random: bool = False,
|
||||
render: float | None = None,
|
||||
no_grad: bool = True,
|
||||
reset_before_collect: bool = False,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
eval_mode: bool = False,
|
||||
) -> CollectStats:
|
||||
"""Collect a specified number of steps or episodes.
|
||||
|
||||
To ensure an unbiased sampling result with the n_episode option, this function will
|
||||
first collect ``n_episode - env_num`` episodes, then for the last ``env_num``
|
||||
episodes, they will be collected evenly from each env.
|
||||
|
||||
:param n_step: how many steps you want to collect.
|
||||
:param n_episode: how many episodes you want to collect.
|
||||
:param random: whether to use random policy for collecting data.
|
||||
:param render: the sleep time between rendering consecutive frames.
|
||||
:param no_grad: whether to retain gradient in policy.forward().
|
||||
:param reset_before_collect: whether to reset the environment before collecting data.
|
||||
(The collector needs the initial obs and info to function properly.)
|
||||
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
|
||||
reset function. Only used if reset_before_collect is True.
|
||||
:param eval_mode: whether to collect data in evaluation mode. Will
|
||||
set the policy to training mode otherwise.
|
||||
|
||||
.. note::
|
||||
|
||||
One and only one collection number specification is permitted, either
|
||||
``n_step`` or ``n_episode``.
|
||||
|
||||
:return: The collected stats
|
||||
"""
|
||||
# check that exactly one of n_step or n_episode is set and that the other is larger than 0
|
||||
self._validate_n_step_n_episode(n_episode, n_step)
|
||||
|
||||
if reset_before_collect:
|
||||
self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs)
|
||||
|
||||
policy_mode_context = in_eval_mode if eval_mode else in_train_mode
|
||||
with policy_mode_context(self.policy):
|
||||
return self._collect(
|
||||
n_step=n_step,
|
||||
n_episode=n_episode,
|
||||
random=random,
|
||||
render=render,
|
||||
no_grad=no_grad,
|
||||
gym_reset_kwargs=gym_reset_kwargs,
|
||||
)
|
||||
|
||||
def _validate_n_step_n_episode(self, n_episode: int | None, n_step: int | None) -> None:
|
||||
if not n_step and not n_episode:
|
||||
raise ValueError(
|
||||
f"Only one of n_step and n_episode should be set to a value larger than zero "
|
||||
f"but got {n_step=}, {n_episode=}.",
|
||||
)
|
||||
if n_step is None and n_episode is None:
|
||||
raise ValueError(
|
||||
"Exactly one of n_step and n_episode should be set but got None for both.",
|
||||
)
|
||||
if n_step and n_step % self.env_num != 0:
|
||||
warnings.warn(
|
||||
f"{n_step=} is not a multiple of ({self.env_num=}), "
|
||||
"which may cause extra transitions being collected into the buffer.",
|
||||
)
|
||||
if n_episode and self.env_num > n_episode:
|
||||
warnings.warn(
|
||||
f"{n_episode=} should be larger than {self.env_num=} to "
|
||||
f"collect at least one trajectory in each environment.",
|
||||
)
|
||||
|
||||
|
||||
class Collector(BaseCollector):
|
||||
# NAMING CONVENTION (mostly suffixes):
|
||||
# episode - An episode means a rollout until done (terminated or truncated). After an episode is completed,
|
||||
# the corresponding env is either reset or removed from the ready envs.
|
||||
# N - number of envs, always fixed and >= R.
|
||||
# R - number ready env ids. Note that this might change when envs get idle.
|
||||
# This can only happen in n_episode case, see explanation in the corresponding block.
|
||||
# For n_step, we always use all envs to collect the data, while for n_episode,
|
||||
# R will be at most n_episode at the beginning, but can decrease during the collection.
|
||||
# O - dimension(s) of observations
|
||||
# A - dimension(s) of actions
|
||||
# H - dimension(s) of hidden state
|
||||
# D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case.
|
||||
# S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration.
|
||||
# Only used in n_episode case. Then, R becomes R-S.
|
||||
|
||||
# set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy
|
||||
# evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on
|
||||
# policy.deterministic_eval)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
env: gym.Env | BaseVectorEnv,
|
||||
buffer: ReplayBuffer | None = None,
|
||||
exploration_noise: bool = False,
|
||||
) -> None:
|
||||
""":param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||
:param env: a ``gym.Env`` environment or an instance of the
|
||||
:class:`~tianshou.env.BaseVectorEnv` class.
|
||||
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
|
||||
If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer`
|
||||
as the default buffer.
|
||||
:param exploration_noise: determine whether the action needs to be modified
|
||||
with the corresponding policy's exploration noise. If so, "policy.
|
||||
exploration_noise(act, batch)" will be called automatically to add the
|
||||
exploration noise into action. Default to False.
|
||||
"""
|
||||
super().__init__(policy, env, buffer, exploration_noise=exploration_noise)
|
||||
self._pre_collect_obs_RO: np.ndarray | None = None
|
||||
self._pre_collect_info_R: np.ndarray | None = None
|
||||
self._pre_collect_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None
|
||||
|
||||
self._is_closed = False
|
||||
self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0
|
||||
|
||||
def close(self) -> None:
|
||||
super().close()
|
||||
self._pre_collect_obs_RO = None
|
||||
self._pre_collect_info_R = None
|
||||
|
||||
def reset_env(
|
||||
self,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Reset the environments and the initial obs, info, and hidden state of the collector."""
|
||||
obs_NO, info_N = super().reset_env(gym_reset_kwargs=gym_reset_kwargs)
|
||||
# We assume that R = N when reset is called.
|
||||
# TODO: there is currently no mechanism that ensures this and it's a public method!
|
||||
self._pre_collect_obs_RO = obs_NO
|
||||
self._pre_collect_info_R = info_N
|
||||
self._pre_collect_hidden_state_RH = None
|
||||
return obs_NO, info_N
|
||||
|
||||
def _compute_action_policy_hidden(
|
||||
self,
|
||||
@ -309,98 +452,30 @@ class Collector:
|
||||
return act_RA, act_normalized_RA, policy_R, hidden_state_RH
|
||||
|
||||
# TODO: reduce complexity, remove the noqa
|
||||
def collect(
|
||||
def _collect(
|
||||
self,
|
||||
n_step: int | None = None,
|
||||
n_episode: int | None = None,
|
||||
random: bool = False,
|
||||
render: float | None = None,
|
||||
no_grad: bool = True,
|
||||
reset_before_collect: bool = False,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
is_eval: bool = False,
|
||||
) -> CollectStats:
|
||||
"""Collect a specified number of steps or episodes.
|
||||
|
||||
To ensure an unbiased sampling result with the n_episode option, this function will
|
||||
first collect ``n_episode - env_num`` episodes, then for the last ``env_num``
|
||||
episodes, they will be collected evenly from each env.
|
||||
|
||||
:param n_step: how many steps you want to collect.
|
||||
:param n_episode: how many episodes you want to collect.
|
||||
:param random: whether to use random policy for collecting data.
|
||||
:param render: the sleep time between rendering consecutive frames.
|
||||
:param no_grad: whether to retain gradient in policy.forward().
|
||||
:param reset_before_collect: whether to reset the environment before collecting data.
|
||||
(The collector needs the initial obs and info to function properly.)
|
||||
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
|
||||
reset function. Only used if reset_before_collect is True.
|
||||
:param is_eval: whether to collect data in evaluation mode.
|
||||
|
||||
.. note::
|
||||
|
||||
One and only one collection number specification is permitted, either
|
||||
``n_step`` or ``n_episode``.
|
||||
|
||||
:return: The collected stats
|
||||
"""
|
||||
# NAMING CONVENTION (mostly suffixes):
|
||||
# episode - An episode means a rollout until done (terminated or truncated). After an episode is completed,
|
||||
# the corresponding env is either reset or removed from the ready envs.
|
||||
# R - number ready env ids. Note that this might change when envs get idle.
|
||||
# This can only happen in n_episode case, see explanation in the corresponding block.
|
||||
# For n_step, we always use all envs to collect the data, while for n_episode,
|
||||
# R will be at most n_episode at the beginning, but can decrease during the collection.
|
||||
# O - dimension(s) of observations
|
||||
# A - dimension(s) of actions
|
||||
# H - dimension(s) of hidden state
|
||||
# D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case.
|
||||
# S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration.
|
||||
# Only used in n_episode case. Then, R becomes R-S.
|
||||
|
||||
# set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy
|
||||
# evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on
|
||||
# policy.deterministic_eval)
|
||||
self.policy.eval()
|
||||
pre_collect_is_eval = self.policy.is_eval
|
||||
self.policy.is_eval = is_eval
|
||||
|
||||
use_grad = not no_grad
|
||||
gym_reset_kwargs = gym_reset_kwargs or {}
|
||||
|
||||
# Input validation
|
||||
assert not self.env.is_async, "Please use AsyncCollector if using async venv."
|
||||
if n_step is not None:
|
||||
assert n_episode is None, (
|
||||
f"Only one of n_step or n_episode is allowed in Collector."
|
||||
f"collect, got {n_step=}, {n_episode=}."
|
||||
# TODO: can't do it init since AsyncCollector is currently a subclass of Collector
|
||||
if self.env.is_async:
|
||||
raise ValueError(
|
||||
f"Please use {AsyncCollector.__name__} for asynchronous environments. "
|
||||
f"Env class: {self.env.__class__.__name__}.",
|
||||
)
|
||||
assert n_step > 0
|
||||
if n_step % self.env_num != 0:
|
||||
warnings.warn(
|
||||
f"{n_step=} is not a multiple of ({self.env_num=}), "
|
||||
"which may cause extra transitions being collected into the buffer.",
|
||||
)
|
||||
|
||||
if n_step is not None:
|
||||
ready_env_ids_R = np.arange(self.env_num)
|
||||
elif n_episode is not None:
|
||||
assert n_episode > 0
|
||||
if self.env_num > n_episode:
|
||||
warnings.warn(
|
||||
f"{n_episode=} should be larger than {self.env_num=} to "
|
||||
f"collect at least one trajectory in each environment.",
|
||||
)
|
||||
ready_env_ids_R = np.arange(min(self.env_num, n_episode))
|
||||
else:
|
||||
raise TypeError(
|
||||
"Please specify at least one (either n_step or n_episode) "
|
||||
"in AsyncCollector.collect().",
|
||||
)
|
||||
|
||||
use_grad = not no_grad
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
if reset_before_collect:
|
||||
self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs)
|
||||
|
||||
if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None:
|
||||
raise ValueError(
|
||||
"Initial obs and info should not be None. "
|
||||
@ -491,7 +566,8 @@ class Collector:
|
||||
step_count += len(ready_env_ids_R)
|
||||
|
||||
# preparing for the next iteration
|
||||
# obs_next, info and hidden_state will be modified inplace in the code below, so we copy to not affect the data in the buffer
|
||||
# obs_next, info and hidden_state will be modified inplace in the code below,
|
||||
# so we copy to not affect the data in the buffer
|
||||
last_obs_RO = copy(obs_next_RO)
|
||||
last_info_R = copy(info_R)
|
||||
last_hidden_state_RH = copy(hidden_state_RH)
|
||||
@ -509,6 +585,7 @@ class Collector:
|
||||
# now we copy obs_next to obs, but since there might be
|
||||
# finished episodes, we have to reset finished envs first.
|
||||
|
||||
gym_reset_kwargs = gym_reset_kwargs or {}
|
||||
obs_reset_DO, info_reset_D = self.env.reset(
|
||||
env_id=env_ind_global_D,
|
||||
**gym_reset_kwargs,
|
||||
@ -577,9 +654,6 @@ class Collector:
|
||||
# reset envs and the _pre_collect fields
|
||||
self.reset_env(gym_reset_kwargs) # todo still necessary?
|
||||
|
||||
# set the policy back to pre collect mode
|
||||
self.policy.is_eval = pre_collect_is_eval
|
||||
|
||||
return CollectStats.with_autogenerated_stats(
|
||||
returns=np.array(episode_returns),
|
||||
lens=np.array(episode_lens),
|
||||
@ -608,8 +682,7 @@ class Collector:
|
||||
class AsyncCollector(Collector):
|
||||
"""Async Collector handles async vector environment.
|
||||
|
||||
The arguments are exactly the same as :class:`~tianshou.data.Collector`, please
|
||||
refer to :class:`~tianshou.data.Collector` for more detailed explanation.
|
||||
Please refer to :class:`~tianshou.data.Collector` for a more detailed explanation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -619,6 +692,12 @@ class AsyncCollector(Collector):
|
||||
buffer: ReplayBuffer | None = None,
|
||||
exploration_noise: bool = False,
|
||||
) -> None:
|
||||
if not env.is_async:
|
||||
# TODO: raise an exception?
|
||||
log.error(
|
||||
f"Please use {Collector.__name__} if not using async venv. "
|
||||
f"Env class: {env.__class__.__name__}",
|
||||
)
|
||||
# assert env.is_async
|
||||
warnings.warn("Using async setting may collect extra transitions into buffer.")
|
||||
super().__init__(
|
||||
@ -644,7 +723,7 @@ class AsyncCollector(Collector):
|
||||
reset_buffer: bool = True,
|
||||
reset_stats: bool = True,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Reset the environment, statistics, and data needed to start the collection.
|
||||
|
||||
:param reset_buffer: if true, reset the replay buffer attached
|
||||
@ -654,7 +733,7 @@ class AsyncCollector(Collector):
|
||||
reset function. Defaults to None (extra keyword arguments)
|
||||
"""
|
||||
# This sets the _pre_collect attrs
|
||||
super().reset(
|
||||
result = super().reset(
|
||||
reset_buffer=reset_buffer,
|
||||
reset_stats=reset_stats,
|
||||
gym_reset_kwargs=gym_reset_kwargs,
|
||||
@ -667,78 +746,29 @@ class AsyncCollector(Collector):
|
||||
self._current_hidden_state_in_all_envs_EH = copy(self._pre_collect_hidden_state_RH)
|
||||
self._current_action_in_all_envs_EA = np.empty(self.env_num)
|
||||
self._current_policy_in_all_envs_E = None
|
||||
return result
|
||||
|
||||
def collect(
|
||||
@override
|
||||
def reset_env(
|
||||
self,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
# we need to step through the envs and wait until they are ready to be able to interact with them
|
||||
if self.env.waiting_id:
|
||||
self.env.step(None, id=self.env.waiting_id)
|
||||
return super().reset_env(gym_reset_kwargs=gym_reset_kwargs)
|
||||
|
||||
@override
|
||||
def _collect(
|
||||
self,
|
||||
n_step: int | None = None,
|
||||
n_episode: int | None = None,
|
||||
random: bool = False,
|
||||
render: float | None = None,
|
||||
no_grad: bool = True,
|
||||
reset_before_collect: bool = False,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
is_eval: bool = False,
|
||||
) -> CollectStats:
|
||||
"""Collect a specified number of steps or episodes with async env setting.
|
||||
|
||||
This function does not collect an exact number of transitions specified by n_step or
|
||||
n_episode. Instead, to support the asynchronous setting, it may collect more transitions
|
||||
than requested by n_step or n_episode and save them into the buffer.
|
||||
|
||||
:param n_step: how many steps you want to collect.
|
||||
:param n_episode: how many episodes you want to collect.
|
||||
:param random: whether to use random policy_R for collecting data. Default
|
||||
to False.
|
||||
:param render: the sleep time between rendering consecutive frames.
|
||||
Default to None (no rendering).
|
||||
:param no_grad: whether to retain gradient in policy_R.forward(). Default to
|
||||
True (no gradient retaining).
|
||||
:param reset_before_collect: whether to reset the environment before
|
||||
collecting data. It has only an effect if n_episode is not None, i.e.
|
||||
if one wants to collect a fixed number of episodes.
|
||||
(The collector needs the initial obs and info to function properly.)
|
||||
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
|
||||
reset function. Defaults to None (extra keyword arguments)
|
||||
:param is_eval: whether to collect data in evaluation mode.
|
||||
|
||||
.. note::
|
||||
|
||||
One and only one collection number specification is permitted, either
|
||||
``n_step`` or ``n_episode``.
|
||||
|
||||
:return: A dataclass object
|
||||
"""
|
||||
# set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy
|
||||
# evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on
|
||||
# policy.deterministic_eval)
|
||||
self.policy.eval()
|
||||
pre_collect_is_eval = self.policy.is_eval
|
||||
self.policy.is_eval = is_eval
|
||||
|
||||
use_grad = not no_grad
|
||||
gym_reset_kwargs = gym_reset_kwargs or {}
|
||||
|
||||
# collect at least n_step or n_episode
|
||||
if n_step is not None:
|
||||
assert n_episode is None, (
|
||||
"Only one of n_step or n_episode is allowed in Collector."
|
||||
f"collect, got n_step={n_step}, n_episode={n_episode}."
|
||||
)
|
||||
assert n_step > 0
|
||||
elif n_episode is not None:
|
||||
assert n_episode > 0
|
||||
else:
|
||||
raise TypeError(
|
||||
"Please specify at least one (either n_step or n_episode) "
|
||||
"in AsyncCollector.collect().",
|
||||
)
|
||||
|
||||
if reset_before_collect:
|
||||
# first we need to step all envs to be able to interact with them
|
||||
if self.env.waiting_id:
|
||||
self.env.step(None, id=self.env.waiting_id)
|
||||
self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
step_count = 0
|
||||
@ -868,12 +898,12 @@ class AsyncCollector(Collector):
|
||||
num_collected_episodes += num_episodes_done_this_iter
|
||||
|
||||
# preparing for the next iteration
|
||||
# todo do we need the copy stuff (tests pass also without)
|
||||
# todo seem we can get rid of this last_sth stuff altogether
|
||||
last_obs_RO = copy(obs_next_RO)
|
||||
last_info_R = copy(info_R)
|
||||
last_hidden_state_RH = copy(self._current_hidden_state_in_all_envs_EH[ready_env_ids_R]) # type: ignore[index]
|
||||
|
||||
last_hidden_state_RH = copy(
|
||||
self._current_hidden_state_in_all_envs_EH[ready_env_ids_R], # type: ignore[index]
|
||||
)
|
||||
if num_episodes_done_this_iter:
|
||||
env_ind_local_D = np.where(done_R)[0]
|
||||
env_ind_global_D = ready_env_ids_R[env_ind_local_D]
|
||||
@ -883,6 +913,7 @@ class AsyncCollector(Collector):
|
||||
|
||||
# now we copy obs_next_RO to obs, but since there might be
|
||||
# finished episodes, we have to reset finished envs first.
|
||||
gym_reset_kwargs = gym_reset_kwargs or {}
|
||||
obs_reset_DO, info_reset_D = self.env.reset(
|
||||
env_id=env_ind_global_D,
|
||||
**gym_reset_kwargs,
|
||||
@ -923,9 +954,6 @@ class AsyncCollector(Collector):
|
||||
# persist for future collect iterations
|
||||
self._ready_env_ids_R = ready_env_ids_R
|
||||
|
||||
# set the policy back to pre collect mode
|
||||
self.policy.is_eval = pre_collect_is_eval
|
||||
|
||||
return CollectStats.with_autogenerated_stats(
|
||||
returns=np.array(episode_returns),
|
||||
lens=np.array(episode_lens),
|
||||
|
@ -337,7 +337,7 @@ class Experiment(ToStringMixin):
|
||||
) -> None:
|
||||
collector = Collector(policy, env)
|
||||
collector.reset()
|
||||
result = collector.collect(n_episode=num_episodes, render=render, is_eval=True)
|
||||
result = collector.collect(n_episode=num_episodes, render=render, eval_mode=True)
|
||||
assert result.returns_stat is not None # for mypy
|
||||
assert result.lens_stat is not None # for mypy
|
||||
log.info(
|
||||
|
Loading…
x
Reference in New Issue
Block a user