use envpool in vizdoom example, update doc (#634)

This commit is contained in:
Jiayi Weng 2022-05-08 12:42:16 -04:00 committed by GitHub
parent 2a7c151738
commit bf8f63ffc3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 335 additions and 149 deletions

View File

@ -88,3 +88,30 @@ AsyncCollector
:members:
:undoc-members:
:show-inheritance:
Utils
-----
to_numpy
~~~~~~~~
.. autofunction:: tianshou.data.to_numpy
to_torch
~~~~~~~~
.. autofunction:: tianshou.data.to_torch
to_torch_as
~~~~~~~~~~~
.. autofunction:: tianshou.data.to_torch_as
SegmentTree
~~~~~~~~~~~
.. autoclass:: tianshou.data.SegmentTree
:members:
:undoc-members:
:show-inheritance:

View File

@ -46,6 +46,26 @@ RayVectorEnv
:show-inheritance:
Wrapper
-------
VectorEnvWrapper
~~~~~~~~~~~~~~~~
.. autoclass:: tianshou.env.VectorEnvWrapper
:members:
:undoc-members:
:show-inheritance:
VectorEnvNormObs
~~~~~~~~~~~~~~~~
.. autoclass:: tianshou.env.VectorEnvNormObs
:members:
:undoc-members:
:show-inheritance:
Worker
------
@ -80,3 +100,15 @@ RayEnvWorker
:members:
:undoc-members:
:show-inheritance:
Utils
-----
PettingZooEnv
~~~~~~~~~~~~~
.. autoclass:: tianshou.env.PettingZooEnv
:members:
:undoc-members:
:show-inheritance:

View File

@ -158,3 +158,4 @@ Enduro
Qbert
Seaquest
subnets
subprocesses

View File

@ -123,7 +123,11 @@ EnvPool Integration
`EnvPool <https://github.com/sail-sg/envpool/>`_ is a C++-based vectorized environment implementation and is way faster than the above solutions. The APIs are almost the same as above four classes, so that means you can directly switch the vectorized environment to envpool and get immediate speed-up.
Currently it supports Atari, VizDoom, toy_text and classic_control environments. For more information, please refer to `EnvPool's documentation <https://envpool.readthedocs.io/en/latest/>`_.
Currently it supports
`Atari <https://github.com/thu-ml/tianshou/tree/master/examples/atari#envpool>`_,
`Mujoco <https://github.com/thu-ml/tianshou/tree/master/examples/mujoco#envpool>`_,
`VizDoom <https://github.com/thu-ml/tianshou/tree/master/examples/vizdoom#envpool>`_,
toy_text and classic_control environments. For more information, please refer to `EnvPool's documentation <https://envpool.readthedocs.io/en/latest/>`_.
::
@ -133,7 +137,7 @@ Currently it supports Atari, VizDoom, toy_text and classic_control environments.
envs = envpool.make_gym("CartPole-v0", num_envs=10)
collector = Collector(policy, envs, buffer)
Here are some examples: https://github.com/sail-sg/envpool/tree/master/examples/tianshou_examples
Here are some other `examples <https://github.com/sail-sg/envpool/tree/master/examples/tianshou_examples>`_.
.. _preprocess_fn:
@ -177,7 +181,7 @@ For example, you can write your hook as:
self.episode_log[i].append(kwargs['rew'][i])
kwargs['rew'][i] -= self.baseline
for i in range(n):
if kwargs['done']:
if kwargs['done'][i]:
self.main_log.append(np.mean(self.episode_log[i]))
self.episode_log[i] = []
self.baseline = np.mean(self.main_log)
@ -191,6 +195,40 @@ And finally,
Some examples are in `test/base/test_collector.py <https://github.com/thu-ml/tianshou/blob/master/test/base/test_collector.py>`_.
Another solution is to create a vector environment wrapper through :class:`~tianshou.env.VectorEnvWrapper`, e.g.
::
import numpy as np
from collections import deque
from tianshou.env import VectorEnvWrapper
class MyWrapper(VectorEnvWrapper):
def __init__(self, venv, size=100):
self.episode_log = None
self.main_log = deque(maxlen=size)
self.main_log.append(0)
self.baseline = 0
def step(self, action, env_id):
obs, rew, done, info = self.venv.step(action, env_id)
n = len(rew)
if self.episode_log is None:
self.episode_log = [[] for i in range(n)]
for i in range(n):
self.episode_log[i].append(rew[i])
rew[i] -= self.baseline
for i in range(n):
if done[i]:
self.main_log.append(np.mean(self.episode_log[i]))
self.episode_log[i] = []
self.baseline = np.mean(self.main_log)
return obs, rew, done, info
env = MyWrapper(env, size=100)
collector = Collector(policy, env, buffer)
We provide an observation normalization vector env wrapper: :class:`~tianshou.env.VectorEnvNormObs`.
.. _rnn_training:

View File

@ -2,12 +2,24 @@
[ViZDoom](https://github.com/mwydmuch/ViZDoom) is a popular RL env for a famous first-person shooting game Doom. Here we provide some results and intuitions for this scenario.
## EnvPool
We highly recommend using envpool to run the following experiments. To install, in a linux machine, type:
```bash
pip install envpool
```
After that, `make_vizdoom_env` will automatically switch to envpool's ViZDoom env. EnvPool's implementation is much faster (about 2\~3x faster for pure execution speed, 1.5x for overall RL training pipeline) than python vectorized env implementation.
For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/) and [Docs](https://envpool.readthedocs.io/en/latest/api/vizdoom.html).
## Train
To train an agent:
```bash
python3 vizdoom_c51.py --task {D1_basic|D3_battle|D4_battle2}
python3 vizdoom_c51.py --task {D1_basic|D2_navigation|D3_battle|D4_battle2}
```
D1 (health gathering) should finish training (no death) in less than 500k env step (5 epochs);

View File

@ -5,6 +5,13 @@ import gym
import numpy as np
import vizdoom as vzd
from tianshou.env import ShmemVectorEnv
try:
import envpool
except ImportError:
envpool = None
def normal_button_comb():
actions = []
@ -112,6 +119,58 @@ class Env(gym.Env):
self.game.close()
def make_vizdoom_env(task, frame_skip, res, save_lmp, seed, training_num, test_num):
test_num = min(os.cpu_count() - 1, test_num)
if envpool is not None:
task_id = "".join([i.capitalize() for i in task.split("_")]) + "-v1"
lmp_save_dir = "lmps/" if save_lmp else ""
reward_config = {
"KILLCOUNT": [20.0, -20.0],
"HEALTH": [1.0, 0.0],
"AMMO2": [1.0, -1.0],
}
if "battle" in task:
reward_config["HEALTH"] = [1.0, -1.0]
env = train_envs = envpool.make_gym(
task_id,
frame_skip=frame_skip,
stack_num=res[0],
seed=seed,
num_envs=training_num,
reward_config=reward_config,
use_combined_action=True,
max_episode_steps=2625,
use_inter_area_resize=False,
)
test_envs = envpool.make_gym(
task_id,
frame_skip=frame_skip,
stack_num=res[0],
lmp_save_dir=lmp_save_dir,
seed=seed,
num_envs=test_num,
reward_config=reward_config,
use_combined_action=True,
max_episode_steps=2625,
use_inter_area_resize=False,
)
else:
cfg_path = f"maps/{task}.cfg"
env = Env(cfg_path, frame_skip, res)
train_envs = ShmemVectorEnv(
[lambda: Env(cfg_path, frame_skip, res) for _ in range(training_num)]
)
test_envs = ShmemVectorEnv(
[
lambda: Env(cfg_path, frame_skip, res, save_lmp)
for _ in range(test_num)
]
)
train_envs.seed(seed)
test_envs.seed(seed)
return env, train_envs, test_envs
if __name__ == '__main__':
# env = Env("maps/D1_basic.cfg", 4, (4, 84, 84))
env = Env("maps/D3_battle.cfg", 4, (4, 84, 84))

View File

@ -1,94 +1,88 @@
import argparse
import datetime
import os
import pprint
import numpy as np
import torch
from env import Env
from env import make_vizdoom_env
from network import C51
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import C51Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils import TensorboardLogger, WandbLogger
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='D1_basic')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=2000000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--num-atoms', type=int, default=51)
parser.add_argument('--v-min', type=float, default=-10.)
parser.add_argument('--v-max', type=float, default=10.)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=300)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument("--task", type=str, default="D1_basic")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument("--eps-train", type=float, default=1.)
parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument("--buffer-size", type=int, default=2000000)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--num-atoms", type=int, default=51)
parser.add_argument("--v-min", type=float, default=-10.)
parser.add_argument("--v-max", type=float, default=10.)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--epoch", type=int, default=300)
parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--skip-num', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument("--skip-num", type=int, default=4)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
'--watch',
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="vizdoom.benchmark")
parser.add_argument(
"--watch",
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
action="store_true",
help="watch the play of pre-trained policy only",
)
parser.add_argument(
'--save-lmp',
"--save-lmp",
default=False,
action='store_true',
help='save lmp file for replay whole episode'
action="store_true",
help="save lmp file for replay whole episode",
)
parser.add_argument('--save-buffer-name', type=str, default=None)
parser.add_argument("--save-buffer-name", type=str, default=None)
return parser.parse_args()
def test_c51(args=get_args()):
args.cfg_path = f"maps/{args.task}.cfg"
args.wad_path = f"maps/{args.task}.wad"
args.res = (args.skip_num, 84, 84)
env = Env(args.cfg_path, args.frames_stack, args.res)
args.state_shape = args.res
# make environments
env, train_envs, test_envs = make_vizdoom_env(
args.task, args.skip_num, (args.frames_stack, 84, 84), args.save_lmp,
args.seed, args.training_num, args.test_num
)
args.state_shape = env.observation_space.shape
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = ShmemVectorEnv(
[
lambda: Env(args.cfg_path, args.frames_stack, args.res)
for _ in range(args.training_num)
]
)
test_envs = ShmemVectorEnv(
[
lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp)
for _ in range(min(os.cpu_count() - 1, args.test_num))
]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
@ -101,7 +95,7 @@ def test_c51(args=get_args()):
args.v_min,
args.v_max,
args.n_step,
target_update_freq=args.target_update_freq
target_update_freq=args.target_update_freq,
).to(args.device)
# load a previous policy
if args.resume_path:
@ -114,25 +108,40 @@ def test_c51(args=get_args()):
buffer_num=len(train_envs),
ignore_obs_next=True,
save_only_last_obs=True,
stack_num=args.frames_stack
stack_num=args.frames_stack,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(args.logdir, args.task, 'c51')
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "c51"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return mean_rewards >= 20
else:
return False
@ -163,7 +172,7 @@ def test_c51(args=get_args()):
buffer_num=len(test_envs),
ignore_obs_next=True,
save_only_last_obs=True,
stack_num=args.frames_stack
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
@ -203,12 +212,12 @@ def test_c51(args=get_args()):
save_best_fn=save_best_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False
test_in_train=False,
)
pprint.pprint(result)
watch()
if __name__ == '__main__':
if __name__ == "__main__":
test_c51(get_args())

