fix logger.write error in atari script (#444)

- fix a bug in #427: logger.write should pass a dict
- change SubprocVectorEnv to ShmemVectorEnv in atari
- increase logger interval for eps
This commit is contained in:
Jiayi Weng 2021-09-09 00:51:39 +08:00 committed by GitHub
parent fc251ab0b8
commit e8f8cdfa41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 44 additions and 35 deletions

1
.gitignore vendored
View File

@ -148,3 +148,4 @@ MUJOCO_LOG.TXT
*.pkl *.pkl
*.hdf5 *.hdf5
wandb/ wandb/
videos/

View File

@ -11,7 +11,7 @@ from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteBCQPolicy from tianshou.policy import DiscreteBCQPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
@ -77,7 +77,7 @@ def test_discrete_bcq(args=get_args()):
print("Observations shape:", args.state_shape) print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape) print("Actions shape:", args.action_shape)
# make environments # make environments
test_envs = SubprocVectorEnv( test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)] [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
) )
# seed # seed

View File

@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.policy import C51Policy from tianshou.policy import C51Policy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
@ -75,10 +75,10 @@ def test_c51(args=get_args()):
print("Observations shape:", args.state_shape) print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape) print("Actions shape:", args.action_shape)
# make environments # make environments
train_envs = SubprocVectorEnv( train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)] [lambda: make_atari_env(args) for _ in range(args.training_num)]
) )
test_envs = SubprocVectorEnv( test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)] [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
) )
# seed # seed
@ -141,7 +141,8 @@ def test_c51(args=get_args()):
else: else:
eps = args.eps_train_final eps = args.eps_train_final
policy.set_eps(eps) policy.set_eps(eps)
logger.write('train/eps', env_step, eps) if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
def test_fn(epoch, env_step): def test_fn(epoch, env_step):
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)

View File

@ -11,7 +11,7 @@ from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteCQLPolicy from tianshou.policy import DiscreteCQLPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
@ -76,7 +76,7 @@ def test_discrete_cql(args=get_args()):
print("Observations shape:", args.state_shape) print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape) print("Actions shape:", args.action_shape)
# make environments # make environments
test_envs = SubprocVectorEnv( test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)] [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
) )
# seed # seed

View File

@ -11,7 +11,7 @@ from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteCRRPolicy from tianshou.policy import DiscreteCRRPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
@ -77,7 +77,7 @@ def test_discrete_crr(args=get_args()):
print("Observations shape:", args.state_shape) print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape) print("Actions shape:", args.action_shape)
# make environments # make environments
test_envs = SubprocVectorEnv( test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)] [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
) )
# seed # seed

View File

@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.policy import DQNPolicy from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
@ -72,10 +72,10 @@ def test_dqn(args=get_args()):
print("Observations shape:", args.state_shape) print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape) print("Actions shape:", args.action_shape)
# make environments # make environments
train_envs = SubprocVectorEnv( train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)] [lambda: make_atari_env(args) for _ in range(args.training_num)]
) )
test_envs = SubprocVectorEnv( test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)] [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
) )
# seed # seed
@ -135,7 +135,8 @@ def test_dqn(args=get_args()):
else: else:
eps = args.eps_train_final eps = args.eps_train_final
policy.set_eps(eps) policy.set_eps(eps)
logger.write('train/eps', env_step, eps) if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
def test_fn(epoch, env_step): def test_fn(epoch, env_step):
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)

View File

@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.policy import FQFPolicy from tianshou.policy import FQFPolicy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
@ -78,10 +78,10 @@ def test_fqf(args=get_args()):
print("Observations shape:", args.state_shape) print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape) print("Actions shape:", args.action_shape)
# make environments # make environments
train_envs = SubprocVectorEnv( train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)] [lambda: make_atari_env(args) for _ in range(args.training_num)]
) )
test_envs = SubprocVectorEnv( test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)] [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
) )
# seed # seed
@ -158,7 +158,8 @@ def test_fqf(args=get_args()):
else: else:
eps = args.eps_train_final eps = args.eps_train_final
policy.set_eps(eps) policy.set_eps(eps)
logger.write('train/eps', env_step, eps) if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
def test_fn(epoch, env_step): def test_fn(epoch, env_step):
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)

View File

