save_fn
This commit is contained in:
parent
74407e13da
commit
6a244d1fbb
2
.gitignore
vendored
2
.gitignore
vendored
@ -141,3 +141,5 @@ flake8.sh
|
||||
log/
|
||||
MUJOCO_LOG.TXT
|
||||
*.pth
|
||||
.vscode/
|
||||
.DS_Store
|
||||
|
4
docs/_static/css/style.css
vendored
4
docs/_static/css/style.css
vendored
@ -112,6 +112,10 @@ footer p {
|
||||
font-size: 100%;
|
||||
}
|
||||
|
||||
.ethical-rtd {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* For hidden headers that appear in TOC tree */
|
||||
/* see http://stackoverflow.com/a/32363545/3343043 */
|
||||
.rst-content .hidden-section {
|
||||
|
@ -20,7 +20,6 @@ else: # pytest
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||
parser.add_argument('--run-id', type=str, default='test')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--actor-lr', type=float, default=1e-4)
|
||||
@ -84,9 +83,12 @@ def test_ddpg(args=get_args()):
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'ddpg', args.run_id)
|
||||
log_path = os.path.join(args.logdir, args.task, 'ddpg')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
|
||||
@ -94,7 +96,7 @@ def test_ddpg(args=get_args()):
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -20,7 +20,6 @@ else: # pytest
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||
parser.add_argument('--run-id', type=str, default='test')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
@ -92,9 +91,12 @@ def _test_ppo(args=get_args()):
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.collect(n_step=args.step_per_epoch)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'ppo', args.run_id)
|
||||
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
|
||||
@ -102,7 +104,8 @@ def _test_ppo(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
|
||||
writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -20,7 +20,6 @@ else: # pytest
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||
parser.add_argument('--run-id', type=str, default='test')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
||||
@ -89,9 +88,12 @@ def test_sac(args=get_args()):
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# train_collector.collect(n_step=args.buffer_size)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id)
|
||||
log_path = os.path.join(args.logdir, args.task, 'sac')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
|
||||
@ -99,7 +101,7 @@ def test_sac(args=get_args()):
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -20,7 +20,6 @@ else: # pytest
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||
parser.add_argument('--run-id', type=str, default='test')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
||||
@ -93,9 +92,12 @@ def test_td3(args=get_args()):
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# train_collector.collect(n_step=args.buffer_size)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'td3', args.run_id)
|
||||
log_path = os.path.join(args.logdir, args.task, 'td3')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
|
||||
@ -103,7 +105,7 @@ def test_td3(args=get_args()):
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
@ -76,7 +77,11 @@ def test_a2c(args=get_args()):
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir + '/' + 'a2c')
|
||||
log_path = os.path.join(args.logdir, args.task, 'a2c')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
@ -85,7 +90,8 @@ def test_a2c(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
|
||||
writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
@ -74,7 +75,11 @@ def test_dqn(args=get_args()):
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size)
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir + '/' + 'dqn')
|
||||
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
@ -90,7 +95,7 @@ def test_dqn(args=get_args()):
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, writer=writer)
|
||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
@ -78,7 +79,11 @@ def test_drqn(args=get_args()):
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size)
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir + '/' + 'dqn')
|
||||
log_path = os.path.join(args.logdir, args.task, 'drqn')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
@ -94,7 +99,7 @@ def test_drqn(args=get_args()):
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, writer=writer)
|
||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import gym
|
||||
import time
|
||||
import torch
|
||||
@ -125,7 +126,11 @@ def test_pg(args=get_args()):
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir + '/' + 'pg')
|
||||
log_path = os.path.join(args.logdir, args.task, 'pg')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
@ -134,7 +139,8 @@ def test_pg(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
|
||||
writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
@ -81,7 +82,11 @@ def test_ppo(args=get_args()):
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir + '/' + 'ppo')
|
||||
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
@ -90,7 +95,8 @@ def test_ppo(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
|
||||
writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -49,7 +49,7 @@ class ReplayBuffer(object):
|
||||
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
|
||||
>>> for i in range(16):
|
||||
... done = i % 5 == 0
|
||||
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i, info={})
|
||||
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1)
|
||||
>>> print(buf)
|
||||
ReplayBuffer(
|
||||
obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
|
||||
@ -74,6 +74,17 @@ class ReplayBuffer(object):
|
||||
>>> # (stack only for obs and obs_next)
|
||||
>>> abs(buf.get(index, 'obs') - buf[index].obs).sum().sum()
|
||||
0.0
|
||||
>>> # we can get obs_next through __getitem__, even if it doesn't store
|
||||
>>> print(buf[:].obs_next)
|
||||
[[ 7. 8. 9. 10.]
|
||||
[ 7. 8. 9. 10.]
|
||||
[11. 11. 11. 12.]
|
||||
[11. 11. 12. 13.]
|
||||
[11. 12. 13. 14.]
|
||||
[12. 13. 14. 15.]
|
||||
[12. 13. 14. 15.]
|
||||
[ 7. 7. 7. 8.]
|
||||
[ 7. 7. 8. 9.]]
|
||||
"""
|
||||
|
||||
def __init__(self, size, stack_num=0, ignore_obs_next=False, **kwargs):
|
||||
|
@ -8,9 +8,9 @@ from tianshou.trainer import test_episode, gather_info
|
||||
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
step_per_epoch, collect_per_step, episode_per_test,
|
||||
batch_size,
|
||||
train_fn=None, test_fn=None, stop_fn=None, log_fn=None,
|
||||
writer=None, log_interval=1, verbose=True, task='',
|
||||
**kwargs):
|
||||
train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
|
||||
log_fn=None, writer=None, log_interval=1, verbose=True,
|
||||
task='', **kwargs):
|
||||
"""A wrapper for off-policy trainer procedure.
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||
@ -35,6 +35,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
:param function test_fn: a function receives the current number of epoch
|
||||
index and performs some operations at the beginning of testing in this
|
||||
epoch.
|
||||
:param function save_fn: a function for saving policy when the undiscounted
|
||||
average mean reward in evaluation phase gets better.
|
||||
:param function stop_fn: a function receives the average undiscounted
|
||||
returns of the testing result, return a boolean which indicates whether
|
||||
reaching the goal.
|
||||
@ -66,6 +68,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
policy, test_collector, test_fn,
|
||||
epoch, episode_per_test)
|
||||
if stop_fn and stop_fn(test_result['rew']):
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
for k in result.keys():
|
||||
data[k] = f'{result[k]:.2f}'
|
||||
t.set_postfix(**data)
|
||||
@ -105,6 +109,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
if best_epoch == -1 or best_reward < result['rew']:
|
||||
best_reward = result['rew']
|
||||
best_epoch = epoch
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
if verbose:
|
||||
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
|
||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||
|
@ -8,9 +8,9 @@ from tianshou.trainer import test_episode, gather_info
|
||||
def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
step_per_epoch, collect_per_step, repeat_per_collect,
|
||||
episode_per_test, batch_size,
|
||||
train_fn=None, test_fn=None, stop_fn=None, log_fn=None,
|
||||
writer=None, log_interval=1, verbose=True, task='',
|
||||
**kwargs):
|
||||
train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
|
||||
log_fn=None, writer=None, log_interval=1, verbose=True,
|
||||
task='', **kwargs):
|
||||
"""A wrapper for on-policy trainer procedure.
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||
@ -39,6 +39,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
:param function test_fn: a function receives the current number of epoch
|
||||
index and performs some operations at the beginning of testing in this
|
||||
epoch.
|
||||
:param function save_fn: a function for saving policy when the undiscounted
|
||||
average mean reward in evaluation phase gets better.
|
||||
:param function stop_fn: a function receives the average undiscounted
|
||||
returns of the testing result, return a boolean which indicates whether
|
||||
reaching the goal.
|
||||
@ -70,6 +72,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
policy, test_collector, test_fn,
|
||||
epoch, episode_per_test)
|
||||
if stop_fn and stop_fn(test_result['rew']):
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
for k in result.keys():
|
||||
data[k] = f'{result[k]:.2f}'
|
||||
t.set_postfix(**data)
|
||||
@ -113,6 +117,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
if best_epoch == -1 or best_reward < result['rew']:
|
||||
best_reward = result['rew']
|
||||
best_epoch = epoch
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
if verbose:
|
||||
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
|
||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||
|
Loading…
x
Reference in New Issue
Block a user