Add Weights and Biases Logger (#427)
- rename BasicLogger to TensorboardLogger - refactor logger code - add WandbLogger Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
This commit is contained in:
parent
e4f4f0e144
commit
8a5e2190f7
@ -192,7 +192,7 @@ buffer_size = 20000
|
|||||||
eps_train, eps_test = 0.1, 0.05
|
eps_train, eps_test = 0.1, 0.05
|
||||||
step_per_epoch, step_per_collect = 10000, 10
|
step_per_epoch, step_per_collect = 10000, 10
|
||||||
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
|
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
|
||||||
logger = ts.utils.BasicLogger(writer)
|
logger = ts.utils.TensorboardLogger(writer)
|
||||||
```
|
```
|
||||||
|
|
||||||
Make environments:
|
Make environments:
|
||||||
|
|||||||
@ -40,8 +40,8 @@ This is related to `Issue 349 <https://github.com/thu-ml/tianshou/issues/349>`_.
|
|||||||
To resume training process from an existing checkpoint, you need to do the following things in the training process:
|
To resume training process from an existing checkpoint, you need to do the following things in the training process:
|
||||||
|
|
||||||
1. Make sure you write ``save_checkpoint_fn`` which saves everything needed in the training process, i.e., policy, optim, buffer; pass it to trainer;
|
1. Make sure you write ``save_checkpoint_fn`` which saves everything needed in the training process, i.e., policy, optim, buffer; pass it to trainer;
|
||||||
2. Use ``BasicLogger`` which contains a tensorboard;
|
2. Use ``TensorboardLogger``;
|
||||||
3. To adjust the save frequency, specify ``save_interval`` when initializing BasicLogger.
|
3. To adjust the save frequency, specify ``save_interval`` when initializing TensorboardLogger.
|
||||||
|
|
||||||
And to successfully resume from a checkpoint:
|
And to successfully resume from a checkpoint:
|
||||||
|
|
||||||
|
|||||||
@ -148,9 +148,9 @@ The trainer supports `TensorBoard <https://www.tensorflow.org/tensorboard>`_ for
|
|||||||
::
|
::
|
||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
writer = SummaryWriter('log/dqn')
|
writer = SummaryWriter('log/dqn')
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
Pass the logger into the trainer, and the training result will be recorded into the TensorBoard.
|
Pass the logger into the trainer, and the training result will be recorded into the TensorBoard.
|
||||||
|
|
||||||
|
|||||||
@ -176,7 +176,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
|
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
@ -323,7 +323,7 @@ With the above preparation, we are close to the first learned agent. The followi
|
|||||||
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
|
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
# ======== callback functions used during training =========
|
# ======== callback functions used during training =========
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import argparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import offline_trainer
|
from tianshou.trainer import offline_trainer
|
||||||
from tianshou.utils.net.discrete import Actor
|
from tianshou.utils.net.discrete import Actor
|
||||||
@ -116,7 +116,7 @@ def test_discrete_bcq(args=get_args()):
|
|||||||
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
|
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer, update_interval=args.log_interval)
|
logger = TensorboardLogger(writer, update_interval=args.log_interval)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import C51Policy
|
from tianshou.policy import C51Policy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
@ -101,7 +101,7 @@ def test_c51(args=get_args()):
|
|||||||
log_path = os.path.join(args.logdir, args.task, 'c51')
|
log_path = os.path.join(args.logdir, args.task, 'c51')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import argparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import offline_trainer
|
from tianshou.trainer import offline_trainer
|
||||||
from tianshou.policy import DiscreteCQLPolicy
|
from tianshou.policy import DiscreteCQLPolicy
|
||||||
@ -108,7 +108,7 @@ def test_discrete_cql(args=get_args()):
|
|||||||
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
|
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer, update_interval=args.log_interval)
|
logger = TensorboardLogger(writer, update_interval=args.log_interval)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import argparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import offline_trainer
|
from tianshou.trainer import offline_trainer
|
||||||
from tianshou.utils.net.discrete import Actor
|
from tianshou.utils.net.discrete import Actor
|
||||||
@ -117,7 +117,7 @@ def test_discrete_crr(args=get_args()):
|
|||||||
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
|
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer, update_interval=args.log_interval)
|
logger = TensorboardLogger(writer, update_interval=args.log_interval)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
@ -96,7 +96,7 @@ def test_dqn(args=get_args()):
|
|||||||
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import FQFPolicy
|
from tianshou.policy import FQFPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
@ -112,7 +112,7 @@ def test_fqf(args=get_args()):
|
|||||||
log_path = os.path.join(args.logdir, args.task, 'fqf')
|
log_path = os.path.join(args.logdir, args.task, 'fqf')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import IQNPolicy
|
from tianshou.policy import IQNPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
@ -109,7 +109,7 @@ def test_iqn(args=get_args()):
|
|||||||
log_path = os.path.join(args.logdir, args.task, 'iqn')
|
log_path = os.path.join(args.logdir, args.task, 'iqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import argparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.policy import QRDQNPolicy
|
from tianshou.policy import QRDQNPolicy
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -99,7 +99,7 @@ def test_qrdqn(args=get_args()):
|
|||||||
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
|
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import RainbowPolicy
|
from tianshou.policy import RainbowPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
|
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
|
||||||
@ -121,7 +121,7 @@ def test_rainbow(args=get_args()):
|
|||||||
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
|
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -82,7 +82,7 @@ def test_dqn(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import SACPolicy
|
from tianshou.policy import SACPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -134,7 +134,7 @@ def test_sac_bipedal(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'sac')
|
log_path = os.path.join(args.logdir, args.task, 'sac')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
@ -84,7 +84,7 @@ def test_dqn(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import SACPolicy
|
from tianshou.policy import SACPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.exploration import OUNoise
|
from tianshou.exploration import OUNoise
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
@ -104,7 +104,7 @@ def test_sac(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'sac')
|
log_path = os.path.join(args.logdir, args.task, 'sac')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from torch.distributions import Independent, Normal
|
from torch.distributions import Independent, Normal
|
||||||
|
|
||||||
from tianshou.policy import A2CPolicy
|
from tianshou.policy import A2CPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
@ -141,7 +141,7 @@ def test_a2c(args=get_args()):
|
|||||||
log_path = os.path.join(args.logdir, args.task, 'a2c', log_file)
|
log_path = os.path.join(args.logdir, args.task, 'a2c', log_file)
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer, update_interval=100, train_interval=100)
|
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import DDPGPolicy
|
from tianshou.policy import DDPGPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.exploration import GaussianNoise
|
from tianshou.exploration import GaussianNoise
|
||||||
@ -110,7 +110,7 @@ def test_ddpg(args=get_args()):
|
|||||||
log_path = os.path.join(args.logdir, args.task, 'ddpg', log_file)
|
log_path = os.path.join(args.logdir, args.task, 'ddpg', log_file)
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from torch.distributions import Independent, Normal
|
from torch.distributions import Independent, Normal
|
||||||
|
|
||||||
from tianshou.policy import NPGPolicy
|
from tianshou.policy import NPGPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
@ -142,7 +142,7 @@ def test_npg(args=get_args()):
|
|||||||
log_path = os.path.join(args.logdir, args.task, 'npg', log_file)
|
log_path = os.path.join(args.logdir, args.task, 'npg', log_file)
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer, update_interval=100, train_interval=100)
|
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from torch.distributions import Independent, Normal
|
from torch.distributions import Independent, Normal
|
||||||
|
|
||||||
from tianshou.policy import PPOPolicy
|
from tianshou.policy import PPOPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
@ -149,7 +149,7 @@ def test_ppo(args=get_args()):
|
|||||||
log_path = os.path.join(args.logdir, args.task, 'ppo', log_file)
|
log_path = os.path.join(args.logdir, args.task, 'ppo', log_file)
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer, update_interval=100, train_interval=100)
|
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from torch.distributions import Independent, Normal
|
from torch.distributions import Independent, Normal
|
||||||
|
|
||||||
from tianshou.policy import PGPolicy
|
from tianshou.policy import PGPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
@ -131,7 +131,7 @@ def test_reinforce(args=get_args()):
|
|||||||
log_path = os.path.join(args.logdir, args.task, 'reinforce', log_file)
|
log_path = os.path.join(args.logdir, args.task, 'reinforce', log_file)
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer, update_interval=10, train_interval=100)
|
logger = TensorboardLogger(writer, update_interval=10, train_interval=100)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import SACPolicy
|
from tianshou.policy import SACPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -122,7 +122,7 @@ def test_sac(args=get_args()):
|
|||||||
log_path = os.path.join(args.logdir, args.task, 'sac', log_file)
|
log_path = os.path.join(args.logdir, args.task, 'sac', log_file)
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import TD3Policy
|
from tianshou.policy import TD3Policy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.exploration import GaussianNoise
|
from tianshou.exploration import GaussianNoise
|
||||||
@ -123,7 +123,7 @@ def test_td3(args=get_args()):
|
|||||||
log_path = os.path.join(args.logdir, args.task, 'td3', log_file)
|
log_path = os.path.join(args.logdir, args.task, 'td3', log_file)
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from torch.distributions import Independent, Normal
|
from torch.distributions import Independent, Normal
|
||||||
|
|
||||||
from tianshou.policy import TRPOPolicy
|
from tianshou.policy import TRPOPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
@ -147,7 +147,7 @@ def test_trpo(args=get_args()):
|
|||||||
log_path = os.path.join(args.logdir, args.task, 'trpo', log_file)
|
log_path = os.path.join(args.logdir, args.task, 'trpo', log_file)
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer, update_interval=100, train_interval=100)
|
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import C51Policy
|
from tianshou.policy import C51Policy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
@ -101,7 +101,7 @@ def test_c51(args=get_args()):
|
|||||||
log_path = os.path.join(args.logdir, args.task, 'c51')
|
log_path = os.path.join(args.logdir, args.task, 'c51')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
1
setup.py
1
setup.py
@ -63,6 +63,7 @@ setup(
|
|||||||
"pytest",
|
"pytest",
|
||||||
"pytest-cov",
|
"pytest-cov",
|
||||||
"ray>=1.0.0",
|
"ray>=1.0.0",
|
||||||
|
"wandb>=0.12.0",
|
||||||
"networkx",
|
"networkx",
|
||||||
"mypy",
|
"mypy",
|
||||||
"pydocstyle",
|
"pydocstyle",
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import DDPGPolicy
|
from tianshou.policy import DDPGPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -90,7 +90,7 @@ def test_ddpg(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'ddpg')
|
log_path = os.path.join(args.logdir, args.task, 'ddpg')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from torch.distributions import Independent, Normal
|
from torch.distributions import Independent, Normal
|
||||||
|
|
||||||
from tianshou.policy import NPGPolicy
|
from tianshou.policy import NPGPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
@ -105,7 +105,7 @@ def test_npg(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'npg')
|
log_path = os.path.join(args.logdir, args.task, 'npg')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from torch.distributions import Independent, Normal
|
from torch.distributions import Independent, Normal
|
||||||
|
|
||||||
from tianshou.policy import PPOPolicy
|
from tianshou.policy import PPOPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
@ -116,7 +116,7 @@ def test_ppo(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer, save_interval=args.save_interval)
|
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import argparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -109,7 +109,7 @@ def test_sac_with_il(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'sac')
|
log_path = os.path.join(args.logdir, args.task, 'sac')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import TD3Policy
|
from tianshou.policy import TD3Policy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -104,7 +104,7 @@ def test_td3(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'td3')
|
log_path = os.path.join(args.logdir, args.task, 'td3')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from torch.distributions import Independent, Normal
|
from torch.distributions import Independent, Normal
|
||||||
|
|
||||||
from tianshou.policy import TRPOPolicy
|
from tianshou.policy import TRPOPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
@ -109,7 +109,7 @@ def test_trpo(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'trpo')
|
log_path = os.path.join(args.logdir, args.task, 'trpo')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import argparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
@ -91,7 +91,7 @@ def test_a2c_with_il(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'a2c')
|
log_path = os.path.join(args.logdir, args.task, 'a2c')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import C51Policy
|
from tianshou.policy import C51Policy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -93,7 +93,7 @@ def test_c51(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'c51')
|
log_path = os.path.join(args.logdir, args.task, 'c51')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer, save_interval=args.save_interval)
|
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -92,7 +92,7 @@ def test_dqn(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.utils.net.common import Recurrent
|
from tianshou.utils.net.common import Recurrent
|
||||||
@ -78,7 +78,7 @@ def test_drqn(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'drqn')
|
log_path = os.path.join(args.logdir, args.task, 'drqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import FQFPolicy
|
from tianshou.policy import FQFPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -100,7 +100,7 @@ def test_fqf(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'fqf')
|
log_path = os.path.join(args.logdir, args.task, 'fqf')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.data import Collector
|
from tianshou.data import Collector
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offline_trainer
|
from tianshou.trainer import offline_trainer
|
||||||
@ -87,7 +87,7 @@ def test_discrete_bcq(args=get_args()):
|
|||||||
|
|
||||||
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
|
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer, save_interval=args.save_interval)
|
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.data import Collector
|
from tianshou.data import Collector
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offline_trainer
|
from tianshou.trainer import offline_trainer
|
||||||
@ -80,7 +80,7 @@ def test_discrete_crr(args=get_args()):
|
|||||||
|
|
||||||
log_path = os.path.join(args.logdir, args.task, 'discrete_cql')
|
log_path = os.path.join(args.logdir, args.task, 'discrete_cql')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import IQNPolicy
|
from tianshou.policy import IQNPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -96,7 +96,7 @@ def test_iqn(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'iqn')
|
log_path = os.path.join(args.logdir, args.task, 'iqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import PGPolicy
|
from tianshou.policy import PGPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
@ -78,7 +78,7 @@ def test_pg(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'pg')
|
log_path = os.path.join(args.logdir, args.task, 'pg')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import PPOPolicy
|
from tianshou.policy import PPOPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
@ -102,7 +102,7 @@ def test_ppo(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import argparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.policy import QRDQNPolicy
|
from tianshou.policy import QRDQNPolicy
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
@ -94,7 +94,7 @@ def test_qrdqn(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
|
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.data import Collector
|
from tianshou.data import Collector
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offline_trainer
|
from tianshou.trainer import offline_trainer
|
||||||
@ -79,7 +79,7 @@ def test_discrete_cql(args=get_args()):
|
|||||||
|
|
||||||
log_path = os.path.join(args.logdir, args.task, 'discrete_cql')
|
log_path = os.path.join(args.logdir, args.task, 'discrete_cql')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import RainbowPolicy
|
from tianshou.policy import RainbowPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.utils.net.discrete import NoisyLinear
|
from tianshou.utils.net.discrete import NoisyLinear
|
||||||
@ -102,7 +102,7 @@ def test_rainbow(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'rainbow')
|
log_path = os.path.join(args.logdir, args.task, 'rainbow')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer, save_interval=args.save_interval)
|
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import argparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.policy import DiscreteSACPolicy
|
from tianshou.policy import DiscreteSACPolicy
|
||||||
@ -97,7 +97,7 @@ def test_discrete_sac(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'discrete_sac')
|
log_path = os.path.join(args.logdir, args.task, 'discrete_sac')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|||||||
@ -7,7 +7,6 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import PSRLPolicy
|
from tianshou.policy import PSRLPolicy
|
||||||
# from tianshou.utils import BasicLogger
|
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
|
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
|
||||||
@ -71,7 +70,6 @@ def test_psrl(args=get_args()):
|
|||||||
log_path = os.path.join(args.logdir, args.task, 'psrl')
|
log_path = os.path.join(args.logdir, args.task, 'psrl')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
# logger = BasicLogger(writer)
|
|
||||||
|
|
||||||
def stop_fn(mean_rewards):
|
def stop_fn(mean_rewards):
|
||||||
if env.spec.reward_threshold:
|
if env.spec.reward_threshold:
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.data import Collector
|
from tianshou.data import Collector
|
||||||
from tianshou.policy import RandomPolicy
|
from tianshou.policy import RandomPolicy
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
|
|
||||||
from tic_tac_toe_env import TicTacToeEnv
|
from tic_tac_toe_env import TicTacToeEnv
|
||||||
from tic_tac_toe import get_parser, get_agents, train_agent, watch
|
from tic_tac_toe import get_parser, get_agents, train_agent, watch
|
||||||
@ -33,7 +33,7 @@ def gomoku(args=get_args()):
|
|||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, 'Gomoku', 'dqn')
|
log_path = os.path.join(args.logdir, 'Gomoku', 'dqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
args.logger = BasicLogger(writer)
|
args.logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
opponent_pool = [agent_opponent]
|
opponent_pool = [agent_opponent]
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from copy import deepcopy
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.utils import BasicLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
@ -135,7 +135,7 @@ def train_agent(
|
|||||||
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
|
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
writer.add_text("args", str(args))
|
writer.add_text("args", str(args))
|
||||||
logger = BasicLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
def save_fn(policy):
|
def save_fn(policy):
|
||||||
if hasattr(args, 'model_save_path'):
|
if hasattr(args, 'model_save_path'):
|
||||||
|
|||||||
@ -22,7 +22,7 @@ class PPOPolicy(A2CPolicy):
|
|||||||
:param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
|
:param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
|
||||||
where c > 1 is a constant indicating the lower bound.
|
where c > 1 is a constant indicating the lower bound.
|
||||||
Default to 5.0 (set None if you do not want to use it).
|
Default to 5.0 (set None if you do not want to use it).
|
||||||
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1.
|
:param bool value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1.
|
||||||
Default to True.
|
Default to True.
|
||||||
:param bool advantage_normalization: whether to do per mini-batch advantage
|
:param bool advantage_normalization: whether to do per mini-batch advantage
|
||||||
normalization. Default to True.
|
normalization. Default to True.
|
||||||
|
|||||||
@ -1,12 +1,17 @@
|
|||||||
from tianshou.utils.config import tqdm_config
|
from tianshou.utils.config import tqdm_config
|
||||||
from tianshou.utils.statistics import MovAvg, RunningMeanStd
|
from tianshou.utils.statistics import MovAvg, RunningMeanStd
|
||||||
from tianshou.utils.log_tools import BasicLogger, LazyLogger, BaseLogger
|
from tianshou.utils.logger.base import BaseLogger, LazyLogger
|
||||||
|
from tianshou.utils.logger.tensorboard import TensorboardLogger, BasicLogger
|
||||||
|
from tianshou.utils.logger.wandb import WandBLogger
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MovAvg",
|
"MovAvg",
|
||||||
"RunningMeanStd",
|
"RunningMeanStd",
|
||||||
"tqdm_config",
|
"tqdm_config",
|
||||||
"BaseLogger",
|
"BaseLogger",
|
||||||
|
"TensorboardLogger",
|
||||||
"BasicLogger",
|
"BasicLogger",
|
||||||
"LazyLogger",
|
"LazyLogger",
|
||||||
|
"WandBLogger"
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,210 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
from numbers import Number
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
from typing import Any, Tuple, Union, Callable, Optional
|
|
||||||
from tensorboard.backend.event_processing import event_accumulator
|
|
||||||
|
|
||||||
|
|
||||||
WRITE_TYPE = Union[int, Number, np.number, np.ndarray]
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLogger(ABC):
|
|
||||||
"""The base class for any logger which is compatible with trainer."""
|
|
||||||
|
|
||||||
def __init__(self, writer: Any) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.writer = writer
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None:
|
|
||||||
"""Specify how the writer is used to log data.
|
|
||||||
|
|
||||||
:param str key: namespace which the input data tuple belongs to.
|
|
||||||
:param int x: stands for the ordinate of the input data tuple.
|
|
||||||
:param y: stands for the abscissa of the input data tuple.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def log_train_data(self, collect_result: dict, step: int) -> None:
|
|
||||||
"""Use writer to log statistics generated during training.
|
|
||||||
|
|
||||||
:param collect_result: a dict containing information of data collected in
|
|
||||||
training stage, i.e., returns of collector.collect().
|
|
||||||
:param int step: stands for the timestep the collect_result being logged.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def log_update_data(self, update_result: dict, step: int) -> None:
|
|
||||||
"""Use writer to log statistics generated during updating.
|
|
||||||
|
|
||||||
:param update_result: a dict containing information of data collected in
|
|
||||||
updating stage, i.e., returns of policy.update().
|
|
||||||
:param int step: stands for the timestep the collect_result being logged.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def log_test_data(self, collect_result: dict, step: int) -> None:
|
|
||||||
"""Use writer to log statistics generated during evaluating.
|
|
||||||
|
|
||||||
:param collect_result: a dict containing information of data collected in
|
|
||||||
evaluating stage, i.e., returns of collector.collect().
|
|
||||||
:param int step: stands for the timestep the collect_result being logged.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def save_data(
|
|
||||||
self,
|
|
||||||
epoch: int,
|
|
||||||
env_step: int,
|
|
||||||
gradient_step: int,
|
|
||||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
|
|
||||||
|
|
||||||
:param int epoch: the epoch in trainer.
|
|
||||||
:param int env_step: the env_step in trainer.
|
|
||||||
:param int gradient_step: the gradient_step in trainer.
|
|
||||||
:param function save_checkpoint_fn: a hook defined by user, see trainer
|
|
||||||
documentation for detail.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def restore_data(self) -> Tuple[int, int, int]:
|
|
||||||
"""Return the metadata from existing log.
|
|
||||||
|
|
||||||
If it finds nothing or an error occurs during the recover process, it will
|
|
||||||
return the default parameters.
|
|
||||||
|
|
||||||
:return: epoch, env_step, gradient_step.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BasicLogger(BaseLogger):
|
|
||||||
"""A loggger that relies on tensorboard SummaryWriter by default to visualize \
|
|
||||||
and log statistics.
|
|
||||||
|
|
||||||
You can also rewrite write() func to use your own writer.
|
|
||||||
|
|
||||||
:param SummaryWriter writer: the writer to log data.
|
|
||||||
:param int train_interval: the log interval in log_train_data(). Default to 1000.
|
|
||||||
:param int test_interval: the log interval in log_test_data(). Default to 1.
|
|
||||||
:param int update_interval: the log interval in log_update_data(). Default to 1000.
|
|
||||||
:param int save_interval: the save interval in save_data(). Default to 1 (save at
|
|
||||||
the end of each epoch).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
writer: SummaryWriter,
|
|
||||||
train_interval: int = 1000,
|
|
||||||
test_interval: int = 1,
|
|
||||||
update_interval: int = 1000,
|
|
||||||
save_interval: int = 1,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(writer)
|
|
||||||
self.train_interval = train_interval
|
|
||||||
self.test_interval = test_interval
|
|
||||||
self.update_interval = update_interval
|
|
||||||
self.save_interval = save_interval
|
|
||||||
self.last_log_train_step = -1
|
|
||||||
self.last_log_test_step = -1
|
|
||||||
self.last_log_update_step = -1
|
|
||||||
self.last_save_step = -1
|
|
||||||
|
|
||||||
def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None:
|
|
||||||
self.writer.add_scalar(key, y, global_step=x)
|
|
||||||
|
|
||||||
def log_train_data(self, collect_result: dict, step: int) -> None:
|
|
||||||
"""Use writer to log statistics generated during training.
|
|
||||||
|
|
||||||
:param collect_result: a dict containing information of data collected in
|
|
||||||
training stage, i.e., returns of collector.collect().
|
|
||||||
:param int step: stands for the timestep the collect_result being logged.
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
``collect_result`` will be modified in-place with "rew" and "len" keys.
|
|
||||||
"""
|
|
||||||
if collect_result["n/ep"] > 0:
|
|
||||||
collect_result["rew"] = collect_result["rews"].mean()
|
|
||||||
collect_result["len"] = collect_result["lens"].mean()
|
|
||||||
if step - self.last_log_train_step >= self.train_interval:
|
|
||||||
self.write("train/n/ep", step, collect_result["n/ep"])
|
|
||||||
self.write("train/rew", step, collect_result["rew"])
|
|
||||||
self.write("train/len", step, collect_result["len"])
|
|
||||||
self.last_log_train_step = step
|
|
||||||
|
|
||||||
def log_test_data(self, collect_result: dict, step: int) -> None:
|
|
||||||
"""Use writer to log statistics generated during evaluating.
|
|
||||||
|
|
||||||
:param collect_result: a dict containing information of data collected in
|
|
||||||
evaluating stage, i.e., returns of collector.collect().
|
|
||||||
:param int step: stands for the timestep the collect_result being logged.
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
``collect_result`` will be modified in-place with "rew", "rew_std", "len",
|
|
||||||
and "len_std" keys.
|
|
||||||
"""
|
|
||||||
assert collect_result["n/ep"] > 0
|
|
||||||
rews, lens = collect_result["rews"], collect_result["lens"]
|
|
||||||
rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std()
|
|
||||||
collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std)
|
|
||||||
if step - self.last_log_test_step >= self.test_interval:
|
|
||||||
self.write("test/rew", step, rew)
|
|
||||||
self.write("test/len", step, len_)
|
|
||||||
self.write("test/rew_std", step, rew_std)
|
|
||||||
self.write("test/len_std", step, len_std)
|
|
||||||
self.last_log_test_step = step
|
|
||||||
|
|
||||||
def log_update_data(self, update_result: dict, step: int) -> None:
|
|
||||||
if step - self.last_log_update_step >= self.update_interval:
|
|
||||||
for k, v in update_result.items():
|
|
||||||
self.write(k, step, v)
|
|
||||||
self.last_log_update_step = step
|
|
||||||
|
|
||||||
def save_data(
|
|
||||||
self,
|
|
||||||
epoch: int,
|
|
||||||
env_step: int,
|
|
||||||
gradient_step: int,
|
|
||||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
|
||||||
) -> None:
|
|
||||||
if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
|
|
||||||
self.last_save_step = epoch
|
|
||||||
save_checkpoint_fn(epoch, env_step, gradient_step)
|
|
||||||
self.write("save/epoch", epoch, epoch)
|
|
||||||
self.write("save/env_step", env_step, env_step)
|
|
||||||
self.write("save/gradient_step", gradient_step, gradient_step)
|
|
||||||
|
|
||||||
def restore_data(self) -> Tuple[int, int, int]:
|
|
||||||
ea = event_accumulator.EventAccumulator(self.writer.log_dir)
|
|
||||||
ea.Reload()
|
|
||||||
|
|
||||||
try: # epoch / gradient_step
|
|
||||||
epoch = ea.scalars.Items("save/epoch")[-1].step
|
|
||||||
self.last_save_step = self.last_log_test_step = epoch
|
|
||||||
gradient_step = ea.scalars.Items("save/gradient_step")[-1].step
|
|
||||||
self.last_log_update_step = gradient_step
|
|
||||||
except KeyError:
|
|
||||||
epoch, gradient_step = 0, 0
|
|
||||||
try: # offline trainer doesn't have env_step
|
|
||||||
env_step = ea.scalars.Items("save/env_step")[-1].step
|
|
||||||
self.last_log_train_step = env_step
|
|
||||||
except KeyError:
|
|
||||||
env_step = 0
|
|
||||||
|
|
||||||
return epoch, env_step, gradient_step
|
|
||||||
|
|
||||||
|
|
||||||
class LazyLogger(BasicLogger):
|
|
||||||
"""A loggger that does nothing. Used as the placeholder in trainer."""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__(None) # type: ignore
|
|
||||||
|
|
||||||
def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None:
|
|
||||||
"""The LazyLogger writes nothing."""
|
|
||||||
pass
|
|
||||||
0
tianshou/utils/logger/__init__.py
Normal file
0
tianshou/utils/logger/__init__.py
Normal file
141
tianshou/utils/logger/base.py
Normal file
141
tianshou/utils/logger/base.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
import numpy as np
|
||||||
|
from numbers import Number
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Tuple, Union, Callable, Optional
|
||||||
|
|
||||||
|
LOG_DATA_TYPE = Dict[str, Union[int, Number, np.number, np.ndarray]]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLogger(ABC):
|
||||||
|
"""The base class for any logger which is compatible with trainer.
|
||||||
|
|
||||||
|
Try to overwrite write() method to use your own writer.
|
||||||
|
|
||||||
|
:param int train_interval: the log interval in log_train_data(). Default to 1000.
|
||||||
|
:param int test_interval: the log interval in log_test_data(). Default to 1.
|
||||||
|
:param int update_interval: the log interval in log_update_data(). Default to 1000.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
train_interval: int = 1000,
|
||||||
|
test_interval: int = 1,
|
||||||
|
update_interval: int = 1000,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.train_interval = train_interval
|
||||||
|
self.test_interval = test_interval
|
||||||
|
self.update_interval = update_interval
|
||||||
|
self.last_log_train_step = -1
|
||||||
|
self.last_log_test_step = -1
|
||||||
|
self.last_log_update_step = -1
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
||||||
|
"""Specify how the writer is used to log data.
|
||||||
|
|
||||||
|
:param str step_type: namespace which the data dict belongs to.
|
||||||
|
:param int step: stands for the ordinate of the data dict.
|
||||||
|
:param dict data: the data to write with format ``{key: value}``.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def log_train_data(self, collect_result: dict, step: int) -> None:
|
||||||
|
"""Use writer to log statistics generated during training.
|
||||||
|
|
||||||
|
:param collect_result: a dict containing information of data collected in
|
||||||
|
training stage, i.e., returns of collector.collect().
|
||||||
|
:param int step: stands for the timestep the collect_result being logged.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
``collect_result`` will be modified in-place with "rew" and "len" keys.
|
||||||
|
"""
|
||||||
|
if collect_result["n/ep"] > 0:
|
||||||
|
collect_result["rew"] = collect_result["rews"].mean()
|
||||||
|
collect_result["len"] = collect_result["lens"].mean()
|
||||||
|
if step - self.last_log_train_step >= self.train_interval:
|
||||||
|
log_data = {
|
||||||
|
"train/episode": collect_result["n/ep"],
|
||||||
|
"train/reward": collect_result["rew"],
|
||||||
|
"train/length": collect_result["len"],
|
||||||
|
}
|
||||||
|
self.write("train/env_step", step, log_data)
|
||||||
|
self.last_log_train_step = step
|
||||||
|
|
||||||
|
def log_test_data(self, collect_result: dict, step: int) -> None:
|
||||||
|
"""Use writer to log statistics generated during evaluating.
|
||||||
|
|
||||||
|
:param collect_result: a dict containing information of data collected in
|
||||||
|
evaluating stage, i.e., returns of collector.collect().
|
||||||
|
:param int step: stands for the timestep the collect_result being logged.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
``collect_result`` will be modified in-place with "rew", "rew_std", "len",
|
||||||
|
and "len_std" keys.
|
||||||
|
"""
|
||||||
|
assert collect_result["n/ep"] > 0
|
||||||
|
rews, lens = collect_result["rews"], collect_result["lens"]
|
||||||
|
rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std()
|
||||||
|
collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std)
|
||||||
|
if step - self.last_log_test_step >= self.test_interval:
|
||||||
|
log_data = {
|
||||||
|
"test/env_step": step,
|
||||||
|
"test/reward": rew,
|
||||||
|
"test/length": len_,
|
||||||
|
"test/reward_std": rew_std,
|
||||||
|
"test/length_std": len_std,
|
||||||
|
}
|
||||||
|
self.write("test/env_step", step, log_data)
|
||||||
|
self.last_log_test_step = step
|
||||||
|
|
||||||
|
def log_update_data(self, update_result: dict, step: int) -> None:
|
||||||
|
"""Use writer to log statistics generated during updating.
|
||||||
|
|
||||||
|
:param update_result: a dict containing information of data collected in
|
||||||
|
updating stage, i.e., returns of policy.update().
|
||||||
|
:param int step: stands for the timestep the collect_result being logged.
|
||||||
|
"""
|
||||||
|
if step - self.last_log_update_step >= self.update_interval:
|
||||||
|
log_data = {f"update/{k}": v for k, v in update_result.items()}
|
||||||
|
self.write("update/gradient_step", step, log_data)
|
||||||
|
self.last_log_update_step = step
|
||||||
|
|
||||||
|
def save_data(
|
||||||
|
self,
|
||||||
|
epoch: int,
|
||||||
|
env_step: int,
|
||||||
|
gradient_step: int,
|
||||||
|
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
|
||||||
|
|
||||||
|
:param int epoch: the epoch in trainer.
|
||||||
|
:param int env_step: the env_step in trainer.
|
||||||
|
:param int gradient_step: the gradient_step in trainer.
|
||||||
|
:param function save_checkpoint_fn: a hook defined by user, see trainer
|
||||||
|
documentation for detail.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def restore_data(self) -> Tuple[int, int, int]:
|
||||||
|
"""Return the metadata from existing log.
|
||||||
|
|
||||||
|
If it finds nothing or an error occurs during the recover process, it will
|
||||||
|
return the default parameters.
|
||||||
|
|
||||||
|
:return: epoch, env_step, gradient_step.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LazyLogger(BaseLogger):
|
||||||
|
"""A logger that does nothing. Used as the placeholder in trainer."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
||||||
|
"""The LazyLogger writes nothing."""
|
||||||
|
pass
|
||||||
83
tianshou/utils/logger/tensorboard.py
Normal file
83
tianshou/utils/logger/tensorboard.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
import warnings
|
||||||
|
from typing import Any, Tuple, Callable, Optional
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tensorboard.backend.event_processing import event_accumulator
|
||||||
|
|
||||||
|
from tianshou.utils.logger.base import BaseLogger, LOG_DATA_TYPE
|
||||||
|
|
||||||
|
|
||||||
|
class TensorboardLogger(BaseLogger):
|
||||||
|
"""A logger that relies on tensorboard SummaryWriter by default to visualize \
|
||||||
|
and log statistics.
|
||||||
|
|
||||||
|
:param SummaryWriter writer: the writer to log data.
|
||||||
|
:param int train_interval: the log interval in log_train_data(). Default to 1000.
|
||||||
|
:param int test_interval: the log interval in log_test_data(). Default to 1.
|
||||||
|
:param int update_interval: the log interval in log_update_data(). Default to 1000.
|
||||||
|
:param int save_interval: the save interval in save_data(). Default to 1 (save at
|
||||||
|
the end of each epoch).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
writer: SummaryWriter,
|
||||||
|
train_interval: int = 1000,
|
||||||
|
test_interval: int = 1,
|
||||||
|
update_interval: int = 1000,
|
||||||
|
save_interval: int = 1,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(train_interval, test_interval, update_interval)
|
||||||
|
self.save_interval = save_interval
|
||||||
|
self.last_save_step = -1
|
||||||
|
self.writer = writer
|
||||||
|
|
||||||
|
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
||||||
|
for k, v in data.items():
|
||||||
|
self.writer.add_scalar(k, v, global_step=step)
|
||||||
|
|
||||||
|
def save_data(
|
||||||
|
self,
|
||||||
|
epoch: int,
|
||||||
|
env_step: int,
|
||||||
|
gradient_step: int,
|
||||||
|
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||||
|
) -> None:
|
||||||
|
if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
|
||||||
|
self.last_save_step = epoch
|
||||||
|
save_checkpoint_fn(epoch, env_step, gradient_step)
|
||||||
|
self.write("save/epoch", epoch, {"save/epoch": epoch})
|
||||||
|
self.write("save/env_step", env_step, {"save/env_step": env_step})
|
||||||
|
self.write("save/gradient_step", gradient_step,
|
||||||
|
{"save/gradient_step": gradient_step})
|
||||||
|
|
||||||
|
def restore_data(self) -> Tuple[int, int, int]:
|
||||||
|
ea = event_accumulator.EventAccumulator(self.writer.log_dir)
|
||||||
|
ea.Reload()
|
||||||
|
|
||||||
|
try: # epoch / gradient_step
|
||||||
|
epoch = ea.scalars.Items("save/epoch")[-1].step
|
||||||
|
self.last_save_step = self.last_log_test_step = epoch
|
||||||
|
gradient_step = ea.scalars.Items("save/gradient_step")[-1].step
|
||||||
|
self.last_log_update_step = gradient_step
|
||||||
|
except KeyError:
|
||||||
|
epoch, gradient_step = 0, 0
|
||||||
|
try: # offline trainer doesn't have env_step
|
||||||
|
env_step = ea.scalars.Items("save/env_step")[-1].step
|
||||||
|
self.last_log_train_step = env_step
|
||||||
|
except KeyError:
|
||||||
|
env_step = 0
|
||||||
|
|
||||||
|
return epoch, env_step, gradient_step
|
||||||
|
|
||||||
|
|
||||||
|
class BasicLogger(TensorboardLogger):
|
||||||
|
"""BasicLogger has changed its name to TensorboardLogger in #427.
|
||||||
|
|
||||||
|
This class is for compatibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
warnings.warn(
|
||||||
|
"Deprecated soon: BasicLogger has renamed to TensorboardLogger in #427.")
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
44
tianshou/utils/logger/wandb.py
Normal file
44
tianshou/utils/logger/wandb.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from tianshou.utils import BaseLogger
|
||||||
|
from tianshou.utils.logger.base import LOG_DATA_TYPE
|
||||||
|
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WandBLogger(BaseLogger):
|
||||||
|
"""Weights and Biases logger that sends data to Weights and Biases.
|
||||||
|
|
||||||
|
Creates three panels with plots: train, test, and update.
|
||||||
|
Make sure to select the correct access for each panel in weights and biases:
|
||||||
|
|
||||||
|
- ``train/env_step`` for train plots
|
||||||
|
- ``test/env_step`` for test plots
|
||||||
|
- ``update/gradient_step`` for update plots
|
||||||
|
|
||||||
|
Example of usage:
|
||||||
|
::
|
||||||
|
|
||||||
|
with wandb.init(project="My Project"):
|
||||||
|
logger = WandBLogger()
|
||||||
|
result = onpolicy_trainer(policy, train_collector, test_collector,
|
||||||
|
logger=logger)
|
||||||
|
|
||||||
|
:param int train_interval: the log interval in log_train_data(). Default to 1000.
|
||||||
|
:param int test_interval: the log interval in log_test_data(). Default to 1.
|
||||||
|
:param int update_interval: the log interval in log_update_data().
|
||||||
|
Default to 1000.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
train_interval: int = 1000,
|
||||||
|
test_interval: int = 1,
|
||||||
|
update_interval: int = 1000,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(train_interval, test_interval, update_interval)
|
||||||
|
|
||||||
|
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
||||||
|
data[step_type] = step
|
||||||
|
wandb.log(data)
|
||||||
Loading…
x
Reference in New Issue
Block a user