Add offline trainer and discrete BCQ algorithm (#263)
The result needs to be tuned after `done` issue fixed. Co-authored-by: n+e <trinkle23897@gmail.com>
This commit is contained in:
parent
a633a6a028
commit
a511cb4779
2
.gitignore
vendored
2
.gitignore
vendored
@ -145,3 +145,5 @@ MUJOCO_LOG.TXT
|
||||
*.zip
|
||||
*.pstats
|
||||
*.swp
|
||||
*.pkl
|
||||
*.hdf5
|
||||
|
@ -31,6 +31,7 @@
|
||||
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
||||
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
|
||||
- Vanilla Imitation Learning
|
||||
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
|
||||
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
|
||||
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
||||
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
|
||||
|
@ -20,8 +20,9 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
||||
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
||||
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
|
||||
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
|
||||
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
||||
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
|
||||
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
|
||||
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
|
||||
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
|
||||
|
||||
|
@ -201,7 +201,7 @@ Trainer
|
||||
|
||||
Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`.
|
||||
|
||||
Tianshou has two types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` and :func:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out :doc:`/api/tianshou.trainer` for the usage.
|
||||
Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage.
|
||||
|
||||
|
||||
.. _pseudocode:
|
||||
|
@ -120,7 +120,7 @@ In each step, the collector will let the policy perform (at least) a specified n
|
||||
Train Policy with a Trainer
|
||||
---------------------------
|
||||
|
||||
Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tianshou.trainer.offpolicy_trainer`. The trainer will automatically stop training when the policy reach the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :class:`~tianshou.trainer.offpolicy_trainer` as follows:
|
||||
Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.trainer.offpolicy_trainer`, and :func:`~tianshou.trainer.offline_trainer`. The trainer will automatically stop training when the policy reach the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :func:`~tianshou.trainer.offpolicy_trainer` as follows:
|
||||
::
|
||||
|
||||
result = ts.trainer.offpolicy_trainer(
|
||||
@ -133,7 +133,7 @@ Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tians
|
||||
writer=None)
|
||||
print(f'Finished training! Use {result["duration"]}')
|
||||
|
||||
The meaning of each parameter is as follows (full description can be found at :meth:`~tianshou.trainer.offpolicy_trainer`):
|
||||
The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`):
|
||||
|
||||
* ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``;
|
||||
* ``step_per_epoch``: The number of step for updating policy network in one epoch;
|
||||
|
@ -38,4 +38,15 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
||||
| SeaquestNoFrameskip-v4 | 6226 |  | `python3 atari_c51.py --task "SeaquestNoFrameskip-v4"` |
|
||||
| SpaceInvadersNoFrameskip-v4 | 988.5 |  | `python3 atari_c51.py --task "SpaceInvadersNoFrameskip-v4"` |
|
||||
|
||||
Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper.
|
||||
Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper.
|
||||
|
||||
# BCQ
|
||||
|
||||
TODO: after the `done` issue fixed, the result should be re-tuned and place here.
|
||||
|
||||
To running BCQ algorithm on Atari, you need to do the following things:
|
||||
|
||||
- Train an expert, by using the command listed in the above DQN section;
|
||||
- Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
|
||||
- Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`.
|
||||
|
||||
|
153
examples/atari/atari_bcq.py
Normal file
153
examples/atari/atari_bcq.py
Normal file
@ -0,0 +1,153 @@
|
||||
import os
|
||||
import torch
|
||||
import pickle
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.trainer import offline_trainer
|
||||
from tianshou.utils.net.discrete import Actor
|
||||
from tianshou.policy import DiscreteBCQPolicy
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
from atari_network import DQN
|
||||
from atari_wrapper import wrap_deepmind
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
|
||||
parser.add_argument("--seed", type=int, default=1626)
|
||||
parser.add_argument("--eps-test", type=float, default=0.001)
|
||||
parser.add_argument("--lr", type=float, default=6.25e-5)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument("--n-step", type=int, default=3)
|
||||
parser.add_argument("--target-update-freq", type=int, default=8000)
|
||||
parser.add_argument("--unlikely-action-threshold", type=float, default=0.3)
|
||||
parser.add_argument("--imitation-logits-penalty", type=float, default=0.01)
|
||||
parser.add_argument("--epoch", type=int, default=100)
|
||||
parser.add_argument("--step-per-epoch", type=int, default=10000)
|
||||
parser.add_argument("--batch-size", type=int, default=32)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[512])
|
||||
parser.add_argument("--test-num", type=int, default=100)
|
||||
parser.add_argument('--frames_stack', type=int, default=4)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=0.)
|
||||
parser.add_argument("--resume-path", type=str, default=None)
|
||||
parser.add_argument("--watch", default=False, action="store_true",
|
||||
help="watch the play of pre-trained policy only")
|
||||
parser.add_argument("--log-interval", type=int, default=1000)
|
||||
parser.add_argument(
|
||||
"--load-buffer-name", type=str,
|
||||
default="./expert_DQN_PongNoFrameskip-v4.hdf5",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def make_atari_env(args):
|
||||
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
|
||||
|
||||
|
||||
def make_atari_env_watch(args):
|
||||
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
|
||||
episode_life=False, clip_rewards=False)
|
||||
|
||||
|
||||
def test_discrete_bcq(args=get_args()):
|
||||
# envs
|
||||
env = make_atari_env(args)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# should be N_FRAMES x H x W
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
|
||||
for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
feature_net = DQN(*args.state_shape, args.action_shape,
|
||||
device=args.device, features_only=True).to(args.device)
|
||||
policy_net = Actor(feature_net, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes).to(args.device)
|
||||
imitation_net = Actor(feature_net, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes).to(args.device)
|
||||
optim = torch.optim.Adam(
|
||||
set(policy_net.parameters()).union(imitation_net.parameters()),
|
||||
lr=args.lr,
|
||||
)
|
||||
# define policy
|
||||
policy = DiscreteBCQPolicy(
|
||||
policy_net, imitation_net, optim, args.gamma, args.n_step,
|
||||
args.target_update_freq, args.eps_test,
|
||||
args.unlikely_action_threshold, args.imitation_logits_penalty,
|
||||
)
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(
|
||||
args.resume_path, map_location=args.device
|
||||
))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
# buffer
|
||||
assert os.path.exists(args.load_buffer_name), \
|
||||
"Please run atari_dqn.py first to get expert's data buffer."
|
||||
if args.load_buffer_name.endswith('.pkl'):
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
elif args.load_buffer_name.endswith('.hdf5'):
|
||||
buffer = ReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||
else:
|
||||
print(f"Unknown buffer format: {args.load_buffer_name}")
|
||||
exit(0)
|
||||
|
||||
# collector
|
||||
test_collector = Collector(policy, test_envs)
|
||||
|
||||
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return False
|
||||
|
||||
# watch agent's performance
|
||||
def watch():
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=[1] * args.test_num,
|
||||
render=args.render)
|
||||
pprint.pprint(result)
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
exit(0)
|
||||
|
||||
result = offline_trainer(
|
||||
policy, buffer, test_collector,
|
||||
args.epoch, args.step_per_epoch, args.test_num, args.batch_size,
|
||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
||||
log_interval=args.log_interval,
|
||||
)
|
||||
|
||||
pprint.pprint(result)
|
||||
watch()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_discrete_bcq(get_args())
|
@ -41,6 +41,7 @@ def get_args():
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument('--watch', default=False, action='store_true',
|
||||
help='watch the play of pre-trained policy only')
|
||||
parser.add_argument('--save-buffer-name', type=str, default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -120,13 +121,25 @@ def test_dqn(args=get_args()):
|
||||
|
||||
# watch agent's performance
|
||||
def watch():
|
||||
print("Testing agent ...")
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=[1] * args.test_num,
|
||||
render=args.render)
|
||||
if args.save_buffer_name:
|
||||
print(f"Generate buffer with size {args.buffer_size}")
|
||||
buffer = ReplayBuffer(
|
||||
args.buffer_size, ignore_obs_next=True,
|
||||
save_only_last_obs=True, stack_num=args.frames_stack)
|
||||
collector = Collector(policy, test_envs, buffer)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=[1] * args.test_num,
|
||||
render=args.render)
|
||||
pprint.pprint(result)
|
||||
|
||||
if args.watch:
|
||||
|
@ -44,8 +44,7 @@ class DQN(nn.Module):
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: x -> Q(x, \*)."""
|
||||
x = torch.as_tensor(
|
||||
x, device=self.device, dtype=torch.float32) # type: ignore
|
||||
x = torch.as_tensor(x, device=self.device, dtype=torch.float32)
|
||||
return self.net(x), state
|
||||
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -51,7 +51,7 @@ setup(
|
||||
"tensorboard",
|
||||
"torch>=1.4.0",
|
||||
"numba>=0.51.0",
|
||||
"h5py>=3.1.0"
|
||||
"h5py>=2.10.0", # to match tensorflow's minimal requirements
|
||||
],
|
||||
extras_require={
|
||||
"dev": [
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pickle
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
@ -38,6 +39,9 @@ def get_args():
|
||||
action="store_true", default=False)
|
||||
parser.add_argument('--alpha', type=float, default=0.6)
|
||||
parser.add_argument('--beta', type=float, default=0.4)
|
||||
parser.add_argument(
|
||||
'--save-buffer-name', type=str,
|
||||
default="./expert_DQN_CartPole-v0.pkl")
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
@ -114,6 +118,7 @@ def test_dqn(args=get_args()):
|
||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -124,6 +129,12 @@ def test_dqn(args=get_args()):
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
|
||||
# save buffer in pickle format, for imitation learning unittest
|
||||
buf = ReplayBuffer(args.buffer_size)
|
||||
collector = Collector(policy, test_envs, buf)
|
||||
collector.collect(n_step=args.buffer_size)
|
||||
pickle.dump(buf, open(args.save_buffer_name, "wb"))
|
||||
|
||||
|
||||
def test_pdqn(args=get_args()):
|
||||
args.prioritized_replay = True
|
||||
|
111
test/discrete/test_il_bcq.py
Normal file
111
test/discrete/test_il_bcq.py
Normal file
@ -0,0 +1,111 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pickle
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import offline_trainer
|
||||
from tianshou.policy import DiscreteBCQPolicy
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--seed", type=int, default=1626)
|
||||
parser.add_argument("--eps-test", type=float, default=0.001)
|
||||
parser.add_argument("--lr", type=float, default=3e-4)
|
||||
parser.add_argument("--gamma", type=float, default=0.9)
|
||||
parser.add_argument("--n-step", type=int, default=3)
|
||||
parser.add_argument("--target-update-freq", type=int, default=320)
|
||||
parser.add_argument("--unlikely-action-threshold", type=float, default=0.3)
|
||||
parser.add_argument("--imitation-logits-penalty", type=float, default=0.01)
|
||||
parser.add_argument("--epoch", type=int, default=5)
|
||||
parser.add_argument("--step-per-epoch", type=int, default=1000)
|
||||
parser.add_argument("--batch-size", type=int, default=64)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128, 128])
|
||||
parser.add_argument("--test-num", type=int, default=100)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=0.)
|
||||
parser.add_argument(
|
||||
"--load-buffer-name", type=str,
|
||||
default="./expert_DQN_CartPole-v0.pkl",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_discrete_bcq(args=get_args()):
|
||||
# envs
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
policy_net = Net(
|
||||
args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, device=args.device).to(args.device)
|
||||
imitation_net = Net(
|
||||
args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, device=args.device).to(args.device)
|
||||
optim = torch.optim.Adam(
|
||||
set(policy_net.parameters()).union(imitation_net.parameters()),
|
||||
lr=args.lr)
|
||||
|
||||
policy = DiscreteBCQPolicy(
|
||||
policy_net, imitation_net, optim, args.gamma, args.n_step,
|
||||
args.target_update_freq, args.eps_test,
|
||||
args.unlikely_action_threshold, args.imitation_logits_penalty,
|
||||
)
|
||||
# buffer
|
||||
assert os.path.exists(args.load_buffer_name), \
|
||||
"Please run test_dqn.py first to get expert's data buffer."
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
|
||||
# collector
|
||||
test_collector = Collector(policy, test_envs)
|
||||
|
||||
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
result = offline_trainer(
|
||||
policy, buffer, test_collector,
|
||||
args.epoch, args.step_per_epoch, args.test_num, args.batch_size,
|
||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_discrete_bcq(get_args())
|
@ -1,7 +1,7 @@
|
||||
from tianshou import data, env, utils, policy, trainer, exploration
|
||||
|
||||
|
||||
__version__ = "0.3.0"
|
||||
__version__ = "0.3.1"
|
||||
|
||||
__all__ = [
|
||||
"env",
|
||||
|
@ -10,6 +10,7 @@ from tianshou.policy.modelfree.ppo import PPOPolicy
|
||||
from tianshou.policy.modelfree.td3 import TD3Policy
|
||||
from tianshou.policy.modelfree.sac import SACPolicy
|
||||
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
|
||||
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
|
||||
from tianshou.policy.modelbase.psrl import PSRLPolicy
|
||||
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
|
||||
|
||||
@ -27,6 +28,7 @@ __all__ = [
|
||||
"TD3Policy",
|
||||
"SACPolicy",
|
||||
"DiscreteSACPolicy",
|
||||
"DiscreteBCQPolicy",
|
||||
"PSRLPolicy",
|
||||
"MultiAgentPolicyManager",
|
||||
]
|
||||
|
@ -57,7 +57,7 @@ class ImitationPolicy(BasePolicy):
|
||||
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
|
||||
loss = F.mse_loss(a, a_) # type: ignore
|
||||
elif self.mode == "discrete": # classification
|
||||
a = self(batch).logits
|
||||
a = F.log_softmax(self(batch).logits, dim=-1)
|
||||
a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
|
||||
loss = F.nll_loss(a, a_) # type: ignore
|
||||
loss.backward()
|
||||
|
139
tianshou/policy/imitation/discrete_bcq.py
Normal file
139
tianshou/policy/imitation/discrete_bcq.py
Normal file
@ -0,0 +1,139 @@
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from typing import Any, Dict, Union, Optional
|
||||
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch
|
||||
|
||||
|
||||
class DiscreteBCQPolicy(DQNPolicy):
|
||||
"""Implementation of discrete BCQ algorithm. arXiv:1910.01708.
|
||||
|
||||
:param torch.nn.Module model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> q_value)
|
||||
:param torch.nn.Module imitator: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> imtation_logits)
|
||||
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||
:param float discount_factor: in [0, 1].
|
||||
:param int estimation_step: greater than 1, the number of steps to look
|
||||
ahead.
|
||||
:param int target_update_freq: the target network update frequency.
|
||||
:param float eval_eps: the epsilon-greedy noise added in evaluation.
|
||||
:param float unlikely_action_threshold: the threshold (tau) for unlikely
|
||||
actions, as shown in Equ. (17) in the paper, defaults to 0.3.
|
||||
:param float imitation_logits_penalty: reguralization weight for imitation
|
||||
logits, defaults to 1e-2.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to False.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
imitator: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
discount_factor: float = 0.99,
|
||||
estimation_step: int = 1,
|
||||
target_update_freq: int = 8000,
|
||||
eval_eps: float = 1e-3,
|
||||
unlikely_action_threshold: float = 0.3,
|
||||
imitation_logits_penalty: float = 1e-2,
|
||||
reward_normalization: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(model, optim, discount_factor, estimation_step,
|
||||
target_update_freq, reward_normalization, **kwargs)
|
||||
assert target_update_freq > 0, "BCQ needs target network setting."
|
||||
self.imitator = imitator
|
||||
assert (
|
||||
0.0 <= unlikely_action_threshold < 1.0
|
||||
), "unlikely_action_threshold should be in [0, 1)"
|
||||
if unlikely_action_threshold > 0:
|
||||
self._log_tau = math.log(unlikely_action_threshold)
|
||||
else:
|
||||
self._log_tau = -np.inf
|
||||
assert 0.0 <= eval_eps < 1.0
|
||||
self._eps = eval_eps
|
||||
self._weight_reg = imitation_logits_penalty
|
||||
|
||||
def train(self, mode: bool = True) -> "DiscreteBCQPolicy":
|
||||
self.training = mode
|
||||
self.model.train(mode)
|
||||
self.imitator.train(mode)
|
||||
return self
|
||||
|
||||
def _target_q(
|
||||
self, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
with torch.no_grad():
|
||||
act = self(batch, input="obs_next", eps=0.0).act
|
||||
target_q, _ = self.model_old(batch.obs_next)
|
||||
target_q = target_q[np.arange(len(act)), act]
|
||||
return target_q
|
||||
|
||||
def forward( # type: ignore
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
input: str = "obs",
|
||||
eps: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
if eps is None:
|
||||
eps = self._eps
|
||||
obs = batch[input]
|
||||
q_value, state = self.model(obs, state=state, info=batch.info)
|
||||
imitation_logits, _ = self.imitator(obs, state=state, info=batch.info)
|
||||
|
||||
# mask actions for argmax
|
||||
ratio = imitation_logits - imitation_logits.max(
|
||||
dim=-1, keepdim=True).values
|
||||
mask = (ratio < self._log_tau).float()
|
||||
action = (q_value - np.inf * mask).argmax(dim=-1)
|
||||
|
||||
# add eps to act
|
||||
if not np.isclose(eps, 0.0):
|
||||
bsz, action_num = q_value.shape
|
||||
mask = np.random.rand(bsz) < eps
|
||||
action_rand = torch.randint(
|
||||
action_num, size=[bsz], device=action.device)
|
||||
action[mask] = action_rand[mask]
|
||||
|
||||
return Batch(act=action, state=state, q_value=q_value,
|
||||
imitation_logits=imitation_logits)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
if self._iter % self._freq == 0:
|
||||
self.sync_weight()
|
||||
self._iter += 1
|
||||
|
||||
target_q = batch.returns.flatten()
|
||||
result = self(batch, eps=0.0)
|
||||
imitation_logits = result.imitation_logits
|
||||
current_q = result.q_value[np.arange(len(target_q)), batch.act]
|
||||
act = to_torch(batch.act, dtype=torch.long, device=target_q.device)
|
||||
q_loss = F.smooth_l1_loss(current_q, target_q)
|
||||
i_loss = F.nll_loss(
|
||||
F.log_softmax(imitation_logits, dim=-1), act) # type: ignore
|
||||
reg_loss = imitation_logits.pow(2).mean()
|
||||
loss = q_loss + i_loss + self._weight_reg * reg_loss
|
||||
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
|
||||
return {
|
||||
"loss": loss.item(),
|
||||
"q_loss": q_loss.item(),
|
||||
"i_loss": i_loss.item(),
|
||||
"reg_loss": reg_loss.item(),
|
||||
}
|
@ -58,8 +58,8 @@ class A2CPolicy(PGPolicy):
|
||||
self.critic = critic
|
||||
assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]."
|
||||
self._lambda = gae_lambda
|
||||
self._w_vf = vf_coef
|
||||
self._w_ent = ent_coef
|
||||
self._weight_vf = vf_coef
|
||||
self._weight_ent = ent_coef
|
||||
self._grad_norm = max_grad_norm
|
||||
self._batch = max_batchsize
|
||||
self._rew_norm = reward_normalization
|
||||
@ -122,7 +122,8 @@ class A2CPolicy(PGPolicy):
|
||||
a_loss = -(log_prob * (r - v).detach()).mean()
|
||||
vf_loss = F.mse_loss(r, v) # type: ignore
|
||||
ent_loss = dist.entropy().mean()
|
||||
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
|
||||
loss = a_loss + self._weight_vf * vf_loss - \
|
||||
self._weight_ent * ent_loss
|
||||
loss.backward()
|
||||
if self._grad_norm is not None:
|
||||
nn.utils.clip_grad_norm_(
|
||||
|
@ -124,7 +124,7 @@ class C51Policy(DQNPolicy):
|
||||
return target_dist.sum(-1)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
if self._target and self._cnt % self._freq == 0:
|
||||
if self._target and self._iter % self._freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
with torch.no_grad():
|
||||
@ -139,5 +139,5 @@ class C51Policy(DQNPolicy):
|
||||
batch.weight = cross_entropy.detach() # prio-buffer
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
self._cnt += 1
|
||||
self._iter += 1
|
||||
return {"loss": loss.item()}
|
||||
|
@ -54,7 +54,7 @@ class DQNPolicy(BasePolicy):
|
||||
self._n_step = estimation_step
|
||||
self._target = target_update_freq > 0
|
||||
self._freq = target_update_freq
|
||||
self._cnt = 0
|
||||
self._iter = 0
|
||||
if self._target:
|
||||
self.model_old = deepcopy(self.model)
|
||||
self.model_old.eval()
|
||||
@ -78,16 +78,15 @@ class DQNPolicy(BasePolicy):
|
||||
self, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||
if self._target:
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
a = self(batch, input="obs_next").act
|
||||
with torch.no_grad():
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
with torch.no_grad():
|
||||
if self._target:
|
||||
a = self(batch, input="obs_next").act
|
||||
target_q = self(
|
||||
batch, model="model_old", input="obs_next"
|
||||
).logits
|
||||
target_q = target_q[np.arange(len(a)), a]
|
||||
else:
|
||||
with torch.no_grad():
|
||||
target_q = target_q[np.arange(len(a)), a]
|
||||
else:
|
||||
target_q = self(batch, input="obs_next").logits.max(dim=1)[0]
|
||||
return target_q
|
||||
|
||||
@ -162,7 +161,7 @@ class DQNPolicy(BasePolicy):
|
||||
return Batch(logits=q, act=act, state=h)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
if self._target and self._cnt % self._freq == 0:
|
||||
if self._target and self._iter % self._freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
weight = batch.pop("weight", 1.0)
|
||||
@ -174,5 +173,5 @@ class DQNPolicy(BasePolicy):
|
||||
batch.weight = td # prio-buffer
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
self._cnt += 1
|
||||
self._iter += 1
|
||||
return {"loss": loss.item()}
|
||||
|
@ -68,8 +68,8 @@ class PPOPolicy(PGPolicy):
|
||||
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
|
||||
self._max_grad_norm = max_grad_norm
|
||||
self._eps_clip = eps_clip
|
||||
self._w_vf = vf_coef
|
||||
self._w_ent = ent_coef
|
||||
self._weight_vf = vf_coef
|
||||
self._weight_ent = ent_coef
|
||||
self._range = action_range
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
@ -174,7 +174,8 @@ class PPOPolicy(PGPolicy):
|
||||
vf_losses.append(vf_loss.item())
|
||||
e_loss = dist.entropy().mean()
|
||||
ent_losses.append(e_loss.item())
|
||||
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss
|
||||
loss = clip_loss + self._weight_vf * vf_loss - \
|
||||
self._weight_ent * e_loss
|
||||
losses.append(loss.item())
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
|
@ -143,7 +143,6 @@ class SACPolicy(DDPGPolicy):
|
||||
with torch.no_grad():
|
||||
obs_next_result = self(batch, input='obs_next')
|
||||
a_ = obs_next_result.act
|
||||
batch.act = to_torch_as(batch.act, a_)
|
||||
target_q = torch.min(
|
||||
self.critic1_old(batch.obs_next, a_),
|
||||
self.critic2_old(batch.obs_next, a_),
|
||||
|
@ -1,10 +1,12 @@
|
||||
from tianshou.trainer.utils import test_episode, gather_info
|
||||
from tianshou.trainer.onpolicy import onpolicy_trainer
|
||||
from tianshou.trainer.offpolicy import offpolicy_trainer
|
||||
from tianshou.trainer.offline import offline_trainer
|
||||
|
||||
__all__ = [
|
||||
"gather_info",
|
||||
"test_episode",
|
||||
"onpolicy_trainer",
|
||||
"offpolicy_trainer",
|
||||
"offline_trainer",
|
||||
]
|
||||
|
97
tianshou/trainer/offline.py
Normal file
97
tianshou/trainer/offline.py
Normal file
@ -0,0 +1,97 @@
|
||||
import time
|
||||
import tqdm
|
||||
from collections import defaultdict
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from typing import Dict, List, Union, Callable, Optional
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.utils import tqdm_config, MovAvg
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.trainer import test_episode, gather_info
|
||||
|
||||
|
||||
def offline_trainer(
|
||||
policy: BasePolicy,
|
||||
buffer: ReplayBuffer,
|
||||
test_collector: Collector,
|
||||
max_epoch: int,
|
||||
step_per_epoch: int,
|
||||
episode_per_test: Union[int, List[int]],
|
||||
batch_size: int,
|
||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
log_interval: int = 1,
|
||||
verbose: bool = True,
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
"""A wrapper for offline trainer procedure.
|
||||
|
||||
The "step" in trainer means a policy network update.
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||
class.
|
||||
:param test_collector: the collector used for testing.
|
||||
:type test_collector: :class:`~tianshou.data.Collector`
|
||||
:param int max_epoch: the maximum number of epochs for training. The
|
||||
training process might be finished before reaching the ``max_epoch``.
|
||||
:param int step_per_epoch: the number of policy network updates, so-called
|
||||
gradient steps, per epoch.
|
||||
:param episode_per_test: the number of episodes for one policy evaluation.
|
||||
:param int batch_size: the batch size of sample data, which is going to
|
||||
feed in the policy network.
|
||||
:param function test_fn: a hook called at the beginning of testing in each
|
||||
epoch. It can be used to perform custom additional operations, with the
|
||||
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
||||
:param function save_fn: a hook called when the undiscounted average mean
|
||||
reward in evaluation phase gets better, with the signature ``f(policy:
|
||||
BasePolicy) -> None``.
|
||||
:param function stop_fn: a function with signature ``f(mean_rewards: float)
|
||||
-> bool``, receives the average undiscounted returns of the testing
|
||||
result, returns a boolean which indicates whether reaching the goal.
|
||||
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
|
||||
SummaryWriter; if None is given, it will not write logs to TensorBoard.
|
||||
:param int log_interval: the log interval of the writer.
|
||||
:param bool verbose: whether to print the information.
|
||||
|
||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||
"""
|
||||
gradient_step = 0
|
||||
best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
|
||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||
start_time = time.time()
|
||||
test_collector.reset_stat()
|
||||
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
policy.train()
|
||||
with tqdm.trange(
|
||||
step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
|
||||
) as t:
|
||||
for i in t:
|
||||
gradient_step += 1
|
||||
losses = policy.update(batch_size, buffer)
|
||||
data = {"gradient_step": str(gradient_step)}
|
||||
for k in losses.keys():
|
||||
stat[k].add(losses[k])
|
||||
data[k] = f"{stat[k].get():.6f}"
|
||||
if writer and gradient_step % log_interval == 0:
|
||||
writer.add_scalar(
|
||||
"train/" + k, stat[k].get(),
|
||||
global_step=gradient_step)
|
||||
t.set_postfix(**data)
|
||||
# test
|
||||
result = test_episode(policy, test_collector, test_fn, epoch,
|
||||
episode_per_test, writer, gradient_step)
|
||||
if best_epoch == -1 or best_reward < result["rew"]:
|
||||
best_reward, best_reward_std = result["rew"], result["rew_std"]
|
||||
best_epoch = epoch
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
if verbose:
|
||||
print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± "
|
||||
f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± "
|
||||
f"{best_reward_std:.6f} in #{best_epoch}")
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
return gather_info(start_time, None, test_collector,
|
||||
best_reward, best_reward_std)
|
@ -1,5 +1,6 @@
|
||||
import time
|
||||
import tqdm
|
||||
from collections import defaultdict
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from typing import Dict, List, Union, Callable, Optional
|
||||
|
||||
@ -38,10 +39,10 @@ def offpolicy_trainer(
|
||||
:type train_collector: :class:`~tianshou.data.Collector`
|
||||
:param test_collector: the collector used for testing.
|
||||
:type test_collector: :class:`~tianshou.data.Collector`
|
||||
:param int max_epoch: the maximum of epochs for training. The training
|
||||
process might be finished before reaching the ``max_epoch``.
|
||||
:param int step_per_epoch: the number of step for updating policy network
|
||||
in one epoch.
|
||||
:param int max_epoch: the maximum number of epochs for training. The
|
||||
training process might be finished before reaching the ``max_epoch``.
|
||||
:param int step_per_epoch: the number of policy network updates, so-called
|
||||
gradient steps, per epoch.
|
||||
:param int collect_per_step: the number of frames the collector would
|
||||
collect before the network update. In other words, collect some frames
|
||||
and do some policy network update.
|
||||
@ -52,19 +53,20 @@ def offpolicy_trainer(
|
||||
be updated after frames are collected, for example, set it to 256 means
|
||||
it updates policy 256 times once after ``collect_per_step`` frames are
|
||||
collected.
|
||||
:param function train_fn: a function receives the current number of epoch
|
||||
and step index, and performs some operations at the beginning of
|
||||
training in this epoch.
|
||||
:param function test_fn: a function receives the current number of epoch
|
||||
and step index, and performs some operations at the beginning of
|
||||
testing in this epoch.
|
||||
:param function save_fn: a function for saving policy when the undiscounted
|
||||
average mean reward in evaluation phase gets better.
|
||||
:param function stop_fn: a function receives the average undiscounted
|
||||
returns of the testing result, return a boolean which indicates whether
|
||||
reaching the goal.
|
||||
:param function train_fn: a hook called at the beginning of training in
|
||||
each epoch. It can be used to perform custom additional operations,
|
||||
with the signature ``f(num_epoch: int, step_idx: int) -> None``.
|
||||
:param function test_fn: a hook called at the beginning of testing in each
|
||||
epoch. It can be used to perform custom additional operations, with the
|
||||
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
||||
:param function save_fn: a hook called when the undiscounted average mean
|
||||
reward in evaluation phase gets better, with the signature ``f(policy:
|
||||
BasePolicy) -> None``.
|
||||
:param function stop_fn: a function with signature ``f(mean_rewards: float)
|
||||
-> bool``, receives the average undiscounted returns of the testing
|
||||
result, returns a boolean which indicates whether reaching the goal.
|
||||
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
|
||||
SummaryWriter.
|
||||
SummaryWriter; if None is given, it will not write logs to TensorBoard.
|
||||
:param int log_interval: the log interval of the writer.
|
||||
:param bool verbose: whether to print the information.
|
||||
:param bool test_in_train: whether to test in the training phase.
|
||||
@ -73,7 +75,7 @@ def offpolicy_trainer(
|
||||
"""
|
||||
env_step, gradient_step = 0, 0
|
||||
best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
|
||||
stat: Dict[str, MovAvg] = {}
|
||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||
start_time = time.time()
|
||||
train_collector.reset_stat()
|
||||
test_collector.reset_stat()
|
||||
@ -122,8 +124,6 @@ def offpolicy_trainer(
|
||||
gradient_step += 1
|
||||
losses = policy.update(batch_size, train_collector.buffer)
|
||||
for k in losses.keys():
|
||||
if stat.get(k) is None:
|
||||
stat[k] = MovAvg()
|
||||
stat[k].add(losses[k])
|
||||
data[k] = f"{stat[k].get():.6f}"
|
||||
if writer and gradient_step % log_interval == 0:
|
||||
|
@ -1,5 +1,6 @@
|
||||
import time
|
||||
import tqdm
|
||||
from collections import defaultdict
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from typing import Dict, List, Union, Callable, Optional
|
||||
|
||||
@ -38,10 +39,10 @@ def onpolicy_trainer(
|
||||
:type train_collector: :class:`~tianshou.data.Collector`
|
||||
:param test_collector: the collector used for testing.
|
||||
:type test_collector: :class:`~tianshou.data.Collector`
|
||||
:param int max_epoch: the maximum of epochs for training. The training
|
||||
process might be finished before reaching the ``max_epoch``.
|
||||
:param int step_per_epoch: the number of step for updating policy network
|
||||
in one epoch.
|
||||
:param int max_epoch: the maximum number of epochs for training. The
|
||||
training process might be finished before reaching the ``max_epoch``.
|
||||
:param int step_per_epoch: the number of policy network updates, so-called
|
||||
gradient steps, per epoch.
|
||||
:param int collect_per_step: the number of episodes the collector would
|
||||
collect before the network update. In other words, collect some
|
||||
episodes and do one policy network update.
|
||||
@ -52,19 +53,20 @@ def onpolicy_trainer(
|
||||
:type episode_per_test: int or list of ints
|
||||
:param int batch_size: the batch size of sample data, which is going to
|
||||
feed in the policy network.
|
||||
:param function train_fn: a function receives the current number of epoch
|
||||
and step index, and performs some operations at the beginning of
|
||||
training in this poch.
|
||||
:param function test_fn: a function receives the current number of epoch
|
||||
and step index, and performs some operations at the beginning of
|
||||
testing in this epoch.
|
||||
:param function save_fn: a function for saving policy when the undiscounted
|
||||
average mean reward in evaluation phase gets better.
|
||||
:param function stop_fn: a function receives the average undiscounted
|
||||
returns of the testing result, return a boolean which indicates whether
|
||||
reaching the goal.
|
||||
:param function train_fn: a hook called at the beginning of training in
|
||||
each epoch. It can be used to perform custom additional operations,
|
||||
with the signature ``f(num_epoch: int, step_idx: int) -> None``.
|
||||
:param function test_fn: a hook called at the beginning of testing in each
|
||||
epoch. It can be used to perform custom additional operations, with the
|
||||
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
||||
:param function save_fn: a hook called when the undiscounted average mean
|
||||
reward in evaluation phase gets better, with the signature ``f(policy:
|
||||
BasePolicy) -> None``.
|
||||
:param function stop_fn: a function with signature ``f(mean_rewards: float)
|
||||
-> bool``, receives the average undiscounted returns of the testing
|
||||
result, returns a boolean which indicates whether reaching the goal.
|
||||
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
|
||||
SummaryWriter.
|
||||
SummaryWriter; if None is given, it will not write logs to TensorBoard.
|
||||
:param int log_interval: the log interval of the writer.
|
||||
:param bool verbose: whether to print the information.
|
||||
:param bool test_in_train: whether to test in the training phase.
|
||||
@ -73,7 +75,7 @@ def onpolicy_trainer(
|
||||
"""
|
||||
env_step, gradient_step = 0, 0
|
||||
best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
|
||||
stat: Dict[str, MovAvg] = {}
|
||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||
start_time = time.time()
|
||||
train_collector.reset_stat()
|
||||
test_collector.reset_stat()
|
||||
@ -125,8 +127,6 @@ def onpolicy_trainer(
|
||||
len(v) for v in losses.values() if isinstance(v, list)])
|
||||
gradient_step += step
|
||||
for k in losses.keys():
|
||||
if stat.get(k) is None:
|
||||
stat[k] = MovAvg()
|
||||
stat[k].add(losses[k])
|
||||
data[k] = f"{stat[k].get():.6f}"
|
||||
if writer and gradient_step % log_interval == 0:
|
||||
|
@ -36,7 +36,7 @@ def test_episode(
|
||||
|
||||
def gather_info(
|
||||
start_time: float,
|
||||
train_c: Collector,
|
||||
train_c: Optional[Collector],
|
||||
test_c: Collector,
|
||||
best_reward: float,
|
||||
best_reward_std: float,
|
||||
@ -59,15 +59,9 @@ def gather_info(
|
||||
* ``duration`` the total elapsed time.
|
||||
"""
|
||||
duration = time.time() - start_time
|
||||
model_time = duration - train_c.collect_time - test_c.collect_time
|
||||
train_speed = train_c.collect_step / (duration - test_c.collect_time)
|
||||
model_time = duration - test_c.collect_time
|
||||
test_speed = test_c.collect_step / test_c.collect_time
|
||||
return {
|
||||
"train_step": train_c.collect_step,
|
||||
"train_episode": train_c.collect_episode,
|
||||
"train_time/collector": f"{train_c.collect_time:.2f}s",
|
||||
"train_time/model": f"{model_time:.2f}s",
|
||||
"train_speed": f"{train_speed:.2f} step/s",
|
||||
result: Dict[str, Union[float, str]] = {
|
||||
"test_step": test_c.collect_step,
|
||||
"test_episode": test_c.collect_episode,
|
||||
"test_time": f"{test_c.collect_time:.2f}s",
|
||||
@ -75,4 +69,16 @@ def gather_info(
|
||||
"best_reward": best_reward,
|
||||
"best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}",
|
||||
"duration": f"{duration:.2f}s",
|
||||
"train_time/model": f"{model_time:.2f}s",
|
||||
}
|
||||
if train_c is not None:
|
||||
model_time -= train_c.collect_time
|
||||
train_speed = train_c.collect_step / (duration - test_c.collect_time)
|
||||
result.update({
|
||||
"train_step": train_c.collect_step,
|
||||
"train_episode": train_c.collect_episode,
|
||||
"train_time/collector": f"{train_c.collect_time:.2f}s",
|
||||
"train_time/model": f"{model_time:.2f}s",
|
||||
"train_speed": f"{train_speed:.2f} step/s",
|
||||
})
|
||||
return result
|
||||
|
Loading…
x
Reference in New Issue
Block a user