Fix critic network for Discrete CRR (#485)
- Fixes an inconsistency in the implementation of Discrete CRR. Now it uses `Critic` class for its critic, following conventions in other actor-critic policies; - Updates several offline policies to use `ActorCritic` class for its optimizer to eliminate randomness caused by parameter sharing between actor and critic; - Add `writer.flush()` in TensorboardLogger to ensure real-time result; - Enable `test_collector=None` in 3 trainers to turn off testing during training; - Updates the Atari offline results in README.md; - Moves Atari offline RL examples to `examples/offline`; tests to `test/offline` per review comments.
This commit is contained in:
parent
5c5a3db94e
commit
3592f45446
0
examples/__init__.py
Normal file
0
examples/__init__.py
Normal file
@ -1,4 +1,4 @@
|
||||
# Atari General
|
||||
# Atari
|
||||
|
||||
The sample speed is \~3000 env step per second (\~12000 Atari frame per second in fact since we use frame_stack=4) under the normal mode (use a CNN policy and a collector, also storing data into the buffer). The main bottleneck is training the convolutional neural network.
|
||||
|
||||
@ -95,66 +95,3 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
||||
| MsPacmanNoFrameskip-v4 | 3101 |  | `python3 atari_rainbow.py --task "MsPacmanNoFrameskip-v4"` |
|
||||
| SeaquestNoFrameskip-v4 | 2126 |  | `python3 atari_rainbow.py --task "SeaquestNoFrameskip-v4"` |
|
||||
| SpaceInvadersNoFrameskip-v4 | 1794.5 |  | `python3 atari_rainbow.py --task "SpaceInvadersNoFrameskip-v4"` |
|
||||
|
||||
# BCQ
|
||||
|
||||
To running BCQ algorithm on Atari, you need to do the following things:
|
||||
|
||||
- Train an expert, by using the command listed in the above DQN section;
|
||||
- Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
|
||||
- Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`.
|
||||
|
||||
We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
|
||||
|
||||
| Task | Online DQN | Behavioral | BCQ |
|
||||
| ---------------------- | ---------- | ---------- | --------------------------------- |
|
||||
| PongNoFrameskip-v4 | 21 | 7.7 | 21 (epoch 5) |
|
||||
| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) |
|
||||
|
||||
# CQL
|
||||
|
||||
To running CQL algorithm on Atari, you need to do the following things:
|
||||
|
||||
- Train an expert, by using the command listed in the above QRDQN section;
|
||||
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
|
||||
- Train CQL: `python3 atari_cql.py --task {your_task} --load-buffer-name expert.hdf5`.
|
||||
|
||||
We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
|
||||
|
||||
| Task | Online QRDQN | Behavioral | CQL | parameters |
|
||||
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20.5 | 6.8 | 19.5 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
|
||||
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 248.3 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
|
||||
|
||||
We reduce the size of the offline data to 10% and 1% of the above and get:
|
||||
|
||||
Buffer size 100000:
|
||||
|
||||
| Task | Online QRDQN | Behavioral | CQL | parameters |
|
||||
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` |
|
||||
| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` |
|
||||
|
||||
Buffer size 10000:
|
||||
|
||||
| Task | Online QRDQN | Behavioral | CQL | parameters |
|
||||
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` |
|
||||
| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` |
|
||||
|
||||
# CRR
|
||||
|
||||
To running CRR algorithm on Atari, you need to do the following things:
|
||||
|
||||
- Train an expert, by using the command listed in the above QRDQN section;
|
||||
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
|
||||
- Train CQL: `python3 atari_crr.py --task {your_task} --load-buffer-name expert.hdf5`.
|
||||
|
||||
We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
|
||||
|
||||
| Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters |
|
||||
| ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 16.1 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
|
||||
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 26.4 (epoch 12) | 125.0 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
|
||||
|
||||
Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps.
|
||||
|
0
examples/atari/__init__.py
Normal file
0
examples/atari/__init__.py
Normal file
@ -2,9 +2,11 @@
|
||||
|
||||
In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore.
|
||||
|
||||
Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets.
|
||||
## Continous control
|
||||
|
||||
## Train
|
||||
Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent for continuous control. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets.
|
||||
|
||||
### Train
|
||||
|
||||
Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset.
|
||||
|
||||
@ -26,3 +28,59 @@ After 1M steps:
|
||||
| --------------------- | --------------- |
|
||||
| halfcheetah-expert-v1 | 10624.0 ± 181.4 |
|
||||
|
||||
## Discrete control
|
||||
|
||||
For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent. In the future, we can switch to better benchmarks such as the Atari portion of [RL Unplugged](https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged).
|
||||
|
||||
### Gather Data
|
||||
|
||||
To running CQL algorithm on Atari, you need to do the following things:
|
||||
|
||||
- Train an expert, by using the command listed in the QRDQN section of Atari examples: `python3 atari_qrdqn.py --task {your_task}`
|
||||
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
|
||||
- Train offline model: `python3 atari_{bcq,cql,crr}.py --task {your_task} --load-buffer-name expert.hdf5`.
|
||||
|
||||
### BCQ
|
||||
|
||||
We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
|
||||
|
||||
| Task | Online QRDQN | Behavioral | BCQ | parameters |
|
||||
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.1 (epoch 5) | `python3 atari_bcq.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
|
||||
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 64.6 (epoch 12, could be higher) | `python3 atari_bcq.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12` |
|
||||
|
||||
### CQL
|
||||
|
||||
We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
|
||||
|
||||
| Task | Online QRDQN | Behavioral | CQL | parameters |
|
||||
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.4 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
|
||||
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 129.4 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
|
||||
|
||||
We reduce the size of the offline data to 10% and 1% of the above and get:
|
||||
|
||||
Buffer size 100000:
|
||||
|
||||
| Task | Online QRDQN | Behavioral | CQL | parameters |
|
||||
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` |
|
||||
| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` |
|
||||
|
||||
Buffer size 10000:
|
||||
|
||||
| Task | Online QRDQN | Behavioral | CQL | parameters |
|
||||
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` |
|
||||
| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` |
|
||||
|
||||
### CRR
|
||||
|
||||
We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
|
||||
|
||||
| Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters |
|
||||
| ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 17.7 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
|
||||
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
|
||||
|
||||
Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps.
|
||||
|
0
examples/offline/__init__.py
Normal file
0
examples/offline/__init__.py
Normal file
@ -6,15 +6,16 @@ import pprint
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from atari_network import DQN
|
||||
from atari_wrapper import wrap_deepmind
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from examples.atari.atari_network import DQN
|
||||
from examples.atari.atari_wrapper import wrap_deepmind
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.policy import DiscreteBCQPolicy
|
||||
from tianshou.trainer import offline_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
from tianshou.utils.net.discrete import Actor
|
||||
|
||||
|
||||
@ -93,18 +94,17 @@ def test_discrete_bcq(args=get_args()):
|
||||
args.action_shape,
|
||||
device=args.device,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
softmax_output=False
|
||||
softmax_output=False,
|
||||
).to(args.device)
|
||||
imitation_net = Actor(
|
||||
feature_net,
|
||||
args.action_shape,
|
||||
device=args.device,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
softmax_output=False
|
||||
softmax_output=False,
|
||||
).to(args.device)
|
||||
optim = torch.optim.Adam(
|
||||
list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr
|
||||
)
|
||||
actor_critic = ActorCritic(policy_net, imitation_net)
|
||||
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
|
||||
# define policy
|
||||
policy = DiscreteBCQPolicy(
|
||||
policy_net, imitation_net, optim, args.gamma, args.n_step,
|
||||
@ -171,7 +171,7 @@ def test_discrete_bcq(args=get_args()):
|
||||
args.batch_size,
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
logger=logger
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
pprint.pprint(result)
|
@ -6,10 +6,10 @@ import pprint
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from atari_network import QRDQN
|
||||
from atari_wrapper import wrap_deepmind
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from examples.atari.atari_network import QRDQN
|
||||
from examples.atari.atari_wrapper import wrap_deepmind
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.policy import DiscreteCQLPolicy
|
||||
@ -94,7 +94,7 @@ def test_discrete_cql(args=get_args()):
|
||||
args.num_quantiles,
|
||||
args.n_step,
|
||||
args.target_update_freq,
|
||||
min_q_weight=args.min_q_weight
|
||||
min_q_weight=args.min_q_weight,
|
||||
).to(args.device)
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
@ -156,7 +156,7 @@ def test_discrete_cql(args=get_args()):
|
||||
args.batch_size,
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
logger=logger
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
pprint.pprint(result)
|
@ -6,16 +6,17 @@ import pprint
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from atari_network import DQN
|
||||
from atari_wrapper import wrap_deepmind
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from examples.atari.atari_network import DQN
|
||||
from examples.atari.atari_wrapper import wrap_deepmind
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.policy import DiscreteCRRPolicy
|
||||
from tianshou.trainer import offline_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.discrete import Actor
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
from tianshou.utils.net.discrete import Actor, Critic
|
||||
|
||||
|
||||
def get_args():
|
||||
@ -91,15 +92,18 @@ def test_discrete_crr(args=get_args()):
|
||||
actor = Actor(
|
||||
feature_net,
|
||||
args.action_shape,
|
||||
device=args.device,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
softmax_output=False
|
||||
device=args.device,
|
||||
softmax_output=False,
|
||||
).to(args.device)
|
||||
critic = DQN(*args.state_shape, args.action_shape,
|
||||
device=args.device).to(args.device)
|
||||
optim = torch.optim.Adam(
|
||||
list(actor.parameters()) + list(critic.parameters()), lr=args.lr
|
||||
)
|
||||
critic = Critic(
|
||||
feature_net,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
last_size=np.prod(args.action_shape),
|
||||
device=args.device,
|
||||
).to(args.device)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
|
||||
# define policy
|
||||
policy = DiscreteCRRPolicy(
|
||||
actor,
|
||||
@ -110,7 +114,7 @@ def test_discrete_crr(args=get_args()):
|
||||
ratio_upper_bound=args.ratio_upper_bound,
|
||||
beta=args.beta,
|
||||
min_q_weight=args.min_q_weight,
|
||||
target_update_freq=args.target_update_freq
|
||||
target_update_freq=args.target_update_freq,
|
||||
).to(args.device)
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
@ -171,7 +175,7 @@ def test_discrete_crr(args=get_args()):
|
||||
args.batch_size,
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
logger=logger
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
pprint.pprint(result)
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
import pprint
|
||||
|
||||
import gym
|
||||
@ -42,9 +41,6 @@ def get_args():
|
||||
parser.add_argument('--prioritized-replay', action="store_true", default=False)
|
||||
parser.add_argument('--alpha', type=float, default=0.6)
|
||||
parser.add_argument('--beta', type=float, default=0.4)
|
||||
parser.add_argument(
|
||||
'--save-buffer-name', type=str, default="./expert_DQN_CartPole-v0.pkl"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
@ -85,7 +81,7 @@ def test_dqn(args=get_args()):
|
||||
optim,
|
||||
args.gamma,
|
||||
args.n_step,
|
||||
target_update_freq=args.target_update_freq
|
||||
target_update_freq=args.target_update_freq,
|
||||
)
|
||||
# buffer
|
||||
if args.prioritized_replay:
|
||||
@ -93,7 +89,7 @@ def test_dqn(args=get_args()):
|
||||
args.buffer_size,
|
||||
buffer_num=len(train_envs),
|
||||
alpha=args.alpha,
|
||||
beta=args.beta
|
||||
beta=args.beta,
|
||||
)
|
||||
else:
|
||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
||||
@ -142,7 +138,7 @@ def test_dqn(args=get_args()):
|
||||
test_fn=test_fn,
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
logger=logger
|
||||
logger=logger,
|
||||
)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
@ -157,14 +153,6 @@ def test_dqn(args=get_args()):
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
# save buffer in pickle format, for imitation learning unittest
|
||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
|
||||
policy.set_eps(0.2)
|
||||
collector = Collector(policy, test_envs, buf, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
pickle.dump(buf, open(args.save_buffer_name, "wb"))
|
||||
print(result["rews"].mean())
|
||||
|
||||
|
||||
def test_pdqn(args=get_args()):
|
||||
args.prioritized_replay = True
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
import pprint
|
||||
|
||||
import gym
|
||||
@ -43,9 +42,6 @@ def get_args():
|
||||
parser.add_argument('--prioritized-replay', action="store_true", default=False)
|
||||
parser.add_argument('--alpha', type=float, default=0.6)
|
||||
parser.add_argument('--beta', type=float, default=0.4)
|
||||
parser.add_argument(
|
||||
'--save-buffer-name', type=str, default="./expert_QRDQN_CartPole-v0.pkl"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
@ -80,7 +76,7 @@ def test_qrdqn(args=get_args()):
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
device=args.device,
|
||||
softmax=False,
|
||||
num_atoms=args.num_quantiles
|
||||
num_atoms=args.num_quantiles,
|
||||
)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = QRDQNPolicy(
|
||||
@ -89,7 +85,7 @@ def test_qrdqn(args=get_args()):
|
||||
args.gamma,
|
||||
args.num_quantiles,
|
||||
args.n_step,
|
||||
target_update_freq=args.target_update_freq
|
||||
target_update_freq=args.target_update_freq,
|
||||
).to(args.device)
|
||||
# buffer
|
||||
if args.prioritized_replay:
|
||||
@ -97,7 +93,7 @@ def test_qrdqn(args=get_args()):
|
||||
args.buffer_size,
|
||||
buffer_num=len(train_envs),
|
||||
alpha=args.alpha,
|
||||
beta=args.beta
|
||||
beta=args.beta,
|
||||
)
|
||||
else:
|
||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
||||
@ -146,7 +142,7 @@ def test_qrdqn(args=get_args()):
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
logger=logger,
|
||||
update_per_step=args.update_per_step
|
||||
update_per_step=args.update_per_step,
|
||||
)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
@ -161,14 +157,6 @@ def test_qrdqn(args=get_args()):
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
# save buffer in pickle format, for imitation learning unittest
|
||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
|
||||
policy.set_eps(0.9) # 10% of expert data as demonstrated in the original paper
|
||||
collector = Collector(policy, test_envs, buf, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
pickle.dump(buf, open(args.save_buffer_name, "wb"))
|
||||
print(result["rews"].mean())
|
||||
|
||||
|
||||
def test_pqrdqn(args=get_args()):
|
||||
args.prioritized_replay = True
|
||||
|
160
test/offline/gather_cartpole_data.py
Normal file
160
test/offline/gather_cartpole_data.py
Normal file
@ -0,0 +1,160 @@
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import QRDQNPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--num-quantiles', type=int, default=200)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||
parser.add_argument('--epoch', type=int, default=10)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument(
|
||||
'--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
|
||||
)
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--prioritized-replay', action="store_true", default=False)
|
||||
parser.add_argument('--alpha', type=float, default=0.6)
|
||||
parser.add_argument('--beta', type=float, default=0.4)
|
||||
parser.add_argument(
|
||||
'--save-buffer-name', type=str, default="./expert_QRDQN_CartPole-v0.pkl"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def gather_data():
|
||||
args = get_args()
|
||||
env = gym.make(args.task)
|
||||
if args.task == 'CartPole-v0':
|
||||
env.spec.reward_threshold = 190 # lower the goal
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)]
|
||||
)
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
device=args.device,
|
||||
softmax=False,
|
||||
num_atoms=args.num_quantiles,
|
||||
)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = QRDQNPolicy(
|
||||
net,
|
||||
optim,
|
||||
args.gamma,
|
||||
args.num_quantiles,
|
||||
args.n_step,
|
||||
target_update_freq=args.target_update_freq,
|
||||
).to(args.device)
|
||||
# buffer
|
||||
if args.prioritized_replay:
|
||||
buf = PrioritizedVectorReplayBuffer(
|
||||
args.buffer_size,
|
||||
buffer_num=len(train_envs),
|
||||
alpha=args.alpha,
|
||||
beta=args.beta,
|
||||
)
|
||||
else:
|
||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = TensorboardLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
def train_fn(epoch, env_step):
|
||||
# eps annnealing, just a demo
|
||||
if env_step <= 10000:
|
||||
policy.set_eps(args.eps_train)
|
||||
elif env_step <= 50000:
|
||||
eps = args.eps_train - (env_step - 10000) / \
|
||||
40000 * (0.9 * args.eps_train)
|
||||
policy.set_eps(eps)
|
||||
else:
|
||||
policy.set_eps(0.1 * args.eps_train)
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.step_per_collect,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
train_fn=train_fn,
|
||||
test_fn=test_fn,
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
logger=logger,
|
||||
update_per_step=args.update_per_step,
|
||||
)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
# save buffer in pickle format, for imitation learning unittest
|
||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
|
||||
policy.set_eps(0.2)
|
||||
collector = Collector(policy, test_envs, buf, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
pickle.dump(buf, open(args.save_buffer_name, "wb"))
|
||||
print(result["rews"].mean())
|
||||
return buf
|
@ -13,7 +13,13 @@ from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import DiscreteBCQPolicy
|
||||
from tianshou.trainer import offline_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.common import ActorCritic, Net
|
||||
from tianshou.utils.net.discrete import Actor
|
||||
|
||||
if __name__ == "__main__":
|
||||
from gather_cartpole_data import gather_data
|
||||
else: # pytest
|
||||
from test.offline.gather_cartpole_data import gather_data
|
||||
|
||||
|
||||
def get_args():
|
||||
@ -37,7 +43,7 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--load-buffer-name",
|
||||
type=str,
|
||||
default="./expert_DQN_CartPole-v0.pkl",
|
||||
default="./expert_QRDQN_CartPole-v0.pkl",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
@ -65,21 +71,15 @@ def test_discrete_bcq(args=get_args()):
|
||||
torch.manual_seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
policy_net = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
device=args.device
|
||||
net = Net(args.state_shape, args.hidden_sizes[0], device=args.device)
|
||||
policy_net = Actor(
|
||||
net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device
|
||||
).to(args.device)
|
||||
imitation_net = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
device=args.device
|
||||
imitation_net = Actor(
|
||||
net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device
|
||||
).to(args.device)
|
||||
optim = torch.optim.Adam(
|
||||
list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr
|
||||
)
|
||||
actor_critic = ActorCritic(policy_net, imitation_net)
|
||||
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
|
||||
|
||||
policy = DiscreteBCQPolicy(
|
||||
policy_net,
|
||||
@ -93,9 +93,10 @@ def test_discrete_bcq(args=get_args()):
|
||||
args.imitation_logits_penalty,
|
||||
)
|
||||
# buffer
|
||||
assert os.path.exists(args.load_buffer_name), \
|
||||
"Please run test_dqn.py first to get expert's data buffer."
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
else:
|
||||
buffer = gather_data()
|
||||
|
||||
# collector
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
@ -15,6 +15,11 @@ from tianshou.trainer import offline_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
|
||||
if __name__ == "__main__":
|
||||
from gather_cartpole_data import gather_data
|
||||
else: # pytest
|
||||
from test.offline.gather_cartpole_data import gather_data
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -83,9 +88,10 @@ def test_discrete_cql(args=get_args()):
|
||||
min_q_weight=args.min_q_weight
|
||||
).to(args.device)
|
||||
# buffer
|
||||
assert os.path.exists(args.load_buffer_name), \
|
||||
"Please run test_qrdqn.py first to get expert's data buffer."
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
else:
|
||||
buffer = gather_data()
|
||||
|
||||
# collector
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
@ -13,7 +13,13 @@ from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import DiscreteCRRPolicy
|
||||
from tianshou.trainer import offline_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.common import ActorCritic, Net
|
||||
from tianshou.utils.net.discrete import Actor, Critic
|
||||
|
||||
if __name__ == "__main__":
|
||||
from gather_cartpole_data import gather_data
|
||||
else: # pytest
|
||||
from test.offline.gather_cartpole_data import gather_data
|
||||
|
||||
|
||||
def get_args():
|
||||
@ -34,7 +40,7 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--load-buffer-name",
|
||||
type=str,
|
||||
default="./expert_DQN_CartPole-v0.pkl",
|
||||
default="./expert_QRDQN_CartPole-v0.pkl",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
@ -60,23 +66,22 @@ def test_discrete_crr(args=get_args()):
|
||||
torch.manual_seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
actor = Net(
|
||||
args.state_shape,
|
||||
net = Net(args.state_shape, args.hidden_sizes[0], device=args.device)
|
||||
actor = Actor(
|
||||
net,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
device=args.device,
|
||||
softmax=False
|
||||
softmax_output=False
|
||||
)
|
||||
critic = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
critic = Critic(
|
||||
net,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
device=args.device,
|
||||
softmax=False
|
||||
)
|
||||
optim = torch.optim.Adam(
|
||||
list(actor.parameters()) + list(critic.parameters()), lr=args.lr
|
||||
last_size=np.prod(args.action_shape),
|
||||
device=args.device
|
||||
)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
|
||||
|
||||
policy = DiscreteCRRPolicy(
|
||||
actor,
|
||||
@ -86,14 +91,15 @@ def test_discrete_crr(args=get_args()):
|
||||
target_update_freq=args.target_update_freq,
|
||||
).to(args.device)
|
||||
# buffer
|
||||
assert os.path.exists(args.load_buffer_name), \
|
||||
"Please run test_dqn.py first to get expert's data buffer."
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
else:
|
||||
buffer = gather_data()
|
||||
|
||||
# collector
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
|
||||
log_path = os.path.join(args.logdir, args.task, 'discrete_cql')
|
||||
log_path = os.path.join(args.logdir, args.task, 'discrete_crr')
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = TensorboardLogger(writer)
|
||||
|
@ -1,6 +1,6 @@
|
||||
from tianshou import data, env, exploration, policy, trainer, utils
|
||||
|
||||
__version__ = "0.4.4"
|
||||
__version__ = "0.4.5"
|
||||
|
||||
__all__ = [
|
||||
"env",
|
||||
|
@ -83,14 +83,14 @@ class DiscreteCRRPolicy(PGPolicy):
|
||||
if self._target and self._iter % self._freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
q_t, _ = self.critic(batch.obs)
|
||||
q_t = self.critic(batch.obs)
|
||||
act = to_torch(batch.act, dtype=torch.long, device=q_t.device)
|
||||
qa_t = q_t.gather(1, act.unsqueeze(1))
|
||||
# Critic loss
|
||||
with torch.no_grad():
|
||||
target_a_t, _ = self.actor_old(batch.obs_next)
|
||||
target_m = Categorical(logits=target_a_t)
|
||||
q_t_target, _ = self.critic_old(batch.obs_next)
|
||||
q_t_target = self.critic_old(batch.obs_next)
|
||||
rew = to_torch_as(batch.rew, q_t_target)
|
||||
expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True)
|
||||
expected_target_q[batch.done > 0] = 0.0
|
||||
|
@ -14,7 +14,7 @@ from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
|
||||
def offline_trainer(
|
||||
policy: BasePolicy,
|
||||
buffer: ReplayBuffer,
|
||||
test_collector: Collector,
|
||||
test_collector: Optional[Collector],
|
||||
max_epoch: int,
|
||||
update_per_epoch: int,
|
||||
episode_per_test: int,
|
||||
@ -33,7 +33,8 @@ def offline_trainer(
|
||||
The "step" in offline trainer means a gradient step.
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||
:param Collector test_collector: the collector used for testing.
|
||||
:param Collector test_collector: the collector used for testing. If it's None, then
|
||||
no testing will be performed.
|
||||
:param int max_epoch: the maximum number of epochs for training. The training
|
||||
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
|
||||
:param int update_per_epoch: the number of policy network updates, so-called
|
||||
@ -73,14 +74,16 @@ def offline_trainer(
|
||||
start_epoch, _, gradient_step = logger.restore_data()
|
||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||
start_time = time.time()
|
||||
test_collector.reset_stat()
|
||||
|
||||
test_result = test_episode(
|
||||
policy, test_collector, test_fn, start_epoch, episode_per_test, logger,
|
||||
gradient_step, reward_metric
|
||||
)
|
||||
best_epoch = start_epoch
|
||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||
if test_collector is not None:
|
||||
test_c: Collector = test_collector
|
||||
test_collector.reset_stat()
|
||||
test_result = test_episode(
|
||||
policy, test_c, test_fn, start_epoch, episode_per_test, logger,
|
||||
gradient_step, reward_metric
|
||||
)
|
||||
best_epoch = start_epoch
|
||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
|
||||
@ -97,22 +100,32 @@ def offline_trainer(
|
||||
data[k] = f"{losses[k]:.3f}"
|
||||
logger.log_update_data(losses, gradient_step)
|
||||
t.set_postfix(**data)
|
||||
# test
|
||||
test_result = test_episode(
|
||||
policy, test_collector, test_fn, epoch, episode_per_test, logger,
|
||||
gradient_step, reward_metric
|
||||
)
|
||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||
if best_epoch < 0 or best_reward < rew:
|
||||
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn)
|
||||
if verbose:
|
||||
print(
|
||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
|
||||
# test
|
||||
if test_collector is not None:
|
||||
test_result = test_episode(
|
||||
policy, test_c, test_fn, epoch, episode_per_test, logger,
|
||||
gradient_step, reward_metric
|
||||
)
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
return gather_info(start_time, None, test_collector, best_reward, best_reward_std)
|
||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||
if best_epoch < 0 or best_reward < rew:
|
||||
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
if verbose:
|
||||
print(
|
||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
|
||||
)
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
|
||||
if test_collector is None and save_fn:
|
||||
save_fn(policy)
|
||||
|
||||
if test_collector is None:
|
||||
return gather_info(start_time, None, None, 0.0, 0.0)
|
||||
else:
|
||||
return gather_info(
|
||||
start_time, None, test_collector, best_reward, best_reward_std
|
||||
)
|
||||
|
@ -14,7 +14,7 @@ from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
|
||||
def offpolicy_trainer(
|
||||
policy: BasePolicy,
|
||||
train_collector: Collector,
|
||||
test_collector: Collector,
|
||||
test_collector: Optional[Collector],
|
||||
max_epoch: int,
|
||||
step_per_epoch: int,
|
||||
step_per_collect: int,
|
||||
@ -38,7 +38,8 @@ def offpolicy_trainer(
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||
:param Collector train_collector: the collector used for training.
|
||||
:param Collector test_collector: the collector used for testing.
|
||||
:param Collector test_collector: the collector used for testing. If it's None, then
|
||||
no testing will be performed.
|
||||
:param int max_epoch: the maximum number of epochs for training. The training
|
||||
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
|
||||
:param int step_per_epoch: the number of transitions collected per epoch.
|
||||
@ -90,14 +91,19 @@ def offpolicy_trainer(
|
||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||
start_time = time.time()
|
||||
train_collector.reset_stat()
|
||||
test_collector.reset_stat()
|
||||
test_in_train = test_in_train and train_collector.policy == policy
|
||||
test_result = test_episode(
|
||||
policy, test_collector, test_fn, start_epoch, episode_per_test, logger,
|
||||
env_step, reward_metric
|
||||
test_in_train = test_in_train and (
|
||||
train_collector.policy == policy and test_collector is not None
|
||||
)
|
||||
best_epoch = start_epoch
|
||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||
|
||||
if test_collector is not None:
|
||||
test_c: Collector = test_collector # for mypy
|
||||
test_collector.reset_stat()
|
||||
test_result = test_episode(
|
||||
policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step,
|
||||
reward_metric
|
||||
)
|
||||
best_epoch = start_epoch
|
||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
|
||||
@ -129,8 +135,8 @@ def offpolicy_trainer(
|
||||
if result["n/ep"] > 0:
|
||||
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
||||
test_result = test_episode(
|
||||
policy, test_collector, test_fn, epoch, episode_per_test,
|
||||
logger, env_step
|
||||
policy, test_c, test_fn, epoch, episode_per_test, logger,
|
||||
env_step
|
||||
)
|
||||
if stop_fn(test_result["rew"]):
|
||||
if save_fn:
|
||||
@ -156,24 +162,32 @@ def offpolicy_trainer(
|
||||
t.set_postfix(**data)
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
# test
|
||||
test_result = test_episode(
|
||||
policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step,
|
||||
reward_metric
|
||||
)
|
||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||
if best_epoch < 0 or best_reward < rew:
|
||||
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
|
||||
if verbose:
|
||||
print(
|
||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
|
||||
# test
|
||||
if test_collector is not None:
|
||||
test_result = test_episode(
|
||||
policy, test_c, test_fn, epoch, episode_per_test, logger, env_step,
|
||||
reward_metric
|
||||
)
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
return gather_info(
|
||||
start_time, train_collector, test_collector, best_reward, best_reward_std
|
||||
)
|
||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||
if best_epoch < 0 or best_reward < rew:
|
||||
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
if verbose:
|
||||
print(
|
||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
|
||||
)
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
|
||||
if test_collector is None and save_fn:
|
||||
save_fn(policy)
|
||||
|
||||
if test_collector is None:
|
||||
return gather_info(start_time, train_collector, None, 0.0, 0.0)
|
||||
else:
|
||||
return gather_info(
|
||||
start_time, train_collector, test_collector, best_reward, best_reward_std
|
||||
)
|
||||
|
@ -14,7 +14,7 @@ from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
|
||||
def onpolicy_trainer(
|
||||
policy: BasePolicy,
|
||||
train_collector: Collector,
|
||||
test_collector: Collector,
|
||||
test_collector: Optional[Collector],
|
||||
max_epoch: int,
|
||||
step_per_epoch: int,
|
||||
repeat_per_collect: int,
|
||||
@ -39,7 +39,8 @@ def onpolicy_trainer(
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||
:param Collector train_collector: the collector used for training.
|
||||
:param Collector test_collector: the collector used for testing.
|
||||
:param Collector test_collector: the collector used for testing. If it's None, then
|
||||
no testing will be performed.
|
||||
:param int max_epoch: the maximum number of epochs for training. The training
|
||||
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
|
||||
:param int step_per_epoch: the number of transitions collected per epoch.
|
||||
@ -96,14 +97,19 @@ def onpolicy_trainer(
|
||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||
start_time = time.time()
|
||||
train_collector.reset_stat()
|
||||
test_collector.reset_stat()
|
||||
test_in_train = test_in_train and train_collector.policy == policy
|
||||
test_result = test_episode(
|
||||
policy, test_collector, test_fn, start_epoch, episode_per_test, logger,
|
||||
env_step, reward_metric
|
||||
test_in_train = test_in_train and (
|
||||
train_collector.policy == policy and test_collector is not None
|
||||
)
|
||||
best_epoch = start_epoch
|
||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||
|
||||
if test_collector is not None:
|
||||
test_c: Collector = test_collector # for mypy
|
||||
test_collector.reset_stat()
|
||||
test_result = test_episode(
|
||||
policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step,
|
||||
reward_metric
|
||||
)
|
||||
best_epoch = start_epoch
|
||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
|
||||
@ -137,8 +143,8 @@ def onpolicy_trainer(
|
||||
if result["n/ep"] > 0:
|
||||
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
||||
test_result = test_episode(
|
||||
policy, test_collector, test_fn, epoch, episode_per_test,
|
||||
logger, env_step
|
||||
policy, test_c, test_fn, epoch, episode_per_test, logger,
|
||||
env_step
|
||||
)
|
||||
if stop_fn(test_result["rew"]):
|
||||
if save_fn:
|
||||
@ -172,24 +178,32 @@ def onpolicy_trainer(
|
||||
t.set_postfix(**data)
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
# test
|
||||
test_result = test_episode(
|
||||
policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step,
|
||||
reward_metric
|
||||
)
|
||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||
if best_epoch < 0 or best_reward < rew:
|
||||
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
|
||||
if verbose:
|
||||
print(
|
||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
|
||||
# test
|
||||
if test_collector is not None:
|
||||
test_result = test_episode(
|
||||
policy, test_c, test_fn, epoch, episode_per_test, logger, env_step,
|
||||
reward_metric
|
||||
)
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
return gather_info(
|
||||
start_time, train_collector, test_collector, best_reward, best_reward_std
|
||||
)
|
||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||
if best_epoch < 0 or best_reward < rew:
|
||||
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
if verbose:
|
||||
print(
|
||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
|
||||
)
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
|
||||
if test_collector is None and save_fn:
|
||||
save_fn(policy)
|
||||
|
||||
if test_collector is None:
|
||||
return gather_info(start_time, train_collector, None, 0.0, 0.0)
|
||||
else:
|
||||
return gather_info(
|
||||
start_time, train_collector, test_collector, best_reward, best_reward_std
|
||||
)
|
||||
|
@ -36,7 +36,7 @@ def test_episode(
|
||||
def gather_info(
|
||||
start_time: float,
|
||||
train_c: Optional[Collector],
|
||||
test_c: Collector,
|
||||
test_c: Optional[Collector],
|
||||
best_reward: float,
|
||||
best_reward_std: float,
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
@ -58,21 +58,32 @@ def gather_info(
|
||||
* ``duration`` the total elapsed time.
|
||||
"""
|
||||
duration = time.time() - start_time
|
||||
model_time = duration - test_c.collect_time
|
||||
test_speed = test_c.collect_step / test_c.collect_time
|
||||
model_time = duration
|
||||
result: Dict[str, Union[float, str]] = {
|
||||
"test_step": test_c.collect_step,
|
||||
"test_episode": test_c.collect_episode,
|
||||
"test_time": f"{test_c.collect_time:.2f}s",
|
||||
"test_speed": f"{test_speed:.2f} step/s",
|
||||
"best_reward": best_reward,
|
||||
"best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}",
|
||||
"duration": f"{duration:.2f}s",
|
||||
"train_time/model": f"{model_time:.2f}s",
|
||||
}
|
||||
if test_c is not None:
|
||||
model_time = duration - test_c.collect_time
|
||||
test_speed = test_c.collect_step / test_c.collect_time
|
||||
result.update(
|
||||
{
|
||||
"test_step": test_c.collect_step,
|
||||
"test_episode": test_c.collect_episode,
|
||||
"test_time": f"{test_c.collect_time:.2f}s",
|
||||
"test_speed": f"{test_speed:.2f} step/s",
|
||||
"best_reward": best_reward,
|
||||
"best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}",
|
||||
"duration": f"{duration:.2f}s",
|
||||
"train_time/model": f"{model_time:.2f}s",
|
||||
}
|
||||
)
|
||||
if train_c is not None:
|
||||
model_time -= train_c.collect_time
|
||||
train_speed = train_c.collect_step / (duration - test_c.collect_time)
|
||||
if test_c is not None:
|
||||
train_speed = train_c.collect_step / (duration - test_c.collect_time)
|
||||
else:
|
||||
train_speed = train_c.collect_step / duration
|
||||
result.update(
|
||||
{
|
||||
"train_step": train_c.collect_step,
|
||||
|
@ -35,6 +35,7 @@ class TensorboardLogger(BaseLogger):
|
||||
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
||||
for k, v in data.items():
|
||||
self.writer.add_scalar(k, v, global_step=step)
|
||||
self.writer.flush() # issue #482
|
||||
|
||||
def save_data(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user