add logger (#295)
This PR focus on refactor of logging method to solve bug of nan reward and log interval. After these two pr, hopefully fundamental change of tianshou/data is finished. We then can concentrate on building benchmarks of tianshou finally. Things changed: 1. trainer now accepts logger (BasicLogger or LazyLogger) instead of writer; 2. remove utils.SummaryWriter;
This commit is contained in:
parent
e99e1b0fdd
commit
9b61bc620c
@ -197,6 +197,7 @@ buffer_size = 20000
|
|||||||
eps_train, eps_test = 0.1, 0.05
|
eps_train, eps_test = 0.1, 0.05
|
||||||
step_per_epoch, step_per_collect = 10000, 10
|
step_per_epoch, step_per_collect = 10000, 10
|
||||||
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
|
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
|
||||||
|
logger = ts.utils.BasicLogger(writer)
|
||||||
```
|
```
|
||||||
|
|
||||||
Make environments:
|
Make environments:
|
||||||
@ -237,7 +238,7 @@ result = ts.trainer.offpolicy_trainer(
|
|||||||
train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
|
train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
|
||||||
test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
|
test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
|
||||||
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
|
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
|
||||||
writer=writer)
|
logger=logger)
|
||||||
print(f'Finished training! Use {result["duration"]}')
|
print(f'Finished training! Use {result["duration"]}')
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@ -7,3 +7,4 @@ We always welcome contributions to help make Tianshou better. Below are an incom
|
|||||||
* Minghao Zhang (`Mehooz <https://github.com/Mehooz>`_)
|
* Minghao Zhang (`Mehooz <https://github.com/Mehooz>`_)
|
||||||
* Alexis Duburcq (`duburcqa <https://github.com/duburcqa>`_)
|
* Alexis Duburcq (`duburcqa <https://github.com/duburcqa>`_)
|
||||||
* Kaichao You (`youkaichao <https://github.com/youkaichao>`_)
|
* Kaichao You (`youkaichao <https://github.com/youkaichao>`_)
|
||||||
|
* Huayu Chen (`ChenDRAG <https://github.com/ChenDRAG>`_)
|
||||||
|
|||||||
@ -130,7 +130,7 @@ Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.t
|
|||||||
train_fn=lambda epoch, env_step: policy.set_eps(0.1),
|
train_fn=lambda epoch, env_step: policy.set_eps(0.1),
|
||||||
test_fn=lambda epoch, env_step: policy.set_eps(0.05),
|
test_fn=lambda epoch, env_step: policy.set_eps(0.05),
|
||||||
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
|
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
|
||||||
writer=None)
|
logger=None)
|
||||||
print(f'Finished training! Use {result["duration"]}')
|
print(f'Finished training! Use {result["duration"]}')
|
||||||
|
|
||||||
The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`):
|
The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`):
|
||||||
@ -143,15 +143,17 @@ The meaning of each parameter is as follows (full description can be found at :f
|
|||||||
* ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
|
* ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
|
||||||
* ``test_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing".
|
* ``test_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing".
|
||||||
* ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal.
|
* ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal.
|
||||||
* ``writer``: See below.
|
* ``logger``: See below.
|
||||||
|
|
||||||
The trainer supports `TensorBoard <https://www.tensorflow.org/tensorboard>`_ for logging. It can be used as:
|
The trainer supports `TensorBoard <https://www.tensorflow.org/tensorboard>`_ for logging. It can be used as:
|
||||||
::
|
::
|
||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
writer = SummaryWriter('log/dqn')
|
writer = SummaryWriter('log/dqn')
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
Pass the writer into the trainer, and the training result will be recorded into the TensorBoard.
|
Pass the logger into the trainer, and the training result will be recorded into the TensorBoard.
|
||||||
|
|
||||||
The returned result is a dictionary as follows:
|
The returned result is a dictionary as follows:
|
||||||
::
|
::
|
||||||
|
|||||||
@ -176,6 +176,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
|
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
@ -319,11 +320,10 @@ With the above preparation, we are close to the first learned agent. The followi
|
|||||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||||
|
|
||||||
# ======== tensorboard logging setup =========
|
# ======== tensorboard logging setup =========
|
||||||
if not hasattr(args, 'writer'):
|
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
|
||||||
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
|
writer = SummaryWriter(log_path)
|
||||||
writer = SummaryWriter(log_path)
|
writer.add_text("args", str(args))
|
||||||
else:
|
logger = BasicLogger(writer)
|
||||||
writer = args.writer
|
|
||||||
|
|
||||||
# ======== callback functions used during training =========
|
# ======== callback functions used during training =========
|
||||||
|
|
||||||
@ -359,7 +359,7 @@ With the above preparation, we are close to the first learned agent. The followi
|
|||||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step,
|
stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step,
|
||||||
writer=writer, test_in_train=False, reward_metric=reward_metric)
|
logger=logger, test_in_train=False, reward_metric=reward_metric)
|
||||||
|
|
||||||
agent = policy.policies[args.agent_id - 1]
|
agent = policy.policies[args.agent_id - 1]
|
||||||
# let's watch the match!
|
# let's watch the match!
|
||||||
|
|||||||
@ -2,10 +2,12 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
import pickle
|
import pickle
|
||||||
import pprint
|
import pprint
|
||||||
|
import datetime
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import offline_trainer
|
from tianshou.trainer import offline_trainer
|
||||||
from tianshou.utils.net.discrete import Actor
|
from tianshou.utils.net.discrete import Actor
|
||||||
@ -39,7 +41,7 @@ def get_args():
|
|||||||
parser.add_argument("--resume-path", type=str, default=None)
|
parser.add_argument("--resume-path", type=str, default=None)
|
||||||
parser.add_argument("--watch", default=False, action="store_true",
|
parser.add_argument("--watch", default=False, action="store_true",
|
||||||
help="watch the play of pre-trained policy only")
|
help="watch the play of pre-trained policy only")
|
||||||
parser.add_argument("--log-interval", type=int, default=1000)
|
parser.add_argument("--log-interval", type=int, default=100)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--load-buffer-name", type=str,
|
"--load-buffer-name", type=str,
|
||||||
default="./expert_DQN_PongNoFrameskip-v4.hdf5",
|
default="./expert_DQN_PongNoFrameskip-v4.hdf5",
|
||||||
@ -113,8 +115,13 @@ def test_discrete_bcq(args=get_args()):
|
|||||||
# collector
|
# collector
|
||||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||||
|
|
||||||
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
|
# log
|
||||||
|
log_path = os.path.join(
|
||||||
|
args.logdir, args.task, 'bcq',
|
||||||
|
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
writer.add_text("args", str(args))
|
||||||
|
logger = BasicLogger(writer, update_interval=args.log_interval)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -141,7 +148,7 @@ def test_discrete_bcq(args=get_args()):
|
|||||||
result = offline_trainer(
|
result = offline_trainer(
|
||||||
policy, buffer, test_collector,
|
policy, buffer, test_collector,
|
||||||
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
|
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||||
log_interval=args.log_interval,
|
log_interval=args.log_interval,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import C51Policy
|
from tianshou.policy import C51Policy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
@ -98,6 +99,8 @@ def test_c51(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'c51')
|
log_path = os.path.join(args.logdir, args.task, 'c51')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
writer.add_text("args", str(args))
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -118,7 +121,7 @@ def test_c51(args=get_args()):
|
|||||||
else:
|
else:
|
||||||
eps = args.eps_train_final
|
eps = args.eps_train_final
|
||||||
policy.set_eps(eps)
|
policy.set_eps(eps)
|
||||||
writer.add_scalar('train/eps', eps, global_step=env_step)
|
logger.write('train/eps', env_step, eps)
|
||||||
|
|
||||||
def test_fn(epoch, env_step):
|
def test_fn(epoch, env_step):
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
@ -144,7 +147,7 @@ def test_c51(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||||
update_per_step=args.update_per_step, test_in_train=False)
|
update_per_step=args.update_per_step, test_in_train=False)
|
||||||
|
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
@ -94,6 +95,8 @@ def test_dqn(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
writer.add_text("args", str(args))
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -114,7 +117,7 @@ def test_dqn(args=get_args()):
|
|||||||
else:
|
else:
|
||||||
eps = args.eps_train_final
|
eps = args.eps_train_final
|
||||||
policy.set_eps(eps)
|
policy.set_eps(eps)
|
||||||
writer.add_scalar('train/eps', eps, global_step=env_step)
|
logger.write('train/eps', env_step, eps)
|
||||||
|
|
||||||
def test_fn(epoch, env_step):
|
def test_fn(epoch, env_step):
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
@ -154,7 +157,7 @@ def test_dqn(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||||
update_per_step=args.update_per_step, test_in_train=False)
|
update_per_step=args.update_per_step, test_in_train=False)
|
||||||
|
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import argparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.policy import QRDQNPolicy
|
from tianshou.policy import QRDQNPolicy
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -96,6 +97,8 @@ def test_qrdqn(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
|
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
writer.add_text("args", str(args))
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -116,7 +119,7 @@ def test_qrdqn(args=get_args()):
|
|||||||
else:
|
else:
|
||||||
eps = args.eps_train_final
|
eps = args.eps_train_final
|
||||||
policy.set_eps(eps)
|
policy.set_eps(eps)
|
||||||
writer.add_scalar('train/eps', eps, global_step=env_step)
|
logger.write('train/eps', env_step, eps)
|
||||||
|
|
||||||
def test_fn(epoch, env_step):
|
def test_fn(epoch, env_step):
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
@ -142,7 +145,7 @@ def test_qrdqn(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||||
update_per_step=args.update_per_step, test_in_train=False)
|
update_per_step=args.update_per_step, test_in_train=False)
|
||||||
|
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -4,7 +4,9 @@ import pprint
|
|||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import A2CPolicy
|
from tianshou.policy import A2CPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
@ -79,7 +81,9 @@ def test_a2c(args=get_args()):
|
|||||||
preprocess_fn=preprocess_fn, exploration_noise=True)
|
preprocess_fn=preprocess_fn, exploration_noise=True)
|
||||||
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
|
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(os.path.join(args.logdir, args.task, 'a2c'))
|
log_path = os.path.join(args.logdir, args.task, 'a2c')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def stop_fn(mean_rewards):
|
def stop_fn(mean_rewards):
|
||||||
if env.env.spec.reward_threshold:
|
if env.env.spec.reward_threshold:
|
||||||
@ -91,7 +95,7 @@ def test_a2c(args=get_args()):
|
|||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
||||||
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer)
|
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, logger=logger)
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
|
|||||||
@ -6,11 +6,12 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import PPOPolicy
|
from tianshou.policy import PPOPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
|
||||||
from tianshou.utils.net.discrete import Actor, Critic
|
from tianshou.utils.net.discrete import Actor, Critic
|
||||||
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
|
|
||||||
from atari import create_atari_environment, preprocess_fn
|
from atari import create_atari_environment, preprocess_fn
|
||||||
|
|
||||||
@ -84,7 +85,9 @@ def test_ppo(args=get_args()):
|
|||||||
preprocess_fn=preprocess_fn, exploration_noise=True)
|
preprocess_fn=preprocess_fn, exploration_noise=True)
|
||||||
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
|
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(os.path.join(args.logdir, args.task, 'ppo'))
|
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def stop_fn(mean_rewards):
|
def stop_fn(mean_rewards):
|
||||||
if env.env.spec.reward_threshold:
|
if env.env.spec.reward_threshold:
|
||||||
@ -96,7 +99,8 @@ def test_ppo(args=get_args()):
|
|||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
||||||
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer)
|
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, logger=logger)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -81,6 +82,7 @@ def test_dqn(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -106,7 +108,7 @@ def test_dqn(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
||||||
update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn,
|
update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||||
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import SACPolicy
|
from tianshou.policy import SACPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -134,6 +135,7 @@ def test_sac_bipedal(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'sac')
|
log_path = os.path.join(args.logdir, args.task, 'sac')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -146,7 +148,7 @@ def test_sac_bipedal(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
||||||
update_per_step=args.update_per_step, test_in_train=False,
|
update_per_step=args.update_per_step, test_in_train=False,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
@ -83,6 +84,7 @@ def test_dqn(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -102,7 +104,7 @@ def test_dqn(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
||||||
update_per_step=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn,
|
update_per_step=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn,
|
||||||
test_fn=test_fn, save_fn=save_fn, writer=writer)
|
test_fn=test_fn, save_fn=save_fn, logger=logger)
|
||||||
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -7,11 +7,12 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import SACPolicy
|
from tianshou.policy import SACPolicy
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.exploration import OUNoise
|
from tianshou.exploration import OUNoise
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
|
from tianshou.trainer import offpolicy_trainer
|
||||||
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||||
|
|
||||||
|
|
||||||
@ -103,6 +104,7 @@ def test_sac(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'sac')
|
log_path = os.path.join(args.logdir, args.task, 'sac')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -115,7 +117,7 @@ def test_sac(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
||||||
update_per_step=args.update_per_step, stop_fn=stop_fn,
|
update_per_step=args.update_per_step, stop_fn=stop_fn,
|
||||||
save_fn=save_fn, writer=writer)
|
save_fn=save_fn, logger=logger)
|
||||||
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -2,11 +2,13 @@ import os
|
|||||||
import gym
|
import gym
|
||||||
import torch
|
import torch
|
||||||
import pprint
|
import pprint
|
||||||
|
import datetime
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import SACPolicy
|
from tianshou.policy import SACPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -114,8 +116,11 @@ def test_sac(args=get_args()):
|
|||||||
exploration_noise=True)
|
exploration_noise=True)
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'sac')
|
log_path = os.path.join(args.logdir, args.task, 'sac', 'seed_' + str(
|
||||||
|
args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S'))
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
writer.add_text("args", str(args))
|
||||||
|
logger = BasicLogger(writer, train_interval=args.log_interval)
|
||||||
|
|
||||||
def watch():
|
def watch():
|
||||||
# watch agent's performance
|
# watch agent's performance
|
||||||
@ -141,8 +146,8 @@ def test_sac(args=get_args()):
|
|||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||||
update_per_step=args.update_per_step, log_interval=args.log_interval)
|
update_per_step=args.update_per_step)
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
watch()
|
watch()
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import gym
|
import gym
|
||||||
import torch
|
import torch
|
||||||
import pprint
|
import pprint
|
||||||
@ -6,6 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import DDPGPolicy
|
from tianshou.policy import DDPGPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -79,7 +81,9 @@ def test_ddpg(args=get_args()):
|
|||||||
exploration_noise=True)
|
exploration_noise=True)
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir + '/' + 'ddpg')
|
log_path = os.path.join(args.logdir, args.task, 'ddpg')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def stop_fn(mean_rewards):
|
def stop_fn(mean_rewards):
|
||||||
return mean_rewards >= env.spec.reward_threshold
|
return mean_rewards >= env.spec.reward_threshold
|
||||||
@ -88,7 +92,7 @@ def test_ddpg(args=get_args()):
|
|||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, logger=logger)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import gym
|
import gym
|
||||||
import torch
|
import torch
|
||||||
import pprint
|
import pprint
|
||||||
@ -6,6 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import TD3Policy
|
from tianshou.policy import TD3Policy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.exploration import GaussianNoise
|
from tianshou.exploration import GaussianNoise
|
||||||
@ -88,7 +90,9 @@ def test_td3(args=get_args()):
|
|||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# train_collector.collect(n_step=args.buffer_size)
|
# train_collector.collect(n_step=args.buffer_size)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir + '/' + 'td3')
|
log_path = os.path.join(args.logdir, args.task, 'td3')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def stop_fn(mean_rewards):
|
def stop_fn(mean_rewards):
|
||||||
return mean_rewards >= env.spec.reward_threshold
|
return mean_rewards >= env.spec.reward_threshold
|
||||||
@ -97,7 +101,7 @@ def test_td3(args=get_args()):
|
|||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, logger=logger)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -2,12 +2,14 @@ import os
|
|||||||
import gym
|
import gym
|
||||||
import torch
|
import torch
|
||||||
import pprint
|
import pprint
|
||||||
|
import datetime
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pybullet_envs
|
import pybullet_envs
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import SACPolicy
|
from tianshou.policy import SACPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -88,8 +90,10 @@ def test_sac(args=get_args()):
|
|||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# train_collector.collect(n_step=args.buffer_size)
|
# train_collector.collect(n_step=args.buffer_size)
|
||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id)
|
log_path = os.path.join(args.logdir, args.task, 'sac', 'seed_' + str(
|
||||||
|
args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S'))
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer, train_interval=args.log_interval)
|
||||||
|
|
||||||
def stop_fn(mean_rewards):
|
def stop_fn(mean_rewards):
|
||||||
return mean_rewards >= env.spec.reward_threshold
|
return mean_rewards >= env.spec.reward_threshold
|
||||||
@ -99,7 +103,7 @@ def test_sac(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn,
|
args.batch_size, stop_fn=stop_fn,
|
||||||
writer=writer, log_interval=args.log_interval)
|
logger=logger)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import gym
|
import gym
|
||||||
import torch
|
import torch
|
||||||
import pprint
|
import pprint
|
||||||
@ -6,6 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import TD3Policy
|
from tianshou.policy import TD3Policy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.exploration import GaussianNoise
|
from tianshou.exploration import GaussianNoise
|
||||||
@ -93,7 +95,9 @@ def test_td3(args=get_args()):
|
|||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# train_collector.collect(n_step=args.buffer_size)
|
# train_collector.collect(n_step=args.buffer_size)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir + '/' + 'td3')
|
log_path = os.path.join(args.logdir, args.task, 'td3')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def stop_fn(mean_rewards):
|
def stop_fn(mean_rewards):
|
||||||
if env.spec.reward_threshold:
|
if env.spec.reward_threshold:
|
||||||
@ -105,7 +109,7 @@ def test_td3(args=get_args()):
|
|||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, logger=logger)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tianshou.utils import MovAvg
|
from tianshou.utils import MovAvg
|
||||||
from tianshou.utils import SummaryWriter
|
|
||||||
from tianshou.utils.net.common import MLP, Net
|
from tianshou.utils.net.common import MLP, Net
|
||||||
from tianshou.exploration import GaussianNoise, OUNoise
|
from tianshou.exploration import GaussianNoise, OUNoise
|
||||||
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
|
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
|
||||||
@ -77,25 +76,7 @@ def test_net():
|
|||||||
assert list(net(data, act).shape) == [bsz, 1]
|
assert list(net(data, act).shape) == [bsz, 1]
|
||||||
|
|
||||||
|
|
||||||
def test_summary_writer():
|
|
||||||
# get first instance by key of `default` or your own key
|
|
||||||
writer1 = SummaryWriter.get_instance(
|
|
||||||
key="first", log_dir="log/test_sw/first")
|
|
||||||
assert writer1.log_dir == "log/test_sw/first"
|
|
||||||
writer2 = SummaryWriter.get_instance()
|
|
||||||
assert writer1 is writer2
|
|
||||||
# create new instance by specify a new key
|
|
||||||
writer3 = SummaryWriter.get_instance(
|
|
||||||
key="second", log_dir="log/test_sw/second")
|
|
||||||
assert writer3.log_dir == "log/test_sw/second"
|
|
||||||
writer4 = SummaryWriter.get_instance(key="second")
|
|
||||||
assert writer3 is writer4
|
|
||||||
assert writer1 is not writer3
|
|
||||||
assert writer1.log_dir != writer4.log_dir
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_noise()
|
test_noise()
|
||||||
test_moving_average()
|
test_moving_average()
|
||||||
test_net()
|
test_net()
|
||||||
test_summary_writer()
|
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import DDPGPolicy
|
from tianshou.policy import DDPGPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -93,6 +94,7 @@ def test_ddpg(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'ddpg')
|
log_path = os.path.join(args.logdir, args.task, 'ddpg')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -105,7 +107,7 @@ def test_ddpg(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
||||||
update_per_step=args.update_per_step, stop_fn=stop_fn,
|
update_per_step=args.update_per_step, stop_fn=stop_fn,
|
||||||
save_fn=save_fn, writer=writer)
|
save_fn=save_fn, logger=logger)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from torch.distributions import Independent, Normal
|
from torch.distributions import Independent, Normal
|
||||||
|
|
||||||
from tianshou.policy import PPOPolicy
|
from tianshou.policy import PPOPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
@ -111,6 +112,7 @@ def test_ppo(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -123,7 +125,7 @@ def test_ppo(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
||||||
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
||||||
writer=writer)
|
logger=logger)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import argparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -102,6 +103,7 @@ def test_sac_with_il(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'sac')
|
log_path = os.path.join(args.logdir, args.task, 'sac')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -114,7 +116,7 @@ def test_sac_with_il(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
||||||
update_per_step=args.update_per_step, stop_fn=stop_fn,
|
update_per_step=args.update_per_step, stop_fn=stop_fn,
|
||||||
save_fn=save_fn, writer=writer)
|
save_fn=save_fn, logger=logger)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
@ -146,7 +148,7 @@ def test_sac_with_il(args=get_args()):
|
|||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
il_policy, train_collector, il_test_collector, args.epoch,
|
il_policy, train_collector, il_test_collector, args.epoch,
|
||||||
args.il_step_per_epoch, args.step_per_collect, args.test_num,
|
args.il_step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import TD3Policy
|
from tianshou.policy import TD3Policy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -106,6 +107,7 @@ def test_td3(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'td3')
|
log_path = os.path.join(args.logdir, args.task, 'td3')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -118,7 +120,7 @@ def test_td3(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
|
||||||
update_per_step=args.update_per_step, stop_fn=stop_fn,
|
update_per_step=args.update_per_step, stop_fn=stop_fn,
|
||||||
save_fn=save_fn, writer=writer)
|
save_fn=save_fn, logger=logger)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import argparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
@ -89,6 +90,7 @@ def test_a2c_with_il(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'a2c')
|
log_path = os.path.join(args.logdir, args.task, 'a2c')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -101,7 +103,7 @@ def test_a2c_with_il(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
||||||
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
||||||
writer=writer)
|
logger=logger)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
@ -130,7 +132,7 @@ def test_a2c_with_il(args=get_args()):
|
|||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
il_policy, train_collector, il_test_collector, args.epoch,
|
il_policy, train_collector, il_test_collector, args.epoch,
|
||||||
args.il_step_per_epoch, args.step_per_collect, args.test_num,
|
args.il_step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import C51Policy
|
from tianshou.policy import C51Policy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -89,6 +90,7 @@ def test_c51(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'c51')
|
log_path = os.path.join(args.logdir, args.task, 'c51')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -115,7 +117,7 @@ def test_c51(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||||
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -91,6 +92,7 @@ def test_dqn(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -117,7 +119,7 @@ def test_dqn(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
|
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
|
||||||
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||||
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
|
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.utils.net.common import Recurrent
|
from tianshou.utils.net.common import Recurrent
|
||||||
@ -77,6 +78,7 @@ def test_drqn(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'drqn')
|
log_path = os.path.join(args.logdir, args.task, 'drqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -96,7 +98,7 @@ def test_drqn(args=get_args()):
|
|||||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, update_per_step=args.update_per_step,
|
args.batch_size, update_per_step=args.update_per_step,
|
||||||
train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn,
|
train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn,
|
||||||
save_fn=save_fn, writer=writer)
|
save_fn=save_fn, logger=logger)
|
||||||
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.data import Collector
|
from tianshou.data import Collector
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offline_trainer
|
from tianshou.trainer import offline_trainer
|
||||||
@ -82,6 +83,7 @@ def test_discrete_bcq(args=get_args()):
|
|||||||
|
|
||||||
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
|
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -92,7 +94,7 @@ def test_discrete_bcq(args=get_args()):
|
|||||||
result = offline_trainer(
|
result = offline_trainer(
|
||||||
policy, buffer, test_collector,
|
policy, buffer, test_collector,
|
||||||
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
|
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||||
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
|
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import PGPolicy
|
from tianshou.policy import PGPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
@ -72,6 +73,7 @@ def test_pg(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'pg')
|
log_path = os.path.join(args.logdir, args.task, 'pg')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -84,7 +86,7 @@ def test_pg(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
||||||
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
||||||
writer=writer)
|
logger=logger)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import PPOPolicy
|
from tianshou.policy import PPOPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
@ -98,6 +99,7 @@ def test_ppo(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -110,7 +112,7 @@ def test_ppo(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
||||||
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
||||||
writer=writer)
|
logger=logger)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import argparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.policy import QRDQNPolicy
|
from tianshou.policy import QRDQNPolicy
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
@ -87,6 +88,7 @@ def test_qrdqn(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
|
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -113,7 +115,7 @@ def test_qrdqn(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||||
update_per_step=args.update_per_step)
|
update_per_step=args.update_per_step)
|
||||||
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
|
|||||||
@ -6,12 +6,13 @@ import argparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
|
||||||
from tianshou.policy import DiscreteSACPolicy
|
from tianshou.policy import DiscreteSACPolicy
|
||||||
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.utils.net.discrete import Actor, Critic
|
from tianshou.utils.net.discrete import Actor, Critic
|
||||||
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -99,6 +100,7 @@ def test_discrete_sac(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'discrete_sac')
|
log_path = os.path.join(args.logdir, args.task, 'discrete_sac')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
@ -110,7 +112,7 @@ def test_discrete_sac(args=get_args()):
|
|||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||||
update_per_step=args.update_per_step, test_in_train=False)
|
update_per_step=args.update_per_step, test_in_train=False)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import gym
|
import gym
|
||||||
import torch
|
import torch
|
||||||
import pprint
|
import pprint
|
||||||
@ -6,6 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import PSRLPolicy
|
from tianshou.policy import PSRLPolicy
|
||||||
|
# from tianshou.utils import BasicLogger
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
|
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
|
||||||
@ -66,7 +68,10 @@ def test_psrl(args=get_args()):
|
|||||||
exploration_noise=True)
|
exploration_noise=True)
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir + '/' + args.task)
|
log_path = os.path.join(args.logdir, args.task, 'psrl')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
writer.add_text("args", str(args))
|
||||||
|
# logger = BasicLogger(writer)
|
||||||
|
|
||||||
def stop_fn(mean_rewards):
|
def stop_fn(mean_rewards):
|
||||||
if env.spec.reward_threshold:
|
if env.spec.reward_threshold:
|
||||||
@ -75,11 +80,12 @@ def test_psrl(args=get_args()):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
train_collector.collect(n_step=args.buffer_size, random=True)
|
train_collector.collect(n_step=args.buffer_size, random=True)
|
||||||
# trainer
|
# trainer, test it without logger
|
||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, 1, args.test_num, 0,
|
args.step_per_epoch, 1, args.test_num, 0,
|
||||||
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer,
|
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn,
|
||||||
|
# logger=logger,
|
||||||
test_in_train=False)
|
test_in_train=False)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.data import Collector
|
from tianshou.data import Collector
|
||||||
from tianshou.policy import RandomPolicy
|
from tianshou.policy import RandomPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
|
|
||||||
from tic_tac_toe_env import TicTacToeEnv
|
from tic_tac_toe_env import TicTacToeEnv
|
||||||
from tic_tac_toe import get_parser, get_agents, train_agent, watch
|
from tic_tac_toe import get_parser, get_agents, train_agent, watch
|
||||||
@ -31,7 +32,8 @@ def gomoku(args=get_args()):
|
|||||||
|
|
||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, 'Gomoku', 'dqn')
|
log_path = os.path.join(args.logdir, 'Gomoku', 'dqn')
|
||||||
args.writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
args.logger = BasicLogger(writer)
|
||||||
|
|
||||||
opponent_pool = [agent_opponent]
|
opponent_pool = [agent_opponent]
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from copy import deepcopy
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -131,12 +132,10 @@ def train_agent(
|
|||||||
# policy.set_eps(1)
|
# policy.set_eps(1)
|
||||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||||
# log
|
# log
|
||||||
if not hasattr(args, 'writer'):
|
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
|
||||||
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
|
writer = SummaryWriter(log_path)
|
||||||
writer = SummaryWriter(log_path)
|
writer.add_text("args", str(args))
|
||||||
args.writer = writer
|
logger = BasicLogger(writer)
|
||||||
else:
|
|
||||||
writer = args.writer
|
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
if hasattr(args, 'model_save_path'):
|
if hasattr(args, 'model_save_path'):
|
||||||
@ -166,7 +165,7 @@ def train_agent(
|
|||||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step,
|
stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step,
|
||||||
writer=writer, test_in_train=False, reward_metric=reward_metric)
|
logger=logger, test_in_train=False, reward_metric=reward_metric)
|
||||||
|
|
||||||
return result, policy.policies[args.agent_id - 1]
|
return result, policy.policies[args.agent_id - 1]
|
||||||
|
|
||||||
|
|||||||
@ -92,8 +92,6 @@ class ReplayBuffer:
|
|||||||
("buffer.__getattr__" is customized).
|
("buffer.__getattr__" is customized).
|
||||||
"""
|
"""
|
||||||
self.__dict__.update(state)
|
self.__dict__.update(state)
|
||||||
# compatible with version == 0.3.1's HDF5 data format
|
|
||||||
self._indices = np.arange(self.maxsize)
|
|
||||||
|
|
||||||
def __setattr__(self, key: str, value: Any) -> None:
|
def __setattr__(self, key: str, value: Any) -> None:
|
||||||
"""Set self.key = value."""
|
"""Set self.key = value."""
|
||||||
|
|||||||
@ -184,9 +184,9 @@ class Collector(object):
|
|||||||
|
|
||||||
* ``n/ep`` collected number of episodes.
|
* ``n/ep`` collected number of episodes.
|
||||||
* ``n/st`` collected number of steps.
|
* ``n/st`` collected number of steps.
|
||||||
* ``rews`` list of episode reward over collected episodes.
|
* ``rews`` array of episode reward over collected episodes.
|
||||||
* ``lens`` list of episode length over collected episodes.
|
* ``lens`` array of episode length over collected episodes.
|
||||||
* ``idxs`` list of episode start index in buffer over collected episodes.
|
* ``idxs`` array of episode start index in buffer over collected episodes.
|
||||||
"""
|
"""
|
||||||
assert not self.env.is_async, "Please use AsyncCollector if using async venv."
|
assert not self.env.is_async, "Please use AsyncCollector if using async venv."
|
||||||
if n_step is not None:
|
if n_step is not None:
|
||||||
@ -379,9 +379,9 @@ class AsyncCollector(Collector):
|
|||||||
|
|
||||||
* ``n/ep`` collected number of episodes.
|
* ``n/ep`` collected number of episodes.
|
||||||
* ``n/st`` collected number of steps.
|
* ``n/st`` collected number of steps.
|
||||||
* ``rews`` list of episode reward over collected episodes.
|
* ``rews`` array of episode reward over collected episodes.
|
||||||
* ``lens`` list of episode length over collected episodes.
|
* ``lens`` array of episode length over collected episodes.
|
||||||
* ``idxs`` list of episode start index in buffer over collected episodes.
|
* ``idxs`` array of episode start index in buffer over collected episodes.
|
||||||
"""
|
"""
|
||||||
# collect at least n_step or n_episode
|
# collect at least n_step or n_episode
|
||||||
if n_step is not None:
|
if n_step is not None:
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from numba import njit
|
from numba import njit
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Union, Mapping, Optional, Callable
|
from typing import Any, Dict, Union, Optional, Callable
|
||||||
|
|
||||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
|
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
|
||||||
|
|
||||||
@ -124,12 +124,10 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def learn(
|
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]:
|
||||||
self, batch: Batch, **kwargs: Any
|
|
||||||
) -> Mapping[str, Union[float, List[float]]]:
|
|
||||||
"""Update policy with a given batch of data.
|
"""Update policy with a given batch of data.
|
||||||
|
|
||||||
:return: A dict which includes loss and its corresponding label.
|
:return: A dict, including the data needed to be logged (e.g., loss).
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
@ -162,18 +160,20 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
|
|
||||||
def update(
|
def update(
|
||||||
self, sample_size: int, buffer: Optional[ReplayBuffer], **kwargs: Any
|
self, sample_size: int, buffer: Optional[ReplayBuffer], **kwargs: Any
|
||||||
) -> Mapping[str, Union[float, List[float]]]:
|
) -> Dict[str, Any]:
|
||||||
"""Update the policy network and replay buffer.
|
"""Update the policy network and replay buffer.
|
||||||
|
|
||||||
It includes 3 function steps: process_fn, learn, and post_process_fn.
|
It includes 3 function steps: process_fn, learn, and post_process_fn. In
|
||||||
In addition, this function will change the value of ``self.updating``:
|
addition, this function will change the value of ``self.updating``: it will be
|
||||||
it will be False before this function and will be True when executing
|
False before this function and will be True when executing :meth:`update`.
|
||||||
:meth:`update`. Please refer to :ref:`policy_state` for more detailed
|
Please refer to :ref:`policy_state` for more detailed explanation.
|
||||||
explanation.
|
|
||||||
|
|
||||||
:param int sample_size: 0 means it will extract all the data from the
|
:param int sample_size: 0 means it will extract all the data from the buffer,
|
||||||
buffer, otherwise it will sample a batch with given sample_size.
|
otherwise it will sample a batch with given sample_size.
|
||||||
:param ReplayBuffer buffer: the corresponding replay buffer.
|
:param ReplayBuffer buffer: the corresponding replay buffer.
|
||||||
|
|
||||||
|
:return: A dict, including the data needed to be logged (e.g., loss) from
|
||||||
|
``policy.learn()``.
|
||||||
"""
|
"""
|
||||||
if buffer is None:
|
if buffer is None:
|
||||||
return {}
|
return {}
|
||||||
|
|||||||
@ -2,11 +2,10 @@ import time
|
|||||||
import tqdm
|
import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
from typing import Dict, Union, Callable, Optional
|
from typing import Dict, Union, Callable, Optional
|
||||||
|
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.utils import tqdm_config, MovAvg
|
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.trainer import test_episode, gather_info
|
from tianshou.trainer import test_episode, gather_info
|
||||||
|
|
||||||
@ -23,8 +22,7 @@ def offline_trainer(
|
|||||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||||
writer: Optional[SummaryWriter] = None,
|
logger: BaseLogger = LazyLogger(),
|
||||||
log_interval: int = 1,
|
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
) -> Dict[str, Union[float, str]]:
|
) -> Dict[str, Union[float, str]]:
|
||||||
"""A wrapper for offline trainer procedure.
|
"""A wrapper for offline trainer procedure.
|
||||||
@ -55,9 +53,8 @@ def offline_trainer(
|
|||||||
result to monitor training in the multi-agent RL setting. This function
|
result to monitor training in the multi-agent RL setting. This function
|
||||||
specifies what is the desired metric, e.g., the reward of agent 1 or the
|
specifies what is the desired metric, e.g., the reward of agent 1 or the
|
||||||
average reward over all agents.
|
average reward over all agents.
|
||||||
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter;
|
:param BaseLogger logger: A logger that logs statistics during updating/testing.
|
||||||
if None is given, it will not write logs to TensorBoard. Default to None.
|
Default to a logger that doesn't log anything.
|
||||||
:param int log_interval: the log interval of the writer. Default to 1.
|
|
||||||
:param bool verbose: whether to print the information. Default to True.
|
:param bool verbose: whether to print the information. Default to True.
|
||||||
|
|
||||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||||
@ -67,10 +64,9 @@ def offline_trainer(
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
test_collector.reset_stat()
|
test_collector.reset_stat()
|
||||||
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
|
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
|
||||||
writer, gradient_step, reward_metric)
|
logger, gradient_step, reward_metric)
|
||||||
best_epoch = 0
|
best_epoch = 0
|
||||||
best_reward = test_result["rews"].mean()
|
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||||
best_reward_std = test_result["rews"].std()
|
|
||||||
for epoch in range(1, 1 + max_epoch):
|
for epoch in range(1, 1 + max_epoch):
|
||||||
policy.train()
|
policy.train()
|
||||||
with tqdm.trange(
|
with tqdm.trange(
|
||||||
@ -82,27 +78,23 @@ def offline_trainer(
|
|||||||
data = {"gradient_step": str(gradient_step)}
|
data = {"gradient_step": str(gradient_step)}
|
||||||
for k in losses.keys():
|
for k in losses.keys():
|
||||||
stat[k].add(losses[k])
|
stat[k].add(losses[k])
|
||||||
data[k] = f"{stat[k].get():.6f}"
|
losses[k] = stat[k].get()
|
||||||
if writer and gradient_step % log_interval == 0:
|
data[k] = f"{losses[k]:.6f}"
|
||||||
writer.add_scalar(
|
logger.log_update_data(losses, gradient_step)
|
||||||
"train/" + k, stat[k].get(),
|
|
||||||
global_step=gradient_step)
|
|
||||||
t.set_postfix(**data)
|
t.set_postfix(**data)
|
||||||
# test
|
# test
|
||||||
test_result = test_episode(policy, test_collector, test_fn, epoch,
|
test_result = test_episode(
|
||||||
episode_per_test, writer, gradient_step,
|
policy, test_collector, test_fn, epoch, episode_per_test,
|
||||||
reward_metric)
|
logger, gradient_step, reward_metric)
|
||||||
if best_epoch == -1 or best_reward < test_result["rews"].mean():
|
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||||
best_reward = test_result["rews"].mean()
|
if best_epoch == -1 or best_reward < rew:
|
||||||
best_reward_std = test_result['rews'].std()
|
best_reward, best_reward_std = rew, rew_std
|
||||||
best_epoch = epoch
|
best_epoch = epoch
|
||||||
if save_fn:
|
if save_fn:
|
||||||
save_fn(policy)
|
save_fn(policy)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± "
|
print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||||
f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± "
|
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
|
||||||
f"{best_reward_std:.6f} in #{best_epoch}")
|
|
||||||
if stop_fn and stop_fn(best_reward):
|
if stop_fn and stop_fn(best_reward):
|
||||||
break
|
break
|
||||||
return gather_info(start_time, None, test_collector,
|
return gather_info(start_time, None, test_collector, best_reward, best_reward_std)
|
||||||
best_reward, best_reward_std)
|
|
||||||
|
|||||||
@ -2,12 +2,11 @@ import time
|
|||||||
import tqdm
|
import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
from typing import Dict, Union, Callable, Optional
|
from typing import Dict, Union, Callable, Optional
|
||||||
|
|
||||||
from tianshou.data import Collector
|
from tianshou.data import Collector
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.utils import tqdm_config, MovAvg
|
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
|
||||||
from tianshou.trainer import test_episode, gather_info
|
from tianshou.trainer import test_episode, gather_info
|
||||||
|
|
||||||
|
|
||||||
@ -26,8 +25,7 @@ def offpolicy_trainer(
|
|||||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||||
writer: Optional[SummaryWriter] = None,
|
logger: BaseLogger = LazyLogger(),
|
||||||
log_interval: int = 1,
|
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
test_in_train: bool = True,
|
test_in_train: bool = True,
|
||||||
) -> Dict[str, Union[float, str]]:
|
) -> Dict[str, Union[float, str]]:
|
||||||
@ -70,25 +68,24 @@ def offpolicy_trainer(
|
|||||||
result to monitor training in the multi-agent RL setting. This function
|
result to monitor training in the multi-agent RL setting. This function
|
||||||
specifies what is the desired metric, e.g., the reward of agent 1 or the
|
specifies what is the desired metric, e.g., the reward of agent 1 or the
|
||||||
average reward over all agents.
|
average reward over all agents.
|
||||||
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter;
|
:param BaseLogger logger: A logger that logs statistics during
|
||||||
if None is given, it will not write logs to TensorBoard. Default to None.
|
training/testing/updating. Default to a logger that doesn't log anything.
|
||||||
:param int log_interval: the log interval of the writer. Default to 1.
|
|
||||||
:param bool verbose: whether to print the information. Default to True.
|
:param bool verbose: whether to print the information. Default to True.
|
||||||
:param bool test_in_train: whether to test in the training phase. Default to True.
|
:param bool test_in_train: whether to test in the training phase. Default to True.
|
||||||
|
|
||||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||||
"""
|
"""
|
||||||
env_step, gradient_step = 0, 0
|
env_step, gradient_step = 0, 0
|
||||||
|
last_rew, last_len = 0.0, 0
|
||||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
train_collector.reset_stat()
|
train_collector.reset_stat()
|
||||||
test_collector.reset_stat()
|
test_collector.reset_stat()
|
||||||
test_in_train = test_in_train and train_collector.policy == policy
|
test_in_train = test_in_train and train_collector.policy == policy
|
||||||
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
|
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
|
||||||
writer, env_step, reward_metric)
|
logger, env_step, reward_metric)
|
||||||
best_epoch = 0
|
best_epoch = 0
|
||||||
best_reward = test_result["rews"].mean()
|
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||||
best_reward_std = test_result["rews"].std()
|
|
||||||
for epoch in range(1, 1 + max_epoch):
|
for epoch in range(1, 1 + max_epoch):
|
||||||
# train
|
# train
|
||||||
policy.train()
|
policy.train()
|
||||||
@ -99,34 +96,32 @@ def offpolicy_trainer(
|
|||||||
if train_fn:
|
if train_fn:
|
||||||
train_fn(epoch, env_step)
|
train_fn(epoch, env_step)
|
||||||
result = train_collector.collect(n_step=step_per_collect)
|
result = train_collector.collect(n_step=step_per_collect)
|
||||||
if len(result["rews"]) > 0 and reward_metric:
|
if result["n/ep"] > 0 and reward_metric:
|
||||||
result["rews"] = reward_metric(result["rews"])
|
result["rews"] = reward_metric(result["rews"])
|
||||||
env_step += int(result["n/st"])
|
env_step += int(result["n/st"])
|
||||||
t.update(result["n/st"])
|
t.update(result["n/st"])
|
||||||
|
logger.log_train_data(result, env_step)
|
||||||
|
last_rew = result['rew'] if 'rew' in result else last_rew
|
||||||
|
last_len = result['len'] if 'len' in result else last_len
|
||||||
data = {
|
data = {
|
||||||
"env_step": str(env_step),
|
"env_step": str(env_step),
|
||||||
"rew": f"{result['rews'].mean():.2f}",
|
"rew": f"{last_rew:.2f}",
|
||||||
"len": str(result["lens"].mean()),
|
"len": str(last_len),
|
||||||
"n/ep": str(int(result["n/ep"])),
|
"n/ep": str(int(result["n/ep"])),
|
||||||
"n/st": str(int(result["n/st"])),
|
"n/st": str(int(result["n/st"])),
|
||||||
}
|
}
|
||||||
if result["n/ep"] > 0:
|
if result["n/ep"] > 0:
|
||||||
if writer and env_step % log_interval == 0:
|
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
||||||
writer.add_scalar(
|
|
||||||
"train/rew", result['rews'].mean(), global_step=env_step)
|
|
||||||
writer.add_scalar(
|
|
||||||
"train/len", result['lens'].mean(), global_step=env_step)
|
|
||||||
if test_in_train and stop_fn and stop_fn(result["rews"].mean()):
|
|
||||||
test_result = test_episode(
|
test_result = test_episode(
|
||||||
policy, test_collector, test_fn,
|
policy, test_collector, test_fn,
|
||||||
epoch, episode_per_test, writer, env_step)
|
epoch, episode_per_test, logger, env_step)
|
||||||
if stop_fn(test_result["rews"].mean()):
|
if stop_fn(test_result["rew"]):
|
||||||
if save_fn:
|
if save_fn:
|
||||||
save_fn(policy)
|
save_fn(policy)
|
||||||
t.set_postfix(**data)
|
t.set_postfix(**data)
|
||||||
return gather_info(
|
return gather_info(
|
||||||
start_time, train_collector, test_collector,
|
start_time, train_collector, test_collector,
|
||||||
test_result["rews"].mean(), test_result["rews"].std())
|
test_result["rew"], test_result["rew_std"])
|
||||||
else:
|
else:
|
||||||
policy.train()
|
policy.train()
|
||||||
for i in range(round(update_per_step * result["n/st"])):
|
for i in range(round(update_per_step * result["n/st"])):
|
||||||
@ -134,26 +129,24 @@ def offpolicy_trainer(
|
|||||||
losses = policy.update(batch_size, train_collector.buffer)
|
losses = policy.update(batch_size, train_collector.buffer)
|
||||||
for k in losses.keys():
|
for k in losses.keys():
|
||||||
stat[k].add(losses[k])
|
stat[k].add(losses[k])
|
||||||
data[k] = f"{stat[k].get():.6f}"
|
losses[k] = stat[k].get()
|
||||||
if writer and gradient_step % log_interval == 0:
|
data[k] = f"{losses[k]:.6f}"
|
||||||
writer.add_scalar(
|
logger.log_update_data(losses, gradient_step)
|
||||||
k, stat[k].get(), global_step=gradient_step)
|
|
||||||
t.set_postfix(**data)
|
t.set_postfix(**data)
|
||||||
if t.n <= t.total:
|
if t.n <= t.total:
|
||||||
t.update()
|
t.update()
|
||||||
# test
|
# test
|
||||||
test_result = test_episode(policy, test_collector, test_fn, epoch,
|
test_result = test_episode(policy, test_collector, test_fn, epoch,
|
||||||
episode_per_test, writer, env_step, reward_metric)
|
episode_per_test, logger, env_step, reward_metric)
|
||||||
if best_epoch == -1 or best_reward < test_result["rews"].mean():
|
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||||
best_reward = test_result["rews"].mean()
|
if best_epoch == -1 or best_reward < rew:
|
||||||
best_reward_std = test_result['rews'].std()
|
best_reward, best_reward_std = rew, rew_std
|
||||||
best_epoch = epoch
|
best_epoch = epoch
|
||||||
if save_fn:
|
if save_fn:
|
||||||
save_fn(policy)
|
save_fn(policy)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± "
|
print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||||
f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± "
|
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
|
||||||
f"{best_reward_std:.6f} in #{best_epoch}")
|
|
||||||
if stop_fn and stop_fn(best_reward):
|
if stop_fn and stop_fn(best_reward):
|
||||||
break
|
break
|
||||||
return gather_info(start_time, train_collector, test_collector,
|
return gather_info(start_time, train_collector, test_collector,
|
||||||
|
|||||||
@ -2,12 +2,11 @@ import time
|
|||||||
import tqdm
|
import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
from typing import Dict, Union, Callable, Optional
|
from typing import Dict, Union, Callable, Optional
|
||||||
|
|
||||||
from tianshou.data import Collector
|
from tianshou.data import Collector
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.utils import tqdm_config, MovAvg
|
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
|
||||||
from tianshou.trainer import test_episode, gather_info
|
from tianshou.trainer import test_episode, gather_info
|
||||||
|
|
||||||
|
|
||||||
@ -27,8 +26,7 @@ def onpolicy_trainer(
|
|||||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||||
writer: Optional[SummaryWriter] = None,
|
logger: BaseLogger = LazyLogger(),
|
||||||
log_interval: int = 1,
|
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
test_in_train: bool = True,
|
test_in_train: bool = True,
|
||||||
) -> Dict[str, Union[float, str]]:
|
) -> Dict[str, Union[float, str]]:
|
||||||
@ -72,9 +70,8 @@ def onpolicy_trainer(
|
|||||||
result to monitor training in the multi-agent RL setting. This function
|
result to monitor training in the multi-agent RL setting. This function
|
||||||
specifies what is the desired metric, e.g., the reward of agent 1 or the
|
specifies what is the desired metric, e.g., the reward of agent 1 or the
|
||||||
average reward over all agents.
|
average reward over all agents.
|
||||||
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter;
|
:param BaseLogger logger: A logger that logs statistics during
|
||||||
if None is given, it will not write logs to TensorBoard. Default to None.
|
training/testing/updating. Default to a logger that doesn't log anything.
|
||||||
:param int log_interval: the log interval of the writer. Default to 1.
|
|
||||||
:param bool verbose: whether to print the information. Default to True.
|
:param bool verbose: whether to print the information. Default to True.
|
||||||
:param bool test_in_train: whether to test in the training phase. Default to True.
|
:param bool test_in_train: whether to test in the training phase. Default to True.
|
||||||
|
|
||||||
@ -85,16 +82,16 @@ def onpolicy_trainer(
|
|||||||
Only either one of step_per_collect and episode_per_collect can be specified.
|
Only either one of step_per_collect and episode_per_collect can be specified.
|
||||||
"""
|
"""
|
||||||
env_step, gradient_step = 0, 0
|
env_step, gradient_step = 0, 0
|
||||||
|
last_rew, last_len = 0.0, 0
|
||||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
train_collector.reset_stat()
|
train_collector.reset_stat()
|
||||||
test_collector.reset_stat()
|
test_collector.reset_stat()
|
||||||
test_in_train = test_in_train and train_collector.policy == policy
|
test_in_train = test_in_train and train_collector.policy == policy
|
||||||
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
|
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
|
||||||
writer, env_step, reward_metric)
|
logger, env_step, reward_metric)
|
||||||
best_epoch = 0
|
best_epoch = 0
|
||||||
best_reward = test_result["rews"].mean()
|
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||||
best_reward_std = test_result["rews"].std()
|
|
||||||
for epoch in range(1, 1 + max_epoch):
|
for epoch in range(1, 1 + max_epoch):
|
||||||
# train
|
# train
|
||||||
policy.train()
|
policy.train()
|
||||||
@ -110,29 +107,27 @@ def onpolicy_trainer(
|
|||||||
result["rews"] = reward_metric(result["rews"])
|
result["rews"] = reward_metric(result["rews"])
|
||||||
env_step += int(result["n/st"])
|
env_step += int(result["n/st"])
|
||||||
t.update(result["n/st"])
|
t.update(result["n/st"])
|
||||||
|
logger.log_train_data(result, env_step)
|
||||||
|
last_rew = result['rew'] if 'rew' in result else last_rew
|
||||||
|
last_len = result['len'] if 'len' in result else last_len
|
||||||
data = {
|
data = {
|
||||||
"env_step": str(env_step),
|
"env_step": str(env_step),
|
||||||
"rew": f"{result['rews'].mean():.2f}",
|
"rew": f"{last_rew:.2f}",
|
||||||
"len": str(int(result["lens"].mean())),
|
"len": str(last_len),
|
||||||
"n/ep": str(int(result["n/ep"])),
|
"n/ep": str(int(result["n/ep"])),
|
||||||
"n/st": str(int(result["n/st"])),
|
"n/st": str(int(result["n/st"])),
|
||||||
}
|
}
|
||||||
if writer and env_step % log_interval == 0:
|
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
||||||
writer.add_scalar(
|
|
||||||
"train/rew", result['rews'].mean(), global_step=env_step)
|
|
||||||
writer.add_scalar(
|
|
||||||
"train/len", result['lens'].mean(), global_step=env_step)
|
|
||||||
if test_in_train and stop_fn and stop_fn(result["rews"].mean()):
|
|
||||||
test_result = test_episode(
|
test_result = test_episode(
|
||||||
policy, test_collector, test_fn,
|
policy, test_collector, test_fn,
|
||||||
epoch, episode_per_test, writer, env_step)
|
epoch, episode_per_test, logger, env_step)
|
||||||
if stop_fn(test_result["rews"].mean()):
|
if stop_fn(test_result["rew"]):
|
||||||
if save_fn:
|
if save_fn:
|
||||||
save_fn(policy)
|
save_fn(policy)
|
||||||
t.set_postfix(**data)
|
t.set_postfix(**data)
|
||||||
return gather_info(
|
return gather_info(
|
||||||
start_time, train_collector, test_collector,
|
start_time, train_collector, test_collector,
|
||||||
test_result["rews"].mean(), test_result["rews"].std())
|
test_result["rew"], test_result["rew_std"])
|
||||||
else:
|
else:
|
||||||
policy.train()
|
policy.train()
|
||||||
losses = policy.update(
|
losses = policy.update(
|
||||||
@ -144,26 +139,24 @@ def onpolicy_trainer(
|
|||||||
gradient_step += step
|
gradient_step += step
|
||||||
for k in losses.keys():
|
for k in losses.keys():
|
||||||
stat[k].add(losses[k])
|
stat[k].add(losses[k])
|
||||||
data[k] = f"{stat[k].get():.6f}"
|
losses[k] = stat[k].get()
|
||||||
if writer and gradient_step % log_interval == 0:
|
data[k] = f"{losses[k]:.6f}"
|
||||||
writer.add_scalar(
|
logger.log_update_data(losses, gradient_step)
|
||||||
k, stat[k].get(), global_step=gradient_step)
|
|
||||||
t.set_postfix(**data)
|
t.set_postfix(**data)
|
||||||
if t.n <= t.total:
|
if t.n <= t.total:
|
||||||
t.update()
|
t.update()
|
||||||
# test
|
# test
|
||||||
test_result = test_episode(policy, test_collector, test_fn, epoch,
|
test_result = test_episode(policy, test_collector, test_fn, epoch,
|
||||||
episode_per_test, writer, env_step)
|
episode_per_test, logger, env_step)
|
||||||
if best_epoch == -1 or best_reward < test_result["rews"].mean():
|
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||||
best_reward = test_result["rews"].mean()
|
if best_epoch == -1 or best_reward < rew:
|
||||||
best_reward_std = test_result['rews'].std()
|
best_reward, best_reward_std = rew, rew_std
|
||||||
best_epoch = epoch
|
best_epoch = epoch
|
||||||
if save_fn:
|
if save_fn:
|
||||||
save_fn(policy)
|
save_fn(policy)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± "
|
print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||||
f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± "
|
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
|
||||||
f"{best_reward_std:.6f} in #{best_epoch}")
|
|
||||||
if stop_fn and stop_fn(best_reward):
|
if stop_fn and stop_fn(best_reward):
|
||||||
break
|
break
|
||||||
return gather_info(start_time, train_collector, test_collector,
|
return gather_info(start_time, train_collector, test_collector,
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
from typing import Any, Dict, Union, Callable, Optional
|
from typing import Any, Dict, Union, Callable, Optional
|
||||||
|
|
||||||
from tianshou.data import Collector
|
from tianshou.data import Collector
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
|
from tianshou.utils import BaseLogger
|
||||||
|
|
||||||
|
|
||||||
def test_episode(
|
def test_episode(
|
||||||
@ -13,7 +13,7 @@ def test_episode(
|
|||||||
test_fn: Optional[Callable[[int, Optional[int]], None]],
|
test_fn: Optional[Callable[[int, Optional[int]], None]],
|
||||||
epoch: int,
|
epoch: int,
|
||||||
n_episode: int,
|
n_episode: int,
|
||||||
writer: Optional[SummaryWriter] = None,
|
logger: Optional[BaseLogger] = None,
|
||||||
global_step: Optional[int] = None,
|
global_step: Optional[int] = None,
|
||||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@ -26,12 +26,8 @@ def test_episode(
|
|||||||
result = collector.collect(n_episode=n_episode)
|
result = collector.collect(n_episode=n_episode)
|
||||||
if reward_metric:
|
if reward_metric:
|
||||||
result["rews"] = reward_metric(result["rews"])
|
result["rews"] = reward_metric(result["rews"])
|
||||||
if writer is not None and global_step is not None:
|
if logger and global_step is not None:
|
||||||
rews, lens = result["rews"], result["lens"]
|
logger.log_test_data(result, global_step)
|
||||||
writer.add_scalar("test/rew", rews.mean(), global_step=global_step)
|
|
||||||
writer.add_scalar("test/rew_std", rews.std(), global_step=global_step)
|
|
||||||
writer.add_scalar("test/len", lens.mean(), global_step=global_step)
|
|
||||||
writer.add_scalar("test/len_std", lens.std(), global_step=global_step)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
from tianshou.utils.config import tqdm_config
|
from tianshou.utils.config import tqdm_config
|
||||||
from tianshou.utils.moving_average import MovAvg
|
from tianshou.utils.moving_average import MovAvg
|
||||||
from tianshou.utils.log_tools import SummaryWriter
|
from tianshou.utils.log_tools import BasicLogger, LazyLogger, BaseLogger
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MovAvg",
|
"MovAvg",
|
||||||
"tqdm_config",
|
"tqdm_config",
|
||||||
"SummaryWriter",
|
"BaseLogger",
|
||||||
|
"BasicLogger",
|
||||||
|
"LazyLogger",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,47 +1,159 @@
|
|||||||
import threading
|
import numpy as np
|
||||||
from torch.utils import tensorboard
|
from numbers import Number
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Union
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
|
||||||
class SummaryWriter(tensorboard.SummaryWriter):
|
class BaseLogger(ABC):
|
||||||
"""A more convenient Summary Writer(`tensorboard.SummaryWriter`).
|
"""The base class for any logger which is compatible with trainer."""
|
||||||
|
|
||||||
You can get the same instance of summary writer everywhere after you
|
def __init__(self, writer: Any) -> None:
|
||||||
created one.
|
super().__init__()
|
||||||
::
|
self.writer = writer
|
||||||
|
|
||||||
>>> writer1 = SummaryWriter.get_instance(
|
@abstractmethod
|
||||||
key="first", log_dir="log/test_sw/first")
|
def write(
|
||||||
>>> writer2 = SummaryWriter.get_instance()
|
self,
|
||||||
>>> writer1 is writer2
|
key: str,
|
||||||
True
|
x: Union[Number, np.number, np.ndarray],
|
||||||
>>> writer4 = SummaryWriter.get_instance(
|
y: Union[Number, np.number, np.ndarray],
|
||||||
key="second", log_dir="log/test_sw/second")
|
**kwargs: Any,
|
||||||
>>> writer5 = SummaryWriter.get_instance(key="second")
|
) -> None:
|
||||||
>>> writer1 is not writer4
|
"""Specify how the writer is used to log data.
|
||||||
True
|
|
||||||
>>> writer4 is writer5
|
:param key: namespace which the input data tuple belongs to.
|
||||||
True
|
:param x: stands for the ordinate of the input data tuple.
|
||||||
|
:param y: stands for the abscissa of the input data tuple.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def log_train_data(self, collect_result: dict, step: int) -> None:
|
||||||
|
"""Use writer to log statistics generated during training.
|
||||||
|
|
||||||
|
:param collect_result: a dict containing information of data collected in
|
||||||
|
training stage, i.e., returns of collector.collect().
|
||||||
|
:param int step: stands for the timestep the collect_result being logged.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def log_update_data(self, update_result: dict, step: int) -> None:
|
||||||
|
"""Use writer to log statistics generated during updating.
|
||||||
|
|
||||||
|
:param update_result: a dict containing information of data collected in
|
||||||
|
updating stage, i.e., returns of policy.update().
|
||||||
|
:param int step: stands for the timestep the collect_result being logged.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def log_test_data(self, collect_result: dict, step: int) -> None:
|
||||||
|
"""Use writer to log statistics generated during evaluating.
|
||||||
|
|
||||||
|
:param collect_result: a dict containing information of data collected in
|
||||||
|
evaluating stage, i.e., returns of collector.collect().
|
||||||
|
:param int step: stands for the timestep the collect_result being logged.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BasicLogger(BaseLogger):
|
||||||
|
"""A loggger that relies on tensorboard SummaryWriter by default to visualize \
|
||||||
|
and log statistics.
|
||||||
|
|
||||||
|
You can also rewrite write() func to use your own writer.
|
||||||
|
|
||||||
|
:param SummaryWriter writer: the writer to log data.
|
||||||
|
:param int train_interval: the log interval in log_train_data(). Default to 1.
|
||||||
|
:param int test_interval: the log interval in log_test_data(). Default to 1.
|
||||||
|
:param int update_interval: the log interval in log_update_data(). Default to 1000.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_mutex_lock = threading.Lock()
|
def __init__(
|
||||||
_default_key: str
|
self,
|
||||||
_instance: Optional[Dict[str, "SummaryWriter"]] = None
|
writer: SummaryWriter,
|
||||||
|
train_interval: int = 1,
|
||||||
|
test_interval: int = 1,
|
||||||
|
update_interval: int = 1000,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(writer)
|
||||||
|
self.train_interval = train_interval
|
||||||
|
self.test_interval = test_interval
|
||||||
|
self.update_interval = update_interval
|
||||||
|
self.last_log_train_step = -1
|
||||||
|
self.last_log_test_step = -1
|
||||||
|
self.last_log_update_step = -1
|
||||||
|
|
||||||
@classmethod
|
def write(
|
||||||
def get_instance(
|
self,
|
||||||
cls,
|
key: str,
|
||||||
key: Optional[str] = None,
|
x: Union[Number, np.number, np.ndarray],
|
||||||
*args: Any,
|
y: Union[Number, np.number, np.ndarray],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> "SummaryWriter":
|
) -> None:
|
||||||
"""Get instance of torch.utils.tensorboard.SummaryWriter by key."""
|
self.writer.add_scalar(key, y, global_step=x)
|
||||||
with SummaryWriter._mutex_lock:
|
|
||||||
if key is None:
|
def log_train_data(self, collect_result: dict, step: int) -> None:
|
||||||
key = SummaryWriter._default_key
|
"""Use writer to log statistics generated during training.
|
||||||
if SummaryWriter._instance is None:
|
|
||||||
SummaryWriter._instance = {}
|
:param collect_result: a dict containing information of data collected in
|
||||||
SummaryWriter._default_key = key
|
training stage, i.e., returns of collector.collect().
|
||||||
if key not in SummaryWriter._instance.keys():
|
:param int step: stands for the timestep the collect_result being logged.
|
||||||
SummaryWriter._instance[key] = SummaryWriter(*args, **kwargs)
|
|
||||||
return SummaryWriter._instance[key]
|
.. note::
|
||||||
|
|
||||||
|
``collect_result`` will be modified in-place with "rew" and "len" keys.
|
||||||
|
"""
|
||||||
|
if collect_result["n/ep"] > 0:
|
||||||
|
collect_result["rew"] = collect_result["rews"].mean()
|
||||||
|
collect_result["len"] = collect_result["lens"].mean()
|
||||||
|
if step - self.last_log_train_step >= self.train_interval:
|
||||||
|
self.write("train/n/ep", step, collect_result["n/ep"])
|
||||||
|
self.write("train/rew", step, collect_result["rew"])
|
||||||
|
self.write("train/len", step, collect_result["len"])
|
||||||
|
self.last_log_train_step = step
|
||||||
|
|
||||||
|
def log_test_data(self, collect_result: dict, step: int) -> None:
|
||||||
|
"""Use writer to log statistics generated during evaluating.
|
||||||
|
|
||||||
|
:param collect_result: a dict containing information of data collected in
|
||||||
|
evaluating stage, i.e., returns of collector.collect().
|
||||||
|
:param int step: stands for the timestep the collect_result being logged.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
``collect_result`` will be modified in-place with "rew", "rew_std", "len",
|
||||||
|
and "len_std" keys.
|
||||||
|
"""
|
||||||
|
assert collect_result["n/ep"] > 0
|
||||||
|
rews, lens = collect_result["rews"], collect_result["lens"]
|
||||||
|
rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std()
|
||||||
|
collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std)
|
||||||
|
if step - self.last_log_test_step >= self.test_interval:
|
||||||
|
self.write("test/rew", step, rew)
|
||||||
|
self.write("test/len", step, len_)
|
||||||
|
self.write("test/rew_std", step, rew_std)
|
||||||
|
self.write("test/len_std", step, len_std)
|
||||||
|
self.last_log_test_step = step
|
||||||
|
|
||||||
|
def log_update_data(self, update_result: dict, step: int) -> None:
|
||||||
|
if step - self.last_log_update_step >= self.update_interval:
|
||||||
|
for k, v in update_result.items():
|
||||||
|
self.write("train/" + k, step, v) # save in train/
|
||||||
|
self.last_log_update_step = step
|
||||||
|
|
||||||
|
|
||||||
|
class LazyLogger(BasicLogger):
|
||||||
|
"""A loggger that does nothing. Used as the placeholder in trainer."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(None) # type: ignore
|
||||||
|
|
||||||
|
def write(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
x: Union[Number, np.number, np.ndarray],
|
||||||
|
y: Union[Number, np.number, np.ndarray],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""The LazyLogger writes nothing."""
|
||||||
|
pass
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user