change API of train_fn and test_fn (#229)

train_fn(epoch) -> train_fn(epoch, num_env_step)
test_fn(epoch) -> test_fn(epoch, num_env_step)
This commit is contained in:
n+e 2020-09-26 16:35:37 +08:00 committed by GitHub
parent d87d31a705
commit 710966eda7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 168 additions and 169 deletions

View File

@ -229,9 +229,11 @@ Let's train it:
```python
result = ts.trainer.offpolicy_trainer(
policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step,
test_num, batch_size, train_fn=lambda e: policy.set_eps(eps_train),
test_fn=lambda e: policy.set_eps(eps_test),
stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer, task=task)
test_num, batch_size,
train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
writer=writer, task=task)
print(f'Finished training! Use {result["duration"]}')
```

View File

@ -123,9 +123,9 @@ Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tians
policy, train_collector, test_collector,
max_epoch=10, step_per_epoch=1000, collect_per_step=10,
episode_per_test=100, batch_size=64,
train_fn=lambda e: policy.set_eps(0.1),
test_fn=lambda e: policy.set_eps(0.05),
stop_fn=lambda x: x >= env.spec.reward_threshold,
train_fn=lambda epoch, env_step: policy.set_eps(0.1),
test_fn=lambda epoch, env_step: policy.set_eps(0.05),
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
writer=None)
print(f'Finished training! Use {result["duration"]}')
@ -136,8 +136,8 @@ The meaning of each parameter is as follows (full description can be found at :m
* ``collect_per_step``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update";
* ``episode_per_test``: The number of episodes for one policy evaluation.
* ``batch_size``: The batch size of sample data, which is going to feed in the policy network.
* ``train_fn``: A function receives the current number of epoch index and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
* ``test_fn``: A function receives the current number of epoch index and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing".
* ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
* ``test_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing".
* ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal.
* ``writer``: See below.

View File

@ -334,15 +334,15 @@ With the above preparation, we are close to the first learned agent. The followi
policy.policies[args.agent_id - 1].state_dict(),
model_save_path)
def stop_fn(x):
return x >= args.win_rate # 95% winning rate by default
def stop_fn(mean_rewards):
return mean_rewards >= args.win_rate # 95% winning rate by default
# the default args.win_rate is 0.9, but the reward is [-1, 1]
# instead of [0, 1], so args.win_rate == 0.9 is equal to 95% win rate.
def train_fn(x):
def train_fn(epoch, env_step):
policy.policies[args.agent_id - 1].set_eps(args.eps_train)
def test_fn(x):
def test_fn(epoch, env_step):
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
# start training, this may require about three minutes

View File

@ -12,14 +12,14 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
| task | best reward | reward curve | parameters | time cost |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | ------------------- |
| PongNoFrameskip-v4 | 20 | ![](results/dqn/Pong_rew.png) | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch_size 64` | ~30 min (~15 epoch) |
| BreakoutNoFrameskip-v4 | 316 | ![](results/dqn/Breakout_rew.png) | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
| EnduroNoFrameskip-v4 | 670 | ![](results/dqn/Enduro_rew.png) | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --test_num 100` | 3~4h (100 epoch) |
| QbertNoFrameskip-v4 | 7307 | ![](results/dqn/Qbert_rew.png) | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
| MsPacmanNoFrameskip-v4 | 2107 | ![](results/dqn/MsPacman_rew.png) | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
| SeaquestNoFrameskip-v4 | 2088 | ![](results/dqn/Seaquest_rew.png) | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
| SpaceInvadersNoFrameskip-v4 | 812.2 | ![](results/dqn/SpaceInvader_rew.png) | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
| PongNoFrameskip-v4 | 20 | ![](results/dqn/Pong_rew.png) | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch-size 64` | ~30 min (~15 epoch) |
| BreakoutNoFrameskip-v4 | 316 | ![](results/dqn/Breakout_rew.png) | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) |
| EnduroNoFrameskip-v4 | 670 | ![](results/dqn/Enduro_rew.png) | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --test-num 100` | 3~4h (100 epoch) |
| QbertNoFrameskip-v4 | 7307 | ![](results/dqn/Qbert_rew.png) | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) |
| MsPacmanNoFrameskip-v4 | 2107 | ![](results/dqn/MsPacman_rew.png) | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) |
| SeaquestNoFrameskip-v4 | 2088 | ![](results/dqn/Seaquest_rew.png) | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) |
| SpaceInvadersNoFrameskip-v4 | 812.2 | ![](results/dqn/SpaceInvader_rew.png) | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) |
Note: The eps_train_final and eps_test in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed.
Note: The `eps_train_final` and `eps_test` in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed.
We haven't tuned this result to the best, so have fun with playing these hyperparameters!

