This commit is contained in:
Trinkle23897 2020-04-11 16:54:27 +08:00
parent 74407e13da
commit 6a244d1fbb
14 changed files with 95 additions and 29 deletions

2
.gitignore vendored
View File

@ -141,3 +141,5 @@ flake8.sh
log/
MUJOCO_LOG.TXT
*.pth
.vscode/
.DS_Store

View File

@ -112,6 +112,10 @@ footer p {
font-size: 100%;
}
.ethical-rtd {
display: none;
}
/* For hidden headers that appear in TOC tree */
/* see http://stackoverflow.com/a/32363545/3343043 */
.rst-content .hidden-section {

View File

@ -20,7 +20,6 @@ else: # pytest
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=1e-4)
@ -84,9 +83,12 @@ def test_ddpg(args=get_args()):
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'ddpg', args.run_id)
log_path = os.path.join(args.logdir, args.task, 'ddpg')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
@ -94,7 +96,7 @@ def test_ddpg(args=get_args()):
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer)
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -20,7 +20,6 @@ else: # pytest
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4)
@ -92,9 +91,12 @@ def _test_ppo(args=get_args()):
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.step_per_epoch)
# log
log_path = os.path.join(args.logdir, args.task, 'ppo', args.run_id)
log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
@ -102,7 +104,8 @@ def _test_ppo(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -20,7 +20,6 @@ else: # pytest
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=3e-4)
@ -89,9 +88,12 @@ def test_sac(args=get_args()):
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id)
log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
@ -99,7 +101,7 @@ def test_sac(args=get_args()):
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer)
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -20,7 +20,6 @@ else: # pytest
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=3e-4)
@ -93,9 +92,12 @@ def test_td3(args=get_args()):
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
log_path = os.path.join(args.logdir, args.task, 'td3', args.run_id)
log_path = os.path.join(args.logdir, args.task, 'td3')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
@ -103,7 +105,7 @@ def test_td3(args=get_args()):
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer)
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -1,3 +1,4 @@
import os
import gym
import torch
import pprint
@ -76,7 +77,11 @@ def test_a2c(args=get_args()):
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
# log
writer = SummaryWriter(args.logdir + '/' + 'a2c')
log_path = os.path.join(args.logdir, args.task, 'a2c')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
@ -85,7 +90,8 @@ def test_a2c(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -1,3 +1,4 @@
import os
import gym
import torch
import pprint
@ -74,7 +75,11 @@ def test_dqn(args=get_args()):
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size)
# log
writer = SummaryWriter(args.logdir + '/' + 'dqn')
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
@ -90,7 +95,7 @@ def test_dqn(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, writer=writer)
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()

View File

@ -1,3 +1,4 @@
import os
import gym
import torch
import pprint
@ -78,7 +79,11 @@ def test_drqn(args=get_args()):
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size)
# log
writer = SummaryWriter(args.logdir + '/' + 'dqn')
log_path = os.path.join(args.logdir, args.task, 'drqn')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
@ -94,7 +99,7 @@ def test_drqn(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, writer=writer)
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()

View File

@ -1,3 +1,4 @@
import os
import gym
import time
import torch
@ -125,7 +126,11 @@ def test_pg(args=get_args()):
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
# log
writer = SummaryWriter(args.logdir + '/' + 'pg')
log_path = os.path.join(args.logdir, args.task, 'pg')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
@ -134,7 +139,8 @@ def test_pg(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -1,3 +1,4 @@
import os
import gym
import torch
import pprint
@ -81,7 +82,11 @@ def test_ppo(args=get_args()):
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
# log
writer = SummaryWriter(args.logdir + '/' + 'ppo')
log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
@ -90,7 +95,8 @@ def test_ppo(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -49,7 +49,7 @@ class ReplayBuffer(object):
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16):
... done = i % 5 == 0
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i, info={})
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1)
>>> print(buf)
ReplayBuffer(
obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
@ -74,6 +74,17 @@ class ReplayBuffer(object):
>>> # (stack only for obs and obs_next)
>>> abs(buf.get(index, 'obs') - buf[index].obs).sum().sum()
0.0
>>> # we can get obs_next through __getitem__, even if it doesn't store
>>> print(buf[:].obs_next)
[[ 7. 8. 9. 10.]
[ 7. 8. 9. 10.]
[11. 11. 11. 12.]
[11. 11. 12. 13.]
[11. 12. 13. 14.]
[12. 13. 14. 15.]
[12. 13. 14. 15.]
[ 7. 7. 7. 8.]
[ 7. 7. 8. 9.]]
"""
def __init__(self, size, stack_num=0, ignore_obs_next=False, **kwargs):

View File

@ -8,9 +8,9 @@ from tianshou.trainer import test_episode, gather_info
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
step_per_epoch, collect_per_step, episode_per_test,
batch_size,
train_fn=None, test_fn=None, stop_fn=None, log_fn=None,
writer=None, log_interval=1, verbose=True, task='',
**kwargs):
train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
log_fn=None, writer=None, log_interval=1, verbose=True,
task='', **kwargs):
"""A wrapper for off-policy trainer procedure.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
@ -35,6 +35,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
:param function test_fn: a function receives the current number of epoch
index and performs some operations at the beginning of testing in this
epoch.
:param function save_fn: a function for saving policy when the undiscounted
average mean reward in evaluation phase gets better.
:param function stop_fn: a function receives the average undiscounted
returns of the testing result, return a boolean which indicates whether
reaching the goal.
@ -66,6 +68,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
policy, test_collector, test_fn,
epoch, episode_per_test)
if stop_fn and stop_fn(test_result['rew']):
if save_fn:
save_fn(policy)
for k in result.keys():
data[k] = f'{result[k]:.2f}'
t.set_postfix(**data)
@ -105,6 +109,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
if best_epoch == -1 or best_reward < result['rew']:
best_reward = result['rew']
best_epoch = epoch
if save_fn:
save_fn(policy)
if verbose:
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')

View File

@ -8,9 +8,9 @@ from tianshou.trainer import test_episode, gather_info
def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
step_per_epoch, collect_per_step, repeat_per_collect,
episode_per_test, batch_size,
train_fn=None, test_fn=None, stop_fn=None, log_fn=None,
writer=None, log_interval=1, verbose=True, task='',
**kwargs):
train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
log_fn=None, writer=None, log_interval=1, verbose=True,
task='', **kwargs):
"""A wrapper for on-policy trainer procedure.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
@ -39,6 +39,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
:param function test_fn: a function receives the current number of epoch
index and performs some operations at the beginning of testing in this
epoch.
:param function save_fn: a function for saving policy when the undiscounted
average mean reward in evaluation phase gets better.
:param function stop_fn: a function receives the average undiscounted
returns of the testing result, return a boolean which indicates whether
reaching the goal.
@ -70,6 +72,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
policy, test_collector, test_fn,
epoch, episode_per_test)
if stop_fn and stop_fn(test_result['rew']):
if save_fn:
save_fn(policy)
for k in result.keys():
data[k] = f'{result[k]:.2f}'
t.set_postfix(**data)
@ -113,6 +117,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
if best_epoch == -1 or best_reward < result['rew']:
best_reward = result['rew']
best_epoch = epoch
if save_fn:
save_fn(policy)
if verbose:
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')