Add offline trainer and discrete BCQ algorithm (#263)

The result needs to be tuned after `done` issue fixed.

Co-authored-by: n+e <trinkle23897@gmail.com>
This commit is contained in:
Jialu Zhu 2021-01-20 02:13:04 -08:00 committed by GitHub
parent a633a6a028
commit a511cb4779
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 628 additions and 80 deletions

2
.gitignore vendored
View File

@ -145,3 +145,5 @@ MUJOCO_LOG.TXT
*.zip
*.pstats
*.swp
*.pkl
*.hdf5

View File

@ -31,6 +31,7 @@
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
- Vanilla Imitation Learning
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)

View File

@ -20,8 +20,9 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_

View File

@ -201,7 +201,7 @@ Trainer
Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`.
Tianshou has two types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` and :func:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out :doc:`/api/tianshou.trainer` for the usage.
Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage.
.. _pseudocode:

View File

@ -120,7 +120,7 @@ In each step, the collector will let the policy perform (at least) a specified n
Train Policy with a Trainer
---------------------------
Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tianshou.trainer.offpolicy_trainer`. The trainer will automatically stop training when the policy reach the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :class:`~tianshou.trainer.offpolicy_trainer` as follows:
Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.trainer.offpolicy_trainer`, and :func:`~tianshou.trainer.offline_trainer`. The trainer will automatically stop training when the policy reach the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :func:`~tianshou.trainer.offpolicy_trainer` as follows:
::
result = ts.trainer.offpolicy_trainer(
@ -133,7 +133,7 @@ Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tians
writer=None)
print(f'Finished training! Use {result["duration"]}')
The meaning of each parameter is as follows (full description can be found at :meth:`~tianshou.trainer.offpolicy_trainer`):
The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`):
* ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``;
* ``step_per_epoch``: The number of step for updating policy network in one epoch;

View File

@ -38,4 +38,15 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
| SeaquestNoFrameskip-v4 | 6226 | ![](results/c51/Seaquest_rew.png) | `python3 atari_c51.py --task "SeaquestNoFrameskip-v4"` |
| SpaceInvadersNoFrameskip-v4 | 988.5 | ![](results/c51/SpaceInvader_rew.png) | `python3 atari_c51.py --task "SpaceInvadersNoFrameskip-v4"` |
Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper.
Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper.
# BCQ
TODO: after the `done` issue fixed, the result should be re-tuned and place here.
To running BCQ algorithm on Atari, you need to do the following things:
- Train an expert, by using the command listed in the above DQN section;
- Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`.

153
examples/atari/atari_bcq.py Normal file
View File

@ -0,0 +1,153 @@
import os
import torch
import pickle
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offline_trainer
from tianshou.utils.net.discrete import Actor
from tianshou.policy import DiscreteBCQPolicy
from tianshou.data import Collector, ReplayBuffer
from atari_network import DQN
from atari_wrapper import wrap_deepmind
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.001)
parser.add_argument("--lr", type=float, default=6.25e-5)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=8000)
parser.add_argument("--unlikely-action-threshold", type=float, default=0.3)
parser.add_argument("--imitation-logits-penalty", type=float, default=0.01)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=10000)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[512])
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument('--frames_stack', type=int, default=4)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--watch", default=False, action="store_true",
help="watch the play of pre-trained policy only")
parser.add_argument("--log-interval", type=int, default=1000)
parser.add_argument(
"--load-buffer-name", type=str,
default="./expert_DQN_PongNoFrameskip-v4.hdf5",
)
parser.add_argument(
"--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
args = parser.parse_known_args()[0]
return args
def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
def make_atari_env_watch(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
episode_life=False, clip_rewards=False)
def test_discrete_bcq(args=get_args()):
# envs
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
feature_net = DQN(*args.state_shape, args.action_shape,
device=args.device, features_only=True).to(args.device)
policy_net = Actor(feature_net, args.action_shape,
hidden_sizes=args.hidden_sizes).to(args.device)
imitation_net = Actor(feature_net, args.action_shape,
hidden_sizes=args.hidden_sizes).to(args.device)
optim = torch.optim.Adam(
set(policy_net.parameters()).union(imitation_net.parameters()),
lr=args.lr,
)
# define policy
policy = DiscreteBCQPolicy(
policy_net, imitation_net, optim, args.gamma, args.n_step,
args.target_update_freq, args.eps_test,
args.unlikely_action_threshold, args.imitation_logits_penalty,
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
args.resume_path, map_location=args.device
))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith('.pkl'):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith('.hdf5'):
buffer = ReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
# collector
test_collector = Collector(policy, test_envs)
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return False
# watch agent's performance
def watch():
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
pprint.pprint(result)
if args.watch:
watch()
exit(0)
result = offline_trainer(
policy, buffer, test_collector,
args.epoch, args.step_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
log_interval=args.log_interval,
)
pprint.pprint(result)
watch()
if __name__ == "__main__":
test_discrete_bcq(get_args())

