Update WandbLogger implementation (#558)

* Use `global_step` as the x-axis for wandb
* Use Tensorboard SummaryWritter as core with `wandb.init(..., sync_tensorboard=True)`
* Update all atari examples with wandb

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
This commit is contained in:
Costa Huang 2022-03-06 17:40:47 -05:00 committed by GitHub
parent 2377f2f186
commit df3d7f582b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 482 additions and 320 deletions

View File

@ -34,8 +34,12 @@ WandbLogger
:: ::
from tianshou.utils import WandbLogger from tianshou.utils import WandbLogger
from torch.utils.tensorboard import SummaryWriter
logger = WandbLogger(...) logger = WandbLogger(...)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger.load(writer)
result = trainer(..., logger=logger) result = trainer(..., logger=logger)
Please refer to :class:`~tianshou.utils.WandbLogger` documentation for advanced configuration. Please refer to :class:`~tianshou.utils.WandbLogger` documentation for advanced configuration.

View File

@ -1,4 +1,5 @@
import argparse import argparse
import datetime
import os import os
import pprint import pprint
@ -11,46 +12,54 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import C51Policy from tianshou.policy import C51Policy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger, WandbLogger
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument('--seed', type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--scale-obs', type=int, default=0) parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005) parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.) parser.add_argument("--eps-train", type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05) parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000) parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001) parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument('--num-atoms', type=int, default=51) parser.add_argument("--num-atoms", type=int, default=51)
parser.add_argument('--v-min', type=float, default=-10.) parser.add_argument("--v-min", type=float, default=-10.)
parser.add_argument('--v-max', type=float, default=10.) parser.add_argument("--v-max", type=float, default=10.)
parser.add_argument('--n-step', type=int, default=3) parser.add_argument("--n-step", type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument("--epoch", type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000) parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32) parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument('--training-num', type=int, default=10) parser.add_argument("--training-num", type=int, default=10)
parser.add_argument('--test-num', type=int, default=10) parser.add_argument("--test-num", type=int, default=10)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument("--logdir", type=str, default="log")
parser.add_argument('--render', type=float, default=0.) parser.add_argument("--render", type=float, default=0.)
parser.add_argument( parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
) )
parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None) parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument( parser.add_argument(
'--watch', "--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument(
"--watch",
default=False, default=False,
action='store_true', action="store_true",
help='watch the play of pre-trained policy only' help="watch the play of pre-trained policy only"
) )
parser.add_argument('--save-buffer-name', type=str, default=None) parser.add_argument("--save-buffer-name", type=str, default=None)
return parser.parse_args() return parser.parse_args()
@ -101,19 +110,36 @@ def test_c51(args=get_args()):
# collector # collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
# log # log
log_path = os.path.join(args.logdir, args.task, 'c51') now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "c51"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
writer.add_text("args", str(args)) writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer) logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
if env.spec.reward_threshold: if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task: elif "Pong" in args.task:
return mean_rewards >= 20 return mean_rewards >= 20
else: else:
return False return False
@ -159,7 +185,7 @@ def test_c51(args=get_args()):
n_episode=args.test_num, render=args.render n_episode=args.test_num, render=args.render
) )
rew = result["rews"].mean() rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
if args.watch: if args.watch:
watch() watch()
@ -190,5 +216,5 @@ def test_c51(args=get_args()):
watch() watch()
if __name__ == '__main__': if __name__ == "__main__":
test_c51(get_args()) test_c51(get_args())

View File

