Fix critic network for Discrete CRR (#485)

- Fixes an inconsistency in the implementation of Discrete CRR. Now it uses `Critic` class for its critic, following conventions in other actor-critic policies;
- Updates several offline policies to use `ActorCritic` class for its optimizer to eliminate randomness caused by parameter sharing between actor and critic;
- Add `writer.flush()` in TensorboardLogger to ensure real-time result;
- Enable `test_collector=None` in 3 trainers to turn off testing during training;
- Updates the Atari offline results in README.md;
- Moves Atari offline RL examples to `examples/offline`; tests to `test/offline` per review comments.
This commit is contained in:
Yi Su 2021-11-28 07:10:28 -08:00 committed by GitHub
parent 5c5a3db94e
commit 3592f45446
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 459 additions and 258 deletions

0
examples/__init__.py Normal file
View File

View File

@ -1,4 +1,4 @@
# Atari General
# Atari
The sample speed is \~3000 env step per second (\~12000 Atari frame per second in fact since we use frame_stack=4) under the normal mode (use a CNN policy and a collector, also storing data into the buffer). The main bottleneck is training the convolutional neural network.
@ -95,66 +95,3 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
| MsPacmanNoFrameskip-v4 | 3101 | ![](results/rainbow/MsPacman_rew.png) | `python3 atari_rainbow.py --task "MsPacmanNoFrameskip-v4"` |
| SeaquestNoFrameskip-v4 | 2126 | ![](results/rainbow/Seaquest_rew.png) | `python3 atari_rainbow.py --task "SeaquestNoFrameskip-v4"` |
| SpaceInvadersNoFrameskip-v4 | 1794.5 | ![](results/rainbow/SpaceInvaders_rew.png) | `python3 atari_rainbow.py --task "SpaceInvadersNoFrameskip-v4"` |
# BCQ
To running BCQ algorithm on Atari, you need to do the following things:
- Train an expert, by using the command listed in the above DQN section;
- Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`.
We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
| Task | Online DQN | Behavioral | BCQ |
| ---------------------- | ---------- | ---------- | --------------------------------- |
| PongNoFrameskip-v4 | 21 | 7.7 | 21 (epoch 5) |
| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) |
# CQL
To running CQL algorithm on Atari, you need to do the following things:
- Train an expert, by using the command listed in the above QRDQN section;
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train CQL: `python3 atari_cql.py --task {your_task} --load-buffer-name expert.hdf5`.
We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | 19.5 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 248.3 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
We reduce the size of the offline data to 10% and 1% of the above and get:
Buffer size 100000:
| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` |
Buffer size 10000:
| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` |
| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` |
# CRR
To running CRR algorithm on Atari, you need to do the following things:
- Train an expert, by using the command listed in the above QRDQN section;
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train CQL: `python3 atari_crr.py --task {your_task} --load-buffer-name expert.hdf5`.
We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
| Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters |
| ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 16.1 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 26.4 (epoch 12) | 125.0 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps.

View File

View File

@ -2,9 +2,11 @@
In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore.
Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets.
## Continous control
## Train
Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent for continuous control. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets.
### Train
Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset.
@ -26,3 +28,59 @@ After 1M steps:
| --------------------- | --------------- |
| halfcheetah-expert-v1 | 10624.0 ± 181.4 |
## Discrete control
For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent. In the future, we can switch to better benchmarks such as the Atari portion of [RL Unplugged](https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged).
### Gather Data
To running CQL algorithm on Atari, you need to do the following things:
- Train an expert, by using the command listed in the QRDQN section of Atari examples: `python3 atari_qrdqn.py --task {your_task}`
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train offline model: `python3 atari_{bcq,cql,crr}.py --task {your_task} --load-buffer-name expert.hdf5`.
### BCQ
We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
| Task | Online QRDQN | Behavioral | BCQ | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.1 (epoch 5) | `python3 atari_bcq.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 64.6 (epoch 12, could be higher) | `python3 atari_bcq.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12` |
### CQL
We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.4 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 129.4 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
We reduce the size of the offline data to 10% and 1% of the above and get:
Buffer size 100000:
| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` |
Buffer size 10000:
| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` |
| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` |
### CRR
We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
| Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters |
| ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 17.7 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps.

View File

View File

@ -6,15 +6,16 @@ import pprint
import numpy as np
import torch
from atari_network import DQN
from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
from examples.atari.atari_network import DQN
from examples.atari.atari_wrapper import wrap_deepmind
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteBCQPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.discrete import Actor
@ -93,18 +94,17 @@ def test_discrete_bcq(args=get_args()):
args.action_shape,
device=args.device,
hidden_sizes=args.hidden_sizes,
softmax_output=False
softmax_output=False,
).to(args.device)
imitation_net = Actor(
feature_net,
args.action_shape,
device=args.device,
hidden_sizes=args.hidden_sizes,
softmax_output=False
softmax_output=False,
).to(args.device)
optim = torch.optim.Adam(
list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr
)
actor_critic = ActorCritic(policy_net, imitation_net)
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
# define policy
policy = DiscreteBCQPolicy(
policy_net, imitation_net, optim, args.gamma, args.n_step,
@ -171,7 +171,7 @@ def test_discrete_bcq(args=get_args()):
args.batch_size,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger
logger=logger,
)
pprint.pprint(result)

