Change the way in which deterministic evaluation is controlled:

* Remove flag `eval_mode` from Collector.collect
  * Replace flag `is_eval` in BasePolicy with `is_within_training_step` (negating usages)
    and set it appropriately in BaseTrainer
This commit is contained in:
Dominik Jain 2024-05-02 18:31:03 +02:00
parent 18f236167f
commit ca69e79b4a
51 changed files with 126 additions and 241 deletions

View File

@ -174,18 +174,14 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict()
if args.watch:

View File

@ -216,18 +216,14 @@ def main(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict()
if args.watch:

View File

@ -187,18 +187,14 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict()
if args.watch:

View File

@ -184,18 +184,14 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict()
if args.watch:

View File

@ -240,18 +240,14 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict()
if args.watch:

View File

@ -178,18 +178,14 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict()
if args.watch:

View File

@ -214,18 +214,14 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
beta=args.beta,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict()
if args.watch:

View File

@ -227,18 +227,14 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict()
if args.watch:

View File

@ -147,11 +147,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -165,11 +165,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -209,11 +209,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -144,11 +144,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -155,11 +155,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -266,11 +266,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -240,11 +240,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats.pprint_asdict()

View File

@ -221,11 +221,7 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -170,11 +170,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -218,11 +218,7 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -226,11 +226,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -198,11 +198,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -198,11 +198,7 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -192,11 +192,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -190,11 +190,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -223,11 +223,7 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -191,7 +191,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
test_envs.seed(args.seed)
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True)
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict()
if args.watch:

View File

@ -175,7 +175,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
test_envs.seed(args.seed)
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True)
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict()
if args.watch:

View File

@ -191,7 +191,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
test_envs.seed(args.seed)
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True)
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict()
if args.watch:

View File

@ -148,7 +148,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None:
test_envs.seed(args.seed)
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True)
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict()
if args.watch:

View File

@ -207,7 +207,7 @@ def test_bcq() -> None:
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
collector.collect(n_episode=1, render=1 / 35)
if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task)
@ -230,11 +230,7 @@ def test_bcq() -> None:
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -345,7 +345,7 @@ def test_cql() -> None:
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
collector.collect(n_episode=1, render=1 / 35)
if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task)
@ -368,11 +368,7 @@ def test_cql() -> None:
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -143,7 +143,7 @@ def test_il() -> None:
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
collector.collect(n_episode=1, render=1 / 35)
if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task)
@ -166,11 +166,7 @@ def test_il() -> None:
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -192,7 +192,7 @@ def test_td3_bc() -> None:
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
collector.collect(n_episode=1, render=1 / 35)
if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task)
@ -218,11 +218,7 @@ def test_td3_bc() -> None:
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)

View File

@ -180,18 +180,14 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict()
if args.watch:

View File

@ -242,18 +242,14 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size, eval_mode=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict()
if args.watch:

View File

@ -250,7 +250,7 @@ def test_finite_dummy_vector_env() -> None:
envs.tracker = MetricTracker()
try:
# TODO: why on earth 10**18?
test_collector.collect(n_step=10**18)
test_collector.collect(n_step=10 ** 18)
except StopIteration:
envs.tracker.validate()
@ -265,6 +265,6 @@ def test_finite_subproc_vector_env() -> None:
for _ in range(3):
envs.tracker = MetricTracker()
try:
test_collector.collect(n_step=10**18)
test_collector.collect(n_step=10 ** 18)
except StopIteration:
envs.tracker.validate()

View File

@ -64,7 +64,7 @@ def policy(request: pytest.FixtureRequest) -> PPOPolicy:
class TestPolicyBasics:
def test_get_action(self, policy: PPOPolicy) -> None:
policy.is_eval = True
policy.is_within_training_step = False
sample_obs = torch.randn(obs_shape)
policy.deterministic_eval = False
actions = [policy.compute_action(sample_obs) for _ in range(10)]

View File

@ -167,7 +167,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
policy.set_eps(0.2)
collector = Collector(policy, test_envs, buf, exploration_noise=True)
collector.reset()
collector_stats = collector.collect(n_step=args.buffer_size, eval_mode=True)
collector_stats = collector.collect(n_step=args.buffer_size)
if args.save_buffer_name.endswith(".hdf5"):
buf.save_hdf5(args.save_buffer_name)
else:

