This commit is contained in:
Trinkle23897 2020-04-11 16:54:27 +08:00
parent 74407e13da
commit 6a244d1fbb
14 changed files with 95 additions and 29 deletions

2
.gitignore vendored
View File

@ -141,3 +141,5 @@ flake8.sh
log/ log/
MUJOCO_LOG.TXT MUJOCO_LOG.TXT
*.pth *.pth
.vscode/
.DS_Store

View File

@ -112,6 +112,10 @@ footer p {
font-size: 100%; font-size: 100%;
} }
.ethical-rtd {
display: none;
}
/* For hidden headers that appear in TOC tree */ /* For hidden headers that appear in TOC tree */
/* see http://stackoverflow.com/a/32363545/3343043 */ /* see http://stackoverflow.com/a/32363545/3343043 */
.rst-content .hidden-section { .rst-content .hidden-section {

View File

@ -20,7 +20,6 @@ else: # pytest
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=1e-4) parser.add_argument('--actor-lr', type=float, default=1e-4)
@ -84,9 +83,12 @@ def test_ddpg(args=get_args()):
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# log # log
log_path = os.path.join(args.logdir, args.task, 'ddpg', args.run_id) log_path = os.path.join(args.logdir, args.task, 'ddpg')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x): def stop_fn(x):
return x >= env.spec.reward_threshold return x >= env.spec.reward_threshold
@ -94,7 +96,7 @@ def test_ddpg(args=get_args()):
result = offpolicy_trainer( result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num, args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer) args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
train_collector.close() train_collector.close()
test_collector.close() test_collector.close()

View File

@ -20,7 +20,6 @@ else: # pytest
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
@ -92,9 +91,12 @@ def _test_ppo(args=get_args()):
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.step_per_epoch) train_collector.collect(n_step=args.step_per_epoch)
# log # log
log_path = os.path.join(args.logdir, args.task, 'ppo', args.run_id) log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x): def stop_fn(x):
return x >= env.spec.reward_threshold return x >= env.spec.reward_threshold
@ -102,7 +104,8 @@ def _test_ppo(args=get_args()):
result = onpolicy_trainer( result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
writer=writer)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
train_collector.close() train_collector.close()
test_collector.close() test_collector.close()

View File

@ -20,7 +20,6 @@ else: # pytest
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=3e-4) parser.add_argument('--actor-lr', type=float, default=3e-4)
@ -89,9 +88,12 @@ def test_sac(args=get_args()):
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size) # train_collector.collect(n_step=args.buffer_size)
# log # log
log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id) log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x): def stop_fn(x):
return x >= env.spec.reward_threshold return x >= env.spec.reward_threshold
@ -99,7 +101,7 @@ def test_sac(args=get_args()):
result = offpolicy_trainer( result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num, args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer) args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
train_collector.close() train_collector.close()
test_collector.close() test_collector.close()

View File

@ -20,7 +20,6 @@ else: # pytest
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=3e-4) parser.add_argument('--actor-lr', type=float, default=3e-4)
@ -93,9 +92,12 @@ def test_td3(args=get_args()):
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size) # train_collector.collect(n_step=args.buffer_size)
# log # log
log_path = os.path.join(args.logdir, args.task, 'td3', args.run_id) log_path = os.path.join(args.logdir, args.task, 'td3')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x): def stop_fn(x):
return x >= env.spec.reward_threshold return x >= env.spec.reward_threshold
@ -103,7 +105,7 @@ def test_td3(args=get_args()):
result = offpolicy_trainer( result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num, args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer) args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
train_collector.close() train_collector.close()
test_collector.close() test_collector.close()

View File

@ -1,3 +1,4 @@
import os
import gym import gym
import torch import torch
import pprint import pprint
@ -76,7 +77,11 @@ def test_a2c(args=get_args()):
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# log # log
writer = SummaryWriter(args.logdir + '/' + 'a2c') log_path = os.path.join(args.logdir, args.task, 'a2c')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x): def stop_fn(x):
return x >= env.spec.reward_threshold return x >= env.spec.reward_threshold
@ -85,7 +90,8 @@ def test_a2c(args=get_args()):
result = onpolicy_trainer( result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
writer=writer)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
train_collector.close() train_collector.close()
test_collector.close() test_collector.close()

