Fix critic network for Discrete CRR (#485)

- Fixes an inconsistency in the implementation of Discrete CRR. Now it uses `Critic` class for its critic, following conventions in other actor-critic policies;
- Updates several offline policies to use `ActorCritic` class for its optimizer to eliminate randomness caused by parameter sharing between actor and critic;
- Add `writer.flush()` in TensorboardLogger to ensure real-time result;
- Enable `test_collector=None` in 3 trainers to turn off testing during training;
- Updates the Atari offline results in README.md;
- Moves Atari offline RL examples to `examples/offline`; tests to `test/offline` per review comments.
This commit is contained in:
Yi Su 2021-11-28 07:10:28 -08:00 committed by GitHub
parent 5c5a3db94e
commit 3592f45446
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 459 additions and 258 deletions

0
examples/__init__.py Normal file
View File

View File

@ -1,4 +1,4 @@
# Atari General # Atari
The sample speed is \~3000 env step per second (\~12000 Atari frame per second in fact since we use frame_stack=4) under the normal mode (use a CNN policy and a collector, also storing data into the buffer). The main bottleneck is training the convolutional neural network. The sample speed is \~3000 env step per second (\~12000 Atari frame per second in fact since we use frame_stack=4) under the normal mode (use a CNN policy and a collector, also storing data into the buffer). The main bottleneck is training the convolutional neural network.
@ -95,66 +95,3 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
| MsPacmanNoFrameskip-v4 | 3101 | ![](results/rainbow/MsPacman_rew.png) | `python3 atari_rainbow.py --task "MsPacmanNoFrameskip-v4"` | | MsPacmanNoFrameskip-v4 | 3101 | ![](results/rainbow/MsPacman_rew.png) | `python3 atari_rainbow.py --task "MsPacmanNoFrameskip-v4"` |
| SeaquestNoFrameskip-v4 | 2126 | ![](results/rainbow/Seaquest_rew.png) | `python3 atari_rainbow.py --task "SeaquestNoFrameskip-v4"` | | SeaquestNoFrameskip-v4 | 2126 | ![](results/rainbow/Seaquest_rew.png) | `python3 atari_rainbow.py --task "SeaquestNoFrameskip-v4"` |
| SpaceInvadersNoFrameskip-v4 | 1794.5 | ![](results/rainbow/SpaceInvaders_rew.png) | `python3 atari_rainbow.py --task "SpaceInvadersNoFrameskip-v4"` | | SpaceInvadersNoFrameskip-v4 | 1794.5 | ![](results/rainbow/SpaceInvaders_rew.png) | `python3 atari_rainbow.py --task "SpaceInvadersNoFrameskip-v4"` |
# BCQ
To running BCQ algorithm on Atari, you need to do the following things:
- Train an expert, by using the command listed in the above DQN section;
- Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`.
We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
| Task | Online DQN | Behavioral | BCQ |
| ---------------------- | ---------- | ---------- | --------------------------------- |
| PongNoFrameskip-v4 | 21 | 7.7 | 21 (epoch 5) |
| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) |
# CQL
To running CQL algorithm on Atari, you need to do the following things:
- Train an expert, by using the command listed in the above QRDQN section;
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train CQL: `python3 atari_cql.py --task {your_task} --load-buffer-name expert.hdf5`.
We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | 19.5 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 248.3 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
We reduce the size of the offline data to 10% and 1% of the above and get:
Buffer size 100000:
| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` |
Buffer size 10000:
| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` |
| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` |
# CRR
To running CRR algorithm on Atari, you need to do the following things:
- Train an expert, by using the command listed in the above QRDQN section;
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train CQL: `python3 atari_crr.py --task {your_task} --load-buffer-name expert.hdf5`.
We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
| Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters |
| ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 16.1 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 26.4 (epoch 12) | 125.0 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps.

View File

View File

@ -2,9 +2,11 @@
In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore. In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore.
Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets. ## Continous control
## Train Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent for continuous control. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets.
### Train
Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset. Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset.
@ -26,3 +28,59 @@ After 1M steps:
| --------------------- | --------------- | | --------------------- | --------------- |
| halfcheetah-expert-v1 | 10624.0 ± 181.4 | | halfcheetah-expert-v1 | 10624.0 ± 181.4 |
## Discrete control
For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent. In the future, we can switch to better benchmarks such as the Atari portion of [RL Unplugged](https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged).
### Gather Data
To running CQL algorithm on Atari, you need to do the following things:
- Train an expert, by using the command listed in the QRDQN section of Atari examples: `python3 atari_qrdqn.py --task {your_task}`
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train offline model: `python3 atari_{bcq,cql,crr}.py --task {your_task} --load-buffer-name expert.hdf5`.
### BCQ
We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
| Task | Online QRDQN | Behavioral | BCQ | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.1 (epoch 5) | `python3 atari_bcq.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 64.6 (epoch 12, could be higher) | `python3 atari_bcq.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12` |
### CQL
We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.4 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 129.4 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
We reduce the size of the offline data to 10% and 1% of the above and get:
Buffer size 100000:
| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` |
Buffer size 10000:
| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` |
| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` |
### CRR
We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
| Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters |
| ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 17.7 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps.

View File

View File

@ -6,15 +6,16 @@ import pprint
import numpy as np import numpy as np
import torch import torch
from atari_network import DQN
from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from examples.atari.atari_network import DQN
from examples.atari.atari_wrapper import wrap_deepmind
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteBCQPolicy from tianshou.policy import DiscreteBCQPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.discrete import Actor from tianshou.utils.net.discrete import Actor
@ -93,18 +94,17 @@ def test_discrete_bcq(args=get_args()):
args.action_shape, args.action_shape,
device=args.device, device=args.device,
hidden_sizes=args.hidden_sizes, hidden_sizes=args.hidden_sizes,
softmax_output=False softmax_output=False,
).to(args.device) ).to(args.device)
imitation_net = Actor( imitation_net = Actor(
feature_net, feature_net,
args.action_shape, args.action_shape,
device=args.device, device=args.device,
hidden_sizes=args.hidden_sizes, hidden_sizes=args.hidden_sizes,
softmax_output=False softmax_output=False,
).to(args.device) ).to(args.device)
optim = torch.optim.Adam( actor_critic = ActorCritic(policy_net, imitation_net)
list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
)
# define policy # define policy
policy = DiscreteBCQPolicy( policy = DiscreteBCQPolicy(
policy_net, imitation_net, optim, args.gamma, args.n_step, policy_net, imitation_net, optim, args.gamma, args.n_step,
@ -171,7 +171,7 @@ def test_discrete_bcq(args=get_args()):
args.batch_size, args.batch_size,
stop_fn=stop_fn, stop_fn=stop_fn,
save_fn=save_fn, save_fn=save_fn,
logger=logger logger=logger,
) )
pprint.pprint(result) pprint.pprint(result)

