change API of train_fn and test_fn (#229)
train_fn(epoch) -> train_fn(epoch, num_env_step) test_fn(epoch) -> test_fn(epoch, num_env_step)
This commit is contained in:
parent
d87d31a705
commit
710966eda7
@ -229,9 +229,11 @@ Let's train it:
|
||||
```python
|
||||
result = ts.trainer.offpolicy_trainer(
|
||||
policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step,
|
||||
test_num, batch_size, train_fn=lambda e: policy.set_eps(eps_train),
|
||||
test_fn=lambda e: policy.set_eps(eps_test),
|
||||
stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer, task=task)
|
||||
test_num, batch_size,
|
||||
train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
|
||||
test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
|
||||
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
|
||||
writer=writer, task=task)
|
||||
print(f'Finished training! Use {result["duration"]}')
|
||||
```
|
||||
|
||||
|
@ -123,9 +123,9 @@ Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tians
|
||||
policy, train_collector, test_collector,
|
||||
max_epoch=10, step_per_epoch=1000, collect_per_step=10,
|
||||
episode_per_test=100, batch_size=64,
|
||||
train_fn=lambda e: policy.set_eps(0.1),
|
||||
test_fn=lambda e: policy.set_eps(0.05),
|
||||
stop_fn=lambda x: x >= env.spec.reward_threshold,
|
||||
train_fn=lambda epoch, env_step: policy.set_eps(0.1),
|
||||
test_fn=lambda epoch, env_step: policy.set_eps(0.05),
|
||||
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
|
||||
writer=None)
|
||||
print(f'Finished training! Use {result["duration"]}')
|
||||
|
||||
@ -136,8 +136,8 @@ The meaning of each parameter is as follows (full description can be found at :m
|
||||
* ``collect_per_step``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update";
|
||||
* ``episode_per_test``: The number of episodes for one policy evaluation.
|
||||
* ``batch_size``: The batch size of sample data, which is going to feed in the policy network.
|
||||
* ``train_fn``: A function receives the current number of epoch index and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
|
||||
* ``test_fn``: A function receives the current number of epoch index and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing".
|
||||
* ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
|
||||
* ``test_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing".
|
||||
* ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal.
|
||||
* ``writer``: See below.
|
||||
|
||||
|
@ -334,15 +334,15 @@ With the above preparation, we are close to the first learned agent. The followi
|
||||
policy.policies[args.agent_id - 1].state_dict(),
|
||||
model_save_path)
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= args.win_rate # 95% winning rate by default
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= args.win_rate # 95% winning rate by default
|
||||
# the default args.win_rate is 0.9, but the reward is [-1, 1]
|
||||
# instead of [0, 1], so args.win_rate == 0.9 is equal to 95% win rate.
|
||||
|
||||
def train_fn(x):
|
||||
def train_fn(epoch, env_step):
|
||||
policy.policies[args.agent_id - 1].set_eps(args.eps_train)
|
||||
|
||||
def test_fn(x):
|
||||
def test_fn(epoch, env_step):
|
||||
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
|
||||
|
||||
# start training, this may require about three minutes
|
||||
|
@ -12,14 +12,14 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
||||
|
||||
| task | best reward | reward curve | parameters | time cost |
|
||||
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | ------------------- |
|
||||
| PongNoFrameskip-v4 | 20 |  | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch_size 64` | ~30 min (~15 epoch) |
|
||||
| BreakoutNoFrameskip-v4 | 316 |  | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
|
||||
| EnduroNoFrameskip-v4 | 670 |  | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --test_num 100` | 3~4h (100 epoch) |
|
||||
| QbertNoFrameskip-v4 | 7307 |  | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
|
||||
| MsPacmanNoFrameskip-v4 | 2107 |  | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
|
||||
| SeaquestNoFrameskip-v4 | 2088 |  | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
|
||||
| SpaceInvadersNoFrameskip-v4 | 812.2 |  | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
|
||||
| PongNoFrameskip-v4 | 20 |  | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch-size 64` | ~30 min (~15 epoch) |
|
||||
| BreakoutNoFrameskip-v4 | 316 |  | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) |
|
||||
| EnduroNoFrameskip-v4 | 670 |  | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --test-num 100` | 3~4h (100 epoch) |
|
||||
| QbertNoFrameskip-v4 | 7307 |  | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) |
|
||||
| MsPacmanNoFrameskip-v4 | 2107 |  | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) |
|
||||
| SeaquestNoFrameskip-v4 | 2088 |  | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) |
|
||||
| SpaceInvadersNoFrameskip-v4 | 812.2 |  | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) |
|
||||
|
||||
Note: The eps_train_final and eps_test in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed.
|
||||
Note: The `eps_train_final` and `eps_test` in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed.
|
||||
|
||||
We haven't tuned this result to the best, so have fun with playing these hyperparameters!
|
||||
|
@ -18,20 +18,20 @@ def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--eps_test', type=float, default=0.005)
|
||||
parser.add_argument('--eps_train', type=float, default=1.)
|
||||
parser.add_argument('--eps_train_final', type=float, default=0.05)
|
||||
parser.add_argument('--eps-test', type=float, default=0.005)
|
||||
parser.add_argument('--eps-train', type=float, default=1.)
|
||||
parser.add_argument('--eps-train-final', type=float, default=0.05)
|
||||
parser.add_argument('--buffer-size', type=int, default=100000)
|
||||
parser.add_argument('--lr', type=float, default=0.0001)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--n_step', type=int, default=3)
|
||||
parser.add_argument('--target_update_freq', type=int, default=500)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step_per_epoch', type=int, default=10000)
|
||||
parser.add_argument('--collect_per_step', type=int, default=10)
|
||||
parser.add_argument('--batch_size', type=int, default=32)
|
||||
parser.add_argument('--training_num', type=int, default=16)
|
||||
parser.add_argument('--test_num', type=int, default=10)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=32)
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
@ -95,26 +95,25 @@ def test_dqn(args=get_args()):
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
def stop_fn(mean_rewards):
|
||||
if env.env.spec.reward_threshold:
|
||||
return x >= env.spec.reward_threshold
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
elif 'Pong' in args.task:
|
||||
return x >= 20
|
||||
return mean_rewards >= 20
|
||||
else:
|
||||
return False
|
||||
|
||||
def train_fn(x):
|
||||
def train_fn(epoch, env_step):
|
||||
# nature DQN setting, linear decay in the first 1M steps
|
||||
now = x * args.collect_per_step * args.step_per_epoch
|
||||
if now <= 1e6:
|
||||
eps = args.eps_train - now / 1e6 * \
|
||||
if env_step <= 1e6:
|
||||
eps = args.eps_train - env_step / 1e6 * \
|
||||
(args.eps_train - args.eps_train_final)
|
||||
else:
|
||||
eps = args.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
writer.add_scalar('train/eps', eps, global_step=now)
|
||||
writer.add_scalar('train/eps', eps, global_step=env_step)
|
||||
|
||||
def test_fn(x):
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# watch agent's performance
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
@ -76,11 +77,11 @@ def test_a2c(args=get_args()):
|
||||
preprocess_fn=preprocess_fn)
|
||||
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir + '/' + 'a2c')
|
||||
writer = SummaryWriter(os.path.join(args.logdir, args.task, 'a2c'))
|
||||
|
||||
def stop_fn(x):
|
||||
def stop_fn(mean_rewards):
|
||||
if env.env.spec.reward_threshold:
|
||||
return x >= env.spec.reward_threshold
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
else:
|
||||
return False
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
@ -80,11 +81,11 @@ def test_ppo(args=get_args()):
|
||||
preprocess_fn=preprocess_fn)
|
||||
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir + '/' + 'ppo')
|
||||
writer = SummaryWriter(os.path.join(args.logdir, args.task, 'ppo'))
|
||||
|
||||
def stop_fn(x):
|
||||
def stop_fn(mean_rewards):
|
||||
if env.env.spec.reward_threshold:
|
||||
return x >= env.spec.reward_threshold
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
else:
|
||||
return False
|
||||
|
||||
|
@ -6,11 +6,11 @@ import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.utils.net.common import Net
|
||||
|
||||
|
||||
def get_args():
|
||||
@ -75,20 +75,20 @@ def test_dqn(args=get_args()):
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
def train_fn(x):
|
||||
if x <= int(0.1 * args.epoch):
|
||||
def train_fn(epoch, env_step):
|
||||
if env_step <= 100000:
|
||||
policy.set_eps(args.eps_train)
|
||||
elif x <= int(0.5 * args.epoch):
|
||||
eps = args.eps_train - (x - 0.1 * args.epoch) / \
|
||||
(0.4 * args.epoch) * (0.5 * args.eps_train)
|
||||
elif env_step <= 500000:
|
||||
eps = args.eps_train - (env_step - 100000) / \
|
||||
400000 * (0.5 * args.eps_train)
|
||||
policy.set_eps(eps)
|
||||
else:
|
||||
policy.set_eps(0.5 * args.eps_train)
|
||||
|
||||
def test_fn(x):
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# trainer
|
||||
|
@ -74,9 +74,6 @@ class EnvWrapper(object):
|
||||
def test_sac_bipedal(args=get_args()):
|
||||
env = EnvWrapper(args.task)
|
||||
|
||||
def IsStop(reward):
|
||||
return reward >= env.spec.reward_threshold
|
||||
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
@ -141,11 +138,14 @@ def test_sac_bipedal(args=get_args()):
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=IsStop, save_fn=save_fn, writer=writer,
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
||||
test_in_train=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -18,7 +18,7 @@ def get_args():
|
||||
# the parameters are found by Optuna
|
||||
parser.add_argument('--task', type=str, default='LunarLander-v2')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-test', type=float, default=0.01)
|
||||
parser.add_argument('--eps-train', type=float, default=0.73)
|
||||
parser.add_argument('--buffer-size', type=int, default=100000)
|
||||
parser.add_argument('--lr', type=float, default=0.013)
|
||||
@ -77,14 +77,14 @@ def test_dqn(args=get_args()):
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
def train_fn(x):
|
||||
args.eps_train = max(args.eps_train * 0.6, 0.01)
|
||||
policy.set_eps(args.eps_train)
|
||||
def train_fn(epoch, env_step): # exp decay
|
||||
eps = max(args.eps_train * (1 - 5e-6) ** env_step, args.eps_test)
|
||||
policy.set_eps(eps)
|
||||
|
||||
def test_fn(x):
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# trainer
|
||||
|
@ -98,8 +98,8 @@ def test_sac(args=get_args()):
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
|
@ -6,11 +6,11 @@ import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.exploration import GaussianNoise
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.exploration import GaussianNoise
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.utils.net.continuous import Actor, Critic
|
||||
|
||||
|
||||
@ -77,8 +77,8 @@ def test_ddpg(args=get_args()):
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir + '/' + 'ddpg')
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
|
@ -7,10 +7,10 @@ import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import SACPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
|
||||
@ -86,8 +86,8 @@ def test_sac(args=get_args()):
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
|
@ -6,11 +6,11 @@ import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import TD3Policy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.exploration import GaussianNoise
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.exploration import GaussianNoise
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import Actor, Critic
|
||||
|
||||
|
||||
@ -88,8 +88,8 @@ def test_td3(args=get_args()):
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir + '/' + 'td3')
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
|
@ -4,18 +4,14 @@ import torch
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
import pybullet_envs
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.policy import SACPolicy
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
try:
|
||||
import pybullet_envs
|
||||
except ImportError:
|
||||
pass
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
|
||||
@ -91,8 +87,8 @@ def test_sac(args=get_args()):
|
||||
log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id)
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
|
@ -6,12 +6,13 @@ import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import TD3Policy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.exploration import GaussianNoise
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.utils.net.continuous import Actor, Critic
|
||||
|
||||
from mujoco.register import reg
|
||||
|
||||
|
||||
@ -40,7 +41,6 @@ def get_args():
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
parser.add_argument('--max_episode_steps', type=int, default=2000)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -91,9 +91,9 @@ def test_td3(args=get_args()):
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir + '/' + 'td3')
|
||||
|
||||
def stop_fn(x):
|
||||
def stop_fn(mean_rewards):
|
||||
if env.spec.reward_threshold:
|
||||
return x >= env.spec.reward_threshold
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
else:
|
||||
return False
|
||||
|
||||
|
@ -96,8 +96,8 @@ def test_ddpg(args=get_args()):
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
|
@ -113,8 +113,8 @@ def test_ppo(args=get_args()):
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = onpolicy_trainer(
|
||||
|
@ -96,8 +96,8 @@ def test_sac_with_il(args=get_args()):
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
|
@ -103,8 +103,8 @@ def test_td3(args=get_args()):
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
|
@ -84,8 +84,8 @@ def test_a2c_with_il(args=get_args()):
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = onpolicy_trainer(
|
||||
|
@ -85,21 +85,21 @@ def test_dqn(args=get_args()):
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
def train_fn(x):
|
||||
def train_fn(epoch, env_step):
|
||||
# eps annnealing, just a demo
|
||||
if x <= int(0.1 * args.epoch):
|
||||
if env_step <= 10000:
|
||||
policy.set_eps(args.eps_train)
|
||||
elif x <= int(0.5 * args.epoch):
|
||||
eps = args.eps_train - (x - 0.1 * args.epoch) / \
|
||||
(0.4 * args.epoch) * (0.9 * args.eps_train)
|
||||
elif env_step <= 50000:
|
||||
eps = args.eps_train - (env_step - 10000) / \
|
||||
40000 * (0.9 * args.eps_train)
|
||||
policy.set_eps(eps)
|
||||
else:
|
||||
policy.set_eps(0.1 * args.eps_train)
|
||||
|
||||
def test_fn(x):
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# trainer
|
||||
|
@ -79,13 +79,13 @@ def test_drqn(args=get_args()):
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
def train_fn(x):
|
||||
def train_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_train)
|
||||
|
||||
def test_fn(x):
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# trainer
|
||||
|
@ -73,8 +73,8 @@ def test_pg(args=get_args()):
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = onpolicy_trainer(
|
||||
|
@ -98,8 +98,8 @@ def test_ppo(args=get_args()):
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = onpolicy_trainer(
|
||||
|
@ -93,8 +93,8 @@ def test_discrete_sac(args=get_args()):
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
|
@ -66,9 +66,9 @@ def test_psrl(args=get_args()):
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir + '/' + args.task)
|
||||
|
||||
def stop_fn(x):
|
||||
def stop_fn(mean_rewards):
|
||||
if env.spec.reward_threshold:
|
||||
return x >= env.spec.reward_threshold
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
else:
|
||||
return False
|
||||
|
||||
|
@ -64,11 +64,12 @@ def get_args() -> argparse.Namespace:
|
||||
return args
|
||||
|
||||
|
||||
def get_agents(args: argparse.Namespace = get_args(),
|
||||
agent_learn: Optional[BasePolicy] = None,
|
||||
agent_opponent: Optional[BasePolicy] = None,
|
||||
optim: Optional[torch.optim.Optimizer] = None,
|
||||
) -> Tuple[BasePolicy, torch.optim.Optimizer]:
|
||||
def get_agents(
|
||||
args: argparse.Namespace = get_args(),
|
||||
agent_learn: Optional[BasePolicy] = None,
|
||||
agent_opponent: Optional[BasePolicy] = None,
|
||||
optim: Optional[torch.optim.Optimizer] = None,
|
||||
) -> Tuple[BasePolicy, torch.optim.Optimizer]:
|
||||
env = TicTacToeEnv(args.board_size, args.win_size)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
@ -99,11 +100,12 @@ def get_agents(args: argparse.Namespace = get_args(),
|
||||
return policy, optim
|
||||
|
||||
|
||||
def train_agent(args: argparse.Namespace = get_args(),
|
||||
agent_learn: Optional[BasePolicy] = None,
|
||||
agent_opponent: Optional[BasePolicy] = None,
|
||||
optim: Optional[torch.optim.Optimizer] = None,
|
||||
) -> Tuple[dict, BasePolicy]:
|
||||
def train_agent(
|
||||
args: argparse.Namespace = get_args(),
|
||||
agent_learn: Optional[BasePolicy] = None,
|
||||
agent_opponent: Optional[BasePolicy] = None,
|
||||
optim: Optional[torch.optim.Optimizer] = None,
|
||||
) -> Tuple[dict, BasePolicy]:
|
||||
def env_func():
|
||||
return TicTacToeEnv(args.board_size, args.win_size)
|
||||
train_envs = DummyVectorEnv([env_func for _ in range(args.training_num)])
|
||||
@ -142,13 +144,13 @@ def train_agent(args: argparse.Namespace = get_args(),
|
||||
policy.policies[args.agent_id - 1].state_dict(),
|
||||
model_save_path)
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= args.win_rate
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= args.win_rate
|
||||
|
||||
def train_fn(x):
|
||||
def train_fn(epoch, env_step):
|
||||
policy.policies[args.agent_id - 1].set_eps(args.eps_train)
|
||||
|
||||
def test_fn(x):
|
||||
def test_fn(epoch, env_step):
|
||||
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
|
||||
|
||||
# trainer
|
||||
@ -162,10 +164,11 @@ def train_agent(args: argparse.Namespace = get_args(),
|
||||
return result, policy.policies[args.agent_id - 1]
|
||||
|
||||
|
||||
def watch(args: argparse.Namespace = get_args(),
|
||||
agent_learn: Optional[BasePolicy] = None,
|
||||
agent_opponent: Optional[BasePolicy] = None,
|
||||
) -> None:
|
||||
def watch(
|
||||
args: argparse.Namespace = get_args(),
|
||||
agent_learn: Optional[BasePolicy] = None,
|
||||
agent_opponent: Optional[BasePolicy] = None,
|
||||
) -> None:
|
||||
env = TicTacToeEnv(args.board_size, args.win_size)
|
||||
policy, optim = get_agents(
|
||||
args, agent_learn=agent_learn, agent_opponent=agent_opponent)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from tianshou import data, env, utils, policy, trainer, exploration
|
||||
|
||||
|
||||
__version__ = "0.3.0rc0"
|
||||
__version__ = "0.3.0"
|
||||
|
||||
__all__ = [
|
||||
"env",
|
||||
|
@ -230,7 +230,8 @@ class ReplayBuffer:
|
||||
obs = obs[-1]
|
||||
self._add_to_buffer("obs", obs)
|
||||
self._add_to_buffer("act", act)
|
||||
self._add_to_buffer("rew", rew)
|
||||
# make sure the reward is a float instead of an int
|
||||
self._add_to_buffer("rew", rew * 1.0) # type: ignore
|
||||
self._add_to_buffer("done", done)
|
||||
if self._save_s_:
|
||||
if obs_next is None:
|
||||
|
@ -19,8 +19,8 @@ def offpolicy_trainer(
|
||||
episode_per_test: Union[int, List[int]],
|
||||
batch_size: int,
|
||||
update_per_step: int = 1,
|
||||
train_fn: Optional[Callable[[int], None]] = None,
|
||||
test_fn: Optional[Callable[[int], None]] = None,
|
||||
train_fn: Optional[Callable[[int, int], None]] = None,
|
||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
@ -53,11 +53,11 @@ def offpolicy_trainer(
|
||||
it updates policy 256 times once after ``collect_per_step`` frames are
|
||||
collected.
|
||||
:param function train_fn: a function receives the current number of epoch
|
||||
index and performs some operations at the beginning of training in this
|
||||
epoch.
|
||||
and step index, and performs some operations at the beginning of
|
||||
training in this epoch.
|
||||
:param function test_fn: a function receives the current number of epoch
|
||||
index and performs some operations at the beginning of testing in this
|
||||
epoch.
|
||||
and step index, and performs some operations at the beginning of
|
||||
testing in this epoch.
|
||||
:param function save_fn: a function for saving policy when the undiscounted
|
||||
average mean reward in evaluation phase gets better.
|
||||
:param function stop_fn: a function receives the average undiscounted
|
||||
@ -81,12 +81,12 @@ def offpolicy_trainer(
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
# train
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
with tqdm.tqdm(
|
||||
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
|
||||
) as t:
|
||||
while t.n < t.total:
|
||||
if train_fn:
|
||||
train_fn(epoch, global_step)
|
||||
result = train_collector.collect(n_step=collect_per_step)
|
||||
data = {}
|
||||
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
||||
@ -104,8 +104,6 @@ def offpolicy_trainer(
|
||||
test_result["rew"])
|
||||
else:
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
for i in range(update_per_step * min(
|
||||
result["n/st"] // collect_per_step, t.total - t.n)):
|
||||
global_step += collect_per_step
|
||||
|
@ -19,8 +19,8 @@ def onpolicy_trainer(
|
||||
repeat_per_collect: int,
|
||||
episode_per_test: Union[int, List[int]],
|
||||
batch_size: int,
|
||||
train_fn: Optional[Callable[[int], None]] = None,
|
||||
test_fn: Optional[Callable[[int], None]] = None,
|
||||
train_fn: Optional[Callable[[int, int], None]] = None,
|
||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
@ -53,11 +53,11 @@ def onpolicy_trainer(
|
||||
:param int batch_size: the batch size of sample data, which is going to
|
||||
feed in the policy network.
|
||||
:param function train_fn: a function receives the current number of epoch
|
||||
index and performs some operations at the beginning of training in this
|
||||
epoch.
|
||||
and step index, and performs some operations at the beginning of
|
||||
training in this poch.
|
||||
:param function test_fn: a function receives the current number of epoch
|
||||
index and performs some operations at the beginning of testing in this
|
||||
epoch.
|
||||
and step index, and performs some operations at the beginning of
|
||||
testing in this epoch.
|
||||
:param function save_fn: a function for saving policy when the undiscounted
|
||||
average mean reward in evaluation phase gets better.
|
||||
:param function stop_fn: a function receives the average undiscounted
|
||||
@ -81,12 +81,12 @@ def onpolicy_trainer(
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
# train
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
with tqdm.tqdm(
|
||||
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
|
||||
) as t:
|
||||
while t.n < t.total:
|
||||
if train_fn:
|
||||
train_fn(epoch, global_step)
|
||||
result = train_collector.collect(n_episode=collect_per_step)
|
||||
data = {}
|
||||
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
||||
@ -104,8 +104,6 @@ def onpolicy_trainer(
|
||||
test_result["rew"])
|
||||
else:
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
losses = policy.update(
|
||||
0, train_collector.buffer,
|
||||
batch_size=batch_size, repeat=repeat_per_collect)
|
||||
|
@ -10,7 +10,7 @@ from tianshou.policy import BasePolicy
|
||||
def test_episode(
|
||||
policy: BasePolicy,
|
||||
collector: Collector,
|
||||
test_fn: Optional[Callable[[int], None]],
|
||||
test_fn: Optional[Callable[[int, Optional[int]], None]],
|
||||
epoch: int,
|
||||
n_episode: Union[int, List[int]],
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
@ -21,7 +21,7 @@ def test_episode(
|
||||
collector.reset_buffer()
|
||||
policy.eval()
|
||||
if test_fn:
|
||||
test_fn(epoch)
|
||||
test_fn(epoch, global_step)
|
||||
if collector.get_env_num() > 1 and isinstance(n_episode, int):
|
||||
n = collector.get_env_num()
|
||||
n_ = np.zeros(n) + n_episode // n
|
||||
|
Loading…
x
Reference in New Issue
Block a user