View File

@ -1,3 +1,4 @@
import os
import gym import gym
import torch import torch
import pprint import pprint
@ -74,7 +75,11 @@ def test_dqn(args=get_args()):
# policy.set_eps(1) # policy.set_eps(1)
train_collector.collect(n_step=args.batch_size) train_collector.collect(n_step=args.batch_size)
# log # log
writer = SummaryWriter(args.logdir + '/' + 'dqn') log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x): def stop_fn(x):
return x >= env.spec.reward_threshold return x >= env.spec.reward_threshold
@ -90,7 +95,7 @@ def test_dqn(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num, args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn, args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, writer=writer) stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
train_collector.close() train_collector.close()

View File

@ -1,3 +1,4 @@
import os
import gym import gym
import torch import torch
import pprint import pprint
@ -78,7 +79,11 @@ def test_drqn(args=get_args()):
# policy.set_eps(1) # policy.set_eps(1)
train_collector.collect(n_step=args.batch_size) train_collector.collect(n_step=args.batch_size)
# log # log
writer = SummaryWriter(args.logdir + '/' + 'dqn') log_path = os.path.join(args.logdir, args.task, 'drqn')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x): def stop_fn(x):
return x >= env.spec.reward_threshold return x >= env.spec.reward_threshold
@ -94,7 +99,7 @@ def test_drqn(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num, args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn, args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, writer=writer) stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
train_collector.close() train_collector.close()

View File

@ -1,3 +1,4 @@
import os
import gym import gym
import time import time
import torch import torch
@ -125,7 +126,11 @@ def test_pg(args=get_args()):
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# log # log
writer = SummaryWriter(args.logdir + '/' + 'pg') log_path = os.path.join(args.logdir, args.task, 'pg')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x): def stop_fn(x):
return x >= env.spec.reward_threshold return x >= env.spec.reward_threshold
@ -134,7 +139,8 @@ def test_pg(args=get_args()):
result = onpolicy_trainer( result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
writer=writer)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
train_collector.close() train_collector.close()
test_collector.close() test_collector.close()

View File

@ -1,3 +1,4 @@
import os
import gym import gym
import torch import torch
import pprint import pprint
@ -81,7 +82,11 @@ def test_ppo(args=get_args()):
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# log # log
writer = SummaryWriter(args.logdir + '/' + 'ppo') log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x): def stop_fn(x):
return x >= env.spec.reward_threshold return x >= env.spec.reward_threshold
@ -90,7 +95,8 @@ def test_ppo(args=get_args()):
result = onpolicy_trainer( result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
writer=writer)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
train_collector.close() train_collector.close()
test_collector.close() test_collector.close()

View File

@ -49,7 +49,7 @@ class ReplayBuffer(object):
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16): >>> for i in range(16):
... done = i % 5 == 0 ... done = i % 5 == 0
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i, info={}) ... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1)
>>> print(buf) >>> print(buf)
ReplayBuffer( ReplayBuffer(
obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.], obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.],
@ -74,6 +74,17 @@ class ReplayBuffer(object):
>>> # (stack only for obs and obs_next) >>> # (stack only for obs and obs_next)
>>> abs(buf.get(index, 'obs') - buf[index].obs).sum().sum() >>> abs(buf.get(index, 'obs') - buf[index].obs).sum().sum()
0.0 0.0
>>> # we can get obs_next through __getitem__, even if it doesn't store
>>> print(buf[:].obs_next)
[[ 7. 8. 9. 10.]
[ 7. 8. 9. 10.]
[11. 11. 11. 12.]
[11. 11. 12. 13.]
[11. 12. 13. 14.]
[12. 13. 14. 15.]
[12. 13. 14. 15.]
[ 7. 7. 7. 8.]
[ 7. 7. 8. 9.]]
""" """
def __init__(self, size, stack_num=0, ignore_obs_next=False, **kwargs): def __init__(self, size, stack_num=0, ignore_obs_next=False, **kwargs):

View File

