Update WandbLogger implementation (#558)

* Use `global_step` as the x-axis for wandb
* Use Tensorboard SummaryWritter as core with `wandb.init(..., sync_tensorboard=True)`
* Update all atari examples with wandb

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
This commit is contained in:
Costa Huang 2022-03-06 17:40:47 -05:00 committed by GitHub
parent 2377f2f186
commit df3d7f582b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 482 additions and 320 deletions

View File

@ -34,8 +34,12 @@ WandbLogger
::
from tianshou.utils import WandbLogger
from torch.utils.tensorboard import SummaryWriter
logger = WandbLogger(...)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger.load(writer)
result = trainer(..., logger=logger)
Please refer to :class:`~tianshou.utils.WandbLogger` documentation for advanced configuration.

View File

@ -1,4 +1,5 @@
import argparse
import datetime
import os
import pprint
@ -11,46 +12,54 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import C51Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils import TensorboardLogger, WandbLogger
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--scale-obs', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--num-atoms', type=int, default=51)
parser.add_argument('--v-min', type=float, default=-10.)
parser.add_argument('--v-max', type=float, default=10.)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument("--eps-train", type=float, default=1.)
parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--num-atoms", type=int, default=51)
parser.add_argument("--v-min", type=float, default=-10.)
parser.add_argument("--v-max", type=float, default=10.)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
'--watch',
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument(
"--watch",
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
action="store_true",
help="watch the play of pre-trained policy only"
)
parser.add_argument('--save-buffer-name', type=str, default=None)
parser.add_argument("--save-buffer-name", type=str, default=None)
return parser.parse_args()
@ -101,19 +110,36 @@ def test_c51(args=get_args()):
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(args.logdir, args.task, 'c51')
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "c51"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
elif "Pong" in args.task:
return mean_rewards >= 20
else:
return False
@ -159,7 +185,7 @@ def test_c51(args=get_args()):
n_episode=args.test_num, render=args.render
)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
if args.watch:
watch()
@ -190,5 +216,5 @@ def test_c51(args=get_args()):
watch()
if __name__ == '__main__':
if __name__ == "__main__":
test_c51(get_args())

View File

@ -1,4 +1,5 @@
import argparse
import datetime
import os
import pprint
@ -18,62 +19,63 @@ from tianshou.utils.net.discrete import IntrinsicCuriosityModule
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--scale-obs', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument("--eps-train", type=float, default=1.)
parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--resume-id', type=str, default=None)
parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
'--logger',
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument(
'--watch',
"--watch",
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
action="store_true",
help="watch the play of pre-trained policy only"
)
parser.add_argument('--save-buffer-name', type=str, default=None)
parser.add_argument("--save-buffer-name", type=str, default=None)
parser.add_argument(
'--icm-lr-scale',
"--icm-lr-scale",
type=float,
default=0.,
help='use intrinsic curiosity module with this lr scale'
help="use intrinsic curiosity module with this lr scale"
)
parser.add_argument(
'--icm-reward-scale',
"--icm-reward-scale",
type=float,
default=0.01,
help='scaling factor for intrinsic curiosity reward'
help="scaling factor for intrinsic curiosity reward"
)
parser.add_argument(
'--icm-forward-loss-weight',
"--icm-forward-loss-weight",
type=float,
default=0.2,
help='weight for the forward model loss in ICM'
help="weight for the forward model loss in ICM"
)
return parser.parse_args()
@ -140,29 +142,36 @@ def test_dqn(args=get_args()):
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_name = 'dqn_icm' if args.icm_lr_scale > 0 else 'dqn'
log_path = os.path.join(args.logdir, args.task, log_name)
if args.logger == "tensorboard":
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
else:
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "dqn_icm" if args.icm_lr_scale > 0 else "dqn"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
project=args.task,
name=log_name,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
elif "Pong" in args.task:
return mean_rewards >= 20
else:
return False
@ -183,8 +192,8 @@ def test_dqn(args=get_args()):
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
torch.save({'model': policy.state_dict()}, ckpt_path)
ckpt_path = os.path.join(log_path, "checkpoint.pth")
torch.save({"model": policy.state_dict()}, ckpt_path)
return ckpt_path
# watch agent's performance
@ -214,7 +223,7 @@ def test_dqn(args=get_args()):
n_episode=args.test_num, render=args.render
)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
if args.watch:
watch()
@ -247,5 +256,5 @@ def test_dqn(args=get_args()):
watch()
if __name__ == '__main__':
if __name__ == "__main__":
test_dqn(get_args())