@ -1,4 +1,5 @@
import argparse import argparse
import datetime
import os import os
import pprint import pprint
@ -18,62 +19,63 @@ from tianshou.utils.net.discrete import IntrinsicCuriosityModule
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument('--seed', type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--scale-obs', type=int, default=0) parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005) parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.) parser.add_argument("--eps-train", type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05) parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000) parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001) parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument('--n-step', type=int, default=3) parser.add_argument("--n-step", type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument("--epoch", type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000) parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32) parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument('--training-num', type=int, default=10) parser.add_argument("--training-num", type=int, default=10)
parser.add_argument('--test-num', type=int, default=10) parser.add_argument("--test-num", type=int, default=10)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument("--logdir", type=str, default="log")
parser.add_argument('--render', type=float, default=0.) parser.add_argument("--render", type=float, default=0.)
parser.add_argument( parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
) )
parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None) parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument('--resume-id', type=str, default=None) parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument( parser.add_argument(
'--logger', "--logger",
type=str, type=str,
default="tensorboard", default="tensorboard",
choices=["tensorboard", "wandb"], choices=["tensorboard", "wandb"],
) )
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument( parser.add_argument(
'--watch', "--watch",
default=False, default=False,
action='store_true', action="store_true",
help='watch the play of pre-trained policy only' help="watch the play of pre-trained policy only"
) )
parser.add_argument('--save-buffer-name', type=str, default=None) parser.add_argument("--save-buffer-name", type=str, default=None)
parser.add_argument( parser.add_argument(
'--icm-lr-scale', "--icm-lr-scale",
type=float, type=float,
default=0., default=0.,
help='use intrinsic curiosity module with this lr scale' help="use intrinsic curiosity module with this lr scale"
) )
parser.add_argument( parser.add_argument(
'--icm-reward-scale', "--icm-reward-scale",
type=float, type=float,
default=0.01, default=0.01,
help='scaling factor for intrinsic curiosity reward' help="scaling factor for intrinsic curiosity reward"
) )
parser.add_argument( parser.add_argument(
'--icm-forward-loss-weight', "--icm-forward-loss-weight",
type=float, type=float,
default=0.2, default=0.2,
help='weight for the forward model loss in ICM' help="weight for the forward model loss in ICM"
) )
return parser.parse_args() return parser.parse_args()
@ -140,29 +142,36 @@ def test_dqn(args=get_args()):
# collector # collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
# log # log
log_name = 'dqn_icm' if args.icm_lr_scale > 0 else 'dqn' now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_path = os.path.join(args.logdir, args.task, log_name) args.algo_name = "dqn_icm" if args.icm_lr_scale > 0 else "dqn"
if args.logger == "tensorboard": log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
writer = SummaryWriter(log_path) log_path = os.path.join(args.logdir, log_name)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer) # logger
else: if args.logger == "wandb":
logger = WandbLogger( logger = WandbLogger(
save_interval=1, save_interval=1,
project=args.task, name=log_name.replace(os.path.sep, "__"),
name=log_name,
run_id=args.resume_id, run_id=args.resume_id,
config=args, config=args,
project=args.wandb_project,
) )
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
if env.spec.reward_threshold: if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task: elif "Pong" in args.task:
return mean_rewards >= 20 return mean_rewards >= 20
else: else:
return False return False
@ -183,8 +192,8 @@ def test_dqn(args=get_args()):
def save_checkpoint_fn(epoch, env_step, gradient_step): def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, 'checkpoint.pth') ckpt_path = os.path.join(log_path, "checkpoint.pth")
torch.save({'model': policy.state_dict()}, ckpt_path) torch.save({"model": policy.state_dict()}, ckpt_path)
return ckpt_path return ckpt_path
# watch agent's performance # watch agent's performance
@ -214,7 +223,7 @@ def test_dqn(args=get_args()):
n_episode=args.test_num, render=args.render n_episode=args.test_num, render=args.render
) )
rew = result["rews"].mean() rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
if args.watch: if args.watch:
watch() watch()
@ -247,5 +256,5 @@ def test_dqn(args=get_args()):
watch() watch()
if __name__ == '__main__': if __name__ == "__main__":
test_dqn(get_args()) test_dqn(get_args())

View File

