Change default number of environments (train=#CPUs, test=1)
This commit is contained in:
parent
3cd6dcc307
commit
d684dae6cd
@ -1,3 +1,4 @@
|
|||||||
|
import multiprocessing
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
@ -11,8 +12,9 @@ class SamplingConfig(ToStringMixin):
|
|||||||
num_epochs: int = 100
|
num_epochs: int = 100
|
||||||
step_per_epoch: int = 30000
|
step_per_epoch: int = 30000
|
||||||
batch_size: int = 64
|
batch_size: int = 64
|
||||||
num_train_envs: int = 64
|
num_train_envs: int = -1
|
||||||
num_test_envs: int = 10
|
"""the number of training environments to use. If set to -1, use number of CPUs/threads."""
|
||||||
|
num_test_envs: int = 1
|
||||||
buffer_size: int = 4096
|
buffer_size: int = 4096
|
||||||
step_per_collect: int = 2048
|
step_per_collect: int = 2048
|
||||||
repeat_per_collect: int | None = 10
|
repeat_per_collect: int | None = 10
|
||||||
@ -27,3 +29,7 @@ class SamplingConfig(ToStringMixin):
|
|||||||
replay_buffer_ignore_obs_next: bool = False
|
replay_buffer_ignore_obs_next: bool = False
|
||||||
replay_buffer_save_only_last_obs: bool = False
|
replay_buffer_save_only_last_obs: bool = False
|
||||||
replay_buffer_stack_num: int = 1
|
replay_buffer_stack_num: int = 1
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if self.num_train_envs == -1:
|
||||||
|
self.num_train_envs = multiprocessing.cpu_count()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user