97 lines
3.3 KiB
Python
97 lines
3.3 KiB
Python
import torch
|
|
import numpy as np
|
|
from copy import deepcopy
|
|
import torch.nn.functional as F
|
|
|
|
from tianshou.data import Batch
|
|
from tianshou.policy import BasePolicy
|
|
|
|
|
|
class DQNPolicy(BasePolicy):
|
|
"""docstring for DQNPolicy"""
|
|
|
|
def __init__(self, model, optim,
|
|
discount_factor=0.99,
|
|
estimation_step=1,
|
|
use_target_network=True):
|
|
super().__init__()
|
|
self.model = model
|
|
self.optim = optim
|
|
self.eps = 0
|
|
assert 0 < discount_factor <= 1, 'discount_factor should in (0, 1]'
|
|
self._gamma = discount_factor
|
|
assert estimation_step > 0, 'estimation_step should greater than 0'
|
|
self._n_step = estimation_step
|
|
self._target = use_target_network
|
|
if use_target_network:
|
|
self.model_old = deepcopy(self.model)
|
|
self.model_old.eval()
|
|
|
|
def set_eps(self, eps):
|
|
self.eps = eps
|
|
|
|
def train(self):
|
|
self.training = True
|
|
self.model.train()
|
|
|
|
def eval(self):
|
|
self.training = False
|
|
self.model.eval()
|
|
|
|
def sync_weight(self):
|
|
if self._target:
|
|
self.model_old.load_state_dict(self.model.state_dict())
|
|
|
|
def process_fn(self, batch, buffer, indice):
|
|
returns = np.zeros_like(indice)
|
|
gammas = np.zeros_like(indice) + self._n_step
|
|
for n in range(self._n_step - 1, -1, -1):
|
|
now = (indice + n) % len(buffer)
|
|
gammas[buffer.done[now] > 0] = n
|
|
returns[buffer.done[now] > 0] = 0
|
|
returns = buffer.rew[now] + self._gamma * returns
|
|
terminal = (indice + self._n_step - 1) % len(buffer)
|
|
if self._target:
|
|
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
|
a = self(buffer[terminal], input='obs_next', eps=0).act
|
|
target_q = self(
|
|
buffer[terminal], model='model_old', input='obs_next').logits
|
|
if isinstance(target_q, torch.Tensor):
|
|
target_q = target_q.detach().cpu().numpy()
|
|
target_q = target_q[np.arange(len(a)), a]
|
|
else:
|
|
target_q = self(buffer[terminal], input='obs_next').logits
|
|
if isinstance(target_q, torch.Tensor):
|
|
target_q = target_q.detach().cpu().numpy()
|
|
target_q = target_q.max(axis=1)
|
|
target_q[gammas != self._n_step] = 0
|
|
returns += (self._gamma ** gammas) * target_q
|
|
batch.update(returns=returns)
|
|
return batch
|
|
|
|
def __call__(self, batch, state=None,
|
|
model='model', input='obs', eps=None):
|
|
model = getattr(self, model)
|
|
obs = getattr(batch, input)
|
|
q, h = model(obs, state=state, info=batch.info)
|
|
act = q.max(dim=1)[1].detach().cpu().numpy()
|
|
# add eps to act
|
|
if eps is None:
|
|
eps = self.eps
|
|
for i in range(len(q)):
|
|
if np.random.rand() < eps:
|
|
act[i] = np.random.randint(q.shape[1])
|
|
return Batch(logits=q, act=act, state=h)
|
|
|
|
def learn(self, batch, batch_size=None):
|
|
self.optim.zero_grad()
|
|
q = self(batch).logits
|
|
q = q[np.arange(len(q)), batch.act]
|
|
r = batch.returns
|
|
if isinstance(r, np.ndarray):
|
|
r = torch.tensor(r, device=q.device, dtype=q.dtype)
|
|
loss = F.mse_loss(q, r)
|
|
loss.backward()
|
|
self.optim.step()
|
|
return {'loss': loss.detach().cpu().numpy()}
|