@ -1,4 +1,5 @@
import argparse import argparse
import datetime
import os import os
import pprint import pprint
@ -11,49 +12,57 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import FQFPolicy from tianshou.policy import FQFPolicy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument('--seed', type=int, default=3128) parser.add_argument("--seed", type=int, default=3128)
parser.add_argument('--scale-obs', type=int, default=0) parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005) parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.) parser.add_argument("--eps-train", type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05) parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000) parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument('--lr', type=float, default=5e-5) parser.add_argument("--lr", type=float, default=5e-5)
parser.add_argument('--fraction-lr', type=float, default=2.5e-9) parser.add_argument("--fraction-lr", type=float, default=2.5e-9)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument('--num-fractions', type=int, default=32) parser.add_argument("--num-fractions", type=int, default=32)
parser.add_argument('--num-cosines', type=int, default=64) parser.add_argument("--num-cosines", type=int, default=64)
parser.add_argument('--ent-coef', type=float, default=10.) parser.add_argument("--ent-coef", type=float, default=10.)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512])
parser.add_argument('--n-step', type=int, default=3) parser.add_argument("--n-step", type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument("--epoch", type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000) parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32) parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument('--training-num', type=int, default=10) parser.add_argument("--training-num", type=int, default=10)
parser.add_argument('--test-num', type=int, default=10) parser.add_argument("--test-num", type=int, default=10)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument("--logdir", type=str, default="log")
parser.add_argument('--render', type=float, default=0.) parser.add_argument("--render", type=float, default=0.)
parser.add_argument( parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
) )
parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None) parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument( parser.add_argument(
'--watch', "--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument(
"--watch",
default=False, default=False,
action='store_true', action="store_true",
help='watch the play of pre-trained policy only' help="watch the play of pre-trained policy only"
) )
parser.add_argument('--save-buffer-name', type=str, default=None) parser.add_argument("--save-buffer-name", type=str, default=None)
return parser.parse_args() return parser.parse_args()
@ -118,19 +127,36 @@ def test_fqf(args=get_args()):
# collector # collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
# log # log
log_path = os.path.join(args.logdir, args.task, 'fqf') now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "fqf"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
writer.add_text("args", str(args)) writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer) logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
if env.spec.reward_threshold: if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task: elif "Pong" in args.task:
return mean_rewards >= 20 return mean_rewards >= 20
else: else:
return False return False
@ -176,7 +202,7 @@ def test_fqf(args=get_args()):
n_episode=args.test_num, render=args.render n_episode=args.test_num, render=args.render
) )
rew = result["rews"].mean() rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
if args.watch: if args.watch:
watch() watch()
@ -207,5 +233,5 @@ def test_fqf(args=get_args()):
watch() watch()
if __name__ == '__main__': if __name__ == "__main__":
test_fqf(get_args()) test_fqf(get_args())

View File

@ -1,4 +1,5 @@
import argparse import argparse
import datetime
import os import os
import pprint import pprint
@ -11,49 +12,57 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import IQNPolicy from tianshou.policy import IQNPolicy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.discrete import ImplicitQuantileNetwork from tianshou.utils.net.discrete import ImplicitQuantileNetwork
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument('--seed', type=int, default=1234) parser.add_argument("--seed", type=int, default=1234)
parser.add_argument('--scale-obs', type=int, default=0) parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005) parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.) parser.add_argument("--eps-train", type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05) parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000) parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001) parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument('--sample-size', type=int, default=32) parser.add_argument("--sample-size", type=int, default=32)
parser.add_argument('--online-sample-size', type=int, default=8) parser.add_argument("--online-sample-size", type=int, default=8)
parser.add_argument('--target-sample-size', type=int, default=8) parser.add_argument("--target-sample-size", type=int, default=8)
parser.add_argument('--num-cosines', type=int, default=64) parser.add_argument("--num-cosines", type=int, default=64)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512])
parser.add_argument('--n-step', type=int, default=3) parser.add_argument("--n-step", type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument("--epoch", type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000) parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32) parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument('--training-num', type=int, default=10) parser.add_argument("--training-num", type=int, default=10)
parser.add_argument('--test-num', type=int, default=10) parser.add_argument("--test-num", type=int, default=10)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument("--logdir", type=str, default="log")
parser.add_argument('--render', type=float, default=0.) parser.add_argument("--render", type=float, default=0.)
parser.add_argument( parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
) )
parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None) parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument( parser.add_argument(
'--watch', "--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument(
"--watch",
default=False, default=False,
action='store_true', action="store_true",
help='watch the play of pre-trained policy only' help="watch the play of pre-trained policy only"
) )
parser.add_argument('--save-buffer-name', type=str, default=None) parser.add_argument("--save-buffer-name", type=str, default=None)
return parser.parse_args() return parser.parse_args()
@ -113,19 +122,36 @@ def test_iqn(args=get_args()):
# collector # collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
# log # log
log_path = os.path.join(args.logdir, args.task, 'iqn') now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "iqn"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
writer.add_text("args", str(args)) writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer) logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
if env.spec.reward_threshold: if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task: elif "Pong" in args.task:
return mean_rewards >= 20 return mean_rewards >= 20
else: else:
return False return False
@ -171,7 +197,7 @@ def test_iqn(args=get_args()):
n_episode=args.test_num, render=args.render n_episode=args.test_num, render=args.render
) )
rew = result["rews"].mean() rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
if args.watch: if args.watch:
watch() watch()
@ -202,5 +228,5 @@ def test_iqn(args=get_args()):
watch() watch()
if __name__ == '__main__': if __name__ == "__main__":
test_iqn(get_args()) test_iqn(get_args())