View File

@ -41,6 +41,7 @@ def get_args():
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()
@ -120,13 +121,25 @@ def test_dqn(args=get_args()):
# watch agent's performance
def watch():
print("Testing agent ...")
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = ReplayBuffer(
args.buffer_size, ignore_obs_next=True,
save_only_last_obs=True, stack_num=args.frames_stack)
collector = Collector(policy, test_envs, buffer)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
pprint.pprint(result)
if args.watch:

View File

@ -44,8 +44,7 @@ class DQN(nn.Module):
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Q(x, \*)."""
x = torch.as_tensor(
x, device=self.device, dtype=torch.float32) # type: ignore
x = torch.as_tensor(x, device=self.device, dtype=torch.float32)
return self.net(x), state

View File

@ -51,7 +51,7 @@ setup(
"tensorboard",
"torch>=1.4.0",
"numba>=0.51.0",
"h5py>=3.1.0"
"h5py>=2.10.0", # to match tensorflow's minimal requirements
],
extras_require={
"dev": [

View File

@ -1,6 +1,7 @@
import os
import gym
import torch
import pickle
import pprint
import argparse
import numpy as np
@ -38,6 +39,9 @@ def get_args():
action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--save-buffer-name', type=str,
default="./expert_DQN_CartPole-v0.pkl")
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
@ -114,6 +118,7 @@ def test_dqn(args=get_args()):
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
@ -124,6 +129,12 @@ def test_dqn(args=get_args()):
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
# save buffer in pickle format, for imitation learning unittest
buf = ReplayBuffer(args.buffer_size)
collector = Collector(policy, test_envs, buf)
collector.collect(n_step=args.buffer_size)
pickle.dump(buf, open(args.save_buffer_name, "wb"))
def test_pdqn(args=get_args()):
args.prioritized_replay = True

View File

@ -0,0 +1,111 @@
import os
import gym
import torch
import pickle
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offline_trainer
from tianshou.policy import DiscreteBCQPolicy
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.001)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--gamma", type=float, default=0.9)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=320)
parser.add_argument("--unlikely-action-threshold", type=float, default=0.3)
parser.add_argument("--imitation-logits-penalty", type=float, default=0.01)
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--step-per-epoch", type=int, default=1000)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128])
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
"--load-buffer-name", type=str,
default="./expert_DQN_CartPole-v0.pkl",
)
parser.add_argument(
"--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
args = parser.parse_known_args()[0]
return args
def test_discrete_bcq(args=get_args()):
# envs
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
policy_net = Net(
args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device).to(args.device)
imitation_net = Net(
args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device).to(args.device)
optim = torch.optim.Adam(
set(policy_net.parameters()).union(imitation_net.parameters()),
lr=args.lr)
policy = DiscreteBCQPolicy(
policy_net, imitation_net, optim, args.gamma, args.n_step,
args.target_update_freq, args.eps_test,
args.unlikely_action_threshold, args.imitation_logits_penalty,
)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run test_dqn.py first to get expert's data buffer."
buffer = pickle.load(open(args.load_buffer_name, "rb"))
# collector
test_collector = Collector(policy, test_envs)
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
result = offline_trainer(
policy, buffer, test_collector,
args.epoch, args.step_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
if __name__ == "__main__":
test_discrete_bcq(get_args())

View File

@ -1,7 +1,7 @@
from tianshou import data, env, utils, policy, trainer, exploration
__version__ = "0.3.0"
__version__ = "0.3.1"
__all__ = [
"env",

View File

@ -10,6 +10,7 @@ from tianshou.policy.modelfree.ppo import PPOPolicy
from tianshou.policy.modelfree.td3 import TD3Policy
from tianshou.policy.modelfree.sac import SACPolicy
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
from tianshou.policy.modelbase.psrl import PSRLPolicy
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
@ -27,6 +28,7 @@ __all__ = [
"TD3Policy",
"SACPolicy",
"DiscreteSACPolicy",
"DiscreteBCQPolicy",
"PSRLPolicy",
"MultiAgentPolicyManager",
]

View File

@ -57,7 +57,7 @@ class ImitationPolicy(BasePolicy):
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
loss = F.mse_loss(a, a_) # type: ignore
elif self.mode == "discrete": # classification
a = self(batch).logits
a = F.log_softmax(self(batch).logits, dim=-1)
a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
loss = F.nll_loss(a, a_) # type: ignore
loss.backward()

View File

@ -0,0 +1,139 @@
import math
import torch
import numpy as np
import torch.nn.functional as F
from typing import Any, Dict, Union, Optional
from tianshou.policy import DQNPolicy
from tianshou.data import Batch, ReplayBuffer, to_torch
class DiscreteBCQPolicy(DQNPolicy):
"""Implementation of discrete BCQ algorithm. arXiv:1910.01708.
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> q_value)
:param torch.nn.Module imitator: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> imtation_logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param float discount_factor: in [0, 1].
:param int estimation_step: greater than 1, the number of steps to look
ahead.
:param int target_update_freq: the target network update frequency.
:param float eval_eps: the epsilon-greedy noise added in evaluation.
:param float unlikely_action_threshold: the threshold (tau) for unlikely
actions, as shown in Equ. (17) in the paper, defaults to 0.3.
:param float imitation_logits_penalty: reguralization weight for imitation
logits, defaults to 1e-2.
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to False.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(
self,
model: torch.nn.Module,
imitator: torch.nn.Module,
optim: torch.optim.Optimizer,
discount_factor: float = 0.99,
estimation_step: int = 1,
target_update_freq: int = 8000,
eval_eps: float = 1e-3,
unlikely_action_threshold: float = 0.3,
imitation_logits_penalty: float = 1e-2,
reward_normalization: bool = False,
**kwargs: Any,
) -> None:
super().__init__(model, optim, discount_factor, estimation_step,
target_update_freq, reward_normalization, **kwargs)
assert target_update_freq > 0, "BCQ needs target network setting."
self.imitator = imitator
assert (
0.0 <= unlikely_action_threshold < 1.0
), "unlikely_action_threshold should be in [0, 1)"
if unlikely_action_threshold > 0:
self._log_tau = math.log(unlikely_action_threshold)
else:
self._log_tau = -np.inf
assert 0.0 <= eval_eps < 1.0
self._eps = eval_eps
self._weight_reg = imitation_logits_penalty
def train(self, mode: bool = True) -> "DiscreteBCQPolicy":
self.training = mode
self.model.train(mode)
self.imitator.train(mode)
return self
def _target_q(
self, buffer: ReplayBuffer, indice: np.ndarray
) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
with torch.no_grad():
act = self(batch, input="obs_next", eps=0.0).act
target_q, _ = self.model_old(batch.obs_next)
target_q = target_q[np.arange(len(act)), act]
return target_q
def forward( # type: ignore
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
input: str = "obs",
eps: Optional[float] = None,
**kwargs: Any,
) -> Batch:
if eps is None:
eps = self._eps
obs = batch[input]
q_value, state = self.model(obs, state=state, info=batch.info)
imitation_logits, _ = self.imitator(obs, state=state, info=batch.info)
# mask actions for argmax
ratio = imitation_logits - imitation_logits.max(
dim=-1, keepdim=True).values
mask = (ratio < self._log_tau).float()
action = (q_value - np.inf * mask).argmax(dim=-1)
# add eps to act
if not np.isclose(eps, 0.0):
bsz, action_num = q_value.shape
mask = np.random.rand(bsz) < eps
action_rand = torch.randint(
action_num, size=[bsz], device=action.device)
action[mask] = action_rand[mask]
return Batch(act=action, state=state, q_value=q_value,
imitation_logits=imitation_logits)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._iter % self._freq == 0:
self.sync_weight()
self._iter += 1
target_q = batch.returns.flatten()
result = self(batch, eps=0.0)
imitation_logits = result.imitation_logits
current_q = result.q_value[np.arange(len(target_q)), batch.act]
act = to_torch(batch.act, dtype=torch.long, device=target_q.device)
q_loss = F.smooth_l1_loss(current_q, target_q)
i_loss = F.nll_loss(
F.log_softmax(imitation_logits, dim=-1), act) # type: ignore
reg_loss = imitation_logits.pow(2).mean()
loss = q_loss + i_loss + self._weight_reg * reg_loss
self.optim.zero_grad()
loss.backward()
self.optim.step()
return {
"loss": loss.item(),
"q_loss": q_loss.item(),
"i_loss": i_loss.item(),
"reg_loss": reg_loss.item(),
}