View File

@ -6,10 +6,10 @@ import pprint
import numpy as np
import torch
from atari_network import QRDQN
from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
from examples.atari.atari_network import QRDQN
from examples.atari.atari_wrapper import wrap_deepmind
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteCQLPolicy
@ -94,7 +94,7 @@ def test_discrete_cql(args=get_args()):
args.num_quantiles,
args.n_step,
args.target_update_freq,
min_q_weight=args.min_q_weight
min_q_weight=args.min_q_weight,
).to(args.device)
# load a previous policy
if args.resume_path:
@ -156,7 +156,7 @@ def test_discrete_cql(args=get_args()):
args.batch_size,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger
logger=logger,
)
pprint.pprint(result)

View File

@ -6,16 +6,17 @@ import pprint
import numpy as np
import torch
from atari_network import DQN
from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
from examples.atari.atari_network import DQN
from examples.atari.atari_wrapper import wrap_deepmind
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteCRRPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.discrete import Actor
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.discrete import Actor, Critic
def get_args():
@ -91,15 +92,18 @@ def test_discrete_crr(args=get_args()):
actor = Actor(
feature_net,
args.action_shape,
device=args.device,
hidden_sizes=args.hidden_sizes,
softmax_output=False
device=args.device,
softmax_output=False,
).to(args.device)
critic = DQN(*args.state_shape, args.action_shape,
device=args.device).to(args.device)
optim = torch.optim.Adam(
list(actor.parameters()) + list(critic.parameters()), lr=args.lr
)
critic = Critic(
feature_net,
hidden_sizes=args.hidden_sizes,
last_size=np.prod(args.action_shape),
device=args.device,
).to(args.device)
actor_critic = ActorCritic(actor, critic)
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
# define policy
policy = DiscreteCRRPolicy(
actor,
@ -110,7 +114,7 @@ def test_discrete_crr(args=get_args()):
ratio_upper_bound=args.ratio_upper_bound,
beta=args.beta,
min_q_weight=args.min_q_weight,
target_update_freq=args.target_update_freq
target_update_freq=args.target_update_freq,
).to(args.device)
# load a previous policy
if args.resume_path:
@ -171,7 +175,7 @@ def test_discrete_crr(args=get_args()):
args.batch_size,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger
logger=logger,
)
pprint.pprint(result)

View File

