32 lines
1.0 KiB
Python
Raw Normal View History

from __future__ import absolute_import
import tensorflow as tf
2018-05-20 22:36:04 +08:00
__all__ = []
class ValueFunctionBase(object):
"""
2018-04-12 21:10:50 +08:00
Base class for value functions, including S-values and Q-values. The only
mandatory method for a value function class is:
:func:`eval_value`, which runs the graph and evaluates the corresponding value.
:param value_tensor: a Tensor. The tensor of V(s) or Q(s, a).
:param observation_placeholder: a :class:`tf.placeholder`. The observation placeholder of the network graph.
"""
def __init__(self, value_tensor, observation_placeholder):
self.observation_placeholder = observation_placeholder
2018-04-12 21:10:50 +08:00
self._value_tensor = tf.squeeze(value_tensor) # canonical value has shape (batchsize, )
2017-12-23 17:25:16 +08:00
def eval_value(self, **kwargs):
"""
2018-04-12 21:10:50 +08:00
Runs the graph and evaluates the corresponding value.
"""
raise NotImplementedError()
2017-12-23 17:25:16 +08:00
@property
def value_tensor(self):
2018-04-12 21:10:50 +08:00
"""Tensor of the corresponding value"""
return self._value_tensor