vanilla imitation learning
This commit is contained in:
parent
befdfb07e8
commit
7b65d43394
@ -25,6 +25,7 @@
|
|||||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
||||||
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
|
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
|
||||||
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
||||||
|
- Vanilla Imitation Learning
|
||||||
|
|
||||||
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Our team is working on supporting more algorithms and more scenarios on Tianshou in this period of development.
|
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. Our team is working on supporting more algorithms and more scenarios on Tianshou in this period of development.
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ This command will run automatic tests in the main directory
|
|||||||
Test by GitHub Actions
|
Test by GitHub Actions
|
||||||
----------------------
|
----------------------
|
||||||
|
|
||||||
1. Click the `Actions` button in your own repo:
|
1. Click the ``Actions`` button in your own repo:
|
||||||
|
|
||||||
.. image:: _static/images/action1.jpg
|
.. image:: _static/images/action1.jpg
|
||||||
:align: center
|
:align: center
|
||||||
|
@ -16,6 +16,7 @@ Welcome to Tianshou!
|
|||||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||||
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
||||||
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
||||||
|
* :class:`~tianshou.policy.ImitationPolicy`
|
||||||
|
|
||||||
|
|
||||||
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.
|
Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.
|
||||||
|
@ -18,7 +18,8 @@ class Actor(nn.Module):
|
|||||||
self._max = max_action
|
self._max = max_action
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
def forward(self, s, **kwargs):
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
if not isinstance(s, torch.Tensor):
|
||||||
|
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||||
batch = s.shape[0]
|
batch = s.shape[0]
|
||||||
s = s.view(batch, -1)
|
s = s.view(batch, -1)
|
||||||
logits = self.model(s)
|
logits = self.model(s)
|
||||||
|
@ -7,14 +7,14 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.env import VectorEnv
|
from tianshou.env import VectorEnv
|
||||||
from tianshou.policy import SACPolicy
|
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
from tianshou.policy import SACPolicy, ImitationPolicy
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from net import ActorProb, Critic
|
from net import Actor, ActorProb, Critic
|
||||||
else: # pytest
|
else: # pytest
|
||||||
from test.continuous.net import ActorProb, Critic
|
from test.continuous.net import Actor, ActorProb, Critic
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -24,6 +24,7 @@ def get_args():
|
|||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--il-lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.99)
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
parser.add_argument('--tau', type=float, default=0.005)
|
parser.add_argument('--tau', type=float, default=0.005)
|
||||||
parser.add_argument('--alpha', type=float, default=0.2)
|
parser.add_argument('--alpha', type=float, default=0.2)
|
||||||
@ -43,7 +44,7 @@ def get_args():
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def test_sac(args=get_args()):
|
def test_sac_with_il(args=get_args()):
|
||||||
torch.set_num_threads(1) # we just need only one thread for NN
|
torch.set_num_threads(1) # we just need only one thread for NN
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
if args.task == 'Pendulum-v0':
|
if args.task == 'Pendulum-v0':
|
||||||
@ -103,7 +104,6 @@ def test_sac(args=get_args()):
|
|||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
@ -114,6 +114,31 @@ def test_sac(args=get_args()):
|
|||||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||||
collector.close()
|
collector.close()
|
||||||
|
|
||||||
|
# here we define an imitation collector with a trivial policy
|
||||||
|
if args.task == 'Pendulum-v0':
|
||||||
|
env.spec.reward_threshold = -300 # lower the goal
|
||||||
|
net = Actor(1, args.state_shape, args.action_shape,
|
||||||
|
args.max_action, args.device).to(args.device)
|
||||||
|
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
||||||
|
il_policy = ImitationPolicy(net, optim)
|
||||||
|
il_test_collector = Collector(il_policy, test_envs)
|
||||||
|
train_collector.reset()
|
||||||
|
result = offpolicy_trainer(
|
||||||
|
il_policy, train_collector, il_test_collector, args.epoch,
|
||||||
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
|
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||||
|
assert stop_fn(result['best_reward'])
|
||||||
|
train_collector.close()
|
||||||
|
il_test_collector.close()
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pprint.pprint(result)
|
||||||
|
# Let's watch its performance!
|
||||||
|
env = gym.make(args.task)
|
||||||
|
collector = Collector(il_policy, env)
|
||||||
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
|
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||||
|
collector.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_sac()
|
test_sac_with_il()
|
@ -97,12 +97,20 @@ class Collector(object):
|
|||||||
ListReplayBuffer() for _ in range(self.env_num)]
|
ListReplayBuffer() for _ in range(self.env_num)]
|
||||||
else:
|
else:
|
||||||
raise TypeError('The buffer in data collector is invalid!')
|
raise TypeError('The buffer in data collector is invalid!')
|
||||||
|
self.stat_size = stat_size
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset all related variables in the collector."""
|
||||||
self.reset_env()
|
self.reset_env()
|
||||||
self.reset_buffer()
|
self.reset_buffer()
|
||||||
# state over batch is either a list, an np.ndarray, or a torch.Tensor
|
# state over batch is either a list, an np.ndarray, or a torch.Tensor
|
||||||
self.state = None
|
self.state = None
|
||||||
self.step_speed = MovAvg(stat_size)
|
self.step_speed = MovAvg(self.stat_size)
|
||||||
self.episode_speed = MovAvg(stat_size)
|
self.episode_speed = MovAvg(self.stat_size)
|
||||||
|
self.collect_step = 0
|
||||||
|
self.collect_episode = 0
|
||||||
|
self.collect_time = 0
|
||||||
|
|
||||||
def reset_buffer(self):
|
def reset_buffer(self):
|
||||||
"""Reset the main data buffer."""
|
"""Reset the main data buffer."""
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from tianshou.policy.base import BasePolicy
|
from tianshou.policy.base import BasePolicy
|
||||||
|
from tianshou.policy.imitation import ImitationPolicy
|
||||||
from tianshou.policy.modelfree.dqn import DQNPolicy
|
from tianshou.policy.modelfree.dqn import DQNPolicy
|
||||||
from tianshou.policy.modelfree.pg import PGPolicy
|
from tianshou.policy.modelfree.pg import PGPolicy
|
||||||
from tianshou.policy.modelfree.a2c import A2CPolicy
|
from tianshou.policy.modelfree.a2c import A2CPolicy
|
||||||
@ -9,6 +10,7 @@ from tianshou.policy.modelfree.sac import SACPolicy
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BasePolicy',
|
'BasePolicy',
|
||||||
|
'ImitationPolicy',
|
||||||
'DQNPolicy',
|
'DQNPolicy',
|
||||||
'PGPolicy',
|
'PGPolicy',
|
||||||
'A2CPolicy',
|
'A2CPolicy',
|
||||||
|
36
tianshou/policy/imitation.py
Normal file
36
tianshou/policy/imitation.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from tianshou.data import Batch
|
||||||
|
from tianshou.policy import BasePolicy
|
||||||
|
|
||||||
|
|
||||||
|
class ImitationPolicy(BasePolicy):
|
||||||
|
"""Implementation of vanilla imitation learning (for continuous action space).
|
||||||
|
|
||||||
|
:param torch.nn.Module model: a model following the rules in
|
||||||
|
:class:`~tianshou.policy.BasePolicy`. (s -> a)
|
||||||
|
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||||
|
explanation.
|
||||||
|
"""
|
||||||
|
def __init__(self, model, optim):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.optim = optim
|
||||||
|
|
||||||
|
def forward(self, batch, state=None):
|
||||||
|
a, h = self.model(batch.obs, state=state, info=batch.info)
|
||||||
|
return Batch(act=a, state=h)
|
||||||
|
|
||||||
|
def learn(self, batch, **kwargs):
|
||||||
|
self.optim.zero_grad()
|
||||||
|
a = self(batch).act
|
||||||
|
a_ = torch.tensor(batch.act, dtype=torch.float, device=a.device)
|
||||||
|
loss = F.mse_loss(a, a_)
|
||||||
|
loss.backward()
|
||||||
|
self.optim.step()
|
||||||
|
return {'loss': loss.item()}
|
@ -52,6 +52,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
best_epoch, best_reward = -1, -1
|
best_epoch, best_reward = -1, -1
|
||||||
stat = {}
|
stat = {}
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
test_in_train = train_collector.policy == policy
|
||||||
for epoch in range(1, 1 + max_epoch):
|
for epoch in range(1, 1 + max_epoch):
|
||||||
# train
|
# train
|
||||||
policy.train()
|
policy.train()
|
||||||
@ -63,7 +64,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
result = train_collector.collect(n_step=collect_per_step,
|
result = train_collector.collect(n_step=collect_per_step,
|
||||||
log_fn=log_fn)
|
log_fn=log_fn)
|
||||||
data = {}
|
data = {}
|
||||||
if stop_fn and stop_fn(result['rew']):
|
if test_in_train and stop_fn and stop_fn(result['rew']):
|
||||||
test_result = test_episode(
|
test_result = test_episode(
|
||||||
policy, test_collector, test_fn,
|
policy, test_collector, test_fn,
|
||||||
epoch, episode_per_test)
|
epoch, episode_per_test)
|
||||||
|
@ -56,6 +56,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
best_epoch, best_reward = -1, -1
|
best_epoch, best_reward = -1, -1
|
||||||
stat = {}
|
stat = {}
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
test_in_train = train_collector.policy == policy
|
||||||
for epoch in range(1, 1 + max_epoch):
|
for epoch in range(1, 1 + max_epoch):
|
||||||
# train
|
# train
|
||||||
policy.train()
|
policy.train()
|
||||||
@ -67,7 +68,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
result = train_collector.collect(n_episode=collect_per_step,
|
result = train_collector.collect(n_episode=collect_per_step,
|
||||||
log_fn=log_fn)
|
log_fn=log_fn)
|
||||||
data = {}
|
data = {}
|
||||||
if stop_fn and stop_fn(result['rew']):
|
if test_in_train and stop_fn and stop_fn(result['rew']):
|
||||||
test_result = test_episode(
|
test_result = test_episode(
|
||||||
policy, test_collector, test_fn,
|
policy, test_collector, test_fn,
|
||||||
epoch, episode_per_test)
|
epoch, episode_per_test)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user