Allow to configure number of test episodes in high-level API
This commit is contained in:
parent
8742e3645c
commit
bf391853dc
@ -184,7 +184,7 @@ class OnPolicyAgentFactory(AgentFactory, ABC):
|
|||||||
max_epoch=sampling_config.num_epochs,
|
max_epoch=sampling_config.num_epochs,
|
||||||
step_per_epoch=sampling_config.step_per_epoch,
|
step_per_epoch=sampling_config.step_per_epoch,
|
||||||
repeat_per_collect=sampling_config.repeat_per_collect,
|
repeat_per_collect=sampling_config.repeat_per_collect,
|
||||||
episode_per_test=sampling_config.num_test_envs,
|
episode_per_test=sampling_config.num_test_episodes_per_test_env,
|
||||||
batch_size=sampling_config.batch_size,
|
batch_size=sampling_config.batch_size,
|
||||||
step_per_collect=sampling_config.step_per_collect,
|
step_per_collect=sampling_config.step_per_collect,
|
||||||
save_best_fn=policy_persistence.get_save_best_fn(world),
|
save_best_fn=policy_persistence.get_save_best_fn(world),
|
||||||
@ -228,7 +228,7 @@ class OffPolicyAgentFactory(AgentFactory, ABC):
|
|||||||
max_epoch=sampling_config.num_epochs,
|
max_epoch=sampling_config.num_epochs,
|
||||||
step_per_epoch=sampling_config.step_per_epoch,
|
step_per_epoch=sampling_config.step_per_epoch,
|
||||||
step_per_collect=sampling_config.step_per_collect,
|
step_per_collect=sampling_config.step_per_collect,
|
||||||
episode_per_test=sampling_config.num_test_envs,
|
episode_per_test=sampling_config.num_test_episodes_per_test_env,
|
||||||
batch_size=sampling_config.batch_size,
|
batch_size=sampling_config.batch_size,
|
||||||
save_best_fn=policy_persistence.get_save_best_fn(world),
|
save_best_fn=policy_persistence.get_save_best_fn(world),
|
||||||
logger=world.logger,
|
logger=world.logger,
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@ -16,7 +17,10 @@ class SamplingConfig(ToStringMixin):
|
|||||||
|
|
||||||
* collects environment steps/transitions (collection step), adding them to the (replay)
|
* collects environment steps/transitions (collection step), adding them to the (replay)
|
||||||
buffer (see :attr:`step_per_collect`)
|
buffer (see :attr:`step_per_collect`)
|
||||||
* performs one or more gradient updates (see :attr:`update_per_step`).
|
* performs one or more gradient updates (see :attr:`update_per_step`),
|
||||||
|
|
||||||
|
and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate
|
||||||
|
agent performance.
|
||||||
|
|
||||||
The number of training steps in each epoch is indirectly determined by
|
The number of training steps in each epoch is indirectly determined by
|
||||||
:attr:`step_per_epoch`: As many training steps will be performed as are required in
|
:attr:`step_per_epoch`: As many training steps will be performed as are required in
|
||||||
@ -49,6 +53,12 @@ class SamplingConfig(ToStringMixin):
|
|||||||
num_test_envs: int = 1
|
num_test_envs: int = 1
|
||||||
"""the number of test environments to use"""
|
"""the number of test environments to use"""
|
||||||
|
|
||||||
|
num_test_episodes: int = 1
|
||||||
|
"""the total number of episodes to collect in each test step (across all test environments).
|
||||||
|
This should be a multiple of the number of test environments; if it is not, the effective
|
||||||
|
number of episodes collected will be the nearest multiple (rounded up).
|
||||||
|
"""
|
||||||
|
|
||||||
buffer_size: int = 4096
|
buffer_size: int = 4096
|
||||||
"""the total size of the sample/replay buffer, in which environment steps (transitions) are
|
"""the total size of the sample/replay buffer, in which environment steps (transitions) are
|
||||||
stored"""
|
stored"""
|
||||||
@ -119,3 +129,8 @@ class SamplingConfig(ToStringMixin):
|
|||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.num_train_envs == -1:
|
if self.num_train_envs == -1:
|
||||||
self.num_train_envs = multiprocessing.cpu_count()
|
self.num_train_envs = multiprocessing.cpu_count()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_test_episodes_per_test_env(self) -> int:
|
||||||
|
""":return: the number of episodes to collect per test environment in every test step"""
|
||||||
|
return math.ceil(self.num_test_episodes / self.num_test_envs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user