fix logger.write error in atari script (#444)
- fix a bug in #427: logger.write should pass a dict - change SubprocVectorEnv to ShmemVectorEnv in atari - increase logger interval for eps
This commit is contained in:
parent
fc251ab0b8
commit
e8f8cdfa41
1
.gitignore
vendored
1
.gitignore
vendored
@ -148,3 +148,4 @@ MUJOCO_LOG.TXT
|
||||
*.pkl
|
||||
*.hdf5
|
||||
wandb/
|
||||
videos/
|
||||
|
@ -11,7 +11,7 @@ from atari_wrapper import wrap_deepmind
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.policy import DiscreteBCQPolicy
|
||||
from tianshou.trainer import offline_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
@ -77,7 +77,7 @@ def test_discrete_bcq(args=get_args()):
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
test_envs = SubprocVectorEnv(
|
||||
test_envs = ShmemVectorEnv(
|
||||
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
|
@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.policy import C51Policy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
@ -75,10 +75,10 @@ def test_c51(args=get_args()):
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
train_envs = SubprocVectorEnv(
|
||||
train_envs = ShmemVectorEnv(
|
||||
[lambda: make_atari_env(args) for _ in range(args.training_num)]
|
||||
)
|
||||
test_envs = SubprocVectorEnv(
|
||||
test_envs = ShmemVectorEnv(
|
||||
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
@ -141,7 +141,8 @@ def test_c51(args=get_args()):
|
||||
else:
|
||||
eps = args.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
logger.write('train/eps', env_step, eps)
|
||||
if env_step % 1000 == 0:
|
||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
@ -11,7 +11,7 @@ from atari_wrapper import wrap_deepmind
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.policy import DiscreteCQLPolicy
|
||||
from tianshou.trainer import offline_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
@ -76,7 +76,7 @@ def test_discrete_cql(args=get_args()):
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
test_envs = SubprocVectorEnv(
|
||||
test_envs = ShmemVectorEnv(
|
||||
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
|
@ -11,7 +11,7 @@ from atari_wrapper import wrap_deepmind
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.policy import DiscreteCRRPolicy
|
||||
from tianshou.trainer import offline_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
@ -77,7 +77,7 @@ def test_discrete_crr(args=get_args()):
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
test_envs = SubprocVectorEnv(
|
||||
test_envs = ShmemVectorEnv(
|
||||
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
|
@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
@ -72,10 +72,10 @@ def test_dqn(args=get_args()):
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
train_envs = SubprocVectorEnv(
|
||||
train_envs = ShmemVectorEnv(
|
||||
[lambda: make_atari_env(args) for _ in range(args.training_num)]
|
||||
)
|
||||
test_envs = SubprocVectorEnv(
|
||||
test_envs = ShmemVectorEnv(
|
||||
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
@ -135,7 +135,8 @@ def test_dqn(args=get_args()):
|
||||
else:
|
||||
eps = args.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
logger.write('train/eps', env_step, eps)
|
||||
if env_step % 1000 == 0:
|
||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.policy import FQFPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
@ -78,10 +78,10 @@ def test_fqf(args=get_args()):
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
train_envs = SubprocVectorEnv(
|
||||
train_envs = ShmemVectorEnv(
|
||||
[lambda: make_atari_env(args) for _ in range(args.training_num)]
|
||||
)
|
||||
test_envs = SubprocVectorEnv(
|
||||
test_envs = ShmemVectorEnv(
|
||||
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
@ -158,7 +158,8 @@ def test_fqf(args=get_args()):
|
||||
else:
|
||||
eps = args.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
logger.write('train/eps', env_step, eps)
|
||||
if env_step % 1000 == 0:
|
||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.policy import IQNPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
@ -78,10 +78,10 @@ def test_iqn(args=get_args()):
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
train_envs = SubprocVectorEnv(
|
||||
train_envs = ShmemVectorEnv(
|
||||
[lambda: make_atari_env(args) for _ in range(args.training_num)]
|
||||
)
|
||||
test_envs = SubprocVectorEnv(
|
||||
test_envs = ShmemVectorEnv(
|
||||
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
@ -153,7 +153,8 @@ def test_iqn(args=get_args()):
|
||||
else:
|
||||
eps = args.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
logger.write('train/eps', env_step, eps)
|
||||
if env_step % 1000 == 0:
|
||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.policy import QRDQNPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
@ -73,10 +73,10 @@ def test_qrdqn(args=get_args()):
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
train_envs = SubprocVectorEnv(
|
||||
train_envs = ShmemVectorEnv(
|
||||
[lambda: make_atari_env(args) for _ in range(args.training_num)]
|
||||
)
|
||||
test_envs = SubprocVectorEnv(
|
||||
test_envs = ShmemVectorEnv(
|
||||
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
@ -137,7 +137,8 @@ def test_qrdqn(args=get_args()):
|
||||
else:
|
||||
eps = args.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
logger.write('train/eps', env_step, eps)
|
||||
if env_step % 1000 == 0:
|
||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
@ -10,7 +10,7 @@ from atari_wrapper import wrap_deepmind
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.policy import RainbowPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
@ -85,10 +85,10 @@ def test_rainbow(args=get_args()):
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
train_envs = SubprocVectorEnv(
|
||||
train_envs = ShmemVectorEnv(
|
||||
[lambda: make_atari_env(args) for _ in range(args.training_num)]
|
||||
)
|
||||
test_envs = SubprocVectorEnv(
|
||||
test_envs = ShmemVectorEnv(
|
||||
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
@ -174,7 +174,8 @@ def test_rainbow(args=get_args()):
|
||||
else:
|
||||
eps = args.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
logger.write('train/eps', env_step, eps)
|
||||
if env_step % 1000 == 0:
|
||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||
if not args.no_priority:
|
||||
if env_step <= args.beta_anneal_step:
|
||||
beta = args.beta - env_step / args.beta_anneal_step * \
|
||||
@ -182,7 +183,8 @@ def test_rainbow(args=get_args()):
|
||||
else:
|
||||
beta = args.beta_final
|
||||
buffer.set_beta(beta)
|
||||
logger.write('train/beta', env_step, beta)
|
||||
if env_step % 1000 == 0:
|
||||
logger.write("train/env_step", env_step, {"train/beta": beta})
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
@ -9,7 +9,7 @@ from network import C51
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.policy import C51Policy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
@ -72,13 +72,13 @@ def test_c51(args=get_args()):
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
train_envs = SubprocVectorEnv(
|
||||
train_envs = ShmemVectorEnv(
|
||||
[
|
||||
lambda: Env(args.cfg_path, args.frames_stack, args.res)
|
||||
for _ in range(args.training_num)
|
||||
]
|
||||
)
|
||||
test_envs = SubprocVectorEnv(
|
||||
test_envs = ShmemVectorEnv(
|
||||
[
|
||||
lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp)
|
||||
for _ in range(min(os.cpu_count() - 1, args.test_num))
|
||||
@ -144,7 +144,8 @@ def test_c51(args=get_args()):
|
||||
else:
|
||||
eps = args.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
logger.write('train/eps', env_step, eps)
|
||||
if env_step % 1000 == 0:
|
||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
Loading…
x
Reference in New Issue
Block a user