Trainer refactor : some definition change (#293)
This PR focus on some definition change of trainer to make it more friendly to use and be consistent with typical usage in research papers, typically change `collect-per-step` to `step-per-collect`, add `update-per-step` / `episode-per-collect` accordingly, and modify the documentation.
This commit is contained in:
parent
150d0ec51b
commit
7036073649
@ -191,11 +191,11 @@ Define some hyper-parameters:
|
|||||||
```python
|
```python
|
||||||
task = 'CartPole-v0'
|
task = 'CartPole-v0'
|
||||||
lr, epoch, batch_size = 1e-3, 10, 64
|
lr, epoch, batch_size = 1e-3, 10, 64
|
||||||
train_num, test_num = 8, 100
|
train_num, test_num = 10, 100
|
||||||
gamma, n_step, target_freq = 0.9, 3, 320
|
gamma, n_step, target_freq = 0.9, 3, 320
|
||||||
buffer_size = 20000
|
buffer_size = 20000
|
||||||
eps_train, eps_test = 0.1, 0.05
|
eps_train, eps_test = 0.1, 0.05
|
||||||
step_per_epoch, collect_per_step = 1000, 8
|
step_per_epoch, step_per_collect = 10000, 10
|
||||||
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
|
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -232,8 +232,8 @@ Let's train it:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
result = ts.trainer.offpolicy_trainer(
|
result = ts.trainer.offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step,
|
policy, train_collector, test_collector, epoch, step_per_epoch, step_per_collect,
|
||||||
test_num, batch_size,
|
test_num, batch_size, update_per_step=1 / step_per_collect,
|
||||||
train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
|
train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
|
||||||
test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
|
test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
|
||||||
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
|
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
|
||||||
|
|||||||
@ -284,7 +284,7 @@ policy.process_fn
|
|||||||
|
|
||||||
The ``process_fn`` function computes some variables that depends on time-series. For example, compute the N-step or GAE returns.
|
The ``process_fn`` function computes some variables that depends on time-series. For example, compute the N-step or GAE returns.
|
||||||
|
|
||||||
Take 2-step return DQN as an example. The 2-step return DQN compute each frame's return as:
|
Take 2-step return DQN as an example. The 2-step return DQN compute each transition's return as:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
||||||
|
|||||||
@ -35,10 +35,10 @@ If you want to use the original ``gym.Env``:
|
|||||||
Tianshou supports parallel sampling for all algorithms. It provides four types of vectorized environment wrapper: :class:`~tianshou.env.DummyVectorEnv`, :class:`~tianshou.env.SubprocVectorEnv`, :class:`~tianshou.env.ShmemVectorEnv`, and :class:`~tianshou.env.RayVectorEnv`. It can be used as follows: (more explanation can be found at :ref:`parallel_sampling`)
|
Tianshou supports parallel sampling for all algorithms. It provides four types of vectorized environment wrapper: :class:`~tianshou.env.DummyVectorEnv`, :class:`~tianshou.env.SubprocVectorEnv`, :class:`~tianshou.env.ShmemVectorEnv`, and :class:`~tianshou.env.RayVectorEnv`. It can be used as follows: (more explanation can be found at :ref:`parallel_sampling`)
|
||||||
::
|
::
|
||||||
|
|
||||||
train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(8)])
|
train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)])
|
||||||
test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)])
|
test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)])
|
||||||
|
|
||||||
Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_envs``.
|
Here, we set up 10 environments in ``train_envs`` and 100 environments in ``test_envs``.
|
||||||
|
|
||||||
For the demonstration, here we use the second code-block.
|
For the demonstration, here we use the second code-block.
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ Tianshou supports any user-defined PyTorch networks and optimizers. Yet, of cour
|
|||||||
net = Net(state_shape, action_shape)
|
net = Net(state_shape, action_shape)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=1e-3)
|
optim = torch.optim.Adam(net.parameters(), lr=1e-3)
|
||||||
|
|
||||||
It is also possible to use pre-defined MLP networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are:
|
You can also use pre-defined MLP networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are:
|
||||||
|
|
||||||
1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment.
|
1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment.
|
||||||
2. Output: some ``logits``, the next hidden state ``state``. The logits could be a tuple instead of a ``torch.Tensor``, or some other useful variables or results during the policy forwarding procedure. It depends on how the policy class process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy.
|
2. Output: some ``logits``, the next hidden state ``state``. The logits could be a tuple instead of a ``torch.Tensor``, or some other useful variables or results during the policy forwarding procedure. It depends on how the policy class process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy.
|
||||||
@ -113,7 +113,7 @@ The collector is a key concept in Tianshou. It allows the policy to interact wit
|
|||||||
In each step, the collector will let the policy perform (at least) a specified number of steps or episodes and store the data in a replay buffer.
|
In each step, the collector will let the policy perform (at least) a specified number of steps or episodes and store the data in a replay buffer.
|
||||||
::
|
::
|
||||||
|
|
||||||
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 8), exploration_noise=True)
|
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 10), exploration_noise=True)
|
||||||
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True)
|
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True)
|
||||||
|
|
||||||
|
|
||||||
@ -125,8 +125,8 @@ Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.t
|
|||||||
|
|
||||||
result = ts.trainer.offpolicy_trainer(
|
result = ts.trainer.offpolicy_trainer(
|
||||||
policy, train_collector, test_collector,
|
policy, train_collector, test_collector,
|
||||||
max_epoch=10, step_per_epoch=1000, collect_per_step=10,
|
max_epoch=10, step_per_epoch=10000, step_per_collect=10,
|
||||||
episode_per_test=100, batch_size=64,
|
update_per_step=0.1, episode_per_test=100, batch_size=64,
|
||||||
train_fn=lambda epoch, env_step: policy.set_eps(0.1),
|
train_fn=lambda epoch, env_step: policy.set_eps(0.1),
|
||||||
test_fn=lambda epoch, env_step: policy.set_eps(0.05),
|
test_fn=lambda epoch, env_step: policy.set_eps(0.05),
|
||||||
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
|
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
|
||||||
@ -136,8 +136,8 @@ Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.t
|
|||||||
The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`):
|
The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`):
|
||||||
|
|
||||||
* ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``;
|
* ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``;
|
||||||
* ``step_per_epoch``: The number of step for updating policy network in one epoch;
|
* ``step_per_epoch``: The number of environment step (a.k.a. transition) collected per epoch;
|
||||||
* ``collect_per_step``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update";
|
* ``step_per_collect``: The number of transition the collector would collect before the network update. For example, the code above means "collect 10 transitions and do one policy network update";
|
||||||
* ``episode_per_test``: The number of episodes for one policy evaluation.
|
* ``episode_per_test``: The number of episodes for one policy evaluation.
|
||||||
* ``batch_size``: The batch size of sample data, which is going to feed in the policy network.
|
* ``batch_size``: The batch size of sample data, which is going to feed in the policy network.
|
||||||
* ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
|
* ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
|
||||||
@ -205,7 +205,7 @@ Train a Policy with Customized Codes
|
|||||||
Tianshou supports user-defined training code. Here is the code snippet:
|
Tianshou supports user-defined training code. Here is the code snippet:
|
||||||
::
|
::
|
||||||
|
|
||||||
# pre-collect at least 5000 frames with random action before training
|
# pre-collect at least 5000 transitions with random action before training
|
||||||
train_collector.collect(n_step=5000, random=True)
|
train_collector.collect(n_step=5000, random=True)
|
||||||
|
|
||||||
policy.set_eps(0.1)
|
policy.set_eps(0.1)
|
||||||
|
|||||||
@ -200,8 +200,9 @@ The explanation of each Tianshou class/function will be deferred to their first
|
|||||||
parser.add_argument('--n-step', type=int, default=3)
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
parser.add_argument('--epoch', type=int, default=20)
|
parser.add_argument('--epoch', type=int, default=20)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=500)
|
parser.add_argument('--step-per-epoch', type=int, default=5000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
nargs='*', default=[128, 128, 128, 128])
|
nargs='*', default=[128, 128, 128, 128])
|
||||||
@ -293,7 +294,7 @@ With the above preparation, we are close to the first learned agent. The followi
|
|||||||
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
|
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
|
||||||
collector = Collector(policy, env)
|
collector = Collector(policy, env)
|
||||||
result = collector.collect(n_episode=1, render=args.render)
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}')
|
print(f'Final reward: {result["rews"][:, args.agent_id - 1].mean()}, length: {result["lens"].mean()}')
|
||||||
if args.watch:
|
if args.watch:
|
||||||
watch(args)
|
watch(args)
|
||||||
exit(0)
|
exit(0)
|
||||||
@ -355,10 +356,10 @@ With the above preparation, we are close to the first learned agent. The followi
|
|||||||
# start training, this may require about three minutes
|
# start training, this may require about three minutes
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, reward_metric=reward_metric,
|
stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step,
|
||||||
writer=writer, test_in_train=False)
|
writer=writer, test_in_train=False, reward_metric=reward_metric)
|
||||||
|
|
||||||
agent = policy.policies[args.agent_id - 1]
|
agent = policy.policies[args.agent_id - 1]
|
||||||
# let's watch the match!
|
# let's watch the match!
|
||||||
|
|||||||
@ -28,7 +28,7 @@ def get_args():
|
|||||||
parser.add_argument("--unlikely-action-threshold", type=float, default=0.3)
|
parser.add_argument("--unlikely-action-threshold", type=float, default=0.3)
|
||||||
parser.add_argument("--imitation-logits-penalty", type=float, default=0.01)
|
parser.add_argument("--imitation-logits-penalty", type=float, default=0.01)
|
||||||
parser.add_argument("--epoch", type=int, default=100)
|
parser.add_argument("--epoch", type=int, default=100)
|
||||||
parser.add_argument("--step-per-epoch", type=int, default=10000)
|
parser.add_argument("--update-per-epoch", type=int, default=10000)
|
||||||
parser.add_argument("--batch-size", type=int, default=32)
|
parser.add_argument("--batch-size", type=int, default=32)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
nargs='*', default=[512])
|
nargs='*', default=[512])
|
||||||
@ -140,7 +140,7 @@ def test_discrete_bcq(args=get_args()):
|
|||||||
|
|
||||||
result = offline_trainer(
|
result = offline_trainer(
|
||||||
policy, buffer, test_collector,
|
policy, buffer, test_collector,
|
||||||
args.epoch, args.step_per_epoch, args.test_num, args.batch_size,
|
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
||||||
log_interval=args.log_interval,
|
log_interval=args.log_interval,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -30,8 +30,9 @@ def get_args():
|
|||||||
parser.add_argument('--n-step', type=int, default=3)
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
parser.add_argument('--target-update-freq', type=int, default=500)
|
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
parser.add_argument('--batch-size', type=int, default=32)
|
parser.add_argument('--batch-size', type=int, default=32)
|
||||||
parser.add_argument('--training-num', type=int, default=10)
|
parser.add_argument('--training-num', type=int, default=10)
|
||||||
parser.add_argument('--test-num', type=int, default=10)
|
parser.add_argument('--test-num', type=int, default=10)
|
||||||
@ -141,9 +142,10 @@ def test_c51(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False)
|
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
||||||
|
update_per_step=args.update_per_step, test_in_train=False)
|
||||||
|
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
watch()
|
watch()
|
||||||
|
|||||||
@ -27,8 +27,9 @@ def get_args():
|
|||||||
parser.add_argument('--n-step', type=int, default=3)
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
parser.add_argument('--target-update-freq', type=int, default=500)
|
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
parser.add_argument('--batch-size', type=int, default=32)
|
parser.add_argument('--batch-size', type=int, default=32)
|
||||||
parser.add_argument('--training-num', type=int, default=10)
|
parser.add_argument('--training-num', type=int, default=10)
|
||||||
parser.add_argument('--test-num', type=int, default=10)
|
parser.add_argument('--test-num', type=int, default=10)
|
||||||
@ -151,9 +152,10 @@ def test_dqn(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False)
|
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
||||||
|
update_per_step=args.update_per_step, test_in_train=False)
|
||||||
|
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
watch()
|
watch()
|
||||||
|
|||||||
@ -28,8 +28,9 @@ def get_args():
|
|||||||
parser.add_argument('--n-step', type=int, default=3)
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
parser.add_argument('--target-update-freq', type=int, default=500)
|
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
parser.add_argument('--batch-size', type=int, default=32)
|
parser.add_argument('--batch-size', type=int, default=32)
|
||||||
parser.add_argument('--training-num', type=int, default=10)
|
parser.add_argument('--training-num', type=int, default=10)
|
||||||
parser.add_argument('--test-num', type=int, default=10)
|
parser.add_argument('--test-num', type=int, default=10)
|
||||||
@ -139,9 +140,10 @@ def test_qrdqn(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False)
|
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
||||||
|
update_per_step=args.update_per_step, test_in_train=False)
|
||||||
|
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
watch()
|
watch()
|
||||||
|
|||||||
@ -4,7 +4,6 @@ import pprint
|
|||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import A2CPolicy
|
from tianshou.policy import A2CPolicy
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
@ -24,7 +23,7 @@ def get_args():
|
|||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--episode-per-collect', type=int, default=10)
|
||||||
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
@ -91,8 +90,8 @@ def test_a2c(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
||||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer)
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
|
|||||||
@ -24,7 +24,7 @@ def get_args():
|
|||||||
parser.add_argument('--gamma', type=float, default=0.99)
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--episode-per-collect', type=int, default=10)
|
||||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
@ -95,8 +95,8 @@ def test_ppo(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
||||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer)
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
|
|||||||
@ -25,8 +25,9 @@ def get_args():
|
|||||||
parser.add_argument('--n-step', type=int, default=3)
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
parser.add_argument('--epoch', type=int, default=10)
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=100)
|
parser.add_argument('--step-per-collect', type=int, default=100)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.01)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128])
|
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128])
|
||||||
parser.add_argument('--dueling-q-hidden-sizes', type=int,
|
parser.add_argument('--dueling-q-hidden-sizes', type=int,
|
||||||
@ -103,8 +104,8 @@ def test_dqn(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||||
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
|
|||||||
@ -27,8 +27,9 @@ def get_args():
|
|||||||
parser.add_argument('--auto-alpha', type=int, default=1)
|
parser.add_argument('--auto-alpha', type=int, default=1)
|
||||||
parser.add_argument('--alpha-lr', type=float, default=3e-4)
|
parser.add_argument('--alpha-lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
parser.add_argument('--batch-size', type=int, default=128)
|
parser.add_argument('--batch-size', type=int, default=128)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
nargs='*', default=[128, 128])
|
nargs='*', default=[128, 128])
|
||||||
@ -143,9 +144,9 @@ def test_sac_bipedal(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
||||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
update_per_step=args.update_per_step, test_in_train=False,
|
||||||
test_in_train=False)
|
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -26,8 +26,9 @@ def get_args():
|
|||||||
parser.add_argument('--n-step', type=int, default=4)
|
parser.add_argument('--n-step', type=int, default=4)
|
||||||
parser.add_argument('--target-update-freq', type=int, default=500)
|
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||||
parser.add_argument('--epoch', type=int, default=10)
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=5000)
|
parser.add_argument('--step-per-epoch', type=int, default=80000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=16)
|
parser.add_argument('--step-per-collect', type=int, default=16)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.0625)
|
||||||
parser.add_argument('--batch-size', type=int, default=128)
|
parser.add_argument('--batch-size', type=int, default=128)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
nargs='*', default=[128, 128])
|
nargs='*', default=[128, 128])
|
||||||
@ -99,10 +100,9 @@ def test_dqn(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
update_per_step=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
test_fn=test_fn, save_fn=save_fn, writer=writer)
|
||||||
test_in_train=False)
|
|
||||||
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -29,8 +29,9 @@ def get_args():
|
|||||||
parser.add_argument('--auto_alpha', type=int, default=1)
|
parser.add_argument('--auto_alpha', type=int, default=1)
|
||||||
parser.add_argument('--alpha', type=float, default=0.2)
|
parser.add_argument('--alpha', type=float, default=0.2)
|
||||||
parser.add_argument('--epoch', type=int, default=20)
|
parser.add_argument('--epoch', type=int, default=20)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
parser.add_argument('--step-per-epoch', type=int, default=12000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=5)
|
parser.add_argument('--step-per-collect', type=int, default=5)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.2)
|
||||||
parser.add_argument('--batch-size', type=int, default=128)
|
parser.add_argument('--batch-size', type=int, default=128)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
nargs='*', default=[128, 128])
|
nargs='*', default=[128, 128])
|
||||||
@ -112,8 +113,10 @@ def test_sac(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
||||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
update_per_step=args.update_per_step, stop_fn=stop_fn,
|
||||||
|
save_fn=save_fn, writer=writer)
|
||||||
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -28,8 +28,9 @@ def get_args():
|
|||||||
parser.add_argument('--alpha-lr', type=float, default=3e-4)
|
parser.add_argument('--alpha-lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--n-step', type=int, default=2)
|
parser.add_argument('--n-step', type=int, default=2)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
parser.add_argument('--step-per-epoch', type=int, default=40000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=4)
|
parser.add_argument('--step-per-collect', type=int, default=4)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.25)
|
||||||
parser.add_argument('--update-per-step', type=int, default=1)
|
parser.add_argument('--update-per-step', type=int, default=1)
|
||||||
parser.add_argument('--pre-collect-step', type=int, default=10000)
|
parser.add_argument('--pre-collect-step', type=int, default=10000)
|
||||||
parser.add_argument('--batch-size', type=int, default=256)
|
parser.add_argument('--batch-size', type=int, default=256)
|
||||||
@ -139,10 +140,9 @@ def test_sac(args=get_args()):
|
|||||||
train_collector.collect(n_step=args.pre_collect_step, random=True)
|
train_collector.collect(n_step=args.pre_collect_step, random=True)
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, args.update_per_step,
|
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
update_per_step=args.update_per_step, log_interval=args.log_interval)
|
||||||
log_interval=args.log_interval)
|
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
watch()
|
watch()
|
||||||
|
|
||||||
|
|||||||
@ -26,7 +26,7 @@ def get_args():
|
|||||||
parser.add_argument('--exploration-noise', type=float, default=0.1)
|
parser.add_argument('--exploration-noise', type=float, default=0.1)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=4)
|
parser.add_argument('--step-per-collect', type=int, default=4)
|
||||||
parser.add_argument('--batch-size', type=int, default=128)
|
parser.add_argument('--batch-size', type=int, default=128)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
nargs='*', default=[128, 128])
|
nargs='*', default=[128, 128])
|
||||||
@ -87,7 +87,7 @@ def test_ddpg(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -29,7 +29,7 @@ def get_args():
|
|||||||
parser.add_argument('--update-actor-freq', type=int, default=2)
|
parser.add_argument('--update-actor-freq', type=int, default=2)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
parser.add_argument('--batch-size', type=int, default=128)
|
parser.add_argument('--batch-size', type=int, default=128)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
nargs='*', default=[128, 128])
|
nargs='*', default=[128, 128])
|
||||||
@ -96,7 +96,7 @@ def test_td3(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -28,7 +28,7 @@ def get_args():
|
|||||||
parser.add_argument('--alpha', type=float, default=0.2)
|
parser.add_argument('--alpha', type=float, default=0.2)
|
||||||
parser.add_argument('--epoch', type=int, default=200)
|
parser.add_argument('--epoch', type=int, default=200)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
parser.add_argument('--batch-size', type=int, default=128)
|
parser.add_argument('--batch-size', type=int, default=128)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
nargs='*', default=[128, 128])
|
nargs='*', default=[128, 128])
|
||||||
@ -97,7 +97,7 @@ def test_sac(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn,
|
args.batch_size, stop_fn=stop_fn,
|
||||||
writer=writer, log_interval=args.log_interval)
|
writer=writer, log_interval=args.log_interval)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
|
|||||||
@ -31,7 +31,7 @@ def get_args():
|
|||||||
parser.add_argument('--update-actor-freq', type=int, default=2)
|
parser.add_argument('--update-actor-freq', type=int, default=2)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
parser.add_argument('--batch-size', type=int, default=128)
|
parser.add_argument('--batch-size', type=int, default=128)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
nargs='*', default=[128, 128])
|
nargs='*', default=[128, 128])
|
||||||
@ -104,7 +104,7 @@ def test_td3(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -26,8 +26,9 @@ def get_args():
|
|||||||
parser.add_argument('--tau', type=float, default=0.005)
|
parser.add_argument('--tau', type=float, default=0.005)
|
||||||
parser.add_argument('--exploration-noise', type=float, default=0.1)
|
parser.add_argument('--exploration-noise', type=float, default=0.1)
|
||||||
parser.add_argument('--epoch', type=int, default=20)
|
parser.add_argument('--epoch', type=int, default=20)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
parser.add_argument('--step-per-epoch', type=int, default=9600)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=4)
|
parser.add_argument('--step-per-collect', type=int, default=4)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.25)
|
||||||
parser.add_argument('--batch-size', type=int, default=128)
|
parser.add_argument('--batch-size', type=int, default=128)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
nargs='*', default=[128, 128])
|
nargs='*', default=[128, 128])
|
||||||
@ -102,8 +103,9 @@ def test_ddpg(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
||||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
update_per_step=args.update_per_step, stop_fn=stop_fn,
|
||||||
|
save_fn=save_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -23,8 +23,8 @@ def get_args():
|
|||||||
parser.add_argument('--lr', type=float, default=1e-3)
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.99)
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
parser.add_argument('--epoch', type=int, default=20)
|
parser.add_argument('--epoch', type=int, default=20)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
parser.add_argument('--step-per-epoch', type=int, default=150000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=16)
|
parser.add_argument('--episode-per-collect', type=int, default=16)
|
||||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||||
parser.add_argument('--batch-size', type=int, default=128)
|
parser.add_argument('--batch-size', type=int, default=128)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
@ -121,8 +121,8 @@ def test_ppo(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
||||||
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
|
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
||||||
writer=writer)
|
writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -26,8 +26,10 @@ def get_args():
|
|||||||
parser.add_argument('--tau', type=float, default=0.005)
|
parser.add_argument('--tau', type=float, default=0.005)
|
||||||
parser.add_argument('--alpha', type=float, default=0.2)
|
parser.add_argument('--alpha', type=float, default=0.2)
|
||||||
parser.add_argument('--epoch', type=int, default=20)
|
parser.add_argument('--epoch', type=int, default=20)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
parser.add_argument('--step-per-epoch', type=int, default=24000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--il-step-per-epoch', type=int, default=500)
|
||||||
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
parser.add_argument('--batch-size', type=int, default=128)
|
parser.add_argument('--batch-size', type=int, default=128)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
nargs='*', default=[128, 128])
|
nargs='*', default=[128, 128])
|
||||||
@ -110,8 +112,9 @@ def test_sac_with_il(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
||||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
update_per_step=args.update_per_step, stop_fn=stop_fn,
|
||||||
|
save_fn=save_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
@ -142,7 +145,7 @@ def test_sac_with_il(args=get_args()):
|
|||||||
train_collector.reset()
|
train_collector.reset()
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
il_policy, train_collector, il_test_collector, args.epoch,
|
il_policy, train_collector, il_test_collector, args.epoch,
|
||||||
args.step_per_epoch // 5, args.collect_per_step, args.test_num,
|
args.il_step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -29,8 +29,9 @@ def get_args():
|
|||||||
parser.add_argument('--noise-clip', type=float, default=0.5)
|
parser.add_argument('--noise-clip', type=float, default=0.5)
|
||||||
parser.add_argument('--update-actor-freq', type=int, default=2)
|
parser.add_argument('--update-actor-freq', type=int, default=2)
|
||||||
parser.add_argument('--epoch', type=int, default=20)
|
parser.add_argument('--epoch', type=int, default=20)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
parser.add_argument('--step-per-epoch', type=int, default=20000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
parser.add_argument('--batch-size', type=int, default=128)
|
parser.add_argument('--batch-size', type=int, default=128)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
nargs='*', default=[128, 128])
|
nargs='*', default=[128, 128])
|
||||||
@ -115,8 +116,9 @@ def test_td3(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
||||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
update_per_step=args.update_per_step, stop_fn=stop_fn,
|
||||||
|
save_fn=save_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -23,8 +23,11 @@ def get_args():
|
|||||||
parser.add_argument('--il-lr', type=float, default=1e-3)
|
parser.add_argument('--il-lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--epoch', type=int, default=10)
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=50000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=8)
|
parser.add_argument('--il-step-per-epoch', type=int, default=1000)
|
||||||
|
parser.add_argument('--episode-per-collect', type=int, default=8)
|
||||||
|
parser.add_argument('--step-per-collect', type=int, default=8)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.125)
|
||||||
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
@ -96,8 +99,8 @@ def test_a2c_with_il(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
||||||
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
|
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
||||||
writer=writer)
|
writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -121,13 +124,12 @@ def test_a2c_with_il(args=get_args()):
|
|||||||
il_policy = ImitationPolicy(net, optim, mode='discrete')
|
il_policy = ImitationPolicy(net, optim, mode='discrete')
|
||||||
il_test_collector = Collector(
|
il_test_collector = Collector(
|
||||||
il_policy,
|
il_policy,
|
||||||
DummyVectorEnv(
|
DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
|
||||||
)
|
)
|
||||||
train_collector.reset()
|
train_collector.reset()
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
il_policy, train_collector, il_test_collector, args.epoch,
|
il_policy, train_collector, il_test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.il_step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -28,8 +28,9 @@ def get_args():
|
|||||||
parser.add_argument('--n-step', type=int, default=3)
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
parser.add_argument('--epoch', type=int, default=10)
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=8000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=8)
|
parser.add_argument('--step-per-collect', type=int, default=8)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.125)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
nargs='*', default=[128, 128, 128, 128])
|
nargs='*', default=[128, 128, 128, 128])
|
||||||
@ -112,7 +113,7 @@ def test_c51(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||||
|
|
||||||
|
|||||||
@ -25,9 +25,10 @@ def get_args():
|
|||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--n-step', type=int, default=3)
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
parser.add_argument('--epoch', type=int, default=10)
|
parser.add_argument('--epoch', type=int, default=20)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
nargs='*', default=[128, 128, 128, 128])
|
nargs='*', default=[128, 128, 128, 128])
|
||||||
@ -114,9 +115,9 @@ def test_dqn(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||||
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
|
|
||||||
|
|||||||
@ -26,8 +26,9 @@ def get_args():
|
|||||||
parser.add_argument('--n-step', type=int, default=4)
|
parser.add_argument('--n-step', type=int, default=4)
|
||||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
parser.add_argument('--epoch', type=int, default=10)
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--layer-num', type=int, default=3)
|
parser.add_argument('--layer-num', type=int, default=3)
|
||||||
parser.add_argument('--training-num', type=int, default=10)
|
parser.add_argument('--training-num', type=int, default=10)
|
||||||
@ -92,9 +93,10 @@ def test_drqn(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, update_per_step=args.update_per_step,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn,
|
||||||
|
save_fn=save_fn, writer=writer)
|
||||||
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -26,7 +26,7 @@ def get_args():
|
|||||||
parser.add_argument("--unlikely-action-threshold", type=float, default=0.3)
|
parser.add_argument("--unlikely-action-threshold", type=float, default=0.3)
|
||||||
parser.add_argument("--imitation-logits-penalty", type=float, default=0.01)
|
parser.add_argument("--imitation-logits-penalty", type=float, default=0.01)
|
||||||
parser.add_argument("--epoch", type=int, default=5)
|
parser.add_argument("--epoch", type=int, default=5)
|
||||||
parser.add_argument("--step-per-epoch", type=int, default=1000)
|
parser.add_argument("--update-per-epoch", type=int, default=1000)
|
||||||
parser.add_argument("--batch-size", type=int, default=64)
|
parser.add_argument("--batch-size", type=int, default=64)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
nargs='*', default=[128, 128, 128])
|
nargs='*', default=[128, 128, 128])
|
||||||
@ -91,7 +91,7 @@ def test_discrete_bcq(args=get_args()):
|
|||||||
|
|
||||||
result = offline_trainer(
|
result = offline_trainer(
|
||||||
policy, buffer, test_collector,
|
policy, buffer, test_collector,
|
||||||
args.epoch, args.step_per_epoch, args.test_num, args.batch_size,
|
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||||
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
|
|||||||
@ -21,8 +21,8 @@ def get_args():
|
|||||||
parser.add_argument('--lr', type=float, default=1e-3)
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.95)
|
parser.add_argument('--gamma', type=float, default=0.95)
|
||||||
parser.add_argument('--epoch', type=int, default=10)
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=40000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=8)
|
parser.add_argument('--episode-per-collect', type=int, default=8)
|
||||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
@ -82,8 +82,8 @@ def test_pg(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
||||||
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
|
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
||||||
writer=writer)
|
writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -22,8 +22,8 @@ def get_args():
|
|||||||
parser.add_argument('--lr', type=float, default=1e-3)
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.99)
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
parser.add_argument('--epoch', type=int, default=10)
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=2000)
|
parser.add_argument('--step-per-epoch', type=int, default=50000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=20)
|
parser.add_argument('--episode-per-collect', type=int, default=20)
|
||||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
@ -108,8 +108,8 @@ def test_ppo(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
||||||
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
|
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
||||||
writer=writer)
|
writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -26,8 +26,9 @@ def get_args():
|
|||||||
parser.add_argument('--n-step', type=int, default=3)
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
parser.add_argument('--epoch', type=int, default=10)
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
nargs='*', default=[128, 128, 128, 128])
|
nargs='*', default=[128, 128, 128, 128])
|
||||||
@ -110,9 +111,10 @@ def test_qrdqn(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
||||||
|
update_per_step=args.update_per_step)
|
||||||
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -27,8 +27,9 @@ def get_args():
|
|||||||
parser.add_argument('--alpha', type=float, default=0.05)
|
parser.add_argument('--alpha', type=float, default=0.05)
|
||||||
parser.add_argument('--auto_alpha', type=int, default=0)
|
parser.add_argument('--auto_alpha', type=int, default=0)
|
||||||
parser.add_argument('--epoch', type=int, default=5)
|
parser.add_argument('--epoch', type=int, default=5)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=5000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=5)
|
parser.add_argument('--step-per-collect', type=int, default=5)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.2)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
nargs='*', default=[128, 128])
|
nargs='*', default=[128, 128])
|
||||||
@ -108,9 +109,9 @@ def test_discrete_sac(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
||||||
test_in_train=False)
|
update_per_step=args.update_per_step, test_in_train=False)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -17,8 +17,8 @@ def get_args():
|
|||||||
parser.add_argument('--seed', type=int, default=1626)
|
parser.add_argument('--seed', type=int, default=1626)
|
||||||
parser.add_argument('--buffer-size', type=int, default=50000)
|
parser.add_argument('--buffer-size', type=int, default=50000)
|
||||||
parser.add_argument('--epoch', type=int, default=5)
|
parser.add_argument('--epoch', type=int, default=5)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=5)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=1)
|
parser.add_argument('--episode-per-collect', type=int, default=1)
|
||||||
parser.add_argument('--training-num', type=int, default=1)
|
parser.add_argument('--training-num', type=int, default=1)
|
||||||
parser.add_argument('--test-num', type=int, default=100)
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
parser.add_argument('--logdir', type=str, default='log')
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
@ -78,8 +78,8 @@ def test_psrl(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, 1,
|
args.step_per_epoch, 1, args.test_num, 0,
|
||||||
args.test_num, 0, stop_fn=stop_fn, writer=writer,
|
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer,
|
||||||
test_in_train=False)
|
test_in_train=False)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -28,8 +28,9 @@ def get_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument('--n-step', type=int, default=3)
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
parser.add_argument('--epoch', type=int, default=20)
|
parser.add_argument('--epoch', type=int, default=20)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=500)
|
parser.add_argument('--step-per-epoch', type=int, default=5000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--hidden-sizes', type=int,
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
nargs='*', default=[128, 128, 128, 128])
|
nargs='*', default=[128, 128, 128, 128])
|
||||||
@ -162,10 +163,10 @@ def train_agent(
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, reward_metric=reward_metric,
|
stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step,
|
||||||
writer=writer, test_in_train=False)
|
writer=writer, test_in_train=False, reward_metric=reward_metric)
|
||||||
|
|
||||||
return result, policy.policies[args.agent_id - 1]
|
return result, policy.policies[args.agent_id - 1]
|
||||||
|
|
||||||
@ -183,4 +184,4 @@ def watch(
|
|||||||
collector = Collector(policy, env)
|
collector = Collector(policy, env)
|
||||||
result = collector.collect(n_episode=1, render=args.render)
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
rews, lens = result["rews"], result["lens"]
|
rews, lens = result["rews"], result["lens"]
|
||||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")
|
||||||
|
|||||||
@ -605,7 +605,7 @@ class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManage
|
|||||||
class VectorReplayBuffer(ReplayBufferManager):
|
class VectorReplayBuffer(ReplayBufferManager):
|
||||||
"""VectorReplayBuffer contains n ReplayBuffer with the same size.
|
"""VectorReplayBuffer contains n ReplayBuffer with the same size.
|
||||||
|
|
||||||
It is used for storing data frame from different environments yet keeping the order
|
It is used for storing transition from different environments yet keeping the order
|
||||||
of time.
|
of time.
|
||||||
|
|
||||||
:param int total_size: the total size of VectorReplayBuffer.
|
:param int total_size: the total size of VectorReplayBuffer.
|
||||||
@ -631,7 +631,7 @@ class VectorReplayBuffer(ReplayBufferManager):
|
|||||||
class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager):
|
class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager):
|
||||||
"""PrioritizedVectorReplayBuffer contains n PrioritizedReplayBuffer with same size.
|
"""PrioritizedVectorReplayBuffer contains n PrioritizedReplayBuffer with same size.
|
||||||
|
|
||||||
It is used for storing data frame from different environments yet keeping the order
|
It is used for storing transition from different environments yet keeping the order
|
||||||
of time.
|
of time.
|
||||||
|
|
||||||
:param int total_size: the total size of PrioritizedVectorReplayBuffer.
|
:param int total_size: the total size of PrioritizedVectorReplayBuffer.
|
||||||
|
|||||||
@ -198,7 +198,7 @@ class Collector(object):
|
|||||||
if not n_step % self.env_num == 0:
|
if not n_step % self.env_num == 0:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"n_step={n_step} is not a multiple of #env ({self.env_num}), "
|
f"n_step={n_step} is not a multiple of #env ({self.env_num}), "
|
||||||
"which may cause extra frame collected into the buffer."
|
"which may cause extra transitions collected into the buffer."
|
||||||
)
|
)
|
||||||
ready_env_ids = np.arange(self.env_num)
|
ready_env_ids = np.arange(self.env_num)
|
||||||
elif n_episode is not None:
|
elif n_episode is not None:
|
||||||
@ -357,9 +357,9 @@ class AsyncCollector(Collector):
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Collect a specified number of step or episode with async env setting.
|
"""Collect a specified number of step or episode with async env setting.
|
||||||
|
|
||||||
This function doesn't collect exactly n_step or n_episode number of frames.
|
This function doesn't collect exactly n_step or n_episode number of
|
||||||
Instead, in order to support async setting, it may collect more than given
|
transitions. Instead, in order to support async setting, it may collect more
|
||||||
n_step or n_episode frames and save into buffer.
|
than given n_step or n_episode transitions and save into buffer.
|
||||||
|
|
||||||
:param int n_step: how many steps you want to collect.
|
:param int n_step: how many steps you want to collect.
|
||||||
:param int n_episode: how many episodes you want to collect.
|
:param int n_episode: how many episodes you want to collect.
|
||||||
@ -395,7 +395,7 @@ class AsyncCollector(Collector):
|
|||||||
else:
|
else:
|
||||||
raise TypeError("Please specify at least one (either n_step or n_episode) "
|
raise TypeError("Please specify at least one (either n_step or n_episode) "
|
||||||
"in AsyncCollector.collect().")
|
"in AsyncCollector.collect().")
|
||||||
warnings.warn("Using async setting may collect extra frames into buffer.")
|
warnings.warn("Using async setting may collect extra transitions into buffer.")
|
||||||
|
|
||||||
ready_env_ids = self._ready_env_ids
|
ready_env_ids = self._ready_env_ids
|
||||||
|
|
||||||
|
|||||||
@ -220,21 +220,17 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
|
Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
|
||||||
to calculate q function/reward to go of given batch.
|
to calculate q function/reward to go of given batch.
|
||||||
|
|
||||||
:param batch: a data batch which contains several episodes of data
|
:param Batch batch: a data batch which contains several episodes of data
|
||||||
in sequential order. Mind that the end of each finished episode of batch
|
in sequential order. Mind that the end of each finished episode of batch
|
||||||
should be marked by done flag, unfinished (or collecting) episodes will be
|
should be marked by done flag, unfinished (or collecting) episodes will be
|
||||||
recongized by buffer.unfinished_index().
|
recongized by buffer.unfinished_index().
|
||||||
:type batch: :class:`~tianshou.data.Batch`
|
:param np.ndarray indice: tell batch's location in buffer, batch is
|
||||||
:param numpy.ndarray indice: tell batch's location in buffer, batch is
|
|
||||||
equal to buffer[indice].
|
equal to buffer[indice].
|
||||||
:param v_s_: the value function of all next states :math:`V(s')`.
|
:param np.ndarray v_s_: the value function of all next states :math:`V(s')`.
|
||||||
:type v_s_: numpy.ndarray
|
:param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
|
||||||
:param float gamma: the discount factor, should be in [0, 1], defaults
|
:param float gae_lambda: the parameter for Generalized Advantage Estimation,
|
||||||
to 0.99.
|
should be in [0, 1]. Default to 0.95.
|
||||||
:param float gae_lambda: the parameter for Generalized Advantage
|
:param bool rew_norm: normalize the reward to Normal(0, 1). Default to False.
|
||||||
Estimation, should be in [0, 1], defaults to 0.95.
|
|
||||||
:param bool rew_norm: normalize the reward to Normal(0, 1), defaults
|
|
||||||
to False.
|
|
||||||
|
|
||||||
:return: a Batch. The result will be stored in batch.returns as a numpy
|
:return: a Batch. The result will be stored in batch.returns as a numpy
|
||||||
array with shape (bsz, ).
|
array with shape (bsz, ).
|
||||||
@ -273,18 +269,14 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`,
|
where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`,
|
||||||
:math:`d_t` is the done flag of step :math:`t`.
|
:math:`d_t` is the done flag of step :math:`t`.
|
||||||
|
|
||||||
:param batch: a data batch, which is equal to buffer[indice].
|
:param Batch batch: a data batch, which is equal to buffer[indice].
|
||||||
:type batch: :class:`~tianshou.data.Batch`
|
:param ReplayBuffer buffer: the data buffer.
|
||||||
:param buffer: the data buffer.
|
|
||||||
:type buffer: :class:`~tianshou.data.ReplayBuffer`
|
|
||||||
:param function target_q_fn: a function which compute target Q value
|
:param function target_q_fn: a function which compute target Q value
|
||||||
of "obs_next" given data buffer and wanted indices.
|
of "obs_next" given data buffer and wanted indices.
|
||||||
:param float gamma: the discount factor, should be in [0, 1], defaults
|
:param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
|
||||||
to 0.99.
|
:param int n_step: the number of estimation step, should be an int greater
|
||||||
:param int n_step: the number of estimation step, should be an int
|
than 0. Default to 1.
|
||||||
greater than 0, defaults to 1.
|
:param bool rew_norm: normalize the reward to Normal(0, 1), Default to False.
|
||||||
:param bool rew_norm: normalize the reward to Normal(0, 1), defaults
|
|
||||||
to False.
|
|
||||||
|
|
||||||
:return: a Batch. The result will be stored in batch.returns as a
|
:return: a Batch. The result will be stored in batch.returns as a
|
||||||
torch.Tensor with the same shape as target_q_fn's return tensor.
|
torch.Tensor with the same shape as target_q_fn's return tensor.
|
||||||
|
|||||||
@ -45,7 +45,7 @@ class PGPolicy(BasePolicy):
|
|||||||
def process_fn(
|
def process_fn(
|
||||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||||
) -> Batch:
|
) -> Batch:
|
||||||
r"""Compute the discounted returns for each frame.
|
r"""Compute the discounted returns for each transition.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
G_t = \sum_{i=t}^T \gamma^{i-t}r_i
|
G_t = \sum_{i=t}^T \gamma^{i-t}r_i
|
||||||
|
|||||||
@ -38,5 +38,5 @@ class RandomPolicy(BasePolicy):
|
|||||||
return Batch(act=logits.argmax(axis=-1))
|
return Batch(act=logits.argmax(axis=-1))
|
||||||
|
|
||||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||||
"""Since a random agent learn nothing, it returns an empty dict."""
|
"""Since a random agent learns nothing, it returns an empty dict."""
|
||||||
return {}
|
return {}
|
||||||
|
|||||||
@ -16,7 +16,7 @@ def offline_trainer(
|
|||||||
buffer: ReplayBuffer,
|
buffer: ReplayBuffer,
|
||||||
test_collector: Collector,
|
test_collector: Collector,
|
||||||
max_epoch: int,
|
max_epoch: int,
|
||||||
step_per_epoch: int,
|
update_per_epoch: int,
|
||||||
episode_per_test: int,
|
episode_per_test: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||||
@ -29,50 +29,52 @@ def offline_trainer(
|
|||||||
) -> Dict[str, Union[float, str]]:
|
) -> Dict[str, Union[float, str]]:
|
||||||
"""A wrapper for offline trainer procedure.
|
"""A wrapper for offline trainer procedure.
|
||||||
|
|
||||||
The "step" in trainer means a policy network update.
|
The "step" in offline trainer means a gradient step.
|
||||||
|
|
||||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||||
:param test_collector: the collector used for testing.
|
:param Collector test_collector: the collector used for testing.
|
||||||
:type test_collector: :class:`~tianshou.data.Collector`
|
:param int max_epoch: the maximum number of epochs for training. The training
|
||||||
:param int max_epoch: the maximum number of epochs for training. The
|
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
|
||||||
training process might be finished before reaching the ``max_epoch``.
|
:param int update_per_epoch: the number of policy network updates, so-called
|
||||||
:param int step_per_epoch: the number of policy network updates, so-called
|
|
||||||
gradient steps, per epoch.
|
gradient steps, per epoch.
|
||||||
:param episode_per_test: the number of episodes for one policy evaluation.
|
:param episode_per_test: the number of episodes for one policy evaluation.
|
||||||
:param int batch_size: the batch size of sample data, which is going to
|
:param int batch_size: the batch size of sample data, which is going to feed in
|
||||||
feed in the policy network.
|
the policy network.
|
||||||
:param function test_fn: a hook called at the beginning of testing in each
|
:param function test_fn: a hook called at the beginning of testing in each epoch.
|
||||||
epoch. It can be used to perform custom additional operations, with the
|
It can be used to perform custom additional operations, with the signature ``f(
|
||||||
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
num_epoch: int, step_idx: int) -> None``.
|
||||||
:param function save_fn: a hook called when the undiscounted average mean
|
:param function save_fn: a hook called when the undiscounted average mean reward in
|
||||||
reward in evaluation phase gets better, with the signature ``f(policy:
|
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
|
||||||
BasePolicy) -> None``.
|
None``.
|
||||||
:param function stop_fn: a function with signature ``f(mean_rewards: float)
|
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||||
-> bool``, receives the average undiscounted returns of the testing
|
bool``, receives the average undiscounted returns of the testing result,
|
||||||
result, returns a boolean which indicates whether reaching the goal.
|
returns a boolean which indicates whether reaching the goal.
|
||||||
:param function reward_metric: a function with signature ``f(rewards: np.ndarray
|
:param function reward_metric: a function with signature ``f(rewards: np.ndarray
|
||||||
with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
|
with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
|
||||||
used in multi-agent RL. We need to return a single scalar for each episode's
|
used in multi-agent RL. We need to return a single scalar for each episode's
|
||||||
result to monitor training in the multi-agent RL setting. This function
|
result to monitor training in the multi-agent RL setting. This function
|
||||||
specifies what is the desired metric, e.g., the reward of agent 1 or the
|
specifies what is the desired metric, e.g., the reward of agent 1 or the
|
||||||
average reward over all agents.
|
average reward over all agents.
|
||||||
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
|
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter;
|
||||||
SummaryWriter; if None is given, it will not write logs to TensorBoard.
|
if None is given, it will not write logs to TensorBoard. Default to None.
|
||||||
:param int log_interval: the log interval of the writer.
|
:param int log_interval: the log interval of the writer. Default to 1.
|
||||||
:param bool verbose: whether to print the information.
|
:param bool verbose: whether to print the information. Default to True.
|
||||||
|
|
||||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||||
"""
|
"""
|
||||||
gradient_step = 0
|
gradient_step = 0
|
||||||
best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
|
|
||||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
test_collector.reset_stat()
|
test_collector.reset_stat()
|
||||||
|
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
|
||||||
|
writer, gradient_step, reward_metric)
|
||||||
|
best_epoch = 0
|
||||||
|
best_reward = test_result["rews"].mean()
|
||||||
|
best_reward_std = test_result["rews"].std()
|
||||||
for epoch in range(1, 1 + max_epoch):
|
for epoch in range(1, 1 + max_epoch):
|
||||||
policy.train()
|
policy.train()
|
||||||
with tqdm.trange(
|
with tqdm.trange(
|
||||||
step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
|
update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
|
||||||
) as t:
|
) as t:
|
||||||
for i in t:
|
for i in t:
|
||||||
gradient_step += 1
|
gradient_step += 1
|
||||||
@ -87,16 +89,18 @@ def offline_trainer(
|
|||||||
global_step=gradient_step)
|
global_step=gradient_step)
|
||||||
t.set_postfix(**data)
|
t.set_postfix(**data)
|
||||||
# test
|
# test
|
||||||
result = test_episode(policy, test_collector, test_fn, epoch,
|
test_result = test_episode(policy, test_collector, test_fn, epoch,
|
||||||
episode_per_test, writer, gradient_step, reward_metric)
|
episode_per_test, writer, gradient_step,
|
||||||
if best_epoch == -1 or best_reward < result["rews"].mean():
|
reward_metric)
|
||||||
best_reward, best_reward_std = result["rews"].mean(), result['rews'].std()
|
if best_epoch == -1 or best_reward < test_result["rews"].mean():
|
||||||
|
best_reward = test_result["rews"].mean()
|
||||||
|
best_reward_std = test_result['rews'].std()
|
||||||
best_epoch = epoch
|
best_epoch = epoch
|
||||||
if save_fn:
|
if save_fn:
|
||||||
save_fn(policy)
|
save_fn(policy)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Epoch #{epoch}: test_reward: {result['rews'].mean():.6f} ± "
|
print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± "
|
||||||
f"{result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± "
|
f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± "
|
||||||
f"{best_reward_std:.6f} in #{best_epoch}")
|
f"{best_reward_std:.6f} in #{best_epoch}")
|
||||||
if stop_fn and stop_fn(best_reward):
|
if stop_fn and stop_fn(best_reward):
|
||||||
break
|
break
|
||||||
|
|||||||
@ -17,10 +17,10 @@ def offpolicy_trainer(
|
|||||||
test_collector: Collector,
|
test_collector: Collector,
|
||||||
max_epoch: int,
|
max_epoch: int,
|
||||||
step_per_epoch: int,
|
step_per_epoch: int,
|
||||||
collect_per_step: int,
|
step_per_collect: int,
|
||||||
episode_per_test: int,
|
episode_per_test: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
update_per_step: int = 1,
|
update_per_step: Union[int, float] = 1,
|
||||||
train_fn: Optional[Callable[[int, int], None]] = None,
|
train_fn: Optional[Callable[[int, int], None]] = None,
|
||||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||||
@ -33,60 +33,62 @@ def offpolicy_trainer(
|
|||||||
) -> Dict[str, Union[float, str]]:
|
) -> Dict[str, Union[float, str]]:
|
||||||
"""A wrapper for off-policy trainer procedure.
|
"""A wrapper for off-policy trainer procedure.
|
||||||
|
|
||||||
The "step" in trainer means a policy network update.
|
The "step" in trainer means an environment step (a.k.a. transition).
|
||||||
|
|
||||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||||
:param train_collector: the collector used for training.
|
:param Collector train_collector: the collector used for training.
|
||||||
:type train_collector: :class:`~tianshou.data.Collector`
|
:param Collector test_collector: the collector used for testing.
|
||||||
:param test_collector: the collector used for testing.
|
:param int max_epoch: the maximum number of epochs for training. The training
|
||||||
:type test_collector: :class:`~tianshou.data.Collector`
|
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
|
||||||
:param int max_epoch: the maximum number of epochs for training. The
|
:param int step_per_epoch: the number of transitions collected per epoch.
|
||||||
training process might be finished before reaching the ``max_epoch``.
|
:param int step_per_collect: the number of transitions the collector would collect
|
||||||
:param int step_per_epoch: the number of policy network updates, so-called
|
before the network update, i.e., trainer will collect "step_per_collect"
|
||||||
gradient steps, per epoch.
|
transitions and do some policy network update repeatly in each epoch.
|
||||||
:param int collect_per_step: the number of frames the collector would
|
|
||||||
collect before the network update. In other words, collect some frames
|
|
||||||
and do some policy network update.
|
|
||||||
:param episode_per_test: the number of episodes for one policy evaluation.
|
:param episode_per_test: the number of episodes for one policy evaluation.
|
||||||
:param int batch_size: the batch size of sample data, which is going to
|
:param int batch_size: the batch size of sample data, which is going to feed in the
|
||||||
feed in the policy network.
|
policy network.
|
||||||
:param int update_per_step: the number of times the policy network would
|
:param int/float update_per_step: the number of times the policy network would be
|
||||||
be updated after frames are collected, for example, set it to 256 means
|
updated per transition after (step_per_collect) transitions are collected,
|
||||||
it updates policy 256 times once after ``collect_per_step`` frames are
|
e.g., if update_per_step set to 0.3, and step_per_collect is 256, policy will
|
||||||
collected.
|
be updated round(256 * 0.3 = 76.8) = 77 times after 256 transitions are
|
||||||
:param function train_fn: a hook called at the beginning of training in
|
collected by the collector. Default to 1.
|
||||||
each epoch. It can be used to perform custom additional operations,
|
:param function train_fn: a hook called at the beginning of training in each epoch.
|
||||||
with the signature ``f(num_epoch: int, step_idx: int) -> None``.
|
It can be used to perform custom additional operations, with the signature ``f(
|
||||||
:param function test_fn: a hook called at the beginning of testing in each
|
num_epoch: int, step_idx: int) -> None``.
|
||||||
epoch. It can be used to perform custom additional operations, with the
|
:param function test_fn: a hook called at the beginning of testing in each epoch.
|
||||||
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
It can be used to perform custom additional operations, with the signature ``f(
|
||||||
:param function save_fn: a hook called when the undiscounted average mean
|
num_epoch: int, step_idx: int) -> None``.
|
||||||
reward in evaluation phase gets better, with the signature ``f(policy:
|
:param function save_fn: a hook called when the undiscounted average mean reward in
|
||||||
BasePolicy) -> None``.
|
evaluation phase gets better, with the signature ``f(policy:BasePolicy) ->
|
||||||
:param function stop_fn: a function with signature ``f(mean_rewards: float)
|
None``.
|
||||||
-> bool``, receives the average undiscounted returns of the testing
|
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||||
result, returns a boolean which indicates whether reaching the goal.
|
bool``, receives the average undiscounted returns of the testing result,
|
||||||
|
returns a boolean which indicates whether reaching the goal.
|
||||||
:param function reward_metric: a function with signature ``f(rewards: np.ndarray
|
:param function reward_metric: a function with signature ``f(rewards: np.ndarray
|
||||||
with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
|
with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
|
||||||
used in multi-agent RL. We need to return a single scalar for each episode's
|
used in multi-agent RL. We need to return a single scalar for each episode's
|
||||||
result to monitor training in the multi-agent RL setting. This function
|
result to monitor training in the multi-agent RL setting. This function
|
||||||
specifies what is the desired metric, e.g., the reward of agent 1 or the
|
specifies what is the desired metric, e.g., the reward of agent 1 or the
|
||||||
average reward over all agents.
|
average reward over all agents.
|
||||||
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
|
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter;
|
||||||
SummaryWriter; if None is given, it will not write logs to TensorBoard.
|
if None is given, it will not write logs to TensorBoard. Default to None.
|
||||||
:param int log_interval: the log interval of the writer.
|
:param int log_interval: the log interval of the writer. Default to 1.
|
||||||
:param bool verbose: whether to print the information.
|
:param bool verbose: whether to print the information. Default to True.
|
||||||
:param bool test_in_train: whether to test in the training phase.
|
:param bool test_in_train: whether to test in the training phase. Default to True.
|
||||||
|
|
||||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||||
"""
|
"""
|
||||||
env_step, gradient_step = 0, 0
|
env_step, gradient_step = 0, 0
|
||||||
best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
|
|
||||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
train_collector.reset_stat()
|
train_collector.reset_stat()
|
||||||
test_collector.reset_stat()
|
test_collector.reset_stat()
|
||||||
test_in_train = test_in_train and train_collector.policy == policy
|
test_in_train = test_in_train and train_collector.policy == policy
|
||||||
|
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
|
||||||
|
writer, env_step, reward_metric)
|
||||||
|
best_epoch = 0
|
||||||
|
best_reward = test_result["rews"].mean()
|
||||||
|
best_reward_std = test_result["rews"].std()
|
||||||
for epoch in range(1, 1 + max_epoch):
|
for epoch in range(1, 1 + max_epoch):
|
||||||
# train
|
# train
|
||||||
policy.train()
|
policy.train()
|
||||||
@ -96,10 +98,11 @@ def offpolicy_trainer(
|
|||||||
while t.n < t.total:
|
while t.n < t.total:
|
||||||
if train_fn:
|
if train_fn:
|
||||||
train_fn(epoch, env_step)
|
train_fn(epoch, env_step)
|
||||||
result = train_collector.collect(n_step=collect_per_step)
|
result = train_collector.collect(n_step=step_per_collect)
|
||||||
if len(result["rews"]) > 0 and reward_metric:
|
if len(result["rews"]) > 0 and reward_metric:
|
||||||
result["rews"] = reward_metric(result["rews"])
|
result["rews"] = reward_metric(result["rews"])
|
||||||
env_step += int(result["n/st"])
|
env_step += int(result["n/st"])
|
||||||
|
t.update(result["n/st"])
|
||||||
data = {
|
data = {
|
||||||
"env_step": str(env_step),
|
"env_step": str(env_step),
|
||||||
"rew": f"{result['rews'].mean():.2f}",
|
"rew": f"{result['rews'].mean():.2f}",
|
||||||
@ -126,8 +129,7 @@ def offpolicy_trainer(
|
|||||||
test_result["rews"].mean(), test_result["rews"].std())
|
test_result["rews"].mean(), test_result["rews"].std())
|
||||||
else:
|
else:
|
||||||
policy.train()
|
policy.train()
|
||||||
for i in range(update_per_step * min(
|
for i in range(round(update_per_step * result["n/st"])):
|
||||||
result["n/st"] // collect_per_step, t.total - t.n)):
|
|
||||||
gradient_step += 1
|
gradient_step += 1
|
||||||
losses = policy.update(batch_size, train_collector.buffer)
|
losses = policy.update(batch_size, train_collector.buffer)
|
||||||
for k in losses.keys():
|
for k in losses.keys():
|
||||||
@ -136,21 +138,21 @@ def offpolicy_trainer(
|
|||||||
if writer and gradient_step % log_interval == 0:
|
if writer and gradient_step % log_interval == 0:
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
k, stat[k].get(), global_step=gradient_step)
|
k, stat[k].get(), global_step=gradient_step)
|
||||||
t.update(1)
|
|
||||||
t.set_postfix(**data)
|
t.set_postfix(**data)
|
||||||
if t.n <= t.total:
|
if t.n <= t.total:
|
||||||
t.update()
|
t.update()
|
||||||
# test
|
# test
|
||||||
result = test_episode(policy, test_collector, test_fn, epoch,
|
test_result = test_episode(policy, test_collector, test_fn, epoch,
|
||||||
episode_per_test, writer, env_step, reward_metric)
|
episode_per_test, writer, env_step, reward_metric)
|
||||||
if best_epoch == -1 or best_reward < result["rews"].mean():
|
if best_epoch == -1 or best_reward < test_result["rews"].mean():
|
||||||
best_reward, best_reward_std = result["rews"].mean(), result["rews"].std()
|
best_reward = test_result["rews"].mean()
|
||||||
|
best_reward_std = test_result['rews'].std()
|
||||||
best_epoch = epoch
|
best_epoch = epoch
|
||||||
if save_fn:
|
if save_fn:
|
||||||
save_fn(policy)
|
save_fn(policy)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Epoch #{epoch}: test_reward: {result['rews'].mean():.6f} ± "
|
print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± "
|
||||||
f"{result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± "
|
f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± "
|
||||||
f"{best_reward_std:.6f} in #{best_epoch}")
|
f"{best_reward_std:.6f} in #{best_epoch}")
|
||||||
if stop_fn and stop_fn(best_reward):
|
if stop_fn and stop_fn(best_reward):
|
||||||
break
|
break
|
||||||
|
|||||||
@ -17,10 +17,11 @@ def onpolicy_trainer(
|
|||||||
test_collector: Collector,
|
test_collector: Collector,
|
||||||
max_epoch: int,
|
max_epoch: int,
|
||||||
step_per_epoch: int,
|
step_per_epoch: int,
|
||||||
collect_per_step: int,
|
|
||||||
repeat_per_collect: int,
|
repeat_per_collect: int,
|
||||||
episode_per_test: int,
|
episode_per_test: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
step_per_collect: Optional[int] = None,
|
||||||
|
episode_per_collect: Optional[int] = None,
|
||||||
train_fn: Optional[Callable[[int, int], None]] = None,
|
train_fn: Optional[Callable[[int, int], None]] = None,
|
||||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||||
@ -33,60 +34,67 @@ def onpolicy_trainer(
|
|||||||
) -> Dict[str, Union[float, str]]:
|
) -> Dict[str, Union[float, str]]:
|
||||||
"""A wrapper for on-policy trainer procedure.
|
"""A wrapper for on-policy trainer procedure.
|
||||||
|
|
||||||
The "step" in trainer means a policy network update.
|
The "step" in trainer means an environment step (a.k.a. transition).
|
||||||
|
|
||||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||||
:param train_collector: the collector used for training.
|
:param Collector train_collector: the collector used for training.
|
||||||
:type train_collector: :class:`~tianshou.data.Collector`
|
:param Collector test_collector: the collector used for testing.
|
||||||
:param test_collector: the collector used for testing.
|
:param int max_epoch: the maximum number of epochs for training. The training
|
||||||
:type test_collector: :class:`~tianshou.data.Collector`
|
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
|
||||||
:param int max_epoch: the maximum number of epochs for training. The
|
:param int step_per_epoch: the number of transitions collected per epoch.
|
||||||
training process might be finished before reaching the ``max_epoch``.
|
:param int repeat_per_collect: the number of repeat time for policy learning, for
|
||||||
:param int step_per_epoch: the number of policy network updates, so-called
|
example, set it to 2 means the policy needs to learn each given batch data
|
||||||
gradient steps, per epoch.
|
twice.
|
||||||
:param int collect_per_step: the number of episodes the collector would
|
:param int episode_per_test: the number of episodes for one policy evaluation.
|
||||||
collect before the network update. In other words, collect some
|
:param int batch_size: the batch size of sample data, which is going to feed in the
|
||||||
episodes and do one policy network update.
|
policy network.
|
||||||
:param int repeat_per_collect: the number of repeat time for policy
|
:param int step_per_collect: the number of transitions the collector would collect
|
||||||
learning, for example, set it to 2 means the policy needs to learn each
|
before the network update, i.e., trainer will collect "step_per_collect"
|
||||||
given batch data twice.
|
transitions and do some policy network update repeatly in each epoch.
|
||||||
:param episode_per_test: the number of episodes for one policy evaluation.
|
:param int episode_per_collect: the number of episodes the collector would collect
|
||||||
:type episode_per_test: int or list of ints
|
before the network update, i.e., trainer will collect "episode_per_collect"
|
||||||
:param int batch_size: the batch size of sample data, which is going to
|
episodes and do some policy network update repeatly in each epoch.
|
||||||
feed in the policy network.
|
:param function train_fn: a hook called at the beginning of training in each epoch.
|
||||||
:param function train_fn: a hook called at the beginning of training in
|
It can be used to perform custom additional operations, with the signature ``f(
|
||||||
each epoch. It can be used to perform custom additional operations,
|
num_epoch: int, step_idx: int) -> None``.
|
||||||
with the signature ``f(num_epoch: int, step_idx: int) -> None``.
|
:param function test_fn: a hook called at the beginning of testing in each epoch.
|
||||||
:param function test_fn: a hook called at the beginning of testing in each
|
It can be used to perform custom additional operations, with the signature ``f(
|
||||||
epoch. It can be used to perform custom additional operations, with the
|
num_epoch: int, step_idx: int) -> None``.
|
||||||
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
:param function save_fn: a hook called when the undiscounted average mean reward in
|
||||||
:param function save_fn: a hook called when the undiscounted average mean
|
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
|
||||||
reward in evaluation phase gets better, with the signature ``f(policy:
|
None``.
|
||||||
BasePolicy) -> None``.
|
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||||
:param function stop_fn: a function with signature ``f(mean_rewards: float)
|
bool``, receives the average undiscounted returns of the testing result,
|
||||||
-> bool``, receives the average undiscounted returns of the testing
|
returns a boolean which indicates whether reaching the goal.
|
||||||
result, returns a boolean which indicates whether reaching the goal.
|
|
||||||
:param function reward_metric: a function with signature ``f(rewards: np.ndarray
|
:param function reward_metric: a function with signature ``f(rewards: np.ndarray
|
||||||
with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
|
with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
|
||||||
used in multi-agent RL. We need to return a single scalar for each episode's
|
used in multi-agent RL. We need to return a single scalar for each episode's
|
||||||
result to monitor training in the multi-agent RL setting. This function
|
result to monitor training in the multi-agent RL setting. This function
|
||||||
specifies what is the desired metric, e.g., the reward of agent 1 or the
|
specifies what is the desired metric, e.g., the reward of agent 1 or the
|
||||||
average reward over all agents.
|
average reward over all agents.
|
||||||
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
|
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter;
|
||||||
SummaryWriter; if None is given, it will not write logs to TensorBoard.
|
if None is given, it will not write logs to TensorBoard. Default to None.
|
||||||
:param int log_interval: the log interval of the writer.
|
:param int log_interval: the log interval of the writer. Default to 1.
|
||||||
:param bool verbose: whether to print the information.
|
:param bool verbose: whether to print the information. Default to True.
|
||||||
:param bool test_in_train: whether to test in the training phase.
|
:param bool test_in_train: whether to test in the training phase. Default to True.
|
||||||
|
|
||||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Only either one of step_per_collect and episode_per_collect can be specified.
|
||||||
"""
|
"""
|
||||||
env_step, gradient_step = 0, 0
|
env_step, gradient_step = 0, 0
|
||||||
best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
|
|
||||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
train_collector.reset_stat()
|
train_collector.reset_stat()
|
||||||
test_collector.reset_stat()
|
test_collector.reset_stat()
|
||||||
test_in_train = test_in_train and train_collector.policy == policy
|
test_in_train = test_in_train and train_collector.policy == policy
|
||||||
|
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
|
||||||
|
writer, env_step, reward_metric)
|
||||||
|
best_epoch = 0
|
||||||
|
best_reward = test_result["rews"].mean()
|
||||||
|
best_reward_std = test_result["rews"].std()
|
||||||
for epoch in range(1, 1 + max_epoch):
|
for epoch in range(1, 1 + max_epoch):
|
||||||
# train
|
# train
|
||||||
policy.train()
|
policy.train()
|
||||||
@ -96,10 +104,12 @@ def onpolicy_trainer(
|
|||||||
while t.n < t.total:
|
while t.n < t.total:
|
||||||
if train_fn:
|
if train_fn:
|
||||||
train_fn(epoch, env_step)
|
train_fn(epoch, env_step)
|
||||||
result = train_collector.collect(n_episode=collect_per_step)
|
result = train_collector.collect(n_step=step_per_collect,
|
||||||
|
n_episode=episode_per_collect)
|
||||||
if reward_metric:
|
if reward_metric:
|
||||||
result["rews"] = reward_metric(result["rews"])
|
result["rews"] = reward_metric(result["rews"])
|
||||||
env_step += int(result["n/st"])
|
env_step += int(result["n/st"])
|
||||||
|
t.update(result["n/st"])
|
||||||
data = {
|
data = {
|
||||||
"env_step": str(env_step),
|
"env_step": str(env_step),
|
||||||
"rew": f"{result['rews'].mean():.2f}",
|
"rew": f"{result['rews'].mean():.2f}",
|
||||||
@ -138,21 +148,21 @@ def onpolicy_trainer(
|
|||||||
if writer and gradient_step % log_interval == 0:
|
if writer and gradient_step % log_interval == 0:
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
k, stat[k].get(), global_step=gradient_step)
|
k, stat[k].get(), global_step=gradient_step)
|
||||||
t.update(step)
|
|
||||||
t.set_postfix(**data)
|
t.set_postfix(**data)
|
||||||
if t.n <= t.total:
|
if t.n <= t.total:
|
||||||
t.update()
|
t.update()
|
||||||
# test
|
# test
|
||||||
result = test_episode(policy, test_collector, test_fn, epoch,
|
test_result = test_episode(policy, test_collector, test_fn, epoch,
|
||||||
episode_per_test, writer, env_step)
|
episode_per_test, writer, env_step)
|
||||||
if best_epoch == -1 or best_reward < result["rews"].mean():
|
if best_epoch == -1 or best_reward < test_result["rews"].mean():
|
||||||
best_reward, best_reward_std = result["rews"].mean(), result["rews"].std()
|
best_reward = test_result["rews"].mean()
|
||||||
|
best_reward_std = test_result['rews'].std()
|
||||||
best_epoch = epoch
|
best_epoch = epoch
|
||||||
if save_fn:
|
if save_fn:
|
||||||
save_fn(policy)
|
save_fn(policy)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Epoch #{epoch}: test_reward: {result['rews'].mean():.6f} ± "
|
print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± "
|
||||||
f"{result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± "
|
f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± "
|
||||||
f"{best_reward_std:.6f} in #{best_epoch}")
|
f"{best_reward_std:.6f} in #{best_epoch}")
|
||||||
if stop_fn and stop_fn(best_reward):
|
if stop_fn and stop_fn(best_reward):
|
||||||
break
|
break
|
||||||
|
|||||||
@ -48,14 +48,14 @@ def gather_info(
|
|||||||
|
|
||||||
* ``train_step`` the total collected step of training collector;
|
* ``train_step`` the total collected step of training collector;
|
||||||
* ``train_episode`` the total collected episode of training collector;
|
* ``train_episode`` the total collected episode of training collector;
|
||||||
* ``train_time/collector`` the time for collecting frames in the \
|
* ``train_time/collector`` the time for collecting transitions in the \
|
||||||
training collector;
|
training collector;
|
||||||
* ``train_time/model`` the time for training models;
|
* ``train_time/model`` the time for training models;
|
||||||
* ``train_speed`` the speed of training (frames per second);
|
* ``train_speed`` the speed of training (env_step per second);
|
||||||
* ``test_step`` the total collected step of test collector;
|
* ``test_step`` the total collected step of test collector;
|
||||||
* ``test_episode`` the total collected episode of test collector;
|
* ``test_episode`` the total collected episode of test collector;
|
||||||
* ``test_time`` the time for testing;
|
* ``test_time`` the time for testing;
|
||||||
* ``test_speed`` the speed of testing (frames per second);
|
* ``test_speed`` the speed of testing (env_step per second);
|
||||||
* ``best_reward`` the best reward over the test results;
|
* ``best_reward`` the best reward over the test results;
|
||||||
* ``duration`` the total elapsed time.
|
* ``duration`` the total elapsed time.
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user