View File

@ -58,8 +58,8 @@ class A2CPolicy(PGPolicy):
self.critic = critic
assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]."
self._lambda = gae_lambda
self._w_vf = vf_coef
self._w_ent = ent_coef
self._weight_vf = vf_coef
self._weight_ent = ent_coef
self._grad_norm = max_grad_norm
self._batch = max_batchsize
self._rew_norm = reward_normalization
@ -122,7 +122,8 @@ class A2CPolicy(PGPolicy):
a_loss = -(log_prob * (r - v).detach()).mean()
vf_loss = F.mse_loss(r, v) # type: ignore
ent_loss = dist.entropy().mean()
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
loss = a_loss + self._weight_vf * vf_loss - \
self._weight_ent * ent_loss
loss.backward()
if self._grad_norm is not None:
nn.utils.clip_grad_norm_(

View File

@ -124,7 +124,7 @@ class C51Policy(DQNPolicy):
return target_dist.sum(-1)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._cnt % self._freq == 0:
if self._target and self._iter % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
with torch.no_grad():
@ -139,5 +139,5 @@ class C51Policy(DQNPolicy):
batch.weight = cross_entropy.detach() # prio-buffer
loss.backward()
self.optim.step()
self._cnt += 1
self._iter += 1
return {"loss": loss.item()}

View File

@ -54,7 +54,7 @@ class DQNPolicy(BasePolicy):
self._n_step = estimation_step
self._target = target_update_freq > 0
self._freq = target_update_freq
self._cnt = 0
self._iter = 0
if self._target:
self.model_old = deepcopy(self.model)
self.model_old.eval()
@ -78,16 +78,15 @@ class DQNPolicy(BasePolicy):
self, buffer: ReplayBuffer, indice: np.ndarray
) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
a = self(batch, input="obs_next").act
with torch.no_grad():
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
with torch.no_grad():
if self._target:
a = self(batch, input="obs_next").act
target_q = self(
batch, model="model_old", input="obs_next"
).logits
target_q = target_q[np.arange(len(a)), a]
else:
with torch.no_grad():
target_q = target_q[np.arange(len(a)), a]
else:
target_q = self(batch, input="obs_next").logits.max(dim=1)[0]
return target_q
@ -162,7 +161,7 @@ class DQNPolicy(BasePolicy):
return Batch(logits=q, act=act, state=h)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._cnt % self._freq == 0:
if self._target and self._iter % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
weight = batch.pop("weight", 1.0)
@ -174,5 +173,5 @@ class DQNPolicy(BasePolicy):
batch.weight = td # prio-buffer
loss.backward()
self.optim.step()
self._cnt += 1
self._iter += 1
return {"loss": loss.item()}