View File

@ -1,4 +1,5 @@
import argparse import argparse
import datetime
import os import os
import pprint import pprint
@ -19,69 +20,70 @@ from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument('--seed', type=int, default=4213) parser.add_argument("--seed", type=int, default=4213)
parser.add_argument('--scale-obs', type=int, default=0) parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=100000) parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument('--lr', type=float, default=5e-5) parser.add_argument("--lr", type=float, default=5e-5)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument("--epoch", type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000) parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=1000) parser.add_argument("--step-per-collect", type=int, default=1000)
parser.add_argument('--repeat-per-collect', type=int, default=4) parser.add_argument("--repeat-per-collect", type=int, default=4)
parser.add_argument('--batch-size', type=int, default=256) parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument('--hidden-size', type=int, default=512) parser.add_argument("--hidden-size", type=int, default=512)
parser.add_argument('--training-num', type=int, default=10) parser.add_argument("--training-num", type=int, default=10)
parser.add_argument('--test-num', type=int, default=10) parser.add_argument("--test-num", type=int, default=10)
parser.add_argument('--rew-norm', type=int, default=False) parser.add_argument("--rew-norm", type=int, default=False)
parser.add_argument('--vf-coef', type=float, default=0.5) parser.add_argument("--vf-coef", type=float, default=0.5)
parser.add_argument('--ent-coef', type=float, default=0.01) parser.add_argument("--ent-coef", type=float, default=0.01)
parser.add_argument('--gae-lambda', type=float, default=0.95) parser.add_argument("--gae-lambda", type=float, default=0.95)
parser.add_argument('--lr-decay', type=int, default=True) parser.add_argument("--lr-decay", type=int, default=True)
parser.add_argument('--max-grad-norm', type=float, default=0.5) parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument('--eps-clip', type=float, default=0.2) parser.add_argument("--eps-clip", type=float, default=0.2)
parser.add_argument('--dual-clip', type=float, default=None) parser.add_argument("--dual-clip", type=float, default=None)
parser.add_argument('--value-clip', type=int, default=0) parser.add_argument("--value-clip", type=int, default=0)
parser.add_argument('--norm-adv', type=int, default=1) parser.add_argument("--norm-adv", type=int, default=1)
parser.add_argument('--recompute-adv', type=int, default=0) parser.add_argument("--recompute-adv", type=int, default=0)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument("--logdir", type=str, default="log")
parser.add_argument('--render', type=float, default=0.) parser.add_argument("--render", type=float, default=0.)
parser.add_argument( parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
) )
parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None) parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument('--resume-id', type=str, default=None) parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument( parser.add_argument(
'--logger', "--logger",
type=str, type=str,
default="tensorboard", default="tensorboard",
choices=["tensorboard", "wandb"], choices=["tensorboard", "wandb"],
) )
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument( parser.add_argument(
'--watch', "--watch",
default=False, default=False,
action='store_true', action="store_true",
help='watch the play of pre-trained policy only' help="watch the play of pre-trained policy only"
) )
parser.add_argument('--save-buffer-name', type=str, default=None) parser.add_argument("--save-buffer-name", type=str, default=None)
parser.add_argument( parser.add_argument(
'--icm-lr-scale', "--icm-lr-scale",
type=float, type=float,
default=0., default=0.,
help='use intrinsic curiosity module with this lr scale' help="use intrinsic curiosity module with this lr scale"
) )
parser.add_argument( parser.add_argument(
'--icm-reward-scale', "--icm-reward-scale",
type=float, type=float,
default=0.01, default=0.01,
help='scaling factor for intrinsic curiosity reward' help="scaling factor for intrinsic curiosity reward"
) )
parser.add_argument( parser.add_argument(
'--icm-forward-loss-weight', "--icm-forward-loss-weight",
type=float, type=float,
default=0.2, default=0.2,
help='weight for the forward model loss in ICM' help="weight for the forward model loss in ICM"
) )
return parser.parse_args() return parser.parse_args()
@ -184,37 +186,44 @@ def test_ppo(args=get_args()):
# collector # collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
# log # log
log_name = 'ppo_icm' if args.icm_lr_scale > 0 else 'ppo' now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_path = os.path.join(args.logdir, args.task, log_name) args.algo_name = "ppo_icm" if args.icm_lr_scale > 0 else "ppo"
if args.logger == "tensorboard": log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
writer = SummaryWriter(log_path) log_path = os.path.join(args.logdir, log_name)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer) # logger
else: if args.logger == "wandb":
logger = WandbLogger( logger = WandbLogger(
save_interval=1, save_interval=1,
project=args.task, name=log_name.replace(os.path.sep, "__"),
name=log_name,
run_id=args.resume_id, run_id=args.resume_id,
config=args, config=args,
project=args.wandb_project,
) )
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
if env.spec.reward_threshold: if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task: elif "Pong" in args.task:
return mean_rewards >= 20 return mean_rewards >= 20
else: else:
return False return False
def save_checkpoint_fn(epoch, env_step, gradient_step): def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, 'checkpoint.pth') ckpt_path = os.path.join(log_path, "checkpoint.pth")
torch.save({'model': policy.state_dict()}, ckpt_path) torch.save({"model": policy.state_dict()}, ckpt_path)
return ckpt_path return ckpt_path
# watch agent's performance # watch agent's performance
@ -243,7 +252,7 @@ def test_ppo(args=get_args()):
n_episode=args.test_num, render=args.render n_episode=args.test_num, render=args.render
) )
rew = result["rews"].mean() rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
if args.watch: if args.watch:
watch() watch()
@ -274,5 +283,5 @@ def test_ppo(args=get_args()):
watch() watch()
if __name__ == '__main__': if __name__ == "__main__":
test_ppo(get_args()) test_ppo(get_args())

