Allow to configure number of test episodes in high-level API

This commit is contained in:
Dominik Jain 2024-02-14 19:06:01 +01:00
parent 8742e3645c
commit bf391853dc
2 changed files with 18 additions and 3 deletions

View File

@ -184,7 +184,7 @@ class OnPolicyAgentFactory(AgentFactory, ABC):
max_epoch=sampling_config.num_epochs, max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch, step_per_epoch=sampling_config.step_per_epoch,
repeat_per_collect=sampling_config.repeat_per_collect, repeat_per_collect=sampling_config.repeat_per_collect,
episode_per_test=sampling_config.num_test_envs, episode_per_test=sampling_config.num_test_episodes_per_test_env,
batch_size=sampling_config.batch_size, batch_size=sampling_config.batch_size,
step_per_collect=sampling_config.step_per_collect, step_per_collect=sampling_config.step_per_collect,
save_best_fn=policy_persistence.get_save_best_fn(world), save_best_fn=policy_persistence.get_save_best_fn(world),
@ -228,7 +228,7 @@ class OffPolicyAgentFactory(AgentFactory, ABC):
max_epoch=sampling_config.num_epochs, max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch, step_per_epoch=sampling_config.step_per_epoch,
step_per_collect=sampling_config.step_per_collect, step_per_collect=sampling_config.step_per_collect,
episode_per_test=sampling_config.num_test_envs, episode_per_test=sampling_config.num_test_episodes_per_test_env,
batch_size=sampling_config.batch_size, batch_size=sampling_config.batch_size,
save_best_fn=policy_persistence.get_save_best_fn(world), save_best_fn=policy_persistence.get_save_best_fn(world),
logger=world.logger, logger=world.logger,

View File

@ -1,3 +1,4 @@
import math
import multiprocessing import multiprocessing
from dataclasses import dataclass from dataclasses import dataclass
@ -16,7 +17,10 @@ class SamplingConfig(ToStringMixin):
* collects environment steps/transitions (collection step), adding them to the (replay) * collects environment steps/transitions (collection step), adding them to the (replay)
buffer (see :attr:`step_per_collect`) buffer (see :attr:`step_per_collect`)
* performs one or more gradient updates (see :attr:`update_per_step`). * performs one or more gradient updates (see :attr:`update_per_step`),
and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate
agent performance.
The number of training steps in each epoch is indirectly determined by The number of training steps in each epoch is indirectly determined by
:attr:`step_per_epoch`: As many training steps will be performed as are required in :attr:`step_per_epoch`: As many training steps will be performed as are required in
@ -49,6 +53,12 @@ class SamplingConfig(ToStringMixin):
num_test_envs: int = 1 num_test_envs: int = 1
"""the number of test environments to use""" """the number of test environments to use"""
num_test_episodes: int = 1
"""the total number of episodes to collect in each test step (across all test environments).
This should be a multiple of the number of test environments; if it is not, the effective
number of episodes collected will be the nearest multiple (rounded up).
"""
buffer_size: int = 4096 buffer_size: int = 4096
"""the total size of the sample/replay buffer, in which environment steps (transitions) are """the total size of the sample/replay buffer, in which environment steps (transitions) are
stored""" stored"""
@ -119,3 +129,8 @@ class SamplingConfig(ToStringMixin):
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.num_train_envs == -1: if self.num_train_envs == -1:
self.num_train_envs = multiprocessing.cpu_count() self.num_train_envs = multiprocessing.cpu_count()
@property
def num_test_episodes_per_test_env(self) -> int:
""":return: the number of episodes to collect per test environment in every test step"""
return math.ceil(self.num_test_episodes / self.num_test_envs)