View File

@ -1,4 +1,5 @@
import argparse
import datetime
import os
import pprint
@ -11,49 +12,57 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import FQFPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=3128)
parser.add_argument('--scale-obs', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=5e-5)
parser.add_argument('--fraction-lr', type=float, default=2.5e-9)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--num-fractions', type=int, default=32)
parser.add_argument('--num-cosines', type=int, default=64)
parser.add_argument('--ent-coef', type=float, default=10.)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=3128)
parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument("--eps-train", type=float, default=1.)
parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=5e-5)
parser.add_argument("--fraction-lr", type=float, default=2.5e-9)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--num-fractions", type=int, default=32)
parser.add_argument("--num-cosines", type=int, default=64)
parser.add_argument("--ent-coef", type=float, default=10.)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512])
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
'--watch',
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument(
"--watch",
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
action="store_true",
help="watch the play of pre-trained policy only"
)
parser.add_argument('--save-buffer-name', type=str, default=None)
parser.add_argument("--save-buffer-name", type=str, default=None)
return parser.parse_args()
@ -118,19 +127,36 @@ def test_fqf(args=get_args()):
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(args.logdir, args.task, 'fqf')
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "fqf"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
elif "Pong" in args.task:
return mean_rewards >= 20
else:
return False
@ -176,7 +202,7 @@ def test_fqf(args=get_args()):
n_episode=args.test_num, render=args.render
)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
if args.watch:
watch()
@ -207,5 +233,5 @@ def test_fqf(args=get_args()):
watch()
if __name__ == '__main__':
if __name__ == "__main__":
test_fqf(get_args())

View File

@ -1,4 +1,5 @@
import argparse
import datetime
import os
import pprint
@ -11,49 +12,57 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import IQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--scale-obs', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--sample-size', type=int, default=32)
parser.add_argument('--online-sample-size', type=int, default=8)
parser.add_argument('--target-sample-size', type=int, default=8)
parser.add_argument('--num-cosines', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument("--eps-train", type=float, default=1.)
parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--sample-size", type=int, default=32)
parser.add_argument("--online-sample-size", type=int, default=8)
parser.add_argument("--target-sample-size", type=int, default=8)
parser.add_argument("--num-cosines", type=int, default=64)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512])
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
'--watch',
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument(
"--watch",
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
action="store_true",
help="watch the play of pre-trained policy only"
)
parser.add_argument('--save-buffer-name', type=str, default=None)
parser.add_argument("--save-buffer-name", type=str, default=None)
return parser.parse_args()
@ -113,19 +122,36 @@ def test_iqn(args=get_args()):
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(args.logdir, args.task, 'iqn')
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "iqn"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
elif "Pong" in args.task:
return mean_rewards >= 20
else:
return False
@ -171,7 +197,7 @@ def test_iqn(args=get_args()):
n_episode=args.test_num, render=args.render
)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
if args.watch:
watch()
@ -202,5 +228,5 @@ def test_iqn(args=get_args()):
watch()
if __name__ == '__main__':
if __name__ == "__main__":
test_iqn(get_args())

View File