View File

@ -6,10 +6,10 @@ import pprint
import numpy as np import numpy as np
import torch import torch
from atari_network import QRDQN
from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from examples.atari.atari_network import QRDQN
from examples.atari.atari_wrapper import wrap_deepmind
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteCQLPolicy from tianshou.policy import DiscreteCQLPolicy
@ -94,7 +94,7 @@ def test_discrete_cql(args=get_args()):
args.num_quantiles, args.num_quantiles,
args.n_step, args.n_step,
args.target_update_freq, args.target_update_freq,
min_q_weight=args.min_q_weight min_q_weight=args.min_q_weight,
).to(args.device) ).to(args.device)
# load a previous policy # load a previous policy
if args.resume_path: if args.resume_path:
@ -156,7 +156,7 @@ def test_discrete_cql(args=get_args()):
args.batch_size, args.batch_size,
stop_fn=stop_fn, stop_fn=stop_fn,
save_fn=save_fn, save_fn=save_fn,
logger=logger logger=logger,
) )
pprint.pprint(result) pprint.pprint(result)

View File

@ -6,16 +6,17 @@ import pprint
import numpy as np import numpy as np
import torch import torch
from atari_network import DQN
from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from examples.atari.atari_network import DQN
from examples.atari.atari_wrapper import wrap_deepmind
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteCRRPolicy from tianshou.policy import DiscreteCRRPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
from tianshou.utils.net.discrete import Actor from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.discrete import Actor, Critic
def get_args(): def get_args():
@ -91,15 +92,18 @@ def test_discrete_crr(args=get_args()):
actor = Actor( actor = Actor(
feature_net, feature_net,
args.action_shape, args.action_shape,
device=args.device,
hidden_sizes=args.hidden_sizes, hidden_sizes=args.hidden_sizes,
softmax_output=False device=args.device,
softmax_output=False,
).to(args.device) ).to(args.device)
critic = DQN(*args.state_shape, args.action_shape, critic = Critic(
device=args.device).to(args.device) feature_net,
optim = torch.optim.Adam( hidden_sizes=args.hidden_sizes,
list(actor.parameters()) + list(critic.parameters()), lr=args.lr last_size=np.prod(args.action_shape),
) device=args.device,
).to(args.device)
actor_critic = ActorCritic(actor, critic)
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
# define policy # define policy
policy = DiscreteCRRPolicy( policy = DiscreteCRRPolicy(
actor, actor,
@ -110,7 +114,7 @@ def test_discrete_crr(args=get_args()):
ratio_upper_bound=args.ratio_upper_bound, ratio_upper_bound=args.ratio_upper_bound,
beta=args.beta, beta=args.beta,
min_q_weight=args.min_q_weight, min_q_weight=args.min_q_weight,
target_update_freq=args.target_update_freq target_update_freq=args.target_update_freq,
).to(args.device) ).to(args.device)
# load a previous policy # load a previous policy
if args.resume_path: if args.resume_path:
@ -171,7 +175,7 @@ def test_discrete_crr(args=get_args()):
args.batch_size, args.batch_size,
stop_fn=stop_fn, stop_fn=stop_fn,
save_fn=save_fn, save_fn=save_fn,
logger=logger logger=logger,
) )
pprint.pprint(result) pprint.pprint(result)