@ -8,9 +8,9 @@ from tianshou.trainer import test_episode, gather_info
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
step_per_epoch, collect_per_step, episode_per_test, step_per_epoch, collect_per_step, episode_per_test,
batch_size, batch_size,
train_fn=None, test_fn=None, stop_fn=None, log_fn=None, train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
writer=None, log_interval=1, verbose=True, task='', log_fn=None, writer=None, log_interval=1, verbose=True,
**kwargs): task='', **kwargs):
"""A wrapper for off-policy trainer procedure. """A wrapper for off-policy trainer procedure.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
@ -35,6 +35,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
:param function test_fn: a function receives the current number of epoch :param function test_fn: a function receives the current number of epoch
index and performs some operations at the beginning of testing in this index and performs some operations at the beginning of testing in this
epoch. epoch.
:param function save_fn: a function for saving policy when the undiscounted
average mean reward in evaluation phase gets better.
:param function stop_fn: a function receives the average undiscounted :param function stop_fn: a function receives the average undiscounted
returns of the testing result, return a boolean which indicates whether returns of the testing result, return a boolean which indicates whether
reaching the goal. reaching the goal.
@ -66,6 +68,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
policy, test_collector, test_fn, policy, test_collector, test_fn,
epoch, episode_per_test) epoch, episode_per_test)
if stop_fn and stop_fn(test_result['rew']): if stop_fn and stop_fn(test_result['rew']):
if save_fn:
save_fn(policy)
for k in result.keys(): for k in result.keys():
data[k] = f'{result[k]:.2f}' data[k] = f'{result[k]:.2f}'
t.set_postfix(**data) t.set_postfix(**data)
@ -105,6 +109,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
if best_epoch == -1 or best_reward < result['rew']: if best_epoch == -1 or best_reward < result['rew']:
best_reward = result['rew'] best_reward = result['rew']
best_epoch = epoch best_epoch = epoch
if save_fn:
save_fn(policy)
if verbose: if verbose:
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, ' print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}') f'best_reward: {best_reward:.6f} in #{best_epoch}')

View File

@ -8,9 +8,9 @@ from tianshou.trainer import test_episode, gather_info
def onpolicy_trainer(policy, train_collector, test_collector, max_epoch, def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
step_per_epoch, collect_per_step, repeat_per_collect, step_per_epoch, collect_per_step, repeat_per_collect,
episode_per_test, batch_size, episode_per_test, batch_size,
train_fn=None, test_fn=None, stop_fn=None, log_fn=None, train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
writer=None, log_interval=1, verbose=True, task='', log_fn=None, writer=None, log_interval=1, verbose=True,
**kwargs): task='', **kwargs):
"""A wrapper for on-policy trainer procedure. """A wrapper for on-policy trainer procedure.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
@ -39,6 +39,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
:param function test_fn: a function receives the current number of epoch :param function test_fn: a function receives the current number of epoch
index and performs some operations at the beginning of testing in this index and performs some operations at the beginning of testing in this
epoch. epoch.
:param function save_fn: a function for saving policy when the undiscounted
average mean reward in evaluation phase gets better.
:param function stop_fn: a function receives the average undiscounted :param function stop_fn: a function receives the average undiscounted
returns of the testing result, return a boolean which indicates whether returns of the testing result, return a boolean which indicates whether
reaching the goal. reaching the goal.
@ -70,6 +72,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
policy, test_collector, test_fn, policy, test_collector, test_fn,
epoch, episode_per_test) epoch, episode_per_test)
if stop_fn and stop_fn(test_result['rew']): if stop_fn and stop_fn(test_result['rew']):
if save_fn:
save_fn(policy)
for k in result.keys(): for k in result.keys():
data[k] = f'{result[k]:.2f}' data[k] = f'{result[k]:.2f}'
t.set_postfix(**data) t.set_postfix(**data)
@ -113,6 +117,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
if best_epoch == -1 or best_reward < result['rew']: if best_epoch == -1 or best_reward < result['rew']:
best_reward = result['rew'] best_reward = result['rew']
best_epoch = epoch best_epoch = epoch
if save_fn:
save_fn(policy)
if verbose: if verbose:
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, ' print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}') f'best_reward: {best_reward:.6f} in #{best_epoch}')