View File

@ -18,20 +18,20 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps_test', type=float, default=0.005)
parser.add_argument('--eps_train', type=float, default=1.)
parser.add_argument('--eps_train_final', type=float, default=0.05)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--n_step', type=int, default=3)
parser.add_argument('--target_update_freq', type=int, default=500)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step_per_epoch', type=int, default=10000)
parser.add_argument('--collect_per_step', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--training_num', type=int, default=16)
parser.add_argument('--test_num', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
@ -95,26 +95,25 @@ def test_dqn(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
def stop_fn(mean_rewards):
if env.env.spec.reward_threshold:
return x >= env.spec.reward_threshold
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return x >= 20
return mean_rewards >= 20
else:
return False
def train_fn(x):
def train_fn(epoch, env_step):
# nature DQN setting, linear decay in the first 1M steps
now = x * args.collect_per_step * args.step_per_epoch
if now <= 1e6:
eps = args.eps_train - now / 1e6 * \
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * \
(args.eps_train - args.eps_train_final)
else:
eps = args.eps_train_final
policy.set_eps(eps)
writer.add_scalar('train/eps', eps, global_step=now)
writer.add_scalar('train/eps', eps, global_step=env_step)
def test_fn(x):
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
# watch agent's performance

View File

@ -1,3 +1,4 @@
import os
import torch
import pprint
import argparse
@ -76,11 +77,11 @@ def test_a2c(args=get_args()):
preprocess_fn=preprocess_fn)
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
# log
writer = SummaryWriter(args.logdir + '/' + 'a2c')
writer = SummaryWriter(os.path.join(args.logdir, args.task, 'a2c'))
def stop_fn(x):
def stop_fn(mean_rewards):
if env.env.spec.reward_threshold:
return x >= env.spec.reward_threshold
return mean_rewards >= env.spec.reward_threshold
else:
return False

View File

@ -1,3 +1,4 @@
import os
import torch
import pprint
import argparse
@ -80,11 +81,11 @@ def test_ppo(args=get_args()):
preprocess_fn=preprocess_fn)
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
# log
writer = SummaryWriter(args.logdir + '/' + 'ppo')
writer = SummaryWriter(os.path.join(args.logdir, args.task, 'ppo'))
def stop_fn(x):
def stop_fn(mean_rewards):
if env.env.spec.reward_threshold:
return x >= env.spec.reward_threshold
return mean_rewards >= env.spec.reward_threshold
else:
return False

View File

@ -6,11 +6,11 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.env import DummyVectorEnv
from tianshou.policy import DQNPolicy
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.common import Net
def get_args():
@ -75,20 +75,20 @@ def test_dqn(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
def train_fn(x):
if x <= int(0.1 * args.epoch):
def train_fn(epoch, env_step):
if env_step <= 100000:
policy.set_eps(args.eps_train)
elif x <= int(0.5 * args.epoch):
eps = args.eps_train - (x - 0.1 * args.epoch) / \
(0.4 * args.epoch) * (0.5 * args.eps_train)
elif env_step <= 500000:
eps = args.eps_train - (env_step - 100000) / \
400000 * (0.5 * args.eps_train)
policy.set_eps(eps)
else:
policy.set_eps(0.5 * args.eps_train)
def test_fn(x):
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
# trainer

View File

@ -74,9 +74,6 @@ class EnvWrapper(object):
def test_sac_bipedal(args=get_args()):
env = EnvWrapper(args.task)
def IsStop(reward):
return reward >= env.spec.reward_threshold
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
@ -141,11 +138,14 @@ def test_sac_bipedal(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=IsStop, save_fn=save_fn, writer=writer,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer,
test_in_train=False)
if __name__ == '__main__':

View File

@ -18,7 +18,7 @@ def get_args():
# the parameters are found by Optuna
parser.add_argument('--task', type=str, default='LunarLander-v2')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-test', type=float, default=0.01)
parser.add_argument('--eps-train', type=float, default=0.73)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.013)
@ -77,14 +77,14 @@ def test_dqn(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
def train_fn(x):
args.eps_train = max(args.eps_train * 0.6, 0.01)
policy.set_eps(args.eps_train)
def train_fn(epoch, env_step): # exp decay
eps = max(args.eps_train * (1 - 5e-6) ** env_step, args.eps_test)
policy.set_eps(eps)
def test_fn(x):
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
# trainer

View File

@ -98,8 +98,8 @@ def test_sac(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
# trainer
result = offpolicy_trainer(

View File

@ -6,11 +6,11 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DDPGPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.exploration import GaussianNoise
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.exploration import GaussianNoise
from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.continuous import Actor, Critic
@ -77,8 +77,8 @@ def test_ddpg(args=get_args()):
# log
writer = SummaryWriter(args.logdir + '/' + 'ddpg')
def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
# trainer
result = offpolicy_trainer(

View File

@ -7,10 +7,10 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import SACPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.continuous import ActorProb, Critic
@ -86,8 +86,8 @@ def test_sac(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
# trainer
result = offpolicy_trainer(

View File

@ -6,11 +6,11 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import TD3Policy
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.exploration import GaussianNoise
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.exploration import GaussianNoise
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic
@ -88,8 +88,8 @@ def test_td3(args=get_args()):
# log
writer = SummaryWriter(args.logdir + '/' + 'td3')
def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
# trainer
result = offpolicy_trainer(

View File

@ -4,18 +4,14 @@ import torch
import pprint
import argparse
import numpy as np
import pybullet_envs
from torch.utils.tensorboard import SummaryWriter
from tianshou.env import SubprocVectorEnv
from tianshou.policy import SACPolicy
from tianshou.utils.net.common import Net
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
try:
import pybullet_envs
except ImportError:
pass
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic
@ -91,8 +87,8 @@ def test_sac(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id)
writer = SummaryWriter(log_path)
def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
# trainer
result = offpolicy_trainer(

View File

@ -6,12 +6,13 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import TD3Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.common import Net
from tianshou.env import SubprocVectorEnv
from tianshou.exploration import GaussianNoise
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.continuous import Actor, Critic
from mujoco.register import reg
@ -40,7 +41,6 @@ def get_args():
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--max_episode_steps', type=int, default=2000)
return parser.parse_args()
@ -91,9 +91,9 @@ def test_td3(args=get_args()):
# log
writer = SummaryWriter(args.logdir + '/' + 'td3')
def stop_fn(x):
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return x >= env.spec.reward_threshold
return mean_rewards >= env.spec.reward_threshold
else:
return False

View File

@ -96,8 +96,8 @@ def test_ddpg(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
# trainer
result = offpolicy_trainer(

View File

@ -113,8 +113,8 @@ def test_ppo(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
# trainer
result = onpolicy_trainer(

View File

@ -96,8 +96,8 @@ def test_sac_with_il(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
# trainer
result = offpolicy_trainer(

View File

@ -103,8 +103,8 @@ def test_td3(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
# trainer
result = offpolicy_trainer(

View File

@ -84,8 +84,8 @@ def test_a2c_with_il(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
# trainer
result = onpolicy_trainer(

View File

@ -85,21 +85,21 @@ def test_dqn(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
def train_fn(x):
def train_fn(epoch, env_step):
# eps annnealing, just a demo
if x <= int(0.1 * args.epoch):
if env_step <= 10000:
policy.set_eps(args.eps_train)
elif x <= int(0.5 * args.epoch):
eps = args.eps_train - (x - 0.1 * args.epoch) / \
(0.4 * args.epoch) * (0.9 * args.eps_train)
elif env_step <= 50000:
eps = args.eps_train - (env_step - 10000) / \
40000 * (0.9 * args.eps_train)
policy.set_eps(eps)
else:
policy.set_eps(0.1 * args.eps_train)
def test_fn(x):
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
# trainer

View File

@ -79,13 +79,13 @@ def test_drqn(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
def train_fn(x):
def train_fn(epoch, env_step):
policy.set_eps(args.eps_train)
def test_fn(x):
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
# trainer

View File

@ -73,8 +73,8 @@ def test_pg(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
# trainer
result = onpolicy_trainer(

View File

@ -98,8 +98,8 @@ def test_ppo(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
# trainer
result = onpolicy_trainer(

View File

@ -93,8 +93,8 @@ def test_discrete_sac(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
# trainer
result = offpolicy_trainer(

View File

@ -66,9 +66,9 @@ def test_psrl(args=get_args()):
# log
writer = SummaryWriter(args.logdir + '/' + args.task)
def stop_fn(x):
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return x >= env.spec.reward_threshold
return mean_rewards >= env.spec.reward_threshold
else:
return False

View File

@ -64,11 +64,12 @@ def get_args() -> argparse.Namespace:
return args
def get_agents(args: argparse.Namespace = get_args(),
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
optim: Optional[torch.optim.Optimizer] = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer]:
def get_agents(
args: argparse.Namespace = get_args(),
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
optim: Optional[torch.optim.Optimizer] = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer]:
env = TicTacToeEnv(args.board_size, args.win_size)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
@ -99,11 +100,12 @@ def get_agents(args: argparse.Namespace = get_args(),
return policy, optim
def train_agent(args: argparse.Namespace = get_args(),
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
optim: Optional[torch.optim.Optimizer] = None,
) -> Tuple[dict, BasePolicy]:
def train_agent(
args: argparse.Namespace = get_args(),
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
optim: Optional[torch.optim.Optimizer] = None,
) -> Tuple[dict, BasePolicy]:
def env_func():
return TicTacToeEnv(args.board_size, args.win_size)
train_envs = DummyVectorEnv([env_func for _ in range(args.training_num)])
@ -142,13 +144,13 @@ def train_agent(args: argparse.Namespace = get_args(),
policy.policies[args.agent_id - 1].state_dict(),
model_save_path)
def stop_fn(x):
return x >= args.win_rate
def stop_fn(mean_rewards):
return mean_rewards >= args.win_rate
def train_fn(x):
def train_fn(epoch, env_step):
policy.policies[args.agent_id - 1].set_eps(args.eps_train)
def test_fn(x):
def test_fn(epoch, env_step):
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
# trainer
@ -162,10 +164,11 @@ def train_agent(args: argparse.Namespace = get_args(),
return result, policy.policies[args.agent_id - 1]
def watch(args: argparse.Namespace = get_args(),
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
) -> None:
def watch(
args: argparse.Namespace = get_args(),
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
) -> None:
env = TicTacToeEnv(args.board_size, args.win_size)
policy, optim = get_agents(
args, agent_learn=agent_learn, agent_opponent=agent_opponent)

View File

@ -1,7 +1,7 @@
from tianshou import data, env, utils, policy, trainer, exploration
__version__ = "0.3.0rc0"
__version__ = "0.3.0"
__all__ = [
"env",

View File

@ -230,7 +230,8 @@ class ReplayBuffer:
obs = obs[-1]
self._add_to_buffer("obs", obs)
self._add_to_buffer("act", act)
self._add_to_buffer("rew", rew)
# make sure the reward is a float instead of an int
self._add_to_buffer("rew", rew * 1.0) # type: ignore
self._add_to_buffer("done", done)
if self._save_s_:
if obs_next is None:

View File

@ -19,8 +19,8 @@ def offpolicy_trainer(
episode_per_test: Union[int, List[int]],
batch_size: int,
update_per_step: int = 1,
train_fn: Optional[Callable[[int], None]] = None,
test_fn: Optional[Callable[[int], None]] = None,
train_fn: Optional[Callable[[int, int], None]] = None,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
writer: Optional[SummaryWriter] = None,
@ -53,11 +53,11 @@ def offpolicy_trainer(
it updates policy 256 times once after ``collect_per_step`` frames are
collected.
:param function train_fn: a function receives the current number of epoch
index and performs some operations at the beginning of training in this
epoch.
and step index, and performs some operations at the beginning of
training in this epoch.
:param function test_fn: a function receives the current number of epoch
index and performs some operations at the beginning of testing in this
epoch.
and step index, and performs some operations at the beginning of
testing in this epoch.
:param function save_fn: a function for saving policy when the undiscounted
average mean reward in evaluation phase gets better.
:param function stop_fn: a function receives the average undiscounted
@ -81,12 +81,12 @@ def offpolicy_trainer(
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
) as t:
while t.n < t.total:
if train_fn:
train_fn(epoch, global_step)
result = train_collector.collect(n_step=collect_per_step)
data = {}
if test_in_train and stop_fn and stop_fn(result["rew"]):
@ -104,8 +104,6 @@ def offpolicy_trainer(
test_result["rew"])
else:
policy.train()
if train_fn:
train_fn(epoch)
for i in range(update_per_step * min(
result["n/st"] // collect_per_step, t.total - t.n)):
global_step += collect_per_step

View File

@ -19,8 +19,8 @@ def onpolicy_trainer(
repeat_per_collect: int,
episode_per_test: Union[int, List[int]],
batch_size: int,
train_fn: Optional[Callable[[int], None]] = None,
test_fn: Optional[Callable[[int], None]] = None,
train_fn: Optional[Callable[[int, int], None]] = None,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
writer: Optional[SummaryWriter] = None,
@ -53,11 +53,11 @@ def onpolicy_trainer(
:param int batch_size: the batch size of sample data, which is going to
feed in the policy network.
:param function train_fn: a function receives the current number of epoch
index and performs some operations at the beginning of training in this
epoch.
and step index, and performs some operations at the beginning of
training in this poch.
:param function test_fn: a function receives the current number of epoch
index and performs some operations at the beginning of testing in this
epoch.
and step index, and performs some operations at the beginning of
testing in this epoch.
:param function save_fn: a function for saving policy when the undiscounted
average mean reward in evaluation phase gets better.
:param function stop_fn: a function receives the average undiscounted
@ -81,12 +81,12 @@ def onpolicy_trainer(
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
) as t:
while t.n < t.total:
if train_fn:
train_fn(epoch, global_step)
result = train_collector.collect(n_episode=collect_per_step)
data = {}
if test_in_train and stop_fn and stop_fn(result["rew"]):
@ -104,8 +104,6 @@ def onpolicy_trainer(
test_result["rew"])
else:
policy.train()
if train_fn:
train_fn(epoch)
losses = policy.update(
0, train_collector.buffer,
batch_size=batch_size, repeat=repeat_per_collect)

View File

@ -10,7 +10,7 @@ from tianshou.policy import BasePolicy
def test_episode(
policy: BasePolicy,
collector: Collector,
test_fn: Optional[Callable[[int], None]],
test_fn: Optional[Callable[[int, Optional[int]], None]],
epoch: int,
n_episode: Union[int, List[int]],
writer: Optional[SummaryWriter] = None,
@ -21,7 +21,7 @@ def test_episode(
collector.reset_buffer()
policy.eval()
if test_fn:
test_fn(epoch)
test_fn(epoch, global_step)
if collector.get_env_num() > 1 and isinstance(n_episode, int):
n = collector.get_env_num()
n_ = np.zeros(n) + n_episode // n