@ -1,4 +1,5 @@
import argparse
import datetime
import os
import pprint
@ -19,69 +20,70 @@ from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=4213)
parser.add_argument('--scale-obs', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=5e-5)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=1000)
parser.add_argument('--repeat-per-collect', type=int, default=4)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--hidden-size', type=int, default=512)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--rew-norm', type=int, default=False)
parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--ent-coef', type=float, default=0.01)
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--dual-clip', type=float, default=None)
parser.add_argument('--value-clip', type=int, default=0)
parser.add_argument('--norm-adv', type=int, default=1)
parser.add_argument('--recompute-adv', type=int, default=0)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=4213)
parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=5e-5)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument("--step-per-collect", type=int, default=1000)
parser.add_argument("--repeat-per-collect", type=int, default=4)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--hidden-size", type=int, default=512)
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--rew-norm", type=int, default=False)
parser.add_argument("--vf-coef", type=float, default=0.5)
parser.add_argument("--ent-coef", type=float, default=0.01)
parser.add_argument("--gae-lambda", type=float, default=0.95)
parser.add_argument("--lr-decay", type=int, default=True)
parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--eps-clip", type=float, default=0.2)
parser.add_argument("--dual-clip", type=float, default=None)
parser.add_argument("--value-clip", type=int, default=0)
parser.add_argument("--norm-adv", type=int, default=1)
parser.add_argument("--recompute-adv", type=int, default=0)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--resume-id', type=str, default=None)
parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
'--logger',
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument(
'--watch',
"--watch",
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
action="store_true",
help="watch the play of pre-trained policy only"
)
parser.add_argument('--save-buffer-name', type=str, default=None)
parser.add_argument("--save-buffer-name", type=str, default=None)
parser.add_argument(
'--icm-lr-scale',
"--icm-lr-scale",
type=float,
default=0.,
help='use intrinsic curiosity module with this lr scale'
help="use intrinsic curiosity module with this lr scale"
)
parser.add_argument(
'--icm-reward-scale',
"--icm-reward-scale",
type=float,
default=0.01,
help='scaling factor for intrinsic curiosity reward'
help="scaling factor for intrinsic curiosity reward"
)
parser.add_argument(
'--icm-forward-loss-weight',
"--icm-forward-loss-weight",
type=float,
default=0.2,
help='weight for the forward model loss in ICM'
help="weight for the forward model loss in ICM"
)
return parser.parse_args()
@ -184,37 +186,44 @@ def test_ppo(args=get_args()):
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_name = 'ppo_icm' if args.icm_lr_scale > 0 else 'ppo'
log_path = os.path.join(args.logdir, args.task, log_name)
if args.logger == "tensorboard":
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
else:
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "ppo_icm" if args.icm_lr_scale > 0 else "ppo"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
project=args.task,
name=log_name,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
elif "Pong" in args.task:
return mean_rewards >= 20
else:
return False
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
torch.save({'model': policy.state_dict()}, ckpt_path)
ckpt_path = os.path.join(log_path, "checkpoint.pth")
torch.save({"model": policy.state_dict()}, ckpt_path)
return ckpt_path
# watch agent's performance
@ -243,7 +252,7 @@ def test_ppo(args=get_args()):
n_episode=args.test_num, render=args.render
)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
if args.watch:
watch()
@ -274,5 +283,5 @@ def test_ppo(args=get_args()):
watch()
if __name__ == '__main__':
if __name__ == "__main__":
test_ppo(get_args())

View File

@ -1,4 +1,5 @@
import argparse
import datetime
import os
import pprint
@ -6,7 +7,7 @@ import numpy as np
import torch
from atari_network import QRDQN
from atari_wrapper import make_atari_env
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard import SummaryWriter, WandbLogger
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import QRDQNPolicy
@ -16,39 +17,47 @@ from tianshou.utils import TensorboardLogger
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--scale-obs', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--num-quantiles', type=int, default=200)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument("--eps-train", type=float, default=1.)
parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--num-quantiles", type=int, default=200)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
'--watch',
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument(
"--watch",
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
action="store_true",
help="watch the play of pre-trained policy only"
)
parser.add_argument('--save-buffer-name', type=str, default=None)
parser.add_argument("--save-buffer-name", type=str, default=None)
return parser.parse_args()
@ -97,19 +106,36 @@ def test_qrdqn(args=get_args()):
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "qrdqn"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
elif "Pong" in args.task:
return mean_rewards >= 20
else:
return False
@ -155,7 +181,7 @@ def test_qrdqn(args=get_args()):
n_episode=args.test_num, render=args.render
)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
if args.watch:
watch()
@ -186,5 +212,5 @@ def test_qrdqn(args=get_args()):
watch()
if __name__ == '__main__':
if __name__ == "__main__":
test_qrdqn(get_args())

View File

