Trinkle23897 64bab0b6a0 ddpg
2020-03-18 21:45:41 +08:00

25 lines
509 B
Python

from torch import nn
from abc import ABC, abstractmethod
class BasePolicy(ABC, nn.Module):
"""docstring for BasePolicy"""
def __init__(self):
super().__init__()
def process_fn(self, batch, buffer, indice):
return batch
@abstractmethod
def __call__(self, batch, state=None):
# return Batch(logits=..., act=..., state=None, ...)
pass
@abstractmethod
def learn(self, batch, batch_size=None):
pass
def sync_weight(self):
pass