View File

@ -68,8 +68,8 @@ class PPOPolicy(PGPolicy):
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
self._max_grad_norm = max_grad_norm
self._eps_clip = eps_clip
self._w_vf = vf_coef
self._w_ent = ent_coef
self._weight_vf = vf_coef
self._weight_ent = ent_coef
self._range = action_range
self.actor = actor
self.critic = critic
@ -174,7 +174,8 @@ class PPOPolicy(PGPolicy):
vf_losses.append(vf_loss.item())
e_loss = dist.entropy().mean()
ent_losses.append(e_loss.item())
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss
loss = clip_loss + self._weight_vf * vf_loss - \
self._weight_ent * e_loss
losses.append(loss.item())
self.optim.zero_grad()
loss.backward()

View File

@ -143,7 +143,6 @@ class SACPolicy(DDPGPolicy):
with torch.no_grad():
obs_next_result = self(batch, input='obs_next')
a_ = obs_next_result.act
batch.act = to_torch_as(batch.act, a_)
target_q = torch.min(
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_),

View File

@ -1,10 +1,12 @@
from tianshou.trainer.utils import test_episode, gather_info
from tianshou.trainer.onpolicy import onpolicy_trainer
from tianshou.trainer.offpolicy import offpolicy_trainer
from tianshou.trainer.offline import offline_trainer
__all__ = [
"gather_info",
"test_episode",
"onpolicy_trainer",
"offpolicy_trainer",
"offline_trainer",
]

