imitation with discrete action space

This commit is contained in:
Trinkle23897 2020-04-20 11:25:20 +08:00
parent 6bf1ea644d
commit 815f3522bb
6 changed files with 81 additions and 42 deletions

View File

@ -121,7 +121,7 @@ def test_sac_with_il(args=get_args()):
net = Actor(1, args.state_shape, args.action_shape, net = Actor(1, args.state_shape, args.action_shape,
args.max_action, args.device).to(args.device) args.max_action, args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(net, optim) il_policy = ImitationPolicy(net, optim, mode='continuous')
il_test_collector = Collector(il_policy, test_envs) il_test_collector = Collector(il_policy, test_envs)
train_collector.reset() train_collector.reset()
result = offpolicy_trainer( result = offpolicy_trainer(

View File

@ -6,10 +6,10 @@ import argparse
import numpy as np import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import A2CPolicy
from tianshou.env import VectorEnv from tianshou.env import VectorEnv
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.policy import A2CPolicy, ImitationPolicy
from tianshou.trainer import onpolicy_trainer, offpolicy_trainer
if __name__ == '__main__': if __name__ == '__main__':
from net import Net, Actor, Critic from net import Net, Actor, Critic
@ -23,6 +23,7 @@ def get_args():
parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--il-lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--step-per-epoch', type=int, default=1000)
@ -95,7 +96,6 @@ def test_a2c(args=get_args()):
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
writer=writer) writer=writer)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close() test_collector.close()
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
@ -106,6 +106,31 @@ def test_a2c(args=get_args()):
print(f'Final reward: {result["rew"]}, length: {result["len"]}') print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close() collector.close()
# here we define an imitation collector with a trivial policy
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -300 # lower the goal
net = Net(1, args.state_shape, device=args.device)
net = Actor(net, args.action_shape).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(net, optim, mode='discrete')
il_test_collector = Collector(il_policy, test_envs)
train_collector.reset()
result = offpolicy_trainer(
il_policy, train_collector, il_test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
il_test_collector.close()
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(il_policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()
if __name__ == '__main__': if __name__ == '__main__':
test_a2c() test_a2c()

View File

@ -167,5 +167,5 @@ def test_pg(args=get_args()):
if __name__ == '__main__': if __name__ == '__main__':
test_fn() # test_fn()
test_pg() test_pg()

View File

@ -1,5 +1,5 @@
from tianshou.policy.base import BasePolicy from tianshou.policy.base import BasePolicy
from tianshou.policy.imitation import ImitationPolicy from tianshou.policy.imitation.base import ImitationPolicy
from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.dqn import DQNPolicy
from tianshou.policy.modelfree.pg import PGPolicy from tianshou.policy.modelfree.pg import PGPolicy
from tianshou.policy.modelfree.a2c import A2CPolicy from tianshou.policy.modelfree.a2c import A2CPolicy

View File

@ -1,36 +0,0 @@
import torch
import torch.nn.functional as F
from tianshou.data import Batch
from tianshou.policy import BasePolicy
class ImitationPolicy(BasePolicy):
"""Implementation of vanilla imitation learning (for continuous action space).
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> a)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(self, model, optim):
super().__init__()
self.model = model
self.optim = optim
def forward(self, batch, state=None):
a, h = self.model(batch.obs, state=state, info=batch.info)
return Batch(act=a, state=h)
def learn(self, batch, **kwargs):
self.optim.zero_grad()
a = self(batch).act
a_ = torch.tensor(batch.act, dtype=torch.float, device=a.device)
loss = F.mse_loss(a, a_)
loss.backward()
self.optim.step()
return {'loss': loss.item()}

View File

@ -0,0 +1,50 @@
import torch
import torch.nn.functional as F
from tianshou.data import Batch
from tianshou.policy import BasePolicy
class ImitationPolicy(BasePolicy):
"""Implementation of vanilla imitation learning (for continuous action space).
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> a)
:param torch.optim.Optimizer optim: for optimizing the model.
:param str mode: indicate the imitation type ("continuous" or "discrete"
action space), defaults to "continuous".
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(self, model, optim, mode='continuous'):
super().__init__()
self.model = model
self.optim = optim
assert mode in ['continuous', 'discrete'], \
f'Mode {mode} is not in ["continuous", "discrete"]'
self.mode = mode
def forward(self, batch, state=None):
logits, h = self.model(batch.obs, state=state, info=batch.info)
if self.mode == 'discrete':
a = logits.max(dim=1)[1]
else:
a = logits
return Batch(logits=logits, act=a, state=h)
def learn(self, batch, **kwargs):
self.optim.zero_grad()
if self.mode == 'continuous':
a = self(batch).act
a_ = torch.tensor(batch.act, dtype=torch.float, device=a.device)
loss = F.mse_loss(a, a_)
elif self.mode == 'discrete': # classification
a = self(batch).logits
a_ = torch.tensor(batch.act, dtype=torch.long, device=a.device)
loss = F.nll_loss(a, a_)
loss.backward()
self.optim.step()
return {'loss': loss.item()}