High-level API improvements (#1014)
- [X] I have added the correct label(s) to this Pull Request or linked the relevant issue(s) - [X] I have provided a description of the changes in this Pull Request - [X] I have added documentation for my changes - [ ] If applicable, I have added tests to cover my changes. - [X] I have reformatted the code using `poe format` - [X] I have checked style and types with `poe lint` and `poe type-check` - [ ] (Optional) I ran tests locally with `poe test` (or a subset of them with `poe test-reduced`) ,and they pass - [X] (Optional) I have tested that documentation builds correctly with `poe doc-build` Changes in this PR (see individual commits): * Fix: SamplingConfig.start_timesteps_random was not used * Environments: Support use of different test environment factory in convenience constructors `from_factory*` * SamplingConfig: Improve/extend docstrings, clearly explaining the parameters * SamplingConfig: Change default of repeat_per_collect to 1 * Improve logging * Fix doc-build on Windows
This commit is contained in:
commit
5d09645a2c
@ -72,7 +72,7 @@ def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix=""
|
|||||||
files_in_dir = os.listdir(src_root)
|
files_in_dir = os.listdir(src_root)
|
||||||
module_names = [f[:-3] for f in files_in_dir if f.endswith(".py") and not f.startswith("_")]
|
module_names = [f[:-3] for f in files_in_dir if f.endswith(".py") and not f.startswith("_")]
|
||||||
subdir_refs = [
|
subdir_refs = [
|
||||||
os.path.join(f, "index")
|
f"{f}/index"
|
||||||
for f in files_in_dir
|
for f in files_in_dir
|
||||||
if os.path.isdir(os.path.join(src_root, f)) and not f.startswith("_")
|
if os.path.isdir(os.path.join(src_root, f)) and not f.startswith("_")
|
||||||
]
|
]
|
||||||
@ -108,7 +108,7 @@ def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix=""
|
|||||||
f[:-3] for f in files_in_dir if f.endswith(".py") and not f.startswith("_")
|
f[:-3] for f in files_in_dir if f.endswith(".py") and not f.startswith("_")
|
||||||
]
|
]
|
||||||
subdir_refs = [
|
subdir_refs = [
|
||||||
os.path.join(f, "index")
|
f"{f}/index"
|
||||||
for f in files_in_dir
|
for f in files_in_dir
|
||||||
if os.path.isdir(os.path.join(root, dirname, f)) and not f.startswith("_")
|
if os.path.isdir(os.path.join(root, dirname, f)) and not f.startswith("_")
|
||||||
]
|
]
|
||||||
|
@ -23,7 +23,7 @@ class ReplayBuffer:
|
|||||||
:param size: the maximum size of replay buffer.
|
:param size: the maximum size of replay buffer.
|
||||||
:param stack_num: the frame-stack sampling argument, should be greater than or
|
:param stack_num: the frame-stack sampling argument, should be greater than or
|
||||||
equal to 1. Default to 1 (no stacking).
|
equal to 1. Default to 1 (no stacking).
|
||||||
:param ignore_obs_next: whether to store obs_next. Default to False.
|
:param ignore_obs_next: whether to not store obs_next. Default to False.
|
||||||
:param save_only_last_obs: only save the last obs/obs_next when it has a shape
|
:param save_only_last_obs: only save the last obs/obs_next when it has a shape
|
||||||
of (timestep, ...) because of temporal stacking. Default to False.
|
of (timestep, ...) because of temporal stacking. Default to False.
|
||||||
:param sample_avail: the parameter indicating sampling only available index
|
:param sample_avail: the parameter indicating sampling only available index
|
||||||
|
@ -115,7 +115,13 @@ class AgentFactory(ABC, ToStringMixin):
|
|||||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||||
test_collector = Collector(policy, envs.test_envs)
|
test_collector = Collector(policy, envs.test_envs)
|
||||||
if self.sampling_config.start_timesteps > 0:
|
if self.sampling_config.start_timesteps > 0:
|
||||||
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=True)
|
log.info(
|
||||||
|
f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})",
|
||||||
|
)
|
||||||
|
train_collector.collect(
|
||||||
|
n_step=self.sampling_config.start_timesteps,
|
||||||
|
random=self.sampling_config.start_timesteps_random,
|
||||||
|
)
|
||||||
return train_collector, test_collector
|
return train_collector, test_collector
|
||||||
|
|
||||||
def set_policy_wrapper_factory(
|
def set_policy_wrapper_factory(
|
||||||
|
@ -6,29 +6,115 @@ from tianshou.utils.string import ToStringMixin
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SamplingConfig(ToStringMixin):
|
class SamplingConfig(ToStringMixin):
|
||||||
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
|
"""Configuration of sampling, epochs, parallelization, buffers, collectors, and batching."""
|
||||||
|
|
||||||
# TODO: What are the most reasonable defaults?
|
# TODO: What are the most reasonable defaults?
|
||||||
num_epochs: int = 100
|
num_epochs: int = 100
|
||||||
|
"""
|
||||||
|
the number of epochs to run training for. An epoch is the outermost iteration level and each
|
||||||
|
epoch consists of a number of training steps and a test step, where each training step
|
||||||
|
|
||||||
|
* collects environment steps/transitions (collection step), adding them to the (replay)
|
||||||
|
buffer (see :attr:`step_per_collect`)
|
||||||
|
* performs one or more gradient updates (see :attr:`update_per_step`).
|
||||||
|
|
||||||
|
The number of training steps in each epoch is indirectly determined by
|
||||||
|
:attr:`step_per_epoch`: As many training steps will be performed as are required in
|
||||||
|
order to reach :attr:`step_per_epoch` total steps in the training environments.
|
||||||
|
Specifically, if the number of transitions collected per step is `c` (see
|
||||||
|
:attr:`step_per_collect`) and :attr:`step_per_epoch` is set to `s`, then the number
|
||||||
|
of training steps per epoch is `ceil(s / c)`.
|
||||||
|
|
||||||
|
Therefore, if `num_epochs = e`, the total number of environment steps taken during training
|
||||||
|
can be computed as `e * ceil(s / c) * c`.
|
||||||
|
"""
|
||||||
|
|
||||||
step_per_epoch: int = 30000
|
step_per_epoch: int = 30000
|
||||||
|
"""
|
||||||
|
the total number of environment steps to be made per epoch. See :attr:`num_epochs` for
|
||||||
|
an explanation of epoch semantics.
|
||||||
|
"""
|
||||||
|
|
||||||
batch_size: int = 64
|
batch_size: int = 64
|
||||||
|
"""for off-policy algorithms, this is the number of environment steps/transitions to sample
|
||||||
|
from the buffer for a gradient update; for on-policy algorithms, its use is algorithm-specific.
|
||||||
|
On-policy algorithms use the full buffer that was collected in the preceding collection step
|
||||||
|
but they may use this parameter to perform the gradient update using mini-batches of this size
|
||||||
|
(causing the gradient to be less accurate, a form of regularization).
|
||||||
|
"""
|
||||||
|
|
||||||
num_train_envs: int = -1
|
num_train_envs: int = -1
|
||||||
"""the number of training environments to use. If set to -1, use number of CPUs/threads."""
|
"""the number of training environments to use. If set to -1, use number of CPUs/threads."""
|
||||||
|
|
||||||
num_test_envs: int = 1
|
num_test_envs: int = 1
|
||||||
|
"""the number of test environments to use"""
|
||||||
|
|
||||||
buffer_size: int = 4096
|
buffer_size: int = 4096
|
||||||
|
"""the total size of the sample/replay buffer, in which environment steps (transitions) are
|
||||||
|
stored"""
|
||||||
|
|
||||||
step_per_collect: int = 2048
|
step_per_collect: int = 2048
|
||||||
repeat_per_collect: int | None = 10
|
"""
|
||||||
|
the number of environment steps/transitions to collect in each collection step before the
|
||||||
|
network update within each training step.
|
||||||
|
Note that the exact number can be reached only if this is a multiple of the number of
|
||||||
|
training environments being used, as each training environment will produce the same
|
||||||
|
(non-zero) number of transitions.
|
||||||
|
Specifically, if this is set to `n` and `m` training environments are used, then the total
|
||||||
|
number of transitions collected per collection step is `ceil(n / m) * m =: c`.
|
||||||
|
|
||||||
|
See :attr:`num_epochs` for information on the total number of environment steps being
|
||||||
|
collected during training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
repeat_per_collect: int | None = 1
|
||||||
|
"""
|
||||||
|
controls, within one gradient update step of an on-policy algorithm, the number of times an
|
||||||
|
actual gradient update is applied using the full collected dataset, i.e. if the parameter is
|
||||||
|
`n`, then the collected data shall be used five times to update the policy within the same
|
||||||
|
training step.
|
||||||
|
|
||||||
|
The parameter is ignored and may be set to None for off-policy and offline algorithms.
|
||||||
|
"""
|
||||||
|
|
||||||
update_per_step: float = 1.0
|
update_per_step: float = 1.0
|
||||||
"""
|
"""
|
||||||
Only used in off-policy algorithms.
|
for off-policy algorithms only: the number of gradient steps to perform per sample
|
||||||
How many gradient steps to perform per step in the environment (i.e., per sample added to the buffer).
|
collected (see :attr:`step_per_collect`).
|
||||||
|
Specifically, if this is set to `u` and the number of samples collected in the preceding
|
||||||
|
collection step is `n`, then `round(u * n)` gradient steps will be performed.
|
||||||
|
|
||||||
|
Note that for on-policy algorithms, only a single gradient update is usually performed,
|
||||||
|
because thereafter, the samples no longer reflect the behavior of the updated policy.
|
||||||
|
To change the number of gradient updates for an on-policy algorithm, use parameter
|
||||||
|
:attr:`repeat_per_collect` instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
start_timesteps: int = 0
|
start_timesteps: int = 0
|
||||||
|
"""
|
||||||
|
the number of environment steps to collect before the actual training loop begins
|
||||||
|
"""
|
||||||
|
|
||||||
start_timesteps_random: bool = False
|
start_timesteps_random: bool = False
|
||||||
# TODO can we set the parameters below intelligently? Perhaps based on env. representation?
|
"""
|
||||||
|
whether to use a random policy (instead of the initial or restored policy to be trained)
|
||||||
|
when collecting the initial :attr:`start_timesteps` environment steps before training
|
||||||
|
"""
|
||||||
|
|
||||||
replay_buffer_ignore_obs_next: bool = False
|
replay_buffer_ignore_obs_next: bool = False
|
||||||
|
|
||||||
replay_buffer_save_only_last_obs: bool = False
|
replay_buffer_save_only_last_obs: bool = False
|
||||||
|
"""if True, only the most recent frame is saved when appending to experiences rather than the
|
||||||
|
full stacked frames. This avoids duplicating observations in buffer memory. Set to False to
|
||||||
|
save stacked frames in full.
|
||||||
|
"""
|
||||||
|
|
||||||
replay_buffer_stack_num: int = 1
|
replay_buffer_stack_num: int = 1
|
||||||
|
"""
|
||||||
|
the number of consecutive environment observations to stack and use as the observation input
|
||||||
|
to the agent for each time step. Setting this to a value greater than 1 can help agents learn
|
||||||
|
temporal aspects (e.g. velocities of moving objects for which only positions are observed).
|
||||||
|
"""
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.num_train_envs == -1:
|
if self.num_train_envs == -1:
|
||||||
|
@ -80,6 +80,7 @@ class Environments(ToStringMixin, ABC):
|
|||||||
venv_type: VectorEnvType,
|
venv_type: VectorEnvType,
|
||||||
num_training_envs: int,
|
num_training_envs: int,
|
||||||
num_test_envs: int,
|
num_test_envs: int,
|
||||||
|
test_factory_fn: Callable[[], gym.Env] | None = None,
|
||||||
) -> "Environments":
|
) -> "Environments":
|
||||||
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
|
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
|
||||||
|
|
||||||
@ -88,10 +89,14 @@ class Environments(ToStringMixin, ABC):
|
|||||||
:param venv_type: the vector environment type to use for parallelization
|
:param venv_type: the vector environment type to use for parallelization
|
||||||
:param num_training_envs: the number of training environments to create
|
:param num_training_envs: the number of training environments to create
|
||||||
:param num_test_envs: the number of test environments to create
|
:param num_test_envs: the number of test environments to create
|
||||||
|
:param test_factory_fn: the factory to use for the creation of test environment instances;
|
||||||
|
if None, use `factory_fn` for all environments (train and test)
|
||||||
:return: the instance
|
:return: the instance
|
||||||
"""
|
"""
|
||||||
|
if test_factory_fn is None:
|
||||||
|
test_factory_fn = factory_fn
|
||||||
train_envs = venv_type.create_venv([factory_fn] * num_training_envs)
|
train_envs = venv_type.create_venv([factory_fn] * num_training_envs)
|
||||||
test_envs = venv_type.create_venv([factory_fn] * num_test_envs)
|
test_envs = venv_type.create_venv([test_factory_fn] * num_test_envs)
|
||||||
env = factory_fn()
|
env = factory_fn()
|
||||||
match env_type:
|
match env_type:
|
||||||
case EnvType.CONTINUOUS:
|
case EnvType.CONTINUOUS:
|
||||||
@ -152,6 +157,7 @@ class ContinuousEnvironments(Environments):
|
|||||||
venv_type: VectorEnvType,
|
venv_type: VectorEnvType,
|
||||||
num_training_envs: int,
|
num_training_envs: int,
|
||||||
num_test_envs: int,
|
num_test_envs: int,
|
||||||
|
test_factory_fn: Callable[[], gym.Env] | None = None,
|
||||||
) -> "ContinuousEnvironments":
|
) -> "ContinuousEnvironments":
|
||||||
"""Creates an instance from a factory function that creates a single instance.
|
"""Creates an instance from a factory function that creates a single instance.
|
||||||
|
|
||||||
@ -159,6 +165,8 @@ class ContinuousEnvironments(Environments):
|
|||||||
:param venv_type: the vector environment type to use for parallelization
|
:param venv_type: the vector environment type to use for parallelization
|
||||||
:param num_training_envs: the number of training environments to create
|
:param num_training_envs: the number of training environments to create
|
||||||
:param num_test_envs: the number of test environments to create
|
:param num_test_envs: the number of test environments to create
|
||||||
|
:param test_factory_fn: the factory to use for the creation of test environment instances;
|
||||||
|
if None, use `factory_fn` for all environments (train and test)
|
||||||
:return: the instance
|
:return: the instance
|
||||||
"""
|
"""
|
||||||
return cast(
|
return cast(
|
||||||
@ -169,6 +177,7 @@ class ContinuousEnvironments(Environments):
|
|||||||
venv_type,
|
venv_type,
|
||||||
num_training_envs,
|
num_training_envs,
|
||||||
num_test_envs,
|
num_test_envs,
|
||||||
|
test_factory_fn=test_factory_fn,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -217,6 +226,7 @@ class DiscreteEnvironments(Environments):
|
|||||||
venv_type: VectorEnvType,
|
venv_type: VectorEnvType,
|
||||||
num_training_envs: int,
|
num_training_envs: int,
|
||||||
num_test_envs: int,
|
num_test_envs: int,
|
||||||
|
test_factory_fn: Callable[[], gym.Env] | None = None,
|
||||||
) -> "DiscreteEnvironments":
|
) -> "DiscreteEnvironments":
|
||||||
"""Creates an instance from a factory function that creates a single instance.
|
"""Creates an instance from a factory function that creates a single instance.
|
||||||
|
|
||||||
@ -224,6 +234,8 @@ class DiscreteEnvironments(Environments):
|
|||||||
:param venv_type: the vector environment type to use for parallelization
|
:param venv_type: the vector environment type to use for parallelization
|
||||||
:param num_training_envs: the number of training environments to create
|
:param num_training_envs: the number of training environments to create
|
||||||
:param num_test_envs: the number of test environments to create
|
:param num_test_envs: the number of test environments to create
|
||||||
|
:param test_factory_fn: the factory to use for the creation of test environment instances;
|
||||||
|
if None, use `factory_fn` for all environments (train and test)
|
||||||
:return: the instance
|
:return: the instance
|
||||||
"""
|
"""
|
||||||
return cast(
|
return cast(
|
||||||
@ -234,6 +246,7 @@ class DiscreteEnvironments(Environments):
|
|||||||
venv_type,
|
venv_type,
|
||||||
num_training_envs,
|
num_training_envs,
|
||||||
num_test_envs,
|
num_test_envs,
|
||||||
|
test_factory_fn=test_factory_fn,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -251,7 +251,9 @@ class Experiment(ToStringMixin):
|
|||||||
logger = LazyLogger()
|
logger = LazyLogger()
|
||||||
|
|
||||||
# create policy and collectors
|
# create policy and collectors
|
||||||
|
log.info("Creating policy")
|
||||||
policy = self.agent_factory.create_policy(envs, self.config.device)
|
policy = self.agent_factory.create_policy(envs, self.config.device)
|
||||||
|
log.info("Creating collectors")
|
||||||
train_collector, test_collector = self.agent_factory.create_train_test_collector(
|
train_collector, test_collector = self.agent_factory.create_train_test_collector(
|
||||||
policy,
|
policy,
|
||||||
envs,
|
envs,
|
||||||
@ -277,15 +279,17 @@ class Experiment(ToStringMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# train policy
|
# train policy
|
||||||
|
log.info("Starting training")
|
||||||
trainer_result: dict[str, Any] | None = None
|
trainer_result: dict[str, Any] | None = None
|
||||||
if self.config.train:
|
if self.config.train:
|
||||||
trainer = self.agent_factory.create_trainer(world, policy_persistence)
|
trainer = self.agent_factory.create_trainer(world, policy_persistence)
|
||||||
world.trainer = trainer
|
world.trainer = trainer
|
||||||
trainer_result = trainer.run()
|
trainer_result = trainer.run()
|
||||||
log.info(f"Trainer result:\n{pformat(trainer_result)}")
|
log.info(f"Training result:\n{pformat(trainer_result)}")
|
||||||
|
|
||||||
# watch agent performance
|
# watch agent performance
|
||||||
if self.config.watch:
|
if self.config.watch:
|
||||||
|
log.info("Watching agent performance")
|
||||||
self._watch_agent(
|
self._watch_agent(
|
||||||
self.config.watch_num_episodes,
|
self.config.watch_num_episodes,
|
||||||
policy,
|
policy,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user