2020-03-14 21:48:31 +08:00

34 lines
968 B
Python

import torch
from torch import nn
from copy import deepcopy
from tianshou.data import Batch
from tianshou.policy import BasePolicy
class DQNPolicy(BasePolicy, nn.Module):
"""docstring for DQNPolicy"""
def __init__(self, model, discount_factor=0.99, estimation_step=1,
use_target_network=True):
super().__init__()
self.model = model
self._gamma = discount_factor
self._n_step = estimation_step
self._target = use_target_network
if use_target_network:
self.model_old = deepcopy(self.model)
def act(self, batch, hidden_state=None):
batch_result = Batch()
return batch_result
def sync_weights(self):
if self._use_target_network:
for old, new in zip(
self.model_old.parameters(), self.model.parameters()):
old.data.copy_(new.data)
def process_fn(self, batch, buffer, indice):
return batch