View File

@ -0,0 +1,97 @@
import time
import tqdm
from collections import defaultdict
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, List, Union, Callable, Optional
from tianshou.policy import BasePolicy
from tianshou.utils import tqdm_config, MovAvg
from tianshou.data import Collector, ReplayBuffer
from tianshou.trainer import test_episode, gather_info
def offline_trainer(
policy: BasePolicy,
buffer: ReplayBuffer,
test_collector: Collector,
max_epoch: int,
step_per_epoch: int,
episode_per_test: Union[int, List[int]],
batch_size: int,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
writer: Optional[SummaryWriter] = None,
log_interval: int = 1,
verbose: bool = True,
) -> Dict[str, Union[float, str]]:
"""A wrapper for offline trainer procedure.
The "step" in trainer means a policy network update.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
class.
:param test_collector: the collector used for testing.
:type test_collector: :class:`~tianshou.data.Collector`
:param int max_epoch: the maximum number of epochs for training. The
training process might be finished before reaching the ``max_epoch``.
:param int step_per_epoch: the number of policy network updates, so-called
gradient steps, per epoch.
:param episode_per_test: the number of episodes for one policy evaluation.
:param int batch_size: the batch size of sample data, which is going to
feed in the policy network.
:param function test_fn: a hook called at the beginning of testing in each
epoch. It can be used to perform custom additional operations, with the
signature ``f(num_epoch: int, step_idx: int) -> None``.
:param function save_fn: a hook called when the undiscounted average mean
reward in evaluation phase gets better, with the signature ``f(policy:
BasePolicy) -> None``.
:param function stop_fn: a function with signature ``f(mean_rewards: float)
-> bool``, receives the average undiscounted returns of the testing
result, returns a boolean which indicates whether reaching the goal.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
SummaryWriter; if None is given, it will not write logs to TensorBoard.
:param int log_interval: the log interval of the writer.
:param bool verbose: whether to print the information.
:return: See :func:`~tianshou.trainer.gather_info`.
"""
gradient_step = 0
best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
test_collector.reset_stat()
for epoch in range(1, 1 + max_epoch):
policy.train()
with tqdm.trange(
step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
) as t:
for i in t:
gradient_step += 1
losses = policy.update(batch_size, buffer)
data = {"gradient_step": str(gradient_step)}
for k in losses.keys():
stat[k].add(losses[k])
data[k] = f"{stat[k].get():.6f}"
if writer and gradient_step % log_interval == 0:
writer.add_scalar(
"train/" + k, stat[k].get(),
global_step=gradient_step)
t.set_postfix(**data)
# test
result = test_episode(policy, test_collector, test_fn, epoch,
episode_per_test, writer, gradient_step)
if best_epoch == -1 or best_reward < result["rew"]:
best_reward, best_reward_std = result["rew"], result["rew_std"]
best_epoch = epoch
if save_fn:
save_fn(policy)
if verbose:
print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± "
f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± "
f"{best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward):
break
return gather_info(start_time, None, test_collector,
best_reward, best_reward_std)