@ -1,6 +1,5 @@
import argparse
import os
import pickle
import pprint
import gym
@ -42,9 +41,6 @@ def get_args():
parser.add_argument('--prioritized-replay', action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--save-buffer-name', type=str, default="./expert_DQN_CartPole-v0.pkl"
)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
@ -85,7 +81,7 @@ def test_dqn(args=get_args()):
optim,
args.gamma,
args.n_step,
target_update_freq=args.target_update_freq
target_update_freq=args.target_update_freq,
)
# buffer
if args.prioritized_replay:
@ -93,7 +89,7 @@ def test_dqn(args=get_args()):
args.buffer_size,
buffer_num=len(train_envs),
alpha=args.alpha,
beta=args.beta
beta=args.beta,
)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
@ -142,7 +138,7 @@ def test_dqn(args=get_args()):
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger
logger=logger,
)
assert stop_fn(result['best_reward'])
@ -157,14 +153,6 @@ def test_dqn(args=get_args()):
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
# save buffer in pickle format, for imitation learning unittest
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
policy.set_eps(0.2)
collector = Collector(policy, test_envs, buf, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
pickle.dump(buf, open(args.save_buffer_name, "wb"))
print(result["rews"].mean())
def test_pdqn(args=get_args()):
args.prioritized_replay = True

View File

@ -1,6 +1,5 @@
import argparse
import os
import pickle
import pprint
import gym
@ -43,9 +42,6 @@ def get_args():
parser.add_argument('--prioritized-replay', action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--save-buffer-name', type=str, default="./expert_QRDQN_CartPole-v0.pkl"
)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
@ -80,7 +76,7 @@ def test_qrdqn(args=get_args()):
hidden_sizes=args.hidden_sizes,
device=args.device,
softmax=False,
num_atoms=args.num_quantiles
num_atoms=args.num_quantiles,
)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = QRDQNPolicy(
@ -89,7 +85,7 @@ def test_qrdqn(args=get_args()):
args.gamma,
args.num_quantiles,
args.n_step,
target_update_freq=args.target_update_freq
target_update_freq=args.target_update_freq,
).to(args.device)
# buffer
if args.prioritized_replay:
@ -97,7 +93,7 @@ def test_qrdqn(args=get_args()):
args.buffer_size,
buffer_num=len(train_envs),
alpha=args.alpha,
beta=args.beta
beta=args.beta,
)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
@ -146,7 +142,7 @@ def test_qrdqn(args=get_args()):
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger,
update_per_step=args.update_per_step
update_per_step=args.update_per_step,
)
assert stop_fn(result['best_reward'])
@ -161,14 +157,6 @@ def test_qrdqn(args=get_args()):
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
# save buffer in pickle format, for imitation learning unittest
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
policy.set_eps(0.9) # 10% of expert data as demonstrated in the original paper
collector = Collector(policy, test_envs, buf, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
pickle.dump(buf, open(args.save_buffer_name, "wb"))
print(result["rews"].mean())
def test_pqrdqn(args=get_args()):
args.prioritized_replay = True

View File

@ -0,0 +1,160 @@
import argparse
import os
import pickle
import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import QRDQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--num-quantiles', type=int, default=200)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument(
'--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--prioritized-replay', action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--save-buffer-name', type=str, default="./expert_QRDQN_CartPole-v0.pkl"
)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
args = parser.parse_known_args()[0]
return args
def gather_data():
args = get_args()
env = gym.make(args.task)
if args.task == 'CartPole-v0':
env.spec.reward_threshold = 190 # lower the goal
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)]
)
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
device=args.device,
softmax=False,
num_atoms=args.num_quantiles,
)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = QRDQNPolicy(
net,
optim,
args.gamma,
args.num_quantiles,
args.n_step,
target_update_freq=args.target_update_freq,
).to(args.device)
# buffer
if args.prioritized_replay:
buf = PrioritizedVectorReplayBuffer(
args.buffer_size,
buffer_num=len(train_envs),
alpha=args.alpha,
beta=args.beta,
)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
# collector
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
def train_fn(epoch, env_step):
# eps annnealing, just a demo
if env_step <= 10000:
policy.set_eps(args.eps_train)
elif env_step <= 50000:
eps = args.eps_train - (env_step - 10000) / \
40000 * (0.9 * args.eps_train)
policy.set_eps(eps)
else:
policy.set_eps(0.1 * args.eps_train)
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
# trainer
result = offpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.step_per_collect,
args.test_num,
args.batch_size,
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger,
update_per_step=args.update_per_step,
)
assert stop_fn(result['best_reward'])
# save buffer in pickle format, for imitation learning unittest
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
policy.set_eps(0.2)
collector = Collector(policy, test_envs, buf, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
pickle.dump(buf, open(args.save_buffer_name, "wb"))
print(result["rews"].mean())
return buf

View File

@ -13,7 +13,13 @@ from tianshou.env import DummyVectorEnv
from tianshou.policy import DiscreteBCQPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.discrete import Actor
if __name__ == "__main__":
from gather_cartpole_data import gather_data
else: # pytest
from test.offline.gather_cartpole_data import gather_data
def get_args():
@ -37,7 +43,7 @@ def get_args():
parser.add_argument(
"--load-buffer-name",
type=str,
default="./expert_DQN_CartPole-v0.pkl",
default="./expert_QRDQN_CartPole-v0.pkl",
)
parser.add_argument(
"--device",
@ -65,21 +71,15 @@ def test_discrete_bcq(args=get_args()):
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
policy_net = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
device=args.device
net = Net(args.state_shape, args.hidden_sizes[0], device=args.device)
policy_net = Actor(
net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device
).to(args.device)
imitation_net = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
device=args.device
imitation_net = Actor(
net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device
).to(args.device)
optim = torch.optim.Adam(
list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr
)
actor_critic = ActorCritic(policy_net, imitation_net)
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
policy = DiscreteBCQPolicy(
policy_net,
@ -93,9 +93,10 @@ def test_discrete_bcq(args=get_args()):
args.imitation_logits_penalty,
)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run test_dqn.py first to get expert's data buffer."
buffer = pickle.load(open(args.load_buffer_name, "rb"))
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
else:
buffer = gather_data()
# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)

View File

@ -15,6 +15,11 @@ from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
if __name__ == "__main__":
from gather_cartpole_data import gather_data
else: # pytest
from test.offline.gather_cartpole_data import gather_data
def get_args():
parser = argparse.ArgumentParser()
@ -83,9 +88,10 @@ def test_discrete_cql(args=get_args()):
min_q_weight=args.min_q_weight
).to(args.device)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run test_qrdqn.py first to get expert's data buffer."
buffer = pickle.load(open(args.load_buffer_name, "rb"))
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
else:
buffer = gather_data()
# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)

View File

@ -13,7 +13,13 @@ from tianshou.env import DummyVectorEnv
from tianshou.policy import DiscreteCRRPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.discrete import Actor, Critic
if __name__ == "__main__":
from gather_cartpole_data import gather_data
else: # pytest
from test.offline.gather_cartpole_data import gather_data
def get_args():
@ -34,7 +40,7 @@ def get_args():
parser.add_argument(
"--load-buffer-name",
type=str,
default="./expert_DQN_CartPole-v0.pkl",
default="./expert_QRDQN_CartPole-v0.pkl",
)
parser.add_argument(
"--device",
@ -60,23 +66,22 @@ def test_discrete_crr(args=get_args()):
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
actor = Net(
args.state_shape,
net = Net(args.state_shape, args.hidden_sizes[0], device=args.device)
actor = Actor(
net,
args.action_shape,
hidden_sizes=args.hidden_sizes,
device=args.device,
softmax=False
softmax_output=False
)
critic = Net(
args.state_shape,
args.action_shape,
critic = Critic(
net,
hidden_sizes=args.hidden_sizes,
device=args.device,
softmax=False
)
optim = torch.optim.Adam(
list(actor.parameters()) + list(critic.parameters()), lr=args.lr
last_size=np.prod(args.action_shape),
device=args.device
)
actor_critic = ActorCritic(actor, critic)
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
policy = DiscreteCRRPolicy(
actor,
@ -86,14 +91,15 @@ def test_discrete_crr(args=get_args()):
target_update_freq=args.target_update_freq,
).to(args.device)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run test_dqn.py first to get expert's data buffer."
buffer = pickle.load(open(args.load_buffer_name, "rb"))
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
else:
buffer = gather_data()
# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)
log_path = os.path.join(args.logdir, args.task, 'discrete_cql')
log_path = os.path.join(args.logdir, args.task, 'discrete_crr')
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

View File

@ -1,6 +1,6 @@
from tianshou import data, env, exploration, policy, trainer, utils
__version__ = "0.4.4"
__version__ = "0.4.5"
__all__ = [
"env",

View File

@ -83,14 +83,14 @@ class DiscreteCRRPolicy(PGPolicy):
if self._target and self._iter % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
q_t, _ = self.critic(batch.obs)
q_t = self.critic(batch.obs)
act = to_torch(batch.act, dtype=torch.long, device=q_t.device)
qa_t = q_t.gather(1, act.unsqueeze(1))
# Critic loss
with torch.no_grad():
target_a_t, _ = self.actor_old(batch.obs_next)
target_m = Categorical(logits=target_a_t)
q_t_target, _ = self.critic_old(batch.obs_next)
q_t_target = self.critic_old(batch.obs_next)
rew = to_torch_as(batch.rew, q_t_target)
expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True)
expected_target_q[batch.done > 0] = 0.0

View File

@ -14,7 +14,7 @@ from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
def offline_trainer(
policy: BasePolicy,
buffer: ReplayBuffer,
test_collector: Collector,
test_collector: Optional[Collector],
max_epoch: int,
update_per_epoch: int,
episode_per_test: int,
@ -33,7 +33,8 @@ def offline_trainer(
The "step" in offline trainer means a gradient step.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param Collector test_collector: the collector used for testing.
:param Collector test_collector: the collector used for testing. If it's None, then
no testing will be performed.
:param int max_epoch: the maximum number of epochs for training. The training
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
:param int update_per_epoch: the number of policy network updates, so-called
@ -73,14 +74,16 @@ def offline_trainer(
start_epoch, _, gradient_step = logger.restore_data()
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
test_collector.reset_stat()
test_result = test_episode(
policy, test_collector, test_fn, start_epoch, episode_per_test, logger,
gradient_step, reward_metric
)
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
if test_collector is not None:
test_c: Collector = test_collector
test_collector.reset_stat()
test_result = test_episode(
policy, test_c, test_fn, start_epoch, episode_per_test, logger,
gradient_step, reward_metric
)
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
if save_fn:
save_fn(policy)
@ -97,22 +100,32 @@ def offline_trainer(
data[k] = f"{losses[k]:.3f}"
logger.log_update_data(losses, gradient_step)
t.set_postfix(**data)
# test
test_result = test_episode(
policy, test_collector, test_fn, epoch, episode_per_test, logger,
gradient_step, reward_metric
)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn:
save_fn(policy)
logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn)
if verbose:
print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
# test
if test_collector is not None:
test_result = test_episode(
policy, test_c, test_fn, epoch, episode_per_test, logger,
gradient_step, reward_metric
)
if stop_fn and stop_fn(best_reward):
break
return gather_info(start_time, None, test_collector, best_reward, best_reward_std)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn:
save_fn(policy)
if verbose:
print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
)
if stop_fn and stop_fn(best_reward):
break
if test_collector is None and save_fn:
save_fn(policy)
if test_collector is None:
return gather_info(start_time, None, None, 0.0, 0.0)
else:
return gather_info(
start_time, None, test_collector, best_reward, best_reward_std
)

View File

@ -14,7 +14,7 @@ from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
def offpolicy_trainer(
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
test_collector: Optional[Collector],
max_epoch: int,
step_per_epoch: int,
step_per_collect: int,
@ -38,7 +38,8 @@ def offpolicy_trainer(
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param Collector train_collector: the collector used for training.
:param Collector test_collector: the collector used for testing.
:param Collector test_collector: the collector used for testing. If it's None, then
no testing will be performed.
:param int max_epoch: the maximum number of epochs for training. The training
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
:param int step_per_epoch: the number of transitions collected per epoch.
@ -90,14 +91,19 @@ def offpolicy_trainer(
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
train_collector.reset_stat()
test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy
test_result = test_episode(
policy, test_collector, test_fn, start_epoch, episode_per_test, logger,
env_step, reward_metric
test_in_train = test_in_train and (
train_collector.policy == policy and test_collector is not None
)
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
if test_collector is not None:
test_c: Collector = test_collector # for mypy
test_collector.reset_stat()
test_result = test_episode(
policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step,
reward_metric
)
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
if save_fn:
save_fn(policy)
@ -129,8 +135,8 @@ def offpolicy_trainer(
if result["n/ep"] > 0:
if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode(
policy, test_collector, test_fn, epoch, episode_per_test,
logger, env_step
policy, test_c, test_fn, epoch, episode_per_test, logger,
env_step
)
if stop_fn(test_result["rew"]):
if save_fn:
@ -156,24 +162,32 @@ def offpolicy_trainer(
t.set_postfix(**data)
if t.n <= t.total:
t.update()
# test
test_result = test_episode(
policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step,
reward_metric
)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn:
save_fn(policy)
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
if verbose:
print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
# test
if test_collector is not None:
test_result = test_episode(
policy, test_c, test_fn, epoch, episode_per_test, logger, env_step,
reward_metric
)
if stop_fn and stop_fn(best_reward):
break
return gather_info(
start_time, train_collector, test_collector, best_reward, best_reward_std
)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn:
save_fn(policy)
if verbose:
print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
)
if stop_fn and stop_fn(best_reward):
break
if test_collector is None and save_fn:
save_fn(policy)
if test_collector is None:
return gather_info(start_time, train_collector, None, 0.0, 0.0)
else:
return gather_info(
start_time, train_collector, test_collector, best_reward, best_reward_std
)

View File

@ -14,7 +14,7 @@ from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
def onpolicy_trainer(
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
test_collector: Optional[Collector],
max_epoch: int,
step_per_epoch: int,
repeat_per_collect: int,
@ -39,7 +39,8 @@ def onpolicy_trainer(
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param Collector train_collector: the collector used for training.
:param Collector test_collector: the collector used for testing.
:param Collector test_collector: the collector used for testing. If it's None, then
no testing will be performed.
:param int max_epoch: the maximum number of epochs for training. The training
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
:param int step_per_epoch: the number of transitions collected per epoch.
@ -96,14 +97,19 @@ def onpolicy_trainer(
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
train_collector.reset_stat()
test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy
test_result = test_episode(
policy, test_collector, test_fn, start_epoch, episode_per_test, logger,
env_step, reward_metric
test_in_train = test_in_train and (
train_collector.policy == policy and test_collector is not None
)
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
if test_collector is not None:
test_c: Collector = test_collector # for mypy
test_collector.reset_stat()
test_result = test_episode(
policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step,
reward_metric
)
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
if save_fn:
save_fn(policy)
@ -137,8 +143,8 @@ def onpolicy_trainer(
if result["n/ep"] > 0:
if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode(
policy, test_collector, test_fn, epoch, episode_per_test,
logger, env_step
policy, test_c, test_fn, epoch, episode_per_test, logger,
env_step
)
if stop_fn(test_result["rew"]):
if save_fn:
@ -172,24 +178,32 @@ def onpolicy_trainer(
t.set_postfix(**data)
if t.n <= t.total:
t.update()
# test
test_result = test_episode(
policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step,
reward_metric
)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn:
save_fn(policy)
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
if verbose:
print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
# test
if test_collector is not None:
test_result = test_episode(
policy, test_c, test_fn, epoch, episode_per_test, logger, env_step,
reward_metric
)
if stop_fn and stop_fn(best_reward):
break
return gather_info(
start_time, train_collector, test_collector, best_reward, best_reward_std
)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn:
save_fn(policy)
if verbose:
print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
)
if stop_fn and stop_fn(best_reward):
break
if test_collector is None and save_fn:
save_fn(policy)
if test_collector is None:
return gather_info(start_time, train_collector, None, 0.0, 0.0)
else:
return gather_info(
start_time, train_collector, test_collector, best_reward, best_reward_std
)

View File

@ -36,7 +36,7 @@ def test_episode(
def gather_info(
start_time: float,
train_c: Optional[Collector],
test_c: Collector,
test_c: Optional[Collector],
best_reward: float,
best_reward_std: float,
) -> Dict[str, Union[float, str]]:
@ -58,21 +58,32 @@ def gather_info(
* ``duration`` the total elapsed time.
"""
duration = time.time() - start_time
model_time = duration - test_c.collect_time
test_speed = test_c.collect_step / test_c.collect_time
model_time = duration
result: Dict[str, Union[float, str]] = {
"test_step": test_c.collect_step,
"test_episode": test_c.collect_episode,
"test_time": f"{test_c.collect_time:.2f}s",
"test_speed": f"{test_speed:.2f} step/s",
"best_reward": best_reward,
"best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}",
"duration": f"{duration:.2f}s",
"train_time/model": f"{model_time:.2f}s",
}
if test_c is not None:
model_time = duration - test_c.collect_time
test_speed = test_c.collect_step / test_c.collect_time
result.update(
{
"test_step": test_c.collect_step,
"test_episode": test_c.collect_episode,
"test_time": f"{test_c.collect_time:.2f}s",
"test_speed": f"{test_speed:.2f} step/s",
"best_reward": best_reward,
"best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}",
"duration": f"{duration:.2f}s",
"train_time/model": f"{model_time:.2f}s",
}
)
if train_c is not None:
model_time -= train_c.collect_time
train_speed = train_c.collect_step / (duration - test_c.collect_time)
if test_c is not None:
train_speed = train_c.collect_step / (duration - test_c.collect_time)
else:
train_speed = train_c.collect_step / duration
result.update(
{
"train_step": train_c.collect_step,

View File

@ -35,6 +35,7 @@ class TensorboardLogger(BaseLogger):
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
for k, v in data.items():
self.writer.add_scalar(k, v, global_step=step)
self.writer.flush() # issue #482
def save_data(
self,