99 lines
3.7 KiB
Python
Raw Normal View History

2020-03-19 17:23:46 +08:00
import torch
2020-03-20 19:52:29 +08:00
import numpy as np
2020-03-19 17:23:46 +08:00
from torch import nn
2020-03-20 19:52:29 +08:00
from copy import deepcopy
2020-03-19 17:23:46 +08:00
import torch.nn.functional as F
from tianshou.data import Batch
from tianshou.policy import PGPolicy
class PPOPolicy(PGPolicy):
"""docstring for PPOPolicy"""
2020-03-20 19:52:29 +08:00
def __init__(self, actor, critic, optim, dist_fn,
discount_factor=0.99,
max_grad_norm=.5,
eps_clip=.2,
vf_coef=.5,
ent_coef=.0,
action_range=None):
2020-03-19 17:23:46 +08:00
super().__init__(None, None, dist_fn, discount_factor)
2020-03-20 19:52:29 +08:00
self._max_grad_norm = max_grad_norm
2020-03-19 17:23:46 +08:00
self._eps_clip = eps_clip
2020-03-20 19:52:29 +08:00
self._w_vf = vf_coef
self._w_ent = ent_coef
self._range = action_range
self.actor, self.actor_old = actor, deepcopy(actor)
self.actor_old.eval()
self.critic, self.critic_old = critic, deepcopy(critic)
self.critic_old.eval()
self.optim = optim
2020-03-19 17:23:46 +08:00
2020-03-20 19:52:29 +08:00
def train(self):
self.training = True
self.actor.train()
self.critic.train()
def eval(self):
self.training = False
self.actor.eval()
self.critic.eval()
def __call__(self, batch, state=None, model='actor'):
model = getattr(self, model)
logits, h = model(batch.obs, state=state, info=batch.info)
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else:
dist = self.dist_fn(logits)
2020-03-19 17:23:46 +08:00
act = dist.sample()
2020-03-20 19:52:29 +08:00
if self._range:
act = act.clamp(self._range[0], self._range[1])
return Batch(logits=logits, act=act, state=h, dist=dist)
def sync_weight(self):
self.actor_old.load_state_dict(self.actor.state_dict())
self.critic_old.load_state_dict(self.critic.state_dict())
def learn(self, batch, batch_size=None, repeat=1):
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
2020-03-26 11:42:34 +08:00
r = batch.returns
batch.returns = (r - r.mean()) / (r.std() + self._eps)
2020-03-20 19:52:29 +08:00
batch.act = torch.tensor(batch.act)
batch.returns = torch.tensor(batch.returns)[:, None]
for _ in range(repeat):
for b in batch.split(batch_size):
vs_old, vs__old = self.critic_old(np.concatenate([
b.obs, b.obs_next])).split(b.obs.shape[0])
dist = self(b).dist
dist_old = self(b, model='actor_old').dist
target_v = b.returns.to(vs__old.device) + self._gamma * vs__old
adv = (target_v - vs_old).detach()
a = b.act.to(adv.device)
ratio = torch.exp(dist.log_prob(a) - dist_old.log_prob(a))
surr1 = ratio * adv
surr2 = ratio.clamp(
1. - self._eps_clip, 1. + self._eps_clip) * adv
clip_loss = -torch.min(surr1, surr2).mean()
clip_losses.append(clip_loss.detach().cpu().numpy())
vf_loss = F.smooth_l1_loss(self.critic(b.obs), target_v)
vf_losses.append(vf_loss.detach().cpu().numpy())
2020-03-26 11:42:34 +08:00
e_loss = dist.entropy().mean()
ent_losses.append(e_loss.detach().cpu().numpy())
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss
2020-03-20 19:52:29 +08:00
losses.append(loss.detach().cpu().numpy())
self.optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(list(
self.actor.parameters()) + list(self.critic.parameters()),
2020-03-26 11:42:34 +08:00
self._max_grad_norm)
2020-03-20 19:52:29 +08:00
self.optim.step()
self.sync_weight()
return {
'loss': losses,
'loss/clip': clip_losses,
'loss/vf': vf_losses,
'loss/ent': ent_losses,
}