View File

@ -1,6 +1,5 @@
import argparse import argparse
import os import os
import pickle
import pprint import pprint
import gym import gym
@ -42,9 +41,6 @@ def get_args():
parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--prioritized-replay', action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4) parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--save-buffer-name', type=str, default="./expert_DQN_CartPole-v0.pkl"
)
parser.add_argument( parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
) )
@ -85,7 +81,7 @@ def test_dqn(args=get_args()):
optim, optim,
args.gamma, args.gamma,
args.n_step, args.n_step,
target_update_freq=args.target_update_freq target_update_freq=args.target_update_freq,
) )
# buffer # buffer
if args.prioritized_replay: if args.prioritized_replay:
@ -93,7 +89,7 @@ def test_dqn(args=get_args()):
args.buffer_size, args.buffer_size,
buffer_num=len(train_envs), buffer_num=len(train_envs),
alpha=args.alpha, alpha=args.alpha,
beta=args.beta beta=args.beta,
) )
else: else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
@ -142,7 +138,7 @@ def test_dqn(args=get_args()):
test_fn=test_fn, test_fn=test_fn,
stop_fn=stop_fn, stop_fn=stop_fn,
save_fn=save_fn, save_fn=save_fn,
logger=logger logger=logger,
) )
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
@ -157,14 +153,6 @@ def test_dqn(args=get_args()):
rews, lens = result["rews"], result["lens"] rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}") print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
# save buffer in pickle format, for imitation learning unittest
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
policy.set_eps(0.2)
collector = Collector(policy, test_envs, buf, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
pickle.dump(buf, open(args.save_buffer_name, "wb"))
print(result["rews"].mean())
def test_pdqn(args=get_args()): def test_pdqn(args=get_args()):
args.prioritized_replay = True args.prioritized_replay = True

View File

@ -1,6 +1,5 @@
import argparse import argparse
import os import os
import pickle
import pprint import pprint
import gym import gym
@ -43,9 +42,6 @@ def get_args():
parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--prioritized-replay', action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4) parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--save-buffer-name', type=str, default="./expert_QRDQN_CartPole-v0.pkl"
)
parser.add_argument( parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
) )
@ -80,7 +76,7 @@ def test_qrdqn(args=get_args()):
hidden_sizes=args.hidden_sizes, hidden_sizes=args.hidden_sizes,
device=args.device, device=args.device,
softmax=False, softmax=False,
num_atoms=args.num_quantiles num_atoms=args.num_quantiles,
) )
optim = torch.optim.Adam(net.parameters(), lr=args.lr) optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = QRDQNPolicy( policy = QRDQNPolicy(
@ -89,7 +85,7 @@ def test_qrdqn(args=get_args()):
args.gamma, args.gamma,
args.num_quantiles, args.num_quantiles,
args.n_step, args.n_step,
target_update_freq=args.target_update_freq target_update_freq=args.target_update_freq,
).to(args.device) ).to(args.device)
# buffer # buffer
if args.prioritized_replay: if args.prioritized_replay:
@ -97,7 +93,7 @@ def test_qrdqn(args=get_args()):
args.buffer_size, args.buffer_size,
buffer_num=len(train_envs), buffer_num=len(train_envs),
alpha=args.alpha, alpha=args.alpha,
beta=args.beta beta=args.beta,
) )
else: else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
@ -146,7 +142,7 @@ def test_qrdqn(args=get_args()):
stop_fn=stop_fn, stop_fn=stop_fn,
save_fn=save_fn, save_fn=save_fn,
logger=logger, logger=logger,
update_per_step=args.update_per_step update_per_step=args.update_per_step,
) )
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
@ -161,14 +157,6 @@ def test_qrdqn(args=get_args()):
rews, lens = result["rews"], result["lens"] rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}") print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
# save buffer in pickle format, for imitation learning unittest
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
policy.set_eps(0.9) # 10% of expert data as demonstrated in the original paper
collector = Collector(policy, test_envs, buf, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
pickle.dump(buf, open(args.save_buffer_name, "wb"))
print(result["rews"].mean())
def test_pqrdqn(args=get_args()): def test_pqrdqn(args=get_args()):
args.prioritized_replay = True args.prioritized_replay = True

View File

@ -0,0 +1,160 @@
import argparse
import os
import pickle
import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import QRDQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--num-quantiles', type=int, default=200)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument(
'--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--prioritized-replay', action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--save-buffer-name', type=str, default="./expert_QRDQN_CartPole-v0.pkl"
)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
args = parser.parse_known_args()[0]
return args
def gather_data():
args = get_args()
env = gym.make(args.task)
if args.task == 'CartPole-v0':
env.spec.reward_threshold = 190 # lower the goal
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)]
)
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
device=args.device,
softmax=False,
num_atoms=args.num_quantiles,
)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = QRDQNPolicy(
net,
optim,
args.gamma,
args.num_quantiles,
args.n_step,
target_update_freq=args.target_update_freq,
).to(args.device)
# buffer
if args.prioritized_replay:
buf = PrioritizedVectorReplayBuffer(
args.buffer_size,
buffer_num=len(train_envs),
alpha=args.alpha,
beta=args.beta,
)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
# collector
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
def train_fn(epoch, env_step):
# eps annnealing, just a demo
if env_step <= 10000:
policy.set_eps(args.eps_train)
elif env_step <= 50000:
eps = args.eps_train - (env_step - 10000) / \
40000 * (0.9 * args.eps_train)
policy.set_eps(eps)
else:
policy.set_eps(0.1 * args.eps_train)
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
# trainer
result = offpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.step_per_collect,
args.test_num,
args.batch_size,
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger,
update_per_step=args.update_per_step,
)
assert stop_fn(result['best_reward'])
# save buffer in pickle format, for imitation learning unittest
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
policy.set_eps(0.2)
collector = Collector(policy, test_envs, buf, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
pickle.dump(buf, open(args.save_buffer_name, "wb"))
print(result["rews"].mean())
return buf

View File

@ -13,7 +13,13 @@ from tianshou.env import DummyVectorEnv
from tianshou.policy import DiscreteBCQPolicy from tianshou.policy import DiscreteBCQPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.discrete import Actor
if __name__ == "__main__":
from gather_cartpole_data import gather_data
else: # pytest
from test.offline.gather_cartpole_data import gather_data
def get_args(): def get_args():
@ -37,7 +43,7 @@ def get_args():
parser.add_argument( parser.add_argument(
"--load-buffer-name", "--load-buffer-name",
type=str, type=str,
default="./expert_DQN_CartPole-v0.pkl", default="./expert_QRDQN_CartPole-v0.pkl",
) )
parser.add_argument( parser.add_argument(
"--device", "--device",
@ -65,21 +71,15 @@ def test_discrete_bcq(args=get_args()):
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
test_envs.seed(args.seed) test_envs.seed(args.seed)
# model # model
policy_net = Net( net = Net(args.state_shape, args.hidden_sizes[0], device=args.device)
args.state_shape, policy_net = Actor(
args.action_shape, net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device
hidden_sizes=args.hidden_sizes,
device=args.device
).to(args.device) ).to(args.device)
imitation_net = Net( imitation_net = Actor(
args.state_shape, net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device
args.action_shape,
hidden_sizes=args.hidden_sizes,
device=args.device
).to(args.device) ).to(args.device)
optim = torch.optim.Adam( actor_critic = ActorCritic(policy_net, imitation_net)
list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
)
policy = DiscreteBCQPolicy( policy = DiscreteBCQPolicy(
policy_net, policy_net,
@ -93,9 +93,10 @@ def test_discrete_bcq(args=get_args()):
args.imitation_logits_penalty, args.imitation_logits_penalty,
) )
# buffer # buffer
assert os.path.exists(args.load_buffer_name), \ if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
"Please run test_dqn.py first to get expert's data buffer."
buffer = pickle.load(open(args.load_buffer_name, "rb")) buffer = pickle.load(open(args.load_buffer_name, "rb"))
else:
buffer = gather_data()
# collector # collector
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)

View File

@ -15,6 +15,11 @@ from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
if __name__ == "__main__":
from gather_cartpole_data import gather_data
else: # pytest
from test.offline.gather_cartpole_data import gather_data
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -83,9 +88,10 @@ def test_discrete_cql(args=get_args()):
min_q_weight=args.min_q_weight min_q_weight=args.min_q_weight
).to(args.device) ).to(args.device)
# buffer # buffer
assert os.path.exists(args.load_buffer_name), \ if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
"Please run test_qrdqn.py first to get expert's data buffer."
buffer = pickle.load(open(args.load_buffer_name, "rb")) buffer = pickle.load(open(args.load_buffer_name, "rb"))
else:
buffer = gather_data()
# collector # collector
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)

