Change default number of environments (train=#CPUs, test=1)
This commit is contained in:
parent
3cd6dcc307
commit
d684dae6cd
@ -1,3 +1,4 @@
|
||||
import multiprocessing
|
||||
from dataclasses import dataclass
|
||||
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
@ -11,8 +12,9 @@ class SamplingConfig(ToStringMixin):
|
||||
num_epochs: int = 100
|
||||
step_per_epoch: int = 30000
|
||||
batch_size: int = 64
|
||||
num_train_envs: int = 64
|
||||
num_test_envs: int = 10
|
||||
num_train_envs: int = -1
|
||||
"""the number of training environments to use. If set to -1, use number of CPUs/threads."""
|
||||
num_test_envs: int = 1
|
||||
buffer_size: int = 4096
|
||||
step_per_collect: int = 2048
|
||||
repeat_per_collect: int | None = 10
|
||||
@ -27,3 +29,7 @@ class SamplingConfig(ToStringMixin):
|
||||
replay_buffer_ignore_obs_next: bool = False
|
||||
replay_buffer_save_only_last_obs: bool = False
|
||||
replay_buffer_stack_num: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.num_train_envs == -1:
|
||||
self.num_train_envs = multiprocessing.cpu_count()
|
||||
|
Loading…
x
Reference in New Issue
Block a user