34 lines
968 B
Python
34 lines
968 B
Python
import torch
|
|
from torch import nn
|
|
from copy import deepcopy
|
|
|
|
from tianshou.data import Batch
|
|
from tianshou.policy import BasePolicy
|
|
|
|
|
|
class DQNPolicy(BasePolicy, nn.Module):
|
|
"""docstring for DQNPolicy"""
|
|
|
|
def __init__(self, model, discount_factor=0.99, estimation_step=1,
|
|
use_target_network=True):
|
|
super().__init__()
|
|
self.model = model
|
|
self._gamma = discount_factor
|
|
self._n_step = estimation_step
|
|
self._target = use_target_network
|
|
if use_target_network:
|
|
self.model_old = deepcopy(self.model)
|
|
|
|
def act(self, batch, hidden_state=None):
|
|
batch_result = Batch()
|
|
return batch_result
|
|
|
|
def sync_weights(self):
|
|
if self._use_target_network:
|
|
for old, new in zip(
|
|
self.model_old.parameters(), self.model.parameters()):
|
|
old.data.copy_(new.data)
|
|
|
|
def process_fn(self, batch, buffer, indice):
|
|
return batch
|