From 815f3522bbe2a116176082835c7d677326666afa Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 20 Apr 2020 11:25:20 +0800 Subject: [PATCH] imitation with discrete action space --- test/continuous/test_sac_with_il.py | 2 +- .../{test_a2c.py => test_a2c_with_il.py} | 31 ++++++++++-- test/discrete/test_pg.py | 2 +- tianshou/policy/__init__.py | 2 +- tianshou/policy/imitation.py | 36 ------------- tianshou/policy/imitation/base.py | 50 +++++++++++++++++++ 6 files changed, 81 insertions(+), 42 deletions(-) rename test/discrete/{test_a2c.py => test_a2c_with_il.py} (76%) delete mode 100644 tianshou/policy/imitation.py create mode 100644 tianshou/policy/imitation/base.py diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 0d39203..65d9e47 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -121,7 +121,7 @@ def test_sac_with_il(args=get_args()): net = Actor(1, args.state_shape, args.action_shape, args.max_action, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) - il_policy = ImitationPolicy(net, optim) + il_policy = ImitationPolicy(net, optim, mode='continuous') il_test_collector = Collector(il_policy, test_envs) train_collector.reset() result = offpolicy_trainer( diff --git a/test/discrete/test_a2c.py b/test/discrete/test_a2c_with_il.py similarity index 76% rename from test/discrete/test_a2c.py rename to test/discrete/test_a2c_with_il.py index 0c167fa..065c4b6 100644 --- a/test/discrete/test_a2c.py +++ b/test/discrete/test_a2c_with_il.py @@ -6,10 +6,10 @@ import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import A2CPolicy from tianshou.env import VectorEnv -from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer +from tianshou.policy import A2CPolicy, ImitationPolicy +from tianshou.trainer import onpolicy_trainer, offpolicy_trainer if __name__ == '__main__': from net import Net, Actor, Critic @@ -23,6 +23,7 @@ def get_args(): parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=3e-4) + parser.add_argument('--il-lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) @@ -95,7 +96,6 @@ def test_a2c(args=get_args()): args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) - train_collector.close() test_collector.close() if __name__ == '__main__': pprint.pprint(result) @@ -106,6 +106,31 @@ def test_a2c(args=get_args()): print(f'Final reward: {result["rew"]}, length: {result["len"]}') collector.close() + # here we define an imitation collector with a trivial policy + if args.task == 'Pendulum-v0': + env.spec.reward_threshold = -300 # lower the goal + net = Net(1, args.state_shape, device=args.device) + net = Actor(net, args.action_shape).to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) + il_policy = ImitationPolicy(net, optim, mode='discrete') + il_test_collector = Collector(il_policy, test_envs) + train_collector.reset() + result = offpolicy_trainer( + il_policy, train_collector, il_test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + assert stop_fn(result['best_reward']) + train_collector.close() + il_test_collector.close() + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + collector = Collector(il_policy, env) + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + collector.close() + if __name__ == '__main__': test_a2c() diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 9b9e955..c1ba4be 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -167,5 +167,5 @@ def test_pg(args=get_args()): if __name__ == '__main__': - test_fn() + # test_fn() test_pg() diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 39a861f..37f11e9 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -1,5 +1,5 @@ from tianshou.policy.base import BasePolicy -from tianshou.policy.imitation import ImitationPolicy +from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.pg import PGPolicy from tianshou.policy.modelfree.a2c import A2CPolicy diff --git a/tianshou/policy/imitation.py b/tianshou/policy/imitation.py deleted file mode 100644 index bdea50a..0000000 --- a/tianshou/policy/imitation.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -import torch.nn.functional as F - -from tianshou.data import Batch -from tianshou.policy import BasePolicy - - -class ImitationPolicy(BasePolicy): - """Implementation of vanilla imitation learning (for continuous action space). - - :param torch.nn.Module model: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> a) - :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - def __init__(self, model, optim): - super().__init__() - self.model = model - self.optim = optim - - def forward(self, batch, state=None): - a, h = self.model(batch.obs, state=state, info=batch.info) - return Batch(act=a, state=h) - - def learn(self, batch, **kwargs): - self.optim.zero_grad() - a = self(batch).act - a_ = torch.tensor(batch.act, dtype=torch.float, device=a.device) - loss = F.mse_loss(a, a_) - loss.backward() - self.optim.step() - return {'loss': loss.item()} diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py new file mode 100644 index 0000000..10bcef6 --- /dev/null +++ b/tianshou/policy/imitation/base.py @@ -0,0 +1,50 @@ +import torch +import torch.nn.functional as F + +from tianshou.data import Batch +from tianshou.policy import BasePolicy + + +class ImitationPolicy(BasePolicy): + """Implementation of vanilla imitation learning (for continuous action space). + + :param torch.nn.Module model: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> a) + :param torch.optim.Optimizer optim: for optimizing the model. + :param str mode: indicate the imitation type ("continuous" or "discrete" + action space), defaults to "continuous". + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + def __init__(self, model, optim, mode='continuous'): + super().__init__() + self.model = model + self.optim = optim + assert mode in ['continuous', 'discrete'], \ + f'Mode {mode} is not in ["continuous", "discrete"]' + self.mode = mode + + def forward(self, batch, state=None): + logits, h = self.model(batch.obs, state=state, info=batch.info) + if self.mode == 'discrete': + a = logits.max(dim=1)[1] + else: + a = logits + return Batch(logits=logits, act=a, state=h) + + def learn(self, batch, **kwargs): + self.optim.zero_grad() + if self.mode == 'continuous': + a = self(batch).act + a_ = torch.tensor(batch.act, dtype=torch.float, device=a.device) + loss = F.mse_loss(a, a_) + elif self.mode == 'discrete': # classification + a = self(batch).logits + a_ = torch.tensor(batch.act, dtype=torch.long, device=a.device) + loss = F.nll_loss(a, a_) + loss.backward() + self.optim.step() + return {'loss': loss.item()}