High-level API improvements (#1014)

- [X] I have added the correct label(s) to this Pull Request or linked
the relevant issue(s)
- [X] I have provided a description of the changes in this Pull Request
- [X] I have added documentation for my changes
- [ ] If applicable, I have added tests to cover my changes.
- [X] I have reformatted the code using `poe format` 
- [X] I have checked style and types with `poe lint` and `poe
type-check`
- [ ] (Optional) I ran tests locally with `poe test` 
(or a subset of them with `poe test-reduced`) ,and they pass
- [X] (Optional) I have tested that documentation builds correctly with
`poe doc-build`

Changes in this PR (see individual commits):
* Fix: SamplingConfig.start_timesteps_random was not used
* Environments: Support use of different test environment factory in
convenience constructors `from_factory*`
* SamplingConfig: Improve/extend docstrings, clearly explaining the
parameters
* SamplingConfig: Change default of repeat_per_collect to 1
* Improve logging
* Fix doc-build on Windows
This commit is contained in:
Michael Panchenko 2023-12-21 10:04:14 -06:00 committed by GitHub
commit 5d09645a2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 120 additions and 11 deletions

View File

@ -72,7 +72,7 @@ def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix=""
files_in_dir = os.listdir(src_root)
module_names = [f[:-3] for f in files_in_dir if f.endswith(".py") and not f.startswith("_")]
subdir_refs = [
os.path.join(f, "index")
f"{f}/index"
for f in files_in_dir
if os.path.isdir(os.path.join(src_root, f)) and not f.startswith("_")
]
@ -108,7 +108,7 @@ def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix=""
f[:-3] for f in files_in_dir if f.endswith(".py") and not f.startswith("_")
]
subdir_refs = [
os.path.join(f, "index")
f"{f}/index"
for f in files_in_dir
if os.path.isdir(os.path.join(root, dirname, f)) and not f.startswith("_")
]

View File

@ -23,7 +23,7 @@ class ReplayBuffer:
:param size: the maximum size of replay buffer.
:param stack_num: the frame-stack sampling argument, should be greater than or
equal to 1. Default to 1 (no stacking).
:param ignore_obs_next: whether to store obs_next. Default to False.
:param ignore_obs_next: whether to not store obs_next. Default to False.
:param save_only_last_obs: only save the last obs/obs_next when it has a shape
of (timestep, ...) because of temporal stacking. Default to False.
:param sample_avail: the parameter indicating sampling only available index

View File

