Allow to configure number of test episodes in high-level API

This commit is contained in:
Dominik Jain 2024-02-14 19:06:01 +01:00
parent 8742e3645c
commit bf391853dc
2 changed files with 18 additions and 3 deletions

View File

@ -184,7 +184,7 @@ class OnPolicyAgentFactory(AgentFactory, ABC):
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
repeat_per_collect=sampling_config.repeat_per_collect,
episode_per_test=sampling_config.num_test_envs,
episode_per_test=sampling_config.num_test_episodes_per_test_env,
batch_size=sampling_config.batch_size,
step_per_collect=sampling_config.step_per_collect,
save_best_fn=policy_persistence.get_save_best_fn(world),
@ -228,7 +228,7 @@ class OffPolicyAgentFactory(AgentFactory, ABC):
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
step_per_collect=sampling_config.step_per_collect,
episode_per_test=sampling_config.num_test_envs,
episode_per_test=sampling_config.num_test_episodes_per_test_env,
batch_size=sampling_config.batch_size,
save_best_fn=policy_persistence.get_save_best_fn(world),
logger=world.logger,

View File

@ -1,3 +1,4 @@
import math
import multiprocessing
from dataclasses import dataclass
@ -16,7 +17,10 @@ class SamplingConfig(ToStringMixin):
* collects environment steps/transitions (collection step), adding them to the (replay)
buffer (see :attr:`step_per_collect`)
* performs one or more gradient updates (see :attr:`update_per_step`).
* performs one or more gradient updates (see :attr:`update_per_step`),
and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate
agent performance.
The number of training steps in each epoch is indirectly determined by
:attr:`step_per_epoch`: As many training steps will be performed as are required in
@ -49,6 +53,12 @@ class SamplingConfig(ToStringMixin):
num_test_envs: int = 1
"""the number of test environments to use"""
num_test_episodes: int = 1
"""the total number of episodes to collect in each test step (across all test environments).
This should be a multiple of the number of test environments; if it is not, the effective
number of episodes collected will be the nearest multiple (rounded up).
"""
buffer_size: int = 4096
"""the total size of the sample/replay buffer, in which environment steps (transitions) are
stored"""
@ -119,3 +129,8 @@ class SamplingConfig(ToStringMixin):
def __post_init__(self) -> None:
if self.num_train_envs == -1:
self.num_train_envs = multiprocessing.cpu_count()
@property
def num_test_episodes_per_test_env(self) -> int:
""":return: the number of episodes to collect per test environment in every test step"""
return math.ceil(self.num_test_episodes / self.num_test_envs)