fix logger.write error in atari script (#444)
- fix a bug in #427: logger.write should pass a dict - change SubprocVectorEnv to ShmemVectorEnv in atari - increase logger interval for eps
This commit is contained in:
parent
fc251ab0b8
commit
e8f8cdfa41
1
.gitignore
vendored
1
.gitignore
vendored
@ -148,3 +148,4 @@ MUJOCO_LOG.TXT
|
|||||||
*.pkl
|
*.pkl
|
||||||
*.hdf5
|
*.hdf5
|
||||||
wandb/
|
wandb/
|
||||||
|
videos/
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from atari_wrapper import wrap_deepmind
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
from tianshou.policy import DiscreteBCQPolicy
|
from tianshou.policy import DiscreteBCQPolicy
|
||||||
from tianshou.trainer import offline_trainer
|
from tianshou.trainer import offline_trainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
@ -77,7 +77,7 @@ def test_discrete_bcq(args=get_args()):
|
|||||||
print("Observations shape:", args.state_shape)
|
print("Observations shape:", args.state_shape)
|
||||||
print("Actions shape:", args.action_shape)
|
print("Actions shape:", args.action_shape)
|
||||||
# make environments
|
# make environments
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = ShmemVectorEnv(
|
||||||
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
||||||
)
|
)
|
||||||
# seed
|
# seed
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
from tianshou.policy import C51Policy
|
from tianshou.policy import C51Policy
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
@ -75,10 +75,10 @@ def test_c51(args=get_args()):
|
|||||||
print("Observations shape:", args.state_shape)
|
print("Observations shape:", args.state_shape)
|
||||||
print("Actions shape:", args.action_shape)
|
print("Actions shape:", args.action_shape)
|
||||||
# make environments
|
# make environments
|
||||||
train_envs = SubprocVectorEnv(
|
train_envs = ShmemVectorEnv(
|
||||||
[lambda: make_atari_env(args) for _ in range(args.training_num)]
|
[lambda: make_atari_env(args) for _ in range(args.training_num)]
|
||||||
)
|
)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = ShmemVectorEnv(
|
||||||
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
||||||
)
|
)
|
||||||
# seed
|
# seed
|
||||||
@ -141,7 +141,8 @@ def test_c51(args=get_args()):
|
|||||||
else:
|
else:
|
||||||
eps = args.eps_train_final
|
eps = args.eps_train_final
|
||||||
policy.set_eps(eps)
|
policy.set_eps(eps)
|
||||||
logger.write('train/eps', env_step, eps)
|
if env_step % 1000 == 0:
|
||||||
|
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||||
|
|
||||||
def test_fn(epoch, env_step):
|
def test_fn(epoch, env_step):
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from atari_wrapper import wrap_deepmind
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
from tianshou.policy import DiscreteCQLPolicy
|
from tianshou.policy import DiscreteCQLPolicy
|
||||||
from tianshou.trainer import offline_trainer
|
from tianshou.trainer import offline_trainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
@ -76,7 +76,7 @@ def test_discrete_cql(args=get_args()):
|
|||||||
print("Observations shape:", args.state_shape)
|
print("Observations shape:", args.state_shape)
|
||||||
print("Actions shape:", args.action_shape)
|
print("Actions shape:", args.action_shape)
|
||||||
# make environments
|
# make environments
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = ShmemVectorEnv(
|
||||||
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
||||||
)
|
)
|
||||||
# seed
|
# seed
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from atari_wrapper import wrap_deepmind
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
from tianshou.policy import DiscreteCRRPolicy
|
from tianshou.policy import DiscreteCRRPolicy
|
||||||
from tianshou.trainer import offline_trainer
|
from tianshou.trainer import offline_trainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
@ -77,7 +77,7 @@ def test_discrete_crr(args=get_args()):
|
|||||||
print("Observations shape:", args.state_shape)
|
print("Observations shape:", args.state_shape)
|
||||||
print("Actions shape:", args.action_shape)
|
print("Actions shape:", args.action_shape)
|
||||||
# make environments
|
# make environments
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = ShmemVectorEnv(
|
||||||
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
||||||
)
|
)
|
||||||
# seed
|
# seed
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
@ -72,10 +72,10 @@ def test_dqn(args=get_args()):
|
|||||||
print("Observations shape:", args.state_shape)
|
print("Observations shape:", args.state_shape)
|
||||||
print("Actions shape:", args.action_shape)
|
print("Actions shape:", args.action_shape)
|
||||||
# make environments
|
# make environments
|
||||||
train_envs = SubprocVectorEnv(
|
train_envs = ShmemVectorEnv(
|
||||||
[lambda: make_atari_env(args) for _ in range(args.training_num)]
|
[lambda: make_atari_env(args) for _ in range(args.training_num)]
|
||||||
)
|
)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = ShmemVectorEnv(
|
||||||
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
||||||
)
|
)
|
||||||
# seed
|
# seed
|
||||||
@ -135,7 +135,8 @@ def test_dqn(args=get_args()):
|
|||||||
else:
|
else:
|
||||||
eps = args.eps_train_final
|
eps = args.eps_train_final
|
||||||
policy.set_eps(eps)
|
policy.set_eps(eps)
|
||||||
logger.write('train/eps', env_step, eps)
|
if env_step % 1000 == 0:
|
||||||
|
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||||
|
|
||||||
def test_fn(epoch, env_step):
|
def test_fn(epoch, env_step):
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
from tianshou.policy import FQFPolicy
|
from tianshou.policy import FQFPolicy
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
@ -78,10 +78,10 @@ def test_fqf(args=get_args()):
|
|||||||
print("Observations shape:", args.state_shape)
|
print("Observations shape:", args.state_shape)
|
||||||
print("Actions shape:", args.action_shape)
|
print("Actions shape:", args.action_shape)
|
||||||
# make environments
|
# make environments
|
||||||
train_envs = SubprocVectorEnv(
|
train_envs = ShmemVectorEnv(
|
||||||
[lambda: make_atari_env(args) for _ in range(args.training_num)]
|
[lambda: make_atari_env(args) for _ in range(args.training_num)]
|
||||||
)
|
)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = ShmemVectorEnv(
|
||||||
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
||||||
)
|
)
|
||||||
# seed
|
# seed
|
||||||
@ -158,7 +158,8 @@ def test_fqf(args=get_args()):
|
|||||||
else:
|
else:
|
||||||
eps = args.eps_train_final
|
eps = args.eps_train_final
|
||||||
policy.set_eps(eps)
|
policy.set_eps(eps)
|
||||||
logger.write('train/eps', env_step, eps)
|
if env_step % 1000 == 0:
|
||||||
|
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||||
|
|
||||||
def test_fn(epoch, env_step):
|
def test_fn(epoch, env_step):
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
from tianshou.policy import IQNPolicy
|
from tianshou.policy import IQNPolicy
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
@ -78,10 +78,10 @@ def test_iqn(args=get_args()):
|
|||||||
print("Observations shape:", args.state_shape)
|
print("Observations shape:", args.state_shape)
|
||||||
print("Actions shape:", args.action_shape)
|
print("Actions shape:", args.action_shape)
|
||||||
# make environments
|
# make environments
|
||||||
train_envs = SubprocVectorEnv(
|
train_envs = ShmemVectorEnv(
|
||||||
[lambda: make_atari_env(args) for _ in range(args.training_num)]
|
[lambda: make_atari_env(args) for _ in range(args.training_num)]
|
||||||
)
|
)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = ShmemVectorEnv(
|
||||||
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
||||||
)
|
)
|
||||||
# seed
|
# seed
|
||||||
@ -153,7 +153,8 @@ def test_iqn(args=get_args()):
|
|||||||
else:
|
else:
|
||||||
eps = args.eps_train_final
|
eps = args.eps_train_final
|
||||||
policy.set_eps(eps)
|
policy.set_eps(eps)
|
||||||
logger.write('train/eps', env_step, eps)
|
if env_step % 1000 == 0:
|
||||||
|
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||||
|
|
||||||
def test_fn(epoch, env_step):
|
def test_fn(epoch, env_step):
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from atari_wrapper import wrap_deepmind
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
from tianshou.policy import QRDQNPolicy
|
from tianshou.policy import QRDQNPolicy
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
@ -73,10 +73,10 @@ def test_qrdqn(args=get_args()):
|
|||||||
print("Observations shape:", args.state_shape)
|
print("Observations shape:", args.state_shape)
|
||||||
print("Actions shape:", args.action_shape)
|
print("Actions shape:", args.action_shape)
|
||||||
# make environments
|
# make environments
|
||||||
train_envs = SubprocVectorEnv(
|
train_envs = ShmemVectorEnv(
|
||||||
[lambda: make_atari_env(args) for _ in range(args.training_num)]
|
[lambda: make_atari_env(args) for _ in range(args.training_num)]
|
||||||
)
|
)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = ShmemVectorEnv(
|
||||||
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
||||||
)
|
)
|
||||||
# seed
|
# seed
|
||||||
@ -137,7 +137,8 @@ def test_qrdqn(args=get_args()):
|
|||||||
else:
|
else:
|
||||||
eps = args.eps_train_final
|
eps = args.eps_train_final
|
||||||
policy.set_eps(eps)
|
policy.set_eps(eps)
|
||||||
logger.write('train/eps', env_step, eps)
|
if env_step % 1000 == 0:
|
||||||
|
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||||
|
|
||||||
def test_fn(epoch, env_step):
|
def test_fn(epoch, env_step):
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from atari_wrapper import wrap_deepmind
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
|
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
from tianshou.policy import RainbowPolicy
|
from tianshou.policy import RainbowPolicy
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
@ -85,10 +85,10 @@ def test_rainbow(args=get_args()):
|
|||||||
print("Observations shape:", args.state_shape)
|
print("Observations shape:", args.state_shape)
|
||||||
print("Actions shape:", args.action_shape)
|
print("Actions shape:", args.action_shape)
|
||||||
# make environments
|
# make environments
|
||||||
train_envs = SubprocVectorEnv(
|
train_envs = ShmemVectorEnv(
|
||||||
[lambda: make_atari_env(args) for _ in range(args.training_num)]
|
[lambda: make_atari_env(args) for _ in range(args.training_num)]
|
||||||
)
|
)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = ShmemVectorEnv(
|
||||||
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
|
||||||
)
|
)
|
||||||
# seed
|
# seed
|
||||||
@ -174,7 +174,8 @@ def test_rainbow(args=get_args()):
|
|||||||
else:
|
else:
|
||||||
eps = args.eps_train_final
|
eps = args.eps_train_final
|
||||||
policy.set_eps(eps)
|
policy.set_eps(eps)
|
||||||
logger.write('train/eps', env_step, eps)
|
if env_step % 1000 == 0:
|
||||||
|
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||||
if not args.no_priority:
|
if not args.no_priority:
|
||||||
if env_step <= args.beta_anneal_step:
|
if env_step <= args.beta_anneal_step:
|
||||||
beta = args.beta - env_step / args.beta_anneal_step * \
|
beta = args.beta - env_step / args.beta_anneal_step * \
|
||||||
@ -182,7 +183,8 @@ def test_rainbow(args=get_args()):
|
|||||||
else:
|
else:
|
||||||
beta = args.beta_final
|
beta = args.beta_final
|
||||||
buffer.set_beta(beta)
|
buffer.set_beta(beta)
|
||||||
logger.write('train/beta', env_step, beta)
|
if env_step % 1000 == 0:
|
||||||
|
logger.write("train/env_step", env_step, {"train/beta": beta})
|
||||||
|
|
||||||
def test_fn(epoch, env_step):
|
def test_fn(epoch, env_step):
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from network import C51
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
from tianshou.policy import C51Policy
|
from tianshou.policy import C51Policy
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
@ -72,13 +72,13 @@ def test_c51(args=get_args()):
|
|||||||
print("Observations shape:", args.state_shape)
|
print("Observations shape:", args.state_shape)
|
||||||
print("Actions shape:", args.action_shape)
|
print("Actions shape:", args.action_shape)
|
||||||
# make environments
|
# make environments
|
||||||
train_envs = SubprocVectorEnv(
|
train_envs = ShmemVectorEnv(
|
||||||
[
|
[
|
||||||
lambda: Env(args.cfg_path, args.frames_stack, args.res)
|
lambda: Env(args.cfg_path, args.frames_stack, args.res)
|
||||||
for _ in range(args.training_num)
|
for _ in range(args.training_num)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = ShmemVectorEnv(
|
||||||
[
|
[
|
||||||
lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp)
|
lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp)
|
||||||
for _ in range(min(os.cpu_count() - 1, args.test_num))
|
for _ in range(min(os.cpu_count() - 1, args.test_num))
|
||||||
@ -144,7 +144,8 @@ def test_c51(args=get_args()):
|
|||||||
else:
|
else:
|
||||||
eps = args.eps_train_final
|
eps = args.eps_train_final
|
||||||
policy.set_eps(eps)
|
policy.set_eps(eps)
|
||||||
logger.write('train/eps', env_step, eps)
|
if env_step % 1000 == 0:
|
||||||
|
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||||
|
|
||||||
def test_fn(epoch, env_step):
|
def test_fn(epoch, env_step):
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user