fix logger.write error in atari script (#444)

- fix a bug in #427: logger.write should pass a dict
- change SubprocVectorEnv to ShmemVectorEnv in atari
- increase logger interval for eps
This commit is contained in:
Jiayi Weng 2021-09-09 00:51:39 +08:00 committed by GitHub
parent fc251ab0b8
commit e8f8cdfa41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 44 additions and 35 deletions

1
.gitignore vendored
View File

@ -148,3 +148,4 @@ MUJOCO_LOG.TXT
*.pkl
*.hdf5
wandb/
videos/

View File

@ -11,7 +11,7 @@ from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteBCQPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger
@ -77,7 +77,7 @@ def test_discrete_bcq(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed

View File

@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import C51Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
@ -75,10 +75,10 @@ def test_c51(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv(
train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)]
)
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
@ -141,7 +141,8 @@ def test_c51(args=get_args()):
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

View File

@ -11,7 +11,7 @@ from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteCQLPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger
@ -76,7 +76,7 @@ def test_discrete_cql(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed

View File

@ -11,7 +11,7 @@ from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteCRRPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger
@ -77,7 +77,7 @@ def test_discrete_crr(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed

View File

@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
@ -72,10 +72,10 @@ def test_dqn(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv(
train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)]
)
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
@ -135,7 +135,8 @@ def test_dqn(args=get_args()):
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

View File

@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import FQFPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
@ -78,10 +78,10 @@ def test_fqf(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv(
train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)]
)
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
@ -158,7 +158,8 @@ def test_fqf(args=get_args()):
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

View File

@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import IQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
@ -78,10 +78,10 @@ def test_iqn(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv(
train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)]
)
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
@ -153,7 +153,8 @@ def test_iqn(args=get_args()):
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

View File

@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import QRDQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
@ -73,10 +73,10 @@ def test_qrdqn(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv(
train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)]
)
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
@ -137,7 +137,8 @@ def test_qrdqn(args=get_args()):
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

View File

@ -10,7 +10,7 @@ from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import RainbowPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
@ -85,10 +85,10 @@ def test_rainbow(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv(
train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)]
)
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
@ -174,7 +174,8 @@ def test_rainbow(args=get_args()):
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
if not args.no_priority:
if env_step <= args.beta_anneal_step:
beta = args.beta - env_step / args.beta_anneal_step * \
@ -182,7 +183,8 @@ def test_rainbow(args=get_args()):
else:
beta = args.beta_final
buffer.set_beta(beta)
logger.write('train/beta', env_step, beta)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/beta": beta})
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

View File

@ -9,7 +9,7 @@ from network import C51
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import C51Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
@ -72,13 +72,13 @@ def test_c51(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv(
train_envs = ShmemVectorEnv(
[
lambda: Env(args.cfg_path, args.frames_stack, args.res)
for _ in range(args.training_num)
]
)
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[
lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp)
for _ in range(min(os.cpu_count() - 1, args.test_num))
@ -144,7 +144,8 @@ def test_c51(args=get_args()):
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)