View File

@ -1,4 +1,5 @@
import argparse import argparse
import datetime
import os import os
import pprint import pprint
@ -6,7 +7,7 @@ import numpy as np
import torch import torch
from atari_network import QRDQN from atari_network import QRDQN
from atari_wrapper import make_atari_env from atari_wrapper import make_atari_env
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter, WandbLogger
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import QRDQNPolicy from tianshou.policy import QRDQNPolicy
@ -16,39 +17,47 @@ from tianshou.utils import TensorboardLogger
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument('--seed', type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--scale-obs', type=int, default=0) parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005) parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.) parser.add_argument("--eps-train", type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05) parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000) parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001) parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument('--num-quantiles', type=int, default=200) parser.add_argument("--num-quantiles", type=int, default=200)
parser.add_argument('--n-step', type=int, default=3) parser.add_argument("--n-step", type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument("--epoch", type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000) parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32) parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument('--training-num', type=int, default=10) parser.add_argument("--training-num", type=int, default=10)
parser.add_argument('--test-num', type=int, default=10) parser.add_argument("--test-num", type=int, default=10)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument("--logdir", type=str, default="log")
parser.add_argument('--render', type=float, default=0.) parser.add_argument("--render", type=float, default=0.)
parser.add_argument( parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
) )
parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None) parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument( parser.add_argument(
'--watch', "--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument(
"--watch",
default=False, default=False,
action='store_true', action="store_true",
help='watch the play of pre-trained policy only' help="watch the play of pre-trained policy only"
) )
parser.add_argument('--save-buffer-name', type=str, default=None) parser.add_argument("--save-buffer-name", type=str, default=None)
return parser.parse_args() return parser.parse_args()
@ -97,19 +106,36 @@ def test_qrdqn(args=get_args()):
# collector # collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
# log # log
log_path = os.path.join(args.logdir, args.task, 'qrdqn') now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "qrdqn"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
writer.add_text("args", str(args)) writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer) logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
if env.spec.reward_threshold: if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task: elif "Pong" in args.task:
return mean_rewards >= 20 return mean_rewards >= 20
else: else:
return False return False
@ -155,7 +181,7 @@ def test_qrdqn(args=get_args()):
n_episode=args.test_num, render=args.render n_episode=args.test_num, render=args.render
) )
rew = result["rews"].mean() rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
if args.watch: if args.watch:
watch() watch()
@ -186,5 +212,5 @@ def test_qrdqn(args=get_args()):
watch() watch()
if __name__ == '__main__': if __name__ == "__main__":
test_qrdqn(get_args()) test_qrdqn(get_args())

