of number of environments in SamplingConfig is used (values are now passed to factory method) This is clearer and removes the need to pass otherwise unnecessary configuration to environment factories at construction
120 lines
3.8 KiB
Python
120 lines
3.8 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import datetime
|
|
import os
|
|
from collections.abc import Sequence
|
|
|
|
from jsonargparse import CLI
|
|
|
|
from examples.atari.atari_network import (
|
|
ActorFactoryAtariDQN,
|
|
IntermediateModuleFactoryAtariDQNFeatures,
|
|
)
|
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
|
from tianshou.highlevel.config import SamplingConfig
|
|
from tianshou.highlevel.experiment import (
|
|
ExperimentConfig,
|
|
PPOExperimentBuilder,
|
|
)
|
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
|
from tianshou.highlevel.params.policy_params import PPOParams
|
|
from tianshou.highlevel.params.policy_wrapper import (
|
|
PolicyWrapperFactoryIntrinsicCuriosity,
|
|
)
|
|
from tianshou.utils import logging
|
|
|
|
|
|
def main(
|
|
experiment_config: ExperimentConfig,
|
|
task: str = "PongNoFrameskip-v4",
|
|
scale_obs: bool = True,
|
|
buffer_size: int = 100000,
|
|
lr: float = 2.5e-4,
|
|
gamma: float = 0.99,
|
|
epoch: int = 100,
|
|
step_per_epoch: int = 100000,
|
|
step_per_collect: int = 1000,
|
|
repeat_per_collect: int = 4,
|
|
batch_size: int = 256,
|
|
hidden_sizes: int | Sequence[int] = 512,
|
|
training_num: int = 10,
|
|
test_num: int = 10,
|
|
rew_norm: bool = False,
|
|
vf_coef: float = 0.25,
|
|
ent_coef: float = 0.01,
|
|
gae_lambda: float = 0.95,
|
|
lr_decay: bool = True,
|
|
max_grad_norm: float = 0.5,
|
|
eps_clip: float = 0.1,
|
|
dual_clip: float | None = None,
|
|
value_clip: bool = True,
|
|
norm_adv: bool = True,
|
|
recompute_adv: bool = False,
|
|
frames_stack: int = 4,
|
|
save_buffer_name: str | None = None, # TODO add support in high-level API?
|
|
icm_lr_scale: float = 0.0,
|
|
icm_reward_scale: float = 0.01,
|
|
icm_forward_loss_weight: float = 0.2,
|
|
):
|
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
|
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
|
|
|
sampling_config = SamplingConfig(
|
|
num_epochs=epoch,
|
|
step_per_epoch=step_per_epoch,
|
|
batch_size=batch_size,
|
|
num_train_envs=training_num,
|
|
num_test_envs=test_num,
|
|
buffer_size=buffer_size,
|
|
step_per_collect=step_per_collect,
|
|
repeat_per_collect=repeat_per_collect,
|
|
replay_buffer_stack_num=frames_stack,
|
|
replay_buffer_ignore_obs_next=True,
|
|
replay_buffer_save_only_last_obs=True,
|
|
)
|
|
|
|
env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack)
|
|
|
|
builder = (
|
|
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
|
.with_ppo_params(
|
|
PPOParams(
|
|
discount_factor=gamma,
|
|
gae_lambda=gae_lambda,
|
|
reward_normalization=rew_norm,
|
|
ent_coef=ent_coef,
|
|
vf_coef=vf_coef,
|
|
max_grad_norm=max_grad_norm,
|
|
value_clip=value_clip,
|
|
advantage_normalization=norm_adv,
|
|
eps_clip=eps_clip,
|
|
dual_clip=dual_clip,
|
|
recompute_advantage=recompute_adv,
|
|
lr=lr,
|
|
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
|
|
if lr_decay
|
|
else None,
|
|
),
|
|
)
|
|
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs, features_only=True))
|
|
.with_critic_factory_use_actor()
|
|
.with_trainer_stop_callback(AtariStopCallback(task))
|
|
)
|
|
if icm_lr_scale > 0:
|
|
builder.with_policy_wrapper_factory(
|
|
PolicyWrapperFactoryIntrinsicCuriosity(
|
|
IntermediateModuleFactoryAtariDQNFeatures(),
|
|
[hidden_sizes],
|
|
lr,
|
|
icm_lr_scale,
|
|
icm_reward_scale,
|
|
icm_forward_loss_weight,
|
|
),
|
|
)
|
|
experiment = builder.build()
|
|
experiment.run(log_name)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.run_main(lambda: CLI(main))
|