@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.policy import IQNPolicy from tianshou.policy import IQNPolicy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
@ -78,10 +78,10 @@ def test_iqn(args=get_args()):
print("Observations shape:", args.state_shape) print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape) print("Actions shape:", args.action_shape)
# make environments # make environments
train_envs = SubprocVectorEnv( train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)] [lambda: make_atari_env(args) for _ in range(args.training_num)]
) )
test_envs = SubprocVectorEnv( test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)] [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
) )
# seed # seed
@ -153,7 +153,8 @@ def test_iqn(args=get_args()):
else: else:
eps = args.eps_train_final eps = args.eps_train_final
policy.set_eps(eps) policy.set_eps(eps)
logger.write('train/eps', env_step, eps) if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
def test_fn(epoch, env_step): def test_fn(epoch, env_step):
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)

View File

@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.policy import QRDQNPolicy from tianshou.policy import QRDQNPolicy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
@ -73,10 +73,10 @@ def test_qrdqn(args=get_args()):
print("Observations shape:", args.state_shape) print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape) print("Actions shape:", args.action_shape)
# make environments # make environments
train_envs = SubprocVectorEnv( train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)] [lambda: make_atari_env(args) for _ in range(args.training_num)]
) )
test_envs = SubprocVectorEnv( test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)] [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
) )
# seed # seed
@ -137,7 +137,8 @@ def test_qrdqn(args=get_args()):
else: else:
eps = args.eps_train_final eps = args.eps_train_final
policy.set_eps(eps) policy.set_eps(eps)
logger.write('train/eps', env_step, eps) if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
def test_fn(epoch, env_step): def test_fn(epoch, env_step):
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)

View File

@ -10,7 +10,7 @@ from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.policy import RainbowPolicy from tianshou.policy import RainbowPolicy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
@ -85,10 +85,10 @@ def test_rainbow(args=get_args()):
print("Observations shape:", args.state_shape) print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape) print("Actions shape:", args.action_shape)
# make environments # make environments
train_envs = SubprocVectorEnv( train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)] [lambda: make_atari_env(args) for _ in range(args.training_num)]
) )
test_envs = SubprocVectorEnv( test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)] [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
) )
# seed # seed
@ -174,7 +174,8 @@ def test_rainbow(args=get_args()):
else: else:
eps = args.eps_train_final eps = args.eps_train_final
policy.set_eps(eps) policy.set_eps(eps)
logger.write('train/eps', env_step, eps) if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
if not args.no_priority: if not args.no_priority:
if env_step <= args.beta_anneal_step: if env_step <= args.beta_anneal_step:
beta = args.beta - env_step / args.beta_anneal_step * \ beta = args.beta - env_step / args.beta_anneal_step * \
@ -182,7 +183,8 @@ def test_rainbow(args=get_args()):
else: else:
beta = args.beta_final beta = args.beta_final
buffer.set_beta(beta) buffer.set_beta(beta)
logger.write('train/beta', env_step, beta) if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/beta": beta})
def test_fn(epoch, env_step): def test_fn(epoch, env_step):
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)

View File

@ -9,7 +9,7 @@ from network import C51
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.policy import C51Policy from tianshou.policy import C51Policy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
@ -72,13 +72,13 @@ def test_c51(args=get_args()):
print("Observations shape:", args.state_shape) print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape) print("Actions shape:", args.action_shape)
# make environments # make environments
train_envs = SubprocVectorEnv( train_envs = ShmemVectorEnv(
[ [
lambda: Env(args.cfg_path, args.frames_stack, args.res) lambda: Env(args.cfg_path, args.frames_stack, args.res)
for _ in range(args.training_num) for _ in range(args.training_num)
] ]
) )
test_envs = SubprocVectorEnv( test_envs = ShmemVectorEnv(
[ [
lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp) lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp)
for _ in range(min(os.cpu_count() - 1, args.test_num)) for _ in range(min(os.cpu_count() - 1, args.test_num))
@ -144,7 +144,8 @@ def test_c51(args=get_args()):
else: else:
eps = args.eps_train_final eps = args.eps_train_final
policy.set_eps(eps) policy.set_eps(eps)
logger.write('train/eps', env_step, eps) if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
def test_fn(epoch, env_step): def test_fn(epoch, env_step):
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)