View File

@ -185,7 +185,7 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None:
torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")),
)
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
collector.collect(n_episode=1, render=1 / 35)
# trainer
result = OfflineTrainer(

View File

@ -190,5 +190,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non
policy, _, _ = get_agents(args)
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=1, render=args.render, eval_mode=True)
result = collector.collect(n_episode=1, render=args.render)
result.pprint_asdict()

View File

@ -285,5 +285,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non
)
policy, _, _ = get_agents(args)
collector = Collector(policy, env)
collector_result = collector.collect(n_episode=1, render=args.render, eval_mode=True)
collector_result = collector.collect(n_episode=1, render=args.render)
collector_result.pprint_asdict()

View File

@ -230,5 +230,5 @@ def watch(
policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent)
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=1, render=args.render, eval_mode=True)
result = collector.collect(n_episode=1, render=args.render)
result.pprint_asdict()

View File

@ -260,17 +260,8 @@ class BaseCollector(ABC):
) -> CollectStats:
pass
def collect(
self,
n_step: int | None = None,
n_episode: int | None = None,
random: bool = False,
render: float | None = None,
no_grad: bool = True,
reset_before_collect: bool = False,
gym_reset_kwargs: dict[str, Any] | None = None,
eval_mode: bool = False,
) -> CollectStats:
def collect(self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None,
no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None) -> CollectStats:
"""Collect a specified number of steps or episodes.
To ensure an unbiased sampling result with the n_episode option, this function will
@ -286,9 +277,6 @@ class BaseCollector(ABC):
(The collector needs the initial obs and info to function properly.)
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Only used if reset_before_collect is True.
:param eval_mode: whether to collect data in evaluation mode. Will
set the policy to training mode otherwise.
.. note::
One and only one collection number specification is permitted, either
@ -302,8 +290,7 @@ class BaseCollector(ABC):
if reset_before_collect:
self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs)
policy_mode_context = in_eval_mode if eval_mode else in_train_mode
with policy_mode_context(self.policy):
with in_eval_mode(self.policy): # safety precaution only
return self._collect(
n_step=n_step,
n_episode=n_episode,

View File

@ -130,10 +130,7 @@ class AgentFactory(ABC, ToStringMixin):
log.info(
f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})",
)
train_collector.collect(
n_step=self.sampling_config.start_timesteps,
random=self.sampling_config.start_timesteps_random,
)
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=self.sampling_config.start_timesteps_random)
return train_collector, test_collector
def set_policy_wrapper_factory(

View File

@ -337,7 +337,7 @@ class Experiment(ToStringMixin):
) -> None:
collector = Collector(policy, env)
collector.reset()
result = collector.collect(n_episode=num_episodes, render=render, eval_mode=True)
result = collector.collect(n_episode=num_episodes, render=render)
assert result.returns_stat is not None # for mypy
assert result.lens_stat is not None # for mypy
log.info(

View File

@ -226,8 +226,18 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
self.action_scaling = action_scaling
self.action_bound_method = action_bound_method
self.lr_scheduler = lr_scheduler
# whether the policy is in evaluation mode
self.is_eval = False # TODO: remove in favor of kwarg in compute_action/forward?
self.is_within_training_step = False
"""
flag indicating whether we are currently within a training step, which encompasses data collection
for training and the policy update (gradient steps).
It can be used, for example, to control whether a flag controlling deterministic evaluation should
indeed be applied, because within a training step, we typically always want to apply stochastic evaluation
(even if such a flag is enabled).
This flag should normally remain False and should be set to True only by the algorithm which performs
training steps.
"""
self._compile()
@property

View File

@ -107,7 +107,7 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
) -> Batch:
logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Categorical(logits=logits_BA)
act_B = dist.mode if self.deterministic_eval and self.is_eval else dist.sample()
act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample()
return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:

View File

@ -197,7 +197,7 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
# the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked
dist = self.dist_fn(action_dist_input_BD)
act_B = dist.mode if self.deterministic_eval and self.is_eval else dist.sample()
act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample()
# act is of dimension BA in continuous case and of dimension B in discrete
result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist)
return cast(DistBatchProtocol, result)

View File

@ -153,7 +153,7 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
) -> Batch:
(loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Independent(Normal(loc_B, scale_B), 1)
if self.deterministic_eval and self.is_eval:
if self.deterministic_eval and not self.is_within_training_step:
act_B = dist.mode
else:
act_B = dist.rsample()

View File

@ -175,7 +175,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
) -> DistLogProbBatchProtocol:
(loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Independent(Normal(loc=loc_B, scale=scale_B), 1)
if self.deterministic_eval and self.is_eval:
if self.deterministic_eval and not self.is_within_training_step:
act_B = dist.mode
else:
act_B = dist.rsample()

View File

@ -3,6 +3,7 @@ import time
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from collections.abc import Callable
from contextlib import contextmanager
from dataclasses import asdict
from typing import Optional, Tuple
@ -407,23 +408,34 @@ class BaseTrainer(ABC):
return test_stat, stop_fn_flag
@contextmanager
def _is_within_training_step_enabled(self, is_within_training_step: bool):
old_value = self.policy.is_within_training_step
try:
self.policy.is_within_training_step = is_within_training_step
yield
finally:
self.policy.is_within_training_step = old_value
def training_step(self) -> Tuple[CollectStatsBase, Optional[TrainingStats], bool]:
should_stop_training = False
with self._is_within_training_step_enabled(True):
if self.train_collector is not None:
collect_stats = self._collect_training_data()
should_stop_training = self._test_in_train(collect_stats)
else:
collect_stats = CollectStatsBase(
n_collected_episodes=len(self.buffer),
)
should_stop_training = False
if not should_stop_training:
training_stats = self.policy_update_fn(collect_stats)
else:
training_stats = None
if self.train_collector is not None:
collect_stats = self._collect_training_data()
should_stop_training = self._test_in_train(collect_stats)
else:
collect_stats = CollectStatsBase(
n_collected_episodes=len(self.buffer),
)
return collect_stats, training_stats, should_stop_training
if not should_stop_training:
training_stats = self.policy_update_fn(collect_stats)
else:
training_stats = None
return collect_stats, training_stats, should_stop_training
def _collect_training_data(self) -> CollectStats:
"""Performs training data collection
@ -434,10 +446,7 @@ class BaseTrainer(ABC):
assert self.train_collector is not None
if self.train_fn:
self.train_fn(self.epoch, self.env_step)
collect_stats = self.train_collector.collect(
n_step=self.step_per_collect,
n_episode=self.episode_per_collect,
)
collect_stats = self.train_collector.collect(n_step=self.step_per_collect, n_episode=self.episode_per_collect)
self.env_step += collect_stats.n_collected_steps
@ -467,26 +476,28 @@ class BaseTrainer(ABC):
"""
should_stop_training = False
if (
collect_stats.n_collected_episodes > 0
and self.test_in_train
and self.stop_fn
and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore
):
assert self.test_collector is not None
test_result = test_episode(
self.test_collector,
self.test_fn,
self.epoch,
self.episode_per_test,
self.logger,
self.env_step,
)
assert test_result.returns_stat is not None # for mypy
if self.stop_fn(test_result.returns_stat.mean):
should_stop_training = True
self.best_reward = test_result.returns_stat.mean
self.best_reward_std = test_result.returns_stat.std
# Because we need to evaluate the policy, we need to temporarily leave the "is_training_step" semantics
with self._is_within_training_step_enabled(False):
if (
collect_stats.n_collected_episodes > 0
and self.test_in_train
and self.stop_fn
and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore
):
assert self.test_collector is not None
test_result = test_episode(
self.test_collector,
self.test_fn,
self.epoch,
self.episode_per_test,
self.logger,
self.env_step,
)
assert test_result.returns_stat is not None # for mypy
if self.stop_fn(test_result.returns_stat.mean):
should_stop_training = True
self.best_reward = test_result.returns_stat.mean
self.best_reward_std = test_result.returns_stat.std
return should_stop_training

View File

@ -27,7 +27,7 @@ def test_episode(
collector.reset(reset_stats=False)
if test_fn:
test_fn(epoch, global_step)
result = collector.collect(n_episode=n_episode, eval_mode=True)
result = collector.collect(n_episode=n_episode)
if reward_metric: # TODO: move into collector
rew = reward_metric(result.returns)
result.returns = rew