@ -115,7 +115,13 @@ class AgentFactory(ABC, ToStringMixin):
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, envs.test_envs)
if self.sampling_config.start_timesteps > 0:
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=True)
log.info(
f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})",
)
train_collector.collect(
n_step=self.sampling_config.start_timesteps,
random=self.sampling_config.start_timesteps_random,
)
return train_collector, test_collector
def set_policy_wrapper_factory(

View File

@ -6,29 +6,115 @@ from tianshou.utils.string import ToStringMixin
@dataclass
class SamplingConfig(ToStringMixin):
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
"""Configuration of sampling, epochs, parallelization, buffers, collectors, and batching."""
# TODO: What are the most reasonable defaults?
num_epochs: int = 100
"""
the number of epochs to run training for. An epoch is the outermost iteration level and each
epoch consists of a number of training steps and a test step, where each training step
* collects environment steps/transitions (collection step), adding them to the (replay)
buffer (see :attr:`step_per_collect`)
* performs one or more gradient updates (see :attr:`update_per_step`).
The number of training steps in each epoch is indirectly determined by
:attr:`step_per_epoch`: As many training steps will be performed as are required in
order to reach :attr:`step_per_epoch` total steps in the training environments.
Specifically, if the number of transitions collected per step is `c` (see
:attr:`step_per_collect`) and :attr:`step_per_epoch` is set to `s`, then the number
of training steps per epoch is `ceil(s / c)`.
Therefore, if `num_epochs = e`, the total number of environment steps taken during training
can be computed as `e * ceil(s / c) * c`.
"""
step_per_epoch: int = 30000
"""
the total number of environment steps to be made per epoch. See :attr:`num_epochs` for
an explanation of epoch semantics.
"""
batch_size: int = 64
"""for off-policy algorithms, this is the number of environment steps/transitions to sample
from the buffer for a gradient update; for on-policy algorithms, its use is algorithm-specific.
On-policy algorithms use the full buffer that was collected in the preceding collection step
but they may use this parameter to perform the gradient update using mini-batches of this size
(causing the gradient to be less accurate, a form of regularization).
"""
num_train_envs: int = -1
"""the number of training environments to use. If set to -1, use number of CPUs/threads."""
num_test_envs: int = 1
"""the number of test environments to use"""
buffer_size: int = 4096
"""the total size of the sample/replay buffer, in which environment steps (transitions) are
stored"""
step_per_collect: int = 2048
repeat_per_collect: int | None = 10
"""
the number of environment steps/transitions to collect in each collection step before the
network update within each training step.
Note that the exact number can be reached only if this is a multiple of the number of
training environments being used, as each training environment will produce the same
(non-zero) number of transitions.
Specifically, if this is set to `n` and `m` training environments are used, then the total
number of transitions collected per collection step is `ceil(n / m) * m =: c`.
See :attr:`num_epochs` for information on the total number of environment steps being
collected during training.
"""
repeat_per_collect: int | None = 1
"""
controls, within one gradient update step of an on-policy algorithm, the number of times an
actual gradient update is applied using the full collected dataset, i.e. if the parameter is
`n`, then the collected data shall be used five times to update the policy within the same
training step.
The parameter is ignored and may be set to None for off-policy and offline algorithms.
"""
update_per_step: float = 1.0
"""
Only used in off-policy algorithms.
How many gradient steps to perform per step in the environment (i.e., per sample added to the buffer).
for off-policy algorithms only: the number of gradient steps to perform per sample
collected (see :attr:`step_per_collect`).
Specifically, if this is set to `u` and the number of samples collected in the preceding
collection step is `n`, then `round(u * n)` gradient steps will be performed.
Note that for on-policy algorithms, only a single gradient update is usually performed,
because thereafter, the samples no longer reflect the behavior of the updated policy.
To change the number of gradient updates for an on-policy algorithm, use parameter
:attr:`repeat_per_collect` instead.
"""
start_timesteps: int = 0
"""
the number of environment steps to collect before the actual training loop begins
"""
start_timesteps_random: bool = False
# TODO can we set the parameters below intelligently? Perhaps based on env. representation?
"""
whether to use a random policy (instead of the initial or restored policy to be trained)
when collecting the initial :attr:`start_timesteps` environment steps before training
"""
replay_buffer_ignore_obs_next: bool = False
replay_buffer_save_only_last_obs: bool = False
"""if True, only the most recent frame is saved when appending to experiences rather than the
full stacked frames. This avoids duplicating observations in buffer memory. Set to False to
save stacked frames in full.
"""
replay_buffer_stack_num: int = 1
"""
the number of consecutive environment observations to stack and use as the observation input
to the agent for each time step. Setting this to a value greater than 1 can help agents learn
temporal aspects (e.g. velocities of moving objects for which only positions are observed).
"""
def __post_init__(self) -> None:
if self.num_train_envs == -1:

View File

@ -80,6 +80,7 @@ class Environments(ToStringMixin, ABC):
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
test_factory_fn: Callable[[], gym.Env] | None = None,
) -> "Environments":
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
@ -88,10 +89,14 @@ class Environments(ToStringMixin, ABC):
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:param test_factory_fn: the factory to use for the creation of test environment instances;
if None, use `factory_fn` for all environments (train and test)
:return: the instance
"""
if test_factory_fn is None:
test_factory_fn = factory_fn
train_envs = venv_type.create_venv([factory_fn] * num_training_envs)
test_envs = venv_type.create_venv([factory_fn] * num_test_envs)
test_envs = venv_type.create_venv([test_factory_fn] * num_test_envs)
env = factory_fn()
match env_type:
case EnvType.CONTINUOUS:
@ -152,6 +157,7 @@ class ContinuousEnvironments(Environments):
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
test_factory_fn: Callable[[], gym.Env] | None = None,
) -> "ContinuousEnvironments":
"""Creates an instance from a factory function that creates a single instance.
@ -159,6 +165,8 @@ class ContinuousEnvironments(Environments):
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:param test_factory_fn: the factory to use for the creation of test environment instances;
if None, use `factory_fn` for all environments (train and test)
:return: the instance
"""
return cast(
@ -169,6 +177,7 @@ class ContinuousEnvironments(Environments):
venv_type,
num_training_envs,
num_test_envs,
test_factory_fn=test_factory_fn,
),
)
@ -217,6 +226,7 @@ class DiscreteEnvironments(Environments):
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
test_factory_fn: Callable[[], gym.Env] | None = None,
) -> "DiscreteEnvironments":
"""Creates an instance from a factory function that creates a single instance.
@ -224,6 +234,8 @@ class DiscreteEnvironments(Environments):
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:param test_factory_fn: the factory to use for the creation of test environment instances;
if None, use `factory_fn` for all environments (train and test)
:return: the instance
"""
return cast(
@ -234,6 +246,7 @@ class DiscreteEnvironments(Environments):
venv_type,
num_training_envs,
num_test_envs,
test_factory_fn=test_factory_fn,
),
)

View File

@ -251,7 +251,9 @@ class Experiment(ToStringMixin):
logger = LazyLogger()
# create policy and collectors
log.info("Creating policy")
policy = self.agent_factory.create_policy(envs, self.config.device)
log.info("Creating collectors")
train_collector, test_collector = self.agent_factory.create_train_test_collector(
policy,
envs,
@ -277,15 +279,17 @@ class Experiment(ToStringMixin):
)
# train policy
log.info("Starting training")
trainer_result: dict[str, Any] | None = None
if self.config.train:
trainer = self.agent_factory.create_trainer(world, policy_persistence)
world.trainer = trainer
trainer_result = trainer.run()
log.info(f"Trainer result:\n{pformat(trainer_result)}")
log.info(f"Training result:\n{pformat(trainer_result)}")
# watch agent performance
if self.config.watch:
log.info("Watching agent performance")
self._watch_agent(
self.config.watch_num_episodes,
policy,