View File

@ -1,126 +1,120 @@
import argparse
import datetime
import os
import pprint
import numpy as np
import torch
from env import Env
from env import make_vizdoom_env
from network import DQN
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import ICMPolicy, PPOPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='D2_navigation')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.00002)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=300)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=1000)
parser.add_argument('--repeat-per-collect', type=int, default=4)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--hidden-size', type=int, default=512)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--rew-norm', type=int, default=False)
parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--ent-coef', type=float, default=0.01)
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--dual-clip', type=float, default=None)
parser.add_argument('--value-clip', type=int, default=0)
parser.add_argument('--norm-adv', type=int, default=1)
parser.add_argument('--recompute-adv', type=int, default=0)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument("--task", type=str, default="D1_basic")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=0.00002)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--epoch", type=int, default=300)
parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument("--step-per-collect", type=int, default=1000)
parser.add_argument("--repeat-per-collect", type=int, default=4)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--hidden-size", type=int, default=512)
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument("--rew-norm", type=int, default=False)
parser.add_argument("--vf-coef", type=float, default=0.5)
parser.add_argument("--ent-coef", type=float, default=0.01)
parser.add_argument("--gae-lambda", type=float, default=0.95)
parser.add_argument("--lr-decay", type=int, default=True)
parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--eps-clip", type=float, default=0.2)
parser.add_argument("--dual-clip", type=float, default=None)
parser.add_argument("--value-clip", type=int, default=0)
parser.add_argument("--norm-adv", type=int, default=1)
parser.add_argument("--recompute-adv", type=int, default=0)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--skip-num', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument("--skip-num", type=int, default=4)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
'--watch',
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="vizdoom.benchmark")
parser.add_argument(
"--watch",
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
action="store_true",
help="watch the play of pre-trained policy only",
)
parser.add_argument(
'--save-lmp',
"--save-lmp",
default=False,
action='store_true',
help='save lmp file for replay whole episode'
action="store_true",
help="save lmp file for replay whole episode",
)
parser.add_argument('--save-buffer-name', type=str, default=None)
parser.add_argument("--save-buffer-name", type=str, default=None)
parser.add_argument(
'--icm-lr-scale',
"--icm-lr-scale",
type=float,
default=0.,
help='use intrinsic curiosity module with this lr scale'
help="use intrinsic curiosity module with this lr scale",
)
parser.add_argument(
'--icm-reward-scale',
"--icm-reward-scale",
type=float,
default=0.01,
help='scaling factor for intrinsic curiosity reward'
help="scaling factor for intrinsic curiosity reward",
)
parser.add_argument(
'--icm-forward-loss-weight',
"--icm-forward-loss-weight",
type=float,
default=0.2,
help='weight for the forward model loss in ICM'
help="weight for the forward model loss in ICM",
)
return parser.parse_args()
def test_ppo(args=get_args()):
args.cfg_path = f"maps/{args.task}.cfg"
args.wad_path = f"maps/{args.task}.wad"
args.res = (args.skip_num, 84, 84)
env = Env(args.cfg_path, args.frames_stack, args.res)
args.state_shape = args.res
# make environments
env, train_envs, test_envs = make_vizdoom_env(
args.task, args.skip_num, (args.frames_stack, 84, 84), args.save_lmp,
args.seed, args.training_num, args.test_num
)
args.state_shape = env.observation_space.shape
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = ShmemVectorEnv(
[
lambda: Env(args.cfg_path, args.frames_stack, args.res)
for _ in range(args.training_num)
]
)
test_envs = ShmemVectorEnv(
[
lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp)
for _ in range(min(os.cpu_count() - 1, args.test_num))
]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
net = DQN(
*args.state_shape,
args.action_shape,
device=args.device,
features_only=True,
output_dim=args.hidden_size
output_dim=args.hidden_size,
)
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
critic = Critic(net, device=args.device)
@ -159,7 +153,7 @@ def test_ppo(args=get_args()):
value_clip=args.value_clip,
dual_clip=args.dual_clip,
advantage_normalization=args.norm_adv,
recompute_advantage=args.recompute_adv
recompute_advantage=args.recompute_adv,
).to(args.device)
if args.icm_lr_scale > 0:
feature_net = DQN(
@ -167,7 +161,7 @@ def test_ppo(args=get_args()):
args.action_shape,
device=args.device,
features_only=True,
output_dim=args.hidden_size
output_dim=args.hidden_size,
)
action_dim = np.prod(args.action_shape)
feature_dim = feature_net.output_dim
@ -190,26 +184,40 @@ def test_ppo(args=get_args()):
buffer_num=len(train_envs),
ignore_obs_next=True,
save_only_last_obs=True,
stack_num=args.frames_stack
stack_num=args.frames_stack,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_name = 'ppo_icm' if args.icm_lr_scale > 0 else 'ppo'
log_path = os.path.join(args.logdir, args.task, log_name)
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "ppo_icm" if args.icm_lr_scale > 0 else "ppo"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return mean_rewards >= 20
else:
return False
@ -225,7 +233,7 @@ def test_ppo(args=get_args()):
buffer_num=len(test_envs),
ignore_obs_next=True,
save_only_last_obs=True,
stack_num=args.frames_stack
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
@ -263,12 +271,12 @@ def test_ppo(args=get_args()):
stop_fn=stop_fn,
save_best_fn=save_best_fn,
logger=logger,
test_in_train=False
test_in_train=False,
)
pprint.pprint(result)
watch()
if __name__ == '__main__':
if __name__ == "__main__":
test_ppo(get_args())