save_fn
This commit is contained in:
parent
74407e13da
commit
6a244d1fbb
2
.gitignore
vendored
2
.gitignore
vendored
@ -141,3 +141,5 @@ flake8.sh
|
|||||||
log/
|
log/
|
||||||
MUJOCO_LOG.TXT
|
MUJOCO_LOG.TXT
|
||||||
*.pth
|
*.pth
|
||||||
|
.vscode/
|
||||||
|
.DS_Store
|
||||||
|
4
docs/_static/css/style.css
vendored
4
docs/_static/css/style.css
vendored
@ -112,6 +112,10 @@ footer p {
|
|||||||
font-size: 100%;
|
font-size: 100%;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.ethical-rtd {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
/* For hidden headers that appear in TOC tree */
|
/* For hidden headers that appear in TOC tree */
|
||||||
/* see http://stackoverflow.com/a/32363545/3343043 */
|
/* see http://stackoverflow.com/a/32363545/3343043 */
|
||||||
.rst-content .hidden-section {
|
.rst-content .hidden-section {
|
||||||
|
@ -20,7 +20,6 @@ else: # pytest
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||||
parser.add_argument('--run-id', type=str, default='test')
|
|
||||||
parser.add_argument('--seed', type=int, default=0)
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--actor-lr', type=float, default=1e-4)
|
parser.add_argument('--actor-lr', type=float, default=1e-4)
|
||||||
@ -84,9 +83,12 @@ def test_ddpg(args=get_args()):
|
|||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'ddpg', args.run_id)
|
log_path = os.path.join(args.logdir, args.task, 'ddpg')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
def stop_fn(x):
|
def stop_fn(x):
|
||||||
return x >= env.spec.reward_threshold
|
return x >= env.spec.reward_threshold
|
||||||
|
|
||||||
@ -94,7 +96,7 @@ def test_ddpg(args=get_args()):
|
|||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
@ -20,7 +20,6 @@ else: # pytest
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||||
parser.add_argument('--run-id', type=str, default='test')
|
|
||||||
parser.add_argument('--seed', type=int, default=0)
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=3e-4)
|
parser.add_argument('--lr', type=float, default=3e-4)
|
||||||
@ -92,9 +91,12 @@ def _test_ppo(args=get_args()):
|
|||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
train_collector.collect(n_step=args.step_per_epoch)
|
train_collector.collect(n_step=args.step_per_epoch)
|
||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'ppo', args.run_id)
|
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
def stop_fn(x):
|
def stop_fn(x):
|
||||||
return x >= env.spec.reward_threshold
|
return x >= env.spec.reward_threshold
|
||||||
|
|
||||||
@ -102,7 +104,8 @@ def _test_ppo(args=get_args()):
|
|||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
|
||||||
|
writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
@ -20,7 +20,6 @@ else: # pytest
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||||
parser.add_argument('--run-id', type=str, default='test')
|
|
||||||
parser.add_argument('--seed', type=int, default=0)
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
||||||
@ -89,9 +88,12 @@ def test_sac(args=get_args()):
|
|||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# train_collector.collect(n_step=args.buffer_size)
|
# train_collector.collect(n_step=args.buffer_size)
|
||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id)
|
log_path = os.path.join(args.logdir, args.task, 'sac')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
def stop_fn(x):
|
def stop_fn(x):
|
||||||
return x >= env.spec.reward_threshold
|
return x >= env.spec.reward_threshold
|
||||||
|
|
||||||
@ -99,7 +101,7 @@ def test_sac(args=get_args()):
|
|||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
@ -20,7 +20,6 @@ else: # pytest
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||||
parser.add_argument('--run-id', type=str, default='test')
|
|
||||||
parser.add_argument('--seed', type=int, default=0)
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
||||||
@ -93,9 +92,12 @@ def test_td3(args=get_args()):
|
|||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# train_collector.collect(n_step=args.buffer_size)
|
# train_collector.collect(n_step=args.buffer_size)
|
||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'td3', args.run_id)
|
log_path = os.path.join(args.logdir, args.task, 'td3')
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
def stop_fn(x):
|
def stop_fn(x):
|
||||||
return x >= env.spec.reward_threshold
|
return x >= env.spec.reward_threshold
|
||||||
|
|
||||||
@ -103,7 +105,7 @@ def test_td3(args=get_args()):
|
|||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import gym
|
import gym
|
||||||
import torch
|
import torch
|
||||||
import pprint
|
import pprint
|
||||||
@ -76,7 +77,11 @@ def test_a2c(args=get_args()):
|
|||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir + '/' + 'a2c')
|
log_path = os.path.join(args.logdir, args.task, 'a2c')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
def stop_fn(x):
|
def stop_fn(x):
|
||||||
return x >= env.spec.reward_threshold
|
return x >= env.spec.reward_threshold
|
||||||
@ -85,7 +90,8 @@ def test_a2c(args=get_args()):
|
|||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
|
||||||
|
writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import gym
|
import gym
|
||||||
import torch
|
import torch
|
||||||
import pprint
|
import pprint
|
||||||
@ -74,7 +75,11 @@ def test_dqn(args=get_args()):
|
|||||||
# policy.set_eps(1)
|
# policy.set_eps(1)
|
||||||
train_collector.collect(n_step=args.batch_size)
|
train_collector.collect(n_step=args.batch_size)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir + '/' + 'dqn')
|
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
def stop_fn(x):
|
def stop_fn(x):
|
||||||
return x >= env.spec.reward_threshold
|
return x >= env.spec.reward_threshold
|
||||||
@ -90,7 +95,7 @@ def test_dqn(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, writer=writer)
|
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||||
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import gym
|
import gym
|
||||||
import torch
|
import torch
|
||||||
import pprint
|
import pprint
|
||||||
@ -78,7 +79,11 @@ def test_drqn(args=get_args()):
|
|||||||
# policy.set_eps(1)
|
# policy.set_eps(1)
|
||||||
train_collector.collect(n_step=args.batch_size)
|
train_collector.collect(n_step=args.batch_size)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir + '/' + 'dqn')
|
log_path = os.path.join(args.logdir, args.task, 'drqn')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
def stop_fn(x):
|
def stop_fn(x):
|
||||||
return x >= env.spec.reward_threshold
|
return x >= env.spec.reward_threshold
|
||||||
@ -94,7 +99,7 @@ def test_drqn(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, writer=writer)
|
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||||
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import gym
|
import gym
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
@ -125,7 +126,11 @@ def test_pg(args=get_args()):
|
|||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir + '/' + 'pg')
|
log_path = os.path.join(args.logdir, args.task, 'pg')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
def stop_fn(x):
|
def stop_fn(x):
|
||||||
return x >= env.spec.reward_threshold
|
return x >= env.spec.reward_threshold
|
||||||
@ -134,7 +139,8 @@ def test_pg(args=get_args()):
|
|||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
|
||||||
|
writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import gym
|
import gym
|
||||||
import torch
|
import torch
|
||||||
import pprint
|
import pprint
|
||||||
@ -81,7 +82,11 @@ def test_ppo(args=get_args()):
|
|||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir + '/' + 'ppo')
|
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
def stop_fn(x):
|
def stop_fn(x):
|
||||||
return x >= env.spec.reward_threshold
|
return x >= env.spec.reward_threshold
|
||||||
@ -90,7 +95,8 @@ def test_ppo(args=get_args()):
|
|||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
|
||||||
|
writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
@ -49,7 +49,7 @@ class ReplayBuffer(object):
|
|||||||
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
|
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
|
||||||
>>> for i in range(16):
|
>>> for i in range(16):
|
||||||
... done = i % 5 == 0
|
... done = i % 5 == 0
|
||||||
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i, info={})
|
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1)
|
||||||
>>> print(buf)
|
>>> print(buf)
|
||||||
ReplayBuffer(
|
ReplayBuffer(
|
||||||
obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
|
obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
|
||||||
@ -74,6 +74,17 @@ class ReplayBuffer(object):
|
|||||||
>>> # (stack only for obs and obs_next)
|
>>> # (stack only for obs and obs_next)
|
||||||
>>> abs(buf.get(index, 'obs') - buf[index].obs).sum().sum()
|
>>> abs(buf.get(index, 'obs') - buf[index].obs).sum().sum()
|
||||||
0.0
|
0.0
|
||||||
|
>>> # we can get obs_next through __getitem__, even if it doesn't store
|
||||||
|
>>> print(buf[:].obs_next)
|
||||||
|
[[ 7. 8. 9. 10.]
|
||||||
|
[ 7. 8. 9. 10.]
|
||||||
|
[11. 11. 11. 12.]
|
||||||
|
[11. 11. 12. 13.]
|
||||||
|
[11. 12. 13. 14.]
|
||||||
|
[12. 13. 14. 15.]
|
||||||
|
[12. 13. 14. 15.]
|
||||||
|
[ 7. 7. 7. 8.]
|
||||||
|
[ 7. 7. 8. 9.]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, size, stack_num=0, ignore_obs_next=False, **kwargs):
|
def __init__(self, size, stack_num=0, ignore_obs_next=False, **kwargs):
|
||||||
|
@ -8,9 +8,9 @@ from tianshou.trainer import test_episode, gather_info
|
|||||||
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||||
step_per_epoch, collect_per_step, episode_per_test,
|
step_per_epoch, collect_per_step, episode_per_test,
|
||||||
batch_size,
|
batch_size,
|
||||||
train_fn=None, test_fn=None, stop_fn=None, log_fn=None,
|
train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
|
||||||
writer=None, log_interval=1, verbose=True, task='',
|
log_fn=None, writer=None, log_interval=1, verbose=True,
|
||||||
**kwargs):
|
task='', **kwargs):
|
||||||
"""A wrapper for off-policy trainer procedure.
|
"""A wrapper for off-policy trainer procedure.
|
||||||
|
|
||||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||||
@ -35,6 +35,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
:param function test_fn: a function receives the current number of epoch
|
:param function test_fn: a function receives the current number of epoch
|
||||||
index and performs some operations at the beginning of testing in this
|
index and performs some operations at the beginning of testing in this
|
||||||
epoch.
|
epoch.
|
||||||
|
:param function save_fn: a function for saving policy when the undiscounted
|
||||||
|
average mean reward in evaluation phase gets better.
|
||||||
:param function stop_fn: a function receives the average undiscounted
|
:param function stop_fn: a function receives the average undiscounted
|
||||||
returns of the testing result, return a boolean which indicates whether
|
returns of the testing result, return a boolean which indicates whether
|
||||||
reaching the goal.
|
reaching the goal.
|
||||||
@ -66,6 +68,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
policy, test_collector, test_fn,
|
policy, test_collector, test_fn,
|
||||||
epoch, episode_per_test)
|
epoch, episode_per_test)
|
||||||
if stop_fn and stop_fn(test_result['rew']):
|
if stop_fn and stop_fn(test_result['rew']):
|
||||||
|
if save_fn:
|
||||||
|
save_fn(policy)
|
||||||
for k in result.keys():
|
for k in result.keys():
|
||||||
data[k] = f'{result[k]:.2f}'
|
data[k] = f'{result[k]:.2f}'
|
||||||
t.set_postfix(**data)
|
t.set_postfix(**data)
|
||||||
@ -105,6 +109,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
if best_epoch == -1 or best_reward < result['rew']:
|
if best_epoch == -1 or best_reward < result['rew']:
|
||||||
best_reward = result['rew']
|
best_reward = result['rew']
|
||||||
best_epoch = epoch
|
best_epoch = epoch
|
||||||
|
if save_fn:
|
||||||
|
save_fn(policy)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
|
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
|
||||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||||
|
@ -8,9 +8,9 @@ from tianshou.trainer import test_episode, gather_info
|
|||||||
def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||||
step_per_epoch, collect_per_step, repeat_per_collect,
|
step_per_epoch, collect_per_step, repeat_per_collect,
|
||||||
episode_per_test, batch_size,
|
episode_per_test, batch_size,
|
||||||
train_fn=None, test_fn=None, stop_fn=None, log_fn=None,
|
train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
|
||||||
writer=None, log_interval=1, verbose=True, task='',
|
log_fn=None, writer=None, log_interval=1, verbose=True,
|
||||||
**kwargs):
|
task='', **kwargs):
|
||||||
"""A wrapper for on-policy trainer procedure.
|
"""A wrapper for on-policy trainer procedure.
|
||||||
|
|
||||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||||
@ -39,6 +39,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
:param function test_fn: a function receives the current number of epoch
|
:param function test_fn: a function receives the current number of epoch
|
||||||
index and performs some operations at the beginning of testing in this
|
index and performs some operations at the beginning of testing in this
|
||||||
epoch.
|
epoch.
|
||||||
|
:param function save_fn: a function for saving policy when the undiscounted
|
||||||
|
average mean reward in evaluation phase gets better.
|
||||||
:param function stop_fn: a function receives the average undiscounted
|
:param function stop_fn: a function receives the average undiscounted
|
||||||
returns of the testing result, return a boolean which indicates whether
|
returns of the testing result, return a boolean which indicates whether
|
||||||
reaching the goal.
|
reaching the goal.
|
||||||
@ -70,6 +72,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
policy, test_collector, test_fn,
|
policy, test_collector, test_fn,
|
||||||
epoch, episode_per_test)
|
epoch, episode_per_test)
|
||||||
if stop_fn and stop_fn(test_result['rew']):
|
if stop_fn and stop_fn(test_result['rew']):
|
||||||
|
if save_fn:
|
||||||
|
save_fn(policy)
|
||||||
for k in result.keys():
|
for k in result.keys():
|
||||||
data[k] = f'{result[k]:.2f}'
|
data[k] = f'{result[k]:.2f}'
|
||||||
t.set_postfix(**data)
|
t.set_postfix(**data)
|
||||||
@ -113,6 +117,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
if best_epoch == -1 or best_reward < result['rew']:
|
if best_epoch == -1 or best_reward < result['rew']:
|
||||||
best_reward = result['rew']
|
best_reward = result['rew']
|
||||||
best_epoch = epoch
|
best_epoch = epoch
|
||||||
|
if save_fn:
|
||||||
|
save_fn(policy)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
|
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
|
||||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user