imitation with discrete action space
This commit is contained in:
parent
6bf1ea644d
commit
815f3522bb
@ -121,7 +121,7 @@ def test_sac_with_il(args=get_args()):
|
|||||||
net = Actor(1, args.state_shape, args.action_shape,
|
net = Actor(1, args.state_shape, args.action_shape,
|
||||||
args.max_action, args.device).to(args.device)
|
args.max_action, args.device).to(args.device)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
||||||
il_policy = ImitationPolicy(net, optim)
|
il_policy = ImitationPolicy(net, optim, mode='continuous')
|
||||||
il_test_collector = Collector(il_policy, test_envs)
|
il_test_collector = Collector(il_policy, test_envs)
|
||||||
train_collector.reset()
|
train_collector.reset()
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
|
@ -6,10 +6,10 @@ import argparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import A2CPolicy
|
|
||||||
from tianshou.env import VectorEnv
|
from tianshou.env import VectorEnv
|
||||||
from tianshou.trainer import onpolicy_trainer
|
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
from tianshou.policy import A2CPolicy, ImitationPolicy
|
||||||
|
from tianshou.trainer import onpolicy_trainer, offpolicy_trainer
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from net import Net, Actor, Critic
|
from net import Net, Actor, Critic
|
||||||
@ -23,6 +23,7 @@ def get_args():
|
|||||||
parser.add_argument('--seed', type=int, default=1626)
|
parser.add_argument('--seed', type=int, default=1626)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=3e-4)
|
parser.add_argument('--lr', type=float, default=3e-4)
|
||||||
|
parser.add_argument('--il-lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--epoch', type=int, default=10)
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
@ -95,7 +96,6 @@ def test_a2c(args=get_args()):
|
|||||||
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
|
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
|
||||||
writer=writer)
|
writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
@ -106,6 +106,31 @@ def test_a2c(args=get_args()):
|
|||||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||||
collector.close()
|
collector.close()
|
||||||
|
|
||||||
|
# here we define an imitation collector with a trivial policy
|
||||||
|
if args.task == 'Pendulum-v0':
|
||||||
|
env.spec.reward_threshold = -300 # lower the goal
|
||||||
|
net = Net(1, args.state_shape, device=args.device)
|
||||||
|
net = Actor(net, args.action_shape).to(args.device)
|
||||||
|
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
||||||
|
il_policy = ImitationPolicy(net, optim, mode='discrete')
|
||||||
|
il_test_collector = Collector(il_policy, test_envs)
|
||||||
|
train_collector.reset()
|
||||||
|
result = offpolicy_trainer(
|
||||||
|
il_policy, train_collector, il_test_collector, args.epoch,
|
||||||
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
|
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||||
|
assert stop_fn(result['best_reward'])
|
||||||
|
train_collector.close()
|
||||||
|
il_test_collector.close()
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pprint.pprint(result)
|
||||||
|
# Let's watch its performance!
|
||||||
|
env = gym.make(args.task)
|
||||||
|
collector = Collector(il_policy, env)
|
||||||
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
|
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||||
|
collector.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_a2c()
|
test_a2c()
|
@ -167,5 +167,5 @@ def test_pg(args=get_args()):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_fn()
|
# test_fn()
|
||||||
test_pg()
|
test_pg()
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from tianshou.policy.base import BasePolicy
|
from tianshou.policy.base import BasePolicy
|
||||||
from tianshou.policy.imitation import ImitationPolicy
|
from tianshou.policy.imitation.base import ImitationPolicy
|
||||||
from tianshou.policy.modelfree.dqn import DQNPolicy
|
from tianshou.policy.modelfree.dqn import DQNPolicy
|
||||||
from tianshou.policy.modelfree.pg import PGPolicy
|
from tianshou.policy.modelfree.pg import PGPolicy
|
||||||
from tianshou.policy.modelfree.a2c import A2CPolicy
|
from tianshou.policy.modelfree.a2c import A2CPolicy
|
||||||
|
@ -1,36 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from tianshou.data import Batch
|
|
||||||
from tianshou.policy import BasePolicy
|
|
||||||
|
|
||||||
|
|
||||||
class ImitationPolicy(BasePolicy):
|
|
||||||
"""Implementation of vanilla imitation learning (for continuous action space).
|
|
||||||
|
|
||||||
:param torch.nn.Module model: a model following the rules in
|
|
||||||
:class:`~tianshou.policy.BasePolicy`. (s -> a)
|
|
||||||
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
|
||||||
|
|
||||||
.. seealso::
|
|
||||||
|
|
||||||
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
|
||||||
explanation.
|
|
||||||
"""
|
|
||||||
def __init__(self, model, optim):
|
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
|
||||||
self.optim = optim
|
|
||||||
|
|
||||||
def forward(self, batch, state=None):
|
|
||||||
a, h = self.model(batch.obs, state=state, info=batch.info)
|
|
||||||
return Batch(act=a, state=h)
|
|
||||||
|
|
||||||
def learn(self, batch, **kwargs):
|
|
||||||
self.optim.zero_grad()
|
|
||||||
a = self(batch).act
|
|
||||||
a_ = torch.tensor(batch.act, dtype=torch.float, device=a.device)
|
|
||||||
loss = F.mse_loss(a, a_)
|
|
||||||
loss.backward()
|
|
||||||
self.optim.step()
|
|
||||||
return {'loss': loss.item()}
|
|
50
tianshou/policy/imitation/base.py
Normal file
50
tianshou/policy/imitation/base.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from tianshou.data import Batch
|
||||||
|
from tianshou.policy import BasePolicy
|
||||||
|
|
||||||
|
|
||||||
|
class ImitationPolicy(BasePolicy):
|
||||||
|
"""Implementation of vanilla imitation learning (for continuous action space).
|
||||||
|
|
||||||
|
:param torch.nn.Module model: a model following the rules in
|
||||||
|
:class:`~tianshou.policy.BasePolicy`. (s -> a)
|
||||||
|
:param torch.optim.Optimizer optim: for optimizing the model.
|
||||||
|
:param str mode: indicate the imitation type ("continuous" or "discrete"
|
||||||
|
action space), defaults to "continuous".
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||||
|
explanation.
|
||||||
|
"""
|
||||||
|
def __init__(self, model, optim, mode='continuous'):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.optim = optim
|
||||||
|
assert mode in ['continuous', 'discrete'], \
|
||||||
|
f'Mode {mode} is not in ["continuous", "discrete"]'
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def forward(self, batch, state=None):
|
||||||
|
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
||||||
|
if self.mode == 'discrete':
|
||||||
|
a = logits.max(dim=1)[1]
|
||||||
|
else:
|
||||||
|
a = logits
|
||||||
|
return Batch(logits=logits, act=a, state=h)
|
||||||
|
|
||||||
|
def learn(self, batch, **kwargs):
|
||||||
|
self.optim.zero_grad()
|
||||||
|
if self.mode == 'continuous':
|
||||||
|
a = self(batch).act
|
||||||
|
a_ = torch.tensor(batch.act, dtype=torch.float, device=a.device)
|
||||||
|
loss = F.mse_loss(a, a_)
|
||||||
|
elif self.mode == 'discrete': # classification
|
||||||
|
a = self(batch).logits
|
||||||
|
a_ = torch.tensor(batch.act, dtype=torch.long, device=a.device)
|
||||||
|
loss = F.nll_loss(a, a_)
|
||||||
|
loss.backward()
|
||||||
|
self.optim.step()
|
||||||
|
return {'loss': loss.item()}
|
Loading…
x
Reference in New Issue
Block a user