Fix critic network for Discrete CRR (#485)
- Fixes an inconsistency in the implementation of Discrete CRR. Now it uses `Critic` class for its critic, following conventions in other actor-critic policies; - Updates several offline policies to use `ActorCritic` class for its optimizer to eliminate randomness caused by parameter sharing between actor and critic; - Add `writer.flush()` in TensorboardLogger to ensure real-time result; - Enable `test_collector=None` in 3 trainers to turn off testing during training; - Updates the Atari offline results in README.md; - Moves Atari offline RL examples to `examples/offline`; tests to `test/offline` per review comments.
This commit is contained in:
parent
5c5a3db94e
commit
3592f45446
0
examples/__init__.py
Normal file
0
examples/__init__.py
Normal file
@ -1,4 +1,4 @@
|
|||||||
# Atari General
|
# Atari
|
||||||
|
|
||||||
The sample speed is \~3000 env step per second (\~12000 Atari frame per second in fact since we use frame_stack=4) under the normal mode (use a CNN policy and a collector, also storing data into the buffer). The main bottleneck is training the convolutional neural network.
|
The sample speed is \~3000 env step per second (\~12000 Atari frame per second in fact since we use frame_stack=4) under the normal mode (use a CNN policy and a collector, also storing data into the buffer). The main bottleneck is training the convolutional neural network.
|
||||||
|
|
||||||
@ -95,66 +95,3 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
|||||||
| MsPacmanNoFrameskip-v4 | 3101 |  | `python3 atari_rainbow.py --task "MsPacmanNoFrameskip-v4"` |
|
| MsPacmanNoFrameskip-v4 | 3101 |  | `python3 atari_rainbow.py --task "MsPacmanNoFrameskip-v4"` |
|
||||||
| SeaquestNoFrameskip-v4 | 2126 |  | `python3 atari_rainbow.py --task "SeaquestNoFrameskip-v4"` |
|
| SeaquestNoFrameskip-v4 | 2126 |  | `python3 atari_rainbow.py --task "SeaquestNoFrameskip-v4"` |
|
||||||
| SpaceInvadersNoFrameskip-v4 | 1794.5 |  | `python3 atari_rainbow.py --task "SpaceInvadersNoFrameskip-v4"` |
|
| SpaceInvadersNoFrameskip-v4 | 1794.5 |  | `python3 atari_rainbow.py --task "SpaceInvadersNoFrameskip-v4"` |
|
||||||
|
|
||||||
# BCQ
|
|
||||||
|
|
||||||
To running BCQ algorithm on Atari, you need to do the following things:
|
|
||||||
|
|
||||||
- Train an expert, by using the command listed in the above DQN section;
|
|
||||||
- Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
|
|
||||||
- Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`.
|
|
||||||
|
|
||||||
We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
|
|
||||||
|
|
||||||
| Task | Online DQN | Behavioral | BCQ |
|
|
||||||
| ---------------------- | ---------- | ---------- | --------------------------------- |
|
|
||||||
| PongNoFrameskip-v4 | 21 | 7.7 | 21 (epoch 5) |
|
|
||||||
| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) |
|
|
||||||
|
|
||||||
# CQL
|
|
||||||
|
|
||||||
To running CQL algorithm on Atari, you need to do the following things:
|
|
||||||
|
|
||||||
- Train an expert, by using the command listed in the above QRDQN section;
|
|
||||||
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
|
|
||||||
- Train CQL: `python3 atari_cql.py --task {your_task} --load-buffer-name expert.hdf5`.
|
|
||||||
|
|
||||||
We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
|
|
||||||
|
|
||||||
| Task | Online QRDQN | Behavioral | CQL | parameters |
|
|
||||||
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
|
|
||||||
| PongNoFrameskip-v4 | 20.5 | 6.8 | 19.5 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
|
|
||||||
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 248.3 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
|
|
||||||
|
|
||||||
We reduce the size of the offline data to 10% and 1% of the above and get:
|
|
||||||
|
|
||||||
Buffer size 100000:
|
|
||||||
|
|
||||||
| Task | Online QRDQN | Behavioral | CQL | parameters |
|
|
||||||
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
|
|
||||||
| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` |
|
|
||||||
| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` |
|
|
||||||
|
|
||||||
Buffer size 10000:
|
|
||||||
|
|
||||||
| Task | Online QRDQN | Behavioral | CQL | parameters |
|
|
||||||
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
|
|
||||||
| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` |
|
|
||||||
| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` |
|
|
||||||
|
|
||||||
# CRR
|
|
||||||
|
|
||||||
To running CRR algorithm on Atari, you need to do the following things:
|
|
||||||
|
|
||||||
- Train an expert, by using the command listed in the above QRDQN section;
|
|
||||||
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
|
|
||||||
- Train CQL: `python3 atari_crr.py --task {your_task} --load-buffer-name expert.hdf5`.
|
|
||||||
|
|
||||||
We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
|
|
||||||
|
|
||||||
| Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters |
|
|
||||||
| ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ |
|
|
||||||
| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 16.1 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
|
|
||||||
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 26.4 (epoch 12) | 125.0 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
|
|
||||||
|
|
||||||
Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps.
|
|
||||||
|
0
examples/atari/__init__.py
Normal file
0
examples/atari/__init__.py
Normal file
@ -2,9 +2,11 @@
|
|||||||
|
|
||||||
In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore.
|
In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore.
|
||||||
|
|
||||||
Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets.
|
## Continous control
|
||||||
|
|
||||||
## Train
|
Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent for continuous control. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets.
|
||||||
|
|
||||||
|
### Train
|
||||||
|
|
||||||
Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset.
|
Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset.
|
||||||
|
|
||||||
@ -26,3 +28,59 @@ After 1M steps:
|
|||||||
| --------------------- | --------------- |
|
| --------------------- | --------------- |
|
||||||
| halfcheetah-expert-v1 | 10624.0 ± 181.4 |
|
| halfcheetah-expert-v1 | 10624.0 ± 181.4 |
|
||||||
|
|
||||||
|
## Discrete control
|
||||||
|
|
||||||
|
For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent. In the future, we can switch to better benchmarks such as the Atari portion of [RL Unplugged](https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged).
|
||||||
|
|
||||||
|
### Gather Data
|
||||||
|
|
||||||
|
To running CQL algorithm on Atari, you need to do the following things:
|
||||||
|
|
||||||
|
- Train an expert, by using the command listed in the QRDQN section of Atari examples: `python3 atari_qrdqn.py --task {your_task}`
|
||||||
|
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
|
||||||
|
- Train offline model: `python3 atari_{bcq,cql,crr}.py --task {your_task} --load-buffer-name expert.hdf5`.
|
||||||
|
|
||||||
|
### BCQ
|
||||||
|
|
||||||
|
We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
|
||||||
|
|
||||||
|
| Task | Online QRDQN | Behavioral | BCQ | parameters |
|
||||||
|
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
|
||||||
|
| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.1 (epoch 5) | `python3 atari_bcq.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
|
||||||
|
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 64.6 (epoch 12, could be higher) | `python3 atari_bcq.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12` |
|
||||||
|
|
||||||
|
### CQL
|
||||||
|
|
||||||
|
We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
|
||||||
|
|
||||||
|
| Task | Online QRDQN | Behavioral | CQL | parameters |
|
||||||
|
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
|
||||||
|
| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.4 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
|
||||||
|
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 129.4 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
|
||||||
|
|
||||||
|
We reduce the size of the offline data to 10% and 1% of the above and get:
|
||||||
|
|
||||||
|
Buffer size 100000:
|
||||||
|
|
||||||
|
| Task | Online QRDQN | Behavioral | CQL | parameters |
|
||||||
|
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
|
||||||
|
| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` |
|
||||||
|
| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` |
|
||||||
|
|
||||||
|
Buffer size 10000:
|
||||||
|
|
||||||
|
| Task | Online QRDQN | Behavioral | CQL | parameters |
|
||||||
|
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
|
||||||
|
| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` |
|
||||||
|
| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` |
|
||||||
|
|
||||||
|
### CRR
|
||||||
|
|
||||||
|
We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
|
||||||
|
|
||||||
|
| Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters |
|
||||||
|
| ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ |
|
||||||
|
| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 17.7 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
|
||||||
|
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
|
||||||
|
|
||||||
|
Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps.
|
||||||
|
0
examples/offline/__init__.py
Normal file
0
examples/offline/__init__.py
Normal file
@ -6,15 +6,16 @@ import pprint
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from atari_network import DQN
|
|
||||||
from atari_wrapper import wrap_deepmind
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from examples.atari.atari_network import DQN
|
||||||
|
from examples.atari.atari_wrapper import wrap_deepmind
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.env import ShmemVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
from tianshou.policy import DiscreteBCQPolicy
|
from tianshou.policy import DiscreteBCQPolicy
|
||||||
from tianshou.trainer import offline_trainer
|
from tianshou.trainer import offline_trainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
|
from tianshou.utils.net.common import ActorCritic
|
||||||
from tianshou.utils.net.discrete import Actor
|
from tianshou.utils.net.discrete import Actor
|
||||||
|
|
||||||
|
|
||||||
@ -93,18 +94,17 @@ def test_discrete_bcq(args=get_args()):
|
|||||||
args.action_shape,
|
args.action_shape,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
hidden_sizes=args.hidden_sizes,
|
hidden_sizes=args.hidden_sizes,
|
||||||
softmax_output=False
|
softmax_output=False,
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
imitation_net = Actor(
|
imitation_net = Actor(
|
||||||
feature_net,
|
feature_net,
|
||||||
args.action_shape,
|
args.action_shape,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
hidden_sizes=args.hidden_sizes,
|
hidden_sizes=args.hidden_sizes,
|
||||||
softmax_output=False
|
softmax_output=False,
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
optim = torch.optim.Adam(
|
actor_critic = ActorCritic(policy_net, imitation_net)
|
||||||
list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr
|
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
|
||||||
)
|
|
||||||
# define policy
|
# define policy
|
||||||
policy = DiscreteBCQPolicy(
|
policy = DiscreteBCQPolicy(
|
||||||
policy_net, imitation_net, optim, args.gamma, args.n_step,
|
policy_net, imitation_net, optim, args.gamma, args.n_step,
|
||||||
@ -171,7 +171,7 @@ def test_discrete_bcq(args=get_args()):
|
|||||||
args.batch_size,
|
args.batch_size,
|
||||||
stop_fn=stop_fn,
|
stop_fn=stop_fn,
|
||||||
save_fn=save_fn,
|
save_fn=save_fn,
|
||||||
logger=logger
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
@ -6,10 +6,10 @@ import pprint
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from atari_network import QRDQN
|
|
||||||
from atari_wrapper import wrap_deepmind
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from examples.atari.atari_network import QRDQN
|
||||||
|
from examples.atari.atari_wrapper import wrap_deepmind
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.env import ShmemVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
from tianshou.policy import DiscreteCQLPolicy
|
from tianshou.policy import DiscreteCQLPolicy
|
||||||
@ -94,7 +94,7 @@ def test_discrete_cql(args=get_args()):
|
|||||||
args.num_quantiles,
|
args.num_quantiles,
|
||||||
args.n_step,
|
args.n_step,
|
||||||
args.target_update_freq,
|
args.target_update_freq,
|
||||||
min_q_weight=args.min_q_weight
|
min_q_weight=args.min_q_weight,
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
# load a previous policy
|
# load a previous policy
|
||||||
if args.resume_path:
|
if args.resume_path:
|
||||||
@ -156,7 +156,7 @@ def test_discrete_cql(args=get_args()):
|
|||||||
args.batch_size,
|
args.batch_size,
|
||||||
stop_fn=stop_fn,
|
stop_fn=stop_fn,
|
||||||
save_fn=save_fn,
|
save_fn=save_fn,
|
||||||
logger=logger
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
@ -6,16 +6,17 @@ import pprint
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from atari_network import DQN
|
|
||||||
from atari_wrapper import wrap_deepmind
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from examples.atari.atari_network import DQN
|
||||||
|
from examples.atari.atari_wrapper import wrap_deepmind
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.env import ShmemVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
from tianshou.policy import DiscreteCRRPolicy
|
from tianshou.policy import DiscreteCRRPolicy
|
||||||
from tianshou.trainer import offline_trainer
|
from tianshou.trainer import offline_trainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.utils.net.discrete import Actor
|
from tianshou.utils.net.common import ActorCritic
|
||||||
|
from tianshou.utils.net.discrete import Actor, Critic
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -91,15 +92,18 @@ def test_discrete_crr(args=get_args()):
|
|||||||
actor = Actor(
|
actor = Actor(
|
||||||
feature_net,
|
feature_net,
|
||||||
args.action_shape,
|
args.action_shape,
|
||||||
device=args.device,
|
|
||||||
hidden_sizes=args.hidden_sizes,
|
hidden_sizes=args.hidden_sizes,
|
||||||
softmax_output=False
|
device=args.device,
|
||||||
|
softmax_output=False,
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
critic = DQN(*args.state_shape, args.action_shape,
|
critic = Critic(
|
||||||
device=args.device).to(args.device)
|
feature_net,
|
||||||
optim = torch.optim.Adam(
|
hidden_sizes=args.hidden_sizes,
|
||||||
list(actor.parameters()) + list(critic.parameters()), lr=args.lr
|
last_size=np.prod(args.action_shape),
|
||||||
)
|
device=args.device,
|
||||||
|
).to(args.device)
|
||||||
|
actor_critic = ActorCritic(actor, critic)
|
||||||
|
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
|
||||||
# define policy
|
# define policy
|
||||||
policy = DiscreteCRRPolicy(
|
policy = DiscreteCRRPolicy(
|
||||||
actor,
|
actor,
|
||||||
@ -110,7 +114,7 @@ def test_discrete_crr(args=get_args()):
|
|||||||
ratio_upper_bound=args.ratio_upper_bound,
|
ratio_upper_bound=args.ratio_upper_bound,
|
||||||
beta=args.beta,
|
beta=args.beta,
|
||||||
min_q_weight=args.min_q_weight,
|
min_q_weight=args.min_q_weight,
|
||||||
target_update_freq=args.target_update_freq
|
target_update_freq=args.target_update_freq,
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
# load a previous policy
|
# load a previous policy
|
||||||
if args.resume_path:
|
if args.resume_path:
|
||||||
@ -171,7 +175,7 @@ def test_discrete_crr(args=get_args()):
|
|||||||
args.batch_size,
|
args.batch_size,
|
||||||
stop_fn=stop_fn,
|
stop_fn=stop_fn,
|
||||||
save_fn=save_fn,
|
save_fn=save_fn,
|
||||||
logger=logger
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
@ -1,6 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
import pprint
|
import pprint
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
@ -42,9 +41,6 @@ def get_args():
|
|||||||
parser.add_argument('--prioritized-replay', action="store_true", default=False)
|
parser.add_argument('--prioritized-replay', action="store_true", default=False)
|
||||||
parser.add_argument('--alpha', type=float, default=0.6)
|
parser.add_argument('--alpha', type=float, default=0.6)
|
||||||
parser.add_argument('--beta', type=float, default=0.4)
|
parser.add_argument('--beta', type=float, default=0.4)
|
||||||
parser.add_argument(
|
|
||||||
'--save-buffer-name', type=str, default="./expert_DQN_CartPole-v0.pkl"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
)
|
)
|
||||||
@ -85,7 +81,7 @@ def test_dqn(args=get_args()):
|
|||||||
optim,
|
optim,
|
||||||
args.gamma,
|
args.gamma,
|
||||||
args.n_step,
|
args.n_step,
|
||||||
target_update_freq=args.target_update_freq
|
target_update_freq=args.target_update_freq,
|
||||||
)
|
)
|
||||||
# buffer
|
# buffer
|
||||||
if args.prioritized_replay:
|
if args.prioritized_replay:
|
||||||
@ -93,7 +89,7 @@ def test_dqn(args=get_args()):
|
|||||||
args.buffer_size,
|
args.buffer_size,
|
||||||
buffer_num=len(train_envs),
|
buffer_num=len(train_envs),
|
||||||
alpha=args.alpha,
|
alpha=args.alpha,
|
||||||
beta=args.beta
|
beta=args.beta,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
||||||
@ -142,7 +138,7 @@ def test_dqn(args=get_args()):
|
|||||||
test_fn=test_fn,
|
test_fn=test_fn,
|
||||||
stop_fn=stop_fn,
|
stop_fn=stop_fn,
|
||||||
save_fn=save_fn,
|
save_fn=save_fn,
|
||||||
logger=logger
|
logger=logger,
|
||||||
)
|
)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
|
|
||||||
@ -157,14 +153,6 @@ def test_dqn(args=get_args()):
|
|||||||
rews, lens = result["rews"], result["lens"]
|
rews, lens = result["rews"], result["lens"]
|
||||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||||
|
|
||||||
# save buffer in pickle format, for imitation learning unittest
|
|
||||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
|
|
||||||
policy.set_eps(0.2)
|
|
||||||
collector = Collector(policy, test_envs, buf, exploration_noise=True)
|
|
||||||
result = collector.collect(n_step=args.buffer_size)
|
|
||||||
pickle.dump(buf, open(args.save_buffer_name, "wb"))
|
|
||||||
print(result["rews"].mean())
|
|
||||||
|
|
||||||
|
|
||||||
def test_pdqn(args=get_args()):
|
def test_pdqn(args=get_args()):
|
||||||
args.prioritized_replay = True
|
args.prioritized_replay = True
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
import pprint
|
import pprint
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
@ -43,9 +42,6 @@ def get_args():
|
|||||||
parser.add_argument('--prioritized-replay', action="store_true", default=False)
|
parser.add_argument('--prioritized-replay', action="store_true", default=False)
|
||||||
parser.add_argument('--alpha', type=float, default=0.6)
|
parser.add_argument('--alpha', type=float, default=0.6)
|
||||||
parser.add_argument('--beta', type=float, default=0.4)
|
parser.add_argument('--beta', type=float, default=0.4)
|
||||||
parser.add_argument(
|
|
||||||
'--save-buffer-name', type=str, default="./expert_QRDQN_CartPole-v0.pkl"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
)
|
)
|
||||||
@ -80,7 +76,7 @@ def test_qrdqn(args=get_args()):
|
|||||||
hidden_sizes=args.hidden_sizes,
|
hidden_sizes=args.hidden_sizes,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
softmax=False,
|
softmax=False,
|
||||||
num_atoms=args.num_quantiles
|
num_atoms=args.num_quantiles,
|
||||||
)
|
)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
policy = QRDQNPolicy(
|
policy = QRDQNPolicy(
|
||||||
@ -89,7 +85,7 @@ def test_qrdqn(args=get_args()):
|
|||||||
args.gamma,
|
args.gamma,
|
||||||
args.num_quantiles,
|
args.num_quantiles,
|
||||||
args.n_step,
|
args.n_step,
|
||||||
target_update_freq=args.target_update_freq
|
target_update_freq=args.target_update_freq,
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
# buffer
|
# buffer
|
||||||
if args.prioritized_replay:
|
if args.prioritized_replay:
|
||||||
@ -97,7 +93,7 @@ def test_qrdqn(args=get_args()):
|
|||||||
args.buffer_size,
|
args.buffer_size,
|
||||||
buffer_num=len(train_envs),
|
buffer_num=len(train_envs),
|
||||||
alpha=args.alpha,
|
alpha=args.alpha,
|
||||||
beta=args.beta
|
beta=args.beta,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
||||||
@ -146,7 +142,7 @@ def test_qrdqn(args=get_args()):
|
|||||||
stop_fn=stop_fn,
|
stop_fn=stop_fn,
|
||||||
save_fn=save_fn,
|
save_fn=save_fn,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
update_per_step=args.update_per_step
|
update_per_step=args.update_per_step,
|
||||||
)
|
)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
|
|
||||||
@ -161,14 +157,6 @@ def test_qrdqn(args=get_args()):
|
|||||||
rews, lens = result["rews"], result["lens"]
|
rews, lens = result["rews"], result["lens"]
|
||||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||||
|
|
||||||
# save buffer in pickle format, for imitation learning unittest
|
|
||||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
|
|
||||||
policy.set_eps(0.9) # 10% of expert data as demonstrated in the original paper
|
|
||||||
collector = Collector(policy, test_envs, buf, exploration_noise=True)
|
|
||||||
result = collector.collect(n_step=args.buffer_size)
|
|
||||||
pickle.dump(buf, open(args.save_buffer_name, "wb"))
|
|
||||||
print(result["rews"].mean())
|
|
||||||
|
|
||||||
|
|
||||||
def test_pqrdqn(args=get_args()):
|
def test_pqrdqn(args=get_args()):
|
||||||
args.prioritized_replay = True
|
args.prioritized_replay = True
|
||||||
|
160
test/offline/gather_cartpole_data.py
Normal file
160
test/offline/gather_cartpole_data.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
|
||||||
|
from tianshou.env import DummyVectorEnv
|
||||||
|
from tianshou.policy import QRDQNPolicy
|
||||||
|
from tianshou.trainer import offpolicy_trainer
|
||||||
|
from tianshou.utils import TensorboardLogger
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||||
|
parser.add_argument('--seed', type=int, default=1)
|
||||||
|
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||||
|
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
|
parser.add_argument('--num-quantiles', type=int, default=200)
|
||||||
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
||||||
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
|
parser.add_argument(
|
||||||
|
'--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
|
||||||
|
)
|
||||||
|
parser.add_argument('--training-num', type=int, default=10)
|
||||||
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument('--prioritized-replay', action="store_true", default=False)
|
||||||
|
parser.add_argument('--alpha', type=float, default=0.6)
|
||||||
|
parser.add_argument('--beta', type=float, default=0.4)
|
||||||
|
parser.add_argument(
|
||||||
|
'--save-buffer-name', type=str, default="./expert_QRDQN_CartPole-v0.pkl"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
)
|
||||||
|
args = parser.parse_known_args()[0]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def gather_data():
|
||||||
|
args = get_args()
|
||||||
|
env = gym.make(args.task)
|
||||||
|
if args.task == 'CartPole-v0':
|
||||||
|
env.spec.reward_threshold = 190 # lower the goal
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
# train_envs = gym.make(args.task)
|
||||||
|
# you can also use tianshou.env.SubprocVectorEnv
|
||||||
|
train_envs = DummyVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.training_num)]
|
||||||
|
)
|
||||||
|
# test_envs = gym.make(args.task)
|
||||||
|
test_envs = DummyVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||||
|
)
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# model
|
||||||
|
net = Net(
|
||||||
|
args.state_shape,
|
||||||
|
args.action_shape,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
device=args.device,
|
||||||
|
softmax=False,
|
||||||
|
num_atoms=args.num_quantiles,
|
||||||
|
)
|
||||||
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
|
policy = QRDQNPolicy(
|
||||||
|
net,
|
||||||
|
optim,
|
||||||
|
args.gamma,
|
||||||
|
args.num_quantiles,
|
||||||
|
args.n_step,
|
||||||
|
target_update_freq=args.target_update_freq,
|
||||||
|
).to(args.device)
|
||||||
|
# buffer
|
||||||
|
if args.prioritized_replay:
|
||||||
|
buf = PrioritizedVectorReplayBuffer(
|
||||||
|
args.buffer_size,
|
||||||
|
buffer_num=len(train_envs),
|
||||||
|
alpha=args.alpha,
|
||||||
|
beta=args.beta,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
||||||
|
# collector
|
||||||
|
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||||
|
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||||
|
# policy.set_eps(1)
|
||||||
|
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||||
|
# log
|
||||||
|
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
|
def stop_fn(mean_rewards):
|
||||||
|
return mean_rewards >= env.spec.reward_threshold
|
||||||
|
|
||||||
|
def train_fn(epoch, env_step):
|
||||||
|
# eps annnealing, just a demo
|
||||||
|
if env_step <= 10000:
|
||||||
|
policy.set_eps(args.eps_train)
|
||||||
|
elif env_step <= 50000:
|
||||||
|
eps = args.eps_train - (env_step - 10000) / \
|
||||||
|
40000 * (0.9 * args.eps_train)
|
||||||
|
policy.set_eps(eps)
|
||||||
|
else:
|
||||||
|
policy.set_eps(0.1 * args.eps_train)
|
||||||
|
|
||||||
|
def test_fn(epoch, env_step):
|
||||||
|
policy.set_eps(args.eps_test)
|
||||||
|
|
||||||
|
# trainer
|
||||||
|
result = offpolicy_trainer(
|
||||||
|
policy,
|
||||||
|
train_collector,
|
||||||
|
test_collector,
|
||||||
|
args.epoch,
|
||||||
|
args.step_per_epoch,
|
||||||
|
args.step_per_collect,
|
||||||
|
args.test_num,
|
||||||
|
args.batch_size,
|
||||||
|
train_fn=train_fn,
|
||||||
|
test_fn=test_fn,
|
||||||
|
stop_fn=stop_fn,
|
||||||
|
save_fn=save_fn,
|
||||||
|
logger=logger,
|
||||||
|
update_per_step=args.update_per_step,
|
||||||
|
)
|
||||||
|
assert stop_fn(result['best_reward'])
|
||||||
|
|
||||||
|
# save buffer in pickle format, for imitation learning unittest
|
||||||
|
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
|
||||||
|
policy.set_eps(0.2)
|
||||||
|
collector = Collector(policy, test_envs, buf, exploration_noise=True)
|
||||||
|
result = collector.collect(n_step=args.buffer_size)
|
||||||
|
pickle.dump(buf, open(args.save_buffer_name, "wb"))
|
||||||
|
print(result["rews"].mean())
|
||||||
|
return buf
|
@ -13,7 +13,13 @@ from tianshou.env import DummyVectorEnv
|
|||||||
from tianshou.policy import DiscreteBCQPolicy
|
from tianshou.policy import DiscreteBCQPolicy
|
||||||
from tianshou.trainer import offline_trainer
|
from tianshou.trainer import offline_trainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import ActorCritic, Net
|
||||||
|
from tianshou.utils.net.discrete import Actor
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from gather_cartpole_data import gather_data
|
||||||
|
else: # pytest
|
||||||
|
from test.offline.gather_cartpole_data import gather_data
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -37,7 +43,7 @@ def get_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--load-buffer-name",
|
"--load-buffer-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="./expert_DQN_CartPole-v0.pkl",
|
default="./expert_QRDQN_CartPole-v0.pkl",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device",
|
"--device",
|
||||||
@ -65,21 +71,15 @@ def test_discrete_bcq(args=get_args()):
|
|||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
policy_net = Net(
|
net = Net(args.state_shape, args.hidden_sizes[0], device=args.device)
|
||||||
args.state_shape,
|
policy_net = Actor(
|
||||||
args.action_shape,
|
net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device
|
||||||
hidden_sizes=args.hidden_sizes,
|
|
||||||
device=args.device
|
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
imitation_net = Net(
|
imitation_net = Actor(
|
||||||
args.state_shape,
|
net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device
|
||||||
args.action_shape,
|
|
||||||
hidden_sizes=args.hidden_sizes,
|
|
||||||
device=args.device
|
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
optim = torch.optim.Adam(
|
actor_critic = ActorCritic(policy_net, imitation_net)
|
||||||
list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr
|
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
|
||||||
)
|
|
||||||
|
|
||||||
policy = DiscreteBCQPolicy(
|
policy = DiscreteBCQPolicy(
|
||||||
policy_net,
|
policy_net,
|
||||||
@ -93,9 +93,10 @@ def test_discrete_bcq(args=get_args()):
|
|||||||
args.imitation_logits_penalty,
|
args.imitation_logits_penalty,
|
||||||
)
|
)
|
||||||
# buffer
|
# buffer
|
||||||
assert os.path.exists(args.load_buffer_name), \
|
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
||||||
"Please run test_dqn.py first to get expert's data buffer."
|
|
||||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||||
|
else:
|
||||||
|
buffer = gather_data()
|
||||||
|
|
||||||
# collector
|
# collector
|
||||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
@ -15,6 +15,11 @@ from tianshou.trainer import offline_trainer
|
|||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from gather_cartpole_data import gather_data
|
||||||
|
else: # pytest
|
||||||
|
from test.offline.gather_cartpole_data import gather_data
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -83,9 +88,10 @@ def test_discrete_cql(args=get_args()):
|
|||||||
min_q_weight=args.min_q_weight
|
min_q_weight=args.min_q_weight
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
# buffer
|
# buffer
|
||||||
assert os.path.exists(args.load_buffer_name), \
|
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
||||||
"Please run test_qrdqn.py first to get expert's data buffer."
|
|
||||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||||
|
else:
|
||||||
|
buffer = gather_data()
|
||||||
|
|
||||||
# collector
|
# collector
|
||||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
@ -13,7 +13,13 @@ from tianshou.env import DummyVectorEnv
|
|||||||
from tianshou.policy import DiscreteCRRPolicy
|
from tianshou.policy import DiscreteCRRPolicy
|
||||||
from tianshou.trainer import offline_trainer
|
from tianshou.trainer import offline_trainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import ActorCritic, Net
|
||||||
|
from tianshou.utils.net.discrete import Actor, Critic
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from gather_cartpole_data import gather_data
|
||||||
|
else: # pytest
|
||||||
|
from test.offline.gather_cartpole_data import gather_data
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -34,7 +40,7 @@ def get_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--load-buffer-name",
|
"--load-buffer-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="./expert_DQN_CartPole-v0.pkl",
|
default="./expert_QRDQN_CartPole-v0.pkl",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device",
|
"--device",
|
||||||
@ -60,23 +66,22 @@ def test_discrete_crr(args=get_args()):
|
|||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
actor = Net(
|
net = Net(args.state_shape, args.hidden_sizes[0], device=args.device)
|
||||||
args.state_shape,
|
actor = Actor(
|
||||||
|
net,
|
||||||
args.action_shape,
|
args.action_shape,
|
||||||
hidden_sizes=args.hidden_sizes,
|
hidden_sizes=args.hidden_sizes,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
softmax=False
|
softmax_output=False
|
||||||
)
|
)
|
||||||
critic = Net(
|
critic = Critic(
|
||||||
args.state_shape,
|
net,
|
||||||
args.action_shape,
|
|
||||||
hidden_sizes=args.hidden_sizes,
|
hidden_sizes=args.hidden_sizes,
|
||||||
device=args.device,
|
last_size=np.prod(args.action_shape),
|
||||||
softmax=False
|
device=args.device
|
||||||
)
|
|
||||||
optim = torch.optim.Adam(
|
|
||||||
list(actor.parameters()) + list(critic.parameters()), lr=args.lr
|
|
||||||
)
|
)
|
||||||
|
actor_critic = ActorCritic(actor, critic)
|
||||||
|
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
|
||||||
|
|
||||||
policy = DiscreteCRRPolicy(
|
policy = DiscreteCRRPolicy(
|
||||||
actor,
|
actor,
|
||||||
@ -86,14 +91,15 @@ def test_discrete_crr(args=get_args()):
|
|||||||
target_update_freq=args.target_update_freq,
|
target_update_freq=args.target_update_freq,
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
# buffer
|
# buffer
|
||||||
assert os.path.exists(args.load_buffer_name), \
|
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
||||||
"Please run test_dqn.py first to get expert's data buffer."
|
|
||||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||||
|
else:
|
||||||
|
buffer = gather_data()
|
||||||
|
|
||||||
# collector
|
# collector
|
||||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||||
|
|
||||||
log_path = os.path.join(args.logdir, args.task, 'discrete_cql')
|
log_path = os.path.join(args.logdir, args.task, 'discrete_crr')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = TensorboardLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
@ -1,6 +1,6 @@
|
|||||||
from tianshou import data, env, exploration, policy, trainer, utils
|
from tianshou import data, env, exploration, policy, trainer, utils
|
||||||
|
|
||||||
__version__ = "0.4.4"
|
__version__ = "0.4.5"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"env",
|
"env",
|
||||||
|
@ -83,14 +83,14 @@ class DiscreteCRRPolicy(PGPolicy):
|
|||||||
if self._target and self._iter % self._freq == 0:
|
if self._target and self._iter % self._freq == 0:
|
||||||
self.sync_weight()
|
self.sync_weight()
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
q_t, _ = self.critic(batch.obs)
|
q_t = self.critic(batch.obs)
|
||||||
act = to_torch(batch.act, dtype=torch.long, device=q_t.device)
|
act = to_torch(batch.act, dtype=torch.long, device=q_t.device)
|
||||||
qa_t = q_t.gather(1, act.unsqueeze(1))
|
qa_t = q_t.gather(1, act.unsqueeze(1))
|
||||||
# Critic loss
|
# Critic loss
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
target_a_t, _ = self.actor_old(batch.obs_next)
|
target_a_t, _ = self.actor_old(batch.obs_next)
|
||||||
target_m = Categorical(logits=target_a_t)
|
target_m = Categorical(logits=target_a_t)
|
||||||
q_t_target, _ = self.critic_old(batch.obs_next)
|
q_t_target = self.critic_old(batch.obs_next)
|
||||||
rew = to_torch_as(batch.rew, q_t_target)
|
rew = to_torch_as(batch.rew, q_t_target)
|
||||||
expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True)
|
expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True)
|
||||||
expected_target_q[batch.done > 0] = 0.0
|
expected_target_q[batch.done > 0] = 0.0
|
||||||
|
@ -14,7 +14,7 @@ from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
|
|||||||
def offline_trainer(
|
def offline_trainer(
|
||||||
policy: BasePolicy,
|
policy: BasePolicy,
|
||||||
buffer: ReplayBuffer,
|
buffer: ReplayBuffer,
|
||||||
test_collector: Collector,
|
test_collector: Optional[Collector],
|
||||||
max_epoch: int,
|
max_epoch: int,
|
||||||
update_per_epoch: int,
|
update_per_epoch: int,
|
||||||
episode_per_test: int,
|
episode_per_test: int,
|
||||||
@ -33,7 +33,8 @@ def offline_trainer(
|
|||||||
The "step" in offline trainer means a gradient step.
|
The "step" in offline trainer means a gradient step.
|
||||||
|
|
||||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||||
:param Collector test_collector: the collector used for testing.
|
:param Collector test_collector: the collector used for testing. If it's None, then
|
||||||
|
no testing will be performed.
|
||||||
:param int max_epoch: the maximum number of epochs for training. The training
|
:param int max_epoch: the maximum number of epochs for training. The training
|
||||||
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
|
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
|
||||||
:param int update_per_epoch: the number of policy network updates, so-called
|
:param int update_per_epoch: the number of policy network updates, so-called
|
||||||
@ -73,10 +74,12 @@ def offline_trainer(
|
|||||||
start_epoch, _, gradient_step = logger.restore_data()
|
start_epoch, _, gradient_step = logger.restore_data()
|
||||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
test_collector.reset_stat()
|
|
||||||
|
|
||||||
|
if test_collector is not None:
|
||||||
|
test_c: Collector = test_collector
|
||||||
|
test_collector.reset_stat()
|
||||||
test_result = test_episode(
|
test_result = test_episode(
|
||||||
policy, test_collector, test_fn, start_epoch, episode_per_test, logger,
|
policy, test_c, test_fn, start_epoch, episode_per_test, logger,
|
||||||
gradient_step, reward_metric
|
gradient_step, reward_metric
|
||||||
)
|
)
|
||||||
best_epoch = start_epoch
|
best_epoch = start_epoch
|
||||||
@ -97,9 +100,11 @@ def offline_trainer(
|
|||||||
data[k] = f"{losses[k]:.3f}"
|
data[k] = f"{losses[k]:.3f}"
|
||||||
logger.log_update_data(losses, gradient_step)
|
logger.log_update_data(losses, gradient_step)
|
||||||
t.set_postfix(**data)
|
t.set_postfix(**data)
|
||||||
|
logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn)
|
||||||
# test
|
# test
|
||||||
|
if test_collector is not None:
|
||||||
test_result = test_episode(
|
test_result = test_episode(
|
||||||
policy, test_collector, test_fn, epoch, episode_per_test, logger,
|
policy, test_c, test_fn, epoch, episode_per_test, logger,
|
||||||
gradient_step, reward_metric
|
gradient_step, reward_metric
|
||||||
)
|
)
|
||||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||||
@ -107,7 +112,6 @@ def offline_trainer(
|
|||||||
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
||||||
if save_fn:
|
if save_fn:
|
||||||
save_fn(policy)
|
save_fn(policy)
|
||||||
logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn)
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(
|
print(
|
||||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||||
@ -115,4 +119,13 @@ def offline_trainer(
|
|||||||
)
|
)
|
||||||
if stop_fn and stop_fn(best_reward):
|
if stop_fn and stop_fn(best_reward):
|
||||||
break
|
break
|
||||||
return gather_info(start_time, None, test_collector, best_reward, best_reward_std)
|
|
||||||
|
if test_collector is None and save_fn:
|
||||||
|
save_fn(policy)
|
||||||
|
|
||||||
|
if test_collector is None:
|
||||||
|
return gather_info(start_time, None, None, 0.0, 0.0)
|
||||||
|
else:
|
||||||
|
return gather_info(
|
||||||
|
start_time, None, test_collector, best_reward, best_reward_std
|
||||||
|
)
|
||||||
|
@ -14,7 +14,7 @@ from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
|
|||||||
def offpolicy_trainer(
|
def offpolicy_trainer(
|
||||||
policy: BasePolicy,
|
policy: BasePolicy,
|
||||||
train_collector: Collector,
|
train_collector: Collector,
|
||||||
test_collector: Collector,
|
test_collector: Optional[Collector],
|
||||||
max_epoch: int,
|
max_epoch: int,
|
||||||
step_per_epoch: int,
|
step_per_epoch: int,
|
||||||
step_per_collect: int,
|
step_per_collect: int,
|
||||||
@ -38,7 +38,8 @@ def offpolicy_trainer(
|
|||||||
|
|
||||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||||
:param Collector train_collector: the collector used for training.
|
:param Collector train_collector: the collector used for training.
|
||||||
:param Collector test_collector: the collector used for testing.
|
:param Collector test_collector: the collector used for testing. If it's None, then
|
||||||
|
no testing will be performed.
|
||||||
:param int max_epoch: the maximum number of epochs for training. The training
|
:param int max_epoch: the maximum number of epochs for training. The training
|
||||||
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
|
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
|
||||||
:param int step_per_epoch: the number of transitions collected per epoch.
|
:param int step_per_epoch: the number of transitions collected per epoch.
|
||||||
@ -90,11 +91,16 @@ def offpolicy_trainer(
|
|||||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
train_collector.reset_stat()
|
train_collector.reset_stat()
|
||||||
|
test_in_train = test_in_train and (
|
||||||
|
train_collector.policy == policy and test_collector is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
if test_collector is not None:
|
||||||
|
test_c: Collector = test_collector # for mypy
|
||||||
test_collector.reset_stat()
|
test_collector.reset_stat()
|
||||||
test_in_train = test_in_train and train_collector.policy == policy
|
|
||||||
test_result = test_episode(
|
test_result = test_episode(
|
||||||
policy, test_collector, test_fn, start_epoch, episode_per_test, logger,
|
policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step,
|
||||||
env_step, reward_metric
|
reward_metric
|
||||||
)
|
)
|
||||||
best_epoch = start_epoch
|
best_epoch = start_epoch
|
||||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||||
@ -129,8 +135,8 @@ def offpolicy_trainer(
|
|||||||
if result["n/ep"] > 0:
|
if result["n/ep"] > 0:
|
||||||
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
||||||
test_result = test_episode(
|
test_result = test_episode(
|
||||||
policy, test_collector, test_fn, epoch, episode_per_test,
|
policy, test_c, test_fn, epoch, episode_per_test, logger,
|
||||||
logger, env_step
|
env_step
|
||||||
)
|
)
|
||||||
if stop_fn(test_result["rew"]):
|
if stop_fn(test_result["rew"]):
|
||||||
if save_fn:
|
if save_fn:
|
||||||
@ -156,9 +162,11 @@ def offpolicy_trainer(
|
|||||||
t.set_postfix(**data)
|
t.set_postfix(**data)
|
||||||
if t.n <= t.total:
|
if t.n <= t.total:
|
||||||
t.update()
|
t.update()
|
||||||
|
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
|
||||||
# test
|
# test
|
||||||
|
if test_collector is not None:
|
||||||
test_result = test_episode(
|
test_result = test_episode(
|
||||||
policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step,
|
policy, test_c, test_fn, epoch, episode_per_test, logger, env_step,
|
||||||
reward_metric
|
reward_metric
|
||||||
)
|
)
|
||||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||||
@ -166,7 +174,6 @@ def offpolicy_trainer(
|
|||||||
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
||||||
if save_fn:
|
if save_fn:
|
||||||
save_fn(policy)
|
save_fn(policy)
|
||||||
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(
|
print(
|
||||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||||
@ -174,6 +181,13 @@ def offpolicy_trainer(
|
|||||||
)
|
)
|
||||||
if stop_fn and stop_fn(best_reward):
|
if stop_fn and stop_fn(best_reward):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if test_collector is None and save_fn:
|
||||||
|
save_fn(policy)
|
||||||
|
|
||||||
|
if test_collector is None:
|
||||||
|
return gather_info(start_time, train_collector, None, 0.0, 0.0)
|
||||||
|
else:
|
||||||
return gather_info(
|
return gather_info(
|
||||||
start_time, train_collector, test_collector, best_reward, best_reward_std
|
start_time, train_collector, test_collector, best_reward, best_reward_std
|
||||||
)
|
)
|
||||||
|
@ -14,7 +14,7 @@ from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
|
|||||||
def onpolicy_trainer(
|
def onpolicy_trainer(
|
||||||
policy: BasePolicy,
|
policy: BasePolicy,
|
||||||
train_collector: Collector,
|
train_collector: Collector,
|
||||||
test_collector: Collector,
|
test_collector: Optional[Collector],
|
||||||
max_epoch: int,
|
max_epoch: int,
|
||||||
step_per_epoch: int,
|
step_per_epoch: int,
|
||||||
repeat_per_collect: int,
|
repeat_per_collect: int,
|
||||||
@ -39,7 +39,8 @@ def onpolicy_trainer(
|
|||||||
|
|
||||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||||
:param Collector train_collector: the collector used for training.
|
:param Collector train_collector: the collector used for training.
|
||||||
:param Collector test_collector: the collector used for testing.
|
:param Collector test_collector: the collector used for testing. If it's None, then
|
||||||
|
no testing will be performed.
|
||||||
:param int max_epoch: the maximum number of epochs for training. The training
|
:param int max_epoch: the maximum number of epochs for training. The training
|
||||||
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
|
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
|
||||||
:param int step_per_epoch: the number of transitions collected per epoch.
|
:param int step_per_epoch: the number of transitions collected per epoch.
|
||||||
@ -96,11 +97,16 @@ def onpolicy_trainer(
|
|||||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
train_collector.reset_stat()
|
train_collector.reset_stat()
|
||||||
|
test_in_train = test_in_train and (
|
||||||
|
train_collector.policy == policy and test_collector is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
if test_collector is not None:
|
||||||
|
test_c: Collector = test_collector # for mypy
|
||||||
test_collector.reset_stat()
|
test_collector.reset_stat()
|
||||||
test_in_train = test_in_train and train_collector.policy == policy
|
|
||||||
test_result = test_episode(
|
test_result = test_episode(
|
||||||
policy, test_collector, test_fn, start_epoch, episode_per_test, logger,
|
policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step,
|
||||||
env_step, reward_metric
|
reward_metric
|
||||||
)
|
)
|
||||||
best_epoch = start_epoch
|
best_epoch = start_epoch
|
||||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||||
@ -137,8 +143,8 @@ def onpolicy_trainer(
|
|||||||
if result["n/ep"] > 0:
|
if result["n/ep"] > 0:
|
||||||
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
||||||
test_result = test_episode(
|
test_result = test_episode(
|
||||||
policy, test_collector, test_fn, epoch, episode_per_test,
|
policy, test_c, test_fn, epoch, episode_per_test, logger,
|
||||||
logger, env_step
|
env_step
|
||||||
)
|
)
|
||||||
if stop_fn(test_result["rew"]):
|
if stop_fn(test_result["rew"]):
|
||||||
if save_fn:
|
if save_fn:
|
||||||
@ -172,9 +178,11 @@ def onpolicy_trainer(
|
|||||||
t.set_postfix(**data)
|
t.set_postfix(**data)
|
||||||
if t.n <= t.total:
|
if t.n <= t.total:
|
||||||
t.update()
|
t.update()
|
||||||
|
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
|
||||||
# test
|
# test
|
||||||
|
if test_collector is not None:
|
||||||
test_result = test_episode(
|
test_result = test_episode(
|
||||||
policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step,
|
policy, test_c, test_fn, epoch, episode_per_test, logger, env_step,
|
||||||
reward_metric
|
reward_metric
|
||||||
)
|
)
|
||||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||||
@ -182,7 +190,6 @@ def onpolicy_trainer(
|
|||||||
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
||||||
if save_fn:
|
if save_fn:
|
||||||
save_fn(policy)
|
save_fn(policy)
|
||||||
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(
|
print(
|
||||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||||
@ -190,6 +197,13 @@ def onpolicy_trainer(
|
|||||||
)
|
)
|
||||||
if stop_fn and stop_fn(best_reward):
|
if stop_fn and stop_fn(best_reward):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if test_collector is None and save_fn:
|
||||||
|
save_fn(policy)
|
||||||
|
|
||||||
|
if test_collector is None:
|
||||||
|
return gather_info(start_time, train_collector, None, 0.0, 0.0)
|
||||||
|
else:
|
||||||
return gather_info(
|
return gather_info(
|
||||||
start_time, train_collector, test_collector, best_reward, best_reward_std
|
start_time, train_collector, test_collector, best_reward, best_reward_std
|
||||||
)
|
)
|
||||||
|
@ -36,7 +36,7 @@ def test_episode(
|
|||||||
def gather_info(
|
def gather_info(
|
||||||
start_time: float,
|
start_time: float,
|
||||||
train_c: Optional[Collector],
|
train_c: Optional[Collector],
|
||||||
test_c: Collector,
|
test_c: Optional[Collector],
|
||||||
best_reward: float,
|
best_reward: float,
|
||||||
best_reward_std: float,
|
best_reward_std: float,
|
||||||
) -> Dict[str, Union[float, str]]:
|
) -> Dict[str, Union[float, str]]:
|
||||||
@ -58,9 +58,16 @@ def gather_info(
|
|||||||
* ``duration`` the total elapsed time.
|
* ``duration`` the total elapsed time.
|
||||||
"""
|
"""
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
|
model_time = duration
|
||||||
|
result: Dict[str, Union[float, str]] = {
|
||||||
|
"duration": f"{duration:.2f}s",
|
||||||
|
"train_time/model": f"{model_time:.2f}s",
|
||||||
|
}
|
||||||
|
if test_c is not None:
|
||||||
model_time = duration - test_c.collect_time
|
model_time = duration - test_c.collect_time
|
||||||
test_speed = test_c.collect_step / test_c.collect_time
|
test_speed = test_c.collect_step / test_c.collect_time
|
||||||
result: Dict[str, Union[float, str]] = {
|
result.update(
|
||||||
|
{
|
||||||
"test_step": test_c.collect_step,
|
"test_step": test_c.collect_step,
|
||||||
"test_episode": test_c.collect_episode,
|
"test_episode": test_c.collect_episode,
|
||||||
"test_time": f"{test_c.collect_time:.2f}s",
|
"test_time": f"{test_c.collect_time:.2f}s",
|
||||||
@ -70,9 +77,13 @@ def gather_info(
|
|||||||
"duration": f"{duration:.2f}s",
|
"duration": f"{duration:.2f}s",
|
||||||
"train_time/model": f"{model_time:.2f}s",
|
"train_time/model": f"{model_time:.2f}s",
|
||||||
}
|
}
|
||||||
|
)
|
||||||
if train_c is not None:
|
if train_c is not None:
|
||||||
model_time -= train_c.collect_time
|
model_time -= train_c.collect_time
|
||||||
|
if test_c is not None:
|
||||||
train_speed = train_c.collect_step / (duration - test_c.collect_time)
|
train_speed = train_c.collect_step / (duration - test_c.collect_time)
|
||||||
|
else:
|
||||||
|
train_speed = train_c.collect_step / duration
|
||||||
result.update(
|
result.update(
|
||||||
{
|
{
|
||||||
"train_step": train_c.collect_step,
|
"train_step": train_c.collect_step,
|
||||||
|
@ -35,6 +35,7 @@ class TensorboardLogger(BaseLogger):
|
|||||||
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
self.writer.add_scalar(k, v, global_step=step)
|
self.writer.add_scalar(k, v, global_step=step)
|
||||||
|
self.writer.flush() # issue #482
|
||||||
|
|
||||||
def save_data(
|
def save_data(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user