View File

@ -13,7 +13,13 @@ from tianshou.env import DummyVectorEnv
from tianshou.policy import DiscreteCRRPolicy from tianshou.policy import DiscreteCRRPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.discrete import Actor, Critic
if __name__ == "__main__":
from gather_cartpole_data import gather_data
else: # pytest
from test.offline.gather_cartpole_data import gather_data
def get_args(): def get_args():
@ -34,7 +40,7 @@ def get_args():
parser.add_argument( parser.add_argument(
"--load-buffer-name", "--load-buffer-name",
type=str, type=str,
default="./expert_DQN_CartPole-v0.pkl", default="./expert_QRDQN_CartPole-v0.pkl",
) )
parser.add_argument( parser.add_argument(
"--device", "--device",
@ -60,23 +66,22 @@ def test_discrete_crr(args=get_args()):
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
test_envs.seed(args.seed) test_envs.seed(args.seed)
# model # model
actor = Net( net = Net(args.state_shape, args.hidden_sizes[0], device=args.device)
args.state_shape, actor = Actor(
net,
args.action_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, hidden_sizes=args.hidden_sizes,
device=args.device, device=args.device,
softmax=False softmax_output=False
) )
critic = Net( critic = Critic(
args.state_shape, net,
args.action_shape,
hidden_sizes=args.hidden_sizes, hidden_sizes=args.hidden_sizes,
device=args.device, last_size=np.prod(args.action_shape),
softmax=False device=args.device
)
optim = torch.optim.Adam(
list(actor.parameters()) + list(critic.parameters()), lr=args.lr
) )
actor_critic = ActorCritic(actor, critic)
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
policy = DiscreteCRRPolicy( policy = DiscreteCRRPolicy(
actor, actor,
@ -86,14 +91,15 @@ def test_discrete_crr(args=get_args()):
target_update_freq=args.target_update_freq, target_update_freq=args.target_update_freq,
).to(args.device) ).to(args.device)
# buffer # buffer
assert os.path.exists(args.load_buffer_name), \ if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
"Please run test_dqn.py first to get expert's data buffer."
buffer = pickle.load(open(args.load_buffer_name, "rb")) buffer = pickle.load(open(args.load_buffer_name, "rb"))
else:
buffer = gather_data()
# collector # collector
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
log_path = os.path.join(args.logdir, args.task, 'discrete_cql') log_path = os.path.join(args.logdir, args.task, 'discrete_crr')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer) logger = TensorboardLogger(writer)

View File

@ -1,6 +1,6 @@
from tianshou import data, env, exploration, policy, trainer, utils from tianshou import data, env, exploration, policy, trainer, utils
__version__ = "0.4.4" __version__ = "0.4.5"
__all__ = [ __all__ = [
"env", "env",

View File

@ -83,14 +83,14 @@ class DiscreteCRRPolicy(PGPolicy):
if self._target and self._iter % self._freq == 0: if self._target and self._iter % self._freq == 0:
self.sync_weight() self.sync_weight()
self.optim.zero_grad() self.optim.zero_grad()
q_t, _ = self.critic(batch.obs) q_t = self.critic(batch.obs)
act = to_torch(batch.act, dtype=torch.long, device=q_t.device) act = to_torch(batch.act, dtype=torch.long, device=q_t.device)
qa_t = q_t.gather(1, act.unsqueeze(1)) qa_t = q_t.gather(1, act.unsqueeze(1))
# Critic loss # Critic loss
with torch.no_grad(): with torch.no_grad():
target_a_t, _ = self.actor_old(batch.obs_next) target_a_t, _ = self.actor_old(batch.obs_next)
target_m = Categorical(logits=target_a_t) target_m = Categorical(logits=target_a_t)
q_t_target, _ = self.critic_old(batch.obs_next) q_t_target = self.critic_old(batch.obs_next)
rew = to_torch_as(batch.rew, q_t_target) rew = to_torch_as(batch.rew, q_t_target)
expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True) expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True)
expected_target_q[batch.done > 0] = 0.0 expected_target_q[batch.done > 0] = 0.0

View File

@ -14,7 +14,7 @@ from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
def offline_trainer( def offline_trainer(
policy: BasePolicy, policy: BasePolicy,
buffer: ReplayBuffer, buffer: ReplayBuffer,
test_collector: Collector, test_collector: Optional[Collector],
max_epoch: int, max_epoch: int,
update_per_epoch: int, update_per_epoch: int,
episode_per_test: int, episode_per_test: int,
@ -33,7 +33,8 @@ def offline_trainer(
The "step" in offline trainer means a gradient step. The "step" in offline trainer means a gradient step.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param Collector test_collector: the collector used for testing. :param Collector test_collector: the collector used for testing. If it's None, then
no testing will be performed.
:param int max_epoch: the maximum number of epochs for training. The training :param int max_epoch: the maximum number of epochs for training. The training
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
:param int update_per_epoch: the number of policy network updates, so-called :param int update_per_epoch: the number of policy network updates, so-called
@ -73,10 +74,12 @@ def offline_trainer(
start_epoch, _, gradient_step = logger.restore_data() start_epoch, _, gradient_step = logger.restore_data()
stat: Dict[str, MovAvg] = defaultdict(MovAvg) stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time() start_time = time.time()
test_collector.reset_stat()
if test_collector is not None:
test_c: Collector = test_collector
test_collector.reset_stat()
test_result = test_episode( test_result = test_episode(
policy, test_collector, test_fn, start_epoch, episode_per_test, logger, policy, test_c, test_fn, start_epoch, episode_per_test, logger,
gradient_step, reward_metric gradient_step, reward_metric
) )
best_epoch = start_epoch best_epoch = start_epoch
@ -97,9 +100,11 @@ def offline_trainer(
data[k] = f"{losses[k]:.3f}" data[k] = f"{losses[k]:.3f}"
logger.log_update_data(losses, gradient_step) logger.log_update_data(losses, gradient_step)
t.set_postfix(**data) t.set_postfix(**data)
logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn)
# test # test
if test_collector is not None:
test_result = test_episode( test_result = test_episode(
policy, test_collector, test_fn, epoch, episode_per_test, logger, policy, test_c, test_fn, epoch, episode_per_test, logger,
gradient_step, reward_metric gradient_step, reward_metric
) )
rew, rew_std = test_result["rew"], test_result["rew_std"] rew, rew_std = test_result["rew"], test_result["rew_std"]
@ -107,7 +112,6 @@ def offline_trainer(
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn)
if verbose: if verbose:
print( print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
@ -115,4 +119,13 @@ def offline_trainer(
) )
if stop_fn and stop_fn(best_reward): if stop_fn and stop_fn(best_reward):
break break
return gather_info(start_time, None, test_collector, best_reward, best_reward_std)
if test_collector is None and save_fn:
save_fn(policy)
if test_collector is None:
return gather_info(start_time, None, None, 0.0, 0.0)
else:
return gather_info(
start_time, None, test_collector, best_reward, best_reward_std
)

View File

@ -14,7 +14,7 @@ from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
def offpolicy_trainer( def offpolicy_trainer(
policy: BasePolicy, policy: BasePolicy,
train_collector: Collector, train_collector: Collector,
test_collector: Collector, test_collector: Optional[Collector],
max_epoch: int, max_epoch: int,
step_per_epoch: int, step_per_epoch: int,
step_per_collect: int, step_per_collect: int,
@ -38,7 +38,8 @@ def offpolicy_trainer(
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param Collector train_collector: the collector used for training. :param Collector train_collector: the collector used for training.
:param Collector test_collector: the collector used for testing. :param Collector test_collector: the collector used for testing. If it's None, then
no testing will be performed.
:param int max_epoch: the maximum number of epochs for training. The training :param int max_epoch: the maximum number of epochs for training. The training
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
:param int step_per_epoch: the number of transitions collected per epoch. :param int step_per_epoch: the number of transitions collected per epoch.
@ -90,11 +91,16 @@ def offpolicy_trainer(
stat: Dict[str, MovAvg] = defaultdict(MovAvg) stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time() start_time = time.time()
train_collector.reset_stat() train_collector.reset_stat()
test_in_train = test_in_train and (
train_collector.policy == policy and test_collector is not None
)
if test_collector is not None:
test_c: Collector = test_collector # for mypy
test_collector.reset_stat() test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy
test_result = test_episode( test_result = test_episode(
policy, test_collector, test_fn, start_epoch, episode_per_test, logger, policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step,
env_step, reward_metric reward_metric
) )
best_epoch = start_epoch best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
@ -129,8 +135,8 @@ def offpolicy_trainer(
if result["n/ep"] > 0: if result["n/ep"] > 0:
if test_in_train and stop_fn and stop_fn(result["rew"]): if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode( test_result = test_episode(
policy, test_collector, test_fn, epoch, episode_per_test, policy, test_c, test_fn, epoch, episode_per_test, logger,
logger, env_step env_step
) )
if stop_fn(test_result["rew"]): if stop_fn(test_result["rew"]):
if save_fn: if save_fn:
@ -156,9 +162,11 @@ def offpolicy_trainer(
t.set_postfix(**data) t.set_postfix(**data)
if t.n <= t.total: if t.n <= t.total:
t.update() t.update()
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
# test # test
if test_collector is not None:
test_result = test_episode( test_result = test_episode(
policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step, policy, test_c, test_fn, epoch, episode_per_test, logger, env_step,
reward_metric reward_metric
) )
rew, rew_std = test_result["rew"], test_result["rew_std"] rew, rew_std = test_result["rew"], test_result["rew_std"]
@ -166,7 +174,6 @@ def offpolicy_trainer(
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
if verbose: if verbose:
print( print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
@ -174,6 +181,13 @@ def offpolicy_trainer(
) )
if stop_fn and stop_fn(best_reward): if stop_fn and stop_fn(best_reward):
break break
if test_collector is None and save_fn:
save_fn(policy)
if test_collector is None:
return gather_info(start_time, train_collector, None, 0.0, 0.0)
else:
return gather_info( return gather_info(
start_time, train_collector, test_collector, best_reward, best_reward_std start_time, train_collector, test_collector, best_reward, best_reward_std
) )

View File

@ -14,7 +14,7 @@ from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
def onpolicy_trainer( def onpolicy_trainer(
policy: BasePolicy, policy: BasePolicy,
train_collector: Collector, train_collector: Collector,
test_collector: Collector, test_collector: Optional[Collector],
max_epoch: int, max_epoch: int,
step_per_epoch: int, step_per_epoch: int,
repeat_per_collect: int, repeat_per_collect: int,
@ -39,7 +39,8 @@ def onpolicy_trainer(
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param Collector train_collector: the collector used for training. :param Collector train_collector: the collector used for training.
:param Collector test_collector: the collector used for testing. :param Collector test_collector: the collector used for testing. If it's None, then
no testing will be performed.
:param int max_epoch: the maximum number of epochs for training. The training :param int max_epoch: the maximum number of epochs for training. The training
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
:param int step_per_epoch: the number of transitions collected per epoch. :param int step_per_epoch: the number of transitions collected per epoch.
@ -96,11 +97,16 @@ def onpolicy_trainer(
stat: Dict[str, MovAvg] = defaultdict(MovAvg) stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time() start_time = time.time()
train_collector.reset_stat() train_collector.reset_stat()
test_in_train = test_in_train and (
train_collector.policy == policy and test_collector is not None
)
if test_collector is not None:
test_c: Collector = test_collector # for mypy
test_collector.reset_stat() test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy
test_result = test_episode( test_result = test_episode(
policy, test_collector, test_fn, start_epoch, episode_per_test, logger, policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step,
env_step, reward_metric reward_metric
) )
best_epoch = start_epoch best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
@ -137,8 +143,8 @@ def onpolicy_trainer(
if result["n/ep"] > 0: if result["n/ep"] > 0:
if test_in_train and stop_fn and stop_fn(result["rew"]): if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode( test_result = test_episode(
policy, test_collector, test_fn, epoch, episode_per_test, policy, test_c, test_fn, epoch, episode_per_test, logger,
logger, env_step env_step
) )
if stop_fn(test_result["rew"]): if stop_fn(test_result["rew"]):
if save_fn: if save_fn:
@ -172,9 +178,11 @@ def onpolicy_trainer(
t.set_postfix(**data) t.set_postfix(**data)
if t.n <= t.total: if t.n <= t.total:
t.update() t.update()
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
# test # test
if test_collector is not None:
test_result = test_episode( test_result = test_episode(
policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step, policy, test_c, test_fn, epoch, episode_per_test, logger, env_step,
reward_metric reward_metric
) )
rew, rew_std = test_result["rew"], test_result["rew_std"] rew, rew_std = test_result["rew"], test_result["rew_std"]
@ -182,7 +190,6 @@ def onpolicy_trainer(
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
if verbose: if verbose:
print( print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
@ -190,6 +197,13 @@ def onpolicy_trainer(
) )
if stop_fn and stop_fn(best_reward): if stop_fn and stop_fn(best_reward):
break break
if test_collector is None and save_fn:
save_fn(policy)
if test_collector is None:
return gather_info(start_time, train_collector, None, 0.0, 0.0)
else:
return gather_info( return gather_info(
start_time, train_collector, test_collector, best_reward, best_reward_std start_time, train_collector, test_collector, best_reward, best_reward_std
) )

View File

@ -36,7 +36,7 @@ def test_episode(
def gather_info( def gather_info(
start_time: float, start_time: float,
train_c: Optional[Collector], train_c: Optional[Collector],
test_c: Collector, test_c: Optional[Collector],
best_reward: float, best_reward: float,
best_reward_std: float, best_reward_std: float,
) -> Dict[str, Union[float, str]]: ) -> Dict[str, Union[float, str]]:
@ -58,9 +58,16 @@ def gather_info(
* ``duration`` the total elapsed time. * ``duration`` the total elapsed time.
""" """
duration = time.time() - start_time duration = time.time() - start_time
model_time = duration
result: Dict[str, Union[float, str]] = {
"duration": f"{duration:.2f}s",
"train_time/model": f"{model_time:.2f}s",
}
if test_c is not None:
model_time = duration - test_c.collect_time model_time = duration - test_c.collect_time
test_speed = test_c.collect_step / test_c.collect_time test_speed = test_c.collect_step / test_c.collect_time
result: Dict[str, Union[float, str]] = { result.update(
{
"test_step": test_c.collect_step, "test_step": test_c.collect_step,
"test_episode": test_c.collect_episode, "test_episode": test_c.collect_episode,
"test_time": f"{test_c.collect_time:.2f}s", "test_time": f"{test_c.collect_time:.2f}s",
@ -70,9 +77,13 @@ def gather_info(
"duration": f"{duration:.2f}s", "duration": f"{duration:.2f}s",
"train_time/model": f"{model_time:.2f}s", "train_time/model": f"{model_time:.2f}s",
} }
)
if train_c is not None: if train_c is not None:
model_time -= train_c.collect_time model_time -= train_c.collect_time
if test_c is not None:
train_speed = train_c.collect_step / (duration - test_c.collect_time) train_speed = train_c.collect_step / (duration - test_c.collect_time)
else:
train_speed = train_c.collect_step / duration
result.update( result.update(
{ {
"train_step": train_c.collect_step, "train_step": train_c.collect_step,

View File

@ -35,6 +35,7 @@ class TensorboardLogger(BaseLogger):
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
for k, v in data.items(): for k, v in data.items():
self.writer.add_scalar(k, v, global_step=step) self.writer.add_scalar(k, v, global_step=step)
self.writer.flush() # issue #482
def save_data( def save_data(
self, self,