@ -7,7 +7,7 @@ import numpy as np
import torch
from atari_network import Rainbow
from atari_wrapper import make_atari_env
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard import SummaryWriter, WandbLogger
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.policy import RainbowPolicy
@ -17,50 +17,58 @@ from tianshou.utils import TensorboardLogger
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--scale-obs', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0000625)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--num-atoms', type=int, default=51)
parser.add_argument('--v-min', type=float, default=-10.)
parser.add_argument('--v-max', type=float, default=10.)
parser.add_argument('--noisy-std', type=float, default=0.1)
parser.add_argument('--no-dueling', action='store_true', default=False)
parser.add_argument('--no-noisy', action='store_true', default=False)
parser.add_argument('--no-priority', action='store_true', default=False)
parser.add_argument('--alpha', type=float, default=0.5)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument('--beta-final', type=float, default=1.)
parser.add_argument('--beta-anneal-step', type=int, default=5000000)
parser.add_argument('--no-weight-norm', action='store_true', default=False)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument("--eps-train", type=float, default=1.)
parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=0.0000625)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--num-atoms", type=int, default=51)
parser.add_argument("--v-min", type=float, default=-10.)
parser.add_argument("--v-max", type=float, default=10.)
parser.add_argument("--noisy-std", type=float, default=0.1)
parser.add_argument("--no-dueling", action="store_true", default=False)
parser.add_argument("--no-noisy", action="store_true", default=False)
parser.add_argument("--no-priority", action="store_true", default=False)
parser.add_argument("--alpha", type=float, default=0.5)
parser.add_argument("--beta", type=float, default=0.4)
parser.add_argument("--beta-final", type=float, default=1.)
parser.add_argument("--beta-anneal-step", type=int, default=5000000)
parser.add_argument("--no-weight-norm", action="store_true", default=False)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
'--watch',
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument(
"--watch",
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
action="store_true",
help="watch the play of pre-trained policy only"
)
parser.add_argument('--save-buffer-name', type=str, default=None)
parser.add_argument("--save-buffer-name", type=str, default=None)
return parser.parse_args()
@ -131,22 +139,36 @@ def test_rainbow(args=get_args()):
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(
args.logdir, args.task, 'rainbow',
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}'
)
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "rainbow"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
elif "Pong" in args.task:
return mean_rewards >= 20
else:
return False
@ -203,7 +225,7 @@ def test_rainbow(args=get_args()):
n_episode=args.test_num, render=args.render
)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
if args.watch:
watch()
@ -234,5 +256,5 @@ def test_rainbow(args=get_args()):
watch()
if __name__ == '__main__':
if __name__ == "__main__":
test_rainbow(get_args())

View File

@ -79,11 +79,14 @@ def test_psrl(args=get_args()):
logger = WandbLogger(
save_interval=1, project='psrl', name='wandb_test', config=args
)
elif args.logger == "tensorboard":
if args.logger != "none":
log_path = os.path.join(args.logdir, args.task, 'psrl')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else:
logger.load(writer)
else:
logger = LazyLogger()

View File

@ -2,7 +2,9 @@ import argparse
import os
from typing import Callable, Optional, Tuple
from tianshou.utils import BaseLogger
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BaseLogger, TensorboardLogger
from tianshou.utils.logger.base import LOG_DATA_TYPE
try:
@ -17,17 +19,13 @@ class WandbLogger(BaseLogger):
This logger creates three panels with plots: train, test, and update.
Make sure to select the correct access for each panel in weights and biases:
- ``train/env_step`` for train plots
- ``test/env_step`` for test plots
- ``update/gradient_step`` for update plots
Example of usage:
::
with wandb.init(project="My Project"):
logger = WandBLogger()
result = onpolicy_trainer(policy, train_collector, test_collector,
logger=logger)
logger = WandbLogger()
logger.load(SummaryWriter(log_path))
result = onpolicy_trainer(policy, train_collector, test_collector,
logger=logger)
:param int train_interval: the log interval in log_train_data(). Default to 1000.
:param int test_interval: the log interval in log_test_data(). Default to 1.
@ -46,7 +44,7 @@ class WandbLogger(BaseLogger):
test_interval: int = 1,
update_interval: int = 1000,
save_interval: int = 1000,
project: str = 'tianshou',
project: Optional[str] = None,
name: Optional[str] = None,
entity: Optional[str] = None,
run_id: Optional[str] = None,
@ -56,6 +54,8 @@ class WandbLogger(BaseLogger):
self.last_save_step = -1
self.save_interval = save_interval
self.restored = False
if project is None:
project = os.getenv("WANDB_PROJECT", "tianshou")
self.wandb_run = wandb.init(
project=project,
@ -63,14 +63,25 @@ class WandbLogger(BaseLogger):
id=run_id,
resume="allow",
entity=entity,
sync_tensorboard=True,
monitor_gym=True,
config=config, # type: ignore
) if not wandb.run else wandb.run
self.wandb_run._label(repo="tianshou") # type: ignore
self.tensorboard_logger: Optional[TensorboardLogger] = None
def load(self, writer: SummaryWriter) -> None:
self.writer = writer
self.tensorboard_logger = TensorboardLogger(writer)
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
data[step_type] = step
wandb.log(data)
if self.tensorboard_logger is None:
raise Exception(
"`logger` needs to load the Tensorboard Writer before "
"writing data. Try `logger.load(SummaryWriter(log_path))`"
)
else:
self.tensorboard_logger.write(step_type, step, data)
def save_data(
self,