Update WandbLogger implementation (#558)
* Use `global_step` as the x-axis for wandb * Use Tensorboard SummaryWritter as core with `wandb.init(..., sync_tensorboard=True)` * Update all atari examples with wandb Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
This commit is contained in:
parent
2377f2f186
commit
df3d7f582b
@ -34,8 +34,12 @@ WandbLogger
|
||||
::
|
||||
|
||||
from tianshou.utils import WandbLogger
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
logger = WandbLogger(...)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger.load(writer)
|
||||
result = trainer(..., logger=logger)
|
||||
|
||||
Please refer to :class:`~tianshou.utils.WandbLogger` documentation for advanced configuration.
|
||||
|
@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pprint
|
||||
|
||||
@ -11,46 +12,54 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.policy import C51Policy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--scale-obs', type=int, default=0)
|
||||
parser.add_argument('--eps-test', type=float, default=0.005)
|
||||
parser.add_argument('--eps-train', type=float, default=1.)
|
||||
parser.add_argument('--eps-train-final', type=float, default=0.05)
|
||||
parser.add_argument('--buffer-size', type=int, default=100000)
|
||||
parser.add_argument('--lr', type=float, default=0.0001)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--num-atoms', type=int, default=51)
|
||||
parser.add_argument('--v-min', type=float, default=-10.)
|
||||
parser.add_argument('--v-max', type=float, default=10.)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||
parser.add_argument('--batch-size', type=int, default=32)
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--scale-obs", type=int, default=0)
|
||||
parser.add_argument("--eps-test", type=float, default=0.005)
|
||||
parser.add_argument("--eps-train", type=float, default=1.)
|
||||
parser.add_argument("--eps-train-final", type=float, default=0.05)
|
||||
parser.add_argument("--buffer-size", type=int, default=100000)
|
||||
parser.add_argument("--lr", type=float, default=0.0001)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument("--num-atoms", type=int, default=51)
|
||||
parser.add_argument("--v-min", type=float, default=-10.)
|
||||
parser.add_argument("--v-max", type=float, default=10.)
|
||||
parser.add_argument("--n-step", type=int, default=3)
|
||||
parser.add_argument("--target-update-freq", type=int, default=500)
|
||||
parser.add_argument("--epoch", type=int, default=100)
|
||||
parser.add_argument("--step-per-epoch", type=int, default=100000)
|
||||
parser.add_argument("--step-per-collect", type=int, default=10)
|
||||
parser.add_argument("--update-per-step", type=float, default=0.1)
|
||||
parser.add_argument("--batch-size", type=int, default=32)
|
||||
parser.add_argument("--training-num", type=int, default=10)
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
parser.add_argument('--frames-stack', type=int, default=4)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument("--frames-stack", type=int, default=4)
|
||||
parser.add_argument("--resume-path", type=str, default=None)
|
||||
parser.add_argument("--resume-id", type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--watch',
|
||||
"--logger",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
choices=["tensorboard", "wandb"],
|
||||
)
|
||||
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
|
||||
parser.add_argument(
|
||||
"--watch",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='watch the play of pre-trained policy only'
|
||||
action="store_true",
|
||||
help="watch the play of pre-trained policy only"
|
||||
)
|
||||
parser.add_argument('--save-buffer-name', type=str, default=None)
|
||||
parser.add_argument("--save-buffer-name", type=str, default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -101,19 +110,36 @@ def test_c51(args=get_args()):
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'c51')
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
args.algo_name = "c51"
|
||||
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
|
||||
log_path = os.path.join(args.logdir, log_name)
|
||||
|
||||
# logger
|
||||
if args.logger == "wandb":
|
||||
logger = WandbLogger(
|
||||
save_interval=1,
|
||||
name=log_name.replace(os.path.sep, "__"),
|
||||
run_id=args.resume_id,
|
||||
config=args,
|
||||
project=args.wandb_project,
|
||||
)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
if args.logger == "tensorboard":
|
||||
logger = TensorboardLogger(writer)
|
||||
else: # wandb
|
||||
logger.load(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
if env.spec.reward_threshold:
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
elif 'Pong' in args.task:
|
||||
elif "Pong" in args.task:
|
||||
return mean_rewards >= 20
|
||||
else:
|
||||
return False
|
||||
@ -159,7 +185,7 @@ def test_c51(args=get_args()):
|
||||
n_episode=args.test_num, render=args.render
|
||||
)
|
||||
rew = result["rews"].mean()
|
||||
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
|
||||
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
@ -190,5 +216,5 @@ def test_c51(args=get_args()):
|
||||
watch()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_c51(get_args())
|
||||
|
@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pprint
|
||||
|
||||
@ -18,62 +19,63 @@ from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--scale-obs', type=int, default=0)
|
||||
parser.add_argument('--eps-test', type=float, default=0.005)
|
||||
parser.add_argument('--eps-train', type=float, default=1.)
|
||||
parser.add_argument('--eps-train-final', type=float, default=0.05)
|
||||
parser.add_argument('--buffer-size', type=int, default=100000)
|
||||
parser.add_argument('--lr', type=float, default=0.0001)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||
parser.add_argument('--batch-size', type=int, default=32)
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--scale-obs", type=int, default=0)
|
||||
parser.add_argument("--eps-test", type=float, default=0.005)
|
||||
parser.add_argument("--eps-train", type=float, default=1.)
|
||||
parser.add_argument("--eps-train-final", type=float, default=0.05)
|
||||
parser.add_argument("--buffer-size", type=int, default=100000)
|
||||
parser.add_argument("--lr", type=float, default=0.0001)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument("--n-step", type=int, default=3)
|
||||
parser.add_argument("--target-update-freq", type=int, default=500)
|
||||
parser.add_argument("--epoch", type=int, default=100)
|
||||
parser.add_argument("--step-per-epoch", type=int, default=100000)
|
||||
parser.add_argument("--step-per-collect", type=int, default=10)
|
||||
parser.add_argument("--update-per-step", type=float, default=0.1)
|
||||
parser.add_argument("--batch-size", type=int, default=32)
|
||||
parser.add_argument("--training-num", type=int, default=10)
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
parser.add_argument('--frames-stack', type=int, default=4)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument('--resume-id', type=str, default=None)
|
||||
parser.add_argument("--frames-stack", type=int, default=4)
|
||||
parser.add_argument("--resume-path", type=str, default=None)
|
||||
parser.add_argument("--resume-id", type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--logger',
|
||||
"--logger",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
choices=["tensorboard", "wandb"],
|
||||
)
|
||||
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
|
||||
parser.add_argument(
|
||||
'--watch',
|
||||
"--watch",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='watch the play of pre-trained policy only'
|
||||
action="store_true",
|
||||
help="watch the play of pre-trained policy only"
|
||||
)
|
||||
parser.add_argument('--save-buffer-name', type=str, default=None)
|
||||
parser.add_argument("--save-buffer-name", type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--icm-lr-scale',
|
||||
"--icm-lr-scale",
|
||||
type=float,
|
||||
default=0.,
|
||||
help='use intrinsic curiosity module with this lr scale'
|
||||
help="use intrinsic curiosity module with this lr scale"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--icm-reward-scale',
|
||||
"--icm-reward-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help='scaling factor for intrinsic curiosity reward'
|
||||
help="scaling factor for intrinsic curiosity reward"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--icm-forward-loss-weight',
|
||||
"--icm-forward-loss-weight",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help='weight for the forward model loss in ICM'
|
||||
help="weight for the forward model loss in ICM"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@ -140,29 +142,36 @@ def test_dqn(args=get_args()):
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
|
||||
# log
|
||||
log_name = 'dqn_icm' if args.icm_lr_scale > 0 else 'dqn'
|
||||
log_path = os.path.join(args.logdir, args.task, log_name)
|
||||
if args.logger == "tensorboard":
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
else:
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
args.algo_name = "dqn_icm" if args.icm_lr_scale > 0 else "dqn"
|
||||
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
|
||||
log_path = os.path.join(args.logdir, log_name)
|
||||
|
||||
# logger
|
||||
if args.logger == "wandb":
|
||||
logger = WandbLogger(
|
||||
save_interval=1,
|
||||
project=args.task,
|
||||
name=log_name,
|
||||
name=log_name.replace(os.path.sep, "__"),
|
||||
run_id=args.resume_id,
|
||||
config=args,
|
||||
project=args.wandb_project,
|
||||
)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
if args.logger == "tensorboard":
|
||||
logger = TensorboardLogger(writer)
|
||||
else: # wandb
|
||||
logger.load(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
if env.spec.reward_threshold:
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
elif 'Pong' in args.task:
|
||||
elif "Pong" in args.task:
|
||||
return mean_rewards >= 20
|
||||
else:
|
||||
return False
|
||||
@ -183,8 +192,8 @@ def test_dqn(args=get_args()):
|
||||
|
||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
||||
torch.save({'model': policy.state_dict()}, ckpt_path)
|
||||
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||
torch.save({"model": policy.state_dict()}, ckpt_path)
|
||||
return ckpt_path
|
||||
|
||||
# watch agent's performance
|
||||
@ -214,7 +223,7 @@ def test_dqn(args=get_args()):
|
||||
n_episode=args.test_num, render=args.render
|
||||
)
|
||||
rew = result["rews"].mean()
|
||||
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
|
||||
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
@ -247,5 +256,5 @@ def test_dqn(args=get_args()):
|
||||
watch()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_dqn(get_args())
|
||||
|
@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pprint
|
||||
|
||||
@ -11,49 +12,57 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.policy import FQFPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
|
||||
parser.add_argument('--seed', type=int, default=3128)
|
||||
parser.add_argument('--scale-obs', type=int, default=0)
|
||||
parser.add_argument('--eps-test', type=float, default=0.005)
|
||||
parser.add_argument('--eps-train', type=float, default=1.)
|
||||
parser.add_argument('--eps-train-final', type=float, default=0.05)
|
||||
parser.add_argument('--buffer-size', type=int, default=100000)
|
||||
parser.add_argument('--lr', type=float, default=5e-5)
|
||||
parser.add_argument('--fraction-lr', type=float, default=2.5e-9)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--num-fractions', type=int, default=32)
|
||||
parser.add_argument('--num-cosines', type=int, default=64)
|
||||
parser.add_argument('--ent-coef', type=float, default=10.)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||
parser.add_argument('--batch-size', type=int, default=32)
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
|
||||
parser.add_argument("--seed", type=int, default=3128)
|
||||
parser.add_argument("--scale-obs", type=int, default=0)
|
||||
parser.add_argument("--eps-test", type=float, default=0.005)
|
||||
parser.add_argument("--eps-train", type=float, default=1.)
|
||||
parser.add_argument("--eps-train-final", type=float, default=0.05)
|
||||
parser.add_argument("--buffer-size", type=int, default=100000)
|
||||
parser.add_argument("--lr", type=float, default=5e-5)
|
||||
parser.add_argument("--fraction-lr", type=float, default=2.5e-9)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument("--num-fractions", type=int, default=32)
|
||||
parser.add_argument("--num-cosines", type=int, default=64)
|
||||
parser.add_argument("--ent-coef", type=float, default=10.)
|
||||
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512])
|
||||
parser.add_argument("--n-step", type=int, default=3)
|
||||
parser.add_argument("--target-update-freq", type=int, default=500)
|
||||
parser.add_argument("--epoch", type=int, default=100)
|
||||
parser.add_argument("--step-per-epoch", type=int, default=100000)
|
||||
parser.add_argument("--step-per-collect", type=int, default=10)
|
||||
parser.add_argument("--update-per-step", type=float, default=0.1)
|
||||
parser.add_argument("--batch-size", type=int, default=32)
|
||||
parser.add_argument("--training-num", type=int, default=10)
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
parser.add_argument('--frames-stack', type=int, default=4)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument("--frames-stack", type=int, default=4)
|
||||
parser.add_argument("--resume-path", type=str, default=None)
|
||||
parser.add_argument("--resume-id", type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--watch',
|
||||
"--logger",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
choices=["tensorboard", "wandb"],
|
||||
)
|
||||
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
|
||||
parser.add_argument(
|
||||
"--watch",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='watch the play of pre-trained policy only'
|
||||
action="store_true",
|
||||
help="watch the play of pre-trained policy only"
|
||||
)
|
||||
parser.add_argument('--save-buffer-name', type=str, default=None)
|
||||
parser.add_argument("--save-buffer-name", type=str, default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -118,19 +127,36 @@ def test_fqf(args=get_args()):
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'fqf')
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
args.algo_name = "fqf"
|
||||
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
|
||||
log_path = os.path.join(args.logdir, log_name)
|
||||
|
||||
# logger
|
||||
if args.logger == "wandb":
|
||||
logger = WandbLogger(
|
||||
save_interval=1,
|
||||
name=log_name.replace(os.path.sep, "__"),
|
||||
run_id=args.resume_id,
|
||||
config=args,
|
||||
project=args.wandb_project,
|
||||
)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
if args.logger == "tensorboard":
|
||||
logger = TensorboardLogger(writer)
|
||||
else: # wandb
|
||||
logger.load(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
if env.spec.reward_threshold:
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
elif 'Pong' in args.task:
|
||||
elif "Pong" in args.task:
|
||||
return mean_rewards >= 20
|
||||
else:
|
||||
return False
|
||||
@ -176,7 +202,7 @@ def test_fqf(args=get_args()):
|
||||
n_episode=args.test_num, render=args.render
|
||||
)
|
||||
rew = result["rews"].mean()
|
||||
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
|
||||
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
@ -207,5 +233,5 @@ def test_fqf(args=get_args()):
|
||||
watch()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_fqf(get_args())
|
||||
|
@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pprint
|
||||
|
||||
@ -11,49 +12,57 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.policy import IQNPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
|
||||
parser.add_argument('--seed', type=int, default=1234)
|
||||
parser.add_argument('--scale-obs', type=int, default=0)
|
||||
parser.add_argument('--eps-test', type=float, default=0.005)
|
||||
parser.add_argument('--eps-train', type=float, default=1.)
|
||||
parser.add_argument('--eps-train-final', type=float, default=0.05)
|
||||
parser.add_argument('--buffer-size', type=int, default=100000)
|
||||
parser.add_argument('--lr', type=float, default=0.0001)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--sample-size', type=int, default=32)
|
||||
parser.add_argument('--online-sample-size', type=int, default=8)
|
||||
parser.add_argument('--target-sample-size', type=int, default=8)
|
||||
parser.add_argument('--num-cosines', type=int, default=64)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||
parser.add_argument('--batch-size', type=int, default=32)
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
|
||||
parser.add_argument("--seed", type=int, default=1234)
|
||||
parser.add_argument("--scale-obs", type=int, default=0)
|
||||
parser.add_argument("--eps-test", type=float, default=0.005)
|
||||
parser.add_argument("--eps-train", type=float, default=1.)
|
||||
parser.add_argument("--eps-train-final", type=float, default=0.05)
|
||||
parser.add_argument("--buffer-size", type=int, default=100000)
|
||||
parser.add_argument("--lr", type=float, default=0.0001)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument("--sample-size", type=int, default=32)
|
||||
parser.add_argument("--online-sample-size", type=int, default=8)
|
||||
parser.add_argument("--target-sample-size", type=int, default=8)
|
||||
parser.add_argument("--num-cosines", type=int, default=64)
|
||||
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512])
|
||||
parser.add_argument("--n-step", type=int, default=3)
|
||||
parser.add_argument("--target-update-freq", type=int, default=500)
|
||||
parser.add_argument("--epoch", type=int, default=100)
|
||||
parser.add_argument("--step-per-epoch", type=int, default=100000)
|
||||
parser.add_argument("--step-per-collect", type=int, default=10)
|
||||
parser.add_argument("--update-per-step", type=float, default=0.1)
|
||||
parser.add_argument("--batch-size", type=int, default=32)
|
||||
parser.add_argument("--training-num", type=int, default=10)
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
parser.add_argument('--frames-stack', type=int, default=4)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument("--frames-stack", type=int, default=4)
|
||||
parser.add_argument("--resume-path", type=str, default=None)
|
||||
parser.add_argument("--resume-id", type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--watch',
|
||||
"--logger",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
choices=["tensorboard", "wandb"],
|
||||
)
|
||||
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
|
||||
parser.add_argument(
|
||||
"--watch",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='watch the play of pre-trained policy only'
|
||||
action="store_true",
|
||||
help="watch the play of pre-trained policy only"
|
||||
)
|
||||
parser.add_argument('--save-buffer-name', type=str, default=None)
|
||||
parser.add_argument("--save-buffer-name", type=str, default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -113,19 +122,36 @@ def test_iqn(args=get_args()):
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'iqn')
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
args.algo_name = "iqn"
|
||||
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
|
||||
log_path = os.path.join(args.logdir, log_name)
|
||||
|
||||
# logger
|
||||
if args.logger == "wandb":
|
||||
logger = WandbLogger(
|
||||
save_interval=1,
|
||||
name=log_name.replace(os.path.sep, "__"),
|
||||
run_id=args.resume_id,
|
||||
config=args,
|
||||
project=args.wandb_project,
|
||||
)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
if args.logger == "tensorboard":
|
||||
logger = TensorboardLogger(writer)
|
||||
else: # wandb
|
||||
logger.load(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
if env.spec.reward_threshold:
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
elif 'Pong' in args.task:
|
||||
elif "Pong" in args.task:
|
||||
return mean_rewards >= 20
|
||||
else:
|
||||
return False
|
||||
@ -171,7 +197,7 @@ def test_iqn(args=get_args()):
|
||||
n_episode=args.test_num, render=args.render
|
||||
)
|
||||
rew = result["rews"].mean()
|
||||
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
|
||||
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
@ -202,5 +228,5 @@ def test_iqn(args=get_args()):
|
||||
watch()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_iqn(get_args())
|
||||
|
@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pprint
|
||||
|
||||
@ -19,69 +20,70 @@ from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
|
||||
parser.add_argument('--seed', type=int, default=4213)
|
||||
parser.add_argument('--scale-obs', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=100000)
|
||||
parser.add_argument('--lr', type=float, default=5e-5)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=1000)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=4)
|
||||
parser.add_argument('--batch-size', type=int, default=256)
|
||||
parser.add_argument('--hidden-size', type=int, default=512)
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--rew-norm', type=int, default=False)
|
||||
parser.add_argument('--vf-coef', type=float, default=0.5)
|
||||
parser.add_argument('--ent-coef', type=float, default=0.01)
|
||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||
parser.add_argument('--lr-decay', type=int, default=True)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||
parser.add_argument('--dual-clip', type=float, default=None)
|
||||
parser.add_argument('--value-clip', type=int, default=0)
|
||||
parser.add_argument('--norm-adv', type=int, default=1)
|
||||
parser.add_argument('--recompute-adv', type=int, default=0)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
|
||||
parser.add_argument("--seed", type=int, default=4213)
|
||||
parser.add_argument("--scale-obs", type=int, default=0)
|
||||
parser.add_argument("--buffer-size", type=int, default=100000)
|
||||
parser.add_argument("--lr", type=float, default=5e-5)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument("--epoch", type=int, default=100)
|
||||
parser.add_argument("--step-per-epoch", type=int, default=100000)
|
||||
parser.add_argument("--step-per-collect", type=int, default=1000)
|
||||
parser.add_argument("--repeat-per-collect", type=int, default=4)
|
||||
parser.add_argument("--batch-size", type=int, default=256)
|
||||
parser.add_argument("--hidden-size", type=int, default=512)
|
||||
parser.add_argument("--training-num", type=int, default=10)
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
parser.add_argument("--rew-norm", type=int, default=False)
|
||||
parser.add_argument("--vf-coef", type=float, default=0.5)
|
||||
parser.add_argument("--ent-coef", type=float, default=0.01)
|
||||
parser.add_argument("--gae-lambda", type=float, default=0.95)
|
||||
parser.add_argument("--lr-decay", type=int, default=True)
|
||||
parser.add_argument("--max-grad-norm", type=float, default=0.5)
|
||||
parser.add_argument("--eps-clip", type=float, default=0.2)
|
||||
parser.add_argument("--dual-clip", type=float, default=None)
|
||||
parser.add_argument("--value-clip", type=int, default=0)
|
||||
parser.add_argument("--norm-adv", type=int, default=1)
|
||||
parser.add_argument("--recompute-adv", type=int, default=0)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
parser.add_argument('--frames-stack', type=int, default=4)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument('--resume-id', type=str, default=None)
|
||||
parser.add_argument("--frames-stack", type=int, default=4)
|
||||
parser.add_argument("--resume-path", type=str, default=None)
|
||||
parser.add_argument("--resume-id", type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--logger',
|
||||
"--logger",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
choices=["tensorboard", "wandb"],
|
||||
)
|
||||
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
|
||||
parser.add_argument(
|
||||
'--watch',
|
||||
"--watch",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='watch the play of pre-trained policy only'
|
||||
action="store_true",
|
||||
help="watch the play of pre-trained policy only"
|
||||
)
|
||||
parser.add_argument('--save-buffer-name', type=str, default=None)
|
||||
parser.add_argument("--save-buffer-name", type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--icm-lr-scale',
|
||||
"--icm-lr-scale",
|
||||
type=float,
|
||||
default=0.,
|
||||
help='use intrinsic curiosity module with this lr scale'
|
||||
help="use intrinsic curiosity module with this lr scale"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--icm-reward-scale',
|
||||
"--icm-reward-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help='scaling factor for intrinsic curiosity reward'
|
||||
help="scaling factor for intrinsic curiosity reward"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--icm-forward-loss-weight',
|
||||
"--icm-forward-loss-weight",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help='weight for the forward model loss in ICM'
|
||||
help="weight for the forward model loss in ICM"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@ -184,37 +186,44 @@ def test_ppo(args=get_args()):
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
|
||||
# log
|
||||
log_name = 'ppo_icm' if args.icm_lr_scale > 0 else 'ppo'
|
||||
log_path = os.path.join(args.logdir, args.task, log_name)
|
||||
if args.logger == "tensorboard":
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
else:
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
args.algo_name = "ppo_icm" if args.icm_lr_scale > 0 else "ppo"
|
||||
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
|
||||
log_path = os.path.join(args.logdir, log_name)
|
||||
|
||||
# logger
|
||||
if args.logger == "wandb":
|
||||
logger = WandbLogger(
|
||||
save_interval=1,
|
||||
project=args.task,
|
||||
name=log_name,
|
||||
name=log_name.replace(os.path.sep, "__"),
|
||||
run_id=args.resume_id,
|
||||
config=args,
|
||||
project=args.wandb_project,
|
||||
)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
if args.logger == "tensorboard":
|
||||
logger = TensorboardLogger(writer)
|
||||
else: # wandb
|
||||
logger.load(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
if env.spec.reward_threshold:
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
elif 'Pong' in args.task:
|
||||
elif "Pong" in args.task:
|
||||
return mean_rewards >= 20
|
||||
else:
|
||||
return False
|
||||
|
||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
||||
torch.save({'model': policy.state_dict()}, ckpt_path)
|
||||
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||
torch.save({"model": policy.state_dict()}, ckpt_path)
|
||||
return ckpt_path
|
||||
|
||||
# watch agent's performance
|
||||
@ -243,7 +252,7 @@ def test_ppo(args=get_args()):
|
||||
n_episode=args.test_num, render=args.render
|
||||
)
|
||||
rew = result["rews"].mean()
|
||||
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
|
||||
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
@ -274,5 +283,5 @@ def test_ppo(args=get_args()):
|
||||
watch()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_ppo(get_args())
|
||||
|
@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pprint
|
||||
|
||||
@ -6,7 +7,7 @@ import numpy as np
|
||||
import torch
|
||||
from atari_network import QRDQN
|
||||
from atari_wrapper import make_atari_env
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.utils.tensorboard import SummaryWriter, WandbLogger
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.policy import QRDQNPolicy
|
||||
@ -16,39 +17,47 @@ from tianshou.utils import TensorboardLogger
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--scale-obs', type=int, default=0)
|
||||
parser.add_argument('--eps-test', type=float, default=0.005)
|
||||
parser.add_argument('--eps-train', type=float, default=1.)
|
||||
parser.add_argument('--eps-train-final', type=float, default=0.05)
|
||||
parser.add_argument('--buffer-size', type=int, default=100000)
|
||||
parser.add_argument('--lr', type=float, default=0.0001)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--num-quantiles', type=int, default=200)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||
parser.add_argument('--batch-size', type=int, default=32)
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--scale-obs", type=int, default=0)
|
||||
parser.add_argument("--eps-test", type=float, default=0.005)
|
||||
parser.add_argument("--eps-train", type=float, default=1.)
|
||||
parser.add_argument("--eps-train-final", type=float, default=0.05)
|
||||
parser.add_argument("--buffer-size", type=int, default=100000)
|
||||
parser.add_argument("--lr", type=float, default=0.0001)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument("--num-quantiles", type=int, default=200)
|
||||
parser.add_argument("--n-step", type=int, default=3)
|
||||
parser.add_argument("--target-update-freq", type=int, default=500)
|
||||
parser.add_argument("--epoch", type=int, default=100)
|
||||
parser.add_argument("--step-per-epoch", type=int, default=100000)
|
||||
parser.add_argument("--step-per-collect", type=int, default=10)
|
||||
parser.add_argument("--update-per-step", type=float, default=0.1)
|
||||
parser.add_argument("--batch-size", type=int, default=32)
|
||||
parser.add_argument("--training-num", type=int, default=10)
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
parser.add_argument('--frames-stack', type=int, default=4)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument("--frames-stack", type=int, default=4)
|
||||
parser.add_argument("--resume-path", type=str, default=None)
|
||||
parser.add_argument("--resume-id", type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--watch',
|
||||
"--logger",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
choices=["tensorboard", "wandb"],
|
||||
)
|
||||
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
|
||||
parser.add_argument(
|
||||
"--watch",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='watch the play of pre-trained policy only'
|
||||
action="store_true",
|
||||
help="watch the play of pre-trained policy only"
|
||||
)
|
||||
parser.add_argument('--save-buffer-name', type=str, default=None)
|
||||
parser.add_argument("--save-buffer-name", type=str, default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -97,19 +106,36 @@ def test_qrdqn(args=get_args()):
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
args.algo_name = "qrdqn"
|
||||
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
|
||||
log_path = os.path.join(args.logdir, log_name)
|
||||
|
||||
# logger
|
||||
if args.logger == "wandb":
|
||||
logger = WandbLogger(
|
||||
save_interval=1,
|
||||
name=log_name.replace(os.path.sep, "__"),
|
||||
run_id=args.resume_id,
|
||||
config=args,
|
||||
project=args.wandb_project,
|
||||
)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
if args.logger == "tensorboard":
|
||||
logger = TensorboardLogger(writer)
|
||||
else: # wandb
|
||||
logger.load(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
if env.spec.reward_threshold:
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
elif 'Pong' in args.task:
|
||||
elif "Pong" in args.task:
|
||||
return mean_rewards >= 20
|
||||
else:
|
||||
return False
|
||||
@ -155,7 +181,7 @@ def test_qrdqn(args=get_args()):
|
||||
n_episode=args.test_num, render=args.render
|
||||
)
|
||||
rew = result["rews"].mean()
|
||||
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
|
||||
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
@ -186,5 +212,5 @@ def test_qrdqn(args=get_args()):
|
||||
watch()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_qrdqn(get_args())
|
||||
|
@ -7,7 +7,7 @@ import numpy as np
|
||||
import torch
|
||||
from atari_network import Rainbow
|
||||
from atari_wrapper import make_atari_env
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.utils.tensorboard import SummaryWriter, WandbLogger
|
||||
|
||||
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.policy import RainbowPolicy
|
||||
@ -17,50 +17,58 @@ from tianshou.utils import TensorboardLogger
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--scale-obs', type=int, default=0)
|
||||
parser.add_argument('--eps-test', type=float, default=0.005)
|
||||
parser.add_argument('--eps-train', type=float, default=1.)
|
||||
parser.add_argument('--eps-train-final', type=float, default=0.05)
|
||||
parser.add_argument('--buffer-size', type=int, default=100000)
|
||||
parser.add_argument('--lr', type=float, default=0.0000625)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--num-atoms', type=int, default=51)
|
||||
parser.add_argument('--v-min', type=float, default=-10.)
|
||||
parser.add_argument('--v-max', type=float, default=10.)
|
||||
parser.add_argument('--noisy-std', type=float, default=0.1)
|
||||
parser.add_argument('--no-dueling', action='store_true', default=False)
|
||||
parser.add_argument('--no-noisy', action='store_true', default=False)
|
||||
parser.add_argument('--no-priority', action='store_true', default=False)
|
||||
parser.add_argument('--alpha', type=float, default=0.5)
|
||||
parser.add_argument('--beta', type=float, default=0.4)
|
||||
parser.add_argument('--beta-final', type=float, default=1.)
|
||||
parser.add_argument('--beta-anneal-step', type=int, default=5000000)
|
||||
parser.add_argument('--no-weight-norm', action='store_true', default=False)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||
parser.add_argument('--batch-size', type=int, default=32)
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--scale-obs", type=int, default=0)
|
||||
parser.add_argument("--eps-test", type=float, default=0.005)
|
||||
parser.add_argument("--eps-train", type=float, default=1.)
|
||||
parser.add_argument("--eps-train-final", type=float, default=0.05)
|
||||
parser.add_argument("--buffer-size", type=int, default=100000)
|
||||
parser.add_argument("--lr", type=float, default=0.0000625)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument("--num-atoms", type=int, default=51)
|
||||
parser.add_argument("--v-min", type=float, default=-10.)
|
||||
parser.add_argument("--v-max", type=float, default=10.)
|
||||
parser.add_argument("--noisy-std", type=float, default=0.1)
|
||||
parser.add_argument("--no-dueling", action="store_true", default=False)
|
||||
parser.add_argument("--no-noisy", action="store_true", default=False)
|
||||
parser.add_argument("--no-priority", action="store_true", default=False)
|
||||
parser.add_argument("--alpha", type=float, default=0.5)
|
||||
parser.add_argument("--beta", type=float, default=0.4)
|
||||
parser.add_argument("--beta-final", type=float, default=1.)
|
||||
parser.add_argument("--beta-anneal-step", type=int, default=5000000)
|
||||
parser.add_argument("--no-weight-norm", action="store_true", default=False)
|
||||
parser.add_argument("--n-step", type=int, default=3)
|
||||
parser.add_argument("--target-update-freq", type=int, default=500)
|
||||
parser.add_argument("--epoch", type=int, default=100)
|
||||
parser.add_argument("--step-per-epoch", type=int, default=100000)
|
||||
parser.add_argument("--step-per-collect", type=int, default=10)
|
||||
parser.add_argument("--update-per-step", type=float, default=0.1)
|
||||
parser.add_argument("--batch-size", type=int, default=32)
|
||||
parser.add_argument("--training-num", type=int, default=10)
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
parser.add_argument('--frames-stack', type=int, default=4)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument("--frames-stack", type=int, default=4)
|
||||
parser.add_argument("--resume-path", type=str, default=None)
|
||||
parser.add_argument("--resume-id", type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--watch',
|
||||
"--logger",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
choices=["tensorboard", "wandb"],
|
||||
)
|
||||
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
|
||||
parser.add_argument(
|
||||
"--watch",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='watch the play of pre-trained policy only'
|
||||
action="store_true",
|
||||
help="watch the play of pre-trained policy only"
|
||||
)
|
||||
parser.add_argument('--save-buffer-name', type=str, default=None)
|
||||
parser.add_argument("--save-buffer-name", type=str, default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -131,22 +139,36 @@ def test_rainbow(args=get_args()):
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
|
||||
# log
|
||||
log_path = os.path.join(
|
||||
args.logdir, args.task, 'rainbow',
|
||||
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}'
|
||||
)
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
args.algo_name = "rainbow"
|
||||
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
|
||||
log_path = os.path.join(args.logdir, log_name)
|
||||
|
||||
# logger
|
||||
if args.logger == "wandb":
|
||||
logger = WandbLogger(
|
||||
save_interval=1,
|
||||
name=log_name.replace(os.path.sep, "__"),
|
||||
run_id=args.resume_id,
|
||||
config=args,
|
||||
project=args.wandb_project,
|
||||
)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
if args.logger == "tensorboard":
|
||||
logger = TensorboardLogger(writer)
|
||||
else: # wandb
|
||||
logger.load(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
if env.spec.reward_threshold:
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
elif 'Pong' in args.task:
|
||||
elif "Pong" in args.task:
|
||||
return mean_rewards >= 20
|
||||
else:
|
||||
return False
|
||||
@ -203,7 +225,7 @@ def test_rainbow(args=get_args()):
|
||||
n_episode=args.test_num, render=args.render
|
||||
)
|
||||
rew = result["rews"].mean()
|
||||
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
|
||||
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
@ -234,5 +256,5 @@ def test_rainbow(args=get_args()):
|
||||
watch()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_rainbow(get_args())
|
||||
|
@ -79,11 +79,14 @@ def test_psrl(args=get_args()):
|
||||
logger = WandbLogger(
|
||||
save_interval=1, project='psrl', name='wandb_test', config=args
|
||||
)
|
||||
elif args.logger == "tensorboard":
|
||||
if args.logger != "none":
|
||||
log_path = os.path.join(args.logdir, args.task, 'psrl')
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
if args.logger == "tensorboard":
|
||||
logger = TensorboardLogger(writer)
|
||||
else:
|
||||
logger.load(writer)
|
||||
else:
|
||||
logger = LazyLogger()
|
||||
|
||||
|
@ -2,7 +2,9 @@ import argparse
|
||||
import os
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
from tianshou.utils import BaseLogger
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.utils import BaseLogger, TensorboardLogger
|
||||
from tianshou.utils.logger.base import LOG_DATA_TYPE
|
||||
|
||||
try:
|
||||
@ -17,17 +19,13 @@ class WandbLogger(BaseLogger):
|
||||
This logger creates three panels with plots: train, test, and update.
|
||||
Make sure to select the correct access for each panel in weights and biases:
|
||||
|
||||
- ``train/env_step`` for train plots
|
||||
- ``test/env_step`` for test plots
|
||||
- ``update/gradient_step`` for update plots
|
||||
|
||||
Example of usage:
|
||||
::
|
||||
|
||||
with wandb.init(project="My Project"):
|
||||
logger = WandBLogger()
|
||||
result = onpolicy_trainer(policy, train_collector, test_collector,
|
||||
logger=logger)
|
||||
logger = WandbLogger()
|
||||
logger.load(SummaryWriter(log_path))
|
||||
result = onpolicy_trainer(policy, train_collector, test_collector,
|
||||
logger=logger)
|
||||
|
||||
:param int train_interval: the log interval in log_train_data(). Default to 1000.
|
||||
:param int test_interval: the log interval in log_test_data(). Default to 1.
|
||||
@ -46,7 +44,7 @@ class WandbLogger(BaseLogger):
|
||||
test_interval: int = 1,
|
||||
update_interval: int = 1000,
|
||||
save_interval: int = 1000,
|
||||
project: str = 'tianshou',
|
||||
project: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
entity: Optional[str] = None,
|
||||
run_id: Optional[str] = None,
|
||||
@ -56,6 +54,8 @@ class WandbLogger(BaseLogger):
|
||||
self.last_save_step = -1
|
||||
self.save_interval = save_interval
|
||||
self.restored = False
|
||||
if project is None:
|
||||
project = os.getenv("WANDB_PROJECT", "tianshou")
|
||||
|
||||
self.wandb_run = wandb.init(
|
||||
project=project,
|
||||
@ -63,14 +63,25 @@ class WandbLogger(BaseLogger):
|
||||
id=run_id,
|
||||
resume="allow",
|
||||
entity=entity,
|
||||
sync_tensorboard=True,
|
||||
monitor_gym=True,
|
||||
config=config, # type: ignore
|
||||
) if not wandb.run else wandb.run
|
||||
self.wandb_run._label(repo="tianshou") # type: ignore
|
||||
self.tensorboard_logger: Optional[TensorboardLogger] = None
|
||||
|
||||
def load(self, writer: SummaryWriter) -> None:
|
||||
self.writer = writer
|
||||
self.tensorboard_logger = TensorboardLogger(writer)
|
||||
|
||||
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
||||
data[step_type] = step
|
||||
wandb.log(data)
|
||||
if self.tensorboard_logger is None:
|
||||
raise Exception(
|
||||
"`logger` needs to load the Tensorboard Writer before "
|
||||
"writing data. Try `logger.load(SummaryWriter(log_path))`"
|
||||
)
|
||||
else:
|
||||
self.tensorboard_logger.write(step_type, step, data)
|
||||
|
||||
def save_data(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user