View File

@ -7,7 +7,7 @@ import numpy as np
import torch import torch
from atari_network import Rainbow from atari_network import Rainbow
from atari_wrapper import make_atari_env from atari_wrapper import make_atari_env
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter, WandbLogger
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.policy import RainbowPolicy from tianshou.policy import RainbowPolicy
@ -17,50 +17,58 @@ from tianshou.utils import TensorboardLogger
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument('--seed', type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--scale-obs', type=int, default=0) parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005) parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.) parser.add_argument("--eps-train", type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05) parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000) parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0000625) parser.add_argument("--lr", type=float, default=0.0000625)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument('--num-atoms', type=int, default=51) parser.add_argument("--num-atoms", type=int, default=51)
parser.add_argument('--v-min', type=float, default=-10.) parser.add_argument("--v-min", type=float, default=-10.)
parser.add_argument('--v-max', type=float, default=10.) parser.add_argument("--v-max", type=float, default=10.)
parser.add_argument('--noisy-std', type=float, default=0.1) parser.add_argument("--noisy-std", type=float, default=0.1)
parser.add_argument('--no-dueling', action='store_true', default=False) parser.add_argument("--no-dueling", action="store_true", default=False)
parser.add_argument('--no-noisy', action='store_true', default=False) parser.add_argument("--no-noisy", action="store_true", default=False)
parser.add_argument('--no-priority', action='store_true', default=False) parser.add_argument("--no-priority", action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.5) parser.add_argument("--alpha", type=float, default=0.5)
parser.add_argument('--beta', type=float, default=0.4) parser.add_argument("--beta", type=float, default=0.4)
parser.add_argument('--beta-final', type=float, default=1.) parser.add_argument("--beta-final", type=float, default=1.)
parser.add_argument('--beta-anneal-step', type=int, default=5000000) parser.add_argument("--beta-anneal-step", type=int, default=5000000)
parser.add_argument('--no-weight-norm', action='store_true', default=False) parser.add_argument("--no-weight-norm", action="store_true", default=False)
parser.add_argument('--n-step', type=int, default=3) parser.add_argument("--n-step", type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument("--epoch", type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000) parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32) parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument('--training-num', type=int, default=10) parser.add_argument("--training-num", type=int, default=10)
parser.add_argument('--test-num', type=int, default=10) parser.add_argument("--test-num", type=int, default=10)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument("--logdir", type=str, default="log")
parser.add_argument('--render', type=float, default=0.) parser.add_argument("--render", type=float, default=0.)
parser.add_argument( parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
) )
parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None) parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument( parser.add_argument(
'--watch', "--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument(
"--watch",
default=False, default=False,
action='store_true', action="store_true",
help='watch the play of pre-trained policy only' help="watch the play of pre-trained policy only"
) )
parser.add_argument('--save-buffer-name', type=str, default=None) parser.add_argument("--save-buffer-name", type=str, default=None)
return parser.parse_args() return parser.parse_args()
@ -131,22 +139,36 @@ def test_rainbow(args=get_args()):
# collector # collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
# log # log
log_path = os.path.join( now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.logdir, args.task, 'rainbow', args.algo_name = "rainbow"
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}' log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
) )
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
writer.add_text("args", str(args)) writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer) logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
if env.spec.reward_threshold: if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task: elif "Pong" in args.task:
return mean_rewards >= 20 return mean_rewards >= 20
else: else:
return False return False
@ -203,7 +225,7 @@ def test_rainbow(args=get_args()):
n_episode=args.test_num, render=args.render n_episode=args.test_num, render=args.render
) )
rew = result["rews"].mean() rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
if args.watch: if args.watch:
watch() watch()
@ -234,5 +256,5 @@ def test_rainbow(args=get_args()):
watch() watch()
if __name__ == '__main__': if __name__ == "__main__":
test_rainbow(get_args()) test_rainbow(get_args())

