Allow to configure number of test episodes in high-level API
This commit is contained in:
parent
8742e3645c
commit
bf391853dc
@ -184,7 +184,7 @@ class OnPolicyAgentFactory(AgentFactory, ABC):
|
||||
max_epoch=sampling_config.num_epochs,
|
||||
step_per_epoch=sampling_config.step_per_epoch,
|
||||
repeat_per_collect=sampling_config.repeat_per_collect,
|
||||
episode_per_test=sampling_config.num_test_envs,
|
||||
episode_per_test=sampling_config.num_test_episodes_per_test_env,
|
||||
batch_size=sampling_config.batch_size,
|
||||
step_per_collect=sampling_config.step_per_collect,
|
||||
save_best_fn=policy_persistence.get_save_best_fn(world),
|
||||
@ -228,7 +228,7 @@ class OffPolicyAgentFactory(AgentFactory, ABC):
|
||||
max_epoch=sampling_config.num_epochs,
|
||||
step_per_epoch=sampling_config.step_per_epoch,
|
||||
step_per_collect=sampling_config.step_per_collect,
|
||||
episode_per_test=sampling_config.num_test_envs,
|
||||
episode_per_test=sampling_config.num_test_episodes_per_test_env,
|
||||
batch_size=sampling_config.batch_size,
|
||||
save_best_fn=policy_persistence.get_save_best_fn(world),
|
||||
logger=world.logger,
|
||||
|
@ -1,3 +1,4 @@
|
||||
import math
|
||||
import multiprocessing
|
||||
from dataclasses import dataclass
|
||||
|
||||
@ -16,7 +17,10 @@ class SamplingConfig(ToStringMixin):
|
||||
|
||||
* collects environment steps/transitions (collection step), adding them to the (replay)
|
||||
buffer (see :attr:`step_per_collect`)
|
||||
* performs one or more gradient updates (see :attr:`update_per_step`).
|
||||
* performs one or more gradient updates (see :attr:`update_per_step`),
|
||||
|
||||
and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate
|
||||
agent performance.
|
||||
|
||||
The number of training steps in each epoch is indirectly determined by
|
||||
:attr:`step_per_epoch`: As many training steps will be performed as are required in
|
||||
@ -49,6 +53,12 @@ class SamplingConfig(ToStringMixin):
|
||||
num_test_envs: int = 1
|
||||
"""the number of test environments to use"""
|
||||
|
||||
num_test_episodes: int = 1
|
||||
"""the total number of episodes to collect in each test step (across all test environments).
|
||||
This should be a multiple of the number of test environments; if it is not, the effective
|
||||
number of episodes collected will be the nearest multiple (rounded up).
|
||||
"""
|
||||
|
||||
buffer_size: int = 4096
|
||||
"""the total size of the sample/replay buffer, in which environment steps (transitions) are
|
||||
stored"""
|
||||
@ -119,3 +129,8 @@ class SamplingConfig(ToStringMixin):
|
||||
def __post_init__(self) -> None:
|
||||
if self.num_train_envs == -1:
|
||||
self.num_train_envs = multiprocessing.cpu_count()
|
||||
|
||||
@property
|
||||
def num_test_episodes_per_test_env(self) -> int:
|
||||
""":return: the number of episodes to collect per test environment in every test step"""
|
||||
return math.ceil(self.num_test_episodes / self.num_test_envs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user