vanilla imitation learning

This commit is contained in:
Trinkle23897 2020-04-13 19:37:27 +08:00
parent befdfb07e8
commit 7b65d43394
10 changed files with 88 additions and 12 deletions

View File

@ -25,6 +25,7 @@
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) - [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf) - [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
- Vanilla Imitation Learning
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Our team is working on supporting more algorithms and more scenarios on Tianshou in this period of development. Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Our team is working on supporting more algorithms and more scenarios on Tianshou in this period of development.

View File

@ -43,7 +43,7 @@ This command will run automatic tests in the main directory
Test by GitHub Actions Test by GitHub Actions
---------------------- ----------------------
1. Click the `Actions` button in your own repo: 1. Click the ``Actions`` button in your own repo:
.. image:: _static/images/action1.jpg .. image:: _static/images/action1.jpg
:align: center :align: center

View File

@ -16,6 +16,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_ * :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_ * :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_ * :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
* :class:`~tianshou.policy.ImitationPolicy`
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.

View File

@ -18,6 +18,7 @@ class Actor(nn.Module):
self._max = max_action self._max = max_action
def forward(self, s, **kwargs): def forward(self, s, **kwargs):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float) s = torch.tensor(s, device=self.device, dtype=torch.float)
batch = s.shape[0] batch = s.shape[0]
s = s.view(batch, -1) s = s.view(batch, -1)

View File

@ -7,14 +7,14 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.env import VectorEnv from tianshou.env import VectorEnv
from tianshou.policy import SACPolicy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.policy import SACPolicy, ImitationPolicy
if __name__ == '__main__': if __name__ == '__main__':
from net import ActorProb, Critic from net import Actor, ActorProb, Critic
else: # pytest else: # pytest
from test.continuous.net import ActorProb, Critic from test.continuous.net import Actor, ActorProb, Critic
def get_args(): def get_args():
@ -24,6 +24,7 @@ def get_args():
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=3e-4) parser.add_argument('--actor-lr', type=float, default=3e-4)
parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--il-lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--alpha', type=float, default=0.2)
@ -43,7 +44,7 @@ def get_args():
return args return args
def test_sac(args=get_args()): def test_sac_with_il(args=get_args()):
torch.set_num_threads(1) # we just need only one thread for NN torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task) env = gym.make(args.task)
if args.task == 'Pendulum-v0': if args.task == 'Pendulum-v0':
@ -103,7 +104,6 @@ def test_sac(args=get_args()):
args.step_per_epoch, args.collect_per_step, args.test_num, args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close() test_collector.close()
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
@ -114,6 +114,31 @@ def test_sac(args=get_args()):
print(f'Final reward: {result["rew"]}, length: {result["len"]}') print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close() collector.close()
# here we define an imitation collector with a trivial policy
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -300 # lower the goal
net = Actor(1, args.state_shape, args.action_shape,
args.max_action, args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(net, optim)
il_test_collector = Collector(il_policy, test_envs)
train_collector.reset()
result = offpolicy_trainer(
il_policy, train_collector, il_test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
il_test_collector.close()
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(il_policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()
if __name__ == '__main__': if __name__ == '__main__':
test_sac() test_sac_with_il()

View File

@ -97,12 +97,20 @@ class Collector(object):
ListReplayBuffer() for _ in range(self.env_num)] ListReplayBuffer() for _ in range(self.env_num)]
else: else:
raise TypeError('The buffer in data collector is invalid!') raise TypeError('The buffer in data collector is invalid!')
self.stat_size = stat_size
self.reset()
def reset(self):
"""Reset all related variables in the collector."""
self.reset_env() self.reset_env()
self.reset_buffer() self.reset_buffer()
# state over batch is either a list, an np.ndarray, or a torch.Tensor # state over batch is either a list, an np.ndarray, or a torch.Tensor
self.state = None self.state = None
self.step_speed = MovAvg(stat_size) self.step_speed = MovAvg(self.stat_size)
self.episode_speed = MovAvg(stat_size) self.episode_speed = MovAvg(self.stat_size)
self.collect_step = 0
self.collect_episode = 0
self.collect_time = 0
def reset_buffer(self): def reset_buffer(self):
"""Reset the main data buffer.""" """Reset the main data buffer."""

View File

@ -1,4 +1,5 @@
from tianshou.policy.base import BasePolicy from tianshou.policy.base import BasePolicy
from tianshou.policy.imitation import ImitationPolicy
from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.dqn import DQNPolicy
from tianshou.policy.modelfree.pg import PGPolicy from tianshou.policy.modelfree.pg import PGPolicy
from tianshou.policy.modelfree.a2c import A2CPolicy from tianshou.policy.modelfree.a2c import A2CPolicy
@ -9,6 +10,7 @@ from tianshou.policy.modelfree.sac import SACPolicy
__all__ = [ __all__ = [
'BasePolicy', 'BasePolicy',
'ImitationPolicy',
'DQNPolicy', 'DQNPolicy',
'PGPolicy', 'PGPolicy',
'A2CPolicy', 'A2CPolicy',

View File

@ -0,0 +1,36 @@
import torch
import torch.nn.functional as F
from tianshou.data import Batch
from tianshou.policy import BasePolicy
class ImitationPolicy(BasePolicy):
"""Implementation of vanilla imitation learning (for continuous action space).
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> a)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(self, model, optim):
super().__init__()
self.model = model
self.optim = optim
def forward(self, batch, state=None):
a, h = self.model(batch.obs, state=state, info=batch.info)
return Batch(act=a, state=h)
def learn(self, batch, **kwargs):
self.optim.zero_grad()
a = self(batch).act
a_ = torch.tensor(batch.act, dtype=torch.float, device=a.device)
loss = F.mse_loss(a, a_)
loss.backward()
self.optim.step()
return {'loss': loss.item()}

View File

@ -52,6 +52,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
best_epoch, best_reward = -1, -1 best_epoch, best_reward = -1, -1
stat = {} stat = {}
start_time = time.time() start_time = time.time()
test_in_train = train_collector.policy == policy
for epoch in range(1, 1 + max_epoch): for epoch in range(1, 1 + max_epoch):
# train # train
policy.train() policy.train()
@ -63,7 +64,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
result = train_collector.collect(n_step=collect_per_step, result = train_collector.collect(n_step=collect_per_step,
log_fn=log_fn) log_fn=log_fn)
data = {} data = {}
if stop_fn and stop_fn(result['rew']): if test_in_train and stop_fn and stop_fn(result['rew']):
test_result = test_episode( test_result = test_episode(
policy, test_collector, test_fn, policy, test_collector, test_fn,
epoch, episode_per_test) epoch, episode_per_test)

View File

@ -56,6 +56,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
best_epoch, best_reward = -1, -1 best_epoch, best_reward = -1, -1
stat = {} stat = {}
start_time = time.time() start_time = time.time()
test_in_train = train_collector.policy == policy
for epoch in range(1, 1 + max_epoch): for epoch in range(1, 1 + max_epoch):
# train # train
policy.train() policy.train()
@ -67,7 +68,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
result = train_collector.collect(n_episode=collect_per_step, result = train_collector.collect(n_episode=collect_per_step,
log_fn=log_fn) log_fn=log_fn)
data = {} data = {}
if stop_fn and stop_fn(result['rew']): if test_in_train and stop_fn and stop_fn(result['rew']):
test_result = test_episode( test_result = test_episode(
policy, test_collector, test_fn, policy, test_collector, test_fn,
epoch, episode_per_test) epoch, episode_per_test)