View File

@ -79,11 +79,14 @@ def test_psrl(args=get_args()):
logger = WandbLogger( logger = WandbLogger(
save_interval=1, project='psrl', name='wandb_test', config=args save_interval=1, project='psrl', name='wandb_test', config=args
) )
elif args.logger == "tensorboard": if args.logger != "none":
log_path = os.path.join(args.logdir, args.task, 'psrl') log_path = os.path.join(args.logdir, args.task, 'psrl')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
writer.add_text("args", str(args)) writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer) logger = TensorboardLogger(writer)
else:
logger.load(writer)
else: else:
logger = LazyLogger() logger = LazyLogger()

View File

@ -2,7 +2,9 @@ import argparse
import os import os
from typing import Callable, Optional, Tuple from typing import Callable, Optional, Tuple
from tianshou.utils import BaseLogger from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BaseLogger, TensorboardLogger
from tianshou.utils.logger.base import LOG_DATA_TYPE from tianshou.utils.logger.base import LOG_DATA_TYPE
try: try:
@ -17,15 +19,11 @@ class WandbLogger(BaseLogger):
This logger creates three panels with plots: train, test, and update. This logger creates three panels with plots: train, test, and update.
Make sure to select the correct access for each panel in weights and biases: Make sure to select the correct access for each panel in weights and biases:
- ``train/env_step`` for train plots
- ``test/env_step`` for test plots
- ``update/gradient_step`` for update plots
Example of usage: Example of usage:
:: ::
with wandb.init(project="My Project"): logger = WandbLogger()
logger = WandBLogger() logger.load(SummaryWriter(log_path))
result = onpolicy_trainer(policy, train_collector, test_collector, result = onpolicy_trainer(policy, train_collector, test_collector,
logger=logger) logger=logger)
@ -46,7 +44,7 @@ class WandbLogger(BaseLogger):
test_interval: int = 1, test_interval: int = 1,
update_interval: int = 1000, update_interval: int = 1000,
save_interval: int = 1000, save_interval: int = 1000,
project: str = 'tianshou', project: Optional[str] = None,
name: Optional[str] = None, name: Optional[str] = None,
entity: Optional[str] = None, entity: Optional[str] = None,
run_id: Optional[str] = None, run_id: Optional[str] = None,
@ -56,6 +54,8 @@ class WandbLogger(BaseLogger):
self.last_save_step = -1 self.last_save_step = -1
self.save_interval = save_interval self.save_interval = save_interval
self.restored = False self.restored = False
if project is None:
project = os.getenv("WANDB_PROJECT", "tianshou")
self.wandb_run = wandb.init( self.wandb_run = wandb.init(
project=project, project=project,
@ -63,14 +63,25 @@ class WandbLogger(BaseLogger):
id=run_id, id=run_id,
resume="allow", resume="allow",
entity=entity, entity=entity,
sync_tensorboard=True,
monitor_gym=True, monitor_gym=True,
config=config, # type: ignore config=config, # type: ignore
) if not wandb.run else wandb.run ) if not wandb.run else wandb.run
self.wandb_run._label(repo="tianshou") # type: ignore self.wandb_run._label(repo="tianshou") # type: ignore
self.tensorboard_logger: Optional[TensorboardLogger] = None
def load(self, writer: SummaryWriter) -> None:
self.writer = writer
self.tensorboard_logger = TensorboardLogger(writer)
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
data[step_type] = step if self.tensorboard_logger is None:
wandb.log(data) raise Exception(
"`logger` needs to load the Tensorboard Writer before "
"writing data. Try `logger.load(SummaryWriter(log_path))`"
)
else:
self.tensorboard_logger.write(step_type, step, data)
def save_data( def save_data(
self, self,