View File

@ -1,5 +1,6 @@
import time
import tqdm
from collections import defaultdict
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, List, Union, Callable, Optional
@ -38,10 +39,10 @@ def offpolicy_trainer(
:type train_collector: :class:`~tianshou.data.Collector`
:param test_collector: the collector used for testing.
:type test_collector: :class:`~tianshou.data.Collector`
:param int max_epoch: the maximum of epochs for training. The training
process might be finished before reaching the ``max_epoch``.
:param int step_per_epoch: the number of step for updating policy network
in one epoch.
:param int max_epoch: the maximum number of epochs for training. The
training process might be finished before reaching the ``max_epoch``.
:param int step_per_epoch: the number of policy network updates, so-called
gradient steps, per epoch.
:param int collect_per_step: the number of frames the collector would
collect before the network update. In other words, collect some frames
and do some policy network update.
@ -52,19 +53,20 @@ def offpolicy_trainer(
be updated after frames are collected, for example, set it to 256 means
it updates policy 256 times once after ``collect_per_step`` frames are
collected.
:param function train_fn: a function receives the current number of epoch
and step index, and performs some operations at the beginning of
training in this epoch.
:param function test_fn: a function receives the current number of epoch
and step index, and performs some operations at the beginning of
testing in this epoch.
:param function save_fn: a function for saving policy when the undiscounted
average mean reward in evaluation phase gets better.
:param function stop_fn: a function receives the average undiscounted
returns of the testing result, return a boolean which indicates whether
reaching the goal.
:param function train_fn: a hook called at the beginning of training in
each epoch. It can be used to perform custom additional operations,
with the signature ``f(num_epoch: int, step_idx: int) -> None``.
:param function test_fn: a hook called at the beginning of testing in each
epoch. It can be used to perform custom additional operations, with the
signature ``f(num_epoch: int, step_idx: int) -> None``.
:param function save_fn: a hook called when the undiscounted average mean
reward in evaluation phase gets better, with the signature ``f(policy:
BasePolicy) -> None``.
:param function stop_fn: a function with signature ``f(mean_rewards: float)
-> bool``, receives the average undiscounted returns of the testing
result, returns a boolean which indicates whether reaching the goal.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
SummaryWriter.
SummaryWriter; if None is given, it will not write logs to TensorBoard.
:param int log_interval: the log interval of the writer.
:param bool verbose: whether to print the information.
:param bool test_in_train: whether to test in the training phase.
@ -73,7 +75,7 @@ def offpolicy_trainer(
"""
env_step, gradient_step = 0, 0
best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
stat: Dict[str, MovAvg] = {}
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
train_collector.reset_stat()
test_collector.reset_stat()
@ -122,8 +124,6 @@ def offpolicy_trainer(
gradient_step += 1
losses = policy.update(batch_size, train_collector.buffer)
for k in losses.keys():
if stat.get(k) is None:
stat[k] = MovAvg()
stat[k].add(losses[k])
data[k] = f"{stat[k].get():.6f}"
if writer and gradient_step % log_interval == 0:

View File

@ -1,5 +1,6 @@
import time
import tqdm
from collections import defaultdict
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, List, Union, Callable, Optional
@ -38,10 +39,10 @@ def onpolicy_trainer(
:type train_collector: :class:`~tianshou.data.Collector`
:param test_collector: the collector used for testing.
:type test_collector: :class:`~tianshou.data.Collector`
:param int max_epoch: the maximum of epochs for training. The training
process might be finished before reaching the ``max_epoch``.
:param int step_per_epoch: the number of step for updating policy network
in one epoch.
:param int max_epoch: the maximum number of epochs for training. The
training process might be finished before reaching the ``max_epoch``.
:param int step_per_epoch: the number of policy network updates, so-called
gradient steps, per epoch.
:param int collect_per_step: the number of episodes the collector would
collect before the network update. In other words, collect some
episodes and do one policy network update.
@ -52,19 +53,20 @@ def onpolicy_trainer(
:type episode_per_test: int or list of ints
:param int batch_size: the batch size of sample data, which is going to
feed in the policy network.
:param function train_fn: a function receives the current number of epoch
and step index, and performs some operations at the beginning of
training in this poch.
:param function test_fn: a function receives the current number of epoch
and step index, and performs some operations at the beginning of
testing in this epoch.
:param function save_fn: a function for saving policy when the undiscounted
average mean reward in evaluation phase gets better.
:param function stop_fn: a function receives the average undiscounted
returns of the testing result, return a boolean which indicates whether
reaching the goal.
:param function train_fn: a hook called at the beginning of training in
each epoch. It can be used to perform custom additional operations,
with the signature ``f(num_epoch: int, step_idx: int) -> None``.
:param function test_fn: a hook called at the beginning of testing in each
epoch. It can be used to perform custom additional operations, with the
signature ``f(num_epoch: int, step_idx: int) -> None``.
:param function save_fn: a hook called when the undiscounted average mean
reward in evaluation phase gets better, with the signature ``f(policy:
BasePolicy) -> None``.
:param function stop_fn: a function with signature ``f(mean_rewards: float)
-> bool``, receives the average undiscounted returns of the testing
result, returns a boolean which indicates whether reaching the goal.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
SummaryWriter.
SummaryWriter; if None is given, it will not write logs to TensorBoard.
:param int log_interval: the log interval of the writer.
:param bool verbose: whether to print the information.
:param bool test_in_train: whether to test in the training phase.
@ -73,7 +75,7 @@ def onpolicy_trainer(
"""
env_step, gradient_step = 0, 0
best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
stat: Dict[str, MovAvg] = {}
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
train_collector.reset_stat()
test_collector.reset_stat()
@ -125,8 +127,6 @@ def onpolicy_trainer(
len(v) for v in losses.values() if isinstance(v, list)])
gradient_step += step
for k in losses.keys():
if stat.get(k) is None:
stat[k] = MovAvg()
stat[k].add(losses[k])
data[k] = f"{stat[k].get():.6f}"
if writer and gradient_step % log_interval == 0:

View File

@ -36,7 +36,7 @@ def test_episode(
def gather_info(
start_time: float,
train_c: Collector,
train_c: Optional[Collector],
test_c: Collector,
best_reward: float,
best_reward_std: float,
@ -59,15 +59,9 @@ def gather_info(
* ``duration`` the total elapsed time.
"""
duration = time.time() - start_time
model_time = duration - train_c.collect_time - test_c.collect_time
train_speed = train_c.collect_step / (duration - test_c.collect_time)
model_time = duration - test_c.collect_time
test_speed = test_c.collect_step / test_c.collect_time
return {
"train_step": train_c.collect_step,
"train_episode": train_c.collect_episode,
"train_time/collector": f"{train_c.collect_time:.2f}s",
"train_time/model": f"{model_time:.2f}s",
"train_speed": f"{train_speed:.2f} step/s",
result: Dict[str, Union[float, str]] = {
"test_step": test_c.collect_step,
"test_episode": test_c.collect_episode,
"test_time": f"{test_c.collect_time:.2f}s",
@ -75,4 +69,16 @@ def gather_info(
"best_reward": best_reward,
"best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}",
"duration": f"{duration:.2f}s",
"train_time/model": f"{model_time:.2f}s",
}
if train_c is not None:
model_time -= train_c.collect_time
train_speed = train_c.collect_step / (duration - test_c.collect_time)
result.update({
"train_step": train_c.collect_step,
"train_episode": train_c.collect_episode,
"train_time/collector": f"{train_c.collect_time:.2f}s",
"train_time/model": f"{model_time:.2f}s",
"train_speed": f"{train_speed:.2f} step/s",
})
return result