vanilla imitation learning

This commit is contained in:
Trinkle23897 2020-04-13 19:37:27 +08:00
parent befdfb07e8
commit 7b65d43394
10 changed files with 88 additions and 12 deletions

View File

@ -25,6 +25,7 @@
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
- Vanilla Imitation Learning
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Our team is working on supporting more algorithms and more scenarios on Tianshou in this period of development.

View File

@ -43,7 +43,7 @@ This command will run automatic tests in the main directory
Test by GitHub Actions
----------------------
1. Click the `Actions` button in your own repo:
1. Click the ``Actions`` button in your own repo:
.. image:: _static/images/action1.jpg
:align: center

View File

@ -16,6 +16,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
* :class:`~tianshou.policy.ImitationPolicy`
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.

View File

@ -18,7 +18,8 @@ class Actor(nn.Module):
self._max = max_action
def forward(self, s, **kwargs):
s = torch.tensor(s, device=self.device, dtype=torch.float)
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.model(s)

View File

@ -7,14 +7,14 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.env import VectorEnv
from tianshou.policy import SACPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.policy import SACPolicy, ImitationPolicy
if __name__ == '__main__':
from net import ActorProb, Critic
from net import Actor, ActorProb, Critic
else: # pytest
from test.continuous.net import ActorProb, Critic
from test.continuous.net import Actor, ActorProb, Critic
def get_args():
@ -24,6 +24,7 @@ def get_args():
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=3e-4)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--il-lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--alpha', type=float, default=0.2)
@ -43,7 +44,7 @@ def get_args():
return args
def test_sac(args=get_args()):
def test_sac_with_il(args=get_args()):
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task)
if args.task == 'Pendulum-v0':
@ -103,7 +104,6 @@ def test_sac(args=get_args()):
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()
if __name__ == '__main__':
pprint.pprint(result)
@ -114,6 +114,31 @@ def test_sac(args=get_args()):
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()
# here we define an imitation collector with a trivial policy
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -300 # lower the goal
net = Actor(1, args.state_shape, args.action_shape,
args.max_action, args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(net, optim)
il_test_collector = Collector(il_policy, test_envs)
train_collector.reset()
result = offpolicy_trainer(
il_policy, train_collector, il_test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
il_test_collector.close()
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(il_policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()
if __name__ == '__main__':
test_sac()
test_sac_with_il()

View File

@ -97,12 +97,20 @@ class Collector(object):
ListReplayBuffer() for _ in range(self.env_num)]
else:
raise TypeError('The buffer in data collector is invalid!')
self.stat_size = stat_size
self.reset()
def reset(self):
"""Reset all related variables in the collector."""
self.reset_env()
self.reset_buffer()
# state over batch is either a list, an np.ndarray, or a torch.Tensor
self.state = None
self.step_speed = MovAvg(stat_size)
self.episode_speed = MovAvg(stat_size)
self.step_speed = MovAvg(self.stat_size)
self.episode_speed = MovAvg(self.stat_size)
self.collect_step = 0
self.collect_episode = 0
self.collect_time = 0
def reset_buffer(self):
"""Reset the main data buffer."""

View File

@ -1,4 +1,5 @@
from tianshou.policy.base import BasePolicy
from tianshou.policy.imitation import ImitationPolicy
from tianshou.policy.modelfree.dqn import DQNPolicy
from tianshou.policy.modelfree.pg import PGPolicy
from tianshou.policy.modelfree.a2c import A2CPolicy
@ -9,6 +10,7 @@ from tianshou.policy.modelfree.sac import SACPolicy
__all__ = [
'BasePolicy',
'ImitationPolicy',
'DQNPolicy',
'PGPolicy',
'A2CPolicy',

View File

@ -0,0 +1,36 @@
import torch
import torch.nn.functional as F
from tianshou.data import Batch
from tianshou.policy import BasePolicy
class ImitationPolicy(BasePolicy):
"""Implementation of vanilla imitation learning (for continuous action space).
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> a)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(self, model, optim):
super().__init__()
self.model = model
self.optim = optim
def forward(self, batch, state=None):
a, h = self.model(batch.obs, state=state, info=batch.info)
return Batch(act=a, state=h)
def learn(self, batch, **kwargs):
self.optim.zero_grad()
a = self(batch).act
a_ = torch.tensor(batch.act, dtype=torch.float, device=a.device)
loss = F.mse_loss(a, a_)
loss.backward()
self.optim.step()
return {'loss': loss.item()}

View File

@ -52,6 +52,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
best_epoch, best_reward = -1, -1
stat = {}
start_time = time.time()
test_in_train = train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
@ -63,7 +64,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
result = train_collector.collect(n_step=collect_per_step,
log_fn=log_fn)
data = {}
if stop_fn and stop_fn(result['rew']):
if test_in_train and stop_fn and stop_fn(result['rew']):
test_result = test_episode(
policy, test_collector, test_fn,
epoch, episode_per_test)

View File

@ -56,6 +56,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
best_epoch, best_reward = -1, -1
stat = {}
start_time = time.time()
test_in_train = train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
@ -67,7 +68,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
result = train_collector.collect(n_episode=collect_per_step,
log_fn=log_fn)
data = {}
if stop_fn and stop_fn(result['rew']):
if test_in_train and stop_fn and stop_fn(result['rew']):
test_result = test_episode(
policy, test_collector, test_fn,
epoch, episode_per_test)